aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-12-14 16:02:08 +0300
committervvvv <vvvv@ydb.tech>2022-12-14 16:02:08 +0300
commit58e33e587065c6bac9600a43369d97363ed84184 (patch)
tree734e5f169e324915338d368f1d9e88c33be2dcd5
parent7ab8d966fbb1720e24f2b00d111f19fa60ff0a40 (diff)
downloadydb-58e33e587065c6bac9600a43369d97363ed84184.tar.gz
support of avg over GROUP BY with keys
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp20
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp44
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp6
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h15
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp77
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp200
6 files changed, 274 insertions, 88 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
index 5ffdb8ad4a9..ae8af5df98f 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
@@ -111,9 +111,7 @@ namespace NMiniKQL {
namespace {
struct TAggParams {
- TStringBuf Name;
- TTupleType* TupleType;
- std::vector<ui32> ArgColumns;
+ std::unique_ptr<IPreparedBlockAggregator> Prepared_;
};
struct TKeyParams {
@@ -412,8 +410,7 @@ private:
ui32 totalStateSize = 0;
for (const auto& p : params) {
- Aggs_.emplace_back(MakeBlockAggregator(p.Name, p.TupleType, filterColumn, p.ArgColumns, ctx));
-
+ Aggs_.emplace_back(p.Prepared_->Make(ctx));
totalStateSize += Aggs_.back()->StateSize;
}
@@ -716,8 +713,7 @@ private:
}
for (const auto& p : params) {
- Aggs_.emplace_back(MakeBlockAggregator(p.Name, p.TupleType, filterColumn, p.ArgColumns, ctx));
-
+ Aggs_.emplace_back(p.Prepared_->Make(ctx));
TotalStateSize_ += Aggs_.back()->StateSize;
}
@@ -756,7 +752,7 @@ private:
std::vector<std::unique_ptr<IKeySerializer>> KeySerializers_;
};
-void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, TVector<TAggParams>& aggsParams) {
+void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, std::optional<ui32> filterColumn, TVector<TAggParams>& aggsParams, const TTypeEnvironment& env) {
for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) {
auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i));
auto name = AS_VALUE(TDataLiteral, aggVal->GetValue(0))->AsValue().AsStringRef();
@@ -766,7 +762,9 @@ void FillAggParams(TTupleLiteral* aggsVal, TTupleType* tupleType, TVector<TAggPa
argColumns.push_back(AS_VALUE(TDataLiteral, aggVal->GetValue(j))->AsValue().Get<ui32>());
}
- aggsParams.emplace_back(TAggParams{ TStringBuf(name), tupleType, argColumns });
+ TAggParams p;
+ p.Prepared_ = PrepareBlockAggregator(name, tupleType, filterColumn, argColumns, env);
+ aggsParams.emplace_back(std::move(p));
}
}
@@ -809,7 +807,7 @@ IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNod
auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(2));
TVector<TAggParams> aggsParams;
- FillAggParams(aggsVal, tupleType, aggsParams);
+ FillAggParams(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env);
return new TBlockCombineAllWrapper(ctx.Mutables, wideFlow, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams));
}
@@ -836,7 +834,7 @@ IComputationNode* WrapBlockCombineHashed(TCallable& callable, const TComputation
auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(3));
TVector<TAggParams> aggsParams;
- FillAggParams(aggsVal, tupleType, aggsParams);
+ FillAggParams(aggsVal, tupleType, filterColumn, aggsParams, ctx.Env);
ui32 totalKeysSize = 0;
std::vector<std::unique_ptr<IKeySerializer>> keySerializers;
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
index 0224d23a5d2..5d1cae3a812 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
@@ -195,28 +195,60 @@ private:
const ui32 ArgColumn_;
};
+class TPreparedCountAllBlockAggregator : public IPreparedBlockAggregator {
+public:
+ TPreparedCountAllBlockAggregator(std::optional<ui32> filterColumn)
+ : FilterColumn_(filterColumn)
+ {}
+
+ std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final {
+ return std::make_unique<TCountAllBlockAggregator>(FilterColumn_, ctx);
+ }
+
+private:
+ const std::optional<ui32> FilterColumn_;
+};
+
class TBlockCountAllFactory : public IBlockAggregatorFactory {
public:
- std::unique_ptr<IBlockAggregator> Make(
+ std::unique_ptr<IPreparedBlockAggregator> Prepare(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
- TComputationContext& ctx) const final {
+ const TTypeEnvironment& env) const final {
Y_UNUSED(tupleType);
Y_UNUSED(argsColumns);
- return std::make_unique<TCountAllBlockAggregator>(filterColumn, ctx);
+ Y_UNUSED(env);
+ return std::make_unique<TPreparedCountAllBlockAggregator>(filterColumn);
}
};
+class TPreparedCountBlockAggregator : public IPreparedBlockAggregator {
+public:
+ TPreparedCountBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn)
+ : FilterColumn_(filterColumn)
+ , ArgColumn_(argColumn)
+ {}
+
+ std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final {
+ return std::make_unique<TCountBlockAggregator>(FilterColumn_, ArgColumn_, ctx);
+ }
+
+private:
+ const std::optional<ui32> FilterColumn_;
+ const ui32 ArgColumn_;
+};
+
class TBlockCountFactory : public IBlockAggregatorFactory {
public:
- std::unique_ptr<IBlockAggregator> Make(
+ std::unique_ptr<IPreparedBlockAggregator> Prepare(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
- TComputationContext& ctx) const final {
+ const TTypeEnvironment& env) const final {
Y_UNUSED(tupleType);
- return std::make_unique<TCountBlockAggregator>(filterColumn, argsColumns[0], ctx);
+ Y_UNUSED(env);
+ return std::make_unique<TPreparedCountBlockAggregator>(filterColumn, argsColumns[0]);
}
};
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp
index a0a66aa9c96..e96e5f3de2b 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp
@@ -20,19 +20,19 @@ struct TAggregatorFactories {
}
};
-std::unique_ptr<IBlockAggregator> MakeBlockAggregator(
+std::unique_ptr<IPreparedBlockAggregator> PrepareBlockAggregator(
TStringBuf name,
TTupleType* tupleType,
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
- TComputationContext& ctx) {
+ const TTypeEnvironment& env) {
const auto& f = Singleton<TAggregatorFactories>()->Factories;
auto it = f.find(name);
if (it == f.end()) {
throw yexception() << "Unsupported block aggregation function: " << name;
}
- return it->second->Make(tupleType, filterColumn, argsColumns, ctx);
+ return it->second->Prepare(tupleType, filterColumn, argsColumns, env);
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
index 2eae714e020..80b9de36723 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
@@ -52,24 +52,29 @@ protected:
TComputationContext& Ctx_;
};
-class THolderFactory;
+class IPreparedBlockAggregator {
+public:
+ virtual ~IPreparedBlockAggregator() = default;
+
+ virtual std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const = 0;
+};
-std::unique_ptr<IBlockAggregator> MakeBlockAggregator(
+std::unique_ptr<IPreparedBlockAggregator> PrepareBlockAggregator(
TStringBuf name,
TTupleType* tupleType,
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
- TComputationContext& ctx);
+ const TTypeEnvironment& env);
class IBlockAggregatorFactory {
public:
virtual ~IBlockAggregatorFactory() = default;
- virtual std::unique_ptr<IBlockAggregator> Make(
+ virtual std::unique_ptr<IPreparedBlockAggregator> Prepare(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
- TComputationContext& ctx) const = 0;
+ const TTypeEnvironment& env) const = 0;
};
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp
index e961aedaf14..2992593eafc 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp
@@ -308,14 +308,55 @@ private:
const std::shared_ptr<arrow::DataType> BuilderDataType_;
};
+template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin>
+class TPreparedMinMaxBlockAggregatorNullableOrScalar : public IPreparedBlockAggregator {
+public:
+ TPreparedMinMaxBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn,
+ const std::shared_ptr<arrow::DataType>& builderDataType)
+ : FilterColumn_(filterColumn)
+ , ArgColumn_(argColumn)
+ , BuilderDataType_(builderDataType)
+ {}
+
+ std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final {
+ return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<TIn, TInScalar, TBuilder, IsMin>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx);
+ }
+
+private:
+ const std::optional<ui32> FilterColumn_;
+ const ui32 ArgColumn_;
+ const std::shared_ptr<arrow::DataType> BuilderDataType_;
+};
+
+template <typename TIn, typename TInScalar, typename TBuilder, bool IsMin>
+class TPreparedMinMaxBlockAggregator : public IPreparedBlockAggregator {
+public:
+ TPreparedMinMaxBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn,
+ const std::shared_ptr<arrow::DataType>& builderDataType)
+ : FilterColumn_(filterColumn)
+ , ArgColumn_(argColumn)
+ , BuilderDataType_(builderDataType)
+ {}
+
+ std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final {
+ return std::make_unique<TMinMaxBlockAggregator<TIn, TInScalar, TBuilder, IsMin>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx);
+ }
+
+private:
+ const std::optional<ui32> FilterColumn_;
+ const ui32 ArgColumn_;
+ const std::shared_ptr<arrow::DataType> BuilderDataType_;
+};
+
template <bool IsMin>
class TBlockMinMaxFactory : public IBlockAggregatorFactory {
public:
- std::unique_ptr<IBlockAggregator> Make(
+ std::unique_ptr<IPreparedBlockAggregator> Prepare(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
- TComputationContext& ctx) const final {
+ const TTypeEnvironment& env) const final {
+ Y_UNUSED(env);
auto blockType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]));
auto argType = blockType->GetItemType();
bool isOptional;
@@ -323,52 +364,52 @@ public:
if (blockType->GetShape() == TBlockType::EShape::Scalar || isOptional) {
switch (*dataType->GetDataSlot()) {
case NUdf::EDataSlot::Int8:
- return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int8(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int8());
case NUdf::EDataSlot::Bool:
case NUdf::EDataSlot::Uint8:
- return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint8(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint8());
case NUdf::EDataSlot::Int16:
- return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int16(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int16());
case NUdf::EDataSlot::Uint16:
case NUdf::EDataSlot::Date:
- return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint16(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint16());
case NUdf::EDataSlot::Int32:
- return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int32(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int32());
case NUdf::EDataSlot::Uint32:
case NUdf::EDataSlot::Datetime:
- return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint32(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint32());
case NUdf::EDataSlot::Int64:
case NUdf::EDataSlot::Interval:
- return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint64:
case NUdf::EDataSlot::Timestamp:
- return std::make_unique<TMinMaxBlockAggregatorNullableOrScalar<ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregatorNullableOrScalar<ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint64());
default:
throw yexception() << "Unsupported MIN/MAX input type";
}
} else {
switch (*dataType->GetDataSlot()) {
case NUdf::EDataSlot::Int8:
- return std::make_unique<TMinMaxBlockAggregator<i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int8(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregator<i8, arrow::Int8Scalar, arrow::Int8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int8());
case NUdf::EDataSlot::Uint8:
case NUdf::EDataSlot::Bool:
- return std::make_unique<TMinMaxBlockAggregator<ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint8(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregator<ui8, arrow::UInt8Scalar, arrow::UInt8Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint8());
case NUdf::EDataSlot::Int16:
- return std::make_unique<TMinMaxBlockAggregator<i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int16(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregator<i16, arrow::Int16Scalar, arrow::Int16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int16());
case NUdf::EDataSlot::Uint16:
case NUdf::EDataSlot::Date:
- return std::make_unique<TMinMaxBlockAggregator<ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint16(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregator<ui16, arrow::UInt16Scalar, arrow::UInt16Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint16());
case NUdf::EDataSlot::Int32:
- return std::make_unique<TMinMaxBlockAggregator<i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int32(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregator<i32, arrow::Int32Scalar, arrow::Int32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int32());
case NUdf::EDataSlot::Uint32:
case NUdf::EDataSlot::Datetime:
- return std::make_unique<TMinMaxBlockAggregator<ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint32(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregator<ui32, arrow::UInt32Scalar, arrow::UInt32Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint32());
case NUdf::EDataSlot::Int64:
case NUdf::EDataSlot::Interval:
- return std::make_unique<TMinMaxBlockAggregator<i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregator<i64, arrow::Int64Scalar, arrow::Int64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint64:
case NUdf::EDataSlot::Timestamp:
- return std::make_unique<TMinMaxBlockAggregator<ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedMinMaxBlockAggregator<ui64, arrow::UInt64Scalar, arrow::UInt64Builder, IsMin>>(filterColumn, argsColumns[0], arrow::uint64());
default:
throw yexception() << "Unsupported MIN/MAX input type";
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
index 5d1c162ffd3..9894a3c2317 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
@@ -288,9 +288,62 @@ public:
ui64 Count_ = 0;
};
- TAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, TComputationContext& ctx)
+ class TColumnBuilder : public IAggColumnBuilder {
+ public:
+ TColumnBuilder(ui64 size, const std::shared_ptr<arrow::DataType>& arrowType, TComputationContext& ctx)
+ : ArrowType_(arrowType)
+ , Ctx_(ctx)
+ , NullBitmapBuilder_(&ctx.ArrowMemoryPool)
+ , SumBuilder_(arrow::float64(), &ctx.ArrowMemoryPool)
+ , CountBuilder_(arrow::uint64(), &ctx.ArrowMemoryPool)
+ {
+ ARROW_OK(NullBitmapBuilder_.Reserve(size));
+ ARROW_OK(SumBuilder_.Reserve(size));
+ ARROW_OK(CountBuilder_.Reserve(size));
+ }
+
+ void Add(const void* state) final {
+ auto typedState = static_cast<const TState*>(state);
+ if (typedState->Count_) {
+ NullBitmapBuilder_.UnsafeAppend(true);
+ SumBuilder_.UnsafeAppend(typedState->Sum_);
+ CountBuilder_.UnsafeAppend(typedState->Count_);
+ } else {
+ NullBitmapBuilder_.UnsafeAppend(false);
+ SumBuilder_.UnsafeAppendNull();
+ CountBuilder_.UnsafeAppendNull();
+ }
+ }
+
+ NUdf::TUnboxedValue Build() final {
+ std::shared_ptr<arrow::ArrayData> sumResult;
+ std::shared_ptr<arrow::ArrayData> countResult;
+ ARROW_OK(SumBuilder_.FinishInternal(&sumResult));
+ ARROW_OK(CountBuilder_.FinishInternal(&countResult));
+ std::shared_ptr<arrow::Buffer> nullBitmap;
+ auto length = NullBitmapBuilder_.length();
+ auto nullCount = NullBitmapBuilder_.false_count();
+ ARROW_OK(NullBitmapBuilder_.Finish(&nullBitmap));
+
+ auto arrayData = arrow::ArrayData::Make(ArrowType_, length, { nullBitmap }, nullCount, 0);
+ arrayData->child_data.push_back(sumResult);
+ arrayData->child_data.push_back(countResult);
+ return Ctx_.HolderFactory.CreateArrowBlock(arrow::Datum(arrayData));
+ }
+
+ private:
+ const std::shared_ptr<arrow::DataType> ArrowType_;
+ TComputationContext& Ctx_;
+ arrow::TypedBufferBuilder<bool> NullBitmapBuilder_;
+ arrow::DoubleBuilder SumBuilder_;
+ arrow::UInt64Builder CountBuilder_;
+ };
+
+ TAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn,
+ const std::shared_ptr<arrow::DataType> builderDataType, TComputationContext& ctx)
: TBlockAggregatorBase(sizeof(TState), filterColumn, ctx)
, ArgColumn_(argColumn)
+ , BuilderDataType_(builderDataType)
{
}
@@ -411,21 +464,62 @@ public:
}
std::unique_ptr<IAggColumnBuilder> MakeBuilder(ui64 size) final {
- Y_UNUSED(size);
- MKQL_ENSURE(false, "TODO: support of tuples");
+ return std::make_unique<TColumnBuilder>(size, BuilderDataType_, Ctx_);
+ }
+
+private:
+ const ui32 ArgColumn_;
+ const std::shared_ptr<arrow::DataType> BuilderDataType_;
+};
+
+template <typename TIn, typename TSum, typename TBuilder, typename TInScalar>
+class TPreparedSumBlockAggregatorNullableOrScalar : public IPreparedBlockAggregator {
+public:
+ TPreparedSumBlockAggregatorNullableOrScalar(std::optional<ui32> filterColumn, ui32 argColumn,
+ const std::shared_ptr<arrow::DataType>& builderDataType)
+ : FilterColumn_(filterColumn)
+ , ArgColumn_(argColumn)
+ , BuilderDataType_(builderDataType)
+ {}
+
+ std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final {
+ return std::make_unique<TSumBlockAggregatorNullableOrScalar<TIn, TSum, TBuilder, TInScalar>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx);
+ }
+
+private:
+ const std::optional<ui32> FilterColumn_;
+ const ui32 ArgColumn_;
+ const std::shared_ptr<arrow::DataType> BuilderDataType_;
+};
+
+template <typename TIn, typename TSum, typename TBuilder, typename TInScalar>
+class TPreparedSumBlockAggregator : public IPreparedBlockAggregator {
+public:
+ TPreparedSumBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn,
+ const std::shared_ptr<arrow::DataType>& builderDataType)
+ : FilterColumn_(filterColumn)
+ , ArgColumn_(argColumn)
+ , BuilderDataType_(builderDataType)
+ {}
+
+ std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final {
+ return std::make_unique<TSumBlockAggregator<TIn, TSum, TBuilder, TInScalar>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx);
}
private:
+ const std::optional<ui32> FilterColumn_;
const ui32 ArgColumn_;
+ const std::shared_ptr<arrow::DataType> BuilderDataType_;
};
class TBlockSumFactory : public IBlockAggregatorFactory {
public:
- std::unique_ptr<IBlockAggregator> Make(
+ std::unique_ptr<IPreparedBlockAggregator> Prepare(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
- TComputationContext& ctx) const final {
+ const TTypeEnvironment& env) const final {
+ Y_UNUSED(env);
auto blockType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]));
auto argType = blockType->GetItemType();
bool isOptional;
@@ -433,50 +527,42 @@ public:
if (blockType->GetShape() == TBlockType::EShape::Scalar || isOptional) {
switch (*dataType->GetDataSlot()) {
case NUdf::EDataSlot::Int8:
- return std::make_unique<TSumBlockAggregatorNullableOrScalar<i8, i64, arrow::Int64Builder, arrow::Int8Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<i8, i64, arrow::Int64Builder, arrow::Int8Scalar>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint8:
- return std::make_unique<TSumBlockAggregatorNullableOrScalar<ui8, ui64, arrow::UInt64Builder, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<ui8, ui64, arrow::UInt64Builder, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], arrow::uint64());
case NUdf::EDataSlot::Int16:
- return std::make_unique<TSumBlockAggregatorNullableOrScalar<i16, i64, arrow::Int64Builder, arrow::Int16Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<i16, i64, arrow::Int64Builder, arrow::Int16Scalar>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint16:
- case NUdf::EDataSlot::Date:
- return std::make_unique<TSumBlockAggregatorNullableOrScalar<ui16, ui64, arrow::UInt64Builder, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<ui16, ui64, arrow::UInt64Builder, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], arrow::uint64());
case NUdf::EDataSlot::Int32:
- return std::make_unique<TSumBlockAggregatorNullableOrScalar<i32, i64, arrow::Int64Builder, arrow::Int32Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<i32, i64, arrow::Int64Builder, arrow::Int32Scalar>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint32:
- case NUdf::EDataSlot::Datetime:
- return std::make_unique<TSumBlockAggregatorNullableOrScalar<ui32, ui64, arrow::UInt64Builder, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<ui32, ui64, arrow::UInt64Builder, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], arrow::uint64());
case NUdf::EDataSlot::Int64:
- case NUdf::EDataSlot::Interval:
- return std::make_unique<TSumBlockAggregatorNullableOrScalar<i64, i64, arrow::Int64Builder, arrow::Int64Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<i64, i64, arrow::Int64Builder, arrow::Int64Scalar>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint64:
- case NUdf::EDataSlot::Timestamp:
- return std::make_unique<TSumBlockAggregatorNullableOrScalar<ui64, ui64, arrow::UInt64Builder, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregatorNullableOrScalar<ui64, ui64, arrow::UInt64Builder, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], arrow::uint64());
default:
throw yexception() << "Unsupported SUM input type";
}
} else {
switch (*dataType->GetDataSlot()) {
case NUdf::EDataSlot::Int8:
- return std::make_unique<TSumBlockAggregator<i8, i64, arrow::Int64Builder, arrow::Int8Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregator<i8, i64, arrow::Int64Builder, arrow::Int8Scalar>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint8:
- return std::make_unique<TSumBlockAggregator<ui8, ui64, arrow::UInt64Builder, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregator<ui8, ui64, arrow::UInt64Builder, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], arrow::uint64());
case NUdf::EDataSlot::Int16:
- return std::make_unique<TSumBlockAggregator<i16, i64, arrow::Int64Builder, arrow::Int16Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregator<i16, i64, arrow::Int64Builder, arrow::Int16Scalar>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint16:
- case NUdf::EDataSlot::Date:
- return std::make_unique<TSumBlockAggregator<ui16, ui64, arrow::UInt64Builder, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregator<ui16, ui64, arrow::UInt64Builder, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], arrow::uint64());
case NUdf::EDataSlot::Int32:
- return std::make_unique<TSumBlockAggregator<i32, i64, arrow::Int64Builder, arrow::Int32Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregator<i32, i64, arrow::Int64Builder, arrow::Int32Scalar>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint32:
- case NUdf::EDataSlot::Datetime:
- return std::make_unique<TSumBlockAggregator<ui32, ui64, arrow::UInt64Builder, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregator<ui32, ui64, arrow::UInt64Builder, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], arrow::uint64());
case NUdf::EDataSlot::Int64:
- case NUdf::EDataSlot::Interval:
- return std::make_unique<TSumBlockAggregator<i64, i64, arrow::Int64Builder, arrow::Int64Scalar>>(filterColumn, argsColumns[0], arrow::int64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregator<i64, i64, arrow::Int64Builder, arrow::Int64Scalar>>(filterColumn, argsColumns[0], arrow::int64());
case NUdf::EDataSlot::Uint64:
- case NUdf::EDataSlot::Timestamp:
- return std::make_unique<TSumBlockAggregator<ui64, ui64, arrow::UInt64Builder, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], arrow::uint64(), ctx);
+ return std::make_unique<TPreparedSumBlockAggregator<ui64, ui64, arrow::UInt64Builder, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], arrow::uint64());
default:
throw yexception() << "Unsupported SUM input type";
}
@@ -484,37 +570,61 @@ public:
}
};
+template <typename TIn, typename TInScalar>
+class TPreparedAvgBlockAggregator : public IPreparedBlockAggregator {
+public:
+ TPreparedAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn,
+ const std::shared_ptr<arrow::DataType>& builderDataType)
+ : FilterColumn_(filterColumn)
+ , ArgColumn_(argColumn)
+ , BuilderDataType_(builderDataType)
+ {}
+
+ std::unique_ptr<IBlockAggregator> Make(TComputationContext& ctx) const final {
+ return std::make_unique<TAvgBlockAggregator<TIn, TInScalar>>(FilterColumn_, ArgColumn_, BuilderDataType_, ctx);
+ }
+
+private:
+ const std::optional<ui32> FilterColumn_;
+ const ui32 ArgColumn_;
+ const std::shared_ptr<arrow::DataType> BuilderDataType_;
+};
+
class TBlockAvgFactory : public IBlockAggregatorFactory {
public:
- std::unique_ptr<IBlockAggregator> Make(
+ std::unique_ptr<IPreparedBlockAggregator> Prepare(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
- TComputationContext& ctx) const final {
- auto argType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]))->GetItemType();
+ const TTypeEnvironment& env) const final {
+
+ auto doubleType = TDataType::Create(NUdf::TDataType<double>::Id, env);
+ auto ui64Type = TDataType::Create(NUdf::TDataType<ui64>::Id, env);
+ TVector<TType*> tupleElements = { doubleType, ui64Type };
+ auto avgRetType = TTupleType::Create(2, tupleElements.data(), env);
+ std::shared_ptr<arrow::DataType> builderDataType;
bool isOptional;
+ MKQL_ENSURE(ConvertArrowType(avgRetType, isOptional, builderDataType), "Unsupported builder type");
+
+ auto argType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]))->GetItemType();
auto dataType = UnpackOptionalData(argType, isOptional);
switch (*dataType->GetDataSlot()) {
case NUdf::EDataSlot::Int8:
- return std::make_unique<TAvgBlockAggregator<i8, arrow::Int8Scalar>>(filterColumn, argsColumns[0], ctx);
+ return std::make_unique<TPreparedAvgBlockAggregator<i8, arrow::Int8Scalar>>(filterColumn, argsColumns[0], builderDataType);
case NUdf::EDataSlot::Uint8:
- return std::make_unique<TAvgBlockAggregator<ui8, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], ctx);
+ return std::make_unique<TPreparedAvgBlockAggregator<ui8, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], builderDataType);
case NUdf::EDataSlot::Int16:
- return std::make_unique<TAvgBlockAggregator<i16, arrow::Int16Scalar>>(filterColumn, argsColumns[0], ctx);
+ return std::make_unique<TPreparedAvgBlockAggregator<i16, arrow::Int16Scalar>>(filterColumn, argsColumns[0], builderDataType);
case NUdf::EDataSlot::Uint16:
- case NUdf::EDataSlot::Date:
- return std::make_unique<TAvgBlockAggregator<ui16, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], ctx);
+ return std::make_unique<TPreparedAvgBlockAggregator<ui16, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], builderDataType);
case NUdf::EDataSlot::Int32:
- return std::make_unique<TAvgBlockAggregator<i32, arrow::Int32Scalar>>(filterColumn, argsColumns[0], ctx);
+ return std::make_unique<TPreparedAvgBlockAggregator<i32, arrow::Int32Scalar>>(filterColumn, argsColumns[0], builderDataType);
case NUdf::EDataSlot::Uint32:
- case NUdf::EDataSlot::Datetime:
- return std::make_unique<TAvgBlockAggregator<ui32, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], ctx);
+ return std::make_unique<TPreparedAvgBlockAggregator<ui32, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], builderDataType);
case NUdf::EDataSlot::Int64:
- case NUdf::EDataSlot::Interval:
- return std::make_unique<TAvgBlockAggregator<i64, arrow::Int64Scalar>>(filterColumn, argsColumns[0], ctx);
+ return std::make_unique<TPreparedAvgBlockAggregator<i64, arrow::Int64Scalar>>(filterColumn, argsColumns[0], builderDataType);
case NUdf::EDataSlot::Uint64:
- case NUdf::EDataSlot::Timestamp:
- return std::make_unique<TAvgBlockAggregator<ui64, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], ctx);
+ return std::make_unique<TPreparedAvgBlockAggregator<ui64, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], builderDataType);
default:
throw yexception() << "Unsupported AVG input type";
}
@@ -528,6 +638,6 @@ std::unique_ptr<IBlockAggregatorFactory> MakeBlockSumFactory() {
std::unique_ptr<IBlockAggregatorFactory> MakeBlockAvgFactory() {
return std::make_unique<TBlockAvgFactory>();
}
-
+
}
}