diff options
author | vvvv <vvvv@ydb.tech> | 2022-09-02 10:21:42 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-09-02 10:21:42 +0300 |
commit | 541ed5eac5b51c78bb18304e014a6c38f626a9d0 (patch) | |
tree | 22315c98b52445d29a8f4a2ba3b6a49194cf9bfb | |
parent | d4a0ee7cb80faa47064f5fa2493d0f0e2b6de508 (diff) | |
download | ydb-541ed5eac5b51c78bb18304e014a6c38f626a9d0.tar.gz |
grouping sets implementation
-rw-r--r-- | ydb/library/yql/core/common_opt/yql_co_pgselect.cpp | 83 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_pg.cpp | 111 |
2 files changed, 149 insertions, 45 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp index 600633243d0..3ed7ec37864 100644 --- a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp +++ b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp @@ -1796,34 +1796,27 @@ TExprNode::TPtr BuildGroup(TPositionHandle pos, TExprNode::TPtr list, } auto payloadsNode = ctx.NewList(pos, std::move(payloadItems)); - TExprNode::TListType keysItems; + TExprNode::TListType extKeysItems; if (finalExtTypes) { for (const auto& x : finalExtTypes->Tail().Children()) { auto type = x->Tail().GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TStructExprType>(); for (const auto& i : type->GetItems()) { - keysItems.push_back(ctx.NewAtom(pos, NTypeAnnImpl::MakeAliasedColumn(x->Head().Content(), i->GetName()))); + extKeysItems.push_back(ctx.NewAtom(pos, NTypeAnnImpl::MakeAliasedColumn(x->Head().Content(), i->GetName()))); } } } + TExprNode::TListType groupKeysItems; if (groupExprs->Tail().ChildrenSize()) { auto arg = ctx.NewArgument(pos, "row"); auto arguments = ctx.NewArguments(pos, { arg }); TExprNode::TListType newColumns; - for (ui32 i = 0; i < groupSets->Tail().ChildrenSize(); ++i) { - auto set = groupSets->Tail().Child(i); - YQL_ENSURE(set->ChildrenSize() == 1); - auto first = set->HeadPtr(); - YQL_ENSURE(first->ChildrenSize() == 1); - auto second = first->HeadPtr(); - YQL_ENSURE(second->IsAtom()); - ui32 index = FromString<ui32>(second->Content()); - YQL_ENSURE(index < groupExprs->Tail().ChildrenSize()); - const auto& group = groupExprs->Tail().Child(index); + for (ui32 i = 0; i < groupExprs->Tail().ChildrenSize(); ++i) { + const auto& group = groupExprs->Tail().Child(i); const auto& lambda = group->Tail(); auto name = "_yql_agg_key_" + ToString(i); - keysItems.push_back(ctx.NewAtom(pos, name)); + groupKeysItems.push_back(ctx.NewAtom(pos, name)); newColumns.push_back(ctx.Builder(pos) .List() .Atom(0, name) @@ -1858,17 +1851,63 @@ TExprNode::TPtr BuildGroup(TPositionHandle pos, TExprNode::TPtr list, .Build(); } - auto keysNode = ctx.NewList(pos, std::move(keysItems)); + TVector<ui32> currentSetIndices, setCounts; + currentSetIndices.resize(groupSets->Tail().ChildrenSize()); + for (ui32 i = 0; i < groupSets->Tail().ChildrenSize(); ++i) { + auto set = groupSets->Tail().Child(i); + YQL_ENSURE(set->ChildrenSize() >= 1); + setCounts.push_back(set->ChildrenSize()); + } - return ctx.Builder(pos) - .Callable("Aggregate") - .Add(0, list) - .Add(1, keysNode) - .Add(2, payloadsNode) - .List(3) // options + TExprNode::TListType unionAllItems; + for (;;) { + TExprNode::TListType keysItems = extKeysItems; + // calculate grouping set keys for current position + TSet<ui32> currentKeys; + for (ui32 i = 0; i < currentSetIndices.size(); ++i) { + const auto& set = groupSets->Tail().Child(i)->Child(currentSetIndices[i]); + YQL_ENSURE(set->IsList()); + for (const auto& atom : set->Children()) { + YQL_ENSURE(atom->IsAtom()); + currentKeys.insert(FromString<ui32>(atom->Content())); + } + } + + for (auto keyIndex : currentKeys) { + YQL_ENSURE(keyIndex < groupKeysItems.size()); + keysItems.push_back(groupKeysItems[keyIndex]); + } + + auto keysNode = ctx.NewList(pos, std::move(keysItems)); + auto aggregate = ctx.Builder(pos) + .Callable("Aggregate") + .Add(0, list) + .Add(1, keysNode) + .Add(2, payloadsNode) + .List(3) // options + .Seal() .Seal() - .Seal() - .Build(); + .Build(); + + unionAllItems.push_back(aggregate); + // shift iterator + ui32 i = 0; + while (i < currentSetIndices.size()) { + ++currentSetIndices[i]; + if (currentSetIndices[i] < setCounts[i]) { + break; + } + + currentSetIndices[i] = 0; + ++i; + } + + if (i == currentSetIndices.size()) { + break; + } + } + + return ctx.NewCallable(pos, "UnionAll", std::move(unionAllItems)); } TExprNode::TPtr BuildHaving(TPositionHandle pos, TExprNode::TPtr list, const TExprNode::TPtr& having, diff --git a/ydb/library/yql/core/type_ann/type_ann_pg.cpp b/ydb/library/yql/core/type_ann/type_ann_pg.cpp index 78d6fa125dd..bc36a0077f9 100644 --- a/ydb/library/yql/core/type_ann/type_ann_pg.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_pg.cpp @@ -2196,45 +2196,110 @@ TExprNode::TPtr ReplaceGroupByExpr(const TExprNode::TPtr& root, const TExprNode& return ret; } +ui32 RegisterGroupExpression(const TExprNode::TPtr& root, const TExprNode::TPtr& args, const TExprNode::TPtr& group, + THashMap<ui64, TVector<ui32>>& hashes, TExprNode::TListType& groupExprsItems, TExprContext& ctx) { + TNodeMap<ui64> visitedHashes; + visitedHashes[&args->Head()] = 0; + auto hash = CalculateExprHash(*root, visitedHashes); + auto it = hashes.find(hash); + if (it != hashes.end()) { + for (auto i : it->second) { + TNodeSet visitedNodes; + if (ExprNodesEquals(*root, groupExprsItems[i]->Tail().Tail(), visitedNodes)) { + return i; + } + } + } + + auto index = groupExprsItems.size(); + hashes[hash].push_back(index); + auto newLambda = ctx.Builder(group->Pos()) + .Lambda() + .Param("row") + .ApplyPartial(args, root) + .With(0, "row") + .Seal() + .Seal() + .Build(); + + newLambda->Head().Head().SetArgIndex(0); + auto newExpr = ctx.ChangeChild(*group, 1, std::move(newLambda)); + groupExprsItems.push_back(newExpr); + return index; +} + bool BuildGroupingSets(const TExprNode& data, TExprNode::TPtr& groupSets, TExprNode::TPtr& groupExprs, TExprContext& ctx) { TExprNode::TListType groupSetsItems, groupExprsItems; THashMap<ui64, TVector<ui32>> hashes; for (const auto& child : data.Children()) { const auto& lambda = child->Tail(); + TExprNode::TPtr sets; if (lambda.Tail().IsCallable("PgGroupingSet")) { - YQL_ENSURE(false); - } else { - TNodeMap<ui64> visited; - visited[&lambda.Head().Head()] = 0; - auto hash = CalculateExprHash(lambda.Tail(), visited); - ui32 index; - bool found = false; - auto it = hashes.find(hash); - if (it != hashes.end()) { - for (auto i : it->second) { - TNodeSet visited; - if (ExprNodesEquals(lambda.Tail(), groupExprsItems[i]->Tail().Tail(), visited)) { - index = i; - found = true; - break; + const auto& gs = lambda.Tail(); + auto kind = gs.Head().Content(); + if (kind == "cube" || kind == "rollup") { + TExprNode::TListType indices; + for (const auto& expr : gs.Tail().Children()) { + auto index = RegisterGroupExpression(expr, lambda.HeadPtr(), child, hashes, groupExprsItems, ctx); + indices.push_back(ctx.NewAtom(expr->Pos(), ToString(index))); + } + + TExprNode::TListType setsItems; + if (kind == "rollup") { + // generate N+1 sets + for (ui32 i = 0; i <= indices.size(); ++i) { + TExprNode::TListType oneSetItems; + for (ui32 j = 0; j < i; ++j) { + oneSetItems.push_back(indices[j]); + } + + setsItems.push_back(ctx.NewList(data.Pos(), std::move(oneSetItems))); + } + } else { + // generate 2**N sets + YQL_ENSURE(indices.size() <= 5, "Too many CUBE components"); + ui32 count = (1u << indices.size()); + for (ui32 i = 0; i < count; ++i) { + TExprNode::TListType oneSetItems; + for (ui32 j = 0; j < indices.size(); ++j) { + if ((1u << j) & i) { + oneSetItems.push_back(indices[j]); + } + } + + setsItems.push_back(ctx.NewList(data.Pos(), std::move(oneSetItems))); } } - } - if (!found) { - index = groupExprsItems.size(); - hashes[hash].push_back(index); - groupExprsItems.push_back(child); - } + sets = ctx.NewList(data.Pos(), std::move(setsItems)); + } else { + YQL_ENSURE(kind == "sets"); + TExprNode::TListType setsItems; + for (ui32 setIndex = 1; setIndex < gs.ChildrenSize(); ++setIndex) { + const auto& g = gs.Child(setIndex); + TExprNode::TListType oneSetItems; + for (const auto& expr : g->Children()) { + auto index = RegisterGroupExpression(expr, lambda.HeadPtr(), child, hashes, groupExprsItems, ctx); + oneSetItems.push_back(ctx.NewAtom(expr->Pos(), ToString(index))); + } + + setsItems.push_back(ctx.NewList(data.Pos(), std::move(oneSetItems))); + } - groupSetsItems.push_back(ctx.Builder(data.Pos()) + sets = ctx.NewList(data.Pos(), std::move(setsItems)); + } + } else { + auto index = RegisterGroupExpression(lambda.TailPtr(), lambda.HeadPtr(), child, hashes, groupExprsItems, ctx); + sets = ctx.Builder(data.Pos()) .List() .List(0) .Atom(0, ToString(index)) .Seal() .Seal() - .Build()); + .Build(); } + + groupSetsItems.push_back(sets); } groupSets = ctx.NewList(data.Pos(), std::move(groupSetsItems)); |