diff options
author | vvvv <vvvv@yandex-team.ru> | 2022-04-13 19:41:17 +0300 |
---|---|---|
committer | vvvv <vvvv@yandex-team.ru> | 2022-04-13 19:41:17 +0300 |
commit | 02760ddd8f167f6c627a0a502252a493c0326800 (patch) | |
tree | 75866fc830f3875ac09ff596b3a4d251ed869cdc | |
parent | 657422034bdec0d2c81859a1e6ddd09dec316930 (diff) | |
download | ydb-02760ddd8f167f6c627a0a502252a493c0326800.tar.gz |
YQL-13710 range functions (multi)
ref:1cfc6a770263056c9ded411ebb52d3a5f10967e3
-rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_pgselect.cpp | 33 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_pg.cpp | 10 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/catalog.cpp | 20 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/catalog.h | 2 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_wrapper/comp_factory.cpp | 254 | ||||
-rw-r--r-- | ydb/library/yql/sql/pg/pg_sql.cpp | 13 |
6 files changed, 294 insertions, 38 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp index 7997f4fc31..ae5a0e36b4 100644 --- a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp @@ -198,16 +198,33 @@ TExprNode::TListType BuildCleanedColumns(TPositionHandle pos, const TExprNode::T Y_ENSURE(!alias.empty()); Y_ENSURE(columns.ChildrenSize() == 0 || columns.ChildrenSize() == 1); auto memberName = (columns.ChildrenSize() == 1) ? columns.Head().Content() : alias; - list = ctx.Builder(pos) - .Callable("AsList") - .Callable(0, "AsStruct") - .List(0) - .Atom(0, memberName) - .Add(1, list) + if (list->GetTypeAnn()->GetKind() == ETypeAnnotationKind::List) { + list = ctx.Builder(pos) + .Callable("OrderedMap") + .Add(0, list) + .Lambda(1) + .Param("item") + .Callable("AsStruct") + .List(0) + .Atom(0, memberName) + .Arg(1, "item") + .Seal() + .Seal() .Seal() .Seal() - .Seal() - .Build(); + .Build(); + } else { + list = ctx.Builder(pos) + .Callable("AsList") + .Callable(0, "AsStruct") + .List(0) + .Atom(0, memberName) + .Add(1, list) + .Seal() + .Seal() + .Seal() + .Build(); + } } auto cleaned = ctx.Builder(pos) diff --git a/ydb/library/yql/core/type_ann/type_ann_pg.cpp b/ydb/library/yql/core/type_ann/type_ann_pg.cpp index 5a8db57033..128503e749 100644 --- a/ydb/library/yql/core/type_ann/type_ann_pg.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_pg.cpp @@ -84,7 +84,11 @@ IGraphTransformer::TStatus PgCallWrapper(const TExprNode::TPtr& input, TExprNode return IGraphTransformer::TStatus::Error; } - auto result = ctx.Expr.MakeType<TPgExprType>(proc.ResultType); + const TTypeAnnotationNode* result = ctx.Expr.MakeType<TPgExprType>(proc.ResultType); + if (proc.ReturnSet) { + result = ctx.Expr.MakeType<TListExprType>(result); + } + input->SetTypeAnn(result); return IGraphTransformer::TStatus::Ok; } else { @@ -1779,6 +1783,10 @@ IGraphTransformer::TStatus PgSetItemWrapper(const TExprNode::TPtr& input, TExprN auto memberName = (p->Child(2)->ChildrenSize() == 1) ? p->Child(2)->Head().Content() : alias; TVector<const TItemExprType*> items; auto itemType = p->Head().GetTypeAnn(); + if (itemType->GetKind() == ETypeAnnotationKind::List) { + itemType = itemType->Cast<TListExprType>()->GetItemType(); + } + items.push_back(ctx.Expr.MakeType<TItemExprType>(memberName, itemType)); inputStructType = ctx.Expr.MakeType<TStructExprType>(items); columnOrder = TColumnOrder({ TString(memberName) }); diff --git a/ydb/library/yql/parser/pg_catalog/catalog.cpp b/ydb/library/yql/parser/pg_catalog/catalog.cpp index ed377ed59e..d3d9cd5c07 100644 --- a/ydb/library/yql/parser/pg_catalog/catalog.cpp +++ b/ydb/library/yql/parser/pg_catalog/catalog.cpp @@ -291,7 +291,7 @@ public: } else if (key == "proisstrict") { LastProc.IsStrict = (value == "t"); } else if (key == "proretset") { - IsSupported = false; + LastProc.ReturnSet = (value == "t"); } } @@ -1120,6 +1120,24 @@ const TProcDesc& LookupProc(ui32 procId) { return *procPtr; } +bool HasReturnSetProc(const TStringBuf& name) { + const auto& catalog = TCatalog::Instance(); + auto procIdPtr = catalog.ProcByName.FindPtr(name); + if (!procIdPtr) { + return false; + } + + for (const auto& id : *procIdPtr) { + const auto& d = catalog.Procs.FindPtr(id); + Y_ENSURE(d); + if (d->ReturnSet) { + return true; + } + } + + return false; +} + bool HasType(const TStringBuf& name) { const auto& catalog = TCatalog::Instance(); return catalog.TypeByName.contains(name); diff --git a/ydb/library/yql/parser/pg_catalog/catalog.h b/ydb/library/yql/parser/pg_catalog/catalog.h index 2c9f4e729f..00bb643f9e 100644 --- a/ydb/library/yql/parser/pg_catalog/catalog.h +++ b/ydb/library/yql/parser/pg_catalog/catalog.h @@ -35,6 +35,7 @@ struct TProcDesc { ui32 ResultType = 0; bool IsStrict = true; EProcKind Kind = EProcKind::Function; + bool ReturnSet = false; }; struct TTypeDesc { @@ -135,6 +136,7 @@ enum class EHashAmProcNum { const TProcDesc& LookupProc(const TString& name, const TVector<ui32>& argTypeIds); const TProcDesc& LookupProc(ui32 procId, const TVector<ui32>& argTypeIds); const TProcDesc& LookupProc(ui32 procId); +bool HasReturnSetProc(const TStringBuf& name); bool HasType(const TStringBuf& name); const TTypeDesc& LookupType(const TString& name); diff --git a/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp b/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp index 013b93bb2d..3864271766 100644 --- a/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp +++ b/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp @@ -2,6 +2,7 @@ #include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_pack_impl.h> +#include <ydb/library/yql/minikql/computation/mkql_custom_list.h> #include <ydb/library/yql/minikql/computation/presort_impl.h> #include <ydb/library/yql/minikql/mkql_node_cast.h> #include <ydb/library/yql/minikql/mkql_alloc.h> @@ -26,8 +27,10 @@ extern "C" { #include "utils/builtins.h" #include "utils/memutils.h" #include "nodes/execnodes.h" +#include "executor/executor.h" #include "lib/stringinfo.h" #include "thread_inits.h" + #undef Abs #undef Min #undef Max @@ -253,7 +256,8 @@ public: NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { LOCAL_FCINFO(callInfo, 3); Zero(*callInfo); - callInfo->flinfo = const_cast<FmgrInfo*>(&FInfo); + FmgrInfo copyFmgrInfo = FInfo; + callInfo->flinfo = ©FmgrInfo; callInfo->nargs = 3; callInfo->fncollation = DEFAULT_COLLATION_OID; callInfo->isnull = false; @@ -312,6 +316,7 @@ class TFunctionCallInfo { public: TFunctionCallInfo(ui32 numArgs, const FmgrInfo* finfo) : NumArgs(numArgs) + , CopyFmgrInfo(*finfo) { if (!finfo->fn_addr) { return; @@ -321,7 +326,7 @@ public: Ptr = MKQLAllocWithSize(MemSize); auto& callInfo = Ref(); Zero(callInfo); - callInfo.flinfo = const_cast<FmgrInfo*>(finfo); + callInfo.flinfo = &CopyFmgrInfo; // client may mutate fn_extra callInfo.nargs = NumArgs; callInfo.fncollation = DEFAULT_COLLATION_OID; } @@ -344,16 +349,55 @@ private: const ui32 NumArgs = 0; ui32 MemSize = 0; void* Ptr = nullptr; + FmgrInfo CopyFmgrInfo; }; -class TPgResolvedCall : public TMutableComputationNode<TPgResolvedCall> { - typedef TMutableComputationNode<TPgResolvedCall> TBaseComputation; +class TReturnSetInfo { public: - TPgResolvedCall(TComputationMutables& mutables, bool useContext, const std::string_view& name, ui32 id, - TComputationNodePtrVector&& argNodes, TVector<TType*>&& argTypes) + TReturnSetInfo() { + Ptr = MKQLAllocWithSize(sizeof(ReturnSetInfo)); + Zero(Ref()); + Ref().type = T_ReturnSetInfo; + } + + ~TReturnSetInfo() { + MKQLFreeWithSize(Ptr, sizeof(ReturnSetInfo)); + } + + ReturnSetInfo& Ref() { + return *static_cast<ReturnSetInfo*>(Ptr); + } + +private: + void* Ptr = nullptr; +}; + +class TExprContextHolder { +public: + TExprContextHolder() { + Ptr = CreateStandaloneExprContext(); + } + + ExprContext& Ref() { + return *Ptr; + } + + ~TExprContextHolder() { + FreeExprContext(Ptr, true); + } + +private: + ExprContext* Ptr; +}; + + +template <typename TDerived> +class TPgResolvedCallBase : public TMutableComputationNode<TDerived> { + typedef TMutableComputationNode<TDerived> TBaseComputation; +public: + TPgResolvedCallBase(TComputationMutables& mutables, const std::string_view& name, ui32 id, + TComputationNodePtrVector&& argNodes, TVector<TType*>&& argTypes, bool isList) : TBaseComputation(mutables) - , StateIndex(mutables.CurValueIndex++) - , UseContext(useContext) , Name(name) , Id(id) , ArgNodes(std::move(argNodes)) @@ -364,7 +408,7 @@ public: Zero(FInfo); Y_ENSURE(Id); fmgr_info(Id, &FInfo); - Y_ENSURE(!FInfo.fn_retset); + Y_ENSURE(FInfo.fn_retset == isList); Y_ENSURE(FInfo.fn_addr); Y_ENSURE(FInfo.fn_nargs == ArgNodes.size()); ArgDesc.reserve(ProcDesc.ArgTypes.size()); @@ -389,6 +433,35 @@ public: Y_ENSURE(ArgDesc.size() == ArgNodes.size()); } +private: + void RegisterDependencies() const final { + for (const auto node : ArgNodes) { + this->DependsOn(node); + } + } + +protected: + const std::string_view Name; + const ui32 Id; + FmgrInfo FInfo; + const NPg::TProcDesc ProcDesc; + const NPg::TTypeDesc RetTypeDesc; + const TComputationNodePtrVector ArgNodes; + const TVector<TType*> ArgTypes; + TVector<NPg::TTypeDesc> ArgDesc; +}; + +class TPgResolvedCall : public TPgResolvedCallBase<TPgResolvedCall> { + typedef TPgResolvedCallBase<TPgResolvedCall> TBaseComputation; +public: + TPgResolvedCall(TComputationMutables& mutables, bool useContext, const std::string_view& name, ui32 id, + TComputationNodePtrVector&& argNodes, TVector<TType*>&& argTypes) + : TBaseComputation(mutables, name, id, std::move(argNodes), std::move(argTypes), false) + , StateIndex(mutables.CurValueIndex++) + , UseContext(useContext) + { + } + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { auto& state = GetState(compCtx); auto& callInfo = state.CallInfo.Ref(); @@ -451,12 +524,6 @@ public: } private: - void RegisterDependencies() const final { - for (const auto node : ArgNodes) { - DependsOn(node); - } - } - struct TState : public TComputationValue<TState> { TState(TMemoryUsageInfo* memInfo, ui32 numArgs, const FmgrInfo* finfo) : TComputationValue(memInfo) @@ -478,14 +545,139 @@ private: const ui32 StateIndex; const bool UseContext; - const std::string_view Name; - const ui32 Id; - FmgrInfo FInfo; - const NPg::TProcDesc ProcDesc; - const NPg::TTypeDesc RetTypeDesc; - const TComputationNodePtrVector ArgNodes; - const TVector<TType*> ArgTypes; - TVector<NPg::TTypeDesc> ArgDesc; +}; + +class TPgResolvedMultiCall : public TPgResolvedCallBase<TPgResolvedMultiCall> { + typedef TPgResolvedCallBase<TPgResolvedMultiCall> TBaseComputation; +private: + class TListValue : public TCustomListValue { + public: + class TIterator : public TComputationValue<TIterator> { + public: + TIterator(TMemoryUsageInfo* memInfo, const std::string_view& name, const TUnboxedValueVector& args, + const TVector<NPg::TTypeDesc>& argDesc, const NPg::TTypeDesc& retTypeDesc, const FmgrInfo* fInfo) + : TComputationValue<TIterator>(memInfo) + , Name(name) + , Args(args) + , ArgDesc(argDesc) + , RetTypeDesc(retTypeDesc) + , CallInfo(argDesc.size(), fInfo) + { + auto& callInfo = CallInfo.Ref(); + callInfo.resultinfo = (fmNodePtr)&RSInfo.Ref(); + ((ReturnSetInfo*)callInfo.resultinfo)->econtext = &ExprContextHolder.Ref(); + for (ui32 i = 0; i < args.size(); ++i) { + const auto& value = args[i]; + NullableDatum argDatum = { 0, false }; + if (!value) { + argDatum.isnull = true; + } else { + argDatum.value = ArgDesc[i].PassByValue ? + ScalarDatumFromPod(value) : + PointerDatumFromPod(value); + } + + callInfo.args[i] = argDatum; + } + } + + ~TIterator() { + } + + private: + bool Next(NUdf::TUnboxedValue& value) final { + if (IsFinished) { + return false; + } + + auto& callInfo = CallInfo.Ref(); + PG_TRY(); + { + callInfo.isnull = false; + auto ret = callInfo.flinfo->fn_addr(&callInfo); + if (RSInfo.Ref().isDone == ExprEndResult) { + IsFinished = true; + return false; + } + + if (callInfo.isnull) { + value = NUdf::TUnboxedValuePod(); + } else if (RetTypeDesc.PassByValue) { + value = ScalarDatumToPod(ret); + } else if (TVPtrHolder::IsBoxedVPtr(ret)) { + // returned one of arguments + value = OwnedPointerDatumToPod(ret); + } else { + value = PointerDatumToPod(ret); + } + + return true; + } + PG_CATCH(); + { + auto error_data = CopyErrorData(); + TStringBuilder errMsg; + errMsg << "Error in function: " << Name << ", reason: " << error_data->message; + FreeErrorData(error_data); + FlushErrorState(); + UdfTerminate(errMsg.c_str()); + } + PG_END_TRY(); + } + + const std::string_view Name; + TUnboxedValueVector Args; + const TVector<NPg::TTypeDesc>& ArgDesc; + const NPg::TTypeDesc& RetTypeDesc; + TExprContextHolder ExprContextHolder; + TFunctionCallInfo CallInfo; + TReturnSetInfo RSInfo; + bool IsFinished = false; + }; + + TListValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, + const std::string_view& name, TUnboxedValueVector&& args, const TVector<NPg::TTypeDesc>& argDesc, + const NPg::TTypeDesc& retTypeDesc, const FmgrInfo* fInfo) + : TCustomListValue(memInfo) + , CompCtx(compCtx) + , Name(name) + , Args(args) + , ArgDesc(argDesc) + , RetTypeDesc(retTypeDesc) + , FInfo(fInfo) + { + } + + private: + NUdf::TUnboxedValue GetListIterator() const final { + return CompCtx.HolderFactory.Create<TIterator>(Name, Args, ArgDesc, RetTypeDesc, FInfo); + } + + TComputationContext& CompCtx; + const std::string_view Name; + TUnboxedValueVector Args; + const TVector<NPg::TTypeDesc>& ArgDesc; + const NPg::TTypeDesc& RetTypeDesc; + const FmgrInfo* FInfo; + }; + +public: + TPgResolvedMultiCall(TComputationMutables& mutables, const std::string_view& name, ui32 id, + TComputationNodePtrVector&& argNodes, TVector<TType*>&& argTypes) + : TBaseComputation(mutables, name, id, std::move(argNodes), std::move(argTypes), true) + { + } + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { + TUnboxedValueVector args; + args.reserve(ArgNodes.size()); + for (ui32 i = 0; i < ArgNodes.size(); ++i) { + auto value = ArgNodes[i]->GetValue(compCtx); + args.push_back(value); + } + + return compCtx.HolderFactory.Create<TListValue>(compCtx, Name, std::move(args), ArgDesc, RetTypeDesc, &FInfo); + } }; class TPgCast : public TMutableComputationNode<TPgCast> { @@ -824,7 +1016,13 @@ TComputationNodeFactory GetPgFactory() { argTypes.emplace_back(callable.GetInput(i).GetStaticType()); } - return new TPgResolvedCall(ctx.Mutables, useContext, name, id, std::move(argNodes), std::move(argTypes)); + const bool isList = callable.GetType()->GetReturnType()->IsList(); + if (isList) { + YQL_ENSURE(!useContext); + return new TPgResolvedMultiCall(ctx.Mutables, name, id, std::move(argNodes), std::move(argTypes)); + } else { + return new TPgResolvedCall(ctx.Mutables, useContext, name, id, std::move(argNodes), std::move(argTypes)); + } } if (name == "PgCast") { @@ -1793,7 +1991,7 @@ public: ui64 Hash(NUdf::TUnboxedValuePod lhs) const override { LOCAL_FCINFO(callInfo, 1); Zero(*callInfo); - callInfo->flinfo = const_cast<FmgrInfo*>(&FInfoHash); + callInfo->flinfo = const_cast<FmgrInfo*>(&FInfoHash); // don't copy becase of IHash isn't threadsafe callInfo->nargs = 1; callInfo->fncollation = DEFAULT_COLLATION_OID; callInfo->isnull = false; @@ -1846,7 +2044,7 @@ public: bool Less(NUdf::TUnboxedValuePod lhs, NUdf::TUnboxedValuePod rhs) const override { LOCAL_FCINFO(callInfo, 2); Zero(*callInfo); - callInfo->flinfo = const_cast<FmgrInfo*>(&FInfoLess); + callInfo->flinfo = const_cast<FmgrInfo*>(&FInfoLess); // don't copy becase of ICompare isn't threadsafe callInfo->nargs = 2; callInfo->fncollation = DEFAULT_COLLATION_OID; callInfo->isnull = false; @@ -1877,7 +2075,7 @@ public: int Compare(NUdf::TUnboxedValuePod lhs, NUdf::TUnboxedValuePod rhs) const override { LOCAL_FCINFO(callInfo, 2); Zero(*callInfo); - callInfo->flinfo = const_cast<FmgrInfo*>(&FInfoCompare); + callInfo->flinfo = const_cast<FmgrInfo*>(&FInfoCompare); // don't copy becase of ICompare isn't threadsafe callInfo->nargs = 2; callInfo->fncollation = DEFAULT_COLLATION_OID; callInfo->isnull = false; @@ -1934,7 +2132,7 @@ public: bool Equals(NUdf::TUnboxedValuePod lhs, NUdf::TUnboxedValuePod rhs) const override { LOCAL_FCINFO(callInfo, 2); Zero(*callInfo); - callInfo->flinfo = const_cast<FmgrInfo*>(&FInfoEquate); + callInfo->flinfo = const_cast<FmgrInfo*>(&FInfoEquate); // don't copy becase of IEquate isn't threadsafe callInfo->nargs = 2; callInfo->fncollation = DEFAULT_COLLATION_OID; callInfo->isnull = false; diff --git a/ydb/library/yql/sql/pg/pg_sql.cpp b/ydb/library/yql/sql/pg/pg_sql.cpp index c6822ddfc3..2d0dffc1ab 100644 --- a/ydb/library/yql/sql/pg/pg_sql.cpp +++ b/ydb/library/yql/sql/pg/pg_sql.cpp @@ -123,6 +123,7 @@ public: bool AllowColumns = false; bool AllowAggregates = false; bool AllowOver = false; + bool AllowReturnSet = false; TVector<TAstNode*>* WindowItems = nullptr; TString Scope; }; @@ -874,6 +875,11 @@ public: TString alias; TVector<TString> colnames; + if (!value->alias) { + AddError("RangeFunction: expected alias"); + return {}; + } + if (!ParseAlias(value->alias, alias, colnames)) { return {}; } @@ -892,6 +898,7 @@ public: TExprSettings settings; settings.AllowColumns = false; + settings.AllowReturnSet = true; settings.Scope = "RANGE FUNCTION"; auto func = ParseExpr(ListNodeNth(lst, 0), settings); if (!func) { @@ -1058,12 +1065,18 @@ public: auto name = names.back(); const bool isAggregateFunc = NYql::NPg::HasAggregation(name); + const bool hasReturnSet = NYql::NPg::HasReturnSetProc(name); if (isAggregateFunc && !settings.AllowAggregates) { AddError(TStringBuilder() << "Aggregate functions are not allowed in: " << settings.Scope); return nullptr; } + if (hasReturnSet && !settings.AllowReturnSet) { + AddError(TStringBuilder() << "Generator functions are not allowed in: " << settings.Scope); + return nullptr; + } + TVector<TAstNode*> args; TString callable; if (window) { |