diff options
author | vvvv <[email protected]> | 2023-05-26 23:40:06 +0300 |
---|---|---|
committer | vvvv <[email protected]> | 2023-05-26 23:40:06 +0300 |
commit | 92529fe7ede3280622c6a21f3f50b303bd0ca765 (patch) | |
tree | dd7ccd00f3c6fda5bbd5ea31c83a1b0409b901e9 | |
parent | 4eb34319ae07ef0fd4a7e7a9f3bd07c15dbb724d (diff) |
Support of filter column in PG combine all aggregators
-rw-r--r-- | ydb/library/yql/parser/pg_wrapper/arrow.h | 44 |
1 files changed, 34 insertions, 10 deletions
diff --git a/ydb/library/yql/parser/pg_wrapper/arrow.h b/ydb/library/yql/parser/pg_wrapper/arrow.h index 00472433a4f..1ef5710cf22 100644 --- a/ydb/library/yql/parser/pg_wrapper/arrow.h +++ b/ydb/library/yql/parser/pg_wrapper/arrow.h @@ -458,20 +458,24 @@ public: {} private: + template <bool HasFilter> class TCombineAllAggregator : public NKikimr::NMiniKQL::TCombineAllTag::TBase { public: using TBase = NKikimr::NMiniKQL::TCombineAllTag::TBase; TCombineAllAggregator(TTransFunc transFunc, TSerializeFunc serializeFunc, const std::vector<ui32>& argsColumns, - const NPg::TAggregateDesc& aggDesc, NKikimr::NMiniKQL::TComputationContext& ctx) + std::optional<ui32> filterColumn, const NPg::TAggregateDesc& aggDesc, NKikimr::NMiniKQL::TComputationContext& ctx) : TBase(sizeof(NullableDatum), std::optional<ui32>(), ctx) , TransFunc(transFunc) , SerializeFunc(serializeFunc) , ArgsColumns(argsColumns) + , FilterColumn(filterColumn) , AggDesc(aggDesc) { if (!HasInitValue && IsTransStrict) { Y_ENSURE(AggDesc.ArgTypes.size() == 1); } + + Y_ENSURE(HasFilter == FilterColumn.has_value()); const auto& transDesc = NPg::LookupProc(AggDesc.TransFuncId); for (ui32 i = 1; i < transDesc.ArgTypes.size(); ++i) { @@ -558,25 +562,33 @@ private: } } + const ui8* filterBitmap = nullptr; + if constexpr(HasFilter) { + const auto& filterDatum = NKikimr::NMiniKQL::TArrowBlock::From(columns[*FilterColumn]).GetDatum(); + const auto& filterArray = filterDatum.array(); + Y_ENSURE(filterArray->GetNullCount() == 0); + filterBitmap = filterArray->template GetValues<uint8_t>(1); + } + WithPgTry(AggDesc.Name, [&]() { if (hasNulls) { if (hasScalars) { - AddManyImpl<true, true>(typedState, values, batchLength); + AddManyImpl<true, true>(typedState, values, batchLength, filterBitmap); } else { - AddManyImpl<true, false>(typedState, values, batchLength); + AddManyImpl<true, false>(typedState, values, batchLength, filterBitmap); } } else { if (hasScalars) { - AddManyImpl<false, true>(typedState, values, batchLength); + AddManyImpl<false, true>(typedState, values, batchLength, filterBitmap); } else { - AddManyImpl<false, false>(typedState, values, batchLength); + AddManyImpl<false, false>(typedState, values, batchLength, filterBitmap); } } }); } template <bool HasNulls, bool HasScalars> - void AddManyImpl(NullableDatum* typedState, const std::vector<arrow::Datum>& values, ui64 batchLength) { + void AddManyImpl(NullableDatum* typedState, const std::vector<arrow::Datum>& values, ui64 batchLength, const ui8* filterBitmap) { LOCAL_FCINFO(transCallInfo, FUNC_MAX_ARGS); transCallInfo->flinfo = &TransFuncInfo; transCallInfo->nargs = 1; @@ -589,6 +601,12 @@ private: inputArgsAccessor.Bind(values, 1); for (ui64 i = 0; i < batchLength; ++i) { + if constexpr (HasFilter) { + if (!filterBitmap[i]) { + continue; + } + } + Datum ret; if constexpr (!TTransArgsPolicy::VarArgs) { if (!constexpr_for_tuple([&](auto const& j, auto const& v) { @@ -710,6 +728,7 @@ SkipCall:; const TTransFunc TransFunc; const TSerializeFunc SerializeFunc; const std::vector<ui32> ArgsColumns; + const std::optional<ui32> FilterColumn; const NPg::TAggregateDesc& AggDesc; std::vector<bool> IsFixedArg; bool IsTransTypeCString; @@ -722,22 +741,28 @@ SkipCall:; class TPreparedCombineAllAggregator : public NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorCombineAll>{ public: TPreparedCombineAllAggregator(TTransFunc transFunc, TSerializeFunc serializeFunc, const std::vector<ui32>& argsColumns, - const NPg::TAggregateDesc& aggDesc) + std::optional<ui32> filterColumn, const NPg::TAggregateDesc& aggDesc) : IPreparedBlockAggregator(sizeof(NullableDatum)) , TransFunc(transFunc) , SerializeFunc(serializeFunc) , ArgsColumns(argsColumns) + , FilterColumn(filterColumn) , AggDesc(aggDesc) {} private: std::unique_ptr<NKikimr::NMiniKQL::IBlockAggregatorCombineAll> Make(NKikimr::NMiniKQL::TComputationContext& ctx) const { - return std::make_unique<TCombineAllAggregator>(TransFunc, SerializeFunc, ArgsColumns, AggDesc, ctx); + if (FilterColumn.has_value()) { + return std::make_unique<TCombineAllAggregator<true>>(TransFunc, SerializeFunc, ArgsColumns, FilterColumn, AggDesc, ctx); + } else { + return std::make_unique<TCombineAllAggregator<false>>(TransFunc, SerializeFunc, ArgsColumns, FilterColumn, AggDesc, ctx); + } } const TTransFunc TransFunc; const TSerializeFunc SerializeFunc; const std::vector<ui32> ArgsColumns; + const std::optional<ui32> FilterColumn; const NPg::TAggregateDesc& AggDesc; }; @@ -746,8 +771,7 @@ public: std::optional<ui32> filterColumn, const std::vector<ui32>& argsColumns, const NPg::TAggregateDesc& aggDesc) const { - Y_ENSURE(!filterColumn); // TODO - return std::make_unique<TPreparedCombineAllAggregator>(TransFunc, SerializeFunc, argsColumns, aggDesc); + return std::make_unique<TPreparedCombineAllAggregator>(TransFunc, SerializeFunc, argsColumns, filterColumn, aggDesc); } std::unique_ptr<NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorCombineKeys>> PrepareCombineKeys( |