diff options
5 files changed, 192 insertions, 100 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp index e730a28c5d..a3418056e8 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp @@ -76,10 +76,13 @@ public: } s.HasValues_ = true; + char* ptr = s.AggStates_.data(); for (size_t i = 0; i < s.Aggs_.size(); ++i) { if (output[i]) { - s.Aggs_[i]->AddMany(s.Values_.data(), batchLength, filtered); + s.Aggs_[i]->AddMany(ptr, s.Values_.data(), batchLength, filtered); } + + ptr += s.Aggs_[i]->StateSize; } } else { s.IsFinished_ = true; @@ -87,10 +90,13 @@ public: return EFetchResult::Finish; } + char* ptr = s.AggStates_.data(); for (size_t i = 0; i < s.Aggs_.size(); ++i) { if (auto* out = output[i]; out != nullptr) { - *out = s.Aggs_[i]->Finish(); + *out = s.Aggs_[i]->FinishOne(ptr); } + + ptr += s.Aggs_[i]->StateSize; } return EFetchResult::One; @@ -107,6 +113,7 @@ private: TVector<std::unique_ptr<IBlockAggregator>> Aggs_; bool IsFinished_ = false; bool HasValues_ = false; + TVector<char> AggStates_; TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const TVector<TAggParams>& params, const THolderFactory& holderFactory) : TComputationValue(memInfo) @@ -117,8 +124,18 @@ private: ValuePointers_[i] = &Values_[i]; } + ui32 totalStateSize = 0; for (const auto& p : params) { Aggs_.emplace_back(MakeBlockAggregator(p.Name, p.TupleType, filterColumn, p.ArgColumns, holderFactory)); + + totalStateSize += Aggs_.back()->StateSize; + } + + AggStates_.resize(totalStateSize); + char* ptr = AggStates_.data(); + for (const auto& agg : Aggs_) { + agg->InitState(ptr); + ptr += agg->StateSize; } } }; diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp index 5afa75b3df..e99e140ed2 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp @@ -5,52 +5,68 @@ namespace NMiniKQL { class TCountAllBlockAggregator : public TBlockAggregatorBase { public: + struct TState { + ui64 Count_ = 0; + }; + TCountAllBlockAggregator(std::optional<ui32> filterColumn) - : TBlockAggregatorBase(filterColumn) + : TBlockAggregatorBase(sizeof(TState), filterColumn) { } - void AddMany(const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + void InitState(void* state) final { + new(state) TState(); + } + + void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + auto typedState = static_cast<TState*>(state); Y_UNUSED(columns); if (filtered) { - State_ += *filtered; + typedState->Count_ += *filtered; } else { - State_ += batchLength; + typedState->Count_ += batchLength; } } - NUdf::TUnboxedValue Finish() final { - return NUdf::TUnboxedValuePod(State_); + NUdf::TUnboxedValue FinishOne(const void* state) final { + auto typedState = static_cast<const TState*>(state); + return NUdf::TUnboxedValuePod(typedState->Count_); } - -private: - ui64 State_ = 0; }; class TCountBlockAggregator : public TBlockAggregatorBase { public: + struct TState { + ui64 Count_ = 0; + }; + TCountBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn) - : TBlockAggregatorBase(filterColumn) + : TBlockAggregatorBase(sizeof(TState), filterColumn) , ArgColumn_(argColumn) { } - void AddMany(const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + void InitState(void* state) final { + new(state) TState(); + } + + void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + auto typedState = static_cast<TState*>(state); const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); if (datum.is_scalar()) { if (datum.scalar()->is_valid) { - State_ += filtered ? *filtered : batchLength; + typedState->Count_ += filtered ? *filtered : batchLength; } } else { const auto& array = datum.array(); if (!filtered) { - State_ += array->length - array->GetNullCount(); + typedState->Count_ += array->length - array->GetNullCount(); } else if (array->GetNullCount() == array->length) { // all nulls return; } else if (array->GetNullCount() == 0) { // no nulls - State_ += *filtered; + typedState->Count_ += *filtered; } else { const auto& filterDatum = TArrowBlock::From(columns[*FilterColumn_]).GetDatum(); // intersect masks from nulls and filter column @@ -58,7 +74,7 @@ public: MKQL_ENSURE(filterArray->GetNullCount() == 0, "Expected non-nullable bool column"); auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0); auto filterBitmap = filterArray->GetValues<uint8_t>(1, 0); - auto state = State_; + auto state = typedState->Count_; for (ui32 i = 0; i < array->length; ++i) { ui64 fullIndex = i + array->offset; auto bit1 = ((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1); @@ -66,18 +82,18 @@ public: state += bit1 & bit2; } - State_ = state; + typedState->Count_ = state; } } } - NUdf::TUnboxedValue Finish() final { - return NUdf::TUnboxedValuePod(State_); + NUdf::TUnboxedValue FinishOne(const void* state) final { + auto typedState = static_cast<const TState*>(state); + return NUdf::TUnboxedValuePod(typedState->Count_); } private: const ui32 ArgColumn_; - ui64 State_ = 0; }; class TBlockCountAllFactory : public IBlockAggregatorFactory { diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h index fc3a44270f..0e13b27dc2 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h @@ -10,15 +10,24 @@ class IBlockAggregator { public: virtual ~IBlockAggregator() = default; - virtual void AddMany(const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) = 0; + virtual void InitState(void* state) = 0; - virtual NUdf::TUnboxedValue Finish() = 0; + virtual void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) = 0; + + virtual NUdf::TUnboxedValue FinishOne(const void* state) = 0; + + const ui32 StateSize; + + explicit IBlockAggregator(ui32 stateSize) + : StateSize(stateSize) + {} }; class TBlockAggregatorBase : public IBlockAggregator { public: - TBlockAggregatorBase(std::optional<ui32> filterColumn) - : FilterColumn_(filterColumn) + TBlockAggregatorBase(ui32 stateSize, std::optional<ui32> filterColumn) + : IBlockAggregator(stateSize) + , FilterColumn_(filterColumn) { } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp index 1fc14ae633..9407f90840 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp @@ -22,24 +22,37 @@ T UpdateMinMax(T x, T y) { template <typename TIn, typename TInScalar, bool IsMin> class TMinMaxBlockAggregatorNullableOrScalar : public TBlockAggregatorBase { public: + struct TState { + TIn Value_; + bool IsValid_ = false; + + TState() { + if constexpr (IsMin) { + Value_ = std::numeric_limits<TIn>::max(); + } else { + Value_ = std::numeric_limits<TIn>::min(); + } + } + }; + TMinMaxBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn) - : TBlockAggregatorBase(filterColumn) + : TBlockAggregatorBase(sizeof(TState), filterColumn) , ArgColumn_(argColumn) { - if constexpr (IsMin) { - Value_ = std::numeric_limits<TIn>::max(); - } else { - Value_ = std::numeric_limits<TIn>::min(); - } } - void AddMany(const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + void InitState(void* state) final { + new(state) TState(); + } + + void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + auto typedState = static_cast<TState*>(state); Y_UNUSED(batchLength); const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); if (datum.is_scalar()) { if (datum.scalar()->is_valid) { - Value_ = datum.scalar_as<TInScalar>().value; - IsValid_ = true; + typedState->Value_ = datum.scalar_as<TInScalar>().value; + typedState->IsValid_ = true; } } else { const auto& array = datum.array(); @@ -51,8 +64,8 @@ public: } if (!filtered) { - IsValid_ = true; - TIn value = Value_; + typedState->IsValid_ = true; + TIn value = typedState->Value_; if (array->GetNullCount() == 0) { for (int64_t i = 0; i < len; ++i) { value = UpdateMinMax<IsMin>(value, ptr[i]); @@ -67,16 +80,16 @@ public: } } - Value_ = value; + typedState->Value_ = value; } else { const auto& filterDatum = TArrowBlock::From(columns[*FilterColumn_]).GetDatum(); const auto& filterArray = filterDatum.array(); MKQL_ENSURE(filterArray->GetNullCount() == 0, "Expected non-nullable bool column"); const ui8* filterBitmap = filterArray->template GetValues<uint8_t>(1); - TIn value = Value_; + TIn value = typedState->Value_; if (array->GetNullCount() == 0) { - IsValid_ = true; + typedState->IsValid_ = true; for (int64_t i = 0; i < len; ++i) { TIn filterMask = (((*filterBitmap++) & 1) ^ 1) - TIn(1); value = UpdateMinMax<IsMin>(value, TIn((ptr[i] & filterMask) | (value & ~filterMask))); @@ -94,43 +107,53 @@ public: count += mask & 1; } - IsValid_ = IsValid_ || count > 0; + typedState->IsValid_ = typedState->IsValid_ || count > 0; } - Value_ = value; + typedState->Value_ = value; } } } - NUdf::TUnboxedValue Finish() final { - if (!IsValid_) { + NUdf::TUnboxedValue FinishOne(const void* state) final { + auto typedState = static_cast<const TState*>(state); + if (!typedState->IsValid_) { return NUdf::TUnboxedValuePod(); } - return NUdf::TUnboxedValuePod(Value_); + return NUdf::TUnboxedValuePod(typedState->Value_); } private: const ui32 ArgColumn_; - TIn Value_ = 0; - bool IsValid_ = false; }; template <typename TIn, typename TInScalar, bool IsMin> class TMinMaxBlockAggregator: public TBlockAggregatorBase { public: + struct TState { + TIn Value_; + TState() { + if constexpr (IsMin) { + Value_ = std::numeric_limits<TIn>::max(); + } else { + Value_ = std::numeric_limits<TIn>::min(); + } + } + }; + TMinMaxBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn) - : TBlockAggregatorBase(filterColumn) + : TBlockAggregatorBase(sizeof(TState), filterColumn) , ArgColumn_(argColumn) { - if constexpr (IsMin) { - Value_ = std::numeric_limits<TIn>::max(); - } else { - Value_ = std::numeric_limits<TIn>::min(); - } } - void AddMany(const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + void InitState(void* state) final { + new(state) TState; + } + + void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + auto typedState = static_cast<TState*>(state); Y_UNUSED(batchLength); const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); MKQL_ENSURE(datum.is_array(), "Expected array"); @@ -140,36 +163,36 @@ public: MKQL_ENSURE(array->GetNullCount() == 0, "Expected no nulls"); MKQL_ENSURE(len > 0, "Expected at least one value"); if (!filtered) { - TIn value = Value_; + TIn value = typedState->Value_; for (int64_t i = 0; i < len; ++i) { value = UpdateMinMax<IsMin>(value, ptr[i]); } - Value_ = value; + typedState->Value_ = value; } else { const auto& filterDatum = TArrowBlock::From(columns[*FilterColumn_]).GetDatum(); const auto& filterArray = filterDatum.array(); MKQL_ENSURE(filterArray->GetNullCount() == 0, "Expected non-nullable bool column"); const ui8* filterBitmap = filterArray->template GetValues<uint8_t>(1); - TIn value = Value_; + TIn value = typedState->Value_; for (int64_t i = 0; i < len; ++i) { ui64 fullIndex = i + array->offset; TIn filterMask = (((*filterBitmap++) & 1) ^ 1) - TIn(1); value = UpdateMinMax<IsMin>(value, TIn((ptr[i] & filterMask) | (value & ~filterMask))); } - Value_ = value; + typedState->Value_ = value; } } - NUdf::TUnboxedValue Finish() final { - return NUdf::TUnboxedValuePod(Value_); + NUdf::TUnboxedValue FinishOne(const void* state) final { + auto typedState = static_cast<const TState*>(state); + return NUdf::TUnboxedValuePod(typedState->Value_); } private: const ui32 ArgColumn_; - TIn Value_ = 0; }; template <bool IsMin> diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp index a4033fc71b..347ad0a26e 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp @@ -13,18 +13,28 @@ namespace NMiniKQL { template <typename TIn, typename TSum, typename TInScalar> class TSumBlockAggregatorNullableOrScalar : public TBlockAggregatorBase { public: + struct TState { + TSum Sum_ = 0; + bool IsValid_ = false; + }; + TSumBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn) - : TBlockAggregatorBase(filterColumn) + : TBlockAggregatorBase(sizeof(TState), filterColumn) , ArgColumn_(argColumn) { } - void AddMany(const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + void InitState(void* state) final { + new(state) TState(); + } + + void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + auto typedState = static_cast<TState*>(state); const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); if (datum.is_scalar()) { if (datum.scalar()->is_valid) { - Sum_ += (filtered ? *filtered : batchLength) * datum.scalar_as<TInScalar>().value; - IsValid_ = true; + typedState->Sum_ += (filtered ? *filtered : batchLength) * datum.scalar_as<TInScalar>().value; + typedState->IsValid_ = true; } } else { const auto& array = datum.array(); @@ -36,8 +46,8 @@ public: } if (!filtered) { - IsValid_ = true; - TSum sum = Sum_; + typedState->IsValid_ = true; + TSum sum = typedState->Sum_; if (array->GetNullCount() == 0) { for (int64_t i = 0; i < len; ++i) { sum += ptr[i]; @@ -52,15 +62,15 @@ public: } } - Sum_ = sum; + typedState->Sum_ = sum; } else { const auto& filterDatum = TArrowBlock::From(columns[*FilterColumn_]).GetDatum(); const auto& filterArray = filterDatum.array(); MKQL_ENSURE(filterArray->GetNullCount() == 0, "Expected non-nullable bool column"); auto filterBitmap = filterArray->template GetValues<uint8_t>(1, 0); - TSum sum = Sum_; + TSum sum = typedState->Sum_; if (array->GetNullCount() == 0) { - IsValid_ = true; + typedState->IsValid_ = true; for (int64_t i = 0; i < len; ++i) { ui64 fullIndex = i + array->offset; // bit 1 -> mask 0xFF..FF, bit 0 -> mask 0x00..00 @@ -80,38 +90,46 @@ public: count += mask & 1; } - IsValid_ = IsValid_ || count > 0; + typedState->IsValid_ = typedState->IsValid_ || count > 0; } - Sum_ = sum; + typedState->Sum_ = sum; } } } - NUdf::TUnboxedValue Finish() final { - if (!IsValid_) { + NUdf::TUnboxedValue FinishOne(const void* state) final { + auto typedState = static_cast<const TState*>(state); + if (!typedState->IsValid_) { return NUdf::TUnboxedValuePod(); } - return NUdf::TUnboxedValuePod(Sum_); + return NUdf::TUnboxedValuePod(typedState->Sum_); } private: const ui32 ArgColumn_; - TSum Sum_ = 0; - bool IsValid_ = false; }; template <typename TIn, typename TSum, typename TInScalar> class TSumBlockAggregator : public TBlockAggregatorBase { public: + struct TState { + TSum Sum_ = 0; + }; + TSumBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn) - : TBlockAggregatorBase(filterColumn) + : TBlockAggregatorBase(sizeof(TState), filterColumn) , ArgColumn_(argColumn) { } - void AddMany(const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + void InitState(void* state) final { + new(state) TState(); + } + + void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + auto typedState = static_cast<TState*>(state); Y_UNUSED(batchLength); const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); MKQL_ENSURE(datum.is_array(), "Expected array"); @@ -121,7 +139,7 @@ public: MKQL_ENSURE(array->GetNullCount() == 0, "Expected no nulls"); MKQL_ENSURE(len > 0, "Expected at least one value"); - TSum sum = Sum_; + TSum sum = typedState->Sum_; if (!filtered) { for (int64_t i = 0; i < len; ++i) { sum += ptr[i]; @@ -139,34 +157,44 @@ public: } } - Sum_ = sum; + typedState->Sum_ = sum; } - NUdf::TUnboxedValue Finish() final { - return NUdf::TUnboxedValuePod(Sum_); + NUdf::TUnboxedValue FinishOne(const void* state) final { + auto typedState = static_cast<const TState*>(state); + return NUdf::TUnboxedValuePod(typedState->Sum_); } private: const ui32 ArgColumn_; - TSum Sum_ = 0; }; template <typename TIn, typename TInScalar> class TAvgBlockAggregator : public TBlockAggregatorBase { public: + struct TState { + double Sum_ = 0; + ui64 Count_ = 0; + }; + TAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, const THolderFactory& holderFactory) - : TBlockAggregatorBase(filterColumn) + : TBlockAggregatorBase(sizeof(TState), filterColumn) , ArgColumn_(argColumn) , HolderFactory_(holderFactory) { } - void AddMany(const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + void InitState(void* state) final { + new(state) TState(); + } + + void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { + auto typedState = static_cast<TState*>(state); const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); if (datum.is_scalar()) { if (datum.scalar()->is_valid) { - Sum_ += double((filtered ? *filtered : batchLength) * datum.scalar_as<TInScalar>().value); - Count_ += batchLength; + typedState->Sum_ += double((filtered ? *filtered : batchLength) * datum.scalar_as<TInScalar>().value); + typedState->Count_ += batchLength; } } else { const auto& array = datum.array(); @@ -178,8 +206,8 @@ public: } if (!filtered) { - Count_ += count; - double sum = Sum_; + typedState->Count_ += count; + double sum = typedState->Sum_; if (array->GetNullCount() == 0) { for (int64_t i = 0; i < len; ++i) { sum += double(ptr[i]); @@ -194,15 +222,15 @@ public: } } - Sum_ = sum; + typedState->Sum_ = sum; } else { const auto& filterDatum = TArrowBlock::From(columns[*FilterColumn_]).GetDatum(); const auto& filterArray = filterDatum.array(); MKQL_ENSURE(filterArray->GetNullCount() == 0, "Expected non-nullable bool column"); auto filterBitmap = filterArray->template GetValues<uint8_t>(1, 0); - double sum = Sum_; - ui64 count = Count_; + double sum = typedState->Sum_; + ui64 count = typedState->Count_; if (array->GetNullCount() == 0) { for (int64_t i = 0; i < len; ++i) { ui64 fullIndex = i + array->offset; @@ -224,29 +252,28 @@ public: } } - Sum_ = sum; - Count_ = count; + typedState->Sum_ = sum; + typedState->Count_ = count; } } } - NUdf::TUnboxedValue Finish() final { - if (!Count_) { + NUdf::TUnboxedValue FinishOne(const void* state) final { + auto typedState = static_cast<const TState*>(state); + if (!typedState->Count_) { return NUdf::TUnboxedValuePod(); } NUdf::TUnboxedValue* items; auto arr = HolderFactory_.CreateDirectArrayHolder(2, items); - items[0] = NUdf::TUnboxedValuePod(Sum_); - items[1] = NUdf::TUnboxedValuePod(Count_); + items[0] = NUdf::TUnboxedValuePod(typedState->Sum_); + items[1] = NUdf::TUnboxedValuePod(typedState->Count_); return arr; } private: const ui32 ArgColumn_; const THolderFactory& HolderFactory_; - double Sum_ = 0; - ui64 Count_ = 0; }; class TBlockSumFactory : public IBlockAggregatorFactory { |