diff options
author | zverevgeny <zverevgeny@ydb.tech> | 2023-08-30 12:54:06 +0300 |
---|---|---|
committer | zverevgeny <zverevgeny@ydb.tech> | 2023-08-30 13:58:23 +0300 |
commit | d8e8dfdd57e6c0f1e39a8839c4f41d979dd9920c (patch) | |
tree | cf3929356268e7c24fc401d3b3bd2b3a9ccb57c4 | |
parent | d211395d364252c63eb2824a57d7966afd123ce6 (diff) | |
download | ydb-d8e8dfdd57e6c0f1e39a8839c4f41d979dd9920c.tar.gz |
YQL-16325 matched vars for MATCH_RECOGNIZE
4 files changed, 186 insertions, 16 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp index 99b5b932be4..dd42271bdd8 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp @@ -1,3 +1,4 @@ +#include "mkql_match_recognize_matched_vars.h" #include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_pack.h> @@ -15,6 +16,8 @@ using TMeasureInputColumnOrder = std::vector<std::pair<EMeasureColumnSource, siz enum class EOutputColumnSource {PartitionKey, Measure}; using TOutputColumnOrder = std::vector<std::pair<EOutputColumnSource, size_t>>; +using namespace NMatchRecognize; + //Process one partition of input data struct IProcessMatchRecognize { ///return true if it has output data ready @@ -28,27 +31,43 @@ class TStreamMatchRecognize: public IProcessMatchRecognize { public: TStreamMatchRecognize( NUdf::TUnboxedValue&& partitionKey, + IComputationExternalNode* matchedVarsArg, std::vector<IComputationNode*>& measures, const TOutputColumnOrder& outputColumnOrder, const TContainerCacheOnContext& cache ) : PartitionKey(std::move(partitionKey)) + , MatchedVarsArg(matchedVarsArg) , Measures(measures) , OutputColumnOrder(outputColumnOrder) , Cache(cache) + , MatchedVars(2) //Assume pattern (A B B)*, where A matches every 3rd row (i%3 == 0) and B matches the rest , HasMatch(false) + , RowCount(0) { } bool ProcessInputRow(NUdf::TUnboxedValue&& row) override{ - //Assume match on every row, TODO fixme Y_UNUSED(row); - HasMatch = true; + //Assume pattern (A B B)*, where A matches every 3rd row (i%3 == 0) and B matches the rest + switch (RowCount % 3) { + case 0: + MatchedVars[0].push_back({RowCount, RowCount}); + break; + case 1: + MatchedVars[1].push_back({RowCount, RowCount}); + break; + case 2: + MatchedVars[1].back().second++; + break; + } + ++RowCount; return HasMatch; } NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) override { if (!HasMatch) return NUdf::TUnboxedValue::Invalid(); + MatchedVarsArg->SetValue(ctx, NUdf::TUnboxedValuePod(new TMatchedVarsValue(&ctx.HolderFactory.GetMemInfo(), MatchedVars))); HasMatch = false; NUdf::TUnboxedValue *itemsPtr = nullptr; const auto result = Cache.NewArray(ctx, OutputColumnOrder.size(), itemsPtr); @@ -67,16 +86,18 @@ public: return result; } bool ProcessEndOfData() override { - //TODO - return false; + HasMatch = true; + return HasMatch; } private: const NUdf::TUnboxedValue PartitionKey; + IComputationExternalNode* const MatchedVarsArg; const std::vector<IComputationNode*>& Measures; const TOutputColumnOrder& OutputColumnOrder; const TContainerCacheOnContext& Cache; + TMatchedVars MatchedVars; bool HasMatch; - + size_t RowCount; }; @@ -87,17 +108,19 @@ public: IComputationExternalNode *inputRowArg, IComputationNode *partitionKey, TType* partitionKeyType, + IComputationExternalNode* matchedVarsArg, std::vector<IComputationNode*>&& measures, TOutputColumnOrder&& outputColumnOrder ) - : TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded) - , InputFlow(inputFlow) - , InputRowArg(inputRowArg) - , PartitionKey(partitionKey) - , PartitionKeyType(partitionKeyType) - , Measures(measures) - , OutputColumnOrder(outputColumnOrder) - , Cache(mutables) + :TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded) + , InputFlow(inputFlow) + , InputRowArg(inputRowArg) + , PartitionKey(partitionKey) + , PartitionKeyType(partitionKeyType) + , MatchedVarsArg(matchedVarsArg) + , Measures(measures) + , OutputColumnOrder(outputColumnOrder) + , Cache(mutables) {} NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue &stateValue, TComputationContext &ctx) const { @@ -106,6 +129,7 @@ public: InputRowArg, PartitionKey, PartitionKeyType, + MatchedVarsArg, Measures, OutputColumnOrder, Cache @@ -136,6 +160,7 @@ private: IComputationExternalNode* inputRowArg, IComputationNode* partitionKey, TType* partitionKeyType, + IComputationExternalNode* matchedVarsArg, const std::vector<IComputationNode*>& measures, const TOutputColumnOrder& outputColumnOrder, const TContainerCacheOnContext& cache @@ -144,6 +169,7 @@ private: , InputRowArg(inputRowArg) , PartitionKey(partitionKey) , PartitionKeyPacker(true, partitionKeyType) + , MatchedVarsArg(matchedVarsArg) , Measures(measures) , OutputColumnOrder(outputColumnOrder) , Cache(cache) @@ -191,6 +217,7 @@ private: } else { return Partitions.emplace_hint(it, TString(packedKey), std::make_unique<TStreamMatchRecognize>( std::move(partitionKey), + MatchedVarsArg, Measures, OutputColumnOrder, Cache @@ -210,6 +237,7 @@ private: TValuePackerGeneric<false> PartitionKeyPacker; //to be passed to partitions + IComputationExternalNode* const MatchedVarsArg; std::vector<IComputationNode*> Measures; const TOutputColumnOrder& OutputColumnOrder; const TContainerCacheOnContext& Cache; @@ -219,14 +247,19 @@ private: void RegisterDependencies() const final { if (const auto flow = FlowDependsOn(InputFlow)) { Own(flow, InputRowArg); + Own(flow, MatchedVarsArg); + DependsOn(flow, PartitionKey); + for (auto& m: Measures) { + DependsOn(flow, m); + } } - DependsOn(PartitionKey, InputRowArg); } IComputationNode* const InputFlow; IComputationExternalNode* const InputRowArg; IComputationNode* const PartitionKey; TType* const PartitionKeyType; + IComputationExternalNode* const MatchedVarsArg; std::vector<IComputationNode*> Measures; TOutputColumnOrder OutputColumnOrder; const TContainerCacheOnContext Cache; @@ -279,7 +312,7 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation const auto& measureInputDataArg = callable.GetInput(inputIndex++); const auto& measureSpecialColumnIndexes = callable.GetInput(inputIndex++); const auto& inputRowColumnCount = callable.GetInput(inputIndex++); - const auto& matchedRangesArg = callable.GetInput(inputIndex++); + const auto& matchedVarsArg = callable.GetInput(inputIndex++); const auto& measureColumnIndexes = callable.GetInput(inputIndex++); std::vector<TRuntimeNode> measures; for (size_t i = 0; i != AS_VALUE(TListLiteral, measureColumnIndexes)->GetItemsCount(); ++i) { @@ -289,13 +322,13 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation Y_UNUSED(measureInputDataArg); Y_UNUSED(measureSpecialColumnIndexes); Y_UNUSED(inputRowColumnCount); - Y_UNUSED(matchedRangesArg); return new TMatchRecognizeWrapper(ctx.Mutables, GetValueRepresentation(inputFlow.GetStaticType()) , LocateNode(ctx.NodeLocator, *inputFlow.GetNode()) , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())) , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()) , partitionKeySelector.GetStaticType() + , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode())) , ConvertVectorOfCallables(measures, ctx) , GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes) ); diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_matched_vars.h b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_matched_vars.h new file mode 100644 index 00000000000..91ae930f6f8 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_matched_vars.h @@ -0,0 +1,95 @@ +#pragma once +#include "../computation/mkql_computation_node_impl.h" + +namespace NKikimr::NMiniKQL::NMatchRecognize { + +using TMatchedRange = std::pair<ui64, ui64>; + +using TMatchedVar = std::vector<TMatchedRange>; + +using TMatchedVars = std::vector<TMatchedVar>; + +class TMatchedVarsValue : public TComputationValue<TMatchedVarsValue> { + class TRangeValue: public TComputationValue<TRangeValue> { + public: + TRangeValue(TMemoryUsageInfo* memInfo, const TMatchedRange& r) + : TComputationValue<TRangeValue>(memInfo) + , Range(r) + { + } + + NUdf::TUnboxedValue* GetElements() const override { + return nullptr; + } + NUdf::TUnboxedValue GetElement(ui32 index) const override { + MKQL_ENSURE(index < 2, "Index out of range"); + switch(index) { + case 0: return NUdf::TUnboxedValuePod(Range.first); + case 1: return NUdf::TUnboxedValuePod(Range.second); + } + return NUdf::TUnboxedValuePod(); + } + private: + const TMatchedRange& Range; + }; + + class TListRangeValue: public TComputationValue<TListRangeValue> { + public: + TListRangeValue(TMemoryUsageInfo* memInfo, const TMatchedVar& v) + : TComputationValue<TListRangeValue>(memInfo) + , Var(v) + { + } + class TIterator : public TComputationValue<TIterator> { + public: + TIterator(TMemoryUsageInfo *memInfo, const std::vector<TMatchedRange>& ranges) + : TComputationValue<TIterator>(memInfo) + , Ranges(ranges) + , Index(0) + {} + + private: + bool Next(NUdf::TUnboxedValue& value) override { + if (Ranges.size() == Index){ + return false; + } + value = NUdf::TUnboxedValuePod(new TRangeValue(GetMemInfo(), Ranges[Index++])); + return true; + } + + const std::vector<TMatchedRange>& Ranges; + size_t Index; + }; + + bool HasFastListLength() const override { + return true; + } + + ui64 GetListLength() const override { + return Var.size(); + } + + bool HasListItems() const override { + return !Var.empty(); + } + + NUdf::TUnboxedValue GetListIterator() const override { + return NUdf::TUnboxedValuePod(new TIterator(GetMemInfo(), Var)); + } + private: + const TMatchedVar& Var; + }; +public: + TMatchedVarsValue(TMemoryUsageInfo* memInfo, const std::vector<TMatchedVar>& vars) + : TComputationValue<TMatchedVarsValue>(memInfo) + , Vars(vars) + { + } + + NUdf::TUnboxedValue GetElement(ui32 index) const override { + return NUdf::TUnboxedValuePod(new TListRangeValue(GetMemInfo(), Vars[index])); + } +private: + const std::vector<TMatchedVar>& Vars; +}; +}//namespace NKikimr::NMiniKQL::NMatchRecognize diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_matched_vars_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_matched_vars_ut.cpp new file mode 100644 index 00000000000..8973dedd053 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_matched_vars_ut.cpp @@ -0,0 +1,41 @@ +#include "../mkql_match_recognize_matched_vars.h" +#include <library/cpp/testing/unittest/registar.h> + +namespace NKikimr::NMiniKQL::NMatchRecognize { + +Y_UNIT_TEST_SUITE(MatchRecognizeMatchedVars) { + TMemoryUsageInfo memUsage("MatchedVars"); + Y_UNIT_TEST(MatchedVarsEmpty) { + TScopedAlloc alloc(__LOCATION__); + { + TMatchedVars vars{}; + NUdf::TUnboxedValue value(NUdf::TUnboxedValuePod(new TMatchedVarsValue(&memUsage, vars))); + UNIT_ASSERT(value.HasValue()); + } + } + Y_UNIT_TEST(MatchedVars) { + TScopedAlloc alloc(__LOCATION__); + { + TMatchedVar A{{1, 4}, {7, 9}, {100, 200}}; + TMatchedVar B{{1, 6}}; + TMatchedVars vars{A, B}; + NUdf::TUnboxedValue value(NUdf::TUnboxedValuePod(new TMatchedVarsValue(&memUsage, vars))); + UNIT_ASSERT(value.HasValue()); + auto a = value.GetElement(0); + UNIT_ASSERT(a.HasValue()); + UNIT_ASSERT_VALUES_EQUAL(3, a.GetListLength()); + auto iter = a.GetListIterator(); + UNIT_ASSERT(iter.HasValue()); + NUdf::TUnboxedValue last; + while (iter.Next(last)) + ; + UNIT_ASSERT(last.HasValue()); + UNIT_ASSERT_VALUES_EQUAL(100, last.GetElement(0).Get<ui64>()); + UNIT_ASSERT_VALUES_EQUAL(200, last.GetElement(1).Get<ui64>()); + auto b = value.GetElement(1); + UNIT_ASSERT(b.HasValue()); + UNIT_ASSERT_VALUES_EQUAL(1, b.GetListLength()); + } + } +} +}//namespace NKikimr::NMiniKQL::TMatchRecognize diff --git a/ydb/library/yql/minikql/comp_nodes/ut/ya.make b/ydb/library/yql/minikql/comp_nodes/ut/ya.make index 2303397956f..4cdc8bf9b5a 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/ya.make +++ b/ydb/library/yql/minikql/comp_nodes/ut/ya.make @@ -40,6 +40,7 @@ SRCS( mkql_join_dict_ut.cpp mkql_grace_join_ut.cpp mkql_map_join_ut.cpp + mkql_match_recognize_matched_vars_ut.cpp mkql_safe_circular_buffer_ut.cpp mkql_sort_ut.cpp mkql_switch_ut.cpp |