diff options
author | AlexSm <alex@ydb.tech> | 2024-11-28 15:41:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-28 15:41:10 +0100 |
commit | 2f8998014b614a26927adaad429c80717c247058 (patch) | |
tree | fe2b3057b03fc53cd809f96a6dba7a4bd6d7281d /yql/essentials/minikql | |
parent | 6067a04d33e1f48cf6f6712eb49fcac5f651b631 (diff) | |
parent | 2b93e092495c43045e2db1ea6a2abc16e85f8381 (diff) | |
download | ydb-2f8998014b614a26927adaad429c80717c247058.tar.gz |
Merge pull request #12088 from ydb-platform/mergelibs-241128-1021
Library import 241128-1021
Diffstat (limited to 'yql/essentials/minikql')
6 files changed, 223 insertions, 42 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_wide_combine.cpp b/yql/essentials/minikql/comp_nodes/mkql_wide_combine.cpp index 75e5e0dd6f..af9e1d4ce0 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_wide_combine.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_wide_combine.cpp @@ -16,6 +16,9 @@ #include <util/string/cast.h> + +#include <contrib/libs/xxhash/xxhash.h> + namespace NKikimr { namespace NMiniKQL { @@ -315,6 +318,7 @@ public: CurrentPage = &Storage.emplace_back(RowSize() * CountRowsOnPage, NUdf::TUnboxedValuePod()); CurrentPosition = 0; Tongue = CurrentPage->data(); + StoredDataSize = 0; CleanupCurrentContext(); return true; @@ -345,6 +349,7 @@ public: EFetchResult InputStatus = EFetchResult::One; NUdf::TUnboxedValuePod* Tongue = nullptr; NUdf::TUnboxedValuePod* Throat = nullptr; + i64 StoredDataSize = 0; private: std::optional<TStorageIterator> ExtractIt; @@ -485,8 +490,7 @@ public: return isNew ? ETasteResult::Init : ETasteResult::Update; } - auto hash = Hasher(ViewForKeyAndState.data()); - auto bucketId = hash % SpilledBucketCount; + auto bucketId = ChooseBucket(ViewForKeyAndState.data()); auto& bucket = SpilledBuckets[bucketId]; if (bucket.BucketState == TSpilledBucket::EBucketState::InMemory) { @@ -530,7 +534,14 @@ public: return value; } + private: + ui64 ChooseBucket(const NUdf::TUnboxedValuePod *const key) { + auto provided_hash = Hasher(key); + XXH64_hash_t bucket = XXH64(&provided_hash, sizeof(provided_hash), 0) % SpilledBucketCount; + return bucket; + } + EUpdateResult FlushSpillingBuffersAndWait() { UpdateSpillingBuckets(); @@ -593,14 +604,17 @@ private: SplitStateSpillingBucket = -1; } while (const auto keyAndState = static_cast<NUdf::TUnboxedValue *>(InMemoryProcessingState.Extract())) { - auto hash = Hasher(keyAndState); //Hasher uses only key for hashing - auto bucketId = hash % SpilledBucketCount; + auto bucketId = ChooseBucket(keyAndState); // This uses only key for hashing auto& bucket = SpilledBuckets[bucketId]; bucket.LineCount++; if (bucket.BucketState != TSpilledBucket::EBucketState::InMemory) { - bucket.BucketState = TSpilledBucket::EBucketState::SpillingState; + if (bucket.BucketState != TSpilledBucket::EBucketState::SpillingState) { + bucket.BucketState = TSpilledBucket::EBucketState::SpillingState; + SpillingBucketsCount++; + } + bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()}); for (size_t i = 0; i < KeyAndStateType->GetElementsCount(); ++i) { //releasing values stored in unsafe TUnboxedValue buffer @@ -629,10 +643,11 @@ private: ui32 bucketNumToSpill = GetLargestInMemoryBucketNumber(); SplitStateSpillingBucket = bucketNumToSpill; - InMemoryBucketsCount--; auto& bucket = SpilledBuckets[bucketNumToSpill]; bucket.BucketState = TSpilledBucket::EBucketState::SpillingState; + SpillingBucketsCount++; + InMemoryBucketsCount--; while (const auto keyAndState = static_cast<NUdf::TUnboxedValue*>(bucket.InMemoryProcessingState->Extract())) { bucket.AsyncWriteOperation = bucket.SpilledState->WriteWideItem({keyAndState, KeyAndStateType->GetElementsCount()}); @@ -662,6 +677,7 @@ private: bucket.InMemoryProcessingState->ReadMore<false>(); bucket.BucketState = TSpilledBucket::EBucketState::SpillingData; + SpillingBucketsCount--; } } @@ -846,6 +862,12 @@ private: YQL_LOG(INFO) << "switching Memory mode to ProcessSpilled"; MKQL_ENSURE(EOperatingMode::Spilling == Mode, "Internal logic error"); MKQL_ENSURE(SpilledBuckets.size() == SpilledBucketCount, "Internal logic error"); + + std::sort(SpilledBuckets.begin(), SpilledBuckets.end(), [](const TSpilledBucket& lhs, const TSpilledBucket& rhs) { + bool lhs_in_memory = lhs.BucketState == TSpilledBucket::EBucketState::InMemory; + bool rhs_in_memory = rhs.BucketState == TSpilledBucket::EBucketState::InMemory; + return lhs_in_memory > rhs_in_memory; + }); break; } @@ -904,6 +926,7 @@ private: llvm::IntegerType* ValueType; llvm::PointerType* PtrValueType; llvm::IntegerType* StatusType; + llvm::IntegerType* StoredType; protected: using TBase::Context; public: @@ -912,6 +935,7 @@ public: result.emplace_back(StatusType); //status result.emplace_back(PtrValueType); //tongue result.emplace_back(PtrValueType); //throat + result.emplace_back(StoredType); //StoredDataSize result.emplace_back(Type::getInt32Ty(Context)); //size result.emplace_back(Type::getInt32Ty(Context)); //size return result; @@ -929,11 +953,16 @@ public: return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 2); } + llvm::Constant* GetStored() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3); + } + TLLVMFieldsStructureState(llvm::LLVMContext& context) : TBase(context) , ValueType(Type::getInt128Ty(Context)) , PtrValueType(PointerType::getUnqual(ValueType)) - , StatusType(Type::getInt32Ty(Context)) { + , StatusType(Type::getInt32Ty(Context)) + , StoredType(Type::getInt64Ty(Context)) { } }; @@ -988,6 +1017,10 @@ public: ptr->InputStatus = Flow->FetchValues(ctx, fields); if constexpr (SkipYields) { if (EFetchResult::Yield == ptr->InputStatus) { + if (MemLimit) { + const auto currentUsage = ctx.HolderFactory.GetMemoryUsed(); + ptr->StoredDataSize += currentUsage > initUsage ? currentUsage - initUsage : 0; + } return EFetchResult::Yield; } else if (EFetchResult::Finish == ptr->InputStatus) { break; @@ -1000,7 +1033,7 @@ public: Nodes.ExtractKey(ctx, fields, static_cast<NUdf::TUnboxedValue*>(ptr->Tongue)); Nodes.ProcessItem(ctx, ptr->TasteIt() ? nullptr : static_cast<NUdf::TUnboxedValue*>(ptr->Tongue), static_cast<NUdf::TUnboxedValue*>(ptr->Throat)); - } while (!ctx.template CheckAdjustedMemLimit<TrackRss>(MemLimit, initUsage)); + } while (!ctx.template CheckAdjustedMemLimit<TrackRss>(MemLimit, initUsage - ptr->StoredDataSize)); ptr->PushStat(ctx.Stats); } @@ -1019,6 +1052,7 @@ public: const auto valueType = Type::getInt128Ty(context); const auto ptrValueType = PointerType::getUnqual(valueType); const auto statusType = Type::getInt32Ty(context); + const auto storedType = Type::getInt64Ty(context); TLLVMFieldsStructureState stateFields(context); const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); @@ -1113,6 +1147,28 @@ public: block = save; + if (MemLimit) { + const auto storedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStored() }, "stored_ptr", block); + const auto lastStored = new LoadInst(storedType, storedPtr, "last_stored", block); + const auto currentUsage = GetMemoryUsed(MemLimit, ctx, block); + + + const auto skipSavingUsed = BasicBlock::Create(context, "skip_saving_used", ctx.Func); + const auto saveUsed = BasicBlock::Create(context, "save_used", ctx.Func); + const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, currentUsage, used, "check", block); + BranchInst::Create(saveUsed, skipSavingUsed, check, block); + + block = saveUsed; + + const auto usedMemory = BinaryOperator::CreateSub(GetMemoryUsed(MemLimit, ctx, block), used, "used_memory", block); + const auto inc = BinaryOperator::CreateAdd(lastStored, usedMemory, "inc", block); + new StoreInst(inc, storedPtr, block); + + BranchInst::Create(skipSavingUsed, block); + + block = skipSavingUsed; + } + new StoreInst(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), statusPtr, block); result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block); BranchInst::Create(over, block); @@ -1249,7 +1305,14 @@ public: block = test; - const auto check = CheckAdjustedMemLimit<TrackRss>(MemLimit, used, ctx, block); + auto totalUsed = used; + if (MemLimit) { + const auto storedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetStored() }, "stored_ptr", block); + const auto lastStored = new LoadInst(storedType, storedPtr, "last_stored", block); + totalUsed = BinaryOperator::CreateSub(used, lastStored, "decr", block); + } + + const auto check = CheckAdjustedMemLimit<TrackRss>(MemLimit, totalUsed, ctx, block); BranchInst::Create(done, loop, check, block); block = done; diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_wide_combine_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_wide_combine_ut.cpp index 904f38ac67..55cb6babbf 100644 --- a/yql/essentials/minikql/comp_nodes/ut/mkql_wide_combine_ut.cpp +++ b/yql/essentials/minikql/comp_nodes/ut/mkql_wide_combine_ut.cpp @@ -15,8 +15,13 @@ namespace NMiniKQL { namespace { constexpr auto border = 9124596000000000ULL; -constexpr ui64 g_Yield = std::numeric_limits<ui64>::max(); -constexpr ui64 g_TestYieldStreamData[] = {0, 1, 0, 2, g_Yield, 0, g_Yield, 1, 2, 0, 1, 3, 0, g_Yield, 1, 2}; + +struct TTestStreamParams { + static constexpr ui64 Yield = std::numeric_limits<ui64>::max(); + + ui64 StringSize = 1; + std::vector<ui64> TestYieldStreamData; +}; class TTestStreamWrapper: public TMutableComputationNode<TTestStreamWrapper> { using TBaseComputation = TMutableComputationNode<TTestStreamWrapper>; @@ -25,19 +30,19 @@ public: public: using TBase = TComputationValue<TStreamValue>; - TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx) - : TBase(memInfo), CompCtx(compCtx) + TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, TTestStreamParams& params) + : TBase(memInfo), CompCtx(compCtx), Params(params) {} private: NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override { - constexpr auto size = Y_ARRAY_SIZE(g_TestYieldStreamData); + auto size = Params.TestYieldStreamData.size(); if (Index == size) { return NUdf::EFetchStatus::Finish; } - const auto val = g_TestYieldStreamData[Index]; - if (g_Yield == val) { + const auto val = Params.TestYieldStreamData[Index]; + if (Params.Yield == val) { ++Index; return NUdf::EFetchStatus::Yield; } @@ -45,7 +50,7 @@ public: NUdf::TUnboxedValue* items = nullptr; result = CompCtx.HolderFactory.CreateDirectArrayHolder(2, items); items[0] = NUdf::TUnboxedValuePod(val); - items[1] = NUdf::TUnboxedValuePod(MakeString(ToString(val))); + items[1] = NUdf::TUnboxedValuePod(MakeString(ToString(val) * Params.StringSize)); ++Index; @@ -55,27 +60,31 @@ public: private: TComputationContext& CompCtx; ui64 Index = 0; + TTestStreamParams& Params; }; - TTestStreamWrapper(TComputationMutables& mutables) + TTestStreamWrapper(TComputationMutables& mutables, TTestStreamParams& params) : TBaseComputation(mutables) + , Params(params) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - return ctx.HolderFactory.Create<TStreamValue>(ctx); + return ctx.HolderFactory.Create<TStreamValue>(ctx, Params); } private: void RegisterDependencies() const final {} + + TTestStreamParams& Params; }; -IComputationNode* WrapTestStream(const TComputationNodeFactoryContext& ctx) { - return new TTestStreamWrapper(ctx.Mutables); +IComputationNode* WrapTestStream(const TComputationNodeFactoryContext& ctx, TTestStreamParams& params) { + return new TTestStreamWrapper(ctx.Mutables, params); } -TComputationNodeFactory GetNodeFactory() { - return [](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* { +TComputationNodeFactory GetNodeFactory(TTestStreamParams& params) { + return [¶ms](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* { if (callable.GetType()->GetName() == "TestYieldStream") { - return WrapTestStream(ctx); + return WrapTestStream(ctx, params); } return GetBuiltinFactory()(callable, ctx); }; @@ -456,7 +465,9 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideCombinerTest) { } #if !defined(MKQL_RUNTIME_VERSION) || MKQL_RUNTIME_VERSION >= 46u Y_UNIT_TEST_LLVM(TestHasLimitButPasstroughtYields) { - TSetup<LLVM> setup(GetNodeFactory()); + TTestStreamParams params; + params.TestYieldStreamData = {0, 1, 0, 2, TTestStreamParams::Yield, 0, TTestStreamParams::Yield, 1, 2, 0, 1, 3, 0, TTestStreamParams::Yield, 1, 2}; + TSetup<LLVM> setup(GetNodeFactory(params)); TProgramBuilder& pb = *setup.PgmBuilder; const auto stream = MakeStream<LLVM>(setup); @@ -486,6 +497,40 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideCombinerTest) { UNIT_ASSERT_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Finish); } #endif +#if !defined(MKQL_RUNTIME_VERSION) || MKQL_RUNTIME_VERSION >= 46u + Y_UNIT_TEST_LLVM(TestSkipYieldRespectsMemLimit) { + TTestStreamParams params; + params.StringSize = 50000; + params.TestYieldStreamData = {0, TTestStreamParams::Yield, 2, TTestStreamParams::Yield, 3, TTestStreamParams::Yield, 4}; + TSetup<LLVM> setup(GetNodeFactory(params)); + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto stream = MakeStream<LLVM>(setup); + const auto pgmReturn = pb.FromFlow(pb.NarrowMap(pb.WideCombiner(pb.ExpandMap(pb.ToFlow(stream), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Member(item, "a"), pb.Member(item, "b")}; }), -100000LL, + [&](TRuntimeNode::TList items) -> TRuntimeNode::TList { return {items.front()}; }, + [&](TRuntimeNode::TList, TRuntimeNode::TList items) -> TRuntimeNode::TList { return items; }, + [&](TRuntimeNode::TList, TRuntimeNode::TList items, TRuntimeNode::TList state) -> TRuntimeNode::TList { return {state.front(), pb.AggrConcat(state.back(), items.back())}; }, + [&](TRuntimeNode::TList, TRuntimeNode::TList state) -> TRuntimeNode::TList { return state; }), + [&](TRuntimeNode::TList items) { return items.back(); } + )); + const auto graph = setup.BuildGraph(pgmReturn); + const auto streamVal = graph->GetValue(); + NUdf::TUnboxedValue result; + + // skip first 2 yields + UNIT_ASSERT_VALUES_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Yield); + UNIT_ASSERT_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Yield); + // return all the collected values + UNIT_ASSERT_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Ok); + UNIT_ASSERT_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Ok); + UNIT_ASSERT_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Ok); + UNIT_ASSERT_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Yield); + UNIT_ASSERT_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Ok); + UNIT_ASSERT_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Finish); + UNIT_ASSERT_EQUAL(streamVal.Fetch(result), NUdf::EFetchStatus::Finish); + } +#endif } Y_UNIT_TEST_SUITE(TMiniKQLWideCombinerPerfTest) { @@ -1351,6 +1396,63 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) { const auto fetchStatus = streamVal.Fetch(item); UNIT_ASSERT_EQUAL(fetchStatus, NUdf::EFetchStatus::Finish); } + + Y_UNIT_TEST_LLVM(TestSpillingBucketsDistribution) { + const size_t expectedBucketsCount = 128; + const size_t sampleSize = 8 * 128; + + TSetup<LLVM, true> setup; + + std::vector<std::pair<ui64, ui64>> samples(sampleSize); + std::generate(samples.begin(), samples.end(), [key = (ui64)1] () mutable -> std::pair<ui64, ui64> { + key += 64; + return {key, 1}; + }); + + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto listType = pb.NewListType(pb.NewTupleType({pb.NewDataType(NUdf::TDataType<ui64>::Id), pb.NewDataType(NUdf::TDataType<ui64>::Id)})); + const auto list = TCallableBuilder(pb.GetTypeEnvironment(), "TestList", listType).Build(); + + const auto pgmReturn = pb.FromFlow(pb.NarrowMap(pb.WideLastCombinerWithSpilling(pb.ExpandMap(pb.ToFlow(TRuntimeNode(list, false)), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return { pb.Nth(item, 0U), pb.Nth(item, 1U) }; }), + [&](TRuntimeNode::TList items) -> TRuntimeNode::TList { return {items.front()}; }, + [&](TRuntimeNode::TList, TRuntimeNode::TList items) -> TRuntimeNode::TList { return {items.back()}; }, + [&](TRuntimeNode::TList, TRuntimeNode::TList items, TRuntimeNode::TList state) -> TRuntimeNode::TList { return {pb.AggrAdd(state.front(), items.back())}; }, + [&](TRuntimeNode::TList keys, TRuntimeNode::TList state) -> TRuntimeNode::TList { return {keys.front(), state.front()}; }), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(items); } + )); + + const auto spillerFactory = std::make_shared<TMockSpillerFactory>(); + const auto graph = setup.BuildGraph(pgmReturn, {list}); + graph->GetContext().SpillerFactory = spillerFactory; + + NUdf::TUnboxedValue* items = nullptr; + graph->GetEntryPoint(0, true)->SetValue(graph->GetContext(), graph->GetHolderFactory().CreateDirectArrayHolder(samples.size(), items)); + for (const auto& sample : samples) { + NUdf::TUnboxedValue* pair = nullptr; + *items++ = graph->GetHolderFactory().CreateDirectArrayHolder(2U, pair); + pair[0] = NUdf::TUnboxedValuePod(sample.first); + pair[1] = NUdf::TUnboxedValuePod(sample.second); + } + + const auto& value = graph->GetValue(); + + NUdf::TUnboxedValue item; + while (value.Fetch(item) != NUdf::EFetchStatus::Finish) { + ; + } + + UNIT_ASSERT_EQUAL_C(spillerFactory->GetCreatedSpillers().size(), 1, "WideLastCombiner expected to create one spiller "); + const auto wideCombinerSpiller = std::dynamic_pointer_cast<TMockSpiller>(spillerFactory->GetCreatedSpillers()[0]); + UNIT_ASSERT_C(wideCombinerSpiller != nullptr, "MockSpillerFactory expected to create only MockSpillers"); + + auto flushedBucketsSizes = wideCombinerSpiller->GetPutSizes(); + UNIT_ASSERT_EQUAL_C(flushedBucketsSizes.size(), expectedBucketsCount, "Spiller doesn't Put expected number of buckets"); + + auto anyEmpty = std::any_of(flushedBucketsSizes.begin(), flushedBucketsSizes.end(), [](size_t size) { return size == 0; }); + UNIT_ASSERT_C(!anyEmpty, "Spiller flushed empty bucket"); + } } Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerPerfTest) { diff --git a/yql/essentials/minikql/comp_nodes/ya.make.inc b/yql/essentials/minikql/comp_nodes/ya.make.inc index 3c1531eac7..518f5f3c43 100644 --- a/yql/essentials/minikql/comp_nodes/ya.make.inc +++ b/yql/essentials/minikql/comp_nodes/ya.make.inc @@ -143,7 +143,7 @@ COPY( AUTO FROM ${ORIG_SRC_DIR} ${ORIG_SOURCES} - OUTPUT_INCLUDES + OUTPUT_INCLUDES ${BINDIR}/yql/essentials/minikql/computation/mkql_computation_node_codegen.h ${BINDIR}/yql/essentials/minikql/computation/mkql_block_impl_codegen.h ${BINDIR}/yql/essentials/minikql/computation/mkql_llvm_base.h @@ -154,6 +154,7 @@ COPY( PEERDIR( contrib/libs/apache/arrow + contrib/libs/xxhash yql/essentials/types/binary_json yql/essentials/minikql yql/essentials/minikql/arrow diff --git a/yql/essentials/minikql/computation/mkql_spiller_adapter.h b/yql/essentials/minikql/computation/mkql_spiller_adapter.h index 462a2a7c5e..8ddcfe46be 100644 --- a/yql/essentials/minikql/computation/mkql_spiller_adapter.h +++ b/yql/essentials/minikql/computation/mkql_spiller_adapter.h @@ -28,12 +28,12 @@ public: /// In this case a caller must wait operation completion and call StoreCompleted. /// Design note: not using Subscribe on a Future here to avoid possible race condition std::optional<NThreading::TFuture<ISpiller::TKey>> WriteWideItem(const TArrayRef<NUdf::TUnboxedValuePod>& wideItem) { - Packer.AddWideItem(wideItem.data(), wideItem.size()); - if(Packer.PackedSizeEstimate() > SizeLimit) { - return Spiller->Put(std::move(Packer.Finish())); - } else { + Packer.AddWideItem(wideItem.data(), wideItem.size()); + if (Packer.PackedSizeEstimate() > SizeLimit) { + return Spiller->Put(std::move(Packer.Finish())); + } else { return std::nullopt; - } + } } std::optional<NThreading::TFuture<ISpiller::TKey>> FinishWriting() { diff --git a/yql/essentials/minikql/computation/mock_spiller_factory_ut.h b/yql/essentials/minikql/computation/mock_spiller_factory_ut.h index c053b2c52e..4b0b2ed24a 100644 --- a/yql/essentials/minikql/computation/mock_spiller_factory_ut.h +++ b/yql/essentials/minikql/computation/mock_spiller_factory_ut.h @@ -12,8 +12,17 @@ public: } ISpiller::TPtr CreateSpiller() override { - return CreateMockSpiller(); + auto new_spiller = CreateMockSpiller(); + Spillers_.push_back(new_spiller); + return new_spiller; } + + const std::vector<ISpiller::TPtr>& GetCreatedSpillers() const { + return Spillers_; + } + +private: + std::vector<ISpiller::TPtr> Spillers_; }; } // namespace NKikimr::NMiniKQL diff --git a/yql/essentials/minikql/computation/mock_spiller_ut.h b/yql/essentials/minikql/computation/mock_spiller_ut.h index 42846eab1f..715018f3e0 100644 --- a/yql/essentials/minikql/computation/mock_spiller_ut.h +++ b/yql/essentials/minikql/computation/mock_spiller_ut.h @@ -11,22 +11,23 @@ namespace NKikimr::NMiniKQL { class TMockSpiller: public ISpiller{ public: TMockSpiller() - : NextKey(0) + : NextKey_(0) {} NThreading::TFuture<TKey> Put(NYql::TChunkedBuffer&& blob) override { auto promise = NThreading::NewPromise<ISpiller::TKey>(); - auto key = NextKey; - Storage[key] = std::move(blob); - NextKey++; + auto key = NextKey_; + Storage_[key] = std::move(blob); + PutSizes_.push_back(Storage_[key].Size()); + NextKey_++; promise.SetValue(key); return promise.GetFuture();; } NThreading::TFuture<std::optional<NYql::TChunkedBuffer>> Get(TKey key) override { auto promise = NThreading::NewPromise<std::optional<NYql::TChunkedBuffer>>(); - if (auto it = Storage.find(key); it != Storage.end()) { + if (auto it = Storage_.find(key); it != Storage_.end()) { promise.SetValue(it->second); } else { promise.SetValue(std::nullopt); @@ -37,9 +38,9 @@ public: NThreading::TFuture<std::optional<NYql::TChunkedBuffer>> Extract(TKey key) override { auto promise = NThreading::NewPromise<std::optional<NYql::TChunkedBuffer>>(); - if (auto it = Storage.find(key); it != Storage.end()) { + if (auto it = Storage_.find(key); it != Storage_.end()) { promise.SetValue(std::move(it->second)); - Storage.erase(it); + Storage_.erase(it); } else { promise.SetValue(std::nullopt); } @@ -49,12 +50,17 @@ public: NThreading::TFuture<void> Delete(TKey key) override { auto promise = NThreading::NewPromise<void>(); promise.SetValue(); - Storage.erase(key); + Storage_.erase(key); return promise.GetFuture(); } + + const std::vector<size_t>& GetPutSizes() const { + return PutSizes_; + } private: - ISpiller::TKey NextKey; - std::unordered_map<ISpiller::TKey, NYql::TChunkedBuffer> Storage; + ISpiller::TKey NextKey_; + std::unordered_map<ISpiller::TKey, NYql::TChunkedBuffer> Storage_; + std::vector<size_t> PutSizes_; }; inline ISpiller::TPtr CreateMockSpiller() { return std::make_shared<TMockSpiller>(); |