diff options
author | vvvv <vvvv@ydb.tech> | 2022-09-22 21:12:15 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-09-22 21:12:15 +0300 |
commit | ba2a2a5b4931c54baa7a03996fe95e54bb11b62f (patch) | |
tree | a74dd0af53268317ae44f9179a9150bbab3acf2f | |
parent | ff33b35bbfa9b0e1a50980a487a23e4fae515cd5 (diff) | |
download | ydb-ba2a2a5b4931c54baa7a03996fe95e54bb11b62f.tar.gz |
use arrow ABI to call ClickHouse functions, move type operations into ITypeInfoHelper
11 files changed, 208 insertions, 88 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 170a016c82..7620d43bd2 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -4376,18 +4376,45 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext& return true; } - auto fit = funcs.find(node->Content()); - if (fit == funcs.end()) { - return true; + TExprNode::TListType funcArgs; + std::string_view arrowFunctionName; + bool bitcastToReturnType = false; + if (node->IsCallable("Apply") && node->Head().IsCallable("Udf")) { + auto func = node->Head().Head().Content(); + if (!func.StartsWith("ClickHouse.")) { + return true; + } + + TVector<const TTypeAnnotationNode*> allTypes; + allTypes.push_back(node->GetTypeAnn()); + for (ui32 i = 1; i < node->ChildrenSize(); ++i) { + allTypes.push_back(node->Child(i)->GetTypeAnn()); + } + + bool supported = false; + YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Pos()), allTypes, supported, ctx)); + if (!supported) { + return true; + } + + funcArgs.push_back(nullptr); + } else { + auto fit = funcs.find(node->Content()); + if (fit == funcs.end()) { + return true; + } + + arrowFunctionName = fit->second.Name; + bitcastToReturnType = fit->second.BitcastToReturnType; + funcArgs.push_back(ctx.NewAtom(node->Pos(), arrowFunctionName)); } - TExprNode::TListType funcArgs; - funcArgs.push_back(ctx.NewAtom(node->Pos(), fit->second.Name)); - for (const auto& child : node->Children()) { + for (ui32 i = arrowFunctionName.empty() ? 1 : 0; i < node->ChildrenSize(); ++i) { + auto child = node->Child(i); if (child->IsComplete()) { - funcArgs.push_back(ctx.NewCallable(node->Pos(), "AsScalar", { child })); + funcArgs.push_back(ctx.NewCallable(node->Pos(), "AsScalar", { node->ChildPtr(i) })); } else { - auto rit = rewrites.find(child.Get()); + auto rit = rewrites.find(child); if (rit == rewrites.end()) { return true; } @@ -4398,7 +4425,8 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext& const TTypeAnnotationNode* outType = nullptr; TVector<const TTypeAnnotationNode*> argTypes; - for (const auto& child : node->Children()) { + for (ui32 i = arrowFunctionName.empty() ? 1 : 0; i < node->ChildrenSize(); ++i) { + auto child = node->Child(i); if (child->IsComplete()) { argTypes.push_back(ctx.MakeType<TScalarExprType>(child->GetTypeAnn())); } else { @@ -4406,53 +4434,85 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext& } } - YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), fit->second.Name, argTypes, outType, ctx)); - if (!outType && !fit->second.BitcastToReturnType) { - return true; - } - - if (!outType) { - argTypes.clear(); - for (const auto& child : node->Children()) { - if (child->IsComplete()) { - argTypes.push_back(ctx.MakeType<TScalarExprType>(node->GetTypeAnn())); - } else { - argTypes.push_back(ctx.MakeType<TBlockExprType>(node->GetTypeAnn())); - } + if (!arrowFunctionName.empty()) { + YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), arrowFunctionName, argTypes, outType, ctx)); + if (!outType && !bitcastToReturnType) { + return true; } - YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), fit->second.Name, argTypes, outType, ctx)); if (!outType) { - return true; - } + argTypes.clear(); + for (const auto& child : node->Children()) { + if (child->IsComplete()) { + argTypes.push_back(ctx.MakeType<TScalarExprType>(node->GetTypeAnn())); + } else { + argTypes.push_back(ctx.MakeType<TBlockExprType>(node->GetTypeAnn())); + } + } - auto typeNode = ExpandType(node->Pos(), *node->GetTypeAnn(), ctx); - for (ui32 i = 1; i < funcArgs.size(); ++i) { - if (IsSameAnnotation(*node->Child(i-1)->GetTypeAnn(), *node->GetTypeAnn())) { - continue; + YQL_ENSURE(types.ArrowResolver->LoadFunctionMetadata(ctx.GetPosition(node->Pos()), arrowFunctionName, argTypes, outType, ctx)); + if (!outType) { + return true; } - if (node->Child(i-1)->IsComplete()) { - funcArgs[i] = ctx.Builder(node->Pos()) - .Callable("AsScalar") - .Callable(0, "BitCast") - .Add(0, funcArgs[i]->HeadPtr()) - .Add(1, typeNode) + auto typeNode = ExpandType(node->Pos(), *node->GetTypeAnn(), ctx); + for (ui32 i = 1; i < funcArgs.size(); ++i) { + if (IsSameAnnotation(*node->Child(i-1)->GetTypeAnn(), *node->GetTypeAnn())) { + continue; + } + + if (node->Child(i-1)->IsComplete()) { + funcArgs[i] = ctx.Builder(node->Pos()) + .Callable("AsScalar") + .Callable(0, "BitCast") + .Add(0, funcArgs[i]->HeadPtr()) + .Add(1, typeNode) + .Seal() .Seal() - .Seal() - .Build(); - } else { - funcArgs[i] = ctx.NewCallable(node->Pos(), "BlockBitCast", { funcArgs[i], typeNode }); + .Build(); + } else { + funcArgs[i] = ctx.NewCallable(node->Pos(), "BlockBitCast", { funcArgs[i], typeNode }); + } } } - } - bool isScalar; - if (!IsSameAnnotation(*node->GetTypeAnn(), *GetBlockItemType(*outType, isScalar))) { - return true; + bool isScalar; + if (!IsSameAnnotation(*node->GetTypeAnn(), *GetBlockItemType(*outType, isScalar))) { + return true; + } + + rewrites[node.Get()] = ctx.NewCallable(node->Pos(), "BlockFunc", std::move(funcArgs)); + } else { + funcArgs[0] = ctx.Builder(node->Head().Pos()) + .Callable("Udf") + .Add(0, node->Head().ChildPtr(0)) + .Add(1, node->Head().ChildPtr(1)) + .Callable(2, "TupleType") + .Callable(0, "TupleType") + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + for (ui32 i = 1; i < node->ChildrenSize(); ++i) { + auto child = node->Child(i); + auto originalTypeNode = node->Head().Child(2)->Head().Child(i - 1); + parent.Callable(i - 1, child->IsComplete() ? "ScalarType" : "BlockType") + .Add(0, originalTypeNode) + .Seal(); + } + + return parent; + }) + .Seal() + .Callable(1, "StructType") + .Seal() + .Callable(2, "TupleType") + .Seal() + .Seal() + .Add(3, node->Head().ChildPtr(3)) + .Seal() + .Build(); + + rewrites[node.Get()] = ctx.NewCallable(node->Pos(), "Apply", std::move(funcArgs)); } - rewrites[node.Get()] = ctx.NewCallable(node->Pos(), "BlockFunc", std::move(funcArgs)); ++newNodes; return true; }); diff --git a/ydb/library/yql/core/yql_arrow_resolver.h b/ydb/library/yql/core/yql_arrow_resolver.h index 87622c4e58..05e3eb3ad4 100644 --- a/ydb/library/yql/core/yql_arrow_resolver.h +++ b/ydb/library/yql/core/yql_arrow_resolver.h @@ -14,6 +14,8 @@ public: const TTypeAnnotationNode*& returnType, TExprContext& ctx) const = 0; virtual bool HasCast(const TPosition& pos, const TTypeAnnotationNode* from, const TTypeAnnotationNode* to, bool& has, TExprContext& ctx) const = 0; + + virtual bool AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, bool& supported, TExprContext& ctx) const = 0; }; } diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp index 26c28fe8a9..38a2b73a51 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.cpp +++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp @@ -4888,6 +4888,20 @@ TExprNode::TPtr ExpandTypeNoCache(TPositionHandle position, const TTypeAnnotatio return ret; } + case ETypeAnnotationKind::Block: + { + auto ret = ctx.NewCallable(position, "BlockType", + { ExpandType(position, *type.Cast<TBlockExprType>()->GetItemType(), ctx) }); + return ret; + } + + case ETypeAnnotationKind::Scalar: + { + auto ret = ctx.NewCallable(position, "ScalarType", + { ExpandType(position, *type.Cast<TScalarExprType>()->GetItemType(), ctx) }); + return ret; + } + default: YQL_ENSURE(false, "Unsupported kind: " << (ui32)type.GetKind()); } diff --git a/ydb/library/yql/minikql/computation/mkql_validate.cpp b/ydb/library/yql/minikql/computation/mkql_validate.cpp index b238b28aee..71d5e08c93 100644 --- a/ydb/library/yql/minikql/computation/mkql_validate.cpp +++ b/ydb/library/yql/minikql/computation/mkql_validate.cpp @@ -473,6 +473,7 @@ NUdf::TUnboxedValue TValidate<TValidateErrorPolicy, TValidateMode>::Value(const case TType::EKind::Stream: case TType::EKind::Variant: + case TType::EKind::Block: // TODO validate it break; diff --git a/ydb/library/yql/minikql/computation/mkql_value_builder_ut.cpp b/ydb/library/yql/minikql/computation/mkql_value_builder_ut.cpp index 9282538948..344e2fbe03 100644 --- a/ydb/library/yql/minikql/computation/mkql_value_builder_ut.cpp +++ b/ydb/library/yql/minikql/computation/mkql_value_builder_ut.cpp @@ -262,7 +262,7 @@ private: void TestArrowBlock() { auto type = FunctionTypeInfoBuilder.SimpleType<ui64>(); - auto atype = FunctionTypeInfoBuilder.MakeArrowType(type); + auto atype = TypeInfoHelper->MakeArrowType(type); { arrow::Datum d1(std::make_shared<arrow::UInt64Scalar>(123)); diff --git a/ydb/library/yql/minikql/mkql_type_builder.cpp b/ydb/library/yql/minikql/mkql_type_builder.cpp index 426f4ff6ec..7304949836 100644 --- a/ydb/library/yql/minikql/mkql_type_builder.cpp +++ b/ydb/library/yql/minikql/mkql_type_builder.cpp @@ -1443,24 +1443,10 @@ NUdf::IBlockTypeBuilder::TPtr TFunctionTypeInfoBuilder::Block(bool isScalar) con return new TBlockTypeBuilder(*this, isScalar); } -NUdf::IArrowType::TPtr TFunctionTypeInfoBuilder::MakeArrowType(const NUdf::TType* type) const { - bool isOptional; - std::shared_ptr<arrow::DataType> arrowType; - if (!ConvertArrowType(const_cast<TType*>(static_cast<const TType*>(type)), isOptional, arrowType)) { - return nullptr; - } - - return new TArrowType(arrowType); +void TFunctionTypeInfoBuilder::Unused1() { } -NUdf::IArrowType::TPtr TFunctionTypeInfoBuilder::ImportArrowType(ArrowSchema* schema) const { - auto res = arrow::ImportType(schema); - auto status = res.status(); - if (!status.ok()) { - UdfTerminate(status.ToString().c_str()); - } - - return new TArrowType(std::move(res).ValueOrDie()); +void TFunctionTypeInfoBuilder::Unused2() { } bool TFunctionTypeInfoBuilder::GetSecureParam(NUdf::TStringRef key, NUdf::TStringRef& value) const { @@ -1783,6 +1769,26 @@ const NYql::NUdf::TPgTypeDescription* TTypeInfoHelper::FindPgTypeDescription(ui3 return HugeSingleton<TPgTypeIndex>()->Resolve(typeId); } +NUdf::IArrowType::TPtr TTypeInfoHelper::MakeArrowType(const NUdf::TType* type) const { + bool isOptional; + std::shared_ptr<arrow::DataType> arrowType; + if (!ConvertArrowType(const_cast<TType*>(static_cast<const TType*>(type)), isOptional, arrowType)) { + return nullptr; + } + + return new TArrowType(arrowType); +} + +NUdf::IArrowType::TPtr TTypeInfoHelper::ImportArrowType(ArrowSchema* schema) const { + auto res = arrow::ImportType(schema); + auto status = res.status(); + if (!status.ok()) { + UdfTerminate(status.ToString().c_str()); + } + + return new TArrowType(std::move(res).ValueOrDie()); +} + void TTypeInfoHelper::DoData(const NMiniKQL::TDataType* dt, NUdf::ITypeVisitor* v) { const auto typeId = dt->GetSchemeType(); v->OnDataType(typeId); diff --git a/ydb/library/yql/minikql/mkql_type_builder.h b/ydb/library/yql/minikql/mkql_type_builder.h index 68fd699b2d..54aabcb507 100644 --- a/ydb/library/yql/minikql/mkql_type_builder.h +++ b/ydb/library/yql/minikql/mkql_type_builder.h @@ -139,8 +139,8 @@ public: NUdf::TType* Tagged(const NUdf::TType* baseType, const NUdf::TStringRef& tag) const override; NUdf::TType* Pg(ui32 typeId) const override; NUdf::IBlockTypeBuilder::TPtr Block(bool isScalar) const override; - NUdf::IArrowType::TPtr MakeArrowType(const NUdf::TType* type) const override; - NUdf::IArrowType::TPtr ImportArrowType(ArrowSchema* schema) const override; + void Unused1() override; + void Unused2() override; bool GetSecureParam(NUdf::TStringRef key, NUdf::TStringRef& value) const override; @@ -172,6 +172,8 @@ public: void VisitType(const NUdf::TType* type, NUdf::ITypeVisitor* visitor) const override; bool IsSameType(const NUdf::TType* type1, const NUdf::TType* type2) const override; const NYql::NUdf::TPgTypeDescription* FindPgTypeDescription(ui32 typeId) const override; + NUdf::IArrowType::TPtr MakeArrowType(const NUdf::TType* type) const override; + NUdf::IArrowType::TPtr ImportArrowType(ArrowSchema* schema) const override; private: static void DoData(const NMiniKQL::TDataType* dt, NUdf::ITypeVisitor* v); diff --git a/ydb/library/yql/minikql/mkql_type_builder_ut.cpp b/ydb/library/yql/minikql/mkql_type_builder_ut.cpp index cc18d49835..7e53503109 100644 --- a/ydb/library/yql/minikql/mkql_type_builder_ut.cpp +++ b/ydb/library/yql/minikql/mkql_type_builder_ut.cpp @@ -339,12 +339,12 @@ private: void TestArrowType() { auto type = FunctionTypeInfoBuilder.SimpleType<ui64>(); - auto atype1 = FunctionTypeInfoBuilder.MakeArrowType(type); + auto atype1 = TypeInfoHelper->MakeArrowType(type); UNIT_ASSERT(atype1); UNIT_ASSERT_VALUES_EQUAL(static_cast<TArrowType*>(atype1.Get())->GetType()->ToString(), std::string("uint64")); ArrowSchema s; atype1->Export(&s); - auto atype2 = FunctionTypeInfoBuilder.ImportArrowType(&s); + auto atype2 = TypeInfoHelper->ImportArrowType(&s); UNIT_ASSERT_VALUES_EQUAL(static_cast<TArrowType*>(atype2.Get())->GetType()->ToString(), std::string("uint64")); } }; diff --git a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp index 40fb07c48c..e7154eab56 100644 --- a/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp +++ b/ydb/library/yql/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp @@ -2,6 +2,7 @@ #include <ydb/library/yql/minikql/arrow/mkql_functions.h> #include <ydb/library/yql/minikql/mkql_program_builder.h> +#include <ydb/library/yql/minikql/mkql_type_builder.h> #include <ydb/library/yql/providers/common/mkql/yql_type_mkql.h> #include <util/stream/null.h> @@ -65,6 +66,30 @@ private: } } + bool AreTypesSupported(const TPosition& pos, const TVector<const TTypeAnnotationNode*>& types, bool& supported, TExprContext& ctx) const override { + try { + supported = false; + TScopedAlloc alloc; + TTypeEnvironment env(alloc); + TProgramBuilder pgmBuilder(env, FunctionRegistry_); + for (const auto& type : types) { + TNullOutput null; + auto mkqlType = NCommon::BuildType(*type, pgmBuilder, null); + bool isOptional; + std::shared_ptr<arrow::DataType> arrowType; + if (!ConvertArrowType(mkqlType, isOptional, arrowType)) { + return true; + } + } + + supported = true; + return true; + } catch (const std::exception& e) { + ctx.AddError(TIssue(pos, e.what())); + return false; + } + } + private: const IFunctionRegistry& FunctionRegistry_; }; diff --git a/ydb/library/yql/public/udf/udf_type_builder.h b/ydb/library/yql/public/udf/udf_type_builder.h index 7355e380a0..dc9825f24c 100644 --- a/ydb/library/yql/public/udf/udf_type_builder.h +++ b/ydb/library/yql/public/udf/udf_type_builder.h @@ -9,8 +9,6 @@ #include <type_traits> -struct ArrowSchema; - namespace NYql { namespace NUdf { @@ -630,29 +628,11 @@ public: #endif #if UDF_ABI_COMPATIBILITY_VERSION_CURRENT >= UDF_ABI_COMPATIBILITY_VERSION(2, 26) - -////////////////////////////////////////////////////////////////////////////// -// IArrowType -////////////////////////////////////////////////////////////////////////////// -class IArrowType -{ -public: - using TPtr = TUniquePtr<IArrowType>; - - virtual ~IArrowType() = default; - - virtual void Export(ArrowSchema* out) const = 0; -}; - -UDF_ASSERT_TYPE_SIZE(IArrowType, 8); - class IFunctionTypeInfoBuilder14: public IFunctionTypeInfoBuilder13 { public: virtual IBlockTypeBuilder::TPtr Block(bool isScalar) const = 0; - // returns nullptr if type isn't supported - virtual IArrowType::TPtr MakeArrowType(const TType* type) const = 0; - // The given ArrowSchema struct is released, even if this function fails. - virtual IArrowType::TPtr ImportArrowType(ArrowSchema* schema) const = 0; + virtual void Unused1() = 0; + virtual void Unused2() = 0; }; #endif diff --git a/ydb/library/yql/public/udf/udf_types.h b/ydb/library/yql/public/udf/udf_types.h index 1da1cd8844..5fee361e3d 100644 --- a/ydb/library/yql/public/udf/udf_types.h +++ b/ydb/library/yql/public/udf/udf_types.h @@ -6,6 +6,8 @@ #include "udf_type_size_check.h" #include "udf_version.h" +struct ArrowSchema; + namespace NYql { namespace NUdf { @@ -276,7 +278,35 @@ public: virtual const TPgTypeDescription* FindPgTypeDescription(ui32 typeId) const = 0; }; -#if UDF_ABI_COMPATIBILITY_VERSION_CURRENT >= UDF_ABI_COMPATIBILITY_VERSION(2, 25) +////////////////////////////////////////////////////////////////////////////// +// IArrowType +////////////////////////////////////////////////////////////////////////////// +class IArrowType +{ +public: + using TPtr = TUniquePtr<IArrowType>; + + virtual ~IArrowType() = default; + + virtual void Export(ArrowSchema* out) const = 0; +}; + +UDF_ASSERT_TYPE_SIZE(IArrowType, 8); + +class ITypeInfoHelper3 : public ITypeInfoHelper2 { +public: + using TPtr = TRefCountedPtr<ITypeInfoHelper3>; + +public: + // returns nullptr if type isn't supported + virtual IArrowType::TPtr MakeArrowType(const TType* type) const = 0; + // The given ArrowSchema struct is released, even if this function fails. + virtual IArrowType::TPtr ImportArrowType(ArrowSchema* schema) const = 0; +}; + +#if UDF_ABI_COMPATIBILITY_VERSION_CURRENT >= UDF_ABI_COMPATIBILITY_VERSION(2, 26) +using ITypeInfoHelper = ITypeInfoHelper3; +#elif UDF_ABI_COMPATIBILITY_VERSION_CURRENT >= UDF_ABI_COMPATIBILITY_VERSION(2, 25) using ITypeInfoHelper = ITypeInfoHelper2; #else using ITypeInfoHelper = ITypeInfoHelper1; |