diff options
author | vvvv <vvvv@yandex-team.com> | 2024-11-07 04:19:26 +0300 |
---|---|---|
committer | vvvv <vvvv@yandex-team.com> | 2024-11-07 04:29:50 +0300 |
commit | 2661be00f3bc47590fda9218bf0386d6355c8c88 (patch) | |
tree | 3d316c07519191283d31c5f537efc6aabb42a2f0 /yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp | |
parent | cf2a23963ac10add28c50cc114fbf48953eca5aa (diff) | |
download | ydb-2661be00f3bc47590fda9218bf0386d6355c8c88.tar.gz |
Moved yql/minikql YQL-19206
init
[nodiff:caesar]
commit_hash:d1182ef7d430ccf7e4d37ed933c7126d7bd5d6e4
Diffstat (limited to 'yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp')
-rw-r--r-- | yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp | 2160 |
1 files changed, 2160 insertions, 0 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp b/yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp new file mode 100644 index 0000000000..2c22e4eec6 --- /dev/null +++ b/yql/essentials/minikql/comp_nodes/mkql_block_agg.cpp @@ -0,0 +1,2160 @@ +#include "mkql_block_agg.h" +#include "mkql_block_agg_factory.h" +#include "mkql_rh_hash.h" + +#include <yql/essentials/minikql/computation/mkql_block_reader.h> +#include <yql/essentials/minikql/computation/mkql_block_builder.h> +#include <yql/essentials/minikql/computation/mkql_block_impl.h> +#include <yql/essentials/minikql/computation/mkql_block_impl_codegen.h> // Y_IGNORE +#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE + +#include <yql/essentials/minikql/mkql_node_cast.h> +#include <yql/essentials/minikql/mkql_node_builder.h> + +#include <yql/essentials/minikql/arrow/arrow_defs.h> +#include <yql/essentials/minikql/arrow/arrow_util.h> +#include <yql/essentials/minikql/arrow/mkql_bit_utils.h> + +#include <yql/essentials/utils/prefetch.h> + +#include <arrow/scalar.h> +#include <arrow/array/array_primitive.h> +#include <arrow/array/builder_primitive.h> +#include <arrow/chunked_array.h> + +//#define USE_STD_UNORDERED + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +constexpr bool InlineAggState = false; + +#ifdef USE_STD_UNORDERED +template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>, typename TSettings = void> +class TDynamicHashMapImpl { +public: + using TMapType = std::unordered_map<TKey, std::vector<char>, THash, TEqual>; + using const_iterator = typename TMapType::const_iterator; + using iterator = typename TMapType::iterator; + + TDynamicHashMapImpl(size_t stateSize, const THash& hasher, const TEqual& equal) + : StateSize_(stateSize) + , Map_(0, hasher, equal) + {} + + ui64 GetSize() const { + return Map_.size(); + } + + const_iterator Begin() const { + return Map_.begin(); + } + + const_iterator End() const { + return Map_.end(); + } + + bool IsValid(const_iterator iter) const { + return true; + } + + void Advance(const_iterator& iter) const { + ++iter; + } + + iterator Insert(const TKey& key, bool& isNew) { + auto res = Map_.emplace(key, std::vector<char>()); + isNew = res.second; + if (isNew) { + res.first->second.resize(StateSize_); + } + + return res.first; + } + + template <typename TSink> + void BatchInsert(std::span<TRobinHoodBatchRequestItem<TKey>> batchRequest, TSink&& sink) { + for (size_t index = 0; index < batchRequest.size(); ++index) { + bool isNew; + auto iter = Insert(batchRequest[index].GetKey(), isNew); + sink(index, iter, isNew); + } + } + + const TKey& GetKey(const_iterator it) const { + return it->first; + } + + char* GetMutablePayload(iterator it) const { + return it->second.data(); + } + + const char* GetPayload(const_iterator it) const { + return it->second.data(); + } + + void CheckGrow() { + } + +private: + const size_t StateSize_; + TMapType Map_; +}; + +template <typename TKey, typename TPayload, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>, typename TSettings = void> +class TFixedHashMapImpl { +public: + using TMapType = std::unordered_map<TKey, TPayload, THash, TEqual>; + using const_iterator = typename TMapType::const_iterator; + using iterator = typename TMapType::iterator; + + TFixedHashMapImpl(const THash& hasher, const TEqual& equal) + : Map_(0, hasher, equal) + {} + + ui64 GetSize() const { + return Map_.size(); + } + + const_iterator Begin() const { + return Map_.begin(); + } + + const_iterator End() const { + return Map_.end(); + } + + bool IsValid(const_iterator iter) const { + return true; + } + + void Advance(const_iterator& iter) const { + ++iter; + } + + iterator Insert(const TKey& key, bool& isNew) { + auto res = Map_.emplace(key, TPayload()); + isNew = res.second; + return res.first; + } + + template <typename TSink> + void BatchInsert(std::span<TRobinHoodBatchRequestItem<TKey>> batchRequest, TSink&& sink) { + for (size_t index = 0; index < batchRequest.size(); ++index) { + bool isNew; + auto iter = Insert(batchRequest[index].GetKey(), isNew); + sink(index, iter, isNew); + } + } + + const TKey& GetKey(const_iterator it) const { + return it->first; + } + + char* GetMutablePayload(iterator it) const { + return (char*)&it->second; + } + + const char* GetPayload(const_iterator it) const { + return (const char*)&it->second; + } + + void CheckGrow() { + } + +private: + TMapType Map_; +}; + +template <typename TKey, typename TEqual = std::equal_to<TKey>, typename THash = std::hash<TKey>, typename TAllocator = std::allocator<char>, typename TSettings = void> +class THashSetImpl { +public: + using TSetType = std::unordered_set<TKey, THash, TEqual>; + using const_iterator = typename TSetType::const_iterator; + using iterator = typename TSetType::iterator; + + THashSetImpl(const THash& hasher, const TEqual& equal) + : Set_(0, hasher, equal) + {} + + ui64 GetSize() const { + return Set_.size(); + } + + const_iterator Begin() const { + return Set_.begin(); + } + + const_iterator End() const { + return Set_.end(); + } + + bool IsValid(const_iterator iter) const { + return true; + } + + void Advance(const_iterator& iter) const { + ++iter; + } + + iterator Insert(const TKey& key, bool& isNew) { + auto res = Set_.emplace(key); + isNew = res.second; + return res.first; + } + + template <typename TSink> + void BatchInsert(std::span<TRobinHoodBatchRequestItem<TKey>> batchRequest, TSink&& sink) { + for (size_t index = 0; index < batchRequest.size(); ++index) { + bool isNew; + auto iter = Insert(batchRequest[index].GetKey(), isNew); + sink(index, iter, isNew); + } + } + + void CheckGrow() { + } + + const TKey& GetKey(const_iterator it) const { + return *it; + } + + char* GetMutablePayload(iterator it) const { + Y_UNUSED(it); + return nullptr; + } + + const char* GetPayload(const_iterator it) const { + Y_UNUSED(it); + return nullptr; + } + +private: + TSetType Set_; +}; + +#else +#define TDynamicHashMapImpl TRobinHoodHashMap +#define TFixedHashMapImpl TRobinHoodHashFixedMap +#define THashSetImpl TRobinHoodHashSet +#endif + +using TState8 = ui64; +static_assert(sizeof(TState8) == 8); + +using TState16 = std::pair<ui64, ui64>; +static_assert(sizeof(TState16) == 16); + +using TStateArena = void*; +static_assert(sizeof(TStateArena) == sizeof(void*)); + +struct TExternalFixedSizeKey { + mutable const char* Data; +}; + +struct TKey16 { + ui64 Lo; + ui64 Hi; +}; + +class TSSOKey { +public: + static constexpr size_t SSO_Length = 15; + static_assert(SSO_Length < 128); // should fit into 7 bits + +private: + struct TExternal { + ui64 Length_; + const char* Ptr_; + }; + + struct TInplace { + ui8 SmallLength_; + char Buffer_[SSO_Length]; + }; + +public: + TSSOKey(const TSSOKey& other) { + memcpy(U.A, other.U.A, SSO_Length + 1); + } + + TSSOKey& operator=(const TSSOKey& other) { + memcpy(U.A, other.U.A, SSO_Length + 1); + return *this; + } + + static bool CanBeInplace(TStringBuf data) { + return data.Size() + 1 <= sizeof(TSSOKey); + } + + static TSSOKey Inplace(TStringBuf data) { + Y_ASSERT(CanBeInplace(data)); + TSSOKey ret(1 | (data.Size() << 1), 0); + memcpy(ret.U.I.Buffer_, data.Data(), data.Size()); + return ret; + } + + static TSSOKey External(TStringBuf data) { + return TSSOKey(data.Size() << 1, data.Data()); + } + + bool IsInplace() const { + return U.I.SmallLength_ & 1; + } + + TStringBuf AsView() const { + if (IsInplace()) { + // inplace + return TStringBuf(U.I.Buffer_, U.I.SmallLength_ >> 1); + } else { + // external + return TStringBuf(U.E.Ptr_, U.E.Length_ >> 1); + } + } + + void UpdateExternalPointer(const char *ptr) const { + Y_ASSERT(!IsInplace()); + const_cast<TExternal&>(U.E).Ptr_ = ptr; + } + +private: + TSSOKey(ui64 length, const char* ptr) { + U.E.Length_ = length; + U.E.Ptr_ = ptr; + } + +private: + union { + TExternal E; + TInplace I; + char A[SSO_Length + 1]; + } U; +}; + +static_assert(sizeof(TSSOKey) == TSSOKey::SSO_Length + 1); + +} +} +} + +namespace std { + template <> + struct hash<NKikimr::NMiniKQL::TKey16> { + using argument_type = NKikimr::NMiniKQL::TKey16; + using result_type = size_t; + inline result_type operator()(argument_type const& s) const noexcept { + auto hasher = std::hash<ui64>(); + return hasher(s.Hi) * 31 + hasher(s.Lo); + } + }; + + template <> + struct equal_to<NKikimr::NMiniKQL::TKey16> { + using argument_type = NKikimr::NMiniKQL::TKey16; + bool operator()(argument_type x, argument_type y) const { + return x.Hi == y.Hi && x.Lo == y.Lo; + } + }; + + template <> + struct hash<NKikimr::NMiniKQL::TSSOKey> { + using argument_type = NKikimr::NMiniKQL::TSSOKey; + using result_type = size_t; + inline result_type operator()(argument_type const& s) const noexcept { + return std::hash<std::string_view>()(s.AsView()); + } + }; + + template <> + struct equal_to<NKikimr::NMiniKQL::TSSOKey> { + using argument_type = NKikimr::NMiniKQL::TSSOKey; + bool operator()(argument_type x, argument_type y) const { + return x.AsView() == y.AsView(); + } + bool operator()(argument_type x, TStringBuf y) const { + return x.AsView() == y; + } + using is_transparent = void; + }; + + template <> + struct hash<NKikimr::NMiniKQL::TExternalFixedSizeKey> { + using argument_type = NKikimr::NMiniKQL::TExternalFixedSizeKey; + using result_type = size_t; + hash(ui32 length) + : Length(length) + {} + + inline result_type operator()(argument_type const& s) const noexcept { + return std::hash<std::string_view>()(std::string_view(s.Data, Length)); + } + + const ui32 Length; + }; + + template <> + struct equal_to<NKikimr::NMiniKQL::TExternalFixedSizeKey> { + using argument_type = NKikimr::NMiniKQL::TExternalFixedSizeKey; + equal_to(ui32 length) + : Length(length) + {} + + bool operator()(argument_type x, argument_type y) const { + return memcmp(x.Data, y.Data, Length) == 0; + } + bool operator()(argument_type x, TStringBuf y) const { + Y_ASSERT(y.Size() <= Length); + return memcmp(x.Data, y.Data(), Length) == 0; + } + using is_transparent = void; + + const ui32 Length; + }; +} + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +template <typename T> +struct TAggParams { + std::unique_ptr<IPreparedBlockAggregator<T>> Prepared_; + ui32 Column_ = 0; + TType* StateType_ = nullptr; + TType* ReturnType_ = nullptr; + ui32 Hint_ = 0; +}; + +struct TKeyParams { + ui32 Index; + TType* Type; +}; + +size_t GetBitmapPopCount(const std::shared_ptr<arrow::ArrayData>& arr) { + size_t len = (size_t)arr->length; + MKQL_ENSURE(arr->GetNullCount() == 0, "Bitmap block should not have nulls"); + const ui8* src = arr->GetValues<ui8>(1); + return GetSparseBitmapPopCount(src, len); +} + +size_t CalcMaxBlockLenForOutput(TType* out) { + const auto outputType = AS_TYPE(TFlowType, out); + const auto wideComponents = GetWideComponents(outputType); + MKQL_ENSURE(wideComponents.size() > 0, "Expecting at least one output column"); + + size_t maxBlockItemSize = 0; + for (ui32 i = 0; i < wideComponents.size() - 1; ++i) { + auto type = AS_TYPE(TBlockType, wideComponents[i]); + MKQL_ENSURE(type->GetShape() != TBlockType::EShape::Scalar, "Expecting block type"); + maxBlockItemSize = std::max(maxBlockItemSize, CalcMaxBlockItemSize(type->GetItemType())); + } + + return CalcBlockLen(maxBlockItemSize); +} + +class TBlockCombineAllWrapperCodegenBase { +protected: +#ifndef MKQL_DISABLE_CODEGEN + class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> { + private: + using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>; + llvm::PointerType*const PointerType; + llvm::IntegerType*const IsFinishedType; + public: + std::vector<llvm::Type*> GetFieldsArray() { + std::vector<llvm::Type*> result = TBase::GetFields(); + result.emplace_back(PointerType); + result.emplace_back(IsFinishedType); + return result; + } + + llvm::Constant* GetPointer() { + return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0); + } + + llvm::Constant* GetIsFinished() { + return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1); + } + + TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width) + : TBase(context) + , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width))) + , IsFinishedType(llvm::Type::getInt1Ty(Context)) + {} + }; + + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValuesImpl(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block, + IComputationWideFlowNode* flow, size_t width, size_t aggCount, + uintptr_t getStateMethodPtr, uint64_t makeStateMethodPtr, + uintptr_t processInputMethodPtr, uintptr_t makeOutputMethodPtr) const { + auto& context = ctx.Codegen.GetContext(); + + const auto valueType = Type::getInt128Ty(context); + const auto statusType = Type::getInt32Ty(context); + const auto indexType = Type::getInt64Ty(context); + const auto flagType = Type::getInt1Ty(context); + const auto arrayType = ArrayType::get(valueType, width); + const auto ptrValuesType = PointerType::getUnqual(arrayType); + + TLLVMFieldsStructureState stateFields(context, width); + const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); + const auto statePtrType = PointerType::getUnqual(stateType); + + const auto atTop = &ctx.Func->getEntryBlock().back(); + + const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), getStateMethodPtr); + const auto getType = FunctionType::get(valueType, {statePtrType, indexType}, false); + const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop); + + const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop); + new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop); + + const auto make = BasicBlock::Create(context, "make", ctx.Func); + const auto main = BasicBlock::Create(context, "main", ctx.Func); + const auto read = BasicBlock::Create(context, "read", ctx.Func); + const auto good = BasicBlock::Create(context, "good", ctx.Func); + const auto work = BasicBlock::Create(context, "work", ctx.Func); + const auto over = BasicBlock::Create(context, "over", ctx.Func); + + BranchInst::Create(make, main, IsInvalid(statePtr, block), block); + block = make; + + const auto ptrType = PointerType::getUnqual(StructType::get(context)); + const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), makeStateMethodPtr); + const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false); + const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); + CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block); + + BranchInst::Create(main, block); + + block = main; + + const auto state = new LoadInst(valueType, statePtr, "state", block); + const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block); + const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block); + + const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block); + const auto finished = new LoadInst(flagType, finishedPtr, "finished", block); + + const auto result = PHINode::Create(statusType, 3U, "result", over); + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block); + + BranchInst::Create(over, read, finished, block); + + block = read; + + const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block); + const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block); + SafeUnRefUnboxed(values, ctx, block); + + const auto getres = GetNodeValues(flow, ctx, block); + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block); + + const auto way = SwitchInst::Create(getres.first, good, 2U, block); + way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), work); + way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over); + + block = good; + + Value* array = UndefValue::get(arrayType); + for (auto idx = 0U; idx < getres.second.size(); ++idx) { + const auto value = getres.second[idx](ctx, block); + AddRefBoxed(value, ctx, block); + array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block); + } + new StoreInst(array, values, block); + + const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), processInputMethodPtr); + const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false); + const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block); + CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block); + + BranchInst::Create(read, block); + + block = work; + + const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), makeOutputMethodPtr); + const auto makeOutputType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false); + const auto makeOutputPtr = CastInst::Create(Instruction::IntToPtr, makeOutputFunc, PointerType::getUnqual(makeOutputType), "make_output_func", block); + const auto hasData = CallInst::Create(makeOutputType, makeOutputPtr, {stateArg, ctx.GetFactory()}, "make_output", block); + const auto output = SelectInst::Create(hasData, ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), "output", block); + new StoreInst(stateArg, stateOnStack, block); + + result->addIncoming(output, block); + BranchInst::Create(over, block); + + block = over; + + ICodegeneratorInlineWideNode::TGettersList getters(aggCount); + for (size_t idx = 0U; idx < getters.size(); ++idx) { + getters[idx] = [idx, getType, getPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) { + Y_UNUSED(ctx); + const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block); + return CallInst::Create(getType, getPtr, {stateArg, ConstantInt::get(indexType, idx)}, "get", block); + }; + } + return {result, std::move(getters)}; + } +#endif +}; + +class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>, + protected TBlockCombineAllWrapperCodegenBase { +using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>; +public: + TBlockCombineAllWrapper(TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional<ui32> filterColumn, + size_t width, + std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams) + : TBaseComputation(mutables, flow, EValueRepresentation::Boxed) + , Flow_(flow) + , FilterColumn_(filterColumn) + , Width_(width) + , AggsParams_(std::move(aggsParams)) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) + { + MKQL_ENSURE(Width_ > 0, "Missing block length column"); + } + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, + TComputationContext& ctx, + NUdf::TUnboxedValue*const* output) const + { + auto& s = GetState(state, ctx); + if (s.IsFinished_) + return EFetchResult::Finish; + + for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) { + switch (Flow_->FetchValues(ctx, fields)) { + case EFetchResult::Yield: + return EFetchResult::Yield; + case EFetchResult::One: + s.ProcessInput(); + continue; + case EFetchResult::Finish: + break; + } + if (s.MakeOutput()) { + for (size_t i = 0; i < AggsParams_.size(); ++i) { + if (const auto out = output[i]) { + *out = s.Get(i); + } + } + return EFetchResult::One; + } + return EFetchResult::Finish; + } + } +#ifndef MKQL_DISABLE_CODEGEN + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, AggsParams_.size(), + GetMethodPtr(&TState::Get), GetMethodPtr(&TBlockCombineAllWrapper::MakeState), + GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::MakeOutput)); + } +#endif +private: + struct TState : public TComputationValue<TState> { + NUdf::TUnboxedValue* Pointer_ = nullptr; + bool IsFinished_ = false; + bool HasValues_ = false; + TUnboxedValueVector Values_; + std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_; + std::vector<char> AggStates_; + const std::optional<ui32> FilterColumn_; + const size_t Width_; + + TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx) + : TComputationValue(memInfo) + , Values_(std::max(width, params.size())) + , FilterColumn_(filterColumn) + , Width_(width) + { + Pointer_ = Values_.data(); + + ui32 totalStateSize = 0; + for (const auto& p : params) { + Aggs_.emplace_back(p.Prepared_->Make(ctx)); + MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch"); + totalStateSize += Aggs_.back()->StateSize; + } + + AggStates_.resize(totalStateSize); + char* ptr = AggStates_.data(); + for (const auto& agg : Aggs_) { + agg->InitState(ptr); + ptr += agg->StateSize; + } + } + + void ProcessInput() { + const ui64 batchLength = TArrowBlock::From(Values_[Width_ - 1U]).GetDatum().scalar_as<arrow::UInt64Scalar>().value; + if (!batchLength) { + return; + } + + std::optional<ui64> filtered; + if (FilterColumn_) { + const auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum(); + if (filterDatum.is_scalar()) { + if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) { + return; + } + } else { + const ui64 popCount = GetBitmapPopCount(filterDatum.array()); + if (popCount == 0) { + return; + } + + if (popCount < batchLength) { + filtered = popCount; + } + } + } + + HasValues_ = true; + char* ptr = AggStates_.data(); + for (size_t i = 0; i < Aggs_.size(); ++i) { + Aggs_[i]->AddMany(ptr, Values_.data(), batchLength, filtered); + ptr += Aggs_[i]->StateSize; + } + } + + bool MakeOutput() { + IsFinished_ = true; + if (!HasValues_) + return false; + + char* ptr = AggStates_.data(); + for (size_t i = 0; i < Aggs_.size(); ++i) { + Values_[i] = Aggs_[i]->FinishOne(ptr); + Aggs_[i]->DestroyState(ptr); + ptr += Aggs_[i]->StateSize; + } + return true; + } + + NUdf::TUnboxedValuePod Get(size_t index) const { + return Values_[index]; + } + }; + void RegisterDependencies() const final { + FlowDependsOn(Flow_); + } + + void MakeState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { + state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx); + } + + TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { + if (state.IsInvalid()) { + MakeState(state, ctx); + + auto& s = *static_cast<TState*>(state.AsBoxed().Get()); + const auto fields = ctx.WideFields.data() + WideFieldsIndex_; + for (size_t i = 0; i < Width_; ++i) { + fields[i] = &s.Values_[i]; + } + return s; + } + return *static_cast<TState*>(state.AsBoxed().Get()); + } +private: + IComputationWideFlowNode *const Flow_; + const std::optional<ui32> FilterColumn_; + const size_t Width_; + const std::vector<TAggParams<IBlockAggregatorCombineAll>> AggsParams_; + const size_t WideFieldsIndex_; +}; + +template <typename T> +T MakeKey(TStringBuf s, ui32 keyLength) { + Y_UNUSED(keyLength); + Y_ASSERT(s.Size() <= sizeof(T)); + return *(const T*)s.Data(); +} + +template <> +TSSOKey MakeKey(TStringBuf s, ui32 keyLength) { + Y_UNUSED(keyLength); + if (TSSOKey::CanBeInplace(s)) { + return TSSOKey::Inplace(s); + } else { + return TSSOKey::External(s); + } +} + +template <> +TExternalFixedSizeKey MakeKey(TStringBuf s, ui32 keyLength) { + Y_ASSERT(s.Size() == keyLength); + return { s.Data() }; +} + +void MoveKeyToArena(const TSSOKey& key, TPagedArena& arena, ui32 keyLength) { + Y_UNUSED(keyLength); + if (key.IsInplace()) { + return; + } + + auto view = key.AsView(); + auto ptr = (char*)arena.Alloc(view.Size()); + memcpy(ptr, view.Data(), view.Size()); + key.UpdateExternalPointer(ptr); +} + +void MoveKeyToArena(const TExternalFixedSizeKey& key, TPagedArena& arena, ui32 keyLength) { + auto ptr = (char*)arena.Alloc(keyLength); + memcpy(ptr, key.Data, keyLength); + key.Data = ptr; +} + +template <typename T> +TStringBuf GetKeyView(const T& key, ui32 keyLength) { + Y_UNUSED(keyLength); + return TStringBuf((const char*)&key, sizeof(T)); +} + +template <> +TStringBuf GetKeyView(const TSSOKey& key, ui32 keyLength) { + Y_UNUSED(keyLength); + return key.AsView(); +} + +template <> +TStringBuf GetKeyView(const TExternalFixedSizeKey& key, ui32 keyLength) { + return TStringBuf(key.Data, keyLength); +} + +template <typename T> +std::equal_to<T> MakeEqual(ui32 keyLength) { + Y_UNUSED(keyLength); + return std::equal_to<T>(); +} + +template <> +std::equal_to<TExternalFixedSizeKey> MakeEqual(ui32 keyLength) { + return std::equal_to<TExternalFixedSizeKey>(keyLength); +} + +template <typename T> +std::hash<T> MakeHash(ui32 keyLength) { + Y_UNUSED(keyLength); + return std::hash<T>(); +} + +template <> +std::hash<TExternalFixedSizeKey> MakeHash(ui32 keyLength) { + return std::hash<TExternalFixedSizeKey>(keyLength); +} + +class THashedWrapperCodegenBase { +protected: +#ifndef MKQL_DISABLE_CODEGEN + class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState { + private: + using TBase = TLLVMFieldsStructureBlockState; + llvm::IntegerType*const WritingOutputType; + llvm::IntegerType*const IsFinishedType; + protected: + using TBase::Context; + public: + std::vector<llvm::Type*> GetFieldsArray() { + std::vector<llvm::Type*> result = TBase::GetFieldsArray(); + result.emplace_back(WritingOutputType); + result.emplace_back(IsFinishedType); + return result; + } + + llvm::Constant* GetWritingOutput() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields); + } + + llvm::Constant* GetIsFinished() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 1); + } + + TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width) + : TBase(context, width) + , WritingOutputType(Type::getInt1Ty(Context)) + , IsFinishedType(Type::getInt1Ty(Context)) + {} + }; + + Y_NO_INLINE ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValuesImpl( + const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block, + IComputationWideFlowNode* flow, size_t width, size_t outputWidth, + uintptr_t getStateMethodPtr, uintptr_t makeStateMethodPtr, + uintptr_t processInputMethodPtr, uintptr_t finishMethodPtr, + uintptr_t fillOutputMethodPtr, uintptr_t sliceMethodPtr) const { + auto& context = ctx.Codegen.GetContext(); + + const auto valueType = Type::getInt128Ty(context); + const auto statusType = Type::getInt32Ty(context); + const auto indexType = Type::getInt64Ty(context); + const auto flagType = Type::getInt1Ty(context); + const auto arrayType = ArrayType::get(valueType, width); + const auto ptrValuesType = PointerType::getUnqual(arrayType); + + TLLVMFieldsStructureState stateFields(context, width); + const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); + const auto statePtrType = PointerType::getUnqual(stateType); + + const auto atTop = &ctx.Func->getEntryBlock().back(); + + const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), getStateMethodPtr); + const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false); + const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop); + + const auto heightPtr = new AllocaInst(indexType, 0U, "height_ptr", atTop); + const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop); + + new StoreInst(ConstantInt::get(indexType, 0), heightPtr, atTop); + new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop); + + const auto make = BasicBlock::Create(context, "make", ctx.Func); + const auto main = BasicBlock::Create(context, "main", ctx.Func); + const auto more = BasicBlock::Create(context, "more", ctx.Func); + const auto test = BasicBlock::Create(context, "test", ctx.Func); + const auto read = BasicBlock::Create(context, "read", ctx.Func); + const auto good = BasicBlock::Create(context, "good", ctx.Func); + const auto stop = BasicBlock::Create(context, "stop", ctx.Func); + const auto work = BasicBlock::Create(context, "work", ctx.Func); + const auto fill = BasicBlock::Create(context, "fill", ctx.Func); + const auto over = BasicBlock::Create(context, "over", ctx.Func); + + BranchInst::Create(make, main, IsInvalid(statePtr, block), block); + block = make; + + const auto ptrType = PointerType::getUnqual(StructType::get(context)); + const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), makeStateMethodPtr); + const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false); + const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); + CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block); + + BranchInst::Create(main, block); + + block = main; + + const auto state = new LoadInst(valueType, statePtr, "state", block); + const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block); + const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block); + const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block); + + const auto count = new LoadInst(indexType, countPtr, "count", block); + const auto none = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, ConstantInt::get(indexType, 0), "none", block); + + BranchInst::Create(more, fill, none, block); + + block = more; + + const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block); + const auto finished = new LoadInst(flagType, finishedPtr, "finished", block); + + const auto result = PHINode::Create(statusType, 5U, "result", over); + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block); + + BranchInst::Create(over, test, finished, block); + + block = test; + + const auto writingOutputPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetWritingOutput() }, "writing_output_ptr", block); + const auto writingOutput = new LoadInst(flagType, writingOutputPtr, "writing_output", block); + + BranchInst::Create(work, read, writingOutput, block); + + block = read; + + const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block); + const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block); + SafeUnRefUnboxed(values, ctx, block); + + const auto getres = GetNodeValues(flow, ctx, block); + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block); + + const auto way = SwitchInst::Create(getres.first, good, 2U, block); + way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), stop); + way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over); + + block = good; + + Value* array = UndefValue::get(arrayType); + for (auto idx = 0U; idx < getres.second.size(); ++idx) { + const auto value = getres.second[idx](ctx, block); + AddRefBoxed(value, ctx, block); + array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block); + } + new StoreInst(array, values, block); + + const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), processInputMethodPtr); + const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false); + const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block); + CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block); + + BranchInst::Create(read, block); + + block = stop; + + const auto finishFunc = ConstantInt::get(Type::getInt64Ty(context), finishMethodPtr); + const auto finishType = FunctionType::get(flagType, {statePtrType}, false); + const auto finishPtr = CastInst::Create(Instruction::IntToPtr, finishFunc, PointerType::getUnqual(finishType), "finish_func", block); + const auto hasOutput = CallInst::Create(finishType, finishPtr, {stateArg}, "has_output", block); + + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block); + + BranchInst::Create(work, over, hasOutput, block); + + block = work; + + const auto fillBlockFunc = ConstantInt::get(Type::getInt64Ty(context), fillOutputMethodPtr); + const auto fillBlockType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false); + const auto fillBlockPtr = CastInst::Create(Instruction::IntToPtr, fillBlockFunc, PointerType::getUnqual(fillBlockType), "fill_output_func", block); + const auto hasData = CallInst::Create(fillBlockType, fillBlockPtr, {stateArg, ctx.GetFactory()}, "fill_output", block); + + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block); + + BranchInst::Create(fill, over, hasData, block); + + block = fill; + + const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), sliceMethodPtr); + const auto sliceType = FunctionType::get(indexType, {statePtrType}, false); + const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block); + const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block); + new StoreInst(slice, heightPtr, block); + new StoreInst(stateArg, stateOnStack, block); + + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block); + + BranchInst::Create(over, block); + + block = over; + + ICodegeneratorInlineWideNode::TGettersList getters(outputWidth); + for (size_t idx = 0U; idx < getters.size(); ++idx) { + getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) { + const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block); + const auto heightArg = new LoadInst(indexType, heightPtr, "height", block); + return CallInst::Create(getType, getPtr, {stateArg, heightArg, ctx.GetFactory(), ConstantInt::get(indexType, idx)}, "get", block); + }; + } + return {result, std::move(getters)}; + } +#endif +}; + +template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived> +class THashedWrapperBase : public TStatefulWideFlowCodegeneratorNode<TDerived>, + protected THashedWrapperCodegenBase +{ + using TComputationBase = TStatefulWideFlowCodegeneratorNode<TDerived>; + static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value; +public: + THashedWrapperBase(TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional<ui32> filterColumn, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<TAggregator>>&& aggsParams, + ui32 streamIndex, + std::vector<std::vector<ui32>>&& streams) + : TComputationBase(mutables, flow, EValueRepresentation::Boxed) + , Flow_(flow) + , FilterColumn_(filterColumn) + , Width_(width) + , OutputWidth_(keys.size() + aggsParams.size() + 1) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) + , Keys_(keys) + , MaxBlockLen_(maxBlockLen) + , AggsParams_(std::move(aggsParams)) + , KeyLength_(keyLength) + , StreamIndex_(streamIndex) + , Streams_(std::move(streams)) + { + MKQL_ENSURE(Width_ > 0, "Missing block length column"); + if constexpr (UseFilter) { + MKQL_ENSURE(filterColumn, "Missing filter column"); + MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize"); + } else { + MKQL_ENSURE(!filterColumn, "Unexpected filter column"); + } + } + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, + TComputationContext& ctx, + NUdf::TUnboxedValue*const* output) const + { + auto& s = GetState(state, ctx); + if (!s.Count) { + if (s.IsFinished_) + return EFetchResult::Finish; + + while (!s.WritingOutput_) { + const auto fields = ctx.WideFields.data() + WideFieldsIndex_; + s.Values_.assign(s.Values_.size(), NUdf::TUnboxedValuePod()); + switch (Flow_->FetchValues(ctx, fields)) { + case EFetchResult::Yield: + return EFetchResult::Yield; + case EFetchResult::One: + s.ProcessInput(ctx.HolderFactory); + continue; + case EFetchResult::Finish: + break; + } + + if (s.Finish()) + break; + else + return EFetchResult::Finish; + } + + if (!s.FillOutput(ctx.HolderFactory)) + return EFetchResult::Finish; + } + + const auto sliceSize = s.Slice(); + for (size_t i = 0; i < OutputWidth_; ++i) { + if (const auto out = output[i]) { + *out = s.Get(sliceSize, ctx.HolderFactory, i); + } + } + return EFetchResult::One; + } +#ifndef MKQL_DISABLE_CODEGEN + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, OutputWidth_, + GetMethodPtr(&TState::Get), GetMethodPtr(&THashedWrapperBase::MakeState), + GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::Finish), + GetMethodPtr(&TState::FillOutput), GetMethodPtr(&TState::Slice)); + } +#endif +private: + struct TState : public TBlockState { + bool WritingOutput_ = false; + bool IsFinished_ = false; + + const std::optional<ui32> FilterColumn_; + const std::vector<TKeyParams> Keys_; + const std::vector<TAggParams<TAggregator>>& AggsParams_; + const ui32 KeyLength_; + const ui32 StreamIndex_; + const std::vector<std::vector<ui32>> Streams_; + const size_t MaxBlockLen_; + + template<typename TKeyType> + struct THashSettings { + static constexpr bool CacheHash = std::is_same_v<TKeyType, TSSOKey>; + }; + using TDynMapImpl = TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>; + using TSetImpl = THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>; + using TFixedMapImpl = TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>; + + ui64 BatchNum_ = 0; + TUnboxedValueVector Values_; + std::vector<std::unique_ptr<TAggregator>> Aggs_; + std::vector<ui32> AggStateOffsets_; + TUnboxedValueVector UnwrappedValues_; + std::vector<std::unique_ptr<IBlockReader>> Readers_; + std::vector<std::unique_ptr<IArrayBuilder>> Builders_; + std::vector<std::unique_ptr<IAggColumnBuilder>> AggBuilders_; + bool HasValues_ = false; + ui32 TotalStateSize_ = 0; + size_t OutputBlockSize_ = 0; + std::unique_ptr<TDynMapImpl> HashMap_; + typename TDynMapImpl::const_iterator HashMapIt_; + std::unique_ptr<TSetImpl> HashSet_; + typename TSetImpl::const_iterator HashSetIt_; + std::unique_ptr<TFixedMapImpl> HashFixedMap_; + typename TFixedMapImpl::const_iterator HashFixedMapIt_; + TPagedArena Arena_; + + TState(TMemoryUsageInfo* memInfo, ui32 keyLength, ui32 streamIndex, size_t width, size_t outputWidth, std::optional<ui32> filterColumn, const std::vector<TAggParams<TAggregator>>& params, + const std::vector<std::vector<ui32>>& streams, const std::vector<TKeyParams>& keys, size_t maxBlockLen, TComputationContext& ctx) + : TBlockState(memInfo, outputWidth) + , FilterColumn_(filterColumn) + , Keys_(keys) + , AggsParams_(params) + , KeyLength_(keyLength) + , StreamIndex_(streamIndex) + , Streams_(streams) + , MaxBlockLen_(maxBlockLen) + , Values_(width) + , UnwrappedValues_(width) + , Readers_(keys.size()) + , Builders_(keys.size()) + , Arena_(TlsAllocState) + { + Pointer_ = Values_.data(); + for (size_t i = 0; i < Keys_.size(); ++i) { + auto itemType = AS_TYPE(TBlockType, Keys_[i].Type)->GetItemType(); + Readers_[i] = NYql::NUdf::MakeBlockReader(TTypeInfoHelper(), itemType); + Builders_[i] = NYql::NUdf::MakeArrayBuilder(TTypeInfoHelper(), itemType, ctx.ArrowMemoryPool, MaxBlockLen_, &ctx.Builder->GetPgBuilder()); + } + + if constexpr (Many) { + TotalStateSize_ += Streams_.size(); + } + + for (const auto& p : AggsParams_) { + Aggs_.emplace_back(p.Prepared_->Make(ctx)); + MKQL_ENSURE(Aggs_.back()->StateSize == p.Prepared_->StateSize, "State size mismatch"); + AggStateOffsets_.emplace_back(TotalStateSize_); + TotalStateSize_ += Aggs_.back()->StateSize; + } + + auto equal = MakeEqual<TKey>(KeyLength_); + auto hasher = MakeHash<TKey>(KeyLength_); + if constexpr (UseSet) { + MKQL_ENSURE(params.empty(), "Only keys are supported"); + HashSet_ = std::make_unique<THashSetImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(hasher, equal); + } else { + if (!InlineAggState) { + HashFixedMap_ = std::make_unique<TFixedHashMapImpl<TKey, TFixedAggState, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(hasher, equal); + } else { + HashMap_ = std::make_unique<TDynamicHashMapImpl<TKey, std::equal_to<TKey>, std::hash<TKey>, TMKQLAllocator<char>, THashSettings<TKey>>>(TotalStateSize_, hasher, equal); + } + } + } + + void ProcessInput(const THolderFactory& holderFactory) { + ++BatchNum_; + const auto batchLength = TArrowBlock::From(Values_.back()).GetDatum().scalar_as<arrow::UInt64Scalar>().value; + if (!batchLength) { + return; + } + + const ui8* filterBitmap = nullptr; + if constexpr (UseFilter) { + auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum(); + if (filterDatum.is_scalar()) { + if (!filterDatum.template scalar_as<arrow::UInt8Scalar>().value) { + return; + } + } else { + const auto& arr = filterDatum.array(); + filterBitmap = arr->template GetValues<ui8>(1); + ui64 popCount = GetBitmapPopCount(arr); + if (popCount == 0) { + return; + } + } + } + + const ui32* streamIndexData = nullptr; + TMaybe<ui32> streamIndexScalar; + if constexpr (Many) { + auto streamIndexDatum = TArrowBlock::From(Values_[StreamIndex_]).GetDatum(); + if (streamIndexDatum.is_scalar()) { + streamIndexScalar = streamIndexDatum.template scalar_as<arrow::UInt32Scalar>().value; + } else { + MKQL_ENSURE(streamIndexDatum.is_array(), "Expected array"); + streamIndexData = streamIndexDatum.array()->template GetValues<ui32>(1); + } + UnwrappedValues_ = Values_; + for (const auto& p : AggsParams_) { + const auto& columnDatum = TArrowBlock::From(UnwrappedValues_[p.Column_]).GetDatum(); + MKQL_ENSURE(columnDatum.is_array(), "Expected array"); + UnwrappedValues_[p.Column_] = holderFactory.CreateArrowBlock(Unwrap(*columnDatum.array(), p.StateType_)); + } + } + + HasValues_ = true; + std::vector<arrow::Datum> keysDatum; + keysDatum.reserve(Keys_.size()); + for (ui32 i = 0; i < Keys_.size(); ++i) { + keysDatum.emplace_back(TArrowBlock::From(Values_[Keys_[i].Index]).GetDatum()); + } + + std::array<TOutputBuffer, PrefetchBatchSize> out; + for (ui32 i = 0; i < PrefetchBatchSize; ++i) { + out[i].Resize(sizeof(TKey)); + } + + std::array<TRobinHoodBatchRequestItem<TKey>, PrefetchBatchSize> insertBatch; + std::array<ui64, PrefetchBatchSize> insertBatchRows; + std::array<char*, PrefetchBatchSize> insertBatchPayloads; + std::array<bool, PrefetchBatchSize> insertBatchIsNew; + ui32 insertBatchLen = 0; + + const auto processInsertBatch = [&]() { + for (ui32 i = 0; i < insertBatchLen; ++i) { + auto& r = insertBatch[i]; + TStringBuf str = out[i].Finish(); + TKey key = MakeKey<TKey>(str, KeyLength_); + r.ConstructKey(key); + } + + if constexpr (UseSet) { + HashSet_->BatchInsert({insertBatch.data(), insertBatchLen},[&](size_t index, typename TState::TSetImpl::iterator iter, bool isNew) { + Y_UNUSED(index); + if (isNew) { + if constexpr (std::is_same<TKey, TSSOKey>::value || std::is_same<TKey, TExternalFixedSizeKey>::value) { + MoveKeyToArena(HashSet_->GetKey(iter), Arena_, KeyLength_); + } + } + }); + } else { + using THashTable = std::conditional_t<InlineAggState, typename TState::TDynMapImpl, typename TState::TFixedMapImpl>; + THashTable* hash; + if constexpr (!InlineAggState) { + hash = HashFixedMap_.get(); + } else { + hash = HashMap_.get(); + } + + hash->BatchInsert({insertBatch.data(), insertBatchLen}, [&](size_t index, typename THashTable::iterator iter, bool isNew) { + if (isNew) { + if constexpr (std::is_same<TKey, TSSOKey>::value || std::is_same<TKey, TExternalFixedSizeKey>::value) { + MoveKeyToArena(hash->GetKey(iter), Arena_, KeyLength_); + } + } + + if constexpr (UseArena) { + // prefetch payloads only + auto payload = hash->GetPayload(iter); + char* ptr; + if (isNew) { + ptr = (char*)Arena_.Alloc(TotalStateSize_); + *(char**)payload = ptr; + } else { + ptr = *(char**)payload; + } + + insertBatchIsNew[index] = isNew; + insertBatchPayloads[index] = ptr; + NYql::PrefetchForWrite(ptr); + } else { + // process insert + auto payload = (char*)hash->GetPayload(iter); + auto row = insertBatchRows[index]; + ui32 streamIndex = 0; + if constexpr (Many) { + streamIndex = streamIndexScalar ? *streamIndexScalar : streamIndexData[row]; + } + + Insert(row, payload, isNew, streamIndex); + } + }); + + if constexpr (UseArena) { + for (ui32 i = 0; i < insertBatchLen; ++i) { + auto row = insertBatchRows[i]; + ui32 streamIndex = 0; + if constexpr (Many) { + streamIndex = streamIndexScalar ? *streamIndexScalar : streamIndexData[row]; + } + + bool isNew = insertBatchIsNew[i]; + char* payload = insertBatchPayloads[i]; + Insert(row, payload, isNew, streamIndex); + } + } + } + }; + + for (ui64 row = 0; row < batchLength; ++row) { + if constexpr (UseFilter) { + if (filterBitmap && !filterBitmap[row]) { + continue; + } + } + + // encode key + out[insertBatchLen].Rewind(); + for (ui32 i = 0; i < keysDatum.size(); ++i) { + if (keysDatum[i].is_scalar()) { + // TODO: more efficient code when grouping by scalar + Readers_[i]->SaveScalarItem(*keysDatum[i].scalar(), out[insertBatchLen]); + } else { + Readers_[i]->SaveItem(*keysDatum[i].array(), row, out[insertBatchLen]); + } + } + + insertBatchRows[insertBatchLen] = row; + ++insertBatchLen; + if (insertBatchLen == PrefetchBatchSize) { + processInsertBatch(); + insertBatchLen = 0; + } + } + + processInsertBatch(); + } + + bool Finish() { + if (!HasValues_) { + IsFinished_ = true; + return false; + } + + WritingOutput_ = true; + OutputBlockSize_ = 0; + PrepareAggBuilders(); + + if constexpr (UseSet) { + HashSetIt_ = HashSet_->Begin(); + } else { + if constexpr (!InlineAggState) { + HashFixedMapIt_ = HashFixedMap_->Begin(); + } else { + HashMapIt_ = HashMap_->Begin(); + } + } + return true; + } + + bool FillOutput(const THolderFactory& holderFactory) { + bool exit = false; + while (WritingOutput_) { + if constexpr (UseSet) { + for (;!exit && HashSetIt_ != HashSet_->End(); HashSet_->Advance(HashSetIt_)) { + if (!HashSet_->IsValid(HashSetIt_)) { + continue; + } + + if (OutputBlockSize_ == MaxBlockLen_) { + Flush(false, holderFactory); + //return EFetchResult::One; + exit = true; + break; + } + + const TKey& key = HashSet_->GetKey(HashSetIt_); + TInputBuffer in(GetKeyView<TKey>(key, KeyLength_)); + for (auto& kb : Builders_) { + kb->Add(in); + } + ++OutputBlockSize_; + } + break; + } else { + const bool done = InlineAggState ? + Iterate(*HashMap_, HashMapIt_) : + Iterate(*HashFixedMap_, HashFixedMapIt_); + if (done) { + break; + } + Flush(false, holderFactory); + exit = true; + break; + } + } + + if (!exit) { + IsFinished_ = true; + WritingOutput_ = false; + if (!OutputBlockSize_) + return false; + Flush(true, holderFactory); + } + + FillArrays(); + return true; + } + private: + void PrepareAggBuilders() { + if constexpr (!UseSet) { + AggBuilders_.clear(); + AggBuilders_.reserve(Aggs_.size()); + for (const auto& a : Aggs_) { + if constexpr (Finalize) { + AggBuilders_.emplace_back(a->MakeResultBuilder(MaxBlockLen_)); + } else { + AggBuilders_.emplace_back(a->MakeStateBuilder(MaxBlockLen_)); + } + } + } + } + + void Flush(bool final, const THolderFactory& holderFactory) { + if (!OutputBlockSize_) { + return; + } + + for (size_t i = 0; i < Builders_.size(); ++i) { + Values[i] = holderFactory.CreateArrowBlock(Builders_[i]->Build(final)); + } + + if constexpr (!UseSet) { + for (size_t i = 0; i < Aggs_.size(); ++i) { + Values[Builders_.size() + i] = AggBuilders_[i]->Build(); + } + if (!final) { + PrepareAggBuilders(); + } + } + + Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(OutputBlockSize_))); + OutputBlockSize_ = 0; + } + + void Insert(ui64 row, char* payload, bool isNew, ui32 currentStreamIndex) const { + char* ptr = payload; + + if (isNew) { + if constexpr (Many) { + static_assert(Finalize); + MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); + memset(ptr, 0, Streams_.size()); + ptr[currentStreamIndex] = 1; + + for (auto i : Streams_[currentStreamIndex]) { + + Aggs_[i]->LoadState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row); + } + } else { + for (size_t i = 0; i < Aggs_.size(); ++i) { + if constexpr (Finalize) { + Aggs_[i]->LoadState(ptr, BatchNum_, Values_.data(), row); + } else { + Aggs_[i]->InitKey(ptr, BatchNum_, Values_.data(), row); + } + + ptr += Aggs_[i]->StateSize; + } + } + } else { + if constexpr (Many) { + static_assert(Finalize); + MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); + + bool isNewStream = !ptr[currentStreamIndex]; + ptr[currentStreamIndex] = 1; + + for (auto i : Streams_[currentStreamIndex]) { + + if (isNewStream) { + Aggs_[i]->LoadState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row); + } else { + Aggs_[i]->UpdateState(ptr + AggStateOffsets_[i], BatchNum_, UnwrappedValues_.data(), row); + } + } + } else { + for (size_t i = 0; i < Aggs_.size(); ++i) { + if constexpr (Finalize) { + Aggs_[i]->UpdateState(ptr, BatchNum_, Values_.data(), row); + } else { + Aggs_[i]->UpdateKey(ptr, BatchNum_, Values_.data(), row); + } + + ptr += Aggs_[i]->StateSize; + } + } + } + } + + template <typename THash> + bool Iterate(THash& hash, typename THash::const_iterator& iter) { + MKQL_ENSURE(WritingOutput_, "Supposed to be called at the end"); + std::array<typename THash::const_iterator, PrefetchBatchSize> iters; + ui32 itersLen = 0; + auto iterateBatch = [&]() { + for (ui32 i = 0; i < itersLen; ++i) { + auto iter = iters[i]; + const TKey& key = hash.GetKey(iter); + auto payload = (char*)hash.GetPayload(iter); + char* ptr; + if constexpr (UseArena) { + ptr = *(char**)payload; + } else { + ptr = payload; + } + + TInputBuffer in(GetKeyView<TKey>(key, KeyLength_)); + for (auto& kb : Builders_) { + kb->Add(in); + } + + if constexpr (Many) { + for (ui32 i = 0; i < Streams_.size(); ++i) { + MKQL_ENSURE(ptr[i], "Missing partial aggregation state for stream #" << i); + } + + ptr += Streams_.size(); + } + + for (size_t i = 0; i < Aggs_.size(); ++i) { + AggBuilders_[i]->Add(ptr); + Aggs_[i]->DestroyState(ptr); + + ptr += Aggs_[i]->StateSize; + } + } + }; + + for (; iter != hash.End(); hash.Advance(iter)) { + if (!hash.IsValid(iter)) { + continue; + } + + if (OutputBlockSize_ == MaxBlockLen_) { + iterateBatch(); + return false; + } + + if (itersLen == iters.size()) { + iterateBatch(); + itersLen = 0; + } + + iters[itersLen] = iter; + ++itersLen; + ++OutputBlockSize_; + if constexpr (UseArena) { + auto payload = (char*)hash.GetPayload(iter); + auto ptr = *(char**)payload; + NYql::PrefetchForWrite(ptr); + } + + if constexpr (std::is_same<TKey, TSSOKey>::value) { + const auto& key = hash.GetKey(iter); + if (!key.IsInplace()) { + NYql::PrefetchForRead(key.AsView().Data()); + } + } else if constexpr (std::is_same<TKey, TExternalFixedSizeKey>::value) { + const auto& key = hash.GetKey(iter); + NYql::PrefetchForRead(key.Data); + } + } + + iterateBatch(); + return true; + } + }; +private: + void RegisterDependencies() const final { + this->FlowDependsOn(Flow_); + } + + void MakeState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { + state = ctx.HolderFactory.Create<TState>(KeyLength_, StreamIndex_, Width_, OutputWidth_, FilterColumn_, AggsParams_, Streams_, Keys_, MaxBlockLen_, ctx); + } + + TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { + if (state.IsInvalid()) { + MakeState(state, ctx); + + auto& s = *static_cast<TState*>(state.AsBoxed().Get()); + const auto fields = ctx.WideFields.data() + WideFieldsIndex_; + for (size_t i = 0; i < s.Values_.size(); ++i) { + fields[i] = &s.Values_[i]; + } + return s; + } + return *static_cast<TState*>(state.AsBoxed().Get()); + } + + IComputationWideFlowNode *const Flow_; + const std::optional<ui32> FilterColumn_; + const size_t Width_; + const size_t OutputWidth_; + const size_t WideFieldsIndex_; + const std::vector<TKeyParams> Keys_; + const size_t MaxBlockLen_; + const std::vector<TAggParams<TAggregator>> AggsParams_; + const ui32 KeyLength_; + const ui32 StreamIndex_; + const std::vector<std::vector<ui32>> Streams_; +}; + +template <typename TKey, typename TFixedAggState, bool UseSet, bool UseFilter> +class TBlockCombineHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>> { +public: + using TSelf = TBlockCombineHashedWrapper<TKey, TFixedAggState, UseSet, UseFilter>; + using TBase = THashedWrapperBase<TKey, IBlockAggregatorCombineKeys, TFixedAggState, UseSet, UseFilter, false, false, TSelf>; + + TBlockCombineHashedWrapper(TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional<ui32> filterColumn, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) + : TBase(mutables, flow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams), 0, {}) + {} +}; + +template <typename TKey, typename TFixedAggState, bool UseSet> +class TBlockMergeFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>> { +public: + using TSelf = TBlockMergeFinalizeHashedWrapper<TKey, TFixedAggState, UseSet>; + using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, UseSet, false, true, false, TSelf>; + + TBlockMergeFinalizeHashedWrapper(TComputationMutables& mutables, + IComputationWideFlowNode* flow, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) + : TBase(mutables, flow, {}, width, keys, maxBlockLen, keyLength, std::move(aggsParams), 0, {}) + {} +}; + +template <typename TKey, typename TFixedAggState> +class TBlockMergeManyFinalizeHashedWrapper : public THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState>> { +public: + using TSelf = TBlockMergeManyFinalizeHashedWrapper<TKey, TFixedAggState>; + using TBase = THashedWrapperBase<TKey, IBlockAggregatorFinalizeKeys, TFixedAggState, false, false, true, true, TSelf>; + + TBlockMergeManyFinalizeHashedWrapper(TComputationMutables& mutables, + IComputationWideFlowNode* flow, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams, + ui32 streamIndex, std::vector<std::vector<ui32>>&& streams) + : TBase(mutables, flow, {}, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)) + {} +}; + +template <typename TAggregator> +std::unique_ptr<IPreparedBlockAggregator<TAggregator>> PrepareBlockAggregator(const IBlockAggregatorFactory& factory, + TTupleType* tupleType, + std::optional<ui32> filterColumn, + const std::vector<ui32>& argsColumns, + const TTypeEnvironment& env, + TType* returnType, + ui32 hint); + +template <> +std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineAll>> PrepareBlockAggregator<IBlockAggregatorCombineAll>(const IBlockAggregatorFactory& factory, + TTupleType* tupleType, + std::optional<ui32> filterColumn, + const std::vector<ui32>& argsColumns, + const TTypeEnvironment& env, + TType* returnType, + ui32 hint) { + Y_UNUSED(hint); + MKQL_ENSURE(!returnType, "Unexpected return type"); + return factory.PrepareCombineAll(tupleType, filterColumn, argsColumns, env); +} + +template <> +std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineKeys>> PrepareBlockAggregator<IBlockAggregatorCombineKeys>(const IBlockAggregatorFactory& factory, + TTupleType* tupleType, + std::optional<ui32> filterColumn, + const std::vector<ui32>& argsColumns, + const TTypeEnvironment& env, + TType* returnType, + ui32 hint) { + Y_UNUSED(hint); + MKQL_ENSURE(!filterColumn, "Unexpected filter column"); + MKQL_ENSURE(!returnType, "Unexpected return type"); + return factory.PrepareCombineKeys(tupleType, argsColumns, env); +} + +template <> +std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorFinalizeKeys>> PrepareBlockAggregator<IBlockAggregatorFinalizeKeys>(const IBlockAggregatorFactory& factory, + TTupleType* tupleType, + std::optional<ui32> filterColumn, + const std::vector<ui32>& argsColumns, + const TTypeEnvironment& env, + TType* returnType, + ui32 hint) { + MKQL_ENSURE(!filterColumn, "Unexpected filter column"); + MKQL_ENSURE(returnType, "Missing return type"); + return factory.PrepareFinalizeKeys(tupleType, argsColumns, env, returnType, hint); +} + +template <typename TAggregator> +ui32 FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, std::vector<TAggParams<TAggregator>>& aggsParams, + const TTypeEnvironment& env, bool overState, bool many, TArrayRef<TType* const> returnTypes, ui32 keysCount) { + TTupleType* unwrappedTupleType = tupleType; + if (many) { + std::vector<TType*> unwrappedTypes(tupleType->GetElementsCount()); + for (ui32 i = 0; i < tupleType->GetElementsCount(); ++i) { + unwrappedTypes[i] = tupleType->GetElementType(i); + } + + for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) { + auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i)); + MKQL_ENSURE(aggVal->GetValuesCount() == 2, "Expected only one column"); + auto index = AS_VALUE(TDataLiteral, aggVal->GetValue(1))->AsValue().Get<ui32>(); + MKQL_ENSURE(index < unwrappedTypes.size(), "Bad state column index"); + auto blockType = AS_TYPE(TBlockType, unwrappedTypes[index]); + MKQL_ENSURE(blockType->GetShape() == TBlockType::EShape::Many, "State must be a block"); + bool isOptional; + auto unpacked = UnpackOptional(blockType->GetItemType(), isOptional); + MKQL_ENSURE(isOptional, "State must be optional"); + unwrappedTypes[index] = TBlockType::Create(unpacked, TBlockType::EShape::Many, env); + } + + unwrappedTupleType = TTupleType::Create(unwrappedTypes.size(), unwrappedTypes.data(), env); + } + + ui32 totalStateSize = 0; + for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) { + auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i)); + auto name = AS_VALUE(TDataLiteral, aggVal->GetValue(0))->AsValue().AsStringRef(); + + std::vector<ui32> argColumns; + for (ui32 j = 1; j < aggVal->GetValuesCount(); ++j) { + argColumns.push_back(AS_VALUE(TDataLiteral, aggVal->GetValue(j))->AsValue().Get<ui32>()); + } + + TAggParams<TAggregator> p; + if (overState) { + MKQL_ENSURE(argColumns.size() == 1, "Expected exactly one column"); + p.Column_ = argColumns[0]; + p.StateType_ = AS_TYPE(TBlockType, tupleType->GetElementType(p.Column_))->GetItemType(); + p.ReturnType_ = returnTypes[i + keysCount]; + TStringBuf left, right; + if (TStringBuf(name).TrySplit('#', left, right)) { + p.Hint_ = FromString<ui32>(right); + } + } + + p.Prepared_ = PrepareBlockAggregator<TAggregator>(GetBlockAggregatorFactory(name), unwrappedTupleType, filterColumn, argColumns, env, p.ReturnType_, p.Hint_); + + totalStateSize += p.Prepared_->StateSize; + aggsParams.emplace_back(std::move(p)); + } + + return totalStateSize; +} + +template <bool UseSet, bool UseFilter, typename TKey> +IComputationNode* MakeBlockCombineHashedWrapper( + ui32 keyLength, + ui32 totalStateSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional<ui32> filterColumn, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) { + if (totalStateSize <= sizeof(TState8)) { + return new TBlockCombineHashedWrapper<TKey, TState8, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + } + + if (totalStateSize <= sizeof(TState16)) { + return new TBlockCombineHashedWrapper<TKey, TState16, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + } + + return new TBlockCombineHashedWrapper<TKey, TStateArena, UseSet, UseFilter>(mutables, flow, filterColumn, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); +} + +template <bool UseSet, bool UseFilter> +IComputationNode* MakeBlockCombineHashedWrapper( + TMaybe<ui32> totalKeysSize, + bool isFixed, + ui32 totalStateSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional<ui32> filterColumn, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + std::vector<TAggParams<IBlockAggregatorCombineKeys>>&& aggsParams) { + if (totalKeysSize && *totalKeysSize <= sizeof(ui32)) { + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui32>(*totalKeysSize, totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); + } + + if (totalKeysSize && *totalKeysSize <= sizeof(ui64)) { + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, ui64>(*totalKeysSize, totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); + } + + if (totalKeysSize && *totalKeysSize <= sizeof(TKey16)) { + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TKey16>(*totalKeysSize, totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); + } + + if (totalKeysSize && isFixed) { + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TExternalFixedSizeKey>(*totalKeysSize, totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); + } + + return MakeBlockCombineHashedWrapper<UseSet, UseFilter, TSSOKey>(Max<ui32>(), totalStateSize, mutables, flow, filterColumn, width, keys, maxBlockLen, std::move(aggsParams)); +} + +template <typename TKey, bool UseSet> +IComputationNode* MakeBlockMergeFinalizeHashedWrapper( + ui32 keyLength, + ui32 totalStateSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) { + + if (totalStateSize <= sizeof(TState8)) { + return new TBlockMergeFinalizeHashedWrapper<TKey, TState8, UseSet>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + } + + if (totalStateSize <= sizeof(TState16)) { + return new TBlockMergeFinalizeHashedWrapper<TKey, TState16, UseSet>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); + } + + return new TBlockMergeFinalizeHashedWrapper<TKey, TStateArena, UseSet>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams)); +} + +template <bool UseSet> +IComputationNode* MakeBlockMergeFinalizeHashedWrapper( + TMaybe<ui32> totalKeysSize, + bool isFixed, + ui32 totalStateSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams) { + if (totalKeysSize && *totalKeysSize <= sizeof(ui32)) { + return MakeBlockMergeFinalizeHashedWrapper<ui32, UseSet>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); + } + + if (totalKeysSize && *totalKeysSize <= sizeof(ui64)) { + return MakeBlockMergeFinalizeHashedWrapper<ui64, UseSet>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); + } + + if (totalKeysSize && *totalKeysSize <= sizeof(TKey16)) { + return MakeBlockMergeFinalizeHashedWrapper<TKey16, UseSet>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); + } + + if (totalKeysSize && isFixed) { + return MakeBlockMergeFinalizeHashedWrapper<TExternalFixedSizeKey, UseSet>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); + } + + return MakeBlockMergeFinalizeHashedWrapper<TSSOKey, UseSet>(Max<ui32>(), totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams)); +} + +template <typename TKey> +IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper( + ui32 keyLength, + ui32 totalStateSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams, + ui32 streamIndex, + std::vector<std::vector<ui32>>&& streams) { + + if (totalStateSize <= sizeof(TState8)) { + return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState8>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)); + } + + if (totalStateSize <= sizeof(TState16)) { + return new TBlockMergeManyFinalizeHashedWrapper<TKey, TState16>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)); + } + + return new TBlockMergeManyFinalizeHashedWrapper<TKey, TStateArena>(mutables, flow, width, keys, maxBlockLen, keyLength, std::move(aggsParams), streamIndex, std::move(streams)); +} + +IComputationNode* MakeBlockMergeManyFinalizeHashedWrapper( + TMaybe<ui32> totalKeysSize, + bool isFixed, + ui32 totalStateSize, + TComputationMutables& mutables, + IComputationWideFlowNode* flow, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>>&& aggsParams, + ui32 streamIndex, + std::vector<std::vector<ui32>>&& streams) { + if (totalKeysSize && *totalKeysSize <= sizeof(ui32)) { + return MakeBlockMergeManyFinalizeHashedWrapper<ui32>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + } + + if (totalKeysSize && *totalKeysSize <= sizeof(ui64)) { + return MakeBlockMergeManyFinalizeHashedWrapper<ui64>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + } + + if (totalKeysSize && *totalKeysSize <= sizeof(TKey16)) { + return MakeBlockMergeManyFinalizeHashedWrapper<TKey16>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + } + + if (totalKeysSize && isFixed) { + return MakeBlockMergeManyFinalizeHashedWrapper<TExternalFixedSizeKey>(*totalKeysSize, totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + } + + return MakeBlockMergeManyFinalizeHashedWrapper<TSSOKey>(Max<ui32>(), totalStateSize, mutables, flow, width, keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); +} + +void PrepareKeys(const std::vector<TKeyParams>& keys, TMaybe<ui32>& totalKeysSize, bool& isFixed) { + NYql::NUdf::TBlockItemSerializeProps props; + for (auto& param : keys) { + auto type = AS_TYPE(TBlockType, param.Type); + UpdateBlockItemSerializeProps(TTypeInfoHelper(), type->GetItemType(), props); + } + + isFixed = props.IsFixed; + totalKeysSize = props.MaxSize; +} + +void FillAggStreams(TRuntimeNode streamsNode, std::vector<std::vector<ui32>>& streams) { + auto streamsVal = AS_VALUE(TTupleLiteral, streamsNode); + for (ui32 i = 0; i < streamsVal->GetValuesCount(); ++i) { + streams.emplace_back(); + auto& stream = streams.back(); + auto streamVal = AS_VALUE(TTupleLiteral, streamsVal->GetValue(i)); + for (ui32 j = 0; j < streamVal->GetValuesCount(); ++j) { + ui32 index = AS_VALUE(TDataLiteral, streamVal->GetValue(j))->AsValue().Get<ui32>(); + stream.emplace_back(index); + } + } +} + +} + +IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args"); + const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); + const auto wideComponents = GetWideComponents(flowType); + const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env); + const auto returnFlowType = AS_TYPE(TFlowType, callable.GetType()->GetReturnType()); + const auto returnWideComponents = GetWideComponents(returnFlowType); + + auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + + auto filterColumnVal = AS_VALUE(TOptionalLiteral, callable.GetInput(1)); + std::optional<ui32> filterColumn; + if (filterColumnVal->HasItem()) { + filterColumn = AS_VALUE(TDataLiteral, filterColumnVal->GetItem())->AsValue().Get<ui32>(); + } + + auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2)); + std::vector<TAggParams<IBlockAggregatorCombineAll>> aggsParams; + FillAggParams<IBlockAggregatorCombineAll>(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env, false, false, returnWideComponents, 0); + return new TBlockCombineAllWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams)); +} + +IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args"); + const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); + const auto wideComponents = GetWideComponents(flowType); + const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env); + const auto returnFlowType = AS_TYPE(TFlowType, callable.GetType()->GetReturnType()); + const auto returnWideComponents = GetWideComponents(returnFlowType); + + auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + + auto filterColumnVal = AS_VALUE(TOptionalLiteral, callable.GetInput(1)); + std::optional<ui32> filterColumn; + if (filterColumnVal->HasItem()) { + filterColumn = AS_VALUE(TDataLiteral, filterColumnVal->GetItem())->AsValue().Get<ui32>(); + } + + auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(2)); + std::vector<TKeyParams> keys; + for (ui32 i = 0; i < keysVal->GetValuesCount(); ++i) { + ui32 index = AS_VALUE(TDataLiteral, keysVal->GetValue(i))->AsValue().Get<ui32>(); + keys.emplace_back(TKeyParams{ index, tupleType->GetElementType(index) }); + } + + auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(3)); + std::vector<TAggParams<IBlockAggregatorCombineKeys>> aggsParams; + ui32 totalStateSize = FillAggParams<IBlockAggregatorCombineKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, false, false, returnWideComponents, keys.size()); + + TMaybe<ui32> totalKeysSize; + bool isFixed = false; + PrepareKeys(keys, totalKeysSize, isFixed); + + const size_t maxBlockLen = CalcMaxBlockLenForOutput(callable.GetType()->GetReturnType()); + if (filterColumn) { + if (aggsParams.empty()) { + return MakeBlockCombineHashedWrapper<true, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockCombineHashedWrapper<false, true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } + } else { + if (aggsParams.empty()) { + return MakeBlockCombineHashedWrapper<true, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockCombineHashedWrapper<false, false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } + } +} + +IComputationNode* WrapBlockMergeFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args"); + const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); + const auto wideComponents = GetWideComponents(flowType); + const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env); + const auto returnFlowType = AS_TYPE(TFlowType, callable.GetType()->GetReturnType()); + const auto returnWideComponents = GetWideComponents(returnFlowType); + + auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + + auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(1)); + std::vector<TKeyParams> keys; + for (ui32 i = 0; i < keysVal->GetValuesCount(); ++i) { + ui32 index = AS_VALUE(TDataLiteral, keysVal->GetValue(i))->AsValue().Get<ui32>(); + keys.emplace_back(TKeyParams{ index, tupleType->GetElementType(index) }); + } + + auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2)); + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>> aggsParams; + ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, true, false, returnWideComponents, keys.size()); + + TMaybe<ui32> totalKeysSize; + bool isFixed = false; + PrepareKeys(keys, totalKeysSize, isFixed); + + const size_t maxBlockLen = CalcMaxBlockLenForOutput(callable.GetType()->GetReturnType()); + if (aggsParams.empty()) { + return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockMergeFinalizeHashedWrapper<false>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), keys, maxBlockLen, std::move(aggsParams)); + } +} + +IComputationNode* WrapBlockMergeManyFinalizeHashed(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 args"); + const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); + const auto wideComponents = GetWideComponents(flowType); + const auto tupleType = TTupleType::Create(wideComponents.size(), wideComponents.data(), ctx.Env); + const auto returnFlowType = AS_TYPE(TFlowType, callable.GetType()->GetReturnType()); + const auto returnWideComponents = GetWideComponents(returnFlowType); + + const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + + auto keysVal = AS_VALUE(TTupleLiteral, callable.GetInput(1)); + std::vector<TKeyParams> keys; + for (ui32 i = 0; i < keysVal->GetValuesCount(); ++i) { + ui32 index = AS_VALUE(TDataLiteral, keysVal->GetValue(i))->AsValue().Get<ui32>(); + keys.emplace_back(TKeyParams{ index, tupleType->GetElementType(index) }); + } + + const auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2)); + std::vector<TAggParams<IBlockAggregatorFinalizeKeys>> aggsParams; + ui32 totalStateSize = FillAggParams<IBlockAggregatorFinalizeKeys>(aggsVal, tupleType, {}, aggsParams, ctx.Env, true, true, returnWideComponents, keys.size()); + + TMaybe<ui32> totalKeysSize; + bool isFixed = false; + PrepareKeys(keys, totalKeysSize, isFixed); + + const ui32 streamIndex = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui32>(); + std::vector<std::vector<ui32>> streams; + FillAggStreams(callable.GetInput(4), streams); + totalStateSize += streams.size(); + + const size_t maxBlockLen = CalcMaxBlockLenForOutput(callable.GetType()->GetReturnType()); + if (aggsParams.empty()) { + return MakeBlockMergeFinalizeHashedWrapper<true>(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), + keys, maxBlockLen, std::move(aggsParams)); + } else { + return MakeBlockMergeManyFinalizeHashedWrapper(totalKeysSize, isFixed, totalStateSize, ctx.Mutables, wideFlow, tupleType->GetElementsCount(), + keys, maxBlockLen, std::move(aggsParams), streamIndex, std::move(streams)); + } +} + +} +} |