diff options
author | alexv-smirnov <alex@ydb.tech> | 2023-06-13 11:05:01 +0300 |
---|---|---|
committer | alexv-smirnov <alex@ydb.tech> | 2023-06-13 11:05:01 +0300 |
commit | bf0f13dd39ee3e65092ba3572bb5b1fcd125dcd0 (patch) | |
tree | 1d1df72c0541a59a81439842f46d95396d3e7189 /contrib/tools/cython/Cython/Compiler/TreeFragment.py | |
parent | 8bfdfa9a9bd19bddbc58d888e180fbd1218681be (diff) | |
download | ydb-bf0f13dd39ee3e65092ba3572bb5b1fcd125dcd0.tar.gz |
add ymake export to ydb
Diffstat (limited to 'contrib/tools/cython/Cython/Compiler/TreeFragment.py')
-rw-r--r-- | contrib/tools/cython/Cython/Compiler/TreeFragment.py | 275 |
1 files changed, 275 insertions, 0 deletions
diff --git a/contrib/tools/cython/Cython/Compiler/TreeFragment.py b/contrib/tools/cython/Cython/Compiler/TreeFragment.py new file mode 100644 index 0000000000..b85da8191a --- /dev/null +++ b/contrib/tools/cython/Cython/Compiler/TreeFragment.py @@ -0,0 +1,275 @@ +# +# TreeFragments - parsing of strings to trees +# + +""" +Support for parsing strings into code trees. +""" + +from __future__ import absolute_import + +import re +from io import StringIO + +from .Scanning import PyrexScanner, StringSourceDescriptor +from .Symtab import ModuleScope +from . import PyrexTypes +from .Visitor import VisitorTransform +from .Nodes import Node, StatListNode +from .ExprNodes import NameNode +from .StringEncoding import _unicode +from . import Parsing +from . import Main +from . import UtilNodes + + +class StringParseContext(Main.Context): + def __init__(self, name, include_directories=None, compiler_directives=None, cpp=False): + if include_directories is None: + include_directories = [] + if compiler_directives is None: + compiler_directives = {} + # TODO: see if "language_level=3" also works for our internal code here. + Main.Context.__init__(self, include_directories, compiler_directives, cpp=cpp, language_level=2) + self.module_name = name + + def find_module(self, module_name, relative_to=None, pos=None, need_pxd=1, absolute_fallback=True): + if module_name not in (self.module_name, 'cython'): + raise AssertionError("Not yet supporting any cimports/includes from string code snippets") + return ModuleScope(module_name, parent_module=None, context=self) + + +def parse_from_strings(name, code, pxds=None, level=None, initial_pos=None, + context=None, allow_struct_enum_decorator=False): + """ + Utility method to parse a (unicode) string of code. This is mostly + used for internal Cython compiler purposes (creating code snippets + that transforms should emit, as well as unit testing). + + code - a unicode string containing Cython (module-level) code + name - a descriptive name for the code source (to use in error messages etc.) + + RETURNS + + The tree, i.e. a ModuleNode. The ModuleNode's scope attribute is + set to the scope used when parsing. + """ + if context is None: + context = StringParseContext(name) + # Since source files carry an encoding, it makes sense in this context + # to use a unicode string so that code fragments don't have to bother + # with encoding. This means that test code passed in should not have an + # encoding header. + assert isinstance(code, _unicode), "unicode code snippets only please" + encoding = "UTF-8" + + module_name = name + if initial_pos is None: + initial_pos = (name, 1, 0) + code_source = StringSourceDescriptor(name, code) + + scope = context.find_module(module_name, pos=initial_pos, need_pxd=False) + + buf = StringIO(code) + + scanner = PyrexScanner(buf, code_source, source_encoding = encoding, + scope = scope, context = context, initial_pos = initial_pos) + ctx = Parsing.Ctx(allow_struct_enum_decorator=allow_struct_enum_decorator) + + if level is None: + tree = Parsing.p_module(scanner, 0, module_name, ctx=ctx) + tree.scope = scope + tree.is_pxd = False + else: + tree = Parsing.p_code(scanner, level=level, ctx=ctx) + + tree.scope = scope + return tree + + +class TreeCopier(VisitorTransform): + def visit_Node(self, node): + if node is None: + return node + else: + c = node.clone_node() + self.visitchildren(c) + return c + + +class ApplyPositionAndCopy(TreeCopier): + def __init__(self, pos): + super(ApplyPositionAndCopy, self).__init__() + self.pos = pos + + def visit_Node(self, node): + copy = super(ApplyPositionAndCopy, self).visit_Node(node) + copy.pos = self.pos + return copy + + +class TemplateTransform(VisitorTransform): + """ + Makes a copy of a template tree while doing substitutions. + + A dictionary "substitutions" should be passed in when calling + the transform; mapping names to replacement nodes. Then replacement + happens like this: + - If an ExprStatNode contains a single NameNode, whose name is + a key in the substitutions dictionary, the ExprStatNode is + replaced with a copy of the tree given in the dictionary. + It is the responsibility of the caller that the replacement + node is a valid statement. + - If a single NameNode is otherwise encountered, it is replaced + if its name is listed in the substitutions dictionary in the + same way. It is the responsibility of the caller to make sure + that the replacement nodes is a valid expression. + + Also a list "temps" should be passed. Any names listed will + be transformed into anonymous, temporary names. + + Currently supported for tempnames is: + NameNode + (various function and class definition nodes etc. should be added to this) + + Each replacement node gets the position of the substituted node + recursively applied to every member node. + """ + + temp_name_counter = 0 + + def __call__(self, node, substitutions, temps, pos): + self.substitutions = substitutions + self.pos = pos + tempmap = {} + temphandles = [] + for temp in temps: + TemplateTransform.temp_name_counter += 1 + handle = UtilNodes.TempHandle(PyrexTypes.py_object_type) + tempmap[temp] = handle + temphandles.append(handle) + self.tempmap = tempmap + result = super(TemplateTransform, self).__call__(node) + if temps: + result = UtilNodes.TempsBlockNode(self.get_pos(node), + temps=temphandles, + body=result) + return result + + def get_pos(self, node): + if self.pos: + return self.pos + else: + return node.pos + + def visit_Node(self, node): + if node is None: + return None + else: + c = node.clone_node() + if self.pos is not None: + c.pos = self.pos + self.visitchildren(c) + return c + + def try_substitution(self, node, key): + sub = self.substitutions.get(key) + if sub is not None: + pos = self.pos + if pos is None: pos = node.pos + return ApplyPositionAndCopy(pos)(sub) + else: + return self.visit_Node(node) # make copy as usual + + def visit_NameNode(self, node): + temphandle = self.tempmap.get(node.name) + if temphandle: + # Replace name with temporary + return temphandle.ref(self.get_pos(node)) + else: + return self.try_substitution(node, node.name) + + def visit_ExprStatNode(self, node): + # If an expression-as-statement consists of only a replaceable + # NameNode, we replace the entire statement, not only the NameNode + if isinstance(node.expr, NameNode): + return self.try_substitution(node, node.expr.name) + else: + return self.visit_Node(node) + + +def copy_code_tree(node): + return TreeCopier()(node) + + +_match_indent = re.compile(u"^ *").match + + +def strip_common_indent(lines): + """Strips empty lines and common indentation from the list of strings given in lines""" + # TODO: Facilitate textwrap.indent instead + lines = [x for x in lines if x.strip() != u""] + if lines: + minindent = min([len(_match_indent(x).group(0)) for x in lines]) + lines = [x[minindent:] for x in lines] + return lines + + +class TreeFragment(object): + def __init__(self, code, name=None, pxds=None, temps=None, pipeline=None, level=None, initial_pos=None): + if pxds is None: + pxds = {} + if temps is None: + temps = [] + if pipeline is None: + pipeline = [] + if not name: + name = "(tree fragment)" + + if isinstance(code, _unicode): + def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) + + fmt_code = fmt(code) + fmt_pxds = {} + for key, value in pxds.items(): + fmt_pxds[key] = fmt(value) + mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos) + if level is None: + t = t.body # Make sure a StatListNode is at the top + if not isinstance(t, StatListNode): + t = StatListNode(pos=mod.pos, stats=[t]) + for transform in pipeline: + if transform is None: + continue + t = transform(t) + self.root = t + elif isinstance(code, Node): + if pxds: + raise NotImplementedError() + self.root = code + else: + raise ValueError("Unrecognized code format (accepts unicode and Node)") + self.temps = temps + + def copy(self): + return copy_code_tree(self.root) + + def substitute(self, nodes=None, temps=None, pos = None): + if nodes is None: + nodes = {} + if temps is None: + temps = [] + return TemplateTransform()(self.root, + substitutions = nodes, + temps = self.temps + temps, pos = pos) + + +class SetPosTransform(VisitorTransform): + def __init__(self, pos): + super(SetPosTransform, self).__init__() + self.pos = pos + + def visit_Node(self, node): + node.pos = self.pos + self.visitchildren(node) + return node |