diff options
author | imunkin <imunkin@yandex-team.com> | 2025-05-28 16:10:50 +0300 |
---|---|---|
committer | imunkin <imunkin@yandex-team.com> | 2025-05-28 16:42:48 +0300 |
commit | 2dcd562c7f9b65363b283af921529a7cc7aeeac6 (patch) | |
tree | e9f83450a2af6452e83b0ce60672dc37f3a5a001 | |
parent | af61751739c03fb682da5db22a7e84ec5f17be42 (diff) | |
download | ydb-2dcd562c7f9b65363b283af921529a7cc7aeeac6.tar.gz |
YQL-19967: Introduce TExtendedArgsWrapper helper
commit_hash:8aa01a548ffd87f8f1f6aa6df7eeddb66dad1a27
-rw-r--r-- | yql/essentials/minikql/comp_nodes/mkql_udf.cpp | 75 | ||||
-rw-r--r-- | yql/essentials/minikql/comp_nodes/ut/mkql_udf_ut.cpp | 256 |
2 files changed, 314 insertions, 17 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_udf.cpp b/yql/essentials/minikql/comp_nodes/mkql_udf.cpp index e792b327382..8230a759d6b 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_udf.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_udf.cpp @@ -37,12 +37,14 @@ public: TString&& typeConfig, NUdf::TSourcePosition pos, const TCallableType* callableType, + const TCallableType* functionType, TType* userType) : TBaseComputation(mutables, EValueRepresentation::Boxed) , FunctionName(std::move(functionName)) , TypeConfig(std::move(typeConfig)) , Pos(pos) , CallableType(callableType) + , FunctionType(functionType) , UserType(userType) { this->Stateless = false; @@ -65,16 +67,55 @@ public: } NUdf::TUnboxedValue udf(NUdf::TUnboxedValuePod(funcInfo.Implementation.Release())); - TValidate<TValidatePolicy,TValidateMode>::WrapCallable(CallableType, udf, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">"); + TValidate<TValidatePolicy,TValidateMode>::WrapCallable(FunctionType, udf, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">"); + ExtendArgs(udf, CallableType, funcInfo.FunctionType); return udf.Release(); } private: + // xXX: This class implements the wrapper to properly handle + // the case when the signature of the emitted callable (i.e. + // callable type) requires less arguments than the actual + // function (i.e. function type). It wraps the unboxed value + // with the resolved UDF to introduce the bridge in the + // Run chain, preparing the valid argument vector for the + // chosen UDF implementation. + class TExtendedArgsWrapper: public NUdf::TBoxedValue { + public: + TExtendedArgsWrapper(NUdf::TUnboxedValue&& callable, size_t usedArgs, size_t requiredArgs) + : Callable_(callable) + , UsedArgs_(usedArgs) + , RequiredArgs_(requiredArgs) + {}; + + private: + NUdf::TUnboxedValue Run(const NUdf::IValueBuilder* valueBuilder, const NUdf::TUnboxedValuePod* args) const final { + NStackArray::TStackArray<NUdf::TUnboxedValue> values(ALLOC_ON_STACK(NUdf::TUnboxedValue, RequiredArgs_)); + for (size_t i = 0; i < UsedArgs_; i++) { + values[i] = args[i]; + } + return Callable_.Run(valueBuilder, values.data()); + } + + const NUdf::TUnboxedValue Callable_; + const size_t UsedArgs_; + const size_t RequiredArgs_; + }; + + void ExtendArgs(NUdf::TUnboxedValue& callable, const TCallableType* callableType, const TCallableType* functionType) const { + const auto callableArgc = callableType->GetArgumentsCount(); + const auto functionArgc = functionType->GetArgumentsCount(); + if (callableArgc < functionArgc) { + callable = NUdf::TUnboxedValuePod(new TExtendedArgsWrapper(std::move(callable), callableArgc, functionArgc)); + } + } + void RegisterDependencies() const final {} const TString FunctionName; const TString TypeConfig; const NUdf::TSourcePosition Pos; const TCallableType *const CallableType; + const TCallableType *const FunctionType; TType *const UserType; }; @@ -90,12 +131,13 @@ public: TString&& typeConfig, NUdf::TSourcePosition pos, const TCallableType* callableType, + const TCallableType* functionType, TType* userType, TString&& moduleIRUniqID, TString&& moduleIR, TString&& fuctioNameIR, NUdf::TUniquePtr<NUdf::IBoxedValue>&& impl) - : TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, userType) + : TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, functionType, userType) , ModuleIRUniqID(std::move(moduleIRUniqID)) , ModuleIR(std::move(moduleIR)) , IRFunctionName(std::move(fuctioNameIR)) @@ -138,7 +180,7 @@ public: NUdf::TSourcePosition pos, IComputationNode* runConfigNode, ui32 runConfigArgs, - const TCallableType* callableType, + const TCallableType* functionType, TType* userType) : TBaseComputation(mutables, EValueRepresentation::Boxed) , FunctionName(std::move(functionName)) @@ -146,7 +188,7 @@ public: , Pos(pos) , RunConfigNode(runConfigNode) , RunConfigArgs(runConfigArgs) - , CallableType(callableType) + , FunctionType(functionType) , UserType(userType) , UdfIndex(mutables.CurValueIndex++) { @@ -238,7 +280,7 @@ private: } void Wrap(NUdf::TUnboxedValue& callable) const { - TValidate<TValidatePolicy,TValidateMode>::WrapCallable(CallableType, callable, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">"); + TValidate<TValidatePolicy,TValidateMode>::WrapCallable(FunctionType, callable, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">"); } void RegisterDependencies() const final { @@ -250,7 +292,7 @@ private: const NUdf::TSourcePosition Pos; IComputationNode* const RunConfigNode; const ui32 RunConfigArgs; - const TCallableType* CallableType; + const TCallableType* FunctionType; TType* const UserType; const ui32 UdfIndex; }; @@ -317,6 +359,8 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont << status.GetError()).c_str()); } + const auto callableFuncType = AS_TYPE(TCallableType, funcInfo.FunctionType); + const auto callableNodeType = AS_TYPE(TCallableType, callable.GetType()->GetReturnType()); const auto runConfigFuncType = funcInfo.RunConfigType; const auto runConfigNodeType = runCfgNode.GetStaticType(); @@ -338,9 +382,6 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont << TruncateTypeDiff(diff)).c_str()); } - const auto callableFuncType = AS_TYPE(TCallableType, funcInfo.FunctionType); - const auto callableNodeType = AS_TYPE(TCallableType, callable.GetType()->GetReturnType()); - const auto callableType = runConfigNodeType->IsVoid() ? callableNodeType : callableFuncType; const auto runConfigType = runConfigNodeType->IsVoid() @@ -396,13 +437,13 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont const auto runConfigCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode()); const auto runConfigArgs = funcInfo.FunctionType->GetArgumentsCount(); return runConfigNodeType->IsVoid() - ? CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType) - : CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, funcInfo.FunctionType, userType); + ? CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, callableNodeType, callableFuncType, userType) + : CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runConfigCompNode, runConfigArgs, callableFuncType, userType); } - if (!funcInfo.FunctionType->IsConvertableTo(*callable.GetType()->GetReturnType(), true)) { - TString diff = TStringBuilder() << "type mismatch, expected return type: " << PrintNode(callable.GetType()->GetReturnType(), true) << - ", actual:" << PrintNode(funcInfo.FunctionType, true); + if (!callableFuncType->IsConvertableTo(*callableNodeType, true)) { + TString diff = TStringBuilder() << "type mismatch, expected return type: " << PrintNode(callableNodeType, true) << + ", actual:" << PrintNode(callableFuncType, true); UdfTerminate((TStringBuilder() << pos << " UDF Function '" << funcName << "' " << TruncateTypeDiff(diff)).c_str()); } @@ -413,15 +454,15 @@ IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryCont if (runConfigFuncType->IsVoid()) { if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName) { return new TUdfRunCodegeneratorNode( - ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType, + ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, callableNodeType, callableFuncType, userType, std::move(funcInfo.ModuleIRUniqID), std::move(funcInfo.ModuleIR), std::move(funcInfo.IRFunctionName), std::move(funcInfo.Implementation) ); } - return CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType); + return CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, callableNodeType, callableFuncType, userType); } const auto runCfgCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode()); - return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, 1U, funcInfo.FunctionType, userType); + return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, 1U, callableFuncType, userType); } IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFactoryContext& ctx) { diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_udf_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_udf_ut.cpp index 39e96baa5d1..265ad3bd059 100644 --- a/yql/essentials/minikql/comp_nodes/ut/mkql_udf_ut.cpp +++ b/yql/essentials/minikql/comp_nodes/ut/mkql_udf_ut.cpp @@ -127,6 +127,79 @@ private: SIMPLE_MODULE(TRunConfigUTModule, TRunConfig) SIMPLE_MODULE(TCurryingUTModule, TCurrying) + +SIMPLE_STRICT_UDF(TTest, char*(char*, char*, char*)) { + TStringStream concat; + concat << args[0].AsStringRef() << " " + << args[1].AsStringRef() << " " + << args[2].AsStringRef(); + return valueBuilder->NewString(NYql::NUdf::TStringRef(concat.Data(), + concat.Size())); +} + +template<bool Old> +class TNewTest : public NYql::NUdf::TBoxedValue { +public: + explicit TNewTest(NYql::NUdf::TSourcePosition pos) + : Pos_(pos) + {} + + static const NYql::NUdf::TStringRef& Name() { + static auto name = NYql::NUdf::TStringRef::Of("Test"); + return name; + } + + static bool DeclareSignature(const NYql::NUdf::TStringRef& name, + NYql::NUdf::TType*, + NYql::NUdf::IFunctionTypeInfoBuilder& builder, + bool typesOnly) + { + if (Name() != name) { + return false; + } + + if (Old && typesOnly) { + builder.SimpleSignature<char*(char*, char*, char*)>(); + return true; + } + + builder.SimpleSignature<char*(char*, char*, char*, NYql::NUdf::TOptional<char*>)>() + .OptionalArgs(1); + if (!typesOnly) { + builder.Implementation(new TNewTest(builder.GetSourcePosition())); + } + + return true; + } + + NYql::NUdf::TUnboxedValue Run(const NYql::NUdf::IValueBuilder* valueBuilder, + const NYql::NUdf::TUnboxedValuePod* args) + const override try { + TStringStream concat; + concat << args[0].AsStringRef() << " " + << args[1].AsStringRef() << " "; + if (args[3]) { + concat << args[3].AsStringRef() << " "; + } + concat << args[2].AsStringRef(); + return valueBuilder->NewString(NYql::NUdf::TStringRef(concat.Data(), + concat.Size())); + } catch (const std::exception& e) { + UdfTerminate((TStringBuilder() << Pos_ << " " << e.what()).data()); + } + +private: + const NYql::NUdf::TSourcePosition Pos_; +}; + +// XXX: "Old" UDF is declared via SIMPLE_UDF helper, so it has to +// use the *actual* function name as a class name. Furthermore, +// the UDF, declared by SIMPLE_UDF has to provide the same +// semantics as TNewTest<true>. +SIMPLE_MODULE(TOldUTModule, TTest) +SIMPLE_MODULE(TIncrementalUTModule, TNewTest<true>) +SIMPLE_MODULE(TNewUTModule, TNewTest<false>) + Y_UNIT_TEST_SUITE(TMiniKQLUdfTest) { Y_UNIT_TEST_LLVM(RunconfigToCurrying) { // Create the test setup, using TRunConfig implementation @@ -224,6 +297,189 @@ Y_UNIT_TEST_SUITE(TMiniKQLUdfTest) { UNIT_ASSERT_STRINGS_EQUAL(TStringBuf(result.AsStringRef()), "Canary is alive"); UNIT_ASSERT(!iterator.Next(result)); } + + Y_UNIT_TEST_LLVM(OldToIncremental) { + // Create the test setup, using the old implementation for + // TestModule.Test UDF. + TVector<TUdfModuleInfo> compileModules; + compileModules.emplace_back( + TUdfModuleInfo{"", "TestModule", new TOldUTModule()} + ); + TSetup<LLVM> compileSetup(GetTestFactory(), std::move(compileModules)); + TProgramBuilder& pb = *compileSetup.PgmBuilder; + + // Build the graph, using the old setup. + const auto strType = pb.NewDataType(NUdf::TDataType<char*>::Id); + const auto arg1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("Canary"); + const auto arg2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("is"); + const auto arg3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("alive"); + + const auto udf = pb.Udf("TestModule.Test"); + const auto argsType = pb.NewTupleType({strType, strType, strType}); + const auto argList = pb.NewList(argsType, {pb.NewTuple({arg1, arg2, arg3})}); + const auto pgmReturn = pb.Map(argList, [&pb, udf](const TRuntimeNode args) { + return pb.Apply(udf, {pb.Nth(args, 0), pb.Nth(args, 1), pb.Nth(args, 2)}); + }); + + // Create the test setup, using the incremental + // implementation for TestModule.Test UDF. + TVector<TUdfModuleInfo> runModules; + runModules.emplace_back( + TUdfModuleInfo{"", "TestModule", new TIncrementalUTModule()} + ); + TSetup<LLVM> runSetup(GetTestFactory(), std::move(runModules)); + // Move the graph from the one setup to another as a + // serialized bytecode sequence. + const auto bytecode = SerializeRuntimeNode(pgmReturn, *compileSetup.Env); + const auto root = DeserializeRuntimeNode(bytecode, *runSetup.Env); + + // Run the graph, using the incremental setup. + const auto graph = runSetup.BuildGraph(root); + const auto iterator = graph->GetValue().GetListIterator(); + + NUdf::TUnboxedValue result; + UNIT_ASSERT(iterator.Next(result)); + UNIT_ASSERT_STRINGS_EQUAL(TStringBuf(result.AsStringRef()), "Canary is alive"); + UNIT_ASSERT(!iterator.Next(result)); + } + + Y_UNIT_TEST_LLVM(IncrementalToOld) { + // Create the test setup, using the incremental + // implementation for TestModule.Test UDF. + TVector<TUdfModuleInfo> compileModules; + compileModules.emplace_back( + TUdfModuleInfo{"", "TestModule", new TIncrementalUTModule()} + ); + TSetup<LLVM> compileSetup(GetTestFactory(), std::move(compileModules)); + TProgramBuilder& pb = *compileSetup.PgmBuilder; + + // Build the graph, using the incremental setup. + const auto strType = pb.NewDataType(NUdf::TDataType<char*>::Id); + const auto arg1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("Canary"); + const auto arg2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("is"); + const auto arg3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("alive"); + + const auto udf = pb.Udf("TestModule.Test"); + const auto argsType = pb.NewTupleType({strType, strType, strType}); + const auto argList = pb.NewList(argsType, {pb.NewTuple({arg1, arg2, arg3})}); + const auto pgmReturn = pb.Map(argList, [&pb, udf](const TRuntimeNode args) { + return pb.Apply(udf, {pb.Nth(args, 0), pb.Nth(args, 1), pb.Nth(args, 2)}); + }); + + // Create the test setup, using the old implementation for + // TestModule.Test UDF. + TVector<TUdfModuleInfo> runModules; + runModules.emplace_back( + TUdfModuleInfo{"", "TestModule", new TOldUTModule()} + ); + TSetup<LLVM> runSetup(GetTestFactory(), std::move(runModules)); + // Move the graph from the one setup to another as a + // serialized bytecode sequence. + const auto bytecode = SerializeRuntimeNode(pgmReturn, *compileSetup.Env); + const auto root = DeserializeRuntimeNode(bytecode, *runSetup.Env); + + // Run the graph, using the old setup. + const auto graph = runSetup.BuildGraph(root); + const auto iterator = graph->GetValue().GetListIterator(); + + NUdf::TUnboxedValue result; + UNIT_ASSERT(iterator.Next(result)); + UNIT_ASSERT_STRINGS_EQUAL(TStringBuf(result.AsStringRef()), "Canary is alive"); + UNIT_ASSERT(!iterator.Next(result)); + } + + Y_UNIT_TEST_LLVM(IncrementalToNew) { + // Create the test setup, using the incremental + // implementation for TestModule.Test UDF. + TVector<TUdfModuleInfo> compileModules; + compileModules.emplace_back( + TUdfModuleInfo{"", "TestModule", new TIncrementalUTModule()} + ); + TSetup<LLVM> compileSetup(GetTestFactory(), std::move(compileModules)); + TProgramBuilder& pb = *compileSetup.PgmBuilder; + + // Build the graph, using the incremental setup. + const auto strType = pb.NewDataType(NUdf::TDataType<char*>::Id); + const auto arg1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("Canary"); + const auto arg2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("is"); + const auto arg3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("alive"); + + const auto udf = pb.Udf("TestModule.Test"); + const auto argsType = pb.NewTupleType({strType, strType, strType}); + const auto argList = pb.NewList(argsType, {pb.NewTuple({arg1, arg2, arg3})}); + const auto pgmReturn = pb.Map(argList, [&pb, udf](const TRuntimeNode args) { + return pb.Apply(udf, {pb.Nth(args, 0), pb.Nth(args, 1), pb.Nth(args, 2)}); + }); + + // Create the test setup, using the new implementation for + // TestModule.Test UDF. + TVector<TUdfModuleInfo> runModules; + runModules.emplace_back( + TUdfModuleInfo{"", "TestModule", new TNewUTModule()} + ); + TSetup<LLVM> runSetup(GetTestFactory(), std::move(runModules)); + // Move the graph from the one setup to another as a + // serialized bytecode sequence. + const auto bytecode = SerializeRuntimeNode(pgmReturn, *compileSetup.Env); + const auto root = DeserializeRuntimeNode(bytecode, *runSetup.Env); + + // Run the graph, using the new setup. + const auto graph = runSetup.BuildGraph(root); + const auto iterator = graph->GetValue().GetListIterator(); + + NUdf::TUnboxedValue result; + UNIT_ASSERT(iterator.Next(result)); + UNIT_ASSERT_STRINGS_EQUAL(TStringBuf(result.AsStringRef()), "Canary is alive"); + UNIT_ASSERT(!iterator.Next(result)); + } + + Y_UNIT_TEST_LLVM(NewToIncremental) { + // Create the test setup, using the new implementation for + // TestModule.Test UDF. + TVector<TUdfModuleInfo> compileModules; + compileModules.emplace_back( + TUdfModuleInfo{"", "TestModule", new TNewUTModule()} + ); + TSetup<LLVM> compileSetup(GetTestFactory(), std::move(compileModules)); + TProgramBuilder& pb = *compileSetup.PgmBuilder; + + // Build the graph, using the new setup. + const auto strType = pb.NewDataType(NUdf::TDataType<char*>::Id); + const auto optType = pb.NewOptionalType(strType); + const auto arg1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("Canary"); + const auto arg2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("is"); + const auto arg3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("alive"); + const auto arg4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("still"); + const auto opt4 = pb.NewOptional(arg4); + + const auto udf = pb.Udf("TestModule.Test"); + const auto argsType = pb.NewTupleType({strType, strType, strType, optType}); + const auto argList = pb.NewList(argsType, {pb.NewTuple({arg1, arg2, arg3, opt4})}); + const auto pgmReturn = pb.Map(argList, [&pb, udf](const TRuntimeNode args) { + return pb.Apply(udf, {pb.Nth(args, 0), pb.Nth(args, 1), pb.Nth(args, 2), pb.Nth(args, 3)}); + }); + + // Create the test setup, using the incremental + // implementation for TestModule.Test UDF. + TVector<TUdfModuleInfo> runModules; + runModules.emplace_back( + TUdfModuleInfo{"", "TestModule", new TIncrementalUTModule()} + ); + TSetup<LLVM> runSetup(GetTestFactory(), std::move(runModules)); + // Move the graph from the one setup to another as a + // serialized bytecode sequence. + const auto bytecode = SerializeRuntimeNode(pgmReturn, *compileSetup.Env); + const auto root = DeserializeRuntimeNode(bytecode, *runSetup.Env); + + // Run the graph, using the incremental setup. + const auto graph = runSetup.BuildGraph(root); + const auto iterator = graph->GetValue().GetListIterator(); + + NUdf::TUnboxedValue result; + UNIT_ASSERT(iterator.Next(result)); + UNIT_ASSERT_STRINGS_EQUAL(TStringBuf(result.AsStringRef()), "Canary is still alive"); + UNIT_ASSERT(!iterator.Next(result)); + } } // Y_UNIT_TEST_SUITE } // namespace NMiniKQL |