diff options
author | vvvv <vvvv@yandex-team.ru> | 2022-03-02 01:55:04 +0300 |
---|---|---|
committer | vvvv <vvvv@yandex-team.ru> | 2022-03-02 01:55:04 +0300 |
commit | 26286f616cee657612a9d820be6da2cdbd4de0ef (patch) | |
tree | 07cb52f2682e92fd23a00bca9553c57481cf76a8 | |
parent | 2ba098f9946ccfd045204904c5a6c51d240462e0 (diff) | |
download | ydb-26286f616cee657612a9d820be6da2cdbd4de0ef.tar.gz |
YQL-13710 support for PgOper
ref:fcfb05393bd7e73bbb1087d25bad8d863c14c6fd
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_core.cpp | 63 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/catalog.cpp | 120 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/catalog.h | 6 | ||||
-rw-r--r-- | ydb/library/yql/providers/common/mkql/CMakeLists.txt | 1 | ||||
-rw-r--r-- | ydb/library/yql/providers/common/mkql/ya.make | 1 | ||||
-rw-r--r-- | ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp | 17 | ||||
-rw-r--r-- | ydb/library/yql/sql/pg/pg_sql.cpp | 36 |
7 files changed, 214 insertions, 30 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 30bb3b63e4..356414974e 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -8774,7 +8774,6 @@ template <NKikimr::NUdf::EDataSlot DataSlot> }; IGraphTransformer::TStatus PgCallWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { - Y_UNUSED(output); bool isResolved = input->Content() == "PgResolvedCall"; if (!EnsureMinArgsCount(*input, isResolved ? 2 : 1, ctx.Expr)) { return IGraphTransformer::TStatus::Error; @@ -8896,6 +8895,66 @@ template <NKikimr::NUdf::EDataSlot DataSlot> } } + IGraphTransformer::TStatus PgOpWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { + bool isResolved = input->IsCallable("PgResolvedOp"); + if (!EnsureMinArgsCount(*input, isResolved ? 3 : 2, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureMaxArgsCount(*input, isResolved ? 4 : 3, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureAtom(input->Head(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto name = input->Head().Content(); + if (isResolved) { + if (!EnsureAtom(*input->Child(1), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } + + TVector<ui32> argTypes; + for (ui32 i = isResolved ? 2 : 1; i < input->ChildrenSize(); ++i) { + auto type = input->Child(i)->GetTypeAnn(); + if (type->GetKind() == ETypeAnnotationKind::Null) { + argTypes.push_back(0); + continue; + } + + if (type->GetKind() != ETypeAnnotationKind::Pg) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Expected PG type for argument " << (i - (isResolved ? 2 : 1) + 1) << ", but got: " << type->GetKind() << " for function: " << name)); + return IGraphTransformer::TStatus::Error; + } + + argTypes.push_back(type->Cast<TPgExprType>()->GetId()); + } + + if (isResolved) { + auto operId = FromString<ui32>(input->Child(1)->Content()); + const auto& oper = NPg::LookupOper(operId, argTypes); + if (oper.Name != name) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Mismatch of resolved operator name, expected: " << name << ", but got:" << oper.Name)); + return IGraphTransformer::TStatus::Error; + } + + auto result = ctx.Expr.MakeType<TPgExprType>(oper.ResultType); + input->SetTypeAnn(result); + return IGraphTransformer::TStatus::Ok; + } else { + const auto& oper = NPg::LookupOper(TString(name), argTypes); + auto children = input->ChildrenList(); + auto idNode = ctx.Expr.NewAtom(input->Pos(), ToString(oper.OperId)); + children.insert(children.begin() + 1, idNode); + output = ctx.Expr.NewCallable(input->Pos(), "PgResolvedOp", std::move(children)); + return IGraphTransformer::TStatus::Repeat; + } + } + IGraphTransformer::TStatus PgWindowCallWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { Y_UNUSED(output); if (!EnsureMinArgsCount(*input, 2, ctx.Expr)) { @@ -13320,6 +13379,8 @@ template <NKikimr::NUdf::EDataSlot DataSlot> ExtFunctions["PgCall"] = &PgCallWrapper; ExtFunctions["PgResolvedCall"] = &PgCallWrapper; + ExtFunctions["PgOp"] = &PgOpWrapper; + ExtFunctions["PgResolvedOp"] = &PgOpWrapper; ExtFunctions["PgSelect"] = &PgSelectWrapper; ExtFunctions["PgSetItem"] = &PgSetItemWrapper; ExtFunctions["TablePath"] = &TablePathWrapper; diff --git a/ydb/library/yql/parser/pg_catalog/catalog.cpp b/ydb/library/yql/parser/pg_catalog/catalog.cpp index 79af8b74af..0a40b02065 100644 --- a/ydb/library/yql/parser/pg_catalog/catalog.cpp +++ b/ydb/library/yql/parser/pg_catalog/catalog.cpp @@ -105,9 +105,11 @@ public: class TOperatorsParser : public TParser { public: - TOperatorsParser(TOperators& operators, const THashMap<TString, ui32>& typeByName) + TOperatorsParser(TOperators& operators, const THashMap<TString, ui32>& typeByName, + const THashMap<TString, TVector<ui32>>& procByName) : Operators(operators) , TypeByName(typeByName) + , ProcByName(procByName) {} void OnKey(const TString& key, const TString& value) override { @@ -138,20 +140,34 @@ public: Y_ENSURE(typeIdPtr); LastOperator.ResultType = *typeIdPtr; } else if (key == "oprcode") { - LastOperator.Code = value; + auto procIdPtr = ProcByName.FindPtr(value); + // skip operator if proc isn't buildin, e.g. path_contain_pt + if (!procIdPtr) { + IsSupported = false; + return; + } + + Y_ENSURE(procIdPtr->size() == 1); + LastOperator.ProcId = procIdPtr->at(0); } } void OnFinish() override { - Y_ENSURE(!LastOperator.Name.empty()); - Operators[LastOperator.OperId] = LastOperator; + if (IsSupported) { + Y_ENSURE(!LastOperator.Name.empty()); + Operators[LastOperator.OperId] = LastOperator; + } + LastOperator = TOperDesc(); + IsSupported = true; } private: TOperators& Operators; const THashMap<TString, ui32>& TypeByName; + const THashMap<TString, TVector<ui32>>& ProcByName; TOperDesc LastOperator; + bool IsSupported = true; }; class TProcsParser : public TParser { @@ -370,9 +386,10 @@ private: bool IsSupported = true; }; -TOperators ParseOperators(const TString& dat, const THashMap<TString, ui32>& typeByName) { +TOperators ParseOperators(const TString& dat, const THashMap<TString, ui32>& typeByName, + const THashMap<TString, TVector<ui32>>& procByName) { TOperators ret; - TOperatorsParser parser(ret, typeByName); + TOperatorsParser parser(ret, typeByName, procByName); parser.Do(dat); return ret; } @@ -433,7 +450,6 @@ struct TCatalog { typePtr->ElementTypeId = *elemTypePtr; } - Operators = ParseOperators(opData, TypeByName); Procs = ParseProcs(procData, TypeByName); for (const auto& [k, v]: Procs) { @@ -457,6 +473,11 @@ struct TCatalog { for (const auto&[k, v] : Casts) { Y_ENSURE(CastsByDir.insert(std::make_pair(std::make_pair(v.SourceId, v.TargetId), k)).second); } + + Operators = ParseOperators(opData, TypeByName, ProcByName); + for (const auto&[k, v] : Operators) { + OperatorsByName[v.Name].push_back(k); + } } static const TCatalog& Instance() { @@ -470,9 +491,10 @@ struct TCatalog { THashMap<TString, TVector<ui32>> ProcByName; THashMap<TString, ui32> TypeByName; THashMap<std::pair<ui32, ui32>, ui32> CastsByDir; + THashMap<TString, TVector<ui32>> OperatorsByName; }; -bool ValidateArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) { +bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) { if (argTypeIds.size() != d.ArgTypes.size()) { return false; } @@ -499,8 +521,8 @@ const TProcDesc& LookupProc(ui32 procId, const TVector<ui32>& argTypeIds) { throw yexception() << "No such proc: " << procId; } - if (!ValidateArgs(*procPtr, argTypeIds)) { - throw yexception() << "Unable to find an overload for with oid " << procId << " with given argument types"; + if (!ValidateProcArgs(*procPtr, argTypeIds)) { + throw yexception() << "Unable to find an overload for proc with oid " << procId << " with given argument types"; } return *procPtr; @@ -510,20 +532,20 @@ const TProcDesc& LookupProc(const TString& name, const TVector<ui32>& argTypeIds const auto& catalog = TCatalog::Instance(); auto procIdPtr = catalog.ProcByName.FindPtr(name); if (!procIdPtr) { - throw yexception() << "No such function: " << name; + throw yexception() << "No such proc: " << name; } for (const auto& id : *procIdPtr) { const auto& d = catalog.Procs.FindPtr(id); Y_ENSURE(d); - if (!ValidateArgs(*d, argTypeIds)) { + if (!ValidateProcArgs(*d, argTypeIds)) { continue; } return *d; } - throw yexception() << "Unable to find an overload for function " << name << " with given argument types"; + throw yexception() << "Unable to find an overload for proc " << name << " with given argument types"; } const TProcDesc& LookupProc(ui32 procId) { @@ -585,4 +607,76 @@ const TCastDesc& LookupCast(ui32 castId) { return *castPtr; } +bool ValidateOperArgs(const TOperDesc& d, const TVector<ui32>& argTypeIds) { + ui32 size = d.Kind == EOperKind::Binary ? 2 : 1; + if (argTypeIds.size() != size) { + return false; + } + + bool found = true; + for (size_t i = 0; i < argTypeIds.size(); ++i) { + if (argTypeIds[i] == 0) { + continue; // NULL + } + + ui32 expectedArgType; + if (d.Kind == EOperKind::RightUnary || (d.Kind == EOperKind::Binary && i == 0)) { + expectedArgType = d.LeftType; + } else { + expectedArgType = d.RightType; + } + + if (argTypeIds[i] != expectedArgType) { + found = false; + break; + } + } + + return found; +} + +const TOperDesc& LookupOper(const TString& name, const TVector<ui32>& argTypeIds) { + const auto& catalog = TCatalog::Instance(); + auto operIdPtr = catalog.OperatorsByName.FindPtr(name); + if (!operIdPtr) { + throw yexception() << "No such operator: " << name; + } + + for (const auto& id : *operIdPtr) { + const auto& d = catalog.Operators.FindPtr(id); + Y_ENSURE(d); + if (!ValidateOperArgs(*d, argTypeIds)) { + continue; + } + + return *d; + } + + throw yexception() << "Unable to find an overload for operator " << name << " with given argument types"; +} + +const TOperDesc& LookupOper(ui32 operId, const TVector<ui32>& argTypeIds) { + const auto& catalog = TCatalog::Instance(); + auto operPtr = catalog.Operators.FindPtr(operId); + if (!operPtr) { + throw yexception() << "No such oper: " << operId; + } + + if (!ValidateOperArgs(*operPtr, argTypeIds)) { + throw yexception() << "Unable to find an overload for operator with oid " << operId << " with given argument types"; + } + + return *operPtr; +} + +const TOperDesc& LookupOper(ui32 operId) { + const auto& catalog = TCatalog::Instance(); + auto operPtr = catalog.Operators.FindPtr(operId); + if (!operPtr) { + throw yexception() << "No such oper: " << operId; + } + + return *operPtr; +} + } diff --git a/ydb/library/yql/parser/pg_catalog/catalog.h b/ydb/library/yql/parser/pg_catalog/catalog.h index ab6930767a..81289d5318 100644 --- a/ydb/library/yql/parser/pg_catalog/catalog.h +++ b/ydb/library/yql/parser/pg_catalog/catalog.h @@ -18,7 +18,7 @@ struct TOperDesc { ui32 LeftType = 0; ui32 RightType = 0; ui32 ResultType = 0; - TString Code; + ui32 ProcId = 0; }; struct TProcDesc { @@ -67,4 +67,8 @@ bool HasCast(ui32 sourceId, ui32 targetId); const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId); const TCastDesc& LookupCast(ui32 castId); +const TOperDesc& LookupOper(const TString& name, const TVector<ui32>& argTypeIds); +const TOperDesc& LookupOper(ui32 operId, const TVector<ui32>& argTypeIds); +const TOperDesc& LookupOper(ui32 operId); + } diff --git a/ydb/library/yql/providers/common/mkql/CMakeLists.txt b/ydb/library/yql/providers/common/mkql/CMakeLists.txt index 782f5594fa..6d49dca7d1 100644 --- a/ydb/library/yql/providers/common/mkql/CMakeLists.txt +++ b/ydb/library/yql/providers/common/mkql/CMakeLists.txt @@ -24,6 +24,7 @@ target_link_libraries(providers-common-mkql PUBLIC yql-core-expr_nodes common-schema-expr providers-dq-expr_nodes + yql-parser-pg_catalog ) target_sources(providers-common-mkql PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/common/mkql/parser.cpp diff --git a/ydb/library/yql/providers/common/mkql/ya.make b/ydb/library/yql/providers/common/mkql/ya.make index 8de943fad2..30bd0aeb16 100644 --- a/ydb/library/yql/providers/common/mkql/ya.make +++ b/ydb/library/yql/providers/common/mkql/ya.make @@ -26,6 +26,7 @@ PEERDIR( ydb/library/yql/core/expr_nodes ydb/library/yql/providers/common/schema/expr ydb/library/yql/providers/dq/expr_nodes + ydb/library/yql/parser/pg_catalog ) YQL_LAST_ABI_VERSION() diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index d10332a053..ae1e5ca8d1 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -13,6 +13,7 @@ #include <ydb/library/yql/minikql/mkql_runtime_version.h> #include <ydb/library/yql/minikql/mkql_type_ops.h> #include <ydb/library/yql/public/decimal/yql_decimal.h> +#include <ydb/library/yql/parser/pg_catalog/catalog.h> #include <util/stream/null.h> @@ -2245,7 +2246,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { auto name = node.Head().Content(); auto id = FromString<ui32>(node.Child(1)->Content()); std::vector<TRuntimeNode> args; - args.reserve(node.ChildrenSize() - 1); + args.reserve(node.ChildrenSize() - 2); for (ui32 i = 2; i < node.ChildrenSize(); ++i) { args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); } @@ -2254,6 +2255,20 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return ctx.ProgramBuilder.PgResolvedCall(name, id, args, returnType); }); + AddCallable("PgResolvedOp", [](const TExprNode& node, TMkqlBuildContext& ctx) { + auto operId = FromString<ui32>(node.Child(1)->Content()); + auto procId = NPg::LookupOper(operId).ProcId; + auto procName = NPg::LookupProc(procId).Name; + std::vector<TRuntimeNode> args; + args.reserve(node.ChildrenSize() - 2); + for (ui32 i = 2; i < node.ChildrenSize(); ++i) { + args.push_back(MkqlBuildExpr(*node.Child(i), ctx)); + } + + auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + return ctx.ProgramBuilder.PgResolvedCall(procName, procId, args, returnType); + }); + AddCallable("PgCast", [](const TExprNode& node, TMkqlBuildContext& ctx) { auto input = MkqlBuildExpr(*node.Child(1), ctx); auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); diff --git a/ydb/library/yql/sql/pg/pg_sql.cpp b/ydb/library/yql/sql/pg/pg_sql.cpp index 3b5919bd68..5213712462 100644 --- a/ydb/library/yql/sql/pg/pg_sql.cpp +++ b/ydb/library/yql/sql/pg/pg_sql.cpp @@ -1455,32 +1455,40 @@ public: } if (!value->lexpr) { - auto opIt = UnaryOpTranslation.find(op); - if (opIt == UnaryOpTranslation.end()) { - AddError(TStringBuilder() << "Unsupported unary op: " << op); - return nullptr; - } - auto rhs = ParseExpr(value->rexpr, settings); if (!rhs) { return nullptr; } - return L(A(opIt->second), rhs); - } else { - auto opIt = BinaryOpTranslation.find(op); - if (opIt == BinaryOpTranslation.end()) { - AddError(TStringBuilder() << "Unsupported binary op: " << op); - return nullptr; - } + if (Settings.PgTypes) { + return L(A("PgOp"), QA(op), rhs); + } else { + auto opIt = UnaryOpTranslation.find(op); + if (opIt == UnaryOpTranslation.end()) { + AddError(TStringBuilder() << "Unsupported unary op: " << op); + return nullptr; + } + return L(A(opIt->second), rhs); + } + } else { auto lhs = ParseExpr(value->lexpr, settings); auto rhs = ParseExpr(value->rexpr, settings); if (!lhs || !rhs) { return nullptr; } - return L(A(opIt->second), lhs, rhs); + if (Settings.PgTypes) { + return L(A("PgOp"), QA(op), lhs, rhs); + } else { + auto opIt = BinaryOpTranslation.find(op); + if (opIt == BinaryOpTranslation.end()) { + AddError(TStringBuilder() << "Unsupported binary op: " << op); + return nullptr; + } + + return L(A(opIt->second), lhs, rhs); + } } } |