diff options
author | aneporada <[email protected]> | 2023-02-28 18:41:59 +0300 |
---|---|---|
committer | aneporada <[email protected]> | 2023-02-28 18:41:59 +0300 |
commit | 77d7dac32d298e05839da6b4a860259cf37a01d8 (patch) | |
tree | 4d65d8d0f3579c6fa4c9728738223447b486a721 | |
parent | 6bfb9762bfe45c454b90c6807f43a1645da09b57 (diff) |
Switched min/max to IArrayBuilder, simplify code and reduce copy-paste
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp | 507 |
1 files changed, 161 insertions, 346 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp index 87b207e4277..f91f486b631 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp @@ -50,14 +50,14 @@ inline void UpdateMinMax(TMaybe<T>& state, bool& stateUpdated, T value) { template<typename TTag, typename TString, bool IsMin> class TMinMaxBlockStringAggregator; -template <typename TTag, typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TMinMaxBlockAggregatorNullableOrScalar; +template <typename TTag, bool IsNullable, bool IsScalar, typename TIn, bool IsMin> +class TMinMaxBlockFixedAggregator; -template <typename TTag, typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TMinMaxBlockAggregator; +template <bool IsNullable, typename TIn, bool IsMin> +struct TState; template <typename TIn, bool IsMin> -struct TState { +struct TState<true, TIn, IsMin> { TIn Value_; ui8 IsValid_ = 0; @@ -71,10 +71,10 @@ struct TState { }; template <typename TIn, bool IsMin> -struct TSimpleState { +struct TState<false, TIn, IsMin> { TIn Value_; - TSimpleState() { + TState() { if constexpr (IsMin) { Value_ = std::numeric_limits<TIn>::max(); } else { @@ -85,55 +85,30 @@ struct TSimpleState { using TGenericState = NUdf::TUnboxedValuePod; -template <typename TIn, bool IsMin, typename TBuilder> +template <bool IsNullable, typename TIn, bool IsMin> class TColumnBuilder : public IAggColumnBuilder { + using TBuilder = typename NYql::NUdf::TFixedSizeArrayBuilder<TIn, IsNullable>; + using TStateType = TState<IsNullable, TIn, IsMin>; public: - TColumnBuilder(ui64 size, const std::shared_ptr<arrow::DataType>& dataType, TComputationContext& ctx) - : Builder_(dataType, &ctx.ArrowMemoryPool) + TColumnBuilder(ui64 size, TType* type, TComputationContext& ctx) + : Builder_(TTypeInfoHelper(), type, ctx.ArrowMemoryPool, size) , Ctx_(ctx) { - ARROW_OK(Builder_.Reserve(size)); } void Add(const void* state) final { - auto typedState = static_cast<const TState<TIn, IsMin>*>(state); - if (typedState->IsValid_) { - Builder_.UnsafeAppend(typedState->Value_); - } else { - Builder_.UnsafeAppendNull(); + auto typedState = static_cast<const TStateType*>(state); + if constexpr (IsNullable) { + if (!typedState->IsValid_) { + Builder_.Add(TBlockItem()); + return; + } } + Builder_.Add(TBlockItem(typedState->Value_)); } NUdf::TUnboxedValue Build() final { - std::shared_ptr<arrow::ArrayData> result; - ARROW_OK(Builder_.FinishInternal(&result)); - return Ctx_.HolderFactory.CreateArrowBlock(result); - } - -private: - TBuilder Builder_; - TComputationContext& Ctx_; -}; - -template <typename TIn, bool IsMin, typename TBuilder> -class TSimpleColumnBuilder : public IAggColumnBuilder { -public: - TSimpleColumnBuilder(ui64 size, const std::shared_ptr<arrow::DataType>& dataType, TComputationContext& ctx) - : Builder_(dataType, &ctx.ArrowMemoryPool) - , Ctx_(ctx) - { - ARROW_OK(Builder_.Reserve(size)); - } - - void Add(const void* state) final { - auto typedState = static_cast<const TSimpleState<TIn, IsMin>*>(state); - Builder_.UnsafeAppend(typedState->Value_); - } - - NUdf::TUnboxedValue Build() final { - std::shared_ptr<arrow::ArrayData> result; - ARROW_OK(Builder_.FinishInternal(&result)); - return Ctx_.HolderFactory.CreateArrowBlock(result); + return Ctx_.HolderFactory.CreateArrowBlock(Builder_.Build(true)); } private: @@ -390,54 +365,60 @@ private: TType* const Type_; }; -template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TMinMaxBlockAggregatorNullableOrScalar<TCombineAllTag, TIn, TInScalar, TBuilder, IsMin> : public TCombineAllTag::TBase { +template <bool IsNullable, bool IsScalar, typename TIn, bool IsMin> +class TMinMaxBlockFixedAggregator<TCombineAllTag, IsNullable, IsScalar, TIn, IsMin> : public TCombineAllTag::TBase { public: using TBase = TCombineAllTag::TBase; + using TStateType = TState<IsNullable, TIn, IsMin>; + using TInScalar = typename TPrimitiveDataType<TIn>::TScalarResult; - TMinMaxBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn, - const std::shared_ptr<arrow::DataType>& builderDataType, TComputationContext& ctx) - : TBase(sizeof(TState<TIn, IsMin>), filterColumn, ctx) + TMinMaxBlockFixedAggregator(TType* type, std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx) + : TBase(sizeof(TStateType), filterColumn, ctx) , ArgColumn_(argColumn) { - Y_UNUSED(builderDataType); + Y_UNUSED(type); } void InitState(void* state) final { - new(state) TState<TIn, IsMin>(); + new(state) TStateType(); } void DestroyState(void* state) noexcept final { - static_assert(std::is_trivially_destructible<TState<TIn, IsMin>>::value); + static_assert(std::is_trivially_destructible<TStateType>::value); Y_UNUSED(state); } void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { - auto typedState = static_cast<TState<TIn, IsMin>*>(state); + auto typedState = static_cast<TStateType*>(state); Y_UNUSED(batchLength); const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); - if (datum.is_scalar()) { - if (datum.scalar()->is_valid) { + if constexpr (IsScalar) { + Y_ENSURE(datum.is_scalar()); + if constexpr (IsNullable) { + if (datum.scalar()->is_valid) { + typedState->Value_ = datum.scalar_as<TInScalar>().value; + typedState->IsValid_ = 1; + } + } else { typedState->Value_ = datum.scalar_as<TInScalar>().value; - typedState->IsValid_ = 1; } } else { const auto& array = datum.array(); auto ptr = array->GetValues<TIn>(1); auto len = array->length; - auto count = len - array->GetNullCount(); + auto nullCount = IsNullable ? array->GetNullCount() : 0; + auto count = len - nullCount; if (!count) { return; } if (!filtered) { - typedState->IsValid_ = 1; TIn value = typedState->Value_; - if (array->GetNullCount() == 0) { - for (int64_t i = 0; i < len; ++i) { - value = UpdateMinMax<IsMin>(value, ptr[i]); - } - } else { + if constexpr (IsNullable) { + typedState->IsValid_ = 1; + } + + if (IsNullable && nullCount != 0) { auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0); for (int64_t i = 0; i < len; ++i) { ui64 fullIndex = i + array->offset; @@ -445,6 +426,10 @@ public: TIn mask = -TIn((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1); value = UpdateMinMax<IsMin>(value, TIn((ptr[i] & mask) | (value & ~mask))); } + } else { + for (int64_t i = 0; i < len; ++i) { + value = UpdateMinMax<IsMin>(value, ptr[i]); + } } typedState->Value_ = value; @@ -455,14 +440,8 @@ public: const ui8* filterBitmap = filterArray->template GetValues<uint8_t>(1); TIn value = typedState->Value_; - if (array->GetNullCount() == 0) { - typedState->IsValid_ = 1; - for (int64_t i = 0; i < len; ++i) { - TIn filterMask = -TIn(filterBitmap[i]); - value = UpdateMinMax<IsMin>(value, TIn((ptr[i] & filterMask) | (value & ~filterMask))); - } - } else { - ui64 count = 0; + ui64 validCount = 0; + if (IsNullable && nullCount != 0) { auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0); for (int64_t i = 0; i < len; ++i) { ui64 fullIndex = i + array->offset; @@ -471,21 +450,30 @@ public: TIn filterMask = -TIn(filterBitmap[i]); mask &= filterMask; value = UpdateMinMax<IsMin>(value, TIn((ptr[i] & mask) | (value & ~mask))); - count += mask & 1; + validCount += mask & 1; + } + } else { + for (int64_t i = 0; i < len; ++i) { + TIn filterMask = -TIn(filterBitmap[i]); + validCount += filterBitmap[i]; + value = UpdateMinMax<IsMin>(value, TIn((ptr[i] & filterMask) | (value & ~filterMask))); } - - typedState->IsValid_ |= count ? 1 : 0; } + if constexpr (IsNullable) { + typedState->IsValid_ |= validCount ? 1 : 0; + } typedState->Value_ = value; } } } NUdf::TUnboxedValue FinishOne(const void* state) final { - auto typedState = static_cast<const TState<TIn, IsMin>*>(state); - if (!typedState->IsValid_) { - return NUdf::TUnboxedValuePod(); + auto typedState = static_cast<const TStateType*>(state); + if constexpr (IsNullable) { + if (!typedState->IsValid_) { + return NUdf::TUnboxedValuePod(); + } } return NUdf::TUnboxedValuePod(typedState->Value_); @@ -495,251 +483,115 @@ private: const ui32 ArgColumn_; }; -template <typename TIn, typename TInScalar, bool IsMin> -void PushValueToState(TState<TIn, IsMin>* typedState, const arrow::Datum& datum, ui64 row) { - if (datum.is_scalar()) { - if (datum.scalar()->is_valid) { +template <bool IsNullable, bool IsScalar, typename TIn, bool IsMin> +static void PushValueToState(TState<IsNullable, TIn, IsMin>* typedState, const arrow::Datum& datum, ui64 row) { + using TInScalar = typename TPrimitiveDataType<TIn>::TScalarResult; + if constexpr (IsScalar) { + Y_ENSURE(datum.is_scalar()); + if constexpr (IsNullable) { + if (datum.scalar()->is_valid) { + typedState->Value_ = datum.scalar_as<TInScalar>().value; + typedState->IsValid_ = 1; + } + } else { typedState->Value_ = datum.scalar_as<TInScalar>().value; - typedState->IsValid_ = 1; } } else { - const auto& array = datum.array(); - auto ptr = array->GetValues<TIn>(1); - if (array->GetNullCount() == 0) { - typedState->IsValid_ = 1; - typedState->Value_ = UpdateMinMax<IsMin>(typedState->Value_, ptr[row]); - } else { - auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0); - ui64 fullIndex = row + array->offset; - // bit 1 -> mask 0xFF..FF, bit 0 -> mask 0x00..00 - TIn mask = -TIn((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1); - typedState->Value_ = UpdateMinMax<IsMin>(typedState->Value_, TIn((ptr[row] & mask) | (typedState->Value_ & ~mask))); - typedState->IsValid_ |= mask & 1; - } - } -} - -template <typename TIn, bool IsMin> -void PushValueToSimpleState(TSimpleState<TIn, IsMin>* typedState, const arrow::Datum& datum, ui64 row) { - const auto& array = datum.array(); - auto ptr = array->GetValues<TIn>(1); - typedState->Value_ = UpdateMinMax<IsMin>(typedState->Value_, ptr[row]); -} - -template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TMinMaxBlockAggregatorNullableOrScalar<TCombineKeysTag, TIn, TInScalar, TBuilder, IsMin> : public TCombineKeysTag::TBase { -public: - using TBase = TCombineKeysTag::TBase; - - TMinMaxBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn, - const std::shared_ptr<arrow::DataType>& builderDataType, TComputationContext& ctx) - : TBase(sizeof(TState<TIn, IsMin>), filterColumn, ctx) - , ArgColumn_(argColumn) - , BuilderDataType_(builderDataType) - { - } - - void InitKey(void* state, const NUdf::TUnboxedValue* columns, ui64 row) final { - new(state) TState<TIn, IsMin>(); - UpdateKey(state, columns, row); - } - - void DestroyState(void* state) noexcept final { - static_assert(std::is_trivially_destructible<TState<TIn, IsMin>>::value); - Y_UNUSED(state); - } - - void UpdateKey(void* state, const NUdf::TUnboxedValue* columns, ui64 row) final { - auto typedState = static_cast<TState<TIn, IsMin>*>(state); - const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); - PushValueToState<TIn, TInScalar, IsMin>(typedState, datum, row); - } - - std::unique_ptr<IAggColumnBuilder> MakeStateBuilder(ui64 size) final { - return std::make_unique<TColumnBuilder<TIn, IsMin, TBuilder>>(size, BuilderDataType_, Ctx_); - } - -private: - const ui32 ArgColumn_; - const std::shared_ptr<arrow::DataType> BuilderDataType_; -}; - -template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TMinMaxBlockAggregatorNullableOrScalar<TFinalizeKeysTag, TIn, TInScalar, TBuilder, IsMin> : public TFinalizeKeysTag::TBase { -public: - using TBase = TFinalizeKeysTag::TBase; - - TMinMaxBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn, - const std::shared_ptr<arrow::DataType>& builderDataType, TComputationContext& ctx) - : TBase(sizeof(TState<TIn, IsMin>), filterColumn, ctx) - , ArgColumn_(argColumn) - , BuilderDataType_(builderDataType) - { - } - - void LoadState(void* state, const NUdf::TUnboxedValue* columns, ui64 row) final { - new(state) TState<TIn, IsMin>(); - UpdateState(state, columns, row); - } - - void DestroyState(void* state) noexcept final { - static_assert(std::is_trivially_destructible<TState<TIn, IsMin>>::value); - Y_UNUSED(state); - } - - void UpdateState(void* state, const NUdf::TUnboxedValue* columns, ui64 row) final { - auto typedState = static_cast<TState<TIn, IsMin>*>(state); - const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); - PushValueToState<TIn, TInScalar, IsMin>(typedState, datum, row); - } - - std::unique_ptr<IAggColumnBuilder> MakeResultBuilder(ui64 size) final { - return std::make_unique<TColumnBuilder<TIn, IsMin, TBuilder>>(size, BuilderDataType_, Ctx_); - } - -private: - const ui32 ArgColumn_; - const std::shared_ptr<arrow::DataType> BuilderDataType_; -}; - -template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TMinMaxBlockAggregator<TCombineAllTag, TIn, TInScalar, TBuilder, IsMin> : public TCombineAllTag::TBase { -public: - using TBase = TCombineAllTag::TBase; - - TMinMaxBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, - const std::shared_ptr<arrow::DataType>& builderDataType, TComputationContext& ctx) - : TBase(sizeof(TSimpleState<TIn, IsMin>), filterColumn, ctx) - , ArgColumn_(argColumn) - { - Y_UNUSED(builderDataType); - } - - void InitState(void* state) final { - new(state) TSimpleState<TIn, IsMin>; - } - - void DestroyState(void* state) noexcept final { - static_assert(std::is_trivially_destructible<TSimpleState<TIn, IsMin>>::value); - Y_UNUSED(state); - } - - void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final { - auto typedState = static_cast<TSimpleState<TIn, IsMin>*>(state); - Y_UNUSED(batchLength); - const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); - MKQL_ENSURE(datum.is_array(), "Expected array"); - const auto& array = datum.array(); + const auto &array = datum.array(); auto ptr = array->GetValues<TIn>(1); - auto len = array->length; - MKQL_ENSURE(array->GetNullCount() == 0, "Expected no nulls"); - MKQL_ENSURE(len > 0, "Expected at least one value"); - if (!filtered) { - TIn value = typedState->Value_; - for (int64_t i = 0; i < len; ++i) { - value = UpdateMinMax<IsMin>(value, ptr[i]); + if constexpr (IsNullable) { + if (array->GetNullCount() == 0) { + typedState->IsValid_ = 1; + typedState->Value_ = UpdateMinMax<IsMin>(typedState->Value_, ptr[row]); + } else { + auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0); + ui64 fullIndex = row + array->offset; + // bit 1 -> mask 0xFF..FF, bit 0 -> mask 0x00..00 + TIn mask = -TIn((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1); + typedState->Value_ = UpdateMinMax<IsMin>(typedState->Value_, TIn((ptr[row] & mask) | (typedState->Value_ & ~mask))); + typedState->IsValid_ |= mask & 1; } - - typedState->Value_ = value; } else { - const auto& filterDatum = TArrowBlock::From(columns[*FilterColumn_]).GetDatum(); - const auto& filterArray = filterDatum.array(); - MKQL_ENSURE(filterArray->GetNullCount() == 0, "Expected non-nullable bool column"); - const ui8* filterBitmap = filterArray->template GetValues<uint8_t>(1); - - TIn value = typedState->Value_; - for (int64_t i = 0; i < len; ++i) { - TIn filterMask = -TIn(filterBitmap[i]); - value = UpdateMinMax<IsMin>(value, TIn((ptr[i] & filterMask) | (value & ~filterMask))); - } - - typedState->Value_ = value; + typedState->Value_ = UpdateMinMax<IsMin>(typedState->Value_, ptr[row]); } } +} - NUdf::TUnboxedValue FinishOne(const void* state) final { - auto typedState = static_cast<const TSimpleState<TIn, IsMin>*>(state); - return NUdf::TUnboxedValuePod(typedState->Value_); - } - -private: - const ui32 ArgColumn_; -}; - -template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TMinMaxBlockAggregator<TCombineKeysTag, TIn, TInScalar, TBuilder, IsMin> : public TCombineKeysTag::TBase { +template <bool IsNullable, bool IsScalar, typename TIn, bool IsMin> +class TMinMaxBlockFixedAggregator<TCombineKeysTag, IsNullable, IsScalar, TIn, IsMin> : public TCombineKeysTag::TBase { public: using TBase = TCombineKeysTag::TBase; + using TStateType = TState<IsNullable, TIn, IsMin>; - TMinMaxBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, - const std::shared_ptr<arrow::DataType>& builderDataType, TComputationContext& ctx) - : TBase(sizeof(TSimpleState<TIn, IsMin>), filterColumn, ctx) + TMinMaxBlockFixedAggregator(TType* type, std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx) + : TBase(sizeof(TStateType), filterColumn, ctx) , ArgColumn_(argColumn) - , BuilderDataType_(builderDataType) + , Type_(type) { } void InitKey(void* state, const NUdf::TUnboxedValue* columns, ui64 row) final { - new(state) TSimpleState<TIn, IsMin>(); + new(state) TStateType(); UpdateKey(state, columns, row); } void DestroyState(void* state) noexcept final { - static_assert(std::is_trivially_destructible<TSimpleState<TIn, IsMin>>::value); + static_assert(std::is_trivially_destructible<TStateType>::value); Y_UNUSED(state); } void UpdateKey(void* state, const NUdf::TUnboxedValue* columns, ui64 row) final { - auto typedState = static_cast<TSimpleState<TIn, IsMin>*>(state); + auto typedState = static_cast<TStateType*>(state); const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); - PushValueToSimpleState<TIn, IsMin>(typedState, datum, row); + PushValueToState<IsNullable, IsScalar, TIn, IsMin>(typedState, datum, row); } std::unique_ptr<IAggColumnBuilder> MakeStateBuilder(ui64 size) final { - return std::make_unique<TSimpleColumnBuilder<TIn, IsMin, TBuilder>>(size, BuilderDataType_, Ctx_); + return std::make_unique<TColumnBuilder<IsNullable, TIn, IsMin>>(size, Type_, Ctx_); } private: const ui32 ArgColumn_; const std::shared_ptr<arrow::DataType> BuilderDataType_; + TType* const Type_; }; -template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TMinMaxBlockAggregator<TFinalizeKeysTag, TIn, TInScalar, TBuilder, IsMin> : public TFinalizeKeysTag::TBase { +template <bool IsNullable, bool IsScalar, typename TIn, bool IsMin> +class TMinMaxBlockFixedAggregator<TFinalizeKeysTag, IsNullable, IsScalar, TIn, IsMin> : public TFinalizeKeysTag::TBase { public: using TBase = TFinalizeKeysTag::TBase; + using TStateType = TState<IsNullable, TIn, IsMin>; - TMinMaxBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, - const std::shared_ptr<arrow::DataType>& builderDataType, TComputationContext& ctx) - : TBase(sizeof(TSimpleState<TIn, IsMin>), filterColumn, ctx) + TMinMaxBlockFixedAggregator(TType* type, std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx) + : TBase(sizeof(TStateType), filterColumn, ctx) , ArgColumn_(argColumn) - , BuilderDataType_(builderDataType) + , Type_(type) { } void LoadState(void* state, const NUdf::TUnboxedValue* columns, ui64 row) final { - new(state) TSimpleState<TIn, IsMin>(); + new(state) TStateType(); UpdateState(state, columns, row); } void DestroyState(void* state) noexcept final { - static_assert(std::is_trivially_destructible<TSimpleState<TIn, IsMin>>::value); + static_assert(std::is_trivially_destructible<TStateType>::value); Y_UNUSED(state); } void UpdateState(void* state, const NUdf::TUnboxedValue* columns, ui64 row) final { - auto typedState = static_cast<TSimpleState<TIn, IsMin>*>(state); + auto typedState = static_cast<TStateType*>(state); const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum(); - PushValueToSimpleState<TIn, IsMin>(typedState, datum, row); + PushValueToState<IsNullable, IsScalar, TIn, IsMin>(typedState, datum, row); } std::unique_ptr<IAggColumnBuilder> MakeResultBuilder(ui64 size) final { - return std::make_unique<TSimpleColumnBuilder<TIn, IsMin, TBuilder>>(size, BuilderDataType_, Ctx_); + return std::make_unique<TColumnBuilder<IsNullable, TIn, IsMin>>(size, Type_, Ctx_); } private: const ui32 ArgColumn_; - const std::shared_ptr<arrow::DataType> BuilderDataType_; + TType* const Type_; }; template<typename TTag, typename TStringType, bool IsMin> @@ -763,51 +615,42 @@ private: const ui32 ArgColumn_; }; -template <typename TTag, typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TPreparedMinMaxBlockAggregatorNullableOrScalar : public TTag::TPreparedAggregator { +template <typename TTag, bool IsNullable, bool IsScalar, typename TIn, bool IsMin> +class TPreparedMinMaxBlockFixedAggregator : public TTag::TPreparedAggregator { public: using TBase = typename TTag::TPreparedAggregator; + using TStateType = TState<IsNullable, TIn, IsMin>; - TPreparedMinMaxBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn, - const std::shared_ptr<arrow::DataType>& builderDataType) - : TBase(sizeof(TState<TIn, IsMin>)) + TPreparedMinMaxBlockFixedAggregator(TType* type, std::optional<ui32> filterColumn, ui32 argColumn) + : TBase(sizeof(TStateType)) + , Type_(type) , FilterColumn_(filterColumn) , ArgColumn_(argColumn) - , BuilderDataType_(builderDataType) {} std::unique_ptr<typename TTag::TAggregator> Make(TComputationContext& ctx) const final { - return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<TTag, TIn, TInScalar, TBuilder, IsMin>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx); + return std::make_unique<TMinMaxBlockFixedAggregator<TTag, IsNullable, IsScalar, TIn, IsMin>>(Type_, FilterColumn_, ArgColumn_, ctx); } private: + TType* const Type_; const std::optional<ui32> FilterColumn_; const ui32 ArgColumn_; - const std::shared_ptr<arrow::DataType> BuilderDataType_; }; -template <typename TTag, typename TIn, typename TInScalar, typename TBuilder, bool IsMin> -class TPreparedMinMaxBlockAggregator : public TTag::TPreparedAggregator { -public: - using TBase = typename TTag::TPreparedAggregator; - - TPreparedMinMaxBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, - const std::shared_ptr<arrow::DataType>& builderDataType) - : TBase(sizeof(TSimpleState<TIn, IsMin>)) - , FilterColumn_(filterColumn) - , ArgColumn_(argColumn) - , BuilderDataType_(builderDataType) - {} - - std::unique_ptr<typename TTag::TAggregator> Make(TComputationContext& ctx) const final { - return std::make_unique<TMinMaxBlockAggregator<TTag, TIn, TInScalar, TBuilder, IsMin>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx); +template<typename TTag, typename TIn, bool IsMin> +std::unique_ptr<typename TTag::TPreparedAggregator> PrepareMinMaxFixed(TType* type, bool isOptional, bool isScalar, std::optional<ui32> filterColumn, ui32 argColumn) { + if (isScalar) { + if (isOptional) { + return std::make_unique<TPreparedMinMaxBlockFixedAggregator<TTag, true, true, TIn, IsMin>>(type, filterColumn, argColumn); + } + return std::make_unique<TPreparedMinMaxBlockFixedAggregator<TTag, false, true, TIn, IsMin>>(type, filterColumn, argColumn); } - -private: - const std::optional<ui32> FilterColumn_; - const ui32 ArgColumn_; - const std::shared_ptr<arrow::DataType> BuilderDataType_; -}; + if (isOptional) { + return std::make_unique<TPreparedMinMaxBlockFixedAggregator<TTag, true, false, TIn, IsMin>>(type, filterColumn, argColumn); + } + return std::make_unique<TPreparedMinMaxBlockFixedAggregator<TTag, false, false, TIn, IsMin>>(type, filterColumn, argColumn); +} template <typename TTag, bool IsMin> std::unique_ptr<typename TTag::TPreparedAggregator> PrepareMinMax(TTupleType* tupleType, std::optional<ui32> filterColumn, ui32 argColumn) { @@ -823,59 +666,31 @@ std::unique_ptr<typename TTag::TPreparedAggregator> PrepareMinMax(TTupleType* tu using TStringType = NUdf::TUtf8; return std::make_unique<TPreparedMinMaxBlockStringAggregator<TTag, TStringType, IsMin>>(argType, filterColumn, argColumn); } - if (blockType->GetShape() == TBlockType::EShape::Scalar || isOptional) { - switch (slot) { - case NUdf::EDataSlot::Int8: - return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<TTag, i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argColumn, arrow::int8()); - case NUdf::EDataSlot::Bool: - case NUdf::EDataSlot::Uint8: - return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<TTag, ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argColumn, arrow::uint8()); - case NUdf::EDataSlot::Int16: - return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<TTag, i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argColumn, arrow::int16()); - case NUdf::EDataSlot::Uint16: - case NUdf::EDataSlot::Date: - return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<TTag, ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argColumn, arrow::uint16()); - case NUdf::EDataSlot::Int32: - return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<TTag, i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argColumn, arrow::int32()); - case NUdf::EDataSlot::Uint32: - case NUdf::EDataSlot::Datetime: - return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<TTag, ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argColumn, arrow::uint32()); - case NUdf::EDataSlot::Int64: - case NUdf::EDataSlot::Interval: - return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<TTag, i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argColumn, arrow::int64()); - case NUdf::EDataSlot::Uint64: - case NUdf::EDataSlot::Timestamp: - return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<TTag, ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argColumn, arrow::uint64()); - default: - throw yexception() << "Unsupported MIN/MAX input type"; - } - } - else { - switch (*dataType->GetDataSlot()) { - case NUdf::EDataSlot::Int8: - return std::make_unique<TPreparedMinMaxBlockAggregator<TTag, i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argColumn, arrow::int8()); - case NUdf::EDataSlot::Uint8: - case NUdf::EDataSlot::Bool: - return std::make_unique<TPreparedMinMaxBlockAggregator<TTag, ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argColumn, arrow::uint8()); - case NUdf::EDataSlot::Int16: - return std::make_unique<TPreparedMinMaxBlockAggregator<TTag, i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argColumn, arrow::int16()); - case NUdf::EDataSlot::Uint16: - case NUdf::EDataSlot::Date: - return std::make_unique<TPreparedMinMaxBlockAggregator<TTag, ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argColumn, arrow::uint16()); - case NUdf::EDataSlot::Int32: - return std::make_unique<TPreparedMinMaxBlockAggregator<TTag, i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argColumn, arrow::int32()); - case NUdf::EDataSlot::Uint32: - case NUdf::EDataSlot::Datetime: - return std::make_unique<TPreparedMinMaxBlockAggregator<TTag, ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argColumn, arrow::uint32()); - case NUdf::EDataSlot::Int64: - case NUdf::EDataSlot::Interval: - return std::make_unique<TPreparedMinMaxBlockAggregator<TTag, i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argColumn, arrow::int64()); - case NUdf::EDataSlot::Uint64: - case NUdf::EDataSlot::Timestamp: - return std::make_unique<TPreparedMinMaxBlockAggregator<TTag, ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argColumn, arrow::uint64()); - default: - throw yexception() << "Unsupported MIN/MAX input type"; - } + bool isScalar = blockType->GetShape() == TBlockType::EShape::Scalar; + switch (slot) { + case NUdf::EDataSlot::Int8: + return PrepareMinMaxFixed<TTag, i8, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn); + case NUdf::EDataSlot::Bool: + case NUdf::EDataSlot::Uint8: + return PrepareMinMaxFixed<TTag, ui8, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn); + case NUdf::EDataSlot::Int16: + return PrepareMinMaxFixed<TTag, i16, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn); + case NUdf::EDataSlot::Uint16: + case NUdf::EDataSlot::Date: + return PrepareMinMaxFixed<TTag, ui16, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn); + case NUdf::EDataSlot::Int32: + return PrepareMinMaxFixed<TTag, i32, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn); + case NUdf::EDataSlot::Uint32: + case NUdf::EDataSlot::Datetime: + return PrepareMinMaxFixed<TTag, ui32, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn); + case NUdf::EDataSlot::Int64: + case NUdf::EDataSlot::Interval: + return PrepareMinMaxFixed<TTag, i64, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn); + case NUdf::EDataSlot::Uint64: + case NUdf::EDataSlot::Timestamp: + return PrepareMinMaxFixed<TTag, ui64, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn); + default: + throw yexception() << "Unsupported MIN/MAX input type"; } } @@ -909,7 +724,7 @@ public: } }; -} +} // namespace std::unique_ptr<IBlockAggregatorFactory> MakeBlockMinFactory() { return std::make_unique<TBlockMinMaxFactory<true>>(); |