diff options
| author | aneporada <[email protected]> | 2022-12-28 18:57:51 +0300 |
|---|---|---|
| committer | aneporada <[email protected]> | 2022-12-28 18:57:51 +0300 |
| commit | bbacf69a08eb0e0462465b1d8dacf55b7594da32 (patch) | |
| tree | d7f67a608e39709d043fa4cbb79ad836179ae76f | |
| parent | 48aaf60ffad50a3cd1dd55d80d8d2fde282f9f26 (diff) | |
Improve IArrowResolver - take into account return type when resolving arrow kernels
13 files changed, 196 insertions, 178 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 4b7612492ab..c35eeb5785f 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -4469,9 +4469,9 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo allInputTypes.push_back(i); } - bool supportedInputTypes = false; - YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(lambda->Pos()), allInputTypes, supportedInputTypes, ctx)); - if (!supportedInputTypes) { + auto resolveStatus = types.ArrowResolver->AreTypesSupported(ctx.GetPosition(lambda->Pos()), allInputTypes, ctx); + YQL_ENSURE(resolveStatus != IArrowResolver::ERROR); + if (resolveStatus != IArrowResolver::OK) { return false; } @@ -4539,52 +4539,30 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo ++newNodes; return true; } - if (node->IsCallable("Apply") && node->Head().IsCallable("Udf")) { + const bool isUdf = node->IsCallable("Apply") && node->Head().IsCallable("Udf"); + if (isUdf) { auto func = node->Head().Head().Content(); if (!func.StartsWith("ClickHouse.")) { return true; } + } + { TVector<const TTypeAnnotationNode*> allTypes; allTypes.push_back(node->GetTypeAnn()); - for (ui32 i = 1; i < node->ChildrenSize(); ++i) { + for (ui32 i = isUdf ? 1 : 0; i < node->ChildrenSize(); ++i) { allTypes.push_back(node->Child(i)->GetTypeAnn()); } - bool supported = false; - YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Pos()), allTypes, supported, ctx)); - if (!supported) { - return true; - } - - funcArgs.push_back(nullptr); - } else { - auto fit = funcs.find(node->Content()); - if (fit == funcs.end()) { + auto resolveStatus = types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Pos()), allTypes, ctx); + YQL_ENSURE(resolveStatus != IArrowResolver::ERROR); + if (resolveStatus != IArrowResolver::OK) { return true; } - - arrowFunctionName = fit->second.Name; - funcArgs.push_back(ctx.NewAtom(node->Pos(), arrowFunctionName)); } - for (ui32 i = arrowFunctionName.empty() ? 1 : 0; i < node->ChildrenSize(); ++i) { - auto child = node->Child(i); - if (child->IsComplete()) { - funcArgs.push_back(ctx.NewCallable(node->Pos(), "AsScalar", { node->ChildPtr(i) })); - } else { - auto rit = rewrites.find(child); - if (rit == rewrites.end()) { - return true; - } - - funcArgs.push_back(rit->second); - } - } - - const TTypeAnnotationNode* outType = nullptr; TVector<const TTypeAnnotationNode*> argTypes; - for (ui32 i = arrowFunctionName.empty() ? 1 : 0; i < node->ChildrenSize(); ++i) { + for (ui32 i = isUdf ? 1 : 0; i < node->ChildrenSize(); ++i) { auto child = node->Child(i); if (child->IsComplete()) { argTypes.push_back(ctx.MakeType<TScalarExprType>(child->GetTypeAnn())); @@ -4595,20 +4573,15 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo } } - if (!arrowFunctionName.empty()) { - YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), arrowFunctionName, argTypes, outType, ctx)); - if (!outType) { - return true; - } - - bool isScalar; - if (!IsSameAnnotation(*node->GetTypeAnn(), *GetBlockItemType(*outType, isScalar))) { - return true; - } - - rewrites[node.Get()] = ctx.NewCallable(node->Pos(), "BlockFunc", std::move(funcArgs)); + const TTypeAnnotationNode* outType = node->GetTypeAnn(); + if (outType->HasFixedSizeRepr()) { + outType = ctx.MakeType<TBlockExprType>(outType); } else { - funcArgs[0] = ctx.Builder(node->Head().Pos()) + outType = ctx.MakeType<TChunkedBlockExprType>(outType); + } + + if (isUdf) { + funcArgs.push_back(ctx.Builder(node->Head().Pos()) .Callable("Udf") .Add(0, node->Head().ChildPtr(0)) .Add(1, node->Head().ChildPtr(1)) @@ -4634,11 +4607,39 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo .Seal() .Add(3, node->Head().ChildPtr(3)) .Seal() - .Build(); + .Build()); + } else { + auto fit = funcs.find(node->Content()); + if (fit == funcs.end()) { + return true; + } + + arrowFunctionName = fit->second.Name; + funcArgs.push_back(ctx.NewAtom(node->Pos(), arrowFunctionName)); + + auto resolveStatus = types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), arrowFunctionName, argTypes, outType, ctx); + YQL_ENSURE(resolveStatus != IArrowResolver::ERROR); + if (resolveStatus != IArrowResolver::OK) { + return true; + } + funcArgs.push_back(ExpandType(node->Pos(), *outType, ctx)); + } + + for (ui32 i = isUdf ? 1 : 0; i < node->ChildrenSize(); ++i) { + auto child = node->Child(i); + if (child->IsComplete()) { + funcArgs.push_back(ctx.NewCallable(node->Pos(), "AsScalar", { node->ChildPtr(i) })); + } else { + auto rit = rewrites.find(child); + if (rit == rewrites.end()) { + return true; + } - rewrites[node.Get()] = ctx.NewCallable(node->Pos(), "Apply", std::move(funcArgs)); + funcArgs.push_back(rit->second); + } } + rewrites[node.Get()] = ctx.NewCallable(node->Pos(), isUdf ? "Apply" : "BlockFunc", std::move(funcArgs)); ++newNodes; return true; }); @@ -4664,9 +4665,9 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo if (lambda->ChildPtr(i)->IsComplete()) { TVector<const TTypeAnnotationNode*> allTypes; allTypes.push_back(lambda->ChildPtr(i)->GetTypeAnn()); - bool supported = false; - YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(lambda->Pos()), allTypes, supported, ctx)); - if (supported) { + auto resolveStatus = types.ArrowResolver->AreTypesSupported(ctx.GetPosition(lambda->Pos()), allTypes, ctx); + YQL_ENSURE(resolveStatus != IArrowResolver::ERROR); + if (resolveStatus == IArrowResolver::OK) { rewrites[lambda->Child(i)] = ctx.NewCallable(lambda->Pos(), "AsScalar", { lambda->ChildPtr(i) }); ++newNodes; } @@ -4893,11 +4894,10 @@ TExprNode::TPtr OptimizeSkipTakeToBlocks(const TExprNode::TPtr& node, TExprConte return node; } - bool supported = false; - YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Head().Pos()), - TVector<const TTypeAnnotationNode*>(allTypes.begin(), allTypes.end()), - supported, ctx)); - if (!supported) { + auto resolveStatus = types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Head().Pos()), + TVector<const TTypeAnnotationNode*>(allTypes.begin(), allTypes.end()), ctx); + YQL_ENSURE(resolveStatus != IArrowResolver::ERROR); + if (resolveStatus != IArrowResolver::OK) { return node; } diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index 3ad8cdf39e9..6b8df2433d4 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -200,7 +200,7 @@ IGraphTransformer::TStatus BlockLogicalWrapper(const TExprNode::TPtr& input, TEx IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { Y_UNUSED(output); - if (!EnsureMinArgsCount(*input, 1U, ctx.Expr)) { + if (!EnsureMinArgsCount(*input, 2U, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -209,34 +209,25 @@ IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprN } auto name = input->Child(0)->Content(); - - for (ui32 i = 1; i < input->ChildrenSize(); ++i) { - if (!EnsureBlockOrScalarType(*input->Child(i), ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - } - - if (!ctx.Types.ArrowResolver) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Arrow resolver isn't available")); - return IGraphTransformer::TStatus::Error; - } - - const TTypeAnnotationNode* outType = nullptr; - TVector<const TTypeAnnotationNode*> argTypes; - for (ui32 i = 1; i < input->ChildrenSize(); ++i) { - argTypes.push_back(input->Child(i)->GetTypeAnn()); + Y_UNUSED(name); + if (auto status = EnsureTypeRewrite(input->ChildRef(1), ctx.Expr); status != IGraphTransformer::TStatus::Ok) { + return status; } - if (!ctx.Types.ArrowResolver->LoadFunctionMetadata(ctx.Expr.GetPosition(input->Pos()), name, argTypes, outType, ctx.Expr)) { + auto returnType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); + const bool allowChunked = true; + if (!EnsureBlockOrScalarType(input->Child(1)->Pos(), *returnType, ctx.Expr, allowChunked)) { return IGraphTransformer::TStatus::Error; } - if (!outType) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() << "No such Arrow function: " << name)); - return IGraphTransformer::TStatus::Error; + for (ui32 i = 2; i < input->ChildrenSize(); ++i) { + if (!EnsureBlockOrScalarType(*input->Child(i), ctx.Expr, allowChunked)) { + return IGraphTransformer::TStatus::Error; + } } - input->SetTypeAnn(outType); + // TODO: more validation + input->SetTypeAnn(returnType); return IGraphTransformer::TStatus::Ok; } @@ -261,24 +252,24 @@ IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TEx bool isScalar; auto inputType = GetBlockItemType(*input->Child(0)->GetTypeAnn(), isScalar); - auto outputType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(); - bool has = false; - if (!ctx.Types.ArrowResolver->HasCast(ctx.Expr.GetPosition(input->Pos()), inputType, outputType, has, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - if (!has) { + auto castStatus = ctx.Types.ArrowResolver->HasCast(ctx.Expr.GetPosition(input->Pos()), inputType, outputType, ctx.Expr); + if (castStatus == IArrowResolver::ERROR) { + return IGraphTransformer::TStatus::Error; + } else if (castStatus == IArrowResolver::NOT_FOUND) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "No such cast")); return IGraphTransformer::TStatus::Error; } if (isScalar) { input->SetTypeAnn(ctx.Expr.MakeType<TScalarExprType>(outputType)); - } else { + } else if (outputType->HasFixedSizeRepr()) { input->SetTypeAnn(ctx.Expr.MakeType<TBlockExprType>(outputType)); + } else { + input->SetTypeAnn(ctx.Expr.MakeType<TChunkedBlockExprType>(outputType)); } - + return IGraphTransformer::TStatus::Ok; } diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp index ab0fb4bef94..d1c9be8a6d4 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.cpp +++ b/ydb/library/yql/core/yql_aggregate_expander.cpp @@ -491,9 +491,9 @@ TExprNode::TPtr TAggregateExpander::GetFinalAggStateExtractor(ui32 i) { return Ctx.Builder(Node->Pos()) .Lambda() .Param("item") - .Callable("Member")
- .Arg(0, "item")
- .Add(1, columnNames[i])
+ .Callable("Member") + .Arg(0, "item") + .Add(1, columnNames[i]) .Seal() .Seal() .Build(); @@ -538,9 +538,9 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea extractorRoots.push_back(extractorArgs[*rowIndex]); } - bool supported = false; - YQL_ENSURE(TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allKeyTypes, supported, Ctx)); - if (!supported) { + auto resolveStatus = TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allKeyTypes, Ctx); + YQL_ENSURE(resolveStatus != IArrowResolver::ERROR); + if (resolveStatus != IArrowResolver::OK) { return nullptr; } @@ -563,9 +563,10 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea TVector<const TTypeAnnotationNode*> allTypes; allTypes.push_back(root->GetTypeAnn()); - bool supported = false; - YQL_ENSURE(TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allTypes, supported, Ctx)); - if (!supported) { + + auto resolveStatus = TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allKeyTypes, Ctx); + YQL_ENSURE(resolveStatus != IArrowResolver::ERROR); + if (resolveStatus != IArrowResolver::OK) { return nullptr; } @@ -735,9 +736,9 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const .Do(GetPartialAggArgExtractor(i, true)) .Done() .With(1) - .Callable("Member")
- .Arg(0, "state")
- .Add(1, columnNames[i])
+ .Callable("Member") + .Arg(0, "state") + .Add(1, columnNames[i]) .Seal() .Done() .Seal() @@ -750,9 +751,9 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const .Do(GetPartialAggArgExtractor(i, true)) .Done() .With(1) - .Callable("Member")
- .Arg(0, "state")
- .Add(1, columnNames[i])
+ .Callable("Member") + .Arg(0, "state") + .Add(1, columnNames[i]) .Seal() .Done() .With(2) diff --git a/ydb/library/yql/core/yql_arrow_resolver.h b/ydb/library/yql/core/yql_arrow_resolver.h index 05e3eb3ad4c..d30f6239ca3 100644 --- a/ydb/library/yql/core/yql_arrow_resolver.h +++ b/ydb/library/yql/core/yql_arrow_resolver.h @@ -8,14 +8,20 @@ class IArrowResolver : public TThrRefBase { public: using TPtr = TIntrusiveConstPtr<IArrowResolver>; + enum EStatus { + OK, + NOT_FOUND, + ERROR, + }; + virtual ~IArrowResolver() = default; - virtual bool LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes, - const TTypeAnnotationNode*& returnType, TExprContext& ctx) const = 0; + virtual EStatus LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes, + const TTypeAnnotationNode* returnType, TExprContext& ctx) const = 0; - virtual bool HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, bool& has, TExprContext& ctx) const = 0; + virtual EStatus HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, TExprContext& ctx) const = 0; - virtual bool AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, bool& supported, TExprContext& ctx) const = 0; + virtual EStatus AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, TExprContext& ctx) const = 0; }; } diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp index 798e3bf02b9..2f9250f88a0 100644 --- a/ydb/library/yql/minikql/arrow/mkql_functions.cpp +++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp @@ -122,36 +122,56 @@ bool ConvertOutputArrowType(const arrow::compute::OutputType& outType, const std } } -bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env, const IBuiltinFunctionRegistry& registry) { +bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType* outputType, const IBuiltinFunctionRegistry& registry) { bool hasOptionals = false; bool many = false; std::vector<NUdf::TDataTypeId> argTypes; for (const auto& t : inputTypes) { - bool isOptional; auto asBlockType = AS_TYPE(TBlockType, t); if (asBlockType->GetShape() == TBlockType::EShape::Many) { many = true; } - auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional); + bool isOptional; + auto baseType = UnpackOptional(asBlockType->GetItemType(), isOptional); + if (!baseType->IsData()) { + return false; + } + hasOptionals = hasOptionals || isOptional; - argTypes.push_back(dataType->GetSchemeType()); + argTypes.push_back(AS_TYPE(TDataType, baseType)->GetSchemeType()); } - auto kernel = registry.FindKernel(name, argTypes.data(), argTypes.size()); + NUdf::TDataTypeId returnType; + bool returnIsOptional; + { + auto asBlockType = AS_TYPE(TBlockType, outputType); + MKQL_ENSURE(many ^ (asBlockType->GetShape() == TBlockType::EShape::Scalar), "Output shape is inconsistent with input shapes"); + auto baseType = UnpackOptional(asBlockType->GetItemType(), returnIsOptional); + if (!baseType->IsData()) { + return false; + } + returnType = AS_TYPE(TDataType, baseType)->GetSchemeType(); + } + + auto kernel = registry.FindKernel(name, argTypes.data(), argTypes.size(), returnType); if (!kernel) { return false; } - outputType = TDataType::Create(kernel->ReturnType, env); - if (kernel->Family.NullMode != TKernelFamily::ENullMode::AlwaysNotNull) { - if (hasOptionals || kernel->Family.NullMode == TKernelFamily::ENullMode::AlwaysNull) { - outputType = TOptionalType::Create(outputType, env); - } + bool match = false; + switch (kernel->Family.NullMode) { + case TKernelFamily::ENullMode::Default: + match = returnIsOptional == hasOptionals; + break; + case TKernelFamily::ENullMode::AlwaysNull: + match = returnIsOptional; + break; + case TKernelFamily::ENullMode::AlwaysNotNull: + match = !returnIsOptional; + break; } - - outputType = TBlockType::Create(outputType, many ? TBlockType::EShape::Many : TBlockType::EShape::Scalar, env); - return true; + return match; } bool HasArrowCast(TType* from, TType* to) { diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.h b/ydb/library/yql/minikql/arrow/mkql_functions.h index 267c79090e3..f33443debe0 100644 --- a/ydb/library/yql/minikql/arrow/mkql_functions.h +++ b/ydb/library/yql/minikql/arrow/mkql_functions.h @@ -7,7 +7,7 @@ namespace NKikimr::NMiniKQL { class IBuiltinFunctionRegistry; -bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env, const IBuiltinFunctionRegistry& registry); +bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType* outputType, const IBuiltinFunctionRegistry& registry); bool ConvertInputArrowType(TType* blockType, bool& isOptional, arrow::ValueDescr& descr); bool HasArrowCast(TType* from, TType* to); } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp index 81ee6bb78de..7d13b9e8288 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp @@ -43,7 +43,7 @@ const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function return *static_cast<const arrow::compute::ScalarKernel*>(kernel); } -const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes) { +const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes, TType* returnType) { std::vector<NUdf::TDataTypeId> argTypes; for (const auto& t : inputTypes) { auto asBlockType = AS_TYPE(TBlockType, t); @@ -52,7 +52,15 @@ const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TSt argTypes.push_back(dataType->GetSchemeType()); } - auto kernel = builtins.FindKernel(funcName, argTypes.data(), argTypes.size()); + NUdf::TDataTypeId returnTypeId; + { + auto asBlockType = AS_TYPE(TBlockType, returnType); + bool isOptional; + auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional); + returnTypeId = dataType->GetSchemeType(); + } + + auto kernel = builtins.FindKernel(funcName, argTypes.data(), argTypes.size(), returnTypeId); MKQL_ENSURE(kernel, "Can't find kernel for " << funcName); return *kernel; } @@ -68,11 +76,11 @@ struct TState : public TComputationValue<TState> { , ExecContext(&ctx.ArrowMemoryPool, nullptr, nullptr) , KernelContext(&ExecContext) { - if (kernel.init) {
- State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
- KernelContext.SetState(State.get());
- }
-
+ if (kernel.init) { + State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options })); + KernelContext.SetState(State.get()); + } + Values.reserve(argsValuesDescr.size()); } @@ -90,14 +98,15 @@ public: const IBuiltinFunctionRegistry& builtins, const TString& funcName, TVector<IComputationNode*>&& argsNodes, - TVector<TType*>&& argsTypes) + TVector<TType*>&& argsTypes, + TType* returnType) : TMutableComputationNode(mutables) , StateIndex(mutables.CurValueIndex++) , FuncName(funcName) , ArgsNodes(std::move(argsNodes)) , ArgsTypes(std::move(argsTypes)) , ArgsValuesDescr(ToValueDescr(ArgsTypes)) - , Kernel(ResolveKernel(builtins, FuncName, ArgsTypes)) + , Kernel(ResolveKernel(builtins, FuncName, ArgsTypes, returnType)) { } @@ -112,8 +121,8 @@ public: auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>(); auto executor = arrow::compute::detail::KernelExecutor::MakeScalar(); - ARROW_OK(executor->Init(&state.KernelContext, { &Kernel.GetArrowKernel(), ArgsValuesDescr, state.Options }));
- ARROW_OK(executor->Execute(state.Values, listener.get()));
+ ARROW_OK(executor->Init(&state.KernelContext, { &Kernel.GetArrowKernel(), ArgsValuesDescr, state.Options })); + ARROW_OK(executor->Execute(state.Values, listener.get())); auto output = executor->WrapResults(state.Values, listener->values()); return ctx.HolderFactory.CreateArrowBlock(std::move(output)); } @@ -176,8 +185,8 @@ public: auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>(); auto executor = arrow::compute::detail::KernelExecutor::MakeScalar(); - ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, state.Options }));
- ARROW_OK(executor->Execute(state.Values, listener.get()));
+ ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, state.Options })); + ARROW_OK(executor->Execute(state.Values, listener.get())); auto output = executor->WrapResults(state.Values, listener->values()); return ctx.HolderFactory.CreateArrowBlock(std::move(output)); } @@ -234,7 +243,8 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto *ctx.FunctionRegistry.GetBuiltins(), funcName, std::move(argsNodes), - std::move(argsTypes) + std::move(argsTypes), + callableType->GetReturnType() ); } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp index 2e5052ce150..9a7a300d118 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp @@ -220,7 +220,7 @@ private: const TDescriptionList& FindCandidates(const std::string_view& name) const; - const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const final; + const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const final; void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) final; @@ -357,13 +357,13 @@ void TBuiltinFunctionRegistry::PrintInfoTo(IOutputStream& out) const } } -const TKernel* TBuiltinFunctionRegistry::FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const { +const TKernel* TBuiltinFunctionRegistry::FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const { auto fit = KernelFamilyMap.find(TString(name)); if (fit == KernelFamilyMap.end()) { return nullptr; } - return fit->second->FindKernel(argTypes, argTypesCount); + return fit->second->FindKernel(argTypes, argTypesCount, returnType); } void TBuiltinFunctionRegistry::RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) { diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h index bef5393b263..4d7f6add989 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h @@ -1188,7 +1188,7 @@ void AddBinaryKernel(TKernelFamilyBase& owner) { arrow::compute::ScalarKernel k({ GetPrimitiveInputArrowType<TInput1>(), GetPrimitiveInputArrowType<TInput2>() }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::Exec); k.null_handling = owner.NullMode == TKernelFamily::ENullMode::Default ? arrow::compute::NullHandling::INTERSECTION : arrow::compute::NullHandling::COMPUTED_PREALLOCATE; - owner.KernelMap.emplace(argTypes, std::make_unique<TPlainKernel>(owner, argTypes, returnType, k)); + owner.KernelMap.emplace(std::make_pair(argTypes, returnType), std::make_unique<TPlainKernel>(owner, argTypes, returnType, k)); } template<typename TInput1, typename TInput2, @@ -1204,7 +1204,7 @@ void AddBinaryPredicateKernel(TKernelFamilyBase& owner) { arrow::compute::ScalarKernel k({ GetPrimitiveInputArrowType<TInput1>(), GetPrimitiveInputArrowType<TInput2>() }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::Exec); k.null_handling = owner.NullMode == TKernelFamily::ENullMode::Default ? arrow::compute::NullHandling::INTERSECTION : arrow::compute::NullHandling::COMPUTED_PREALLOCATE; - owner.KernelMap.emplace(argTypes, std::make_unique<TPlainKernel>(owner, argTypes, returnType, k)); + owner.KernelMap.emplace(std::make_pair(argTypes, returnType), std::make_unique<TPlainKernel>(owner, argTypes, returnType, k)); } template<template<typename, typename, typename> class TFunc> diff --git a/ydb/library/yql/minikql/mkql_function_metadata.h b/ydb/library/yql/minikql/mkql_function_metadata.h index 3042922133f..251710244e4 100644 --- a/ydb/library/yql/minikql/mkql_function_metadata.h +++ b/ydb/library/yql/minikql/mkql_function_metadata.h @@ -69,7 +69,7 @@ public: {} virtual ~TKernelFamily() = default; - virtual const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const = 0; + virtual const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const = 0; }; class TKernel { @@ -90,18 +90,20 @@ public: virtual ~TKernel() = default; }; +using TKernelMapKey = std::pair<std::vector<NUdf::TDataTypeId>, NUdf::TDataTypeId>; struct TTypeHasher { - std::size_t operator()(const std::vector<NUdf::TDataTypeId>& s) const noexcept { + std::size_t operator()(const TKernelMapKey& s) const noexcept { size_t r = 0; - for (const auto& x : s) { + for (const auto& x : s.first) { r = CombineHashes<size_t>(r, x); } + r = CombineHashes<size_t>(r, s.second); return r; } }; -using TKernelMap = std::unordered_map<std::vector<NUdf::TDataTypeId>, std::unique_ptr<TKernel>, TTypeHasher>; +using TKernelMap = std::unordered_map<TKernelMapKey, std::unique_ptr<TKernel>, TTypeHasher>; using TKernelFamilyMap = std::unordered_map<TString, std::unique_ptr<TKernelFamily>>; @@ -112,9 +114,9 @@ public: : TKernelFamily(nullMode, functionOptions) {} - const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const final { - std::vector<NUdf::TDataTypeId> key(argTypes, argTypes + argTypesCount); - auto it = KernelMap.find(key); + const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const final { + std::vector<NUdf::TDataTypeId> args(argTypes, argTypes + argTypesCount); + auto it = KernelMap.find({args, returnType}); if (it == KernelMap.end()) { return nullptr; } @@ -144,7 +146,7 @@ public: virtual TFunctionDescriptor GetBuiltin(const std::string_view& name, const std::pair<NUdf::TDataTypeId, bool>* argTypes, size_t argTypesCount) const = 0; - virtual const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const = 0; + virtual const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const = 0; virtual void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) = 0; }; diff --git a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp index e15b992de9f..c8b10f48c49 100644 --- a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp +++ b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp @@ -19,57 +19,45 @@ public: {} private: - bool LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes, - const TTypeAnnotationNode*& returnType, TExprContext& ctx) const override { + EStatus LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes, + const TTypeAnnotationNode* returnType, TExprContext& ctx) const override + { try { - returnType = nullptr; TScopedAlloc alloc(__LOCATION__); TTypeEnvironment env(alloc); TProgramBuilder pgmBuilder(env, FunctionRegistry_); - TType* mkqlOutputType; + TNullOutput null; TVector<TType*> mkqlInputTypes; for (const auto& type : argTypes) { - TNullOutput null; auto mkqlType = NCommon::BuildType(*type, pgmBuilder, null); mkqlInputTypes.emplace_back(mkqlType); } - - if (!FindArrowFunction(name, mkqlInputTypes, mkqlOutputType, env, *FunctionRegistry_.GetBuiltins())) { - return true; - } - - returnType = NCommon::ConvertMiniKQLType(pos, mkqlOutputType, ctx); - return true; + TType* mkqlOutputType = NCommon::BuildType(*returnType, pgmBuilder, null); + bool found = FindArrowFunction(name, mkqlInputTypes, mkqlOutputType, *FunctionRegistry_.GetBuiltins()); + return found ? EStatus::OK : EStatus::NOT_FOUND; } catch (const std::exception& e) { ctx.AddError(TIssue(pos, e.what())); - return false; + return EStatus::ERROR; } } - bool HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, bool& has, TExprContext& ctx) const override { + EStatus HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, TExprContext& ctx) const override { try { - has = false; TScopedAlloc alloc(__LOCATION__); TTypeEnvironment env(alloc); TProgramBuilder pgmBuilder(env, FunctionRegistry_); TNullOutput null; auto mkqlFromType = NCommon::BuildType(*from, pgmBuilder, null); auto mkqlToType = NCommon::BuildType(*to, pgmBuilder, null); - if (!HasArrowCast(mkqlFromType, mkqlToType)) { - return true; - } - - has = true; - return true; + return HasArrowCast(mkqlFromType, mkqlToType) ? EStatus::OK : EStatus::NOT_FOUND; } catch (const std::exception& e) { ctx.AddError(TIssue(pos, e.what())); - return false; + return EStatus::ERROR; } } - bool AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, bool& supported, TExprContext& ctx) const override { + EStatus AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, TExprContext& ctx) const override { try { - supported = false; TScopedAlloc alloc(__LOCATION__); TTypeEnvironment env(alloc); TProgramBuilder pgmBuilder(env, FunctionRegistry_); @@ -79,15 +67,13 @@ private: bool isOptional; std::shared_ptr<arrow::DataType> arrowType; if (!ConvertArrowType(mkqlType, isOptional, arrowType)) { - return true; + return EStatus::NOT_FOUND; } } - - supported = true; - return true; + return EStatus::OK; } catch (const std::exception& e) { ctx.AddError(TIssue(pos, e.what())); - return false; + return EStatus::ERROR; } } diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index cec8f7e8633..a6c89236af8 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -2390,7 +2390,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("BlockFunc", [](const TExprNode& node, TMkqlBuildContext& ctx) { TVector<TRuntimeNode> args; - for (ui32 i = 1; i < node.ChildrenSize(); ++i) { + for (ui32 i = 2; i < node.ChildrenSize(); ++i) { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } diff --git a/ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp b/ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp index 57ff41b96c6..bd610534c15 100644 --- a/ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp +++ b/ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp @@ -172,7 +172,9 @@ public: allTypes.push_back(x->GetItemType()); } - YQL_ENSURE(State_->Types->ArrowResolver->AreTypesSupported(ctx.GetPosition(read->Pos()), allTypes, supportedArrowTypes, ctx)); + auto resolveStatus = State_->Types->ArrowResolver->AreTypesSupported(ctx.GetPosition(read->Pos()), allTypes, ctx); + YQL_ENSURE(resolveStatus != IArrowResolver::ERROR); + supportedArrowTypes = resolveStatus == IArrowResolver::OK; } return Build<TDqSourceWrap>(ctx, read->Pos()) |
