aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/mkql_program_builder.cpp
diff options
context:
space:
mode:
authorziganshinmr <ziganshinmr@yandex-team.com>2024-12-11 16:43:17 +0300
committerziganshinmr <ziganshinmr@yandex-team.com>2024-12-11 17:43:57 +0300
commite5fb86995e297dee00808170bd7b6fede4c4172c (patch)
tree240ec00d10136ac47eab15965e26bc45baa60e73 /yql/essentials/minikql/mkql_program_builder.cpp
parent0ff827ab7bda8ca3b0aa91a1321eb58adedf6f97 (diff)
downloadydb-e5fb86995e297dee00808170bd7b6fede4c4172c.tar.gz
BlockMapJoinCore computation node
commit_hash:7eaad4219a36a3a486c82cdbf82e7630e59e67f9
Diffstat (limited to 'yql/essentials/minikql/mkql_program_builder.cpp')
-rw-r--r--yql/essentials/minikql/mkql_program_builder.cpp61
1 files changed, 39 insertions, 22 deletions
diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp
index 0a4cd30828..691a642814 100644
--- a/yql/essentials/minikql/mkql_program_builder.cpp
+++ b/yql/essentials/minikql/mkql_program_builder.cpp
@@ -257,16 +257,6 @@ static std::vector<TType*> ValidateBlockItems(const TArrayRef<TType* const>& wid
return items;
}
-std::vector<TType*> ValidateBlockStreamType(const TType* streamType, bool unwrap = true) {
- const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType));
- return ValidateBlockItems(wideComponents, unwrap);
-}
-
-std::vector<TType*> ValidateBlockFlowType(const TType* flowType, bool unwrap = true) {
- const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType));
- return ValidateBlockItems(wideComponents, unwrap);
-}
-
} // namespace
std::string_view ScriptTypeAsStr(EScriptType type) {
@@ -331,6 +321,16 @@ void EnsureDataOrOptionalOfData(TRuntimeNode node) {
->GetItemType()->IsData(), "Expected data or optional of data");
}
+std::vector<TType*> ValidateBlockStreamType(const TType* streamType, bool unwrap) {
+ const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType));
+ return ValidateBlockItems(wideComponents, unwrap);
+}
+
+std::vector<TType*> ValidateBlockFlowType(const TType* flowType, bool unwrap) {
+ const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType));
+ return ValidateBlockItems(wideComponents, unwrap);
+}
+
TProgramBuilder::TProgramBuilder(const TTypeEnvironment& env, const IFunctionRegistry& functionRegistry, bool voidWithEffects)
: TTypeBuilder(env)
, FunctionRegistry(functionRegistry)
@@ -5827,7 +5827,7 @@ TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode stream, std::optio
return FromFlow(BuildBlockCombineHashed(__func__, ToFlow(stream), filterColumn, keys, aggs, flowReturnType));
} else {
return BuildBlockCombineHashed(__func__, stream, filterColumn, keys, aggs, returnType);
- }
+ }
}
TRuntimeNode TProgramBuilder::BuildBlockMergeFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys,
@@ -5968,22 +5968,22 @@ TRuntimeNode TProgramBuilder::ScalarApply(const TArrayRef<const TRuntimeNode>& a
return TRuntimeNode(builder.Build(), false);
}
-TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode stream, TRuntimeNode dict,
- EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns,
- const TArrayRef<const ui32>& leftKeyDrops, TType* returnType
+TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode leftStream, TRuntimeNode rightStream, EJoinKind joinKind,
+ const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftKeyDrops,
+ const TArrayRef<const ui32>& rightKeyColumns, const TArrayRef<const ui32>& rightKeyDrops, bool rightAny, TType* returnType
) {
- if constexpr (RuntimeVersion < 51U) {
+ if constexpr (RuntimeVersion < 53U) {
THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
}
MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left ||
joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly,
"Unsupported join kind");
MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
- const THashSet<ui32> leftKeySet(leftKeyColumns.cbegin(), leftKeyColumns.cend());
- for (const auto& drop : leftKeyDrops) {
- MKQL_ENSURE(leftKeySet.contains(drop),
- "Only key columns has to be specified in drop column set");
- }
+ MKQL_ENSURE(leftKeyColumns.size() == rightKeyColumns.size(), "Key column count mismatch");
+
+ ValidateBlockStreamType(leftStream.GetStaticType());
+ ValidateBlockStreamType(rightStream.GetStaticType());
+ ValidateBlockStreamType(returnType);
TRuntimeNode::TList leftKeyColumnsNodes;
leftKeyColumnsNodes.reserve(leftKeyColumns.size());
@@ -5999,12 +5999,29 @@ TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode stream, TRuntimeNode
return NewDataLiteral(idx);
});
+ TRuntimeNode::TList rightKeyColumnsNodes;
+ rightKeyColumnsNodes.reserve(rightKeyColumns.size());
+ std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(),
+ std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) {
+ return NewDataLiteral(idx);
+ });
+
+ TRuntimeNode::TList rightKeyDropsNodes;
+ rightKeyDropsNodes.reserve(leftKeyDrops.size());
+ std::transform(rightKeyDrops.cbegin(), rightKeyDrops.cend(),
+ std::back_inserter(rightKeyDropsNodes), [this](const ui32 idx) {
+ return NewDataLiteral(idx);
+ });
+
TCallableBuilder callableBuilder(Env, __func__, returnType);
- callableBuilder.Add(stream);
- callableBuilder.Add(dict);
+ callableBuilder.Add(leftStream);
+ callableBuilder.Add(rightStream);
callableBuilder.Add(NewDataLiteral((ui32)joinKind));
callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
callableBuilder.Add(NewTuple(leftKeyDropsNodes));
+ callableBuilder.Add(NewTuple(rightKeyColumnsNodes));
+ callableBuilder.Add(NewTuple(rightKeyDropsNodes));
+ callableBuilder.Add(NewDataLiteral((bool)rightAny));
return TRuntimeNode(callableBuilder.Build(), false);
}