diff options
author | vvvv <vvvv@ydb.tech> | 2022-12-14 16:02:08 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-12-14 16:02:08 +0300 |
commit | 58e33e587065c6bac9600a43369d97363ed84184 (patch) | |
tree | 734e5f169e324915338d368f1d9e88c33be2dcd5 | |
parent | 7ab8d966fbb1720e24f2b00d111f19fa60ff0a40 (diff) | |
download | ydb-58e33e587065c6bac9600a43369d97363ed84184.tar.gz |
support of avg over GROUP BY with keys
6 files changed, 274 insertions, 88 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp index 5ffdb8ad4a9..ae8af5df98f 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp @@ -111,9 +111,7 @@ namespace NMiniKQL { namespace { struct TAggParams { - TStringBuf Name; - TTupleType* TupleType; - std::vector<ui32> ArgColumns; + std::unique_ptr<IPreparedBlockAggregator> Prepared_; }; struct TKeyParams { @@ -412,8 +410,7 @@ private: ui32 totalStateSize = 0; for (const auto& p : params) { - Aggs_.emplace_back(MakeBlockAggregator(p.Name, p.TupleType, filterColumn, p.ArgColumns, ctx)); - + Aggs_.emplace_back(p.Prepared_->Make(ctx)); totalStateSize += Aggs_.back()->StateSize; } @@ -716,8 +713,7 @@ private: } for (const auto& p : params) { - Aggs_.emplace_back(MakeBlockAggregator(p.Name, p.TupleType, filterColumn, p.ArgColumns, ctx)); - + Aggs_.emplace_back(p.Prepared_->Make(ctx)); TotalStateSize_ += Aggs_.back()->StateSize; } @@ -756,7 +752,7 @@ private: std::vector<std::unique_ptr<IKeySerializer>> KeySerializers_; }; -void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, TVector<TAggParams>& aggsParams) { +void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, TVector<TAggParams>& aggsParams, const TTypeEnvironment& env) { for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) { auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i)); auto name = AS_VALUE(TDataLiteral, aggVal->GetValue(0))->AsValue().AsStringRef(); @@ -766,7 +762,9 @@ void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, TVector<TAggPa argColumns.push_back(AS_VALUE(TDataLiteral, aggVal->GetValue(j))->AsValue().Get<ui32>()); } - aggsParams.emplace_back(TAggParams{ TStringBuf(name), tupleType, argColumns }); + TAggParams p; + p.Prepared_ = PrepareBlockAggregator(name, tupleType, filterColumn, argColumns, env); + aggsParams.emplace_back(std::move(p)); } } @@ -809,7 +807,7 @@ IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNod auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2)); TVector<TAggParams> aggsParams; - FillAggParams(aggsVal, tupleType, aggsParams); + FillAggParams(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env); return new TBlockCombineAllWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams)); } @@ -836,7 +834,7 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(3)); TVector<TAggParams> aggsParams; - FillAggParams(aggsVal, tupleType, aggsParams); + FillAggParams(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env); ui32 totalKeysSize = 0; std::vector<std::unique_ptr<IKeySerializer>> keySerializers; diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp index 0224d23a5d2..5d1cae3a812 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp @@ -195,28 +195,60 @@ private: const ui32 ArgColumn_; }; +class TPreparedCountAllBlockAggregator : public IPreparedBlockAggregator { +public: + TPreparedCountAllBlockAggregator(std::optional<ui32> filterColumn) + : FilterColumn_(filterColumn) + {} + + std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final { + return std::make_unique<TCountAllBlockAggregator>(FilterColumn_, ctx); + } + +private: + const std::optional<ui32> FilterColumn_; +}; + class TBlockCountAllFactory : public IBlockAggregatorFactory { public: - std::unique_ptr<IBlockAggregator> Make( + std::unique_ptr<IPreparedBlockAggregator> Prepare( TTupleType* tupleType, std::optional<ui32> filterColumn, const std::vector<ui32>& argsColumns, - TComputationContext& ctx) const final { + const TTypeEnvironment& env) const final { Y_UNUSED(tupleType); Y_UNUSED(argsColumns); - return std::make_unique<TCountAllBlockAggregator>(filterColumn, ctx); + Y_UNUSED(env); + return std::make_unique<TPreparedCountAllBlockAggregator>(filterColumn); } }; +class TPreparedCountBlockAggregator : public IPreparedBlockAggregator { +public: + TPreparedCountBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn) + : FilterColumn_(filterColumn) + , ArgColumn_(argColumn) + {} + + std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final { + return std::make_unique<TCountBlockAggregator>(FilterColumn_, ArgColumn_, ctx); + } + +private: + const std::optional<ui32> FilterColumn_; + const ui32 ArgColumn_; +}; + class TBlockCountFactory : public IBlockAggregatorFactory { public: - std::unique_ptr<IBlockAggregator> Make( + std::unique_ptr<IPreparedBlockAggregator> Prepare( TTupleType* tupleType, std::optional<ui32> filterColumn, const std::vector<ui32>& argsColumns, - TComputationContext& ctx) const final { + const TTypeEnvironment& env) const final { Y_UNUSED(tupleType); - return std::make_unique<TCountBlockAggregator>(filterColumn, argsColumns[0], ctx); + Y_UNUSED(env); + return std::make_unique<TPreparedCountBlockAggregator>(filterColumn, argsColumns[0]); } }; diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp index a0a66aa9c96..e96e5f3de2b 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp @@ -20,19 +20,19 @@ struct TAggregatorFactories { } }; -std::unique_ptr<IBlockAggregator> MakeBlockAggregator( +std::unique_ptr<IPreparedBlockAggregator> PrepareBlockAggregator( TStringBuf name, TTupleType* tupleType, std::optional<ui32> filterColumn, const std::vector<ui32>& argsColumns, - TComputationContext& ctx) { + const TTypeEnvironment& env) { const auto& f = Singleton<TAggregatorFactories>()->Factories; auto it = f.find(name); if (it == f.end()) { throw yexception() << "Unsupported block aggregation function: " << name; } - return it->second->Make(tupleType, filterColumn, argsColumns, ctx); + return it->second->Prepare(tupleType, filterColumn, argsColumns, env); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h index 2eae714e020..80b9de36723 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h @@ -52,24 +52,29 @@ protected: TComputationContext& Ctx_; }; -class THolderFactory; +class IPreparedBlockAggregator { +public: + virtual ~IPreparedBlockAggregator() = default; + + virtual std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const = 0; +}; -std::unique_ptr<IBlockAggregator> MakeBlockAggregator( +std::unique_ptr<IPreparedBlockAggregator> PrepareBlockAggregator( TStringBuf name, TTupleType* tupleType, std::optional<ui32> filterColumn, const std::vector<ui32>& argsColumns, - TComputationContext& ctx); + const TTypeEnvironment& env); class IBlockAggregatorFactory { public: virtual ~IBlockAggregatorFactory() = default; - virtual std::unique_ptr<IBlockAggregator> Make( + virtual std::unique_ptr<IPreparedBlockAggregator> Prepare( TTupleType* tupleType, std::optional<ui32> filterColumn, const std::vector<ui32>& argsColumns, - TComputationContext& ctx) const = 0; + const TTypeEnvironment& env) const = 0; }; } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp index e961aedaf14..2992593eafc 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp @@ -308,14 +308,55 @@ private: const std::shared_ptr<arrow::DataType> BuilderDataType_; }; +template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin> +class TPreparedMinMaxBlockAggregatorNullableOrScalar : public IPreparedBlockAggregator { +public: + TPreparedMinMaxBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn, + const std::shared_ptr<arrow::DataType>& builderDataType) + : FilterColumn_(filterColumn) + , ArgColumn_(argColumn) + , BuilderDataType_(builderDataType) + {} + + std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final { + return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<TIn, TInScalar, TBuilder, IsMin>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx); + } + +private: + const std::optional<ui32> FilterColumn_; + const ui32 ArgColumn_; + const std::shared_ptr<arrow::DataType> BuilderDataType_; +}; + +template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin> +class TPreparedMinMaxBlockAggregator : public IPreparedBlockAggregator { +public: + TPreparedMinMaxBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, + const std::shared_ptr<arrow::DataType>& builderDataType) + : FilterColumn_(filterColumn) + , ArgColumn_(argColumn) + , BuilderDataType_(builderDataType) + {} + + std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final { + return std::make_unique<TMinMaxBlockAggregator<TIn, TInScalar, TBuilder, IsMin>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx); + } + +private: + const std::optional<ui32> FilterColumn_; + const ui32 ArgColumn_; + const std::shared_ptr<arrow::DataType> BuilderDataType_; +}; + template <bool IsMin> class TBlockMinMaxFactory : public IBlockAggregatorFactory { public: - std::unique_ptr<IBlockAggregator> Make( + std::unique_ptr<IPreparedBlockAggregator> Prepare( TTupleType* tupleType, std::optional<ui32> filterColumn, const std::vector<ui32>& argsColumns, - TComputationContext& ctx) const final { + const TTypeEnvironment& env) const final { + Y_UNUSED(env); auto blockType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0])); auto argType = blockType->GetItemType(); bool isOptional; @@ -323,52 +364,52 @@ public: if (blockType->GetShape() == TBlockType::EShape::Scalar || isOptional) { switch (*dataType->GetDataSlot()) { case NUdf::EDataSlot::Int8: - return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int8(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int8()); case NUdf::EDataSlot::Bool: case NUdf::EDataSlot::Uint8: - return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint8(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint8()); case NUdf::EDataSlot::Int16: - return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int16(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int16()); case NUdf::EDataSlot::Uint16: case NUdf::EDataSlot::Date: - return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint16(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint16()); case NUdf::EDataSlot::Int32: - return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int32(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int32()); case NUdf::EDataSlot::Uint32: case NUdf::EDataSlot::Datetime: - return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint32(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint32()); case NUdf::EDataSlot::Int64: case NUdf::EDataSlot::Interval: - return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint64: case NUdf::EDataSlot::Timestamp: - return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint64()); default: throw yexception() << "Unsupported MIN/MAX input type"; } } else { switch (*dataType->GetDataSlot()) { case NUdf::EDataSlot::Int8: - return std::make_unique<TMinMaxBlockAggregator<i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int8(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregator<i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int8()); case NUdf::EDataSlot::Uint8: case NUdf::EDataSlot::Bool: - return std::make_unique<TMinMaxBlockAggregator<ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint8(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregator<ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint8()); case NUdf::EDataSlot::Int16: - return std::make_unique<TMinMaxBlockAggregator<i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int16(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregator<i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int16()); case NUdf::EDataSlot::Uint16: case NUdf::EDataSlot::Date: - return std::make_unique<TMinMaxBlockAggregator<ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint16(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregator<ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint16()); case NUdf::EDataSlot::Int32: - return std::make_unique<TMinMaxBlockAggregator<i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int32(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregator<i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int32()); case NUdf::EDataSlot::Uint32: case NUdf::EDataSlot::Datetime: - return std::make_unique<TMinMaxBlockAggregator<ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint32(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregator<ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint32()); case NUdf::EDataSlot::Int64: case NUdf::EDataSlot::Interval: - return std::make_unique<TMinMaxBlockAggregator<i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregator<i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint64: case NUdf::EDataSlot::Timestamp: - return std::make_unique<TMinMaxBlockAggregator<ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedMinMaxBlockAggregator<ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint64()); default: throw yexception() << "Unsupported MIN/MAX input type"; } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp index 5d1c162ffd3..9894a3c2317 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp @@ -288,9 +288,62 @@ public: ui64 Count_ = 0; }; - TAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx) + class TColumnBuilder : public IAggColumnBuilder { + public: + TColumnBuilder(ui64 size, const std::shared_ptr<arrow::DataType>& arrowType, TComputationContext& ctx) + : ArrowType_(arrowType) + , Ctx_(ctx) + , NullBitmapBuilder_(&ctx.ArrowMemoryPool) + , SumBuilder_(arrow::float64(), &ctx.ArrowMemoryPool) + , CountBuilder_(arrow::uint64(), &ctx.ArrowMemoryPool) + { + ARROW_OK(NullBitmapBuilder_.Reserve(size)); + ARROW_OK(SumBuilder_.Reserve(size)); + ARROW_OK(CountBuilder_.Reserve(size)); + } + + void Add(const void* state) final { + auto typedState = static_cast<const TState*>(state); + if (typedState->Count_) { + NullBitmapBuilder_.UnsafeAppend(true); + SumBuilder_.UnsafeAppend(typedState->Sum_); + CountBuilder_.UnsafeAppend(typedState->Count_); + } else { + NullBitmapBuilder_.UnsafeAppend(false); + SumBuilder_.UnsafeAppendNull(); + CountBuilder_.UnsafeAppendNull(); + } + } + + NUdf::TUnboxedValue Build() final { + std::shared_ptr<arrow::ArrayData> sumResult; + std::shared_ptr<arrow::ArrayData> countResult; + ARROW_OK(SumBuilder_.FinishInternal(&sumResult)); + ARROW_OK(CountBuilder_.FinishInternal(&countResult)); + std::shared_ptr<arrow::Buffer> nullBitmap; + auto length = NullBitmapBuilder_.length(); + auto nullCount = NullBitmapBuilder_.false_count(); + ARROW_OK(NullBitmapBuilder_.Finish(&nullBitmap)); + + auto arrayData = arrow::ArrayData::Make(ArrowType_, length, { nullBitmap }, nullCount, 0); + arrayData->child_data.push_back(sumResult); + arrayData->child_data.push_back(countResult); + return Ctx_.HolderFactory.CreateArrowBlock(arrow::Datum(arrayData)); + } + + private: + const std::shared_ptr<arrow::DataType> ArrowType_; + TComputationContext& Ctx_; + arrow::TypedBufferBuilder<bool> NullBitmapBuilder_; + arrow::DoubleBuilder SumBuilder_; + arrow::UInt64Builder CountBuilder_; + }; + + TAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, + const std::shared_ptr<arrow::DataType> builderDataType, TComputationContext& ctx) : TBlockAggregatorBase(sizeof(TState), filterColumn, ctx) , ArgColumn_(argColumn) + , BuilderDataType_(builderDataType) { } @@ -411,21 +464,62 @@ public: } std::unique_ptr<IAggColumnBuilder> MakeBuilder(ui64 size) final { - Y_UNUSED(size); - MKQL_ENSURE(false, "TODO: support of tuples"); + return std::make_unique<TColumnBuilder>(size, BuilderDataType_, Ctx_); + } + +private: + const ui32 ArgColumn_; + const std::shared_ptr<arrow::DataType> BuilderDataType_; +}; + +template <typename TIn, typename TSum, typename TBuilder, typename TInScalar> +class TPreparedSumBlockAggregatorNullableOrScalar : public IPreparedBlockAggregator { +public: + TPreparedSumBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn, + const std::shared_ptr<arrow::DataType>& builderDataType) + : FilterColumn_(filterColumn) + , ArgColumn_(argColumn) + , BuilderDataType_(builderDataType) + {} + + std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final { + return std::make_unique<TSumBlockAggregatorNullableOrScalar<TIn, TSum, TBuilder, TInScalar>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx); + } + +private: + const std::optional<ui32> FilterColumn_; + const ui32 ArgColumn_; + const std::shared_ptr<arrow::DataType> BuilderDataType_; +}; + +template <typename TIn, typename TSum, typename TBuilder, typename TInScalar> +class TPreparedSumBlockAggregator : public IPreparedBlockAggregator { +public: + TPreparedSumBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, + const std::shared_ptr<arrow::DataType>& builderDataType) + : FilterColumn_(filterColumn) + , ArgColumn_(argColumn) + , BuilderDataType_(builderDataType) + {} + + std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final { + return std::make_unique<TSumBlockAggregator<TIn, TSum, TBuilder, TInScalar>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx); } private: + const std::optional<ui32> FilterColumn_; const ui32 ArgColumn_; + const std::shared_ptr<arrow::DataType> BuilderDataType_; }; class TBlockSumFactory : public IBlockAggregatorFactory { public: - std::unique_ptr<IBlockAggregator> Make( + std::unique_ptr<IPreparedBlockAggregator> Prepare( TTupleType* tupleType, std::optional<ui32> filterColumn, const std::vector<ui32>& argsColumns, - TComputationContext& ctx) const final { + const TTypeEnvironment& env) const final { + Y_UNUSED(env); auto blockType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0])); auto argType = blockType->GetItemType(); bool isOptional; @@ -433,50 +527,42 @@ public: if (blockType->GetShape() == TBlockType::EShape::Scalar || isOptional) { switch (*dataType->GetDataSlot()) { case NUdf::EDataSlot::Int8: - return std::make_unique<TSumBlockAggregatorNullableOrScalar<i8, i64, arrow::Int64Builder, arrow::Int8Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<i8, i64, arrow::Int64Builder, arrow::Int8Scalar>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint8: - return std::make_unique<TSumBlockAggregatorNullableOrScalar<ui8, ui64, arrow::UInt64Builder, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<ui8, ui64, arrow::UInt64Builder, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], arrow::uint64()); case NUdf::EDataSlot::Int16: - return std::make_unique<TSumBlockAggregatorNullableOrScalar<i16, i64, arrow::Int64Builder, arrow::Int16Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<i16, i64, arrow::Int64Builder, arrow::Int16Scalar>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint16: - case NUdf::EDataSlot::Date: - return std::make_unique<TSumBlockAggregatorNullableOrScalar<ui16, ui64, arrow::UInt64Builder, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<ui16, ui64, arrow::UInt64Builder, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], arrow::uint64()); case NUdf::EDataSlot::Int32: - return std::make_unique<TSumBlockAggregatorNullableOrScalar<i32, i64, arrow::Int64Builder, arrow::Int32Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<i32, i64, arrow::Int64Builder, arrow::Int32Scalar>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint32: - case NUdf::EDataSlot::Datetime: - return std::make_unique<TSumBlockAggregatorNullableOrScalar<ui32, ui64, arrow::UInt64Builder, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<ui32, ui64, arrow::UInt64Builder, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], arrow::uint64()); case NUdf::EDataSlot::Int64: - case NUdf::EDataSlot::Interval: - return std::make_unique<TSumBlockAggregatorNullableOrScalar<i64, i64, arrow::Int64Builder, arrow::Int64Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<i64, i64, arrow::Int64Builder, arrow::Int64Scalar>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint64: - case NUdf::EDataSlot::Timestamp: - return std::make_unique<TSumBlockAggregatorNullableOrScalar<ui64, ui64, arrow::UInt64Builder, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<ui64, ui64, arrow::UInt64Builder, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], arrow::uint64()); default: throw yexception() << "Unsupported SUM input type"; } } else { switch (*dataType->GetDataSlot()) { case NUdf::EDataSlot::Int8: - return std::make_unique<TSumBlockAggregator<i8, i64, arrow::Int64Builder, arrow::Int8Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedSumBlockAggregator<i8, i64, arrow::Int64Builder, arrow::Int8Scalar>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint8: - return std::make_unique<TSumBlockAggregator<ui8, ui64, arrow::UInt64Builder, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedSumBlockAggregator<ui8, ui64, arrow::UInt64Builder, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], arrow::uint64()); case NUdf::EDataSlot::Int16: - return std::make_unique<TSumBlockAggregator<i16, i64, arrow::Int64Builder, arrow::Int16Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedSumBlockAggregator<i16, i64, arrow::Int64Builder, arrow::Int16Scalar>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint16: - case NUdf::EDataSlot::Date: - return std::make_unique<TSumBlockAggregator<ui16, ui64, arrow::UInt64Builder, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedSumBlockAggregator<ui16, ui64, arrow::UInt64Builder, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], arrow::uint64()); case NUdf::EDataSlot::Int32: - return std::make_unique<TSumBlockAggregator<i32, i64, arrow::Int64Builder, arrow::Int32Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedSumBlockAggregator<i32, i64, arrow::Int64Builder, arrow::Int32Scalar>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint32: - case NUdf::EDataSlot::Datetime: - return std::make_unique<TSumBlockAggregator<ui32, ui64, arrow::UInt64Builder, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedSumBlockAggregator<ui32, ui64, arrow::UInt64Builder, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], arrow::uint64()); case NUdf::EDataSlot::Int64: - case NUdf::EDataSlot::Interval: - return std::make_unique<TSumBlockAggregator<i64, i64, arrow::Int64Builder, arrow::Int64Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx); + return std::make_unique<TPreparedSumBlockAggregator<i64, i64, arrow::Int64Builder, arrow::Int64Scalar>>(filterColumn, argsColumns[0], arrow::int64()); case NUdf::EDataSlot::Uint64: - case NUdf::EDataSlot::Timestamp: - return std::make_unique<TSumBlockAggregator<ui64, ui64, arrow::UInt64Builder, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx); + return std::make_unique<TPreparedSumBlockAggregator<ui64, ui64, arrow::UInt64Builder, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], arrow::uint64()); default: throw yexception() << "Unsupported SUM input type"; } @@ -484,37 +570,61 @@ public: } }; +template <typename TIn, typename TInScalar> +class TPreparedAvgBlockAggregator : public IPreparedBlockAggregator { +public: + TPreparedAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, + const std::shared_ptr<arrow::DataType>& builderDataType) + : FilterColumn_(filterColumn) + , ArgColumn_(argColumn) + , BuilderDataType_(builderDataType) + {} + + std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final { + return std::make_unique<TAvgBlockAggregator<TIn, TInScalar>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx); + } + +private: + const std::optional<ui32> FilterColumn_; + const ui32 ArgColumn_; + const std::shared_ptr<arrow::DataType> BuilderDataType_; +}; + class TBlockAvgFactory : public IBlockAggregatorFactory { public: - std::unique_ptr<IBlockAggregator> Make( + std::unique_ptr<IPreparedBlockAggregator> Prepare( TTupleType* tupleType, std::optional<ui32> filterColumn, const std::vector<ui32>& argsColumns, - TComputationContext& ctx) const final { - auto argType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]))->GetItemType(); + const TTypeEnvironment& env) const final { + + auto doubleType = TDataType::Create(NUdf::TDataType<double>::Id, env); + auto ui64Type = TDataType::Create(NUdf::TDataType<ui64>::Id, env); + TVector<TType*> tupleElements = { doubleType, ui64Type }; + auto avgRetType = TTupleType::Create(2, tupleElements.data(), env); + std::shared_ptr<arrow::DataType> builderDataType; bool isOptional; + MKQL_ENSURE(ConvertArrowType(avgRetType, isOptional, builderDataType), "Unsupported builder type"); + + auto argType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]))->GetItemType(); auto dataType = UnpackOptionalData(argType, isOptional); switch (*dataType->GetDataSlot()) { case NUdf::EDataSlot::Int8: - return std::make_unique<TAvgBlockAggregator<i8, arrow::Int8Scalar>>(filterColumn, argsColumns[0], ctx); + return std::make_unique<TPreparedAvgBlockAggregator<i8, arrow::Int8Scalar>>(filterColumn, argsColumns[0], builderDataType); case NUdf::EDataSlot::Uint8: - return std::make_unique<TAvgBlockAggregator<ui8, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], ctx); + return std::make_unique<TPreparedAvgBlockAggregator<ui8, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], builderDataType); case NUdf::EDataSlot::Int16: - return std::make_unique<TAvgBlockAggregator<i16, arrow::Int16Scalar>>(filterColumn, argsColumns[0], ctx); + return std::make_unique<TPreparedAvgBlockAggregator<i16, arrow::Int16Scalar>>(filterColumn, argsColumns[0], builderDataType); case NUdf::EDataSlot::Uint16: - case NUdf::EDataSlot::Date: - return std::make_unique<TAvgBlockAggregator<ui16, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], ctx); + return std::make_unique<TPreparedAvgBlockAggregator<ui16, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], builderDataType); case NUdf::EDataSlot::Int32: - return std::make_unique<TAvgBlockAggregator<i32, arrow::Int32Scalar>>(filterColumn, argsColumns[0], ctx); + return std::make_unique<TPreparedAvgBlockAggregator<i32, arrow::Int32Scalar>>(filterColumn, argsColumns[0], builderDataType); case NUdf::EDataSlot::Uint32: - case NUdf::EDataSlot::Datetime: - return std::make_unique<TAvgBlockAggregator<ui32, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], ctx); + return std::make_unique<TPreparedAvgBlockAggregator<ui32, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], builderDataType); case NUdf::EDataSlot::Int64: - case NUdf::EDataSlot::Interval: - return std::make_unique<TAvgBlockAggregator<i64, arrow::Int64Scalar>>(filterColumn, argsColumns[0], ctx); + return std::make_unique<TPreparedAvgBlockAggregator<i64, arrow::Int64Scalar>>(filterColumn, argsColumns[0], builderDataType); case NUdf::EDataSlot::Uint64: - case NUdf::EDataSlot::Timestamp: - return std::make_unique<TAvgBlockAggregator<ui64, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], ctx); + return std::make_unique<TPreparedAvgBlockAggregator<ui64, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], builderDataType); default: throw yexception() << "Unsupported AVG input type"; } @@ -528,6 +638,6 @@ std::unique_ptr<IBlockAggregatorFactory> MakeBlockSumFactory() { std::unique_ptr<IBlockAggregatorFactory> MakeBlockAvgFactory() { return std::make_unique<TBlockAvgFactory>(); } - + } } |