summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <[email protected]>2023-05-26 23:40:06 +0300
committervvvv <[email protected]>2023-05-26 23:40:06 +0300
commit92529fe7ede3280622c6a21f3f50b303bd0ca765 (patch)
treedd7ccd00f3c6fda5bbd5ea31c83a1b0409b901e9
parent4eb34319ae07ef0fd4a7e7a9f3bd07c15dbb724d (diff)
Support of filter column in PG combine all aggregators
-rw-r--r--ydb/library/yql/parser/pg_wrapper/arrow.h44
1 files changed, 34 insertions, 10 deletions
diff --git a/ydb/library/yql/parser/pg_wrapper/arrow.h b/ydb/library/yql/parser/pg_wrapper/arrow.h
index 00472433a4f..1ef5710cf22 100644
--- a/ydb/library/yql/parser/pg_wrapper/arrow.h
+++ b/ydb/library/yql/parser/pg_wrapper/arrow.h
@@ -458,20 +458,24 @@ public:
{}
private:
+ template <bool HasFilter>
class TCombineAllAggregator : public NKikimr::NMiniKQL::TCombineAllTag::TBase {
public:
using TBase = NKikimr::NMiniKQL::TCombineAllTag::TBase;
TCombineAllAggregator(TTransFunc transFunc, TSerializeFunc serializeFunc, const std::vector<ui32>& argsColumns,
- const NPg::TAggregateDesc& aggDesc, NKikimr::NMiniKQL::TComputationContext& ctx)
+ std::optional<ui32> filterColumn, const NPg::TAggregateDesc& aggDesc, NKikimr::NMiniKQL::TComputationContext& ctx)
: TBase(sizeof(NullableDatum), std::optional<ui32>(), ctx)
, TransFunc(transFunc)
, SerializeFunc(serializeFunc)
, ArgsColumns(argsColumns)
+ , FilterColumn(filterColumn)
, AggDesc(aggDesc)
{
if (!HasInitValue && IsTransStrict) {
Y_ENSURE(AggDesc.ArgTypes.size() == 1);
}
+
+ Y_ENSURE(HasFilter == FilterColumn.has_value());
const auto& transDesc = NPg::LookupProc(AggDesc.TransFuncId);
for (ui32 i = 1; i < transDesc.ArgTypes.size(); ++i) {
@@ -558,25 +562,33 @@ private:
}
}
+ const ui8* filterBitmap = nullptr;
+ if constexpr(HasFilter) {
+ const auto& filterDatum = NKikimr::NMiniKQL::TArrowBlock::From(columns[*FilterColumn]).GetDatum();
+ const auto& filterArray = filterDatum.array();
+ Y_ENSURE(filterArray->GetNullCount() == 0);
+ filterBitmap = filterArray->template GetValues<uint8_t>(1);
+ }
+
WithPgTry(AggDesc.Name, [&]() {
if (hasNulls) {
if (hasScalars) {
- AddManyImpl<true, true>(typedState, values, batchLength);
+ AddManyImpl<true, true>(typedState, values, batchLength, filterBitmap);
} else {
- AddManyImpl<true, false>(typedState, values, batchLength);
+ AddManyImpl<true, false>(typedState, values, batchLength, filterBitmap);
}
} else {
if (hasScalars) {
- AddManyImpl<false, true>(typedState, values, batchLength);
+ AddManyImpl<false, true>(typedState, values, batchLength, filterBitmap);
} else {
- AddManyImpl<false, false>(typedState, values, batchLength);
+ AddManyImpl<false, false>(typedState, values, batchLength, filterBitmap);
}
}
});
}
template <bool HasNulls, bool HasScalars>
- void AddManyImpl(NullableDatum* typedState, const std::vector<arrow::Datum>& values, ui64 batchLength) {
+ void AddManyImpl(NullableDatum* typedState, const std::vector<arrow::Datum>& values, ui64 batchLength, const ui8* filterBitmap) {
LOCAL_FCINFO(transCallInfo, FUNC_MAX_ARGS);
transCallInfo->flinfo = &TransFuncInfo;
transCallInfo->nargs = 1;
@@ -589,6 +601,12 @@ private:
inputArgsAccessor.Bind(values, 1);
for (ui64 i = 0; i < batchLength; ++i) {
+ if constexpr (HasFilter) {
+ if (!filterBitmap[i]) {
+ continue;
+ }
+ }
+
Datum ret;
if constexpr (!TTransArgsPolicy::VarArgs) {
if (!constexpr_for_tuple([&](auto const& j, auto const& v) {
@@ -710,6 +728,7 @@ SkipCall:;
const TTransFunc TransFunc;
const TSerializeFunc SerializeFunc;
const std::vector<ui32> ArgsColumns;
+ const std::optional<ui32> FilterColumn;
const NPg::TAggregateDesc& AggDesc;
std::vector<bool> IsFixedArg;
bool IsTransTypeCString;
@@ -722,22 +741,28 @@ SkipCall:;
class TPreparedCombineAllAggregator : public NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorCombineAll>{
public:
TPreparedCombineAllAggregator(TTransFunc transFunc, TSerializeFunc serializeFunc, const std::vector<ui32>& argsColumns,
- const NPg::TAggregateDesc& aggDesc)
+ std::optional<ui32> filterColumn, const NPg::TAggregateDesc& aggDesc)
: IPreparedBlockAggregator(sizeof(NullableDatum))
, TransFunc(transFunc)
, SerializeFunc(serializeFunc)
, ArgsColumns(argsColumns)
+ , FilterColumn(filterColumn)
, AggDesc(aggDesc)
{}
private:
std::unique_ptr<NKikimr::NMiniKQL::IBlockAggregatorCombineAll> Make(NKikimr::NMiniKQL::TComputationContext& ctx) const {
- return std::make_unique<TCombineAllAggregator>(TransFunc, SerializeFunc, ArgsColumns, AggDesc, ctx);
+ if (FilterColumn.has_value()) {
+ return std::make_unique<TCombineAllAggregator<true>>(TransFunc, SerializeFunc, ArgsColumns, FilterColumn, AggDesc, ctx);
+ } else {
+ return std::make_unique<TCombineAllAggregator<false>>(TransFunc, SerializeFunc, ArgsColumns, FilterColumn, AggDesc, ctx);
+ }
}
const TTransFunc TransFunc;
const TSerializeFunc SerializeFunc;
const std::vector<ui32> ArgsColumns;
+ const std::optional<ui32> FilterColumn;
const NPg::TAggregateDesc& AggDesc;
};
@@ -746,8 +771,7 @@ public:
std::optional<ui32> filterColumn,
const std::vector<ui32>& argsColumns,
const NPg::TAggregateDesc& aggDesc) const {
- Y_ENSURE(!filterColumn); // TODO
- return std::make_unique<TPreparedCombineAllAggregator>(TransFunc, SerializeFunc, argsColumns, aggDesc);
+ return std::make_unique<TPreparedCombineAllAggregator>(TransFunc, SerializeFunc, argsColumns, filterColumn, aggDesc);
}
std::unique_ptr<NKikimr::NMiniKQL::IPreparedBlockAggregator<NKikimr::NMiniKQL::IBlockAggregatorCombineKeys>> PrepareCombineKeys(