aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-09-02 10:21:42 +0300
committervvvv <vvvv@ydb.tech>2022-09-02 10:21:42 +0300
commit541ed5eac5b51c78bb18304e014a6c38f626a9d0 (patch)
tree22315c98b52445d29a8f4a2ba3b6a49194cf9bfb
parentd4a0ee7cb80faa47064f5fa2493d0f0e2b6de508 (diff)
downloadydb-541ed5eac5b51c78bb18304e014a6c38f626a9d0.tar.gz
grouping sets implementation
-rw-r--r--ydb/library/yql/core/common_opt/yql_co_pgselect.cpp83
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_pg.cpp111
2 files changed, 149 insertions, 45 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
index 600633243d0..3ed7ec37864 100644
--- a/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
+++ b/ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
@@ -1796,34 +1796,27 @@ TExprNode::TPtr BuildGroup(TPositionHandle pos, TExprNode::TPtr list,
}
auto payloadsNode = ctx.NewList(pos, std::move(payloadItems));
- TExprNode::TListType keysItems;
+ TExprNode::TListType extKeysItems;
if (finalExtTypes) {
for (const auto& x : finalExtTypes->Tail().Children()) {
auto type = x->Tail().GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TStructExprType>();
for (const auto& i : type->GetItems()) {
- keysItems.push_back(ctx.NewAtom(pos, NTypeAnnImpl::MakeAliasedColumn(x->Head().Content(), i->GetName())));
+ extKeysItems.push_back(ctx.NewAtom(pos, NTypeAnnImpl::MakeAliasedColumn(x->Head().Content(), i->GetName())));
}
}
}
+ TExprNode::TListType groupKeysItems;
if (groupExprs->Tail().ChildrenSize()) {
auto arg = ctx.NewArgument(pos, "row");
auto arguments = ctx.NewArguments(pos, { arg });
TExprNode::TListType newColumns;
- for (ui32 i = 0; i < groupSets->Tail().ChildrenSize(); ++i) {
- auto set = groupSets->Tail().Child(i);
- YQL_ENSURE(set->ChildrenSize() == 1);
- auto first = set->HeadPtr();
- YQL_ENSURE(first->ChildrenSize() == 1);
- auto second = first->HeadPtr();
- YQL_ENSURE(second->IsAtom());
- ui32 index = FromString<ui32>(second->Content());
- YQL_ENSURE(index < groupExprs->Tail().ChildrenSize());
- const auto& group = groupExprs->Tail().Child(index);
+ for (ui32 i = 0; i < groupExprs->Tail().ChildrenSize(); ++i) {
+ const auto& group = groupExprs->Tail().Child(i);
const auto& lambda = group->Tail();
auto name = "_yql_agg_key_" + ToString(i);
- keysItems.push_back(ctx.NewAtom(pos, name));
+ groupKeysItems.push_back(ctx.NewAtom(pos, name));
newColumns.push_back(ctx.Builder(pos)
.List()
.Atom(0, name)
@@ -1858,17 +1851,63 @@ TExprNode::TPtr BuildGroup(TPositionHandle pos, TExprNode::TPtr list,
.Build();
}
- auto keysNode = ctx.NewList(pos, std::move(keysItems));
+ TVector<ui32> currentSetIndices, setCounts;
+ currentSetIndices.resize(groupSets->Tail().ChildrenSize());
+ for (ui32 i = 0; i < groupSets->Tail().ChildrenSize(); ++i) {
+ auto set = groupSets->Tail().Child(i);
+ YQL_ENSURE(set->ChildrenSize() >= 1);
+ setCounts.push_back(set->ChildrenSize());
+ }
- return ctx.Builder(pos)
- .Callable("Aggregate")
- .Add(0, list)
- .Add(1, keysNode)
- .Add(2, payloadsNode)
- .List(3) // options
+ TExprNode::TListType unionAllItems;
+ for (;;) {
+ TExprNode::TListType keysItems = extKeysItems;
+ // calculate grouping set keys for current position
+ TSet<ui32> currentKeys;
+ for (ui32 i = 0; i < currentSetIndices.size(); ++i) {
+ const auto& set = groupSets->Tail().Child(i)->Child(currentSetIndices[i]);
+ YQL_ENSURE(set->IsList());
+ for (const auto& atom : set->Children()) {
+ YQL_ENSURE(atom->IsAtom());
+ currentKeys.insert(FromString<ui32>(atom->Content()));
+ }
+ }
+
+ for (auto keyIndex : currentKeys) {
+ YQL_ENSURE(keyIndex < groupKeysItems.size());
+ keysItems.push_back(groupKeysItems[keyIndex]);
+ }
+
+ auto keysNode = ctx.NewList(pos, std::move(keysItems));
+ auto aggregate = ctx.Builder(pos)
+ .Callable("Aggregate")
+ .Add(0, list)
+ .Add(1, keysNode)
+ .Add(2, payloadsNode)
+ .List(3) // options
+ .Seal()
.Seal()
- .Seal()
- .Build();
+ .Build();
+
+ unionAllItems.push_back(aggregate);
+ // shift iterator
+ ui32 i = 0;
+ while (i < currentSetIndices.size()) {
+ ++currentSetIndices[i];
+ if (currentSetIndices[i] < setCounts[i]) {
+ break;
+ }
+
+ currentSetIndices[i] = 0;
+ ++i;
+ }
+
+ if (i == currentSetIndices.size()) {
+ break;
+ }
+ }
+
+ return ctx.NewCallable(pos, "UnionAll", std::move(unionAllItems));
}
TExprNode::TPtr BuildHaving(TPositionHandle pos, TExprNode::TPtr list, const TExprNode::TPtr& having,
diff --git a/ydb/library/yql/core/type_ann/type_ann_pg.cpp b/ydb/library/yql/core/type_ann/type_ann_pg.cpp
index 78d6fa125dd..bc36a0077f9 100644
--- a/ydb/library/yql/core/type_ann/type_ann_pg.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_pg.cpp
@@ -2196,45 +2196,110 @@ TExprNode::TPtr ReplaceGroupByExpr(const TExprNode::TPtr& root, const TExprNode&
return ret;
}
+ui32 RegisterGroupExpression(const TExprNode::TPtr& root, const TExprNode::TPtr& args, const TExprNode::TPtr& group,
+ THashMap<ui64, TVector<ui32>>& hashes, TExprNode::TListType& groupExprsItems, TExprContext& ctx) {
+ TNodeMap<ui64> visitedHashes;
+ visitedHashes[&args->Head()] = 0;
+ auto hash = CalculateExprHash(*root, visitedHashes);
+ auto it = hashes.find(hash);
+ if (it != hashes.end()) {
+ for (auto i : it->second) {
+ TNodeSet visitedNodes;
+ if (ExprNodesEquals(*root, groupExprsItems[i]->Tail().Tail(), visitedNodes)) {
+ return i;
+ }
+ }
+ }
+
+ auto index = groupExprsItems.size();
+ hashes[hash].push_back(index);
+ auto newLambda = ctx.Builder(group->Pos())
+ .Lambda()
+ .Param("row")
+ .ApplyPartial(args, root)
+ .With(0, "row")
+ .Seal()
+ .Seal()
+ .Build();
+
+ newLambda->Head().Head().SetArgIndex(0);
+ auto newExpr = ctx.ChangeChild(*group, 1, std::move(newLambda));
+ groupExprsItems.push_back(newExpr);
+ return index;
+}
+
bool BuildGroupingSets(const TExprNode& data, TExprNode::TPtr& groupSets, TExprNode::TPtr& groupExprs, TExprContext& ctx) {
TExprNode::TListType groupSetsItems, groupExprsItems;
THashMap<ui64, TVector<ui32>> hashes;
for (const auto& child : data.Children()) {
const auto& lambda = child->Tail();
+ TExprNode::TPtr sets;
if (lambda.Tail().IsCallable("PgGroupingSet")) {
- YQL_ENSURE(false);
- } else {
- TNodeMap<ui64> visited;
- visited[&lambda.Head().Head()] = 0;
- auto hash = CalculateExprHash(lambda.Tail(), visited);
- ui32 index;
- bool found = false;
- auto it = hashes.find(hash);
- if (it != hashes.end()) {
- for (auto i : it->second) {
- TNodeSet visited;
- if (ExprNodesEquals(lambda.Tail(), groupExprsItems[i]->Tail().Tail(), visited)) {
- index = i;
- found = true;
- break;
+ const auto& gs = lambda.Tail();
+ auto kind = gs.Head().Content();
+ if (kind == "cube" || kind == "rollup") {
+ TExprNode::TListType indices;
+ for (const auto& expr : gs.Tail().Children()) {
+ auto index = RegisterGroupExpression(expr, lambda.HeadPtr(), child, hashes, groupExprsItems, ctx);
+ indices.push_back(ctx.NewAtom(expr->Pos(), ToString(index)));
+ }
+
+ TExprNode::TListType setsItems;
+ if (kind == "rollup") {
+ // generate N+1 sets
+ for (ui32 i = 0; i <= indices.size(); ++i) {
+ TExprNode::TListType oneSetItems;
+ for (ui32 j = 0; j < i; ++j) {
+ oneSetItems.push_back(indices[j]);
+ }
+
+ setsItems.push_back(ctx.NewList(data.Pos(), std::move(oneSetItems)));
+ }
+ } else {
+ // generate 2**N sets
+ YQL_ENSURE(indices.size() <= 5, "Too many CUBE components");
+ ui32 count = (1u << indices.size());
+ for (ui32 i = 0; i < count; ++i) {
+ TExprNode::TListType oneSetItems;
+ for (ui32 j = 0; j < indices.size(); ++j) {
+ if ((1u << j) & i) {
+ oneSetItems.push_back(indices[j]);
+ }
+ }
+
+ setsItems.push_back(ctx.NewList(data.Pos(), std::move(oneSetItems)));
}
}
- }
- if (!found) {
- index = groupExprsItems.size();
- hashes[hash].push_back(index);
- groupExprsItems.push_back(child);
- }
+ sets = ctx.NewList(data.Pos(), std::move(setsItems));
+ } else {
+ YQL_ENSURE(kind == "sets");
+ TExprNode::TListType setsItems;
+ for (ui32 setIndex = 1; setIndex < gs.ChildrenSize(); ++setIndex) {
+ const auto& g = gs.Child(setIndex);
+ TExprNode::TListType oneSetItems;
+ for (const auto& expr : g->Children()) {
+ auto index = RegisterGroupExpression(expr, lambda.HeadPtr(), child, hashes, groupExprsItems, ctx);
+ oneSetItems.push_back(ctx.NewAtom(expr->Pos(), ToString(index)));
+ }
+
+ setsItems.push_back(ctx.NewList(data.Pos(), std::move(oneSetItems)));
+ }
- groupSetsItems.push_back(ctx.Builder(data.Pos())
+ sets = ctx.NewList(data.Pos(), std::move(setsItems));
+ }
+ } else {
+ auto index = RegisterGroupExpression(lambda.TailPtr(), lambda.HeadPtr(), child, hashes, groupExprsItems, ctx);
+ sets = ctx.Builder(data.Pos())
.List()
.List(0)
.Atom(0, ToString(index))
.Seal()
.Seal()
- .Build());
+ .Build();
}
+
+ groupSetsItems.push_back(sets);
}
groupSets = ctx.NewList(data.Pos(), std::move(groupSetsItems));