diff options
author | zverevgeny <zverevgeny@ydb.tech> | 2023-09-08 07:08:49 +0300 |
---|---|---|
committer | zverevgeny <zverevgeny@ydb.tech> | 2023-09-08 07:30:58 +0300 |
commit | b27ebe54a0677e957e81d60d8286b52ad72265af (patch) | |
tree | b85df044d67c66ac077e32649899afb5a6c7e1d4 | |
parent | 563e21fcd7280e84fb359d0a8de2403a5924a3d8 (diff) | |
download | ydb-b27ebe54a0677e957e81d60d8286b52ad72265af.tar.gz |
YQL-16325 augmented input data for measures
3 files changed, 233 insertions, 25 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp index e7a59032f4..6b4764a8e3 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp @@ -1,9 +1,12 @@ #include "mkql_match_recognize_matched_vars.h" +#include "mkql_match_recognize_measure_arg.h" +#include <ydb/library/yql/core/sql_types/match_recognize.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_pack.h> #include <ydb/library/yql/minikql/mkql_node_cast.h> #include <ydb/library/yql/minikql/mkql_runtime_version.h> +#include <ydb/library/yql/minikql/mkql_string_util.h> #include <ydb/library/yql/core/sql_types/match_recognize.h> #include <deque> @@ -11,9 +14,6 @@ namespace NKikimr::NMiniKQL { namespace NMatchRecognize { -enum class EMeasureColumnSource {Classifier = 0, MatchNumber = 1, Input}; -using TMeasureInputColumnOrder = TVector<std::pair<EMeasureColumnSource, size_t>>; - enum class EOutputColumnSource {PartitionKey, Measure}; using TOutputColumnOrder = TVector<std::pair<EOutputColumnSource, size_t>>; @@ -33,22 +33,29 @@ public: TBackTrackingMatchRecognize( NUdf::TUnboxedValue&& partitionKey, IComputationExternalNode* matchedVarsArg, + const TUnboxedValueVector& varNames, + IComputationExternalNode* measureInputDataArg, + const TMeasureInputColumnOrder& measureInputColumnOrder, const TComputationNodePtrVector& measures, const TOutputColumnOrder& outputColumnOrder, IComputationExternalNode* currentRowIndexArg, - IComputationExternalNode* defineInputDataArg, + IComputationExternalNode* inputDataArg, const TComputationNodePtrVector& defines, const TContainerCacheOnContext& cache ) : PartitionKey(std::move(partitionKey)) , MatchedVarsArg(matchedVarsArg) + , VarNames(varNames) + , MeasureInputDataArg(measureInputDataArg) + , MeasureInputColumnOrder(measureInputColumnOrder) , Measures(measures) , OutputColumnOrder(outputColumnOrder) , CurrentRowIndexArg(currentRowIndexArg) - , DefineInputDataArg(defineInputDataArg) + , InputDataArg(inputDataArg) , Defines(defines) , Cache(cache) , CurMatchedVars(Defines.size()) + , MatchNumber(0) { } @@ -60,8 +67,15 @@ public: NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) override { if (Matches.empty()) return NUdf::TUnboxedValue{}; - MatchedVarsArg->SetValue(ctx, ToValue(ctx, Matches.front())); + MatchedVarsArg->SetValue(ctx, ToValue(ctx, std::move(Matches.front()))); Matches.pop_front(); + MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>( + InputDataArg->GetValue(ctx), + MeasureInputColumnOrder, + MatchedVarsArg->GetValue(ctx), + VarNames, + ++MatchNumber + )); NUdf::TUnboxedValue *itemsPtr = nullptr; const auto result = Cache.NewArray(ctx, OutputColumnOrder.size(), itemsPtr); for (auto const& c: OutputColumnOrder) { @@ -80,7 +94,7 @@ public: //Assume, that data moved to IComputationExternalNode node, will not be modified or released //till the end of the current function auto rowsSize = Rows.size(); - DefineInputDataArg->SetValue(ctx, ctx.HolderFactory.VectorAsVectorHolder(std::move(Rows))); + InputDataArg->SetValue(ctx, ctx.HolderFactory.VectorAsVectorHolder(std::move(Rows))); for (size_t i = 0; i != rowsSize; ++i) { CurrentRowIndexArg->SetValue(ctx, NUdf::TUnboxedValuePod(static_cast<ui64>(i))); for (size_t v = 0; v != Defines.size(); ++v) { @@ -109,15 +123,19 @@ public: private: const NUdf::TUnboxedValue PartitionKey; IComputationExternalNode* const MatchedVarsArg; + const TUnboxedValueVector& VarNames; + IComputationExternalNode* const MeasureInputDataArg; + const TMeasureInputColumnOrder& MeasureInputColumnOrder; const TComputationNodePtrVector& Measures; const TOutputColumnOrder& OutputColumnOrder; IComputationExternalNode* const CurrentRowIndexArg; - IComputationExternalNode* const DefineInputDataArg; + IComputationExternalNode* const InputDataArg; const TComputationNodePtrVector& Defines; const TContainerCacheOnContext& Cache; TUnboxedValueVector Rows; TMatchedVars CurMatchedVars; std::deque<TMatchedVars> Matches; + ui64 MatchNumber; }; class TStreamingMatchRecognize: public IProcessMatchRecognize { @@ -168,7 +186,7 @@ public: NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) override { if (!HasMatch) return NUdf::TUnboxedValue{}; - MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue>(MatchedVars)); + MatchedVarsArg->SetValue(ctx, ToValue(ctx, MatchedVars)); HasMatch = false; NUdf::TUnboxedValue *itemsPtr = nullptr; const auto result = Cache.NewArray(ctx, OutputColumnOrder.size(), itemsPtr); @@ -211,10 +229,13 @@ public: IComputationNode *partitionKey, TType* partitionKeyType, IComputationExternalNode* matchedVarsArg, + const TUnboxedValueVector& varNames, + IComputationExternalNode* measureInputDataArg, + const TMeasureInputColumnOrder& measureInputColumnOrder, const TComputationNodePtrVector& measures, TOutputColumnOrder&& outputColumnOrder, IComputationExternalNode* currentRowIndexArg, - IComputationExternalNode* defineInputDataArg, + IComputationExternalNode* inputDataArg, const TComputationNodePtrVector& defines ) :TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded) @@ -223,10 +244,13 @@ public: , PartitionKey(partitionKey) , PartitionKeyType(partitionKeyType) , MatchedVarsArg(matchedVarsArg) + , VarNames(varNames) + , MeasureInputDataArg(measureInputDataArg) + , MeasureInputColumnOrder(measureInputColumnOrder) , Measures(measures) , OutputColumnOrder(outputColumnOrder) , CurrentRowIndexArg(currentRowIndexArg) - , DefineInputDataArg(defineInputDataArg) + , inputDataArg(inputDataArg) , Defines(defines) , Cache(mutables) {} @@ -238,10 +262,13 @@ public: PartitionKey, PartitionKeyType, MatchedVarsArg, + VarNames, + MeasureInputDataArg, + MeasureInputColumnOrder, Measures, OutputColumnOrder, CurrentRowIndexArg, - DefineInputDataArg, + inputDataArg, Defines, Cache ); @@ -271,10 +298,13 @@ private: IComputationNode* partitionKey, TType* partitionKeyType, IComputationExternalNode* matchedVarsArg, + const TUnboxedValueVector& varNames, + IComputationExternalNode* measureInputDataArg, + const TMeasureInputColumnOrder& measureInputColumnOrder, const TComputationNodePtrVector& measures, const TOutputColumnOrder& outputColumnOrder, IComputationExternalNode* currentRowIndexArg, - IComputationExternalNode* defineInputDataArg, + IComputationExternalNode* inputDataArg, const TComputationNodePtrVector& defines, const TContainerCacheOnContext& cache ) @@ -283,10 +313,13 @@ private: , PartitionKey(partitionKey) , PartitionKeyPacker(true, partitionKeyType) , MatchedVarsArg(matchedVarsArg) + , VarNames(varNames) + , MeasureInputDataArg(measureInputDataArg) + , MeasureInputColumnOrder(measureInputColumnOrder) , Measures(measures) , OutputColumnOrder(outputColumnOrder) , CurrentRowIndexArg(currentRowIndexArg) - , DefineInputDataArg(defineInputDataArg) + , InputDataArg(inputDataArg) , Defines(defines) , Cache(cache) { @@ -334,10 +367,13 @@ private: return Partitions.emplace_hint(it, TString(packedKey), std::make_unique<TBackTrackingMatchRecognize>( std::move(partitionKey), MatchedVarsArg, + VarNames, + MeasureInputDataArg, + MeasureInputColumnOrder, Measures, OutputColumnOrder, CurrentRowIndexArg, - DefineInputDataArg, + InputDataArg, Defines, Cache )); @@ -357,10 +393,13 @@ private: //to be passed to partitions IComputationExternalNode* const MatchedVarsArg; + const TUnboxedValueVector& VarNames; + IComputationExternalNode* MeasureInputDataArg; + const TMeasureInputColumnOrder& MeasureInputColumnOrder; TComputationNodePtrVector Measures; const TOutputColumnOrder& OutputColumnOrder; IComputationExternalNode* const CurrentRowIndexArg; - IComputationExternalNode* const DefineInputDataArg; + IComputationExternalNode* const InputDataArg; const TComputationNodePtrVector& Defines; const TContainerCacheOnContext& Cache; }; @@ -370,8 +409,9 @@ private: if (const auto flow = FlowDependsOn(InputFlow)) { Own(flow, InputRowArg); Own(flow, MatchedVarsArg); + Own(flow, MeasureInputDataArg); Own(flow, CurrentRowIndexArg); - Own(flow, DefineInputDataArg); + Own(flow, inputDataArg); DependsOn(flow, PartitionKey); for (auto& m: Measures) { DependsOn(flow, m); @@ -387,10 +427,13 @@ private: IComputationNode* const PartitionKey; TType* const PartitionKeyType; IComputationExternalNode* const MatchedVarsArg; + const TUnboxedValueVector VarNames; + IComputationExternalNode* const MeasureInputDataArg; + const TMeasureInputColumnOrder MeasureInputColumnOrder; const TComputationNodePtrVector Measures; const TOutputColumnOrder OutputColumnOrder; IComputationExternalNode* const CurrentRowIndexArg; - IComputationExternalNode* const DefineInputDataArg; + IComputationExternalNode* const inputDataArg; const TComputationNodePtrVector Defines; const TContainerCacheOnContext Cache; }; @@ -446,6 +489,28 @@ TRowPattern ConvertPattern(const TRuntimeNode& pattern) { return result; } +TMeasureInputColumnOrder GetMeasureColumnOrder(const TListLiteral& specialColumnIndexes, ui32 inputRowColumnCount) { + using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns; + //Use Last enum value to denote that c colum comes from the input table + TMeasureInputColumnOrder result(inputRowColumnCount + specialColumnIndexes.GetItemsCount(), std::make_pair(EMeasureInputDataSpecialColumns::Last, 0)); + if (specialColumnIndexes.GetItemsCount() != 0) { + MKQL_ENSURE(specialColumnIndexes.GetItemsCount() == static_cast<size_t>(EMeasureInputDataSpecialColumns::Last), + "Internal logic error"); + for (size_t i = 0; i != specialColumnIndexes.GetItemsCount(); ++i) { + auto ind = AS_VALUE(TDataLiteral, specialColumnIndexes.GetItems()[i])->AsValue().Get<ui32>(); + result[ind] = std::make_pair(static_cast<EMeasureInputDataSpecialColumns>(i), 0); + } + } + //update indexes for input table columns + ui32 inputIdx = 0; + for (auto& [t, i]: result) { + if (EMeasureInputDataSpecialColumns::Last == t) { + i = inputIdx++; + } + } + return result; +} + TComputationNodePtrVector ConvertVectorOfCallables(const TRuntimeNode::TList& v, const TComputationNodeFactoryContext& ctx) { TComputationNodePtrVector result; result.reserve(v.size()); @@ -455,6 +520,16 @@ TComputationNodePtrVector ConvertVectorOfCallables(const TRuntimeNode::TList& v, return result; } +TUnboxedValueVector ConvertListOfStrings(const TRuntimeNode& l) { + TUnboxedValueVector result; + const auto& list = AS_VALUE(TListLiteral, l); + result.reserve(list->GetItemsCount()); + for (ui32 i = 0; i != list->GetItemsCount(); ++i) { + result.push_back(MakeString(AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().AsStringRef())); + } + return result; +} + } //namespace NMatchRecognize @@ -477,10 +552,10 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation } const auto& pattern = callable.GetInput(inputIndex++); const auto& currentRowIndexArg = callable.GetInput(inputIndex++); - const auto& defineInputDataArg = callable.GetInput(inputIndex++); - const auto& defineNames = callable.GetInput(inputIndex++); + const auto& inputDataArg = callable.GetInput(inputIndex++); + const auto& varNames = callable.GetInput(inputIndex++); TRuntimeNode::TList defines; - for (size_t i = 0; i != AS_VALUE(TListLiteral, defineNames)->GetItemsCount(); ++i) { + for (size_t i = 0; i != AS_VALUE(TListLiteral, varNames)->GetItemsCount(); ++i) { defines.push_back(callable.GetInput(inputIndex++)); } MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count"); @@ -497,10 +572,16 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()) , partitionKeySelector.GetStaticType() , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode())) + , ConvertListOfStrings(varNames) + , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode())) + , GetMeasureColumnOrder( + *AS_VALUE(TListLiteral, measureSpecialColumnIndexes), + AS_VALUE(TDataLiteral, inputRowColumnCount)->AsValue().Get<ui32>() + ) , ConvertVectorOfCallables(measures, ctx) , GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes) , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode())) - , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *defineInputDataArg.GetNode())) + , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode())) , ConvertVectorOfCallables(defines, ctx) ); } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_measure_arg.h b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_measure_arg.h new file mode 100644 index 0000000000..54d328075d --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_measure_arg.h @@ -0,0 +1,127 @@ +#pragma once + +#include "mkql_match_recognize_matched_vars.h" +#include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h> +#include <ydb/library/yql/core/sql_types/match_recognize.h> +#include <ydb/library/yql/minikql/mkql_string_util.h> + +namespace NKikimr::NMiniKQL::NMatchRecognize { + +using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns; + +using TMeasureInputColumnOrder = TVector<std::pair<EMeasureInputDataSpecialColumns, size_t>>; + +//Input row augmented with lightweight special columns for calculating MEASURE lambdas +class TRowForMeasureValue: public TComputationValue<TRowForMeasureValue> +{ +public: + TRowForMeasureValue( + TMemoryUsageInfo* memInfo, + NUdf::TUnboxedValue inputRow, + ui64 rowIndex, + const TMeasureInputColumnOrder& columnOrder, + const NUdf::TUnboxedValue& matchedVars, + const TUnboxedValueVector& varNames, + ui64 matchNumber + ) + : TComputationValue<TRowForMeasureValue>(memInfo) + , InputRow(inputRow) + , RowIndex(rowIndex) + , ColumnOrder(columnOrder) + , MatchedVars(matchedVars) + , VarNames(varNames) + , MatchNumber(matchNumber) + {} + NUdf::TUnboxedValue GetElement(ui32 index) const override { + switch(ColumnOrder[index].first) { + case EMeasureInputDataSpecialColumns::Classifier: { + auto varIterator = MatchedVars.GetListIterator(); + MKQL_ENSURE(varIterator, "Internal logic error"); + NUdf::TUnboxedValue var; + size_t varIndex = 0; + while(varIterator.Next(var)) { + auto rangeIterator = var.GetListIterator(); + MKQL_ENSURE(varIterator, "Internal logic error"); + NUdf::TUnboxedValue range; + while(rangeIterator.Next(range)) { + const auto from = range.GetElement(0).Get<ui64>(); + const auto to = range.GetElement(1).Get<ui64>(); + if (RowIndex >= from and RowIndex <= to) { + return VarNames[varIndex]; + } + } + ++varIndex; + } + MKQL_ENSURE(MatchedVars.GetListLength() == varIndex, "Internal logic error"); + return MakeString(""); + } + case EMeasureInputDataSpecialColumns::MatchNumber: + return NUdf::TUnboxedValuePod(MatchNumber); + case EMeasureInputDataSpecialColumns::Last: //Last corresponds to columns from the input table row + return InputRow.GetElement(ColumnOrder[index].second); + } + } +private: + const NUdf::TUnboxedValue InputRow; + const ui64 RowIndex; + const TMeasureInputColumnOrder& ColumnOrder; + const NUdf::TUnboxedValue& MatchedVars; + const TUnboxedValueVector& VarNames; + ui64 MatchNumber; +}; + +class TMeasureInputDataValue: public TComputationValue<TMeasureInputDataValue> { + using Base = TComputationValue<TMeasureInputDataValue>; +public: + TMeasureInputDataValue(TMemoryUsageInfo* memInfo, + const NUdf::TUnboxedValue& inputData, + const TMeasureInputColumnOrder& columnOrder, + const NUdf::TUnboxedValue& matchedVars, + const TUnboxedValueVector& varNames, + ui64 matchNumber) + : Base(memInfo) + , InputData(inputData) + , ColumnOrder(columnOrder) + , MatchedVars(matchedVars) + , VarNames(varNames) + , MatchNumber(matchNumber) + {} + + bool HasFastListLength() const override { + return InputData.HasFastListLength(); + } + + ui64 GetListLength() const override { + return InputData.GetListLength(); + } + + //TODO https://st.yandex-team.ru/YQL-16508 + //NUdf::TUnboxedValue GetListIterator() const override; + + NUdf::IBoxedValuePtr ToIndexDictImpl(const NUdf::IValueBuilder& builder) const override { + Y_UNUSED(builder); + return const_cast<TMeasureInputDataValue*>(this); + } + + NUdf::TUnboxedValue Lookup(const NUdf::TUnboxedValuePod& key) const override { + auto inputRow = InputData.Lookup(key); + return NUdf::TUnboxedValuePod{new TRowForMeasureValue( + GetMemInfo(), + inputRow, + key.Get<ui64>(), + ColumnOrder, + MatchedVars, + VarNames, + MatchNumber + )}; + } +private: + const NUdf::TUnboxedValue InputData; + const TMeasureInputColumnOrder& ColumnOrder; + const NUdf::TUnboxedValue MatchedVars; + const TUnboxedValueVector& VarNames; + const ui64 MatchNumber; +}; + +}//namespace NKikimr::NMiniKQL::NMatchRecognize + diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 07a2841f3a..3b2eceee7a 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -5930,12 +5930,12 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( std::vector<TRuntimeNode> defineNames(patternVarLookup.size()); std::vector<TRuntimeNode> defineNodes(patternVarLookup.size()); - const auto& defineInputDataArg = Arg(TListType::Create(inputRowType, Env)); + const auto& inputDataArg = Arg(TListType::Create(inputRowType, Env)); const auto& currentRowIndexArg = Arg(TDataType::Create(NUdf::TDataType<ui64>::Id, Env)); for (const auto& [v, i]: patternVarLookup) { defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v); if (const auto it = defineLookup.find(v); it != defineLookup.end()) { - defineNodes[i] = getDefines[it->second].second(defineInputDataArg, matchedVarsArg, currentRowIndexArg); + defineNodes[i] = getDefines[it->second].second(inputDataArg, matchedVarsArg, currentRowIndexArg); } else { //no predicate for var if ("$" == v || "^" == v) { @@ -5970,7 +5970,7 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( callableBuilder.Add(PatternToRuntimeNode(pattern, *this)); callableBuilder.Add(currentRowIndexArg); - callableBuilder.Add(defineInputDataArg); + callableBuilder.Add(inputDataArg); const auto stringType = NewDataType(NUdf::EDataSlot::String); callableBuilder.Add(TRuntimeNode(TListLiteral::Create(defineNames.begin(), defineNames.size(), TListType::Create(stringType, Env), Env), true)); for (const auto& d: defineNodes) { |