diff options
author | udovichenko-r <udovichenko-r@yandex-team.com> | 2024-12-05 16:42:04 +0300 |
---|---|---|
committer | udovichenko-r <udovichenko-r@yandex-team.com> | 2024-12-05 17:07:33 +0300 |
commit | ac81b4ff5a8fd2fd4c1fafa78ac07defb71478dc (patch) | |
tree | 159a2723ccc1de82b274fa36603ac6cb723bfad3 /yql/essentials/providers/common/mkql | |
parent | 8aee4675a8b41c39f833026f579f9828deee361a (diff) | |
download | ydb-ac81b4ff5a8fd2fd4c1fafa78ac07defb71478dc.tar.gz |
[mkql] Use type memoization in compiler
YQL-19355
commit_hash:ff1684e65f8ec72be0a1407be5d31bb3fb0465a7
Diffstat (limited to 'yql/essentials/providers/common/mkql')
4 files changed, 130 insertions, 84 deletions
diff --git a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp index d8ad662c9b..fb2a2a29c1 100644 --- a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp +++ b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp @@ -170,7 +170,7 @@ TMkqlBuildContext* GetContextForMemoize(const TExprNode& node, TMkqlBuildContext const TRuntimeNode& CheckTypeAndMemoize(const TExprNode& node, TMkqlBuildContext& ctx, const TRuntimeNode& runtime) { if (node.GetTypeAnn()) { TNullOutput null; - if (const auto type = BuildType(*node.GetTypeAnn(), ctx.ProgramBuilder, null)) { + if (const auto type = BuildType(*node.GetTypeAnn(), ctx.ProgramBuilder, *ctx.TypeMemoization, null)) { if (!type->IsSameType(*runtime.GetStaticType())) { ythrow TNodeException(node) << "Expected: " << *type << " type, but got: " << *runtime.GetStaticType() << "."; } @@ -661,7 +661,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable({"RoundUp", "RoundDown"}, [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto dstType = BuildType(node.Tail(), *node.Tail().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto dstType = ctx.BuildType(node.Tail(), *node.Tail().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); return ctx.ProgramBuilder.Round(node.Content(), arg, dstType); }); @@ -1020,7 +1020,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("CurrentActorId", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); return TRuntimeNode(call.Build(), false); }); @@ -1337,7 +1337,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("Struct", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto structType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto structType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto verifiedStructType = AS_TYPE(TStructType, structType); std::vector<std::pair<std::string_view, TRuntimeNode>> members; members.reserve(verifiedStructType->GetMembersCount()); @@ -1356,7 +1356,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("List", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto listType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto listType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto itemType = AS_TYPE(TListType, listType)->GetItemType(); const auto& items = GetArgumentsFrom<1U>(node, ctx); return ctx.ProgramBuilder.NewList(itemType, items); @@ -1364,48 +1364,48 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("FromString", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.FromString(arg, type); }); AddCallable("StrictFromString", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.StrictFromString(arg, type); }); AddCallable("FromBytes", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node, *node.GetTypeAnn()->Cast<TOptionalExprType>()->GetItemType(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()->Cast<TOptionalExprType>()->GetItemType()); return ctx.ProgramBuilder.FromBytes(arg, type); }); AddCallable("Convert", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.Convert(arg, type); }); AddCallable("ToIntegral", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.ToIntegral(arg, type); }); AddCallable("UnsafeTimestampCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.Convert(arg, type); }); AddCallable("SafeCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); return ctx.ProgramBuilder.Cast(arg, type); }); AddCallable("Default", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.Default(type); }); @@ -1427,7 +1427,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("EmptyFrom", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node.Head(), *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.GetTypeAnn()); switch (node.GetTypeAnn()->GetKind()) { case ETypeAnnotationKind::Flow: case ETypeAnnotationKind::Stream: @@ -1444,18 +1444,18 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("Nothing", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto optType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto optType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); return ctx.ProgramBuilder.NewEmptyOptional(optType); }); AddCallable("Unpickle", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto serialized = MkqlBuildExpr(node.Tail(), ctx); return ctx.ProgramBuilder.Unpickle(type, serialized); }); AddCallable("Optional", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto optType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto optType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto arg = MkqlBuildExpr(node.Tail(), ctx); return ctx.ProgramBuilder.NewOptional(optType, arg); }); @@ -1467,7 +1467,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("EmptyIterator", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto streamType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto streamType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); return ctx.ProgramBuilder.EmptyIterator(streamType); }); @@ -1495,11 +1495,11 @@ TMkqlCommonCallableCompiler::TShared::TShared() { } offset += outputStreams; - input.InputType = BuildType(lambdaArg, *lambdaArg.GetTypeAnn(), ctx.ProgramBuilder); + input.InputType = ctx.BuildType(lambdaArg, *lambdaArg.GetTypeAnn()); inputs.emplace_back(input); } - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.Switch(stream, inputs, [&](ui32 index, TRuntimeNode item) -> TRuntimeNode { return MkqlBuildLambda(*node.Child(2 + 2 * index + 1), ctx, {item}); }, memoryLimitBytes, returnType); @@ -1664,7 +1664,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { ythrow TNodeException(node) << "Wrong MapJoinCore input item type: " << inputItemType; } - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.MapJoinCore(list, dict, joinKind, leftKeyColumns, leftRenames, rightRenames, returnType); }); @@ -1715,7 +1715,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { anyJoinSettings = EAnyJoinSettings::Left == anyJoinSettings ? EAnyJoinSettings::Both : EAnyJoinSettings::Right; }); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return selfJoin ? ctx.ProgramBuilder.GraceSelfJoin(flowLeft, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings) @@ -1800,7 +1800,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { } } - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.CommonJoinCore(list, joinKind, leftColumns, rightColumns, requiredColumns, keyColumns, memLimit, sortedTableOrder, anyJoinSettings, tableIndexFieldPos, returnType); }); @@ -2084,7 +2084,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("Dict", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto listType = BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + const auto listType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType()); const auto dictType = AS_TYPE(TDictType, listType); std::vector<std::pair<TRuntimeNode, TRuntimeNode>> items; @@ -2099,7 +2099,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("Variant", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto varType = node.Child(2)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TVariantExprType>(); - const auto type = BuildType(*node.Child(2), *varType, ctx.ProgramBuilder); + const auto type = ctx.BuildType(*node.Child(2), *varType); const auto item = MkqlBuildExpr(node.Head(), ctx); return varType->GetUnderlyingType()->GetKind() == ETypeAnnotationKind::Tuple ? @@ -2168,17 +2168,17 @@ TMkqlCommonCallableCompiler::TShared::TShared() { "EmptyListType", "EmptyDictType"}, [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return TRuntimeNode(type, true); }); AddCallable("ParseType", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return TRuntimeNode(type, true); }); AddCallable("TypeOf", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return TRuntimeNode(type, true); }); @@ -2203,14 +2203,14 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("SourceOf", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto type = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto type = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.SourceOf(type); }); AddCallable("TypeHandle", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto type = node.Head().GetTypeAnn()->Cast<TTypeExprType>()->GetType(); const auto yson = WriteTypeToYson(type); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::Yson>(yson)); return TRuntimeNode(call.Build(), false); @@ -2220,7 +2220,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { const auto type = node.Head().GetTypeAnn(); const auto yson = WriteTypeToYson(type); const auto& args = GetAllArguments(node, ctx); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(pos.File)); @@ -2244,7 +2244,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { "SerializeCode", }, [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto& args = GetAllArguments(node, ctx); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); for (auto arg : args) { call.Add(arg); @@ -2291,7 +2291,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { "FuncCode", }, [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto& args = GetAllArguments(node, ctx); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(pos.File)); @@ -2307,7 +2307,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("LambdaCode", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto lambda = node.Child(node.ChildrenSize() - 1); - const auto retType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto retType = ctx.BuildType(node, *node.GetTypeAnn()); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), retType); call.Add(ctx.ProgramBuilder.NewDataLiteral<NUdf::EDataSlot::String>(pos.File)); @@ -2325,7 +2325,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { TMkqlBuildContext::TArgumentsMap innerArguments; innerArguments.reserve(lambda->Head().ChildrenSize()); lambda->Head().ForEachChild([&](const TExprNode& argNode) { - const auto argType = BuildType(argNode, *argNode.GetTypeAnn(), ctx.ProgramBuilder); + const auto argType = ctx.BuildType(argNode, *argNode.GetTypeAnn()); const auto arg = ctx.ProgramBuilder.Arg(argType); innerArguments.emplace(&argNode, arg); call.Add(arg); @@ -2380,7 +2380,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable({ "AsTagged","Untag" }, [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(node.Head(), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.Nop(input, returnType); }); @@ -2453,9 +2453,9 @@ TMkqlCommonCallableCompiler::TShared::TShared() { YQL_ENSURE(node.ChildrenSize() == 8); std::string_view function = node.Head().Content(); const auto runConfig = MkqlBuildExpr(*node.Child(1), ctx); - const auto userType = BuildType(*node.Child(2), *node.Child(2)->GetTypeAnn(), ctx.ProgramBuilder); + const auto userType = ctx.BuildType(*node.Child(2), *node.Child(2)->GetTypeAnn()); const auto typeConfig = node.Child(3)->Content(); - const auto callableType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto callableType = ctx.BuildType(node, *node.GetTypeAnn()); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); return ctx.ProgramBuilder.TypedUdf(function, callableType, runConfig, userType, typeConfig, pos.File, pos.Row, pos.Column); @@ -2471,7 +2471,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { std::string_view funcName = node.Child(1)->Content(); const auto typeNode = node.Child(2); - const auto funcType = BuildType(*typeNode, *typeNode->GetTypeAnn(), ctx.ProgramBuilder); + const auto funcType = ctx.BuildType(*typeNode, *typeNode->GetTypeAnn()); const auto script = MkqlBuildExpr(*node.Child(3), ctx); const auto pos = ctx.ExprCtx.GetPosition(node.Pos()); return ctx.ProgramBuilder.ScriptUdf(node.Head().Content(), funcName, funcType, script, @@ -2534,7 +2534,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { continue; } - auto mkqlType = BuildType(node, *callableType->GetArguments()[i].Type, ctx.ProgramBuilder); + auto mkqlType = ctx.BuildType(node, *callableType->GetArguments()[i].Type); arg = ctx.ProgramBuilder.NewEmptyOptional(mkqlType); } @@ -2546,7 +2546,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("Callable", [](const TExprNode& node, TMkqlBuildContext& ctx) { - const auto callableType = BuildType(node.Head(), *node.Head().GetTypeAnn(), ctx.ProgramBuilder); + const auto callableType = ctx.BuildType(node.Head(), *node.Head().GetTypeAnn()); return ctx.ProgramBuilder.Callable(callableType, [&](const TArrayRef<const TRuntimeNode>& args) { const auto& lambda = node.Tail(); TMkqlBuildContext::TArgumentsMap innerArguments; @@ -2560,7 +2560,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("PgConst", [](const TExprNode& node, TMkqlBuildContext& ctx) { - auto type = AS_TYPE(TPgType, BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder)); + auto type = AS_TYPE(TPgType, ctx.BuildType(node, *node.GetTypeAnn())); TRuntimeNode typeMod; if (node.ChildrenSize() >= 3) { typeMod = MkqlBuildExpr(*node.Child(2), ctx); @@ -2581,7 +2581,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("PgInternal0", [](const TExprNode& node, TMkqlBuildContext& ctx) { - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgInternal0(returnType); }); @@ -2601,7 +2601,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { } } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgResolvedCall(node.IsCallable("PgResolvedCallCtx"), name, id, args, returnType, rangeFunction); }); @@ -2615,7 +2615,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgResolvedCall(false, procName, procId, args, returnType, false); }); @@ -2628,7 +2628,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockPgResolvedCall(name, id, args, returnType); }); @@ -2642,14 +2642,14 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockPgResolvedCall(procName, procId, args, returnType); }); AddCallable("PgCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); TRuntimeNode typeMod; if (node.ChildrenSize() >= 3) { typeMod = MkqlBuildExpr(*node.Child(2), ctx); @@ -2685,25 +2685,25 @@ TMkqlCommonCallableCompiler::TShared::TShared() { AddCallable("FromPg", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.FromPg(input, returnType); }); AddCallable("ToPg", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.ToPg(input, returnType); }); AddCallable("BlockFromPg", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockFromPg(input, returnType); }); AddCallable("BlockToPg", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(0), ctx); - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockToPg(input, returnType); }); @@ -2726,7 +2726,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); AddCallable("PgTableContent", [](const TExprNode& node, TMkqlBuildContext& ctx) { - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgTableContent( node.Child(0)->Content(), node.Child(1)->Content(), @@ -2754,13 +2754,13 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockFunc(node.Child(0)->Content(), returnType, args); }); AddCallable("BlockBitCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto arg = MkqlBuildExpr(*node.Child(0), ctx); - auto targetType = BuildType(node, *node.Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType(), ctx.ProgramBuilder); + auto targetType = ctx.BuildType(node, *node.Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()); return ctx.ProgramBuilder.BlockBitCast(arg, targetType); }); @@ -2811,7 +2811,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { aggs.push_back(info); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockCombineAll(arg, filterColumn, aggs, returnType); }); @@ -2838,7 +2838,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { aggs.push_back(info); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockCombineHashed(arg, filterColumn, keys, aggs, returnType); }); @@ -2860,7 +2860,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { aggs.push_back(info); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockMergeFinalizeHashed(arg, keys, aggs, returnType); }); @@ -2891,7 +2891,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { } } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.BlockMergeManyFinalizeHashed(arg, keys, aggs, streamIndex, streams, returnType); }); @@ -2913,7 +2913,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } - auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.PgArray(args, returnType); }); @@ -2921,7 +2921,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { const auto initCapacity = MkqlBuildExpr(*node.Child(1), ctx); const auto initSize = MkqlBuildExpr(*node.Child(2), ctx); const auto& args = GetArgumentsFrom<3U>(node, ctx); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.QueueCreate(initCapacity, initSize, args, returnType); }); @@ -2929,7 +2929,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { const auto resource = MkqlBuildExpr(node.Head(), ctx); const auto index = MkqlBuildExpr(*node.Child(1), ctx); const auto& args = GetArgumentsFrom<2U>(node, ctx); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.QueuePeek(resource, index, args, returnType); }); @@ -2938,13 +2938,13 @@ TMkqlCommonCallableCompiler::TShared::TShared() { const auto begin = MkqlBuildExpr(*node.Child(1), ctx); const auto end = MkqlBuildExpr(*node.Child(2), ctx); const auto& args = GetArgumentsFrom<3U>(node, ctx); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.QueueRange(resource, begin, end, args, returnType); }); AddCallable("Seq", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto& args = GetArgumentsFrom<0U>(node, ctx); - const auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + const auto returnType = ctx.BuildType(node, *node.GetTypeAnn()); return ctx.ProgramBuilder.Seq(args, returnType); }); diff --git a/yql/essentials/providers/common/mkql/yql_provider_mkql.h b/yql/essentials/providers/common/mkql/yql_provider_mkql.h index c65494cecb..1cdcfd8262 100644 --- a/yql/essentials/providers/common/mkql/yql_provider_mkql.h +++ b/yql/essentials/providers/common/mkql/yql_provider_mkql.h @@ -1,5 +1,7 @@ #pragma once +#include "yql_type_mkql.h" + #include <yql/essentials/ast/yql_expr.h> #include <yql/essentials/minikql/mkql_node.h> #include <yql/essentials/minikql/mkql_program_builder.h> @@ -17,16 +19,19 @@ struct TMkqlBuildContext { NKikimr::NMiniKQL::TProgramBuilder& ProgramBuilder; TExprContext& ExprCtx; TMemoizedNodesMap Memoization; + TMemoizedTypesMap TypeMemoizationHolder; + TMemoizedTypesMap* TypeMemoization; TMkqlBuildContext *const ParentCtx = nullptr; const size_t Level = 0ULL; const ui64 LambdaId = 0ULL; NKikimr::NMiniKQL::TRuntimeNode Parameters; - TMkqlBuildContext(const IMkqlCallableCompiler& mkqlCompiler, NKikimr::NMiniKQL::TProgramBuilder& builder, TExprContext& exprCtx, ui64 lambdaId = 0ULL, TArgumentsMap&& args = {}) + TMkqlBuildContext(const IMkqlCallableCompiler& mkqlCompiler, NKikimr::NMiniKQL::TProgramBuilder& builder, TExprContext& exprCtx, ui64 lambdaId = 0ULL, TArgumentsMap&& args = {}, TMemoizedTypesMap* typeMemoization = nullptr) : MkqlCompiler(mkqlCompiler) , ProgramBuilder(builder) , ExprCtx(exprCtx) , Memoization(std::move(args)) + , TypeMemoization(typeMemoization ? typeMemoization : &TypeMemoizationHolder) , LambdaId(lambdaId) {} @@ -35,11 +40,16 @@ struct TMkqlBuildContext { , ProgramBuilder(parent.ProgramBuilder) , ExprCtx(parent.ExprCtx) , Memoization(std::move(args)) + , TypeMemoization(parent.TypeMemoization) , ParentCtx(&parent) , Level(parent.Level + 1U) , LambdaId(lambdaId) , Parameters(parent.Parameters) {} + + NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation) { + return NYql::NCommon::BuildType(owner, annotation, ProgramBuilder, *TypeMemoization); + } }; class IMkqlCallableCompiler : public TThrRefBase { diff --git a/yql/essentials/providers/common/mkql/yql_type_mkql.cpp b/yql/essentials/providers/common/mkql/yql_type_mkql.cpp index 0aaf365468..cb001c847b 100644 --- a/yql/essentials/providers/common/mkql/yql_type_mkql.cpp +++ b/yql/essentials/providers/common/mkql/yql_type_mkql.cpp @@ -16,7 +16,9 @@ namespace NYql { namespace NCommon { -NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, IOutputStream& err) { +namespace { + +NKikimr::NMiniKQL::TType* BuildTypeImpl(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization, IOutputStream& err) { switch (annotation.GetKind()) { case ETypeAnnotationKind::Data: { auto data = annotation.Cast<TDataExprType>(); @@ -42,7 +44,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const members.reserve(structObj->GetItems().size()); for (auto& item : structObj->GetItems()) { - auto itemType = BuildType(*item->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*item->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -53,7 +55,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::List: { auto list = annotation.Cast<TListExprType>(); - auto itemType = BuildType(*list->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*list->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -62,7 +64,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Optional: { auto optional = annotation.Cast<TOptionalExprType>(); - auto itemType = BuildType(*optional->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*optional->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -74,7 +76,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const TVector<NKikimr::NMiniKQL::TType*> elements; elements.reserve(tuple->GetItems().size()); for (auto& child : tuple->GetItems()) { - elements.push_back(BuildType(*child, typeBuilder, err)); + elements.push_back(BuildType(*child, typeBuilder, memoization, err)); if (!elements.back()) { return nullptr; } @@ -97,8 +99,8 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Dict: { auto dictType = annotation.Cast<TDictExprType>(); - auto keyType = BuildType(*dictType->GetKeyType(), typeBuilder, err); - auto payloadType = BuildType(*dictType->GetPayloadType(), typeBuilder, err); + auto keyType = BuildType(*dictType->GetKeyType(), typeBuilder, memoization, err); + auto payloadType = BuildType(*dictType->GetPayloadType(), typeBuilder, memoization, err); if (!keyType || !payloadType) { return nullptr; } @@ -107,7 +109,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Type: { auto type = annotation.Cast<TTypeExprType>()->GetType(); - return BuildType(*type, typeBuilder, err); + return BuildType(*type, typeBuilder, memoization, err); } case ETypeAnnotationKind::Void: { @@ -120,10 +122,10 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Callable: { auto callable = annotation.Cast<TCallableExprType>(); - auto returnType = BuildType(*callable->GetReturnType(), typeBuilder, err); + auto returnType = BuildType(*callable->GetReturnType(), typeBuilder, memoization, err); NKikimr::NMiniKQL::TCallableTypeBuilder callableTypeBuilder(typeBuilder.GetTypeEnvironment(), "", returnType); for (auto& child : callable->GetArguments()) { - callableTypeBuilder.Add(BuildType(*child.Type, typeBuilder, err)); + callableTypeBuilder.Add(BuildType(*child.Type, typeBuilder, memoization, err)); if (!child.Name.empty()) { callableTypeBuilder.SetArgumentName(child.Name); } @@ -150,13 +152,13 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Tagged: { auto tagged = annotation.Cast<TTaggedExprType>(); - auto base = BuildType(*tagged->GetBaseType(), typeBuilder, err); + auto base = BuildType(*tagged->GetBaseType(), typeBuilder, memoization, err); return typeBuilder.NewTaggedType(base, tagged->GetTag()); } case ETypeAnnotationKind::Variant: { auto var = annotation.Cast<TVariantExprType>(); - auto underlyingType = BuildType(*var->GetUnderlyingType(), typeBuilder, err); + auto underlyingType = BuildType(*var->GetUnderlyingType(), typeBuilder, memoization, err); if (!underlyingType) { return nullptr; } @@ -165,7 +167,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Stream: { auto stream = annotation.Cast<TStreamExprType>(); - auto itemType = BuildType(*stream->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*stream->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -174,7 +176,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Flow: { auto flow = annotation.Cast<TFlowExprType>(); - auto itemType = BuildType(*flow->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*flow->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -201,7 +203,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Block: { auto block = annotation.Cast<TBlockExprType>(); - auto itemType = BuildType(*block->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*block->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -210,7 +212,7 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::Scalar: { auto scalar = annotation.Cast<TScalarExprType>(); - auto itemType = BuildType(*scalar->GetItemType(), typeBuilder, err); + auto itemType = BuildType(*scalar->GetItemType(), typeBuilder, memoization, err); if (!itemType) { return nullptr; } @@ -221,21 +223,47 @@ NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const case ETypeAnnotationKind::World: case ETypeAnnotationKind::Error: case ETypeAnnotationKind::LastType: + err << "Can't build mkql type for " << annotation.GetKind(); return nullptr; } } -NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder) { +} // unnamed + +NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization, IOutputStream& err) { + if (const auto knownType = memoization.find(&annotation); memoization.cend() != knownType) { + return knownType->second; + } + + return memoization.emplace(&annotation, BuildTypeImpl(annotation, typeBuilder, memoization, err)).first->second; +} + +NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization) { TStringStream err; - auto type = BuildType(annotation, typeBuilder, err); + auto type = BuildType(annotation, typeBuilder, memoization, err); if (!type) { ythrow TNodeException(pos) << err.Str(); } return type; } +NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization) { + return BuildType(owner.Pos(), annotation, typeBuilder, memoization); +} + +NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, IOutputStream& err) { + TMemoizedTypesMap memoization; + return BuildType(annotation, typeBuilder, memoization, err); +} + +NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder) { + TMemoizedTypesMap memoization; + return BuildType(pos, annotation, typeBuilder, memoization); +} + NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder) { - return BuildType(owner.Pos(), annotation, typeBuilder); + TMemoizedTypesMap memoization; + return BuildType(owner, annotation, typeBuilder, memoization); } const TTypeAnnotationNode* ConvertMiniKQLType(TPosition position, NKikimr::NMiniKQL::TType* type, TExprContext& ctx) { diff --git a/yql/essentials/providers/common/mkql/yql_type_mkql.h b/yql/essentials/providers/common/mkql/yql_type_mkql.h index 834b202848..3718e4be8f 100644 --- a/yql/essentials/providers/common/mkql/yql_type_mkql.h +++ b/yql/essentials/providers/common/mkql/yql_type_mkql.h @@ -4,6 +4,8 @@ #include <yql/essentials/ast/yql_pos_handle.h> #include <yql/essentials/minikql/mkql_type_builder.h> +#include <unordered_map> + namespace NKikimr { namespace NMiniKQL { @@ -17,6 +19,12 @@ class TProgramBuilder; namespace NYql { namespace NCommon { +using TMemoizedTypesMap = std::unordered_map<const TTypeAnnotationNode*, NKikimr::NMiniKQL::TType*>; + +NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization, IOutputStream& err); +NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization); +NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, TMemoizedTypesMap& memoization); + NKikimr::NMiniKQL::TType* BuildType(const TTypeAnnotationNode& annotation, const NKikimr::NMiniKQL::TTypeBuilder& typeBuilder, IOutputStream& err); NKikimr::NMiniKQL::TType* BuildType(TPositionHandle pos, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder); NKikimr::NMiniKQL::TType* BuildType(const TExprNode& owner, const TTypeAnnotationNode& annotation, NKikimr::NMiniKQL::TTypeBuilder& typeBuilder); |