aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-11-10 20:11:10 +0300
committervvvv <vvvv@ydb.tech>2022-11-10 20:11:10 +0300
commit27861ff30998a509903c4213f30192757f483dc1 (patch)
treeaadff4d208b21aa03579c4a1257d54732279c4fc
parentcfb04309ef3039a207f5b377e65c47e791977699 (diff)
downloadydb-27861ff30998a509903c4213f30192757f483dc1.tar.gz
avg
-rw-r--r--ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp11
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp38
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp8
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp6
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h8
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp126
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.h1
7 files changed, 169 insertions, 29 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
index 30cbbc2a3b..9b3b90034c 100644
--- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
+++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
@@ -4346,10 +4346,19 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext&
}
auto multiInputType = node->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>();
+ TVector<const TTypeAnnotationNode*> allInputTypes;
for (const auto& i : multiInputType->GetItems()) {
- if (i->GetKind() == ETypeAnnotationKind::Block) {
+ if (i->GetKind() == ETypeAnnotationKind::Block || i->GetKind() == ETypeAnnotationKind::Scalar) {
return node;
}
+
+ allInputTypes.push_back(i);
+ }
+
+ bool supportedInputTypes = false;
+ YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Pos()), allInputTypes, supportedInputTypes, ctx));
+ if (!supportedInputTypes) {
+ return node;
}
TExprNode::TListType blockArgs;
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
index 4b01cc8351..421b80dbcb 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
@@ -11,18 +11,26 @@ namespace NMiniKQL {
namespace {
+struct TAggParams {
+ TStringBuf Name;
+ TTupleType* TupleType;
+ std::vector<ui32> ArgColumns;
+};
+
class TBlockCombineAllWrapper : public TStatefulWideFlowComputationNode<TBlockCombineAllWrapper> {
public:
TBlockCombineAllWrapper(TComputationMutables& mutables,
IComputationWideFlowNode* flow,
ui32 countColumn,
+ std::optional<ui32> filterColumn,
size_t width,
- TVector<std::unique_ptr<IBlockAggregator>>&& aggs)
+ TVector<TAggParams>&& aggsParams)
: TStatefulWideFlowComputationNode(mutables, flow, EValueRepresentation::Any)
, Flow_(flow)
, CountColumn_(countColumn)
+ , FilterColumn_(filterColumn)
, Width_(width)
- , Aggs_(std::move(aggs))
+ , AggsParams_(std::move(aggsParams))
{
}
@@ -46,9 +54,9 @@ public:
}
s.HasValues_ = true;
- for (size_t i = 0; i < Aggs_.size(); ++i) {
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
if (output[i]) {
- Aggs_[i]->AddMany(s.Values_.data(), batchLength);
+ s.Aggs_[i]->AddMany(s.Values_.data(), batchLength);
}
}
} else {
@@ -57,9 +65,9 @@ public:
return EFetchResult::Finish;
}
- for (size_t i = 0; i < Aggs_.size(); ++i) {
+ for (size_t i = 0; i < s.Aggs_.size(); ++i) {
if (auto* out = output[i]; out != nullptr) {
- *out = Aggs_[i]->Finish();
+ *out = s.Aggs_[i]->Finish();
}
}
@@ -74,10 +82,11 @@ private:
struct TState : public TComputationValue<TState> {
TVector<NUdf::TUnboxedValue> Values_;
TVector<NUdf::TUnboxedValue*> ValuePointers_;
+ TVector<std::unique_ptr<IBlockAggregator>> Aggs_;
bool IsFinished_ = false;
bool HasValues_ = false;
- TState(TMemoryUsageInfo* memInfo, size_t width)
+ TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const TVector<TAggParams>& params, const THolderFactory& holderFactory)
: TComputationValue(memInfo)
, Values_(width)
, ValuePointers_(width)
@@ -85,6 +94,10 @@ private:
for (size_t i = 0; i < width; ++i) {
ValuePointers_[i] = &Values_[i];
}
+
+ for (const auto& p : params) {
+ Aggs_.emplace_back(MakeBlockAggregator(p.Name, p.TupleType, filterColumn, p.ArgColumns, holderFactory));
+ }
}
};
@@ -95,7 +108,7 @@ private:
TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
if (!state.HasValue()) {
- state = ctx.HolderFactory.Create<TState>(Width_);
+ state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx.HolderFactory);
}
return *static_cast<TState*>(state.AsBoxed().Get());
}
@@ -107,8 +120,9 @@ private:
private:
IComputationWideFlowNode* Flow_;
const ui32 CountColumn_;
+ std::optional<ui32> FilterColumn_;
const size_t Width_;
- TVector<std::unique_ptr<IBlockAggregator>> Aggs_;
+ const TVector<TAggParams> AggsParams_;
};
}
@@ -130,7 +144,7 @@ IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNod
MKQL_ENSURE(!filterColumn, "Filter column is not supported yet");
auto aggsVal = AS_VALUE(TTupleLiteral, callable.GetInput(3));
- TVector<std::unique_ptr<IBlockAggregator>> aggs;
+ TVector<TAggParams> aggsParams;
for (ui32 i = 0; i < aggsVal->GetValuesCount(); ++i) {
auto aggVal = AS_VALUE(TTupleLiteral, aggsVal->GetValue(i));
auto name = AS_VALUE(TDataLiteral, aggVal->GetValue(0))->AsValue().AsStringRef();
@@ -140,10 +154,10 @@ IComputationNode* WrapBlockCombineAll(TCallable& callable, const TComputationNod
argColumns.push_back(AS_VALUE(TDataLiteral, aggVal->GetValue(j))->AsValue().Get<ui32>());
}
- aggs.emplace_back(MakeBlockAggregator(name, tupleType, filterColumn, argColumns));
+ aggsParams.emplace_back(TAggParams{ TStringBuf(name), tupleType, argColumns });
}
- return new TBlockCombineAllWrapper(ctx.Mutables, wideFlow, countColumn, tupleType->GetElementsCount(), std::move(aggs));
+ return new TBlockCombineAllWrapper(ctx.Mutables, wideFlow, countColumn, filterColumn, tupleType->GetElementsCount(), std::move(aggsParams));
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
index 765f9928f2..169e12a655 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
@@ -57,9 +57,11 @@ public:
std::unique_ptr<IBlockAggregator> Make(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
- const std::vector<ui32>& argsColumns) const final {
+ const std::vector<ui32>& argsColumns,
+ const THolderFactory& holderFactory) const final {
Y_UNUSED(tupleType);
Y_UNUSED(argsColumns);
+ Y_UNUSED(holderFactory);
return std::make_unique<TCountAllBlockAggregator>(filterColumn);
}
};
@@ -69,8 +71,10 @@ public:
std::unique_ptr<IBlockAggregator> Make(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
- const std::vector<ui32>& argsColumns) const final {
+ const std::vector<ui32>& argsColumns,
+ const THolderFactory& holderFactory) const final {
Y_UNUSED(tupleType);
+ Y_UNUSED(holderFactory);
return std::make_unique<TCountBlockAggregator>(filterColumn, argsColumns[0]);
}
};
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp
index 732fd8d984..6dc31dc518 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp
@@ -13,6 +13,7 @@ struct TAggregatorFactories {
Factories["count_all"] = MakeBlockCountAllFactory();
Factories["count"] = MakeBlockCountFactory();
Factories["sum"] = MakeBlockSumFactory();
+ Factories["avg"] = MakeBlockAvgFactory();
}
};
@@ -20,14 +21,15 @@ std::unique_ptr<IBlockAggregator> MakeBlockAggregator(
TStringBuf name,
TTupleType* tupleType,
std::optional<ui32> filterColumn,
- const std::vector<ui32>& argsColumns) {
+ const std::vector<ui32>& argsColumns,
+ const THolderFactory& holderFactory) {
const auto& f = Singleton<TAggregatorFactories>()->Factories;
auto it = f.find(name);
if (it == f.end()) {
throw yexception() << "Unsupported block aggregation function: " << name;
}
- return it->second->Make(tupleType, filterColumn, argsColumns);
+ return it->second->Make(tupleType, filterColumn, argsColumns, holderFactory);
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
index 21a4be24cd..a60b243e30 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.h
@@ -26,11 +26,14 @@ protected:
const std::optional<ui32> FilterColumn_;
};
+class THolderFactory;
+
std::unique_ptr<IBlockAggregator> MakeBlockAggregator(
TStringBuf name,
TTupleType* tupleType,
std::optional<ui32> filterColumn,
- const std::vector<ui32>& argsColumns);
+ const std::vector<ui32>& argsColumns,
+ const THolderFactory& holderFactory);
class IBlockAggregatorFactory {
public:
@@ -39,7 +42,8 @@ public:
virtual std::unique_ptr<IBlockAggregator> Make(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
- const std::vector<ui32>& argsColumns) const = 0;
+ const std::vector<ui32>& argsColumns,
+ const THolderFactory& holderFactory) const = 0;
};
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
index 4ab1fa71e0..b20fce0c5c 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
@@ -3,12 +3,14 @@
#include <ydb/library/yql/minikql/mkql_node_builder.h>
#include <ydb/library/yql/minikql/mkql_node_cast.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
+
#include <arrow/scalar.h>
namespace NKikimr {
namespace NMiniKQL {
-template <typename TIn, typename TState, typename TInScalar>
+template <typename TIn, typename TSum, typename TInScalar>
class TSumBlockAggregator : public TBlockAggregatorBase {
public:
TSumBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn)
@@ -21,7 +23,67 @@ public:
const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
if (datum.is_scalar()) {
if (datum.scalar()->is_valid) {
- State_ += batchLength * datum.scalar_as<TInScalar>().value;
+ Sum_ += batchLength * datum.scalar_as<TInScalar>().value;
+ Count_ += batchLength;
+ }
+ } else {
+ const auto& array = datum.array();
+ auto ptr = array->GetValues<TIn>(1);
+ auto len = array->length;
+ auto count = len - array->GetNullCount();
+ if (!count) {
+ return;
+ }
+
+ Count_ += count;
+ TSum sum = Sum_;
+ if (array->GetNullCount() == 0) {
+ for (int64_t i = 0; i < len; ++i) {
+ sum += ptr[i];
+ }
+ } else {
+ auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
+ for (int64_t i = 0; i < len; ++i) {
+ ui64 fullIndex = i + array->offset;
+ // bit 1 -> mask 0xFF..FF, bit 0 -> mask 0x00..00
+ TIn mask = (((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1) ^ 1) - TIn(1);
+ sum += (ptr[i] & mask);
+ }
+ }
+
+ Sum_ = sum;
+ }
+ }
+
+ NUdf::TUnboxedValue Finish() final {
+ if (!Count_) {
+ return NUdf::TUnboxedValuePod();
+ }
+
+ return NUdf::TUnboxedValuePod(Sum_);
+ }
+
+private:
+ const ui32 ArgColumn_;
+ TSum Sum_ = 0;
+ ui64 Count_ = 0;
+};
+
+template <typename TIn, typename TInScalar>
+class TAvgBlockAggregator : public TBlockAggregatorBase {
+public:
+ TAvgBlockAggregator(std::optional<ui32> filterColumn, ui32 argColumn, const THolderFactory& holderFactory)
+ : TBlockAggregatorBase(filterColumn)
+ , ArgColumn_(argColumn)
+ , HolderFactory_(holderFactory)
+ {
+ }
+
+ void AddMany(const NUdf::TUnboxedValue* columns, ui64 batchLength) final {
+ const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
+ if (datum.is_scalar()) {
+ if (datum.scalar()->is_valid) {
+ Sum_ += double(batchLength * datum.scalar_as<TInScalar>().value);
Count_ += batchLength;
}
} else {
@@ -34,22 +96,22 @@ public:
}
Count_ += count;
- TState state = State_;
+ double sum = Sum_;
if (array->GetNullCount() == 0) {
for (int64_t i = 0; i < len; ++i) {
- state += ptr[i];
+ sum += double(ptr[i]);
}
} else {
auto nullBitmapPtr = array->GetValues<uint8_t>(0, 0);
for (int64_t i = 0; i < len; ++i) {
ui64 fullIndex = i + array->offset;
// bit 1 -> mask 0xFF..FF, bit 0 -> mask 0x00..00
- TState mask = (((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1) ^ 1) - TState(1);
- state += ptr[i] & mask;
+ TIn mask = (((nullBitmapPtr[fullIndex >> 3] >> (fullIndex & 0x07)) & 1) ^ 1) - TIn(1);
+ sum += double(ptr[i] & mask);
}
}
- State_ = state;
+ Sum_ = sum;
}
}
@@ -58,12 +120,17 @@ public:
return NUdf::TUnboxedValuePod();
}
- return NUdf::TUnboxedValuePod(State_);
+ NUdf::TUnboxedValue* items;
+ auto arr = HolderFactory_.CreateDirectArrayHolder(2, items);
+ items[0] = NUdf::TUnboxedValuePod(Sum_);
+ items[1] = NUdf::TUnboxedValuePod(Count_);
+ return arr;
}
private:
const ui32 ArgColumn_;
- TState State_ = 0;
+ const THolderFactory& HolderFactory_;
+ double Sum_ = 0;
ui64 Count_ = 0;
};
@@ -72,7 +139,9 @@ public:
std::unique_ptr<IBlockAggregator> Make(
TTupleType* tupleType,
std::optional<ui32> filterColumn,
- const std::vector<ui32>& argsColumns) const final {
+ const std::vector<ui32>& argsColumns,
+ const THolderFactory& holderFactory) const final {
+ Y_UNUSED(holderFactory);
auto argType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]))->GetItemType();
bool isOptional;
auto dataType = UnpackOptionalData(argType, isOptional);
@@ -99,9 +168,46 @@ public:
}
};
+class TBlockAvgFactory : public IBlockAggregatorFactory {
+public:
+ std::unique_ptr<IBlockAggregator> Make(
+ TTupleType* tupleType,
+ std::optional<ui32> filterColumn,
+ const std::vector<ui32>& argsColumns,
+ const THolderFactory& holderFactory) const final {
+ auto argType = AS_TYPE(TBlockType, tupleType->GetElementType(argsColumns[0]))->GetItemType();
+ bool isOptional;
+ auto dataType = UnpackOptionalData(argType, isOptional);
+ switch (*dataType->GetDataSlot()) {
+ case NUdf::EDataSlot::Int8:
+ return std::make_unique<TAvgBlockAggregator<i8, arrow::Int8Scalar>>(filterColumn, argsColumns[0], holderFactory);
+ case NUdf::EDataSlot::Uint8:
+ return std::make_unique<TAvgBlockAggregator<ui8, arrow::UInt8Scalar>>(filterColumn, argsColumns[0], holderFactory);
+ case NUdf::EDataSlot::Int16:
+ return std::make_unique<TAvgBlockAggregator<i16, arrow::Int16Scalar>>(filterColumn, argsColumns[0], holderFactory);
+ case NUdf::EDataSlot::Uint16:
+ return std::make_unique<TAvgBlockAggregator<ui16, arrow::UInt16Scalar>>(filterColumn, argsColumns[0], holderFactory);
+ case NUdf::EDataSlot::Int32:
+ return std::make_unique<TAvgBlockAggregator<i32, arrow::Int32Scalar>>(filterColumn, argsColumns[0], holderFactory);
+ case NUdf::EDataSlot::Uint32:
+ return std::make_unique<TAvgBlockAggregator<ui32, arrow::UInt32Scalar>>(filterColumn, argsColumns[0], holderFactory);
+ case NUdf::EDataSlot::Int64:
+ return std::make_unique<TAvgBlockAggregator<i64, arrow::Int64Scalar>>(filterColumn, argsColumns[0], holderFactory);
+ case NUdf::EDataSlot::Uint64:
+ return std::make_unique<TAvgBlockAggregator<ui64, arrow::UInt64Scalar>>(filterColumn, argsColumns[0], holderFactory);
+ default:
+ throw yexception() << "Unsupported AVG input type";
+ }
+ }
+};
+
std::unique_ptr<IBlockAggregatorFactory> MakeBlockSumFactory() {
return std::make_unique<TBlockSumFactory>();
}
+
+std::unique_ptr<IBlockAggregatorFactory> MakeBlockAvgFactory() {
+ return std::make_unique<TBlockAvgFactory>();
+}
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.h
index 14ac5d8c13..403b21f560 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.h
@@ -5,6 +5,7 @@ namespace NKikimr {
namespace NMiniKQL {
std::unique_ptr<IBlockAggregatorFactory> MakeBlockSumFactory();
+std::unique_ptr<IBlockAggregatorFactory> MakeBlockAvgFactory();
}
}