aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@yandex-team.ru>2022-03-17 16:43:06 +0300
committervvvv <vvvv@yandex-team.ru>2022-03-17 16:43:06 +0300
commit65cc3636e4a6df0b569be1fde6c5aad0d8fa4d55 (patch)
treeeb31d9b392f1d8771590e93397c460b7bc6cd700
parentd62a2feb1b4868615e72ece28cde5d7b70fa4243 (diff)
downloadydb-65cc3636e4a6df0b569be1fde6c5aad0d8fa4d55.tar.gz
YQL-13710 support of aggregations in pg_catalog
ref:6a4187e92350b626ce44922f35920e48b58ecea0
-rw-r--r--ydb/library/yql/parser/pg_catalog/CMakeLists.txt8
-rw-r--r--ydb/library/yql/parser/pg_catalog/catalog.cpp257
-rw-r--r--ydb/library/yql/parser/pg_catalog/catalog.h23
-rw-r--r--ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp60
4 files changed, 328 insertions, 20 deletions
diff --git a/ydb/library/yql/parser/pg_catalog/CMakeLists.txt b/ydb/library/yql/parser/pg_catalog/CMakeLists.txt
index c4eed486c1..768259591b 100644
--- a/ydb/library/yql/parser/pg_catalog/CMakeLists.txt
+++ b/ydb/library/yql/parser/pg_catalog/CMakeLists.txt
@@ -28,6 +28,7 @@ target_sources(yql-parser-pg_catalog.global PRIVATE
${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/f1a4bc0ed0162412cffb7ad89af053b9.cpp
${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/1ecd65c896f36a3990d870644a6da9c8.cpp
${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/46b25697572e7e60703079b71cd18295.cpp
+ ${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/15e039bc35a86f84476091e328dd74ea.cpp
)
resources(yql-parser-pg_catalog.global
${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/99bdd9e510834dd355c2457fcccc53d5.cpp
@@ -57,3 +58,10 @@ resources(yql-parser-pg_catalog.global
KEYS
pg_cast.dat
)
+resources(yql-parser-pg_catalog.global
+ ${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/15e039bc35a86f84476091e328dd74ea.cpp
+ INPUTS
+ ${CMAKE_SOURCE_DIR}/contrib/libs/postgresql/src/include/catalog/pg_aggregate.dat
+ KEYS
+ pg_aggregate.dat
+)
diff --git a/ydb/library/yql/parser/pg_catalog/catalog.cpp b/ydb/library/yql/parser/pg_catalog/catalog.cpp
index ec856c801f..e9f049dcd4 100644
--- a/ydb/library/yql/parser/pg_catalog/catalog.cpp
+++ b/ydb/library/yql/parser/pg_catalog/catalog.cpp
@@ -16,6 +16,8 @@ using TTypes = THashMap<ui32, TTypeDesc>;
using TCasts = THashMap<ui32, TCastDesc>;
+using TAggregations = THashMap<ui32, TAggregateDesc>;
+
class TParser {
public:
void Do(const TString& dat) {
@@ -383,8 +385,8 @@ public:
void OnFinish() override {
if (IsSupported) {
- LastCast.CastId = 1 + Casts.size();
- Casts[LastCast.CastId] = LastCast;
+ auto id = 1 + Casts.size();
+ Casts[id] = LastCast;
}
LastCast = TCastDesc();
@@ -400,6 +402,191 @@ private:
bool IsSupported = true;
};
+class TAggregationsParser : public TParser {
+public:
+ TAggregationsParser(TAggregations& aggregations, const THashMap<TString, ui32>& typeByName,
+ const THashMap<TString, TVector<ui32>>& procByName, const TProcs& procs)
+ : Aggregations(aggregations)
+ , TypeByName(typeByName)
+ , ProcByName(procByName)
+ , Procs(procs)
+ {}
+
+ void OnKey(const TString& key, const TString& value) override {
+ Y_UNUSED(ProcByName);
+ if (key == "aggtranstype") {
+ auto typeId = TypeByName.FindPtr(value);
+ Y_ENSURE(typeId);
+ LastAggregation.TransTypeId = *typeId;
+ } else if (key == "aggfnoid") {
+ LastOid = value;
+ } else if (key == "aggtransfn") {
+ LastTransFunc = value;
+ } else if (key == "aggfinalfn") {
+ LastFinalFunc = value;
+ } else if (key == "aggcombinefn") {
+ LastCombineFunc = value;
+ } else if (key == "aggserialfn") {
+ LastSerializeFunc = value;
+ } else if (key == "aggdeserialfn") {
+ LastDeserializeFunc = value;
+ } else if (key == "aggkind") {
+ if (value == "n") {
+ LastAggregation.Kind = EAggKind::Normal;
+ } else if (value == "o") {
+ LastAggregation.Kind = EAggKind::OrderedSet;
+ } else if (value == "h") {
+ LastAggregation.Kind = EAggKind::Hypothetical;
+ } else {
+ ythrow yexception() << "Unknown aggkind value: " << value;
+ }
+ }
+ }
+
+ void OnFinish() override {
+ if (IsSupported) {
+ if (FillSupported()) {
+ auto id = Aggregations.size() + 1;
+ Aggregations[id] = LastAggregation;
+ }
+ }
+
+ LastAggregation = TAggregateDesc();
+ IsSupported = true;
+ LastOid = "";
+ LastTransFunc = "";
+ LastFinalFunc = "";
+ LastCombineFunc = "";
+ LastSerializeFunc = "";
+ LastDeserializeFunc = "";
+ }
+
+ bool FillSupported() {
+ Y_ENSURE(LastAggregation.TransTypeId);
+ Y_ENSURE(LastOid);
+ Y_ENSURE(LastTransFunc);
+ auto transFuncIdsPtr = ProcByName.FindPtr(LastTransFunc);
+ if (!transFuncIdsPtr) {
+ // e.g. variadic ordered_set_transition_multi
+ return false;
+ }
+
+ for (const auto id : *transFuncIdsPtr) {
+ auto procPtr = Procs.FindPtr(id);
+ Y_ENSURE(procPtr);
+ if (procPtr->ArgTypes.size() >= 1 && procPtr->ArgTypes[0] == LastAggregation.TransTypeId) {
+ Y_ENSURE(!LastAggregation.TransFuncId);
+ LastAggregation.TransFuncId = id;
+ }
+ }
+
+ Y_ENSURE(LastAggregation.TransFuncId);
+
+ // oid format: name(arg1,arg2...)
+ auto pos1 = LastOid.find('(');
+ if (pos1 != TString::npos) {
+ LastAggregation.Name = LastOid.substr(0, pos1);
+ auto pos = pos1 + 1;
+ for (;;) {
+ auto nextPos = Min(LastOid.find(',', pos), LastOid.find(')', pos));
+ Y_ENSURE(nextPos != TString::npos);
+ if (pos == nextPos) {
+ break;
+ }
+
+ auto arg = LastOid.substr(pos, nextPos - pos);
+ auto argTypeId = TypeByName.FindPtr(arg);
+ Y_ENSURE(argTypeId);
+ LastAggregation.ArgTypes.push_back(*argTypeId);
+ pos = nextPos;
+ if (LastOid[pos] == ')') {
+ break;
+ } else {
+ ++pos;
+ }
+ }
+ } else {
+ // no signature in oid, use transfunc
+ LastAggregation.Name = LastOid;
+ auto procPtr = Procs.FindPtr(LastAggregation.TransFuncId);
+ Y_ENSURE(procPtr);
+ LastAggregation.ArgTypes = procPtr->ArgTypes;
+ Y_ENSURE(LastAggregation.ArgTypes.size() >= 1);
+ Y_ENSURE(LastAggregation.ArgTypes[0] == LastAggregation.TransTypeId);
+ LastAggregation.ArgTypes.erase(LastAggregation.ArgTypes.begin());
+ }
+
+ Y_ENSURE(!LastAggregation.Name.empty());
+ if (!ResolveFunc(LastFinalFunc, LastAggregation.FinalFuncId, 1)) {
+ return false;
+ }
+
+ if (!ResolveFunc(LastCombineFunc, LastAggregation.CombineFuncId, 2)) {
+ return false;
+ }
+
+ if (!ResolveFunc(LastSerializeFunc, LastAggregation.SerializeFuncId, 1)) {
+ return false;
+ }
+
+ if (!ResolveFunc(LastDeserializeFunc, LastAggregation.DeserializeFuncId, 0)) {
+ return false;
+ }
+
+ return true;
+ }
+
+ bool ResolveFunc(const TString& name, ui32& funcId, ui32 stateArgsCount) {
+ if (name) {
+ auto funcIdsPtr = ProcByName.FindPtr(name);
+ if (!funcIdsPtr) {
+ return false;
+ }
+
+ if (!stateArgsCount) {
+ Y_ENSURE(funcIdsPtr->size() == 1);
+ }
+
+ for (const auto id : *funcIdsPtr) {
+ auto procPtr = Procs.FindPtr(id);
+ Y_ENSURE(procPtr);
+ bool found = true;
+ if (stateArgsCount > 0 && procPtr->ArgTypes.size() == stateArgsCount) {
+ for (ui32 i = 0; i < stateArgsCount; ++i) {
+ if (procPtr->ArgTypes[i] != LastAggregation.TransTypeId) {
+ found = false;
+ break;
+ }
+ }
+ }
+
+ if (found) {
+ Y_ENSURE(!funcId);
+ funcId = id;
+ }
+ }
+
+ Y_ENSURE(funcId);
+ }
+
+ return true;
+ }
+
+private:
+ TAggregations& Aggregations;
+ const THashMap<TString, ui32>& TypeByName;
+ const THashMap<TString, TVector<ui32>>& ProcByName;
+ const TProcs& Procs;
+ TAggregateDesc LastAggregation;
+ bool IsSupported = true;
+ TString LastOid;
+ TString LastTransFunc;
+ TString LastFinalFunc;
+ TString LastCombineFunc;
+ TString LastSerializeFunc;
+ TString LastDeserializeFunc;
+};
+
TOperators ParseOperators(const TString& dat, const THashMap<TString, ui32>& typeByName,
const THashMap<TString, TVector<ui32>>& procByName) {
TOperators ret;
@@ -408,6 +595,14 @@ TOperators ParseOperators(const TString& dat, const THashMap<TString, ui32>& typ
return ret;
}
+TAggregations ParseAggregations(const TString& dat, const THashMap<TString, ui32>& typeByName,
+ const THashMap<TString, TVector<ui32>>& procByName, const TProcs& procs) {
+ TAggregations ret;
+ TAggregationsParser parser(ret, typeByName, procByName, procs);
+ parser.Do(dat);
+ return ret;
+}
+
TProcs ParseProcs(const TString& dat, const THashMap<TString, ui32>& typeByName) {
TProcs ret;
TProcsParser parser(ret, typeByName);
@@ -440,6 +635,8 @@ struct TCatalog {
Y_ENSURE(NResource::FindExact("pg_proc.dat", &procData));
TString castData;
Y_ENSURE(NResource::FindExact("pg_cast.dat", &castData));
+ TString aggData;
+ Y_ENSURE(NResource::FindExact("pg_aggregate.dat", &aggData));
THashMap<ui32, TLazyTypeInfo> lazyTypeInfos;
Types = ParseTypes(typeData, lazyTypeInfos);
for (const auto& [k, v] : Types) {
@@ -527,6 +724,11 @@ struct TCatalog {
for (const auto&[k, v] : Operators) {
OperatorsByName[v.Name].push_back(k);
}
+
+ Aggregations = ParseAggregations(aggData, TypeByName, ProcByName, Procs);
+ for (const auto&[k, v] : Aggregations) {
+ AggregationsByName[v.Name].push_back(k);
+ }
}
static const TCatalog& Instance() {
@@ -537,14 +739,16 @@ struct TCatalog {
TProcs Procs;
TTypes Types;
TCasts Casts;
+ TAggregations Aggregations;
THashMap<TString, TVector<ui32>> ProcByName;
THashMap<TString, ui32> TypeByName;
THashMap<std::pair<ui32, ui32>, ui32> CastsByDir;
THashMap<TString, TVector<ui32>> OperatorsByName;
+ THashMap<TString, TVector<ui32>> AggregationsByName;
};
-bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) {
- if (argTypeIds.size() != d.ArgTypes.size()) {
+bool ValidateArgs(const TVector<ui32>& descArgTypeIds, const TVector<ui32>& argTypeIds) {
+ if (argTypeIds.size() != descArgTypeIds.size()) {
return false;
}
@@ -554,7 +758,7 @@ bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) {
continue; // NULL
}
- if (argTypeIds[i] != d.ArgTypes[i]) {
+ if (argTypeIds[i] != descArgTypeIds[i]) {
found = false;
break;
}
@@ -563,6 +767,10 @@ bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) {
return found;
}
+bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) {
+ return ValidateArgs(d.ArgTypes, argTypeIds);
+}
+
const TProcDesc& LookupProc(ui32 procId, const TVector<ui32>& argTypeIds) {
const auto& catalog = TCatalog::Instance();
auto procPtr = catalog.Procs.FindPtr(procId);
@@ -651,16 +859,6 @@ const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId) {
return *castPtr;
}
-const TCastDesc& LookupCast(ui32 castId) {
- const auto& catalog = TCatalog::Instance();
- auto castPtr = catalog.Casts.FindPtr(castId);
- if (!castPtr) {
- throw yexception() << "No such cast: " << castId;
- }
-
- return *castPtr;
-}
-
bool ValidateOperArgs(const TOperDesc& d, const TVector<ui32>& argTypeIds) {
ui32 size = d.Kind == EOperKind::Binary ? 2 : 1;
if (argTypeIds.size() != size) {
@@ -733,4 +931,33 @@ const TOperDesc& LookupOper(ui32 operId) {
return *operPtr;
}
+bool HasAggregation(const TStringBuf& name) {
+ const auto& catalog = TCatalog::Instance();
+ return catalog.AggregationsByName.contains(name);
+}
+
+bool ValidateAggregateArgs(const TAggregateDesc& d, const TVector<ui32>& argTypeIds) {
+ return ValidateArgs(d.ArgTypes, argTypeIds);
+}
+
+const TAggregateDesc& LookupAggregation(const TStringBuf& name, const TVector<ui32>& argTypeIds) {
+ const auto& catalog = TCatalog::Instance();
+ auto aggIdPtr = catalog.AggregationsByName.FindPtr(name);
+ if (!aggIdPtr) {
+ throw yexception() << "No such aggregate: " << name;
+ }
+
+ for (const auto& id : *aggIdPtr) {
+ const auto& d = catalog.Aggregations.FindPtr(id);
+ Y_ENSURE(d);
+ if (!ValidateAggregateArgs(*d, argTypeIds)) {
+ continue;
+ }
+
+ return *d;
+ }
+
+ throw yexception() << "Unable to find an overload for aggregate " << name << " with given argument types";
+}
+
}
diff --git a/ydb/library/yql/parser/pg_catalog/catalog.h b/ydb/library/yql/parser/pg_catalog/catalog.h
index a2c86a439e..538db1375f 100644
--- a/ydb/library/yql/parser/pg_catalog/catalog.h
+++ b/ydb/library/yql/parser/pg_catalog/catalog.h
@@ -51,13 +51,30 @@ enum class ECastMethod {
};
struct TCastDesc {
- ui32 CastId = 0;
ui32 SourceId = 0;
ui32 TargetId = 0;
ECastMethod Method = ECastMethod::Function;
ui32 FunctionId = 0;
};
+enum class EAggKind {
+ Normal,
+ OrderedSet,
+ Hypothetical
+};
+
+struct TAggregateDesc {
+ TString Name;
+ TVector<ui32> ArgTypes;
+ EAggKind Kind = EAggKind::Normal;
+ ui32 TransTypeId = 0;
+ ui32 TransFuncId = 0;
+ ui32 FinalFuncId = 0;
+ ui32 CombineFuncId = 0;
+ ui32 SerializeFuncId = 0;
+ ui32 DeserializeFuncId = 0;
+};
+
const TProcDesc& LookupProc(const TString& name, const TVector<ui32>& argTypeIds);
const TProcDesc& LookupProc(ui32 procId, const TVector<ui32>& argTypeIds);
const TProcDesc& LookupProc(ui32 procId);
@@ -68,10 +85,12 @@ const TTypeDesc& LookupType(ui32 typeId);
bool HasCast(ui32 sourceId, ui32 targetId);
const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId);
-const TCastDesc& LookupCast(ui32 castId);
const TOperDesc& LookupOper(const TString& name, const TVector<ui32>& argTypeIds);
const TOperDesc& LookupOper(ui32 operId, const TVector<ui32>& argTypeIds);
const TOperDesc& LookupOper(ui32 operId);
+bool HasAggregation(const TStringBuf& name);
+const TAggregateDesc& LookupAggregation(const TStringBuf& name, const TVector<ui32>& argTypeIds);
+
}
diff --git a/ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp b/ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp
index 17d50ed0ae..28539412a5 100644
--- a/ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp
+++ b/ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp
@@ -34,10 +34,10 @@ Y_UNIT_TEST_SUITE(TTypesTests) {
UNIT_ASSERT_VALUES_EQUAL(ret.ElementTypeId, LookupType("float8").TypeId);
ret = LookupType(1009);
- UNIT_ASSERT_VALUES_EQUAL(ret.TypeId, 25);
+ UNIT_ASSERT_VALUES_EQUAL(ret.TypeId, 1009);
UNIT_ASSERT_VALUES_EQUAL(ret.ArrayTypeId, 1009);
- UNIT_ASSERT_VALUES_EQUAL(ret.Name, "text");
- UNIT_ASSERT_VALUES_EQUAL(ret.ElementTypeId, 0);
+ UNIT_ASSERT_VALUES_EQUAL(ret.Name, "_text");
+ UNIT_ASSERT_VALUES_EQUAL(ret.ElementTypeId, 25);
}
}
@@ -99,3 +99,57 @@ Y_UNIT_TEST_SUITE(TCastsTests) {
UNIT_ASSERT_VALUES_EQUAL(ret.FunctionId, 0);
}
}
+
+Y_UNIT_TEST_SUITE(TAggregationsTests) {
+ Y_UNIT_TEST(TestMissing) {
+ UNIT_ASSERT_EXCEPTION(LookupAggregation("foo", {}), yexception);
+ }
+
+ Y_UNIT_TEST(TestOk) {
+ auto ret = LookupAggregation("sum", {LookupType("int4").TypeId});
+ UNIT_ASSERT_VALUES_EQUAL(ret.TransTypeId, LookupType("int8").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(ret.Name, "sum");
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes.size(), 1);
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[0], LookupType("int4").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.TransFuncId).Name, "int4_sum");
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.CombineFuncId).Name, "int8pl");
+ UNIT_ASSERT_VALUES_EQUAL(ret.FinalFuncId, 0);
+ UNIT_ASSERT_VALUES_EQUAL(ret.SerializeFuncId, 0);
+ UNIT_ASSERT_VALUES_EQUAL(ret.DeserializeFuncId, 0);
+
+ ret = LookupAggregation("sum", {LookupType("int8").TypeId});
+ UNIT_ASSERT_VALUES_EQUAL(ret.TransTypeId, LookupType("internal").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(ret.Name, "sum");
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes.size(), 1);
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[0], LookupType("int8").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.TransFuncId).Name, "int8_avg_accum");
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.CombineFuncId).Name, "int8_avg_combine");
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.FinalFuncId).Name, "numeric_poly_sum");
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.SerializeFuncId).Name, "int8_avg_serialize");
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.DeserializeFuncId).Name, "int8_avg_deserialize");
+
+ ret = LookupAggregation("string_agg", {LookupType("text").TypeId, LookupType("text").TypeId});
+ UNIT_ASSERT_VALUES_EQUAL(ret.TransTypeId, LookupType("internal").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(ret.Name, "string_agg");
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes.size(), 2);
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[0], LookupType("text").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[1], LookupType("text").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.TransFuncId).Name, "string_agg_transfn");
+ UNIT_ASSERT_VALUES_EQUAL(ret.CombineFuncId, 0);
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.FinalFuncId).Name, "string_agg_finalfn");
+ UNIT_ASSERT_VALUES_EQUAL(ret.SerializeFuncId, 0);
+ UNIT_ASSERT_VALUES_EQUAL(ret.DeserializeFuncId, 0);
+
+ ret = LookupAggregation("regr_count", {LookupType("float8").TypeId, LookupType("float8").TypeId});
+ UNIT_ASSERT_VALUES_EQUAL(ret.TransTypeId, LookupType("int8").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(ret.Name, "regr_count");
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes.size(), 2);
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[0], LookupType("float8").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[1], LookupType("float8").TypeId);
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.TransFuncId).Name, "int8inc_float8_float8");
+ UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.CombineFuncId).Name, "int8pl");
+ UNIT_ASSERT_VALUES_EQUAL(ret.FinalFuncId, 0);
+ UNIT_ASSERT_VALUES_EQUAL(ret.SerializeFuncId, 0);
+ UNIT_ASSERT_VALUES_EQUAL(ret.DeserializeFuncId, 0);
+ }
+}