diff options
author | a-romanov <Anton.Romanov@ydb.tech> | 2023-01-24 10:33:49 +0300 |
---|---|---|
committer | a-romanov <Anton.Romanov@ydb.tech> | 2023-01-24 10:33:49 +0300 |
commit | bc3cabbcad7e8c95f9bba817859ede1db8ac009b (patch) | |
tree | eb852058da2b2aedf0aa276a6c9a9dbf5de41c4a | |
parent | 535b420b872da60210194dbaf1f9e5560bcc308b (diff) | |
download | ydb-bc3cabbcad7e8c95f9bba817859ede1db8ac009b.tar.gz |
WideTop[Sort] mkql layer.
13 files changed, 1299 insertions, 3 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt index a85f5d940d..4ded4d6924 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt @@ -149,6 +149,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_condense.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_map.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_withcontext.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_zip.cpp ) diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt index ae199f0008..b7a73edfa8 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt @@ -150,6 +150,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_condense.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_map.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_withcontext.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_zip.cpp ) diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt index ae199f0008..b7a73edfa8 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt @@ -150,6 +150,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_condense.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_map.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_withcontext.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_zip.cpp ) diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index ddeecfa047..63df468c8e 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -102,6 +102,7 @@ #include "mkql_wide_condense.h" #include "mkql_wide_filter.h" #include "mkql_wide_map.h" +#include "mkql_wide_top_sort.h" #include "mkql_withcontext.h" #include "mkql_zip.h" @@ -316,6 +317,8 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"WideLastCombiner", &WrapWideLastCombiner}, {"WideCondense1", &WrapWideCondense1}, {"WideChopper", &WrapWideChopper}, + {"WideTop", &WrapWideTop}, + {"WideTopSort", &WrapWideTopSort}, {"WideFlowArg", &WrapWideFlowArg}, {"Source", &WrapSource}, {"RangeCreate", &WrapRangeCreate}, diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp new file mode 100644 index 0000000000..485b96a9aa --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp @@ -0,0 +1,459 @@ +#include "mkql_wide_top_sort.h" +#include "mkql_llvm_base.h" + +#include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h> +#include <ydb/library/yql/minikql/mkql_node_builder.h> +#include <ydb/library/yql/minikql/mkql_node_cast.h> +#include <ydb/library/yql/minikql/defs.h> +#include <ydb/library/yql/utils/cast.h> + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +struct TMyValueCompare { + TMyValueCompare(const TKeyTypes& types) + : Types(types) + {} + + int operator()(const bool* directions, const NUdf::TUnboxedValuePod* left, const NUdf::TUnboxedValuePod* right) const { + return CompareValues(left, right, Types, directions); + } + + const TKeyTypes& Types; +}; + +using TComparePtr = int(*)(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*); +using TCompareFunc = std::function<int(const bool*, const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)>; + +class TState : public TComputationValue<TState> { +using TBase = TComputationValue<TState>; +public: +using TLLVMBase = TLLVMFieldsStructure<TComputationValue<TState>>; +private: + using TStorage = std::vector<NUdf::TUnboxedValue, TMKQLAllocator<NUdf::TUnboxedValue, EMemorySubPool::Temporary>>; + using TFields = std::vector<NUdf::TUnboxedValue*, TMKQLAllocator<NUdf::TUnboxedValue*, EMemorySubPool::Temporary>>; + using TPointers = std::vector<NUdf::TUnboxedValuePod*, TMKQLAllocator<NUdf::TUnboxedValuePod*, EMemorySubPool::Temporary>>; + + size_t GetStorageSize() const { + return std::max<size_t>(Count << 2ULL, 1ULL << 8ULL); + } + + void ResetFields() { + const auto ptr = Free.back(); + auto i = 0U; + std::generate(Fields.begin(), Fields.end(), [&]() { return &static_cast<NUdf::TUnboxedValue&>(ptr[Indexes[i++]]); }); + Tongue = ptr; + } +public: + TState(TMemoryUsageInfo* memInfo, ui64 count, const bool* directons, const TCompareFunc& compare, const std::vector<ui32>& indexes) + : TBase(memInfo), Count(count), Indexes(indexes), Directions(directons, directons + Indexes.size()) + , LessFunc(std::bind(std::less<int>(), std::bind(compare, Directions.data(), std::placeholders::_1, std::placeholders::_2), 0)) + , Storage(GetStorageSize() * Indexes.size()), Free(GetStorageSize(), nullptr), Fields(Indexes.size(), nullptr) + { + if (Count) { + Full.reserve(GetStorageSize()); + auto ptr = Storage.data(); + std::generate(Free.begin(), Free.end(), [&ptr, this]() { + const auto p = ptr; + ptr += Indexes.size(); + return p; + }); + ResetFields(); + } else + InputStatus = EFetchResult::Finish; + } + + NUdf::TUnboxedValue*const* GetFields() const { + return Fields.data(); + } + + void Push() { + if (Full.size() < Count) { + Full.emplace_back(Free.back()); + Free.pop_back(); + ResetFields(); + return; + } else if (!Throat) { + Throat = *std::max_element(Full.cbegin(), Full.cend(), LessFunc); + } + + if (!LessFunc(Tongue, Throat)) { + std::fill_n(static_cast<NUdf::TUnboxedValue*>(Tongue), Indexes.size(), NUdf::TUnboxedValuePod()); + return; + } + + Full.emplace_back(Free.back()); + Free.pop_back(); + + if (Full.size() == GetStorageSize()) { + std::nth_element(Full.begin(), Full.begin() + Count, Full.end(), LessFunc); + std::copy(Full.cbegin() + Count, Full.cend(), std::back_inserter(Free)); + Full.resize(Count); + std::for_each(Free.cbegin(), Free.cend(), [this](NUdf::TUnboxedValuePod* ptr) { + std::fill_n(static_cast<NUdf::TUnboxedValue*>(ptr), Indexes.size(), NUdf::TUnboxedValuePod()); + }); + Throat = *std::max_element(Full.cbegin(), Full.cend(), LessFunc); + } + + ResetFields(); + } + + template<bool Sort> + void Seal() { + Free.clear(); + Free.shrink_to_fit(); + + if (Full.size() > Count) { + std::nth_element(Full.begin(), Full.begin() + Count, Full.end(), LessFunc); + Full.resize(Count); + } + + if constexpr (Sort) + std::sort(Full.rbegin(), Full.rend(), LessFunc); + } + + NUdf::TUnboxedValue* Extract() { + if (Full.empty()) + return nullptr; + + const auto ptr = Full.back(); + Full.pop_back(); + return static_cast<NUdf::TUnboxedValue*>(ptr); + } + + EFetchResult InputStatus = EFetchResult::One; + NUdf::TUnboxedValuePod* Tongue = nullptr; + NUdf::TUnboxedValuePod* Throat = nullptr; +private: + const ui64 Count; + const std::vector<ui32> Indexes; + const std::vector<bool> Directions; + const std::function<bool(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*)> LessFunc; + TStorage Storage; + TPointers Free, Full; + TFields Fields; +}; + +#ifndef MKQL_DISABLE_CODEGEN +class TLLVMFieldsStructureState: public TState::TLLVMBase { +private: + using TBase = TState::TLLVMBase; + llvm::IntegerType* ValueType; + llvm::PointerType* PtrValueType; + llvm::IntegerType* StatusType; +protected: + using TBase::Context; +public: + std::vector<llvm::Type*> GetFieldsArray() { + std::vector<llvm::Type*> result = TBase::GetFields(); + result.emplace_back(StatusType); //status + result.emplace_back(PtrValueType); //tongue + result.emplace_back(PtrValueType); //throat + result.emplace_back(Type::getInt64Ty(Context)); //count + return result; + } + + llvm::Constant* GetStatus() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0); + } + + llvm::Constant* GetTongue() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1); + } + + TLLVMFieldsStructureState(llvm::LLVMContext& context) + : TBase(context) + , ValueType(Type::getInt128Ty(Context)) + , PtrValueType(PointerType::getUnqual(ValueType)) + , StatusType(Type::getInt32Ty(Context)) { + + } +}; +#endif + +template<bool Sort> +class TWideTopWrapper: public TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort>> +#ifndef MKQL_DISABLE_CODEGEN + , public ICodegeneratorRootNode +#endif +{ +using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTopWrapper<Sort>>; +public: + TWideTopWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, TComputationNodePtrVector&& directions, TKeyTypes&& keyTypes, std::vector<ui32>&& indexes, std::vector<EValueRepresentation>&& representations) + : TBaseComputation(mutables, flow, EValueRepresentation::Boxed), Flow(flow), Count(count), Directions(std::move(directions)), KeyTypes(std::move(keyTypes)), Indexes(std::move(indexes)), Representations(std::move(representations)) + {} + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + if (!state.HasValue()) { + const auto count = Count->GetValue(ctx).Get<ui64>(); + std::vector<bool> dirs(Directions.size()); + std::transform(Directions.cbegin(), Directions.cend(), dirs.begin(), [&ctx](IComputationNode* dir){ return dir->GetValue(ctx).Get<bool>(); }); + MakeState(ctx, state, count, dirs.data()); + } + + if (const auto ptr = static_cast<TState*>(state.AsBoxed().Get())) { + while (EFetchResult::Finish != ptr->InputStatus) { + switch (ptr->InputStatus = Flow->FetchValues(ctx, ptr->GetFields())) { + case EFetchResult::One: + ptr->Push(); + continue; + case EFetchResult::Yield: + return EFetchResult::Yield; + case EFetchResult::Finish: + ptr->Seal<Sort>(); + break; + } + } + + if (auto extract = ptr->Extract()) { + for (const auto index : Indexes) + if (const auto to = output[index]) + *to = std::move(*extract++); + else + ++extract; + return EFetchResult::One; + } + + return EFetchResult::Finish; + } + + Y_UNREACHABLE(); + } +#ifndef MKQL_DISABLE_CODEGEN + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + auto& context = ctx.Codegen->GetContext(); + + const auto valueType = Type::getInt128Ty(context); + const auto ptrValueType = PointerType::getUnqual(valueType); + const auto structPtrType = PointerType::getUnqual(StructType::get(context)); + const auto contextType = GetCompContextType(context); + const auto statusType = Type::getInt32Ty(context); + const auto indexType = Type::getInt32Ty(ctx.Codegen->GetContext()); + + TLLVMFieldsStructureState stateFields(context); + const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); + + const auto statePtrType = PointerType::getUnqual(stateType); + + const auto outputPtrType = PointerType::getUnqual(ArrayType::get(valueType, Representations.size())); + const auto outs = new AllocaInst(outputPtrType, 0U, "outs", &ctx.Func->getEntryBlock().back()); + + ICodegeneratorInlineWideNode::TGettersList getters(Representations.size()); + + for (auto i = 0U; i < getters.size(); ++i) { + getters[Indexes[i]] = [i, outs, indexType](const TCodegenContext& ctx, BasicBlock*& block) { + const auto values = new LoadInst(outs, "values", block); + const auto pointer = GetElementPtrInst::CreateInBounds(values, {ConstantInt::get(indexType, 0), ConstantInt::get(indexType, i)}, (TString("ptr_") += ToString(i)).c_str(), block); + return new LoadInst(pointer, (TString("load_") += ToString(i)).c_str(), block); + }; + } + + const auto make = BasicBlock::Create(context, "make", ctx.Func); + const auto main = BasicBlock::Create(context, "main", ctx.Func); + const auto more = BasicBlock::Create(context, "more", ctx.Func); + + BranchInst::Create(main, make, HasValue(statePtr, block), block); + block = make; + + const auto count = GetNodeValue(Count, ctx, block); + const auto trunc = GetterFor<ui64>(count, context, block); + + const auto dirs = new AllocaInst(ArrayType::get(Type::getInt1Ty(context), Directions.size()), 0U, "dirs", block); + for (auto i = 0U; i < Directions.size(); ++i) { + const auto dir = GetNodeValue(Directions[i], ctx, block); + const auto cut = GetterFor<bool>(dir, context, block); + const auto ptr = GetElementPtrInst::CreateInBounds(dirs, {ConstantInt::get(indexType, 0), ConstantInt::get(indexType, i)}, "ptr", block); + new StoreInst(cut, ptr, block); + } + + const auto ptrType = PointerType::getUnqual(StructType::get(context)); + const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideTopWrapper::MakeState)); + const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType(), trunc->getType(), dirs->getType()}, false); + const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); + CallInst::Create(makeFuncPtr, {self, ctx.Ctx, statePtr, trunc, dirs}, "", block); + BranchInst::Create(main, block); + + block = main; + + const auto state = new LoadInst(statePtr, "state", block); + const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block); + const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block); + BranchInst::Create(more, block); + + block = more; + + const auto loop = BasicBlock::Create(context, "loop", ctx.Func); + const auto full = BasicBlock::Create(context, "full", ctx.Func); + const auto over = BasicBlock::Create(context, "over", ctx.Func); + const auto result = PHINode::Create(statusType, 3U, "result", over); + + const auto statusPtr = GetElementPtrInst::CreateInBounds(stateArg, {stateFields.This(), stateFields.GetStatus()}, "last", block); + const auto last = new LoadInst(statusPtr, "last", block); + const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, last, ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), "finish", block); + + BranchInst::Create(full, loop, finish, block); + + { + const auto rest = BasicBlock::Create(context, "rest", ctx.Func); + const auto good = BasicBlock::Create(context, "good", ctx.Func); + + block = loop; + + const auto getres = GetNodeValues(Flow, ctx, block); + + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block); + + const auto choise = SwitchInst::Create(getres.first, good, 2U, block); + choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), over); + choise->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), rest); + + block = rest; + + new StoreInst(ConstantInt::get(last->getType(), static_cast<i32>(EFetchResult::Finish)), statusPtr, block); + const auto sealFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Seal<Sort>)); + const auto sealType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType()}, false); + const auto sealPtr = CastInst::Create(Instruction::IntToPtr, sealFunc, PointerType::getUnqual(sealType), "seal", block); + CallInst::Create(sealPtr, {stateArg}, "", block); + + BranchInst::Create(full, block); + + block = good; + + const auto tonguePtr = GetElementPtrInst::CreateInBounds(stateArg, { stateFields.This(), stateFields.GetTongue() }, "tongue_ptr", block); + const auto tongue = new LoadInst(tonguePtr, "tongue", block); + + for (auto i = 0U; i < Representations.size(); ++i) { + const auto item = getres.second[Indexes[i]](ctx, block); + ValueAddRef(Representations[i], item, ctx, block); + const auto ptr = GetElementPtrInst::CreateInBounds(tongue, {ConstantInt::get(Type::getInt32Ty(context), i)}, (TString("ptr_") += ToString(i)).c_str(), block); + new StoreInst(item, ptr, block); + } + + const auto pushFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Push)); + const auto pushType = FunctionType::get(Type::getVoidTy(context), {stateArg->getType()}, false); + const auto pushPtr = CastInst::Create(Instruction::IntToPtr, pushFunc, PointerType::getUnqual(pushType), "function", block); + CallInst::Create(pushPtr, {stateArg}, "", block); + + BranchInst::Create(loop, block); + } + + { + block = full; + + const auto good = BasicBlock::Create(context, "good", ctx.Func); + + const auto extractFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Extract)); + const auto extractType = FunctionType::get(outputPtrType, {stateArg->getType()}, false); + const auto extractPtr = CastInst::Create(Instruction::IntToPtr, extractFunc, PointerType::getUnqual(extractType), "extract", block); + const auto out = CallInst::Create(extractPtr, {stateArg}, "out", block); + const auto has = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, out, ConstantPointerNull::get(outputPtrType), "has", block); + + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block); + + BranchInst::Create(good, over, has, block); + + block = good; + + new StoreInst(out, outs, block); + + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block); + BranchInst::Create(over, block); + } + + block = over; + return {result, std::move(getters)}; + } +#endif +private: + void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state, ui64 count, const bool* directions) const { +#ifdef MKQL_DISABLE_CODEGEN + state = ctx.HolderFactory.Create<TState>(count, directions, TMyValueCompare(KeyTypes), Indexes); +#else + state = ctx.HolderFactory.Create<TState>(count, directions, ctx.ExecuteLLVM && Compare ? TCompareFunc(Compare) : TCompareFunc(TMyValueCompare(KeyTypes)), Indexes); +#endif + } + + void RegisterDependencies() const final { + if (const auto flow = this->FlowDependsOn(Flow)) { + TWideTopWrapper::DependsOn(flow, Count); + std::for_each(Directions.cbegin(), Directions.cend(), std::bind(&TWideTopWrapper::DependsOn, flow, std::placeholders::_1)); + } + } + + IComputationWideFlowNode *const Flow; + IComputationNode *const Count; + const TComputationNodePtrVector Directions; + const TKeyTypes KeyTypes; + const std::vector<ui32> Indexes; + const std::vector<EValueRepresentation> Representations; +#ifndef MKQL_DISABLE_CODEGEN + TComparePtr Compare = nullptr; + + Function* CompareFunc = nullptr; + + TString MakeName() const { + TStringStream out; + out << this->DebugString() << "::Compare_(" << static_cast<const void*>(this) << ")."; + return out.Str(); + } + + void FinalizeFunctions(const NYql::NCodegen::ICodegen::TPtr& codegen) final { + if (CompareFunc) { + Compare = reinterpret_cast<TComparePtr>(codegen->GetPointerToFunction(CompareFunc)); + } + } + + void GenerateFunctions(const NYql::NCodegen::ICodegen::TPtr& codegen) final { + codegen->ExportSymbol(CompareFunc = GenerateCompareFunction(codegen, MakeName(), KeyTypes)); + } +#endif +}; + +} + +template<bool Sort> +IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() > 2U && !(callable.GetInputsCount() % 2U), "Expected more arguments."); + + const auto flow = LocateNode(ctx.NodeLocator, callable, 0); + const auto count = LocateNode(ctx.NodeLocator, callable, 1); + const auto keyWidth = (callable.GetInputsCount() >> 1U) - 1U; + const auto inputType = AS_TYPE(TTupleType, AS_TYPE(TFlowType, callable.GetType()->GetReturnType())->GetItemType()); + std::vector<ui32> indexes(inputType->GetElementsCount()); + std::iota(indexes.begin(), indexes.end(), 0U); + + TKeyTypes keyTypes(keyWidth); + for (auto i = 0U; i < keyTypes.size(); ++i) { + const auto keyIndex = AS_VALUE(TDataLiteral, callable.GetInput((i + 1U) << 1U))->AsValue().Get<ui32>(); + std::swap(indexes[i], indexes[indexes[keyIndex]]); + keyTypes[i].first = *UnpackOptionalData(inputType->GetElementType(keyIndex), keyTypes[i].second)->GetDataSlot(); + } + + std::vector<EValueRepresentation> representations(inputType->GetElementsCount()); + for (auto i = 0U; i < representations.size(); ++i) + representations[i] = GetValueRepresentation(inputType->GetElementType(indexes[i])); + + TComputationNodePtrVector directions(keyWidth); + auto index = 1U; + std::generate(directions.begin(), directions.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++++index); }); + + if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) { + return new TWideTopWrapper<Sort>(ctx.Mutables, wide, count, std::move(directions), std::move(keyTypes), std::move(indexes), std::move(representations)); + } + + THROW yexception() << "Expected wide flow."; +} + +IComputationNode* WrapWideTop(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + return WrapWideTopT<false>(callable, ctx); +} + +IComputationNode* WrapWideTopSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + return WrapWideTopT<true>(callable, ctx); +} + +} +} diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h new file mode 100644 index 0000000000..d31c3d8553 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.h @@ -0,0 +1,14 @@ +#pragma once + +#include <ydb/library/yql/minikql/computation/mkql_computation_node.h> + +namespace NKikimr { +namespace NMiniKQL { + +IComputationNode* WrapWideTop(TCallable& callable, const TComputationNodeFactoryContext& ctx); +IComputationNode* WrapWideTopSort(TCallable& callable, const TComputationNodeFactoryContext& ctx); + +} +} + + diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp new file mode 100644 index 0000000000..aa91f211d1 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_top_sort_ut.cpp @@ -0,0 +1,568 @@ +#include "mkql_computation_node_ut.h" +#include <ydb/library/yql/minikql/mkql_runtime_version.h> + +#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> + +#include <cstring> + +namespace NKikimr { +namespace NMiniKQL { + +#if !defined(MKQL_RUNTIME_VERSION) || MKQL_RUNTIME_VERSION >= 33u +Y_UNIT_TEST_SUITE(TMiniKQLWideTopTest) { + Y_UNIT_TEST_LLVM(TopByFirstKeyAsc) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id); + const auto tupleType = pb.NewTupleType({dataType, dataType}); + + const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one"); + const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two"); + + const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one"); + const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two"); + + const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1"); + const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2"); + const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3"); + const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4"); + const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5"); + const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6"); + const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7"); + const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8"); + const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9"); + + const auto data1 = pb.NewTuple(tupleType, {keyOne, value1}); + + const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2}); + const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3}); + + const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4}); + + const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5}); + const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6}); + const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7}); + const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8}); + const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9}); + + const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9}); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideTop(pb.ExpandMap(pb.ToFlow(list), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }), + pb.NewDataLiteral<ui64>(4ULL), {{0U, pb.NewDataLiteral<bool>(true)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 4"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 3"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 2"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 1"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } + + Y_UNIT_TEST_LLVM(TopByFirstKeyDesc) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id); + const auto tupleType = pb.NewTupleType({dataType, dataType}); + + const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one"); + const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two"); + + const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one"); + const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two"); + + const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1"); + const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2"); + const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3"); + const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4"); + const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5"); + const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6"); + const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7"); + const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8"); + const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9"); + + const auto data1 = pb.NewTuple(tupleType, {keyOne, value1}); + + const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2}); + const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3}); + + const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4}); + + const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5}); + const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6}); + const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7}); + const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8}); + const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9}); + + const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9}); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideTop(pb.ExpandMap(pb.ToFlow(list), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }), + pb.NewDataLiteral<ui64>(6ULL), {{0U, pb.NewDataLiteral<bool>(false)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 4"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 9"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 6"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 7"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 8"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 5"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } + + Y_UNIT_TEST_LLVM(TopBySecondKeyAsc) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id); + const auto tupleType = pb.NewTupleType({dataType, dataType}); + + const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one"); + const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two"); + + const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one"); + const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two"); + + const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1"); + const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2"); + const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3"); + const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4"); + const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5"); + const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6"); + const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7"); + const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8"); + const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9"); + + const auto data1 = pb.NewTuple(tupleType, {keyOne, value1}); + + const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2}); + const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3}); + + const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4}); + + const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5}); + const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6}); + const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7}); + const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8}); + const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9}); + + const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9}); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideTop(pb.ExpandMap(pb.ToFlow(list), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }), + pb.NewDataLiteral<ui64>(3ULL), {{1U, pb.NewDataLiteral<bool>(true)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 3"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 2"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 1"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } + + Y_UNIT_TEST_LLVM(TopBySecondKeyDesc) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id); + const auto tupleType = pb.NewTupleType({dataType, dataType}); + + const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one"); + const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two"); + + const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one"); + const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two"); + + const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1"); + const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2"); + const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3"); + const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4"); + const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5"); + const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6"); + const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7"); + const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8"); + const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9"); + + const auto data1 = pb.NewTuple(tupleType, {keyOne, value1}); + + const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2}); + const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3}); + + const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4}); + + const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5}); + const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6}); + const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7}); + const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8}); + const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9}); + + const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9}); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideTop(pb.ExpandMap(pb.ToFlow(list), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }), + pb.NewDataLiteral<ui64>(2ULL), {{1U, pb.NewDataLiteral<bool>(false)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 8"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 9"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } + + Y_UNIT_TEST_LLVM(TopSortByFirstSecondAscDesc) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id); + const auto tupleType = pb.NewTupleType({dataType, dataType}); + + const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one"); + const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two"); + + const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one"); + const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two"); + + const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1"); + const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2"); + const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3"); + const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4"); + const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5"); + const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6"); + const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7"); + const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8"); + const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9"); + + const auto data1 = pb.NewTuple(tupleType, {keyOne, value1}); + + const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2}); + const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3}); + + const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4}); + + const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5}); + const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6}); + const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7}); + const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8}); + const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9}); + + const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9}); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideTopSort(pb.ExpandMap(pb.ToFlow(list), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }), + pb.NewDataLiteral<ui64>(4ULL), {{0U, pb.NewDataLiteral<bool>(true)}, {1U, pb.NewDataLiteral<bool>(false)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 1"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 3"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 2"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 4"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } + + Y_UNIT_TEST_LLVM(TopSortByFirstSecondDescAsc) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id); + const auto tupleType = pb.NewTupleType({dataType, dataType}); + + const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one"); + const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two"); + + const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one"); + const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two"); + + const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1"); + const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2"); + const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3"); + const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4"); + const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5"); + const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6"); + const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7"); + const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8"); + const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9"); + + const auto data1 = pb.NewTuple(tupleType, {keyOne, value1}); + + const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2}); + const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3}); + + const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4}); + + const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5}); + const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6}); + const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7}); + const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8}); + const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9}); + + const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9}); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideTopSort(pb.ExpandMap(pb.ToFlow(list), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }), + pb.NewDataLiteral<ui64>(6ULL), {{0U, pb.NewDataLiteral<bool>(false)}, {1U, pb.NewDataLiteral<bool>(true)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 5"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 6"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 7"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 8"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 9"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 4"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } + + Y_UNIT_TEST_LLVM(TopSortBySecondFirstAscDesc) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id); + const auto tupleType = pb.NewTupleType({dataType, dataType}); + + const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one"); + const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two"); + + const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one"); + const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two"); + + const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1"); + const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2"); + const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3"); + const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4"); + const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5"); + const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6"); + const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7"); + const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8"); + const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9"); + + const auto data1 = pb.NewTuple(tupleType, {keyOne, value1}); + + const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2}); + const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3}); + + const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4}); + + const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5}); + const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6}); + const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7}); + const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8}); + const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9}); + + const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9}); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideTopSort(pb.ExpandMap(pb.ToFlow(list), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }), + pb.NewDataLiteral<ui64>(4ULL), {{1U, pb.NewDataLiteral<bool>(true)}, {0U, pb.NewDataLiteral<bool>(false)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 1"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 2"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 3"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 4"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } + + Y_UNIT_TEST_LLVM(TopSortBySecondFirstDescAsc) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<const char*>::Id); + const auto tupleType = pb.NewTupleType({dataType, dataType}); + + const auto keyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("key one"); + const auto keyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("key two"); + + const auto longKeyOne = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key one"); + const auto longKeyTwo = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long key two"); + + const auto value1 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 1"); + const auto value2 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 2"); + const auto value3 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 3"); + const auto value4 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 4"); + const auto value5 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 5"); + const auto value6 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 6"); + const auto value7 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 7"); + const auto value8 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 8"); + const auto value9 = pb.NewDataLiteral<NUdf::EDataSlot::String>("very long value 9"); + + const auto data1 = pb.NewTuple(tupleType, {keyOne, value1}); + + const auto data2 = pb.NewTuple(tupleType, {keyTwo, value2}); + const auto data3 = pb.NewTuple(tupleType, {keyTwo, value3}); + + const auto data4 = pb.NewTuple(tupleType, {longKeyOne, value4}); + + const auto data5 = pb.NewTuple(tupleType, {longKeyTwo, value5}); + const auto data6 = pb.NewTuple(tupleType, {longKeyTwo, value6}); + const auto data7 = pb.NewTuple(tupleType, {longKeyTwo, value7}); + const auto data8 = pb.NewTuple(tupleType, {longKeyTwo, value8}); + const auto data9 = pb.NewTuple(tupleType, {longKeyTwo, value9}); + + const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9}); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideTopSort(pb.ExpandMap(pb.ToFlow(list), + [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }), + pb.NewDataLiteral<ui64>(6ULL), {{1U, pb.NewDataLiteral<bool>(false)}, {0U, pb.NewDataLiteral<bool>(true)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 9"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 8"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 7"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 6"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key two"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 5"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(0), "very long key one"); + UNBOXED_VALUE_STR_EQUAL(item.GetElement(1), "very long value 4"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } + + Y_UNIT_TEST_LLVM(TopSortLargeList) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto minusday = i64(-24LL * 60LL * 60LL * 1000000LL); // -1 Day + const auto step = pb.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&minusday, sizeof(minusday))); + const auto list = pb.ListFromRange(pb.NewTzDataLiteral<NUdf::TTzDate>(30000u, 42u), pb.NewTzDataLiteral<NUdf::TTzDate>(10000u, 42u), step); + + const auto pgmReturn = pb.Collect(pb.NarrowMap(pb.WideTopSort(pb.ExpandMap(pb.ToFlow(pb.Enumerate(list)), + [&](TRuntimeNode item) -> TRuntimeNode::TList { + const auto utf = pb.ToString<true>(pb.Nth(item, 1U)); + const auto day = pb.StrictFromString(pb.Substring(utf, pb.NewDataLiteral<ui32>(8U), pb.NewDataLiteral<ui32>(2U)), pb.NewDataType(NUdf::EDataSlot::Uint8)); + return {utf, pb.Nth(item, 0U), day, pb.Nth(item, 1U)}; + }), + pb.NewDataLiteral<ui64>(7ULL), {{2U, pb.NewDataLiteral<bool>(true)}, {1U, pb.NewDataLiteral<bool>(false)}}), + [&](TRuntimeNode::TList items) -> TRuntimeNode { return items.front(); } + )); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item, "1997-06-01,Africa/Mbabane"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item, "1997-07-01,Africa/Mbabane"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item, "1997-08-01,Africa/Mbabane"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item, "1997-09-01,Africa/Mbabane"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item, "1997-10-01,Africa/Mbabane"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item, "1997-11-01,Africa/Mbabane"); + UNIT_ASSERT(iterator.Next(item)); + UNBOXED_VALUE_STR_EQUAL(item, "1997-12-01,Africa/Mbabane"); + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } +} +#endif +} +} + diff --git a/ydb/library/yql/minikql/computation/mkql_computation_node_codegen.cpp b/ydb/library/yql/minikql/computation/mkql_computation_node_codegen.cpp index 4d700f2f01..3479db37b5 100644 --- a/ydb/library/yql/minikql/computation/mkql_computation_node_codegen.cpp +++ b/ydb/library/yql/minikql/computation/mkql_computation_node_codegen.cpp @@ -235,6 +235,10 @@ Value* GetPointerFromUnboxed(Value* value, const TCodegenContext& ctx, BasicBloc } } +ui32 MyCompareStrings(NUdf::TUnboxedValuePod lhs, NUdf::TUnboxedValuePod rhs) { + return NUdf::CompareStrings(lhs, rhs); +} + bool MyEquteStrings(NUdf::TUnboxedValuePod lhs, NUdf::TUnboxedValuePod rhs) { return NUdf::EquateStrings(lhs, rhs); } @@ -322,6 +326,143 @@ Value* GenEqualsFunction(NUdf::EDataSlot slot, bool isOptional, Value* lv, Value return isOptional ? GenEqualsFunction<true>(slot, lv, rv, ctx, block) : GenEqualsFunction<false>(slot, lv, rv, ctx, block); } +template <bool IsOptional> +Value* GenCompareFunction(NUdf::EDataSlot slot, Value* lv, Value* rv, TCodegenContext& ctx, BasicBlock*& block); + +template <> +Value* GenCompareFunction<false>(NUdf::EDataSlot slot, Value* lv, Value* rv, TCodegenContext& ctx, BasicBlock*& block) { + auto& context = ctx.Codegen->GetContext(); + + const auto& info = NUdf::GetDataTypeInfo(slot); + + if ((info.Features & NUdf::EDataTypeFeatures::CommonType) && (info.Features & NUdf::EDataTypeFeatures::StringType || NUdf::EDataSlot::Uuid == slot || NUdf::EDataSlot::DyNumber == slot)) { + return CallBinaryUnboxedValueFunction(&MyCompareStrings, Type::getInt32Ty(context), lv, rv, ctx.Codegen, block); + } + + const bool extra = info.Features & (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::TzDateType); + const auto resultType = Type::getInt32Ty(context); + + const auto exit = BasicBlock::Create(context, "exit", ctx.Func); + const auto test = BasicBlock::Create(context, "test", ctx.Func); + + const auto res = PHINode::Create(resultType, extra ? 3U : 2U, "result", exit); + + const auto lhs = GetterFor(slot, lv, context, block); + const auto rhs = GetterFor(slot, rv, context, block); + + if (info.Features & NUdf::EDataTypeFeatures::FloatType) { + const auto more = BasicBlock::Create(context, "more", ctx.Func); + const auto next = BasicBlock::Create(context, "next", ctx.Func); + + const auto uno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, lhs, rhs, "unorded", block); + + BranchInst::Create(more, next, uno, block); + block = more; + + const auto luno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(lhs->getType(), 0.0), lhs, "luno", block); + const auto runo = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(rhs->getType(), 0.0), rhs, "runo", block); + const auto once = BinaryOperator::CreateXor(luno, runo, "xor", block); + + const auto left = SelectInst::Create(luno, ConstantInt::get(resultType, 1), ConstantInt::get(resultType, -1), "left", block); + const auto both = SelectInst::Create(once, left, ConstantInt::get(resultType, 0), "both", block); + + res->addIncoming(both, block); + BranchInst::Create(exit, block); + + block = next; + } + + const auto equals = info.Features & NUdf::EDataTypeFeatures::FloatType ? + CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_OEQ, lhs, rhs, "equals", block): + CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, lhs, rhs, "equals", block); + + if (info.Features & NUdf::EDataTypeFeatures::TzDateType) { + const auto more = BasicBlock::Create(context, "more", ctx.Func); + const auto next = BasicBlock::Create(context, "next", ctx.Func); + + BranchInst::Create(more, test, equals, block); + + block = more; + + const auto ltz = GetterForTimezone(context, lv, block); + const auto rtz = GetterForTimezone(context, rv, block); + const auto tzeq = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, ltz, rtz, "tzeq", block); + res->addIncoming(ConstantInt::get(resultType, 0), block); + BranchInst::Create(exit, next, tzeq, block); + + block = next; + const auto tzlt = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, ltz, rtz, "tzlt", block); + const auto tzout = SelectInst::Create(tzlt, ConstantInt::get(resultType, -1), ConstantInt::get(resultType, 1), "tzout", block); + res->addIncoming(tzout, block); + BranchInst::Create(exit, block); + } else { + res->addIncoming(ConstantInt::get(resultType, 0), block); + BranchInst::Create(exit, test, equals, block); + } + + block = test; + + const auto less = info.Features & NUdf::EDataTypeFeatures::FloatType ? + CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_OLT, lhs, rhs, "less", block): // float + info.Features & (NUdf::EDataTypeFeatures::SignedIntegralType | NUdf::EDataTypeFeatures::TimeIntervalType | NUdf::EDataTypeFeatures::DecimalType) ? + CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, lhs, rhs, "less", block): // signed + info.Features & (NUdf::EDataTypeFeatures::UnsignedIntegralType | NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType) ? + CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, lhs, rhs, "less", block): // unsigned + rhs; // bool + + const auto out = SelectInst::Create(less, ConstantInt::get(resultType, -1), ConstantInt::get(resultType, 1), "out", block); + res->addIncoming(out, block); + BranchInst::Create(exit, block); + + block = exit; + return res; +} + +template <> +Value* GenCompareFunction<true>(NUdf::EDataSlot slot, Value* lv, Value* rv, TCodegenContext& ctx, BasicBlock*& block) { + auto& context = ctx.Codegen->GetContext(); + + const auto tiny = BasicBlock::Create(context, "tiny", ctx.Func); + const auto side = BasicBlock::Create(context, "side", ctx.Func); + const auto test = BasicBlock::Create(context, "test", ctx.Func); + const auto done = BasicBlock::Create(context, "done", ctx.Func); + + const auto resultType = Type::getInt32Ty(context); + const auto res = PHINode::Create(resultType, 2U, "result", done); + + const auto le = IsEmpty(lv, block); + const auto re = IsEmpty(rv, block); + + const auto any = BinaryOperator::CreateOr(le, re, "or", block); + + BranchInst::Create(tiny, test, any, block); + + block = tiny; + + const auto both = BinaryOperator::CreateAnd(le, re, "and", block); + res->addIncoming(ConstantInt::get(resultType, 0), block); + BranchInst::Create(done, side, both, block); + + block = side; + + const auto out = SelectInst::Create(le, ConstantInt::get(resultType, -1), ConstantInt::get(resultType, 1), "out", block); + res->addIncoming(out, block); + BranchInst::Create(done, block); + + block = test; + + const auto comp = GenCompareFunction<false>(slot, lv, rv, ctx, block); + res->addIncoming(comp, block); + BranchInst::Create(done, block); + + block = done; + return res; +} + +Value* GenCompareFunction(NUdf::EDataSlot slot, bool isOptional, Value* lv, Value* rv, TCodegenContext& ctx, BasicBlock*& block) { + return isOptional ? GenCompareFunction<true>(slot, lv, rv, ctx, block) : GenCompareFunction<false>(slot, lv, rv, ctx, block); +} + Value* GenCombineHashes(Value* first, Value* second, BasicBlock* block) { // key += ~(key << 32); const auto x01 = BinaryOperator::CreateShl(first, ConstantInt::get(first->getType(), 32), "x01", block); @@ -774,6 +915,70 @@ Function* GenerateHashFunction(const NYql::NCodegen::ICodegen::TPtr& codegen, co return ctx.Func; } +Function* GenerateCompareFunction(const NYql::NCodegen::ICodegen::TPtr& codegen, const TString& name, const TKeyTypes& types) { + auto& module = codegen->GetModule(); + if (const auto f = module.getFunction(name.c_str())) + return f; + + auto& context = codegen->GetContext(); + const auto valueType = Type::getInt128Ty(context); + const auto elementsType = ArrayType::get(valueType, types.size()); + const auto ptrType = PointerType::getUnqual(elementsType); + const auto ptrDirsType = PointerType::getUnqual(ArrayType::get(Type::getInt1Ty(context), types.size())); + const auto returnType = Type::getInt32Ty(context); + + const auto funcType = FunctionType::get(returnType, {ptrDirsType, ptrType, ptrType}, false); + + TCodegenContext ctx(codegen); + ctx.AlwaysInline = true; + ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee()); + + auto args = ctx.Func->arg_begin(); + + const auto main = BasicBlock::Create(context, "main", ctx.Func); + auto block = main; + + const auto dp = &*args; + const auto lv = &*++args; + const auto rv = &*++args; + + if (types.empty()) { + ReturnInst::Create(context, ConstantInt::get(returnType, 0), block); + return ctx.Func; + } + + const auto directions = new LoadInst(dp, "directions", block); + const auto elementsOne = new LoadInst(lv, "elements_one", block); + const auto elementsTwo = new LoadInst(rv, "elements_two", block); + const auto zero = ConstantInt::get(returnType, 0); + + for (auto i = 0U; i < types.size(); ++i) { + const auto nextOne = ExtractValueInst::Create(elementsOne, i, (TString("next_one_") += ToString(i)).c_str(), block); + const auto nextTwo = ExtractValueInst::Create(elementsTwo, i, (TString("next_two_") += ToString(i)).c_str(), block); + + const auto exit = BasicBlock::Create(context, (TString("exit_") += ToString(i)).c_str(), ctx.Func); + const auto step = BasicBlock::Create(context, (TString("step_") += ToString(i)).c_str(), ctx.Func); + + const auto test = GenCompareFunction(types[i].first, types[i].second, nextOne, nextTwo, ctx, block); + const auto skip = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, zero, test, (TString("skip_") += ToString(i)).c_str(), block); + + BranchInst::Create(step, exit, skip, block); + + block = exit; + + const auto dir = ExtractValueInst::Create(directions, i, (TString("dir_") += ToString(i)).c_str(), block); + const auto neg = BinaryOperator::CreateNeg(test, (TString("neg_") += ToString(i)).c_str(), block); + const auto out = SelectInst::Create(dir, test, neg, (TString("neg_") += ToString(i)).c_str(), block); + + ReturnInst::Create(context, out, block); + + block = step; + } + + ReturnInst::Create(context, zero, block); + return ctx.Func; +} + void GenInvalidate(const TCodegenContext& ctx, const std::vector<std::pair<ui32, EValueRepresentation>>& invalidationSet, BasicBlock*& block) { auto& context = ctx.Codegen->GetContext(); const auto indexType = Type::getInt32Ty(context); diff --git a/ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h b/ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h index 77230e4825..fa10a28e6f 100644 --- a/ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h +++ b/ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h @@ -401,6 +401,7 @@ Function* GenerateHashFunction(const NYql::NCodegen::ICodegen::TPtr& codegen, co Function* GenerateEqualsFunction(const NYql::NCodegen::ICodegen::TPtr& codegen, const TString& name, const TKeyTypes& types); Function* GenerateHashFunction(const NYql::NCodegen::ICodegen::TPtr& codegen, const TString& name, const TKeyTypes& types); +Function* GenerateCompareFunction(const NYql::NCodegen::ICodegen::TPtr& codegen, const TString& name, const TKeyTypes& types); template <typename TDerived> class TDecoratorCodegeneratorNode: public TDecoratorComputationNode<TDerived>, public ICodegeneratorInlineNode diff --git a/ydb/library/yql/minikql/computation/mkql_computation_node_holders.h b/ydb/library/yql/minikql/computation/mkql_computation_node_holders.h index 74b769f180..be60525110 100644 --- a/ydb/library/yql/minikql/computation/mkql_computation_node_holders.h +++ b/ydb/library/yql/minikql/computation/mkql_computation_node_holders.h @@ -46,8 +46,7 @@ using TUnboxedValueDeque = std::deque<NUdf::TUnboxedValue, TMKQLAllocator<NUdf:: using TKeyPayloadPair = std::pair<NUdf::TUnboxedValue, NUdf::TUnboxedValue>; using TKeyPayloadPairVector = std::vector<TKeyPayloadPair, TMKQLAllocator<TKeyPayloadPair>>; -inline int CompareValues(NUdf::EDataSlot type, - bool asc, bool isOptional, const NUdf::TUnboxedValuePod& lhs, const NUdf::TUnboxedValuePod& rhs) { +inline int CompareValues(NUdf::EDataSlot type, bool asc, bool isOptional, const NUdf::TUnboxedValuePod& lhs, const NUdf::TUnboxedValuePod& rhs) { int cmp; if (isOptional) { if (!lhs && !rhs) { @@ -74,6 +73,16 @@ inline int CompareValues(NUdf::EDataSlot type, return cmp; } +inline int CompareValues(const NUdf::TUnboxedValuePod* left, const NUdf::TUnboxedValuePod* right, const TKeyTypes& types, const bool* directions) { + for (ui32 i = 0; i < types.size(); ++i) { + if (const auto cmp = CompareValues(types[i].first, directions[i], types[i].second, left[i], right[i])) { + return cmp; + } + } + + return 0; +} + inline int CompareKeys(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right, const TKeyTypes& types, bool isTuple) { if (isTuple) { if (left && right) diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index a4f25520f0..fd6f946229 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -1661,6 +1661,35 @@ TRuntimeNode TProgramBuilder::Sort(TRuntimeNode list, TRuntimeNode ascending, co return BuildSort(__func__, list, ascending, keyExtractor); } +TRuntimeNode TProgramBuilder::WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) +{ + return BuildWideTop(__func__, flow, count, keys); +} + +TRuntimeNode TProgramBuilder::WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) +{ + return BuildWideTop(__func__, flow, count, keys); +} + +TRuntimeNode TProgramBuilder::BuildWideTop(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { + if constexpr (RuntimeVersion < 33U) { + THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << callableName; + } + + const auto width = AS_TYPE(TTupleType, AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType())->GetElementsCount(); + MKQL_ENSURE(!keys.empty() && keys.size() <= width, "Unexpected keys count: " << keys.size()); + + TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType()); + callableBuilder.Add(flow); + callableBuilder.Add(count); + std::for_each(keys.cbegin(), keys.cend(), [&](const std::pair<ui32, TRuntimeNode>& key) { + MKQL_ENSURE(key.first < width, "Key index too large: " << key.first); + callableBuilder.Add(NewDataLiteral(key.first)); + callableBuilder.Add(key.second); + }); + return TRuntimeNode(callableBuilder.Build(), false); +} + TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) { if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) { diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index 24ec38a1d0..d8012bd063 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -403,6 +403,9 @@ public: TRuntimeNode WideLastCombiner(TRuntimeNode flow, const TWideLambda& keyExtractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish); TRuntimeNode WideCondense1(TRuntimeNode stream, const TWideLambda& init, const TWideSwitchLambda& switcher, const TBinaryWideLambda& handler, bool useCtx = false); + TRuntimeNode WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); + TRuntimeNode WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); + TRuntimeNode Length(TRuntimeNode listOrDict); TRuntimeNode Iterator(TRuntimeNode list, const TArrayRef<const TRuntimeNode>& dependentNodes); TRuntimeNode EmptyIterator(TType* streamType); @@ -718,6 +721,8 @@ private: TRuntimeNode BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members, const std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>>& filteredItems); + TRuntimeNode BuildWideTop(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); + TRuntimeNode InvokeBinary(const std::string_view& callableName, TType* type, TRuntimeNode data1, TRuntimeNode data2); TRuntimeNode AggrCompare(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2); TRuntimeNode DataCompare(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2); diff --git a/ydb/library/yql/minikql/mkql_runtime_version.h b/ydb/library/yql/minikql/mkql_runtime_version.h index 043e54c485..601e40788e 100644 --- a/ydb/library/yql/minikql/mkql_runtime_version.h +++ b/ydb/library/yql/minikql/mkql_runtime_version.h @@ -24,7 +24,7 @@ namespace NMiniKQL { // 1. Bump this version every time incompatible runtime nodes are introduced. // 2. Make sure you provide runtime node generation for previous runtime versions. #ifndef MKQL_RUNTIME_VERSION -#define MKQL_RUNTIME_VERSION 32U +#define MKQL_RUNTIME_VERSION 33U #endif // History: |