aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-09-21 13:16:59 +0300
committervvvv <vvvv@ydb.tech>2022-09-21 13:16:59 +0300
commit8a7a2f9431cada8190c4dc97519442db3fe686f8 (patch)
tree6735a66d0c4afb9b5e4f468c419ff27e720b303a
parentdbe687629ccb40b695d7cf47fa98269a94ced0cd (diff)
downloadydb-8a7a2f9431cada8190c4dc97519442db3fe686f8.tar.gz
bitcast mode for some operations
-rw-r--r--ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp53
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_blocks.cpp42
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_blocks.h1
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp1
-rw-r--r--ydb/library/yql/core/yql_arrow_resolver.h2
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions.cpp15
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions.h2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp163
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_func.h1
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp96
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp1
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.cpp11
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.h1
-rw-r--r--ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp21
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp6
15 files changed, 329 insertions, 87 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
index 0525349a78..170a016c82 100644
--- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
+++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
@@ -34,7 +34,12 @@ using TPeepHoleOptimizerMap = std::unordered_map<std::string_view, TPeepHoleOpti
using TExtPeepHoleOptimizerPtr = TExprNode::TPtr (*const)(const TExprNode::TPtr&, TExprContext&, TTypeAnnotationContext& types);
using TExtPeepHoleOptimizerMap = std::unordered_map<std::string_view, TExtPeepHoleOptimizerPtr>;
-using TBlockFuncMap = std::unordered_map<std::string_view, std::string_view>;
+struct TBlockFuncRule {
+ std::string_view Name;
+ bool BitcastToReturnType = false;
+};
+
+using TBlockFuncMap = std::unordered_map<std::string_view, TBlockFuncRule>;
TExprNode::TPtr MakeNothing(TPositionHandle pos, const TTypeAnnotationNode& type, TExprContext& ctx) {
return ctx.NewCallable(pos, "Nothing", {ExpandType(pos, *ctx.MakeType<TOptionalExprType>(&type), ctx)});
@@ -4304,8 +4309,8 @@ TExprNode::TPtr OptimizeWideChopper(const TExprNode::TPtr& node, TExprContext& c
struct TBlockRules {
static constexpr std::initializer_list<TBlockFuncMap::value_type> FuncsInit = {
- {"+", "add" },
- {"Not", "invert" },
+ {"+", {"add", true} },
+ {"Not", {"invert", false} },
};
TBlockRules()
@@ -4377,7 +4382,7 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext&
}
TExprNode::TListType funcArgs;
- funcArgs.push_back(ctx.NewAtom(node->Pos(), fit->second));
+ funcArgs.push_back(ctx.NewAtom(node->Pos(), fit->second.Name));
for (const auto& child : node->Children()) {
if (child->IsComplete()) {
funcArgs.push_back(ctx.NewCallable(node->Pos(), "AsScalar", { child }));
@@ -4401,11 +4406,47 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext&
}
}
- YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), fit->second, argTypes, outType, ctx));
- if (!outType) {
+ YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), fit->second.Name, argTypes, outType, ctx));
+ if (!outType && !fit->second.BitcastToReturnType) {
return true;
}
+ if (!outType) {
+ argTypes.clear();
+ for (const auto& child : node->Children()) {
+ if (child->IsComplete()) {
+ argTypes.push_back(ctx.MakeType<TScalarExprType>(node->GetTypeAnn()));
+ } else {
+ argTypes.push_back(ctx.MakeType<TBlockExprType>(node->GetTypeAnn()));
+ }
+ }
+
+ YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), fit->second.Name, argTypes, outType, ctx));
+ if (!outType) {
+ return true;
+ }
+
+ auto typeNode = ExpandType(node->Pos(), *node->GetTypeAnn(), ctx);
+ for (ui32 i = 1; i < funcArgs.size(); ++i) {
+ if (IsSameAnnotation(*node->Child(i-1)->GetTypeAnn(), *node->GetTypeAnn())) {
+ continue;
+ }
+
+ if (node->Child(i-1)->IsComplete()) {
+ funcArgs[i] = ctx.Builder(node->Pos())
+ .Callable("AsScalar")
+ .Callable(0, "BitCast")
+ .Add(0, funcArgs[i]->HeadPtr())
+ .Add(1, typeNode)
+ .Seal()
+ .Seal()
+ .Build();
+ } else {
+ funcArgs[i] = ctx.NewCallable(node->Pos(), "BlockBitCast", { funcArgs[i], typeNode });
+ }
+ }
+ }
+
bool isScalar;
if (!IsSameAnnotation(*node->GetTypeAnn(), *GetBlockItemType(*outType, isScalar))) {
return true;
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
index b729736598..6cabcb5198 100644
--- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
@@ -68,5 +68,47 @@ IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprN
return IGraphTransformer::TStatus::Ok;
}
+IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) {
+ Y_UNUSED(output);
+ if (!EnsureArgsCount(*input, 2U, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!EnsureBlockOrScalarType(*input->Child(0), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (auto status = EnsureTypeRewrite(input->ChildRef(1), ctx.Expr); status != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ if (!ctx.Types.ArrowResolver) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Arrow resolver isn't available"));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ bool isScalar;
+ auto inputType = GetBlockItemType(*input->Child(0)->GetTypeAnn(), isScalar);
+
+ auto outputType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ bool has = false;
+ if (!ctx.Types.ArrowResolver->HasCast(ctx.Expr.GetPosition(input->Pos()), inputType, outputType, has, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!has) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "No such cast"));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (isScalar) {
+ input->SetTypeAnn(ctx.Expr.MakeType<TScalarExprType>(outputType));
+ } else {
+ input->SetTypeAnn(ctx.Expr.MakeType<TBlockExprType>(outputType));
+ }
+
+ return IGraphTransformer::TStatus::Ok;
+}
+
} // namespace NTypeAnnImpl
}
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.h b/ydb/library/yql/core/type_ann/type_ann_blocks.h
index 20b35026cb..b80ca7ac17 100644
--- a/ydb/library/yql/core/type_ann/type_ann_blocks.h
+++ b/ydb/library/yql/core/type_ann/type_ann_blocks.h
@@ -10,6 +10,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus AsScalarWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
+ IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
} // namespace NTypeAnnImpl
} // namespace NYql
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 60137cba35..f4bc1e3078 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -11533,6 +11533,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["WideFromBlocks"] = &WideFromBlocksWrapper;
Functions["AsScalar"] = &AsScalarWrapper;
ExtFunctions["BlockFunc"] = &BlockFuncWrapper;
+ ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper;
Functions["AsRange"] = &AsRangeWrapper;
Functions["RangeCreate"] = &RangeCreateWrapper;
diff --git a/ydb/library/yql/core/yql_arrow_resolver.h b/ydb/library/yql/core/yql_arrow_resolver.h
index a0514fbff6..87622c4e58 100644
--- a/ydb/library/yql/core/yql_arrow_resolver.h
+++ b/ydb/library/yql/core/yql_arrow_resolver.h
@@ -12,6 +12,8 @@ public:
virtual bool LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes,
const TTypeAnnotationNode*& returnType, TExprContext& ctx) const = 0;
+
+ virtual bool HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, bool& has, TExprContext& ctx) const = 0;
};
}
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
index 73b82e7cd9..226108956d 100644
--- a/ydb/library/yql/minikql/arrow/mkql_functions.cpp
+++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
@@ -6,6 +6,7 @@
#include <arrow/visitor.h>
#include <arrow/compute/registry.h>
#include <arrow/compute/function.h>
+#include <arrow/compute/cast.h>
namespace NKikimr::NMiniKQL {
@@ -203,4 +204,18 @@ bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TTy
return true;
}
+bool HasArrowCast(TType* from, TType* to) {
+ bool isOptional;
+ std::shared_ptr<arrow::DataType> fromArrowType, toArrowType;
+ if (!ConvertArrowType(from, isOptional, fromArrowType)) {
+ return false;
+ }
+
+ if (!ConvertArrowType(to, isOptional, toArrowType)) {
+ return false;
+ }
+
+ return arrow::compute::CanCast(*fromArrowType, *toArrowType);
+}
+
}
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.h b/ydb/library/yql/minikql/arrow/mkql_functions.h
index 329456ecab..5dffb01a59 100644
--- a/ydb/library/yql/minikql/arrow/mkql_functions.h
+++ b/ydb/library/yql/minikql/arrow/mkql_functions.h
@@ -8,5 +8,5 @@ namespace NKikimr::NMiniKQL {
bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env);
bool ConvertInputArrowType(TType* blockType, bool& isOptional, arrow::ValueDescr& descr);
bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow::DataType>& type);
-
+bool HasArrowCast(TType* from, TType* to);
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
index 2994b92b36..4fe047e43f 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
@@ -8,6 +8,7 @@
#include <ydb/library/yql/minikql/arrow/mkql_functions.h>
#include <arrow/array/builder_primitive.h>
+#include <arrow/compute/cast.h>
#include <arrow/compute/exec_internal.h>
#include <arrow/compute/function.h>
#include <arrow/compute/kernel.h>
@@ -19,6 +20,56 @@ namespace NMiniKQL {
namespace {
+arrow::ValueDescr ToValueDescr(TType* type) {
+ bool isOptional;
+ arrow::ValueDescr ret;
+ MKQL_ENSURE(ConvertInputArrowType(type, isOptional, ret), "can't get arrow type");
+ return ret;
+}
+
+std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) {
+ std::vector<arrow::ValueDescr> res;
+ res.reserve(types.size());
+ for (const auto& type : types) {
+ res.emplace_back(ToValueDescr(type));
+ }
+
+ return res;
+}
+
+const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function& function, const std::vector<arrow::ValueDescr>& args) {
+ const auto kernel = ARROW_RESULT(function.DispatchExact(args));
+ return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
+}
+
+struct TState : public TComputationValue<TState> {
+ using TComputationValue::TComputationValue;
+
+ TState(TMemoryUsageInfo* memInfo, const arrow::compute::Function& function, const arrow::compute::FunctionOptions* options,
+ const arrow::compute::ScalarKernel& kernel,
+ arrow::compute::FunctionRegistry& registry, const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx)
+ : TComputationValue(memInfo)
+ , ExecContext(&ctx.ArrowMemoryPool, nullptr, &registry)
+ , KernelContext(&ExecContext)
+ , Executor(arrow::compute::detail::KernelExecutor::MakeScalar())
+ {
+ if (kernel.init) {
+ State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ KernelContext.SetState(State.get());
+ }
+
+ ARROW_OK(Executor->Init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ Values.reserve(argsValuesDescr.size());
+ }
+
+ arrow::compute::ExecContext ExecContext;
+ arrow::compute::KernelContext KernelContext;
+ std::unique_ptr<arrow::compute::KernelState> State;
+ std::unique_ptr<arrow::compute::detail::KernelExecutor> Executor;
+
+ std::vector<arrow::Datum> Values;
+};
+
class TBlockFuncWrapper : public TMutableComputationNode<TBlockFuncWrapper> {
public:
TBlockFuncWrapper(TComputationMutables& mutables,
@@ -66,60 +117,74 @@ private:
return *function;
}
- static const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function& function, const std::vector<arrow::ValueDescr>& args) {
- const auto kernel = ARROW_RESULT(function.DispatchExact(args));
- return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
+ TState& GetState(TComputationContext& ctx) const {
+ auto& result = ctx.MutableValues[StateIndex];
+ if (!result.HasValue()) {
+ result = ctx.HolderFactory.Create<TState>(Function, Function.default_options(), Kernel, FunctionRegistry, ArgsValuesDescr, ctx);
+ }
+
+ return *static_cast<TState*>(result.AsBoxed().Get());
}
- static arrow::ValueDescr ToValueDescr(TType* type) {
- bool isOptional;
- arrow::ValueDescr ret;
- MKQL_ENSURE(ConvertInputArrowType(type, isOptional, ret), "can't get arrow type");
- return ret;
+private:
+ const ui32 StateIndex;
+ const TString FuncName;
+ const TVector<IComputationNode*> ArgsNodes;
+ const TVector<TType*> ArgsTypes;
+
+ const std::vector<arrow::ValueDescr> ArgsValuesDescr;
+ arrow::compute::FunctionRegistry& FunctionRegistry;
+ const arrow::compute::Function& Function;
+ const arrow::compute::ScalarKernel& Kernel;
+};
+
+class TBlockBitCastWrapper : public TMutableComputationNode<TBlockBitCastWrapper> {
+public:
+ TBlockBitCastWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* argType, TType* to)
+ : TMutableComputationNode(mutables)
+ , StateIndex(mutables.CurValueIndex++)
+ , Arg(arg)
+ , ArgsValuesDescr({ ToValueDescr(argType) })
+ , FunctionRegistry(*arrow::compute::GetFunctionRegistry())
+ , Function(ResolveFunction(FunctionRegistry, to))
+ , Kernel(ResolveKernel(Function, ArgsValuesDescr))
+ {
}
- static std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) {
- std::vector<arrow::ValueDescr> res;
- res.reserve(types.size());
- for (const auto& type : types) {
- res.emplace_back(ToValueDescr(type));
- }
+ NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
+ auto& state = GetState(ctx);
+
+ state.Values.clear();
+ state.Values.emplace_back(TArrowBlock::From(Arg->GetValue(ctx)).GetDatum());
+ Y_VERIFY_DEBUG(ArgsValuesDescr[0] == state.Values.back().descr());
- return res;
+ auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
+ ARROW_OK(state.Executor->Execute(state.Values, listener.get()));
+ auto output = state.Executor->WrapResults(state.Values, listener->values());
+ return ctx.HolderFactory.CreateArrowBlock(std::move(output));
}
- struct TState : public TComputationValue<TState> {
- using TComputationValue::TComputationValue;
-
- TState(TMemoryUsageInfo* memInfo, const arrow::compute::Function& function, const arrow::compute::ScalarKernel& kernel,
- arrow::compute::FunctionRegistry& registry, const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx)
- : TComputationValue(memInfo)
- , ExecContext(&ctx.ArrowMemoryPool, nullptr, &registry)
- , KernelContext(&ExecContext)
- , Executor(arrow::compute::detail::KernelExecutor::MakeScalar())
- {
- auto options = function.default_options();
- if (kernel.init) {
- State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
- KernelContext.SetState(State.get());
- }
-
- ARROW_OK(Executor->Init(&KernelContext, { &kernel, argsValuesDescr, options }));
- Values.reserve(argsValuesDescr.size());
- }
+private:
+ void RegisterDependencies() const final {
+ this->DependsOn(Arg);
+ }
- arrow::compute::ExecContext ExecContext;
- arrow::compute::KernelContext KernelContext;
- std::unique_ptr<arrow::compute::KernelState> State;
- std::unique_ptr<arrow::compute::detail::KernelExecutor> Executor;
+ static const arrow::compute::Function& ResolveFunction(const arrow::compute::FunctionRegistry& registry, TType* to) {
+ bool isOptional;
+ std::shared_ptr<arrow::DataType> type;
+ MKQL_ENSURE(ConvertArrowType(to, isOptional, type), "can't get arrow type");
- std::vector<arrow::Datum> Values;
- };
+ auto function = ARROW_RESULT(arrow::compute::GetCastFunction(type));
+ MKQL_ENSURE(function != nullptr, "missing function");
+ MKQL_ENSURE(function->kind() == arrow::compute::Function::SCALAR, "expected SCALAR function");
+ return *function;
+ }
TState& GetState(TComputationContext& ctx) const {
auto& result = ctx.MutableValues[StateIndex];
if (!result.HasValue()) {
- result = ctx.HolderFactory.Create<TState>(Function, Kernel, FunctionRegistry, ArgsValuesDescr, ctx);
+ arrow::compute::CastOptions options(false);
+ result = ctx.HolderFactory.Create<TState>(Function, (const arrow::compute::FunctionOptions*)&options, Kernel, FunctionRegistry, ArgsValuesDescr, ctx);
}
return *static_cast<TState*>(result.AsBoxed().Get());
@@ -127,10 +192,7 @@ private:
private:
const ui32 StateIndex;
- const TString FuncName;
- const TVector<IComputationNode*> ArgsNodes;
- const TVector<TType*> ArgsTypes;
-
+ IComputationNode* Arg;
const std::vector<arrow::ValueDescr> ArgsValuesDescr;
arrow::compute::FunctionRegistry& FunctionRegistry;
const arrow::compute::Function& Function;
@@ -158,5 +220,16 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto
);
}
+IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
+ MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
+ auto argNode = LocateNode(ctx.NodeLocator, callable, 0);
+ MKQL_ENSURE(callable.GetInput(1).GetStaticType()->IsType(), "Expected type");
+ return new TBlockBitCastWrapper(ctx.Mutables,
+ argNode,
+ callable.GetType()->GetArgumentType(0),
+ static_cast<TType*>(callable.GetInput(1).GetNode())
+ );
+}
+
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h
index 0bbd881fdf..8cd15bca62 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.h
@@ -6,6 +6,7 @@ namespace NKikimr {
namespace NMiniKQL {
IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx);
+IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFactoryContext& ctx);
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
index 524fe761a3..ee68f268c3 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
@@ -449,43 +449,70 @@ private:
const size_t Width_;
};
-arrow::Datum ExtractLiteral(TRuntimeNode n) {
- if (n.GetStaticType()->IsOptional()) {
- const auto* dataLiteral = AS_VALUE(TOptionalLiteral, n);
- if (!dataLiteral->HasItem()) {
- bool isOptional;
- auto unpacked = UnpackOptionalData(dataLiteral->GetType(), isOptional);
- std::shared_ptr<arrow::DataType> type;
- MKQL_ENSURE(ConvertArrowType(unpacked, isOptional, type), "Unsupported type of literal");
- return arrow::MakeNullScalar(type);
+class TAsScalarWrapper : public TMutableComputationNode<TAsScalarWrapper> {
+public:
+ TAsScalarWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* type)
+ : TMutableComputationNode(mutables)
+ , Arg_(arg)
+ {
+ bool isOptional;
+ auto unpacked = UnpackOptionalData(type, isOptional);
+ MKQL_ENSURE(ConvertArrowType(unpacked, isOptional, Type_), "Unsupported type of scalar");
+ Slot_ = *unpacked->GetDataSlot();
+ }
+
+ NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
+ auto value = Arg_->GetValue(ctx);
+ arrow::Datum result;
+ if (!value) {
+ result = arrow::MakeNullScalar(Type_);
+ } else {
+ switch (Slot_) {
+ case NUdf::EDataSlot::Bool:
+ result = arrow::Datum(static_cast<bool>(value.Get<bool>()));
+ break;
+ case NUdf::EDataSlot::Int8:
+ result = arrow::Datum(static_cast<int8_t>(value.Get<i8>()));
+ break;
+ case NUdf::EDataSlot::Uint8:
+ result = arrow::Datum(static_cast<uint8_t>(value.Get<ui8>()));
+ break;
+ case NUdf::EDataSlot::Int16:
+ result = arrow::Datum(static_cast<int16_t>(value.Get<i16>()));
+ break;
+ case NUdf::EDataSlot::Uint16:
+ result = arrow::Datum(static_cast<uint16_t>(value.Get<ui16>()));
+ break;
+ case NUdf::EDataSlot::Int32:
+ result = arrow::Datum(static_cast<int32_t>(value.Get<i32>()));
+ break;
+ case NUdf::EDataSlot::Uint32:
+ result = arrow::Datum(static_cast<uint32_t>(value.Get<ui32>()));
+ break;
+ case NUdf::EDataSlot::Int64:
+ result = arrow::Datum(static_cast<int64_t>(value.Get<i64>()));
+ break;
+ case NUdf::EDataSlot::Uint64:
+ result = arrow::Datum(static_cast<uint64_t>(value.Get<ui64>()));
+ break;
+ default:
+ MKQL_ENSURE(false, "Unsupported data slot");
+ }
}
- n = dataLiteral->GetItem();
+
+ return ctx.HolderFactory.CreateArrowBlock(std::move(result));
}
- const auto* dataLiteral = AS_VALUE(TDataLiteral, n);
- switch (*dataLiteral->GetType()->GetDataSlot()) {
- case NUdf::EDataSlot::Bool:
- return arrow::Datum(static_cast<bool>(dataLiteral->AsValue().Get<bool>()));
- case NUdf::EDataSlot::Int8:
- return arrow::Datum(static_cast<int8_t>(dataLiteral->AsValue().Get<i8>()));
- case NUdf::EDataSlot::Uint8:
- return arrow::Datum(static_cast<uint8_t>(dataLiteral->AsValue().Get<ui8>()));
- case NUdf::EDataSlot::Int16:
- return arrow::Datum(static_cast<int16_t>(dataLiteral->AsValue().Get<i16>()));
- case NUdf::EDataSlot::Uint16:
- return arrow::Datum(static_cast<uint16_t>(dataLiteral->AsValue().Get<ui16>()));
- case NUdf::EDataSlot::Int32:
- return arrow::Datum(static_cast<int32_t>(dataLiteral->AsValue().Get<i32>()));
- case NUdf::EDataSlot::Uint32:
- return arrow::Datum(static_cast<uint32_t>(dataLiteral->AsValue().Get<ui32>()));
- case NUdf::EDataSlot::Int64:
- return arrow::Datum(static_cast<int64_t>(dataLiteral->AsValue().Get<i64>()));
- case NUdf::EDataSlot::Uint64:
- return arrow::Datum(static_cast<uint64_t>(dataLiteral->AsValue().Get<ui64>()));
- default:
- MKQL_ENSURE(false, "Unsupported data slot");
+private:
+ void RegisterDependencies() const final {
+ DependsOn(Arg_);
}
-}
+
+private:
+ IComputationNode* const Arg_;
+ std::shared_ptr<arrow::DataType> Type_;
+ NUdf::EDataSlot Slot_;
+};
}
@@ -545,8 +572,7 @@ IComputationNode* WrapWideFromBlocks(TCallable& callable, const TComputationNode
IComputationNode* WrapAsScalar(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
- auto value = ExtractLiteral(callable.GetInput(0U));
- return ctx.NodeFactory.CreateImmutableNode(ctx.HolderFactory.CreateArrowBlock(std::move(value)));
+ return new TAsScalarWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0), callable.GetInput(0).GetStaticType());
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
index c3d3f5c1be..22cb1b5e80 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
@@ -264,6 +264,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller {
{"ToBlocks", &WrapToBlocks},
{"WideToBlocks", &WrapWideToBlocks},
{"BlockFunc", &WrapBlockFunc},
+ {"BlockBitCast", &WrapBlockBitCast},
{"FromBlocks", &WrapFromBlocks},
{"WideFromBlocks", &WrapWideFromBlocks},
{"AsScalar", &WrapAsScalar},
diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp
index 808b607b07..c083bf7f9f 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_program_builder.cpp
@@ -5182,6 +5182,17 @@ TRuntimeNode TProgramBuilder::BlockFunc(const std::string_view& funcName, TType*
return TRuntimeNode(builder.Build(), false);
}
+TRuntimeNode TProgramBuilder::BlockBitCast(TRuntimeNode value, TType* targetType) {
+ MKQL_ENSURE(value.GetStaticType()->IsBlock(), "Expected Block type");
+
+ auto returnType = TBlockType::Create(targetType, AS_TYPE(TBlockType, value.GetStaticType())->GetShape(), Env);
+ TCallableBuilder builder(Env, __func__, returnType);
+ builder.Add(value);
+ builder.Add(TRuntimeNode(targetType, true));
+
+ return TRuntimeNode(builder.Build(), false);
+}
+
bool CanExportType(TType* type, const TTypeEnvironment& env) {
if (type->GetKind() == TType::EKind::Type) {
return false; // Type of Type
diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h
index d514255b2d..8b42022443 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.h
+++ b/ydb/library/yql/minikql/mkql_program_builder.h
@@ -241,6 +241,7 @@ public:
TRuntimeNode AsScalar(TRuntimeNode value);
TRuntimeNode BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args);
+ TRuntimeNode BlockBitCast(TRuntimeNode value, TType* targetType);
// udfs
TRuntimeNode Udf(
diff --git a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp
index 3cb21a0e57..40fb07c48c 100644
--- a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp
+++ b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp
@@ -44,6 +44,27 @@ private:
}
}
+ bool HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, bool& has, TExprContext& ctx) const override {
+ try {
+ has = false;
+ TScopedAlloc alloc;
+ TTypeEnvironment env(alloc);
+ TProgramBuilder pgmBuilder(env, FunctionRegistry_);
+ TNullOutput null;
+ auto mkqlFromType = NCommon::BuildType(*from, pgmBuilder, null);
+ auto mkqlToType = NCommon::BuildType(*to, pgmBuilder, null);
+ if (!HasArrowCast(mkqlFromType, mkqlToType)) {
+ return true;
+ }
+
+ has = true;
+ return true;
+ } catch (const std::exception& e) {
+ ctx.AddError(TIssue(pos, e.what()));
+ return false;
+ }
+ }
+
private:
const IFunctionRegistry& FunctionRegistry_;
};
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index b242d345e3..757dbfafb4 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -2382,6 +2382,12 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
return ctx.ProgramBuilder.BlockFunc(node.Child(0)->Content(), returnType, args);
});
+ AddCallable("BlockBitCast", [](const TExprNode& node, TMkqlBuildContext& ctx) {
+ auto arg = MkqlBuildExpr(*node.Child(0), ctx);
+ auto targetType = BuildType(node, *node.Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder);
+ return ctx.ProgramBuilder.BlockBitCast(arg, targetType);
+ });
+
AddCallable("PgArray", [](const TExprNode& node, TMkqlBuildContext& ctx) {
std::vector<TRuntimeNode> args;
args.reserve(node.ChildrenSize());