diff options
author | vvvv <vvvv@ydb.tech> | 2022-09-14 13:48:02 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-09-14 13:48:02 +0300 |
commit | d9f598eabeb666bcb34ea08f36d39503dc39a6de (patch) | |
tree | 76ec2078d5bbf02fa1c64f718c43063689814d03 | |
parent | 5756a87bc4295b39c99409b169f53cdcf1004122 (diff) | |
download | ydb-d9f598eabeb666bcb34ea08f36d39503dc39a6de.tar.gz |
generalization to BlockFunc (above computation)
-rw-r--r-- | ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp | 6 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_blocks.cpp | 22 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_blocks.h | 2 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_core.cpp | 2 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/CMakeLists.txt | 2 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp (renamed from ydb/library/yql/minikql/comp_nodes/mkql_block_add.cpp) | 27 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_block_func.h (renamed from ydb/library/yql/minikql/comp_nodes/mkql_block_add.h) | 2 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp | 4 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp | 24 | ||||
-rw-r--r-- | ydb/library/yql/minikql/mkql_program_builder.cpp | 38 | ||||
-rw-r--r-- | ydb/library/yql/minikql/mkql_program_builder.h | 2 | ||||
-rw-r--r-- | ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp | 12 | ||||
-rw-r--r-- | ydb/library/yql/providers/common/mkql/yql_type_mkql.cpp | 29 |
13 files changed, 107 insertions, 65 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 7fbc0f8bec..365c38a494 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -4297,10 +4297,10 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext& auto status = OptimizeExpr(lambda, lambda, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { if (node->IsCallable("+")) { - auto ret = ctx.RenameNode(*node, "BlockAdd"); - for (ui32 index = 0; index < ret->ChildrenSize(); ++index) { + auto ret = ctx.NewCallable(node->Pos(), "BlockFunc", { ctx.NewAtom(node->Pos(), "add"), node->ChildPtr(0), node->ChildPtr(1) }); + for (ui32 index = 0; index < node->ChildrenSize(); ++index) { if (node->Child(index)->IsComplete()) { - ret->ChildRef(index) = ctx.NewCallable(node->Pos(), "AsScalar", { node->ChildPtr(index) }); + ret->ChildRef(index + 1) = ctx.NewCallable(node->Pos(), "AsScalar", { node->ChildPtr(index) }); } } diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index a01110d80d..75c1a99053 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -26,13 +26,13 @@ IGraphTransformer::TStatus AsScalarWrapper(const TExprNode::TPtr& input, TExprNo return IGraphTransformer::TStatus::Ok; } -IGraphTransformer::TStatus BlockAddWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { +IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { Y_UNUSED(output); - if (!EnsureArgsCount(*input, 2U, ctx.Expr)) { + if (!EnsureArgsCount(*input, 3U, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - if (!EnsureBlockOrScalarType(*input->Child(0), ctx.Expr)) { + if (!EnsureAtom(*input->Child(0), ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -40,9 +40,13 @@ IGraphTransformer::TStatus BlockAddWrapper(const TExprNode::TPtr& input, TExprNo return IGraphTransformer::TStatus::Error; } + if (!EnsureBlockOrScalarType(*input->Child(2), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + bool scalarLeft, scalarRight; - auto leftItemType = GetBlockItemType(*input->Child(0)->GetTypeAnn(), scalarLeft); - auto rightItemType = GetBlockItemType(*input->Child(1)->GetTypeAnn(), scalarRight); + auto leftItemType = GetBlockItemType(*input->Child(1)->GetTypeAnn(), scalarLeft); + auto rightItemType = GetBlockItemType(*input->Child(2)->GetTypeAnn(), scalarRight); if (scalarLeft && scalarRight) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "At least one input should be a block")); return IGraphTransformer::TStatus::Error; @@ -50,21 +54,21 @@ IGraphTransformer::TStatus BlockAddWrapper(const TExprNode::TPtr& input, TExprNo bool isOptional1; const TDataExprType* dataType1; - if (!EnsureDataOrOptionalOfData(input->Child(0)->Pos(), leftItemType, isOptional1, dataType1, ctx.Expr)) { + if (!EnsureDataOrOptionalOfData(input->Child(1)->Pos(), leftItemType, isOptional1, dataType1, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - if (!EnsureSpecificDataType(input->Child(0)->Pos(), *dataType1, EDataSlot::Uint64, ctx.Expr)) { + if (!EnsureSpecificDataType(input->Child(1)->Pos(), *dataType1, EDataSlot::Uint64, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } bool isOptional2; const TDataExprType* dataType2; - if (!EnsureDataOrOptionalOfData(input->Child(1)->Pos(), rightItemType, isOptional2, dataType2, ctx.Expr)) { + if (!EnsureDataOrOptionalOfData(input->Child(2)->Pos(), rightItemType, isOptional2, dataType2, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - if (!EnsureSpecificDataType(input->Child(1)->Pos(), *dataType2, EDataSlot::Uint64, ctx.Expr)) { + if (!EnsureSpecificDataType(input->Child(2)->Pos(), *dataType2, EDataSlot::Uint64, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.h b/ydb/library/yql/core/type_ann/type_ann_blocks.h index 3ac99ffbe0..9af184a1ec 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.h +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.h @@ -9,7 +9,7 @@ namespace NYql { namespace NTypeAnnImpl { IGraphTransformer::TStatus AsScalarWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); - IGraphTransformer::TStatus BlockAddWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); } // namespace NTypeAnnImpl } // namespace NYql diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index ff859ae3b9..dabc56eee1 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11531,7 +11531,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["WideToBlocks"] = &WideToBlocksWrapper; Functions["WideFromBlocks"] = &WideFromBlocksWrapper; Functions["AsScalar"] = &AsScalarWrapper; - Functions["BlockAdd"] = &BlockAddWrapper; + Functions["BlockFunc"] = &BlockFuncWrapper; Functions["AsRange"] = &AsRangeWrapper; Functions["RangeCreate"] = &RangeCreateWrapper; diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt index 99274a6b31..1a50d6ec5d 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt @@ -33,7 +33,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_aggrcount.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_append.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_apply.cpp - ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_add.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_callable.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_chain_map.cpp diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_add.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp index 07b837747a..917678ada8 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_add.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp @@ -1,4 +1,4 @@ -#include "mkql_block_add.h" +#include "mkql_block_func.h" #include <ydb/library/yql/minikql/arrow/arrow_defs.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> @@ -18,21 +18,23 @@ namespace NMiniKQL { namespace { -class TBlockAddWrapper : public TMutableComputationNode<TBlockAddWrapper> { +class TBlockFuncWrapper : public TMutableComputationNode<TBlockFuncWrapper> { public: - TBlockAddWrapper(TComputationMutables& mutables, + TBlockFuncWrapper(TComputationMutables& mutables, + const TString& funcName, IComputationNode* leftArg, IComputationNode* rightArg, TType* leftArgType, TType* rightArgType, TType* outputType) : TMutableComputationNode(mutables) + , FuncName(funcName) , LeftArg(leftArg) , RightArg(rightArg) , LeftValueDesc(ToValueDescr(leftArgType)) , RightValueDesc(ToValueDescr(rightArgType)) , OutputValueDescr(ToValueDescr(outputType)) - , Kernel(ResolveKernel(LeftValueDesc, RightValueDesc)) + , Kernel(ResolveKernel(FuncName, LeftValueDesc, RightValueDesc)) , OutputTypeBitWidth(static_cast<const arrow::FixedWidthType&>(*OutputValueDescr.type).bit_width()) , FunctionRegistry(*arrow::compute::GetFunctionRegistry()) { @@ -93,12 +95,13 @@ private: this->DependsOn(RightArg); } - static const arrow::compute::ScalarKernel& ResolveKernel(const arrow::ValueDescr& leftArg, + static const arrow::compute::ScalarKernel& ResolveKernel(const TString& funcName, + const arrow::ValueDescr& leftArg, const arrow::ValueDescr& rightArg) { auto* functionRegistry = arrow::compute::GetFunctionRegistry(); Y_VERIFY_DEBUG(functionRegistry != nullptr); - auto function = ARROW_RESULT(functionRegistry->GetFunction("add")); + auto function = ARROW_RESULT(functionRegistry->GetFunction(funcName)); Y_VERIFY_DEBUG(function != nullptr); Y_VERIFY_DEBUG(function->kind() == arrow::compute::Function::SCALAR); @@ -126,6 +129,7 @@ private: } private: + const TString FuncName; IComputationNode* LeftArg; IComputationNode* RightArg; const arrow::ValueDescr LeftValueDesc; @@ -138,13 +142,16 @@ private: } -IComputationNode* WrapBlockAdd(TCallable& callable, const TComputationNodeFactoryContext& ctx) { +IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx) { const auto* callableType = callable.GetType(); - return new TBlockAddWrapper(ctx.Mutables, - LocateNode(ctx.NodeLocator, callable, 0), + const auto funcNameData = AS_VALUE(TDataLiteral, callable.GetInput(0)); + const auto funcName = TString(funcNameData->AsValue().AsStringRef()); + return new TBlockFuncWrapper(ctx.Mutables, + funcName, LocateNode(ctx.NodeLocator, callable, 1), - callableType->GetArgumentType(0), + LocateNode(ctx.NodeLocator, callable, 2), callableType->GetArgumentType(1), + callableType->GetArgumentType(2), callableType->GetReturnType()); } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_add.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h index bc4aa3b370..0bbd881fdf 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_add.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h @@ -5,7 +5,7 @@ namespace NKikimr { namespace NMiniKQL { -IComputationNode* WrapBlockAdd(TCallable& callable, const TComputationNodeFactoryContext& ctx); +IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index 028a0b8abb..c3d3f5c1be 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -4,7 +4,7 @@ #include "mkql_aggrcount.h" #include "mkql_append.h" #include "mkql_apply.h" -#include "mkql_block_add.h" +#include "mkql_block_func.h" #include "mkql_blocks.h" #include "mkql_callable.h" #include "mkql_chain_map.h" @@ -263,7 +263,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"FromFlow", &WrapFromFlow}, {"ToBlocks", &WrapToBlocks}, {"WideToBlocks", &WrapWideToBlocks}, - {"BlockAdd", &WrapBlockAdd}, + {"BlockFunc", &WrapBlockFunc}, {"FromBlocks", &WrapFromBlocks}, {"WideFromBlocks", &WrapWideFromBlocks}, {"AsScalar", &WrapAsScalar}, diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp index 350c2c0893..739356325e 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp @@ -101,12 +101,13 @@ Y_UNIT_TEST(TestScalar) { UNIT_ASSERT_VALUES_EQUAL(TArrowBlock::From(value).GetDatum().scalar_as<arrow::UInt64Scalar>().value, testValue); } -Y_UNIT_TEST(TestBlockAdd) { +Y_UNIT_TEST(TestBlockFunc) { TSetup<false> setup; TProgramBuilder& pb = *setup.PgmBuilder; const auto ui64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id); const auto tupleType = pb.NewTupleType({ui64Type, ui64Type}); + const auto ui64BlockType = pb.NewBlockType(ui64Type, TBlockType::EShape::Many); const auto data1 = pb.NewTuple(tupleType, {pb.NewDataLiteral<ui64>(1), pb.NewDataLiteral<ui64>(10)}); const auto data2 = pb.NewTuple(tupleType, {pb.NewDataLiteral<ui64>(2), pb.NewDataLiteral<ui64>(20)}); @@ -120,7 +121,7 @@ Y_UNIT_TEST(TestBlockAdd) { }); const auto wideBlocksFlow = pb.WideToBlocks(wideFlow); const auto sumWideFlow = pb.WideMap(wideBlocksFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode::TList { - return {pb.BlockAdd(items[0], items[1])}; + return {pb.BlockFunc("add", ui64BlockType, {items[0], items[1]})}; }); const auto sumNarrowFlow = pb.NarrowMap(sumWideFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return items[0]; @@ -143,13 +144,14 @@ Y_UNIT_TEST(TestBlockAdd) { UNIT_ASSERT(!iterator.Next(item)); } -Y_UNIT_TEST(TestBlockAddWithNullables) { +Y_UNIT_TEST(TestBlockFuncWithNullables) { TSetup<false> setup; TProgramBuilder& pb = *setup.PgmBuilder; const auto optionalUi64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id, true); const auto tupleType = pb.NewTupleType({optionalUi64Type, optionalUi64Type}); const auto emptyOptionalUi64 = pb.NewEmptyOptional(optionalUi64Type); + const auto ui64OptBlockType = pb.NewBlockType(optionalUi64Type, TBlockType::EShape::Many); const auto data1 = pb.NewTuple(tupleType, { pb.NewOptional(pb.NewDataLiteral<ui64>(1)), @@ -176,7 +178,7 @@ Y_UNIT_TEST(TestBlockAddWithNullables) { }); const auto wideBlocksFlow = pb.WideToBlocks(wideFlow); const auto sumWideFlow = pb.WideMap(wideBlocksFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode::TList { - return {pb.BlockAdd(items[0], items[1])}; + return {pb.BlockFunc("add", ui64OptBlockType, {items[0], items[1]})}; }); const auto sumNarrowFlow = pb.NarrowMap(sumWideFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return items[0]; @@ -201,11 +203,12 @@ Y_UNIT_TEST(TestBlockAddWithNullables) { UNIT_ASSERT(!iterator.Next(item)); } -Y_UNIT_TEST(TestBlockAddWithNullableScalar) { +Y_UNIT_TEST(TestBlockFuncWithNullableScalar) { TSetup<false> setup; TProgramBuilder& pb = *setup.PgmBuilder; const auto optionalUi64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id, true); + const auto ui64OptBlockType = pb.NewBlockType(optionalUi64Type, TBlockType::EShape::Many); const auto emptyOptionalUi64 = pb.NewEmptyOptional(optionalUi64Type); const auto list = pb.NewList(optionalUi64Type, { @@ -226,7 +229,7 @@ Y_UNIT_TEST(TestBlockAddWithNullableScalar) { { const auto scalar = pb.AsScalar(emptyOptionalUi64); auto iterator = map([&](TRuntimeNode item) -> TRuntimeNode { - return {pb.BlockAdd(scalar, item)}; + return {pb.BlockFunc("add", ui64OptBlockType, {scalar, item})}; }); NUdf::TUnboxedValue item; @@ -245,7 +248,7 @@ Y_UNIT_TEST(TestBlockAddWithNullableScalar) { { const auto scalar = pb.AsScalar(emptyOptionalUi64); auto iterator = map([&](TRuntimeNode item) -> TRuntimeNode { - return {pb.BlockAdd(item, scalar)}; + return {pb.BlockFunc("add", ui64OptBlockType, {item, scalar})}; }); NUdf::TUnboxedValue item; @@ -264,7 +267,7 @@ Y_UNIT_TEST(TestBlockAddWithNullableScalar) { { const auto scalar = pb.AsScalar(pb.NewDataLiteral<ui64>(100)); auto iterator = map([&](TRuntimeNode item) -> TRuntimeNode { - return {pb.BlockAdd(item, scalar)}; + return {pb.BlockFunc("add", ui64OptBlockType, {item, scalar})}; }); NUdf::TUnboxedValue item; @@ -281,11 +284,12 @@ Y_UNIT_TEST(TestBlockAddWithNullableScalar) { } } -Y_UNIT_TEST(TestBlockAddWithScalar) { +Y_UNIT_TEST(TestBlockFuncWithScalar) { TSetup<false> setup; TProgramBuilder& pb = *setup.PgmBuilder; const auto ui64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id); + const auto ui64BlockType = pb.NewBlockType(ui64Type, TBlockType::EShape::Many); const auto data1 = pb.NewDataLiteral<ui64>(10); const auto data2 = pb.NewDataLiteral<ui64>(20); @@ -297,7 +301,7 @@ Y_UNIT_TEST(TestBlockAddWithScalar) { const auto flow = pb.ToFlow(list); const auto blocksFlow = pb.ToBlocks(flow); const auto sumBlocksFlow = pb.Map(blocksFlow, [&](TRuntimeNode item) -> TRuntimeNode { - return {pb.BlockAdd(leftScalar, {pb.BlockAdd(item, rightScalar )})}; + return {pb.BlockFunc("add", ui64BlockType, { leftScalar, {pb.BlockFunc("add", ui64BlockType, { item, rightScalar } )}})}; }); const auto pgmReturn = pb.Collect(pb.FromBlocks(sumBlocksFlow)); diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 9f3f4b985b..4dfedcc84c 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -1471,30 +1471,6 @@ TRuntimeNode TProgramBuilder::AsScalar(TRuntimeNode value) { return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockAdd(TRuntimeNode arg1, TRuntimeNode arg2) { - bool arg1Optional; - auto* arg1BlockType = AS_TYPE(TBlockType, arg1.GetStaticType()); - auto* arg1Type = UnpackOptionalData(arg1BlockType->GetItemType(), arg1Optional); - - bool arg2Optional; - auto* arg2BlockType = AS_TYPE(TBlockType, arg2.GetStaticType()); - auto* arg2Type = UnpackOptionalData(arg2BlockType->GetItemType(), arg2Optional); - - MKQL_ENSURE(arg1BlockType->GetShape() != TBlockType::EShape::Scalar || - arg2BlockType->GetShape() != TBlockType::EShape::Scalar, - "At least one EShape::Many block expected"); - - auto* resultDataType = BuildArithmeticCommonType(arg1Type, arg2Type); - if (arg1Optional || arg2Optional) { - resultDataType = NewOptionalType(resultDataType); - } - auto* callableType = TCallableBuilder(Env, __func__, NewBlockType(resultDataType, TBlockType::EShape::Many)) - .Add(arg1) - .Add(arg2) - .Build(); - return TRuntimeNode(callableType, false); -} - TRuntimeNode TProgramBuilder::ListFromRange(TRuntimeNode start, TRuntimeNode end, TRuntimeNode step) { MKQL_ENSURE(start.GetStaticType()->IsData(), "Expected data"); MKQL_ENSURE(end.GetStaticType()->IsSameType(*start.GetStaticType()), "Mismatch type"); @@ -5194,6 +5170,20 @@ TRuntimeNode TProgramBuilder::PgInternal0(TType* returnType) { return TRuntimeNode(callableBuilder.Build(), false); } +TRuntimeNode TProgramBuilder::BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args) { + for (const auto& arg : args) { + MKQL_ENSURE(arg.GetStaticType()->IsBlock(), "Expected Block type"); + } + + TCallableBuilder builder(Env, __func__, returnType); + builder.Add(NewDataLiteral<NUdf::EDataSlot::String>(funcName)); + for (const auto& arg : args) { + builder.Add(arg); + } + + return TRuntimeNode(builder.Build(), false); +} + bool CanExportType(TType* type, const TTypeEnvironment& env) { if (type->GetKind() == TType::EKind::Type) { return false; // Type of Type diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index 435b32aed8..72112c9d1e 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -240,7 +240,7 @@ public: TRuntimeNode WideFromBlocks(TRuntimeNode flow); TRuntimeNode AsScalar(TRuntimeNode value); - TRuntimeNode BlockAdd(TRuntimeNode data1, TRuntimeNode data2); + TRuntimeNode BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args); // udfs TRuntimeNode Udf( diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index ea62a2b299..ce81e57817 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -467,8 +467,6 @@ TMkqlCommonCallableCompiler::TShared::TShared() { {"Div", &TProgramBuilder::Div}, {"Mod", &TProgramBuilder::Mod}, - {"BlockAdd", &TProgramBuilder::BlockAdd}, - {"DecimalMul", &TProgramBuilder::DecimalMul}, {"DecimalDiv", &TProgramBuilder::DecimalDiv}, {"DecimalMod", &TProgramBuilder::DecimalMod}, @@ -2336,6 +2334,16 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return ctx.ProgramBuilder.WithContext(input, node.Child(1)->Content()); }); + AddCallable("BlockFunc", [](const TExprNode& node, TMkqlBuildContext& ctx) { + TVector<TRuntimeNode> args; + for (ui32 i = 1; i < node.ChildrenSize(); ++i) { + args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); + } + + auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + return ctx.ProgramBuilder.BlockFunc(node.Child(0)->Content(), returnType, args); + }); + AddCallable("PgArray", [](const TExprNode& node, TMkqlBuildContext& ctx) { std::vector<TRuntimeNode> args; args.reserve(node.ChildrenSize()); diff --git a/ydb/library/yql/providers/common/mkql/yql_type_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_type_mkql.cpp index 7cfcf006c6..cbd56a87b2 100644 --- a/ydb/library/yql/providers/common/mkql/yql_type_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_type_mkql.cpp @@ -203,6 +203,24 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, NKiki return pgmBuilder.GetTypeEnvironment().GetTypeOfEmptyDict(); } + case ETypeAnnotationKind::Block: { + auto block = annotation.Cast<TBlockExprType>(); + auto itemType = BuildType(*block->GetItemType(), pgmBuilder, err, withTagged); + if (!itemType) { + return nullptr; + } + return pgmBuilder.NewBlockType(itemType, NKikimr::NMiniKQL::TBlockType::EShape::Many); + } + + case ETypeAnnotationKind::Scalar: { + auto scalar = annotation.Cast<TScalarExprType>(); + auto itemType = BuildType(*scalar->GetItemType(), pgmBuilder, err, withTagged); + if (!itemType) { + return nullptr; + } + return pgmBuilder.NewBlockType(itemType, NKikimr::NMiniKQL::TBlockType::EShape::Scalar); + } + default: return nullptr; } @@ -365,6 +383,17 @@ const TTypeAnnotationNode* ConvertMiniKQLType(TPosition position, NKikimr::NMini return ctx.MakeType<TTaggedExprType>(baseType, taggedType->GetTag()); } + case TType::EKind::Block: + { + auto blockType = static_cast<TBlockType*>(type); + auto itemType = ConvertMiniKQLType(position, blockType->GetItemType(), ctx); + if (blockType->GetShape() == NKikimr::NMiniKQL::TBlockType::EShape::Many) { + return ctx.MakeType<TBlockExprType>(itemType); + } else { + return ctx.MakeType<TScalarExprType>(itemType); + } + } + default: YQL_ENSURE(false, "Unknown kind"); } |