diff options
author | vvvv <vvvv@yandex-team.ru> | 2022-03-17 16:43:06 +0300 |
---|---|---|
committer | vvvv <vvvv@yandex-team.ru> | 2022-03-17 16:43:06 +0300 |
commit | 65cc3636e4a6df0b569be1fde6c5aad0d8fa4d55 (patch) | |
tree | eb31d9b392f1d8771590e93397c460b7bc6cd700 | |
parent | d62a2feb1b4868615e72ece28cde5d7b70fa4243 (diff) | |
download | ydb-65cc3636e4a6df0b569be1fde6c5aad0d8fa4d55.tar.gz |
YQL-13710 support of aggregations in pg_catalog
ref:6a4187e92350b626ce44922f35920e48b58ecea0
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/CMakeLists.txt | 8 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/catalog.cpp | 257 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/catalog.h | 23 | ||||
-rw-r--r-- | ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp | 60 |
4 files changed, 328 insertions, 20 deletions
diff --git a/ydb/library/yql/parser/pg_catalog/CMakeLists.txt b/ydb/library/yql/parser/pg_catalog/CMakeLists.txt index c4eed486c1..768259591b 100644 --- a/ydb/library/yql/parser/pg_catalog/CMakeLists.txt +++ b/ydb/library/yql/parser/pg_catalog/CMakeLists.txt @@ -28,6 +28,7 @@ target_sources(yql-parser-pg_catalog.global PRIVATE ${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/f1a4bc0ed0162412cffb7ad89af053b9.cpp ${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/1ecd65c896f36a3990d870644a6da9c8.cpp ${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/46b25697572e7e60703079b71cd18295.cpp + ${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/15e039bc35a86f84476091e328dd74ea.cpp ) resources(yql-parser-pg_catalog.global ${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/99bdd9e510834dd355c2457fcccc53d5.cpp @@ -57,3 +58,10 @@ resources(yql-parser-pg_catalog.global KEYS pg_cast.dat ) +resources(yql-parser-pg_catalog.global + ${CMAKE_BINARY_DIR}/ydb/library/yql/parser/pg_catalog/15e039bc35a86f84476091e328dd74ea.cpp + INPUTS + ${CMAKE_SOURCE_DIR}/contrib/libs/postgresql/src/include/catalog/pg_aggregate.dat + KEYS + pg_aggregate.dat +) diff --git a/ydb/library/yql/parser/pg_catalog/catalog.cpp b/ydb/library/yql/parser/pg_catalog/catalog.cpp index ec856c801f..e9f049dcd4 100644 --- a/ydb/library/yql/parser/pg_catalog/catalog.cpp +++ b/ydb/library/yql/parser/pg_catalog/catalog.cpp @@ -16,6 +16,8 @@ using TTypes = THashMap<ui32, TTypeDesc>; using TCasts = THashMap<ui32, TCastDesc>; +using TAggregations = THashMap<ui32, TAggregateDesc>; + class TParser { public: void Do(const TString& dat) { @@ -383,8 +385,8 @@ public: void OnFinish() override { if (IsSupported) { - LastCast.CastId = 1 + Casts.size(); - Casts[LastCast.CastId] = LastCast; + auto id = 1 + Casts.size(); + Casts[id] = LastCast; } LastCast = TCastDesc(); @@ -400,6 +402,191 @@ private: bool IsSupported = true; }; +class TAggregationsParser : public TParser { +public: + TAggregationsParser(TAggregations& aggregations, const THashMap<TString, ui32>& typeByName, + const THashMap<TString, TVector<ui32>>& procByName, const TProcs& procs) + : Aggregations(aggregations) + , TypeByName(typeByName) + , ProcByName(procByName) + , Procs(procs) + {} + + void OnKey(const TString& key, const TString& value) override { + Y_UNUSED(ProcByName); + if (key == "aggtranstype") { + auto typeId = TypeByName.FindPtr(value); + Y_ENSURE(typeId); + LastAggregation.TransTypeId = *typeId; + } else if (key == "aggfnoid") { + LastOid = value; + } else if (key == "aggtransfn") { + LastTransFunc = value; + } else if (key == "aggfinalfn") { + LastFinalFunc = value; + } else if (key == "aggcombinefn") { + LastCombineFunc = value; + } else if (key == "aggserialfn") { + LastSerializeFunc = value; + } else if (key == "aggdeserialfn") { + LastDeserializeFunc = value; + } else if (key == "aggkind") { + if (value == "n") { + LastAggregation.Kind = EAggKind::Normal; + } else if (value == "o") { + LastAggregation.Kind = EAggKind::OrderedSet; + } else if (value == "h") { + LastAggregation.Kind = EAggKind::Hypothetical; + } else { + ythrow yexception() << "Unknown aggkind value: " << value; + } + } + } + + void OnFinish() override { + if (IsSupported) { + if (FillSupported()) { + auto id = Aggregations.size() + 1; + Aggregations[id] = LastAggregation; + } + } + + LastAggregation = TAggregateDesc(); + IsSupported = true; + LastOid = ""; + LastTransFunc = ""; + LastFinalFunc = ""; + LastCombineFunc = ""; + LastSerializeFunc = ""; + LastDeserializeFunc = ""; + } + + bool FillSupported() { + Y_ENSURE(LastAggregation.TransTypeId); + Y_ENSURE(LastOid); + Y_ENSURE(LastTransFunc); + auto transFuncIdsPtr = ProcByName.FindPtr(LastTransFunc); + if (!transFuncIdsPtr) { + // e.g. variadic ordered_set_transition_multi + return false; + } + + for (const auto id : *transFuncIdsPtr) { + auto procPtr = Procs.FindPtr(id); + Y_ENSURE(procPtr); + if (procPtr->ArgTypes.size() >= 1 && procPtr->ArgTypes[0] == LastAggregation.TransTypeId) { + Y_ENSURE(!LastAggregation.TransFuncId); + LastAggregation.TransFuncId = id; + } + } + + Y_ENSURE(LastAggregation.TransFuncId); + + // oid format: name(arg1,arg2...) + auto pos1 = LastOid.find('('); + if (pos1 != TString::npos) { + LastAggregation.Name = LastOid.substr(0, pos1); + auto pos = pos1 + 1; + for (;;) { + auto nextPos = Min(LastOid.find(',', pos), LastOid.find(')', pos)); + Y_ENSURE(nextPos != TString::npos); + if (pos == nextPos) { + break; + } + + auto arg = LastOid.substr(pos, nextPos - pos); + auto argTypeId = TypeByName.FindPtr(arg); + Y_ENSURE(argTypeId); + LastAggregation.ArgTypes.push_back(*argTypeId); + pos = nextPos; + if (LastOid[pos] == ')') { + break; + } else { + ++pos; + } + } + } else { + // no signature in oid, use transfunc + LastAggregation.Name = LastOid; + auto procPtr = Procs.FindPtr(LastAggregation.TransFuncId); + Y_ENSURE(procPtr); + LastAggregation.ArgTypes = procPtr->ArgTypes; + Y_ENSURE(LastAggregation.ArgTypes.size() >= 1); + Y_ENSURE(LastAggregation.ArgTypes[0] == LastAggregation.TransTypeId); + LastAggregation.ArgTypes.erase(LastAggregation.ArgTypes.begin()); + } + + Y_ENSURE(!LastAggregation.Name.empty()); + if (!ResolveFunc(LastFinalFunc, LastAggregation.FinalFuncId, 1)) { + return false; + } + + if (!ResolveFunc(LastCombineFunc, LastAggregation.CombineFuncId, 2)) { + return false; + } + + if (!ResolveFunc(LastSerializeFunc, LastAggregation.SerializeFuncId, 1)) { + return false; + } + + if (!ResolveFunc(LastDeserializeFunc, LastAggregation.DeserializeFuncId, 0)) { + return false; + } + + return true; + } + + bool ResolveFunc(const TString& name, ui32& funcId, ui32 stateArgsCount) { + if (name) { + auto funcIdsPtr = ProcByName.FindPtr(name); + if (!funcIdsPtr) { + return false; + } + + if (!stateArgsCount) { + Y_ENSURE(funcIdsPtr->size() == 1); + } + + for (const auto id : *funcIdsPtr) { + auto procPtr = Procs.FindPtr(id); + Y_ENSURE(procPtr); + bool found = true; + if (stateArgsCount > 0 && procPtr->ArgTypes.size() == stateArgsCount) { + for (ui32 i = 0; i < stateArgsCount; ++i) { + if (procPtr->ArgTypes[i] != LastAggregation.TransTypeId) { + found = false; + break; + } + } + } + + if (found) { + Y_ENSURE(!funcId); + funcId = id; + } + } + + Y_ENSURE(funcId); + } + + return true; + } + +private: + TAggregations& Aggregations; + const THashMap<TString, ui32>& TypeByName; + const THashMap<TString, TVector<ui32>>& ProcByName; + const TProcs& Procs; + TAggregateDesc LastAggregation; + bool IsSupported = true; + TString LastOid; + TString LastTransFunc; + TString LastFinalFunc; + TString LastCombineFunc; + TString LastSerializeFunc; + TString LastDeserializeFunc; +}; + TOperators ParseOperators(const TString& dat, const THashMap<TString, ui32>& typeByName, const THashMap<TString, TVector<ui32>>& procByName) { TOperators ret; @@ -408,6 +595,14 @@ TOperators ParseOperators(const TString& dat, const THashMap<TString, ui32>& typ return ret; } +TAggregations ParseAggregations(const TString& dat, const THashMap<TString, ui32>& typeByName, + const THashMap<TString, TVector<ui32>>& procByName, const TProcs& procs) { + TAggregations ret; + TAggregationsParser parser(ret, typeByName, procByName, procs); + parser.Do(dat); + return ret; +} + TProcs ParseProcs(const TString& dat, const THashMap<TString, ui32>& typeByName) { TProcs ret; TProcsParser parser(ret, typeByName); @@ -440,6 +635,8 @@ struct TCatalog { Y_ENSURE(NResource::FindExact("pg_proc.dat", &procData)); TString castData; Y_ENSURE(NResource::FindExact("pg_cast.dat", &castData)); + TString aggData; + Y_ENSURE(NResource::FindExact("pg_aggregate.dat", &aggData)); THashMap<ui32, TLazyTypeInfo> lazyTypeInfos; Types = ParseTypes(typeData, lazyTypeInfos); for (const auto& [k, v] : Types) { @@ -527,6 +724,11 @@ struct TCatalog { for (const auto&[k, v] : Operators) { OperatorsByName[v.Name].push_back(k); } + + Aggregations = ParseAggregations(aggData, TypeByName, ProcByName, Procs); + for (const auto&[k, v] : Aggregations) { + AggregationsByName[v.Name].push_back(k); + } } static const TCatalog& Instance() { @@ -537,14 +739,16 @@ struct TCatalog { TProcs Procs; TTypes Types; TCasts Casts; + TAggregations Aggregations; THashMap<TString, TVector<ui32>> ProcByName; THashMap<TString, ui32> TypeByName; THashMap<std::pair<ui32, ui32>, ui32> CastsByDir; THashMap<TString, TVector<ui32>> OperatorsByName; + THashMap<TString, TVector<ui32>> AggregationsByName; }; -bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) { - if (argTypeIds.size() != d.ArgTypes.size()) { +bool ValidateArgs(const TVector<ui32>& descArgTypeIds, const TVector<ui32>& argTypeIds) { + if (argTypeIds.size() != descArgTypeIds.size()) { return false; } @@ -554,7 +758,7 @@ bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) { continue; // NULL } - if (argTypeIds[i] != d.ArgTypes[i]) { + if (argTypeIds[i] != descArgTypeIds[i]) { found = false; break; } @@ -563,6 +767,10 @@ bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) { return found; } +bool ValidateProcArgs(const TProcDesc& d, const TVector<ui32>& argTypeIds) { + return ValidateArgs(d.ArgTypes, argTypeIds); +} + const TProcDesc& LookupProc(ui32 procId, const TVector<ui32>& argTypeIds) { const auto& catalog = TCatalog::Instance(); auto procPtr = catalog.Procs.FindPtr(procId); @@ -651,16 +859,6 @@ const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId) { return *castPtr; } -const TCastDesc& LookupCast(ui32 castId) { - const auto& catalog = TCatalog::Instance(); - auto castPtr = catalog.Casts.FindPtr(castId); - if (!castPtr) { - throw yexception() << "No such cast: " << castId; - } - - return *castPtr; -} - bool ValidateOperArgs(const TOperDesc& d, const TVector<ui32>& argTypeIds) { ui32 size = d.Kind == EOperKind::Binary ? 2 : 1; if (argTypeIds.size() != size) { @@ -733,4 +931,33 @@ const TOperDesc& LookupOper(ui32 operId) { return *operPtr; } +bool HasAggregation(const TStringBuf& name) { + const auto& catalog = TCatalog::Instance(); + return catalog.AggregationsByName.contains(name); +} + +bool ValidateAggregateArgs(const TAggregateDesc& d, const TVector<ui32>& argTypeIds) { + return ValidateArgs(d.ArgTypes, argTypeIds); +} + +const TAggregateDesc& LookupAggregation(const TStringBuf& name, const TVector<ui32>& argTypeIds) { + const auto& catalog = TCatalog::Instance(); + auto aggIdPtr = catalog.AggregationsByName.FindPtr(name); + if (!aggIdPtr) { + throw yexception() << "No such aggregate: " << name; + } + + for (const auto& id : *aggIdPtr) { + const auto& d = catalog.Aggregations.FindPtr(id); + Y_ENSURE(d); + if (!ValidateAggregateArgs(*d, argTypeIds)) { + continue; + } + + return *d; + } + + throw yexception() << "Unable to find an overload for aggregate " << name << " with given argument types"; +} + } diff --git a/ydb/library/yql/parser/pg_catalog/catalog.h b/ydb/library/yql/parser/pg_catalog/catalog.h index a2c86a439e..538db1375f 100644 --- a/ydb/library/yql/parser/pg_catalog/catalog.h +++ b/ydb/library/yql/parser/pg_catalog/catalog.h @@ -51,13 +51,30 @@ enum class ECastMethod { }; struct TCastDesc { - ui32 CastId = 0; ui32 SourceId = 0; ui32 TargetId = 0; ECastMethod Method = ECastMethod::Function; ui32 FunctionId = 0; }; +enum class EAggKind { + Normal, + OrderedSet, + Hypothetical +}; + +struct TAggregateDesc { + TString Name; + TVector<ui32> ArgTypes; + EAggKind Kind = EAggKind::Normal; + ui32 TransTypeId = 0; + ui32 TransFuncId = 0; + ui32 FinalFuncId = 0; + ui32 CombineFuncId = 0; + ui32 SerializeFuncId = 0; + ui32 DeserializeFuncId = 0; +}; + const TProcDesc& LookupProc(const TString& name, const TVector<ui32>& argTypeIds); const TProcDesc& LookupProc(ui32 procId, const TVector<ui32>& argTypeIds); const TProcDesc& LookupProc(ui32 procId); @@ -68,10 +85,12 @@ const TTypeDesc& LookupType(ui32 typeId); bool HasCast(ui32 sourceId, ui32 targetId); const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId); -const TCastDesc& LookupCast(ui32 castId); const TOperDesc& LookupOper(const TString& name, const TVector<ui32>& argTypeIds); const TOperDesc& LookupOper(ui32 operId, const TVector<ui32>& argTypeIds); const TOperDesc& LookupOper(ui32 operId); +bool HasAggregation(const TStringBuf& name); +const TAggregateDesc& LookupAggregation(const TStringBuf& name, const TVector<ui32>& argTypeIds); + } diff --git a/ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp b/ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp index 17d50ed0ae..28539412a5 100644 --- a/ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp +++ b/ydb/library/yql/parser/pg_catalog/ut/catalog_ut.cpp @@ -34,10 +34,10 @@ Y_UNIT_TEST_SUITE(TTypesTests) { UNIT_ASSERT_VALUES_EQUAL(ret.ElementTypeId, LookupType("float8").TypeId); ret = LookupType(1009); - UNIT_ASSERT_VALUES_EQUAL(ret.TypeId, 25); + UNIT_ASSERT_VALUES_EQUAL(ret.TypeId, 1009); UNIT_ASSERT_VALUES_EQUAL(ret.ArrayTypeId, 1009); - UNIT_ASSERT_VALUES_EQUAL(ret.Name, "text"); - UNIT_ASSERT_VALUES_EQUAL(ret.ElementTypeId, 0); + UNIT_ASSERT_VALUES_EQUAL(ret.Name, "_text"); + UNIT_ASSERT_VALUES_EQUAL(ret.ElementTypeId, 25); } } @@ -99,3 +99,57 @@ Y_UNIT_TEST_SUITE(TCastsTests) { UNIT_ASSERT_VALUES_EQUAL(ret.FunctionId, 0); } } + +Y_UNIT_TEST_SUITE(TAggregationsTests) { + Y_UNIT_TEST(TestMissing) { + UNIT_ASSERT_EXCEPTION(LookupAggregation("foo", {}), yexception); + } + + Y_UNIT_TEST(TestOk) { + auto ret = LookupAggregation("sum", {LookupType("int4").TypeId}); + UNIT_ASSERT_VALUES_EQUAL(ret.TransTypeId, LookupType("int8").TypeId); + UNIT_ASSERT_VALUES_EQUAL(ret.Name, "sum"); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes.size(), 1); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[0], LookupType("int4").TypeId); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.TransFuncId).Name, "int4_sum"); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.CombineFuncId).Name, "int8pl"); + UNIT_ASSERT_VALUES_EQUAL(ret.FinalFuncId, 0); + UNIT_ASSERT_VALUES_EQUAL(ret.SerializeFuncId, 0); + UNIT_ASSERT_VALUES_EQUAL(ret.DeserializeFuncId, 0); + + ret = LookupAggregation("sum", {LookupType("int8").TypeId}); + UNIT_ASSERT_VALUES_EQUAL(ret.TransTypeId, LookupType("internal").TypeId); + UNIT_ASSERT_VALUES_EQUAL(ret.Name, "sum"); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes.size(), 1); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[0], LookupType("int8").TypeId); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.TransFuncId).Name, "int8_avg_accum"); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.CombineFuncId).Name, "int8_avg_combine"); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.FinalFuncId).Name, "numeric_poly_sum"); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.SerializeFuncId).Name, "int8_avg_serialize"); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.DeserializeFuncId).Name, "int8_avg_deserialize"); + + ret = LookupAggregation("string_agg", {LookupType("text").TypeId, LookupType("text").TypeId}); + UNIT_ASSERT_VALUES_EQUAL(ret.TransTypeId, LookupType("internal").TypeId); + UNIT_ASSERT_VALUES_EQUAL(ret.Name, "string_agg"); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes.size(), 2); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[0], LookupType("text").TypeId); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[1], LookupType("text").TypeId); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.TransFuncId).Name, "string_agg_transfn"); + UNIT_ASSERT_VALUES_EQUAL(ret.CombineFuncId, 0); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.FinalFuncId).Name, "string_agg_finalfn"); + UNIT_ASSERT_VALUES_EQUAL(ret.SerializeFuncId, 0); + UNIT_ASSERT_VALUES_EQUAL(ret.DeserializeFuncId, 0); + + ret = LookupAggregation("regr_count", {LookupType("float8").TypeId, LookupType("float8").TypeId}); + UNIT_ASSERT_VALUES_EQUAL(ret.TransTypeId, LookupType("int8").TypeId); + UNIT_ASSERT_VALUES_EQUAL(ret.Name, "regr_count"); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes.size(), 2); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[0], LookupType("float8").TypeId); + UNIT_ASSERT_VALUES_EQUAL(ret.ArgTypes[1], LookupType("float8").TypeId); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.TransFuncId).Name, "int8inc_float8_float8"); + UNIT_ASSERT_VALUES_EQUAL(LookupProc(ret.CombineFuncId).Name, "int8pl"); + UNIT_ASSERT_VALUES_EQUAL(ret.FinalFuncId, 0); + UNIT_ASSERT_VALUES_EQUAL(ret.SerializeFuncId, 0); + UNIT_ASSERT_VALUES_EQUAL(ret.DeserializeFuncId, 0); + } +} |