diff options
author | vvvv <[email protected]> | 2024-12-19 09:31:27 +0300 |
---|---|---|
committer | vvvv <[email protected]> | 2024-12-19 10:03:05 +0300 |
commit | e9866bdcbc5f04545c60374b690b1e04005e091b (patch) | |
tree | 20f48408e6b219c967b482e0d48e54fe19e9142b | |
parent | 2ecfb874d70a7c6b6261f7581508c176f389d59c (diff) |
Reduce bloat in ToDict
before 2936M
after 2897M
commit_hash:ce4eacbffea54983891bbb1be7675efcdc84ceee
-rw-r--r-- | yql/essentials/minikql/comp_nodes/mkql_todict.cpp | 397 |
1 files changed, 242 insertions, 155 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_todict.cpp b/yql/essentials/minikql/comp_nodes/mkql_todict.cpp index bd396695692..af4d937855e 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_todict.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_todict.cpp @@ -24,7 +24,66 @@ using NYql::EnsureDynamicCast; namespace { -class THashedMultiMapAccumulator { +class ISetAccumulator { +public: + virtual ~ISetAccumulator() = default; + virtual void Add(NUdf::TUnboxedValue&& key) = 0; + virtual NUdf::TUnboxedValue Build() = 0; +}; + +class ISetAccumulatorFactory { +public: + virtual ~ISetAccumulatorFactory() = default; + virtual bool IsSorted() const = 0; + virtual std::unique_ptr<ISetAccumulator> Create(TType* keyType, const TKeyTypes& keyTypes, bool isTuple, bool encoded, + const NUdf::ICompare* compare, const NUdf::IEquate* equate, const NUdf::IHash* hash, TComputationContext& ctx, + ui64 itemsCountHint) const = 0; +}; + +class IMapAccumulator { +public: + virtual ~IMapAccumulator() = default; + virtual void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) = 0; + virtual NUdf::TUnboxedValue Build() = 0; +}; + +class IMapAccumulatorFactory { +public: + virtual ~IMapAccumulatorFactory() = default; + virtual bool IsSorted() const = 0; + virtual std::unique_ptr<IMapAccumulator> Create(TType* keyType, TType* payloadType, const TKeyTypes& keyTypes, bool isTuple, bool encoded, + const NUdf::ICompare* compare, const NUdf::IEquate* equate, const NUdf::IHash* hash, TComputationContext& ctx, ui64 itemsCountHint) const = 0; +}; + +template <typename T> +class TSetAccumulatorFactory : public ISetAccumulatorFactory { +public: + bool IsSorted() const final { + return T::IsSorted; + } + + std::unique_ptr<ISetAccumulator> Create(TType* keyType, const TKeyTypes& keyTypes, bool isTuple, bool encoded, + const NUdf::ICompare* compare, const NUdf::IEquate* equate, const NUdf::IHash* hash, TComputationContext& ctx, + ui64 itemsCountHint) const { + return std::make_unique<T>(keyType, keyTypes, isTuple, encoded, compare, equate, hash, ctx, itemsCountHint); + } +}; + +template <typename T> +class TMapAccumulatorFactory : public IMapAccumulatorFactory { +public: + bool IsSorted() const final { + return T::IsSorted; + } + + std::unique_ptr<IMapAccumulator> Create(TType* keyType, TType* payloadType, const TKeyTypes& keyTypes, bool isTuple, bool encoded, + const NUdf::ICompare* compare, const NUdf::IEquate* equate, const NUdf::IHash* hash, TComputationContext& ctx, + ui64 itemsCountHint) const { + return std::make_unique<T>(keyType, payloadType, keyTypes, isTuple, encoded, compare, equate, hash, ctx, itemsCountHint); + } +}; + +class THashedMultiMapAccumulator : public IMapAccumulator { using TMapType = TValuesDictHashMap; TComputationContext& Ctx; @@ -54,7 +113,7 @@ public: Map.reserve(itemsCountHint); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { if (Packer) { key = MakeString(Packer->Pack(key)); @@ -67,7 +126,7 @@ public: it->second.Push(std::move(payload)); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { const auto filler = [this](TValuesDictHashMap& targetMap) { targetMap = std::move(Map); @@ -77,7 +136,7 @@ public: } }; -class THashedMapAccumulator { +class THashedMapAccumulator : public IMapAccumulator { using TMapType = TValuesDictHashMap; TComputationContext& Ctx; @@ -107,7 +166,7 @@ public: Map.reserve(itemsCountHint); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { if (Packer) { key = MakeString(Packer->Pack(key)); @@ -116,7 +175,7 @@ public: Map.emplace(std::move(key), std::move(payload)); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { const auto filler = [this](TMapType& targetMap) { targetMap = std::move(Map); @@ -127,7 +186,7 @@ public: }; template<typename T, bool OptionalKey> -class THashedSingleFixedMultiMapAccumulator { +class THashedSingleFixedMultiMapAccumulator : public IMapAccumulator { using TMapType = TValuesDictHashSingleFixedMap<T>; TComputationContext& Ctx; @@ -152,7 +211,7 @@ public: CurrentEmptyVectorForInsert = Ctx.HolderFactory.NewVectorHolder(); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) { + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { if constexpr (OptionalKey) { if (!key) { NullPayloads.emplace_back(std::move(payload)); @@ -166,7 +225,7 @@ public: insertInfo.first->second.Push(payload.Release()); } - NUdf::TUnboxedValue Build() { + NUdf::TUnboxedValue Build() final { std::optional<NUdf::TUnboxedValue> nullPayload; if (NullPayloads.size()) { nullPayload = Ctx.HolderFactory.VectorAsVectorHolder(std::move(NullPayloads)); @@ -176,7 +235,7 @@ public: }; template<typename T, bool OptionalKey> -class THashedSingleFixedMapAccumulator { +class THashedSingleFixedMapAccumulator : public IMapAccumulator { using TMapType = TValuesDictHashSingleFixedMap<T>; TComputationContext& Ctx; @@ -201,7 +260,7 @@ public: Map.reserve(itemsCountHint); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { if constexpr (OptionalKey) { if (!key) { @@ -212,13 +271,13 @@ public: Map.emplace(key.Get<T>(), std::move(payload)); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { return Ctx.HolderFactory.CreateDirectHashedSingleFixedMapHolder<T, OptionalKey>(std::move(Map), std::move(NullPayload)); } }; -class THashedSetAccumulator { +class THashedSetAccumulator : public ISetAccumulator { using TSetType = TValuesDictHashSet; TComputationContext& Ctx; @@ -246,7 +305,7 @@ public: Set.reserve(itemsCountHint); } - void Add(NUdf::TUnboxedValue&& key) + void Add(NUdf::TUnboxedValue&& key) final { if (Packer) { key = MakeString(Packer->Pack(key)); @@ -255,7 +314,7 @@ public: Set.emplace(std::move(key)); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { const auto filler = [this](TSetType& targetSet) { targetSet = std::move(Set); @@ -266,7 +325,7 @@ public: }; template <typename T, bool OptionalKey> -class THashedSingleFixedSetAccumulator { +class THashedSingleFixedSetAccumulator : public ISetAccumulator{ using TSetType = TValuesDictHashSingleFixedSet<T>; TComputationContext& Ctx; @@ -290,7 +349,7 @@ public: Set.reserve(itemsCountHint); } - void Add(NUdf::TUnboxedValue&& key) + void Add(NUdf::TUnboxedValue&& key) final { if constexpr (OptionalKey) { if (!key) { @@ -301,14 +360,14 @@ public: Set.emplace(key.Get<T>()); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { return Ctx.HolderFactory.CreateDirectHashedSingleFixedSetHolder<T, OptionalKey>(std::move(Set), HasNull); } }; template <typename T, bool OptionalKey> -class THashedSingleFixedCompactSetAccumulator { +class THashedSingleFixedCompactSetAccumulator : public ISetAccumulator { using TSetType = TValuesDictHashSingleFixedCompactSet<T>; TComputationContext& Ctx; @@ -333,7 +392,7 @@ public: Set.SetMaxLoadFactor(COMPACT_HASH_MAX_LOAD_FACTOR); } - void Add(NUdf::TUnboxedValue&& key) + void Add(NUdf::TUnboxedValue&& key) final { if constexpr (OptionalKey) { if (!key) { @@ -344,13 +403,13 @@ public: Set.Insert(key.Get<T>()); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { return Ctx.HolderFactory.CreateDirectHashedSingleFixedCompactSetHolder<T, OptionalKey>(std::move(Set), HasNull); } }; -class THashedCompactSetAccumulator { +class THashedCompactSetAccumulator : public ISetAccumulator { using TSetType = TValuesDictHashCompactSet; TComputationContext& Ctx; @@ -376,12 +435,12 @@ public: Set.SetMaxLoadFactor(COMPACT_HASH_MAX_LOAD_FACTOR); } - void Add(NUdf::TUnboxedValue&& key) + void Add(NUdf::TUnboxedValue&& key) final { Set.Insert(AddSmallValue(Pool, KeyPacker->Pack(key))); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { return Ctx.HolderFactory.CreateDirectHashedCompactSetHolder(std::move(Set), std::move(Pool), KeyType, &Ctx); } @@ -391,7 +450,7 @@ template <bool Multi> class THashedCompactMapAccumulator; template <> -class THashedCompactMapAccumulator<false> { +class THashedCompactMapAccumulator<false> : public IMapAccumulator { using TMapType = TValuesDictHashCompactMap; TComputationContext& Ctx; @@ -419,19 +478,19 @@ public: Map.SetMaxLoadFactor(COMPACT_HASH_MAX_LOAD_FACTOR); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { Map.InsertNew(AddSmallValue(Pool, KeyPacker->Pack(key)), AddSmallValue(Pool, PayloadPacker->Pack(payload))); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { return Ctx.HolderFactory.CreateDirectHashedCompactMapHolder(std::move(Map), std::move(Pool), KeyType, PayloadType, &Ctx); } }; template <> -class THashedCompactMapAccumulator<true> { +class THashedCompactMapAccumulator<true> : public IMapAccumulator { using TMapType = TValuesDictHashCompactMultiMap; TComputationContext& Ctx; @@ -459,12 +518,12 @@ public: Map.SetMaxLoadFactor(COMPACT_HASH_MAX_LOAD_FACTOR); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { Map.Insert(AddSmallValue(Pool, KeyPacker->Pack(key)), AddSmallValue(Pool, PayloadPacker->Pack(payload))); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { return Ctx.HolderFactory.CreateDirectHashedCompactMultiMapHolder(std::move(Map), std::move(Pool), KeyType, PayloadType, &Ctx); } @@ -474,7 +533,7 @@ template <typename T, bool OptionalKey, bool Multi> class THashedSingleFixedCompactMapAccumulator; template <typename T, bool OptionalKey> -class THashedSingleFixedCompactMapAccumulator<T, OptionalKey, false> { +class THashedSingleFixedCompactMapAccumulator<T, OptionalKey, false> : public IMapAccumulator { using TMapType = TValuesDictHashSingleFixedCompactMap<T>; TComputationContext& Ctx; @@ -502,7 +561,7 @@ public: Map.SetMaxLoadFactor(COMPACT_HASH_MAX_LOAD_FACTOR); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { if constexpr (OptionalKey) { if (!key) { @@ -513,14 +572,14 @@ public: Map.InsertNew(key.Get<T>(), AddSmallValue(Pool, PayloadPacker->Pack(payload))); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { return Ctx.HolderFactory.CreateDirectHashedSingleFixedCompactMapHolder<T, OptionalKey>(std::move(Map), std::move(NullPayload), std::move(Pool), PayloadType, &Ctx); } }; template <typename T, bool OptionalKey> -class THashedSingleFixedCompactMapAccumulator<T, OptionalKey, true> { +class THashedSingleFixedCompactMapAccumulator<T, OptionalKey, true> : public IMapAccumulator { using TMapType = TValuesDictHashSingleFixedCompactMultiMap<T>; TComputationContext& Ctx; @@ -548,7 +607,7 @@ public: Map.SetMaxLoadFactor(COMPACT_HASH_MAX_LOAD_FACTOR); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { if constexpr (OptionalKey) { if (!key) { @@ -559,13 +618,13 @@ public: Map.Insert(key.Get<T>(), AddSmallValue(Pool, PayloadPacker->Pack(payload))); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { return Ctx.HolderFactory.CreateDirectHashedSingleFixedCompactMultiMapHolder<T, OptionalKey>(std::move(Map), std::move(NullPayloads), std::move(Pool), PayloadType, &Ctx); } }; -class TSortedSetAccumulator { +class TSortedSetAccumulator : public ISetAccumulator { TComputationContext& Ctx; TType* KeyType; const TKeyTypes& KeyTypes; @@ -591,7 +650,7 @@ public: Items.reserve(itemsCountHint); } - void Add(NUdf::TUnboxedValue&& key) + void Add(NUdf::TUnboxedValue&& key) final { if (Packer) { key = MakeString(Packer->Encode(key, false)); @@ -600,7 +659,7 @@ public: Items.emplace_back(std::move(key)); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { const TSortedSetFiller filler = [this](TUnboxedValueVector& values) { std::stable_sort(Items.begin(), Items.end(), TValueLess(KeyTypes, IsTuple, Compare)); @@ -617,7 +676,7 @@ template<bool IsMulti> class TSortedMapAccumulator; template<> -class TSortedMapAccumulator<false> { +class TSortedMapAccumulator<false> : public IMapAccumulator { TComputationContext& Ctx; TType* KeyType; const TKeyTypes& KeyTypes; @@ -649,7 +708,7 @@ public: Items.reserve(itemsCountHint); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { if (Packer) { key = MakeString(Packer->Encode(key, false)); @@ -658,7 +717,7 @@ public: Items.emplace_back(std::move(key), std::move(payload)); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { const TSortedDictFiller filler = [this](TKeyPayloadPairVector& values) { values = std::move(Items); @@ -670,7 +729,7 @@ public: }; template<> -class TSortedMapAccumulator<true> { +class TSortedMapAccumulator<true> : public IMapAccumulator { TComputationContext& Ctx; TType* KeyType; const TKeyTypes& KeyTypes; @@ -696,7 +755,7 @@ public: Items.reserve(itemsCountHint); } - void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) + void Add(NUdf::TUnboxedValue&& key, NUdf::TUnboxedValue&& payload) final { if (Packer) { key = MakeString(Packer->Encode(key, false)); @@ -705,7 +764,7 @@ public: Items.emplace_back(std::move(key), std::move(payload)); } - NUdf::TUnboxedValue Build() + NUdf::TUnboxedValue Build() final { const TSortedDictFiller filler = [this](TKeyPayloadPairVector& values) { std::stable_sort(Items.begin(), Items.end(), TKeyPayloadPairLess(KeyTypes, IsTuple, Compare)); @@ -740,14 +799,13 @@ public: } }; -template <typename TSetAccumulator, bool IsStream> -class TSetWrapper : public TMutableComputationNode<TSetWrapper<TSetAccumulator, IsStream>> { - typedef TMutableComputationNode<TSetWrapper<TSetAccumulator, IsStream>> TBaseComputation; +class TSetWrapper : public TMutableComputationNode<TSetWrapper> { + typedef TMutableComputationNode<TSetWrapper> TBaseComputation; public: class TStreamValue : public TComputationValue<TStreamValue> { public: TStreamValue(TMemoryUsageInfo* memInfo, NUdf::TUnboxedValue&& input, IComputationExternalNode* const item, - IComputationNode* const key, TSetAccumulator&& setAccum, TComputationContext& ctx) + IComputationNode* const key, std::unique_ptr<ISetAccumulator>&& setAccum, TComputationContext& ctx) : TComputationValue<TStreamValue>(memInfo) , Input(std::move(input)) , Item(item) @@ -766,11 +824,11 @@ public: switch (auto status = Input.Fetch(item)) { case NUdf::EFetchStatus::Ok: { Item->SetValue(Ctx, std::move(item)); - SetAccum.Add(Key->GetValue(Ctx)); + SetAccum->Add(Key->GetValue(Ctx)); break; // and continue } case NUdf::EFetchStatus::Finish: { - result = SetAccum.Build(); + result = SetAccum->Build(); Finished = true; return NUdf::EFetchStatus::Ok; } @@ -784,31 +842,33 @@ public: NUdf::TUnboxedValue Input; IComputationExternalNode* const Item; IComputationNode* const Key; - TSetAccumulator SetAccum; + const std::unique_ptr<ISetAccumulator> SetAccum; TComputationContext& Ctx; bool Finished = false; }; TSetWrapper(TComputationMutables& mutables, TType* keyType, IComputationNode* list, IComputationExternalNode* item, - IComputationNode* key, ui64 itemsCountHint) + IComputationNode* key, ui64 itemsCountHint, bool isStream, std::unique_ptr<ISetAccumulatorFactory> factory) : TBaseComputation(mutables, EValueRepresentation::Boxed) , KeyType(keyType) , List(list) , Item(item) , Key(key) , ItemsCountHint(itemsCountHint) + , IsStream(isStream) + , Factory(std::move(factory)) { GetDictionaryKeyTypes(KeyType, KeyTypes, IsTuple, Encoded, UseIHash); - Compare = UseIHash && TSetAccumulator::IsSorted ? MakeCompareImpl(KeyType) : nullptr; + Compare = UseIHash && Factory->IsSorted() ? MakeCompareImpl(KeyType) : nullptr; Equate = UseIHash ? MakeEquateImpl(KeyType) : nullptr; - Hash = UseIHash && !TSetAccumulator::IsSorted ? MakeHashImpl(KeyType) : nullptr; + Hash = UseIHash && !Factory->IsSorted() ? MakeHashImpl(KeyType) : nullptr; } NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - if constexpr (IsStream) { + if (IsStream) { return ctx.HolderFactory.Create<TStreamValue>(List->GetValue(ctx), Item, Key, - TSetAccumulator(KeyType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), + Factory->Create(KeyType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), ctx, ItemsCountHint), ctx); } @@ -821,17 +881,17 @@ public: return ctx.HolderFactory.GetEmptyContainerLazy(); } - TSetAccumulator accumulator(KeyType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), + auto acc = Factory->Create(KeyType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), ctx, itemsCountHint); TThresher<false>::DoForEachItem(list, - [this, &accumulator, &ctx] (NUdf::TUnboxedValue&& item) { + [this, &acc, &ctx] (NUdf::TUnboxedValue&& item) { Item->SetValue(ctx, std::move(item)); - accumulator.Add(Key->GetValue(ctx)); + acc->Add(Key->GetValue(ctx)); } ); - return accumulator.Build().Release(); + return acc->Build().Release(); } private: @@ -846,6 +906,8 @@ private: IComputationExternalNode* const Item; IComputationNode* const Key; const ui64 ItemsCountHint; + const bool IsStream; + const std::unique_ptr<ISetAccumulatorFactory> Factory; TKeyTypes KeyTypes; bool IsTuple; bool Encoded; @@ -882,42 +944,43 @@ public: }; #endif -template <typename TSetAccumulator> -class TSqueezeSetFlowWrapper : public TStatefulFlowCodegeneratorNode<TSqueezeSetFlowWrapper<TSetAccumulator>> { - using TBase = TStatefulFlowCodegeneratorNode<TSqueezeSetFlowWrapper<TSetAccumulator>>; +class TSqueezeSetFlowWrapper : public TStatefulFlowCodegeneratorNode<TSqueezeSetFlowWrapper> { + using TBase = TStatefulFlowCodegeneratorNode<TSqueezeSetFlowWrapper>; public: class TState : public TComputationValue<TState> { using TBase = TComputationValue<TState>; public: - TState(TMemoryUsageInfo* memInfo, TSetAccumulator&& setAccum) + TState(TMemoryUsageInfo* memInfo, std::unique_ptr<ISetAccumulator>&& setAccum) : TBase(memInfo), SetAccum(std::move(setAccum)) {} NUdf::TUnboxedValuePod Build() { - return SetAccum.Build().Release(); + return SetAccum->Build().Release(); } void Insert(NUdf::TUnboxedValuePod value) { - SetAccum.Add(value); + SetAccum->Add(value); } private: - TSetAccumulator SetAccum; + const std::unique_ptr<ISetAccumulator> SetAccum; }; TSqueezeSetFlowWrapper(TComputationMutables& mutables, TType* keyType, - IComputationNode* flow, IComputationExternalNode* item, IComputationNode* key, ui64 itemsCountHint) + IComputationNode* flow, IComputationExternalNode* item, IComputationNode* key, ui64 itemsCountHint, + std::unique_ptr<ISetAccumulatorFactory> factory) : TBase(mutables, flow, EValueRepresentation::Boxed, EValueRepresentation::Any) , KeyType(keyType) , Flow(flow) , Item(item) , Key(key) , ItemsCountHint(itemsCountHint) + , Factory(std::move(factory)) { GetDictionaryKeyTypes(KeyType, KeyTypes, IsTuple, Encoded, UseIHash); - Compare = UseIHash && TSetAccumulator::IsSorted ? MakeCompareImpl(KeyType) : nullptr; + Compare = UseIHash && Factory->IsSorted() ? MakeCompareImpl(KeyType) : nullptr; Equate = UseIHash ? MakeEquateImpl(KeyType) : nullptr; - Hash = UseIHash && !TSetAccumulator::IsSorted ? MakeHashImpl(KeyType) : nullptr; + Hash = UseIHash && !Factory->IsSorted() ? MakeHashImpl(KeyType) : nullptr; } NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { @@ -963,7 +1026,7 @@ public: const auto ptrType = PointerType::getUnqual(StructType::get(context)); const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); - const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSqueezeSetFlowWrapper<TSetAccumulator>::MakeState)); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSqueezeSetFlowWrapper::MakeState)); const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false); const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block); @@ -1039,7 +1102,7 @@ public: #endif private: void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const { - state = ctx.HolderFactory.Create<TState>(TSetAccumulator(KeyType, KeyTypes, IsTuple, Encoded, + state = ctx.HolderFactory.Create<TState>(Factory->Create(KeyType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), ctx, ItemsCountHint)); } @@ -1055,6 +1118,7 @@ private: IComputationExternalNode* const Item; IComputationNode* const Key; const ui64 ItemsCountHint; + const std::unique_ptr<ISetAccumulatorFactory> Factory; TKeyTypes KeyTypes; bool IsTuple; bool Encoded; @@ -1065,44 +1129,45 @@ private: NUdf::IHash::TPtr Hash; }; -template <typename TSetAccumulator> -class TSqueezeSetWideWrapper : public TStatefulFlowCodegeneratorNode<TSqueezeSetWideWrapper<TSetAccumulator>> { - using TBase = TStatefulFlowCodegeneratorNode<TSqueezeSetWideWrapper<TSetAccumulator>>; +class TSqueezeSetWideWrapper : public TStatefulFlowCodegeneratorNode<TSqueezeSetWideWrapper> { + using TBase = TStatefulFlowCodegeneratorNode<TSqueezeSetWideWrapper>; public: class TState : public TComputationValue<TState> { using TBase = TComputationValue<TState>; public: - TState(TMemoryUsageInfo* memInfo, TSetAccumulator&& setAccum) + TState(TMemoryUsageInfo* memInfo, std::unique_ptr<ISetAccumulator>&& setAccum) : TBase(memInfo), SetAccum(std::move(setAccum)) {} NUdf::TUnboxedValuePod Build() { - return SetAccum.Build().Release(); + return SetAccum->Build().Release(); } void Insert(NUdf::TUnboxedValuePod value) { - SetAccum.Add(value); + SetAccum->Add(value); } private: - TSetAccumulator SetAccum; + const std::unique_ptr<ISetAccumulator> SetAccum; }; TSqueezeSetWideWrapper(TComputationMutables& mutables, TType* keyType, - IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, IComputationNode* key, ui64 itemsCountHint) + IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, IComputationNode* key, + ui64 itemsCountHint, std::unique_ptr<ISetAccumulatorFactory> factory) : TBase(mutables, flow, EValueRepresentation::Boxed, EValueRepresentation::Any) , KeyType(keyType) , Flow(flow) , Items(std::move(items)) , Key(key) , ItemsCountHint(itemsCountHint) + , Factory(std::move(factory)) , PasstroughKey(GetPasstroughtMap(TComputationNodePtrVector{Key}, Items).front()) , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Items.size())) { GetDictionaryKeyTypes(KeyType, KeyTypes, IsTuple, Encoded, UseIHash); - Compare = UseIHash && TSetAccumulator::IsSorted ? MakeCompareImpl(KeyType) : nullptr; + Compare = UseIHash && Factory->IsSorted() ? MakeCompareImpl(KeyType) : nullptr; Equate = UseIHash ? MakeEquateImpl(KeyType) : nullptr; - Hash = UseIHash && !TSetAccumulator::IsSorted ? MakeHashImpl(KeyType) : nullptr; + Hash = UseIHash && !Factory->IsSorted() ? MakeHashImpl(KeyType) : nullptr; } NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { @@ -1152,7 +1217,7 @@ public: const auto ptrType = PointerType::getUnqual(StructType::get(context)); const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); - const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSqueezeSetWideWrapper<TSetAccumulator>::MakeState)); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSqueezeSetWideWrapper::MakeState)); const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false); const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block); @@ -1234,7 +1299,7 @@ public: #endif private: void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const { - state = ctx.HolderFactory.Create<TState>(TSetAccumulator(KeyType, KeyTypes, IsTuple, Encoded, + state = ctx.HolderFactory.Create<TState>(Factory->Create(KeyType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), ctx, ItemsCountHint)); } @@ -1250,6 +1315,7 @@ private: const TComputationExternalNodePtrVector Items; IComputationNode* const Key; const ui64 ItemsCountHint; + const std::unique_ptr<ISetAccumulatorFactory> Factory; TKeyTypes KeyTypes; bool IsTuple; bool Encoded; @@ -1264,14 +1330,13 @@ private: NUdf::IHash::TPtr Hash; }; -template <typename TMapAccumulator, bool IsStream> -class TMapWrapper : public TMutableComputationNode<TMapWrapper<TMapAccumulator, IsStream>> { - typedef TMutableComputationNode<TMapWrapper<TMapAccumulator, IsStream>> TBaseComputation; +class TMapWrapper : public TMutableComputationNode<TMapWrapper> { + typedef TMutableComputationNode<TMapWrapper> TBaseComputation; public: class TStreamValue : public TComputationValue<TStreamValue> { public: TStreamValue(TMemoryUsageInfo* memInfo, NUdf::TUnboxedValue&& input, IComputationExternalNode* const item, - IComputationNode* const key, IComputationNode* const payload, TMapAccumulator&& mapAccum, TComputationContext& ctx) + IComputationNode* const key, IComputationNode* const payload, std::unique_ptr<IMapAccumulator>&& mapAccum, TComputationContext& ctx) : TComputationValue<TStreamValue>(memInfo) , Input(std::move(input)) , Item(item) @@ -1291,11 +1356,11 @@ public: switch (auto status = Input.Fetch(item)) { case NUdf::EFetchStatus::Ok: { Item->SetValue(Ctx, std::move(item)); - MapAccum.Add(Key->GetValue(Ctx), Payload->GetValue(Ctx)); + MapAccum->Add(Key->GetValue(Ctx), Payload->GetValue(Ctx)); break; // and continue } case NUdf::EFetchStatus::Finish: { - result = MapAccum.Build(); + result = MapAccum->Build(); Finished = true; return NUdf::EFetchStatus::Ok; } @@ -1310,13 +1375,13 @@ public: IComputationExternalNode* const Item; IComputationNode* const Key; IComputationNode* const Payload; - TMapAccumulator MapAccum; + const std::unique_ptr<IMapAccumulator> MapAccum; TComputationContext& Ctx; bool Finished = false; }; TMapWrapper(TComputationMutables& mutables, TType* keyType, TType* payloadType, IComputationNode* list, IComputationExternalNode* item, - IComputationNode* key, IComputationNode* payload, ui64 itemsCountHint) + IComputationNode* key, IComputationNode* payload, ui64 itemsCountHint, bool isStream, std::unique_ptr<IMapAccumulatorFactory> factory) : TBaseComputation(mutables, EValueRepresentation::Boxed) , KeyType(keyType) , PayloadType(payloadType) @@ -1325,18 +1390,20 @@ public: , Key(key) , Payload(payload) , ItemsCountHint(itemsCountHint) + , IsStream(isStream) + , Factory(std::move(factory)) { GetDictionaryKeyTypes(KeyType, KeyTypes, IsTuple, Encoded, UseIHash); - Compare = UseIHash && TMapAccumulator::IsSorted ? MakeCompareImpl(KeyType) : nullptr; + Compare = UseIHash && Factory->IsSorted() ? MakeCompareImpl(KeyType) : nullptr; Equate = UseIHash ? MakeEquateImpl(KeyType) : nullptr; - Hash = UseIHash && !TMapAccumulator::IsSorted ? MakeHashImpl(KeyType) : nullptr; + Hash = UseIHash && !Factory->IsSorted() ? MakeHashImpl(KeyType) : nullptr; } NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - if constexpr (IsStream) { + if (IsStream) { return ctx.HolderFactory.Create<TStreamValue>(List->GetValue(ctx), Item, Key, Payload, - TMapAccumulator(KeyType, PayloadType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), + Factory->Create(KeyType, PayloadType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), ctx, ItemsCountHint), ctx); } @@ -1350,17 +1417,17 @@ public: return ctx.HolderFactory.GetEmptyContainerLazy(); } - TMapAccumulator accumulator(KeyType, PayloadType, KeyTypes, IsTuple, Encoded, + auto acc = Factory->Create(KeyType, PayloadType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), ctx, itemsCountHint); TThresher<false>::DoForEachItem(list, - [this, &accumulator, &ctx] (NUdf::TUnboxedValue&& item) { + [this, &acc, &ctx] (NUdf::TUnboxedValue&& item) { Item->SetValue(ctx, std::move(item)); - accumulator.Add(Key->GetValue(ctx), Payload->GetValue(ctx)); + acc->Add(Key->GetValue(ctx), Payload->GetValue(ctx)); } ); - return accumulator.Build().Release(); + return acc->Build().Release(); } private: @@ -1378,6 +1445,8 @@ private: IComputationNode* const Key; IComputationNode* const Payload; const ui64 ItemsCountHint; + const bool IsStream; + const std::unique_ptr<IMapAccumulatorFactory> Factory; TKeyTypes KeyTypes; bool IsTuple; bool Encoded; @@ -1388,31 +1457,30 @@ private: NUdf::IHash::TPtr Hash; }; -template <typename TMapAccumulator> -class TSqueezeMapFlowWrapper : public TStatefulFlowCodegeneratorNode<TSqueezeMapFlowWrapper<TMapAccumulator>> { - using TBase = TStatefulFlowCodegeneratorNode<TSqueezeMapFlowWrapper<TMapAccumulator>>; +class TSqueezeMapFlowWrapper : public TStatefulFlowCodegeneratorNode<TSqueezeMapFlowWrapper> { + using TBase = TStatefulFlowCodegeneratorNode<TSqueezeMapFlowWrapper>; public: class TState : public TComputationValue<TState> { using TBase = TComputationValue<TState>; public: - TState(TMemoryUsageInfo* memInfo, TMapAccumulator&& mapAccum) + TState(TMemoryUsageInfo* memInfo, std::unique_ptr<IMapAccumulator>&& mapAccum) : TBase(memInfo), MapAccum(std::move(mapAccum)) {} NUdf::TUnboxedValuePod Build() { - return MapAccum.Build().Release(); + return MapAccum->Build().Release(); } void Insert(NUdf::TUnboxedValuePod key, NUdf::TUnboxedValuePod value) { - MapAccum.Add(key, value); + MapAccum->Add(key, value); } private: - TMapAccumulator MapAccum; + const std::unique_ptr<IMapAccumulator> MapAccum; }; TSqueezeMapFlowWrapper(TComputationMutables& mutables, TType* keyType, TType* payloadType, IComputationNode* flow, IComputationExternalNode* item, IComputationNode* key, IComputationNode* payload, - ui64 itemsCountHint) + ui64 itemsCountHint, std::unique_ptr<IMapAccumulatorFactory> factory) : TBase(mutables, flow, EValueRepresentation::Boxed, EValueRepresentation::Any) , KeyType(keyType) , PayloadType(payloadType) @@ -1421,12 +1489,13 @@ public: , Key(key) , Payload(payload) , ItemsCountHint(itemsCountHint) + , Factory(std::move(factory)) { GetDictionaryKeyTypes(KeyType, KeyTypes, IsTuple, Encoded, UseIHash); - Compare = UseIHash && TMapAccumulator::IsSorted ? MakeCompareImpl(KeyType) : nullptr; + Compare = UseIHash && Factory->IsSorted() ? MakeCompareImpl(KeyType) : nullptr; Equate = UseIHash ? MakeEquateImpl(KeyType) : nullptr; - Hash = UseIHash && !TMapAccumulator::IsSorted ? MakeHashImpl(KeyType) : nullptr; + Hash = UseIHash && !Factory->IsSorted() ? MakeHashImpl(KeyType) : nullptr; } NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { @@ -1471,7 +1540,7 @@ public: const auto ptrType = PointerType::getUnqual(StructType::get(context)); const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); - const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSqueezeMapFlowWrapper<TMapAccumulator>::MakeState)); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSqueezeMapFlowWrapper::MakeState)); const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false); const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block); @@ -1549,7 +1618,7 @@ public: #endif private: void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const { - state = ctx.HolderFactory.Create<TState>(TMapAccumulator(KeyType, PayloadType, KeyTypes, IsTuple, Encoded, + state = ctx.HolderFactory.Create<TState>(Factory->Create(KeyType, PayloadType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), ctx, ItemsCountHint)); } @@ -1568,6 +1637,7 @@ private: IComputationNode* const Key; IComputationNode* const Payload; const ui64 ItemsCountHint; + const std::unique_ptr<IMapAccumulatorFactory> Factory; TKeyTypes KeyTypes; bool IsTuple; bool Encoded; @@ -1578,31 +1648,30 @@ private: NUdf::IHash::TPtr Hash; }; -template <typename TMapAccumulator> -class TSqueezeMapWideWrapper : public TStatefulFlowCodegeneratorNode<TSqueezeMapWideWrapper<TMapAccumulator>> { - using TBase = TStatefulFlowCodegeneratorNode<TSqueezeMapWideWrapper<TMapAccumulator>>; +class TSqueezeMapWideWrapper : public TStatefulFlowCodegeneratorNode<TSqueezeMapWideWrapper> { + using TBase = TStatefulFlowCodegeneratorNode<TSqueezeMapWideWrapper>; public: class TState : public TComputationValue<TState> { using TBase = TComputationValue<TState>; public: - TState(TMemoryUsageInfo* memInfo, TMapAccumulator&& mapAccum) + TState(TMemoryUsageInfo* memInfo, std::unique_ptr<IMapAccumulator>&& mapAccum) : TBase(memInfo), MapAccum(std::move(mapAccum)) {} NUdf::TUnboxedValuePod Build() { - return MapAccum.Build().Release(); + return MapAccum->Build().Release(); } void Insert(NUdf::TUnboxedValuePod key, NUdf::TUnboxedValuePod value) { - MapAccum.Add(key, value); + MapAccum->Add(key, value); } private: - TMapAccumulator MapAccum; + const std::unique_ptr<IMapAccumulator> MapAccum; }; TSqueezeMapWideWrapper(TComputationMutables& mutables, TType* keyType, TType* payloadType, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, IComputationNode* key, IComputationNode* payload, - ui64 itemsCountHint) + ui64 itemsCountHint, std::unique_ptr<IMapAccumulatorFactory> factory) : TBase(mutables, flow, EValueRepresentation::Boxed, EValueRepresentation::Any) , KeyType(keyType) , PayloadType(payloadType) @@ -1611,15 +1680,16 @@ public: , Key(key) , Payload(payload) , ItemsCountHint(itemsCountHint) + , Factory(std::move(factory)) , PasstroughKey(GetPasstroughtMap(TComputationNodePtrVector{Key}, Items).front()) , PasstroughPayload(GetPasstroughtMap(TComputationNodePtrVector{Payload}, Items).front()) , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Items.size())) { GetDictionaryKeyTypes(KeyType, KeyTypes, IsTuple, Encoded, UseIHash); - Compare = UseIHash && TMapAccumulator::IsSorted ? MakeCompareImpl(KeyType) : nullptr; + Compare = UseIHash && Factory->IsSorted() ? MakeCompareImpl(KeyType) : nullptr; Equate = UseIHash ? MakeEquateImpl(KeyType) : nullptr; - Hash = UseIHash && !TMapAccumulator::IsSorted ? MakeHashImpl(KeyType) : nullptr; + Hash = UseIHash && !Factory->IsSorted() ? MakeHashImpl(KeyType) : nullptr; } NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { @@ -1669,7 +1739,7 @@ public: const auto ptrType = PointerType::getUnqual(StructType::get(context)); const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); - const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSqueezeMapWideWrapper<TMapAccumulator>::MakeState)); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TSqueezeMapWideWrapper::MakeState)); const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false); const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block); @@ -1753,7 +1823,7 @@ public: #endif private: void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const { - state = ctx.HolderFactory.Create<TState>(TMapAccumulator(KeyType, PayloadType, KeyTypes, IsTuple, Encoded, + state = ctx.HolderFactory.Create<TState>(Factory->Create(KeyType, PayloadType, KeyTypes, IsTuple, Encoded, Compare.Get(), Equate.Get(), Hash.Get(), ctx, ItemsCountHint)); } @@ -1773,6 +1843,7 @@ private: IComputationNode* const Key; IComputationNode* const Payload; const ui64 ItemsCountHint; + const std::unique_ptr<IMapAccumulatorFactory> Factory; TKeyTypes KeyTypes; bool IsTuple; bool Encoded; @@ -1797,26 +1868,28 @@ IComputationNode* WrapToSet(TCallable& callable, const TNodeLocator& nodeLocator const auto flow = LocateNode(nodeLocator, callable, 0U); const auto keySelector = LocateNode(nodeLocator, callable, callable.GetInputsCount() - 5U); + auto factory = std::make_unique<TSetAccumulatorFactory<TAccumulator>>(); + if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) { const auto width = callable.GetInputsCount() - 6U; TComputationExternalNodePtrVector args(width, nullptr); auto index = 0U; std::generate_n(args.begin(), width, [&](){ return LocateExternalNode(nodeLocator, callable, ++index); }); - return new TSqueezeSetWideWrapper<TAccumulator>(mutables, keyType, wide, std::move(args), keySelector, itemsCountHint); + return new TSqueezeSetWideWrapper(mutables, keyType, wide, std::move(args), keySelector, itemsCountHint, std::move(factory)); } const auto itemArg = LocateExternalNode(nodeLocator, callable, 1U); const auto type = callable.GetInput(0U).GetStaticType(); if (type->IsList()) { - return new TSetWrapper<TAccumulator, false>(mutables, keyType, flow, itemArg, keySelector, itemsCountHint); + return new TSetWrapper(mutables, keyType, flow, itemArg, keySelector, itemsCountHint, false, std::move(factory)); } if (type->IsFlow()) { - return new TSqueezeSetFlowWrapper<TAccumulator>(mutables, keyType, flow, itemArg, keySelector, itemsCountHint); + return new TSqueezeSetFlowWrapper(mutables, keyType, flow, itemArg, keySelector, itemsCountHint, std::move(factory)); } if (type->IsStream()) { - return new TSetWrapper<TAccumulator, true>(mutables, keyType, flow, itemArg, keySelector, itemsCountHint); + return new TSetWrapper(mutables, keyType, flow, itemArg, keySelector, itemsCountHint, true, std::move(factory)); } THROW yexception() << "Expected list, flow or stream."; @@ -1833,37 +1906,37 @@ IComputationNode* WrapToMap(TCallable& callable, const TNodeLocator& nodeLocator const auto keySelector = LocateNode(nodeLocator, callable, callable.GetInputsCount() - 5U); const auto payloadSelector = LocateNode(nodeLocator, callable, callable.GetInputsCount() - 4U); + auto factory = std::make_unique<TMapAccumulatorFactory<TAccumulator>>(); if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) { const auto width = callable.GetInputsCount() - 6U; TComputationExternalNodePtrVector args(width, nullptr); auto index = 0U; std::generate(args.begin(), args.end(), [&](){ return LocateExternalNode(nodeLocator, callable, ++index); }); - return new TSqueezeMapWideWrapper<TAccumulator>(mutables, keyType, payloadType, wide, std::move(args), keySelector, payloadSelector, itemsCountHint); + return new TSqueezeMapWideWrapper(mutables, keyType, payloadType, wide, std::move(args), keySelector, payloadSelector, itemsCountHint, std::move(factory)); } const auto itemArg = LocateExternalNode(nodeLocator, callable, 1U); const auto type = callable.GetInput(0U).GetStaticType(); if (type->IsList()) { - return new TMapWrapper<TAccumulator, false>(mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint); + return new TMapWrapper(mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint, false, std::move(factory)); } if (type->IsFlow()) { - return new TSqueezeMapFlowWrapper<TAccumulator>(mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint); + return new TSqueezeMapFlowWrapper(mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint, std::move(factory)); } if (type->IsStream()) { - return new TMapWrapper<TAccumulator, true>(mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint); + return new TMapWrapper(mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint, true, std::move(factory)); } THROW yexception() << "Expected list, flow or stream."; } -template <bool IsList> -IComputationNode* WrapToSortedDictInternal(TCallable& callable, const TComputationNodeFactoryContext& ctx) { +IComputationNode* WrapToSortedDictInternal(TCallable& callable, const TComputationNodeFactoryContext& ctx, bool isList) { MKQL_ENSURE(callable.GetInputsCount() >= 6U, "Expected six or more args."); const auto type = callable.GetInput(0U).GetStaticType(); - if constexpr (IsList) { + if (isList) { MKQL_ENSURE(type->IsList(), "Expected list."); } else { MKQL_ENSURE(type->IsFlow() || type->IsStream(), "Expected flow or stream."); @@ -1887,56 +1960,70 @@ IComputationNode* WrapToSortedDictInternal(TCallable& callable, const TComputati std::generate(args.begin(), args.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); }); if (!isMulti && payloadType->IsVoid()) { - return new TSqueezeSetWideWrapper<TSortedSetAccumulator>(ctx.Mutables, keyType, wide, std::move(args), keySelector, itemsCountHint); + return new TSqueezeSetWideWrapper(ctx.Mutables, keyType, wide, std::move(args), keySelector, itemsCountHint, + std::make_unique<TSetAccumulatorFactory<TSortedSetAccumulator>>()); } else if (isMulti) { - return new TSqueezeMapWideWrapper<TSortedMapAccumulator<true>>(ctx.Mutables, keyType, payloadType, wide, std::move(args), keySelector, payloadSelector, itemsCountHint); + return new TSqueezeMapWideWrapper(ctx.Mutables, keyType, payloadType, wide, std::move(args), keySelector, payloadSelector, itemsCountHint, + std::make_unique<TMapAccumulatorFactory<TSortedMapAccumulator<true>>>()); } else { - return new TSqueezeMapWideWrapper<TSortedMapAccumulator<false>>(ctx.Mutables, keyType, payloadType, wide, std::move(args), keySelector, payloadSelector, itemsCountHint); + return new TSqueezeMapWideWrapper(ctx.Mutables, keyType, payloadType, wide, std::move(args), keySelector, payloadSelector, itemsCountHint, + std::make_unique<TMapAccumulatorFactory<TSortedMapAccumulator<false>>>()); } } const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1U); if (!isMulti && payloadType->IsVoid()) { + auto factory = std::make_unique<TSetAccumulatorFactory<TSortedSetAccumulator>>(); if (type->IsList()) { - return new TSetWrapper<TSortedSetAccumulator, false>(ctx.Mutables, keyType, flow, itemArg, keySelector, itemsCountHint); + return new TSetWrapper(ctx.Mutables, keyType, flow, itemArg, keySelector, itemsCountHint, + false, std::move(factory)); } if (type->IsFlow()) { - return new TSqueezeSetFlowWrapper<TSortedSetAccumulator>(ctx.Mutables, keyType, flow, itemArg, keySelector, itemsCountHint); + return new TSqueezeSetFlowWrapper(ctx.Mutables, keyType, flow, itemArg, keySelector, + itemsCountHint, std::move(factory)); } if (type->IsStream()) { - return new TSetWrapper<TSortedSetAccumulator, true>(ctx.Mutables, keyType, flow, itemArg, keySelector, itemsCountHint); + return new TSetWrapper(ctx.Mutables, keyType, flow, itemArg, keySelector, itemsCountHint, + true, std::move(factory)); } } else if (isMulti) { + auto factory = std::make_unique<TMapAccumulatorFactory<TSortedMapAccumulator<true>>>(); if (type->IsList()) { - return new TMapWrapper<TSortedMapAccumulator<true>, false>(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint); + return new TMapWrapper(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint, + false, std::move(factory)); } if (type->IsFlow()) { - return new TSqueezeMapFlowWrapper<TSortedMapAccumulator<true>>(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint); + return new TSqueezeMapFlowWrapper(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, + itemsCountHint, std::move(factory)); } if (type->IsStream()) { - return new TMapWrapper<TSortedMapAccumulator<true>, true>(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint); + return new TMapWrapper(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint, + true, std::move(factory)); } } else { + auto factory = std::make_unique<TMapAccumulatorFactory<TSortedMapAccumulator<false>>>(); if (type->IsList()) { - return new TMapWrapper<TSortedMapAccumulator<false>, false>(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint); + return new TMapWrapper(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint, + false, std::move(factory)); } if (type->IsFlow()) { - return new TSqueezeMapFlowWrapper<TSortedMapAccumulator<false>>(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint); + return new TSqueezeMapFlowWrapper(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, + itemsCountHint, std::move(factory)); } if (type->IsStream()) { - return new TMapWrapper<TSortedMapAccumulator<false>, true>(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint); + return new TMapWrapper(ctx.Mutables, keyType, payloadType, flow, itemArg, keySelector, payloadSelector, itemsCountHint, + true, std::move(factory)); } } THROW yexception() << "Expected list, flow or stream."; } -template <bool IsList> -IComputationNode* WrapToHashedDictInternal(TCallable& callable, const TComputationNodeFactoryContext& ctx) { +IComputationNode* WrapToHashedDictInternal(TCallable& callable, const TComputationNodeFactoryContext& ctx, bool isList) { MKQL_ENSURE(callable.GetInputsCount() >= 6U, "Expected six or more args."); const auto type = callable.GetInput(0U).GetStaticType(); - if constexpr (IsList) { + if (isList) { MKQL_ENSURE(type->IsList(), "Expected list."); } else { MKQL_ENSURE(type->IsFlow() || type->IsStream(), "Expected flow or stream."); @@ -2064,19 +2151,19 @@ IComputationNode* WrapToHashedDictInternal(TCallable& callable, const TComputati } IComputationNode* WrapToSortedDict(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - return WrapToSortedDictInternal<true>(callable, ctx); + return WrapToSortedDictInternal(callable, ctx, true); } IComputationNode* WrapToHashedDict(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - return WrapToHashedDictInternal<true>(callable, ctx); + return WrapToHashedDictInternal(callable, ctx, true); } IComputationNode* WrapSqueezeToSortedDict(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - return WrapToSortedDictInternal<false>(callable, ctx); + return WrapToSortedDictInternal(callable, ctx, false); } IComputationNode* WrapSqueezeToHashedDict(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - return WrapToHashedDictInternal<false>(callable, ctx); + return WrapToHashedDictInternal(callable, ctx, false); } } |