diff options
author | vvvv <vvvv@ydb.tech> | 2022-08-16 10:09:31 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-08-16 10:09:31 +0300 |
commit | 090d6a73ed082496eddf79943b652ba49b932617 (patch) | |
tree | 29f156c83cc42901dbee0eeab483df5957261708 | |
parent | b711c1843fb1cb35d8013ee74ecfa1ab47421fb9 (diff) | |
download | ydb-090d6a73ed082496eddf79943b652ba49b932617.tar.gz |
set combiners
-rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_pgselect.cpp | 312 | ||||
-rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_pgselect.h | 5 | ||||
-rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_simple1.cpp | 45 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_core.cpp | 1 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.cpp | 80 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.h | 1 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_pg.cpp | 5 |
7 files changed, 418 insertions, 31 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp index a3bdf910765..a234f791e5f 100644 --- a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp @@ -689,6 +689,38 @@ TExprNode::TPtr BuildFilter(TPositionHandle pos, const TExprNode::TPtr& list, co .Build(); } +TExprNode::TPtr NormalizeColumnOrder(const TExprNode::TPtr& node, const TColumnOrder& sourceColumnOrder, + const TColumnOrder& targetColumnOrder, TExprContext& ctx) { + if (sourceColumnOrder == targetColumnOrder) { + return node; + } + + YQL_ENSURE(sourceColumnOrder.size() == targetColumnOrder.size()); + return ctx.Builder(node->Pos()) + .Callable("OrderedMap") + .Add(0, node) + .Lambda(1) + .Param("row") + .Callable("AsStruct") + .Do([&](TExprNodeBuilder &parent) -> TExprNodeBuilder & { + for (size_t i = 0; i < sourceColumnOrder.size(); ++i) { + parent + .List(i) + .Atom(0, targetColumnOrder[i]) + .Callable(1, "Member") + .Arg(0, "row") + .Atom(1, sourceColumnOrder[i]) + .Seal() + .Seal(); + } + return parent; + }) + .Seal() + .Seal() + .Seal() + .Build(); +} + TExprNode::TPtr ExpandPositionalUnionAll(const TExprNode& node, const TVector<TColumnOrder>& columnOrders, TExprNode::TListType children, TExprContext& ctx, TOptimizeContext& optCtx) { auto targetColumnOrder = optCtx.Types->LookupColumnOrder(node); @@ -697,34 +729,7 @@ TExprNode::TPtr ExpandPositionalUnionAll(const TExprNode& node, const TVector<TC for (ui32 childIndex = 0; childIndex < children.size(); ++childIndex) { const auto& childColumnOrder = columnOrders[childIndex]; auto& child = children[childIndex]; - if (childColumnOrder == *targetColumnOrder) { - continue; - } - - YQL_ENSURE(childColumnOrder.size() == targetColumnOrder->size()); - child = ctx.Builder(child->Pos()) - .Callable("OrderedMap") - .Add(0, child) - .Lambda(1) - .Param("row") - .Callable("AsStruct") - .Do([&](TExprNodeBuilder &parent) -> TExprNodeBuilder & { - for (size_t i = 0; i < childColumnOrder.size(); ++i) { - parent - .List(i) - .Atom(0, child->Pos(), (*targetColumnOrder)[i]) - .Callable(1, "Member") - .Arg(0, "row") - .Atom(1, childColumnOrder[i]) - .Seal() - .Seal(); - } - return parent; - }) - .Seal() - .Seal() - .Seal() - .Build(); + child = NormalizeColumnOrder(child, childColumnOrder, *targetColumnOrder, ctx); } auto res = ctx.NewCallable(node.Pos(), "UnionAll", std::move(children)); @@ -2479,6 +2484,217 @@ TExprNode::TPtr JoinOuter(TPositionHandle pos, TExprNode::TPtr list, return list; } +TExprNode::TPtr CombineSetItems(TPositionHandle pos, const TExprNode::TPtr& left, const TExprNode::TPtr& right, const TStringBuf& op, TExprContext& ctx) { + if (op == "union_all") { + return ctx.NewCallable(pos, "UnionAll", { left, right }); + } + + auto leftSide = ctx.Builder(pos) + .Callable("OrderedMap") + .Add(0, left) + .Lambda(1) + .Param("row") + .Callable("AddMember") + .Callable(0, "AddMember") + .Arg(0, "row") + .Atom(1, "_yql_count_right") + .Callable(2, "Null") + .Seal() + .Seal() + .Atom(1, "_yql_count_left") + .Callable(2, "Uint32") + .Atom(0, "1") + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + + auto rightSide = ctx.Builder(pos) + .Callable("OrderedMap") + .Add(0, right) + .Lambda(1) + .Param("row") + .Callable("AddMember") + .Callable(0, "AddMember") + .Arg(0, "row") + .Atom(1, "_yql_count_right") + .Callable(2, "Uint32") + .Atom(0, "1") + .Seal() + .Seal() + .Atom(1, "_yql_count_left") + .Callable(2, "Null") + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + + auto both = ctx.NewCallable(pos, "UnionAll", { leftSide, rightSide }); + auto aggregated = ctx.Builder(pos) + .Callable("CountedAggregateAll") + .Add(0, both) + .List(1) + .Atom(0, "_yql_count_left") + .Atom(1, "_yql_count_right") + .Seal() + .Seal() + .Build(); + + TExprNode::TPtr ret; + auto zero = ctx.Builder(pos) + .Callable("Uint64") + .Atom(0, "0") + .Seal() + .Build(); + + if (!op.EndsWith("_all")) { + if (op.StartsWith("union")) { + ret = ctx.Builder(pos) + .Callable("OrderedFilter") + .Add(0, aggregated) + .Lambda(1) + .Param("row") + .Callable("Or") + .Callable(0, ">") + .Callable(0, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_left") + .Seal() + .Add(1, zero) + .Seal() + .Callable(1, ">") + .Callable(0, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_right") + .Seal() + .Add(1, zero) + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + } else if (op.StartsWith("intersect")) { + ret = ctx.Builder(pos) + .Callable("OrderedFilter") + .Add(0, aggregated) + .Lambda(1) + .Param("row") + .Callable("And") + .Callable(0, ">") + .Callable(0, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_left") + .Seal() + .Add(1, zero) + .Seal() + .Callable(1, ">") + .Callable(0, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_right") + .Seal() + .Add(1, zero) + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + } else { + YQL_ENSURE(op.StartsWith("except")); + ret = ctx.Builder(pos) + .Callable("OrderedFilter") + .Add(0, aggregated) + .Lambda(1) + .Param("row") + .Callable(0, ">") + .Callable(0, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_left") + .Seal() + .Callable(1, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_right") + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + } + } else { + YQL_ENSURE(!op.StartsWith("union")); + if (op.StartsWith("intersect")) { + ret = ctx.Builder(pos) + .Callable("OrderedFlatMap") + .Add(0, aggregated) + .Lambda(1) + .Param("row") + .Callable("Replicate") + .Arg(0, "row") + .Callable(1, "Min") + .Callable(0, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_left") + .Seal() + .Callable(1, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_right") + .Seal() + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + } else { + YQL_ENSURE(op.StartsWith("except")); + ret = ctx.Builder(pos) + .Callable("OrderedFlatMap") + .Add(0, aggregated) + .Lambda(1) + .Param("row") + .Callable("Replicate") + .Arg(0, "row") + .Callable(1, "-") + .Callable(0, "Max") + .Callable(0, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_left") + .Seal() + .Callable(1, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_right") + .Seal() + .Seal() + .Callable(1, "Member") + .Arg(0, "row") + .Atom(1, "_yql_count_right") + .Seal() + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + } + } + + ret = ctx.Builder(pos) + .Callable("OrderedMap") + .Add(0, ret) + .Lambda(1) + .Param("row") + .Callable("RemoveMembers") + .Arg(0, "row") + .List(1) + .Atom(0, "_yql_count_left") + .Atom(1, "_yql_count_right") + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + return ret; +} + TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx, TMaybe<ui32> subLinkId, const TExprNode::TListType& outerInputs, const TVector<TString>& outerInputAliases) { auto order = optCtx.Types->LookupColumnOrder(*node); @@ -2489,6 +2705,9 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct } auto setItems = GetSetting(node->Head(), "set_items"); + auto setOps = GetSetting(node->Head(), "set_ops"); + YQL_ENSURE(setItems); + YQL_ENSURE(setOps); const bool onlyOneSetItem = (setItems->Tail().ChildrenSize() == 1); TExprNode::TListType setItemNodes; @@ -2642,7 +2861,42 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct if (onlyOneSetItem == 1) { list = setItemNodes.front(); } else { - list = ExpandPositionalUnionAll(*node, columnOrders, setItemNodes, ctx, optCtx); + bool hasNonUnionAll = false; + for (const auto& x : setOps->Tail().Children()) { + if (x->Content() != "push" && x->Content() != "union_all") { + hasNonUnionAll = true; + break; + } + } + + if (hasNonUnionAll) { + TExprNode::TListType stack; + auto targetColumnOrder = optCtx.Types->LookupColumnOrder(*node); + YQL_ENSURE(targetColumnOrder); + + ui32 inputIndex = 0; + for (const auto& x : setOps->Tail().Children()) { + if (x->Content() == "push") { + YQL_ENSURE(inputIndex < setItemNodes.size()); + stack.push_back(NormalizeColumnOrder(setItemNodes[inputIndex], columnOrders[inputIndex], *targetColumnOrder, ctx)); + ++inputIndex; + continue; + } + + YQL_ENSURE(stack.size() >= 2); + auto left = stack[stack.size() - 2]; + auto right = stack[stack.size() - 1]; + stack.pop_back(); + stack.pop_back(); + auto combined = CombineSetItems(node->Pos(), left, right, x->Content(), ctx); + stack.push_back(combined); + } + + YQL_ENSURE(stack.size() == 1); + list = KeepColumnOrder(stack.front(), *node, ctx, *optCtx.Types); + } else { + list = ExpandPositionalUnionAll(*node, columnOrders, setItemNodes, ctx, optCtx); + } } auto finalSort = GetSetting(node->Head(), "sort"); diff --git a/ydb/library/yql/core/common_opt/yql_co_pgselect.h b/ydb/library/yql/core/common_opt/yql_co_pgselect.h index 599ce578394..e3f2b30eed0 100644 --- a/ydb/library/yql/core/common_opt/yql_co_pgselect.h +++ b/ydb/library/yql/core/common_opt/yql_co_pgselect.h @@ -9,9 +9,12 @@ TExprNode::TPtr ExpandPgSelect(const TExprNode::TPtr& node, TExprContext& ctx, T TExprNode::TPtr ExpandPgSelectSublink(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx, ui32 subLinkId, const TExprNode::TListType& outerInputs, const TVector<TString>& outerInputAliases); -TExprNode::TPtr ExpandPositionalUnionAll(const TExprNode& node, const TVector<TColumnOrder>& columnOrders, +TExprNode::TPtr ExpandPositionalUnionAll(const TExprNode& input, const TVector<TColumnOrder>& columnOrders, TExprNode::TListType children, TExprContext& ctx, TOptimizeContext& optCtx); +TExprNode::TPtr NormalizeColumnOrder(const TExprNode::TPtr& node, const TColumnOrder& sourceColumnOrder, + const TColumnOrder& targetColumnOrder, TExprContext& ctx); + TExprNode::TPtr ExpandPgLike(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx); TExprNode::TPtr ExpandPgIn(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx); diff --git a/ydb/library/yql/core/common_opt/yql_co_simple1.cpp b/ydb/library/yql/core/common_opt/yql_co_simple1.cpp index e421d2ebcdf..43be4a2bb8f 100644 --- a/ydb/library/yql/core/common_opt/yql_co_simple1.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_simple1.cpp @@ -6273,6 +6273,51 @@ void RegisterCoSimpleCallables1(TCallableOptimizerMap& map) { return KeepColumnOrder(res, *node, ctx, *optCtx.Types); }; + map["CountedAggregateAll"] = [](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) { + YQL_CLOG(DEBUG, Core) << "Expand " << node->Content(); + auto itemType = GetItemType(*node->Head().GetTypeAnn()); + auto inputTypeNode = ExpandType(node->Pos(), *itemType, ctx); + THashSet<TStringBuf> countedColumns; + for (const auto& child : node->Tail().Children()) { + countedColumns.insert(child->Content()); + } + + TExprNode::TListType keys; + TExprNode::TListType payloads; + for (auto i : itemType->Cast<TStructExprType>()->GetItems()) { + if (!countedColumns.contains(i->GetName())) { + keys.push_back(ctx.NewAtom(node->Pos(), i->GetName())); + } else { + payloads.push_back(ctx.Builder(node->Pos()) + .List() + .Atom(0, i->GetName()) + .Callable(1, "AggApply") + .Atom(0, "count") + .Add(1, inputTypeNode) + .Lambda(2) + .Param("row") + .Callable("Member") + .Arg(0, "row") + .Atom(1, i->GetName()) + .Seal() + .Seal() + .Seal() + .Seal() + .Build()); + } + } + + auto emptyTuple = ctx.NewList(node->Pos(), {}); + auto res = ctx.NewCallable(node->Pos(), "Aggregate", { + node->HeadPtr(), + ctx.NewList(node->Pos(), std::move(keys)), + ctx.NewList(node->Pos(), std::move(payloads)), + emptyTuple + }); + + return KeepColumnOrder(res, *node, ctx, *optCtx.Types); + }; + map["Mux"] = [](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) { if (node->Head().IsList()) { TExprNodeList children = node->Head().ChildrenList(); diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 4c7a5edefee..9a0d9671702 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11317,6 +11317,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["MultiAggregate"] = &MultiAggregateWrapper; Functions["Aggregate"] = &AggregateWrapper; Functions["SqlAggregateAll"] = &SqlAggregateAllWrapper; + Functions["CountedAggregateAll"] = &CountedAggregateAllWrapper; Functions["AggApply"] = &AggApplyWrapper; Functions["WinOnRows"] = &WinOnWrapper; Functions["WinOnGroups"] = &WinOnWrapper; diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index ea686d5bab4..2e07fa8399a 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -4858,6 +4858,86 @@ namespace { return IGraphTransformer::TStatus::Ok; } + IGraphTransformer::TStatus CountedAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + if (!EnsureArgsCount(*input, 2, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (IsEmptyList(input->Head())) { + output = input->HeadPtr(); + return IGraphTransformer::TStatus::Repeat; + } + + bool isStream; + if (!EnsureSeqType(input->Head(), ctx.Expr, &isStream)) { + return IGraphTransformer::TStatus::Error; + } + + auto inputItemType = isStream + ? input->Head().GetTypeAnn()->Cast<TStreamExprType>()->GetItemType() + : input->Head().GetTypeAnn()->Cast<TListExprType>()->GetItemType(); + + if (!EnsureStructType(input->Head().Pos(), *inputItemType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + THashSet<TStringBuf> countedColumns; + auto inputStructType = inputItemType->Cast<TStructExprType>(); + if (!EnsureTuple(input->Tail(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + for (auto child : input->Tail().Children()) { + if (!EnsureAtom(*child, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!inputStructType->FindItem(child->Content())) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Unknown counted member: " << child->Content())); + return IGraphTransformer::TStatus::Error; + } + + if (countedColumns.contains(child->Content())) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Duplicated counted member: " << child->Content())); + return IGraphTransformer::TStatus::Error; + } + + countedColumns.insert(child->Content()); + } + + TVector<const TItemExprType*> retItems; + for (auto& item : inputStructType->GetItems()) { + auto columnName = item->GetName(); + auto columnType = item->GetItemType(); + if (countedColumns.contains(columnName)) { + if (!EnsureComputableType(input->Pos(), *columnType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + retItems.push_back(ctx.Expr.MakeType<TItemExprType>(columnName, ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64))); + continue; + } + + if (!columnType->IsHashable() || !columnType->IsEquatable()) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), + TStringBuilder() << "Expected hashable and equatable type for key column: " << columnName << ", but got: " << *columnType)); + return IGraphTransformer::TStatus::Error; + } + + retItems.push_back(item); + } + + auto retStruct = ctx.Expr.MakeType<TStructExprType>(retItems); + if (isStream) { + input->SetTypeAnn(ctx.Expr.MakeType<TStreamExprType>(retStruct)); + } else { + input->SetTypeAnn(ctx.Expr.MakeType<TListExprType>(retStruct)); + } + return IGraphTransformer::TStatus::Ok; + } + IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { Y_UNUSED(output); if (!EnsureArgsCount(*input, 3, ctx.Expr)) { diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h index 840f6477742..c6f1ccf8cc8 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.h +++ b/ydb/library/yql/core/type_ann/type_ann_list.h @@ -94,6 +94,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus MultiAggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus AggregateWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus SqlAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus CountedAggregateAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus AggApplyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus FilterNullMembersWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus SkipNullMembersWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); diff --git a/ydb/library/yql/core/type_ann/type_ann_pg.cpp b/ydb/library/yql/core/type_ann/type_ann_pg.cpp index d07c5254567..47f153aebf1 100644 --- a/ydb/library/yql/core/type_ann/type_ann_pg.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_pg.cpp @@ -3192,7 +3192,10 @@ IGraphTransformer::TStatus PgSelectWrapper(const TExprNode::TPtr& input, TExprNo return IGraphTransformer::TStatus::Error; } - if (child->Content() != "push" && child->Content() != "union_all") { + if (child->Content() != "push" && child->Content() != "union_all" && + child->Content() != "union" && child->Content() != "except_all" && + child->Content() != "except" && child->Content() != "intersect_all" && + child->Content() != "intersect") { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Pos()), TStringBuilder() << "Unexpected operation: " << child->Content())); return IGraphTransformer::TStatus::Error; |