diff options
author | whcrc <whcrc@yandex-team.ru> | 2022-05-05 16:01:42 +0300 |
---|---|---|
committer | whcrc <whcrc@yandex-team.ru> | 2022-05-05 16:01:42 +0300 |
commit | f8831bfbf91cd9e5f400dbc59372651bf756d87a (patch) | |
tree | 1b7c4db893a22c91e59ce6f8f226b7b19ce162b6 | |
parent | b3353587c534c52ab70463fa77473a36586dadee (diff) | |
download | ydb-f8831bfbf91cd9e5f400dbc59372651bf756d87a.tar.gz |
YQL-14403: map join, support complex keys
ref:8e217cad01482d5c28fc719148a7625fdd436264
-rw-r--r-- | ydb/core/kqp/ut/kqp_join_ut.cpp | 29 | ||||
-rw-r--r-- | ydb/library/yql/dq/opt/dq_opt_peephole.cpp | 94 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp | 9 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp | 75 |
4 files changed, 193 insertions, 14 deletions
diff --git a/ydb/core/kqp/ut/kqp_join_ut.cpp b/ydb/core/kqp/ut/kqp_join_ut.cpp index 4c238373b20..fb11f968509 100644 --- a/ydb/core/kqp/ut/kqp_join_ut.cpp +++ b/ydb/core/kqp/ut/kqp_join_ut.cpp @@ -1055,6 +1055,35 @@ Y_UNIT_TEST_SUITE(KqpJoin) { CompareYson(R"([[5u]])", FormatResultSetYson(result.GetResultSet(0))); } + Y_UNIT_TEST_NEW_ENGINE(JoinLeftPureInnerConverted) { + TKikimrRunner kikimr; + auto db = kikimr.GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + CreateSampleTables(session); + + auto params = db.GetParamsBuilder() + .AddParam("$rows") + .BeginList() + .AddListItem() + .BeginStruct() + .AddMember("Key").Uint8(1) + .EndStruct() + .EndList() + .Build() + .Build(); + auto result = session.ExecuteDataQuery(Q1_(R"( + DECLARE $rows AS List<Struct<Key: Uint8>>; + + SELECT COUNT(*) + FROM AS_TABLE($rows) AS tl + INNER JOIN `/Root/Join1_1` AS tr + ON tl.Key = tr.Key; -- Uint8 = Int32 + )"), TTxControl::BeginTx().CommitTx(), params).GetValueSync(); + + UNIT_ASSERT_VALUES_EQUAL_C(result.GetStatus(), EStatus::SUCCESS, result.GetIssues().ToString()); + CompareYson(R"([[1u]])", FormatResultSetYson(result.GetResultSet(0))); + } + Y_UNIT_TEST_NEW_ENGINE(JoinLeftPureFull) { TKikimrRunner kikimr; auto db = kikimr.GetTableClient(); diff --git a/ydb/library/yql/dq/opt/dq_opt_peephole.cpp b/ydb/library/yql/dq/opt/dq_opt_peephole.cpp index 87ac170eeac..3d4a4352249 100644 --- a/ydb/library/yql/dq/opt/dq_opt_peephole.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_peephole.cpp @@ -120,6 +120,68 @@ TExprNode::TPtr BuildDictKeySelector(TExprContext& ctx, TPositionHandle pos, con .Done().Ptr(); } +TExprNode::TPtr AddConvertedKeys(TExprNode::TPtr list, TExprContext& ctx, TExprNode::TListType& leftKeyColumnNodes, const TTypeAnnotationNode::TListType& keyTypes, const TStructExprType* origItemType) { + std::vector<std::pair<TString, std::pair<TString, const TTypeAnnotationNode*>>> columnsToConvert; + for (auto i = 0U; i < leftKeyColumnNodes.size(); i++) { + const auto origName = TString(leftKeyColumnNodes[i]->Content()); + auto itemType= origItemType->FindItemType(origName); + if (itemType->Equals(*keyTypes[i])) { + continue; + } + const auto newName = TStringBuilder() << origName << "_map_join_core_key_converted_" << i << "_"; + columnsToConvert.emplace_back(origName, std::pair<TString, const TTypeAnnotationNode*>{newName, keyTypes[i]}); + leftKeyColumnNodes[i] = ctx.NewAtom(leftKeyColumnNodes[i]->Pos(), newName); + } + const auto pos = list->Pos(); + return ctx.Builder(pos) + .Callable("Map") + .Add(0, std::move(list)) + .Lambda(1) + .Param("dict") + .Callable("FlattenMembers") + .List(0) + .Atom(0, "") + .Arg(1, "dict") + .Seal() + .List(1) + .Atom(0, "") + .Callable(1, "AsStruct") + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + auto i = 0U; + for (const auto& [oldName, newCol]: columnsToConvert) { + parent.List(i) + .Atom(0, newCol.first) + .Callable(1, "StrictCast") + .Callable(0, "Member") + .Arg(0, "dict") + .Atom(1, oldName) + .Seal() + .Add(1, ExpandType(pos, *newCol.second, ctx)) + .Seal() + .Seal(); + i++; + } + return parent; + }) + .Seal() + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); +} + +TExprNode::TListType OriginalJoinOutputMembers(const TDqPhyMapJoin& mapJoin, TExprContext& ctx) { + const auto origItemType = mapJoin.Ref().GetTypeAnn()->GetKind() == ETypeAnnotationKind::List ? + mapJoin.Ref().GetTypeAnn()->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>() : + mapJoin.Ref().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TStructExprType>(); + TExprNode::TListType structMembers; + structMembers.reserve(origItemType->GetItems().size()); + for (const auto& item: origItemType->GetItems()) { + structMembers.push_back(ctx.NewAtom(mapJoin.Pos(), item->GetName())); + } + return structMembers; +} } // anonymous namespace end /** @@ -130,6 +192,7 @@ TExprNode::TPtr BuildDictKeySelector(TExprContext& ctx, TPositionHandle pos, con * - Explicitly convert right input to the dict * - Use quite pretty trick: do `MapJoinCore` in `FlatMap`-lambda * (rely on the fact that there will be only one element in the `FlatMap`-stream) + * - Align key types using `StrictCast`, use internal columns to store converted left keys */ TExprBase DqPeepholeRewriteMapJoin(const TExprBase& node, TExprContext& ctx) { if (!node.Maybe<TDqPhyMapJoin>()) { @@ -166,12 +229,14 @@ TExprBase DqPeepholeRewriteMapJoin(const TExprBase& node, TExprContext& ctx) { rightPayloads.emplace_back(*it); } + TTypeAnnotationNode::TListType keyTypesLeft(keyWidth); TTypeAnnotationNode::TListType keyTypes(keyWidth); for (auto i = 0U; i < keyTypes.size(); ++i) { const auto keyTypeLeft = itemTypeLeft->FindItemType(leftKeyColumnNodes[i]->Content()); const auto keyTypeRight = itemTypeRight->FindItemType(rightKeyColumnNodes[i]->Content()); bool optKey = false; keyTypes[i] = JoinDryKeyType(keyTypeLeft, keyTypeRight, optKey, ctx); + keyTypesLeft[i] = optKey ? ctx.MakeType<TOptionalExprType>(keyTypes[i]) : keyTypes[i]; if (!keyTypes[i]) keyTypes.clear(); } @@ -221,20 +286,25 @@ TExprBase DqPeepholeRewriteMapJoin(const TExprBase& node, TExprContext& ctx) { const bool payloads = !rightPayloads.empty(); rightInput = MakeDictForJoin<true>(PrepareListForJoin(std::move(rightInput), keyTypes, rightKeyColumnNodes, rightPayloads, payloads, false, true, ctx), payloads, withRightSide, ctx); - - return Build<TCoFlatMap>(ctx, pos) - .Input(std::move(rightInput)) - .Lambda() - .Args({"dict"}) - .Body<TCoMapJoinCore>() - .LeftInput(std::move(leftInput)) - .RightDict("dict") - .JoinKind(mapJoin.JoinType()) - .LeftKeysColumns(ctx.NewList(pos, std::move(leftKeyColumnNodes))) - .LeftRenames(ctx.NewList(pos, std::move(leftRenames))) - .RightRenames(ctx.NewList(pos, std::move(rightRenames))) + leftInput = AddConvertedKeys(std::move(leftInput), ctx, leftKeyColumnNodes, keyTypesLeft, itemTypeLeft); + return Build<TCoExtractMembers>(ctx, pos) + .Input<TCoFlatMap>() + .Input(std::move(rightInput)) + .Lambda() + .Args({"dict"}) + .Body<TCoMapJoinCore>() + .LeftInput(std::move(leftInput)) + .RightDict("dict") + .JoinKind(mapJoin.JoinType()) + .LeftKeysColumns(ctx.NewList(pos, std::move(leftKeyColumnNodes))) + .LeftRenames(ctx.NewList(pos, std::move(leftRenames))) + .RightRenames(ctx.NewList(pos, std::move(rightRenames))) .Build() .Build() + .Build() + .Members() + .Add(OriginalJoinOutputMembers(mapJoin, ctx)) + .Build() .Done(); } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp index 613958f74b2..457092900c6 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_map_join.cpp @@ -1796,7 +1796,6 @@ IComputationNode* WrapMapJoinCore(TCallable& callable, const TComputationNodeFac const auto dictNode = callable.GetInput(1); const auto dictType = AS_TYPE(TDictType, dictNode); const auto dictKeyType = dictType->GetKeyType(); - const bool isTupleKey = dictKeyType->IsTuple(); const auto joinKindNode = callable.GetInput(2); const auto rawKind = AS_VALUE(TDataLiteral, joinKindNode)->AsValue().Get<ui32>(); const auto kind = GetJoinKind(rawKind); @@ -1806,6 +1805,7 @@ IComputationNode* WrapMapJoinCore(TCallable& callable, const TComputationNodeFac AS_TYPE(TFlowType, callable.GetType()->GetReturnType())->GetItemType(): AS_TYPE(TStreamType, callable.GetType()->GetReturnType())->GetItemType(); const auto leftKeyColumnsNode = AS_VALUE(TTupleLiteral, callable.GetInput(3)); + const bool isTupleKey = leftKeyColumnsNode->GetValuesCount() > 1; const auto leftRenamesNode = AS_VALUE(TTupleLiteral, callable.GetInput(4)); const auto rightRenamesNode = AS_VALUE(TTupleLiteral, callable.GetInput(5)); @@ -1832,10 +1832,15 @@ IComputationNode* WrapMapJoinCore(TCallable& callable, const TComputationNodeFac const auto leftColumnType = leftItemType->IsTuple() ? AS_TYPE(TTupleType, leftItemType)->GetElementType(leftKeyColumns[i]): AS_TYPE(TStructType, leftItemType)->GetMemberType(leftKeyColumns[i]); + const auto rightType = isTupleKey ? AS_TYPE(TTupleType, dictKeyType)->GetElementType(i) : dictKeyType; + bool isOptional; + if (UnpackOptional(leftColumnType, isOptional)->IsSameType(*rightType)) { + continue; + } bool isLeftOptional; const auto leftDataType = UnpackOptionalData(leftColumnType, isLeftOptional); bool isRightOptional; - const auto rightDataType = UnpackOptionalData(isTupleKey ? AS_TYPE(TTupleType, dictKeyType)->GetElementType(i) : dictKeyType, isRightOptional); + const auto rightDataType = UnpackOptionalData(rightType, isRightOptional); if (leftDataType->GetSchemeType() != rightDataType->GetSchemeType()) { // find a converter const std::array<TArgType, 2U> argsTypes = {{{rightDataType->GetSchemeType(), isLeftOptional}, {leftDataType->GetSchemeType(), isLeftOptional}}}; diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp index 437567fd6dc..6adb2643c62 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_map_join_ut.cpp @@ -5,6 +5,81 @@ namespace NKikimr { namespace NMiniKQL { Y_UNIT_TEST_SUITE(TMiniKQLMapJoinCoreTest) { + Y_UNIT_TEST_LLVM(TestInnerOnTuple) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto optionalUi64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id, true); + const auto tupleType = pb.NewTupleType({optionalUi64Type, optionalUi64Type}); + const auto emptyOptionalUi64 = pb.NewEmptyOptional(optionalUi64Type); + + const auto key1 = pb.NewTuple(tupleType, { + pb.NewOptional(pb.NewDataLiteral<ui64>(1)), + pb.NewOptional(pb.NewDataLiteral<ui64>(1)), + }); + const auto key2 = pb.NewTuple(tupleType, { + pb.NewOptional(pb.NewDataLiteral<ui64>(2)), + pb.NewOptional(pb.NewDataLiteral<ui64>(2)), + }); + const auto key3 = pb.NewTuple(tupleType, { + pb.NewOptional(pb.NewDataLiteral<ui64>(3)), + emptyOptionalUi64, + }); + const auto key4 = pb.NewTuple(tupleType, { + pb.NewOptional(pb.NewDataLiteral<ui64>(4)), + pb.NewOptional(pb.NewDataLiteral<ui64>(4)), + }); + const auto payload1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("A"); + const auto payload2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("B"); + const auto payload3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("C"); + const auto payload4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("X"); + const auto payload5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("Y"); + const auto payload6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("Z"); + + const auto structType = pb.NewStructType({ + {"Key", tupleType}, + {"Payload", pb.NewDataType(NUdf::TDataType<char*>::Id)} + }); + + const auto list1 = pb.NewList(structType, { + pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key1), "Payload", payload1), + pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key2), "Payload", payload2), + pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key3), "Payload", payload3) + }); + + const auto list2 = pb.NewList(structType, { + pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key2), "Payload", payload4), + pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key3), "Payload", payload5), + pb.AddMember(pb.AddMember(pb.NewEmptyStruct(), "Key", key4), "Payload", payload6) + }); + const auto dict2 = pb.ToSortedDict(list2, false, + [&](TRuntimeNode item) { + return pb.Member(item, "Key"); + }, + [&](TRuntimeNode item) { + return pb.AddMember(pb.NewEmptyStruct(), "Payload", pb.Member(item, "Payload")); + }); + + const auto resultType = pb.NewFlowType(pb.NewStructType({ + {"Left", pb.NewDataType(NUdf::TDataType<char*>::Id)}, + {"Right", pb.NewDataType(NUdf::TDataType<char*>::Id)}, + })); + + const auto pgmReturn = pb.Collect(pb.MapJoinCore(pb.ToFlow(list1), dict2, EJoinKind::Inner, {0U}, {1U, 0U}, {0U, 1U}, resultType)); + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue tuple; + + UNIT_ASSERT(iterator.Next(tuple)); + UNBOXED_VALUE_STR_EQUAL(tuple.GetElement(0), "B"); + UNBOXED_VALUE_STR_EQUAL(tuple.GetElement(1), "X"); + UNIT_ASSERT(iterator.Next(tuple)); + UNBOXED_VALUE_STR_EQUAL(tuple.GetElement(0), "C"); + UNBOXED_VALUE_STR_EQUAL(tuple.GetElement(1), "Y"); + UNIT_ASSERT(!iterator.Next(tuple)); + UNIT_ASSERT(!iterator.Next(tuple)); + } + Y_UNIT_TEST_LLVM(TestInner) { for (ui32 pass = 0; pass < 1; ++pass) { TSetup<LLVM> setup; |