diff options
author | aneporada <aneporada@yandex-team.com> | 2024-11-12 14:35:35 +0300 |
---|---|---|
committer | aneporada <aneporada@yandex-team.com> | 2024-11-12 14:49:16 +0300 |
commit | 0f8074b32931e95bc77e99fc0cc06079449a2f03 (patch) | |
tree | addc4b14e22412b769270d75018927719ee8e564 | |
parent | 39b2d17a032a25ea4334f40208471552f1e3561b (diff) | |
download | ydb-0f8074b32931e95bc77e99fc0cc06079449a2f03.tar.gz |
Merge PR #10707: Fixed: Make block combine use stream instead of flow
commit_hash:946462d1ea7e74758c7d6f86cc30cd674dc2195e
-rw-r--r-- | yql/essentials/core/type_ann/type_ann_blocks.cpp | 14 | ||||
-rw-r--r-- | yql/essentials/core/yql_aggregate_expander.cpp | 57 | ||||
-rw-r--r-- | yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp | 1648 | ||||
-rw-r--r-- | yql/essentials/minikql/mkql_program_builder.cpp | 104 | ||||
-rw-r--r-- | yql/essentials/minikql/mkql_program_builder.h | 9 |
5 files changed, 1140 insertions, 692 deletions
diff --git a/yql/essentials/core/type_ann/type_ann_blocks.cpp b/yql/essentials/core/type_ann/type_ann_blocks.cpp index a429ee7ba6..8dcfe1a180 100644 --- a/yql/essentials/core/type_ann/type_ann_blocks.cpp +++ b/yql/essentials/core/type_ann/type_ann_blocks.cpp @@ -791,7 +791,7 @@ IGraphTransformer::TStatus BlockCombineAllWrapper(const TExprNode::TPtr& input, } TTypeAnnotationNode::TListType blockItemTypes; - if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -817,7 +817,7 @@ IGraphTransformer::TStatus BlockCombineAllWrapper(const TExprNode::TPtr& input, } auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType); - input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType)); + input->SetTypeAnn(ctx.Expr.MakeType<TStreamExprType>(outputItemType)); return IGraphTransformer::TStatus::Ok; } @@ -828,7 +828,7 @@ IGraphTransformer::TStatus BlockCombineHashedWrapper(const TExprNode::TPtr& inpu } TTypeAnnotationNode::TListType blockItemTypes; - if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -867,7 +867,7 @@ IGraphTransformer::TStatus BlockCombineHashedWrapper(const TExprNode::TPtr& inpu retMultiType.push_back(ctx.Expr.MakeType<TScalarExprType>(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64))); auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType); - input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType)); + input->SetTypeAnn(ctx.Expr.MakeType<TStreamExprType>(outputItemType)); return IGraphTransformer::TStatus::Ok; } @@ -879,7 +879,7 @@ IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr } TTypeAnnotationNode::TListType blockItemTypes; - if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } YQL_ENSURE(blockItemTypes.size() > 0); @@ -917,7 +917,7 @@ IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr } // disallow any scalar columns except for streamIndex column - auto itemTypes = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems(); + auto itemTypes = input->Head().GetTypeAnn()->Cast<TStreamExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems(); for (ui32 i = 0; i + 1 < itemTypes.size(); ++i) { bool isScalar = itemTypes[i]->GetKind() == ETypeAnnotationKind::Scalar; if (isScalar && i != streamIndex) { @@ -929,7 +929,7 @@ IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr retMultiType.push_back(ctx.Expr.MakeType<TScalarExprType>(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64))); auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType); - input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType)); + input->SetTypeAnn(ctx.Expr.MakeType<TStreamExprType>(outputItemType)); return IGraphTransformer::TStatus::Ok; } diff --git a/yql/essentials/core/yql_aggregate_expander.cpp b/yql/essentials/core/yql_aggregate_expander.cpp index c7068727ef..e268d022e7 100644 --- a/yql/essentials/core/yql_aggregate_expander.cpp +++ b/yql/essentials/core/yql_aggregate_expander.cpp @@ -699,7 +699,8 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() { } else { stream = AggList; } - auto blocks = MakeInputBlocks(stream, keyIdxs, outputColumns, aggs, false, false); + + TExprNode::TPtr blocks = MakeInputBlocks(stream, keyIdxs, outputColumns, aggs, false, false); if (!blocks) { return nullptr; } @@ -708,22 +709,30 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockCombineAllOrHashed() { if (hashed) { aggWideFlow = Ctx.Builder(Node->Pos()) .Callable("WideFromBlocks") - .Callable(0, "BlockCombineHashed") - .Add(0, blocks) - .Callable(1, "Void") + .Callable(0, "ToFlow") + .Callable(0, "BlockCombineHashed") + .Callable(0, "FromFlow") + .Add(0, blocks) + .Seal() + .Callable(1, "Void") + .Seal() + .Add(2, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) + .Add(3, Ctx.NewList(Node->Pos(), std::move(aggs))) .Seal() - .Add(2, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) - .Add(3, Ctx.NewList(Node->Pos(), std::move(aggs))) .Seal() .Seal() .Build(); } else { aggWideFlow = Ctx.Builder(Node->Pos()) - .Callable("BlockCombineAll") - .Add(0, blocks) - .Callable(1, "Void") + .Callable("ToFlow") + .Callable(0, "BlockCombineAll") + .Callable(0, "FromFlow") + .Add(0, blocks) + .Seal() + .Callable(1, "Void") + .Seal() + .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) .Seal() - .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) .Seal() .Build(); } @@ -2891,10 +2900,14 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockMergeFinalizeHashed() { TExprNode::TPtr aggBlocks; if (!isMany) { aggBlocks = Ctx.Builder(Node->Pos()) - .Callable("BlockMergeFinalizeHashed") - .Add(0, blocks) - .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) - .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) + .Callable("ToFlow") + .Callable(0, "BlockMergeFinalizeHashed") + .Callable(0, "FromFlow") + .Add(0, blocks) + .Seal() + .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) + .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) + .Seal() .Seal() .Build(); } else { @@ -2902,12 +2915,16 @@ TExprNode::TPtr TAggregateExpander::TryGenerateBlockMergeFinalizeHashed() { YQL_ENSURE(manyStreamsSetting, "Missing many_streams setting"); aggBlocks = Ctx.Builder(Node->Pos()) - .Callable("BlockMergeManyFinalizeHashed") - .Add(0, blocks) - .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) - .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) - .Atom(3, ToString(streamIdxColumn)) - .Add(4, manyStreamsSetting->TailPtr()) + .Callable("ToFlow") + .Callable(0, "BlockMergeManyFinalizeHashed") + .Callable(0, "FromFlow") + .Add(0, blocks) + .Seal() + .Add(1, Ctx.NewList(Node->Pos(), std::move(keyIdxs))) + .Add(2, Ctx.NewList(Node->Pos(), std::move(aggs))) + .Atom(3, ToString(streamIdxColumn)) + .Add(4, manyStreamsSetting->TailPtr()) + .Seal() .Seal() .Build(); } diff --git a/yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp b/yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp index 2c22e4eec6..225c5249ca 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp @@ -442,9 +442,20 @@ size_t GetBitmapPopCount(const std::shared_ptr<arrow::ArrayData>& arr) { return GetSparseBitmapPopCount(src, len); } +TArrayRef<TType *const> GetWideComponents(TType* type) { + if (type->IsFlow()) { + const auto outputFlowType = AS_TYPE(TFlowType, type); + return GetWideComponents(outputFlowType); + } + if (type->IsStream()) { + const auto outputStreamType = AS_TYPE(TStreamType, type); + return GetWideComponents(outputStreamType); + } + MKQL_ENSURE(false, "Expect either flow or stream"); +} + size_t CalcMaxBlockLenForOutput(TType* out) { - const auto outputType = AS_TYPE(TFlowType, out); - const auto wideComponents = GetWideComponents(outputType); + const auto wideComponents = GetWideComponents(out); MKQL_ENSURE(wideComponents.size() > 0, "Expecting at least one output column"); size_t maxBlockItemSize = 0; @@ -604,11 +615,99 @@ protected: #endif }; -class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>, + +struct TBlockCombineAllState : public TComputationValue<TBlockCombineAllState> { + NUdf::TUnboxedValue* Pointer_ = nullptr; + bool IsFinished_ = false; + bool HasValues_ = false; + TUnboxedValueVector Values_; + std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_; + std::vector<char> AggStates_; + const std::optional<ui32> FilterColumn_; + const size_t Width_; + + TBlockCombineAllState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx) + : TComputationValue(memInfo) + , Values_(std::max(width, params.size())) + , FilterColumn_(filterColumn) + , Width_(width) + { + Pointer_ = Values_.data(); + + ui32 totalStateSize = 0; + for (const auto& p : params) { + Aggs_.emplace_back(p.Prepared_->Make(ctx)); + MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch"); + totalStateSize += Aggs_.back()->StateSize; + } + + AggStates_.resize(totalStateSize); + char* ptr = AggStates_.data(); + for (const auto& agg : Aggs_) { + agg->InitState(ptr); + ptr += agg->StateSize; + } + } + + void ProcessInput() { + const ui64 batchLength = TArrowBlock::From(Values_[Width_ - 1U]).GetDatum().scalar_as<arrow::UInt64Scalar>().value; + if (!batchLength) { + return; + } + + std::optional<ui64> filtered; + if (FilterColumn_) { + const auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum(); + if (filterDatum.is_scalar()) { + if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) { + return; + } + } else { + const ui64 popCount = GetBitmapPopCount(filterDatum.array()); + if (popCount == 0) { + return; + } + + if (popCount < batchLength) { + filtered = popCount; + } + } + } + + HasValues_ = true; + char* ptr = AggStates_.data(); + for (size_t i = 0; i < Aggs_.size(); ++i) { + Aggs_[i]->AddMany(ptr, Values_.data(), batchLength, filtered); + ptr += Aggs_[i]->StateSize; + } + } + + bool MakeOutput() { + IsFinished_ = true; + if (!HasValues_) + return false; + + char* ptr = AggStates_.data(); + for (size_t i = 0; i < Aggs_.size(); ++i) { + Values_[i] = Aggs_[i]->FinishOne(ptr); + Aggs_[i]->DestroyState(ptr); + ptr += Aggs_[i]->StateSize; + } + return true; + } + + NUdf::TUnboxedValuePod Get(size_t index) const { + return Values_[index]; + } +}; + +class TBlockCombineAllWrapperFromFlow : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapperFromFlow>, protected TBlockCombineAllWrapperCodegenBase { -using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>; +using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapperFromFlow>; + +using TState = TBlockCombineAllState; public: - TBlockCombineAllWrapper(TComputationMutables& mutables, + TBlockCombineAllWrapperFromFlow(TComputationMutables& mutables, IComputationWideFlowNode* flow, std::optional<ui32> filterColumn, size_t width, @@ -655,95 +754,11 @@ public: #ifndef MKQL_DISABLE_CODEGEN ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, AggsParams_.size(), - GetMethodPtr(&TState::Get), GetMethodPtr(&TBlockCombineAllWrapper::MakeState), + GetMethodPtr(&TState::Get), GetMethodPtr(&TBlockCombineAllWrapperFromFlow::MakeState), GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::MakeOutput)); } #endif private: - struct TState : public TComputationValue<TState> { - NUdf::TUnboxedValue* Pointer_ = nullptr; - bool IsFinished_ = false; - bool HasValues_ = false; - TUnboxedValueVector Values_; - std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_; - std::vector<char> AggStates_; - const std::optional<ui32> FilterColumn_; - const size_t Width_; - - TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx) - : TComputationValue(memInfo) - , Values_(std::max(width, params.size())) - , FilterColumn_(filterColumn) - , Width_(width) - { - Pointer_ = Values_.data(); - - ui32 totalStateSize = 0; - for (const auto& p : params) { - Aggs_.emplace_back(p.Prepared_->Make(ctx)); - MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch"); - totalStateSize += Aggs_.back()->StateSize; - } - - AggStates_.resize(totalStateSize); - char* ptr = AggStates_.data(); - for (const auto& agg : Aggs_) { - agg->InitState(ptr); - ptr += agg->StateSize; - } - } - - void ProcessInput() { - const ui64 batchLength = TArrowBlock::From(Values_[Width_ - 1U]).GetDatum().scalar_as<arrow::UInt64Scalar>().value; - if (!batchLength) { - return; - } - - std::optional<ui64> filtered; - if (FilterColumn_) { - const auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum(); - if (filterDatum.is_scalar()) { - if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) { - return; - } - } else { - const ui64 popCount = GetBitmapPopCount(filterDatum.array()); - if (popCount == 0) { - return; - } - - if (popCount < batchLength) { - filtered = popCount; - } - } - } - - HasValues_ = true; - char* ptr = AggStates_.data(); - for (size_t i = 0; i < Aggs_.size(); ++i) { - Aggs_[i]->AddMany(ptr, Values_.data(), batchLength, filtered); - ptr += Aggs_[i]->StateSize; - } - } - - bool MakeOutput() { - IsFinished_ = true; - if (!HasValues_) - return false; - - char* ptr = AggStates_.data(); - for (size_t i = 0; i < Aggs_.size(); ++i) { - Values_[i] = Aggs_[i]->FinishOne(ptr); - Aggs_[i]->DestroyState(ptr); - ptr += Aggs_[i]->StateSize; - } - return true; - } - - NUdf::TUnboxedValuePod Get(size_t index) const { - return Values_[index]; - } - }; void RegisterDependencies() const final { FlowDependsOn(Flow_); } @@ -773,6 +788,89 @@ private: const size_t WideFieldsIndex_; }; +class TBlockCombineAllWrapperFromStream : public TMutableComputationNode<TBlockCombineAllWrapperFromStream> { +using TBaseComputation = TMutableComputationNode<TBlockCombineAllWrapperFromStream>; + +using TState = TBlockCombineAllState; +public: + TBlockCombineAllWrapperFromStream(TComputationMutables& mutables, + IComputationNode* stream, + std::optional<ui32> filterColumn, + size_t width, + std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams) + : TBaseComputation(mutables, EValueRepresentation::Boxed) + , Stream_(stream) + , FilterColumn_(filterColumn) + , Width_(width) + , AggsParams_(std::move(aggsParams)) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) + { + MKQL_ENSURE(Width_ > 0, "Missing block length column"); + } + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const + { + const auto state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx); + return ctx.HolderFactory.Create<TStreamValue>(std::move(state), std::move(Stream_->GetValue(ctx))); + } + +private: + class TStreamValue : public TComputationValue<TStreamValue> { + using TBase = TComputationValue<TStreamValue>; + public: + TStreamValue(TMemoryUsageInfo* memInfo, NUdf::TUnboxedValue&& state, NUdf::TUnboxedValue&& stream) + : TBase(memInfo) + , State_(state) + , Stream_(stream) + { + } + + private: + NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) { + TState& state = *static_cast<TState*>(State_.AsBoxed().Get()); + auto* inputFields = state.Values_.data(); + const size_t inputWidth = state.Width_; + + if (state.IsFinished_) + return NUdf::EFetchStatus::Finish; + + while (true) { + switch (Stream_.WideFetch(inputFields, inputWidth)) { + case NUdf::EFetchStatus::Yield: + return NUdf::EFetchStatus::Yield; + case NUdf::EFetchStatus::Ok: + state.ProcessInput(); + continue; + case NUdf::EFetchStatus::Finish: + break; + } + if (state.MakeOutput()) { + for (size_t i = 0; i < width; ++i) { + output[i] = state.Get(i); + } + return NUdf::EFetchStatus::Ok; + } + return NUdf::EFetchStatus::Finish; + } + } + private: + NUdf::TUnboxedValue State_; + NUdf::TUnboxedValue Stream_; + }; + +private: + void RegisterDependencies() const final { + DependsOn(Stream_); + } + +private: + IComputationNode *const Stream_; + const std::optional<ui32> FilterColumn_; + const size_t Width_; + const std::vector<TAggParams<IBlockAggregatorCombineAll>> AggsParams_; + const size_t WideFieldsIndex_; +}; + template <typename T> T MakeKey(TStringBuf s, ui32 keyLength) { Y_UNUSED(keyLength); @@ -1050,585 +1148,594 @@ protected: }; template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived> -class THashedWrapperBase : public TStatefulWideFlowCodegeneratorNode<TDerived>, - protected THashedWrapperCodegenBase -{ - using TComputationBase = TStatefulWideFlowCodegeneratorNode<TDerived>; +struct THashedWrapperBaseState : public TBlockState { +private: static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value; public: - THashedWrapperBase(TComputationMutables& mutables, - IComputationWideFlowNode* flow, - std::optional<ui32> filterColumn, - size_t width, - const std::vector<TKeyParams>& keys, - size_t maxBlockLen, - ui32 keyLength, - std::vector<TAggParams<TAggregator>>&& aggsParams, - ui32 streamIndex, - std::vector<std::vector<ui32>>&& streams) - : TComputationBase(mutables, flow, EValueRepresentation::Boxed) - , Flow_(flow) + bool WritingOutput_ = false; + bool IsFinished_ = false; + + const std::optional<ui32> FilterColumn_; + const std::vector<TKeyParams> Keys_; + const std::vector<TAggParams<TAggregator>>& AggsParams_; + const ui32 KeyLength_; + const ui32 StreamIndex_; + const std::vector<std::vector<ui32>> Streams_; + const size_t MaxBlockLen_; + const size_t Width_; + const size_t OutputWidth_; + + template<typename TKeyType> + struct THashSettings { + static constexpr bool CacheHash = std::is_same_v<TKeyType, TSSOKey>; + }; + using TDynMapImpl = TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>; + using TSetImpl = THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>; + using TFixedMapImpl = TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>; + + ui64 BatchNum_ = 0; + TUnboxedValueVector Values_; + std::vector<std::unique_ptr<TAggregator>> Aggs_; + std::vector<ui32> AggStateOffsets_; + TUnboxedValueVector UnwrappedValues_; + std::vector<std::unique_ptr<IBlockReader>> Readers_; + std::vector<std::unique_ptr<IArrayBuilder>> Builders_; + std::vector<std::unique_ptr<IAggColumnBuilder>> AggBuilders_; + bool HasValues_ = false; + ui32 TotalStateSize_ = 0; + size_t OutputBlockSize_ = 0; + std::unique_ptr<TDynMapImpl> HashMap_; + typename TDynMapImpl::const_iterator HashMapIt_; + std::unique_ptr<TSetImpl> HashSet_; + typename TSetImpl::const_iterator HashSetIt_; + std::unique_ptr<TFixedMapImpl> HashFixedMap_; + typename TFixedMapImpl::const_iterator HashFixedMapIt_; + TPagedArena Arena_; + + THashedWrapperBaseState(TMemoryUsageInfo* memInfo, ui32 keyLength, ui32 streamIndex, size_t width, size_t outputWidth, std::optional<ui32> filterColumn, const std::vector<TAggParams<TAggregator>>& params, + const std::vector<std::vector<ui32>>& streams, const std::vector<TKeyParams>& keys, size_t maxBlockLen, TComputationContext& ctx) + : TBlockState(memInfo, outputWidth) , FilterColumn_(filterColumn) - , Width_(width) - , OutputWidth_(keys.size() + aggsParams.size() + 1) - , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) , Keys_(keys) - , MaxBlockLen_(maxBlockLen) - , AggsParams_(std::move(aggsParams)) + , AggsParams_(params) , KeyLength_(keyLength) , StreamIndex_(streamIndex) - , Streams_(std::move(streams)) + , Streams_(streams) + , MaxBlockLen_(maxBlockLen) + , Width_(width) + , OutputWidth_(outputWidth) + , Values_(width) + , UnwrappedValues_(width) + , Readers_(keys.size()) + , Builders_(keys.size()) + , Arena_(TlsAllocState) { - MKQL_ENSURE(Width_ > 0, "Missing block length column"); - if constexpr (UseFilter) { - MKQL_ENSURE(filterColumn, "Missing filter column"); - MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize"); - } else { - MKQL_ENSURE(!filterColumn, "Unexpected filter column"); + Pointer_ = Values_.data(); + for (size_t i = 0; i < Keys_.size(); ++i) { + auto itemType = AS_TYPE(TBlockType, Keys_[i].Type)->GetItemType(); + Readers_[i] = NYql::NUdf::MakeBlockReader(TTypeInfoHelper(), itemType); + Builders_[i] = NYql::NUdf::MakeArrayBuilder(TTypeInfoHelper(), itemType, ctx.ArrowMemoryPool, MaxBlockLen_, &ctx.Builder->GetPgBuilder()); } - } - - EFetchResult DoCalculate(NUdf::TUnboxedValue& state, - TComputationContext& ctx, - NUdf::TUnboxedValue*const* output) const - { - auto& s = GetState(state, ctx); - if (!s.Count) { - if (s.IsFinished_) - return EFetchResult::Finish; - - while (!s.WritingOutput_) { - const auto fields = ctx.WideFields.data() + WideFieldsIndex_; - s.Values_.assign(s.Values_.size(), NUdf::TUnboxedValuePod()); - switch (Flow_->FetchValues(ctx, fields)) { - case EFetchResult::Yield: - return EFetchResult::Yield; - case EFetchResult::One: - s.ProcessInput(ctx.HolderFactory); - continue; - case EFetchResult::Finish: - break; - } - if (s.Finish()) - break; - else - return EFetchResult::Finish; - } + if constexpr (Many) { + TotalStateSize_ += Streams_.size(); + } - if (!s.FillOutput(ctx.HolderFactory)) - return EFetchResult::Finish; + for (const auto& p : AggsParams_) { + Aggs_.emplace_back(p.Prepared_->Make(ctx)); + MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch"); + AggStateOffsets_.emplace_back(TotalStateSize_); + TotalStateSize_ += Aggs_.back()->StateSize; } - const auto sliceSize = s.Slice(); - for (size_t i = 0; i < OutputWidth_; ++i) { - if (const auto out = output[i]) { - *out = s.Get(sliceSize, ctx.HolderFactory, i); + auto equal = MakeEqual<TKey>(KeyLength_); + auto hasher = MakeHash<TKey>(KeyLength_); + if constexpr (UseSet) { + MKQL_ENSURE(params.empty(), "Only keys are supported"); + HashSet_ = std::make_unique<THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(hasher, equal); + } else { + if (!InlineAggState) { + HashFixedMap_ = std::make_unique<TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(hasher, equal); + } else { + HashMap_ = std::make_unique<TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(TotalStateSize_, hasher, equal); } } - return EFetchResult::One; } -#ifndef MKQL_DISABLE_CODEGEN - ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { - return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, OutputWidth_, - GetMethodPtr(&TState::Get), GetMethodPtr(&THashedWrapperBase::MakeState), - GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::Finish), - GetMethodPtr(&TState::FillOutput), GetMethodPtr(&TState::Slice)); - } -#endif -private: - struct TState : public TBlockState { - bool WritingOutput_ = false; - bool IsFinished_ = false; - - const std::optional<ui32> FilterColumn_; - const std::vector<TKeyParams> Keys_; - const std::vector<TAggParams<TAggregator>>& AggsParams_; - const ui32 KeyLength_; - const ui32 StreamIndex_; - const std::vector<std::vector<ui32>> Streams_; - const size_t MaxBlockLen_; - - template<typename TKeyType> - struct THashSettings { - static constexpr bool CacheHash = std::is_same_v<TKeyType, TSSOKey>; - }; - using TDynMapImpl = TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>; - using TSetImpl = THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>; - using TFixedMapImpl = TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>; - - ui64 BatchNum_ = 0; - TUnboxedValueVector Values_; - std::vector<std::unique_ptr<TAggregator>> Aggs_; - std::vector<ui32> AggStateOffsets_; - TUnboxedValueVector UnwrappedValues_; - std::vector<std::unique_ptr<IBlockReader>> Readers_; - std::vector<std::unique_ptr<IArrayBuilder>> Builders_; - std::vector<std::unique_ptr<IAggColumnBuilder>> AggBuilders_; - bool HasValues_ = false; - ui32 TotalStateSize_ = 0; - size_t OutputBlockSize_ = 0; - std::unique_ptr<TDynMapImpl> HashMap_; - typename TDynMapImpl::const_iterator HashMapIt_; - std::unique_ptr<TSetImpl> HashSet_; - typename TSetImpl::const_iterator HashSetIt_; - std::unique_ptr<TFixedMapImpl> HashFixedMap_; - typename TFixedMapImpl::const_iterator HashFixedMapIt_; - TPagedArena Arena_; - - TState(TMemoryUsageInfo* memInfo, ui32 keyLength, ui32 streamIndex, size_t width, size_t outputWidth, std::optional<ui32> filterColumn, const std::vector<TAggParams<TAggregator>>& params, - const std::vector<std::vector<ui32>>& streams, const std::vector<TKeyParams>& keys, size_t maxBlockLen, TComputationContext& ctx) - : TBlockState(memInfo, outputWidth) - , FilterColumn_(filterColumn) - , Keys_(keys) - , AggsParams_(params) - , KeyLength_(keyLength) - , StreamIndex_(streamIndex) - , Streams_(streams) - , MaxBlockLen_(maxBlockLen) - , Values_(width) - , UnwrappedValues_(width) - , Readers_(keys.size()) - , Builders_(keys.size()) - , Arena_(TlsAllocState) - { - Pointer_ = Values_.data(); - for (size_t i = 0; i < Keys_.size(); ++i) { - auto itemType = AS_TYPE(TBlockType, Keys_[i].Type)->GetItemType(); - Readers_[i] = NYql::NUdf::MakeBlockReader(TTypeInfoHelper(), itemType); - Builders_[i] = NYql::NUdf::MakeArrayBuilder(TTypeInfoHelper(), itemType, ctx.ArrowMemoryPool, MaxBlockLen_, &ctx.Builder->GetPgBuilder()); - } - - if constexpr (Many) { - TotalStateSize_ += Streams_.size(); - } - for (const auto& p : AggsParams_) { - Aggs_.emplace_back(p.Prepared_->Make(ctx)); - MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch"); - AggStateOffsets_.emplace_back(TotalStateSize_); - TotalStateSize_ += Aggs_.back()->StateSize; - } + void ProcessInput(const THolderFactory& holderFactory) { + ++BatchNum_; + const auto batchLength = TArrowBlock::From(Values_.back()).GetDatum().scalar_as<arrow::UInt64Scalar>().value; + if (!batchLength) { + return; + } - auto equal = MakeEqual<TKey>(KeyLength_); - auto hasher = MakeHash<TKey>(KeyLength_); - if constexpr (UseSet) { - MKQL_ENSURE(params.empty(), "Only keys are supported"); - HashSet_ = std::make_unique<THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(hasher, equal); + const ui8* filterBitmap = nullptr; + if constexpr (UseFilter) { + auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum(); + if (filterDatum.is_scalar()) { + if (!filterDatum.template scalar_as<arrow::UInt8Scalar>().value) { + return; + } } else { - if (!InlineAggState) { - HashFixedMap_ = std::make_unique<TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(hasher, equal); - } else { - HashMap_ = std::make_unique<TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(TotalStateSize_, hasher, equal); + const auto& arr = filterDatum.array(); + filterBitmap = arr->template GetValues<ui8>(1); + ui64 popCount = GetBitmapPopCount(arr); + if (popCount == 0) { + return; } } } - void ProcessInput(const THolderFactory& holderFactory) { - ++BatchNum_; - const auto batchLength = TArrowBlock::From(Values_.back()).GetDatum().scalar_as<arrow::UInt64Scalar>().value; - if (!batchLength) { - return; + const ui32* streamIndexData = nullptr; + TMaybe<ui32> streamIndexScalar; + if constexpr (Many) { + auto streamIndexDatum = TArrowBlock::From(Values_[StreamIndex_]).GetDatum(); + if (streamIndexDatum.is_scalar()) { + streamIndexScalar = streamIndexDatum.template scalar_as<arrow::UInt32Scalar>().value; + } else { + MKQL_ENSURE(streamIndexDatum.is_array(), "Expected array"); + streamIndexData = streamIndexDatum.array()->template GetValues<ui32>(1); } - - const ui8* filterBitmap = nullptr; - if constexpr (UseFilter) { - auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum(); - if (filterDatum.is_scalar()) { - if (!filterDatum.template scalar_as<arrow::UInt8Scalar>().value) { - return; - } - } else { - const auto& arr = filterDatum.array(); - filterBitmap = arr->template GetValues<ui8>(1); - ui64 popCount = GetBitmapPopCount(arr); - if (popCount == 0) { - return; - } - } + UnwrappedValues_ = Values_; + for (const auto& p : AggsParams_) { + const auto& columnDatum = TArrowBlock::From(UnwrappedValues_[p.Column_]).GetDatum(); + MKQL_ENSURE(columnDatum.is_array(), "Expected array"); + UnwrappedValues_[p.Column_] = holderFactory.CreateArrowBlock(Unwrap(*columnDatum.array(), p.StateType_)); } + } - const ui32* streamIndexData = nullptr; - TMaybe<ui32> streamIndexScalar; - if constexpr (Many) { - auto streamIndexDatum = TArrowBlock::From(Values_[StreamIndex_]).GetDatum(); - if (streamIndexDatum.is_scalar()) { - streamIndexScalar = streamIndexDatum.template scalar_as<arrow::UInt32Scalar>().value; - } else { - MKQL_ENSURE(streamIndexDatum.is_array(), "Expected array"); - streamIndexData = streamIndexDatum.array()->template GetValues<ui32>(1); - } - UnwrappedValues_ = Values_; - for (const auto& p : AggsParams_) { - const auto& columnDatum = TArrowBlock::From(UnwrappedValues_[p.Column_]).GetDatum(); - MKQL_ENSURE(columnDatum.is_array(), "Expected array"); - UnwrappedValues_[p.Column_] = holderFactory.CreateArrowBlock(Unwrap(*columnDatum.array(), p.StateType_)); - } - } + HasValues_ = true; + std::vector<arrow::Datum> keysDatum; + keysDatum.reserve(Keys_.size()); + for (ui32 i = 0; i < Keys_.size(); ++i) { + keysDatum.emplace_back(TArrowBlock::From(Values_[Keys_[i].Index]).GetDatum()); + } - HasValues_ = true; - std::vector<arrow::Datum> keysDatum; - keysDatum.reserve(Keys_.size()); - for (ui32 i = 0; i < Keys_.size(); ++i) { - keysDatum.emplace_back(TArrowBlock::From(Values_[Keys_[i].Index]).GetDatum()); - } + std::array<TOutputBuffer, PrefetchBatchSize> out; + for (ui32 i = 0; i < PrefetchBatchSize; ++i) { + out[i].Resize(sizeof(TKey)); + } - std::array<TOutputBuffer, PrefetchBatchSize> out; - for (ui32 i = 0; i < PrefetchBatchSize; ++i) { - out[i].Resize(sizeof(TKey)); + std::array<TRobinHoodBatchRequestItem<TKey>, PrefetchBatchSize> insertBatch; + std::array<ui64, PrefetchBatchSize> insertBatchRows; + std::array<char*, PrefetchBatchSize> insertBatchPayloads; + std::array<bool, PrefetchBatchSize> insertBatchIsNew; + ui32 insertBatchLen = 0; + + const auto processInsertBatch = [&]() { + for (ui32 i = 0; i < insertBatchLen; ++i) { + auto& r = insertBatch[i]; + TStringBuf str = out[i].Finish(); + TKey key = MakeKey<TKey>(str, KeyLength_); + r.ConstructKey(key); } - std::array<TRobinHoodBatchRequestItem<TKey>, PrefetchBatchSize> insertBatch; - std::array<ui64, PrefetchBatchSize> insertBatchRows; - std::array<char*, PrefetchBatchSize> insertBatchPayloads; - std::array<bool, PrefetchBatchSize> insertBatchIsNew; - ui32 insertBatchLen = 0; - - const auto processInsertBatch = [&]() { - for (ui32 i = 0; i < insertBatchLen; ++i) { - auto& r = insertBatch[i]; - TStringBuf str = out[i].Finish(); - TKey key = MakeKey<TKey>(str, KeyLength_); - r.ConstructKey(key); + if constexpr (UseSet) { + HashSet_->BatchInsert({insertBatch.data(), insertBatchLen},[&](size_t index, typename THashedWrapperBaseState::TSetImpl::iterator iter, bool isNew) { + Y_UNUSED(index); + if (isNew) { + if constexpr (std::is_same<TKey, TSSOKey>::value || std::is_same<TKey, TExternalFixedSizeKey>::value) { + MoveKeyToArena(HashSet_->GetKey(iter), Arena_, KeyLength_); + } + } + }); + } else { + using THashTable = std::conditional_t<InlineAggState, typename THashedWrapperBaseState::TDynMapImpl, typename THashedWrapperBaseState::TFixedMapImpl>; + THashTable* hash; + if constexpr (!InlineAggState) { + hash = HashFixedMap_.get(); + } else { + hash = HashMap_.get(); } - if constexpr (UseSet) { - HashSet_->BatchInsert({insertBatch.data(), insertBatchLen},[&](size_t index, typename TState::TSetImpl::iterator iter, bool isNew) { - Y_UNUSED(index); - if (isNew) { - if constexpr (std::is_same<TKey, TSSOKey>::value || std::is_same<TKey, TExternalFixedSizeKey>::value) { - MoveKeyToArena(HashSet_->GetKey(iter), Arena_, KeyLength_); - } + hash->BatchInsert({insertBatch.data(), insertBatchLen}, [&](size_t index, typename THashTable::iterator iter, bool isNew) { + if (isNew) { + if constexpr (std::is_same<TKey, TSSOKey>::value || std::is_same<TKey, TExternalFixedSizeKey>::value) { + MoveKeyToArena(hash->GetKey(iter), Arena_, KeyLength_); } - }); - } else { - using THashTable = std::conditional_t<InlineAggState, typename TState::TDynMapImpl, typename TState::TFixedMapImpl>; - THashTable* hash; - if constexpr (!InlineAggState) { - hash = HashFixedMap_.get(); - } else { - hash = HashMap_.get(); } - hash->BatchInsert({insertBatch.data(), insertBatchLen}, [&](size_t index, typename THashTable::iterator iter, bool isNew) { + if constexpr (UseArena) { + // prefetch payloads only + auto payload = hash->GetPayload(iter); + char* ptr; if (isNew) { - if constexpr (std::is_same<TKey, TSSOKey>::value || std::is_same<TKey, TExternalFixedSizeKey>::value) { - MoveKeyToArena(hash->GetKey(iter), Arena_, KeyLength_); - } - } - - if constexpr (UseArena) { - // prefetch payloads only - auto payload = hash->GetPayload(iter); - char* ptr; - if (isNew) { - ptr = (char*)Arena_.Alloc(TotalStateSize_); - *(char**)payload = ptr; - } else { - ptr = *(char**)payload; - } - - insertBatchIsNew[index] = isNew; - insertBatchPayloads[index] = ptr; - NYql::PrefetchForWrite(ptr); + ptr = (char*)Arena_.Alloc(TotalStateSize_); + *(char**)payload = ptr; } else { - // process insert - auto payload = (char*)hash->GetPayload(iter); - auto row = insertBatchRows[index]; - ui32 streamIndex = 0; - if constexpr (Many) { - streamIndex = streamIndexScalar ? *streamIndexScalar : streamIndexData[row]; - } - - Insert(row, payload, isNew, streamIndex); + ptr = *(char**)payload; } - }); - if constexpr (UseArena) { - for (ui32 i = 0; i < insertBatchLen; ++i) { - auto row = insertBatchRows[i]; - ui32 streamIndex = 0; - if constexpr (Many) { - streamIndex = streamIndexScalar ? *streamIndexScalar : streamIndexData[row]; - } - - bool isNew = insertBatchIsNew[i]; - char* payload = insertBatchPayloads[i]; - Insert(row, payload, isNew, streamIndex); + insertBatchIsNew[index] = isNew; + insertBatchPayloads[index] = ptr; + NYql::PrefetchForWrite(ptr); + } else { + // process insert + auto payload = (char*)hash->GetPayload(iter); + auto row = insertBatchRows[index]; + ui32 streamIndex = 0; + if constexpr (Many) { + streamIndex = streamIndexScalar ? *streamIndexScalar : streamIndexData[row]; } + + Insert(row, payload, isNew, streamIndex); } - } - }; + }); - for (ui64 row = 0; row < batchLength; ++row) { - if constexpr (UseFilter) { - if (filterBitmap && !filterBitmap[row]) { - continue; + if constexpr (UseArena) { + for (ui32 i = 0; i < insertBatchLen; ++i) { + auto row = insertBatchRows[i]; + ui32 streamIndex = 0; + if constexpr (Many) { + streamIndex = streamIndexScalar ? *streamIndexScalar : streamIndexData[row]; + } + + bool isNew = insertBatchIsNew[i]; + char* payload = insertBatchPayloads[i]; + Insert(row, payload, isNew, streamIndex); } } + } + }; - // encode key - out[insertBatchLen].Rewind(); - for (ui32 i = 0; i < keysDatum.size(); ++i) { - if (keysDatum[i].is_scalar()) { - // TODO: more efficient code when grouping by scalar - Readers_[i]->SaveScalarItem(*keysDatum[i].scalar(), out[insertBatchLen]); - } else { - Readers_[i]->SaveItem(*keysDatum[i].array(), row, out[insertBatchLen]); - } + for (ui64 row = 0; row < batchLength; ++row) { + if constexpr (UseFilter) { + if (filterBitmap && !filterBitmap[row]) { + continue; } + } - insertBatchRows[insertBatchLen] = row; - ++insertBatchLen; - if (insertBatchLen == PrefetchBatchSize) { - processInsertBatch(); - insertBatchLen = 0; + // encode key + out[insertBatchLen].Rewind(); + for (ui32 i = 0; i < keysDatum.size(); ++i) { + if (keysDatum[i].is_scalar()) { + // TODO: more efficient code when grouping by scalar + Readers_[i]->SaveScalarItem(*keysDatum[i].scalar(), out[insertBatchLen]); + } else { + Readers_[i]->SaveItem(*keysDatum[i].array(), row, out[insertBatchLen]); } } - processInsertBatch(); + insertBatchRows[insertBatchLen] = row; + ++insertBatchLen; + if (insertBatchLen == PrefetchBatchSize) { + processInsertBatch(); + insertBatchLen = 0; + } } - bool Finish() { - if (!HasValues_) { - IsFinished_ = true; - return false; - } + processInsertBatch(); + } - WritingOutput_ = true; - OutputBlockSize_ = 0; - PrepareAggBuilders(); + bool Finish() { + if (!HasValues_) { + IsFinished_ = true; + return false; + } - if constexpr (UseSet) { - HashSetIt_ = HashSet_->Begin(); + WritingOutput_ = true; + OutputBlockSize_ = 0; + PrepareAggBuilders(); + + if constexpr (UseSet) { + HashSetIt_ = HashSet_->Begin(); + } else { + if constexpr (!InlineAggState) { + HashFixedMapIt_ = HashFixedMap_->Begin(); } else { - if constexpr (!InlineAggState) { - HashFixedMapIt_ = HashFixedMap_->Begin(); - } else { - HashMapIt_ = HashMap_->Begin(); - } + HashMapIt_ = HashMap_->Begin(); } - return true; } + return true; + } - bool FillOutput(const THolderFactory& holderFactory) { - bool exit = false; - while (WritingOutput_) { - if constexpr (UseSet) { - for (;!exit && HashSetIt_ != HashSet_->End(); HashSet_->Advance(HashSetIt_)) { - if (!HashSet_->IsValid(HashSetIt_)) { - continue; - } - - if (OutputBlockSize_ == MaxBlockLen_) { - Flush(false, holderFactory); - //return EFetchResult::One; - exit = true; - break; - } - - const TKey& key = HashSet_->GetKey(HashSetIt_); - TInputBuffer in(GetKeyView<TKey>(key, KeyLength_)); - for (auto& kb : Builders_) { - kb->Add(in); - } - ++OutputBlockSize_; + bool FillOutput(const THolderFactory& holderFactory) { + bool exit = false; + while (WritingOutput_) { + if constexpr (UseSet) { + for (;!exit && HashSetIt_ != HashSet_->End(); HashSet_->Advance(HashSetIt_)) { + if (!HashSet_->IsValid(HashSetIt_)) { + continue; } - break; - } else { - const bool done = InlineAggState ? - Iterate(*HashMap_, HashMapIt_) : - Iterate(*HashFixedMap_, HashFixedMapIt_); - if (done) { + + if (OutputBlockSize_ == MaxBlockLen_) { + Flush(false, holderFactory); + //return EFetchResult::One; + exit = true; break; } - Flush(false, holderFactory); - exit = true; + + const TKey& key = HashSet_->GetKey(HashSetIt_); + TInputBuffer in(GetKeyView<TKey>(key, KeyLength_)); + for (auto& kb : Builders_) { + kb->Add(in); + } + ++OutputBlockSize_; + } + break; + } else { + const bool done = InlineAggState ? + Iterate(*HashMap_, HashMapIt_) : + Iterate(*HashFixedMap_, HashFixedMapIt_); + if (done) { break; } + Flush(false, holderFactory); + exit = true; + break; } + } - if (!exit) { - IsFinished_ = true; - WritingOutput_ = false; - if (!OutputBlockSize_) - return false; - Flush(true, holderFactory); - } - - FillArrays(); - return true; + if (!exit) { + IsFinished_ = true; + WritingOutput_ = false; + if (!OutputBlockSize_) + return false; + Flush(true, holderFactory); } - private: - void PrepareAggBuilders() { - if constexpr (!UseSet) { - AggBuilders_.clear(); - AggBuilders_.reserve(Aggs_.size()); - for (const auto& a : Aggs_) { - if constexpr (Finalize) { - AggBuilders_.emplace_back(a->MakeResultBuilder(MaxBlockLen_)); - } else { - AggBuilders_.emplace_back(a->MakeStateBuilder(MaxBlockLen_)); - } + + FillArrays(); + return true; + } +private: + void PrepareAggBuilders() { + if constexpr (!UseSet) { + AggBuilders_.clear(); + AggBuilders_.reserve(Aggs_.size()); + for (const auto& a : Aggs_) { + if constexpr (Finalize) { + AggBuilders_.emplace_back(a->MakeResultBuilder(MaxBlockLen_)); + } else { + AggBuilders_.emplace_back(a->MakeStateBuilder(MaxBlockLen_)); } } } + } - void Flush(bool final, const THolderFactory& holderFactory) { - if (!OutputBlockSize_) { - return; - } + void Flush(bool final, const THolderFactory& holderFactory) { + if (!OutputBlockSize_) { + return; + } - for (size_t i = 0; i < Builders_.size(); ++i) { - Values[i] = holderFactory.CreateArrowBlock(Builders_[i]->Build(final)); - } + for (size_t i = 0; i < Builders_.size(); ++i) { + Values[i] = holderFactory.CreateArrowBlock(Builders_[i]->Build(final)); + } - if constexpr (!UseSet) { - for (size_t i = 0; i < Aggs_.size(); ++i) { - Values[Builders_.size() + i] = AggBuilders_[i]->Build(); - } - if (!final) { - PrepareAggBuilders(); - } + if constexpr (!UseSet) { + for (size_t i = 0; i < Aggs_.size(); ++i) { + Values[Builders_.size() + i] = AggBuilders_[i]->Build(); + } + if (!final) { + PrepareAggBuilders(); } - - Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(OutputBlockSize_))); - OutputBlockSize_ = 0; } - void Insert(ui64 row, char* payload, bool isNew, ui32 currentStreamIndex) const { - char* ptr = payload; + Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(OutputBlockSize_))); + OutputBlockSize_ = 0; + } - if (isNew) { - if constexpr (Many) { - static_assert(Finalize); - MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); - memset(ptr, 0, Streams_.size()); - ptr[currentStreamIndex] = 1; + void Insert(ui64 row, char* payload, bool isNew, ui32 currentStreamIndex) const { + char* ptr = payload; - for (auto i : Streams_[currentStreamIndex]) { + if (isNew) { + if constexpr (Many) { + static_assert(Finalize); + MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); + memset(ptr, 0, Streams_.size()); + ptr[currentStreamIndex] = 1; - Aggs_[i]->LoadState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row); - } - } else { - for (size_t i = 0; i < Aggs_.size(); ++i) { - if constexpr (Finalize) { - Aggs_[i]->LoadState(ptr, BatchNum_, Values_.data(), row); - } else { - Aggs_[i]->InitKey(ptr, BatchNum_, Values_.data(), row); - } + for (auto i : Streams_[currentStreamIndex]) { - ptr += Aggs_[i]->StateSize; - } + Aggs_[i]->LoadState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row); } } else { - if constexpr (Many) { - static_assert(Finalize); - MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); - - bool isNewStream = !ptr[currentStreamIndex]; - ptr[currentStreamIndex] = 1; - - for (auto i : Streams_[currentStreamIndex]) { - - if (isNewStream) { - Aggs_[i]->LoadState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row); - } else { - Aggs_[i]->UpdateState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row); - } + for (size_t i = 0; i < Aggs_.size(); ++i) { + if constexpr (Finalize) { + Aggs_[i]->LoadState(ptr, BatchNum_, Values_.data(), row); + } else { + Aggs_[i]->InitKey(ptr, BatchNum_, Values_.data(), row); } - } else { - for (size_t i = 0; i < Aggs_.size(); ++i) { - if constexpr (Finalize) { - Aggs_[i]->UpdateState(ptr, BatchNum_, Values_.data(), row); - } else { - Aggs_[i]->UpdateKey(ptr, BatchNum_, Values_.data(), row); - } - ptr += Aggs_[i]->StateSize; - } + ptr += Aggs_[i]->StateSize; } } - } + } else { + if constexpr (Many) { + static_assert(Finalize); + MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); - template <typename THash> - bool Iterate(THash& hash, typename THash::const_iterator& iter) { - MKQL_ENSURE(WritingOutput_, "Supposed to be called at the end"); - std::array<typename THash::const_iterator, PrefetchBatchSize> iters; - ui32 itersLen = 0; - auto iterateBatch = [&]() { - for (ui32 i = 0; i < itersLen; ++i) { - auto iter = iters[i]; - const TKey& key = hash.GetKey(iter); - auto payload = (char*)hash.GetPayload(iter); - char* ptr; - if constexpr (UseArena) { - ptr = *(char**)payload; + bool isNewStream = !ptr[currentStreamIndex]; + ptr[currentStreamIndex] = 1; + + for (auto i : Streams_[currentStreamIndex]) { + + if (isNewStream) { + Aggs_[i]->LoadState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row); } else { - ptr = payload; + Aggs_[i]->UpdateState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row); } - - TInputBuffer in(GetKeyView<TKey>(key, KeyLength_)); - for (auto& kb : Builders_) { - kb->Add(in); + } + } else { + for (size_t i = 0; i < Aggs_.size(); ++i) { + if constexpr (Finalize) { + Aggs_[i]->UpdateState(ptr, BatchNum_, Values_.data(), row); + } else { + Aggs_[i]->UpdateKey(ptr, BatchNum_, Values_.data(), row); } - if constexpr (Many) { - for (ui32 i = 0; i < Streams_.size(); ++i) { - MKQL_ENSURE(ptr[i], "Missing partial aggregation state for stream #" << i); - } + ptr += Aggs_[i]->StateSize; + } + } + } + } - ptr += Streams_.size(); - } + template <typename THash> + bool Iterate(THash& hash, typename THash::const_iterator& iter) { + MKQL_ENSURE(WritingOutput_, "Supposed to be called at the end"); + std::array<typename THash::const_iterator, PrefetchBatchSize> iters; + ui32 itersLen = 0; + auto iterateBatch = [&]() { + for (ui32 i = 0; i < itersLen; ++i) { + auto iter = iters[i]; + const TKey& key = hash.GetKey(iter); + auto payload = (char*)hash.GetPayload(iter); + char* ptr; + if constexpr (UseArena) { + ptr = *(char**)payload; + } else { + ptr = payload; + } - for (size_t i = 0; i < Aggs_.size(); ++i) { - AggBuilders_[i]->Add(ptr); - Aggs_[i]->DestroyState(ptr); + TInputBuffer in(GetKeyView<TKey>(key, KeyLength_)); + for (auto& kb : Builders_) { + kb->Add(in); + } - ptr += Aggs_[i]->StateSize; + if constexpr (Many) { + for (ui32 i = 0; i < Streams_.size(); ++i) { + MKQL_ENSURE(ptr[i], "Missing partial aggregation state for stream #" << i); } - } - }; - for (; iter != hash.End(); hash.Advance(iter)) { - if (!hash.IsValid(iter)) { - continue; + ptr += Streams_.size(); } - if (OutputBlockSize_ == MaxBlockLen_) { - iterateBatch(); - return false; - } + for (size_t i = 0; i < Aggs_.size(); ++i) { + AggBuilders_[i]->Add(ptr); + Aggs_[i]->DestroyState(ptr); - if (itersLen == iters.size()) { - iterateBatch(); - itersLen = 0; + ptr += Aggs_[i]->StateSize; } + } + }; - iters[itersLen] = iter; - ++itersLen; - ++OutputBlockSize_; - if constexpr (UseArena) { - auto payload = (char*)hash.GetPayload(iter); - auto ptr = *(char**)payload; - NYql::PrefetchForWrite(ptr); + for (; iter != hash.End(); hash.Advance(iter)) { + if (!hash.IsValid(iter)) { + continue; + } + + if (OutputBlockSize_ == MaxBlockLen_) { + iterateBatch(); + return false; + } + + if (itersLen == iters.size()) { + iterateBatch(); + itersLen = 0; + } + + iters[itersLen] = iter; + ++itersLen; + ++OutputBlockSize_; + if constexpr (UseArena) { + auto payload = (char*)hash.GetPayload(iter); + auto ptr = *(char**)payload; + NYql::PrefetchForWrite(ptr); + } + + if constexpr (std::is_same<TKey, TSSOKey>::value) { + const auto& key = hash.GetKey(iter); + if (!key.IsInplace()) { + NYql::PrefetchForRead(key.AsView().Data()); } + } else if constexpr (std::is_same<TKey, TExternalFixedSizeKey>::value) { + const auto& key = hash.GetKey(iter); + NYql::PrefetchForRead(key.Data); + } + } - if constexpr (std::is_same<TKey, TSSOKey>::value) { - const auto& key = hash.GetKey(iter); - if (!key.IsInplace()) { - NYql::PrefetchForRead(key.AsView().Data()); - } - } else if constexpr (std::is_same<TKey, TExternalFixedSizeKey>::value) { - const auto& key = hash.GetKey(iter); - NYql::PrefetchForRead(key.Data); + iterateBatch(); + return true; + } +}; + +template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived> +class THashedWrapperBaseFromFlow : public TStatefulWideFlowCodegeneratorNode<TDerived>, + protected THashedWrapperCodegenBase +{ + using TComputationBase = TStatefulWideFlowCodegeneratorNode<TDerived>; + + using TState = THashedWrapperBaseState<TKey, TAggregator, TFixedAggState, UseSet, UseFilter, Finalize, Many, TDerived>; +public: + THashedWrapperBaseFromFlow(TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional<ui32> filterColumn, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<TAggregator>>&& aggsParams, + ui32 streamIndex, + std::vector<std::vector<ui32>>&& streams) + : TComputationBase(mutables, flow, EValueRepresentation::Boxed) + , Flow_(flow) + , FilterColumn_(filterColumn) + , Width_(width) + , OutputWidth_(keys.size() + aggsParams.size() + 1) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) + , Keys_(keys) + , MaxBlockLen_(maxBlockLen) + , AggsParams_(std::move(aggsParams)) + , KeyLength_(keyLength) + , StreamIndex_(streamIndex) + , Streams_(std::move(streams)) + { + MKQL_ENSURE(Width_ > 0, "Missing block length column"); + if constexpr (UseFilter) { + MKQL_ENSURE(filterColumn, "Missing filter column"); + MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize"); + } else { + MKQL_ENSURE(!filterColumn, "Unexpected filter column"); + } + } + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, + TComputationContext& ctx, + NUdf::TUnboxedValue*const* output) const + { + auto& s = GetState(state, ctx); + if (!s.Count) { + if (s.IsFinished_) + return EFetchResult::Finish; + + while (!s.WritingOutput_) { + const auto fields = ctx.WideFields.data() + WideFieldsIndex_; + s.Values_.assign(s.Values_.size(), NUdf::TUnboxedValuePod()); + switch (Flow_->FetchValues(ctx, fields)) { + case EFetchResult::Yield: + return EFetchResult::Yield; + case EFetchResult::One: + s.ProcessInput(ctx.HolderFactory); + continue; + case EFetchResult::Finish: + break; } + + if (s.Finish()) + break; + else + return EFetchResult::Finish; } - iterateBatch(); - return true; + if (!s.FillOutput(ctx.HolderFactory)) + return EFetchResult::Finish; } - }; + + const auto sliceSize = s.Slice(); + for (size_t i = 0; i < OutputWidth_; ++i) { + if (const auto out = output[i]) { + *out = s.Get(sliceSize, ctx.HolderFactory, i); + } + } + return EFetchResult::One; + } +#ifndef MKQL_DISABLE_CODEGEN + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, OutputWidth_, + GetMethodPtr(&TState::Get), GetMethodPtr(&THashedWrapperBaseFromFlow::MakeState), + GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::Finish), + GetMethodPtr(&TState::FillOutput), GetMethodPtr(&TState::Slice)); + } +#endif private: void RegisterDependencies() const final { this->FlowDependsOn(Flow_); @@ -1665,11 +1772,136 @@ private: const std::vector<std::vector<ui32>> Streams_; }; + +template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived> +class THashedWrapperBaseFromStream : public TMutableComputationNode<TDerived>, + protected THashedWrapperCodegenBase +{ + using TComputationBase = TMutableComputationNode<TDerived>; + + using TState = THashedWrapperBaseState<TKey, TAggregator, TFixedAggState, UseSet, UseFilter, Finalize, Many, TDerived>; +public: + THashedWrapperBaseFromStream(TComputationMutables& mutables, + IComputationNode* stream, + std::optional<ui32> filterColumn, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<TAggregator>>&& aggsParams, + ui32 streamIndex, + std::vector<std::vector<ui32>>&& streams) + : TComputationBase(mutables, EValueRepresentation::Boxed) + , Stream_(stream) + , FilterColumn_(filterColumn) + , Width_(width) + , OutputWidth_(keys.size() + aggsParams.size() + 1) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) + , Keys_(keys) + , MaxBlockLen_(maxBlockLen) + , AggsParams_(std::move(aggsParams)) + , KeyLength_(keyLength) + , StreamIndex_(streamIndex) + , Streams_(std::move(streams)) + { + MKQL_ENSURE(Width_ > 0, "Missing block length column"); + if constexpr (UseFilter) { + MKQL_ENSURE(filterColumn, "Missing filter column"); + MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize"); + } else { + MKQL_ENSURE(!filterColumn, "Unexpected filter column"); + } + } + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const + { + const auto state = ctx.HolderFactory.Create<TState>(KeyLength_, StreamIndex_, Width_, OutputWidth_, FilterColumn_, AggsParams_, Streams_, Keys_, MaxBlockLen_, ctx); + return ctx.HolderFactory.Create<TStreamValue>(ctx.HolderFactory, std::move(state), std::move(Stream_->GetValue(ctx))); + } +private: + class TStreamValue : public TComputationValue<TStreamValue> { + using TBase = TComputationValue<TStreamValue>; + public: + TStreamValue(TMemoryUsageInfo* memInfo, const THolderFactory& holderFactory, + NUdf::TUnboxedValue&& state, NUdf::TUnboxedValue&& stream) + : TBase(memInfo) + , State_(state) + , Stream_(stream) + , HolderFactory_(holderFactory) + { + } + + private: + NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) { + TState& state = *static_cast<TState*>(State_.AsBoxed().Get()); + auto* inputFields = state.Values_.data(); + const size_t inputWidth = state.Width_; + const size_t outputWidth = state.OutputWidth_; + MKQL_ENSURE(outputWidth == width, "The given width doesn't equal to the result type size"); + + if (!state.Count) { + if (state.IsFinished_) + return NUdf::EFetchStatus::Finish; + + while (!state.WritingOutput_) { + switch (Stream_.WideFetch(inputFields, inputWidth)) { + case NUdf::EFetchStatus::Yield: + return NUdf::EFetchStatus::Yield; + case NUdf::EFetchStatus::Ok: + state.ProcessInput(HolderFactory_); + continue; + case NUdf::EFetchStatus::Finish: + break; + } + + if (state.Finish()) + break; + else + return NUdf::EFetchStatus::Finish; + } + + if (!state.FillOutput(HolderFactory_)) + return NUdf::EFetchStatus::Finish; + } + + const auto sliceSize = state.Slice(); + for (size_t i = 0; i < outputWidth; ++i) { + output[i] = state.Get(sliceSize, HolderFactory_, i); + } + return NUdf::EFetchStatus::Ok; + } + private: + NUdf::TUnboxedValue State_; + NUdf::TUnboxedValue Stream_; + const THolderFactory& HolderFactory_; + }; +private: + void RegisterDependencies() const final { + this->DependsOn(Stream_); + } + + IComputationNode *const Stream_; + const std::optional<ui32> FilterColumn_; + const size_t Width_; + const size_t OutputWidth_; + const size_t WideFieldsIndex_; + const std::vector<TKeyParams> Keys_; + const size_t MaxBlockLen_; + const std::vector<TAggParams<TAggregator>> AggsParams_; + const ui32 KeyLength_; + const ui32 StreamIndex_; + const std::vector<std::vector<ui32>> Streams_; +}; + +template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter, typename TInputNode> +class TBlockCombineHashedWrapper {}; + template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter> -class TBlockCombineHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>> { +class TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationWideFlowNode> + : public THashedWrapperBaseFromFlow<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationWideFlowNode>> { public: - using TSelf = TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>; - using TBase = THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TSelf>; + using TSelf = TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationWideFlowNode>; + using TBase = THashedWrapperBaseFromFlow<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TSelf>; TBlockCombineHashedWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, @@ -1683,11 +1915,34 @@ public: {} }; +template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter> +class TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationNode> + : public THashedWrapperBaseFromStream<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationNode>> { +public: + using TSelf = TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter, IComputationNode>; + using TBase = THashedWrapperBaseFromStream<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TSelf>; + + TBlockCombineHashedWrapper(TComputationMutables& mutables, + IComputationNode* stream, + std::optional<ui32> filterColumn, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) + : TBase(mutables, stream, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams), 0, {}) + {} +}; + +template <typename TKey, typename TFixedAggState, bool UseSet, typename TInputNode> +class TBlockMergeFinalizeHashedWrapper {}; + template <typename TKey, typename TFixedAggState, bool UseSet> -class TBlockMergeFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>> { +class TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationWideFlowNode> + : public THashedWrapperBaseFromFlow<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationWideFlowNode>> { public: - using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>; - using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TSelf>; + using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationWideFlowNode>; + using TBase = THashedWrapperBaseFromFlow<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TSelf>; TBlockMergeFinalizeHashedWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, @@ -1700,11 +1955,33 @@ public: {} }; +template <typename TKey, typename TFixedAggState, bool UseSet> +class TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationNode> + : public THashedWrapperBaseFromStream<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationNode>> { +public: + using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet, IComputationNode>; + using TBase = THashedWrapperBaseFromStream<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TSelf>; + + TBlockMergeFinalizeHashedWrapper(TComputationMutables& mutables, + IComputationNode* stream, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) + : TBase(mutables, stream, {}, width, keys, maxBlockLen, keyLength, std::move(aggsParams), 0, {}) + {} +}; + +template <typename TKey, typename TFixedAggState, typename TInputNode> +class TBlockMergeManyFinalizeHashedWrapper {}; + template <typename TKey, typename TFixedAggState> -class TBlockMergeManyFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState>> { +class TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationWideFlowNode> + : public THashedWrapperBaseFromFlow<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationWideFlowNode>> { public: - using TSelf = TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState>; - using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TSelf>; + using TSelf = TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationWideFlowNode>; + using TBase = THashedWrapperBaseFromFlow<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TSelf>; TBlockMergeManyFinalizeHashedWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, @@ -1718,6 +1995,25 @@ public: {} }; +template <typename TKey, typename TFixedAggState> +class TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationNode> + : public THashedWrapperBaseFromStream<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationNode>> { +public: + using TSelf = TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState, IComputationNode>; + using TBase = THashedWrapperBaseFromStream<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TSelf>; + + TBlockMergeManyFinalizeHashedWrapper(TComputationMutables& mutables, + IComputationNode* stream, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams, + ui32 streamIndex, std::vector<std::vector<ui32>>&& streams) + : TBase(mutables, stream, {}, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)) + {} +}; + template <typename TAggregator> std::unique_ptr<IPreparedBlockAggregator<TAggregator>> PrepareBlockAggregator(const IBlockAggregatorFactory& factory, TTupleType* tupleType, @@ -1824,117 +2120,117 @@ ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional< return totalStateSize; } -template <bool UseSet, bool UseFilter, typename TKey> +template <bool UseSet, bool UseFilter, typename TKey, typename TInputNode> IComputationNode* MakeBlockCombineHashedWrapper( ui32 keyLength, ui32 totalStateSize, TComputationMutables& mutables, - IComputationWideFlowNode* flow, + TInputNode* streamOrFlow, std::optional<ui32> filterColumn, size_t width, const std::vector<TKeyParams>& keys, size_t maxBlockLen, std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) { if (totalStateSize <= sizeof(TState8)) { - return new TBlockCombineHashedWrapper<TKey, TState8, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + return new TBlockCombineHashedWrapper<TKey, TState8, UseSet, UseFilter, TInputNode>(mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); } if (totalStateSize <= sizeof(TState16)) { - return new TBlockCombineHashedWrapper<TKey, TState16, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + return new TBlockCombineHashedWrapper<TKey, TState16, UseSet, UseFilter, TInputNode>(mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); } - return new TBlockCombineHashedWrapper<TKey, TStateArena, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + return new TBlockCombineHashedWrapper<TKey, TStateArena, UseSet, UseFilter, TInputNode>(mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); } -template <bool UseSet, bool UseFilter> +template <bool UseSet, bool UseFilter, typename TInputNode> IComputationNode* MakeBlockCombineHashedWrapper( TMaybe<ui32> totalKeysSize, bool isFixed, ui32 totalStateSize, TComputationMutables& mutables, - IComputationWideFlowNode* flow, + TInputNode* streamOrFlow, std::optional<ui32> filterColumn, size_t width, const std::vector<TKeyParams>& keys, size_t maxBlockLen, std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) { if (totalKeysSize && *totalKeysSize <= sizeof(ui32)) { - return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui32>(*totalKeysSize, totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui32>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); } if (totalKeysSize && *totalKeysSize <= sizeof(ui64)) { - return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui64>(*totalKeysSize, totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui64>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); } if (totalKeysSize && *totalKeysSize <= sizeof(TKey16)) { - return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TKey16>(*totalKeysSize, totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TKey16>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); } if (totalKeysSize && isFixed) { - return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TExternalFixedSizeKey>(*totalKeysSize, totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TExternalFixedSizeKey>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); } - return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TSSOKey>(Max<ui32>(), totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TSSOKey>(Max<ui32>(), totalStateSize, mutables, streamOrFlow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); } -template <typename TKey, bool UseSet> +template <typename TKey, bool UseSet, typename TInputNode> IComputationNode* MakeBlockMergeFinalizeHashedWrapper( ui32 keyLength, ui32 totalStateSize, TComputationMutables& mutables, - IComputationWideFlowNode* flow, + TInputNode* streamOrFlow, size_t width, const std::vector<TKeyParams>& keys, size_t maxBlockLen, std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) { if (totalStateSize <= sizeof(TState8)) { - return new TBlockMergeFinalizeHashedWrapper<TKey, TState8, UseSet>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + return new TBlockMergeFinalizeHashedWrapper<TKey, TState8, UseSet, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); } if (totalStateSize <= sizeof(TState16)) { - return new TBlockMergeFinalizeHashedWrapper<TKey, TState16, UseSet>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + return new TBlockMergeFinalizeHashedWrapper<TKey, TState16, UseSet, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); } - return new TBlockMergeFinalizeHashedWrapper<TKey, TStateArena, UseSet>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + return new TBlockMergeFinalizeHashedWrapper<TKey, TStateArena, UseSet, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); } -template <bool UseSet> +template <bool UseSet, typename TInputNode> IComputationNode* MakeBlockMergeFinalizeHashedWrapper( TMaybe<ui32> totalKeysSize, bool isFixed, ui32 totalStateSize, TComputationMutables& mutables, - IComputationWideFlowNode* flow, + TInputNode* streamOrFlow, size_t width, const std::vector<TKeyParams>& keys, size_t maxBlockLen, std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) { if (totalKeysSize && *totalKeysSize <= sizeof(ui32)) { - return MakeBlockMergeFinalizeHashedWrapper<ui32, UseSet>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockMergeFinalizeHashedWrapper<ui32, UseSet>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams)); } if (totalKeysSize && *totalKeysSize <= sizeof(ui64)) { - return MakeBlockMergeFinalizeHashedWrapper<ui64, UseSet>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockMergeFinalizeHashedWrapper<ui64, UseSet>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams)); } if (totalKeysSize && *totalKeysSize <= sizeof(TKey16)) { - return MakeBlockMergeFinalizeHashedWrapper<TKey16, UseSet>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockMergeFinalizeHashedWrapper<TKey16, UseSet>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams)); } if (totalKeysSize && isFixed) { - return MakeBlockMergeFinalizeHashedWrapper<TExternalFixedSizeKey, UseSet>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockMergeFinalizeHashedWrapper<TExternalFixedSizeKey, UseSet>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams)); } - return MakeBlockMergeFinalizeHashedWrapper<TSSOKey, UseSet>(Max<ui32>(), totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); + return MakeBlockMergeFinalizeHashedWrapper<TSSOKey, UseSet>(Max<ui32>(), totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams)); } -template <typename TKey> +template <typename TKey, typename TInputNode> IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper( ui32 keyLength, ui32 totalStateSize, TComputationMutables& mutables, - IComputationWideFlowNode* flow, + TInputNode* streamOrFlow, size_t width, const std::vector<TKeyParams>& keys, size_t maxBlockLen, @@ -1943,22 +2239,23 @@ IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper( std::vector<std::vector<ui32>>&& streams) { if (totalStateSize <= sizeof(TState8)) { - return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState8>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)); + return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState8, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)); } if (totalStateSize <= sizeof(TState16)) { - return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState16>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)); + return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState16, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)); } - return new TBlockMergeManyFinalizeHashedWrapper<TKey, TStateArena>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)); + return new TBlockMergeManyFinalizeHashedWrapper<TKey, TStateArena, TInputNode>(mutables, streamOrFlow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)); } +template <typename TInputNode> IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper( TMaybe<ui32> totalKeysSize, bool isFixed, ui32 totalStateSize, TComputationMutables& mutables, - IComputationWideFlowNode* flow, + TInputNode* streamOrFlow, size_t width, const std::vector<TKeyParams>& keys, size_t maxBlockLen, @@ -1966,22 +2263,22 @@ IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper( ui32 streamIndex, std::vector<std::vector<ui32>>&& streams) { if (totalKeysSize && *totalKeysSize <= sizeof(ui32)) { - return MakeBlockMergeManyFinalizeHashedWrapper<ui32>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + return MakeBlockMergeManyFinalizeHashedWrapper<ui32>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); } if (totalKeysSize && *totalKeysSize <= sizeof(ui64)) { - return MakeBlockMergeManyFinalizeHashedWrapper<ui64>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + return MakeBlockMergeManyFinalizeHashedWrapper<ui64>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); } if (totalKeysSize && *totalKeysSize <= sizeof(TKey16)) { - return MakeBlockMergeManyFinalizeHashedWrapper<TKey16>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + return MakeBlockMergeManyFinalizeHashedWrapper<TKey16>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); } if (totalKeysSize && isFixed) { - return MakeBlockMergeManyFinalizeHashedWrapper<TExternalFixedSizeKey>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + return MakeBlockMergeManyFinalizeHashedWrapper<TExternalFixedSizeKey>(*totalKeysSize, totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); } - return MakeBlockMergeManyFinalizeHashedWrapper<TSSOKey>(Max<ui32>(), totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + return MakeBlockMergeManyFinalizeHashedWrapper<TSSOKey>(Max<ui32>(), totalStateSize, mutables, streamOrFlow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); } void PrepareKeys(const std::vector<TKeyParams>& keys, TMaybe<ui32>& totalKeysSize, bool& isFixed) { @@ -2012,14 +2309,15 @@ void FillAggStreams(TRuntimeNode streamsNode, std::vector<std::vector<ui32>>& st IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args"); - const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); - const auto wideComponents = GetWideComponents(flowType); + + const bool isStream = callable.GetInput(0).GetStaticType()->IsStream(); + MKQL_ENSURE(isStream == callable.GetType()->GetReturnType()->IsStream(), "input and output must be both either flow or stream"); + + const auto wideComponents = GetWideComponents(callable.GetInput(0).GetStaticType()); const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env); - const auto returnFlowType = AS_TYPE(TFlowType, callable.GetType()->GetReturnType()); - const auto returnWideComponents = GetWideComponents(returnFlowType); + const auto returnWideComponents = GetWideComponents(callable.GetType()->GetReturnType()); - auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); - MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + const auto wideFlowOrStream = LocateNode(ctx.NodeLocator, callable, 0); auto filterColumnVal = AS_VALUE(TOptionalLiteral, callable.GetInput(1)); std::optional<ui32> filterColumn; @@ -2030,19 +2328,28 @@ IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNod auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2)); std::vector<TAggParams<IBlockAggregatorCombineAll>> aggsParams; FillAggParams<IBlockAggregatorCombineAll>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env, false, false, returnWideComponents, 0); - return new TBlockCombineAllWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams)); + + if (isStream) { + const auto wideStream = wideFlowOrStream; + return new TBlockCombineAllWrapperFromStream(ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams)); + } else { + const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(wideFlowOrStream); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + return new TBlockCombineAllWrapperFromFlow(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams)); + } } IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args"); - const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); - const auto wideComponents = GetWideComponents(flowType); + + const bool isStream = callable.GetInput(0).GetStaticType()->IsStream(); + MKQL_ENSURE(isStream == callable.GetType()->GetReturnType()->IsStream(), "input and output must be both either flow or stream"); + + const auto wideComponents = GetWideComponents(callable.GetInput(0).GetStaticType()); const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env); - const auto returnFlowType = AS_TYPE(TFlowType, callable.GetType()->GetReturnType()); - const auto returnWideComponents = GetWideComponents(returnFlowType); + const auto returnWideComponents = GetWideComponents(callable.GetType()->GetReturnType()); - auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); - MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + const auto wideStreamOrFlow = LocateNode(ctx.NodeLocator, callable, 0); auto filterColumnVal = AS_VALUE(TOptionalLiteral, callable.GetInput(1)); std::optional<ui32> filterColumn; @@ -2066,31 +2373,51 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation PrepareKeys(keys, totalKeysSize, isFixed); const size_t maxBlockLen = CalcMaxBlockLenForOutput(callable.GetType()->GetReturnType()); - if (filterColumn) { - if (aggsParams.empty()) { - return MakeBlockCombineHashedWrapper<true, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + if (isStream) { + const auto wideStream = wideStreamOrFlow; + if (filterColumn) { + if (aggsParams.empty()) { + return MakeBlockCombineHashedWrapper<true, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockCombineHashedWrapper<false, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } } else { - return MakeBlockCombineHashedWrapper<false, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + if (aggsParams.empty()) { + return MakeBlockCombineHashedWrapper<true, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockCombineHashedWrapper<false, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } } } else { - if (aggsParams.empty()) { - return MakeBlockCombineHashedWrapper<true, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + const auto wideFlow = dynamic_cast<IComputationWideFlowNode *>(wideStreamOrFlow); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + if (filterColumn) { + if (aggsParams.empty()) { + return MakeBlockCombineHashedWrapper<true, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockCombineHashedWrapper<false, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } } else { - return MakeBlockCombineHashedWrapper<false, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + if (aggsParams.empty()) { + return MakeBlockCombineHashedWrapper<true, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockCombineHashedWrapper<false, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } } } } IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args"); - const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); - const auto wideComponents = GetWideComponents(flowType); + + const bool isStream = callable.GetInput(0).GetStaticType()->IsStream(); + MKQL_ENSURE(isStream == callable.GetType()->GetReturnType()->IsStream(), "input and output must be both either flow or stream"); + + const auto wideComponents = GetWideComponents(callable.GetInput(0).GetStaticType()); const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env); - const auto returnFlowType = AS_TYPE(TFlowType, callable.GetType()->GetReturnType()); - const auto returnWideComponents = GetWideComponents(returnFlowType); + const auto returnWideComponents = GetWideComponents(callable.GetType()->GetReturnType()); - auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); - MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + const auto wideStreamOrFlow = LocateNode(ctx.NodeLocator, callable, 0); auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(1)); std::vector<TKeyParams> keys; @@ -2108,23 +2435,35 @@ IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TCompu PrepareKeys(keys, totalKeysSize, isFixed); const size_t maxBlockLen = CalcMaxBlockLenForOutput(callable.GetType()->GetReturnType()); - if (aggsParams.empty()) { - return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + if (isStream) { + const auto wideStream = wideStreamOrFlow; + if (aggsParams.empty()) { + return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockMergeFinalizeHashedWrapper<false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } } else { - return MakeBlockMergeFinalizeHashedWrapper<false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + const auto wideFlow = dynamic_cast<IComputationWideFlowNode *>(wideStreamOrFlow); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + if (aggsParams.empty()) { + return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockMergeFinalizeHashedWrapper<false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } } } IComputationNode* WrapBlockMergeManyFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 args"); - const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); - const auto wideComponents = GetWideComponents(flowType); + + const bool isStream = callable.GetInput(0).GetStaticType()->IsStream(); + MKQL_ENSURE(isStream == callable.GetType()->GetReturnType()->IsStream(), "input and output must be both either flow or stream"); + + const auto wideComponents = GetWideComponents(callable.GetInput(0).GetStaticType()); const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env); - const auto returnFlowType = AS_TYPE(TFlowType, callable.GetType()->GetReturnType()); - const auto returnWideComponents = GetWideComponents(returnFlowType); + const auto returnWideComponents = GetWideComponents(callable.GetType()->GetReturnType()); - const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); - MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + const auto wideStreamOrFlow = LocateNode(ctx.NodeLocator, callable, 0); auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(1)); std::vector<TKeyParams> keys; @@ -2147,12 +2486,25 @@ IComputationNode* WrapBlockMergeManyFinalizeHashed(TCallable& callable, const TC totalStateSize += streams.size(); const size_t maxBlockLen = CalcMaxBlockLenForOutput(callable.GetType()->GetReturnType()); - if (aggsParams.empty()) { - return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), - keys, maxBlockLen, std::move(aggsParams)); + if (isStream){ + const auto wideStream = wideStreamOrFlow; + if (aggsParams.empty()) { + return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, tupleType->GetElementsCount(), + keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockMergeManyFinalizeHashedWrapper(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideStream, tupleType->GetElementsCount(), + keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + } } else { - return MakeBlockMergeManyFinalizeHashedWrapper(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), - keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + const auto wideFlow = dynamic_cast<IComputationWideFlowNode *>(wideStreamOrFlow); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + if (aggsParams.empty()) { + return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), + keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockMergeManyFinalizeHashedWrapper(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), + keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + } } } diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp index 81f857b4be..0a4cd30828 100644 --- a/yql/essentials/minikql/mkql_program_builder.cpp +++ b/yql/essentials/minikql/mkql_program_builder.cpp @@ -5730,14 +5730,15 @@ TRuntimeNode TProgramBuilder::BlockBitCast(TRuntimeNode value, TType* targetType return TRuntimeNode(builder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockCombineAll(TRuntimeNode flow, std::optional<ui32> filterColumn, - const TArrayRef<const TAggInfo>& aggs, TType* returnType) { - if constexpr (RuntimeVersion < 31U) { - THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; - } +TRuntimeNode TProgramBuilder::BuildBlockCombineAll(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn, + const TArrayRef<const TAggInfo>& aggs, TType* returnType) { + const auto inputType = input.GetStaticType(); + MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type"); + MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type"); + + TCallableBuilder builder(Env, callableName, returnType); + builder.Add(input); - TCallableBuilder builder(Env, __func__, returnType); - builder.Add(flow); if (!filterColumn) { builder.Add(NewEmptyOptionalDataLiteral(NUdf::TDataType<ui32>::Id)); } else { @@ -5759,14 +5760,32 @@ TRuntimeNode TProgramBuilder::BlockCombineAll(TRuntimeNode flow, std::optional<u return TRuntimeNode(builder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode flow, std::optional<ui32> filterColumn, const TArrayRef<ui32>& keys, +TRuntimeNode TProgramBuilder::BlockCombineAll(TRuntimeNode stream, std::optional<ui32> filterColumn, const TArrayRef<const TAggInfo>& aggs, TType* returnType) { if constexpr (RuntimeVersion < 31U) { THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; } - TCallableBuilder builder(Env, __func__, returnType); - builder.Add(flow); + MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type"); + MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type"); + + if constexpr (RuntimeVersion < 52U) { + const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType()); + return FromFlow(BuildBlockCombineAll(__func__, ToFlow(stream), filterColumn, aggs, flowReturnType)); + } else { + return BuildBlockCombineAll(__func__, stream, filterColumn, aggs, returnType); + } +} + +TRuntimeNode TProgramBuilder::BuildBlockCombineHashed(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn, + const TArrayRef<ui32>& keys, const TArrayRef<const TAggInfo>& aggs, TType* returnType) { + const auto inputType = input.GetStaticType(); + MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type"); + MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type"); + + TCallableBuilder builder(Env, callableName, returnType); + builder.Add(input); + if (!filterColumn) { builder.Add(NewEmptyOptionalDataLiteral(NUdf::TDataType<ui32>::Id)); } else { @@ -5794,14 +5813,31 @@ TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode flow, std::optiona return TRuntimeNode(builder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockMergeFinalizeHashed(TRuntimeNode flow, const TArrayRef<ui32>& keys, +TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode stream, std::optional<ui32> filterColumn, const TArrayRef<ui32>& keys, const TArrayRef<const TAggInfo>& aggs, TType* returnType) { if constexpr (RuntimeVersion < 31U) { THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; } - TCallableBuilder builder(Env, __func__, returnType); - builder.Add(flow); + MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type"); + MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type"); + + if constexpr (RuntimeVersion < 52U) { + const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType()); + return FromFlow(BuildBlockCombineHashed(__func__, ToFlow(stream), filterColumn, keys, aggs, flowReturnType)); + } else { + return BuildBlockCombineHashed(__func__, stream, filterColumn, keys, aggs, returnType); + } +} + +TRuntimeNode TProgramBuilder::BuildBlockMergeFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, TType* returnType) { + const auto inputType = input.GetStaticType(); + MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type"); + MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type"); + + TCallableBuilder builder(Env, callableName, returnType); + builder.Add(input); TVector<TRuntimeNode> keyNodes; for (const auto& key : keys) { @@ -5824,14 +5860,31 @@ TRuntimeNode TProgramBuilder::BlockMergeFinalizeHashed(TRuntimeNode flow, const return TRuntimeNode(builder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockMergeManyFinalizeHashed(TRuntimeNode flow, const TArrayRef<ui32>& keys, - const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) { +TRuntimeNode TProgramBuilder::BlockMergeFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, TType* returnType) { if constexpr (RuntimeVersion < 31U) { THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; } - TCallableBuilder builder(Env, __func__, returnType); - builder.Add(flow); + MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type"); + MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type"); + + if constexpr (RuntimeVersion < 52U) { + const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType()); + return FromFlow(BuildBlockMergeFinalizeHashed(__func__, ToFlow(stream), keys, aggs, flowReturnType)); + } else { + return BuildBlockMergeFinalizeHashed(__func__, stream, keys, aggs, returnType); + } +} + +TRuntimeNode TProgramBuilder::BuildBlockMergeManyFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) { + const auto inputType = input.GetStaticType(); + MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type"); + MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type"); + + TCallableBuilder builder(Env, callableName, returnType); + builder.Add(input); TVector<TRuntimeNode> keyNodes; for (const auto& key : keys) { @@ -5866,6 +5919,23 @@ TRuntimeNode TProgramBuilder::BlockMergeManyFinalizeHashed(TRuntimeNode flow, co return TRuntimeNode(builder.Build(), false); } +TRuntimeNode TProgramBuilder::BlockMergeManyFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) { + if constexpr (RuntimeVersion < 31U) { + THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; + } + + MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type"); + MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type"); + + if constexpr (RuntimeVersion < 52U) { + const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType()); + return FromFlow(BuildBlockMergeManyFinalizeHashed(__func__, ToFlow(stream), keys, aggs, streamIndex, streams, flowReturnType)); + } else { + return BuildBlockMergeManyFinalizeHashed(__func__, stream, keys, aggs, streamIndex, streams, returnType); + } +} + TRuntimeNode TProgramBuilder::ScalarApply(const TArrayRef<const TRuntimeNode>& args, const TArrayLambda& handler) { if constexpr (RuntimeVersion < 39U) { THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; diff --git a/yql/essentials/minikql/mkql_program_builder.h b/yql/essentials/minikql/mkql_program_builder.h index eca10da928..7e139cd1a9 100644 --- a/yql/essentials/minikql/mkql_program_builder.h +++ b/yql/essentials/minikql/mkql_program_builder.h @@ -760,6 +760,15 @@ protected: private: TRuntimeNode BuildWideFilter(const std::string_view& callableName, TRuntimeNode flow, const TNarrowLambda& handler); + TRuntimeNode BuildBlockCombineAll(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn, + const TArrayRef<const TAggInfo>& aggs, TType* returnType); + TRuntimeNode BuildBlockCombineHashed(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn, + const TArrayRef<ui32>& keys, const TArrayRef<const TAggInfo>& aggs, TType* returnType); + TRuntimeNode BuildBlockMergeFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, TType* returnType); + TRuntimeNode BuildBlockMergeManyFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType); + TRuntimeNode DictItems(TRuntimeNode dict, EDictItems mode); TRuntimeNode If(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch, TType* resultType); |