summaryrefslogtreecommitdiffstats
path: root/yql/essentials/parser/pg_wrapper/generate_kernels.py
diff options
context:
space:
mode:
authorvvvv <[email protected]>2024-11-07 12:29:36 +0300
committervvvv <[email protected]>2024-11-07 13:49:47 +0300
commitd4c258e9431675bab6745c8638df6e3dfd4dca6b (patch)
treeb5efcfa11351152a4c872fccaea35749141c0b11 /yql/essentials/parser/pg_wrapper/generate_kernels.py
parent13a4f274caef5cfdaf0263b24e4d6bdd5521472b (diff)
Moved other yql/essentials libs YQL-19206
init commit_hash:7d4c435602078407bbf20dd3c32f9c90d2bbcbc0
Diffstat (limited to 'yql/essentials/parser/pg_wrapper/generate_kernels.py')
-rwxr-xr-xyql/essentials/parser/pg_wrapper/generate_kernels.py425
1 files changed, 425 insertions, 0 deletions
diff --git a/yql/essentials/parser/pg_wrapper/generate_kernels.py b/yql/essentials/parser/pg_wrapper/generate_kernels.py
new file mode 100755
index 00000000000..da00d08e3b2
--- /dev/null
+++ b/yql/essentials/parser/pg_wrapper/generate_kernels.py
@@ -0,0 +1,425 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import os
+import json
+import zlib
+from collections import defaultdict
+
+SPLIT_FILES = 20
+
+def is_strict(oid_per_name, funcs, name):
+ found = None
+ for oid in oid_per_name[name]:
+ strict = funcs[oid][1]
+ if found is None:
+ found = strict
+ else:
+ assert found == strict
+ return "true" if found else "false"
+
+def is_result_fixed(oid_per_name, catalog_by_oid, name):
+ found = None
+ for oid in oid_per_name[name]:
+ fixed = catalog_by_oid[oid]["ret_type_fixed"]
+ if found is None:
+ found = fixed
+ else:
+ assert found == fixed
+ return "true" if found else "false"
+
+def get_fixed_args(oid_per_name, catalog_by_oid, name):
+ found = None
+ for oid in oid_per_name[name]:
+ if "var_type" in catalog_by_oid[oid]:
+ return None
+ fixed = [x["arg_type_fixed"] for x in catalog_by_oid[oid]["args"]]
+ if found is None:
+ found = fixed
+ else:
+ # e.g. range_constructor2
+ if found != fixed:
+ return None
+ return found
+
+def main():
+ pg_sources = []
+ with open("pg_sources.inc") as f:
+ for line in f:
+ pg_sources.append(line.rstrip())
+ with open("../../tools/pg_catalog_dump/dump.json") as f:
+ catalog = json.load(f)
+ catalog_by_oid = {}
+ catalog_funcs = set()
+ for proc in catalog["proc"]:
+ catalog_by_oid[proc["oid"]] = proc
+ catalog_funcs.add(proc["src"])
+ catalog_aggs_by_id = {}
+ for agg in catalog["aggregation"]:
+ if not agg["combine_func_id"]:
+ continue
+ catalog_aggs_by_id[agg["agg_id"]] = agg
+ assert len(agg["args"]) <= 2
+
+ funcs={}
+ with open("postgresql/src/backend/utils/fmgrtab.c") as f:
+ parse=False
+ for line in f:
+ if "fmgr_builtins[]" in line:
+ parse=True
+ continue
+ if not parse:
+ continue
+ if line.startswith("}"):
+ parse=False
+ continue
+ c=line.strip()[1:-2].split(", ")
+ oid=int(c[0])
+ nargs=int(c[1])
+ strict=c[2].strip()=="true"
+ retset=c[3].strip()=="true"
+ name=c[4].strip().strip('"')
+ func=c[5].strip()
+ if retset: continue
+ if name!=func:
+ print(name,func)
+ continue
+ if not oid in catalog_by_oid:
+ print("skipped by catalog: ",name)
+ continue
+ funcs[oid] = (nargs, strict, name)
+ print("funcs: ", len(funcs))
+ func_names=set(x[2] for x in funcs.values())
+ print("unique names: ", len(func_names))
+ print("aggs: ", len(catalog_aggs_by_id))
+ oid_per_name={}
+ all_found_funcs=set()
+ for x in funcs:
+ name = funcs[x][2]
+ if not name in oid_per_name:
+ oid_per_name[name]=[]
+ oid_per_name[name].append(x)
+ symbols={}
+ for i in range(len(pg_sources)):
+ line = pg_sources[i]
+ if not line.endswith(".c"):
+ continue
+ cfile = line.strip()
+ found_funcs = set()
+ #print(cfile)
+ with open(cfile) as f:
+ for srcline in f:
+ pos=srcline.find("(PG_FUNCTION_ARGS)")
+ if pos!=-1:
+ names=[srcline[0:pos].strip()]
+ elif srcline.startswith("CMPFUNC("):
+ pos=srcline.find(",")
+ names=[srcline[8:pos]]
+ elif srcline.startswith("TSVECTORCMPFUNC("):
+ pos=srcline.find(",")
+ names=["tsvector_"+srcline[16:pos]]
+ elif srcline.startswith("PSEUDOTYPE_DUMMY_IO_FUNCS("):
+ pos=srcline.find(")")
+ names=[srcline[26:pos]+"_in", srcline[26:pos]+"_out"]
+ elif srcline.startswith("PSEUDOTYPE_DUMMY_INPUT_FUNC(") and "\\" not in srcline:
+ pos=srcline.find(")")
+ names=[srcline[28:pos]+"_in"]
+ elif srcline.startswith("PSEUDOTYPE_DUMMY_BINARY_IO_FUNCS"):
+ pos=srcline.find(")")
+ names=[srcline[33:pos]+"_send", srcline[33:pos]+"_recv"]
+ elif srcline.startswith("PSEUDOTYPE_DUMMY_RECEIVE_FUNC(") and "\\" not in srcline:
+ pos=srcline.find(")")
+ names=[srcline[30:pos]+"_recv"]
+ elif srcline.startswith("PG_STAT_GET_DBENTRY_FLOAT8_MS("):
+ pos=srcline.find(")")
+ names=["pg_stat_get_db_" + srcline[30:pos]]
+ elif srcline.startswith("PG_STAT_GET_DBENTRY_INT64("):
+ pos=srcline.find(")")
+ names=["pg_stat_get_db_" + srcline[26:pos]]
+ elif srcline.startswith("PG_STAT_GET_RELENTRY_INT64("):
+ pos=srcline.find(")")
+ names=["pg_stat_get_" + srcline[27:pos]]
+ elif srcline.startswith("PG_STAT_GET_RELENTRY_TIMESTAMPTZ("):
+ pos=srcline.find(")")
+ names=["pg_stat_get_" + srcline[33:pos]]
+ elif srcline.startswith("PG_STAT_GET_XACT_RELENTRY_INT64("):
+ pos=srcline.find(")")
+ names=["pg_stat_get_xact_" + srcline[32:pos]]
+ elif srcline.startswith("PG_STAT_GET_FUNCENTRY_FLOAT8_MS("):
+ pos=srcline.find(")")
+ names=["pg_stat_get_function_" + srcline[32:pos]]
+ elif srcline.startswith("PG_STAT_GET_XACT_FUNCENTRY_FLOAT8_MS("):
+ pos=srcline.find(")")
+ names=["pg_stat_get_xact_function_" + srcline[37:pos]]
+ else:
+ continue
+ for name in names:
+ if name in func_names:
+ found_funcs.add(name)
+ all_found_funcs.add(name)
+ if not found_funcs:
+ continue
+ print(cfile, len(found_funcs))
+ symbols[cfile] = found_funcs
+ split_symbols = []
+ split_all_found_funcs = []
+ for i in range(SPLIT_FILES):
+ curr_symbols = {}
+ curr_all_found_funcs = set()
+ for cfile in symbols:
+ if zlib.crc32(cfile.encode("utf8")) % SPLIT_FILES != i: continue
+ curr_symbols[cfile] = symbols[cfile]
+ curr_all_found_funcs.update(symbols[cfile])
+
+ split_symbols.append(curr_symbols)
+ split_all_found_funcs.append(curr_all_found_funcs)
+ # check if all functions are available
+ split_for_agg = {}
+ for agg in catalog_aggs_by_id.values():
+ oids = set()
+ oids.add(agg["trans_func_id"])
+ if agg["serialize_func_id"]:
+ assert catalog_by_oid[agg["serialize_func_id"]]["strict"]
+ oids.add(agg["serialize_func_id"])
+ if agg["deserialize_func_id"]:
+ assert catalog_by_oid[agg["deserialize_func_id"]]["strict"]
+ oids.add(agg["deserialize_func_id"])
+ if agg["final_func_id"]:
+ oids.add(agg["final_func_id"])
+ oids.add(agg["combine_func_id"])
+ names = [catalog_by_oid[oid]["src"] for oid in oids]
+
+ for i in range(SPLIT_FILES):
+ with open("pg_bc."+str(i)+".inc", "w") as bc:
+ bc.write("LLVM_BC(\n" + \
+ "\n".join((" " + x) for x in sorted(split_symbols[i].keys())) + \
+ "\n pg_kernels."+str(i)+".cpp\n" + \
+ "\n NAME PgFuncs" + str(i) + "\n" + \
+ "\n SYMBOLS\n" + \
+ "\n".join((" arrow_" + x) for x in sorted(split_all_found_funcs[i])) + \
+ "\n)\n")
+ for i in range(SPLIT_FILES):
+ with open("pg_proc_policies." + str(i) + ".inc", "w") as p:
+ for x in sorted(split_all_found_funcs[i]):
+ fixed_args = get_fixed_args(oid_per_name, catalog_by_oid, x)
+ if fixed_args is not None:
+ p.write("struct TArgs_NAME_Policy {\n".replace("NAME", x))
+ p.write(" static constexpr bool VarArgs = false;\n")
+ p.write(" static constexpr std::array<bool, N> IsFixedArg = {V};\n" \
+ .replace("N", str(len(fixed_args))) \
+ .replace("V", ",".join("true" if x else "false" for x in fixed_args)))
+ p.write("};\n")
+ else:
+ print("polymorphic args:", x)
+ for i in range(SPLIT_FILES):
+ with open("pg_kernels." + str(i) + ".inc", "w") as k:
+ for x in sorted(split_all_found_funcs[i]):
+ fixed_args = get_fixed_args(oid_per_name, catalog_by_oid, x)
+ k.write(
+ "TExecFunc arrow_NAME() { return TGenericExec<TPgDirectFunc<&NAME>, STRICT, IS_RESULT_FIXED, POLICY>({}); }\n" \
+ .replace("NAME", x) \
+ .replace("STRICT", is_strict(oid_per_name, funcs, x)) \
+ .replace("IS_RESULT_FIXED", is_result_fixed(oid_per_name, catalog_by_oid, x))
+ .replace("POLICY", "TArgs_" + x + "_Policy" if fixed_args is not None else "TDefaultArgsPolicy"))
+ for i in range(SPLIT_FILES):
+ with open("pg_kernels.slow." + str(i) + ".inc", "w") as k:
+ for x in sorted(split_all_found_funcs[i]):
+ k.write(
+ "TExecFunc arrow_NAME() { return MakeIndirectExec<STRICT, IS_RESULT_FIXED>(&NAME); }\n" \
+ .replace("NAME", x) \
+ .replace("STRICT", is_strict(oid_per_name, funcs, x)) \
+ .replace("IS_RESULT_FIXED", is_result_fixed(oid_per_name, catalog_by_oid, x)))
+ with open("pg_kernels_fwd.inc", "w") as k:
+ k.write(\
+ "\n".join("extern TExecFunc arrow_NAME();".replace("NAME", x) for x in sorted(all_found_funcs)) + \
+ "\n")
+ for i in range(SPLIT_FILES):
+ with open("pg_kernels_register." + str(i) + ".inc", "w") as r:
+ for name in oid_per_name:
+ if not name in split_all_found_funcs[i]: continue
+ for oid in sorted(oid_per_name[name]):
+ r.write("RegisterExec(" + str(oid) + ", arrow_" + name + "());\n")
+ for slow in [False, True]:
+ with open("pg_aggs" + (".slow" if slow else "") + ".inc","w") as p:
+ for agg_id in sorted(catalog_aggs_by_id.keys()):
+ agg = catalog_aggs_by_id[agg_id]
+ trans_func = catalog_by_oid[agg["trans_func_id"]]["src"]
+ trans_fixed_args = None if slow else get_fixed_args(oid_per_name, catalog_by_oid, trans_func)
+ combine_func = catalog_by_oid[agg["combine_func_id"]]["src"]
+ combine_fixed_args = None if slow else get_fixed_args(oid_per_name, catalog_by_oid, combine_func)
+ serialize_func = ""
+ serialize_fixed_args = None
+ if agg["serialize_func_id"]:
+ serialize_func = catalog_by_oid[agg["serialize_func_id"]]["src"]
+ serialize_fixed_args = None if slow else get_fixed_args(oid_per_name, catalog_by_oid, serialize_func)
+ deserialize_func = ""
+ deserialize_fixed_args = None
+ if agg["deserialize_func_id"]:
+ deserialize_func = catalog_by_oid[agg["deserialize_func_id"]]["src"]
+ deserialize_fixed_args = None if slow else get_fixed_args(oid_per_name, catalog_by_oid, deserialize_func)
+ final_func = ""
+ final_fixed_args = None
+ if agg["final_func_id"]:
+ final_func = catalog_by_oid[agg["final_func_id"]]["src"]
+ final_fixed_args = None if slow else get_fixed_args(oid_per_name, catalog_by_oid, final_func)
+
+ p.write("auto MakePgAgg_" + agg["name"] + "_" + str(agg_id) + "() {\n"
+ " return TGenericAgg<\n \
+ TRANS_FUNC, IS_TRANS_STRICT, TRANS_ARGS_POLICY,\n \
+ COMBINE_FUNC, IS_COMBINE_STRICT, COMBINE_ARGS_POLICY,\n \
+ HAS_SERIALIZE_FUNC, SERIALIZE_FUNC1, SERIALIZE_ARGS_POLICY1,\n \
+ HAS_DESERIALIZE_FUNC, DESERIALIZE_FUNC, DESERIALIZE_ARGS_POLICY,\n \
+ HAS_FINAL_FUNC, FINAL_FUNC, IS_FINAL_STRICT, FINAL_ARGS_POLICY,\n \
+ TRANS_TYPE_FIXED, SERIALIZED_TYPE_FIXED, FINAL_TYPE_FIXED, HAS_INIT_VALUE\n \
+ >(TRANS_OBJ, COMBINE_OBJ, SERIALIZE1_OBJ, DESERIALIZE_OBJ, FINAL_OBJ);\n" \
+ .replace("TRANS_FUNC", "TPgIndirectFunc" if slow else "TPgDirectFunc<&" + trans_func + ">") \
+ .replace("IS_TRANS_STRICT", "true" if catalog_by_oid[agg["trans_func_id"]]["strict"] else "false") \
+ .replace("TRANS_ARGS_POLICY", "TArgs_" + trans_func + "_Policy" if trans_fixed_args is not None else "TDefaultArgsPolicy") \
+ .replace("COMBINE_FUNC", "TPgIndirectFunc" if slow else "TPgDirectFunc<&" + combine_func + ">") \
+ .replace("IS_COMBINE_STRICT", "true" if catalog_by_oid[agg["combine_func_id"]]["strict"] else "false") \
+ .replace("COMBINE_ARGS_POLICY", "TArgs_" + combine_func + "_Policy" if combine_fixed_args is not None else "TDefaultArgsPolicy") \
+ .replace("HAS_SERIALIZE_FUNC", "true" if serialize_func else "false") \
+ .replace("SERIALIZE_FUNC1", "TPgDirectFunc<&" + serialize_func + ">" if serialize_func and not slow else "TPgIndirectFunc") \
+ .replace("SERIALIZE_ARGS_POLICY1", "TArgs_" + serialize_func + "_Policy" if serialize_fixed_args is not None else "TDefaultArgsPolicy") \
+ .replace("HAS_DESERIALIZE_FUNC", "true" if deserialize_func else "false") \
+ .replace("DESERIALIZE_FUNC", "TPgDirectFunc<&" + deserialize_func + ">" if deserialize_func and not slow else "TPgIndirectFunc") \
+ .replace("DESERIALIZE_ARGS_POLICY", "TArgs_" + deserialize_func + "_Policy" if deserialize_fixed_args is not None else "TDefaultArgsPolicy") \
+ .replace("HAS_FINAL_FUNC", "true" if final_func else "false") \
+ .replace("FINAL_FUNC", "TPgDirectFunc<&" + final_func + ">" if final_func and not slow else "TPgIndirectFunc") \
+ .replace("IS_FINAL_STRICT", "true" if final_func and catalog_by_oid[agg["final_func_id"]]["strict"] else "false") \
+ .replace("FINAL_ARGS_POLICY", "TArgs_" + final_func + "_Policy" if final_fixed_args is not None else "TDefaultArgsPolicy") \
+ .replace("TRANS_TYPE_FIXED", "true" if agg["trans_type_fixed"] else "false") \
+ .replace("SERIALIZED_TYPE_FIXED", "true" if agg["serialized_type_fixed"] else "false") \
+ .replace("FINAL_TYPE_FIXED", "true" if agg["ret_type_fixed"] else "false") \
+ .replace("HAS_INIT_VALUE", "true" if agg["has_init_value"] else "false") \
+ .replace("TRANS_OBJ", "&" + trans_func if slow else "{}") \
+ .replace("COMBINE_OBJ", "&" + combine_func if slow else "{}") \
+ .replace("SERIALIZE1_OBJ", ("&" + serialize_func if slow else "{}") if serialize_func else "nullptr") \
+ .replace("DESERIALIZE_OBJ", ("&" +deserialize_func if slow else "{}") if deserialize_func else "nullptr") \
+ .replace("FINAL_OBJ", ("&" + final_func if slow else "{}") if final_func else "nullptr") \
+ )
+ p.write("}\n")
+
+ agg_names = defaultdict(list)
+ with open("pg_aggs_register.inc","w") as p:
+ for agg_id in sorted(catalog_aggs_by_id.keys()):
+ agg_names[catalog_aggs_by_id[agg_id]["name"]].append(agg_id)
+ for name in agg_names:
+ p.write(
+ ("class TPgAggFactory_NAME: public IBlockAggregatorFactory {\n" \
+ "std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineAll>> PrepareCombineAll(\n" \
+ " TTupleType* tupleType,\n" \
+ " std::optional<ui32> filterColumn,\n" \
+ " const std::vector<ui32>& argsColumns,\n" \
+ " const TTypeEnvironment& env) const final {\n" \
+ " const auto& aggDesc = ResolveAggregation(\"NAME\", tupleType, argsColumns, nullptr);\n" \
+ " switch (aggDesc.AggId) {\n" +
+ "".join([" case " + str(agg_id) + ": return MakePgAgg_NAME_" + str(agg_id) + "().PrepareCombineAll(filterColumn, argsColumns, aggDesc);\n" for agg_id in agg_names[name]]) +
+ " default: throw yexception() << \"Unsupported agg id: \" << aggDesc.AggId;\n" \
+ " }\n" \
+ "}\n" \
+ "\n" \
+ "std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorCombineKeys>> PrepareCombineKeys(\n" \
+ " TTupleType* tupleType,\n" \
+ " const std::vector<ui32>& argsColumns,\n" \
+ " const TTypeEnvironment& env) const final {\n" \
+ " const auto& aggDesc = ResolveAggregation(\"NAME\", tupleType, argsColumns, nullptr);\n"
+ " switch (aggDesc.AggId) {\n" +
+ "".join([" case " + str(agg_id) + ": return MakePgAgg_NAME_" + str(agg_id) + "().PrepareCombineKeys(argsColumns, aggDesc);\n" for agg_id in agg_names[name]]) +
+ " default: throw yexception() << \"Unsupported agg id: \" << aggDesc.AggId;\n" \
+ " }\n" \
+ "}\n" \
+ "\n" \
+ "std::unique_ptr<IPreparedBlockAggregator<IBlockAggregatorFinalizeKeys>> PrepareFinalizeKeys(\n" \
+ " TTupleType* tupleType,\n" \
+ " const std::vector<ui32>& argsColumns,\n" \
+ " const TTypeEnvironment& env,\n" \
+ " TType* returnType,\n" \
+ " ui32 hint) const final {\n" \
+ " const auto& aggDesc = ResolveAggregation(\"NAME\", tupleType, argsColumns, returnType, hint);\n"
+ " switch (aggDesc.AggId) {\n" +
+ "".join([" case " + str(agg_id) + ": return MakePgAgg_NAME_" + str(agg_id) + "().PrepareFinalizeKeys(argsColumns.front(), aggDesc);\n" for agg_id in agg_names[name]]) +
+ " default: throw yexception() << \"Unsupported agg id: \" << aggDesc.AggId;\n" \
+ " }\n" \
+ "}\n" \
+ "};\n").replace("NAME", name))
+ for name in agg_names:
+ p.write('registry.emplace("pg_' + name + '", std::make_unique<TPgAggFactory_' + name + '>());\n')
+
+ for i in range(SPLIT_FILES):
+ with open("pg_kernels." + str(i) + ".cpp","w") as f:
+ f.write(
+ 'extern "C" {\n'
+ '#include "postgres.h"\n'
+ '#include "fmgr.h"\n'
+ '#include "postgresql/src/backend/utils/fmgrprotos.h"\n'
+ '#undef Abs\n'
+ '#undef Min\n'
+ '#undef Max\n'
+ '#undef TypeName\n'
+ '#undef SortBy\n'
+ '#undef Sort\n'
+ '#undef Unique\n'
+ '#undef LOG\n'
+ '#undef INFO\n'
+ '#undef NOTICE\n'
+ '#undef WARNING\n'
+ '#undef ERROR\n'
+ '#undef FATAL\n'
+ '#undef PANIC\n'
+ '#undef open\n'
+ '#undef fopen\n'
+ '#undef bind\n'
+ '#undef locale_t\n'
+ '#undef strtou64\n'
+ '}\n'
+ '\n'
+ '#include "arrow.h"\n'
+ '\n'
+ 'namespace NYql {\n'
+ '\n'
+ 'extern "C" {\n'
+ '\n'
+ 'Y_PRAGMA_DIAGNOSTIC_PUSH\n'
+ 'Y_PRAGMA("GCC diagnostic ignored \\"-Wreturn-type-c-linkage\\"")\n'
+ '#ifdef USE_SLOW_PG_KERNELS\n'
+ '#include "pg_kernels.slow.INDEX.inc"\n'
+ '#else\n'
+ '#include "pg_proc_policies.INDEX.inc"\n'
+ '#include "pg_kernels.INDEX.inc"\n'
+ '#endif\n'
+ 'Y_PRAGMA_DIAGNOSTIC_POP\n'
+ '\n'
+ '}\n'
+ '\n'
+ '}\n'.replace("INDEX",str(i))
+ )
+
+ with open("pg_kernels_register.all.inc","w") as f:
+ for i in range(SPLIT_FILES):
+ f.write('#include "pg_kernels_register.INDEX.inc"\n'.replace("INDEX", str(i)))
+
+ with open("pg_proc_policies.all.inc","w") as f:
+ for i in range(SPLIT_FILES):
+ f.write('#include "pg_proc_policies.INDEX.inc"\n'.replace("INDEX", str(i)))
+
+ with open("pg_kernel_sources.inc","w") as f:
+ f.write("SRCS(\n")
+ for i in range(SPLIT_FILES):
+ f.write(' pg_kernels.INDEX.cpp\n'.replace("INDEX", str(i)))
+ f.write(")\n")
+
+ with open("pg_bc.all.inc","w") as f:
+ for i in range(SPLIT_FILES):
+ f.write('INCLUDE(pg_bc.INDEX.inc)\n'.replace("INDEX", str(i)))
+
+ print("found funcs: ",len(all_found_funcs))
+ print("agg names: ",len(agg_names))
+ print("agg funcs: ",len(catalog_aggs_by_id))
+ missing=func_names.difference(all_found_funcs)
+ if missing:
+ print("missing funcs: ",len(missing))
+ print(missing)
+
+if __name__ == "__main__":
+ main()