aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAidar Samerkhanov <aidarsamer@ydb.tech>2024-04-11 17:27:28 +0400
committerGitHub <noreply@github.com>2024-04-11 17:27:28 +0400
commit7dfc51a47057a55016028f9e6e8e035f86faa86f (patch)
treef53abe49298b80d58f337de019a3a3c9772c149b
parentee73bc343ec419aecb3bcdd8b44944361e5065b6 (diff)
downloadydb-7dfc51a47057a55016028f9e6e8e035f86faa86f.tar.gz
YQL-17167: Add Spilling support to Sort operator (#3339)
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp517
1 files changed, 475 insertions, 42 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp
index 324cb1d72ef..7139a82bc42 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp
@@ -2,6 +2,7 @@
#include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
#include <ydb/library/yql/minikql/computation/mkql_llvm_base.h> // Y_IGNORE
+#include <ydb/library/yql/minikql/computation/mkql_spiller_adapter.h>
#include <ydb/library/yql/minikql/computation/presort.h>
#include <ydb/library/yql/minikql/mkql_node_builder.h>
#include <ydb/library/yql/minikql/mkql_node_cast.h>
@@ -10,6 +11,7 @@
#include <ydb/library/yql/utils/sort.h>
+
namespace NKikimr {
namespace NMiniKQL {
@@ -78,14 +80,127 @@ struct TMyValueCompare {
const std::vector<TRuntimeKeyInfo> Keys;
};
+using TAsyncWriteOperation = std::optional<NThreading::TFuture<ISpiller::TKey>>;
+using TAsyncReadOperation = std::optional<NThreading::TFuture<std::optional<TRope>>>;
+using TStorage = std::vector<NUdf::TUnboxedValue, TMKQLAllocator<NUdf::TUnboxedValue, EMemorySubPool::Temporary>>;
+
+struct TSpilledData {
+ using TPtr = TSpilledData*;
+
+ TSpilledData(std::unique_ptr<TWideUnboxedValuesSpillerAdapter> &&spiller)
+ : Spiller(std::move(spiller)) {}
+
+ TAsyncWriteOperation Write(NUdf::TUnboxedValue* item, size_t size) {
+ AsyncWriteOperation = Spiller->WriteWideItem({item, size});
+ return AsyncWriteOperation;
+ }
+
+ TAsyncWriteOperation FinishWrite() {
+ AsyncWriteOperation = Spiller->FinishWriting();
+ return AsyncWriteOperation;
+ }
+
+ TAsyncReadOperation Read(TStorage &buffer, TComputationContext& ctx) {
+ if (AsyncReadOperation) {
+ if (AsyncReadOperation->HasValue()) {
+ Spiller->AsyncReadCompleted(AsyncReadOperation->ExtractValue().value(), ctx.HolderFactory);
+ AsyncReadOperation = std::nullopt;
+ } else {
+ return AsyncReadOperation;
+ }
+ }
+ if (Spiller->Empty()) {
+ IsFinished = true;
+ return std::nullopt;
+ }
+ AsyncReadOperation = Spiller->ExtractWideItem(buffer);
+ return AsyncReadOperation;
+ }
+
+ bool Empty() const {
+ return IsFinished;
+ }
+
+ std::unique_ptr<TWideUnboxedValuesSpillerAdapter> Spiller;
+ TAsyncWriteOperation AsyncWriteOperation = std::nullopt;
+ TAsyncReadOperation AsyncReadOperation = std::nullopt;
+ bool IsFinished = false;
+};
+
+class TSpilledUnboxedValuesIterator {
+private:
+ TStorage Data;
+ TSpilledData::TPtr SpilledData;
+ std::function<bool(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)> LessFunc;
+ ui32 Width_;
+ TComputationContext* Ctx;
+ bool HasValue = false;
+public:
+
+ TSpilledUnboxedValuesIterator(
+ const std::function<bool(const NUdf::TUnboxedValuePod*,const NUdf::TUnboxedValuePod*)>& lessFunc,
+ TSpilledData::TPtr spilledData,
+ size_t dataWidth,
+ TComputationContext* ctx
+ )
+ : SpilledData(spilledData)
+ , LessFunc(lessFunc)
+ , Width_(dataWidth)
+ , Ctx(ctx)
+ {
+ Data.resize(Width_);
+ }
+
+ EFetchResult Read() {
+ if (!HasValue) {
+ if (SpilledData->Read(Data, *Ctx)) {
+ return EFetchResult::Yield;
+ }
+ if (SpilledData->Empty()) {
+ return EFetchResult::Finish;
+ }
+ }
+ HasValue = true;
+ return EFetchResult::One;
+ }
+
+ bool CheckForInit() {
+ Read();
+ return HasValue;
+ }
+
+ bool IsFinished() const {
+ return SpilledData->Empty();
+ }
+
+ bool operator<(const TSpilledUnboxedValuesIterator& item) const {
+ return !LessFunc(GetValue(), item.GetValue());
+ }
+
+ ui32 Width() const {
+ return Width_;
+ }
+
+ void Pop() {
+ HasValue = false;
+ Read();
+ }
+
+ NKikimr::NUdf::TUnboxedValue* GetValue() {
+ return &*Data.begin();
+ }
+ const NKikimr::NUdf::TUnboxedValue* GetValue() const {
+ return &*Data.begin();
+ }
+};
+
using TComparePtr = int(*)(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*);
using TCompareFunc = std::function<int(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)>;
-template <bool HasCount>
-class TState : public TComputationValue<TState<HasCount>> {
-using TBase = TComputationValue<TState<HasCount>>;
+template <bool Sort, bool HasCount>
+class TState : public TComputationValue<TState<Sort, HasCount>> {
+using TBase = TComputationValue<TState<Sort, HasCount>>;
private:
- using TStorage = std::vector<NUdf::TUnboxedValue, TMKQLAllocator<NUdf::TUnboxedValue, EMemorySubPool::Temporary>>;
using TFields = std::vector<NUdf::TUnboxedValue*, TMKQLAllocator<NUdf::TUnboxedValue*, EMemorySubPool::Temporary>>;
using TPointers = std::vector<NUdf::TUnboxedValuePod*, TMKQLAllocator<NUdf::TUnboxedValuePod*, EMemorySubPool::Temporary>>;
@@ -106,8 +221,12 @@ private:
std::for_each(Indexes.cbegin(), Indexes.cend(), [&](ui32 index) { Fields[index] = static_cast<NUdf::TUnboxedValue*>(ptr++); });
}
public:
- TState(TMemoryUsageInfo* memInfo, ui64 count, const bool* directons, size_t keyWidth, const TCompareFunc& compare, const std::vector<ui32>& indexes)
- : TBase(memInfo), Count(count), Indexes(indexes), Directions(directons, directons + keyWidth)
+ TState(TMemoryUsageInfo* memInfo, ui64 count, const bool* directons, size_t keyWidth, const TCompareFunc& compare, const std::vector<ui32>& indexes, IComputationWideFlowNode *const flow)
+ : TBase(memInfo)
+ , Flow(flow)
+ , Count(count)
+ , Indexes(indexes)
+ , Directions(directons, directons + keyWidth)
, LessFunc(std::bind(std::less<int>(), std::bind(compare, Directions.data(), std::placeholders::_1, std::placeholders::_2), 0))
, Fields(Indexes.size(), nullptr)
{
@@ -131,6 +250,32 @@ public:
InputStatus = EFetchResult::Finish;
}
+ virtual EFetchResult DoCalculate(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) {
+ while (EFetchResult::Finish != InputStatus) {
+ switch (InputStatus = Flow->FetchValues(ctx, GetFields())) {
+ case EFetchResult::One:
+ Put();
+ continue;
+ case EFetchResult::Yield:
+ return EFetchResult::Yield;
+ case EFetchResult::Finish:
+ Seal();
+ break;
+ }
+ }
+
+ if (auto extract = Extract()) {
+ for (const auto index : Indexes)
+ if (const auto to = output[index])
+ *to = std::move(*extract++);
+ else
+ ++extract;
+ return EFetchResult::One;
+ }
+
+ return EFetchResult::Finish;
+ }
+
NUdf::TUnboxedValue*const* GetFields() const {
return Fields.data();
}
@@ -169,7 +314,6 @@ public:
return true;
}
- template<bool Sort>
void Seal() {
if constexpr (!HasCount) {
static_assert (Sort);
@@ -208,6 +352,284 @@ public:
NUdf::TUnboxedValuePod* Tongue = nullptr;
NUdf::TUnboxedValuePod* Throat = nullptr;
private:
+ IComputationWideFlowNode *const Flow;
+ const ui64 Count;
+ const std::vector<ui32> Indexes;
+ const std::vector<bool> Directions;
+ const std::function<bool(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)> LessFunc;
+ TStorage Storage;
+ TPointers Free, Full;
+ TFields Fields;
+};
+
+template <bool Sort, bool HasCount>
+class TSpillingSupportState : public TComputationValue<TSpillingSupportState<Sort, HasCount>> {
+using TBase = TComputationValue<TSpillingSupportState<Sort, HasCount>>;
+private:
+ using TStorage = std::vector<NUdf::TUnboxedValue, TMKQLAllocator<NUdf::TUnboxedValue, EMemorySubPool::Temporary>>;
+ using TFields = std::vector<NUdf::TUnboxedValue*, TMKQLAllocator<NUdf::TUnboxedValue*, EMemorySubPool::Temporary>>;
+ using TPointers = std::vector<NUdf::TUnboxedValuePod*, TMKQLAllocator<NUdf::TUnboxedValuePod*, EMemorySubPool::Temporary>>;
+
+ enum class EOperatingMode {
+ InMemory,
+ Spilling,
+ ProcessSpilled
+ };
+
+ size_t GetStorageSize() const {
+ return std::max<size_t>(Count << 2ULL, 1ULL << 8ULL);
+ }
+
+ void ResetFields() {
+ NUdf::TUnboxedValuePod* ptr;
+ if constexpr (!HasCount) {
+ auto pos = Storage.size();
+ Storage.insert(Storage.end(), Indexes.size(), {});
+ ptr = Storage.data() + pos;
+ }
+
+ std::for_each(Indexes.cbegin(), Indexes.cend(), [&](ui32 index) { Fields[index] = static_cast<NUdf::TUnboxedValue*>(ptr++); });
+ }
+
+public:
+ TSpillingSupportState(TMemoryUsageInfo* memInfo, ui64 count, const bool* directons, size_t keyWidth, const TCompareFunc& compare,
+ const std::vector<ui32>& indexes, IComputationWideFlowNode *const flow, TMultiType* tupleMultiType)
+ : TBase(memInfo)
+ , Flow(flow)
+ , Count(count)
+ , Indexes(indexes)
+ , Directions(directons, directons + keyWidth)
+ , LessFunc(std::bind(std::less<int>(), std::bind(compare, Directions.data(), std::placeholders::_1, std::placeholders::_2), 0))
+ , Fields(Indexes.size(), nullptr)
+ , TupleMultiType(tupleMultiType)
+ {
+ if constexpr (!HasCount) {
+ ResetFields();
+ return;
+ }
+ throw yexception() << "Spilling doesn't support TopSort.";
+ }
+
+ virtual EFetchResult DoCalculate(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) {
+ while (true) {
+ switch(GetMode()) {
+ case EOperatingMode::InMemory: {
+ auto r = DoCalculateInMemory(ctx, output);
+ if (GetMode() == TSpillingSupportState::EOperatingMode::InMemory) {
+ return r;
+ }
+ break;
+ }
+ case EOperatingMode::Spilling: {
+ DoCalculateWithSpilling(ctx);
+ if (GetMode() == EOperatingMode::Spilling) {
+ return EFetchResult::Yield;
+ }
+ break;
+ }
+ case EOperatingMode::ProcessSpilled: {
+ return ProcessSpilledData(output);
+ }
+
+ }
+ }
+ Y_UNREACHABLE();
+ }
+
+private:
+
+ EFetchResult DoCalculateInMemory(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) {
+ while (EFetchResult::Finish != InputStatus) {
+ switch (InputStatus = Flow->FetchValues(ctx, GetFields())) {
+ case EFetchResult::One:
+ if (Put()) {
+ if (!HasMemoryForProcessing()) {
+ SwitchMode(EOperatingMode::Spilling, ctx);
+ return EFetchResult::Yield;
+ }
+ }
+ continue;
+ case EFetchResult::Yield:
+ return EFetchResult::Yield;
+ case EFetchResult::Finish:
+ {
+ if (!SpilledStates.empty()) {
+ SwitchMode(EOperatingMode::Spilling, ctx);
+ return EFetchResult::Yield;
+ }
+ Seal();
+ break;
+ }
+ }
+ }
+
+ if (auto extract = Extract()) {
+ for (const auto index : Indexes)
+ if (const auto to = output[index])
+ *to = std::move(*extract++);
+ else
+ ++extract;
+ return EFetchResult::One;
+ }
+
+ return EFetchResult::Finish;
+ }
+
+ EFetchResult DoCalculateWithSpilling(TComputationContext& ctx) {
+ if (!SpillState()) {
+ return EFetchResult::Yield;
+ }
+ ResetFields();
+ auto nextMode = (IsReadFromChannelFinished() ? EOperatingMode::ProcessSpilled : EOperatingMode::InMemory);
+ SwitchMode(nextMode, ctx);
+ return EFetchResult::Yield;
+ }
+
+ EFetchResult ProcessSpilledData(NUdf::TUnboxedValue*const* output) {
+ if (SpilledUnboxedValuesIterators.empty()) {
+ return EFetchResult::Finish;
+ }
+
+ for (auto &spilledUnboxedValuesIterator : SpilledUnboxedValuesIterators) {
+ if (!spilledUnboxedValuesIterator.CheckForInit()) {
+ return EFetchResult::Yield;
+ }
+ }
+ if (!IsHeapBuilt) {
+ std::make_heap(SpilledUnboxedValuesIterators.begin(), SpilledUnboxedValuesIterators.end());
+ IsHeapBuilt = true;
+ } else {
+ std::push_heap(SpilledUnboxedValuesIterators.begin(), SpilledUnboxedValuesIterators.end());
+ }
+
+ std::pop_heap(SpilledUnboxedValuesIterators.begin(), SpilledUnboxedValuesIterators.end());
+ auto &currentIt = SpilledUnboxedValuesIterators.back();
+ NKikimr::NUdf::TUnboxedValue* res = currentIt.GetValue();
+ for (const auto index : Indexes)
+ {
+ if (const auto to = output[index])
+ *to = std::move(*res++);
+ else
+ ++res;
+ }
+ currentIt.Pop();
+ if (currentIt.IsFinished()) {
+ SpilledUnboxedValuesIterators.pop_back();
+ }
+ return EFetchResult::One;
+ }
+
+ NUdf::TUnboxedValue*const* GetFields() const {
+ return Fields.data();
+ }
+
+ bool Put() {
+ if constexpr (!HasCount) {
+ ResetFields();
+ return true;
+ }
+
+ throw yexception() << "Spilling doesn't support TopSort.";
+ }
+
+ void Seal() {
+ if constexpr (!HasCount) {
+ static_assert (Sort);
+ // Remove placeholder for new data
+ Storage.resize(Storage.size() - Indexes.size());
+
+ Full.reserve(Storage.size() / Indexes.size());
+ for (auto it = Storage.begin(); it != Storage.end(); it += Indexes.size()) {
+ Full.emplace_back(&*it);
+ }
+
+ std::sort(Full.rbegin(), Full.rend(), LessFunc);
+ return;
+ }
+
+ throw yexception() << "Spilling doesn't support TopSort.";
+ }
+
+ NUdf::TUnboxedValue* Extract() {
+ if (Full.empty())
+ return nullptr;
+
+ const auto ptr = Full.back();
+ Full.pop_back();
+ return static_cast<NUdf::TUnboxedValue*>(ptr);
+ }
+
+ EOperatingMode GetMode() const { return Mode; }
+
+ bool HasMemoryForProcessing() const {
+ // TODO: Change to enable spilling
+ // return !TlsAllocState->IsMemoryYellowZoneEnabled();
+ return true;
+ }
+
+ bool IsReadFromChannelFinished() const {
+ return InputStatus == EFetchResult::Finish;
+ }
+
+ void SwitchMode(EOperatingMode mode, TComputationContext& ctx) {
+ switch(mode) {
+ case EOperatingMode::InMemory:
+ break;
+ case EOperatingMode::Spilling:
+ {
+ auto spiller = ctx.SpillerFactory->CreateSpiller();
+ const size_t PACK_SIZE = 5_MB;
+ SpilledStates.emplace_back(std::make_unique<TWideUnboxedValuesSpillerAdapter>(spiller, TupleMultiType, PACK_SIZE));
+ break;
+ }
+ case EOperatingMode::ProcessSpilled:
+ {
+ SpilledUnboxedValuesIterators.reserve(SpilledStates.size());
+ for (auto &state: SpilledStates) {
+ SpilledUnboxedValuesIterators.emplace_back(LessFunc, &state, Indexes.size(), &ctx);
+ }
+ break;
+ }
+ }
+ Mode = mode;
+ }
+
+ bool SpillState() {
+ MKQL_ENSURE(!SpilledStates.empty(), "At least one Spiller must be created to spill data in Sort operation.");
+ auto &lastSpilledState = SpilledStates.back();
+ if (lastSpilledState.AsyncWriteOperation.has_value()) {
+ if (!lastSpilledState.AsyncWriteOperation->HasValue()) {
+ return false;
+ }
+ lastSpilledState.Spiller->AsyncWriteCompleted(lastSpilledState.AsyncWriteOperation->ExtractValue());
+ lastSpilledState.AsyncWriteOperation = std::nullopt;
+ } else {
+ Seal();
+ if (Full.empty()) {
+ // Nothing to spill
+ SpilledStates.pop_back();
+ return true;
+ }
+ }
+
+ while (auto extract = Extract()) {
+ auto writeOp = lastSpilledState.Write(extract, Indexes.size());
+ if (writeOp) {
+ return false;
+ }
+ }
+
+ auto writeFinishOp = lastSpilledState.FinishWrite();
+ if (writeFinishOp){
+ return false;
+ }
+ Storage.resize(0);
+
+ return true;
+ }
+
+ EFetchResult InputStatus = EFetchResult::One;
+ IComputationWideFlowNode *const Flow;
const ui64 Count;
const std::vector<ui32> Indexes;
const std::vector<bool> Directions;
@@ -215,13 +637,18 @@ private:
TStorage Storage;
TPointers Free, Full;
TFields Fields;
+ TMultiType* TupleMultiType;
+ std::vector<TSpilledData> SpilledStates;
+ EOperatingMode Mode = EOperatingMode::InMemory;
+ std::vector<TSpilledUnboxedValuesIterator> SpilledUnboxedValuesIterators;
+ bool IsHeapBuilt = false;
};
#ifndef MKQL_DISABLE_CODEGEN
-template <bool HasCount>
-class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState<HasCount>>> {
+template <bool Sort, bool HasCount>
+class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState<Sort, HasCount>>> {
private:
- using TBase = TLLVMFieldsStructure<TComputationValue<TState<HasCount>>>;
+ using TBase = TLLVMFieldsStructure<TComputationValue<TState<Sort, HasCount>>>;
llvm::IntegerType* ValueType;
llvm::PointerType* PtrValueType;
llvm::IntegerType* StatusType;
@@ -264,9 +691,9 @@ class TWideTopWrapper: public TStatefulWideFlowCodegeneratorNode<TWideTopWrapper
using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort, HasCount>>;
public:
TWideTopWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, TComputationNodePtrVector&& directions, std::vector<TKeyInfo>&& keys,
- std::vector<ui32>&& indexes, std::vector<EValueRepresentation>&& representations)
+ std::vector<ui32>&& indexes, std::vector<EValueRepresentation>&& representations, TMultiType* tupleMultiType)
: TBaseComputation(mutables, flow, EValueRepresentation::Boxed), Flow(flow), Count(count), Directions(std::move(directions)), Keys(std::move(keys))
- , Indexes(std::move(indexes)), Representations(std::move(representations))
+ , Indexes(std::move(indexes)), Representations(std::move(representations)), TupleMultiType(tupleMultiType)
{
for (const auto& x : Keys) {
if (x.Compare || x.PresortType) {
@@ -290,33 +717,23 @@ public:
std::vector<bool> dirs(Directions.size());
std::transform(Directions.cbegin(), Directions.cend(), dirs.begin(), [&ctx](IComputationNode* dir){ return dir->GetValue(ctx).Get<bool>(); });
- MakeState(ctx, state, count, dirs.data());
+ if (!ctx.ExecuteLLVM) {
+ MakeSpillingSupportState(ctx, state, count, dirs.data());
+ } else {
+ MakeState(ctx, state, count, dirs.data());
+ }
}
- if (const auto ptr = static_cast<TState<HasCount>*>(state.AsBoxed().Get())) {
- while (EFetchResult::Finish != ptr->InputStatus) {
- switch (ptr->InputStatus = Flow->FetchValues(ctx, ptr->GetFields())) {
- case EFetchResult::One:
- ptr->Put();
- continue;
- case EFetchResult::Yield:
- return EFetchResult::Yield;
- case EFetchResult::Finish:
- ptr->template Seal<Sort>();
- break;
- }
+ // To avoid dynamic_cast implementation in LLVM implementation
+ // This is temporary solution. Final result will have just one state here.
+ if (!ctx.ExecuteLLVM) {
+ if (const auto ptr = static_cast<TSpillingSupportState<Sort, HasCount>*>(state.AsBoxed().Get())) {
+ return ptr->DoCalculate(ctx, output);
}
-
- if (auto extract = ptr->Extract()) {
- for (const auto index : Indexes)
- if (const auto to = output[index])
- *to = std::move(*extract++);
- else
- ++extract;
- return EFetchResult::One;
+ } else {
+ if (const auto ptr = static_cast<TState<Sort, HasCount>*>(state.AsBoxed().Get())) {
+ return ptr->DoCalculate(ctx, output);
}
-
- return EFetchResult::Finish;
}
Y_UNREACHABLE();
@@ -330,7 +747,7 @@ public:
const auto statusType = Type::getInt32Ty(context);
const auto indexType = Type::getInt32Ty(ctx.Codegen.GetContext());
- TLLVMFieldsStructureState<HasCount> stateFields(context);
+ TLLVMFieldsStructureState<Sort, HasCount> stateFields(context);
const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
const auto statePtrType = PointerType::getUnqual(stateType);
@@ -419,7 +836,7 @@ public:
block = rest;
new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), statusPtr, block);
- const auto sealFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::template Seal<Sort>));
+ const auto sealFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<Sort, HasCount>::Seal));
const auto sealType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType()}, false);
const auto sealPtr = CastInst::Create(Instruction::IntToPtr, sealFunc, PointerType::getUnqual(sealType), "seal", block);
CallInst::Create(sealType, sealPtr, {stateArg}, "", block);
@@ -450,7 +867,7 @@ public:
}
- const auto pushFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::Put));
+ const auto pushFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<Sort, HasCount>::Put));
const auto pushType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false);
const auto pushPtr = CastInst::Create(Instruction::IntToPtr, pushFunc, PointerType::getUnqual(pushType), "function", block);
const auto accepted = CallInst::Create(pushType, pushPtr, {stateArg}, "accepted", block);
@@ -490,7 +907,7 @@ public:
const auto good = BasicBlock::Create(context, "good", ctx.Func);
- const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::Extract));
+ const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<Sort, HasCount>::Extract));
const auto extractType = FunctionType::get(outputPtrType, {stateArg->getType()}, false);
const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
const auto out = CallInst::Create(extractType, extractPtr, {stateArg}, "out", block);
@@ -515,12 +932,20 @@ public:
private:
void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state, ui64 count, const bool* directions) const {
#ifdef MKQL_DISABLE_CODEGEN
- state = ctx.HolderFactory.Create<TState<HasCount>>(count, directions, Directions.size(), TMyValueCompare(Keys), Indexes);
+ state = ctx.HolderFactory.Create<TState<Sort, HasCount>>(count, directions, Directions.size(), TMyValueCompare(Keys), Indexes, Flow);
#else
- state = ctx.HolderFactory.Create<TState<HasCount>>(count, directions, Directions.size(), ctx.ExecuteLLVM && Compare ? TCompareFunc(Compare) : TCompareFunc(TMyValueCompare(Keys)), Indexes);
+ state = ctx.HolderFactory.Create<TState<Sort, HasCount>>(count, directions, Directions.size(), ctx.ExecuteLLVM && Compare ? TCompareFunc(Compare) : TCompareFunc(TMyValueCompare(Keys)), Indexes, Flow);
#endif
}
+ void MakeSpillingSupportState(TComputationContext& ctx, NUdf::TUnboxedValue& state, ui64 count, const bool* directions) const {
+ if (Sort && !HasCount && !ctx.ExecuteLLVM) {
+ state = ctx.HolderFactory.Create<TSpillingSupportState<Sort, HasCount>>(count, directions, Directions.size(), TMyValueCompare(Keys), Indexes, Flow, TupleMultiType);
+ return;
+ }
+ state = ctx.HolderFactory.Create<TState<Sort, HasCount>>(count, directions, Directions.size(), TMyValueCompare(Keys), Indexes, Flow);
+ }
+
void RegisterDependencies() const final {
if (const auto flow = this->FlowDependsOn(Flow)) {
if constexpr (HasCount) {
@@ -538,6 +963,7 @@ private:
const std::vector<ui32> Indexes;
const std::vector<EValueRepresentation> Representations;
TKeyTypes KeyTypes;
+ TMultiType* TupleMultiType;
bool HasComplexType = false;
#ifndef MKQL_DISABLE_CODEGEN
@@ -587,10 +1013,14 @@ IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactor
std::unordered_set<ui32> keyIndexes;
std::vector<TKeyInfo> keys(keyWidth);
+ std::vector<TType*> tupleTypes;
+ tupleTypes.reserve(inputWideComponents.size());
+
for (auto i = 0U; i < keyWidth; ++i) {
const auto keyIndex = AS_VALUE(TDataLiteral, callable.GetInput(((i + 1U) << 1U) - offset))->AsValue().Get<ui32>();
indexes[i] = keyIndex;
keyIndexes.emplace(keyIndex);
+ tupleTypes.emplace_back(inputWideComponents[keyIndex]);
bool isTuple;
bool encoded;
@@ -608,6 +1038,7 @@ IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactor
}
}
+
size_t payloadPos = keyWidth;
for (auto i = 0U; i < indexes.size(); ++i) {
if (keyIndexes.contains(i)) {
@@ -615,19 +1046,21 @@ IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactor
}
indexes[payloadPos++] = i;
+ tupleTypes.emplace_back(inputWideComponents[i]);
}
std::vector<EValueRepresentation> representations(inputWideComponents.size());
for (auto i = 0U; i < representations.size(); ++i)
representations[i] = GetValueRepresentation(inputWideComponents[indexes[i]]);
+ auto tupleMultiType = TMultiType::Create(tupleTypes.size(),tupleTypes.data(), ctx.Env);
TComputationNodePtrVector directions(keyWidth);
auto index = 1U - offset;
std::generate(directions.begin(), directions.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++++index); });
if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
return new TWideTopWrapper<Sort, HasCount>(ctx.Mutables, wide, count, std::move(directions), std::move(keys),
- std::move(indexes), std::move(representations));
+ std::move(indexes), std::move(representations), tupleMultiType);
}
THROW yexception() << "Expected wide flow.";