diff options
author | arcadia-devtools <arcadia-devtools@yandex-team.ru> | 2022-02-09 12:00:52 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 15:58:17 +0300 |
commit | 8e1413fed79d1e8036e65228af6c93399ccf5502 (patch) | |
tree | 502c9df7b2614d20541c7a2d39d390e9a51877cc /contrib/python/pytest/py3/_pytest/_code/source.py | |
parent | 6b813c17d56d1d05f92c61ddc347d0e4d358fe85 (diff) | |
download | ydb-8e1413fed79d1e8036e65228af6c93399ccf5502.tar.gz |
intermediate changes
ref:614ed510ddd3cdf86a8c5dbf19afd113397e0172
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/_code/source.py')
-rw-r--r-- | contrib/python/pytest/py3/_pytest/_code/source.py | 328 |
1 files changed, 62 insertions, 266 deletions
diff --git a/contrib/python/pytest/py3/_pytest/_code/source.py b/contrib/python/pytest/py3/_pytest/_code/source.py index 28c11e5d5e..6f54057c0a 100644 --- a/contrib/python/pytest/py3/_pytest/_code/source.py +++ b/contrib/python/pytest/py3/_pytest/_code/source.py @@ -1,76 +1,59 @@ import ast import inspect -import linecache -import sys import textwrap import tokenize +import types import warnings from bisect import bisect_right -from types import CodeType -from types import FrameType -from typing import Any +from typing import Iterable from typing import Iterator from typing import List from typing import Optional -from typing import Sequence +from typing import overload from typing import Tuple from typing import Union -import py - -from _pytest.compat import get_real_func -from _pytest.compat import overload -from _pytest.compat import TYPE_CHECKING - -if TYPE_CHECKING: - from typing_extensions import Literal - class Source: - """ an immutable object holding a source code fragment, - possibly deindenting it. + """An immutable object holding a source code fragment. + + When using Source(...), the source lines are deindented. """ - _compilecounter = 0 - - def __init__(self, *parts, **kwargs) -> None: - self.lines = lines = [] # type: List[str] - de = kwargs.get("deindent", True) - for part in parts: - if not part: - partlines = [] # type: List[str] - elif isinstance(part, Source): - partlines = part.lines - elif isinstance(part, (tuple, list)): - partlines = [x.rstrip("\n") for x in part] - elif isinstance(part, str): - partlines = part.split("\n") - else: - partlines = getsource(part, deindent=de).lines - if de: - partlines = deindent(partlines) - lines.extend(partlines) - - def __eq__(self, other): - try: - return self.lines == other.lines - except AttributeError: - if isinstance(other, str): - return str(self) == other - return False + def __init__(self, obj: object = None) -> None: + if not obj: + self.lines: List[str] = [] + elif isinstance(obj, Source): + self.lines = obj.lines + elif isinstance(obj, (tuple, list)): + self.lines = deindent(x.rstrip("\n") for x in obj) + elif isinstance(obj, str): + self.lines = deindent(obj.split("\n")) + else: + try: + rawcode = getrawcode(obj) + src = inspect.getsource(rawcode) + except TypeError: + src = inspect.getsource(obj) # type: ignore[arg-type] + self.lines = deindent(src.split("\n")) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Source): + return NotImplemented + return self.lines == other.lines # Ignore type because of https://github.com/python/mypy/issues/4266. __hash__ = None # type: ignore @overload def __getitem__(self, key: int) -> str: - raise NotImplementedError() + ... - @overload # noqa: F811 - def __getitem__(self, key: slice) -> "Source": # noqa: F811 - raise NotImplementedError() + @overload + def __getitem__(self, key: slice) -> "Source": + ... - def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811 + def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: if isinstance(key, int): return self.lines[key] else: @@ -87,9 +70,7 @@ class Source: return len(self.lines) def strip(self) -> "Source": - """ return new source object with trailing - and leading blank lines removed. - """ + """Return new Source object with trailing and leading blank lines removed.""" start, end = 0, len(self) while start < end and not self.lines[start].strip(): start += 1 @@ -99,220 +80,36 @@ class Source: source.lines[:] = self.lines[start:end] return source - def putaround( - self, before: str = "", after: str = "", indent: str = " " * 4 - ) -> "Source": - """ return a copy of the source object with - 'before' and 'after' wrapped around it. - """ - beforesource = Source(before) - aftersource = Source(after) - newsource = Source() - lines = [(indent + line) for line in self.lines] - newsource.lines = beforesource.lines + lines + aftersource.lines - return newsource - def indent(self, indent: str = " " * 4) -> "Source": - """ return a copy of the source object with - all lines indented by the given indent-string. - """ + """Return a copy of the source object with all lines indented by the + given indent-string.""" newsource = Source() newsource.lines = [(indent + line) for line in self.lines] return newsource def getstatement(self, lineno: int) -> "Source": - """ return Source statement which contains the - given linenumber (counted from 0). - """ + """Return Source statement which contains the given linenumber + (counted from 0).""" start, end = self.getstatementrange(lineno) return self[start:end] def getstatementrange(self, lineno: int) -> Tuple[int, int]: - """ return (start, end) tuple which spans the minimal - statement region which containing the given lineno. - """ + """Return (start, end) tuple which spans the minimal statement region + which containing the given lineno.""" if not (0 <= lineno < len(self)): raise IndexError("lineno out of range") ast, start, end = getstatementrange_ast(lineno, self) return start, end def deindent(self) -> "Source": - """return a new source object deindented.""" + """Return a new Source object deindented.""" newsource = Source() newsource.lines[:] = deindent(self.lines) return newsource - def isparseable(self, deindent: bool = True) -> bool: - """ return True if source is parseable, heuristically - deindenting it by default. - """ - if deindent: - source = str(self.deindent()) - else: - source = str(self) - try: - ast.parse(source) - except (SyntaxError, ValueError, TypeError): - return False - else: - return True - def __str__(self) -> str: return "\n".join(self.lines) - @overload - def compile( - self, - filename: Optional[str] = ..., - mode: str = ..., - flag: "Literal[0]" = ..., - dont_inherit: int = ..., - _genframe: Optional[FrameType] = ..., - ) -> CodeType: - raise NotImplementedError() - - @overload # noqa: F811 - def compile( # noqa: F811 - self, - filename: Optional[str] = ..., - mode: str = ..., - flag: int = ..., - dont_inherit: int = ..., - _genframe: Optional[FrameType] = ..., - ) -> Union[CodeType, ast.AST]: - raise NotImplementedError() - - def compile( # noqa: F811 - self, - filename: Optional[str] = None, - mode: str = "exec", - flag: int = 0, - dont_inherit: int = 0, - _genframe: Optional[FrameType] = None, - ) -> Union[CodeType, ast.AST]: - """ return compiled code object. if filename is None - invent an artificial filename which displays - the source/line position of the caller frame. - """ - if not filename or py.path.local(filename).check(file=0): - if _genframe is None: - _genframe = sys._getframe(1) # the caller - fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno - base = "<%d-codegen " % self._compilecounter - self.__class__._compilecounter += 1 - if not filename: - filename = base + "%s:%d>" % (fn, lineno) - else: - filename = base + "%r %s:%d>" % (filename, fn, lineno) - source = "\n".join(self.lines) + "\n" - try: - co = compile(source, filename, mode, flag) - except SyntaxError as ex: - # re-represent syntax errors from parsing python strings - msglines = self.lines[: ex.lineno] - if ex.offset: - msglines.append(" " * ex.offset + "^") - msglines.append("(code was compiled probably from here: %s)" % filename) - newex = SyntaxError("\n".join(msglines)) - newex.offset = ex.offset - newex.lineno = ex.lineno - newex.text = ex.text - raise newex - else: - if flag & ast.PyCF_ONLY_AST: - assert isinstance(co, ast.AST) - return co - assert isinstance(co, CodeType) - lines = [(x + "\n") for x in self.lines] - # Type ignored because linecache.cache is private. - linecache.cache[filename] = (1, None, lines, filename) # type: ignore - return co - - -# -# public API shortcut functions -# - - -@overload -def compile_( - source: Union[str, bytes, ast.mod, ast.AST], - filename: Optional[str] = ..., - mode: str = ..., - flags: "Literal[0]" = ..., - dont_inherit: int = ..., -) -> CodeType: - raise NotImplementedError() - - -@overload # noqa: F811 -def compile_( # noqa: F811 - source: Union[str, bytes, ast.mod, ast.AST], - filename: Optional[str] = ..., - mode: str = ..., - flags: int = ..., - dont_inherit: int = ..., -) -> Union[CodeType, ast.AST]: - raise NotImplementedError() - - -def compile_( # noqa: F811 - source: Union[str, bytes, ast.mod, ast.AST], - filename: Optional[str] = None, - mode: str = "exec", - flags: int = 0, - dont_inherit: int = 0, -) -> Union[CodeType, ast.AST]: - """ compile the given source to a raw code object, - and maintain an internal cache which allows later - retrieval of the source code for the code object - and any recursively created code objects. - """ - if isinstance(source, ast.AST): - # XXX should Source support having AST? - assert filename is not None - co = compile(source, filename, mode, flags, dont_inherit) - assert isinstance(co, (CodeType, ast.AST)) - return co - _genframe = sys._getframe(1) # the caller - s = Source(source) - return s.compile(filename, mode, flags, _genframe=_genframe) - - -def getfslineno(obj: Any) -> Tuple[Union[str, py.path.local], int]: - """ Return source location (path, lineno) for the given object. - If the source cannot be determined return ("", -1). - - The line number is 0-based. - """ - from .code import Code - - # xxx let decorators etc specify a sane ordering - # NOTE: this used to be done in _pytest.compat.getfslineno, initially added - # in 6ec13a2b9. It ("place_as") appears to be something very custom. - obj = get_real_func(obj) - if hasattr(obj, "place_as"): - obj = obj.place_as - - try: - code = Code(obj) - except TypeError: - try: - fn = inspect.getsourcefile(obj) or inspect.getfile(obj) - except TypeError: - return "", -1 - - fspath = fn and py.path.local(fn) or "" - lineno = -1 - if fspath: - try: - _, lineno = findsource(obj) - except IOError: - pass - return fspath, lineno - else: - return code.path, code.firstlineno - # # helper functions @@ -329,35 +126,34 @@ def findsource(obj) -> Tuple[Optional[Source], int]: return source, lineno -def getsource(obj, **kwargs) -> Source: - from .code import getrawcode - - obj = getrawcode(obj) +def getrawcode(obj: object, trycall: bool = True) -> types.CodeType: + """Return code object for given function.""" try: - strsrc = inspect.getsource(obj) - except IndentationError: - strsrc = '"Buggy python version consider upgrading, cannot get source"' - assert isinstance(strsrc, str) - return Source(strsrc, **kwargs) + return obj.__code__ # type: ignore[attr-defined,no-any-return] + except AttributeError: + pass + if trycall: + call = getattr(obj, "__call__", None) + if call and not isinstance(obj, type): + return getrawcode(call, trycall=False) + raise TypeError(f"could not get code object for {obj!r}") -def deindent(lines: Sequence[str]) -> List[str]: +def deindent(lines: Iterable[str]) -> List[str]: return textwrap.dedent("\n".join(lines)).splitlines() def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: - import ast - - # flatten all statements and except handlers into one lineno-list - # AST's line numbers start indexing at 1 - values = [] # type: List[int] + # Flatten all statements and except handlers into one lineno-list. + # AST's line numbers start indexing at 1. + values: List[int] = [] for x in ast.walk(node): if isinstance(x, (ast.stmt, ast.ExceptHandler)): values.append(x.lineno - 1) for name in ("finalbody", "orelse"): - val = getattr(x, name, None) # type: Optional[List[ast.stmt]] + val: Optional[List[ast.stmt]] = getattr(x, name, None) if val: - # treat the finally/orelse part as its own statement + # Treat the finally/orelse part as its own statement. values.append(val[0].lineno - 1 - 1) values.sort() insert_index = bisect_right(values, lineno) @@ -378,13 +174,13 @@ def getstatementrange_ast( if astnode is None: content = str(source) # See #4260: - # don't produce duplicate warnings when compiling source to find ast + # Don't produce duplicate warnings when compiling source to find AST. with warnings.catch_warnings(): warnings.simplefilter("ignore") astnode = ast.parse(content, "source", "exec") start, end = get_statement_startend2(lineno, astnode) - # we need to correct the end: + # We need to correct the end: # - ast-parsing strips comments # - there might be empty lines # - we might have lesser indented code blocks at the end @@ -392,10 +188,10 @@ def getstatementrange_ast( end = len(source.lines) if end > start + 1: - # make sure we don't span differently indented code blocks - # by using the BlockFinder helper used which inspect.getsource() uses itself + # Make sure we don't span differently indented code blocks + # by using the BlockFinder helper used which inspect.getsource() uses itself. block_finder = inspect.BlockFinder() - # if we start with an indented line, put blockfinder to "started" mode + # If we start with an indented line, put blockfinder to "started" mode. block_finder.started = source.lines[start][0].isspace() it = ((x + "\n") for x in source.lines[start:end]) try: @@ -406,7 +202,7 @@ def getstatementrange_ast( except Exception: pass - # the end might still point to a comment or empty line, correct it + # The end might still point to a comment or empty line, correct it. while end: line = source.lines[end - 1].lstrip() if line.startswith("#") or not line: |