diff options
author | vvvv <vvvv@ydb.tech> | 2023-02-22 17:06:34 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2023-02-22 17:06:34 +0300 |
commit | 65963701e9f4d71fb82dd81e0987021612296adb (patch) | |
tree | 9d2ed228ce96011d6472c1a01286657c6d0945e0 | |
parent | 91daac67540e7734dc2e7c8f198eba48d7ea890e (diff) | |
download | ydb-65963701e9f4d71fb82dd81e0987021612296adb.tar.gz |
block sort
5 files changed, 154 insertions, 33 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index b3fa592f171..53941e0cc6e 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -7280,7 +7280,7 @@ struct TPeepHoleRules { {"BlockCombineHashed", &OptimizeBlockCombine}, {"WideTop", &OptimizeTopOrSortBlocks}, {"WideTopSort", &OptimizeTopOrSortBlocks}, - //{"WideSort", &OptimizeTopOrSortBlocks}, + {"WideSort", &OptimizeTopOrSortBlocks}, }; TPeepHoleRules() diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_top.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_top.cpp index d50b5a1c76e..cfc075e855d 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_top.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_top.cpp @@ -19,17 +19,19 @@ namespace NMiniKQL { namespace { -class TTopBlocksWrapper : public TStatefulWideFlowBlockComputationNode<TTopBlocksWrapper> { +template <bool Sort, bool HasCount> +class TTopOrSortBlocksWrapper : public TStatefulWideFlowBlockComputationNode<TTopOrSortBlocksWrapper<Sort, HasCount>> { + using TSelf = TTopOrSortBlocksWrapper<Sort, HasCount>; + using TBase = TStatefulWideFlowBlockComputationNode<TSelf>; using TChunkedArrayIndex = TVector<IArrayBuilder::TArrayDataItem>; public: - TTopBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TTupleType* tupleType, IComputationNode* count, - TVector<IComputationNode*>&& directions, TVector<ui32>&& keyIndicies, bool sort) - : TStatefulWideFlowBlockComputationNode(mutables, flow, tupleType->GetElementsCount()) + TTopOrSortBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TTupleType* tupleType, IComputationNode* count, + TVector<IComputationNode*>&& directions, TVector<ui32>&& keyIndicies) + : TBase(mutables, flow, tupleType->GetElementsCount()) , Flow_(flow) , Count_(count) , Directions_(std::move(directions)) , KeyIndicies_(std::move(keyIndicies)) - , Sort_(sort) { for (ui32 i = 0; i < tupleType->GetElementsCount() - 1; ++i) { Columns_.push_back(AS_TYPE(TBlockType, tupleType->GetElementType(i))); @@ -45,7 +47,10 @@ public: } if (!s.PreparedCountAndDirections_) { - s.Count_ = Count_->GetValue(ctx).Get<ui64>(); + if constexpr (HasCount) { + s.Count_ = Count_->GetValue(ctx).Get<ui64>(); + } + for (ui32 k = 0; k < KeyIndicies_.size(); ++k) { s.Directions_[k] = Directions_[k]->GetValue(ctx).Get<bool>(); } @@ -53,9 +58,11 @@ public: s.PreparedCountAndDirections_ = true; } - if (!s.Count_) { - s.IsFinished_ = true; - return EFetchResult::Finish; + if constexpr (HasCount) { + if (!s.Count_) { + s.IsFinished_ = true; + return EFetchResult::Finish; + } } if (!s.PreparedBuilders_) { @@ -63,11 +70,23 @@ public: s.PreparedBuilders_ = true; } + if (s.WritingOutput_) { + if (s.Written_ >= s.OutputLength_) { + s.IsFinished_ = true; + return EFetchResult::Finish; + } + + s.FillSortOutputPart(Columns_, output, ctx); + return EFetchResult::One; + } + for (;;) { auto result = Flow_->FetchValues(ctx, s.ValuePointers_.data()); if (result == EFetchResult::Yield) { return result; } else if (result == EFetchResult::One) { + ui64 blockLen = TArrowBlock::From(s.Values_.back()).GetDatum().template scalar_as<arrow::UInt64Scalar>().value; + if (!s.ScalarsFilled_) { for (ui32 i = 0; i < Columns_.size(); ++i) { if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { @@ -78,7 +97,17 @@ public: s.ScalarsFilled_ = true; } - ui64 blockLen = TArrowBlock::From(s.Values_.back()).GetDatum().scalar_as<arrow::UInt64Scalar>().value; + if constexpr (!HasCount) { + for (ui32 i = 0; i < Columns_.size(); ++i) { + auto datum = TArrowBlock::From(s.Values_[i]).GetDatum(); + if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) { + s.SortInput_[i].emplace_back(datum); + } + } + + s.OutputLength_ += blockLen; + continue; + } // shrink input block TMaybe<TVector<ui64>> blockIndicies; @@ -108,13 +137,25 @@ public: } } else { + if constexpr (!HasCount) { + if (!s.OutputLength_) { + s.IsFinished_ = true; + return EFetchResult::Finish; + } + + s.SortAll(Columns_, KeyIndicies_); + s.WritingOutput_ = true; + s.FillSortOutputPart(Columns_, output, ctx); + return EFetchResult::One; + } + s.IsFinished_ = true; if (!s.BuilderLength_) { return EFetchResult::Finish; } - if (s.BuilderLength_ > s.Count_ || Sort_) { - s.CompressBuilders(Sort_, Columns_, KeyIndicies_, ctx); + if (s.BuilderLength_ > s.Count_ || Sort) { + s.CompressBuilders(Sort, Columns_, KeyIndicies_, ctx); } s.FillOutput(Columns_, output, ctx); @@ -125,17 +166,27 @@ public: private: void RegisterDependencies() const final { - if (const auto flow = FlowDependsOn(Flow_)) { - DependsOn(flow, Count_); + if (const auto flow = this->FlowDependsOn(Flow_)) { + if constexpr (HasCount) { + this->DependsOn(flow, Count_); + } for (auto dir : Directions_) { - DependsOn(flow, dir); + this->DependsOn(flow, dir); } } } class TState : public TComputationValue<TState> { + using TBase = TComputationValue<TState>; public: + bool WritingOutput_ = false; + ui64 OutputLength_ = 0; + ui64 Written_ = 0; + TVector<TVector<arrow::Datum>> SortInput_; + TVector<ui64> SortPermutation_; + TVector<TChunkedArrayIndex> SortArrays_; + bool IsFinished_ = false; bool PreparedCountAndDirections_ = false; ui64 Count_ = 0; @@ -154,7 +205,7 @@ private: TVector<NUdf::TUnboxedValue*> ValuePointers_; TState(TMemoryUsageInfo* memInfo, const TVector<ui32>& keyIndicies, const TVector<TBlockType*>& columns) - : TComputationValue(memInfo) + : TBase(memInfo) { Directions_.resize(keyIndicies.size()); LeftReaders_.resize(columns.size()); @@ -179,6 +230,9 @@ private: for (ui32 k = 0; k < keyIndicies.size(); ++k) { Comparators_[k] = NUdf::MakeBlockItemComparator(TTypeInfoHelper(), columns[keyIndicies[k]]->GetItemType()); } + + SortInput_.resize(columns.size()); + SortArrays_.resize(columns.size()); } ui64 GetStorageLength() const { @@ -188,14 +242,17 @@ private: void AllocateBuilders(const TVector<TBlockType*>& columns, TComputationContext& ctx) { BuilderMaxLength_ = GetStorageLength(); + size_t maxBlockItemSize = 0; for (ui32 i = 0; i < columns.size(); ++i) { if (columns[i]->GetShape() == TBlockType::EShape::Scalar) { continue; } - BuilderMaxLength_ = Max(BuilderMaxLength_, CalcBlockLen(CalcMaxBlockItemSize(columns[i]->GetItemType()))); + maxBlockItemSize = Max(maxBlockItemSize, CalcMaxBlockItemSize(columns[i]->GetItemType())); }; + BuilderMaxLength_ = Max(BuilderMaxLength_, CalcBlockLen(maxBlockItemSize)); + for (ui32 i = 0; i < columns.size(); ++i) { if (columns[i]->GetShape() == TBlockType::EShape::Scalar) { continue; @@ -249,6 +306,58 @@ private: BuilderLength_ = blockLen; } + void SortAll(const TVector<TBlockType*>& columns, const TVector<ui32>& keyIndicies) { + SortPermutation_.reserve(OutputLength_); + for (ui64 i = 0; i < OutputLength_; ++i) { + SortPermutation_.emplace_back(i); + } + + for (ui32 i = 0; i < columns.size(); ++i) { + ui64 offset = 0; + for (const auto& datum : SortInput_[i]) { + if (datum.is_scalar()) { + continue; + } else if (datum.is_array()) { + auto arrayData = datum.array(); + SortArrays_[i].push_back({ arrayData.get(), offset }); + offset += arrayData->length; + } else { + auto chunks = datum.chunks(); + for (auto& chunk : chunks) { + auto arrayData = chunk->data(); + SortArrays_[i].push_back({ arrayData.get(), offset }); + offset += arrayData->length; + } + } + } + } + + TBlockLess cmp(keyIndicies, *this, SortArrays_); + std::sort(SortPermutation_.begin(), SortPermutation_.end(), cmp); + } + + void FillSortOutputPart(const TVector<TBlockType*>& columns, NUdf::TUnboxedValue*const* output, TComputationContext& ctx) { + auto blockLen = Min(BuilderMaxLength_, OutputLength_ - Written_); + const bool isLast = (Written_ + blockLen == OutputLength_); + + for (ui32 i = 0; i < columns.size(); ++i) { + if (!output[i]) { + continue; + } + + if (columns[i]->GetShape() == TBlockType::EShape::Scalar) { + *output[i] = ScalarValues_[i]; + } else { + Builders_[i]->AddMany(SortArrays_[i].data(), SortArrays_[i].size(), SortPermutation_.data() + Written_, blockLen); + *output[i] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(Builders_[i]->Build(isLast))); + } + } + + *output[columns.size()] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(blockLen))); + + Written_ += blockLen; + } + void FillOutput(const TVector<TBlockType*>& columns, NUdf::TUnboxedValue*const* output, TComputationContext& ctx) { for (ui32 i = 0; i < columns.size(); ++i) { if (!output[i]) { @@ -383,12 +492,14 @@ private: IComputationNode* Count_; const TVector<IComputationNode*> Directions_; const TVector<ui32> KeyIndicies_; - const bool Sort_; TVector<TBlockType*> Columns_; }; -IComputationNode* WrapTop(TCallable& callable, const TComputationNodeFactoryContext& ctx, bool sort) { - MKQL_ENSURE(callable.GetInputsCount() > 2U && !(callable.GetInputsCount() % 2U), "Expected more arguments."); +template <bool Sort, bool HasCount> +IComputationNode* WrapTopOrSort(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + constexpr ui32 offset = HasCount ? 0 : 1; + const ui32 inputsWithCount = callable.GetInputsCount() + offset; + MKQL_ENSURE(inputsWithCount > 2U && !(inputsWithCount % 2U), "Expected more arguments."); const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); const auto tupleType = AS_TYPE(TTupleType, flowType->GetItemType()); @@ -397,32 +508,38 @@ IComputationNode* WrapTop(TCallable& callable, const TComputationNodeFactoryCont auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); - const auto count = LocateNode(ctx.NodeLocator, callable, 1); - const auto countType = AS_TYPE(TDataType, callable.GetInput(1).GetStaticType()); - MKQL_ENSURE(countType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64"); - + IComputationNode* count = nullptr; + if constexpr (HasCount) { + const auto countType = AS_TYPE(TDataType, callable.GetInput(1).GetStaticType()); + MKQL_ENSURE(countType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64"); + count = LocateNode(ctx.NodeLocator, callable, 1); + } + TVector<IComputationNode*> directions; TVector<ui32> keyIndicies; - for (ui32 i = 2; i < callable.GetInputsCount(); i += 2) { - ui32 keyIndex = AS_VALUE(TDataLiteral, callable.GetInput(i))->AsValue().Get<ui32>(); + for (ui32 i = 2; i < inputsWithCount; i += 2) { + ui32 keyIndex = AS_VALUE(TDataLiteral, callable.GetInput(i - offset))->AsValue().Get<ui32>(); MKQL_ENSURE(keyIndex + 1 < tupleType->GetElementsCount(), "Wrong key index"); keyIndicies.push_back(keyIndex); - directions.push_back(LocateNode(ctx.NodeLocator, callable, i + 1)); + directions.push_back(LocateNode(ctx.NodeLocator, callable, i + 1 - offset)); } - return new TTopBlocksWrapper(ctx.Mutables, wideFlow, tupleType, count, std::move(directions), std::move(keyIndicies), sort); + return new TTopOrSortBlocksWrapper<Sort, HasCount>(ctx.Mutables, wideFlow, tupleType, count, std::move(directions), std::move(keyIndicies)); } } //namespace IComputationNode* WrapWideTopBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - return WrapTop(callable, ctx, false); + return WrapTopOrSort<false, true>(callable, ctx); } IComputationNode* WrapWideTopSortBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - return WrapTop(callable, ctx, true); + return WrapTopOrSort<true, true>(callable, ctx); } +IComputationNode* WrapWideSortBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + return WrapTopOrSort<true, false>(callable, ctx); +} } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_top.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_top.h index 80b6cd70c5f..bb4f23b8394 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_top.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_top.h @@ -6,6 +6,7 @@ namespace NMiniKQL { IComputationNode* WrapWideTopBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx); IComputationNode* WrapWideTopSortBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx); +IComputationNode* WrapWideSortBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index 1662fed7b84..d336e46810b 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -280,6 +280,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"WideTakeBlocks", &WrapWideTakeBlocks}, {"WideTopBlocks", &WrapWideTopBlocks}, {"WideTopSortBlocks", &WrapWideTopSortBlocks}, + {"WideSortBlocks", &WrapWideSortBlocks}, {"AsScalar", &WrapAsScalar}, {"BlockCoalesce", &WrapBlockCoalesce}, {"BlockIf", &WrapBlockIf}, diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp index c1b98766358..ae59e906c70 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_wide_top_sort.cpp @@ -503,13 +503,15 @@ private: template<bool Sort, bool HasCount> IComputationNode* WrapWideTopT(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - const ui32 offset = HasCount ? 0 : 1; + constexpr ui32 offset = HasCount ? 0 : 1; const ui32 inputsWithCount = callable.GetInputsCount() + offset; MKQL_ENSURE(inputsWithCount > 2U && !(inputsWithCount % 2U), "Expected more arguments."); const auto flow = LocateNode(ctx.NodeLocator, callable, 0); IComputationNode* count = nullptr; - if (HasCount) { + if constexpr (HasCount) { + const auto countType = AS_TYPE(TDataType, callable.GetInput(1).GetStaticType()); + MKQL_ENSURE(countType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64"); count = LocateNode(ctx.NodeLocator, callable, 1); } |