diff options
author | vvvv <[email protected]> | 2025-05-14 00:34:38 +0300 |
---|---|---|
committer | vvvv <[email protected]> | 2025-05-14 00:47:09 +0300 |
commit | 3797381bcf6a3df56ff84fb4a6aff8708846ad48 (patch) | |
tree | 82385321011fb05b9adbfec90549446fb4520c41 | |
parent | a3ebca5b6aa6e745e856eb735ee81906b0165391 (diff) |
YQL-19943 supported custom callable typecheck inside evaluation pipeline
commit_hash:70c07242fe9b14ecf016dac8184fca9f7f01d56f
7 files changed, 55 insertions, 16 deletions
diff --git a/yql/essentials/core/services/yql_eval_expr.cpp b/yql/essentials/core/services/yql_eval_expr.cpp index bc39a9e9c83..fa3417e40d2 100644 --- a/yql/essentials/core/services/yql_eval_expr.cpp +++ b/yql/essentials/core/services/yql_eval_expr.cpp @@ -376,7 +376,8 @@ TExprNode::TPtr QuoteCode(const TExprNode::TPtr& node, TExprContext& ctx, TNodeO } IGraphTransformer::TStatus EvaluateExpression(const TExprNode::TPtr& input, TExprNode::TPtr& output, - TTypeAnnotationContext& types, TExprContext& ctx, const IFunctionRegistry& functionRegistry, IGraphTransformer* calcTransfomer) { + TTypeAnnotationContext& types, TExprContext& ctx, const IFunctionRegistry& functionRegistry, + IGraphTransformer* calcTransfomer, TTypeAnnCallableFactory typeAnnCallableFactory) { output = input; if (ctx.Step.IsDone(TExprStep::ExprEval)) return IGraphTransformer::TStatus::Ok; @@ -391,7 +392,7 @@ IGraphTransformer::TStatus EvaluateExpression(const TExprNode::TPtr& input, TExp bool isOptionalAtom = false; bool isTypePipeline = false; bool isCodePipeline = false; - TTransformationPipeline pipeline(&types); + TTransformationPipeline pipeline(&types, typeAnnCallableFactory); pipeline.AddServiceTransformers(); pipeline.AddPreTypeAnnotation(); pipeline.AddExpressionEvaluation(functionRegistry); diff --git a/yql/essentials/core/services/yql_eval_expr.h b/yql/essentials/core/services/yql_eval_expr.h index 93f46778922..d61b8acf73d 100644 --- a/yql/essentials/core/services/yql_eval_expr.h +++ b/yql/essentials/core/services/yql_eval_expr.h @@ -16,6 +16,7 @@ class IFunctionRegistry; namespace NYql { IGraphTransformer::TStatus EvaluateExpression(const TExprNode::TPtr& input, TExprNode::TPtr& output, TTypeAnnotationContext& types, TExprContext& ctx, - const NKikimr::NMiniKQL::IFunctionRegistry& functionRegistry, IGraphTransformer* calcTransfomer = nullptr); + const NKikimr::NMiniKQL::IFunctionRegistry& functionRegistry, + IGraphTransformer* calcTransfomer = nullptr, TTypeAnnCallableFactory typeAnnCallableFactory = {}); } diff --git a/yql/essentials/core/services/yql_transform_pipeline.cpp b/yql/essentials/core/services/yql_transform_pipeline.cpp index c6bb45a2189..bd7a64673f6 100644 --- a/yql/essentials/core/services/yql_transform_pipeline.cpp +++ b/yql/essentials/core/services/yql_transform_pipeline.cpp @@ -19,8 +19,13 @@ namespace NYql { -TTransformationPipeline::TTransformationPipeline(TIntrusivePtr<TTypeAnnotationContext> ctx) +TTransformationPipeline::TTransformationPipeline( + TIntrusivePtr<TTypeAnnotationContext> ctx, + TTypeAnnCallableFactory typeAnnCallableFactory) : TypeAnnotationContext_(ctx) + , TypeAnnCallableFactory_(typeAnnCallableFactory ? typeAnnCallableFactory : [ctx = ctx.Get()](){ + return CreateExtCallableTypeAnnotationTransformer(*ctx); + }) {} TTransformationPipeline& TTransformationPipeline::Add(TAutoPtr<IGraphTransformer> transformer, const TString& stageName, @@ -58,9 +63,10 @@ TTransformationPipeline& TTransformationPipeline::AddExpressionEvaluation(const IGraphTransformer* calcTransfomer, EYqlIssueCode issueCode) { auto& typeCtx = *TypeAnnotationContext_; auto& funcReg = functionRegistry; + auto typeAnnCallableFactory = TypeAnnCallableFactory_; Transformers_.push_back(TTransformStage(CreateFunctorTransformer( - [&typeCtx, &funcReg, calcTransfomer](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { - return EvaluateExpression(input, output, typeCtx, ctx, funcReg, calcTransfomer); + [&typeCtx, &funcReg, calcTransfomer, typeAnnCallableFactory](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { + return EvaluateExpression(input, output, typeCtx, ctx, funcReg, calcTransfomer, typeAnnCallableFactory); }), "EvaluateExpression", issueCode)); return *this; @@ -138,10 +144,6 @@ TTransformationPipeline& TTransformationPipeline::AddPostTypeAnnotation(bool for } TTransformationPipeline& TTransformationPipeline::AddCommonOptimization(bool forPeephole, EYqlIssueCode issueCode) { - // auto instantCallableTransformer = - // CreateExtCallableTypeAnnotationTransformer(*TypeAnnotationContext_, true); - // TypeAnnotationContext_->CustomInstantTypeTransformer = - // CreateTypeAnnotationTransformer(instantCallableTransformer, *TypeAnnotationContext_); Transformers_.push_back(TTransformStage( CreateCommonOptTransformer(forPeephole, TypeAnnotationContext_.Get()), "CommonOptimization", @@ -271,7 +273,7 @@ TTransformationPipeline& TTransformationPipeline::AddTypeAnnotationTransformer( } TTransformationPipeline& TTransformationPipeline::AddTypeAnnotationTransformerWithMode(EYqlIssueCode issueCode, ETypeCheckMode mode) { - auto callableTransformer = CreateExtCallableTypeAnnotationTransformer(*TypeAnnotationContext_); + auto callableTransformer = TypeAnnCallableFactory_(); AddTypeAnnotationTransformer(callableTransformer, issueCode, mode); return *this; } @@ -279,11 +281,11 @@ TTransformationPipeline& TTransformationPipeline::AddTypeAnnotationTransformerWi TTransformationPipeline& TTransformationPipeline::AddTypeAnnotationTransformer(EYqlIssueCode issueCode, bool twoStages) { if (twoStages) { - std::shared_ptr<IGraphTransformer> callableTransformer(CreateExtCallableTypeAnnotationTransformer(*TypeAnnotationContext_).Release()); + std::shared_ptr<IGraphTransformer> callableTransformer(TypeAnnCallableFactory_().Release()); AddTypeAnnotationTransformer(MakeSharedTransformerProxy(callableTransformer), issueCode, ETypeCheckMode::Initial); AddTypeAnnotationTransformer(MakeSharedTransformerProxy(callableTransformer), issueCode, ETypeCheckMode::Repeat); } else { - auto callableTransformer = CreateExtCallableTypeAnnotationTransformer(*TypeAnnotationContext_); + auto callableTransformer = TypeAnnCallableFactory_(); AddTypeAnnotationTransformer(callableTransformer, issueCode, ETypeCheckMode::Single); } diff --git a/yql/essentials/core/services/yql_transform_pipeline.h b/yql/essentials/core/services/yql_transform_pipeline.h index 83250336db0..3e6dbdb4cf1 100644 --- a/yql/essentials/core/services/yql_transform_pipeline.h +++ b/yql/essentials/core/services/yql_transform_pipeline.h @@ -21,7 +21,8 @@ namespace NYql { class TTransformationPipeline { public: - TTransformationPipeline(TIntrusivePtr<TTypeAnnotationContext> ctx); + TTransformationPipeline(TIntrusivePtr<TTypeAnnotationContext> ctx, + TTypeAnnCallableFactory typeAnnCallableFactory = {}); TTransformationPipeline& AddServiceTransformers(EYqlIssueCode issueCode = TIssuesIds::CORE_GC); TTransformationPipeline& AddParametersEvaluation(const NKikimr::NMiniKQL::IFunctionRegistry& functionRegistry, EYqlIssueCode issueCode = TIssuesIds::CORE_PARAM_EVALUATION); @@ -59,6 +60,7 @@ public: private: TIntrusivePtr<TTypeAnnotationContext> TypeAnnotationContext_; + TTypeAnnCallableFactory TypeAnnCallableFactory_; TVector<TTransformStage> Transformers_; }; diff --git a/yql/essentials/core/yql_type_annotation.h b/yql/essentials/core/yql_type_annotation.h index 66148f8caf2..3f47e54ae04 100644 --- a/yql/essentials/core/yql_type_annotation.h +++ b/yql/essentials/core/yql_type_annotation.h @@ -27,10 +27,13 @@ #include <util/generic/vector.h> #include <util/digest/city.h> +#include <functional> #include <vector> namespace NYql { +using TTypeAnnCallableFactory = std::function<TAutoPtr<IGraphTransformer>()>; + class IUrlLoader : public TThrRefBase { public: ~IUrlLoader() = default; diff --git a/yql/essentials/public/purecalc/common/worker_factory.cpp b/yql/essentials/public/purecalc/common/worker_factory.cpp index cf6b79d4a82..58af0959c23 100644 --- a/yql/essentials/public/purecalc/common/worker_factory.cpp +++ b/yql/essentials/public/purecalc/common/worker_factory.cpp @@ -320,7 +320,11 @@ TExprNode::TPtr TWorkerFactory<TBase>::Compile( ? PurecalcBlockInputCallableName : PurecalcInputCallableName); - TTransformationPipeline pipeline(typeContext); + TTypeAnnCallableFactory typeAnnCallableFactory = [&]() { + return MakeTypeAnnotationTransformer(typeContext, InputTypes_, RawInputTypes_, processorMode, selfName); + }; + + TTransformationPipeline pipeline(typeContext, typeAnnCallableFactory); pipeline.Add(MakeTableReadsReplacer(InputTypes_, UseSystemColumns_, processorMode, selfName), "ReplaceTableReads", EYqlIssueCode::TIssuesIds_EIssueCode_DEFAULT_ERROR, @@ -329,7 +333,7 @@ TExprNode::TPtr TWorkerFactory<TBase>::Compile( pipeline.AddPreTypeAnnotation(); pipeline.AddExpressionEvaluation(*FuncRegistry_, calcTransformer.Get()); pipeline.AddIOAnnotation(); - pipeline.AddTypeAnnotationTransformer(MakeTypeAnnotationTransformer(typeContext, InputTypes_, RawInputTypes_, processorMode, selfName)); + pipeline.AddTypeAnnotationTransformer(); pipeline.AddPostTypeAnnotation(); pipeline.Add(CreateFunctorTransformer( [&](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { diff --git a/yql/essentials/public/purecalc/ut/test_eval.cpp b/yql/essentials/public/purecalc/ut/test_eval.cpp index 38ad7cc952d..cbb5641ddd3 100644 --- a/yql/essentials/public/purecalc/ut/test_eval.cpp +++ b/yql/essentials/public/purecalc/ut/test_eval.cpp @@ -27,4 +27,30 @@ Y_UNIT_TEST_SUITE(TestEval) { UNIT_ASSERT_EQUAL(message->GetX(), "foobar"); UNIT_ASSERT(!stream->Fetch()); } + + Y_UNIT_TEST(TestSelfType) { + using namespace NYql::NPureCalc; + + auto options = TProgramFactoryOptions(); + auto factory = MakeProgramFactory(options); + + try { + auto program = factory->MakePullListProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + "$input = PROCESS Input;select unwrap(cast(FormatType(EvaluateType(TypeHandle(TypeOf($input)))) AS Utf8)) AS X", + ETranslationMode::SQL + ); + + auto stream = program->Apply(EmptyStream<NPureCalcProto::TStringMessage*>()); + + NPureCalcProto::TStringMessage* message; + + UNIT_ASSERT(message = stream->Fetch()); + UNIT_ASSERT_VALUES_EQUAL(message->GetX(), "List<Struct<'X':Utf8>>"); + UNIT_ASSERT(!stream->Fetch()); + } catch (const TCompileError& e) { + UNIT_FAIL(e.GetIssues()); + } + } } |