diff options
author | atarasov5 <[email protected]> | 2025-07-09 12:58:08 +0300 |
---|---|---|
committer | atarasov5 <[email protected]> | 2025-07-09 13:23:47 +0300 |
commit | 7c10e44e4cedbf09d754e8ff05392a9d168af143 (patch) | |
tree | efcb9964c9ae273f31c7e4893b41421aaaffff0b | |
parent | 3611e9a59db095abcc485ca2609a38274bbec210 (diff) |
YQL-20080: flow -> stream rewrite
В этом пре переписал ноды `Wide{Top,TopSort,Sort}Blocks` с flow на stream реализацию.
Я разбил пр на два коммита: первый, это просто двигаю классы вверх вниз. Второй - сами изменения.
[Прогон тестов](https://nda.ya.ru/t/P9kfAmHr7GFmgy с понижением Runtime версии
commit_hash:0813a74aaa904b12846692c0e7504334170ea6db
10 files changed, 638 insertions, 434 deletions
diff --git a/yql/essentials/core/peephole_opt/yql_opt_peephole_physical.cpp b/yql/essentials/core/peephole_opt/yql_opt_peephole_physical.cpp index 7005f48e9e3..b1382cbf5ab 100644 --- a/yql/essentials/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/yql/essentials/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -297,16 +297,12 @@ TExprNode::TPtr OptimizeBlockCompress(const TExprNode::TPtr& node, TExprContext& TExprNode::TPtr OptimizeBlocksTopOrSort(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& types) { Y_UNUSED(types); const auto& input = node->HeadPtr(); - if (input->IsCallable("ToFlow") && input->Head().IsCallable("ReplicateScalars")) { - const auto& replicateScalars = input->HeadPtr(); + if (input->IsCallable("ReplicateScalars")) { // Technically, the code below rewrites the following sequence - // (Wide{Top,TopSort,Sort}Blocks (ToFlow (ReplicateScalars (<input>)))) - // into (ToFlow (ReplicateScalars (FromFlow (Wide{Top,TopSort,Sort}Blocks (<input>))))), - // but ToFlow/FromFlow wrappers will be removed when all other - // nodes in block pipeline start using WideStream instead of the - // WideFlow. Hence, the logging is left intact. - YQL_CLOG(DEBUG, CorePeepHole) << "Swap " << node->Content() << " with " << replicateScalars->Content(); - return SwapFlowNodeWithStreamNode(node, replicateScalars, ctx); + // (Wide{Top,TopSort,Sort}Blocks (ReplicateScalars (<input>))) + // into (ReplicateScalars (Wide{Top,TopSort,Sort}Blocks (<input>))). + YQL_CLOG(DEBUG, CorePeepHole) << "Swap " << node->Content() << " with " << input->Content(); + return ctx.SwapWithHead(*node); } return node; @@ -6933,24 +6929,25 @@ TExprNode::TPtr OptimizeTopOrSortBlocks(const TExprNode::TPtr& node, TExprContex TString newName = node->Content() + TString("Blocks"); YQL_CLOG(DEBUG, CorePeepHole) << "Convert " << node->Content() << " to " << newName; auto children = node->ChildrenList(); + // clang-format off children[0] = ctx.Builder(node->Pos()) - .Callable("ToFlow") - .Callable(0, "WideToBlocks") - .Callable(0, "FromFlow") - .Add(0, children[0]) - .Seal() + .Callable("WideToBlocks") + .Callable(0, "FromFlow") + .Add(0, children[0]) .Seal() .Seal() .Build(); + // clang-format on + + // clang-format off return ctx.Builder(node->Pos()) .Callable("ToFlow") .Callable(0, "WideFromBlocks") - .Callable(0, "FromFlow") - .Add(0, ctx.NewCallable(node->Pos(), newName, std::move(children))) - .Seal() + .Add(0, ctx.NewCallable(node->Pos(), newName, std::move(children))) .Seal() .Seal() .Build(); + // clang-format on } // TODO(YQL): Implement one more optimization for block types. diff --git a/yql/essentials/core/type_ann/type_ann_blocks.cpp b/yql/essentials/core/type_ann/type_ann_blocks.cpp index 9a1a7fdee1b..be824399d89 100644 --- a/yql/essentials/core/type_ann/type_ann_blocks.cpp +++ b/yql/essentials/core/type_ann/type_ann_blocks.cpp @@ -1029,7 +1029,7 @@ IGraphTransformer::TStatus WideTopBlocksWrapper(const TExprNode::TPtr& input, TE } TTypeAnnotationNode::TListType blockItemTypes; - if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -1060,7 +1060,7 @@ IGraphTransformer::TStatus WideSortBlocksWrapper(const TExprNode::TPtr& input, T } TTypeAnnotationNode::TListType blockItemTypes; - if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } diff --git a/yql/essentials/minikql/comp_nodes/mkql_block_top.cpp b/yql/essentials/minikql/comp_nodes/mkql_block_top.cpp index 7da4e7c4919..c9a1e67cf7e 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_block_top.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_block_top.cpp @@ -22,12 +22,390 @@ namespace NMiniKQL { namespace { -template <bool Sort, bool HasCount> -class TTopOrSortBlocksWrapper : public TStatefulWideFlowCodegeneratorNode<TTopOrSortBlocksWrapper<Sort, HasCount>> { -using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TTopOrSortBlocksWrapper<Sort, HasCount>>; using TChunkedArrayIndex = std::vector<IArrayBuilder::TArrayDataItem>; + +TChunkedArrayIndex MakeChunkedArrayIndex(const arrow::Datum& datum) { + TChunkedArrayIndex result; + if (datum.is_array()) { + result.push_back({datum.array().get(), 0}); + } else { + auto chunks = datum.chunks(); + ui64 offset = 0; + for (auto& chunk : chunks) { + auto arrayData = chunk->data(); + result.push_back({arrayData.get(), offset}); + offset += arrayData->length; + } + } + return result; +} + +template <bool Sort, bool HasCount> +class TTopOrSortBlocksState: public TBlockState { public: - TTopOrSortBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TArrayRef<TType* const> wideComponents, IComputationNode* count, + bool WritingOutput_ = false; + bool IsFinished_ = false; + + ui64 OutputLength_ = 0; + ui64 Written_ = 0; + const std::vector<bool> Directions_; + const ui64 Count_; + const std::vector<TBlockType*> Columns_; + const std::vector<ui32> KeyIndicies_; + std::vector<std::vector<arrow::Datum>> SortInput_; + std::vector<ui64> SortPermutation_; + std::vector<TChunkedArrayIndex> SortArrays_; + + bool ScalarsFilled_ = false; + TUnboxedValueVector ScalarValues_; + std::vector<std::unique_ptr<IBlockReader>> LeftReaders_; + std::vector<std::unique_ptr<IBlockReader>> RightReaders_; + std::vector<std::unique_ptr<IArrayBuilder>> Builders_; + ui64 BuilderMaxLength_ = 0; + ui64 BuilderLength_ = 0; + std::vector<NUdf::IBlockItemComparator::TPtr> Comparators_; // by key columns only + + class TBlockLess { + public: + TBlockLess(const std::vector<ui32>& keyIndicies, const TTopOrSortBlocksState<Sort, HasCount>& state, const std::vector<TChunkedArrayIndex>& arrayIndicies) + : KeyIndicies_(keyIndicies) + , ArrayIndicies_(arrayIndicies) + , State_(state) + { + } + + bool operator()(ui64 lhs, ui64 rhs) const { + if (KeyIndicies_.size() == 1) { + auto i = KeyIndicies_[0]; + auto& arrayIndex = ArrayIndicies_[i]; + if (arrayIndex.empty()) { + // scalar + return false; + } + + auto leftItem = GetBlockItem(*State_.LeftReaders_[i], arrayIndex, lhs); + auto rightItem = GetBlockItem(*State_.RightReaders_[i], arrayIndex, rhs); + if (State_.Directions_[0]) { + return State_.Comparators_[0]->Less(leftItem, rightItem); + } else { + return State_.Comparators_[0]->Greater(leftItem, rightItem); + } + } else { + for (ui32 k = 0; k < KeyIndicies_.size(); ++k) { + auto i = KeyIndicies_[k]; + auto& arrayIndex = ArrayIndicies_[i]; + if (arrayIndex.empty()) { + // scalar + continue; + } + + auto leftItem = GetBlockItem(*State_.LeftReaders_[i], arrayIndex, lhs); + auto rightItem = GetBlockItem(*State_.RightReaders_[i], arrayIndex, rhs); + auto cmp = State_.Comparators_[k]->Compare(leftItem, rightItem); + if (cmp == 0) { + continue; + } + + if (State_.Directions_[k]) { + return cmp < 0; + } else { + return cmp > 0; + } + } + + return false; + } + } + + private: + static TBlockItem GetBlockItem(IBlockReader& reader, const TChunkedArrayIndex& arrayIndex, ui64 idx) { + Y_DEBUG_ABORT_UNLESS(!arrayIndex.empty()); + if (arrayIndex.size() == 1) { + return reader.GetItem(*arrayIndex.front().Data, idx); + } + + auto it = LookupArrayDataItem(arrayIndex.data(), arrayIndex.size(), idx); + return reader.GetItem(*it->Data, idx); + } + + const std::vector<ui32>& KeyIndicies_; + const std::vector<TChunkedArrayIndex> ArrayIndicies_; + const TTopOrSortBlocksState<Sort, HasCount>& State_; + }; + + TTopOrSortBlocksState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const std::vector<ui32>& keyIndicies, const std::vector<TBlockType*>& columns, const bool* directions, ui64 count) + : TBlockState(memInfo, columns.size() + 1U) + , IsFinished_(HasCount && !count) + , Directions_(directions, directions + keyIndicies.size()) + , Count_(count) + , Columns_(columns) + , KeyIndicies_(keyIndicies) + , SortInput_(Columns_.size()) + , SortArrays_(Columns_.size()) + , ScalarValues_(Columns_.size()) + , LeftReaders_(Columns_.size()) + , RightReaders_(Columns_.size()) + , Builders_(Columns_.size()) + , Comparators_(KeyIndicies_.size()) + { + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { + continue; + } + + LeftReaders_[i] = MakeBlockReader(TTypeInfoHelper(), columns[i]->GetItemType()); + RightReaders_[i] = MakeBlockReader(TTypeInfoHelper(), columns[i]->GetItemType()); + } + + for (ui32 k = 0; k < KeyIndicies_.size(); ++k) { + Comparators_[k] = TBlockTypeHelper().MakeComparator(Columns_[KeyIndicies_[k]]->GetItemType()); + } + + BuilderMaxLength_ = GetStorageLength(); + size_t maxBlockItemSize = 0; + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { + continue; + } + + maxBlockItemSize = Max(maxBlockItemSize, CalcMaxBlockItemSize(Columns_[i]->GetItemType())); + }; + + BuilderMaxLength_ = Max(BuilderMaxLength_, CalcBlockLen(maxBlockItemSize)); + + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { + continue; + } + + Builders_[i] = MakeArrayBuilder(TTypeInfoHelper(), Columns_[i]->GetItemType(), ctx.ArrowMemoryPool, BuilderMaxLength_, &ctx.Builder->GetPgBuilder()); + } + } + + void Add(const NUdf::TUnboxedValuePod value, size_t idx) { + Values[idx] = value; + } + + void ProcessInput() { + const ui64 blockLen = TArrowBlock::From(Values.back()).GetDatum().template scalar_as<arrow::UInt64Scalar>().value; + + if (!ScalarsFilled_) { + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { + ScalarValues_[i] = std::move(Values[i]); + } + } + + ScalarsFilled_ = true; + } + + if constexpr (!HasCount) { + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) { + auto datum = TArrowBlock::From(Values[i]).GetDatum(); + SortInput_[i].emplace_back(datum); + } + } + + OutputLength_ += blockLen; + Values.assign(Values.size(), NUdf::TUnboxedValuePod()); + return; + } + + // shrink input block + std::optional<std::vector<ui64>> blockIndicies; + if (blockLen > Count_) { + blockIndicies.emplace(); + blockIndicies->reserve(blockLen); + for (ui64 row = 0; row < blockLen; ++row) { + blockIndicies->emplace_back(row); + } + + std::vector<TChunkedArrayIndex> arrayIndicies(Columns_.size()); + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) { + auto datum = TArrowBlock::From(Values[i]).GetDatum(); + arrayIndicies[i] = MakeChunkedArrayIndex(datum); + } + } + + const TBlockLess cmp(KeyIndicies_, *this, arrayIndicies); + NYql::FastNthElement(blockIndicies->begin(), blockIndicies->begin() + Count_, blockIndicies->end(), cmp); + } + + // copy all to builders + AddTop(Columns_, blockIndicies, blockLen); + if (BuilderLength_ + Count_ > BuilderMaxLength_) { + CompressBuilders(false); + } + + Values.assign(Values.size(), NUdf::TUnboxedValuePod()); + } + + ui64 GetStorageLength() const { + return 2 * Count_; + } + + void CompressBuilders(bool sort) { + Y_ABORT_UNLESS(ScalarsFilled_); + std::vector<TChunkedArrayIndex> arrayIndicies(Columns_.size()); + std::vector<arrow::Datum> tmpDatums(Columns_.size()); + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) { + auto datum = Builders_[i]->Build(false); + arrayIndicies[i] = MakeChunkedArrayIndex(datum); + tmpDatums[i] = std::move(datum); + } + } + + std::vector<ui64> blockIndicies; + blockIndicies.reserve(BuilderLength_); + for (ui64 row = 0; row < BuilderLength_; ++row) { + blockIndicies.push_back(row); + } + + const ui64 blockLen = Min(BuilderLength_, Count_); + const TBlockLess cmp(KeyIndicies_, *this, arrayIndicies); + if (BuilderLength_ <= Count_) { + if (sort) { + std::sort(blockIndicies.begin(), blockIndicies.end(), cmp); + } + } else { + if (sort) { + NYql::FastPartialSort(blockIndicies.begin(), blockIndicies.begin() + blockLen, blockIndicies.end(), cmp); + } else { + NYql::FastNthElement(blockIndicies.begin(), blockIndicies.begin() + blockLen, blockIndicies.end(), cmp); + } + } + + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { + continue; + } + + auto& arrayIndex = arrayIndicies[i]; + Builders_[i]->AddMany(arrayIndex.data(), arrayIndex.size(), blockIndicies.data(), blockLen); + } + + BuilderLength_ = blockLen; + } + + void SortAll() { + SortPermutation_.reserve(OutputLength_); + for (ui64 i = 0; i < OutputLength_; ++i) { + SortPermutation_.emplace_back(i); + } + + for (ui32 i = 0; i < Columns_.size(); ++i) { + ui64 offset = 0; + for (const auto& datum : SortInput_[i]) { + if (datum.is_scalar()) { + continue; + } else if (datum.is_array()) { + auto arrayData = datum.array(); + SortArrays_[i].push_back({arrayData.get(), offset}); + offset += arrayData->length; + } else { + auto chunks = datum.chunks(); + for (auto& chunk : chunks) { + auto arrayData = chunk->data(); + SortArrays_[i].push_back({arrayData.get(), offset}); + offset += arrayData->length; + } + } + } + } + + TBlockLess cmp(KeyIndicies_, *this, SortArrays_); + std::sort(SortPermutation_.begin(), SortPermutation_.end(), cmp); + } + + bool FillOutput(const THolderFactory& holderFactory) { + if (WritingOutput_) { + FillSortOutputPart(holderFactory); + } else if constexpr (!HasCount) { + if (!OutputLength_) { + IsFinished_ = true; + return false; + } + + SortAll(); + WritingOutput_ = true; + FillSortOutputPart(holderFactory); + } else { + IsFinished_ = true; + if (!BuilderLength_) { + return false; + } + + if (BuilderLength_ > Count_ || Sort) { + CompressBuilders(Sort); + } + + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { + Values[i] = ScalarValues_[i]; + } else { + Values[i] = holderFactory.CreateArrowBlock(arrow::Datum(Builders_[i]->Build(true))); + } + } + + Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(BuilderLength_))); + } + FillArrays(); + return true; + } + + void FillSortOutputPart(const THolderFactory& holderFactory) { + auto blockLen = Min(BuilderMaxLength_, OutputLength_ - Written_); + const bool isLast = (Written_ + blockLen == OutputLength_); + + for (ui32 i = 0; i < Columns_.size(); ++i) { + if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { + Values[i] = ScalarValues_[i]; + } else { + Builders_[i]->AddMany(SortArrays_[i].data(), SortArrays_[i].size(), SortPermutation_.data() + Written_, blockLen); + Values[i] = holderFactory.CreateArrowBlock(arrow::Datum(Builders_[i]->Build(isLast))); + } + } + + Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(blockLen))); + Written_ += blockLen; + if (Written_ >= OutputLength_) { + IsFinished_ = true; + } + } + + void AddTop(const std::vector<TBlockType*>& columns, const std::optional<std::vector<ui64>>& blockIndicies, ui64 blockLen) { + for (ui32 i = 0; i < columns.size(); ++i) { + if (columns[i]->GetShape() == TBlockType::EShape::Scalar) { + continue; + } + + const auto& datum = TArrowBlock::From(Values[i]).GetDatum(); + auto arrayIndex = MakeChunkedArrayIndex(datum); + if (blockIndicies) { + Builders_[i]->AddMany(arrayIndex.data(), arrayIndex.size(), blockIndicies->data(), Count_); + } else { + Builders_[i]->AddMany(arrayIndex.data(), arrayIndex.size(), ui64(0), blockLen); + } + } + + if (blockIndicies) { + BuilderLength_ += Count_; + } else { + BuilderLength_ += blockLen; + } + } +}; + +template <bool Sort, bool HasCount> +class TTopOrSortBlocksFlowWrapper : public TStatefulWideFlowCodegeneratorNode<TTopOrSortBlocksFlowWrapper<Sort, HasCount>> { + using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TTopOrSortBlocksFlowWrapper<Sort, HasCount>>; + using TState = TTopOrSortBlocksState<Sort, HasCount>; + +public: + TTopOrSortBlocksFlowWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TArrayRef<TType* const> wideComponents, IComputationNode* count, TComputationNodePtrVector&& directions, std::vector<ui32>&& keyIndicies) : TBaseComputation(mutables, flow, EValueRepresentation::Boxed) , Flow_(flow) @@ -95,7 +473,7 @@ public: const auto atTop = &ctx.Func->getEntryBlock().back(); - const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TState::Get>()); + const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TTopOrSortBlocksState<Sort, HasCount>::Get>()); const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false); const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop); @@ -137,7 +515,7 @@ public: const auto ptrType = PointerType::getUnqual(StructType::get(context)); const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); - const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TTopOrSortBlocksWrapper::MakeState>()); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TTopOrSortBlocksFlowWrapper::MakeState>()); const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType(), dirs->getType(), trunc->getType()}, false); const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr, dirs, trunc}, "", block); @@ -195,7 +573,7 @@ public: } new StoreInst(array, values, block); - const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TState::ProcessInput>()); + const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TTopOrSortBlocksState<Sort, HasCount>::ProcessInput>()); const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false); const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block); CallInst::Create(processBlockType, processBlockPtr, {stateArg}, "", block); @@ -204,7 +582,7 @@ public: block = work; - const auto fillBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TState::FillOutput>()); + const auto fillBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TTopOrSortBlocksState<Sort, HasCount>::FillOutput>()); const auto fillBlockType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false); const auto fillBlockPtr = CastInst::Create(Instruction::IntToPtr, fillBlockFunc, PointerType::getUnqual(fillBlockType), "fill_output_func", block); const auto hasData = CallInst::Create(fillBlockType, fillBlockPtr, {stateArg, ctx.GetFactory()}, "fill_output", block); @@ -215,7 +593,7 @@ public: block = fill; - const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TState::Slice>()); + const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr<&TTopOrSortBlocksState<Sort, HasCount>::Slice>()); const auto sliceType = FunctionType::get(indexType, {statePtrType}, false); const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block); const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block); @@ -252,293 +630,6 @@ private: } } - class TState : public TBlockState { - public: - bool WritingOutput_ = false; - bool IsFinished_ = false; - - ui64 OutputLength_ = 0; - ui64 Written_ = 0; - const std::vector<bool> Directions_; - const ui64 Count_; - const std::vector<TBlockType*> Columns_; - const std::vector<ui32> KeyIndicies_; - std::vector<std::vector<arrow::Datum>> SortInput_; - std::vector<ui64> SortPermutation_; - std::vector<TChunkedArrayIndex> SortArrays_; - - bool ScalarsFilled_ = false; - TUnboxedValueVector ScalarValues_; - std::vector<std::unique_ptr<IBlockReader>> LeftReaders_; - std::vector<std::unique_ptr<IBlockReader>> RightReaders_; - std::vector<std::unique_ptr<IArrayBuilder>> Builders_; - ui64 BuilderMaxLength_ = 0; - ui64 BuilderLength_ = 0; - std::vector<NUdf::IBlockItemComparator::TPtr> Comparators_; // by key columns only - - TState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const std::vector<ui32>& keyIndicies, const std::vector<TBlockType*>& columns, const bool* directions, ui64 count) - : TBlockState(memInfo, columns.size() + 1U) - , IsFinished_(HasCount && !count) - , Directions_(directions, directions + keyIndicies.size()) - , Count_(count) - , Columns_(columns) - , KeyIndicies_(keyIndicies) - , SortInput_(Columns_.size()) - , SortArrays_(Columns_.size()) - , LeftReaders_(Columns_.size()) - , RightReaders_(Columns_.size()) - , Builders_(Columns_.size()) - , Comparators_(KeyIndicies_.size()) - { - for (ui32 i = 0; i < Columns_.size(); ++i) { - if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { - continue; - } - - LeftReaders_[i] = MakeBlockReader(TTypeInfoHelper(), columns[i]->GetItemType()); - RightReaders_[i] = MakeBlockReader(TTypeInfoHelper(), columns[i]->GetItemType()); - } - - for (ui32 k = 0; k < KeyIndicies_.size(); ++k) { - Comparators_[k] = TBlockTypeHelper().MakeComparator(Columns_[KeyIndicies_[k]]->GetItemType()); - } - - BuilderMaxLength_ = GetStorageLength(); - size_t maxBlockItemSize = 0; - for (ui32 i = 0; i < Columns_.size(); ++i) { - if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { - continue; - } - - maxBlockItemSize = Max(maxBlockItemSize, CalcMaxBlockItemSize(Columns_[i]->GetItemType())); - }; - - BuilderMaxLength_ = Max(BuilderMaxLength_, CalcBlockLen(maxBlockItemSize)); - - for (ui32 i = 0; i < Columns_.size(); ++i) { - if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { - continue; - } - - Builders_[i] = MakeArrayBuilder(TTypeInfoHelper(), Columns_[i]->GetItemType(), ctx.ArrowMemoryPool, BuilderMaxLength_, &ctx.Builder->GetPgBuilder()); - } - } - - void Add(const NUdf::TUnboxedValuePod value, size_t idx) { - Values[idx] = value; - } - - void ProcessInput() { - const ui64 blockLen = TArrowBlock::From(Values.back()).GetDatum().template scalar_as<arrow::UInt64Scalar>().value; - - if (!ScalarsFilled_) { - for (ui32 i = 0; i < Columns_.size(); ++i) { - if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { - ScalarValues_[i] = std::move(Values[i]); - } - } - - ScalarsFilled_ = true; - } - - if constexpr (!HasCount) { - for (ui32 i = 0; i < Columns_.size(); ++i) { - auto datum = TArrowBlock::From(Values[i]).GetDatum(); - if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) { - SortInput_[i].emplace_back(datum); - } - } - - OutputLength_ += blockLen; - Values.assign(Values.size(), NUdf::TUnboxedValuePod()); - return; - } - - // shrink input block - std::optional<std::vector<ui64>> blockIndicies; - if (blockLen > Count_) { - blockIndicies.emplace(); - blockIndicies->reserve(blockLen); - for (ui64 row = 0; row < blockLen; ++row) { - blockIndicies->emplace_back(row); - } - - std::vector<TChunkedArrayIndex> arrayIndicies(Columns_.size()); - for (ui32 i = 0; i < Columns_.size(); ++i) { - if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) { - auto datum = TArrowBlock::From(Values[i]).GetDatum(); - arrayIndicies[i] = MakeChunkedArrayIndex(datum); - } - } - - const TBlockLess cmp(KeyIndicies_, *this, arrayIndicies); - NYql::FastNthElement(blockIndicies->begin(), blockIndicies->begin() + Count_, blockIndicies->end(), cmp); - } - - // copy all to builders - AddTop(Columns_, blockIndicies, blockLen); - if (BuilderLength_ + Count_ > BuilderMaxLength_) { - CompressBuilders(false); - } - - Values.assign(Values.size(), NUdf::TUnboxedValuePod()); - } - - ui64 GetStorageLength() const { - return 2 * Count_; - } - - void CompressBuilders(bool sort) { - Y_ABORT_UNLESS(ScalarsFilled_); - std::vector<TChunkedArrayIndex> arrayIndicies(Columns_.size()); - std::vector<arrow::Datum> tmpDatums(Columns_.size()); - for (ui32 i = 0; i < Columns_.size(); ++i) { - if (Columns_[i]->GetShape() != TBlockType::EShape::Scalar) { - auto datum = Builders_[i]->Build(false); - arrayIndicies[i] = MakeChunkedArrayIndex(datum); - tmpDatums[i] = std::move(datum); - } - } - - std::vector<ui64> blockIndicies; - blockIndicies.reserve(BuilderLength_); - for (ui64 row = 0; row < BuilderLength_; ++row) { - blockIndicies.push_back(row); - } - - const ui64 blockLen = Min(BuilderLength_, Count_); - const TBlockLess cmp(KeyIndicies_, *this, arrayIndicies); - if (BuilderLength_ <= Count_) { - if (sort) { - std::sort(blockIndicies.begin(), blockIndicies.end(), cmp); - } - } else { - if (sort) { - NYql::FastPartialSort(blockIndicies.begin(), blockIndicies.begin() + blockLen, blockIndicies.end(), cmp); - } else { - NYql::FastNthElement(blockIndicies.begin(), blockIndicies.begin() + blockLen, blockIndicies.end(), cmp); - } - } - - for (ui32 i = 0; i < Columns_.size(); ++i) { - if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { - continue; - } - - auto& arrayIndex = arrayIndicies[i]; - Builders_[i]->AddMany(arrayIndex.data(), arrayIndex.size(), blockIndicies.data(), blockLen); - } - - BuilderLength_ = blockLen; - } - - void SortAll() { - SortPermutation_.reserve(OutputLength_); - for (ui64 i = 0; i < OutputLength_; ++i) { - SortPermutation_.emplace_back(i); - } - - for (ui32 i = 0; i < Columns_.size(); ++i) { - ui64 offset = 0; - for (const auto& datum : SortInput_[i]) { - if (datum.is_scalar()) { - continue; - } else if (datum.is_array()) { - auto arrayData = datum.array(); - SortArrays_[i].push_back({ arrayData.get(), offset }); - offset += arrayData->length; - } else { - auto chunks = datum.chunks(); - for (auto& chunk : chunks) { - auto arrayData = chunk->data(); - SortArrays_[i].push_back({ arrayData.get(), offset }); - offset += arrayData->length; - } - } - } - } - - TBlockLess cmp(KeyIndicies_, *this, SortArrays_); - std::sort(SortPermutation_.begin(), SortPermutation_.end(), cmp); - } - - bool FillOutput(const THolderFactory& holderFactory) { - if (WritingOutput_) { - FillSortOutputPart(holderFactory); - } else if constexpr (!HasCount) { - if (!OutputLength_) { - IsFinished_ = true; - return false; - } - - SortAll(); - WritingOutput_ = true; - FillSortOutputPart(holderFactory); - } else { - IsFinished_ = true; - if (!BuilderLength_) { - return false; - } - - if (BuilderLength_ > Count_ || Sort) { - CompressBuilders(Sort); - } - - for (ui32 i = 0; i < Columns_.size(); ++i) { - if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { - Values[i] = ScalarValues_[i]; - } else { - Values[i] = holderFactory.CreateArrowBlock(arrow::Datum(Builders_[i]->Build(true))); - } - } - - Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(BuilderLength_))); - } - FillArrays(); - return true; - } - - void FillSortOutputPart(const THolderFactory& holderFactory) { - auto blockLen = Min(BuilderMaxLength_, OutputLength_ - Written_); - const bool isLast = (Written_ + blockLen == OutputLength_); - - for (ui32 i = 0; i < Columns_.size(); ++i) { - if (Columns_[i]->GetShape() == TBlockType::EShape::Scalar) { - Values[i] = ScalarValues_[i]; - } else { - Builders_[i]->AddMany(SortArrays_[i].data(), SortArrays_[i].size(), SortPermutation_.data() + Written_, blockLen); - Values[i] = holderFactory.CreateArrowBlock(arrow::Datum(Builders_[i]->Build(isLast))); - } - } - - Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(blockLen))); - Written_ += blockLen; - if (Written_ >= OutputLength_) - IsFinished_ = true; - } - - void AddTop(const std::vector<TBlockType*>& columns, const std::optional<std::vector<ui64>>& blockIndicies, ui64 blockLen) { - for (ui32 i = 0; i < columns.size(); ++i) { - if (columns[i]->GetShape() == TBlockType::EShape::Scalar) { - continue; - } - - const auto& datum = TArrowBlock::From(Values[i]).GetDatum(); - auto arrayIndex = MakeChunkedArrayIndex(datum); - if (blockIndicies) { - Builders_[i]->AddMany(arrayIndex.data(), arrayIndex.size(), blockIndicies->data(), Count_); - } else { - Builders_[i]->AddMany(arrayIndex.data(), arrayIndex.size(), ui64(0), blockLen); - } - } - - if (blockIndicies) { - BuilderLength_ += Count_; - } else { - BuilderLength_ += blockLen; - } - } - }; #ifndef MKQL_DISABLE_CODEGEN class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState { private: @@ -574,7 +665,7 @@ private: state = ctx.HolderFactory.Create<TState>(ctx, KeyIndicies_, Columns_, directions, count); } - TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { + TTopOrSortBlocksState<Sort, HasCount>& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { if (state.IsInvalid()) { std::vector<bool> dirs(Directions_.size()); std::transform(Directions_.cbegin(), Directions_.cend(), dirs.begin(), [&ctx](IComputationNode* dir){ return dir->GetValue(ctx).Get<bool>(); }); @@ -595,90 +686,125 @@ private: return *static_cast<TState*>(state.AsBoxed().Get()); } - static TChunkedArrayIndex MakeChunkedArrayIndex(const arrow::Datum& datum) { - TChunkedArrayIndex result; - if (datum.is_array()) { - result.push_back({datum.array().get(), 0}); + IComputationWideFlowNode *const Flow_; + IComputationNode *const Count_; + const TComputationNodePtrVector Directions_; + const std::vector<ui32> KeyIndicies_; + std::vector<TBlockType*> Columns_; + const size_t WideFieldsIndex_; +}; + +template <bool Sort, bool HasCount> +class TTopOrSortBlocksStreamWrapper: public TMutableComputationNode<TTopOrSortBlocksStreamWrapper<Sort, HasCount>> { + using TBaseComputation = TMutableComputationNode<TTopOrSortBlocksStreamWrapper>; + using TState = TTopOrSortBlocksState<Sort, HasCount>; + +public: + TTopOrSortBlocksStreamWrapper(TComputationMutables& mutables, + IComputationNode* stream, + TArrayRef<TType* const> wideComponents, + IComputationNode* count, + TComputationNodePtrVector&& directions, + std::vector<ui32>&& keyIndicies) + : TBaseComputation(mutables, EValueRepresentation::Boxed) + , Stream_(stream) + , Count_(count) + , Directions_(std::move(directions)) + , KeyIndicies_(std::move(keyIndicies)) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(wideComponents.size())) + { + for (ui32 i = 0; i < wideComponents.size() - 1; ++i) { + Columns_.push_back(AS_TYPE(TBlockType, wideComponents[i])); + } + } + + NUdf::TUnboxedValue MakeState(TComputationContext& ctx) const { + std::vector<bool> dirs(Directions_.size()); + std::transform(Directions_.cbegin(), Directions_.cend(), dirs.begin(), [&ctx](IComputationNode* dir) { return dir->GetValue(ctx).Get<bool>(); }); + if constexpr (HasCount) { + return ctx.HolderFactory.Create<TState>(ctx, KeyIndicies_, Columns_, dirs.data(), Count_->GetValue(ctx).Get<ui64>()); } else { - auto chunks = datum.chunks(); - ui64 offset = 0; - for (auto& chunk : chunks) { - auto arrayData = chunk->data(); - result.push_back({arrayData.get(), offset}); - offset += arrayData->length; - } + return ctx.HolderFactory.Create<TState>(ctx, KeyIndicies_, Columns_, dirs.data(), 0); } - return result; } - class TBlockLess { - public: - TBlockLess(const std::vector<ui32>& keyIndicies, const TState& state, const std::vector<TChunkedArrayIndex>& arrayIndicies) - : KeyIndicies_(keyIndicies) - , ArrayIndicies_(arrayIndicies) - , State_(state) - {} + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + auto state = MakeState(ctx); + return ctx.HolderFactory.Create<TStreamValue>(ctx.HolderFactory, + std::move(state), + std::move(Stream_->GetValue(ctx))); + } - bool operator()(ui64 lhs, ui64 rhs) const { - if (KeyIndicies_.size() == 1) { - auto i = KeyIndicies_[0]; - auto& arrayIndex = ArrayIndicies_[i]; - if (arrayIndex.empty()) { - // scalar - return false; - } +private: + class TStreamValue: public TComputationValue<TStreamValue> { + using TBase = TComputationValue<TStreamValue>; - auto leftItem = GetBlockItem(*State_.LeftReaders_[i], arrayIndex, lhs); - auto rightItem = GetBlockItem(*State_.RightReaders_[i], arrayIndex, rhs); - if (State_.Directions_[0]) { - return State_.Comparators_[0]->Less(leftItem, rightItem); - } else { - return State_.Comparators_[0]->Greater(leftItem, rightItem); - } - } else { - for (ui32 k = 0; k < KeyIndicies_.size(); ++k) { - auto i = KeyIndicies_[k]; - auto& arrayIndex = ArrayIndicies_[i]; - if (arrayIndex.empty()) { - // scalar - continue; - } + public: + TStreamValue(TMemoryUsageInfo* memInfo, + const THolderFactory& holderFactory, + NUdf::TUnboxedValue&& blockState, + NUdf::TUnboxedValue&& stream) + : TBase(memInfo) + , BlockState_(std::move(blockState)) + , Stream_(std::move(stream)) + , HolderFactory_(holderFactory) + { + } - auto leftItem = GetBlockItem(*State_.LeftReaders_[i], arrayIndex, lhs); - auto rightItem = GetBlockItem(*State_.RightReaders_[i], arrayIndex, rhs); - auto cmp = State_.Comparators_[k]->Compare(leftItem, rightItem); - if (cmp == 0) { - continue; - } + private: + NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) { + auto& blockState = *static_cast<TState*>(BlockState_.AsBoxed().Get()); + Y_DEBUG_ABORT_UNLESS(blockState.Values.size() == width); + Y_DEBUG_ABORT_UNLESS(blockState.Values.size() == blockState.Columns_.size() + 1); + auto* inputFields = blockState.Pointer; + + if (!blockState.Count) { + if (blockState.IsFinished_) { + return NUdf::EFetchStatus::Finish; + } - if (State_.Directions_[k]) { - return cmp < 0; - } else { - return cmp > 0; + if (!blockState.WritingOutput_) { + while (true) { + switch (Stream_.WideFetch(inputFields, width)) { + case NUdf::EFetchStatus::Yield: + return NUdf::EFetchStatus::Yield; + case NUdf::EFetchStatus::Ok: + blockState.ProcessInput(); + continue; + case NUdf::EFetchStatus::Finish: + break; + } + break; } } - return false; - } - } - private: - static TBlockItem GetBlockItem(IBlockReader& reader, const TChunkedArrayIndex& arrayIndex, ui64 idx) { - Y_DEBUG_ABORT_UNLESS(!arrayIndex.empty()); - if (arrayIndex.size() == 1) { - return reader.GetItem(*arrayIndex.front().Data, idx); + if (!blockState.FillOutput(HolderFactory_)) { + return NUdf::EFetchStatus::Finish; + } } - auto it = LookupArrayDataItem(arrayIndex.data(), arrayIndex.size(), idx); - return reader.GetItem(*it->Data, idx); + const auto sliceSize = blockState.Slice(); + for (size_t i = 0; i < width; ++i) { + output[i] = blockState.Get(sliceSize, HolderFactory_, i); + } + return NUdf::EFetchStatus::Ok; } - const std::vector<ui32>& KeyIndicies_; - const std::vector<TChunkedArrayIndex> ArrayIndicies_; - const TState& State_; + NUdf::TUnboxedValue BlockState_; + NUdf::TUnboxedValue Stream_; + const THolderFactory& HolderFactory_; }; - IComputationWideFlowNode *const Flow_; - IComputationNode *const Count_; + void RegisterDependencies() const final { + this->DependsOn(Stream_); + this->DependsOn(Count_); + for (auto dir : Directions_) { + this->DependsOn(dir); + } + } + + IComputationNode* const Stream_; + IComputationNode* const Count_; const TComputationNodePtrVector Directions_; const std::vector<ui32> KeyIndicies_; std::vector<TBlockType*> Columns_; @@ -690,13 +816,14 @@ IComputationNode* WrapTopOrSort(TCallable& callable, const TComputationNodeFacto constexpr ui32 offset = HasCount ? 0 : 1; const ui32 inputsWithCount = callable.GetInputsCount() + offset; MKQL_ENSURE(inputsWithCount > 2U && !(inputsWithCount % 2U), "Expected more arguments."); + const TType* const inputType = callable.GetInput(0).GetStaticType(); + + MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either WideStream or WideFlow as an input"); - const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); - const auto wideComponents = GetWideComponents(flowType); + const auto wideComponents = GetWideComponents(inputType); MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column"); - const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); - MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + auto node = LocateNode(ctx.NodeLocator, callable, 0); IComputationNode* count = nullptr; if constexpr (HasCount) { @@ -714,7 +841,12 @@ IComputationNode* WrapTopOrSort(TCallable& callable, const TComputationNodeFacto directions.push_back(LocateNode(ctx.NodeLocator, callable, i + 1 - offset)); } - return new TTopOrSortBlocksWrapper<Sort, HasCount>(ctx.Mutables, wideFlow, wideComponents, count, std::move(directions), std::move(keyIndicies)); + const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(node); + if (!wideFlow) { + MKQL_ENSURE(inputType->IsStream(), "Expecting stream as input type."); + return new TTopOrSortBlocksStreamWrapper<Sort, HasCount>(ctx.Mutables, node, wideComponents, count, std::move(directions), std::move(keyIndicies)); + } + return new TTopOrSortBlocksFlowWrapper<Sort, HasCount>(ctx.Mutables, wideFlow, wideComponents, count, std::move(directions), std::move(keyIndicies)); } } //namespace diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_block_top_sort_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_block_top_sort_ut.cpp index 5a0f737ec6b..43edba1b9b1 100644 --- a/yql/essentials/minikql/comp_nodes/ut/mkql_block_top_sort_ut.cpp +++ b/yql/essentials/minikql/comp_nodes/ut/mkql_block_top_sort_ut.cpp @@ -49,10 +49,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLBlockTopTest) { const auto wideFlow = pb.ExpandMap(pb.ToFlow(list), [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }); - const auto blockFlow = pb.ToFlow(pb.WideToBlocks(pb.FromFlow(wideFlow))); - const auto topBlocks = pb.WideTopBlocks(blockFlow, + const auto blockStream = pb.WideToBlocks(pb.FromFlow(wideFlow)); + const auto topBlocks = pb.WideTopBlocks(blockStream, pb.NewDataLiteral<ui64>(4ULL), {{0U, pb.NewDataLiteral<bool>(true)}}); - const auto topFlow = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(topBlocks))); + const auto topFlow = pb.ToFlow(pb.WideFromBlocks(topBlocks)); const auto pgmReturn = pb.Collect(pb.NarrowMap(topFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } )); @@ -116,10 +116,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLBlockTopTest) { const auto wideFlow = pb.ExpandMap(pb.ToFlow(list), [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }); - const auto blockFlow = pb.ToFlow(pb.WideToBlocks(pb.FromFlow(wideFlow))); - const auto topBlocks = pb.WideTopBlocks(blockFlow, + const auto blockStream = pb.WideToBlocks(pb.FromFlow(wideFlow)); + const auto topBlocks = pb.WideTopBlocks(blockStream, pb.NewDataLiteral<ui64>(6ULL), {{0U, pb.NewDataLiteral<bool>(false)}}); - const auto topFlow = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(topBlocks))); + const auto topFlow = pb.ToFlow(pb.WideFromBlocks(topBlocks)); const auto pgmReturn = pb.Collect(pb.NarrowMap(topFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } )); @@ -189,10 +189,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLBlockTopTest) { const auto wideFlow = pb.ExpandMap(pb.ToFlow(list), [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }); - const auto blockFlow = pb.ToFlow(pb.WideToBlocks(pb.FromFlow(wideFlow))); - const auto topBlocks = pb.WideTopBlocks(blockFlow, + const auto blockStream = pb.WideToBlocks(pb.FromFlow(wideFlow)); + const auto topBlocks = pb.WideTopBlocks(blockStream, pb.NewDataLiteral<ui64>(3ULL), {{1U, pb.NewDataLiteral<bool>(true)}}); - const auto topFlow = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(topBlocks))); + const auto topFlow = pb.ToFlow(pb.WideFromBlocks(topBlocks)); const auto pgmReturn = pb.Collect(pb.NarrowMap(topFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } )); @@ -253,10 +253,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLBlockTopTest) { const auto wideFlow = pb.ExpandMap(pb.ToFlow(list), [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }); - const auto blockFlow = pb.ToFlow(pb.WideToBlocks(pb.FromFlow(wideFlow))); - const auto topBlocks = pb.WideTopBlocks(blockFlow, + const auto blockStream = pb.WideToBlocks(pb.FromFlow(wideFlow)); + const auto topBlocks = pb.WideTopBlocks(blockStream, pb.NewDataLiteral<ui64>(2ULL), {{1U, pb.NewDataLiteral<bool>(false)}}); - const auto topFlow = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(topBlocks))); + const auto topFlow = pb.ToFlow(pb.WideFromBlocks(topBlocks)); const auto pgmReturn = pb.Collect(pb.NarrowMap(topFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } )); @@ -314,10 +314,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLBlockTopTest) { const auto wideFlow = pb.ExpandMap(pb.ToFlow(list), [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }); - const auto blockFlow = pb.ToFlow(pb.WideToBlocks(pb.FromFlow(wideFlow))); - const auto topSortBlocks = pb.WideTopSortBlocks(blockFlow, + const auto blockStream = pb.WideToBlocks(pb.FromFlow(wideFlow)); + const auto topSortBlocks = pb.WideTopSortBlocks(blockStream, pb.NewDataLiteral<ui64>(4ULL), {{0U, pb.NewDataLiteral<bool>(true)}, {1U, pb.NewDataLiteral<bool>(false)}}); - const auto topSortFlow = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(topSortBlocks))); + const auto topSortFlow = pb.ToFlow(pb.WideFromBlocks(topSortBlocks)); const auto pgmReturn = pb.Collect(pb.NarrowMap(topSortFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } )); @@ -381,10 +381,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLBlockTopTest) { const auto wideFlow = pb.ExpandMap(pb.ToFlow(list), [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }); - const auto blockFlow = pb.ToFlow(pb.WideToBlocks(pb.FromFlow(wideFlow))); - const auto topSortBlocks = pb.WideTopSortBlocks(blockFlow, + const auto blockStream = pb.WideToBlocks(pb.FromFlow(wideFlow)); + const auto topSortBlocks = pb.WideTopSortBlocks(blockStream, pb.NewDataLiteral<ui64>(6ULL), {{0U, pb.NewDataLiteral<bool>(false)}, {1U, pb.NewDataLiteral<bool>(true)}}); - const auto topSortFlow = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(topSortBlocks))); + const auto topSortFlow = pb.ToFlow(pb.WideFromBlocks(topSortBlocks)); const auto pgmReturn = pb.Collect(pb.NarrowMap(topSortFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } )); @@ -454,10 +454,12 @@ Y_UNIT_TEST_SUITE(TMiniKQLBlockTopTest) { const auto wideFlow = pb.ExpandMap(pb.ToFlow(list), [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }); - const auto blockFlow = pb.ToFlow(pb.WideToBlocks(pb.FromFlow(wideFlow))); - const auto topSortBlocks = pb.WideTopSortBlocks(blockFlow, - pb.NewDataLiteral<ui64>(4ULL), {{1U, pb.NewDataLiteral<bool>(true)}, {0U, pb.NewDataLiteral<bool>(false)}}); - const auto topSortFlow = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(topSortBlocks))); + const auto blockStream = pb.WideToBlocks(pb.FromFlow(wideFlow)); + const auto topSortBlocks = pb.WideTopSortBlocks(blockStream, + pb.NewDataLiteral<ui64>(4ULL), + {{1U, pb.NewDataLiteral<bool>(true)}, + {0U, pb.NewDataLiteral<bool>(false)}}); + const auto topSortFlow = pb.ToFlow(pb.WideFromBlocks(topSortBlocks)); const auto pgmReturn = pb.Collect(pb.NarrowMap(topSortFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } )); @@ -521,10 +523,12 @@ Y_UNIT_TEST_SUITE(TMiniKQLBlockTopTest) { const auto wideFlow = pb.ExpandMap(pb.ToFlow(list), [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }); - const auto blockFlow = pb.ToFlow(pb.WideToBlocks(pb.FromFlow(wideFlow))); - const auto topSortBlocks = pb.WideTopSortBlocks(blockFlow, - pb.NewDataLiteral<ui64>(6ULL), {{1U, pb.NewDataLiteral<bool>(false)}, {0U, pb.NewDataLiteral<bool>(true)}}); - const auto topSortFlow = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(topSortBlocks))); + const auto blockStream = pb.WideToBlocks(pb.FromFlow(wideFlow)); + const auto topSortBlocks = pb.WideTopSortBlocks(blockStream, + pb.NewDataLiteral<ui64>(6ULL), + {{1U, pb.NewDataLiteral<bool>(false)}, + {0U, pb.NewDataLiteral<bool>(true)}}); + const auto topSortFlow = pb.ToFlow(pb.WideFromBlocks(topSortBlocks)); const auto pgmReturn = pb.Collect(pb.NarrowMap(topSortFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } )); @@ -596,10 +600,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLBlockSortTest) { const auto wideFlow = pb.ExpandMap(pb.ToFlow(list), [&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }); - const auto blockFlow = pb.ToFlow(pb.WideToBlocks(pb.FromFlow(wideFlow))); - const auto sortBlocks = pb.WideSortBlocks(blockFlow, + const auto blockStream = pb.WideToBlocks(pb.FromFlow(wideFlow)); + const auto sortBlocks = pb.WideSortBlocks(blockStream, {{0U, pb.NewDataLiteral<bool>(true)}}); - const auto sortFlow = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(sortBlocks))); + const auto sortFlow = pb.ToFlow(pb.WideFromBlocks(sortBlocks)); const auto pgmReturn = pb.Collect(pb.NarrowMap(sortFlow, [&](TRuntimeNode::TList items) -> TRuntimeNode { return pb.NewTuple(tupleType, items); } )); diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp index 274090f7b18..49332e3fe0c 100644 --- a/yql/essentials/minikql/mkql_program_builder.cpp +++ b/yql/essentials/minikql/mkql_program_builder.cpp @@ -1574,15 +1574,15 @@ TRuntimeNode TProgramBuilder::WideTakeBlocks(TRuntimeNode flow, TRuntimeNode cou } TRuntimeNode TProgramBuilder::WideTopBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTopOrSort(__func__, flow, count, keys); + return BuildWideTopOrSort(__func__, flow, count, keys, /*isBlocks=*/true); } TRuntimeNode TProgramBuilder::WideTopSortBlocks(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTopOrSort(__func__, flow, count, keys); + return BuildWideTopOrSort(__func__, flow, count, keys, /*isBlocks=*/true); } TRuntimeNode TProgramBuilder::WideSortBlocks(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTopOrSort(__func__, flow, Nothing(), keys); + return BuildWideTopOrSort(__func__, flow, Nothing(), keys, /*isBlocks=*/true); } TRuntimeNode TProgramBuilder::AsScalar(TRuntimeNode value) { @@ -1900,25 +1900,42 @@ TRuntimeNode TProgramBuilder::Sort(TRuntimeNode list, TRuntimeNode ascending, co TRuntimeNode TProgramBuilder::WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTopOrSort(__func__, flow, count, keys); + return BuildWideTopOrSort(__func__, flow, count, keys, /*isBlocks=*/false); } TRuntimeNode TProgramBuilder::WideTopSort(TRuntimeNode flow, TRuntimeNode count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTopOrSort(__func__, flow, count, keys); + return BuildWideTopOrSort(__func__, flow, count, keys, /*isBlocks=*/false); } TRuntimeNode TProgramBuilder::WideSort(TRuntimeNode flow, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - return BuildWideTopOrSort(__func__, flow, Nothing(), keys); + return BuildWideTopOrSort(__func__, flow, Nothing(), keys, /*isBlocks=*/false); } -TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys) { - const auto width = GetWideComponentsCount(AS_TYPE(TFlowType, flow.GetStaticType())); - MKQL_ENSURE(!keys.empty() && keys.size() <= width, "Unexpected keys count: " << keys.size()); +TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode stream, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys, bool isBlocks) { + if (isBlocks) { + return BuildWideTopOrSortImpl(callableName, stream, count, keys, TType::EKind::Stream); + } else { + return BuildWideTopOrSortImpl(callableName, stream, count, keys, TType::EKind::Flow); + } +} - TCallableBuilder callableBuilder(Env_, callableName, flow.GetStaticType()); - callableBuilder.Add(flow); +TRuntimeNode TProgramBuilder::BuildWideTopOrSortImpl(const std::string_view& callableName, TRuntimeNode stream, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys, TType::EKind streamKind) { + MKQL_ENSURE(stream.GetStaticType()->GetKind() == streamKind, "Mismatched input type"); + const auto width = GetWideComponentsCount(stream.GetStaticType()); + MKQL_ENSURE(!keys.empty() && keys.size() <= width, "Unexpected keys count: " << keys.size()); + bool shouldRewriteToFlow = RuntimeVersion < 64U && streamKind == TType::EKind::Stream; + if (shouldRewriteToFlow) { + // Preserve the old behaviour for ABI compatibility. + // Emit (FromFlow (Wide{Top,TopSort,Sort}Blocks (ToFlow (<stream>)))) to + // process the flow in favor to the given stream following + // the older MKQL ABI. + // FIXME: Drop the branch below, when the time comes. + stream = ToFlow(stream); + } + TCallableBuilder callableBuilder(Env_, callableName, stream.GetStaticType()); + callableBuilder.Add(stream); if (count) { callableBuilder.Add(*count); } @@ -1928,7 +1945,17 @@ TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callabl callableBuilder.Add(NewDataLiteral(key.first)); callableBuilder.Add(key.second); }); - return TRuntimeNode(callableBuilder.Build(), false); + + auto resultNode = TRuntimeNode(callableBuilder.Build(), false); + if (shouldRewriteToFlow) { + // Preserve the old behaviour for ABI compatibility. + // Emit (FromFlow (Wide{Top,TopSort,Sort}Blocks (ToFlow (<stream>)))) to + // process the flow in favor to the given stream following + // the older MKQL ABI. + // FIXME: Drop the branch below, when the time comes. + return FromFlow(resultNode); + } + return resultNode; } TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) { diff --git a/yql/essentials/minikql/mkql_program_builder.h b/yql/essentials/minikql/mkql_program_builder.h index f5a6fc28e73..6d6d02a083d 100644 --- a/yql/essentials/minikql/mkql_program_builder.h +++ b/yql/essentials/minikql/mkql_program_builder.h @@ -806,7 +806,8 @@ private: TRuntimeNode BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members, const std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>>& filteredItems); - TRuntimeNode BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys); + TRuntimeNode BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys, bool isBlocks); + TRuntimeNode BuildWideTopOrSortImpl(const std::string_view& callableName, TRuntimeNode flow, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys, TType::EKind streamKind); TRuntimeNode InvokeBinary(const std::string_view& callableName, TType* type, TRuntimeNode data1, TRuntimeNode data2); TRuntimeNode AggrCompare(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2); diff --git a/yql/essentials/minikql/mkql_runtime_version.h b/yql/essentials/minikql/mkql_runtime_version.h index eb64e14b02b..ea8be58513c 100644 --- a/yql/essentials/minikql/mkql_runtime_version.h +++ b/yql/essentials/minikql/mkql_runtime_version.h @@ -24,7 +24,7 @@ namespace NMiniKQL { // 1. Bump this version every time incompatible runtime nodes are introduced. // 2. Make sure you provide runtime node generation for previous runtime versions. #ifndef MKQL_RUNTIME_VERSION -#define MKQL_RUNTIME_VERSION 63U +#define MKQL_RUNTIME_VERSION 64U #endif // History: diff --git a/yql/essentials/tests/s-expressions/minirun/part5/canondata/result.json b/yql/essentials/tests/s-expressions/minirun/part5/canondata/result.json index 0bf1144eecf..c8be00cd9ee 100644 --- a/yql/essentials/tests/s-expressions/minirun/part5/canondata/result.json +++ b/yql/essentials/tests/s-expressions/minirun/part5/canondata/result.json @@ -13,6 +13,20 @@ "uri": "https://{canondata_backend}/1936273/fc47b74ce53fdd3426495bb35fba8f4c1bc5deda/resource.tar.gz#test.test_Blocks-BlockOrderedExtend-default.txt-Results_/results.txt" } ], + "test.test[Blocks-BlocksSort+ReplicateScalars-default.txt-Debug]": [ + { + "checksum": "8adeb84aac9d8865516a4dd8cc3b1359", + "size": 635, + "uri": "https://{canondata_backend}/1942100/7f62e1d7e5ababbffae9cdcf4da35d121501ff83/resource.tar.gz#test.test_Blocks-BlocksSort+ReplicateScalars-default.txt-Debug_/opt.yql" + } + ], + "test.test[Blocks-BlocksSort+ReplicateScalars-default.txt-Results]": [ + { + "checksum": "ce4b5ec69129932b188215b28cd20e2a", + "size": 887, + "uri": "https://{canondata_backend}/1942100/7f62e1d7e5ababbffae9cdcf4da35d121501ff83/resource.tar.gz#test.test_Blocks-BlocksSort+ReplicateScalars-default.txt-Results_/results.txt" + } + ], "test.test[Builtins-DivePrefixMembersOpt-default.txt-Debug]": [ { "checksum": "7c408ddcedc1deae859a58f6f29ff7fd", diff --git a/yql/essentials/tests/s-expressions/suites/Blocks/BlocksSort+ReplicateScalars.yqls b/yql/essentials/tests/s-expressions/suites/Blocks/BlocksSort+ReplicateScalars.yqls new file mode 100644 index 00000000000..a1ef286ce7a --- /dev/null +++ b/yql/essentials/tests/s-expressions/suites/Blocks/BlocksSort+ReplicateScalars.yqls @@ -0,0 +1,29 @@ +( +(let world (Configure! world (DataSource 'config) 'BlockEngine 'force)) +(let wconf (DataSink 'result)) + +(let x1 (AsStruct '('"x" (Uint64 '"1")))) +(let x2 (AsStruct '('"x" (Uint64 '"2")))) +(let x3 (AsStruct '('"x" (Uint64 '"3")))) +(let x4 (AsStruct '('"x" (Uint64 '"4")))) + +(let list (AsList x1 x2 x3 x4)) + +(let expandLambda (lambda '(item) (Member item '"x"))) +(let wideStream (FromFlow (ExpandMap (ToFlow list) expandLambda))) +(let wideBlockStream (WideToBlocks wideStream)) +(let wideFlowScalar (WideMap (ToFlow wideBlockStream) (lambda '($x, $blockSize) (AsScalar (Uint64 '"123")) $blockSize))) + +(let sortParams '('('0 (Bool 'true)))) + +(let replicateParams '('"0")) + +(let nopFromBlocksToBlocks (WideFromBlocks (WideSortBlocks (ReplicateScalars (FromFlow wideFlowScalar) replicateParams) sortParams))) + +(let narrowLambda (lambda '(x) (AsStruct '('"x" x)))) +(let scalarList (ForwardList (NarrowMap (ToFlow nopFromBlocksToBlocks) narrowLambda))) + +(let world (Write! world wconf (Key) scalarList '('('type)))) +(let world (Commit! world wconf)) +(return world) +) diff --git a/yql/essentials/tests/sql/minirun/part2/canondata/result.json b/yql/essentials/tests/sql/minirun/part2/canondata/result.json index cd286c78c0c..f9a05fffd89 100644 --- a/yql/essentials/tests/sql/minirun/part2/canondata/result.json +++ b/yql/essentials/tests/sql/minirun/part2/canondata/result.json @@ -351,9 +351,9 @@ ], "test.test[blocks-sort-default.txt-Peephole]": [ { - "checksum": "756e6c3f006f33c75a3dedea68c889ac", - "size": 919, - "uri": "https://{canondata_backend}/1917492/6fdf85f7e05da60eed58efcacce70c29bce9a047/resource.tar.gz#test.test_blocks-sort-default.txt-Peephole_/opt.yql" + "checksum": "fc3dd8d7c0631fdd696c87d359fd7c6b", + "size": 905, + "uri": "https://{canondata_backend}/1775059/6773e6533ba4ece8ecb355472912579ca2809089/resource.tar.gz#test.test_blocks-sort-default.txt-Peephole_/opt.yql" } ], "test.test[blocks-sort-default.txt-Results]": [ |