aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@yandex-team.ru>2022-03-02 01:55:04 +0300
committervvvv <vvvv@yandex-team.ru>2022-03-02 01:55:04 +0300
commit26286f616cee657612a9d820be6da2cdbd4de0ef (patch)
tree07cb52f2682e92fd23a00bca9553c57481cf76a8
parent2ba098f9946ccfd045204904c5a6c51d240462e0 (diff)
downloadydb-26286f616cee657612a9d820be6da2cdbd4de0ef.tar.gz
YQL-13710 support for PgOper
ref:fcfb05393bd7e73bbb1087d25bad8d863c14c6fd
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp63
-rw-r--r--ydb/library/yql/parser/pg_catalog/catalog.cpp120
-rw-r--r--ydb/library/yql/parser/pg_catalog/catalog.h6
-rw-r--r--ydb/library/yql/providers/common/mkql/CMakeLists.txt1
-rw-r--r--ydb/library/yql/providers/common/mkql/ya.make1
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp17
-rw-r--r--ydb/library/yql/sql/pg/pg_sql.cpp36
7 files changed, 214 insertions, 30 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 30bb3b63e4..356414974e 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -8774,7 +8774,6 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
};
IGraphTransformer::TStatus PgCallWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) {
- Y_UNUSED(output);
bool isResolved = input->Content() == "PgResolvedCall";
if (!EnsureMinArgsCount(*input, isResolved ? 2 : 1, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
@@ -8896,6 +8895,66 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
}
}
+ IGraphTransformer::TStatus PgOpWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) {
+ bool isResolved = input->IsCallable("PgResolvedOp");
+ if (!EnsureMinArgsCount(*input, isResolved ? 3 : 2, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!EnsureMaxArgsCount(*input, isResolved ? 4 : 3, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!EnsureAtom(input->Head(), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto name = input->Head().Content();
+ if (isResolved) {
+ if (!EnsureAtom(*input->Child(1), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+ }
+
+ TVector<ui32> argTypes;
+ for (ui32 i = isResolved ? 2 : 1; i < input->ChildrenSize(); ++i) {
+ auto type = input->Child(i)->GetTypeAnn();
+ if (type->GetKind() == ETypeAnnotationKind::Null) {
+ argTypes.push_back(0);
+ continue;
+ }
+
+ if (type->GetKind() != ETypeAnnotationKind::Pg) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Expected PG type for argument " << (i - (isResolved ? 2 : 1) + 1) << ", but got: " << type->GetKind() << " for function: " << name));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ argTypes.push_back(type->Cast<TPgExprType>()->GetId());
+ }
+
+ if (isResolved) {
+ auto operId = FromString<ui32>(input->Child(1)->Content());
+ const auto& oper = NPg::LookupOper(operId, argTypes);
+ if (oper.Name != name) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Mismatch of resolved operator name, expected: " << name << ", but got:" << oper.Name));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto result = ctx.Expr.MakeType<TPgExprType>(oper.ResultType);
+ input->SetTypeAnn(result);
+ return IGraphTransformer::TStatus::Ok;
+ } else {
+ const auto& oper = NPg::LookupOper(TString(name), argTypes);
+ auto children = input->ChildrenList();
+ auto idNode = ctx.Expr.NewAtom(input->Pos(), ToString(oper.OperId));
+ children.insert(children.begin() + 1, idNode);
+ output = ctx.Expr.NewCallable(input->Pos(), "PgResolvedOp", std::move(children));
+ return IGraphTransformer::TStatus::Repeat;
+ }
+ }
+
IGraphTransformer::TStatus PgWindowCallWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
if (!EnsureMinArgsCount(*input, 2, ctx.Expr)) {
@@ -13320,6 +13379,8 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
ExtFunctions["PgCall"] = &PgCallWrapper;
ExtFunctions["PgResolvedCall"] = &PgCallWrapper;
+ ExtFunctions["PgOp"] = &PgOpWrapper;
+ ExtFunctions["PgResolvedOp"] = &PgOpWrapper;
ExtFunctions["PgSelect"] = &PgSelectWrapper;
ExtFunctions["PgSetItem"] = &PgSetItemWrapper;
ExtFunctions["TablePath"] = &TablePathWrapper;
diff --git a/ydb/library/yql/parser/pg_catalog/catalog.cpp b/ydb/library/yql/parser/pg_catalog/catalog.cpp
index 79af8b74af..0a40b02065 100644
--- a/ydb/library/yql/parser/pg_catalog/catalog.cpp
+++ b/ydb/library/yql/parser/pg_catalog/catalog.cpp
@@ -105,9 +105,11 @@ public:
class TOperatorsParser : public TParser {
public:
- TOperatorsParser(TOperators& operators, const THashMap<TString, ui32>& typeByName)
+ TOperatorsParser(TOperators& operators, const THashMap<TString, ui32>& typeByName,
+ const THashMap<TString, TVector<ui32>>& procByName)
: Operators(operators)
, TypeByName(typeByName)
+ , ProcByName(procByName)
{}
void OnKey(const TString& key, const TString& value) override {
@@ -138,20 +140,34 @@ public:
Y_ENSURE(typeIdPtr);
LastOperator.ResultType = *typeIdPtr;
} else if (key == "oprcode") {
- LastOperator.Code = value;
+ auto procIdPtr = ProcByName.FindPtr(value);
+ // skip operator if proc isn't buildin, e.g. path_contain_pt
+ if (!procIdPtr) {
+ IsSupported = false;
+ return;
+ }
+
+ Y_ENSURE(procIdPtr->size() == 1);
+ LastOperator.ProcId = procIdPtr->at(0);
}
}
void OnFinish() override {
- Y_ENSURE(!LastOperator.Name.empty());
- Operators[LastOperator.OperId] = LastOperator;
+ if (IsSupported) {
+ Y_ENSURE(!LastOperator.Name.empty());
+ Operators[LastOperator.OperId] = LastOperator;
+ }
+
LastOperator = TOperDesc();
+ IsSupported = true;
}
private:
TOperators& Operators;
const THashMap<TString, ui32>& TypeByName;
+ const THashMap<TString, TVector<ui32>>& ProcByName;
TOperDesc LastOperator;
+ bool IsSupported = true;
};
class TProcsParser : public TParser {
@@ -370,9 +386,10 @@ private:
bool IsSupported = true;
};
-TOperators ParseOperators(const TString& dat, const THashMap<TString, ui32>& typeByName) {
+TOperators ParseOperators(const TString& dat, const THashMap<TString, ui32>& typeByName,
+ const THashMap<TString, TVector<ui32>>& procByName) {
TOperators ret;
- TOperatorsParser parser(ret, typeByName);
+ TOperatorsParser parser(ret, typeByName, procByName);
parser.Do(dat);
return ret;
}
@@ -433,7 +450,6 @@ struct TCatalog {
typePtr->ElementTypeId = *elemTypePtr;
}
- Operators = ParseOperators(opData, TypeByName);
Procs = ParseProcs(procData, TypeByName);
for (const auto& [k, v]: Procs) {
@@ -457,6 +473,11 @@ struct TCatalog {
for (const auto&[k, v] : Casts) {
Y_ENSURE(CastsByDir.insert(std::make_pair(std::make_pair(v.SourceId, v.TargetId), k)).second);
}
+
+ Operators = ParseOperators(opData, TypeByName, ProcByName);
+ for (const auto&[k, v] : Operators) {
+ OperatorsByName[v.Name].push_back(k);
+ }
}
static const TCatalog& Instance() {
@@ -470,9 +491,10 @@ struct TCatalog {
THashMap<TString, TVector<ui32>> ProcByName;
THashMap<TString, ui32> TypeByName;
THashMap<std::pair<ui32, ui32>, ui32> CastsByDir;
+ THashMap<TString, TVector<ui32>> OperatorsByName;
};
-bool ValidateArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) {
+bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) {
if (argTypeIds.size() != d.ArgTypes.size()) {
return false;
}
@@ -499,8 +521,8 @@ const TProcDesc& LookupProc(ui32 procId, const TVector<ui32>& argTypeIds) {
throw yexception() << "No such proc: " << procId;
}
- if (!ValidateArgs(*procPtr, argTypeIds)) {
- throw yexception() << "Unable to find an overload for with oid " << procId << " with given argument types";
+ if (!ValidateProcArgs(*procPtr, argTypeIds)) {
+ throw yexception() << "Unable to find an overload for proc with oid " << procId << " with given argument types";
}
return *procPtr;
@@ -510,20 +532,20 @@ const TProcDesc& LookupProc(const TString& name, const TVector<ui32>& argTypeIds
const auto& catalog = TCatalog::Instance();
auto procIdPtr = catalog.ProcByName.FindPtr(name);
if (!procIdPtr) {
- throw yexception() << "No such function: " << name;
+ throw yexception() << "No such proc: " << name;
}
for (const auto& id : *procIdPtr) {
const auto& d = catalog.Procs.FindPtr(id);
Y_ENSURE(d);
- if (!ValidateArgs(*d, argTypeIds)) {
+ if (!ValidateProcArgs(*d, argTypeIds)) {
continue;
}
return *d;
}
- throw yexception() << "Unable to find an overload for function " << name << " with given argument types";
+ throw yexception() << "Unable to find an overload for proc " << name << " with given argument types";
}
const TProcDesc& LookupProc(ui32 procId) {
@@ -585,4 +607,76 @@ const TCastDesc& LookupCast(ui32 castId) {
return *castPtr;
}
+bool ValidateOperArgs(const TOperDesc& d, const TVector<ui32>& argTypeIds) {
+ ui32 size = d.Kind == EOperKind::Binary ? 2 : 1;
+ if (argTypeIds.size() != size) {
+ return false;
+ }
+
+ bool found = true;
+ for (size_t i = 0; i < argTypeIds.size(); ++i) {
+ if (argTypeIds[i] == 0) {
+ continue; // NULL
+ }
+
+ ui32 expectedArgType;
+ if (d.Kind == EOperKind::RightUnary || (d.Kind == EOperKind::Binary && i == 0)) {
+ expectedArgType = d.LeftType;
+ } else {
+ expectedArgType = d.RightType;
+ }
+
+ if (argTypeIds[i] != expectedArgType) {
+ found = false;
+ break;
+ }
+ }
+
+ return found;
+}
+
+const TOperDesc& LookupOper(const TString& name, const TVector<ui32>& argTypeIds) {
+ const auto& catalog = TCatalog::Instance();
+ auto operIdPtr = catalog.OperatorsByName.FindPtr(name);
+ if (!operIdPtr) {
+ throw yexception() << "No such operator: " << name;
+ }
+
+ for (const auto& id : *operIdPtr) {
+ const auto& d = catalog.Operators.FindPtr(id);
+ Y_ENSURE(d);
+ if (!ValidateOperArgs(*d, argTypeIds)) {
+ continue;
+ }
+
+ return *d;
+ }
+
+ throw yexception() << "Unable to find an overload for operator " << name << " with given argument types";
+}
+
+const TOperDesc& LookupOper(ui32 operId, const TVector<ui32>& argTypeIds) {
+ const auto& catalog = TCatalog::Instance();
+ auto operPtr = catalog.Operators.FindPtr(operId);
+ if (!operPtr) {
+ throw yexception() << "No such oper: " << operId;
+ }
+
+ if (!ValidateOperArgs(*operPtr, argTypeIds)) {
+ throw yexception() << "Unable to find an overload for operator with oid " << operId << " with given argument types";
+ }
+
+ return *operPtr;
+}
+
+const TOperDesc& LookupOper(ui32 operId) {
+ const auto& catalog = TCatalog::Instance();
+ auto operPtr = catalog.Operators.FindPtr(operId);
+ if (!operPtr) {
+ throw yexception() << "No such oper: " << operId;
+ }
+
+ return *operPtr;
+}
+
}
diff --git a/ydb/library/yql/parser/pg_catalog/catalog.h b/ydb/library/yql/parser/pg_catalog/catalog.h
index ab6930767a..81289d5318 100644
--- a/ydb/library/yql/parser/pg_catalog/catalog.h
+++ b/ydb/library/yql/parser/pg_catalog/catalog.h
@@ -18,7 +18,7 @@ struct TOperDesc {
ui32 LeftType = 0;
ui32 RightType = 0;
ui32 ResultType = 0;
- TString Code;
+ ui32 ProcId = 0;
};
struct TProcDesc {
@@ -67,4 +67,8 @@ bool HasCast(ui32 sourceId, ui32 targetId);
const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId);
const TCastDesc& LookupCast(ui32 castId);
+const TOperDesc& LookupOper(const TString& name, const TVector<ui32>& argTypeIds);
+const TOperDesc& LookupOper(ui32 operId, const TVector<ui32>& argTypeIds);
+const TOperDesc& LookupOper(ui32 operId);
+
}
diff --git a/ydb/library/yql/providers/common/mkql/CMakeLists.txt b/ydb/library/yql/providers/common/mkql/CMakeLists.txt
index 782f5594fa..6d49dca7d1 100644
--- a/ydb/library/yql/providers/common/mkql/CMakeLists.txt
+++ b/ydb/library/yql/providers/common/mkql/CMakeLists.txt
@@ -24,6 +24,7 @@ target_link_libraries(providers-common-mkql PUBLIC
yql-core-expr_nodes
common-schema-expr
providers-dq-expr_nodes
+ yql-parser-pg_catalog
)
target_sources(providers-common-mkql PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/common/mkql/parser.cpp
diff --git a/ydb/library/yql/providers/common/mkql/ya.make b/ydb/library/yql/providers/common/mkql/ya.make
index 8de943fad2..30bd0aeb16 100644
--- a/ydb/library/yql/providers/common/mkql/ya.make
+++ b/ydb/library/yql/providers/common/mkql/ya.make
@@ -26,6 +26,7 @@ PEERDIR(
ydb/library/yql/core/expr_nodes
ydb/library/yql/providers/common/schema/expr
ydb/library/yql/providers/dq/expr_nodes
+ ydb/library/yql/parser/pg_catalog
)
YQL_LAST_ABI_VERSION()
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index d10332a053..ae1e5ca8d1 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -13,6 +13,7 @@
#include <ydb/library/yql/minikql/mkql_runtime_version.h>
#include <ydb/library/yql/minikql/mkql_type_ops.h>
#include <ydb/library/yql/public/decimal/yql_decimal.h>
+#include <ydb/library/yql/parser/pg_catalog/catalog.h>
#include <util/stream/null.h>
@@ -2245,7 +2246,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
auto name = node.Head().Content();
auto id = FromString<ui32>(node.Child(1)->Content());
std::vector<TRuntimeNode> args;
- args.reserve(node.ChildrenSize() - 1);
+ args.reserve(node.ChildrenSize() - 2);
for (ui32 i = 2; i < node.ChildrenSize(); ++i) {
args.push_back(MkqlBuildExpr(*node.Child(i), ctx));
}
@@ -2254,6 +2255,20 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
return ctx.ProgramBuilder.PgResolvedCall(name, id, args, returnType);
});
+ AddCallable("PgResolvedOp", [](const TExprNode& node, TMkqlBuildContext& ctx) {
+ auto operId = FromString<ui32>(node.Child(1)->Content());
+ auto procId = NPg::LookupOper(operId).ProcId;
+ auto procName = NPg::LookupProc(procId).Name;
+ std::vector<TRuntimeNode> args;
+ args.reserve(node.ChildrenSize() - 2);
+ for (ui32 i = 2; i < node.ChildrenSize(); ++i) {
+ args.push_back(MkqlBuildExpr(*node.Child(i), ctx));
+ }
+
+ auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder);
+ return ctx.ProgramBuilder.PgResolvedCall(procName, procId, args, returnType);
+ });
+
AddCallable("PgCast", [](const TExprNode& node, TMkqlBuildContext& ctx) {
auto input = MkqlBuildExpr(*node.Child(1), ctx);
auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder);
diff --git a/ydb/library/yql/sql/pg/pg_sql.cpp b/ydb/library/yql/sql/pg/pg_sql.cpp
index 3b5919bd68..5213712462 100644
--- a/ydb/library/yql/sql/pg/pg_sql.cpp
+++ b/ydb/library/yql/sql/pg/pg_sql.cpp
@@ -1455,32 +1455,40 @@ public:
}
if (!value->lexpr) {
- auto opIt = UnaryOpTranslation.find(op);
- if (opIt == UnaryOpTranslation.end()) {
- AddError(TStringBuilder() << "Unsupported unary op: " << op);
- return nullptr;
- }
-
auto rhs = ParseExpr(value->rexpr, settings);
if (!rhs) {
return nullptr;
}
- return L(A(opIt->second), rhs);
- } else {
- auto opIt = BinaryOpTranslation.find(op);
- if (opIt == BinaryOpTranslation.end()) {
- AddError(TStringBuilder() << "Unsupported binary op: " << op);
- return nullptr;
- }
+ if (Settings.PgTypes) {
+ return L(A("PgOp"), QA(op), rhs);
+ } else {
+ auto opIt = UnaryOpTranslation.find(op);
+ if (opIt == UnaryOpTranslation.end()) {
+ AddError(TStringBuilder() << "Unsupported unary op: " << op);
+ return nullptr;
+ }
+ return L(A(opIt->second), rhs);
+ }
+ } else {
auto lhs = ParseExpr(value->lexpr, settings);
auto rhs = ParseExpr(value->rexpr, settings);
if (!lhs || !rhs) {
return nullptr;
}
- return L(A(opIt->second), lhs, rhs);
+ if (Settings.PgTypes) {
+ return L(A("PgOp"), QA(op), lhs, rhs);
+ } else {
+ auto opIt = BinaryOpTranslation.find(op);
+ if (opIt == BinaryOpTranslation.end()) {
+ AddError(TStringBuilder() << "Unsupported binary op: " << op);
+ return nullptr;
+ }
+
+ return L(A(opIt->second), lhs, rhs);
+ }
}
}