aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/tools/cython/Cython/Compiler/TreeFragment.py
diff options
context:
space:
mode:
authoralexv-smirnov <alex@ydb.tech>2023-06-13 11:05:01 +0300
committeralexv-smirnov <alex@ydb.tech>2023-06-13 11:05:01 +0300
commitbf0f13dd39ee3e65092ba3572bb5b1fcd125dcd0 (patch)
tree1d1df72c0541a59a81439842f46d95396d3e7189 /contrib/tools/cython/Cython/Compiler/TreeFragment.py
parent8bfdfa9a9bd19bddbc58d888e180fbd1218681be (diff)
downloadydb-bf0f13dd39ee3e65092ba3572bb5b1fcd125dcd0.tar.gz
add ymake export to ydb
Diffstat (limited to 'contrib/tools/cython/Cython/Compiler/TreeFragment.py')
-rw-r--r--contrib/tools/cython/Cython/Compiler/TreeFragment.py275
1 files changed, 275 insertions, 0 deletions
diff --git a/contrib/tools/cython/Cython/Compiler/TreeFragment.py b/contrib/tools/cython/Cython/Compiler/TreeFragment.py
new file mode 100644
index 0000000000..b85da8191a
--- /dev/null
+++ b/contrib/tools/cython/Cython/Compiler/TreeFragment.py
@@ -0,0 +1,275 @@
+#
+# TreeFragments - parsing of strings to trees
+#
+
+"""
+Support for parsing strings into code trees.
+"""
+
+from __future__ import absolute_import
+
+import re
+from io import StringIO
+
+from .Scanning import PyrexScanner, StringSourceDescriptor
+from .Symtab import ModuleScope
+from . import PyrexTypes
+from .Visitor import VisitorTransform
+from .Nodes import Node, StatListNode
+from .ExprNodes import NameNode
+from .StringEncoding import _unicode
+from . import Parsing
+from . import Main
+from . import UtilNodes
+
+
+class StringParseContext(Main.Context):
+ def __init__(self, name, include_directories=None, compiler_directives=None, cpp=False):
+ if include_directories is None:
+ include_directories = []
+ if compiler_directives is None:
+ compiler_directives = {}
+ # TODO: see if "language_level=3" also works for our internal code here.
+ Main.Context.__init__(self, include_directories, compiler_directives, cpp=cpp, language_level=2)
+ self.module_name = name
+
+ def find_module(self, module_name, relative_to=None, pos=None, need_pxd=1, absolute_fallback=True):
+ if module_name not in (self.module_name, 'cython'):
+ raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
+ return ModuleScope(module_name, parent_module=None, context=self)
+
+
+def parse_from_strings(name, code, pxds=None, level=None, initial_pos=None,
+ context=None, allow_struct_enum_decorator=False):
+ """
+ Utility method to parse a (unicode) string of code. This is mostly
+ used for internal Cython compiler purposes (creating code snippets
+ that transforms should emit, as well as unit testing).
+
+ code - a unicode string containing Cython (module-level) code
+ name - a descriptive name for the code source (to use in error messages etc.)
+
+ RETURNS
+
+ The tree, i.e. a ModuleNode. The ModuleNode's scope attribute is
+ set to the scope used when parsing.
+ """
+ if context is None:
+ context = StringParseContext(name)
+ # Since source files carry an encoding, it makes sense in this context
+ # to use a unicode string so that code fragments don't have to bother
+ # with encoding. This means that test code passed in should not have an
+ # encoding header.
+ assert isinstance(code, _unicode), "unicode code snippets only please"
+ encoding = "UTF-8"
+
+ module_name = name
+ if initial_pos is None:
+ initial_pos = (name, 1, 0)
+ code_source = StringSourceDescriptor(name, code)
+
+ scope = context.find_module(module_name, pos=initial_pos, need_pxd=False)
+
+ buf = StringIO(code)
+
+ scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
+ scope = scope, context = context, initial_pos = initial_pos)
+ ctx = Parsing.Ctx(allow_struct_enum_decorator=allow_struct_enum_decorator)
+
+ if level is None:
+ tree = Parsing.p_module(scanner, 0, module_name, ctx=ctx)
+ tree.scope = scope
+ tree.is_pxd = False
+ else:
+ tree = Parsing.p_code(scanner, level=level, ctx=ctx)
+
+ tree.scope = scope
+ return tree
+
+
+class TreeCopier(VisitorTransform):
+ def visit_Node(self, node):
+ if node is None:
+ return node
+ else:
+ c = node.clone_node()
+ self.visitchildren(c)
+ return c
+
+
+class ApplyPositionAndCopy(TreeCopier):
+ def __init__(self, pos):
+ super(ApplyPositionAndCopy, self).__init__()
+ self.pos = pos
+
+ def visit_Node(self, node):
+ copy = super(ApplyPositionAndCopy, self).visit_Node(node)
+ copy.pos = self.pos
+ return copy
+
+
+class TemplateTransform(VisitorTransform):
+ """
+ Makes a copy of a template tree while doing substitutions.
+
+ A dictionary "substitutions" should be passed in when calling
+ the transform; mapping names to replacement nodes. Then replacement
+ happens like this:
+ - If an ExprStatNode contains a single NameNode, whose name is
+ a key in the substitutions dictionary, the ExprStatNode is
+ replaced with a copy of the tree given in the dictionary.
+ It is the responsibility of the caller that the replacement
+ node is a valid statement.
+ - If a single NameNode is otherwise encountered, it is replaced
+ if its name is listed in the substitutions dictionary in the
+ same way. It is the responsibility of the caller to make sure
+ that the replacement nodes is a valid expression.
+
+ Also a list "temps" should be passed. Any names listed will
+ be transformed into anonymous, temporary names.
+
+ Currently supported for tempnames is:
+ NameNode
+ (various function and class definition nodes etc. should be added to this)
+
+ Each replacement node gets the position of the substituted node
+ recursively applied to every member node.
+ """
+
+ temp_name_counter = 0
+
+ def __call__(self, node, substitutions, temps, pos):
+ self.substitutions = substitutions
+ self.pos = pos
+ tempmap = {}
+ temphandles = []
+ for temp in temps:
+ TemplateTransform.temp_name_counter += 1
+ handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
+ tempmap[temp] = handle
+ temphandles.append(handle)
+ self.tempmap = tempmap
+ result = super(TemplateTransform, self).__call__(node)
+ if temps:
+ result = UtilNodes.TempsBlockNode(self.get_pos(node),
+ temps=temphandles,
+ body=result)
+ return result
+
+ def get_pos(self, node):
+ if self.pos:
+ return self.pos
+ else:
+ return node.pos
+
+ def visit_Node(self, node):
+ if node is None:
+ return None
+ else:
+ c = node.clone_node()
+ if self.pos is not None:
+ c.pos = self.pos
+ self.visitchildren(c)
+ return c
+
+ def try_substitution(self, node, key):
+ sub = self.substitutions.get(key)
+ if sub is not None:
+ pos = self.pos
+ if pos is None: pos = node.pos
+ return ApplyPositionAndCopy(pos)(sub)
+ else:
+ return self.visit_Node(node) # make copy as usual
+
+ def visit_NameNode(self, node):
+ temphandle = self.tempmap.get(node.name)
+ if temphandle:
+ # Replace name with temporary
+ return temphandle.ref(self.get_pos(node))
+ else:
+ return self.try_substitution(node, node.name)
+
+ def visit_ExprStatNode(self, node):
+ # If an expression-as-statement consists of only a replaceable
+ # NameNode, we replace the entire statement, not only the NameNode
+ if isinstance(node.expr, NameNode):
+ return self.try_substitution(node, node.expr.name)
+ else:
+ return self.visit_Node(node)
+
+
+def copy_code_tree(node):
+ return TreeCopier()(node)
+
+
+_match_indent = re.compile(u"^ *").match
+
+
+def strip_common_indent(lines):
+ """Strips empty lines and common indentation from the list of strings given in lines"""
+ # TODO: Facilitate textwrap.indent instead
+ lines = [x for x in lines if x.strip() != u""]
+ if lines:
+ minindent = min([len(_match_indent(x).group(0)) for x in lines])
+ lines = [x[minindent:] for x in lines]
+ return lines
+
+
+class TreeFragment(object):
+ def __init__(self, code, name=None, pxds=None, temps=None, pipeline=None, level=None, initial_pos=None):
+ if pxds is None:
+ pxds = {}
+ if temps is None:
+ temps = []
+ if pipeline is None:
+ pipeline = []
+ if not name:
+ name = "(tree fragment)"
+
+ if isinstance(code, _unicode):
+ def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
+
+ fmt_code = fmt(code)
+ fmt_pxds = {}
+ for key, value in pxds.items():
+ fmt_pxds[key] = fmt(value)
+ mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
+ if level is None:
+ t = t.body # Make sure a StatListNode is at the top
+ if not isinstance(t, StatListNode):
+ t = StatListNode(pos=mod.pos, stats=[t])
+ for transform in pipeline:
+ if transform is None:
+ continue
+ t = transform(t)
+ self.root = t
+ elif isinstance(code, Node):
+ if pxds:
+ raise NotImplementedError()
+ self.root = code
+ else:
+ raise ValueError("Unrecognized code format (accepts unicode and Node)")
+ self.temps = temps
+
+ def copy(self):
+ return copy_code_tree(self.root)
+
+ def substitute(self, nodes=None, temps=None, pos = None):
+ if nodes is None:
+ nodes = {}
+ if temps is None:
+ temps = []
+ return TemplateTransform()(self.root,
+ substitutions = nodes,
+ temps = self.temps + temps, pos = pos)
+
+
+class SetPosTransform(VisitorTransform):
+ def __init__(self, pos):
+ super(SetPosTransform, self).__init__()
+ self.pos = pos
+
+ def visit_Node(self, node):
+ node.pos = self.pos
+ self.visitchildren(node)
+ return node