diff options
author | Aidar Samerkhanov <aidarsamer@ydb.tech> | 2024-04-11 17:27:28 +0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-11 17:27:28 +0400 |
commit | 7dfc51a47057a55016028f9e6e8e035f86faa86f (patch) | |
tree | f53abe49298b80d58f337de019a3a3c9772c149b | |
parent | ee73bc343ec419aecb3bcdd8b44944361e5065b6 (diff) | |
download | ydb-7dfc51a47057a55016028f9e6e8e035f86faa86f.tar.gz |
YQL-17167: Add Spilling support to Sort operator (#3339)
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp | 517 |
1 files changed, 475 insertions, 42 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp index 324cb1d72ef..7139a82bc42 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp @@ -2,6 +2,7 @@ #include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE #include <ydb/library/yql/minikql/computation/mkql_llvm_base.h> // Y_IGNORE +#include <ydb/library/yql/minikql/computation/mkql_spiller_adapter.h> #include <ydb/library/yql/minikql/computation/presort.h> #include <ydb/library/yql/minikql/mkql_node_builder.h> #include <ydb/library/yql/minikql/mkql_node_cast.h> @@ -10,6 +11,7 @@ #include <ydb/library/yql/utils/sort.h> + namespace NKikimr { namespace NMiniKQL { @@ -78,14 +80,127 @@ struct TMyValueCompare { const std::vector<TRuntimeKeyInfo> Keys; }; +using TAsyncWriteOperation = std::optional<NThreading::TFuture<ISpiller::TKey>>; +using TAsyncReadOperation = std::optional<NThreading::TFuture<std::optional<TRope>>>; +using TStorage = std::vector<NUdf::TUnboxedValue, TMKQLAllocator<NUdf::TUnboxedValue, EMemorySubPool::Temporary>>; + +struct TSpilledData { + using TPtr = TSpilledData*; + + TSpilledData(std::unique_ptr<TWideUnboxedValuesSpillerAdapter> &&spiller) + : Spiller(std::move(spiller)) {} + + TAsyncWriteOperation Write(NUdf::TUnboxedValue* item, size_t size) { + AsyncWriteOperation = Spiller->WriteWideItem({item, size}); + return AsyncWriteOperation; + } + + TAsyncWriteOperation FinishWrite() { + AsyncWriteOperation = Spiller->FinishWriting(); + return AsyncWriteOperation; + } + + TAsyncReadOperation Read(TStorage &buffer, TComputationContext& ctx) { + if (AsyncReadOperation) { + if (AsyncReadOperation->HasValue()) { + Spiller->AsyncReadCompleted(AsyncReadOperation->ExtractValue().value(), ctx.HolderFactory); + AsyncReadOperation = std::nullopt; + } else { + return AsyncReadOperation; + } + } + if (Spiller->Empty()) { + IsFinished = true; + return std::nullopt; + } + AsyncReadOperation = Spiller->ExtractWideItem(buffer); + return AsyncReadOperation; + } + + bool Empty() const { + return IsFinished; + } + + std::unique_ptr<TWideUnboxedValuesSpillerAdapter> Spiller; + TAsyncWriteOperation AsyncWriteOperation = std::nullopt; + TAsyncReadOperation AsyncReadOperation = std::nullopt; + bool IsFinished = false; +}; + +class TSpilledUnboxedValuesIterator { +private: + TStorage Data; + TSpilledData::TPtr SpilledData; + std::function<bool(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)> LessFunc; + ui32 Width_; + TComputationContext* Ctx; + bool HasValue = false; +public: + + TSpilledUnboxedValuesIterator( + const std::function<bool(const NUdf::TUnboxedValuePod*,const NUdf::TUnboxedValuePod*)>& lessFunc, + TSpilledData::TPtr spilledData, + size_t dataWidth, + TComputationContext* ctx + ) + : SpilledData(spilledData) + , LessFunc(lessFunc) + , Width_(dataWidth) + , Ctx(ctx) + { + Data.resize(Width_); + } + + EFetchResult Read() { + if (!HasValue) { + if (SpilledData->Read(Data, *Ctx)) { + return EFetchResult::Yield; + } + if (SpilledData->Empty()) { + return EFetchResult::Finish; + } + } + HasValue = true; + return EFetchResult::One; + } + + bool CheckForInit() { + Read(); + return HasValue; + } + + bool IsFinished() const { + return SpilledData->Empty(); + } + + bool operator<(const TSpilledUnboxedValuesIterator& item) const { + return !LessFunc(GetValue(), item.GetValue()); + } + + ui32 Width() const { + return Width_; + } + + void Pop() { + HasValue = false; + Read(); + } + + NKikimr::NUdf::TUnboxedValue* GetValue() { + return &*Data.begin(); + } + const NKikimr::NUdf::TUnboxedValue* GetValue() const { + return &*Data.begin(); + } +}; + using TComparePtr = int(*)(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*); using TCompareFunc = std::function<int(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)>; -template <bool HasCount> -class TState : public TComputationValue<TState<HasCount>> { -using TBase = TComputationValue<TState<HasCount>>; +template <bool Sort, bool HasCount> +class TState : public TComputationValue<TState<Sort, HasCount>> { +using TBase = TComputationValue<TState<Sort, HasCount>>; private: - using TStorage = std::vector<NUdf::TUnboxedValue, TMKQLAllocator<NUdf::TUnboxedValue, EMemorySubPool::Temporary>>; using TFields = std::vector<NUdf::TUnboxedValue*, TMKQLAllocator<NUdf::TUnboxedValue*, EMemorySubPool::Temporary>>; using TPointers = std::vector<NUdf::TUnboxedValuePod*, TMKQLAllocator<NUdf::TUnboxedValuePod*, EMemorySubPool::Temporary>>; @@ -106,8 +221,12 @@ private: std::for_each(Indexes.cbegin(), Indexes.cend(), [&](ui32 index) { Fields[index] = static_cast<NUdf::TUnboxedValue*>(ptr++); }); } public: - TState(TMemoryUsageInfo* memInfo, ui64 count, const bool* directons, size_t keyWidth, const TCompareFunc& compare, const std::vector<ui32>& indexes) - : TBase(memInfo), Count(count), Indexes(indexes), Directions(directons, directons + keyWidth) + TState(TMemoryUsageInfo* memInfo, ui64 count, const bool* directons, size_t keyWidth, const TCompareFunc& compare, const std::vector<ui32>& indexes, IComputationWideFlowNode *const flow) + : TBase(memInfo) + , Flow(flow) + , Count(count) + , Indexes(indexes) + , Directions(directons, directons + keyWidth) , LessFunc(std::bind(std::less<int>(), std::bind(compare, Directions.data(), std::placeholders::_1, std::placeholders::_2), 0)) , Fields(Indexes.size(), nullptr) { @@ -131,6 +250,32 @@ public: InputStatus = EFetchResult::Finish; } + virtual EFetchResult DoCalculate(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) { + while (EFetchResult::Finish != InputStatus) { + switch (InputStatus = Flow->FetchValues(ctx, GetFields())) { + case EFetchResult::One: + Put(); + continue; + case EFetchResult::Yield: + return EFetchResult::Yield; + case EFetchResult::Finish: + Seal(); + break; + } + } + + if (auto extract = Extract()) { + for (const auto index : Indexes) + if (const auto to = output[index]) + *to = std::move(*extract++); + else + ++extract; + return EFetchResult::One; + } + + return EFetchResult::Finish; + } + NUdf::TUnboxedValue*const* GetFields() const { return Fields.data(); } @@ -169,7 +314,6 @@ public: return true; } - template<bool Sort> void Seal() { if constexpr (!HasCount) { static_assert (Sort); @@ -208,6 +352,284 @@ public: NUdf::TUnboxedValuePod* Tongue = nullptr; NUdf::TUnboxedValuePod* Throat = nullptr; private: + IComputationWideFlowNode *const Flow; + const ui64 Count; + const std::vector<ui32> Indexes; + const std::vector<bool> Directions; + const std::function<bool(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)> LessFunc; + TStorage Storage; + TPointers Free, Full; + TFields Fields; +}; + +template <bool Sort, bool HasCount> +class TSpillingSupportState : public TComputationValue<TSpillingSupportState<Sort, HasCount>> { +using TBase = TComputationValue<TSpillingSupportState<Sort, HasCount>>; +private: + using TStorage = std::vector<NUdf::TUnboxedValue, TMKQLAllocator<NUdf::TUnboxedValue, EMemorySubPool::Temporary>>; + using TFields = std::vector<NUdf::TUnboxedValue*, TMKQLAllocator<NUdf::TUnboxedValue*, EMemorySubPool::Temporary>>; + using TPointers = std::vector<NUdf::TUnboxedValuePod*, TMKQLAllocator<NUdf::TUnboxedValuePod*, EMemorySubPool::Temporary>>; + + enum class EOperatingMode { + InMemory, + Spilling, + ProcessSpilled + }; + + size_t GetStorageSize() const { + return std::max<size_t>(Count << 2ULL, 1ULL << 8ULL); + } + + void ResetFields() { + NUdf::TUnboxedValuePod* ptr; + if constexpr (!HasCount) { + auto pos = Storage.size(); + Storage.insert(Storage.end(), Indexes.size(), {}); + ptr = Storage.data() + pos; + } + + std::for_each(Indexes.cbegin(), Indexes.cend(), [&](ui32 index) { Fields[index] = static_cast<NUdf::TUnboxedValue*>(ptr++); }); + } + +public: + TSpillingSupportState(TMemoryUsageInfo* memInfo, ui64 count, const bool* directons, size_t keyWidth, const TCompareFunc& compare, + const std::vector<ui32>& indexes, IComputationWideFlowNode *const flow, TMultiType* tupleMultiType) + : TBase(memInfo) + , Flow(flow) + , Count(count) + , Indexes(indexes) + , Directions(directons, directons + keyWidth) + , LessFunc(std::bind(std::less<int>(), std::bind(compare, Directions.data(), std::placeholders::_1, std::placeholders::_2), 0)) + , Fields(Indexes.size(), nullptr) + , TupleMultiType(tupleMultiType) + { + if constexpr (!HasCount) { + ResetFields(); + return; + } + throw yexception() << "Spilling doesn't support TopSort."; + } + + virtual EFetchResult DoCalculate(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) { + while (true) { + switch(GetMode()) { + case EOperatingMode::InMemory: { + auto r = DoCalculateInMemory(ctx, output); + if (GetMode() == TSpillingSupportState::EOperatingMode::InMemory) { + return r; + } + break; + } + case EOperatingMode::Spilling: { + DoCalculateWithSpilling(ctx); + if (GetMode() == EOperatingMode::Spilling) { + return EFetchResult::Yield; + } + break; + } + case EOperatingMode::ProcessSpilled: { + return ProcessSpilledData(output); + } + + } + } + Y_UNREACHABLE(); + } + +private: + + EFetchResult DoCalculateInMemory(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) { + while (EFetchResult::Finish != InputStatus) { + switch (InputStatus = Flow->FetchValues(ctx, GetFields())) { + case EFetchResult::One: + if (Put()) { + if (!HasMemoryForProcessing()) { + SwitchMode(EOperatingMode::Spilling, ctx); + return EFetchResult::Yield; + } + } + continue; + case EFetchResult::Yield: + return EFetchResult::Yield; + case EFetchResult::Finish: + { + if (!SpilledStates.empty()) { + SwitchMode(EOperatingMode::Spilling, ctx); + return EFetchResult::Yield; + } + Seal(); + break; + } + } + } + + if (auto extract = Extract()) { + for (const auto index : Indexes) + if (const auto to = output[index]) + *to = std::move(*extract++); + else + ++extract; + return EFetchResult::One; + } + + return EFetchResult::Finish; + } + + EFetchResult DoCalculateWithSpilling(TComputationContext& ctx) { + if (!SpillState()) { + return EFetchResult::Yield; + } + ResetFields(); + auto nextMode = (IsReadFromChannelFinished() ? EOperatingMode::ProcessSpilled : EOperatingMode::InMemory); + SwitchMode(nextMode, ctx); + return EFetchResult::Yield; + } + + EFetchResult ProcessSpilledData(NUdf::TUnboxedValue*const* output) { + if (SpilledUnboxedValuesIterators.empty()) { + return EFetchResult::Finish; + } + + for (auto &spilledUnboxedValuesIterator : SpilledUnboxedValuesIterators) { + if (!spilledUnboxedValuesIterator.CheckForInit()) { + return EFetchResult::Yield; + } + } + if (!IsHeapBuilt) { + std::make_heap(SpilledUnboxedValuesIterators.begin(), SpilledUnboxedValuesIterators.end()); + IsHeapBuilt = true; + } else { + std::push_heap(SpilledUnboxedValuesIterators.begin(), SpilledUnboxedValuesIterators.end()); + } + + std::pop_heap(SpilledUnboxedValuesIterators.begin(), SpilledUnboxedValuesIterators.end()); + auto ¤tIt = SpilledUnboxedValuesIterators.back(); + NKikimr::NUdf::TUnboxedValue* res = currentIt.GetValue(); + for (const auto index : Indexes) + { + if (const auto to = output[index]) + *to = std::move(*res++); + else + ++res; + } + currentIt.Pop(); + if (currentIt.IsFinished()) { + SpilledUnboxedValuesIterators.pop_back(); + } + return EFetchResult::One; + } + + NUdf::TUnboxedValue*const* GetFields() const { + return Fields.data(); + } + + bool Put() { + if constexpr (!HasCount) { + ResetFields(); + return true; + } + + throw yexception() << "Spilling doesn't support TopSort."; + } + + void Seal() { + if constexpr (!HasCount) { + static_assert (Sort); + // Remove placeholder for new data + Storage.resize(Storage.size() - Indexes.size()); + + Full.reserve(Storage.size() / Indexes.size()); + for (auto it = Storage.begin(); it != Storage.end(); it += Indexes.size()) { + Full.emplace_back(&*it); + } + + std::sort(Full.rbegin(), Full.rend(), LessFunc); + return; + } + + throw yexception() << "Spilling doesn't support TopSort."; + } + + NUdf::TUnboxedValue* Extract() { + if (Full.empty()) + return nullptr; + + const auto ptr = Full.back(); + Full.pop_back(); + return static_cast<NUdf::TUnboxedValue*>(ptr); + } + + EOperatingMode GetMode() const { return Mode; } + + bool HasMemoryForProcessing() const { + // TODO: Change to enable spilling + // return !TlsAllocState->IsMemoryYellowZoneEnabled(); + return true; + } + + bool IsReadFromChannelFinished() const { + return InputStatus == EFetchResult::Finish; + } + + void SwitchMode(EOperatingMode mode, TComputationContext& ctx) { + switch(mode) { + case EOperatingMode::InMemory: + break; + case EOperatingMode::Spilling: + { + auto spiller = ctx.SpillerFactory->CreateSpiller(); + const size_t PACK_SIZE = 5_MB; + SpilledStates.emplace_back(std::make_unique<TWideUnboxedValuesSpillerAdapter>(spiller, TupleMultiType, PACK_SIZE)); + break; + } + case EOperatingMode::ProcessSpilled: + { + SpilledUnboxedValuesIterators.reserve(SpilledStates.size()); + for (auto &state: SpilledStates) { + SpilledUnboxedValuesIterators.emplace_back(LessFunc, &state, Indexes.size(), &ctx); + } + break; + } + } + Mode = mode; + } + + bool SpillState() { + MKQL_ENSURE(!SpilledStates.empty(), "At least one Spiller must be created to spill data in Sort operation."); + auto &lastSpilledState = SpilledStates.back(); + if (lastSpilledState.AsyncWriteOperation.has_value()) { + if (!lastSpilledState.AsyncWriteOperation->HasValue()) { + return false; + } + lastSpilledState.Spiller->AsyncWriteCompleted(lastSpilledState.AsyncWriteOperation->ExtractValue()); + lastSpilledState.AsyncWriteOperation = std::nullopt; + } else { + Seal(); + if (Full.empty()) { + // Nothing to spill + SpilledStates.pop_back(); + return true; + } + } + + while (auto extract = Extract()) { + auto writeOp = lastSpilledState.Write(extract, Indexes.size()); + if (writeOp) { + return false; + } + } + + auto writeFinishOp = lastSpilledState.FinishWrite(); + if (writeFinishOp){ + return false; + } + Storage.resize(0); + + return true; + } + + EFetchResult InputStatus = EFetchResult::One; + IComputationWideFlowNode *const Flow; const ui64 Count; const std::vector<ui32> Indexes; const std::vector<bool> Directions; @@ -215,13 +637,18 @@ private: TStorage Storage; TPointers Free, Full; TFields Fields; + TMultiType* TupleMultiType; + std::vector<TSpilledData> SpilledStates; + EOperatingMode Mode = EOperatingMode::InMemory; + std::vector<TSpilledUnboxedValuesIterator> SpilledUnboxedValuesIterators; + bool IsHeapBuilt = false; }; #ifndef MKQL_DISABLE_CODEGEN -template <bool HasCount> -class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState<HasCount>>> { +template <bool Sort, bool HasCount> +class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState<Sort, HasCount>>> { private: - using TBase = TLLVMFieldsStructure<TComputationValue<TState<HasCount>>>; + using TBase = TLLVMFieldsStructure<TComputationValue<TState<Sort, HasCount>>>; llvm::IntegerType* ValueType; llvm::PointerType* PtrValueType; llvm::IntegerType* StatusType; @@ -264,9 +691,9 @@ class TWideTopWrapper: public TStatefulWideFlowCodegeneratorNode<TWideTopWrapper using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort, HasCount>>; public: TWideTopWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, TComputationNodePtrVector&& directions, std::vector<TKeyInfo>&& keys, - std::vector<ui32>&& indexes, std::vector<EValueRepresentation>&& representations) + std::vector<ui32>&& indexes, std::vector<EValueRepresentation>&& representations, TMultiType* tupleMultiType) : TBaseComputation(mutables, flow, EValueRepresentation::Boxed), Flow(flow), Count(count), Directions(std::move(directions)), Keys(std::move(keys)) - , Indexes(std::move(indexes)), Representations(std::move(representations)) + , Indexes(std::move(indexes)), Representations(std::move(representations)), TupleMultiType(tupleMultiType) { for (const auto& x : Keys) { if (x.Compare || x.PresortType) { @@ -290,33 +717,23 @@ public: std::vector<bool> dirs(Directions.size()); std::transform(Directions.cbegin(), Directions.cend(), dirs.begin(), [&ctx](IComputationNode* dir){ return dir->GetValue(ctx).Get<bool>(); }); - MakeState(ctx, state, count, dirs.data()); + if (!ctx.ExecuteLLVM) { + MakeSpillingSupportState(ctx, state, count, dirs.data()); + } else { + MakeState(ctx, state, count, dirs.data()); + } } - if (const auto ptr = static_cast<TState<HasCount>*>(state.AsBoxed().Get())) { - while (EFetchResult::Finish != ptr->InputStatus) { - switch (ptr->InputStatus = Flow->FetchValues(ctx, ptr->GetFields())) { - case EFetchResult::One: - ptr->Put(); - continue; - case EFetchResult::Yield: - return EFetchResult::Yield; - case EFetchResult::Finish: - ptr->template Seal<Sort>(); - break; - } + // To avoid dynamic_cast implementation in LLVM implementation + // This is temporary solution. Final result will have just one state here. + if (!ctx.ExecuteLLVM) { + if (const auto ptr = static_cast<TSpillingSupportState<Sort, HasCount>*>(state.AsBoxed().Get())) { + return ptr->DoCalculate(ctx, output); } - - if (auto extract = ptr->Extract()) { - for (const auto index : Indexes) - if (const auto to = output[index]) - *to = std::move(*extract++); - else - ++extract; - return EFetchResult::One; + } else { + if (const auto ptr = static_cast<TState<Sort, HasCount>*>(state.AsBoxed().Get())) { + return ptr->DoCalculate(ctx, output); } - - return EFetchResult::Finish; } Y_UNREACHABLE(); @@ -330,7 +747,7 @@ public: const auto statusType = Type::getInt32Ty(context); const auto indexType = Type::getInt32Ty(ctx.Codegen.GetContext()); - TLLVMFieldsStructureState<HasCount> stateFields(context); + TLLVMFieldsStructureState<Sort, HasCount> stateFields(context); const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); const auto statePtrType = PointerType::getUnqual(stateType); @@ -419,7 +836,7 @@ public: block = rest; new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), statusPtr, block); - const auto sealFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::template Seal<Sort>)); + const auto sealFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<Sort, HasCount>::Seal)); const auto sealType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType()}, false); const auto sealPtr = CastInst::Create(Instruction::IntToPtr, sealFunc, PointerType::getUnqual(sealType), "seal", block); CallInst::Create(sealType, sealPtr, {stateArg}, "", block); @@ -450,7 +867,7 @@ public: } - const auto pushFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::Put)); + const auto pushFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<Sort, HasCount>::Put)); const auto pushType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false); const auto pushPtr = CastInst::Create(Instruction::IntToPtr, pushFunc, PointerType::getUnqual(pushType), "function", block); const auto accepted = CallInst::Create(pushType, pushPtr, {stateArg}, "accepted", block); @@ -490,7 +907,7 @@ public: const auto good = BasicBlock::Create(context, "good", ctx.Func); - const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::Extract)); + const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<Sort, HasCount>::Extract)); const auto extractType = FunctionType::get(outputPtrType, {stateArg->getType()}, false); const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block); const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block); @@ -515,12 +932,20 @@ public: private: void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state, ui64 count, const bool* directions) const { #ifdef MKQL_DISABLE_CODEGEN - state = ctx.HolderFactory.Create<TState<HasCount>>(count, directions, Directions.size(), TMyValueCompare(Keys), Indexes); + state = ctx.HolderFactory.Create<TState<Sort, HasCount>>(count, directions, Directions.size(), TMyValueCompare(Keys), Indexes, Flow); #else - state = ctx.HolderFactory.Create<TState<HasCount>>(count, directions, Directions.size(), ctx.ExecuteLLVM && Compare ? TCompareFunc(Compare) : TCompareFunc(TMyValueCompare(Keys)), Indexes); + state = ctx.HolderFactory.Create<TState<Sort, HasCount>>(count, directions, Directions.size(), ctx.ExecuteLLVM && Compare ? TCompareFunc(Compare) : TCompareFunc(TMyValueCompare(Keys)), Indexes, Flow); #endif } + void MakeSpillingSupportState(TComputationContext& ctx, NUdf::TUnboxedValue& state, ui64 count, const bool* directions) const { + if (Sort && !HasCount && !ctx.ExecuteLLVM) { + state = ctx.HolderFactory.Create<TSpillingSupportState<Sort, HasCount>>(count, directions, Directions.size(), TMyValueCompare(Keys), Indexes, Flow, TupleMultiType); + return; + } + state = ctx.HolderFactory.Create<TState<Sort, HasCount>>(count, directions, Directions.size(), TMyValueCompare(Keys), Indexes, Flow); + } + void RegisterDependencies() const final { if (const auto flow = this->FlowDependsOn(Flow)) { if constexpr (HasCount) { @@ -538,6 +963,7 @@ private: const std::vector<ui32> Indexes; const std::vector<EValueRepresentation> Representations; TKeyTypes KeyTypes; + TMultiType* TupleMultiType; bool HasComplexType = false; #ifndef MKQL_DISABLE_CODEGEN @@ -587,10 +1013,14 @@ IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactor std::unordered_set<ui32> keyIndexes; std::vector<TKeyInfo> keys(keyWidth); + std::vector<TType*> tupleTypes; + tupleTypes.reserve(inputWideComponents.size()); + for (auto i = 0U; i < keyWidth; ++i) { const auto keyIndex = AS_VALUE(TDataLiteral, callable.GetInput(((i + 1U) << 1U) - offset))->AsValue().Get<ui32>(); indexes[i] = keyIndex; keyIndexes.emplace(keyIndex); + tupleTypes.emplace_back(inputWideComponents[keyIndex]); bool isTuple; bool encoded; @@ -608,6 +1038,7 @@ IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactor } } + size_t payloadPos = keyWidth; for (auto i = 0U; i < indexes.size(); ++i) { if (keyIndexes.contains(i)) { @@ -615,19 +1046,21 @@ IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactor } indexes[payloadPos++] = i; + tupleTypes.emplace_back(inputWideComponents[i]); } std::vector<EValueRepresentation> representations(inputWideComponents.size()); for (auto i = 0U; i < representations.size(); ++i) representations[i] = GetValueRepresentation(inputWideComponents[indexes[i]]); + auto tupleMultiType = TMultiType::Create(tupleTypes.size(),tupleTypes.data(), ctx.Env); TComputationNodePtrVector directions(keyWidth); auto index = 1U - offset; std::generate(directions.begin(), directions.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++++index); }); if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) { return new TWideTopWrapper<Sort, HasCount>(ctx.Mutables, wide, count, std::move(directions), std::move(keys), - std::move(indexes), std::move(representations)); + std::move(indexes), std::move(representations), tupleMultiType); } THROW yexception() << "Expected wide flow."; |