diff options
author | aneporada <aneporada@ydb.tech> | 2023-07-31 22:21:35 +0300 |
---|---|---|
committer | aneporada <aneporada@ydb.tech> | 2023-07-31 22:21:35 +0300 |
commit | d2231661b577df3a282cadb55845ceefcb419c8b (patch) | |
tree | 68526f669f2750262e5e04dcc9c20681b996a158 | |
parent | f9e4743508b7930e884714cc99985ac45f84ed98 (diff) | |
download | ydb-d2231661b577df3a282cadb55845ceefcb419c8b.tar.gz |
Support scalar stream index in BlockMergeManyFinalizeHashed
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_blocks.cpp | 15 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp | 21 |
2 files changed, 27 insertions, 9 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index 1f7343585f8..6b4b1278561 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -659,9 +659,10 @@ IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr } TTypeAnnotationNode::TListType blockItemTypes; - if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr, !many)) { + if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } + YQL_ENSURE(blockItemTypes.size() > 0); TTypeAnnotationNode::TListType retMultiType; if (!ValidateBlockKeys(input->Pos(), blockItemTypes, *input->Child(1), retMultiType, ctx.Expr)) { @@ -682,7 +683,7 @@ IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr } ui32 streamIndex; - if (!TryFromString(input->Child(3)->Content(), streamIndex) || streamIndex >= blockItemTypes.size()) { + if (!TryFromString(input->Child(3)->Content(), streamIndex) || streamIndex >= blockItemTypes.size() - 1) { ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(3)->Pos()), "Bad stream index")); return IGraphTransformer::TStatus::Error; } @@ -694,6 +695,16 @@ IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr if (!ValidateAggManyStreams(*input->Child(4), input->Child(2)->ChildrenSize(), ctx.Expr)) { return IGraphTransformer::TStatus::Error; } + + // disallow any scalar columns except for streamIndex column + auto itemTypes = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems(); + for (ui32 i = 0; i + 1 < itemTypes.size(); ++i) { + bool isScalar = itemTypes[i]->GetKind() == ETypeAnnotationKind::Scalar; + if (isScalar && i != streamIndex) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Head().Pos()), TStringBuilder() << "Unexpected scalar type " << *itemTypes[i] << ", on input column #" << i)); + return IGraphTransformer::TStatus::Error; + } + } } retMultiType.push_back(ctx.Expr.MakeType<TScalarExprType>(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64))); diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp index ffb7242941f..adc2eda366a 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp @@ -701,10 +701,15 @@ public: } const ui32* streamIndexData = nullptr; + TMaybe<ui32> streamIndexScalar; if constexpr (Many) { auto streamIndexDatum = TArrowBlock::From(s.Values_[StreamIndex_]).GetDatum(); - MKQL_ENSURE(streamIndexDatum.is_array(), "Expected array"); - streamIndexData = streamIndexDatum.array()->template GetValues<ui32>(1); + if (streamIndexDatum.is_scalar()) { + streamIndexScalar = streamIndexDatum.template scalar_as<arrow::UInt32Scalar>().value; + } else { + MKQL_ENSURE(streamIndexDatum.is_array(), "Expected array"); + streamIndexData = streamIndexDatum.array()->template GetValues<ui32>(1); + } s.UnwrappedValues_ = s.Values_; for (const auto& p : AggsParams_) { const auto& columnDatum = TArrowBlock::From(s.UnwrappedValues_[p.Column_]).GetDatum(); @@ -748,10 +753,14 @@ public: s.HashSet_->CheckGrow(); } } else { + ui32 streamIndex = 0; + if constexpr (Many) { + streamIndex = streamIndexScalar ? *streamIndexScalar : streamIndexData[row]; + } if (!InlineAggState) { - Insert(*s.HashFixedMap_, key, row, streamIndexData, output, s); + Insert(*s.HashFixedMap_, key, row, streamIndex, output, s); } else { - Insert(*s.HashMap_, key, row, streamIndexData, output, s); + Insert(*s.HashMap_, key, row, streamIndex, output, s); } } } @@ -948,7 +957,7 @@ private: } template <typename THash> - void Insert(THash& hash, const TKey& key, ui64 row, const ui32* streamIndexData, NUdf::TUnboxedValue*const* output, TState& s) const { + void Insert(THash& hash, const TKey& key, ui64 row, ui32 currentStreamIndex, NUdf::TUnboxedValue*const* output, TState& s) const { bool isNew; auto iter = hash.Insert(key, isNew); char* payload = (char*)hash.GetMutablePayload(iter); @@ -964,7 +973,6 @@ private: if constexpr (Many) { static_assert(Finalize); - ui32 currentStreamIndex = streamIndexData[row]; MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); memset(ptr, 0, Streams_.size()); ptr[currentStreamIndex] = 1; @@ -1002,7 +1010,6 @@ private: if constexpr (Many) { static_assert(Finalize); - ui32 currentStreamIndex = streamIndexData[row]; MKQL_ENSURE(currentStreamIndex < Streams_.size(), "Invalid stream index"); bool isNewStream = !ptr[currentStreamIndex]; |