aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2023-02-20 18:58:57 +0300
committervvvv <vvvv@ydb.tech>2023-02-20 18:58:57 +0300
commitcaced409f95cbd73bbb5acd5f6ca0eb85ee282ab (patch)
treee4c1e77a69af550b6914fabadc670667abffa453
parenta181588e011454f9b95afde48fd5a71f7513461b (diff)
downloadydb-caced409f95cbd73bbb5acd5f6ca0eb85ee282ab.tar.gz
mkql part of WideSort
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp1
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp174
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h1
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp83
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.cpp34
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.h4
-rw-r--r--ydb/library/yql/minikql/mkql_runtime_version.h2
8 files changed, 238 insertions, 63 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
index 21400a0e14..1662fed7b8 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
@@ -324,6 +324,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller {
{"WideChopper", &WrapWideChopper},
{"WideTop", &WrapWideTop},
{"WideTopSort", &WrapWideTopSort},
+ {"WideSort", &WrapWideSort},
{"WideFlowArg", &WrapWideFlowArg},
{"Source", &WrapSource},
{"RangeCreate", &WrapRangeCreate},
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp
index 33848acde6..a7e5f4b0ca 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp
@@ -29,10 +29,11 @@ struct TMyValueCompare {
using TComparePtr = int(*)(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*);
using TCompareFunc = std::function<int(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)>;
-class TState : public TComputationValue<TState> {
-using TBase = TComputationValue<TState>;
+template <bool HasCount>
+class TState : public TComputationValue<TState<HasCount>> {
+using TBase = TComputationValue<TState<HasCount>>;
public:
-using TLLVMBase = TLLVMFieldsStructure<TComputationValue<TState>>;
+using TLLVMBase = TLLVMFieldsStructure<TComputationValue<TState<HasCount>>>;
private:
using TStorage = std::vector<NUdf::TUnboxedValue, TMKQLAllocator<NUdf::TUnboxedValue, EMemorySubPool::Temporary>>;
using TFields = std::vector<NUdf::TUnboxedValue*, TMKQLAllocator<NUdf::TUnboxedValue*, EMemorySubPool::Temporary>>;
@@ -43,15 +44,30 @@ private:
}
void ResetFields() {
- auto ptr = Tongue = Free.back();
+ NUdf::TUnboxedValuePod* ptr;
+ if constexpr (HasCount) {
+ ptr = Tongue = Free.back();
+ } else {
+ auto pos = Storage.size();
+ Storage.insert(Storage.end(), Indexes.size(), {});
+ ptr = Tongue = Storage.data() + pos;
+ }
+
std::for_each(Indexes.cbegin(), Indexes.cend(), [&](ui32 index) { Fields[index] = static_cast<NUdf::TUnboxedValue*>(ptr++); });
}
public:
TState(TMemoryUsageInfo* memInfo, ui64 count, const bool* directons, size_t keyWidth, const TCompareFunc& compare, const std::vector<ui32>& indexes)
: TBase(memInfo), Count(count), Indexes(indexes), Directions(directons, directons + keyWidth)
, LessFunc(std::bind(std::less<int>(), std::bind(compare, Directions.data(), std::placeholders::_1, std::placeholders::_2), 0))
- , Storage(GetStorageSize() * Indexes.size()), Free(GetStorageSize(), nullptr), Fields(Indexes.size(), nullptr)
+ , Fields(Indexes.size(), nullptr)
{
+ if constexpr (!HasCount) {
+ ResetFields();
+ return;
+ }
+
+ Storage.resize(GetStorageSize() * Indexes.size());
+ Free.resize(GetStorageSize(), nullptr);
if (Count) {
Full.reserve(GetStorageSize());
auto ptr = Storage.data();
@@ -70,6 +86,11 @@ public:
}
bool Push() {
+ if constexpr (!HasCount) {
+ ResetFields();
+ return true;
+ }
+
if (Full.size() + 1U == GetStorageSize()) {
Free.pop_back();
@@ -100,6 +121,18 @@ public:
template<bool Sort>
void Seal() {
+ if constexpr (!HasCount) {
+ static_assert (Sort);
+ Storage.resize(Storage.size() - Indexes.size());
+ Full.reserve(Storage.size() / Indexes.size());
+ for (auto it = Storage.begin(); it != Storage.end(); it += Indexes.size()) {
+ Full.emplace_back(&*it);
+ }
+
+ std::sort(Full.rbegin(), Full.rend(), LessFunc);
+ return;
+ }
+
Free.clear();
Free.shrink_to_fit();
@@ -135,9 +168,10 @@ private:
};
#ifndef MKQL_DISABLE_CODEGEN
-class TLLVMFieldsStructureState: public TState::TLLVMBase {
+template <bool HasCount>
+class TLLVMFieldsStructureState: public TState<HasCount>::TLLVMBase {
private:
- using TBase = TState::TLLVMBase;
+ using TBase = typename TState<HasCount>::TLLVMBase;
llvm::IntegerType* ValueType;
llvm::PointerType* PtrValueType;
llvm::IntegerType* StatusType;
@@ -171,13 +205,13 @@ public:
};
#endif
-template<bool Sort>
-class TWideTopWrapper: public TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort>>
+template<bool Sort, bool HasCount>
+class TWideTopWrapper: public TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort, HasCount>>
#ifndef MKQL_DISABLE_CODEGEN
, public ICodegeneratorRootNode
#endif
{
-using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort>>;
+using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort, HasCount>>;
public:
TWideTopWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, TComputationNodePtrVector&& directions, TKeyTypes&& keyTypes, std::vector<ui32>&& indexes, std::vector<EValueRepresentation>&& representations)
: TBaseComputation(mutables, flow, EValueRepresentation::Boxed), Flow(flow), Count(count), Directions(std::move(directions)), KeyTypes(std::move(keyTypes)), Indexes(std::move(indexes)), Representations(std::move(representations))
@@ -185,13 +219,19 @@ public:
EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
if (!state.HasValue()) {
- const auto count = Count->GetValue(ctx).Get<ui64>();
+ ui64 count;
+ if constexpr (HasCount) {
+ count = Count->GetValue(ctx).Get<ui64>();
+ } else {
+ count = 0;
+ }
+
std::vector<bool> dirs(Directions.size());
std::transform(Directions.cbegin(), Directions.cend(), dirs.begin(), [&ctx](IComputationNode* dir){ return dir->GetValue(ctx).Get<bool>(); });
MakeState(ctx, state, count, dirs.data());
}
- if (const auto ptr = static_cast<TState*>(state.AsBoxed().Get())) {
+ if (const auto ptr = static_cast<TState<HasCount>*>(state.AsBoxed().Get())) {
while (EFetchResult::Finish != ptr->InputStatus) {
switch (ptr->InputStatus = Flow->FetchValues(ctx, ptr->GetFields())) {
case EFetchResult::One:
@@ -200,7 +240,7 @@ public:
case EFetchResult::Yield:
return EFetchResult::Yield;
case EFetchResult::Finish:
- ptr->Seal<Sort>();
+ ptr->template Seal<Sort>();
break;
}
}
@@ -230,7 +270,7 @@ public:
const auto statusType = Type::getInt32Ty(context);
const auto indexType = Type::getInt32Ty(ctx.Codegen->GetContext());
- TLLVMFieldsStructureState stateFields(context);
+ TLLVMFieldsStructureState<HasCount> stateFields(context);
const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
const auto statePtrType = PointerType::getUnqual(stateType);
@@ -255,8 +295,13 @@ public:
BranchInst::Create(main, make, HasValue(statePtr, block), block);
block = make;
- const auto count = GetNodeValue(Count, ctx, block);
- const auto trunc = GetterFor<ui64>(count, context, block);
+ llvm::Value* trunc;
+ if constexpr (HasCount) {
+ const auto count = GetNodeValue(Count, ctx, block);
+ trunc = GetterFor<ui64>(count, context, block);
+ } else {
+ trunc = ConstantInt::get(Type::getInt64Ty(context), 0U);
+ }
const auto dirs = new AllocaInst(ArrayType::get(Type::getInt1Ty(context), Directions.size()), 0U, "dirs", block);
for (auto i = 0U; i < Directions.size(); ++i) {
@@ -311,7 +356,7 @@ public:
block = rest;
new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), statusPtr, block);
- const auto sealFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Seal<Sort>));
+ const auto sealFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::template Seal<Sort>));
const auto sealType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType()}, false);
const auto sealPtr = CastInst::Create(Instruction::IntToPtr, sealFunc, PointerType::getUnqual(sealType), "seal", block);
CallInst::Create(sealPtr, {stateArg}, "", block);
@@ -328,40 +373,52 @@ public:
placeholders[i] = GetElementPtrInst::CreateInBounds(tongue, {ConstantInt::get(indexType, i)}, (TString("placeholder_") += ToString(i)).c_str(), block);
}
- for (auto i = 0U; i < KeyTypes.size(); ++i) {
- const auto item = getres.second[Indexes[i]](ctx, block);
- new StoreInst(item, placeholders[i], block);
+ if constexpr (!HasCount) {
+ for (auto i = 0; i < Representations.size(); ++i) {
+ const auto item = getres.second[Indexes[i]](ctx, block);
+ ValueAddRef(Representations[i], item, ctx, block);
+ new StoreInst(item, placeholders[i], block);
+ }
+
+ } else {
+ for (auto i = 0U; i < KeyTypes.size(); ++i) {
+ const auto item = getres.second[Indexes[i]](ctx, block);
+ new StoreInst(item, placeholders[i], block);
+ }
}
- const auto pushFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Push));
+
+ const auto pushFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::Push));
const auto pushType = FunctionType::get(Type::getInt1Ty(context), {stateArg->getType()}, false);
const auto pushPtr = CastInst::Create(Instruction::IntToPtr, pushFunc, PointerType::getUnqual(pushType), "function", block);
const auto accepted = CallInst::Create(pushPtr, {stateArg}, "accepted", block);
+ if constexpr (HasCount) {
+ const auto push = BasicBlock::Create(context, "push", ctx.Func);
+ const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
- const auto push = BasicBlock::Create(context, "push", ctx.Func);
- const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
-
- BranchInst::Create(push, skip, accepted, block);
+ BranchInst::Create(push, skip, accepted, block);
- block = push;
+ block = push;
- for (auto i = 0U; i < KeyTypes.size(); ++i) {
- ValueAddRef(Representations[i], placeholders[i], ctx, block);
- }
+ for (auto i = 0U; i < KeyTypes.size(); ++i) {
+ ValueAddRef(Representations[i], placeholders[i], ctx, block);
+ }
- for (auto i = KeyTypes.size(); i < Representations.size(); ++i) {
- const auto item = getres.second[Indexes[i]](ctx, block);
- ValueAddRef(Representations[i], item, ctx, block);
- new StoreInst(item, placeholders[i], block);
- }
+ for (auto i = KeyTypes.size(); i < Representations.size(); ++i) {
+ const auto item = getres.second[Indexes[i]](ctx, block);
+ ValueAddRef(Representations[i], item, ctx, block);
+ new StoreInst(item, placeholders[i], block);
+ }
+
- BranchInst::Create(loop, block);
+ BranchInst::Create(loop, block);
- block = skip;
+ block = skip;
- for (auto i = 0U; i < KeyTypes.size(); ++i) {
- ValueCleanup(Representations[i], placeholders[i], ctx, block);
- new StoreInst(ConstantInt::get(valueType, 0), placeholders[i], block);
+ for (auto i = 0U; i < KeyTypes.size(); ++i) {
+ ValueCleanup(Representations[i], placeholders[i], ctx, block);
+ new StoreInst(ConstantInt::get(valueType, 0), placeholders[i], block);
+ }
}
BranchInst::Create(loop, block);
@@ -372,7 +429,7 @@ public:
const auto good = BasicBlock::Create(context, "good", ctx.Func);
- const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract));
+ const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState<HasCount>::Extract));
const auto extractType = FunctionType::get(outputPtrType, {stateArg->getType()}, false);
const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block);
const auto out = CallInst::Create(extractPtr, {stateArg}, "out", block);
@@ -397,15 +454,18 @@ public:
private:
void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state, ui64 count, const bool* directions) const {
#ifdef MKQL_DISABLE_CODEGEN
- state = ctx.HolderFactory.Create<TState>(count, directions, Directions.size(), TMyValueCompare(KeyTypes), Indexes);
+ state = ctx.HolderFactory.Create<TState<HasCount>>(count, directions, Directions.size(), TMyValueCompare(KeyTypes), Indexes);
#else
- state = ctx.HolderFactory.Create<TState>(count, directions, Directions.size(), ctx.ExecuteLLVM && Compare ? TCompareFunc(Compare) : TCompareFunc(TMyValueCompare(KeyTypes)), Indexes);
+ state = ctx.HolderFactory.Create<TState<HasCount>>(count, directions, Directions.size(), ctx.ExecuteLLVM && Compare ? TCompareFunc(Compare) : TCompareFunc(TMyValueCompare(KeyTypes)), Indexes);
#endif
}
void RegisterDependencies() const final {
if (const auto flow = this->FlowDependsOn(Flow)) {
- TWideTopWrapper::DependsOn(flow, Count);
+ if constexpr (HasCount) {
+ TWideTopWrapper::DependsOn(flow, Count);
+ }
+
std::for_each(Directions.cbegin(), Directions.cend(), std::bind(&TWideTopWrapper::DependsOn, flow, std::placeholders::_1));
}
}
@@ -441,20 +501,26 @@ private:
}
-template<bool Sort>
+template<bool Sort, bool HasCount>
IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- MKQL_ENSURE(callable.GetInputsCount() > 2U && !(callable.GetInputsCount() % 2U), "Expected more arguments.");
+ const ui32 offset = HasCount ? 0 : 1;
+ const ui32 inputsWithCount = callable.GetInputsCount() + offset;
+ MKQL_ENSURE(inputsWithCount > 2U && !(inputsWithCount % 2U), "Expected more arguments.");
const auto flow = LocateNode(ctx.NodeLocator, callable, 0);
- const auto count = LocateNode(ctx.NodeLocator, callable, 1);
- const auto keyWidth = (callable.GetInputsCount() >> 1U) - 1U;
+ IComputationNode* count = nullptr;
+ if (HasCount) {
+ count = LocateNode(ctx.NodeLocator, callable, 1);
+ }
+
+ const auto keyWidth = (inputsWithCount >> 1U) - 1U;
const auto inputType = AS_TYPE(TTupleType, AS_TYPE(TFlowType, callable.GetType()->GetReturnType())->GetItemType());
std::vector<ui32> indexes(inputType->GetElementsCount());
std::iota(indexes.begin(), indexes.end(), 0U);
TKeyTypes keyTypes(keyWidth);
for (auto i = 0U; i < keyTypes.size(); ++i) {
- const auto keyIndex = AS_VALUE(TDataLiteral, callable.GetInput((i + 1U) << 1U))->AsValue().Get<ui32>();
+ const auto keyIndex = AS_VALUE(TDataLiteral, callable.GetInput(((i + 1U) << 1U) - offset))->AsValue().Get<ui32>();
std::swap(indexes[i], indexes[indexes[keyIndex]]);
keyTypes[i].first = *UnpackOptionalData(inputType->GetElementType(keyIndex), keyTypes[i].second)->GetDataSlot();
}
@@ -464,22 +530,26 @@ IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactor
representations[i] = GetValueRepresentation(inputType->GetElementType(indexes[i]));
TComputationNodePtrVector directions(keyWidth);
- auto index = 1U;
+ auto index = 1U - offset;
std::generate(directions.begin(), directions.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++++index); });
if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
- return new TWideTopWrapper<Sort>(ctx.Mutables, wide, count, std::move(directions), std::move(keyTypes), std::move(indexes), std::move(representations));
+ return new TWideTopWrapper<Sort, HasCount>(ctx.Mutables, wide, count, std::move(directions), std::move(keyTypes), std::move(indexes), std::move(representations));
}
THROW yexception() << "Expected wide flow.";
}
IComputationNode* WrapWideTop(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- return WrapWideTopT<false>(callable, ctx);
+ return WrapWideTopT<false, true>(callable, ctx);
}
IComputationNode* WrapWideTopSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- return WrapWideTopT<true>(callable, ctx);
+ return WrapWideTopT<true, true>(callable, ctx);
+}
+
+IComputationNode* WrapWideSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
+ return WrapWideTopT<true, false>(callable, ctx);
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h
index d31c3d8553..e56bbc1d36 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h
@@ -7,6 +7,7 @@ namespace NMiniKQL {
IComputationNode* WrapWideTop(TCallable& callable, const TComputationNodeFactoryContext& ctx);
IComputationNode* WrapWideTopSort(TCallable& callable, const TComputationNodeFactoryContext& ctx);
+IComputationNode* WrapWideSort(TCallable& callable, const TComputationNodeFactoryContext& ctx);
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h b/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h
index c48e5f88e3..bf20ea2c11 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h
@@ -15,7 +15,7 @@
if (!(v.AsStringRef() == (expected))) { \
UNIT_FAIL_IMPL( \
"equal assertion failed", \
- Sprintf("%s == %s", #unboxed, #expected)); \
+ Sprintf("%s %s == %s", #unboxed, TString(v.AsStringRef()).c_str(), #expected)); \
} \
} while (0)
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp
index 3dca8d8d16..b52b617587 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp
@@ -563,6 +563,89 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideTopTest) {
}
}
#endif
+
+#if !defined(MKQL_RUNTIME_VERSION) || MKQL_RUNTIME_VERSION >= 34u
+Y_UNIT_TEST_SUITE(TMiniKQLWideSortTest) {
+ Y_UNIT_TEST_LLVM(SortByFirstKeyAsc) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id);
+ const auto tupleType = pb.NewTupleType({dataType, dataType});
+
+ const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one");
+ const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two");
+
+ const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one");
+ const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two");
+
+ const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1");
+ const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2");
+ const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3");
+ const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4");
+ const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5");
+ const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6");
+ const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7");
+ const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8");
+ const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9");
+
+ const auto data1 = pb.NewTuple(tupleType, {keyOne, value1});
+
+ const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2});
+ const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3});
+
+ const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4});
+
+ const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5});
+ const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6});
+ const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7});
+ const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8});
+ const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9});
+
+ const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9});
+
+ const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideSort(pb.ExpandMap(pb.ToFlow(list),
+ [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }),
+ {{0U, pb.NewDataLiteral<bool>(true)}}),
+ [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); }
+ ));
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue().GetListIterator();
+ NUdf::TUnboxedValue item;
+ UNIT_ASSERT(iterator.Next(item));
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key one");
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 1");
+ UNIT_ASSERT(iterator.Next(item));
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two");
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 3");
+ UNIT_ASSERT(iterator.Next(item));
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two");
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 2");
+ UNIT_ASSERT(iterator.Next(item));
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key one");
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 4");
+ UNIT_ASSERT(iterator.Next(item));
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two");
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 9");
+ UNIT_ASSERT(iterator.Next(item));
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two");
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 8");
+ UNIT_ASSERT(iterator.Next(item));
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two");
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 7");
+ UNIT_ASSERT(iterator.Next(item));
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two");
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 6");
+ UNIT_ASSERT(iterator.Next(item));
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two");
+ UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 5");
+ UNIT_ASSERT(!iterator.Next(item));
+ UNIT_ASSERT(!iterator.Next(item));
+ }
+}
+#endif
+
}
}
diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp
index 4ebec923a9..5c1e38847d 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_program_builder.cpp
@@ -1484,11 +1484,15 @@ TRuntimeNode TProgramBuilder::WideTakeBlocks(TRuntimeNode flow, TRuntimeNode cou
}
TRuntimeNode TProgramBuilder::WideTopBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
- return BuildWideTop(__func__, flow, count, keys);
+ return BuildWideTopOrSort(__func__, flow, count, keys);
}
TRuntimeNode TProgramBuilder::WideTopSortBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
- return BuildWideTop(__func__, flow, count, keys);
+ return BuildWideTopOrSort(__func__, flow, count, keys);
+}
+
+TRuntimeNode TProgramBuilder::WideSortBlocks(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
+ return BuildWideTopOrSort(__func__, flow, Nothing(), keys);
}
TRuntimeNode TProgramBuilder::AsScalar(TRuntimeNode value) {
@@ -1671,17 +1675,28 @@ TRuntimeNode TProgramBuilder::Sort(TRuntimeNode list, TRuntimeNode ascending, co
TRuntimeNode TProgramBuilder::WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
{
- return BuildWideTop(__func__, flow, count, keys);
+ return BuildWideTopOrSort(__func__, flow, count, keys);
}
TRuntimeNode TProgramBuilder::WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
{
- return BuildWideTop(__func__, flow, count, keys);
+ return BuildWideTopOrSort(__func__, flow, count, keys);
+}
+
+TRuntimeNode TProgramBuilder::WideSort(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys)
+{
+ return BuildWideTopOrSort(__func__, flow, Nothing(), keys);
}
-TRuntimeNode TProgramBuilder::BuildWideTop(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
- if constexpr (RuntimeVersion < 33U) {
- THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName;
+TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) {
+ if (count) {
+ if constexpr (RuntimeVersion < 33U) {
+ THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName;
+ }
+ } else {
+ if constexpr (RuntimeVersion < 34U) {
+ THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName;
+ }
}
const auto width = AS_TYPE(TTupleType, AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType())->GetElementsCount();
@@ -1689,7 +1704,10 @@ TRuntimeNode TProgramBuilder::BuildWideTop(const std::string_view& callableName,
TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
callableBuilder.Add(flow);
- callableBuilder.Add(count);
+ if (count) {
+ callableBuilder.Add(*count);
+ }
+
std::for_each(keys.cbegin(), keys.cend(), [&](const std::pair<ui32, TRuntimeNode>& key) {
MKQL_ENSURE(key.first < width, "Key index too large: " << key.first);
callableBuilder.Add(NewDataLiteral(key.first));
diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h
index 45dd4d200d..643bb1e7b9 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.h
+++ b/ydb/library/yql/minikql/mkql_program_builder.h
@@ -248,6 +248,7 @@ public:
TRuntimeNode WideTakeBlocks(TRuntimeNode flow, TRuntimeNode count);
TRuntimeNode WideTopBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys);
TRuntimeNode WideTopSortBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys);
+ TRuntimeNode WideSortBlocks(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys);
TRuntimeNode AsScalar(TRuntimeNode value);
TRuntimeNode BlockCompress(TRuntimeNode flow, ui32 bitmapIndex);
TRuntimeNode BlockExpandChunked(TRuntimeNode flow);
@@ -409,6 +410,7 @@ public:
TRuntimeNode WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys);
TRuntimeNode WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys);
+ TRuntimeNode WideSort(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys);
TRuntimeNode Length(TRuntimeNode listOrDict);
TRuntimeNode Iterator(TRuntimeNode list, const TArrayRef<const TRuntimeNode>& dependentNodes);
@@ -725,7 +727,7 @@ private:
TRuntimeNode BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members,
const std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>>& filteredItems);
- TRuntimeNode BuildWideTop(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys);
+ TRuntimeNode BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys);
TRuntimeNode InvokeBinary(const std::string_view& callableName, TType* type, TRuntimeNode data1, TRuntimeNode data2);
TRuntimeNode AggrCompare(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2);
diff --git a/ydb/library/yql/minikql/mkql_runtime_version.h b/ydb/library/yql/minikql/mkql_runtime_version.h
index 601e40788e..e776c5577d 100644
--- a/ydb/library/yql/minikql/mkql_runtime_version.h
+++ b/ydb/library/yql/minikql/mkql_runtime_version.h
@@ -24,7 +24,7 @@ namespace NMiniKQL {
// 1. Bump this version every time incompatible runtime nodes are introduced.
// 2. Make sure you provide runtime node generation for previous runtime versions.
#ifndef MKQL_RUNTIME_VERSION
-#define MKQL_RUNTIME_VERSION 33U
+#define MKQL_RUNTIME_VERSION 34U
#endif
// History: