diff options
author | vvvv <vvvv@ydb.tech> | 2023-02-20 18:58:57 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2023-02-20 18:58:57 +0300 |
commit | caced409f95cbd73bbb5acd5f6ca0eb85ee282ab (patch) | |
tree | e4c1e77a69af550b6914fabadc670667abffa453 | |
parent | a181588e011454f9b95afde48fd5a71f7513461b (diff) | |
download | ydb-caced409f95cbd73bbb5acd5f6ca0eb85ee282ab.tar.gz |
mkql part of WideSort
8 files changed, 238 insertions, 63 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index 21400a0e14..1662fed7b8 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -324,6 +324,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"WideChopper", &WrapWideChopper}, {"WideTop", &WrapWideTop}, {"WideTopSort", &WrapWideTopSort}, + {"WideSort", &WrapWideSort}, {"WideFlowArg", &WrapWideFlowArg}, {"Source", &WrapSource}, {"RangeCreate", &WrapRangeCreate}, diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp index 33848acde6..a7e5f4b0ca 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp @@ -29,10 +29,11 @@ struct TMyValueCompare { using TComparePtr = int(*)(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*); using TCompareFunc = std::function<int(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)>; -class TState : public TComputationValue<TState> { -using TBase = TComputationValue<TState>; +template <bool HasCount> +class TState : public TComputationValue<TState<HasCount>> { +using TBase = TComputationValue<TState<HasCount>>; public: -using TLLVMBase = TLLVMFieldsStructure<TComputationValue<TState>>; +using TLLVMBase = TLLVMFieldsStructure<TComputationValue<TState<HasCount>>>; private: using TStorage = std::vector<NUdf::TUnboxedValue, TMKQLAllocator<NUdf::TUnboxedValue, EMemorySubPool::Temporary>>; using TFields = std::vector<NUdf::TUnboxedValue*, TMKQLAllocator<NUdf::TUnboxedValue*, EMemorySubPool::Temporary>>; @@ -43,15 +44,30 @@ private: } void ResetFields() { - auto ptr = Tongue = Free.back(); + NUdf::TUnboxedValuePod* ptr; + if constexpr (HasCount) { + ptr = Tongue = Free.back(); + } else { + auto pos = Storage.size(); + Storage.insert(Storage.end(), Indexes.size(), {}); + ptr = Tongue = Storage.data() + pos; + } + std::for_each(Indexes.cbegin(), Indexes.cend(), [&](ui32 index) { Fields[index] = static_cast<NUdf::TUnboxedValue*>(ptr++); }); } public: TState(TMemoryUsageInfo* memInfo, ui64 count, const bool* directons, size_t keyWidth, const TCompareFunc& compare, const std::vector<ui32>& indexes) : TBase(memInfo), Count(count), Indexes(indexes), Directions(directons, directons + keyWidth) , LessFunc(std::bind(std::less<int>(), std::bind(compare, Directions.data(), std::placeholders::_1, std::placeholders::_2), 0)) - , Storage(GetStorageSize() * Indexes.size()), Free(GetStorageSize(), nullptr), Fields(Indexes.size(), nullptr) + , Fields(Indexes.size(), nullptr) { + if constexpr (!HasCount) { + ResetFields(); + return; + } + + Storage.resize(GetStorageSize() * Indexes.size()); + Free.resize(GetStorageSize(), nullptr); if (Count) { Full.reserve(GetStorageSize()); auto ptr = Storage.data(); @@ -70,6 +86,11 @@ public: } bool Push() { + if constexpr (!HasCount) { + ResetFields(); + return true; + } + if (Full.size() + 1U == GetStorageSize()) { Free.pop_back(); @@ -100,6 +121,18 @@ public: template<bool Sort> void Seal() { + if constexpr (!HasCount) { + static_assert (Sort); + Storage.resize(Storage.size() - Indexes.size()); + Full.reserve(Storage.size() / Indexes.size()); + for (auto it = Storage.begin(); it != Storage.end(); it += Indexes.size()) { + Full.emplace_back(&*it); + } + + std::sort(Full.rbegin(), Full.rend(), LessFunc); + return; + } + Free.clear(); Free.shrink_to_fit(); @@ -135,9 +168,10 @@ private: }; #ifndef MKQL_DISABLE_CODEGEN -class TLLVMFieldsStructureState: public TState::TLLVMBase { +template <bool HasCount> +class TLLVMFieldsStructureState: public TState<HasCount>::TLLVMBase { private: - using TBase = TState::TLLVMBase; + using TBase = typename TState<HasCount>::TLLVMBase; llvm::IntegerType* ValueType; llvm::PointerType* PtrValueType; llvm::IntegerType* StatusType; @@ -171,13 +205,13 @@ public: }; #endif -template<bool Sort> -class TWideTopWrapper: public TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort>> +template<bool Sort, bool HasCount> +class TWideTopWrapper: public TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort, HasCount>> #ifndef MKQL_DISABLE_CODEGEN , public ICodegeneratorRootNode #endif { -using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort>>; +using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort, HasCount>>; public: TWideTopWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, TComputationNodePtrVector&& directions, TKeyTypes&& keyTypes, std::vector<ui32>&& indexes, std::vector<EValueRepresentation>&& representations) : TBaseComputation(mutables, flow, EValueRepresentation::Boxed), Flow(flow), Count(count), Directions(std::move(directions)), KeyTypes(std::move(keyTypes)), Indexes(std::move(indexes)), Representations(std::move(representations)) @@ -185,13 +219,19 @@ public: EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { if (!state.HasValue()) { - const auto count = Count->GetValue(ctx).Get<ui64>(); + ui64 count; + if constexpr (HasCount) { + count = Count->GetValue(ctx).Get<ui64>(); + } else { + count = 0; + } + std::vector<bool> dirs(Directions.size()); std::transform(Directions.cbegin(), Directions.cend(), dirs.begin(), [&ctx](IComputationNode* dir){ return dir->GetValue(ctx).Get<bool>(); }); MakeState(ctx, state, count, dirs.data()); } - if (const auto ptr = static_cast<TState*>(state.AsBoxed().Get())) { + if (const auto ptr = static_cast<TState<HasCount>*>(state.AsBoxed().Get())) { while (EFetchResult::Finish != ptr->InputStatus) { switch (ptr->InputStatus = Flow->FetchValues(ctx, ptr->GetFields())) { case EFetchResult::One: @@ -200,7 +240,7 @@ public: case EFetchResult::Yield: return EFetchResult::Yield; case EFetchResult::Finish: - ptr->Seal<Sort>(); + ptr->template Seal<Sort>(); break; } } @@ -230,7 +270,7 @@ public: const auto statusType = Type::getInt32Ty(context); const auto indexType = Type::getInt32Ty(ctx.Codegen->GetContext()); - TLLVMFieldsStructureState stateFields(context); + TLLVMFieldsStructureState<HasCount> stateFields(context); const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); const auto statePtrType = PointerType::getUnqual(stateType); @@ -255,8 +295,13 @@ public: BranchInst::Create(main, make, HasValue(statePtr, block), block); block = make; - const auto count = GetNodeValue(Count, ctx, block); - const auto trunc = GetterFor<ui64>(count, context, block); + llvm::Value* trunc; + if constexpr (HasCount) { + const auto count = GetNodeValue(Count, ctx, block); + trunc = GetterFor<ui64>(count, context, block); + } else { + trunc = ConstantInt::get(Type::getInt64Ty(context), 0U); + } const auto dirs = new AllocaInst(ArrayType::get(Type::getInt1Ty(context), Directions.size()), 0U, "dirs", block); for (auto i = 0U; i < Directions.size(); ++i) { @@ -311,7 +356,7 @@ public: block = rest; new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), statusPtr, block); - const auto sealFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Seal<Sort>)); + const auto sealFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::template Seal<Sort>)); const auto sealType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType()}, false); const auto sealPtr = CastInst::Create(Instruction::IntToPtr, sealFunc, PointerType::getUnqual(sealType), "seal", block); CallInst::Create(sealPtr, {stateArg}, "", block); @@ -328,40 +373,52 @@ public: placeholders[i] = GetElementPtrInst::CreateInBounds(tongue, {ConstantInt::get(indexType, i)}, (TString("placeholder_") += ToString(i)).c_str(), block); } - for (auto i = 0U; i < KeyTypes.size(); ++i) { - const auto item = getres.second[Indexes[i]](ctx, block); - new StoreInst(item, placeholders[i], block); + if constexpr (!HasCount) { + for (auto i = 0; i < Representations.size(); ++i) { + const auto item = getres.second[Indexes[i]](ctx, block); + ValueAddRef(Representations[i], item, ctx, block); + new StoreInst(item, placeholders[i], block); + } + + } else { + for (auto i = 0U; i < KeyTypes.size(); ++i) { + const auto item = getres.second[Indexes[i]](ctx, block); + new StoreInst(item, placeholders[i], block); + } } - const auto pushFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Push)); + + const auto pushFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::Push)); const auto pushType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false); const auto pushPtr = CastInst::Create(Instruction::IntToPtr, pushFunc, PointerType::getUnqual(pushType), "function", block); const auto accepted = CallInst::Create(pushPtr, {stateArg}, "accepted", block); + if constexpr (HasCount) { + const auto push = BasicBlock::Create(context, "push", ctx.Func); + const auto skip = BasicBlock::Create(context, "skip", ctx.Func); - const auto push = BasicBlock::Create(context, "push", ctx.Func); - const auto skip = BasicBlock::Create(context, "skip", ctx.Func); - - BranchInst::Create(push, skip, accepted, block); + BranchInst::Create(push, skip, accepted, block); - block = push; + block = push; - for (auto i = 0U; i < KeyTypes.size(); ++i) { - ValueAddRef(Representations[i], placeholders[i], ctx, block); - } + for (auto i = 0U; i < KeyTypes.size(); ++i) { + ValueAddRef(Representations[i], placeholders[i], ctx, block); + } - for (auto i = KeyTypes.size(); i < Representations.size(); ++i) { - const auto item = getres.second[Indexes[i]](ctx, block); - ValueAddRef(Representations[i], item, ctx, block); - new StoreInst(item, placeholders[i], block); - } + for (auto i = KeyTypes.size(); i < Representations.size(); ++i) { + const auto item = getres.second[Indexes[i]](ctx, block); + ValueAddRef(Representations[i], item, ctx, block); + new StoreInst(item, placeholders[i], block); + } + - BranchInst::Create(loop, block); + BranchInst::Create(loop, block); - block = skip; + block = skip; - for (auto i = 0U; i < KeyTypes.size(); ++i) { - ValueCleanup(Representations[i], placeholders[i], ctx, block); - new StoreInst(ConstantInt::get(valueType, 0), placeholders[i], block); + for (auto i = 0U; i < KeyTypes.size(); ++i) { + ValueCleanup(Representations[i], placeholders[i], ctx, block); + new StoreInst(ConstantInt::get(valueType, 0), placeholders[i], block); + } } BranchInst::Create(loop, block); @@ -372,7 +429,7 @@ public: const auto good = BasicBlock::Create(context, "good", ctx.Func); - const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract)); + const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::Extract)); const auto extractType = FunctionType::get(outputPtrType, {stateArg->getType()}, false); const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block); const auto out = CallInst::Create(extractPtr, {stateArg}, "out", block); @@ -397,15 +454,18 @@ public: private: void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state, ui64 count, const bool* directions) const { #ifdef MKQL_DISABLE_CODEGEN - state = ctx.HolderFactory.Create<TState>(count, directions, Directions.size(), TMyValueCompare(KeyTypes), Indexes); + state = ctx.HolderFactory.Create<TState<HasCount>>(count, directions, Directions.size(), TMyValueCompare(KeyTypes), Indexes); #else - state = ctx.HolderFactory.Create<TState>(count, directions, Directions.size(), ctx.ExecuteLLVM && Compare ? TCompareFunc(Compare) : TCompareFunc(TMyValueCompare(KeyTypes)), Indexes); + state = ctx.HolderFactory.Create<TState<HasCount>>(count, directions, Directions.size(), ctx.ExecuteLLVM && Compare ? TCompareFunc(Compare) : TCompareFunc(TMyValueCompare(KeyTypes)), Indexes); #endif } void RegisterDependencies() const final { if (const auto flow = this->FlowDependsOn(Flow)) { - TWideTopWrapper::DependsOn(flow, Count); + if constexpr (HasCount) { + TWideTopWrapper::DependsOn(flow, Count); + } + std::for_each(Directions.cbegin(), Directions.cend(), std::bind(&TWideTopWrapper::DependsOn, flow, std::placeholders::_1)); } } @@ -441,20 +501,26 @@ private: } -template<bool Sort> +template<bool Sort, bool HasCount> IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - MKQL_ENSURE(callable.GetInputsCount() > 2U && !(callable.GetInputsCount() % 2U), "Expected more arguments."); + const ui32 offset = HasCount ? 0 : 1; + const ui32 inputsWithCount = callable.GetInputsCount() + offset; + MKQL_ENSURE(inputsWithCount > 2U && !(inputsWithCount % 2U), "Expected more arguments."); const auto flow = LocateNode(ctx.NodeLocator, callable, 0); - const auto count = LocateNode(ctx.NodeLocator, callable, 1); - const auto keyWidth = (callable.GetInputsCount() >> 1U) - 1U; + IComputationNode* count = nullptr; + if (HasCount) { + count = LocateNode(ctx.NodeLocator, callable, 1); + } + + const auto keyWidth = (inputsWithCount >> 1U) - 1U; const auto inputType = AS_TYPE(TTupleType, AS_TYPE(TFlowType, callable.GetType()->GetReturnType())->GetItemType()); std::vector<ui32> indexes(inputType->GetElementsCount()); std::iota(indexes.begin(), indexes.end(), 0U); TKeyTypes keyTypes(keyWidth); for (auto i = 0U; i < keyTypes.size(); ++i) { - const auto keyIndex = AS_VALUE(TDataLiteral, callable.GetInput((i + 1U) << 1U))->AsValue().Get<ui32>(); + const auto keyIndex = AS_VALUE(TDataLiteral, callable.GetInput(((i + 1U) << 1U) - offset))->AsValue().Get<ui32>(); std::swap(indexes[i], indexes[indexes[keyIndex]]); keyTypes[i].first = *UnpackOptionalData(inputType->GetElementType(keyIndex), keyTypes[i].second)->GetDataSlot(); } @@ -464,22 +530,26 @@ IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactor representations[i] = GetValueRepresentation(inputType->GetElementType(indexes[i])); TComputationNodePtrVector directions(keyWidth); - auto index = 1U; + auto index = 1U - offset; std::generate(directions.begin(), directions.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++++index); }); if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) { - return new TWideTopWrapper<Sort>(ctx.Mutables, wide, count, std::move(directions), std::move(keyTypes), std::move(indexes), std::move(representations)); + return new TWideTopWrapper<Sort, HasCount>(ctx.Mutables, wide, count, std::move(directions), std::move(keyTypes), std::move(indexes), std::move(representations)); } THROW yexception() << "Expected wide flow."; } IComputationNode* WrapWideTop(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - return WrapWideTopT<false>(callable, ctx); + return WrapWideTopT<false, true>(callable, ctx); } IComputationNode* WrapWideTopSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - return WrapWideTopT<true>(callable, ctx); + return WrapWideTopT<true, true>(callable, ctx); +} + +IComputationNode* WrapWideSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + return WrapWideTopT<true, false>(callable, ctx); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h index d31c3d8553..e56bbc1d36 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h @@ -7,6 +7,7 @@ namespace NMiniKQL { IComputationNode* WrapWideTop(TCallable& callable, const TComputationNodeFactoryContext& ctx); IComputationNode* WrapWideTopSort(TCallable& callable, const TComputationNodeFactoryContext& ctx); +IComputationNode* WrapWideSort(TCallable& callable, const TComputationNodeFactoryContext& ctx); } } diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h b/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h index c48e5f88e3..bf20ea2c11 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h @@ -15,7 +15,7 @@ if (!(v.AsStringRef() == (expected))) { \ UNIT_FAIL_IMPL( \ "equal assertion failed", \ - Sprintf("%s == %s", #unboxed, #expected)); \ + Sprintf("%s %s == %s", #unboxed, TString(v.AsStringRef()).c_str(), #expected)); \ } \ } while (0) diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp index 3dca8d8d16..b52b617587 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp @@ -563,6 +563,89 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideTopTest) { } } #endif + +#if !defined(MKQL_RUNTIME_VERSION) || MKQL_RUNTIME_VERSION >= 34u +Y_UNIT_TEST_SUITE(TMiniKQLWideSortTest) { + Y_UNIT_TEST_LLVM(SortByFirstKeyAsc) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id); + const auto tupleType = pb.NewTupleType({dataType, dataType}); + + const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one"); + const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two"); + + const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one"); + const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two"); + + const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1"); + const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2"); + const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3"); + const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4"); + const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5"); + const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6"); + const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7"); + const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8"); + const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9"); + + const auto data1 = pb.NewTuple(tupleType, {keyOne, value1}); + + const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2}); + const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3}); + + const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4}); + + const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5}); + const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6}); + const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7}); + const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8}); + const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9}); + + const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9}); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideSort(pb.ExpandMap(pb.ToFlow(list), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }), + {{0U, pb.NewDataLiteral<bool>(true)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 1"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 3"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 2"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 4"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 9"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 8"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 7"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 6"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 5"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } +} +#endif + } } diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 4ebec923a9..5c1e38847d 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -1484,11 +1484,15 @@ TRuntimeNode TProgramBuilder::WideTakeBlocks(TRuntimeNode flow, TRuntimeNode cou } TRuntimeNode TProgramBuilder::WideTopBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTop(__func__, flow, count, keys); + return BuildWideTopOrSort(__func__, flow, count, keys); } TRuntimeNode TProgramBuilder::WideTopSortBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTop(__func__, flow, count, keys); + return BuildWideTopOrSort(__func__, flow, count, keys); +} + +TRuntimeNode TProgramBuilder::WideSortBlocks(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { + return BuildWideTopOrSort(__func__, flow, Nothing(), keys); } TRuntimeNode TProgramBuilder::AsScalar(TRuntimeNode value) { @@ -1671,17 +1675,28 @@ TRuntimeNode TProgramBuilder::Sort(TRuntimeNode list, TRuntimeNode ascending, co TRuntimeNode TProgramBuilder::WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTop(__func__, flow, count, keys); + return BuildWideTopOrSort(__func__, flow, count, keys); } TRuntimeNode TProgramBuilder::WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTop(__func__, flow, count, keys); + return BuildWideTopOrSort(__func__, flow, count, keys); +} + +TRuntimeNode TProgramBuilder::WideSort(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) +{ + return BuildWideTopOrSort(__func__, flow, Nothing(), keys); } -TRuntimeNode TProgramBuilder::BuildWideTop(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - if constexpr (RuntimeVersion < 33U) { - THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName; +TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { + if (count) { + if constexpr (RuntimeVersion < 33U) { + THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName; + } + } else { + if constexpr (RuntimeVersion < 34U) { + THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName; + } } const auto width = AS_TYPE(TTupleType, AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType())->GetElementsCount(); @@ -1689,7 +1704,10 @@ TRuntimeNode TProgramBuilder::BuildWideTop(const std::string_view& callableName, TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType()); callableBuilder.Add(flow); - callableBuilder.Add(count); + if (count) { + callableBuilder.Add(*count); + } + std::for_each(keys.cbegin(), keys.cend(), [&](const std::pair<ui32, TRuntimeNode>& key) { MKQL_ENSURE(key.first < width, "Key index too large: " << key.first); callableBuilder.Add(NewDataLiteral(key.first)); diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index 45dd4d200d..643bb1e7b9 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -248,6 +248,7 @@ public: TRuntimeNode WideTakeBlocks(TRuntimeNode flow, TRuntimeNode count); TRuntimeNode WideTopBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); TRuntimeNode WideTopSortBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); + TRuntimeNode WideSortBlocks(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); TRuntimeNode AsScalar(TRuntimeNode value); TRuntimeNode BlockCompress(TRuntimeNode flow, ui32 bitmapIndex); TRuntimeNode BlockExpandChunked(TRuntimeNode flow); @@ -409,6 +410,7 @@ public: TRuntimeNode WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); TRuntimeNode WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); + TRuntimeNode WideSort(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); TRuntimeNode Length(TRuntimeNode listOrDict); TRuntimeNode Iterator(TRuntimeNode list, const TArrayRef<const TRuntimeNode>& dependentNodes); @@ -725,7 +727,7 @@ private: TRuntimeNode BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members, const std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>>& filteredItems); - TRuntimeNode BuildWideTop(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); + TRuntimeNode BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); TRuntimeNode InvokeBinary(const std::string_view& callableName, TType* type, TRuntimeNode data1, TRuntimeNode data2); TRuntimeNode AggrCompare(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2); diff --git a/ydb/library/yql/minikql/mkql_runtime_version.h b/ydb/library/yql/minikql/mkql_runtime_version.h index 601e40788e..e776c5577d 100644 --- a/ydb/library/yql/minikql/mkql_runtime_version.h +++ b/ydb/library/yql/minikql/mkql_runtime_version.h @@ -24,7 +24,7 @@ namespace NMiniKQL { // 1. Bump this version every time incompatible runtime nodes are introduced. // 2. Make sure you provide runtime node generation for previous runtime versions. #ifndef MKQL_RUNTIME_VERSION -#define MKQL_RUNTIME_VERSION 33U +#define MKQL_RUNTIME_VERSION 34U #endif // History: |