diff options
author | vvvv <vvvv@yandex-team.ru> | 2022-03-15 17:33:07 +0300 |
---|---|---|
committer | vvvv <vvvv@yandex-team.ru> | 2022-03-15 17:33:07 +0300 |
commit | be1e84d0d989bf7eeeade81d38de61370f9db02e (patch) | |
tree | b67eac1b2ed2d1ad3286cd625b595d233df4b383 | |
parent | d594eb68063ba68b202dda9ab4a3f91bda0383f7 (diff) | |
download | ydb-be1e84d0d989bf7eeeade81d38de61370f9db02e.tar.gz |
YQL-13710 FromPg, supported of pgbool predicate in where/having
ref:2f6a9a935e2fc2a33c9a28878d842c5197e7db00
-rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_simple1.cpp | 22 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_core.cpp | 72 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_type_annotation.cpp | 11 | ||||
-rw-r--r-- | ydb/library/yql/minikql/mkql_program_builder.cpp | 20 | ||||
-rw-r--r-- | ydb/library/yql/minikql/mkql_program_builder.h | 3 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_wrapper/comp_factory.cpp | 95 | ||||
-rw-r--r-- | ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp | 6 |
7 files changed, 204 insertions, 25 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_simple1.cpp b/ydb/library/yql/core/common_opt/yql_co_simple1.cpp index 11899138d86..1fc8746898f 100644 --- a/ydb/library/yql/core/common_opt/yql_co_simple1.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_simple1.cpp @@ -6641,8 +6641,15 @@ void RegisterCoSimpleCallables1(TCallableOptimizerMap& map) { .Add(0, list) .Lambda(1) .Param("row") - .Apply(filter->Tail().TailPtr()) - .With(0, "row") + .Callable("Coalesce") + .Callable(0, "FromPg") + .Apply(0, filter->Tail().TailPtr()) + .With(0, "row") + .Seal() + .Seal() + .Callable(1, "Bool") + .Atom(0, "0") + .Seal() .Seal() .Seal() .Seal() @@ -6785,8 +6792,15 @@ void RegisterCoSimpleCallables1(TCallableOptimizerMap& map) { .Add(0, list) .Lambda(1) .Param("row") - .Apply(havingLambda) - .With(0, "row") + .Callable("Coalesce") + .Callable(0, "FromPg") + .Apply(0, havingLambda) + .With(0, "row") + .Seal() + .Seal() + .Callable(1, "Bool") + .Atom(0, "0") + .Seal() .Seal() .Seal() .Seal() diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 1e020f2d03f..822785ce0ae 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -4207,7 +4207,7 @@ namespace NTypeAnnImpl { auto leftItemType = leftType; if (leftType->GetKind() == ETypeAnnotationKind::Optional) { leftItemType = leftType->Cast<TOptionalExprType>()->GetItemType(); - } else { + } else if (leftType->GetKind() != ETypeAnnotationKind::Pg) { output = input->HeadPtr(); return IGraphTransformer::TStatus::Repeat; } @@ -4765,8 +4765,8 @@ template <NKikimr::NUdf::EDataSlot DataSlot> return IGraphTransformer::TStatus::Repeat; } - if (type->GetKind() != ETypeAnnotationKind::Optional) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Head().Pos()), TStringBuilder() << "Expected optional or Null type, but got: " + if (type->GetKind() != ETypeAnnotationKind::Optional && type->GetKind() != ETypeAnnotationKind::Pg) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Head().Pos()), TStringBuilder() << "Expected optional, pg type or Null type, but got: " << *type)); return IGraphTransformer::TStatus::Error; } @@ -8937,6 +8937,49 @@ template <NKikimr::NUdf::EDataSlot DataSlot> } } + IGraphTransformer::TStatus FromPgWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + if (!EnsureArgsCount(*input, 1, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureComputable(input->Head(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (input->Head().GetTypeAnn()->GetKind() != ETypeAnnotationKind::Pg) { + output = input->HeadPtr(); + return IGraphTransformer::TStatus::Repeat; + } + + auto name = input->Head().GetTypeAnn()->Cast<TPgExprType>()->GetName(); + const TDataExprType* dataType; + if (name == "bool") { + dataType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Bool); + } else if (name == "int2") { + dataType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Int16); + } else if (name == "int4") { + dataType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Int32); + } else if (name == "int8") { + dataType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Int64); + } else if (name == "float4") { + dataType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Float); + } else if (name == "float8") { + dataType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Double); + } else if (name == "text" || name == "varchar" || name == "cstring") { + dataType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Utf8); + } else if (name == "bytea") { + dataType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::String); + } else { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Unsupported type: " << name)); + return IGraphTransformer::TStatus::Error; + } + + auto result = ctx.Expr.MakeType<TOptionalExprType>(dataType); + input->SetTypeAnn(result); + return IGraphTransformer::TStatus::Ok; + } + IGraphTransformer::TStatus PgOpWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { bool isResolved = input->IsCallable("PgResolvedOp"); if (!EnsureMinArgsCount(*input, isResolved ? 3 : 2, ctx.Expr)) { @@ -10307,16 +10350,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> return IGraphTransformer::TStatus::Error; } - auto predicate = ctx.Expr.Builder(data.Pos()) - .Callable("Coalesce") - .Add(0, newRoot) - .Callable(1, "Bool") - .Atom(0, "0") - .Seal() - .Seal() - .Build(); - - auto newLambda = ctx.Expr.NewLambda(data.Pos(), std::move(arguments), std::move(predicate)); + auto newLambda = ctx.Expr.NewLambda(data.Pos(), std::move(arguments), std::move(newRoot)); auto newChildren = data.ChildrenList(); newChildren[0] = typeNode; @@ -10326,7 +10360,16 @@ template <NKikimr::NUdf::EDataSlot DataSlot> output = ctx.Expr.ChangeChild(*input, 0, std::move(newSettings)); return IGraphTransformer::TStatus::Repeat; } else { - if (!EnsureSpecificDataType(data, EDataSlot::Bool, ctx.Expr)) { + if (data.GetTypeAnn() && data.GetTypeAnn()->GetKind() == ETypeAnnotationKind::Null) { + // nothing to do + } else if (data.GetTypeAnn() && data.GetTypeAnn()->GetKind() == ETypeAnnotationKind::Pg) { + auto name = data.GetTypeAnn()->Cast<TPgExprType>()->GetName(); + if (name != "bool") { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(data.Pos()), TStringBuilder() << + "Expected bool type, but got: " << name)); + return IGraphTransformer::TStatus::Error; + } + } else if (!EnsureSpecificDataType(data, EDataSlot::Bool, ctx.Expr, true)) { return IGraphTransformer::TStatus::Error; } } @@ -13431,6 +13474,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["RoundDown"] = &RoundWrapper; Functions["NextValue"] = &NextValueWrapper; + Functions["FromPg"] = &FromPgWrapper; ExtFunctions["PgCall"] = &PgCallWrapper; ExtFunctions["PgResolvedCall"] = &PgCallWrapper; ExtFunctions["PgOp"] = &PgOpWrapper; diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp index 7cccc910e99..9765f53de72 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.cpp +++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp @@ -102,6 +102,13 @@ IGraphTransformer::TStatus TryConvertToImpl(TExprContext& ctx, TExprNode::TPtr& } } + if (expectedType.GetKind() == ETypeAnnotationKind::Pg) { + if (IsNull(sourceType)) { + node = ctx.NewCallable(node->Pos(), "Nothing", { ExpandType(node->Pos(), expectedType, ctx) }); + return IGraphTransformer::TStatus::Repeat; + } + } + if (expectedType.GetKind() == ETypeAnnotationKind::Optional) { auto nextType = expectedType.Cast<TOptionalExprType>()->GetItemType(); auto originalNode = node; @@ -3587,7 +3594,7 @@ IGraphTransformer::TStatus SilentInferCommonType(TExprNode::TPtr& node1, const T } if (IsNull(type1)) { - if (type2.GetKind() == ETypeAnnotationKind::Optional) { + if (type2.GetKind() == ETypeAnnotationKind::Optional || type2.GetKind() == ETypeAnnotationKind::Pg) { node1 = ctx.NewCallable(node1->Pos(), "Nothing", { ExpandType(node2->Pos(), type2, ctx) }); commonType = &type2; return IGraphTransformer::TStatus::Repeat; @@ -3601,7 +3608,7 @@ IGraphTransformer::TStatus SilentInferCommonType(TExprNode::TPtr& node1, const T } if (IsNull(type2)) { - if (type1.GetKind() == ETypeAnnotationKind::Optional) { + if (type1.GetKind() == ETypeAnnotationKind::Optional || type1.GetKind() == ETypeAnnotationKind::Pg) { node2 = ctx.NewCallable(node2->Pos(), "Nothing", { ExpandType(node1->Pos(), type1, ctx) }); commonType = &type1; return IGraphTransformer::TStatus::Repeat; diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 3984f3ed4bf..7912a1f634d 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -1945,10 +1945,14 @@ TRuntimeNode TProgramBuilder::NewEmptyListOfVoid() { return TRuntimeNode(Env.GetListOfVoid(), true); } -TRuntimeNode TProgramBuilder::NewEmptyOptional(TType* optionalType) { - MKQL_ENSURE(optionalType->IsOptional(), "Expected optional type"); +TRuntimeNode TProgramBuilder::NewEmptyOptional(TType* optionalOrPgType) { + MKQL_ENSURE(optionalOrPgType->IsOptional() || optionalOrPgType->IsPg(), "Expected optional or pg type"); - return TRuntimeNode(TOptionalLiteral::Create(static_cast<TOptionalType*>(optionalType), Env), true); + if (optionalOrPgType->IsOptional()) { + return TRuntimeNode(TOptionalLiteral::Create(static_cast<TOptionalType*>(optionalOrPgType), Env), true); + } + + return PgCast(NewNull(), optionalOrPgType); } TRuntimeNode TProgramBuilder::NewEmptyOptionalDataLiteral(NUdf::TDataTypeId schemeType) { @@ -5052,6 +5056,16 @@ TRuntimeNode TProgramBuilder::PgCast(TRuntimeNode input, TType* returnType) { return TRuntimeNode(callableBuilder.Build(), false); } +TRuntimeNode TProgramBuilder::FromPg(TRuntimeNode input, TType* returnType) { + if constexpr (RuntimeVersion < 30U) { + THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; + } + + TCallableBuilder callableBuilder(Env, __func__, returnType); + callableBuilder.Add(input); + return TRuntimeNode(callableBuilder.Build(), false); +} + bool CanExportType(TType* type, const TTypeEnvironment& env) { if (type->GetKind() == TType::EKind::Type) { return false; // Type of Type diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index 599ece65800..d5337ac1001 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -156,7 +156,7 @@ public: TRuntimeNode NewDecimalLiteral(NYql::NDecimal::TInt128 data, ui8 precision, ui8 scale) const; TType* NewOptionalType(TType* itemType); - TRuntimeNode NewEmptyOptional(TType* optionalType); + TRuntimeNode NewEmptyOptional(TType* optionalOrPgType); TRuntimeNode NewEmptyOptionalDataLiteral(NUdf::TDataTypeId schemeType); TRuntimeNode NewOptional(TRuntimeNode data); TRuntimeNode NewOptional(TType* optionalType, TRuntimeNode data); @@ -626,6 +626,7 @@ public: TRuntimeNode PgConst(TPgType* pgType, const std::string_view& value); TRuntimeNode PgResolvedCall(const std::string_view& name, ui32 id, const TArrayRef<const TRuntimeNode>& args, TType* returnType); TRuntimeNode PgCast(TRuntimeNode input, TType* returnType); + TRuntimeNode FromPg(TRuntimeNode input, TType* returnType); protected: TRuntimeNode Invoke(const std::string_view& funcName, TType* resultType, const TArrayRef<const TRuntimeNode>& args); diff --git a/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp b/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp index bae3680e349..4acb4ee8998 100644 --- a/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp +++ b/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp @@ -5,6 +5,7 @@ #include <ydb/library/yql/minikql/mkql_node_cast.h> #include <ydb/library/yql/minikql/mkql_alloc.h> #include <ydb/library/yql/minikql/mkql_node_builder.h> +#include <ydb/library/yql/minikql/mkql_string_util.h> #include <ydb/library/yql/providers/common/codec/yql_pg_codec.h> #include <ydb/library/yql/parser/pg_catalog/catalog.h> #include <ydb/library/yql/core/yql_pg_utils.h> @@ -417,7 +418,7 @@ public: , SourceId(sourceId) , TargetId(targetId) , Arg(arg) - , SourceTypeDesc(NPg::LookupType(SourceId)) + , SourceTypeDesc(SourceId ? NPg::LookupType(SourceId) : NPg::TTypeDesc()) , TargetTypeDesc(NPg::LookupType(targetId)) { TypeIOParam = MakeTypeIOParam(TargetTypeDesc); @@ -612,6 +613,69 @@ private: ui32 TypeIOParam = 0; }; +template <NUdf::EDataSlot Slot, bool IsCString> +class TFromPg : public TMutableComputationNode<TFromPg<Slot, IsCString>> { + typedef TMutableComputationNode<TFromPg<Slot, IsCString>> TBaseComputation; +public: + TFromPg(TComputationMutables& mutables, IComputationNode* arg) + : TBaseComputation(mutables) + , Arg(arg) + { + } + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { + auto value = Arg->GetValue(compCtx); + if (!value) { + return value.Release(); + } + + switch (Slot) + { + case NUdf::EDataSlot::Bool: + return NUdf::TUnboxedValuePod(DatumGetBool(ScalarDatumFromPod(value))); + case NUdf::EDataSlot::Int16: + return NUdf::TUnboxedValuePod(DatumGetInt16(ScalarDatumFromPod(value))); + case NUdf::EDataSlot::Int32: + return NUdf::TUnboxedValuePod(DatumGetInt32(ScalarDatumFromPod(value))); + case NUdf::EDataSlot::Int64: + return NUdf::TUnboxedValuePod(DatumGetInt64(ScalarDatumFromPod(value))); + case NUdf::EDataSlot::Float: + return NUdf::TUnboxedValuePod(DatumGetFloat4(ScalarDatumFromPod(value))); + case NUdf::EDataSlot::Double: + return NUdf::TUnboxedValuePod(DatumGetFloat8(ScalarDatumFromPod(value))); + case NUdf::EDataSlot::String: + case NUdf::EDataSlot::Utf8: + if (value.IsEmbedded() || value.IsString()) { + return value.Release(); + } + + if (IsCString) { + auto x = (const char*)value.AsBoxed().Get() + PallocHdrSize; + return MakeString(TStringBuf(x)); + } else { + auto x = (const text*)((const char*)value.AsBoxed().Get() + PallocHdrSize); + ui32 len = VARSIZE_ANY_EXHDR(x); + TString s; + if (len) { + s = TString::Uninitialized(len); + text_to_cstring_buffer(x, s.begin(), len + 1); + } + + return MakeString(s); + } + default: + Y_UNREACHABLE(); + } + } + +private: + void RegisterDependencies() const final { + this->DependsOn(Arg); + } + + IComputationNode* const Arg; +}; + TComputationNodeFactory GetPgFactory() { return [] (TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* { pg_thread_init(); @@ -655,6 +719,35 @@ TComputationNodeFactory GetPgFactory() { return new TPgCast(ctx.Mutables, sourceId, targetId, arg); } + if (name == "FromPg") { + auto arg = LocateNode(ctx.NodeLocator, callable, 0); + auto inputType = callable.GetInput(0).GetStaticType(); + ui32 sourceId = AS_TYPE(TPgType, inputType)->GetTypeId(); + switch (sourceId) { + case BOOLOID: + return new TFromPg<NUdf::EDataSlot::Bool, false>(ctx.Mutables, arg); + case INT2OID: + return new TFromPg<NUdf::EDataSlot::Int16, false>(ctx.Mutables, arg); + case INT4OID: + return new TFromPg<NUdf::EDataSlot::Int32, false>(ctx.Mutables, arg); + case INT8OID: + return new TFromPg<NUdf::EDataSlot::Int64, false>(ctx.Mutables, arg); + case FLOAT4OID: + return new TFromPg<NUdf::EDataSlot::Float, false>(ctx.Mutables, arg); + case FLOAT8OID: + return new TFromPg<NUdf::EDataSlot::Double, false>(ctx.Mutables, arg); + case TEXTOID: + case VARCHAROID: + return new TFromPg<NUdf::EDataSlot::Utf8, false>(ctx.Mutables, arg); + case BYTEAOID: + return new TFromPg<NUdf::EDataSlot::String, false>(ctx.Mutables, arg); + case CSTRINGOID: + return new TFromPg<NUdf::EDataSlot::Utf8, true>(ctx.Mutables, arg); + default: + ythrow yexception() << "Unsupported type: " << NPg::LookupType(sourceId).Name; + } + } + return nullptr; }; } diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index ae1e5ca8d1b..9198afeb17a 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -2275,6 +2275,12 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return ctx.ProgramBuilder.PgCast(input, returnType); }); + AddCallable("FromPg", [](const TExprNode& node, TMkqlBuildContext& ctx) { + auto input = MkqlBuildExpr(*node.Child(0), ctx); + auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + return ctx.ProgramBuilder.FromPg(input, returnType); + }); + AddCallable("QueueCreate", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto initCapacity = MkqlBuildExpr(*node.Child(1), ctx); const auto initSize = MkqlBuildExpr(*node.Child(2), ctx); |