diff options
author | vvvv <vvvv@ydb.tech> | 2022-09-21 13:16:59 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-09-21 13:16:59 +0300 |
commit | 8a7a2f9431cada8190c4dc97519442db3fe686f8 (patch) | |
tree | 6735a66d0c4afb9b5e4f468c419ff27e720b303a | |
parent | dbe687629ccb40b695d7cf47fa98269a94ced0cd (diff) | |
download | ydb-8a7a2f9431cada8190c4dc97519442db3fe686f8.tar.gz |
bitcast mode for some operations
15 files changed, 329 insertions, 87 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 0525349a78..170a016c82 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -34,7 +34,12 @@ using TPeepHoleOptimizerMap = std::unordered_map<std::string_view, TPeepHoleOpti using TExtPeepHoleOptimizerPtr = TExprNode::TPtr (*const)(const TExprNode::TPtr&, TExprContext&, TTypeAnnotationContext& types); using TExtPeepHoleOptimizerMap = std::unordered_map<std::string_view, TExtPeepHoleOptimizerPtr>; -using TBlockFuncMap = std::unordered_map<std::string_view, std::string_view>; +struct TBlockFuncRule { + std::string_view Name; + bool BitcastToReturnType = false; +}; + +using TBlockFuncMap = std::unordered_map<std::string_view, TBlockFuncRule>; TExprNode::TPtr MakeNothing(TPositionHandle pos, const TTypeAnnotationNode& type, TExprContext& ctx) { return ctx.NewCallable(pos, "Nothing", {ExpandType(pos, *ctx.MakeType<TOptionalExprType>(&type), ctx)}); @@ -4304,8 +4309,8 @@ TExprNode::TPtr OptimizeWideChopper(const TExprNode::TPtr& node, TExprContext& c struct TBlockRules { static constexpr std::initializer_list<TBlockFuncMap::value_type> FuncsInit = { - {"+", "add" }, - {"Not", "invert" }, + {"+", {"add", true} }, + {"Not", {"invert", false} }, }; TBlockRules() @@ -4377,7 +4382,7 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext& } TExprNode::TListType funcArgs; - funcArgs.push_back(ctx.NewAtom(node->Pos(), fit->second)); + funcArgs.push_back(ctx.NewAtom(node->Pos(), fit->second.Name)); for (const auto& child : node->Children()) { if (child->IsComplete()) { funcArgs.push_back(ctx.NewCallable(node->Pos(), "AsScalar", { child })); @@ -4401,11 +4406,47 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext& } } - YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), fit->second, argTypes, outType, ctx)); - if (!outType) { + YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), fit->second.Name, argTypes, outType, ctx)); + if (!outType && !fit->second.BitcastToReturnType) { return true; } + if (!outType) { + argTypes.clear(); + for (const auto& child : node->Children()) { + if (child->IsComplete()) { + argTypes.push_back(ctx.MakeType<TScalarExprType>(node->GetTypeAnn())); + } else { + argTypes.push_back(ctx.MakeType<TBlockExprType>(node->GetTypeAnn())); + } + } + + YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), fit->second.Name, argTypes, outType, ctx)); + if (!outType) { + return true; + } + + auto typeNode = ExpandType(node->Pos(), *node->GetTypeAnn(), ctx); + for (ui32 i = 1; i < funcArgs.size(); ++i) { + if (IsSameAnnotation(*node->Child(i-1)->GetTypeAnn(), *node->GetTypeAnn())) { + continue; + } + + if (node->Child(i-1)->IsComplete()) { + funcArgs[i] = ctx.Builder(node->Pos()) + .Callable("AsScalar") + .Callable(0, "BitCast") + .Add(0, funcArgs[i]->HeadPtr()) + .Add(1, typeNode) + .Seal() + .Seal() + .Build(); + } else { + funcArgs[i] = ctx.NewCallable(node->Pos(), "BlockBitCast", { funcArgs[i], typeNode }); + } + } + } + bool isScalar; if (!IsSameAnnotation(*node->GetTypeAnn(), *GetBlockItemType(*outType, isScalar))) { return true; diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index b729736598..6cabcb5198 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -68,5 +68,47 @@ IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprN return IGraphTransformer::TStatus::Ok; } +IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { + Y_UNUSED(output); + if (!EnsureArgsCount(*input, 2U, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureBlockOrScalarType(*input->Child(0), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (auto status = EnsureTypeRewrite(input->ChildRef(1), ctx.Expr); status != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (!ctx.Types.ArrowResolver) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Arrow resolver isn't available")); + return IGraphTransformer::TStatus::Error; + } + + bool isScalar; + auto inputType = GetBlockItemType(*input->Child(0)->GetTypeAnn(), isScalar); + + auto outputType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + bool has = false; + if (!ctx.Types.ArrowResolver->HasCast(ctx.Expr.GetPosition(input->Pos()), inputType, outputType, has, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!has) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "No such cast")); + return IGraphTransformer::TStatus::Error; + } + + if (isScalar) { + input->SetTypeAnn(ctx.Expr.MakeType<TScalarExprType>(outputType)); + } else { + input->SetTypeAnn(ctx.Expr.MakeType<TBlockExprType>(outputType)); + } + + return IGraphTransformer::TStatus::Ok; +} + } // namespace NTypeAnnImpl } diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.h b/ydb/library/yql/core/type_ann/type_ann_blocks.h index 20b35026cb..b80ca7ac17 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.h +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.h @@ -10,6 +10,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus AsScalarWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); + IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); } // namespace NTypeAnnImpl } // namespace NYql diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 60137cba35..f4bc1e3078 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11533,6 +11533,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["WideFromBlocks"] = &WideFromBlocksWrapper; Functions["AsScalar"] = &AsScalarWrapper; ExtFunctions["BlockFunc"] = &BlockFuncWrapper; + ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper; Functions["AsRange"] = &AsRangeWrapper; Functions["RangeCreate"] = &RangeCreateWrapper; diff --git a/ydb/library/yql/core/yql_arrow_resolver.h b/ydb/library/yql/core/yql_arrow_resolver.h index a0514fbff6..87622c4e58 100644 --- a/ydb/library/yql/core/yql_arrow_resolver.h +++ b/ydb/library/yql/core/yql_arrow_resolver.h @@ -12,6 +12,8 @@ public: virtual bool LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes, const TTypeAnnotationNode*& returnType, TExprContext& ctx) const = 0; + + virtual bool HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, bool& has, TExprContext& ctx) const = 0; }; } diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp index 73b82e7cd9..226108956d 100644 --- a/ydb/library/yql/minikql/arrow/mkql_functions.cpp +++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp @@ -6,6 +6,7 @@ #include <arrow/visitor.h> #include <arrow/compute/registry.h> #include <arrow/compute/function.h> +#include <arrow/compute/cast.h> namespace NKikimr::NMiniKQL { @@ -203,4 +204,18 @@ bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TTy return true; } +bool HasArrowCast(TType* from, TType* to) { + bool isOptional; + std::shared_ptr<arrow::DataType> fromArrowType, toArrowType; + if (!ConvertArrowType(from, isOptional, fromArrowType)) { + return false; + } + + if (!ConvertArrowType(to, isOptional, toArrowType)) { + return false; + } + + return arrow::compute::CanCast(*fromArrowType, *toArrowType); +} + } diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.h b/ydb/library/yql/minikql/arrow/mkql_functions.h index 329456ecab..5dffb01a59 100644 --- a/ydb/library/yql/minikql/arrow/mkql_functions.h +++ b/ydb/library/yql/minikql/arrow/mkql_functions.h @@ -8,5 +8,5 @@ namespace NKikimr::NMiniKQL { bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env); bool ConvertInputArrowType(TType* blockType, bool& isOptional, arrow::ValueDescr& descr); bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow::DataType>& type); - +bool HasArrowCast(TType* from, TType* to); } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp index 2994b92b36..4fe047e43f 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp @@ -8,6 +8,7 @@ #include <ydb/library/yql/minikql/arrow/mkql_functions.h> #include <arrow/array/builder_primitive.h> +#include <arrow/compute/cast.h> #include <arrow/compute/exec_internal.h> #include <arrow/compute/function.h> #include <arrow/compute/kernel.h> @@ -19,6 +20,56 @@ namespace NMiniKQL { namespace { +arrow::ValueDescr ToValueDescr(TType* type) { + bool isOptional; + arrow::ValueDescr ret; + MKQL_ENSURE(ConvertInputArrowType(type, isOptional, ret), "can't get arrow type"); + return ret; +} + +std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) { + std::vector<arrow::ValueDescr> res; + res.reserve(types.size()); + for (const auto& type : types) { + res.emplace_back(ToValueDescr(type)); + } + + return res; +} + +const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function& function, const std::vector<arrow::ValueDescr>& args) { + const auto kernel = ARROW_RESULT(function.DispatchExact(args)); + return *static_cast<const arrow::compute::ScalarKernel*>(kernel); +} + +struct TState : public TComputationValue<TState> { + using TComputationValue::TComputationValue; + + TState(TMemoryUsageInfo* memInfo, const arrow::compute::Function& function, const arrow::compute::FunctionOptions* options, + const arrow::compute::ScalarKernel& kernel, + arrow::compute::FunctionRegistry& registry, const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx) + : TComputationValue(memInfo) + , ExecContext(&ctx.ArrowMemoryPool, nullptr, ®istry) + , KernelContext(&ExecContext) + , Executor(arrow::compute::detail::KernelExecutor::MakeScalar()) + { + if (kernel.init) {
+ State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ KernelContext.SetState(State.get());
+ }
+
+ ARROW_OK(Executor->Init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ Values.reserve(argsValuesDescr.size()); + } + + arrow::compute::ExecContext ExecContext; + arrow::compute::KernelContext KernelContext; + std::unique_ptr<arrow::compute::KernelState> State; + std::unique_ptr<arrow::compute::detail::KernelExecutor> Executor; + + std::vector<arrow::Datum> Values; +}; + class TBlockFuncWrapper : public TMutableComputationNode<TBlockFuncWrapper> { public: TBlockFuncWrapper(TComputationMutables& mutables, @@ -66,60 +117,74 @@ private: return *function; } - static const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function& function, const std::vector<arrow::ValueDescr>& args) { - const auto kernel = ARROW_RESULT(function.DispatchExact(args)); - return *static_cast<const arrow::compute::ScalarKernel*>(kernel); + TState& GetState(TComputationContext& ctx) const { + auto& result = ctx.MutableValues[StateIndex]; + if (!result.HasValue()) { + result = ctx.HolderFactory.Create<TState>(Function, Function.default_options(), Kernel, FunctionRegistry, ArgsValuesDescr, ctx); + } + + return *static_cast<TState*>(result.AsBoxed().Get()); } - static arrow::ValueDescr ToValueDescr(TType* type) { - bool isOptional; - arrow::ValueDescr ret; - MKQL_ENSURE(ConvertInputArrowType(type, isOptional, ret), "can't get arrow type"); - return ret; +private: + const ui32 StateIndex; + const TString FuncName; + const TVector<IComputationNode*> ArgsNodes; + const TVector<TType*> ArgsTypes; + + const std::vector<arrow::ValueDescr> ArgsValuesDescr; + arrow::compute::FunctionRegistry& FunctionRegistry; + const arrow::compute::Function& Function; + const arrow::compute::ScalarKernel& Kernel; +}; + +class TBlockBitCastWrapper : public TMutableComputationNode<TBlockBitCastWrapper> { +public: + TBlockBitCastWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* argType, TType* to) + : TMutableComputationNode(mutables) + , StateIndex(mutables.CurValueIndex++) + , Arg(arg) + , ArgsValuesDescr({ ToValueDescr(argType) }) + , FunctionRegistry(*arrow::compute::GetFunctionRegistry()) + , Function(ResolveFunction(FunctionRegistry, to)) + , Kernel(ResolveKernel(Function, ArgsValuesDescr)) + { } - static std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) { - std::vector<arrow::ValueDescr> res; - res.reserve(types.size()); - for (const auto& type : types) { - res.emplace_back(ToValueDescr(type)); - } + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + auto& state = GetState(ctx); + + state.Values.clear(); + state.Values.emplace_back(TArrowBlock::From(Arg->GetValue(ctx)).GetDatum()); + Y_VERIFY_DEBUG(ArgsValuesDescr[0] == state.Values.back().descr()); - return res; + auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
+ ARROW_OK(state.Executor->Execute(state.Values, listener.get()));
+ auto output = state.Executor->WrapResults(state.Values, listener->values()); + return ctx.HolderFactory.CreateArrowBlock(std::move(output)); } - struct TState : public TComputationValue<TState> { - using TComputationValue::TComputationValue; - - TState(TMemoryUsageInfo* memInfo, const arrow::compute::Function& function, const arrow::compute::ScalarKernel& kernel, - arrow::compute::FunctionRegistry& registry, const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx) - : TComputationValue(memInfo) - , ExecContext(&ctx.ArrowMemoryPool, nullptr, ®istry) - , KernelContext(&ExecContext) - , Executor(arrow::compute::detail::KernelExecutor::MakeScalar()) - { - auto options = function.default_options(); - if (kernel.init) {
- State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
- KernelContext.SetState(State.get());
- }
-
- ARROW_OK(Executor->Init(&KernelContext, { &kernel, argsValuesDescr, options }));
- Values.reserve(argsValuesDescr.size()); - } +private: + void RegisterDependencies() const final { + this->DependsOn(Arg); + } - arrow::compute::ExecContext ExecContext; - arrow::compute::KernelContext KernelContext; - std::unique_ptr<arrow::compute::KernelState> State; - std::unique_ptr<arrow::compute::detail::KernelExecutor> Executor; + static const arrow::compute::Function& ResolveFunction(const arrow::compute::FunctionRegistry& registry, TType* to) { + bool isOptional; + std::shared_ptr<arrow::DataType> type; + MKQL_ENSURE(ConvertArrowType(to, isOptional, type), "can't get arrow type"); - std::vector<arrow::Datum> Values; - }; + auto function = ARROW_RESULT(arrow::compute::GetCastFunction(type)); + MKQL_ENSURE(function != nullptr, "missing function"); + MKQL_ENSURE(function->kind() == arrow::compute::Function::SCALAR, "expected SCALAR function"); + return *function; + } TState& GetState(TComputationContext& ctx) const { auto& result = ctx.MutableValues[StateIndex]; if (!result.HasValue()) { - result = ctx.HolderFactory.Create<TState>(Function, Kernel, FunctionRegistry, ArgsValuesDescr, ctx); + arrow::compute::CastOptions options(false); + result = ctx.HolderFactory.Create<TState>(Function, (const arrow::compute::FunctionOptions*)&options, Kernel, FunctionRegistry, ArgsValuesDescr, ctx); } return *static_cast<TState*>(result.AsBoxed().Get()); @@ -127,10 +192,7 @@ private: private: const ui32 StateIndex; - const TString FuncName; - const TVector<IComputationNode*> ArgsNodes; - const TVector<TType*> ArgsTypes; - + IComputationNode* Arg; const std::vector<arrow::ValueDescr> ArgsValuesDescr; arrow::compute::FunctionRegistry& FunctionRegistry; const arrow::compute::Function& Function; @@ -158,5 +220,16 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto ); } +IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args"); + auto argNode = LocateNode(ctx.NodeLocator, callable, 0); + MKQL_ENSURE(callable.GetInput(1).GetStaticType()->IsType(), "Expected type"); + return new TBlockBitCastWrapper(ctx.Mutables, + argNode, + callable.GetType()->GetArgumentType(0), + static_cast<TType*>(callable.GetInput(1).GetNode()) + ); +} + } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h index 0bbd881fdf..8cd15bca62 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h @@ -6,6 +6,7 @@ namespace NKikimr { namespace NMiniKQL { IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx); +IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFactoryContext& ctx); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp index 524fe761a3..ee68f268c3 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp @@ -449,43 +449,70 @@ private: const size_t Width_; }; -arrow::Datum ExtractLiteral(TRuntimeNode n) { - if (n.GetStaticType()->IsOptional()) { - const auto* dataLiteral = AS_VALUE(TOptionalLiteral, n); - if (!dataLiteral->HasItem()) { - bool isOptional; - auto unpacked = UnpackOptionalData(dataLiteral->GetType(), isOptional); - std::shared_ptr<arrow::DataType> type; - MKQL_ENSURE(ConvertArrowType(unpacked, isOptional, type), "Unsupported type of literal"); - return arrow::MakeNullScalar(type); +class TAsScalarWrapper : public TMutableComputationNode<TAsScalarWrapper> { +public: + TAsScalarWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* type) + : TMutableComputationNode(mutables) + , Arg_(arg) + { + bool isOptional; + auto unpacked = UnpackOptionalData(type, isOptional); + MKQL_ENSURE(ConvertArrowType(unpacked, isOptional, Type_), "Unsupported type of scalar"); + Slot_ = *unpacked->GetDataSlot(); + } + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + auto value = Arg_->GetValue(ctx); + arrow::Datum result; + if (!value) { + result = arrow::MakeNullScalar(Type_); + } else { + switch (Slot_) { + case NUdf::EDataSlot::Bool: + result = arrow::Datum(static_cast<bool>(value.Get<bool>())); + break; + case NUdf::EDataSlot::Int8: + result = arrow::Datum(static_cast<int8_t>(value.Get<i8>())); + break; + case NUdf::EDataSlot::Uint8: + result = arrow::Datum(static_cast<uint8_t>(value.Get<ui8>())); + break; + case NUdf::EDataSlot::Int16: + result = arrow::Datum(static_cast<int16_t>(value.Get<i16>())); + break; + case NUdf::EDataSlot::Uint16: + result = arrow::Datum(static_cast<uint16_t>(value.Get<ui16>())); + break; + case NUdf::EDataSlot::Int32: + result = arrow::Datum(static_cast<int32_t>(value.Get<i32>())); + break; + case NUdf::EDataSlot::Uint32: + result = arrow::Datum(static_cast<uint32_t>(value.Get<ui32>())); + break; + case NUdf::EDataSlot::Int64: + result = arrow::Datum(static_cast<int64_t>(value.Get<i64>())); + break; + case NUdf::EDataSlot::Uint64: + result = arrow::Datum(static_cast<uint64_t>(value.Get<ui64>())); + break; + default: + MKQL_ENSURE(false, "Unsupported data slot"); + } } - n = dataLiteral->GetItem(); + + return ctx.HolderFactory.CreateArrowBlock(std::move(result)); } - const auto* dataLiteral = AS_VALUE(TDataLiteral, n); - switch (*dataLiteral->GetType()->GetDataSlot()) { - case NUdf::EDataSlot::Bool: - return arrow::Datum(static_cast<bool>(dataLiteral->AsValue().Get<bool>())); - case NUdf::EDataSlot::Int8: - return arrow::Datum(static_cast<int8_t>(dataLiteral->AsValue().Get<i8>())); - case NUdf::EDataSlot::Uint8: - return arrow::Datum(static_cast<uint8_t>(dataLiteral->AsValue().Get<ui8>())); - case NUdf::EDataSlot::Int16: - return arrow::Datum(static_cast<int16_t>(dataLiteral->AsValue().Get<i16>())); - case NUdf::EDataSlot::Uint16: - return arrow::Datum(static_cast<uint16_t>(dataLiteral->AsValue().Get<ui16>())); - case NUdf::EDataSlot::Int32: - return arrow::Datum(static_cast<int32_t>(dataLiteral->AsValue().Get<i32>())); - case NUdf::EDataSlot::Uint32: - return arrow::Datum(static_cast<uint32_t>(dataLiteral->AsValue().Get<ui32>())); - case NUdf::EDataSlot::Int64: - return arrow::Datum(static_cast<int64_t>(dataLiteral->AsValue().Get<i64>())); - case NUdf::EDataSlot::Uint64: - return arrow::Datum(static_cast<uint64_t>(dataLiteral->AsValue().Get<ui64>())); - default: - MKQL_ENSURE(false, "Unsupported data slot"); +private: + void RegisterDependencies() const final { + DependsOn(Arg_); } -} + +private: + IComputationNode* const Arg_; + std::shared_ptr<arrow::DataType> Type_; + NUdf::EDataSlot Slot_; +}; } @@ -545,8 +572,7 @@ IComputationNode* WrapWideFromBlocks(TCallable& callable, const TComputationNode IComputationNode* WrapAsScalar(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount()); - auto value = ExtractLiteral(callable.GetInput(0U)); - return ctx.NodeFactory.CreateImmutableNode(ctx.HolderFactory.CreateArrowBlock(std::move(value))); + return new TAsScalarWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0), callable.GetInput(0).GetStaticType()); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index c3d3f5c1be..22cb1b5e80 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -264,6 +264,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"ToBlocks", &WrapToBlocks}, {"WideToBlocks", &WrapWideToBlocks}, {"BlockFunc", &WrapBlockFunc}, + {"BlockBitCast", &WrapBlockBitCast}, {"FromBlocks", &WrapFromBlocks}, {"WideFromBlocks", &WrapWideFromBlocks}, {"AsScalar", &WrapAsScalar}, diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 808b607b07..c083bf7f9f 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -5182,6 +5182,17 @@ TRuntimeNode TProgramBuilder::BlockFunc(const std::string_view& funcName, TType* return TRuntimeNode(builder.Build(), false); } +TRuntimeNode TProgramBuilder::BlockBitCast(TRuntimeNode value, TType* targetType) { + MKQL_ENSURE(value.GetStaticType()->IsBlock(), "Expected Block type"); + + auto returnType = TBlockType::Create(targetType, AS_TYPE(TBlockType, value.GetStaticType())->GetShape(), Env); + TCallableBuilder builder(Env, __func__, returnType); + builder.Add(value); + builder.Add(TRuntimeNode(targetType, true)); + + return TRuntimeNode(builder.Build(), false); +} + bool CanExportType(TType* type, const TTypeEnvironment& env) { if (type->GetKind() == TType::EKind::Type) { return false; // Type of Type diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index d514255b2d..8b42022443 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -241,6 +241,7 @@ public: TRuntimeNode AsScalar(TRuntimeNode value); TRuntimeNode BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args); + TRuntimeNode BlockBitCast(TRuntimeNode value, TType* targetType); // udfs TRuntimeNode Udf( diff --git a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp index 3cb21a0e57..40fb07c48c 100644 --- a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp +++ b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp @@ -44,6 +44,27 @@ private: } } + bool HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, bool& has, TExprContext& ctx) const override { + try { + has = false; + TScopedAlloc alloc; + TTypeEnvironment env(alloc); + TProgramBuilder pgmBuilder(env, FunctionRegistry_); + TNullOutput null; + auto mkqlFromType = NCommon::BuildType(*from, pgmBuilder, null); + auto mkqlToType = NCommon::BuildType(*to, pgmBuilder, null); + if (!HasArrowCast(mkqlFromType, mkqlToType)) { + return true; + } + + has = true; + return true; + } catch (const std::exception& e) { + ctx.AddError(TIssue(pos, e.what())); + return false; + } + } + private: const IFunctionRegistry& FunctionRegistry_; }; diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index b242d345e3..757dbfafb4 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -2382,6 +2382,12 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return ctx.ProgramBuilder.BlockFunc(node.Child(0)->Content(), returnType, args); }); + AddCallable("BlockBitCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { + auto arg = MkqlBuildExpr(*node.Child(0), ctx); + auto targetType = BuildType(node, *node.Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + return ctx.ProgramBuilder.BlockBitCast(arg, targetType); + }); + AddCallable("PgArray", [](const TExprNode& node, TMkqlBuildContext& ctx) { std::vector<TRuntimeNode> args; args.reserve(node.ChildrenSize()); |