aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorzverevgeny <zverevgeny@ydb.tech>2023-08-30 12:54:06 +0300
committerzverevgeny <zverevgeny@ydb.tech>2023-08-30 13:58:23 +0300
commitd8e8dfdd57e6c0f1e39a8839c4f41d979dd9920c (patch)
treecf3929356268e7c24fc401d3b3bd2b3a9ccb57c4
parentd211395d364252c63eb2824a57d7966afd123ce6 (diff)
downloadydb-d8e8dfdd57e6c0f1e39a8839c4f41d979dd9920c.tar.gz
YQL-16325 matched vars for MATCH_RECOGNIZE
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp65
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_matched_vars.h95
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_matched_vars_ut.cpp41
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/ya.make1
4 files changed, 186 insertions, 16 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
index 99b5b932be4..dd42271bdd8 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
@@ -1,3 +1,4 @@
+#include "mkql_match_recognize_matched_vars.h"
#include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h>
#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
#include <ydb/library/yql/minikql/computation/mkql_computation_node_pack.h>
@@ -15,6 +16,8 @@ using TMeasureInputColumnOrder = std::vector<std::pair<EMeasureColumnSource, siz
enum class EOutputColumnSource {PartitionKey, Measure};
using TOutputColumnOrder = std::vector<std::pair<EOutputColumnSource, size_t>>;
+using namespace NMatchRecognize;
+
//Process one partition of input data
struct IProcessMatchRecognize {
///return true if it has output data ready
@@ -28,27 +31,43 @@ class TStreamMatchRecognize: public IProcessMatchRecognize {
public:
TStreamMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
+ IComputationExternalNode* matchedVarsArg,
std::vector<IComputationNode*>& measures,
const TOutputColumnOrder& outputColumnOrder,
const TContainerCacheOnContext& cache
)
: PartitionKey(std::move(partitionKey))
+ , MatchedVarsArg(matchedVarsArg)
, Measures(measures)
, OutputColumnOrder(outputColumnOrder)
, Cache(cache)
+ , MatchedVars(2) //Assume pattern (A B B)*, where A matches every 3rd row (i%3 == 0) and B matches the rest
, HasMatch(false)
+ , RowCount(0)
{
}
bool ProcessInputRow(NUdf::TUnboxedValue&& row) override{
- //Assume match on every row, TODO fixme
Y_UNUSED(row);
- HasMatch = true;
+ //Assume pattern (A B B)*, where A matches every 3rd row (i%3 == 0) and B matches the rest
+ switch (RowCount % 3) {
+ case 0:
+ MatchedVars[0].push_back({RowCount, RowCount});
+ break;
+ case 1:
+ MatchedVars[1].push_back({RowCount, RowCount});
+ break;
+ case 2:
+ MatchedVars[1].back().second++;
+ break;
+ }
+ ++RowCount;
return HasMatch;
}
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) override {
if (!HasMatch)
return NUdf::TUnboxedValue::Invalid();
+ MatchedVarsArg->SetValue(ctx, NUdf::TUnboxedValuePod(new TMatchedVarsValue(&ctx.HolderFactory.GetMemInfo(), MatchedVars)));
HasMatch = false;
NUdf::TUnboxedValue *itemsPtr = nullptr;
const auto result = Cache.NewArray(ctx, OutputColumnOrder.size(), itemsPtr);
@@ -67,16 +86,18 @@ public:
return result;
}
bool ProcessEndOfData() override {
- //TODO
- return false;
+ HasMatch = true;
+ return HasMatch;
}
private:
const NUdf::TUnboxedValue PartitionKey;
+ IComputationExternalNode* const MatchedVarsArg;
const std::vector<IComputationNode*>& Measures;
const TOutputColumnOrder& OutputColumnOrder;
const TContainerCacheOnContext& Cache;
+ TMatchedVars MatchedVars;
bool HasMatch;
-
+ size_t RowCount;
};
@@ -87,17 +108,19 @@ public:
IComputationExternalNode *inputRowArg,
IComputationNode *partitionKey,
TType* partitionKeyType,
+ IComputationExternalNode* matchedVarsArg,
std::vector<IComputationNode*>&& measures,
TOutputColumnOrder&& outputColumnOrder
)
- : TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded)
- , InputFlow(inputFlow)
- , InputRowArg(inputRowArg)
- , PartitionKey(partitionKey)
- , PartitionKeyType(partitionKeyType)
- , Measures(measures)
- , OutputColumnOrder(outputColumnOrder)
- , Cache(mutables)
+ :TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded)
+ , InputFlow(inputFlow)
+ , InputRowArg(inputRowArg)
+ , PartitionKey(partitionKey)
+ , PartitionKeyType(partitionKeyType)
+ , MatchedVarsArg(matchedVarsArg)
+ , Measures(measures)
+ , OutputColumnOrder(outputColumnOrder)
+ , Cache(mutables)
{}
NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue &stateValue, TComputationContext &ctx) const {
@@ -106,6 +129,7 @@ public:
InputRowArg,
PartitionKey,
PartitionKeyType,
+ MatchedVarsArg,
Measures,
OutputColumnOrder,
Cache
@@ -136,6 +160,7 @@ private:
IComputationExternalNode* inputRowArg,
IComputationNode* partitionKey,
TType* partitionKeyType,
+ IComputationExternalNode* matchedVarsArg,
const std::vector<IComputationNode*>& measures,
const TOutputColumnOrder& outputColumnOrder,
const TContainerCacheOnContext& cache
@@ -144,6 +169,7 @@ private:
, InputRowArg(inputRowArg)
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
+ , MatchedVarsArg(matchedVarsArg)
, Measures(measures)
, OutputColumnOrder(outputColumnOrder)
, Cache(cache)
@@ -191,6 +217,7 @@ private:
} else {
return Partitions.emplace_hint(it, TString(packedKey), std::make_unique<TStreamMatchRecognize>(
std::move(partitionKey),
+ MatchedVarsArg,
Measures,
OutputColumnOrder,
Cache
@@ -210,6 +237,7 @@ private:
TValuePackerGeneric<false> PartitionKeyPacker;
//to be passed to partitions
+ IComputationExternalNode* const MatchedVarsArg;
std::vector<IComputationNode*> Measures;
const TOutputColumnOrder& OutputColumnOrder;
const TContainerCacheOnContext& Cache;
@@ -219,14 +247,19 @@ private:
void RegisterDependencies() const final {
if (const auto flow = FlowDependsOn(InputFlow)) {
Own(flow, InputRowArg);
+ Own(flow, MatchedVarsArg);
+ DependsOn(flow, PartitionKey);
+ for (auto& m: Measures) {
+ DependsOn(flow, m);
+ }
}
- DependsOn(PartitionKey, InputRowArg);
}
IComputationNode* const InputFlow;
IComputationExternalNode* const InputRowArg;
IComputationNode* const PartitionKey;
TType* const PartitionKeyType;
+ IComputationExternalNode* const MatchedVarsArg;
std::vector<IComputationNode*> Measures;
TOutputColumnOrder OutputColumnOrder;
const TContainerCacheOnContext Cache;
@@ -279,7 +312,7 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
const auto& measureInputDataArg = callable.GetInput(inputIndex++);
const auto& measureSpecialColumnIndexes = callable.GetInput(inputIndex++);
const auto& inputRowColumnCount = callable.GetInput(inputIndex++);
- const auto& matchedRangesArg = callable.GetInput(inputIndex++);
+ const auto& matchedVarsArg = callable.GetInput(inputIndex++);
const auto& measureColumnIndexes = callable.GetInput(inputIndex++);
std::vector<TRuntimeNode> measures;
for (size_t i = 0; i != AS_VALUE(TListLiteral, measureColumnIndexes)->GetItemsCount(); ++i) {
@@ -289,13 +322,13 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
Y_UNUSED(measureInputDataArg);
Y_UNUSED(measureSpecialColumnIndexes);
Y_UNUSED(inputRowColumnCount);
- Y_UNUSED(matchedRangesArg);
return new TMatchRecognizeWrapper(ctx.Mutables, GetValueRepresentation(inputFlow.GetStaticType())
, LocateNode(ctx.NodeLocator, *inputFlow.GetNode())
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode()))
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
, partitionKeySelector.GetStaticType()
+ , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode()))
, ConvertVectorOfCallables(measures, ctx)
, GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes)
);
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_matched_vars.h b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_matched_vars.h
new file mode 100644
index 00000000000..91ae930f6f8
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_matched_vars.h
@@ -0,0 +1,95 @@
+#pragma once
+#include "../computation/mkql_computation_node_impl.h"
+
+namespace NKikimr::NMiniKQL::NMatchRecognize {
+
+using TMatchedRange = std::pair<ui64, ui64>;
+
+using TMatchedVar = std::vector<TMatchedRange>;
+
+using TMatchedVars = std::vector<TMatchedVar>;
+
+class TMatchedVarsValue : public TComputationValue<TMatchedVarsValue> {
+ class TRangeValue: public TComputationValue<TRangeValue> {
+ public:
+ TRangeValue(TMemoryUsageInfo* memInfo, const TMatchedRange& r)
+ : TComputationValue<TRangeValue>(memInfo)
+ , Range(r)
+ {
+ }
+
+ NUdf::TUnboxedValue* GetElements() const override {
+ return nullptr;
+ }
+ NUdf::TUnboxedValue GetElement(ui32 index) const override {
+ MKQL_ENSURE(index < 2, "Index out of range");
+ switch(index) {
+ case 0: return NUdf::TUnboxedValuePod(Range.first);
+ case 1: return NUdf::TUnboxedValuePod(Range.second);
+ }
+ return NUdf::TUnboxedValuePod();
+ }
+ private:
+ const TMatchedRange& Range;
+ };
+
+ class TListRangeValue: public TComputationValue<TListRangeValue> {
+ public:
+ TListRangeValue(TMemoryUsageInfo* memInfo, const TMatchedVar& v)
+ : TComputationValue<TListRangeValue>(memInfo)
+ , Var(v)
+ {
+ }
+ class TIterator : public TComputationValue<TIterator> {
+ public:
+ TIterator(TMemoryUsageInfo *memInfo, const std::vector<TMatchedRange>& ranges)
+ : TComputationValue<TIterator>(memInfo)
+ , Ranges(ranges)
+ , Index(0)
+ {}
+
+ private:
+ bool Next(NUdf::TUnboxedValue& value) override {
+ if (Ranges.size() == Index){
+ return false;
+ }
+ value = NUdf::TUnboxedValuePod(new TRangeValue(GetMemInfo(), Ranges[Index++]));
+ return true;
+ }
+
+ const std::vector<TMatchedRange>& Ranges;
+ size_t Index;
+ };
+
+ bool HasFastListLength() const override {
+ return true;
+ }
+
+ ui64 GetListLength() const override {
+ return Var.size();
+ }
+
+ bool HasListItems() const override {
+ return !Var.empty();
+ }
+
+ NUdf::TUnboxedValue GetListIterator() const override {
+ return NUdf::TUnboxedValuePod(new TIterator(GetMemInfo(), Var));
+ }
+ private:
+ const TMatchedVar& Var;
+ };
+public:
+ TMatchedVarsValue(TMemoryUsageInfo* memInfo, const std::vector<TMatchedVar>& vars)
+ : TComputationValue<TMatchedVarsValue>(memInfo)
+ , Vars(vars)
+ {
+ }
+
+ NUdf::TUnboxedValue GetElement(ui32 index) const override {
+ return NUdf::TUnboxedValuePod(new TListRangeValue(GetMemInfo(), Vars[index]));
+ }
+private:
+ const std::vector<TMatchedVar>& Vars;
+};
+}//namespace NKikimr::NMiniKQL::NMatchRecognize
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_matched_vars_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_matched_vars_ut.cpp
new file mode 100644
index 00000000000..8973dedd053
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_matched_vars_ut.cpp
@@ -0,0 +1,41 @@
+#include "../mkql_match_recognize_matched_vars.h"
+#include <library/cpp/testing/unittest/registar.h>
+
+namespace NKikimr::NMiniKQL::NMatchRecognize {
+
+Y_UNIT_TEST_SUITE(MatchRecognizeMatchedVars) {
+ TMemoryUsageInfo memUsage("MatchedVars");
+ Y_UNIT_TEST(MatchedVarsEmpty) {
+ TScopedAlloc alloc(__LOCATION__);
+ {
+ TMatchedVars vars{};
+ NUdf::TUnboxedValue value(NUdf::TUnboxedValuePod(new TMatchedVarsValue(&memUsage, vars)));
+ UNIT_ASSERT(value.HasValue());
+ }
+ }
+ Y_UNIT_TEST(MatchedVars) {
+ TScopedAlloc alloc(__LOCATION__);
+ {
+ TMatchedVar A{{1, 4}, {7, 9}, {100, 200}};
+ TMatchedVar B{{1, 6}};
+ TMatchedVars vars{A, B};
+ NUdf::TUnboxedValue value(NUdf::TUnboxedValuePod(new TMatchedVarsValue(&memUsage, vars)));
+ UNIT_ASSERT(value.HasValue());
+ auto a = value.GetElement(0);
+ UNIT_ASSERT(a.HasValue());
+ UNIT_ASSERT_VALUES_EQUAL(3, a.GetListLength());
+ auto iter = a.GetListIterator();
+ UNIT_ASSERT(iter.HasValue());
+ NUdf::TUnboxedValue last;
+ while (iter.Next(last))
+ ;
+ UNIT_ASSERT(last.HasValue());
+ UNIT_ASSERT_VALUES_EQUAL(100, last.GetElement(0).Get<ui64>());
+ UNIT_ASSERT_VALUES_EQUAL(200, last.GetElement(1).Get<ui64>());
+ auto b = value.GetElement(1);
+ UNIT_ASSERT(b.HasValue());
+ UNIT_ASSERT_VALUES_EQUAL(1, b.GetListLength());
+ }
+ }
+}
+}//namespace NKikimr::NMiniKQL::TMatchRecognize
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/ya.make b/ydb/library/yql/minikql/comp_nodes/ut/ya.make
index 2303397956f..4cdc8bf9b5a 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/ya.make
+++ b/ydb/library/yql/minikql/comp_nodes/ut/ya.make
@@ -40,6 +40,7 @@ SRCS(
mkql_join_dict_ut.cpp
mkql_grace_join_ut.cpp
mkql_map_join_ut.cpp
+ mkql_match_recognize_matched_vars_ut.cpp
mkql_safe_circular_buffer_ut.cpp
mkql_sort_ut.cpp
mkql_switch_ut.cpp