diff options
author | ziganshinmr <ziganshinmr@yandex-team.com> | 2024-12-11 16:43:17 +0300 |
---|---|---|
committer | ziganshinmr <ziganshinmr@yandex-team.com> | 2024-12-11 17:43:57 +0300 |
commit | e5fb86995e297dee00808170bd7b6fede4c4172c (patch) | |
tree | 240ec00d10136ac47eab15965e26bc45baa60e73 /yql/essentials/minikql/mkql_program_builder.cpp | |
parent | 0ff827ab7bda8ca3b0aa91a1321eb58adedf6f97 (diff) | |
download | ydb-e5fb86995e297dee00808170bd7b6fede4c4172c.tar.gz |
BlockMapJoinCore computation node
commit_hash:7eaad4219a36a3a486c82cdbf82e7630e59e67f9
Diffstat (limited to 'yql/essentials/minikql/mkql_program_builder.cpp')
-rw-r--r-- | yql/essentials/minikql/mkql_program_builder.cpp | 61 |
1 files changed, 39 insertions, 22 deletions
diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp index 0a4cd30828..691a642814 100644 --- a/yql/essentials/minikql/mkql_program_builder.cpp +++ b/yql/essentials/minikql/mkql_program_builder.cpp @@ -257,16 +257,6 @@ static std::vector<TType*> ValidateBlockItems(const TArrayRef<TType* const>& wid return items; } -std::vector<TType*> ValidateBlockStreamType(const TType* streamType, bool unwrap = true) { - const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType)); - return ValidateBlockItems(wideComponents, unwrap); -} - -std::vector<TType*> ValidateBlockFlowType(const TType* flowType, bool unwrap = true) { - const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType)); - return ValidateBlockItems(wideComponents, unwrap); -} - } // namespace std::string_view ScriptTypeAsStr(EScriptType type) { @@ -331,6 +321,16 @@ void EnsureDataOrOptionalOfData(TRuntimeNode node) { ->GetItemType()->IsData(), "Expected data or optional of data"); } +std::vector<TType*> ValidateBlockStreamType(const TType* streamType, bool unwrap) { + const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType)); + return ValidateBlockItems(wideComponents, unwrap); +} + +std::vector<TType*> ValidateBlockFlowType(const TType* flowType, bool unwrap) { + const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType)); + return ValidateBlockItems(wideComponents, unwrap); +} + TProgramBuilder::TProgramBuilder(const TTypeEnvironment& env, const IFunctionRegistry& functionRegistry, bool voidWithEffects) : TTypeBuilder(env) , FunctionRegistry(functionRegistry) @@ -5827,7 +5827,7 @@ TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode stream, std::optio return FromFlow(BuildBlockCombineHashed(__func__, ToFlow(stream), filterColumn, keys, aggs, flowReturnType)); } else { return BuildBlockCombineHashed(__func__, stream, filterColumn, keys, aggs, returnType); - } + } } TRuntimeNode TProgramBuilder::BuildBlockMergeFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys, @@ -5968,22 +5968,22 @@ TRuntimeNode TProgramBuilder::ScalarApply(const TArrayRef<const TRuntimeNode>& a return TRuntimeNode(builder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode stream, TRuntimeNode dict, - EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns, - const TArrayRef<const ui32>& leftKeyDrops, TType* returnType +TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode leftStream, TRuntimeNode rightStream, EJoinKind joinKind, + const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftKeyDrops, + const TArrayRef<const ui32>& rightKeyColumns, const TArrayRef<const ui32>& rightKeyDrops, bool rightAny, TType* returnType ) { - if constexpr (RuntimeVersion < 51U) { + if constexpr (RuntimeVersion < 53U) { THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; } MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly, "Unsupported join kind"); MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified"); - const THashSet<ui32> leftKeySet(leftKeyColumns.cbegin(), leftKeyColumns.cend()); - for (const auto& drop : leftKeyDrops) { - MKQL_ENSURE(leftKeySet.contains(drop), - "Only key columns has to be specified in drop column set"); - } + MKQL_ENSURE(leftKeyColumns.size() == rightKeyColumns.size(), "Key column count mismatch"); + + ValidateBlockStreamType(leftStream.GetStaticType()); + ValidateBlockStreamType(rightStream.GetStaticType()); + ValidateBlockStreamType(returnType); TRuntimeNode::TList leftKeyColumnsNodes; leftKeyColumnsNodes.reserve(leftKeyColumns.size()); @@ -5999,12 +5999,29 @@ TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode stream, TRuntimeNode return NewDataLiteral(idx); }); + TRuntimeNode::TList rightKeyColumnsNodes; + rightKeyColumnsNodes.reserve(rightKeyColumns.size()); + std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(), + std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) { + return NewDataLiteral(idx); + }); + + TRuntimeNode::TList rightKeyDropsNodes; + rightKeyDropsNodes.reserve(leftKeyDrops.size()); + std::transform(rightKeyDrops.cbegin(), rightKeyDrops.cend(), + std::back_inserter(rightKeyDropsNodes), [this](const ui32 idx) { + return NewDataLiteral(idx); + }); + TCallableBuilder callableBuilder(Env, __func__, returnType); - callableBuilder.Add(stream); - callableBuilder.Add(dict); + callableBuilder.Add(leftStream); + callableBuilder.Add(rightStream); callableBuilder.Add(NewDataLiteral((ui32)joinKind)); callableBuilder.Add(NewTuple(leftKeyColumnsNodes)); callableBuilder.Add(NewTuple(leftKeyDropsNodes)); + callableBuilder.Add(NewTuple(rightKeyColumnsNodes)); + callableBuilder.Add(NewTuple(rightKeyDropsNodes)); + callableBuilder.Add(NewDataLiteral((bool)rightAny)); return TRuntimeNode(callableBuilder.Build(), false); } |