diff options
author | aneporada <aneporada@ydb.tech> | 2023-01-14 15:35:20 +0300 |
---|---|---|
committer | aneporada <aneporada@ydb.tech> | 2023-01-14 15:35:20 +0300 |
commit | d95441a3c516b3781878af847751cf9d143af91d (patch) | |
tree | 8871fb639ae8a7856e00538214f7099c29f85ae9 | |
parent | 3657be5988251fc9074ba5b86b62bfa985ff4643 (diff) | |
download | ydb-d95441a3c516b3781878af847751cf9d143af91d.tar.gz |
Implement BlockIf
23 files changed, 585 insertions, 223 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 63a58e9fdb..e377d1592b 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -4855,7 +4855,7 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo TExprNode::TListType funcArgs; std::string_view arrowFunctionName; - if (node->IsCallable({"And", "Or", "Xor", "Not", "Coalesce"})) + if (node->IsCallable({"And", "Or", "Xor", "Not", "Coalesce", "If"})) { for (auto& child : node->ChildrenList()) { if (child->IsComplete()) { @@ -4868,9 +4868,8 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo } TString blockFuncName = TString("Block") + node->Content(); - if (funcArgs.size() > 2) { + if (node->IsCallable({"And", "Or", "Xor"}) && funcArgs.size() > 2) { // Split original argument list by pairs (since the order is not important balanced tree is used) - // this is only supported by And/Or/Xor rewrites[node.Get()] = SplitByPairs(node->Pos(), blockFuncName, funcArgs, 0, funcArgs.size(), ctx); } else { rewrites[node.Get()] = ctx.NewCallable(node->Pos(), blockFuncName, std::move(funcArgs)); diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index 8800570e23..9984ed3766 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -189,6 +189,49 @@ IGraphTransformer::TStatus BlockLogicalWrapper(const TExprNode::TPtr& input, TEx return IGraphTransformer::TStatus::Ok; } +IGraphTransformer::TStatus BlockIfWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); + if (!EnsureArgsCount(*input, 3U, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto pred = input->Child(0); + auto thenNode = input->Child(1); + auto elseNode = input->Child(2); + + if (!EnsureBlockOrScalarType(*pred, ctx.Expr) || + !EnsureBlockOrScalarType(*thenNode, ctx.Expr) || + !EnsureBlockOrScalarType(*elseNode, ctx.Expr)) + { + return IGraphTransformer::TStatus::Error; + } + + bool predIsScalar; + const TTypeAnnotationNode* predItemType = GetBlockItemType(*pred->GetTypeAnn(), predIsScalar); + if (!EnsureSpecificDataType(pred->Pos(), *predItemType, NUdf::EDataSlot::Bool, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + bool thenIsScalar; + const TTypeAnnotationNode* thenItemType = GetBlockItemType(*thenNode->GetTypeAnn(), thenIsScalar); + + bool elseIsScalar; + const TTypeAnnotationNode* elseItemType = GetBlockItemType(*elseNode->GetTypeAnn(), elseIsScalar); + + if (!IsSameAnnotation(*thenItemType, *elseItemType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() << + "Mismatch item types: then branch is " << *thenItemType << ", else branch is " << *elseItemType)); + return IGraphTransformer::TStatus::Error; + } + + if (predIsScalar && thenIsScalar && elseIsScalar) { + input->SetTypeAnn(ctx.Expr.MakeType<TScalarExprType>(thenItemType)); + } else { + input->SetTypeAnn(ctx.Expr.MakeType<TBlockExprType>(thenItemType)); + } + return IGraphTransformer::TStatus::Ok; +} + IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { Y_UNUSED(output); if (!EnsureMinArgsCount(*input, 2U, ctx.Expr)) { diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.h b/ydb/library/yql/core/type_ann/type_ann_blocks.h index 9e2364dacf..07fed8ae83 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.h +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.h @@ -13,6 +13,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus BlockExpandChunkedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus BlockCoalesceWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus BlockLogicalWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus BlockIfWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus BlockCombineAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 11a259c61d..d54270154a 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11840,6 +11840,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["BlockOr"] = &BlockLogicalWrapper; Functions["BlockXor"] = &BlockLogicalWrapper; Functions["BlockNot"] = &BlockLogicalWrapper; + Functions["BlockIf"] = &BlockIfWrapper; ExtFunctions["BlockFunc"] = &BlockFuncWrapper; ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper; diff --git a/ydb/library/yql/minikql/arrow/arrow_util.cpp b/ydb/library/yql/minikql/arrow/arrow_util.cpp index 37ded548e4..ff8015734c 100644 --- a/ydb/library/yql/minikql/arrow/arrow_util.cpp +++ b/ydb/library/yql/minikql/arrow/arrow_util.cpp @@ -1,5 +1,9 @@ #include "arrow_util.h" #include "mkql_bit_utils.h" + +#include <arrow/array/array_base.h> +#include <arrow/chunked_array.h> + #include <ydb/library/yql/minikql/mkql_node_builder.h> #include <util/system/yassert.h> @@ -67,4 +71,30 @@ std::shared_ptr<arrow::Buffer> MakeDenseBitmap(const ui8* srcSparse, size_t len, return bitmap; } +void ForEachArrayData(const arrow::Datum& datum, const std::function<void(const std::shared_ptr<arrow::ArrayData>&)>& func) { + MKQL_ENSURE(datum.is_arraylike(), "Expected array"); + if (datum.is_array()) { + func(datum.array()); + } else { + for (auto& chunk : datum.chunks()) { + func(chunk->data()); + } + } +} + +arrow::Datum MakeArray(const TVector<std::shared_ptr<arrow::ArrayData>>& chunks) { + MKQL_ENSURE(!chunks.empty(), "Expected non empty chunks"); + arrow::ArrayVector resultChunks; + for (auto& chunk : chunks) { + resultChunks.push_back(arrow::Datum(chunk).make_array()); + } + + if (resultChunks.size() > 1) { + auto type = resultChunks.front()->type(); + auto chunked = ARROW_RESULT(arrow::ChunkedArray::Make(std::move(resultChunks), type)); + return arrow::Datum(chunked); + } + return arrow::Datum(resultChunks.front()); +} + } diff --git a/ydb/library/yql/minikql/arrow/arrow_util.h b/ydb/library/yql/minikql/arrow/arrow_util.h index 005b3f2938..86e0890732 100644 --- a/ydb/library/yql/minikql/arrow/arrow_util.h +++ b/ydb/library/yql/minikql/arrow/arrow_util.h @@ -28,6 +28,9 @@ inline arrow::internal::Bitmap GetBitmap(const arrow::ArrayData& arr, int index) return arrow::internal::Bitmap{ arr.buffers[index], arr.offset, arr.length }; } +void ForEachArrayData(const arrow::Datum& datum, const std::function<void(const std::shared_ptr<arrow::ArrayData>&)>& func); +arrow::Datum MakeArray(const TVector<std::shared_ptr<arrow::ArrayData>>& chunks); + template <typename T> T GetPrimitiveScalarValue(const arrow::Scalar& scalar) { return *static_cast<const T*>(dynamic_cast<const arrow::internal::PrimitiveScalarBase&>(scalar).data()); diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt index 2e6b5588c3..6cc5135bdb 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt @@ -42,6 +42,8 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_coalesce.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_logical.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt index 303a6c6b5e..87be655fd0 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt @@ -43,6 +43,8 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_coalesce.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_logical.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt index 303a6c6b5e..87be655fd0 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt @@ -43,6 +43,8 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_coalesce.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_logical.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp index a94d01c8b0..08fdd91649 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp @@ -17,29 +17,6 @@ namespace NMiniKQL { namespace { -bool AlwaysUseChunks(const TType* type) { - if (type->IsOptional()) { - return AlwaysUseChunks(AS_TYPE(TOptionalType, type)->GetItemType()); - } - - if (type->IsTuple()) { - auto tupleType = AS_TYPE(TTupleType, type); - for (ui32 i = 0; i < tupleType->GetElementsCount(); ++i) { - if (AlwaysUseChunks(tupleType->GetElementType(i))) { - return true; - } - } - return false; - } - - if (type->IsData()) { - auto slot = *AS_TYPE(TDataType, type)->GetDataSlot(); - return (GetDataTypeInfo(slot).Features & NYql::NUdf::EDataTypeFeatures::StringType) != 0u; - } - - MKQL_ENSURE(false, "Unsupported type"); -} - std::shared_ptr<arrow::DataType> GetArrowType(TType* type) { std::shared_ptr<arrow::DataType> result; Y_VERIFY(ConvertArrowType(type, result)); @@ -101,21 +78,14 @@ public: CurrLen += popCount; } - NUdf::TUnboxedValuePod Build(TComputationContext& ctx, bool finish) final { + arrow::Datum Build(bool finish) final { auto tree = BuildTree(finish); - arrow::ArrayVector chunks; + TVector<std::shared_ptr<arrow::ArrayData>> chunks; while (size_t size = CalcSliceSize(*tree)) { - std::shared_ptr<arrow::ArrayData> data = Slice(*tree, size); - chunks.push_back(arrow::Datum(data).make_array()); + chunks.push_back(Slice(*tree, size)); } - Y_VERIFY(!chunks.empty()); - - if (chunks.size() > 1 || AlwaysUseChunks(Type)) { - auto chunked = ARROW_RESULT(arrow::ChunkedArray::Make(std::move(chunks), GetArrowType(Type))); - return ctx.HolderFactory.CreateArrowBlock(std::move(chunked)); - } - return ctx.HolderFactory.CreateArrowBlock(chunks.front()); + return MakeArray(chunks); } TBlockArrayTree::Ptr BuildTree(bool finish) { diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.h index 4769c3d59f..3a937d6dd7 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.h @@ -31,7 +31,7 @@ public: virtual void Add(NUdf::TUnboxedValuePod value) = 0; virtual void Add(TBlockItem value) = 0; virtual void AddMany(const arrow::ArrayData& array, size_t popCount, const ui8* sparseBitmap, size_t bitmapSize) = 0; - virtual NUdf::TUnboxedValuePod Build(TComputationContext& ctx, bool finish) = 0; + virtual arrow::Datum Build(bool finish) = 0; }; std::unique_ptr<IBlockBuilder> MakeBlockBuilder(TType* type, arrow::MemoryPool& pool, size_t maxBlockLength); diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp index 68801afd6d..34a5d098a8 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp @@ -315,7 +315,7 @@ private: for (ui32 i = 0, outIndex = 0; i < Width_; ++i) { bool isScalar = Types_[i]->GetShape() == TBlockType::EShape::Scalar; if (i != BitmapIndex_ && output[outIndex]) { - *output[outIndex] = isScalar ? s.InputValues_[i] : s.Builders_[i]->Build(ctx, s.Finish_); + *output[outIndex] = isScalar ? s.InputValues_[i] : ctx.HolderFactory.CreateArrowBlock(s.Builders_[i]->Build(s.Finish_)); } if (i != BitmapIndex_) { outIndex++; diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp index d8b4aa8b43..d692a30464 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp @@ -1,47 +1,18 @@ #include "mkql_block_func.h" +#include "mkql_block_impl.h" #include <ydb/library/yql/minikql/arrow/arrow_defs.h> -#include <ydb/library/yql/minikql/arrow/mkql_functions.h> -#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> -#include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h> #include <ydb/library/yql/minikql/mkql_node_builder.h> #include <ydb/library/yql/minikql/mkql_node_cast.h> #include <ydb/library/yql/minikql/mkql_type_builder.h> -#include <arrow/array/builder_primitive.h> #include <arrow/compute/cast.h> -#include <arrow/compute/exec_internal.h> -#include <arrow/compute/function.h> -#include <arrow/compute/kernel.h> -#include <arrow/compute/registry.h> -#include <arrow/util/bit_util.h> namespace NKikimr { namespace NMiniKQL { namespace { -arrow::ValueDescr ToValueDescr(TType* type) { - arrow::ValueDescr ret; - MKQL_ENSURE(ConvertInputArrowType(type, ret), "can't get arrow type"); - return ret; -} - -std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) { - std::vector<arrow::ValueDescr> res; - res.reserve(types.size()); - for (const auto& type : types) { - res.emplace_back(ToValueDescr(type)); - } - - return res; -} - -const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function& function, const std::vector<arrow::ValueDescr>& args) { - const auto kernel = ARROW_RESULT(function.DispatchExact(args)); - return *static_cast<const arrow::compute::ScalarKernel*>(kernel); -} - const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes, TType* returnType) { std::vector<NUdf::TDataTypeId> argTypes; for (const auto& t : inputTypes) { @@ -64,166 +35,31 @@ const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TSt return *kernel; } -struct TState : public TComputationValue<TState> { - using TComputationValue::TComputationValue; - - TState(TMemoryUsageInfo* memInfo, const arrow::compute::FunctionOptions* options, - const arrow::compute::ScalarKernel& kernel, - const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx) - : TComputationValue(memInfo) - , Options(options) - , ExecContext(&ctx.ArrowMemoryPool, nullptr, nullptr) - , KernelContext(&ExecContext) - { - if (kernel.init) { - State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options })); - KernelContext.SetState(State.get()); - } - - Values.reserve(argsValuesDescr.size()); - } - - const arrow::compute::FunctionOptions* Options; - arrow::compute::ExecContext ExecContext; - arrow::compute::KernelContext KernelContext; - std::unique_ptr<arrow::compute::KernelState> State; - - std::vector<arrow::Datum> Values; -}; - -class TBlockFuncWrapper : public TMutableComputationNode<TBlockFuncWrapper> { -public: - TBlockFuncWrapper(TComputationMutables& mutables, - const IBuiltinFunctionRegistry& builtins, - const TString& funcName, - TVector<IComputationNode*>&& argsNodes, - TVector<TType*>&& argsTypes, - TType* returnType) - : TMutableComputationNode(mutables) - , StateIndex(mutables.CurValueIndex++) - , FuncName(funcName) - , ArgsNodes(std::move(argsNodes)) - , ArgsTypes(std::move(argsTypes)) - , ArgsValuesDescr(ToValueDescr(ArgsTypes)) - , Kernel(ResolveKernel(builtins, FuncName, ArgsTypes, returnType)) - { - } - - NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - auto& state = GetState(ctx); - - state.Values.clear(); - for (ui32 i = 0; i < ArgsNodes.size(); ++i) { - state.Values.emplace_back(TArrowBlock::From(ArgsNodes[i]->GetValue(ctx)).GetDatum()); - Y_VERIFY_DEBUG(ArgsValuesDescr[i] == state.Values.back().descr()); - } - - auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>(); - auto executor = arrow::compute::detail::KernelExecutor::MakeScalar(); - ARROW_OK(executor->Init(&state.KernelContext, { &Kernel.GetArrowKernel(), ArgsValuesDescr, state.Options })); - ARROW_OK(executor->Execute(state.Values, listener.get())); - auto output = executor->WrapResults(state.Values, listener->values()); - return ctx.HolderFactory.CreateArrowBlock(std::move(output)); - } - -private: - void RegisterDependencies() const final { - for (const auto& arg : ArgsNodes) { - this->DependsOn(arg); - } - } - - static const arrow::compute::Function& ResolveFunction(const arrow::compute::FunctionRegistry& registry, const TString& funcName) { - auto function = ARROW_RESULT(registry.GetFunction(funcName)); - MKQL_ENSURE(function != nullptr, "missing function"); - MKQL_ENSURE(function->kind() == arrow::compute::Function::SCALAR, "expected SCALAR function"); - return *function; - } - - TState& GetState(TComputationContext& ctx) const { - auto& result = ctx.MutableValues[StateIndex]; - if (!result.HasValue()) { - result = ctx.HolderFactory.Create<TState>(Kernel.Family.FunctionOptions, Kernel.GetArrowKernel(), ArgsValuesDescr, ctx); - } - - return *static_cast<TState*>(result.AsBoxed().Get()); - } - -private: - const ui32 StateIndex; - const TString FuncName; - const TVector<IComputationNode*> ArgsNodes; - const TVector<TType*> ArgsTypes; - - const std::vector<arrow::ValueDescr> ArgsValuesDescr; - const TKernel& Kernel; -}; - -class TBlockBitCastWrapper : public TMutableComputationNode<TBlockBitCastWrapper> { +class TBlockBitCastWrapper : public TBlockFuncNode { public: - TBlockBitCastWrapper(TComputationMutables& mutables, - IComputationNode* arg, - TType* argType, - TType* to) - : TMutableComputationNode(mutables) - , StateIndex(mutables.CurValueIndex++) - , Arg(arg) - , ArgsValuesDescr({ ToValueDescr(argType) }) - , Function(ResolveFunction(to)) - , Kernel(ResolveKernel(Function, ArgsValuesDescr)) + TBlockBitCastWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* argType, TType* to) + : TBlockFuncNode(mutables, { arg }, { argType }, ResolveKernel(argType, to), {}, &CastOptions) , CastOptions(false) { } - - NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - auto& state = GetState(ctx); - - state.Values.clear(); - state.Values.emplace_back(TArrowBlock::From(Arg->GetValue(ctx)).GetDatum()); - Y_VERIFY_DEBUG(ArgsValuesDescr[0] == state.Values.back().descr()); - - auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>(); - auto executor = arrow::compute::detail::KernelExecutor::MakeScalar(); - ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, state.Options })); - ARROW_OK(executor->Execute(state.Values, listener.get())); - auto output = executor->WrapResults(state.Values, listener->values()); - return ctx.HolderFactory.CreateArrowBlock(std::move(output)); - } - private: - void RegisterDependencies() const final { - this->DependsOn(Arg); - } - - static const arrow::compute::Function& ResolveFunction(TType* to) { + static const arrow::compute::ScalarKernel& ResolveKernel(TType* from, TType* to) { std::shared_ptr<arrow::DataType> type; MKQL_ENSURE(ConvertArrowType(to, type), "can't get arrow type"); auto function = ARROW_RESULT(arrow::compute::GetCastFunction(type)); MKQL_ENSURE(function != nullptr, "missing function"); MKQL_ENSURE(function->kind() == arrow::compute::Function::SCALAR, "expected SCALAR function"); - return *function; - } - - TState& GetState(TComputationContext& ctx) const { - auto& result = ctx.MutableValues[StateIndex]; - if (!result.HasValue()) { - result = ctx.HolderFactory.Create<TState>((const arrow::compute::FunctionOptions*)&CastOptions, Kernel, ArgsValuesDescr, ctx); - } - return *static_cast<TState*>(result.AsBoxed().Get()); + std::vector<arrow::ValueDescr> args = { ToValueDescr(from) }; + const auto kernel = ARROW_RESULT(function->DispatchExact(args)); + return *static_cast<const arrow::compute::ScalarKernel*>(kernel); } -private: - const ui32 StateIndex; - IComputationNode* Arg; - const std::vector<arrow::ValueDescr> ArgsValuesDescr; - const arrow::compute::Function& Function; - const arrow::compute::ScalarKernel& Kernel; - arrow::compute::CastOptions CastOptions; + const arrow::compute::CastOptions CastOptions; }; -} +} // namespace IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() >= 1, "Expected at least 1 arg"); @@ -237,13 +73,8 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto argsTypes.push_back(callableType->GetArgumentType(i)); } - return new TBlockFuncWrapper(ctx.Mutables, - *ctx.FunctionRegistry.GetBuiltins(), - funcName, - std::move(argsNodes), - std::move(argsTypes), - callableType->GetReturnType() - ); + const TKernel& kernel = ResolveKernel(*ctx.FunctionRegistry.GetBuiltins(), funcName, argsTypes, callableType->GetReturnType()); + return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, kernel.GetArrowKernel(), {}, kernel.Family.FunctionOptions); } IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFactoryContext& ctx) { diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp new file mode 100644 index 0000000000..52084e3ae6 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp @@ -0,0 +1,198 @@ +#include "mkql_block_if.h" +#include "mkql_block_impl.h" +#include "mkql_block_reader.h" +#include "mkql_block_builder.h" + +#include <ydb/library/yql/minikql/arrow/arrow_defs.h> +#include <ydb/library/yql/minikql/arrow/arrow_util.h> +#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> +#include <ydb/library/yql/minikql/mkql_node_cast.h> + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +class TBlockIfScalarWrapper : public TMutableComputationNode<TBlockIfScalarWrapper> { +public: + TBlockIfScalarWrapper(TComputationMutables& mutables, IComputationNode* pred, IComputationNode* thenNode, IComputationNode* elseNode, TType* resultType, + bool thenIsScalar, bool elseIsScalar) + : TMutableComputationNode(mutables) + , Pred(pred) + , Then(thenNode) + , Else(elseNode) + , Type(resultType) + , ThenIsScalar(thenIsScalar) + , ElseIsScalar(elseIsScalar) + { + } + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + auto predValue = Pred->GetValue(ctx); + + const bool predScalarValue = GetPrimitiveScalarValue<bool>(*TArrowBlock::From(predValue).GetDatum().scalar()); + auto result = predScalarValue ? Then->GetValue(ctx) : Else->GetValue(ctx); + + if (ThenIsScalar == ElseIsScalar || (predScalarValue ? !ThenIsScalar : !ElseIsScalar)) { + // can return result as-is + return result.Release(); + } + + auto other = predScalarValue ? Else->GetValue(ctx) : Then->GetValue(ctx); + const auto& otherDatum = TArrowBlock::From(other).GetDatum(); + MKQL_ENSURE(otherDatum.is_arraylike(), "Expecting array"); + + std::shared_ptr<arrow::Scalar> resultScalar = TArrowBlock::From(result).GetDatum().scalar(); + + TVector<std::shared_ptr<arrow::ArrayData>> resultArrays; + ForEachArrayData(otherDatum, [&](const std::shared_ptr<arrow::ArrayData>& otherData) { + auto chunk = MakeArrayFromScalar(*resultScalar, otherData->length, Type, ctx.ArrowMemoryPool); + ForEachArrayData(chunk, [&](const auto& array) { + resultArrays.push_back(array); + }); + }); + return ctx.HolderFactory.CreateArrowBlock(MakeArray(resultArrays)); + } +private: + void RegisterDependencies() const final { + DependsOn(Pred); + DependsOn(Then); + DependsOn(Else); + } + + IComputationNode* const Pred; + IComputationNode* const Then; + IComputationNode* const Else; + TType* const Type; + const bool ThenIsScalar; + const bool ElseIsScalar; +}; + +template<bool ThenIsScalar, bool ElseIsScalar> +class TIfBlockExec { +public: + explicit TIfBlockExec(TType* type) + : Type(type) + , ThenReader(MakeBlockReader(type)) + , ElseReader(MakeBlockReader(type)) + { + } + + arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { + arrow::Datum predDatum = batch.values[0]; + arrow::Datum thenDatum = batch.values[1]; + arrow::Datum elseDatum = batch.values[2]; + + TBlockItem thenItem; + const arrow::ArrayData* thenArray = nullptr; + if constexpr(ThenIsScalar) { + thenItem = ThenReader->GetScalarItem(*thenDatum.scalar()); + } else { + MKQL_ENSURE(thenDatum.is_array(), "Expecting array"); + thenArray = thenDatum.array().get(); + } + + TBlockItem elseItem; + const arrow::ArrayData* elseArray = nullptr; + if constexpr(ElseIsScalar) { + elseItem = ElseReader->GetScalarItem(*elseDatum.scalar()); + } else { + MKQL_ENSURE(elseDatum.is_array(), "Expecting array"); + elseArray = elseDatum.array().get(); + } + + MKQL_ENSURE(predDatum.is_array(), "Expecting array"); + const std::shared_ptr<arrow::ArrayData>& pred = predDatum.array(); + + const size_t len = pred->length; + auto builder = MakeBlockBuilder(Type, *ctx->memory_pool(), len); + const ui8* predValues = pred->GetValues<uint8_t>(1); + for (size_t i = 0; i < len; ++i) { + if constexpr (!ThenIsScalar) { + thenItem = ThenReader->GetItem(*thenArray, i); + } + if constexpr (!ElseIsScalar) { + elseItem = ElseReader->GetItem(*elseArray, i); + } + + ui64 mask = -ui64(predValues[i]); + + TBlockItem result; + ui64 low = (thenItem.Low() & mask) | (elseItem.Low() & ~mask); + ui64 high = (thenItem.High() & mask) | (elseItem.High() & ~mask); + builder->Add(TBlockItem{low, high}); + } + *res = builder->Build(true); + return arrow::Status::OK(); + } + +private: + const std::unique_ptr<IBlockReader> ThenReader; + const std::unique_ptr<IBlockReader> ElseReader; + TType* const Type; +}; + + +template<bool ThenIsScalar, bool ElseIsScalar> +std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockIfKernel(const TVector<TType*>& argTypes, TType* resultType) { + using TExec = TIfBlockExec<ThenIsScalar, ElseIsScalar>; + + auto exec = std::make_shared<TExec>(AS_TYPE(TBlockType, resultType)->GetItemType()); + auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType), + [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { + return exec->Exec(ctx, batch, res); + }); + + kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; + return kernel; +} + +} // namespace + +IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args"); + + auto pred = callable.GetInput(0); + auto thenNode = callable.GetInput(1); + auto elseNode = callable.GetInput(2); + + auto predType = AS_TYPE(TBlockType, pred.GetStaticType()); + MKQL_ENSURE(AS_TYPE(TDataType, predType->GetItemType())->GetSchemeType() == NUdf::TDataType<bool>::Id, + "Expected bool as first argument"); + + auto thenType = AS_TYPE(TBlockType, thenNode.GetStaticType()); + auto elseType = AS_TYPE(TBlockType, elseNode.GetStaticType()); + MKQL_ENSURE(thenType->GetItemType()->IsSameType(*elseType->GetItemType()), "Different return types in branches."); + + auto predCompute = LocateNode(ctx.NodeLocator, callable, 0); + auto thenCompute = LocateNode(ctx.NodeLocator, callable, 1); + auto elseCompute = LocateNode(ctx.NodeLocator, callable, 2); + + bool predIsScalar = predType->GetShape() == TBlockType::EShape::Scalar; + bool thenIsScalar = thenType->GetShape() == TBlockType::EShape::Scalar; + bool elseIsScalar = elseType->GetShape() == TBlockType::EShape::Scalar; + + if (predIsScalar) { + return new TBlockIfScalarWrapper(ctx.Mutables, predCompute, thenCompute, elseCompute, thenType->GetItemType(), + thenIsScalar, elseIsScalar); + } + + TVector<IComputationNode*> argsNodes = { predCompute, thenCompute, elseCompute }; + TVector<TType*> argsTypes = { predType, thenType, elseType }; + + std::shared_ptr<arrow::compute::ScalarKernel> kernel; + if (thenIsScalar && elseIsScalar) { + kernel = MakeBlockIfKernel<true, true>(argsTypes, thenType); + } else if (thenIsScalar && !elseIsScalar) { + kernel = MakeBlockIfKernel<true, false>(argsTypes, thenType); + } else if (!thenIsScalar && elseIsScalar) { + kernel = MakeBlockIfKernel<false, true>(argsTypes, thenType); + } else { + kernel = MakeBlockIfKernel<false, false>(argsTypes, thenType); + } + + return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel); +} + +} +} diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_if.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.h new file mode 100644 index 0000000000..62fc88c2a2 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.h @@ -0,0 +1,10 @@ +#pragma once +#include <ydb/library/yql/minikql/computation/mkql_computation_node.h> + +namespace NKikimr { +namespace NMiniKQL { + +IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactoryContext& ctx); + +} +} diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp new file mode 100644 index 0000000000..f0b709d949 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp @@ -0,0 +1,183 @@ +#include "mkql_block_impl.h" +#include "mkql_block_builder.h" +#include "mkql_block_reader.h" + +#include <ydb/library/yql/minikql/arrow/mkql_functions.h> +#include <ydb/library/yql/minikql/mkql_node_builder.h> +#include <ydb/library/yql/minikql/arrow/arrow_util.h> + +#include <arrow/compute/exec_internal.h> + +namespace NKikimr::NMiniKQL { + +namespace { + +class TArgsDechunker { +public: + explicit TArgsDechunker(std::vector<arrow::Datum>&& args) + : Args(std::move(args)) + , Arrays(Args.size()) + { + for (size_t i = 0; i < Args.size(); ++i) { + if (Args[i].is_arraylike()) { + ForEachArrayData(Args[i], [&](const auto& data) { + Arrays[i].push_back(data); + }); + } + } + } + + bool Next(std::vector<arrow::Datum>& chunk) { + if (Finish) { + return false; + } + + size_t minSize = Max<size_t>(); + bool haveData = false; + chunk.resize(Args.size()); + for (size_t i = 0; i < Args.size(); ++i) { + if (Args[i].is_scalar()) { + chunk[i] = Args[i]; + continue; + } + while (!Arrays[i].empty() && Arrays[i].front()->length == 0) { + Arrays[i].pop_front(); + } + if (!Arrays[i].empty()) { + haveData = true; + minSize = std::min<size_t>(minSize, Arrays[i].front()->length); + } else { + minSize = 0; + } + } + + MKQL_ENSURE(!haveData || minSize > 0, "Block length mismatch"); + if (!haveData) { + Finish = true; + return false; + } + + for (size_t i = 0; i < Args.size(); ++i) { + if (!Args[i].is_scalar()) { + MKQL_ENSURE(!Arrays[i].empty(), "Block length mismatch"); + chunk[i] = arrow::Datum(Chop(Arrays[i].front(), minSize)); + } + } + return true; + } +private: + const std::vector<arrow::Datum> Args; + std::vector<std::deque<std::shared_ptr<arrow::ArrayData>>> Arrays; + bool Finish = false; +}; + +std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) { + std::vector<arrow::ValueDescr> res; + res.reserve(types.size()); + for (const auto& type : types) { + res.emplace_back(ToValueDescr(type)); + } + + return res; +} + +} // namespace + +arrow::Datum MakeArrayFromScalar(const arrow::Scalar& scalar, size_t len, TType* type, arrow::MemoryPool& pool) { + MKQL_ENSURE(len > 0, "Invalid block size"); + auto reader = MakeBlockReader(type); + auto builder = MakeBlockBuilder(type, pool, len); + + auto scalarItem = reader->GetScalarItem(scalar); + for (size_t i = 0; i < len; ++i) { + builder->Add(scalarItem); + } + + return builder->Build(true); +} + +arrow::ValueDescr ToValueDescr(TType* type) { + arrow::ValueDescr ret; + MKQL_ENSURE(ConvertInputArrowType(type, ret), "can't get arrow type"); + return ret; +} + +std::vector<arrow::compute::InputType> ConvertToInputTypes(const TVector<TType*>& argTypes) { + std::vector<arrow::compute::InputType> result; + result.reserve(argTypes.size()); + for (auto& type : argTypes) { + result.emplace_back(ToValueDescr(type)); + } + return result; +} + +arrow::compute::OutputType ConvertToOutputType(TType* output) { + return arrow::compute::OutputType(ToValueDescr(output)); +} + +TBlockFuncNode::TBlockFuncNode(TComputationMutables& mutables, TVector<IComputationNode*>&& argsNodes, + const TVector<TType*>& argsTypes, const arrow::compute::ScalarKernel& kernel, + std::shared_ptr<arrow::compute::ScalarKernel> kernelHolder, + const arrow::compute::FunctionOptions* functionOptions) + : TMutableComputationNode(mutables) + , StateIndex(mutables.CurValueIndex++) + , ArgsNodes(std::move(argsNodes)) + , ArgsValuesDescr(ToValueDescr(argsTypes)) + , Kernel(kernel) + , KernelHolder(std::move(kernelHolder)) + , Options(functionOptions) + , ScalarOutput(GetResultShape(argsTypes) == TBlockType::EShape::Scalar) +{ +} + +NUdf::TUnboxedValuePod TBlockFuncNode::DoCalculate(TComputationContext& ctx) const { + auto& state = GetState(ctx); + + std::vector<arrow::Datum> argDatums; + for (ui32 i = 0; i < ArgsNodes.size(); ++i) { + argDatums.emplace_back(TArrowBlock::From(ArgsNodes[i]->GetValue(ctx)).GetDatum()); + Y_VERIFY_DEBUG(ArgsValuesDescr[i] == argDatums.back().descr()); + } + + auto executor = arrow::compute::detail::KernelExecutor::MakeScalar(); + ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, Options })); + + if (ScalarOutput) { + auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>(); + ARROW_OK(executor->Execute(argDatums, listener.get())); + auto output = executor->WrapResults(argDatums, listener->values()); + return ctx.HolderFactory.CreateArrowBlock(std::move(output)); + } + + TArgsDechunker dechunker(std::move(argDatums)); + std::vector<arrow::Datum> chunk; + TVector<std::shared_ptr<arrow::ArrayData>> arrays; + + while (dechunker.Next(chunk)) { + arrow::compute::detail::DatumAccumulator listener; + ARROW_OK(executor->Execute(chunk, &listener)); + auto output = executor->WrapResults(chunk, listener.values()); + + ForEachArrayData(output, [&](const auto& arr) { arrays.push_back(arr); }); + } + + return ctx.HolderFactory.CreateArrowBlock(MakeArray(arrays)); +} + + +void TBlockFuncNode::RegisterDependencies() const { + for (const auto& arg : ArgsNodes) { + DependsOn(arg); + } +} + +TBlockFuncNode::TState& TBlockFuncNode::GetState(TComputationContext& ctx) const { + auto& result = ctx.MutableValues[StateIndex]; + if (!result.HasValue()) { + result = ctx.HolderFactory.Create<TState>(Options, Kernel, ArgsValuesDescr, ctx); + } + + return *static_cast<TState*>(result.AsBoxed().Get()); +} + +} diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h index 58885a2b1e..3e82ae1e61 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h @@ -6,9 +6,61 @@ #include <ydb/library/yql/minikql/arrow/arrow_util.h> #include <arrow/array.h> +#include <arrow/scalar.h> +#include <arrow/datum.h> +#include <arrow/compute/kernel.h> namespace NKikimr::NMiniKQL { +arrow::Datum MakeArrayFromScalar(const arrow::Scalar& scalar, size_t len, TType* type, arrow::MemoryPool& pool); + +arrow::ValueDescr ToValueDescr(TType* type); + +std::vector<arrow::compute::InputType> ConvertToInputTypes(const TVector<TType*>& argTypes); +arrow::compute::OutputType ConvertToOutputType(TType* output); + +class TBlockFuncNode : public TMutableComputationNode<TBlockFuncNode> { +public: + TBlockFuncNode(TComputationMutables& mutables, TVector<IComputationNode*>&& argsNodes, + const TVector<TType*>& argsTypes, const arrow::compute::ScalarKernel& kernel, + std::shared_ptr<arrow::compute::ScalarKernel> kernelHolder = {}, + const arrow::compute::FunctionOptions* functionOptions = nullptr); + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const; +private: + struct TState : public TComputationValue<TState> { + using TComputationValue::TComputationValue; + + TState(TMemoryUsageInfo* memInfo, const arrow::compute::FunctionOptions* options, + const arrow::compute::ScalarKernel& kernel, const std::vector<arrow::ValueDescr>& argsValuesDescr, + TComputationContext& ctx) + : TComputationValue(memInfo) + , ExecContext(&ctx.ArrowMemoryPool, nullptr, nullptr) + , KernelContext(&ExecContext) + { + if (kernel.init) { + State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options })); + KernelContext.SetState(State.get()); + } + } + + arrow::compute::ExecContext ExecContext; + arrow::compute::KernelContext KernelContext; + std::unique_ptr<arrow::compute::KernelState> State; + }; + + void RegisterDependencies() const final; + TState& GetState(TComputationContext& ctx) const; +private: + const ui32 StateIndex; + const TVector<IComputationNode*> ArgsNodes; + const std::vector<arrow::ValueDescr> ArgsValuesDescr; + const arrow::compute::ScalarKernel& Kernel; + const std::shared_ptr<arrow::compute::ScalarKernel> KernelHolder; + const arrow::compute::FunctionOptions* const Options; + const bool ScalarOutput; +}; + template <typename TDerived> class TStatefulWideFlowBlockComputationNode: public TWideFlowBaseComputationNode<TDerived> { diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_item.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_item.h index 84924977f0..37c357fcc9 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_item.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_item.h @@ -36,6 +36,19 @@ public: Raw.Simple.Meta = static_cast<ui8>(EMarkers::Present); } + inline TBlockItem(ui64 low, ui64 high) { + Raw.Halfs[0] = low; + Raw.Halfs[1] = high; + } + + inline ui64 Low() const { + return Raw.Halfs[0]; + } + + inline ui64 High() const { + return Raw.Halfs[1]; + } + template <typename T, typename = std::enable_if_t<NYql::NUdf::TPrimitiveDataType<T>::Result>> inline T As() const; diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp index 789cb67df2..2e02a6a9c6 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp @@ -43,7 +43,7 @@ public: builder->Add(result); } - return builder->Build(ctx, true); + return ctx.HolderFactory.CreateArrowBlock(builder->Build(true)); } private: @@ -98,7 +98,7 @@ public: for (size_t i = 0; i < Width_; ++i) { if (auto* out = output[i]; out != nullptr) { - *out = s.Builders_[i]->Build(ctx, s.IsFinished_); + *out = ctx.HolderFactory.CreateArrowBlock(s.Builders_[i]->Build(s.IsFinished_)); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index e29c051c97..ddeecfa047 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -8,6 +8,7 @@ #include "mkql_blocks.h" #include "mkql_block_agg.h" #include "mkql_block_coalesce.h" +#include "mkql_block_if.h" #include "mkql_block_logical.h" #include "mkql_block_compress.h" #include "mkql_block_skiptake.h" @@ -276,6 +277,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"WideTakeBlocks", &WrapWideTakeBlocks}, {"AsScalar", &WrapAsScalar}, {"BlockCoalesce", &WrapBlockCoalesce}, + {"BlockIf", &WrapBlockIf}, {"BlockAnd", &WrapBlockAnd}, {"BlockOr", &WrapBlockOr}, {"BlockXor", &WrapBlockXor}, diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 6eb2f14802..a4f25520f0 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -5312,6 +5312,23 @@ TRuntimeNode TProgramBuilder::PgInternal0(TType* returnType) { return TRuntimeNode(callableBuilder.Build(), false); } +TRuntimeNode TProgramBuilder::BlockIf(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch) { + const auto conditionType = AS_TYPE(TBlockType, condition.GetStaticType()); + MKQL_ENSURE(AS_TYPE(TDataType, conditionType->GetItemType())->GetSchemeType() == NUdf::TDataType<bool>::Id, + "Expected bool as first argument"); + + const auto thenType = AS_TYPE(TBlockType, thenBranch.GetStaticType()); + const auto elseType = AS_TYPE(TBlockType, elseBranch.GetStaticType()); + MKQL_ENSURE(thenType->GetItemType()->IsSameType(*elseType->GetItemType()), "Different return types in branches."); + + auto returnType = NewBlockType(thenType->GetItemType(), GetResultShape({conditionType, thenType, elseType})); + TCallableBuilder callableBuilder(Env, __func__, returnType); + callableBuilder.Add(condition); + callableBuilder.Add(thenBranch); + callableBuilder.Add(elseBranch); + return TRuntimeNode(callableBuilder.Build(), false); +} + TRuntimeNode TProgramBuilder::BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args) { for (const auto& arg : args) { MKQL_ENSURE(arg.GetStaticType()->IsBlock(), "Expected Block type"); diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index 99b4233201..0e2076e798 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -257,6 +257,7 @@ public: TRuntimeNode BlockOr(TRuntimeNode first, TRuntimeNode second); TRuntimeNode BlockXor(TRuntimeNode first, TRuntimeNode second); + TRuntimeNode BlockIf(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch); TRuntimeNode BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args); TRuntimeNode BlockBitCast(TRuntimeNode value, TType* targetType); TRuntimeNode BlockCombineAll(TRuntimeNode flow, std::optional<ui32> filterColumn, diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index 7531252eaa..c0f1254586 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -552,7 +552,9 @@ TMkqlCommonCallableCompiler::TShared::TShared() { {"ListFromRange", &TProgramBuilder::ListFromRange}, - {"PreserveStream", &TProgramBuilder::PreserveStream} + {"PreserveStream", &TProgramBuilder::PreserveStream}, + + {"BlockIf", &TProgramBuilder::BlockIf}, }); AddSimpleCallables({ |