aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorzverevgeny <zverevgeny@ydb.tech>2023-09-08 07:08:49 +0300
committerzverevgeny <zverevgeny@ydb.tech>2023-09-08 07:30:58 +0300
commitb27ebe54a0677e957e81d60d8286b52ad72265af (patch)
treeb85df044d67c66ac077e32649899afb5a6c7e1d4
parent563e21fcd7280e84fb359d0a8de2403a5924a3d8 (diff)
downloadydb-b27ebe54a0677e957e81d60d8286b52ad72265af.tar.gz
YQL-16325 augmented input data for measures
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp125
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_measure_arg.h127
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.cpp6
3 files changed, 233 insertions, 25 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
index e7a59032f4..6b4764a8e3 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp
@@ -1,9 +1,12 @@
#include "mkql_match_recognize_matched_vars.h"
+#include "mkql_match_recognize_measure_arg.h"
+#include <ydb/library/yql/core/sql_types/match_recognize.h>
#include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h>
#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
#include <ydb/library/yql/minikql/computation/mkql_computation_node_pack.h>
#include <ydb/library/yql/minikql/mkql_node_cast.h>
#include <ydb/library/yql/minikql/mkql_runtime_version.h>
+#include <ydb/library/yql/minikql/mkql_string_util.h>
#include <ydb/library/yql/core/sql_types/match_recognize.h>
#include <deque>
@@ -11,9 +14,6 @@ namespace NKikimr::NMiniKQL {
namespace NMatchRecognize {
-enum class EMeasureColumnSource {Classifier = 0, MatchNumber = 1, Input};
-using TMeasureInputColumnOrder = TVector<std::pair<EMeasureColumnSource, size_t>>;
-
enum class EOutputColumnSource {PartitionKey, Measure};
using TOutputColumnOrder = TVector<std::pair<EOutputColumnSource, size_t>>;
@@ -33,22 +33,29 @@ public:
TBackTrackingMatchRecognize(
NUdf::TUnboxedValue&& partitionKey,
IComputationExternalNode* matchedVarsArg,
+ const TUnboxedValueVector& varNames,
+ IComputationExternalNode* measureInputDataArg,
+ const TMeasureInputColumnOrder& measureInputColumnOrder,
const TComputationNodePtrVector& measures,
const TOutputColumnOrder& outputColumnOrder,
IComputationExternalNode* currentRowIndexArg,
- IComputationExternalNode* defineInputDataArg,
+ IComputationExternalNode* inputDataArg,
const TComputationNodePtrVector& defines,
const TContainerCacheOnContext& cache
)
: PartitionKey(std::move(partitionKey))
, MatchedVarsArg(matchedVarsArg)
+ , VarNames(varNames)
+ , MeasureInputDataArg(measureInputDataArg)
+ , MeasureInputColumnOrder(measureInputColumnOrder)
, Measures(measures)
, OutputColumnOrder(outputColumnOrder)
, CurrentRowIndexArg(currentRowIndexArg)
- , DefineInputDataArg(defineInputDataArg)
+ , InputDataArg(inputDataArg)
, Defines(defines)
, Cache(cache)
, CurMatchedVars(Defines.size())
+ , MatchNumber(0)
{
}
@@ -60,8 +67,15 @@ public:
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) override {
if (Matches.empty())
return NUdf::TUnboxedValue{};
- MatchedVarsArg->SetValue(ctx, ToValue(ctx, Matches.front()));
+ MatchedVarsArg->SetValue(ctx, ToValue(ctx, std::move(Matches.front())));
Matches.pop_front();
+ MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
+ InputDataArg->GetValue(ctx),
+ MeasureInputColumnOrder,
+ MatchedVarsArg->GetValue(ctx),
+ VarNames,
+ ++MatchNumber
+ ));
NUdf::TUnboxedValue *itemsPtr = nullptr;
const auto result = Cache.NewArray(ctx, OutputColumnOrder.size(), itemsPtr);
for (auto const& c: OutputColumnOrder) {
@@ -80,7 +94,7 @@ public:
//Assume, that data moved to IComputationExternalNode node, will not be modified or released
//till the end of the current function
auto rowsSize = Rows.size();
- DefineInputDataArg->SetValue(ctx, ctx.HolderFactory.VectorAsVectorHolder(std::move(Rows)));
+ InputDataArg->SetValue(ctx, ctx.HolderFactory.VectorAsVectorHolder(std::move(Rows)));
for (size_t i = 0; i != rowsSize; ++i) {
CurrentRowIndexArg->SetValue(ctx, NUdf::TUnboxedValuePod(static_cast<ui64>(i)));
for (size_t v = 0; v != Defines.size(); ++v) {
@@ -109,15 +123,19 @@ public:
private:
const NUdf::TUnboxedValue PartitionKey;
IComputationExternalNode* const MatchedVarsArg;
+ const TUnboxedValueVector& VarNames;
+ IComputationExternalNode* const MeasureInputDataArg;
+ const TMeasureInputColumnOrder& MeasureInputColumnOrder;
const TComputationNodePtrVector& Measures;
const TOutputColumnOrder& OutputColumnOrder;
IComputationExternalNode* const CurrentRowIndexArg;
- IComputationExternalNode* const DefineInputDataArg;
+ IComputationExternalNode* const InputDataArg;
const TComputationNodePtrVector& Defines;
const TContainerCacheOnContext& Cache;
TUnboxedValueVector Rows;
TMatchedVars CurMatchedVars;
std::deque<TMatchedVars> Matches;
+ ui64 MatchNumber;
};
class TStreamingMatchRecognize: public IProcessMatchRecognize {
@@ -168,7 +186,7 @@ public:
NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) override {
if (!HasMatch)
return NUdf::TUnboxedValue{};
- MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue>(MatchedVars));
+ MatchedVarsArg->SetValue(ctx, ToValue(ctx, MatchedVars));
HasMatch = false;
NUdf::TUnboxedValue *itemsPtr = nullptr;
const auto result = Cache.NewArray(ctx, OutputColumnOrder.size(), itemsPtr);
@@ -211,10 +229,13 @@ public:
IComputationNode *partitionKey,
TType* partitionKeyType,
IComputationExternalNode* matchedVarsArg,
+ const TUnboxedValueVector& varNames,
+ IComputationExternalNode* measureInputDataArg,
+ const TMeasureInputColumnOrder& measureInputColumnOrder,
const TComputationNodePtrVector& measures,
TOutputColumnOrder&& outputColumnOrder,
IComputationExternalNode* currentRowIndexArg,
- IComputationExternalNode* defineInputDataArg,
+ IComputationExternalNode* inputDataArg,
const TComputationNodePtrVector& defines
)
:TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded)
@@ -223,10 +244,13 @@ public:
, PartitionKey(partitionKey)
, PartitionKeyType(partitionKeyType)
, MatchedVarsArg(matchedVarsArg)
+ , VarNames(varNames)
+ , MeasureInputDataArg(measureInputDataArg)
+ , MeasureInputColumnOrder(measureInputColumnOrder)
, Measures(measures)
, OutputColumnOrder(outputColumnOrder)
, CurrentRowIndexArg(currentRowIndexArg)
- , DefineInputDataArg(defineInputDataArg)
+ , inputDataArg(inputDataArg)
, Defines(defines)
, Cache(mutables)
{}
@@ -238,10 +262,13 @@ public:
PartitionKey,
PartitionKeyType,
MatchedVarsArg,
+ VarNames,
+ MeasureInputDataArg,
+ MeasureInputColumnOrder,
Measures,
OutputColumnOrder,
CurrentRowIndexArg,
- DefineInputDataArg,
+ inputDataArg,
Defines,
Cache
);
@@ -271,10 +298,13 @@ private:
IComputationNode* partitionKey,
TType* partitionKeyType,
IComputationExternalNode* matchedVarsArg,
+ const TUnboxedValueVector& varNames,
+ IComputationExternalNode* measureInputDataArg,
+ const TMeasureInputColumnOrder& measureInputColumnOrder,
const TComputationNodePtrVector& measures,
const TOutputColumnOrder& outputColumnOrder,
IComputationExternalNode* currentRowIndexArg,
- IComputationExternalNode* defineInputDataArg,
+ IComputationExternalNode* inputDataArg,
const TComputationNodePtrVector& defines,
const TContainerCacheOnContext& cache
)
@@ -283,10 +313,13 @@ private:
, PartitionKey(partitionKey)
, PartitionKeyPacker(true, partitionKeyType)
, MatchedVarsArg(matchedVarsArg)
+ , VarNames(varNames)
+ , MeasureInputDataArg(measureInputDataArg)
+ , MeasureInputColumnOrder(measureInputColumnOrder)
, Measures(measures)
, OutputColumnOrder(outputColumnOrder)
, CurrentRowIndexArg(currentRowIndexArg)
- , DefineInputDataArg(defineInputDataArg)
+ , InputDataArg(inputDataArg)
, Defines(defines)
, Cache(cache)
{
@@ -334,10 +367,13 @@ private:
return Partitions.emplace_hint(it, TString(packedKey), std::make_unique<TBackTrackingMatchRecognize>(
std::move(partitionKey),
MatchedVarsArg,
+ VarNames,
+ MeasureInputDataArg,
+ MeasureInputColumnOrder,
Measures,
OutputColumnOrder,
CurrentRowIndexArg,
- DefineInputDataArg,
+ InputDataArg,
Defines,
Cache
));
@@ -357,10 +393,13 @@ private:
//to be passed to partitions
IComputationExternalNode* const MatchedVarsArg;
+ const TUnboxedValueVector& VarNames;
+ IComputationExternalNode* MeasureInputDataArg;
+ const TMeasureInputColumnOrder& MeasureInputColumnOrder;
TComputationNodePtrVector Measures;
const TOutputColumnOrder& OutputColumnOrder;
IComputationExternalNode* const CurrentRowIndexArg;
- IComputationExternalNode* const DefineInputDataArg;
+ IComputationExternalNode* const InputDataArg;
const TComputationNodePtrVector& Defines;
const TContainerCacheOnContext& Cache;
};
@@ -370,8 +409,9 @@ private:
if (const auto flow = FlowDependsOn(InputFlow)) {
Own(flow, InputRowArg);
Own(flow, MatchedVarsArg);
+ Own(flow, MeasureInputDataArg);
Own(flow, CurrentRowIndexArg);
- Own(flow, DefineInputDataArg);
+ Own(flow, inputDataArg);
DependsOn(flow, PartitionKey);
for (auto& m: Measures) {
DependsOn(flow, m);
@@ -387,10 +427,13 @@ private:
IComputationNode* const PartitionKey;
TType* const PartitionKeyType;
IComputationExternalNode* const MatchedVarsArg;
+ const TUnboxedValueVector VarNames;
+ IComputationExternalNode* const MeasureInputDataArg;
+ const TMeasureInputColumnOrder MeasureInputColumnOrder;
const TComputationNodePtrVector Measures;
const TOutputColumnOrder OutputColumnOrder;
IComputationExternalNode* const CurrentRowIndexArg;
- IComputationExternalNode* const DefineInputDataArg;
+ IComputationExternalNode* const inputDataArg;
const TComputationNodePtrVector Defines;
const TContainerCacheOnContext Cache;
};
@@ -446,6 +489,28 @@ TRowPattern ConvertPattern(const TRuntimeNode& pattern) {
return result;
}
+TMeasureInputColumnOrder GetMeasureColumnOrder(const TListLiteral& specialColumnIndexes, ui32 inputRowColumnCount) {
+ using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns;
+ //Use Last enum value to denote that c colum comes from the input table
+ TMeasureInputColumnOrder result(inputRowColumnCount + specialColumnIndexes.GetItemsCount(), std::make_pair(EMeasureInputDataSpecialColumns::Last, 0));
+ if (specialColumnIndexes.GetItemsCount() != 0) {
+ MKQL_ENSURE(specialColumnIndexes.GetItemsCount() == static_cast<size_t>(EMeasureInputDataSpecialColumns::Last),
+ "Internal logic error");
+ for (size_t i = 0; i != specialColumnIndexes.GetItemsCount(); ++i) {
+ auto ind = AS_VALUE(TDataLiteral, specialColumnIndexes.GetItems()[i])->AsValue().Get<ui32>();
+ result[ind] = std::make_pair(static_cast<EMeasureInputDataSpecialColumns>(i), 0);
+ }
+ }
+ //update indexes for input table columns
+ ui32 inputIdx = 0;
+ for (auto& [t, i]: result) {
+ if (EMeasureInputDataSpecialColumns::Last == t) {
+ i = inputIdx++;
+ }
+ }
+ return result;
+}
+
TComputationNodePtrVector ConvertVectorOfCallables(const TRuntimeNode::TList& v, const TComputationNodeFactoryContext& ctx) {
TComputationNodePtrVector result;
result.reserve(v.size());
@@ -455,6 +520,16 @@ TComputationNodePtrVector ConvertVectorOfCallables(const TRuntimeNode::TList& v,
return result;
}
+TUnboxedValueVector ConvertListOfStrings(const TRuntimeNode& l) {
+ TUnboxedValueVector result;
+ const auto& list = AS_VALUE(TListLiteral, l);
+ result.reserve(list->GetItemsCount());
+ for (ui32 i = 0; i != list->GetItemsCount(); ++i) {
+ result.push_back(MakeString(AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().AsStringRef()));
+ }
+ return result;
+}
+
} //namespace NMatchRecognize
@@ -477,10 +552,10 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
}
const auto& pattern = callable.GetInput(inputIndex++);
const auto& currentRowIndexArg = callable.GetInput(inputIndex++);
- const auto& defineInputDataArg = callable.GetInput(inputIndex++);
- const auto& defineNames = callable.GetInput(inputIndex++);
+ const auto& inputDataArg = callable.GetInput(inputIndex++);
+ const auto& varNames = callable.GetInput(inputIndex++);
TRuntimeNode::TList defines;
- for (size_t i = 0; i != AS_VALUE(TListLiteral, defineNames)->GetItemsCount(); ++i) {
+ for (size_t i = 0; i != AS_VALUE(TListLiteral, varNames)->GetItemsCount(); ++i) {
defines.push_back(callable.GetInput(inputIndex++));
}
MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count");
@@ -497,10 +572,16 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
, LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
, partitionKeySelector.GetStaticType()
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode()))
+ , ConvertListOfStrings(varNames)
+ , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode()))
+ , GetMeasureColumnOrder(
+ *AS_VALUE(TListLiteral, measureSpecialColumnIndexes),
+ AS_VALUE(TDataLiteral, inputRowColumnCount)->AsValue().Get<ui32>()
+ )
, ConvertVectorOfCallables(measures, ctx)
, GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes)
, static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode()))
- , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *defineInputDataArg.GetNode()))
+ , static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode()))
, ConvertVectorOfCallables(defines, ctx)
);
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_measure_arg.h b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_measure_arg.h
new file mode 100644
index 0000000000..54d328075d
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_measure_arg.h
@@ -0,0 +1,127 @@
+#pragma once
+
+#include "mkql_match_recognize_matched_vars.h"
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h>
+#include <ydb/library/yql/core/sql_types/match_recognize.h>
+#include <ydb/library/yql/minikql/mkql_string_util.h>
+
+namespace NKikimr::NMiniKQL::NMatchRecognize {
+
+using NYql::NMatchRecognize::EMeasureInputDataSpecialColumns;
+
+using TMeasureInputColumnOrder = TVector<std::pair<EMeasureInputDataSpecialColumns, size_t>>;
+
+//Input row augmented with lightweight special columns for calculating MEASURE lambdas
+class TRowForMeasureValue: public TComputationValue<TRowForMeasureValue>
+{
+public:
+ TRowForMeasureValue(
+ TMemoryUsageInfo* memInfo,
+ NUdf::TUnboxedValue inputRow,
+ ui64 rowIndex,
+ const TMeasureInputColumnOrder& columnOrder,
+ const NUdf::TUnboxedValue& matchedVars,
+ const TUnboxedValueVector& varNames,
+ ui64 matchNumber
+ )
+ : TComputationValue<TRowForMeasureValue>(memInfo)
+ , InputRow(inputRow)
+ , RowIndex(rowIndex)
+ , ColumnOrder(columnOrder)
+ , MatchedVars(matchedVars)
+ , VarNames(varNames)
+ , MatchNumber(matchNumber)
+ {}
+ NUdf::TUnboxedValue GetElement(ui32 index) const override {
+ switch(ColumnOrder[index].first) {
+ case EMeasureInputDataSpecialColumns::Classifier: {
+ auto varIterator = MatchedVars.GetListIterator();
+ MKQL_ENSURE(varIterator, "Internal logic error");
+ NUdf::TUnboxedValue var;
+ size_t varIndex = 0;
+ while(varIterator.Next(var)) {
+ auto rangeIterator = var.GetListIterator();
+ MKQL_ENSURE(varIterator, "Internal logic error");
+ NUdf::TUnboxedValue range;
+ while(rangeIterator.Next(range)) {
+ const auto from = range.GetElement(0).Get<ui64>();
+ const auto to = range.GetElement(1).Get<ui64>();
+ if (RowIndex >= from and RowIndex <= to) {
+ return VarNames[varIndex];
+ }
+ }
+ ++varIndex;
+ }
+ MKQL_ENSURE(MatchedVars.GetListLength() == varIndex, "Internal logic error");
+ return MakeString("");
+ }
+ case EMeasureInputDataSpecialColumns::MatchNumber:
+ return NUdf::TUnboxedValuePod(MatchNumber);
+ case EMeasureInputDataSpecialColumns::Last: //Last corresponds to columns from the input table row
+ return InputRow.GetElement(ColumnOrder[index].second);
+ }
+ }
+private:
+ const NUdf::TUnboxedValue InputRow;
+ const ui64 RowIndex;
+ const TMeasureInputColumnOrder& ColumnOrder;
+ const NUdf::TUnboxedValue& MatchedVars;
+ const TUnboxedValueVector& VarNames;
+ ui64 MatchNumber;
+};
+
+class TMeasureInputDataValue: public TComputationValue<TMeasureInputDataValue> {
+ using Base = TComputationValue<TMeasureInputDataValue>;
+public:
+ TMeasureInputDataValue(TMemoryUsageInfo* memInfo,
+ const NUdf::TUnboxedValue& inputData,
+ const TMeasureInputColumnOrder& columnOrder,
+ const NUdf::TUnboxedValue& matchedVars,
+ const TUnboxedValueVector& varNames,
+ ui64 matchNumber)
+ : Base(memInfo)
+ , InputData(inputData)
+ , ColumnOrder(columnOrder)
+ , MatchedVars(matchedVars)
+ , VarNames(varNames)
+ , MatchNumber(matchNumber)
+ {}
+
+ bool HasFastListLength() const override {
+ return InputData.HasFastListLength();
+ }
+
+ ui64 GetListLength() const override {
+ return InputData.GetListLength();
+ }
+
+ //TODO https://st.yandex-team.ru/YQL-16508
+ //NUdf::TUnboxedValue GetListIterator() const override;
+
+ NUdf::IBoxedValuePtr ToIndexDictImpl(const NUdf::IValueBuilder& builder) const override {
+ Y_UNUSED(builder);
+ return const_cast<TMeasureInputDataValue*>(this);
+ }
+
+ NUdf::TUnboxedValue Lookup(const NUdf::TUnboxedValuePod& key) const override {
+ auto inputRow = InputData.Lookup(key);
+ return NUdf::TUnboxedValuePod{new TRowForMeasureValue(
+ GetMemInfo(),
+ inputRow,
+ key.Get<ui64>(),
+ ColumnOrder,
+ MatchedVars,
+ VarNames,
+ MatchNumber
+ )};
+ }
+private:
+ const NUdf::TUnboxedValue InputData;
+ const TMeasureInputColumnOrder& ColumnOrder;
+ const NUdf::TUnboxedValue MatchedVars;
+ const TUnboxedValueVector& VarNames;
+ const ui64 MatchNumber;
+};
+
+}//namespace NKikimr::NMiniKQL::NMatchRecognize
+
diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp
index 07a2841f3a..3b2eceee7a 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_program_builder.cpp
@@ -5930,12 +5930,12 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
std::vector<TRuntimeNode> defineNames(patternVarLookup.size());
std::vector<TRuntimeNode> defineNodes(patternVarLookup.size());
- const auto& defineInputDataArg = Arg(TListType::Create(inputRowType, Env));
+ const auto& inputDataArg = Arg(TListType::Create(inputRowType, Env));
const auto& currentRowIndexArg = Arg(TDataType::Create(NUdf::TDataType<ui64>::Id, Env));
for (const auto& [v, i]: patternVarLookup) {
defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v);
if (const auto it = defineLookup.find(v); it != defineLookup.end()) {
- defineNodes[i] = getDefines[it->second].second(defineInputDataArg, matchedVarsArg, currentRowIndexArg);
+ defineNodes[i] = getDefines[it->second].second(inputDataArg, matchedVarsArg, currentRowIndexArg);
}
else { //no predicate for var
if ("$" == v || "^" == v) {
@@ -5970,7 +5970,7 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore(
callableBuilder.Add(PatternToRuntimeNode(pattern, *this));
callableBuilder.Add(currentRowIndexArg);
- callableBuilder.Add(defineInputDataArg);
+ callableBuilder.Add(inputDataArg);
const auto stringType = NewDataType(NUdf::EDataSlot::String);
callableBuilder.Add(TRuntimeNode(TListLiteral::Create(defineNames.begin(), defineNames.size(), TListType::Create(stringType, Env), Env), true));
for (const auto& d: defineNodes) {