diff options
author | udovichenko-r <udovichenko-r@yandex-team.com> | 2024-12-05 16:42:04 +0300 |
---|---|---|
committer | udovichenko-r <udovichenko-r@yandex-team.com> | 2024-12-05 17:07:33 +0300 |
commit | ac81b4ff5a8fd2fd4c1fafa78ac07defb71478dc (patch) | |
tree | 159a2723ccc1de82b274fa36603ac6cb723bfad3 | |
parent | 8aee4675a8b41c39f833026f579f9828deee361a (diff) | |
download | ydb-ac81b4ff5a8fd2fd4c1fafa78ac07defb71478dc.tar.gz |
[mkql] Use type memoization in compiler
YQL-19355
commit_hash:ff1684e65f8ec72be0a1407be5d31bb3fb0465a7
13 files changed, 182 insertions, 130 deletions
diff --git a/yql/essentials/core/arrow_kernels/request/request.cpp b/yql/essentials/core/arrow_kernels/request/request.cpp index 6249012269..b51adda710 100644 --- a/yql/essentials/core/arrow_kernels/request/request.cpp +++ b/yql/essentials/core/arrow_kernels/request/request.cpp @@ -1,5 +1,4 @@ #include "request.h" -#include <yql/essentials/providers/common/mkql/yql_type_mkql.h> #include <yql/essentials/providers/common/mkql/yql_provider_mkql.h> #include <yql/essentials/minikql/mkql_node_cast.h> #include <yql/essentials/minikql/mkql_node_serialization.h> @@ -246,18 +245,13 @@ TRuntimeNode TKernelRequestBuilder::MakeArg(const TTypeAnnotationNode* type) { } TBlockType* TKernelRequestBuilder::MakeType(const TTypeAnnotationNode* type) { - auto [it, inserted] = CachedTypes_.emplace(type, nullptr); - if (!inserted) { - return it->second; - } - TStringStream err; - const auto ret = NCommon::BuildType(*type, Pb_, err); + const auto ret = NCommon::BuildType(*type, Pb_, TypesMemoization_, err); if (!ret) { ythrow yexception() << err.Str(); } - return it->second = AS_TYPE(TBlockType, ret); + return AS_TYPE(TBlockType, ret); } } diff --git a/yql/essentials/core/arrow_kernels/request/request.h b/yql/essentials/core/arrow_kernels/request/request.h index bfed8d5dd7..9dad9b6830 100644 --- a/yql/essentials/core/arrow_kernels/request/request.h +++ b/yql/essentials/core/arrow_kernels/request/request.h @@ -1,6 +1,7 @@ #pragma once #include <yql/essentials/ast/yql_expr.h> #include <yql/essentials/minikql/mkql_program_builder.h> +#include <yql/essentials/providers/common/mkql/yql_type_mkql.h> #include <unordered_map> @@ -63,7 +64,7 @@ private: NKikimr::NMiniKQL::TProgramBuilder Pb_; std::vector<NKikimr::NMiniKQL::TRuntimeNode> Items_; std::vector<NKikimr::NMiniKQL::TRuntimeNode> ArgsItems_; - std::unordered_map<const TTypeAnnotationNode*, NKikimr::NMiniKQL::TBlockType*> CachedTypes_; + NCommon::TMemoizedTypesMap TypesMemoization_; std::unordered_map<const TTypeAnnotationNode*, NKikimr::NMiniKQL::TRuntimeNode> CachedArgs_; }; diff --git a/yql/essentials/core/services/yql_eval_params.cpp b/yql/essentials/core/services/yql_eval_params.cpp index 8002becd8e..23b716d0ec 100644 --- a/yql/essentials/core/services/yql_eval_params.cpp +++ b/yql/essentials/core/services/yql_eval_params.cpp @@ -27,6 +27,7 @@ bool BuildParameterValuesAsNodes(const THashMap<TStringBuf, const TTypeAnnotatio TTypeEnvironment env(alloc); TMemoryUsageInfo memInfo("Parameters"); THolderFactory holderFactory(alloc.Ref(), memInfo); + NCommon::TMemoizedTypesMap typesMemoization; bool isOk = true; auto& paramDataMap = paramData.AsMap(); for (auto& p : paramTypes) { @@ -34,7 +35,7 @@ bool BuildParameterValuesAsNodes(const THashMap<TStringBuf, const TTypeAnnotatio TStringStream err; TProgramBuilder pgmBuilder(env, functionRegistry); - TType* mkqlType = NCommon::BuildType(*p.second, pgmBuilder, err); + TType* mkqlType = NCommon::BuildType(*p.second, pgmBuilder, typesMemoization, err); if (!mkqlType) { ctx.AddError(TIssue({}, TStringBuilder() << "Failed to process type for parameter: " << name << ", reason: " << err.Str())); isOk = false; diff --git a/yql/essentials/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp b/yql/essentials/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp index 5aacd3a669..58bf38ca45 100644 --- a/yql/essentials/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp +++ b/yql/essentials/providers/common/arrow_resolve/yql_simple_arrow_resolver.cpp @@ -28,12 +28,14 @@ private: TTypeBuilder typeBuilder(env); TNullOutput null; TVector<TType*> mkqlInputTypes; + NCommon::TMemoizedTypesMap typeMemoization; for (const auto& type : argTypes) { - auto mkqlType = NCommon::BuildType(*type, typeBuilder, null); + auto mkqlType = NCommon::BuildType(*type, typeBuilder, typeMemoization, null); YQL_ENSURE(mkqlType, "Failed to convert type " << *type << " to MKQL type"); mkqlInputTypes.emplace_back(mkqlType); } - TType* mkqlOutputType = NCommon::BuildType(*returnType, typeBuilder, null); + TType* mkqlOutputType = NCommon::BuildType(*returnType, typeBuilder, typeMemoization, null); + YQL_ENSURE(mkqlOutputType, "Failed to convert type " << *returnType << " to MKQL type"); bool found = FindArrowFunction(name, mkqlInputTypes, mkqlOutputType, *FunctionRegistry_.GetBuiltins()); return found ? EStatus::OK : EStatus::NOT_FOUND; } catch (const std::exception& e) { @@ -47,9 +49,12 @@ private: TScopedAlloc alloc(__LOCATION__); TTypeEnvironment env(alloc); TTypeBuilder typeBuilder(env); + NCommon::TMemoizedTypesMap typeMemoization; TNullOutput null; - auto mkqlFromType = NCommon::BuildType(*from, typeBuilder, null); - auto mkqlToType = NCommon::BuildType(*to, typeBuilder, null); + auto mkqlFromType = NCommon::BuildType(*from, typeBuilder, typeMemoization, null); + YQL_ENSURE(mkqlFromType, "Failed to convert type " << *from << " to MKQL type"); + auto mkqlToType = NCommon::BuildType(*to, typeBuilder, typeMemoization, null); + YQL_ENSURE(mkqlToType, "Failed to convert type " << *to << " to MKQL type"); return HasArrowCast(mkqlFromType, mkqlToType) ? EStatus::OK : EStatus::NOT_FOUND; } catch (const std::exception& e) { ctx.AddError(TIssue(pos, e.what())); @@ -64,8 +69,10 @@ private: TScopedAlloc alloc(__LOCATION__); TTypeEnvironment env(alloc); TTypeBuilder typeBuilder(env); + NCommon::TMemoizedTypesMap typeMemoization; + TNullOutput null; - bool allOk = true; + bool allOk = true; TArrowConvertFailedCallback cb; if (onUnsupported) { cb = [&](TType* failed) { @@ -81,8 +88,8 @@ private: for (const auto& type : types) { YQL_ENSURE(type); - TNullOutput null; - auto mkqlType = NCommon::BuildType(*type, typeBuilder, null); + auto mkqlType = NCommon::BuildType(*type, typeBuilder, typeMemoization, null); + YQL_ENSURE(mkqlType); std::shared_ptr<arrow::DataType> arrowType; if (!ConvertArrowType(mkqlType, arrowType, cb)) { allOk = false; diff --git a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp index d8ad662c9b..fb2a2a29c1 100644 --- a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp +++ b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp @@ -170,7 +170,7 @@ TMkqlBuildContext* GetContextForMemoize(const TExprNode& node, TMkqlBuildContext const TRuntimeNode& CheckTypeAndMemoize(const TExprNode& node, TMkqlBuildContext& ctx, const TRuntimeNode& runtime) { if (node.GetTypeAnn()) { TNullOutput null; - if (const auto type = BuildType(*node.GetTypeAnn(), ctx.ProgramBuilder, null)) { + if (const auto type = BuildType(*node.GetTypeAnn(), ctx.ProgramBuilder, *ctx.TypeMemoization, null)) { if (!type->IsSameType(*runtime.GetStaticType())) { ythrow TNodeException(node) << "Expected: " << *type << " type, but got: " << *runtime.GetStaticType() << "."; } @@ -661,7 +661,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable({"RoundUp", "RoundDown"}, [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto dstType = BuildType(node.Tail(), *node.Tail().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto dstType = ctx.BuildType(node.Tail(), *node.Tail().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); return ctx.ProgramBuilder.Round(node.Content(), arg, dstType); }); @@ -1020,7 +1020,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("CurrentActorId", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); return TRuntimeNode(call.Build(), false); }); @@ -1337,7 +1337,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("Struct", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto structType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto structType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto verifiedStructType = AS_TYPE(TStructType, structType); std::vector<std::pair<std::string_view, TRuntimeNode>> members; members.reserve(verifiedStructType->GetMembersCount()); @@ -1356,7 +1356,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("List", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto listType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto listType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto itemType = AS_TYPE(TListType, listType)->GetItemType(); const auto& items = GetArgumentsFrom<1U>(node, ctx); return ctx.ProgramBuilder.NewList(itemType, items); @@ -1364,48 +1364,48 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("FromString", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.FromString(arg, type); }); AddCallable("StrictFromString", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.StrictFromString(arg, type); }); AddCallable("FromBytes", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node, *node.GetTypeAnn()->Cast<TOptionalExprType>()->GetItemType(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()->Cast<TOptionalExprType>()->GetItemType()); return ctx.ProgramBuilder.FromBytes(arg, type); }); AddCallable("Convert", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.Convert(arg, type); }); AddCallable("ToIntegral", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.ToIntegral(arg, type); }); AddCallable("UnsafeTimestampCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.Convert(arg, type); }); AddCallable("SafeCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.Cast(arg, type); }); AddCallable("Default", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.Default(type); }); @@ -1427,7 +1427,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("EmptyFrom", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); switch (node.GetTypeAnn()->GetKind()) { case ETypeAnnotationKind::Flow: case ETypeAnnotationKind::Stream: @@ -1444,18 +1444,18 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("Nothing", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto optType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto optType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); return ctx.ProgramBuilder.NewEmptyOptional(optType); }); AddCallable("Unpickle", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto serialized = MkqlBuildExpr(node.Tail(), ctx); return ctx.ProgramBuilder.Unpickle(type, serialized); }); AddCallable("Optional", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto optType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto optType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto arg = MkqlBuildExpr(node.Tail(), ctx); return ctx.ProgramBuilder.NewOptional(optType, arg); }); @@ -1467,7 +1467,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("EmptyIterator", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto streamType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto streamType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); return ctx.ProgramBuilder.EmptyIterator(streamType); }); @@ -1495,11 +1495,11 @@ TMkqlCommonCallableCompiler::TShared::TShared() { } offset += outputStreams; - input.InputType = BuildType(lambdaArg, *lambdaArg.GetTypeAnn(), ctx.ProgramBuilder); + input.InputType = ctx.BuildType(lambdaArg, *lambdaArg.GetTypeAnn()); inputs.emplace_back(input); } - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.Switch(stream, inputs, [&](ui32 index, TRuntimeNode item) -> TRuntimeNode { return MkqlBuildLambda(*node.Child(2 + 2 * index + 1), ctx, {item}); }, memoryLimitBytes, returnType); @@ -1664,7 +1664,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { ythrow TNodeException(node) << "Wrong MapJoinCore input item type: " << inputItemType; } - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.MapJoinCore(list, dict, joinKind, leftKeyColumns, leftRenames, rightRenames, returnType); }); @@ -1715,7 +1715,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { anyJoinSettings = EAnyJoinSettings::Left == anyJoinSettings ? EAnyJoinSettings::Both : EAnyJoinSettings::Right; }); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return selfJoin ? ctx.ProgramBuilder.GraceSelfJoin(flowLeft, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings) @@ -1800,7 +1800,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { } } - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.CommonJoinCore(list, joinKind, leftColumns, rightColumns, requiredColumns, keyColumns, memLimit, sortedTableOrder, anyJoinSettings, tableIndexFieldPos, returnType); }); @@ -2084,7 +2084,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("Dict", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto listType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto listType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto dictType = AS_TYPE(TDictType, listType); std::vector<std::pair<TRuntimeNode, TRuntimeNode>> items; @@ -2099,7 +2099,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("Variant", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto varType = node.Child(2)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TVariantExprType>(); - const auto type = BuildType(*node.Child(2), *varType, ctx.ProgramBuilder); + const auto type = ctx.BuildType(*node.Child(2), *varType); const auto item = MkqlBuildExpr(node.Head(), ctx); return varType->GetUnderlyingType()->GetKind() == ETypeAnnotationKind::Tuple ? @@ -2168,17 +2168,17 @@ TMkqlCommonCallableCompiler::TShared::TShared() { "EmptyListType", "EmptyDictType"}, [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return TRuntimeNode(type, true); }); AddCallable("ParseType", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return TRuntimeNode(type, true); }); AddCallable("TypeOf", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return TRuntimeNode(type, true); }); @@ -2203,14 +2203,14 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("SourceOf", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.SourceOf(type); }); AddCallable("TypeHandle", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto type = node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(); const auto yson = WriteTypeToYson(type); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::Yson>(yson)); return TRuntimeNode(call.Build(), false); @@ -2220,7 +2220,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { const auto type = node.Head().GetTypeAnn(); const auto yson = WriteTypeToYson(type); const auto& args = GetAllArguments(node, ctx); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(pos.File)); @@ -2244,7 +2244,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { "SerializeCode", }, [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto& args = GetAllArguments(node, ctx); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); for (auto arg : args) { call.Add(arg); @@ -2291,7 +2291,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { "FuncCode", }, [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto& args = GetAllArguments(node, ctx); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(pos.File)); @@ -2307,7 +2307,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("LambdaCode", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto lambda = node.Child(node.ChildrenSize() - 1); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(pos.File)); @@ -2325,7 +2325,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { TMkqlBuildContext::TArgumentsMap innerArguments; innerArguments.reserve(lambda->Head().ChildrenSize()); lambda->Head().ForEachChild([&](const TExprNode& argNode) { - const auto argType = BuildType(argNode, *argNode.GetTypeAnn(), ctx.ProgramBuilder); + const auto argType = ctx.BuildType(argNode, *argNode.GetTypeAnn()); const auto arg = ctx.ProgramBuilder.Arg(argType); innerArguments.emplace(&argNode, arg); call.Add(arg); @@ -2380,7 +2380,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable({ "AsTagged","Untag" }, [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(node.Head(), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.Nop(input, returnType); }); @@ -2453,9 +2453,9 @@ TMkqlCommonCallableCompiler::TShared::TShared() { YQL_ENSURE(node.ChildrenSize() == 8); std::string_view function = node.Head().Content(); const auto runConfig = MkqlBuildExpr(*node.Child(1), ctx); - const auto userType = BuildType(*node.Child(2), *node.Child(2)->GetTypeAnn(), ctx.ProgramBuilder); + const auto userType = ctx.BuildType(*node.Child(2), *node.Child(2)->GetTypeAnn()); const auto typeConfig = node.Child(3)->Content(); - const auto callableType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto callableType = ctx.BuildType(node, *node.GetTypeAnn()); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); return ctx.ProgramBuilder.TypedUdf(function, callableType, runConfig, userType, typeConfig, pos.File, pos.Row, pos.Column); @@ -2471,7 +2471,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { std::string_view funcName = node.Child(1)->Content(); const auto typeNode = node.Child(2); - const auto funcType = BuildType(*typeNode, *typeNode->GetTypeAnn(), ctx.ProgramBuilder); + const auto funcType = ctx.BuildType(*typeNode, *typeNode->GetTypeAnn()); const auto script = MkqlBuildExpr(*node.Child(3), ctx); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); return ctx.ProgramBuilder.ScriptUdf(node.Head().Content(), funcName, funcType, script, @@ -2534,7 +2534,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { continue; } - auto mkqlType = BuildType(node, *callableType->GetArguments()[i].Type, ctx.ProgramBuilder); + auto mkqlType = ctx.BuildType(node, *callableType->GetArguments()[i].Type); arg = ctx.ProgramBuilder.NewEmptyOptional(mkqlType); } @@ -2546,7 +2546,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("Callable", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto callableType = BuildType(node.Head(), *node.Head().GetTypeAnn(), ctx.ProgramBuilder); + const auto callableType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()); return ctx.ProgramBuilder.Callable(callableType, [&](const TArrayRef<const TRuntimeNode>& args) { const auto& lambda = node.Tail(); TMkqlBuildContext::TArgumentsMap innerArguments; @@ -2560,7 +2560,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("PgConst", [](const TExprNode& node, TMkqlBuildContext& ctx) { - auto type = AS_TYPE(TPgType, BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder)); + auto type = AS_TYPE(TPgType, ctx.BuildType(node, *node.GetTypeAnn())); TRuntimeNode typeMod; if (node.ChildrenSize() >= 3) { typeMod = MkqlBuildExpr(*node.Child(2), ctx); @@ -2581,7 +2581,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("PgInternal0", [](const TExprNode& node, TMkqlBuildContext& ctx) { - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgInternal0(returnType); }); @@ -2601,7 +2601,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { } } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgResolvedCall(node.IsCallable("PgResolvedCallCtx"), name, id, args, returnType, rangeFunction); }); @@ -2615,7 +2615,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgResolvedCall(false, procName, procId, args, returnType, false); }); @@ -2628,7 +2628,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockPgResolvedCall(name, id, args, returnType); }); @@ -2642,14 +2642,14 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockPgResolvedCall(procName, procId, args, returnType); }); AddCallable("PgCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); TRuntimeNode typeMod; if (node.ChildrenSize() >= 3) { typeMod = MkqlBuildExpr(*node.Child(2), ctx); @@ -2685,25 +2685,25 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("FromPg", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.FromPg(input, returnType); }); AddCallable("ToPg", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.ToPg(input, returnType); }); AddCallable("BlockFromPg", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockFromPg(input, returnType); }); AddCallable("BlockToPg", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockToPg(input, returnType); }); @@ -2726,7 +2726,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("PgTableContent", [](const TExprNode& node, TMkqlBuildContext& ctx) { - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgTableContent( node.Child(0)->Content(), node.Child(1)->Content(), @@ -2754,13 +2754,13 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockFunc(node.Child(0)->Content(), returnType, args); }); AddCallable("BlockBitCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto arg = MkqlBuildExpr(*node.Child(0), ctx); - auto targetType = BuildType(node, *node.Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + auto targetType = ctx.BuildType(node, *node.Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()); return ctx.ProgramBuilder.BlockBitCast(arg, targetType); }); @@ -2811,7 +2811,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { aggs.push_back(info); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockCombineAll(arg, filterColumn, aggs, returnType); }); @@ -2838,7 +2838,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { aggs.push_back(info); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockCombineHashed(arg, filterColumn, keys, aggs, returnType); }); @@ -2860,7 +2860,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { aggs.push_back(info); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockMergeFinalizeHashed(arg, keys, aggs, returnType); }); @@ -2891,7 +2891,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { } } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockMergeManyFinalizeHashed(arg, keys, aggs, streamIndex, streams, returnType); }); @@ -2913,7 +2913,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgArray(args, returnType); }); @@ -2921,7 +2921,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { const auto initCapacity = MkqlBuildExpr(*node.Child(1), ctx); const auto initSize = MkqlBuildExpr(*node.Child(2), ctx); const auto& args = GetArgumentsFrom<3U>(node, ctx); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.QueueCreate(initCapacity, initSize, args, returnType); }); @@ -2929,7 +2929,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { const auto resource = MkqlBuildExpr(node.Head(), ctx); const auto index = MkqlBuildExpr(*node.Child(1), ctx); const auto& args = GetArgumentsFrom<2U>(node, ctx); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.QueuePeek(resource, index, args, returnType); }); @@ -2938,13 +2938,13 @@ TMkqlCommonCallableCompiler::TShared::TShared() { const auto begin = MkqlBuildExpr(*node.Child(1), ctx); const auto end = MkqlBuildExpr(*node.Child(2), ctx); const auto& args = GetArgumentsFrom<3U>(node, ctx); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.QueueRange(resource, begin, end, args, returnType); }); AddCallable("Seq", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto& args = GetArgumentsFrom<0U>(node, ctx); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.Seq(args, returnType); }); diff --git a/yql/essentials/providers/common/mkql/yql_provider_mkql.h b/yql/essentials/providers/common/mkql/yql_provider_mkql.h index c65494cecb..1cdcfd8262 100644 --- a/yql/essentials/providers/common/mkql/yql_provider_mkql.h +++ b/yql/essentials/providers/common/mkql/yql_provider_mkql.h @@ -1,5 +1,7 @@ #pragma once +#include "yql_type_mkql.h" + #include <yql/essentials/ast/yql_expr.h> #include <yql/essentials/minikql/mkql_node.h> #include <yql/essentials/minikql/mkql_program_builder.h> @@ -17,16 +19,19 @@ struct TMkqlBuildContext { NKikimr::NMiniKQL::TProgramBuilder& ProgramBuilder; TExprContext& ExprCtx; TMemoizedNodesMap Memoization; + TMemoizedTypesMap TypeMemoizationHolder; + TMemoizedTypesMap* TypeMemoization; TMkqlBuildContext *const ParentCtx = nullptr; const size_t Level = 0ULL; const ui64 LambdaId = 0ULL; NKikimr::NMiniKQL::TRuntimeNode Parameters; - TMkqlBuildContext(const IMkqlCallableCompiler& mkqlCompiler, NKikimr::NMiniKQL::TProgramBuilder& builder, TExprContext& exprCtx, ui64 lambdaId = 0ULL, TArgumentsMap&& args = {}) + TMkqlBuildContext(const IMkqlCallableCompiler& mkqlCompiler, NKikimr::NMiniKQL::TProgramBuilder& builder, TExprContext& exprCtx, ui64 lambdaId = 0ULL, TArgumentsMap&& args = {}, TMemoizedTypesMap* typeMemoization = nullptr) : MkqlCompiler(mkqlCompiler) , ProgramBuilder(builder) , ExprCtx(exprCtx) , Memoization(std::move(args)) + , TypeMemoization(typeMemoization ? typeMemoization : &TypeMemoizationHolder) , LambdaId(lambdaId) {} @@ -35,11 +40,16 @@ struct TMkqlBuildContext { , ProgramBuilder(parent.ProgramBuilder) , ExprCtx(parent.ExprCtx) , Memoization(std::move(args)) + , TypeMemoization(parent.TypeMemoization) , ParentCtx(&parent) , Level(parent.Level + 1U) , LambdaId(lambdaId) , Parameters(parent.Parameters) {} + + NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation) { + return NYql::NCommon::BuildType(owner, annotation, ProgramBuilder, *TypeMemoization); + } }; class IMkqlCallableCompiler : public TThrRefBase { diff --git a/yql/essentials/providers/common/mkql/yql_type_mkql.cpp b/yql/essentials/providers/common/mkql/yql_type_mkql.cpp index 0aaf365468..cb001c847b 100644 --- a/yql/essentials/providers/common/mkql/yql_type_mkql.cpp +++ b/yql/essentials/providers/common/mkql/yql_type_mkql.cpp @@ -16,7 +16,9 @@ namespace NYql { namespace NCommon { -NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, IOutputStream& err) { +namespace { + +NKikimr::NMiniKQL::TType* BuildTypeImpl(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization, IOutputStream& err) { switch (annotation.GetKind()) { case ETypeAnnotationKind::Data: { auto data = annotation.Cast<TDataExprType>(); @@ -42,7 +44,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const members.reserve(structObj->GetItems().size()); for (auto& item : structObj->GetItems()) { - auto itemType = BuildType(*item->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*item->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -53,7 +55,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::List: { auto list = annotation.Cast<TListExprType>(); - auto itemType = BuildType(*list->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*list->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -62,7 +64,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Optional: { auto optional = annotation.Cast<TOptionalExprType>(); - auto itemType = BuildType(*optional->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*optional->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -74,7 +76,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const TVector<NKikimr::NMiniKQL::TType*> elements; elements.reserve(tuple->GetItems().size()); for (auto& child : tuple->GetItems()) { - elements.push_back(BuildType(*child, typeBuilder, err)); + elements.push_back(BuildType(*child, typeBuilder, memoization, err)); if (!elements.back()) { return nullptr; } @@ -97,8 +99,8 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Dict: { auto dictType = annotation.Cast<TDictExprType>(); - auto keyType = BuildType(*dictType->GetKeyType(), typeBuilder, err); - auto payloadType = BuildType(*dictType->GetPayloadType(), typeBuilder, err); + auto keyType = BuildType(*dictType->GetKeyType(), typeBuilder, memoization, err); + auto payloadType = BuildType(*dictType->GetPayloadType(), typeBuilder, memoization, err); if (!keyType || !payloadType) { return nullptr; } @@ -107,7 +109,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Type: { auto type = annotation.Cast<TTypeExprType>()->GetType(); - return BuildType(*type, typeBuilder, err); + return BuildType(*type, typeBuilder, memoization, err); } case ETypeAnnotationKind::Void: { @@ -120,10 +122,10 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Callable: { auto callable = annotation.Cast<TCallableExprType>(); - auto returnType = BuildType(*callable->GetReturnType(), typeBuilder, err); + auto returnType = BuildType(*callable->GetReturnType(), typeBuilder, memoization, err); NKikimr::NMiniKQL::TCallableTypeBuilder callableTypeBuilder(typeBuilder.GetTypeEnvironment(), "", returnType); for (auto& child : callable->GetArguments()) { - callableTypeBuilder.Add(BuildType(*child.Type, typeBuilder, err)); + callableTypeBuilder.Add(BuildType(*child.Type, typeBuilder, memoization, err)); if (!child.Name.empty()) { callableTypeBuilder.SetArgumentName(child.Name); } @@ -150,13 +152,13 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Tagged: { auto tagged = annotation.Cast<TTaggedExprType>(); - auto base = BuildType(*tagged->GetBaseType(), typeBuilder, err); + auto base = BuildType(*tagged->GetBaseType(), typeBuilder, memoization, err); return typeBuilder.NewTaggedType(base, tagged->GetTag()); } case ETypeAnnotationKind::Variant: { auto var = annotation.Cast<TVariantExprType>(); - auto underlyingType = BuildType(*var->GetUnderlyingType(), typeBuilder, err); + auto underlyingType = BuildType(*var->GetUnderlyingType(), typeBuilder, memoization, err); if (!underlyingType) { return nullptr; } @@ -165,7 +167,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Stream: { auto stream = annotation.Cast<TStreamExprType>(); - auto itemType = BuildType(*stream->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*stream->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -174,7 +176,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Flow: { auto flow = annotation.Cast<TFlowExprType>(); - auto itemType = BuildType(*flow->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*flow->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -201,7 +203,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Block: { auto block = annotation.Cast<TBlockExprType>(); - auto itemType = BuildType(*block->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*block->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -210,7 +212,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Scalar: { auto scalar = annotation.Cast<TScalarExprType>(); - auto itemType = BuildType(*scalar->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*scalar->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -221,21 +223,47 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::World: case ETypeAnnotationKind::Error: case ETypeAnnotationKind::LastType: + err << "Can't build mkql type for " << annotation.GetKind(); return nullptr; } } -NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder) { +} // unnamed + +NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization, IOutputStream& err) { + if (const auto knownType = memoization.find(&annotation); memoization.cend() != knownType) { + return knownType->second; + } + + return memoization.emplace(&annotation, BuildTypeImpl(annotation, typeBuilder, memoization, err)).first->second; +} + +NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization) { TStringStream err; - auto type = BuildType(annotation, typeBuilder, err); + auto type = BuildType(annotation, typeBuilder, memoization, err); if (!type) { ythrow TNodeException(pos) << err.Str(); } return type; } +NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization) { + return BuildType(owner.Pos(), annotation, typeBuilder, memoization); +} + +NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, IOutputStream& err) { + TMemoizedTypesMap memoization; + return BuildType(annotation, typeBuilder, memoization, err); +} + +NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder) { + TMemoizedTypesMap memoization; + return BuildType(pos, annotation, typeBuilder, memoization); +} + NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder) { - return BuildType(owner.Pos(), annotation, typeBuilder); + TMemoizedTypesMap memoization; + return BuildType(owner, annotation, typeBuilder, memoization); } const TTypeAnnotationNode* ConvertMiniKQLType(TPosition position, NKikimr::NMiniKQL::TType* type, TExprContext& ctx) { diff --git a/yql/essentials/providers/common/mkql/yql_type_mkql.h b/yql/essentials/providers/common/mkql/yql_type_mkql.h index 834b202848..3718e4be8f 100644 --- a/yql/essentials/providers/common/mkql/yql_type_mkql.h +++ b/yql/essentials/providers/common/mkql/yql_type_mkql.h @@ -4,6 +4,8 @@ #include <yql/essentials/ast/yql_pos_handle.h> #include <yql/essentials/minikql/mkql_type_builder.h> +#include <unordered_map> + namespace NKikimr { namespace NMiniKQL { @@ -17,6 +19,12 @@ class TProgramBuilder; namespace NYql { namespace NCommon { +using TMemoizedTypesMap = std::unordered_map<const TTypeAnnotationNode*, NKikimr::NMiniKQL::TType*>; + +NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization, IOutputStream& err); +NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization); +NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization); + NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, IOutputStream& err); NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder); NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder); diff --git a/yql/essentials/public/purecalc/common/compile_mkql.cpp b/yql/essentials/public/purecalc/common/compile_mkql.cpp index 743447ada9..7bb95e2075 100644 --- a/yql/essentials/public/purecalc/common/compile_mkql.cpp +++ b/yql/essentials/public/purecalc/common/compile_mkql.cpp @@ -18,7 +18,7 @@ NCommon::IMkqlCallableCompiler::TCompiler MakeSelfCallableCompiler() { MKQL_ENSURE(argument->IsAtom(), "Self argument must be atom"); ui32 inputIndex = 0; MKQL_ENSURE(TryFromString(argument->Content(), inputIndex), "Self argument must be UI32"); - auto type = NCommon::BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto type = ctx.BuildType(node, *node.GetTypeAnn()); NKikimr::NMiniKQL::TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), type); call.Add(ctx.ProgramBuilder.NewDataLiteral<ui32>(inputIndex)); return NKikimr::NMiniKQL::TRuntimeNode(call.Build(), false); @@ -93,7 +93,7 @@ NCommon::IMkqlCallableCompiler::TCompiler MakeFolderPathCallableCompiler(const T } NKikimr::NMiniKQL::TRuntimeNode CompileMkql(const TExprNode::TPtr& exprRoot, TExprContext& exprCtx, - const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, const NKikimr::NMiniKQL::TTypeEnvironment& env, const TUserDataTable& userData) + const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, const NKikimr::NMiniKQL::TTypeEnvironment& env, const TUserDataTable& userData, NCommon::TMemoizedTypesMap* typeMemoization) { NCommon::TMkqlCommonCallableCompiler compiler; @@ -106,7 +106,7 @@ NKikimr::NMiniKQL::TRuntimeNode CompileMkql(const TExprNode::TPtr& exprRoot, TEx // Prepare build context NKikimr::NMiniKQL::TProgramBuilder pgmBuilder(env, funcRegistry); - NCommon::TMkqlBuildContext buildCtx(compiler, pgmBuilder, exprCtx); + NCommon::TMkqlBuildContext buildCtx(compiler, pgmBuilder, exprCtx, /*lambdaId*/0, /*args*/{}, typeMemoization); // Build the root MKQL node diff --git a/yql/essentials/public/purecalc/common/compile_mkql.h b/yql/essentials/public/purecalc/common/compile_mkql.h index 0b6c16aef5..caba3baa2b 100644 --- a/yql/essentials/public/purecalc/common/compile_mkql.h +++ b/yql/essentials/public/purecalc/common/compile_mkql.h @@ -1,17 +1,20 @@ #pragma once +#include <yql/essentials/providers/common/mkql/yql_type_mkql.h> #include <yql/essentials/public/purecalc/common/interface.h> #include <yql/essentials/minikql/mkql_node.h> #include <yql/essentials/ast/yql_expr.h> #include <yql/essentials/core/yql_user_data.h> -namespace NYql { - namespace NPureCalc { - /** - * Compile expr to mkql byte-code - */ - NKikimr::NMiniKQL::TRuntimeNode CompileMkql(const TExprNode::TPtr& exprRoot, TExprContext& exprCtx, - const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, const NKikimr::NMiniKQL::TTypeEnvironment& env, const TUserDataTable& userData); - } +namespace NYql::NPureCalc { + + /** + * Compile expr to mkql byte-code + */ + + NKikimr::NMiniKQL::TRuntimeNode CompileMkql(const TExprNode::TPtr& exprRoot, TExprContext& exprCtx, + const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, const NKikimr::NMiniKQL::TTypeEnvironment& env, const TUserDataTable& userData, + NCommon::TMemoizedTypesMap* typeMemoization = nullptr); + } diff --git a/yql/essentials/public/purecalc/common/worker.cpp b/yql/essentials/public/purecalc/common/worker.cpp index f670458c72..58cbf23c92 100644 --- a/yql/essentials/public/purecalc/common/worker.cpp +++ b/yql/essentials/public/purecalc/common/worker.cpp @@ -57,10 +57,10 @@ TWorkerGraph::TWorkerGraph( , NativeYtTypeFlags_(nativeYtTypeFlags) { // Build the root MKQL node - + NCommon::TMemoizedTypesMap typeMemoization; NKikimr::NMiniKQL::TRuntimeNode rootNode; if (exprRoot) { - rootNode = CompileMkql(exprRoot, exprCtx, FuncRegistry_, Env_, userData); + rootNode = CompileMkql(exprRoot, exprCtx, FuncRegistry_, Env_, userData, &typeMemoization); } else { rootNode = NKikimr::NMiniKQL::DeserializeRuntimeNode(serializedProgram, Env_); } @@ -79,12 +79,12 @@ TWorkerGraph::TWorkerGraph( NKikimr::NMiniKQL::TProgramBuilder pgmBuilder(Env_, FuncRegistry_); for (ui32 i = 0; i < inputsCount; ++i) { - const auto* type = static_cast<NKikimr::NMiniKQL::TStructType*>(NCommon::BuildType(TPositionHandle(), *inputTypes[i], pgmBuilder)); + const auto* type = static_cast<NKikimr::NMiniKQL::TStructType*>(NCommon::BuildType(TPositionHandle(), *inputTypes[i], pgmBuilder, typeMemoization)); const auto* originalType = type; - const auto* rawType = static_cast<NKikimr::NMiniKQL::TStructType*>(NCommon::BuildType(TPositionHandle(), *rawInputTypes[i], pgmBuilder)); + const auto* rawType = static_cast<NKikimr::NMiniKQL::TStructType*>(NCommon::BuildType(TPositionHandle(), *rawInputTypes[i], pgmBuilder, typeMemoization)); if (inputTypes[i] != originalInputTypes[i]) { YQL_ENSURE(inputTypes[i]->GetSize() >= originalInputTypes[i]->GetSize()); - originalType = static_cast<NKikimr::NMiniKQL::TStructType*>(NCommon::BuildType(TPositionHandle(), *originalInputTypes[i], pgmBuilder)); + originalType = static_cast<NKikimr::NMiniKQL::TStructType*>(NCommon::BuildType(TPositionHandle(), *originalInputTypes[i], pgmBuilder, typeMemoization)); } InputTypes_.push_back(type); @@ -93,10 +93,10 @@ TWorkerGraph::TWorkerGraph( } if (outputType) { - OutputType_ = NCommon::BuildType(TPositionHandle(), *outputType, pgmBuilder); + OutputType_ = NCommon::BuildType(TPositionHandle(), *outputType, pgmBuilder, typeMemoization); } if (rawOutputType) { - RawOutputType_ = NCommon::BuildType(TPositionHandle(), *rawOutputType, pgmBuilder); + RawOutputType_ = NCommon::BuildType(TPositionHandle(), *rawOutputType, pgmBuilder, typeMemoization); } if (!exprRoot) { diff --git a/yt/yql/providers/yt/gateway/file/yql_yt_file_mkql_compiler.cpp b/yt/yql/providers/yt/gateway/file/yql_yt_file_mkql_compiler.cpp index b4ff5c5ca4..f6535f7909 100644 --- a/yt/yql/providers/yt/gateway/file/yql_yt_file_mkql_compiler.cpp +++ b/yt/yql/providers/yt/gateway/file/yql_yt_file_mkql_compiler.cpp @@ -373,7 +373,7 @@ TRuntimeNode ToList(TRuntimeNode list, NCommon::TMkqlBuildContext& ctx) { TType* BuildInputType(TYtSectionList input, NCommon::TMkqlBuildContext& ctx) { TVector<TType*> items; for (auto section: input) { - items.push_back(NCommon::BuildType(input.Ref(), *section.Ref().GetTypeAnn()->Cast<TListExprType>()->GetItemType(), ctx.ProgramBuilder)); + items.push_back(ctx.BuildType(input.Ref(), *section.Ref().GetTypeAnn()->Cast<TListExprType>()->GetItemType())); } return items.size() == 1 ? items.front() @@ -383,7 +383,7 @@ TType* BuildInputType(TYtSectionList input, NCommon::TMkqlBuildContext& ctx) { TType* BuildOutputType(TYtOutSection output, NCommon::TMkqlBuildContext& ctx) { TVector<TType*> items; for (auto table: output) { - items.push_back(NCommon::BuildType(output.Ref(), *table.Ref().GetTypeAnn()->Cast<TListExprType>()->GetItemType(), ctx.ProgramBuilder)); + items.push_back(ctx.BuildType(output.Ref(), *table.Ref().GetTypeAnn()->Cast<TListExprType>()->GetItemType())); } return items.size() == 1 ? items.front() diff --git a/yt/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp b/yt/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp index 902af3f9b9..3ad3f683d9 100644 --- a/yt/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp +++ b/yt/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp @@ -153,7 +153,7 @@ TRuntimeNode BuildTableContentCall(TStringBuf callName, } secType = ctx.ProgramBuilder.NewStructType(secType, key, - NCommon::BuildType(section.Ref(), *keyType, ctx.ProgramBuilder)); + ctx.BuildType(section.Ref(), *keyType)); rebuildType = true; } @@ -455,7 +455,7 @@ void RegisterYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler) { [](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) { TYtTableContent tableContent(&node); if (node.GetConstraint<TEmptyConstraintNode>()) { - const auto itemType = NCommon::BuildType(node, GetSeqItemType(*node.GetTypeAnn()), ctx.ProgramBuilder); + const auto itemType = ctx.BuildType(node, GetSeqItemType(*node.GetTypeAnn())); return ctx.ProgramBuilder.NewEmptyList(itemType); } TMaybe<ui64> itemsCount; @@ -469,12 +469,12 @@ void RegisterYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler) { if (auto maybeRead = tableContent.Input().Maybe<TYtReadTable>()) { auto read = maybeRead.Cast(); return BuildTableContentCall(name, - NCommon::BuildType(node, *node.GetTypeAnn()->Cast<TListExprType>()->GetItemType(), ctx.ProgramBuilder), + ctx.BuildType(node, *node.GetTypeAnn()->Cast<TListExprType>()->GetItemType()), read.DataSource().Cluster().Value(), read.Input().Ref(), itemsCount, ctx, true); } else { auto output = tableContent.Input().Cast<TYtOutput>(); return BuildTableContentCall(name, - NCommon::BuildType(node, *node.GetTypeAnn()->Cast<TListExprType>()->GetItemType(), ctx.ProgramBuilder), + ctx.BuildType(node, *node.GetTypeAnn()->Cast<TListExprType>()->GetItemType()), GetOutputOp(output).DataSink().Cluster().Value(), output.Ref(), itemsCount, ctx, true); } }); @@ -498,12 +498,12 @@ void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, con if (const auto& wrapper = TDqReadBlockWideWrap(&node); wrapper.Input().Maybe<TYtReadTable>().IsValid()) { const auto ytRead = wrapper.Input().Cast<TYtReadTable>(); const auto readType = ytRead.Ref().GetTypeAnn()->Cast<TTupleExprType>()->GetItems().back(); - const auto inputItemType = NCommon::BuildType(wrapper.Input().Ref(), GetSeqItemType(*readType), ctx.ProgramBuilder); + const auto inputItemType = ctx.BuildType(wrapper.Input().Ref(), GetSeqItemType(*readType)); const auto cluster = ytRead.DataSource().Cluster().StringValue(); const bool useRPCReaderDefault = DEFAULT_USE_RPC_READER_IN_DQ || state->Types->BlockEngineMode != EBlockEngineMode::Disable; size_t inflight = state->Configuration->UseRPCReaderInDQ.Get(cluster).GetOrElse(useRPCReaderDefault) ? state->Configuration->DQRPCReaderInflight.Get(cluster).GetOrElse(DEFAULT_RPC_READER_INFLIGHT) : 0; size_t timeout = state->Configuration->DQRPCReaderTimeout.Get(cluster).GetOrElse(DEFAULT_RPC_READER_TIMEOUT).MilliSeconds(); - const auto outputType = NCommon::BuildType(wrapper.Ref(), *wrapper.Ref().GetTypeAnn(), ctx.ProgramBuilder); + const auto outputType = ctx.BuildType(wrapper.Ref(), *wrapper.Ref().GetTypeAnn()); TString tokenName; if (auto secureParams = wrapper.Token()) { tokenName = secureParams.Cast().Name().StringValue(); @@ -529,11 +529,11 @@ void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, con if (const auto& wrapper = TDqReadWideWrap(&node); wrapper.Input().Maybe<TYtReadTable>().IsValid()) { const auto ytRead = wrapper.Input().Cast<TYtReadTable>(); const auto readType = ytRead.Ref().GetTypeAnn()->Cast<TTupleExprType>()->GetItems().back(); - const auto inputItemType = NCommon::BuildType(wrapper.Input().Ref(), GetSeqItemType(*readType), ctx.ProgramBuilder); + const auto inputItemType = ctx.BuildType(wrapper.Input().Ref(), GetSeqItemType(*readType)); const auto cluster = ytRead.DataSource().Cluster().StringValue(); size_t isRPC = state->Configuration->UseRPCReaderInDQ.Get(cluster).GetOrElse(DEFAULT_USE_RPC_READER_IN_DQ) ? state->Configuration->DQRPCReaderInflight.Get(cluster).GetOrElse(DEFAULT_RPC_READER_INFLIGHT) : 0; size_t timeout = state->Configuration->DQRPCReaderTimeout.Get(cluster).GetOrElse(DEFAULT_RPC_READER_TIMEOUT).MilliSeconds(); - const auto outputType = NCommon::BuildType(wrapper.Ref(), *wrapper.Ref().GetTypeAnn(), ctx.ProgramBuilder); + const auto outputType = ctx.BuildType(wrapper.Ref(), *wrapper.Ref().GetTypeAnn()); TString tokenName; if (auto secureParams = wrapper.Token()) { tokenName = secureParams.Cast().Name().StringValue(); @@ -556,7 +556,7 @@ void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, con compiler.AddCallable(TYtDqWideWrite::CallableName(), [](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) { const auto write = TYtDqWideWrite(&node); - const auto outType = NCommon::BuildType(write.Ref(), *write.Ref().GetTypeAnn(), ctx.ProgramBuilder); + const auto outType = ctx.BuildType(write.Ref(), *write.Ref().GetTypeAnn()); const auto arg = MkqlBuildExpr(write.Input().Ref(), ctx); TString server{GetSetting(write.Settings().Ref(), "server")->Child(1)->Content()}; |