diff options
author | vvvv <vvvv@ydb.tech> | 2023-03-22 11:52:37 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2023-03-22 11:52:37 +0300 |
commit | 15168eacfd3bc1a3ddeea8b110d15f39d8314ce0 (patch) | |
tree | 2028aaa30df33666c868e8acd08413258993c1eb | |
parent | 144d85a4b3417951a8565c3a55e630cf71ef8d00 (diff) | |
download | ydb-15168eacfd3bc1a3ddeea8b110d15f39d8314ce0.tar.gz |
YQL-9511 block version of Nth/AsTuple
15 files changed, 379 insertions, 5 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 1c9ab1d94c..f337e9650f 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -4944,16 +4944,22 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo return true; } - if (!node->IsCallable()) { + if (!node->IsList() && !node->IsCallable()) { + return true; + } + + if (node->IsList() && !node->GetTypeAnn()->IsComputable()) { return true; } TExprNode::TListType funcArgs; std::string_view arrowFunctionName; - if (node->IsCallable({"And", "Or", "Xor", "Not", "Coalesce", "If", "Just"})) + if (node->IsList() || node->IsCallable({"And", "Or", "Xor", "Not", "Coalesce", "If", "Just", "Nth"})) { for (auto& child : node->ChildrenList()) { - if (child->IsComplete()) { + if (!child->GetTypeAnn()->IsComputable()) { + funcArgs.push_back(child); + } else if (child->IsComplete()) { funcArgs.push_back(ctx.NewCallable(node->Pos(), "AsScalar", { child })); } else if (auto rit = rewrites.find(child.Get()); rit != rewrites.end()) { funcArgs.push_back(rit->second); @@ -4962,7 +4968,7 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo } } - TString blockFuncName = TString("Block") + node->Content(); + TString blockFuncName = TString("Block") + (node->IsList() ? "AsTuple" : node->Content()); if (node->IsCallable({"And", "Or", "Xor"}) && funcArgs.size() > 2) { // Split original argument list by pairs (since the order is not important balanced tree is used) rewrites[node.Get()] = SplitByPairs(node->Pos(), blockFuncName, funcArgs, 0, funcArgs.size(), ctx); diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index 7cde42795e..f89e3a260f 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -256,6 +256,109 @@ IGraphTransformer::TStatus BlockJustWrapper(const TExprNode::TPtr& input, TExprN return IGraphTransformer::TStatus::Ok; } +IGraphTransformer::TStatus BlockAsTupleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); + if (!EnsureMinArgsCount(*input, 1, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + TTypeAnnotationNode::TListType items; + bool onlyScalars = true; + for (const auto& child : input->Children()) { + if (!EnsureBlockOrScalarType(*child, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + bool isScalar; + const TTypeAnnotationNode* blockItemType = GetBlockItemType(*child->GetTypeAnn(), isScalar); + onlyScalars = onlyScalars && isScalar; + items.push_back(blockItemType); + } + + const TTypeAnnotationNode* resultType = ctx.Expr.MakeType<TTupleExprType>(items); + if (onlyScalars) { + resultType = ctx.Expr.MakeType<TScalarExprType>(resultType); + } else { + resultType = ctx.Expr.MakeType<TBlockExprType>(resultType); + } + + input->SetTypeAnn(resultType); + return IGraphTransformer::TStatus::Ok; +} + +IGraphTransformer::TStatus BlockNthWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); + if (!EnsureArgsCount(*input, 2, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto child = input->Child(0); + if (!EnsureBlockOrScalarType(*child, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + bool isScalar; + const TTypeAnnotationNode* blockItemType = GetBlockItemType(*child->GetTypeAnn(), isScalar); + const TTypeAnnotationNode* resultType; + if (IsNull(*blockItemType)) { + resultType = blockItemType; + } else { + const TTupleExprType* tupleType; + bool isOptional; + if (blockItemType->GetKind() == ETypeAnnotationKind::Optional) { + auto itemType = blockItemType->Cast<TOptionalExprType>()->GetItemType(); + if (!EnsureTupleType(child->Pos(), *itemType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + tupleType = itemType->Cast<TTupleExprType>(); + isOptional = true; + } + else { + if (!EnsureTupleType(child->Pos(), *blockItemType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + tupleType = blockItemType->Cast<TTupleExprType>(); + isOptional = false; + } + + if (!EnsureAtom(input->Tail(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + ui32 index = 0; + if (!TryFromString(input->Tail().Content(), index)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() << "Failed to convert to integer: " << input->Tail().Content())); + return IGraphTransformer::TStatus::Error; + } + + if (index >= tupleType->GetSize()) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() << "Index out of range. Index: " << + index << ", size: " << tupleType->GetSize())); + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureComputableType(input->Head().Pos(), *tupleType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + resultType = tupleType->GetItems()[index]; + if (isOptional && !resultType->IsOptionalOrNull()) { + resultType = ctx.Expr.MakeType<TOptionalExprType>(resultType); + } + } + + if (isScalar) { + resultType = ctx.Expr.MakeType<TScalarExprType>(resultType); + } else { + resultType = ctx.Expr.MakeType<TBlockExprType>(resultType); + } + + input->SetTypeAnn(resultType); + return IGraphTransformer::TStatus::Ok; +} + IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { Y_UNUSED(output); if (!EnsureMinArgsCount(*input, 2U, ctx.Expr)) { diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.h b/ydb/library/yql/core/type_ann/type_ann_blocks.h index 0c32ae6bb4..35f0c0d026 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.h +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.h @@ -15,6 +15,8 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus BlockLogicalWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus BlockIfWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus BlockJustWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus BlockAsTupleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus BlockNthWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus BlockCombineAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 2b3cdcd66e..5b4c37bdb4 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -1652,7 +1652,7 @@ namespace NTypeAnnImpl { } input->SetTypeAnn(tupleType->GetItems()[index]); - if (isOptional && input->GetTypeAnn()->GetKind() != ETypeAnnotationKind::Optional && input->GetTypeAnn()->GetKind() != ETypeAnnotationKind::Null) { + if (isOptional && !input->GetTypeAnn()->IsOptionalOrNull()) { input->SetTypeAnn(ctx.Expr.MakeType<TOptionalExprType>(input->GetTypeAnn())); } @@ -11883,6 +11883,8 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["BlockNot"] = &BlockLogicalWrapper; Functions["BlockIf"] = &BlockIfWrapper; Functions["BlockJust"] = &BlockJustWrapper; + Functions["BlockAsTuple"] = &BlockAsTupleWrapper; + Functions["BlockNth"] = &BlockNthWrapper; ExtFunctions["BlockFunc"] = &BlockFuncWrapper; ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper; diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin-x86_64.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin-x86_64.txt index 52aec2119d..1e05662e53 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin-x86_64.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin-x86_64.txt @@ -53,6 +53,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_reader.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_top.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_callable.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_chain_map.cpp diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt index 5819bbc6d3..078a9ba89d 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt @@ -54,6 +54,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_reader.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_top.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_callable.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_chain_map.cpp diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-x86_64.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-x86_64.txt index 5819bbc6d3..078a9ba89d 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-x86_64.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-x86_64.txt @@ -54,6 +54,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_reader.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_top.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_callable.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_chain_map.cpp diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.windows-x86_64.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.windows-x86_64.txt index 52aec2119d..1e05662e53 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.windows-x86_64.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.windows-x86_64.txt @@ -53,6 +53,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_reader.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_top.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_callable.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_chain_map.cpp diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp new file mode 100644 index 0000000000..1a97f96a20 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp @@ -0,0 +1,182 @@ +#include "mkql_block_tuple.h" +#include "mkql_block_impl.h" + +#include <ydb/library/yql/minikql/arrow/arrow_defs.h> +#include <ydb/library/yql/minikql/arrow/arrow_util.h> +#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> +#include <ydb/library/yql/minikql/mkql_node_cast.h> +#include <ydb/library/yql/minikql/mkql_node_builder.h> + +#include <arrow/util/bitmap_ops.h> + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +class TBlockAsTupleExec { +public: + TBlockAsTupleExec(const TVector<TType*>& argTypes, const std::shared_ptr<arrow::DataType>& returnArrowType) + : ArgTypes(argTypes) + , ReturnArrowType(returnArrowType) + {} + + arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { + bool allScalars = true; + size_t length = 0; + for (const auto& x : batch.values) { + if (!x.is_scalar()) { + allScalars = false; + length = x.array()->length; + break; + } + } + + if (allScalars) { + // return scalar too + std::vector<std::shared_ptr<arrow::Scalar>> arrowValue; + for (const auto& x : batch.values) { + arrowValue.emplace_back(x.scalar()); + } + + *res = arrow::Datum(std::make_shared<arrow::StructScalar>(arrowValue, ReturnArrowType)); + return arrow::Status::OK(); + } + + auto newArrayData = arrow::ArrayData::Make(ReturnArrowType, length, { nullptr }, 0, 0); + MKQL_ENSURE(ArgTypes.size() == batch.values.size(), "Mismatch batch columns"); + for (ui32 i = 0; i < batch.values.size(); ++i) { + const auto& datum = batch.values[i]; + if (datum.is_scalar()) { + // expand scalar to array + auto expandedArray = MakeArrayFromScalar(*datum.scalar(), length, AS_TYPE(TBlockType, ArgTypes[i])->GetItemType(), *ctx->memory_pool()); + newArrayData->child_data.push_back(expandedArray.array()); + } else { + newArrayData->child_data.push_back(datum.array()); + } + } + + *res = arrow::Datum(newArrayData); + return arrow::Status::OK(); + } + +private: + const TVector<TType*> ArgTypes; + const std::shared_ptr<arrow::DataType> ReturnArrowType; +}; + +class TBlockNthExec { +public: + TBlockNthExec(const std::shared_ptr<arrow::DataType>& returnArrowType, ui32 index, bool isOptional, bool needExternalOptional) + : ReturnArrowType(returnArrowType) + , Index(index) + , IsOptional(isOptional) + , NeedExternalOptional(needExternalOptional) + {} + + arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { + arrow::Datum inputDatum = batch.values[0]; + if (inputDatum.is_scalar()) { + if (inputDatum.scalar()->is_valid) { + const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*inputDatum.scalar()); + *res = arrow::Datum(structScalar.value[Index]); + } else { + *res = arrow::Datum(arrow::MakeNullScalar(ReturnArrowType)); + } + } else { + const auto& array = inputDatum.array(); + auto child = array->child_data[Index]; + if (NeedExternalOptional) { + auto newArrayData = arrow::ArrayData::Make(ReturnArrowType, array->length, { array->buffers[0] }); + newArrayData->child_data.push_back(child); + *res = arrow::Datum(newArrayData); + } else if (!IsOptional || !array->buffers[0]) { + *res = arrow::Datum(child); + } else { + auto newArrayData = child->Copy(); + if (!newArrayData->buffers[0]) { + newArrayData->buffers[0] = array->buffers[0]; + } else { + auto buffer = AllocateBitmapWithReserve(array->length + array->offset, ctx->memory_pool()); + arrow::internal::BitmapAnd(child->GetValues<uint8_t>(0, 0), child->offset, array->GetValues<uint8_t>(0, 0), array->offset, array->length, array->offset, buffer->mutable_data()); + newArrayData->buffers[0] = buffer; + } + + newArrayData->SetNullCount(arrow::kUnknownNullCount); + *res = arrow::Datum(newArrayData); + } + } + + return arrow::Status::OK(); + } + +private: + const std::shared_ptr<arrow::DataType> ReturnArrowType; + const ui32 Index; + const bool IsOptional; + const bool NeedExternalOptional; +}; + +std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockAsTupleKernel(const TVector<TType*>& argTypes, TType* resultType) { + std::shared_ptr<arrow::DataType> returnArrowType; + MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type"); + auto exec = std::make_shared<TBlockAsTupleExec>(argTypes, returnArrowType); + auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType), + [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { + return exec->Exec(ctx, batch, res); + }); + + kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; + return kernel; +} + +std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockNthKernel(const TVector<TType*>& argTypes, TType* resultType, ui32 index, + bool isOptional, bool needExternalOptional) { + std::shared_ptr<arrow::DataType> returnArrowType; + MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type"); + auto exec = std::make_shared<TBlockNthExec>(returnArrowType, index, isOptional, needExternalOptional); + auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType), + [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { + return exec->Exec(ctx, batch, res); + }); + + kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; + return kernel; +} + +} // namespace + +IComputationNode* WrapBlockAsTuple(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + TVector<IComputationNode*> argsNodes; + TVector<TType*> argsTypes; + for (ui32 i = 0; i < callable.GetInputsCount(); ++i) { + argsNodes.push_back(LocateNode(ctx.NodeLocator, callable, i)); + argsTypes.push_back(callable.GetInput(i).GetStaticType()); + } + + auto kernel = MakeBlockAsTupleKernel(argsTypes, callable.GetType()->GetReturnType()); + return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel); +} + +IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 2U, "Expected two args."); + auto input = callable.GetInput(0U); + auto blockType = AS_TYPE(TBlockType, input.GetStaticType()); + bool isOptional; + auto tupleType = AS_TYPE(TTupleType, UnpackOptional(blockType->GetItemType(), isOptional)); + auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1U)); + auto index = indexData->AsValue().Get<ui32>(); + MKQL_ENSURE(index < tupleType->GetElementsCount(), "Bad tuple index"); + auto childType = tupleType->GetElementType(index); + bool needExternalOptional = isOptional && childType->IsVariant(); + + auto tuple = LocateNode(ctx.NodeLocator, callable, 0); + + TVector<IComputationNode*> argsNodes = { tuple }; + TVector<TType*> argsTypes = { blockType }; + auto kernel = MakeBlockNthKernel(argsTypes, callable.GetType()->GetReturnType(), index, isOptional, needExternalOptional); + return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel); +} + +} +} diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.h new file mode 100644 index 0000000000..06e60f6e28 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.h @@ -0,0 +1,11 @@ +#pragma once +#include <ydb/library/yql/minikql/computation/mkql_computation_node.h> + +namespace NKikimr { +namespace NMiniKQL { + +IComputationNode* WrapBlockAsTuple(TCallable& callable, const TComputationNodeFactoryContext& ctx); +IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactoryContext& ctx); + +} +} diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index 7b51fb5398..53296f977b 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -14,6 +14,7 @@ #include "mkql_block_compress.h" #include "mkql_block_skiptake.h" #include "mkql_block_top.h" +#include "mkql_block_tuple.h" #include "mkql_callable.h" #include "mkql_chain_map.h" #include "mkql_chain1_map.h" @@ -291,6 +292,8 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"BlockNot", &WrapBlockNot}, {"BlockJust", &WrapBlockJust}, {"BlockCompress", &WrapBlockCompress}, + {"BlockAsTuple", &WrapBlockAsTuple}, + {"BlockNth", &WrapBlockNth}, {"BlockExpandChunked", &WrapBlockExpandChunked}, {"BlockCombineAll", &WrapBlockCombineAll}, {"BlockCombineHashed", &WrapBlockCombineHashed}, diff --git a/ydb/library/yql/minikql/comp_nodes/ya.make b/ydb/library/yql/minikql/comp_nodes/ya.make index 5bd47ac0d4..34648333d5 100644 --- a/ydb/library/yql/minikql/comp_nodes/ya.make +++ b/ydb/library/yql/minikql/comp_nodes/ya.make @@ -41,6 +41,8 @@ SRCS( mkql_block_skiptake.h mkql_block_top.cpp mkql_block_top.h + mkql_block_tuple.cpp + mkql_block_tuple.h mkql_blocks.cpp mkql_blocks.h mkql_callable.cpp diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index ee264e8a17..84a9071c86 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -1553,6 +1553,48 @@ TRuntimeNode TProgramBuilder::BlockCoalesce(TRuntimeNode first, TRuntimeNode sec return TRuntimeNode(callableBuilder.Build(), false); } +TRuntimeNode TProgramBuilder::BlockNth(TRuntimeNode tuple, ui32 index) { + auto blockType = AS_TYPE(TBlockType, tuple.GetStaticType()); + bool isOptional; + const auto type = AS_TYPE(TTupleType, UnpackOptional(blockType->GetItemType(), isOptional)); + + MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index << + " is not less than " << type->GetElementsCount()); + auto itemType = type->GetElementType(index); + if (isOptional && !itemType->IsOptional() && !itemType->IsNull() && !itemType->IsPg()) { + itemType = TOptionalType::Create(itemType, Env); + } + + auto returnType = NewBlockType(itemType, blockType->GetShape()); + TCallableBuilder callableBuilder(Env, __func__, returnType); + callableBuilder.Add(tuple); + callableBuilder.Add(NewDataLiteral<ui32>(index)); + return TRuntimeNode(callableBuilder.Build(), false); +} + +TRuntimeNode TProgramBuilder::BlockAsTuple(const TArrayRef<const TRuntimeNode>& args) { + MKQL_ENSURE(!args.empty(), "Expected at least one argument"); + + TBlockType::EShape resultShape = TBlockType::EShape::Scalar; + TVector<TType*> types; + for (const auto& x : args) { + auto blockType = AS_TYPE(TBlockType, x.GetStaticType()); + types.push_back(blockType->GetItemType()); + if (blockType->GetShape() == TBlockType::EShape::Many) { + resultShape = TBlockType::EShape::Many; + } + } + + auto tupleType = NewTupleType(types); + auto returnType = NewBlockType(tupleType, resultShape); + TCallableBuilder callableBuilder(Env, __func__, returnType); + for (const auto& x : args) { + callableBuilder.Add(x); + } + + return TRuntimeNode(callableBuilder.Build(), false); +} + TRuntimeNode TProgramBuilder::BlockNot(TRuntimeNode data) { auto dataType = AS_TYPE(TBlockType, data.GetStaticType()); diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index 5c6348bd2b..a54bcf14b9 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -257,6 +257,8 @@ public: TRuntimeNode BlockCompress(TRuntimeNode flow, ui32 bitmapIndex); TRuntimeNode BlockExpandChunked(TRuntimeNode flow); TRuntimeNode BlockCoalesce(TRuntimeNode first, TRuntimeNode second); + TRuntimeNode BlockNth(TRuntimeNode tuple, ui32 index); + TRuntimeNode BlockAsTuple(const TArrayRef<const TRuntimeNode>& args); //-- logical functions TRuntimeNode BlockNot(TRuntimeNode data); diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index 4847165b44..7ec16e840a 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -2464,6 +2464,21 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return ctx.ProgramBuilder.BlockBitCast(arg, targetType); }); + AddCallable("BlockNth", [](const TExprNode& node, TMkqlBuildContext& ctx) { + const auto tupleObj = MkqlBuildExpr(node.Head(), ctx); + const auto index = FromString<ui32>(node.Tail().Content()); + return ctx.ProgramBuilder.BlockNth(tupleObj, index); + }); + + AddCallable("BlockAsTuple", [](const TExprNode& node, TMkqlBuildContext& ctx) { + TVector<TRuntimeNode> args; + for (const auto& x : node.Children()) { + args.push_back(MkqlBuildExpr(*x, ctx)); + } + + return ctx.ProgramBuilder.BlockAsTuple(args); + }); + AddCallable("BlockCombineAll", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto arg = MkqlBuildExpr(*node.Child(0), ctx); std::optional<ui32> filterColumn; |