aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorwhcrc <whcrc@yandex-team.ru>2022-05-05 16:01:42 +0300
committerwhcrc <whcrc@yandex-team.ru>2022-05-05 16:01:42 +0300
commitf8831bfbf91cd9e5f400dbc59372651bf756d87a (patch)
tree1b7c4db893a22c91e59ce6f8f226b7b19ce162b6
parentb3353587c534c52ab70463fa77473a36586dadee (diff)
downloadydb-f8831bfbf91cd9e5f400dbc59372651bf756d87a.tar.gz
YQL-14403: map join, support complex keys
ref:8e217cad01482d5c28fc719148a7625fdd436264
-rw-r--r--ydb/core/kqp/ut/kqp_join_ut.cpp29
-rw-r--r--ydb/library/yql/dq/opt/dq_opt_peephole.cpp94
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp9
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp75
4 files changed, 193 insertions, 14 deletions
diff --git a/ydb/core/kqp/ut/kqp_join_ut.cpp b/ydb/core/kqp/ut/kqp_join_ut.cpp
index 4c238373b20..fb11f968509 100644
--- a/ydb/core/kqp/ut/kqp_join_ut.cpp
+++ b/ydb/core/kqp/ut/kqp_join_ut.cpp
@@ -1055,6 +1055,35 @@ Y_UNIT_TEST_SUITE(KqpJoin) {
CompareYson(R"([[5u]])", FormatResultSetYson(result.GetResultSet(0)));
}
+ Y_UNIT_TEST_NEW_ENGINE(JoinLeftPureInnerConverted) {
+ TKikimrRunner kikimr;
+ auto db = kikimr.GetTableClient();
+ auto session = db.CreateSession().GetValueSync().GetSession();
+ CreateSampleTables(session);
+
+ auto params = db.GetParamsBuilder()
+ .AddParam("$rows")
+ .BeginList()
+ .AddListItem()
+ .BeginStruct()
+ .AddMember("Key").Uint8(1)
+ .EndStruct()
+ .EndList()
+ .Build()
+ .Build();
+ auto result = session.ExecuteDataQuery(Q1_(R"(
+ DECLARE $rows AS List<Struct<Key: Uint8>>;
+
+ SELECT COUNT(*)
+ FROM AS_TABLE($rows) AS tl
+ INNER JOIN `/Root/Join1_1` AS tr
+ ON tl.Key = tr.Key; -- Uint8 = Int32
+ )"), TTxControl::BeginTx().CommitTx(), params).GetValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL_C(result.GetStatus(), EStatus::SUCCESS, result.GetIssues().ToString());
+ CompareYson(R"([[1u]])", FormatResultSetYson(result.GetResultSet(0)));
+ }
+
Y_UNIT_TEST_NEW_ENGINE(JoinLeftPureFull) {
TKikimrRunner kikimr;
auto db = kikimr.GetTableClient();
diff --git a/ydb/library/yql/dq/opt/dq_opt_peephole.cpp b/ydb/library/yql/dq/opt/dq_opt_peephole.cpp
index 87ac170eeac..3d4a4352249 100644
--- a/ydb/library/yql/dq/opt/dq_opt_peephole.cpp
+++ b/ydb/library/yql/dq/opt/dq_opt_peephole.cpp
@@ -120,6 +120,68 @@ TExprNode::TPtr BuildDictKeySelector(TExprContext& ctx, TPositionHandle pos, con
.Done().Ptr();
}
+TExprNode::TPtr AddConvertedKeys(TExprNode::TPtr list, TExprContext& ctx, TExprNode::TListType& leftKeyColumnNodes, const TTypeAnnotationNode::TListType& keyTypes, const TStructExprType* origItemType) {
+ std::vector<std::pair<TString, std::pair<TString, const TTypeAnnotationNode*>>> columnsToConvert;
+ for (auto i = 0U; i < leftKeyColumnNodes.size(); i++) {
+ const auto origName = TString(leftKeyColumnNodes[i]->Content());
+ auto itemType= origItemType->FindItemType(origName);
+ if (itemType->Equals(*keyTypes[i])) {
+ continue;
+ }
+ const auto newName = TStringBuilder() << origName << "_map_join_core_key_converted_" << i << "_";
+ columnsToConvert.emplace_back(origName, std::pair<TString, const TTypeAnnotationNode*>{newName, keyTypes[i]});
+ leftKeyColumnNodes[i] = ctx.NewAtom(leftKeyColumnNodes[i]->Pos(), newName);
+ }
+ const auto pos = list->Pos();
+ return ctx.Builder(pos)
+ .Callable("Map")
+ .Add(0, std::move(list))
+ .Lambda(1)
+ .Param("dict")
+ .Callable("FlattenMembers")
+ .List(0)
+ .Atom(0, "")
+ .Arg(1, "dict")
+ .Seal()
+ .List(1)
+ .Atom(0, "")
+ .Callable(1, "AsStruct")
+ .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
+ auto i = 0U;
+ for (const auto& [oldName, newCol]: columnsToConvert) {
+ parent.List(i)
+ .Atom(0, newCol.first)
+ .Callable(1, "StrictCast")
+ .Callable(0, "Member")
+ .Arg(0, "dict")
+ .Atom(1, oldName)
+ .Seal()
+ .Add(1, ExpandType(pos, *newCol.second, ctx))
+ .Seal()
+ .Seal();
+ i++;
+ }
+ return parent;
+ })
+ .Seal()
+ .Seal()
+ .Seal()
+ .Seal()
+ .Seal()
+ .Build();
+}
+
+TExprNode::TListType OriginalJoinOutputMembers(const TDqPhyMapJoin& mapJoin, TExprContext& ctx) {
+ const auto origItemType = mapJoin.Ref().GetTypeAnn()->GetKind() == ETypeAnnotationKind::List ?
+ mapJoin.Ref().GetTypeAnn()->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>() :
+ mapJoin.Ref().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TStructExprType>();
+ TExprNode::TListType structMembers;
+ structMembers.reserve(origItemType->GetItems().size());
+ for (const auto& item: origItemType->GetItems()) {
+ structMembers.push_back(ctx.NewAtom(mapJoin.Pos(), item->GetName()));
+ }
+ return structMembers;
+}
} // anonymous namespace end
/**
@@ -130,6 +192,7 @@ TExprNode::TPtr BuildDictKeySelector(TExprContext& ctx, TPositionHandle pos, con
* - Explicitly convert right input to the dict
* - Use quite pretty trick: do `MapJoinCore` in `FlatMap`-lambda
* (rely on the fact that there will be only one element in the `FlatMap`-stream)
+ * - Align key types using `StrictCast`, use internal columns to store converted left keys
*/
TExprBase DqPeepholeRewriteMapJoin(const TExprBase& node, TExprContext& ctx) {
if (!node.Maybe<TDqPhyMapJoin>()) {
@@ -166,12 +229,14 @@ TExprBase DqPeepholeRewriteMapJoin(const TExprBase& node, TExprContext& ctx) {
rightPayloads.emplace_back(*it);
}
+ TTypeAnnotationNode::TListType keyTypesLeft(keyWidth);
TTypeAnnotationNode::TListType keyTypes(keyWidth);
for (auto i = 0U; i < keyTypes.size(); ++i) {
const auto keyTypeLeft = itemTypeLeft->FindItemType(leftKeyColumnNodes[i]->Content());
const auto keyTypeRight = itemTypeRight->FindItemType(rightKeyColumnNodes[i]->Content());
bool optKey = false;
keyTypes[i] = JoinDryKeyType(keyTypeLeft, keyTypeRight, optKey, ctx);
+ keyTypesLeft[i] = optKey ? ctx.MakeType<TOptionalExprType>(keyTypes[i]) : keyTypes[i];
if (!keyTypes[i])
keyTypes.clear();
}
@@ -221,20 +286,25 @@ TExprBase DqPeepholeRewriteMapJoin(const TExprBase& node, TExprContext& ctx) {
const bool payloads = !rightPayloads.empty();
rightInput = MakeDictForJoin<true>(PrepareListForJoin(std::move(rightInput), keyTypes, rightKeyColumnNodes, rightPayloads, payloads, false, true, ctx), payloads, withRightSide, ctx);
-
- return Build<TCoFlatMap>(ctx, pos)
- .Input(std::move(rightInput))
- .Lambda()
- .Args({"dict"})
- .Body<TCoMapJoinCore>()
- .LeftInput(std::move(leftInput))
- .RightDict("dict")
- .JoinKind(mapJoin.JoinType())
- .LeftKeysColumns(ctx.NewList(pos, std::move(leftKeyColumnNodes)))
- .LeftRenames(ctx.NewList(pos, std::move(leftRenames)))
- .RightRenames(ctx.NewList(pos, std::move(rightRenames)))
+ leftInput = AddConvertedKeys(std::move(leftInput), ctx, leftKeyColumnNodes, keyTypesLeft, itemTypeLeft);
+ return Build<TCoExtractMembers>(ctx, pos)
+ .Input<TCoFlatMap>()
+ .Input(std::move(rightInput))
+ .Lambda()
+ .Args({"dict"})
+ .Body<TCoMapJoinCore>()
+ .LeftInput(std::move(leftInput))
+ .RightDict("dict")
+ .JoinKind(mapJoin.JoinType())
+ .LeftKeysColumns(ctx.NewList(pos, std::move(leftKeyColumnNodes)))
+ .LeftRenames(ctx.NewList(pos, std::move(leftRenames)))
+ .RightRenames(ctx.NewList(pos, std::move(rightRenames)))
.Build()
.Build()
+ .Build()
+ .Members()
+ .Add(OriginalJoinOutputMembers(mapJoin, ctx))
+ .Build()
.Done();
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp
index 613958f74b2..457092900c6 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp
@@ -1796,7 +1796,6 @@ IComputationNode* WrapMapJoinCore(TCallable& callable, const TComputationNodeFac
const auto dictNode = callable.GetInput(1);
const auto dictType = AS_TYPE(TDictType, dictNode);
const auto dictKeyType = dictType->GetKeyType();
- const bool isTupleKey = dictKeyType->IsTuple();
const auto joinKindNode = callable.GetInput(2);
const auto rawKind = AS_VALUE(TDataLiteral, joinKindNode)->AsValue().Get<ui32>();
const auto kind = GetJoinKind(rawKind);
@@ -1806,6 +1805,7 @@ IComputationNode* WrapMapJoinCore(TCallable& callable, const TComputationNodeFac
AS_TYPE(TFlowType, callable.GetType()->GetReturnType())->GetItemType():
AS_TYPE(TStreamType, callable.GetType()->GetReturnType())->GetItemType();
const auto leftKeyColumnsNode = AS_VALUE(TTupleLiteral, callable.GetInput(3));
+ const bool isTupleKey = leftKeyColumnsNode->GetValuesCount() > 1;
const auto leftRenamesNode = AS_VALUE(TTupleLiteral, callable.GetInput(4));
const auto rightRenamesNode = AS_VALUE(TTupleLiteral, callable.GetInput(5));
@@ -1832,10 +1832,15 @@ IComputationNode* WrapMapJoinCore(TCallable& callable, const TComputationNodeFac
const auto leftColumnType = leftItemType->IsTuple() ?
AS_TYPE(TTupleType, leftItemType)->GetElementType(leftKeyColumns[i]):
AS_TYPE(TStructType, leftItemType)->GetMemberType(leftKeyColumns[i]);
+ const auto rightType = isTupleKey ? AS_TYPE(TTupleType, dictKeyType)->GetElementType(i) : dictKeyType;
+ bool isOptional;
+ if (UnpackOptional(leftColumnType, isOptional)->IsSameType(*rightType)) {
+ continue;
+ }
bool isLeftOptional;
const auto leftDataType = UnpackOptionalData(leftColumnType, isLeftOptional);
bool isRightOptional;
- const auto rightDataType = UnpackOptionalData(isTupleKey ? AS_TYPE(TTupleType, dictKeyType)->GetElementType(i) : dictKeyType, isRightOptional);
+ const auto rightDataType = UnpackOptionalData(rightType, isRightOptional);
if (leftDataType->GetSchemeType() != rightDataType->GetSchemeType()) {
// find a converter
const std::array<TArgType, 2U> argsTypes = {{{rightDataType->GetSchemeType(), isLeftOptional}, {leftDataType->GetSchemeType(), isLeftOptional}}};
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp
index 437567fd6dc..6adb2643c62 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp
@@ -5,6 +5,81 @@ namespace NKikimr {
namespace NMiniKQL {
Y_UNIT_TEST_SUITE(TMiniKQLMapJoinCoreTest) {
+ Y_UNIT_TEST_LLVM(TestInnerOnTuple) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto optionalUi64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id, true);
+ const auto tupleType = pb.NewTupleType({optionalUi64Type, optionalUi64Type});
+ const auto emptyOptionalUi64 = pb.NewEmptyOptional(optionalUi64Type);
+
+ const auto key1 = pb.NewTuple(tupleType, {
+ pb.NewOptional(pb.NewDataLiteral<ui64>(1)),
+ pb.NewOptional(pb.NewDataLiteral<ui64>(1)),
+ });
+ const auto key2 = pb.NewTuple(tupleType, {
+ pb.NewOptional(pb.NewDataLiteral<ui64>(2)),
+ pb.NewOptional(pb.NewDataLiteral<ui64>(2)),
+ });
+ const auto key3 = pb.NewTuple(tupleType, {
+ pb.NewOptional(pb.NewDataLiteral<ui64>(3)),
+ emptyOptionalUi64,
+ });
+ const auto key4 = pb.NewTuple(tupleType, {
+ pb.NewOptional(pb.NewDataLiteral<ui64>(4)),
+ pb.NewOptional(pb.NewDataLiteral<ui64>(4)),
+ });
+ const auto payload1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("A");
+ const auto payload2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("B");
+ const auto payload3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("C");
+ const auto payload4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("X");
+ const auto payload5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("Y");
+ const auto payload6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("Z");
+
+ const auto structType = pb.NewStructType({
+ {"Key", tupleType},
+ {"Payload", pb.NewDataType(NUdf::TDataType<char*>::Id)}
+ });
+
+ const auto list1 = pb.NewList(structType, {
+ pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key1), "Payload", payload1),
+ pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key2), "Payload", payload2),
+ pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key3), "Payload", payload3)
+ });
+
+ const auto list2 = pb.NewList(structType, {
+ pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key2), "Payload", payload4),
+ pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key3), "Payload", payload5),
+ pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key4), "Payload", payload6)
+ });
+ const auto dict2 = pb.ToSortedDict(list2, false,
+ [&](TRuntimeNode item) {
+ return pb.Member(item, "Key");
+ },
+ [&](TRuntimeNode item) {
+ return pb.AddMember(pb.NewEmptyStruct(), "Payload", pb.Member(item, "Payload"));
+ });
+
+ const auto resultType = pb.NewFlowType(pb.NewStructType({
+ {"Left", pb.NewDataType(NUdf::TDataType<char*>::Id)},
+ {"Right", pb.NewDataType(NUdf::TDataType<char*>::Id)},
+ }));
+
+ const auto pgmReturn = pb.Collect(pb.MapJoinCore(pb.ToFlow(list1), dict2, EJoinKind::Inner, {0U}, {1U, 0U}, {0U, 1U}, resultType));
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue().GetListIterator();
+ NUdf::TUnboxedValue tuple;
+
+ UNIT_ASSERT(iterator.Next(tuple));
+ UNBOXED_VALUE_STR_EQUAL(tuple.GetElement(0), "B");
+ UNBOXED_VALUE_STR_EQUAL(tuple.GetElement(1), "X");
+ UNIT_ASSERT(iterator.Next(tuple));
+ UNBOXED_VALUE_STR_EQUAL(tuple.GetElement(0), "C");
+ UNBOXED_VALUE_STR_EQUAL(tuple.GetElement(1), "Y");
+ UNIT_ASSERT(!iterator.Next(tuple));
+ UNIT_ASSERT(!iterator.Next(tuple));
+ }
+
Y_UNIT_TEST_LLVM(TestInner) {
for (ui32 pass = 0; pass < 1; ++pass) {
TSetup<LLVM> setup;