summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <[email protected]>2022-03-01 21:47:34 +0300
committervvvv <[email protected]>2022-03-01 21:47:34 +0300
commit8de79fac61fafe1e9e559da116135cca3f5846d1 (patch)
tree2d394ce07f4e26dc91d6a38832697815c19d804d
parentc71dfbc7a152b6f951bca559dcfa23af6d3952a3 (diff)
YQL-13710 runtime support for PgCast
ref:d8d2e1117419626450d0f8932e59d76e203b79de
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp36
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.cpp10
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.h1
-rw-r--r--ydb/library/yql/parser/pg_catalog/catalog.cpp71
-rw-r--r--ydb/library/yql/parser/pg_catalog/catalog.h5
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp6
-rw-r--r--ydb/library/yql/sql/pg/pg_sql.cpp32
7 files changed, 136 insertions, 25 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 47c2c076dbd..30bb3b63e45 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -9427,6 +9427,41 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
return IGraphTransformer::TStatus::Ok;
}
+ IGraphTransformer::TStatus PgCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ Y_UNUSED(output);
+ if (!EnsureArgsCount(*input, 2, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!EnsureAtom(*input->Child(0), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto targetTypeId = NPg::LookupType(TString(input->Child(0)->Content())).TypeId;
+
+ auto type = input->Tail().GetTypeAnn();
+ ui32 inputTypeId = 0;
+ if (type->GetKind() != ETypeAnnotationKind::Null) {
+ if (type->GetKind() != ETypeAnnotationKind::Pg) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+ TStringBuilder() << "Expected PG type for cast argument, but got: " << type->GetKind()));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ inputTypeId = type->Cast<TPgExprType>()->GetId();
+ }
+
+ if (inputTypeId != 0 && inputTypeId != targetTypeId) {
+ if (NPg::LookupType(inputTypeId).Category != 'S' &&
+ NPg::LookupType(targetTypeId).Category != 'S') {
+ Y_UNUSED(NPg::LookupCast(inputTypeId, targetTypeId));
+ }
+ }
+
+ input->SetTypeAnn(ctx.Expr.MakeType<TPgExprType>(targetTypeId));
+ return IGraphTransformer::TStatus::Ok;
+ }
+
IGraphTransformer::TStatus PgTypeWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
if (!EnsureArgsCount(*input, 1, ctx.Expr)) {
@@ -13156,6 +13191,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["PgAnonWindow"] = &PgAnonWindowWrapper;
Functions["PgConst"] = &PgConstWrapper;
Functions["PgType"] = &PgTypeWrapper;
+ Functions["PgCast"] = &PgCastWrapper;
Functions["AutoDemuxList"] = &AutoDemuxListWrapper;
Functions["AggrCountInit"] = &AggrCountInitWrapper;
Functions["AggrCountUpdate"] = &AggrCountUpdateWrapper;
diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp
index ba55d920722..3984f3ed4bf 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_program_builder.cpp
@@ -5042,6 +5042,16 @@ TRuntimeNode TProgramBuilder::PgResolvedCall(const std::string_view& name, ui32
return TRuntimeNode(callableBuilder.Build(), false);
}
+TRuntimeNode TProgramBuilder::PgCast(TRuntimeNode input, TType* returnType) {
+ if constexpr (RuntimeVersion < 30U) {
+ THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
+ }
+
+ TCallableBuilder callableBuilder(Env, __func__, returnType);
+ callableBuilder.Add(input);
+ return TRuntimeNode(callableBuilder.Build(), false);
+}
+
bool CanExportType(TType* type, const TTypeEnvironment& env) {
if (type->GetKind() == TType::EKind::Type) {
return false; // Type of Type
diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h
index ed9333378c4..599ece65800 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.h
+++ b/ydb/library/yql/minikql/mkql_program_builder.h
@@ -625,6 +625,7 @@ public:
TRuntimeNode PgConst(TPgType* pgType, const std::string_view& value);
TRuntimeNode PgResolvedCall(const std::string_view& name, ui32 id, const TArrayRef<const TRuntimeNode>& args, TType* returnType);
+ TRuntimeNode PgCast(TRuntimeNode input, TType* returnType);
protected:
TRuntimeNode Invoke(const std::string_view& funcName, TType* resultType, const TArrayRef<const TRuntimeNode>& args);
diff --git a/ydb/library/yql/parser/pg_catalog/catalog.cpp b/ydb/library/yql/parser/pg_catalog/catalog.cpp
index fe1a164ca9d..79af8b74af8 100644
--- a/ydb/library/yql/parser/pg_catalog/catalog.cpp
+++ b/ydb/library/yql/parser/pg_catalog/catalog.cpp
@@ -209,11 +209,17 @@ private:
bool IsSupported = true;
};
+struct TLazyTypeInfo {
+ TString ElementType;
+ TString InFunc;
+ TString OutFunc;
+};
+
class TTypesParser : public TParser {
public:
- TTypesParser(TTypes& types, THashMap<ui32, TString>& elementTypes)
+ TTypesParser(TTypes& types, THashMap<ui32, TLazyTypeInfo>& lazyInfos)
: Types(types)
- , ElementTypes(elementTypes)
+ , LazyInfos(lazyInfos)
{}
void OnKey(const TString& key, const TString& value) override {
@@ -223,8 +229,23 @@ public:
LastType.ArrayTypeId = FromString<ui32>(value);
} else if (key == "typname") {
LastType.Name = value;
+ } else if (key == "typcategory") {
+ Y_ENSURE(value.size() == 1);
+ LastType.Category = value[0];
+ } else if (key == "typlen") {
+ if (value == "NAMEDATALEN") {
+ LastType.TypeLen = 64;
+ } else if (value == "SIZEOF_POINTER") {
+ LastType.TypeLen = 8;
+ } else {
+ LastType.TypeLen = FromString<i32>(value);
+ }
} else if (key == "typelem") {
- LastElementType = value; // resolve later
+ LastLazyTypeInfo.ElementType = value; // resolve later
+ } else if (key == "typinput") {
+ LastLazyTypeInfo.InFunc = value; // resolve later
+ } else if (key == "typoutput") {
+ LastLazyTypeInfo.OutFunc = value; // resolve later
} else if (key == "typbyval") {
if (value == "f") {
LastType.PassByValue = false;
@@ -243,19 +264,17 @@ public:
Types[LastType.ArrayTypeId] = LastType;
}
- if (LastElementType) {
- ElementTypes[LastType.TypeId] = LastElementType;
- }
+ LazyInfos[LastType.TypeId] = LastLazyTypeInfo;
LastType = TTypeDesc();
- LastElementType = TString();
+ LastLazyTypeInfo = TLazyTypeInfo();
}
private:
TTypes& Types;
- THashMap<ui32, TString>& ElementTypes;
+ THashMap<ui32, TLazyTypeInfo>& LazyInfos;
TTypeDesc LastType;
- TString LastElementType;
+ TLazyTypeInfo LastLazyTypeInfo;
};
class TCastsParser : public TParser {
@@ -365,9 +384,9 @@ TProcs ParseProcs(const TString& dat, const THashMap<TString, ui32>& typeByName)
return ret;
}
-TTypes ParseTypes(const TString& dat, THashMap<ui32, TString>& elementTypes) {
+TTypes ParseTypes(const TString& dat, THashMap<ui32, TLazyTypeInfo>& lazyInfos) {
TTypes ret;
- TTypesParser parser(ret, elementTypes);
+ TTypesParser parser(ret, lazyInfos);
parser.Do(dat);
return ret;
}
@@ -390,8 +409,8 @@ struct TCatalog {
Y_ENSURE(NResource::FindExact("pg_proc.dat", &procData));
TString castData;
Y_ENSURE(NResource::FindExact("pg_cast.dat", &castData));
- THashMap<ui32, TString> elementTypes;
- Types = ParseTypes(typeData, elementTypes);
+ THashMap<ui32, TLazyTypeInfo> lazyTypeInfos;
+ Types = ParseTypes(typeData, lazyTypeInfos);
for (const auto& [k, v] : Types) {
if (k == v.TypeId) {
Y_ENSURE(TypeByName.insert(std::make_pair(v.Name, k)).second);
@@ -402,8 +421,12 @@ struct TCatalog {
}
}
- for (const auto& [k, v]: elementTypes) {
- auto elemTypePtr = TypeByName.FindPtr(v);
+ for (const auto& [k, v]: lazyTypeInfos) {
+ if (!v.ElementType) {
+ continue;
+ }
+
+ auto elemTypePtr = TypeByName.FindPtr(v.ElementType);
Y_ENSURE(elemTypePtr);
auto typePtr = Types.FindPtr(k);
Y_ENSURE(typePtr);
@@ -417,6 +440,19 @@ struct TCatalog {
ProcByName[v.Name].push_back(k);
}
+ for (const auto&[k, v] : lazyTypeInfos) {
+ auto inFuncIdPtr = ProcByName.FindPtr(v.InFunc);
+ Y_ENSURE(inFuncIdPtr);
+ Y_ENSURE(inFuncIdPtr->size() == 1);
+ auto outFuncIdPtr = ProcByName.FindPtr(v.OutFunc);
+ Y_ENSURE(outFuncIdPtr);
+ Y_ENSURE(outFuncIdPtr->size() == 1);
+ auto typePtr = Types.FindPtr(k);
+ Y_ENSURE(typePtr);
+ typePtr->InFuncId = inFuncIdPtr->at(0);
+ typePtr->OutFuncId = outFuncIdPtr->at(0);
+ }
+
Casts = ParseCasts(castData, TypeByName, ProcByName, Procs);
for (const auto&[k, v] : Casts) {
Y_ENSURE(CastsByDir.insert(std::make_pair(std::make_pair(v.SourceId, v.TargetId), k)).second);
@@ -522,6 +558,11 @@ const TTypeDesc& LookupType(ui32 typeId) {
return *typePtr;
}
+bool HasCast(ui32 sourceId, ui32 targetId) {
+ const auto& catalog = TCatalog::Instance();
+ return catalog.CastsByDir.contains(std::make_pair(sourceId, targetId));
+}
+
const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId) {
const auto& catalog = TCatalog::Instance();
auto castByDirPtr = catalog.CastsByDir.FindPtr(std::make_pair(sourceId, targetId));
diff --git a/ydb/library/yql/parser/pg_catalog/catalog.h b/ydb/library/yql/parser/pg_catalog/catalog.h
index 9ae80704fdb..ab6930767af 100644
--- a/ydb/library/yql/parser/pg_catalog/catalog.h
+++ b/ydb/library/yql/parser/pg_catalog/catalog.h
@@ -36,6 +36,10 @@ struct TTypeDesc {
TString Name;
ui32 ElementTypeId = 0;
bool PassByValue = false;
+ char Category = '\0';
+ ui32 InFuncId = 0;
+ ui32 OutFuncId = 0;
+ i32 TypeLen = 0;
};
enum class ECastMethod {
@@ -59,6 +63,7 @@ const TProcDesc& LookupProc(ui32 procId);
const TTypeDesc& LookupType(const TString& name);
const TTypeDesc& LookupType(ui32 typeId);
+bool HasCast(ui32 sourceId, ui32 targetId);
const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId);
const TCastDesc& LookupCast(ui32 castId);
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index 9e22d7771cd..d10332a053a 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -2254,6 +2254,12 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
return ctx.ProgramBuilder.PgResolvedCall(name, id, args, returnType);
});
+ AddCallable("PgCast", [](const TExprNode& node, TMkqlBuildContext& ctx) {
+ auto input = MkqlBuildExpr(*node.Child(1), ctx);
+ auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder);
+ return ctx.ProgramBuilder.PgCast(input, returnType);
+ });
+
AddCallable("QueueCreate", [](const TExprNode& node, TMkqlBuildContext& ctx) {
const auto initCapacity = MkqlBuildExpr(*node.Child(1), ctx);
const auto initSize = MkqlBuildExpr(*node.Child(2), ctx);
diff --git a/ydb/library/yql/sql/pg/pg_sql.cpp b/ydb/library/yql/sql/pg/pg_sql.cpp
index d8ec0f33bb3..3b5919bd687 100644
--- a/ydb/library/yql/sql/pg/pg_sql.cpp
+++ b/ydb/library/yql/sql/pg/pg_sql.cpp
@@ -802,7 +802,7 @@ public:
return ParseColumnRef(CAST_NODE(ColumnRef, node));
}
case T_TypeCast: {
- return ParseTypeCast(CAST_NODE(TypeCast, node));
+ return ParseTypeCast(CAST_NODE(TypeCast, node), settings);
}
case T_BoolExpr: {
return ParseBoolExpr(CAST_NODE(BoolExpr, node), settings);
@@ -976,7 +976,7 @@ public:
return VL(args.data(), args.size());
}
- TAstNode* ParseTypeCast(const TypeCast* value) {
+ TAstNode* ParseTypeCast(const TypeCast* value, const TExprSettings& settings) {
if (!value->arg) {
AddError("Expected arg");
return nullptr;
@@ -989,19 +989,21 @@ public:
auto arg = value->arg;
auto typeName = value->typeName;
- if (NodeTag(arg) == T_A_Const &&
- (NodeTag(CAST_NODE(A_Const, arg)->val) == T_String ||
- NodeTag(CAST_NODE(A_Const, arg)->val) == T_Null) &&
- typeName->typeOid == 0 &&
+ auto supportedTypeName = typeName->typeOid == 0 &&
!typeName->setof &&
!typeName->pct_type &&
ListLength(typeName->typmods) == 0 &&
ListLength(typeName->arrayBounds) == 0 &&
(ListLength(typeName->names) == 2 &&
- NodeTag(ListNodeNth(typeName->names, 0)) == T_String &&
- !StrCompare(StrVal(ListNodeNth(typeName->names, 0)), "pg_catalog") || ListLength(typeName->names) == 1) &&
+ NodeTag(ListNodeNth(typeName->names, 0)) == T_String &&
+ !StrCompare(StrVal(ListNodeNth(typeName->names, 0)), "pg_catalog") || ListLength(typeName->names) == 1) &&
NodeTag(ListNodeNth(typeName->names, ListLength(typeName->names) - 1)) == T_String &&
- typeName->typemod == -1) {
+ typeName->typemod == -1;
+
+ if (NodeTag(arg) == T_A_Const &&
+ (NodeTag(CAST_NODE(A_Const, arg)->val) == T_String ||
+ NodeTag(CAST_NODE(A_Const, arg)->val) == T_Null) &&
+ supportedTypeName) {
TStringBuf targetType = StrVal(ListNodeNth(typeName->names, ListLength(typeName->names) - 1));
if (NodeTag(CAST_NODE(A_Const, arg)->val) == T_String && targetType == "bool") {
auto str = StrVal(CAST_NODE(A_Const, arg)->val);
@@ -1019,7 +1021,7 @@ public:
}
}
- if (NodeTag(CAST_NODE(A_Const, arg)->val) == T_Null) {
+ if (!Settings.PgTypes && NodeTag(CAST_NODE(A_Const, arg)->val) == T_Null) {
TString yqlType;
if (targetType == "bool") {
yqlType = "Bool";
@@ -1037,6 +1039,16 @@ public:
}
}
+ if (Settings.PgTypes && supportedTypeName) {
+ TStringBuf targetType = StrVal(ListNodeNth(typeName->names, ListLength(typeName->names) - 1));
+ auto input = ParseExpr(arg, settings);
+ if (!input) {
+ return nullptr;
+ }
+
+ return L(A("PgCast"), QA(TString(targetType)), input);
+ }
+
AddError("Unsupported form of type cast");
return nullptr;
}