summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraneporada <[email protected]>2022-12-28 18:57:51 +0300
committeraneporada <[email protected]>2022-12-28 18:57:51 +0300
commitbbacf69a08eb0e0462465b1d8dacf55b7594da32 (patch)
treed7f67a608e39709d043fa4cbb79ad836179ae76f
parent48aaf60ffad50a3cd1dd55d80d8d2fde282f9f26 (diff)
Improve IArrowResolver - take into account return type when resolving arrow kernels
-rw-r--r--ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp114
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_blocks.cpp51
-rw-r--r--ydb/library/yql/core/yql_aggregate_expander.cpp31
-rw-r--r--ydb/library/yql/core/yql_arrow_resolver.h14
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions.cpp46
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions.h2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp38
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp6
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h4
-rw-r--r--ydb/library/yql/minikql/mkql_function_metadata.h18
-rw-r--r--ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp44
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp2
-rw-r--r--ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp4
13 files changed, 196 insertions, 178 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
index 4b7612492ab..c35eeb5785f 100644
--- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
+++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
@@ -4469,9 +4469,9 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo
allInputTypes.push_back(i);
}
- bool supportedInputTypes = false;
- YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(lambda->Pos()), allInputTypes, supportedInputTypes, ctx));
- if (!supportedInputTypes) {
+ auto resolveStatus = types.ArrowResolver->AreTypesSupported(ctx.GetPosition(lambda->Pos()), allInputTypes, ctx);
+ YQL_ENSURE(resolveStatus != IArrowResolver::ERROR);
+ if (resolveStatus != IArrowResolver::OK) {
return false;
}
@@ -4539,52 +4539,30 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo
++newNodes;
return true;
}
- if (node->IsCallable("Apply") && node->Head().IsCallable("Udf")) {
+ const bool isUdf = node->IsCallable("Apply") && node->Head().IsCallable("Udf");
+ if (isUdf) {
auto func = node->Head().Head().Content();
if (!func.StartsWith("ClickHouse.")) {
return true;
}
+ }
+ {
TVector<const TTypeAnnotationNode*> allTypes;
allTypes.push_back(node->GetTypeAnn());
- for (ui32 i = 1; i < node->ChildrenSize(); ++i) {
+ for (ui32 i = isUdf ? 1 : 0; i < node->ChildrenSize(); ++i) {
allTypes.push_back(node->Child(i)->GetTypeAnn());
}
- bool supported = false;
- YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Pos()), allTypes, supported, ctx));
- if (!supported) {
- return true;
- }
-
- funcArgs.push_back(nullptr);
- } else {
- auto fit = funcs.find(node->Content());
- if (fit == funcs.end()) {
+ auto resolveStatus = types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Pos()), allTypes, ctx);
+ YQL_ENSURE(resolveStatus != IArrowResolver::ERROR);
+ if (resolveStatus != IArrowResolver::OK) {
return true;
}
-
- arrowFunctionName = fit->second.Name;
- funcArgs.push_back(ctx.NewAtom(node->Pos(), arrowFunctionName));
}
- for (ui32 i = arrowFunctionName.empty() ? 1 : 0; i < node->ChildrenSize(); ++i) {
- auto child = node->Child(i);
- if (child->IsComplete()) {
- funcArgs.push_back(ctx.NewCallable(node->Pos(), "AsScalar", { node->ChildPtr(i) }));
- } else {
- auto rit = rewrites.find(child);
- if (rit == rewrites.end()) {
- return true;
- }
-
- funcArgs.push_back(rit->second);
- }
- }
-
- const TTypeAnnotationNode* outType = nullptr;
TVector<const TTypeAnnotationNode*> argTypes;
- for (ui32 i = arrowFunctionName.empty() ? 1 : 0; i < node->ChildrenSize(); ++i) {
+ for (ui32 i = isUdf ? 1 : 0; i < node->ChildrenSize(); ++i) {
auto child = node->Child(i);
if (child->IsComplete()) {
argTypes.push_back(ctx.MakeType<TScalarExprType>(child->GetTypeAnn()));
@@ -4595,20 +4573,15 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo
}
}
- if (!arrowFunctionName.empty()) {
- YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), arrowFunctionName, argTypes, outType, ctx));
- if (!outType) {
- return true;
- }
-
- bool isScalar;
- if (!IsSameAnnotation(*node->GetTypeAnn(), *GetBlockItemType(*outType, isScalar))) {
- return true;
- }
-
- rewrites[node.Get()] = ctx.NewCallable(node->Pos(), "BlockFunc", std::move(funcArgs));
+ const TTypeAnnotationNode* outType = node->GetTypeAnn();
+ if (outType->HasFixedSizeRepr()) {
+ outType = ctx.MakeType<TBlockExprType>(outType);
} else {
- funcArgs[0] = ctx.Builder(node->Head().Pos())
+ outType = ctx.MakeType<TChunkedBlockExprType>(outType);
+ }
+
+ if (isUdf) {
+ funcArgs.push_back(ctx.Builder(node->Head().Pos())
.Callable("Udf")
.Add(0, node->Head().ChildPtr(0))
.Add(1, node->Head().ChildPtr(1))
@@ -4634,11 +4607,39 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo
.Seal()
.Add(3, node->Head().ChildPtr(3))
.Seal()
- .Build();
+ .Build());
+ } else {
+ auto fit = funcs.find(node->Content());
+ if (fit == funcs.end()) {
+ return true;
+ }
+
+ arrowFunctionName = fit->second.Name;
+ funcArgs.push_back(ctx.NewAtom(node->Pos(), arrowFunctionName));
+
+ auto resolveStatus = types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), arrowFunctionName, argTypes, outType, ctx);
+ YQL_ENSURE(resolveStatus != IArrowResolver::ERROR);
+ if (resolveStatus != IArrowResolver::OK) {
+ return true;
+ }
+ funcArgs.push_back(ExpandType(node->Pos(), *outType, ctx));
+ }
+
+ for (ui32 i = isUdf ? 1 : 0; i < node->ChildrenSize(); ++i) {
+ auto child = node->Child(i);
+ if (child->IsComplete()) {
+ funcArgs.push_back(ctx.NewCallable(node->Pos(), "AsScalar", { node->ChildPtr(i) }));
+ } else {
+ auto rit = rewrites.find(child);
+ if (rit == rewrites.end()) {
+ return true;
+ }
- rewrites[node.Get()] = ctx.NewCallable(node->Pos(), "Apply", std::move(funcArgs));
+ funcArgs.push_back(rit->second);
+ }
}
+ rewrites[node.Get()] = ctx.NewCallable(node->Pos(), isUdf ? "Apply" : "BlockFunc", std::move(funcArgs));
++newNodes;
return true;
});
@@ -4664,9 +4665,9 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo
if (lambda->ChildPtr(i)->IsComplete()) {
TVector<const TTypeAnnotationNode*> allTypes;
allTypes.push_back(lambda->ChildPtr(i)->GetTypeAnn());
- bool supported = false;
- YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(lambda->Pos()), allTypes, supported, ctx));
- if (supported) {
+ auto resolveStatus = types.ArrowResolver->AreTypesSupported(ctx.GetPosition(lambda->Pos()), allTypes, ctx);
+ YQL_ENSURE(resolveStatus != IArrowResolver::ERROR);
+ if (resolveStatus == IArrowResolver::OK) {
rewrites[lambda->Child(i)] = ctx.NewCallable(lambda->Pos(), "AsScalar", { lambda->ChildPtr(i) });
++newNodes;
}
@@ -4893,11 +4894,10 @@ TExprNode::TPtr OptimizeSkipTakeToBlocks(const TExprNode::TPtr& node, TExprConte
return node;
}
- bool supported = false;
- YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Head().Pos()),
- TVector<const TTypeAnnotationNode*>(allTypes.begin(), allTypes.end()),
- supported, ctx));
- if (!supported) {
+ auto resolveStatus = types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Head().Pos()),
+ TVector<const TTypeAnnotationNode*>(allTypes.begin(), allTypes.end()), ctx);
+ YQL_ENSURE(resolveStatus != IArrowResolver::ERROR);
+ if (resolveStatus != IArrowResolver::OK) {
return node;
}
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
index 3ad8cdf39e9..6b8df2433d4 100644
--- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
@@ -200,7 +200,7 @@ IGraphTransformer::TStatus BlockLogicalWrapper(const TExprNode::TPtr& input, TEx
IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) {
Y_UNUSED(output);
- if (!EnsureMinArgsCount(*input, 1U, ctx.Expr)) {
+ if (!EnsureMinArgsCount(*input, 2U, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
@@ -209,34 +209,25 @@ IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprN
}
auto name = input->Child(0)->Content();
-
- for (ui32 i = 1; i < input->ChildrenSize(); ++i) {
- if (!EnsureBlockOrScalarType(*input->Child(i), ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
- }
-
- if (!ctx.Types.ArrowResolver) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Arrow resolver isn't available"));
- return IGraphTransformer::TStatus::Error;
- }
-
- const TTypeAnnotationNode* outType = nullptr;
- TVector<const TTypeAnnotationNode*> argTypes;
- for (ui32 i = 1; i < input->ChildrenSize(); ++i) {
- argTypes.push_back(input->Child(i)->GetTypeAnn());
+ Y_UNUSED(name);
+ if (auto status = EnsureTypeRewrite(input->ChildRef(1), ctx.Expr); status != IGraphTransformer::TStatus::Ok) {
+ return status;
}
- if (!ctx.Types.ArrowResolver->LoadFunctionMetadata(ctx.Expr.GetPosition(input->Pos()), name, argTypes, outType, ctx.Expr)) {
+ auto returnType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
+ const bool allowChunked = true;
+ if (!EnsureBlockOrScalarType(input->Child(1)->Pos(), *returnType, ctx.Expr, allowChunked)) {
return IGraphTransformer::TStatus::Error;
}
- if (!outType) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() << "No such Arrow function: " << name));
- return IGraphTransformer::TStatus::Error;
+ for (ui32 i = 2; i < input->ChildrenSize(); ++i) {
+ if (!EnsureBlockOrScalarType(*input->Child(i), ctx.Expr, allowChunked)) {
+ return IGraphTransformer::TStatus::Error;
+ }
}
- input->SetTypeAnn(outType);
+ // TODO: more validation
+ input->SetTypeAnn(returnType);
return IGraphTransformer::TStatus::Ok;
}
@@ -261,24 +252,24 @@ IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TEx
bool isScalar;
auto inputType = GetBlockItemType(*input->Child(0)->GetTypeAnn(), isScalar);
-
auto outputType = input->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType();
- bool has = false;
- if (!ctx.Types.ArrowResolver->HasCast(ctx.Expr.GetPosition(input->Pos()), inputType, outputType, has, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
- if (!has) {
+ auto castStatus = ctx.Types.ArrowResolver->HasCast(ctx.Expr.GetPosition(input->Pos()), inputType, outputType, ctx.Expr);
+ if (castStatus == IArrowResolver::ERROR) {
+ return IGraphTransformer::TStatus::Error;
+ } else if (castStatus == IArrowResolver::NOT_FOUND) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "No such cast"));
return IGraphTransformer::TStatus::Error;
}
if (isScalar) {
input->SetTypeAnn(ctx.Expr.MakeType<TScalarExprType>(outputType));
- } else {
+ } else if (outputType->HasFixedSizeRepr()) {
input->SetTypeAnn(ctx.Expr.MakeType<TBlockExprType>(outputType));
+ } else {
+ input->SetTypeAnn(ctx.Expr.MakeType<TChunkedBlockExprType>(outputType));
}
-
+
return IGraphTransformer::TStatus::Ok;
}
diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp
index ab0fb4bef94..d1c9be8a6d4 100644
--- a/ydb/library/yql/core/yql_aggregate_expander.cpp
+++ b/ydb/library/yql/core/yql_aggregate_expander.cpp
@@ -491,9 +491,9 @@ TExprNode::TPtr TAggregateExpander::GetFinalAggStateExtractor(ui32 i) {
return Ctx.Builder(Node->Pos())
.Lambda()
.Param("item")
- .Callable("Member")
- .Arg(0, "item")
- .Add(1, columnNames[i])
+ .Callable("Member")
+ .Arg(0, "item")
+ .Add(1, columnNames[i])
.Seal()
.Seal()
.Build();
@@ -538,9 +538,9 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea
extractorRoots.push_back(extractorArgs[*rowIndex]);
}
- bool supported = false;
- YQL_ENSURE(TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allKeyTypes, supported, Ctx));
- if (!supported) {
+ auto resolveStatus = TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allKeyTypes, Ctx);
+ YQL_ENSURE(resolveStatus != IArrowResolver::ERROR);
+ if (resolveStatus != IArrowResolver::OK) {
return nullptr;
}
@@ -563,9 +563,10 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea
TVector<const TTypeAnnotationNode*> allTypes;
allTypes.push_back(root->GetTypeAnn());
- bool supported = false;
- YQL_ENSURE(TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allTypes, supported, Ctx));
- if (!supported) {
+
+ auto resolveStatus = TypesCtx.ArrowResolver->AreTypesSupported(Ctx.GetPosition(Node->Pos()), allKeyTypes, Ctx);
+ YQL_ENSURE(resolveStatus != IArrowResolver::ERROR);
+ if (resolveStatus != IArrowResolver::OK) {
return nullptr;
}
@@ -735,9 +736,9 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const
.Do(GetPartialAggArgExtractor(i, true))
.Done()
.With(1)
- .Callable("Member")
- .Arg(0, "state")
- .Add(1, columnNames[i])
+ .Callable("Member")
+ .Arg(0, "state")
+ .Add(1, columnNames[i])
.Seal()
.Done()
.Seal()
@@ -750,9 +751,9 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const
.Do(GetPartialAggArgExtractor(i, true))
.Done()
.With(1)
- .Callable("Member")
- .Arg(0, "state")
- .Add(1, columnNames[i])
+ .Callable("Member")
+ .Arg(0, "state")
+ .Add(1, columnNames[i])
.Seal()
.Done()
.With(2)
diff --git a/ydb/library/yql/core/yql_arrow_resolver.h b/ydb/library/yql/core/yql_arrow_resolver.h
index 05e3eb3ad4c..d30f6239ca3 100644
--- a/ydb/library/yql/core/yql_arrow_resolver.h
+++ b/ydb/library/yql/core/yql_arrow_resolver.h
@@ -8,14 +8,20 @@ class IArrowResolver : public TThrRefBase {
public:
using TPtr = TIntrusiveConstPtr<IArrowResolver>;
+ enum EStatus {
+ OK,
+ NOT_FOUND,
+ ERROR,
+ };
+
virtual ~IArrowResolver() = default;
- virtual bool LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes,
- const TTypeAnnotationNode*& returnType, TExprContext& ctx) const = 0;
+ virtual EStatus LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes,
+ const TTypeAnnotationNode* returnType, TExprContext& ctx) const = 0;
- virtual bool HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, bool& has, TExprContext& ctx) const = 0;
+ virtual EStatus HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, TExprContext& ctx) const = 0;
- virtual bool AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, bool& supported, TExprContext& ctx) const = 0;
+ virtual EStatus AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, TExprContext& ctx) const = 0;
};
}
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
index 798e3bf02b9..2f9250f88a0 100644
--- a/ydb/library/yql/minikql/arrow/mkql_functions.cpp
+++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
@@ -122,36 +122,56 @@ bool ConvertOutputArrowType(const arrow::compute::OutputType& outType, const std
}
}
-bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env, const IBuiltinFunctionRegistry& registry) {
+bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType* outputType, const IBuiltinFunctionRegistry& registry) {
bool hasOptionals = false;
bool many = false;
std::vector<NUdf::TDataTypeId> argTypes;
for (const auto& t : inputTypes) {
- bool isOptional;
auto asBlockType = AS_TYPE(TBlockType, t);
if (asBlockType->GetShape() == TBlockType::EShape::Many) {
many = true;
}
- auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
+ bool isOptional;
+ auto baseType = UnpackOptional(asBlockType->GetItemType(), isOptional);
+ if (!baseType->IsData()) {
+ return false;
+ }
+
hasOptionals = hasOptionals || isOptional;
- argTypes.push_back(dataType->GetSchemeType());
+ argTypes.push_back(AS_TYPE(TDataType, baseType)->GetSchemeType());
}
- auto kernel = registry.FindKernel(name, argTypes.data(), argTypes.size());
+ NUdf::TDataTypeId returnType;
+ bool returnIsOptional;
+ {
+ auto asBlockType = AS_TYPE(TBlockType, outputType);
+ MKQL_ENSURE(many ^ (asBlockType->GetShape() == TBlockType::EShape::Scalar), "Output shape is inconsistent with input shapes");
+ auto baseType = UnpackOptional(asBlockType->GetItemType(), returnIsOptional);
+ if (!baseType->IsData()) {
+ return false;
+ }
+ returnType = AS_TYPE(TDataType, baseType)->GetSchemeType();
+ }
+
+ auto kernel = registry.FindKernel(name, argTypes.data(), argTypes.size(), returnType);
if (!kernel) {
return false;
}
- outputType = TDataType::Create(kernel->ReturnType, env);
- if (kernel->Family.NullMode != TKernelFamily::ENullMode::AlwaysNotNull) {
- if (hasOptionals || kernel->Family.NullMode == TKernelFamily::ENullMode::AlwaysNull) {
- outputType = TOptionalType::Create(outputType, env);
- }
+ bool match = false;
+ switch (kernel->Family.NullMode) {
+ case TKernelFamily::ENullMode::Default:
+ match = returnIsOptional == hasOptionals;
+ break;
+ case TKernelFamily::ENullMode::AlwaysNull:
+ match = returnIsOptional;
+ break;
+ case TKernelFamily::ENullMode::AlwaysNotNull:
+ match = !returnIsOptional;
+ break;
}
-
- outputType = TBlockType::Create(outputType, many ? TBlockType::EShape::Many : TBlockType::EShape::Scalar, env);
- return true;
+ return match;
}
bool HasArrowCast(TType* from, TType* to) {
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.h b/ydb/library/yql/minikql/arrow/mkql_functions.h
index 267c79090e3..f33443debe0 100644
--- a/ydb/library/yql/minikql/arrow/mkql_functions.h
+++ b/ydb/library/yql/minikql/arrow/mkql_functions.h
@@ -7,7 +7,7 @@ namespace NKikimr::NMiniKQL {
class IBuiltinFunctionRegistry;
-bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env, const IBuiltinFunctionRegistry& registry);
+bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType* outputType, const IBuiltinFunctionRegistry& registry);
bool ConvertInputArrowType(TType* blockType, bool& isOptional, arrow::ValueDescr& descr);
bool HasArrowCast(TType* from, TType* to);
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
index 81ee6bb78de..7d13b9e8288 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
@@ -43,7 +43,7 @@ const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function
return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
}
-const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes) {
+const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes, TType* returnType) {
std::vector<NUdf::TDataTypeId> argTypes;
for (const auto& t : inputTypes) {
auto asBlockType = AS_TYPE(TBlockType, t);
@@ -52,7 +52,15 @@ const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TSt
argTypes.push_back(dataType->GetSchemeType());
}
- auto kernel = builtins.FindKernel(funcName, argTypes.data(), argTypes.size());
+ NUdf::TDataTypeId returnTypeId;
+ {
+ auto asBlockType = AS_TYPE(TBlockType, returnType);
+ bool isOptional;
+ auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
+ returnTypeId = dataType->GetSchemeType();
+ }
+
+ auto kernel = builtins.FindKernel(funcName, argTypes.data(), argTypes.size(), returnTypeId);
MKQL_ENSURE(kernel, "Can't find kernel for " << funcName);
return *kernel;
}
@@ -68,11 +76,11 @@ struct TState : public TComputationValue<TState> {
, ExecContext(&ctx.ArrowMemoryPool, nullptr, nullptr)
, KernelContext(&ExecContext)
{
- if (kernel.init) {
- State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
- KernelContext.SetState(State.get());
- }
-
+ if (kernel.init) {
+ State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ KernelContext.SetState(State.get());
+ }
+
Values.reserve(argsValuesDescr.size());
}
@@ -90,14 +98,15 @@ public:
const IBuiltinFunctionRegistry& builtins,
const TString& funcName,
TVector<IComputationNode*>&& argsNodes,
- TVector<TType*>&& argsTypes)
+ TVector<TType*>&& argsTypes,
+ TType* returnType)
: TMutableComputationNode(mutables)
, StateIndex(mutables.CurValueIndex++)
, FuncName(funcName)
, ArgsNodes(std::move(argsNodes))
, ArgsTypes(std::move(argsTypes))
, ArgsValuesDescr(ToValueDescr(ArgsTypes))
- , Kernel(ResolveKernel(builtins, FuncName, ArgsTypes))
+ , Kernel(ResolveKernel(builtins, FuncName, ArgsTypes, returnType))
{
}
@@ -112,8 +121,8 @@ public:
auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
- ARROW_OK(executor->Init(&state.KernelContext, { &Kernel.GetArrowKernel(), ArgsValuesDescr, state.Options }));
- ARROW_OK(executor->Execute(state.Values, listener.get()));
+ ARROW_OK(executor->Init(&state.KernelContext, { &Kernel.GetArrowKernel(), ArgsValuesDescr, state.Options }));
+ ARROW_OK(executor->Execute(state.Values, listener.get()));
auto output = executor->WrapResults(state.Values, listener->values());
return ctx.HolderFactory.CreateArrowBlock(std::move(output));
}
@@ -176,8 +185,8 @@ public:
auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
- ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, state.Options }));
- ARROW_OK(executor->Execute(state.Values, listener.get()));
+ ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, state.Options }));
+ ARROW_OK(executor->Execute(state.Values, listener.get()));
auto output = executor->WrapResults(state.Values, listener->values());
return ctx.HolderFactory.CreateArrowBlock(std::move(output));
}
@@ -234,7 +243,8 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto
*ctx.FunctionRegistry.GetBuiltins(),
funcName,
std::move(argsNodes),
- std::move(argsTypes)
+ std::move(argsTypes),
+ callableType->GetReturnType()
);
}
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp
index 2e5052ce150..9a7a300d118 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp
@@ -220,7 +220,7 @@ private:
const TDescriptionList& FindCandidates(const std::string_view& name) const;
- const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const final;
+ const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const final;
void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) final;
@@ -357,13 +357,13 @@ void TBuiltinFunctionRegistry::PrintInfoTo(IOutputStream& out) const
}
}
-const TKernel* TBuiltinFunctionRegistry::FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const {
+const TKernel* TBuiltinFunctionRegistry::FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const {
auto fit = KernelFamilyMap.find(TString(name));
if (fit == KernelFamilyMap.end()) {
return nullptr;
}
- return fit->second->FindKernel(argTypes, argTypesCount);
+ return fit->second->FindKernel(argTypes, argTypesCount, returnType);
}
void TBuiltinFunctionRegistry::RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) {
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h
index bef5393b263..4d7f6add989 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h
@@ -1188,7 +1188,7 @@ void AddBinaryKernel(TKernelFamilyBase& owner) {
arrow::compute::ScalarKernel k({ GetPrimitiveInputArrowType<TInput1>(), GetPrimitiveInputArrowType<TInput2>() }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::Exec);
k.null_handling = owner.NullMode == TKernelFamily::ENullMode::Default ? arrow::compute::NullHandling::INTERSECTION : arrow::compute::NullHandling::COMPUTED_PREALLOCATE;
- owner.KernelMap.emplace(argTypes, std::make_unique<TPlainKernel>(owner, argTypes, returnType, k));
+ owner.KernelMap.emplace(std::make_pair(argTypes, returnType), std::make_unique<TPlainKernel>(owner, argTypes, returnType, k));
}
template<typename TInput1, typename TInput2,
@@ -1204,7 +1204,7 @@ void AddBinaryPredicateKernel(TKernelFamilyBase& owner) {
arrow::compute::ScalarKernel k({ GetPrimitiveInputArrowType<TInput1>(), GetPrimitiveInputArrowType<TInput2>() }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::Exec);
k.null_handling = owner.NullMode == TKernelFamily::ENullMode::Default ? arrow::compute::NullHandling::INTERSECTION : arrow::compute::NullHandling::COMPUTED_PREALLOCATE;
- owner.KernelMap.emplace(argTypes, std::make_unique<TPlainKernel>(owner, argTypes, returnType, k));
+ owner.KernelMap.emplace(std::make_pair(argTypes, returnType), std::make_unique<TPlainKernel>(owner, argTypes, returnType, k));
}
template<template<typename, typename, typename> class TFunc>
diff --git a/ydb/library/yql/minikql/mkql_function_metadata.h b/ydb/library/yql/minikql/mkql_function_metadata.h
index 3042922133f..251710244e4 100644
--- a/ydb/library/yql/minikql/mkql_function_metadata.h
+++ b/ydb/library/yql/minikql/mkql_function_metadata.h
@@ -69,7 +69,7 @@ public:
{}
virtual ~TKernelFamily() = default;
- virtual const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const = 0;
+ virtual const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const = 0;
};
class TKernel {
@@ -90,18 +90,20 @@ public:
virtual ~TKernel() = default;
};
+using TKernelMapKey = std::pair<std::vector<NUdf::TDataTypeId>, NUdf::TDataTypeId>;
struct TTypeHasher {
- std::size_t operator()(const std::vector<NUdf::TDataTypeId>& s) const noexcept {
+ std::size_t operator()(const TKernelMapKey& s) const noexcept {
size_t r = 0;
- for (const auto& x : s) {
+ for (const auto& x : s.first) {
r = CombineHashes<size_t>(r, x);
}
+ r = CombineHashes<size_t>(r, s.second);
return r;
}
};
-using TKernelMap = std::unordered_map<std::vector<NUdf::TDataTypeId>, std::unique_ptr<TKernel>, TTypeHasher>;
+using TKernelMap = std::unordered_map<TKernelMapKey, std::unique_ptr<TKernel>, TTypeHasher>;
using TKernelFamilyMap = std::unordered_map<TString, std::unique_ptr<TKernelFamily>>;
@@ -112,9 +114,9 @@ public:
: TKernelFamily(nullMode, functionOptions)
{}
- const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const final {
- std::vector<NUdf::TDataTypeId> key(argTypes, argTypes + argTypesCount);
- auto it = KernelMap.find(key);
+ const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const final {
+ std::vector<NUdf::TDataTypeId> args(argTypes, argTypes + argTypesCount);
+ auto it = KernelMap.find({args, returnType});
if (it == KernelMap.end()) {
return nullptr;
}
@@ -144,7 +146,7 @@ public:
virtual TFunctionDescriptor GetBuiltin(const std::string_view& name, const std::pair<NUdf::TDataTypeId, bool>* argTypes, size_t argTypesCount) const = 0;
- virtual const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const = 0;
+ virtual const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const = 0;
virtual void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) = 0;
};
diff --git a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp
index e15b992de9f..c8b10f48c49 100644
--- a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp
+++ b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp
@@ -19,57 +19,45 @@ public:
{}
private:
- bool LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes,
- const TTypeAnnotationNode*& returnType, TExprContext& ctx) const override {
+ EStatus LoadFunctionMetadata(const TPosition& pos, TStringBuf name, const TVector<const TTypeAnnotationNode*>& argTypes,
+ const TTypeAnnotationNode* returnType, TExprContext& ctx) const override
+ {
try {
- returnType = nullptr;
TScopedAlloc alloc(__LOCATION__);
TTypeEnvironment env(alloc);
TProgramBuilder pgmBuilder(env, FunctionRegistry_);
- TType* mkqlOutputType;
+ TNullOutput null;
TVector<TType*> mkqlInputTypes;
for (const auto& type : argTypes) {
- TNullOutput null;
auto mkqlType = NCommon::BuildType(*type, pgmBuilder, null);
mkqlInputTypes.emplace_back(mkqlType);
}
-
- if (!FindArrowFunction(name, mkqlInputTypes, mkqlOutputType, env, *FunctionRegistry_.GetBuiltins())) {
- return true;
- }
-
- returnType = NCommon::ConvertMiniKQLType(pos, mkqlOutputType, ctx);
- return true;
+ TType* mkqlOutputType = NCommon::BuildType(*returnType, pgmBuilder, null);
+ bool found = FindArrowFunction(name, mkqlInputTypes, mkqlOutputType, *FunctionRegistry_.GetBuiltins());
+ return found ? EStatus::OK : EStatus::NOT_FOUND;
} catch (const std::exception& e) {
ctx.AddError(TIssue(pos, e.what()));
- return false;
+ return EStatus::ERROR;
}
}
- bool HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, bool& has, TExprContext& ctx) const override {
+ EStatus HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, TExprContext& ctx) const override {
try {
- has = false;
TScopedAlloc alloc(__LOCATION__);
TTypeEnvironment env(alloc);
TProgramBuilder pgmBuilder(env, FunctionRegistry_);
TNullOutput null;
auto mkqlFromType = NCommon::BuildType(*from, pgmBuilder, null);
auto mkqlToType = NCommon::BuildType(*to, pgmBuilder, null);
- if (!HasArrowCast(mkqlFromType, mkqlToType)) {
- return true;
- }
-
- has = true;
- return true;
+ return HasArrowCast(mkqlFromType, mkqlToType) ? EStatus::OK : EStatus::NOT_FOUND;
} catch (const std::exception& e) {
ctx.AddError(TIssue(pos, e.what()));
- return false;
+ return EStatus::ERROR;
}
}
- bool AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, bool& supported, TExprContext& ctx) const override {
+ EStatus AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, TExprContext& ctx) const override {
try {
- supported = false;
TScopedAlloc alloc(__LOCATION__);
TTypeEnvironment env(alloc);
TProgramBuilder pgmBuilder(env, FunctionRegistry_);
@@ -79,15 +67,13 @@ private:
bool isOptional;
std::shared_ptr<arrow::DataType> arrowType;
if (!ConvertArrowType(mkqlType, isOptional, arrowType)) {
- return true;
+ return EStatus::NOT_FOUND;
}
}
-
- supported = true;
- return true;
+ return EStatus::OK;
} catch (const std::exception& e) {
ctx.AddError(TIssue(pos, e.what()));
- return false;
+ return EStatus::ERROR;
}
}
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index cec8f7e8633..a6c89236af8 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -2390,7 +2390,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
AddCallable("BlockFunc", [](const TExprNode& node, TMkqlBuildContext& ctx) {
TVector<TRuntimeNode> args;
- for (ui32 i = 1; i < node.ChildrenSize(); ++i) {
+ for (ui32 i = 2; i < node.ChildrenSize(); ++i) {
args.push_back(MkqlBuildExpr(*node.Child(i), ctx));
}
diff --git a/ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp b/ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp
index 57ff41b96c6..bd610534c15 100644
--- a/ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp
+++ b/ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp
@@ -172,7 +172,9 @@ public:
allTypes.push_back(x->GetItemType());
}
- YQL_ENSURE(State_->Types->ArrowResolver->AreTypesSupported(ctx.GetPosition(read->Pos()), allTypes, supportedArrowTypes, ctx));
+ auto resolveStatus = State_->Types->ArrowResolver->AreTypesSupported(ctx.GetPosition(read->Pos()), allTypes, ctx);
+ YQL_ENSURE(resolveStatus != IArrowResolver::ERROR);
+ supportedArrowTypes = resolveStatus == IArrowResolver::OK;
}
return Build<TDqSourceWrap>(ctx, read->Pos())