diff options
| author | arcadia-devtools <[email protected]> | 2022-02-09 12:00:52 +0300 |
|---|---|---|
| committer | Daniil Cherednik <[email protected]> | 2022-02-10 15:58:17 +0300 |
| commit | 8e1413fed79d1e8036e65228af6c93399ccf5502 (patch) | |
| tree | 502c9df7b2614d20541c7a2d39d390e9a51877cc /contrib/python/pytest/py3/_pytest/_code | |
| parent | 6b813c17d56d1d05f92c61ddc347d0e4d358fe85 (diff) | |
intermediate changes
ref:614ed510ddd3cdf86a8c5dbf19afd113397e0172
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/_code')
| -rw-r--r-- | contrib/python/pytest/py3/_pytest/_code/__init__.py | 32 | ||||
| -rw-r--r-- | contrib/python/pytest/py3/_pytest/_code/code.py | 656 | ||||
| -rw-r--r-- | contrib/python/pytest/py3/_pytest/_code/source.py | 328 |
3 files changed, 446 insertions, 570 deletions
diff --git a/contrib/python/pytest/py3/_pytest/_code/__init__.py b/contrib/python/pytest/py3/_pytest/_code/__init__.py index 370e41dc9f3..511d0dde661 100644 --- a/contrib/python/pytest/py3/_pytest/_code/__init__.py +++ b/contrib/python/pytest/py3/_pytest/_code/__init__.py @@ -1,10 +1,22 @@ -""" python inspection/code generation API """ -from .code import Code # noqa -from .code import ExceptionInfo # noqa -from .code import filter_traceback # noqa -from .code import Frame # noqa -from .code import getrawcode # noqa -from .code import Traceback # noqa -from .source import compile_ as compile # noqa -from .source import getfslineno # noqa -from .source import Source # noqa +"""Python inspection/code generation API.""" +from .code import Code +from .code import ExceptionInfo +from .code import filter_traceback +from .code import Frame +from .code import getfslineno +from .code import Traceback +from .code import TracebackEntry +from .source import getrawcode +from .source import Source + +__all__ = [ + "Code", + "ExceptionInfo", + "filter_traceback", + "Frame", + "getfslineno", + "getrawcode", + "Traceback", + "TracebackEntry", + "Source", +] diff --git a/contrib/python/pytest/py3/_pytest/_code/code.py b/contrib/python/pytest/py3/_pytest/_code/code.py index 965074c924b..423069330a5 100644 --- a/contrib/python/pytest/py3/_pytest/_code/code.py +++ b/contrib/python/pytest/py3/_pytest/_code/code.py @@ -5,6 +5,7 @@ import traceback from inspect import CO_VARARGS from inspect import CO_VARKEYWORDS from io import StringIO +from pathlib import Path from traceback import format_exception_only from types import CodeType from types import FrameType @@ -15,11 +16,15 @@ from typing import Dict from typing import Generic from typing import Iterable from typing import List +from typing import Mapping from typing import Optional +from typing import overload from typing import Pattern from typing import Sequence from typing import Set from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from weakref import ref @@ -29,35 +34,34 @@ import pluggy import py import _pytest +from _pytest._code.source import findsource +from _pytest._code.source import getrawcode +from _pytest._code.source import getstatementrange_ast +from _pytest._code.source import Source from _pytest._io import TerminalWriter from _pytest._io.saferepr import safeformat from _pytest._io.saferepr import saferepr -from _pytest.compat import ATTRS_EQ_FIELD -from _pytest.compat import overload -from _pytest.compat import TYPE_CHECKING +from _pytest.compat import final +from _pytest.compat import get_real_func if TYPE_CHECKING: - from typing import Type from typing_extensions import Literal - from weakref import ReferenceType # noqa: F401 + from weakref import ReferenceType - from _pytest._code import Source - - _TracebackStyle = Literal["long", "short", "line", "no", "native"] + _TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"] class Code: - """ wrapper around Python code objects """ - - def __init__(self, rawcode) -> None: - if not hasattr(rawcode, "co_filename"): - rawcode = getrawcode(rawcode) - if not isinstance(rawcode, CodeType): - raise TypeError("not a code object: {!r}".format(rawcode)) - self.filename = rawcode.co_filename - self.firstlineno = rawcode.co_firstlineno - 1 - self.name = rawcode.co_name - self.raw = rawcode + """Wrapper around Python code objects.""" + + __slots__ = ("raw",) + + def __init__(self, obj: CodeType) -> None: + self.raw = obj + + @classmethod + def from_function(cls, obj: object) -> "Code": + return cls(getrawcode(obj)) def __eq__(self, other): return self.raw == other.raw @@ -65,14 +69,18 @@ class Code: # Ignore type because of https://github.com/python/mypy/issues/4266. __hash__ = None # type: ignore - def __ne__(self, other): - return not self == other + @property + def firstlineno(self) -> int: + return self.raw.co_firstlineno - 1 + + @property + def name(self) -> str: + return self.raw.co_name @property def path(self) -> Union[py.path.local, str]: - """ return a path object pointing to source code (or a str in case - of OSError / non-existing file). - """ + """Return a path object pointing to source code, or an ``str`` in + case of ``OSError`` / non-existing file.""" if not self.raw.co_filename: return "" try: @@ -88,28 +96,22 @@ class Code: @property def fullsource(self) -> Optional["Source"]: - """ return a _pytest._code.Source object for the full source file of the code - """ - from _pytest._code import source - - full, _ = source.findsource(self.raw) + """Return a _pytest._code.Source object for the full source file of the code.""" + full, _ = findsource(self.raw) return full def source(self) -> "Source": - """ return a _pytest._code.Source object for the code object's source only - """ + """Return a _pytest._code.Source object for the code object's source only.""" # return source only for that part of code - import _pytest._code - - return _pytest._code.Source(self.raw) + return Source(self.raw) def getargs(self, var: bool = False) -> Tuple[str, ...]: - """ return a tuple with the argument names for the code object + """Return a tuple with the argument names for the code object. - if 'var' is set True also return the names of the variable and - keyword arguments when present + If 'var' is set True also return the names of the variable and + keyword arguments when present. """ - # handfull shortcut for getting args + # Handy shortcut for getting args. raw = self.raw argcount = raw.co_argcount if var: @@ -122,55 +124,54 @@ class Frame: """Wrapper around a Python frame holding f_locals and f_globals in which expressions can be evaluated.""" + __slots__ = ("raw",) + def __init__(self, frame: FrameType) -> None: - self.lineno = frame.f_lineno - 1 - self.f_globals = frame.f_globals - self.f_locals = frame.f_locals self.raw = frame - self.code = Code(frame.f_code) @property - def statement(self) -> "Source": - """ statement this frame is at """ - import _pytest._code + def lineno(self) -> int: + return self.raw.f_lineno - 1 + + @property + def f_globals(self) -> Dict[str, Any]: + return self.raw.f_globals + + @property + def f_locals(self) -> Dict[str, Any]: + return self.raw.f_locals + + @property + def code(self) -> Code: + return Code(self.raw.f_code) + @property + def statement(self) -> "Source": + """Statement this frame is at.""" if self.code.fullsource is None: - return _pytest._code.Source("") + return Source("") return self.code.fullsource.getstatement(self.lineno) def eval(self, code, **vars): - """ evaluate 'code' in the frame + """Evaluate 'code' in the frame. - 'vars' are optional additional local variables + 'vars' are optional additional local variables. - returns the result of the evaluation + Returns the result of the evaluation. """ f_locals = self.f_locals.copy() f_locals.update(vars) return eval(code, self.f_globals, f_locals) - def exec_(self, code, **vars) -> None: - """ exec 'code' in the frame - - 'vars' are optional; additional local variables - """ - f_locals = self.f_locals.copy() - f_locals.update(vars) - exec(code, self.f_globals, f_locals) - def repr(self, object: object) -> str: - """ return a 'safe' (non-recursive, one-line) string repr for 'object' - """ + """Return a 'safe' (non-recursive, one-line) string repr for 'object'.""" return saferepr(object) - def is_true(self, object): - return object - def getargs(self, var: bool = False): - """ return a list of tuples (name, value) for all arguments + """Return a list of tuples (name, value) for all arguments. - if 'var' is set True also include the variable and keyword - arguments when present + If 'var' is set True, also include the variable and keyword arguments + when present. """ retval = [] for arg in self.code.getargs(var): @@ -182,15 +183,22 @@ class Frame: class TracebackEntry: - """ a single entry in a traceback """ + """A single entry in a Traceback.""" - _repr_style = None # type: Optional[Literal["short", "long"]] - exprinfo = None + __slots__ = ("_rawentry", "_excinfo", "_repr_style") - def __init__(self, rawentry: TracebackType, excinfo=None) -> None: - self._excinfo = excinfo + def __init__( + self, + rawentry: TracebackType, + excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None, + ) -> None: self._rawentry = rawentry - self.lineno = rawentry.tb_lineno - 1 + self._excinfo = excinfo + self._repr_style: Optional['Literal["short", "long"]'] = None + + @property + def lineno(self) -> int: + return self._rawentry.tb_lineno - 1 def set_repr_style(self, mode: "Literal['short', 'long']") -> None: assert mode in ("short", "long") @@ -209,30 +217,28 @@ class TracebackEntry: @property def statement(self) -> "Source": - """ _pytest._code.Source object for the current statement """ + """_pytest._code.Source object for the current statement.""" source = self.frame.code.fullsource assert source is not None return source.getstatement(self.lineno) @property - def path(self): - """ path to the source code """ + def path(self) -> Union[py.path.local, str]: + """Path to the source code.""" return self.frame.code.path @property def locals(self) -> Dict[str, Any]: - """ locals of underlying frame """ + """Locals of underlying frame.""" return self.frame.f_locals def getfirstlinesource(self) -> int: return self.frame.code.firstlineno def getsource(self, astcache=None) -> Optional["Source"]: - """ return failing source code. """ + """Return failing source code.""" # we use the passed in astcache to not reparse asttrees # within exception info printing - from _pytest._code.source import getstatementrange_ast - source = self.frame.code.fullsource if source is None: return None @@ -255,59 +261,71 @@ class TracebackEntry: source = property(getsource) - def ishidden(self): - """ return True if the current frame has a var __tracebackhide__ - resolving to True. + def ishidden(self) -> bool: + """Return True if the current frame has a var __tracebackhide__ + resolving to True. - If __tracebackhide__ is a callable, it gets called with the - ExceptionInfo instance and can decide whether to hide the traceback. + If __tracebackhide__ is a callable, it gets called with the + ExceptionInfo instance and can decide whether to hide the traceback. - mostly for internal use + Mostly for internal use. """ - f = self.frame - tbh = f.f_locals.get( - "__tracebackhide__", f.f_globals.get("__tracebackhide__", False) + tbh: Union[bool, Callable[[Optional[ExceptionInfo[BaseException]]], bool]] = ( + False ) + for maybe_ns_dct in (self.frame.f_locals, self.frame.f_globals): + # in normal cases, f_locals and f_globals are dictionaries + # however via `exec(...)` / `eval(...)` they can be other types + # (even incorrect types!). + # as such, we suppress all exceptions while accessing __tracebackhide__ + try: + tbh = maybe_ns_dct["__tracebackhide__"] + except Exception: + pass + else: + break if tbh and callable(tbh): return tbh(None if self._excinfo is None else self._excinfo()) return tbh def __str__(self) -> str: - try: - fn = str(self.path) - except py.error.Error: - fn = "???" name = self.frame.code.name try: line = str(self.statement).lstrip() except KeyboardInterrupt: raise - except: # noqa + except BaseException: line = "???" - return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line) + # This output does not quite match Python's repr for traceback entries, + # but changing it to do so would break certain plugins. See + # https://github.com/pytest-dev/pytest/pull/7535/ for details. + return " File %r:%d in %s\n %s\n" % ( + str(self.path), + self.lineno + 1, + name, + line, + ) @property def name(self) -> str: - """ co_name of underlying code """ + """co_name of underlying code.""" return self.frame.code.raw.co_name class Traceback(List[TracebackEntry]): - """ Traceback objects encapsulate and offer higher level - access to Traceback entries. - """ + """Traceback objects encapsulate and offer higher level access to Traceback entries.""" def __init__( self, tb: Union[TracebackType, Iterable[TracebackEntry]], - excinfo: Optional["ReferenceType[ExceptionInfo]"] = None, + excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None, ) -> None: - """ initialize from given python traceback object and ExceptionInfo """ + """Initialize from given python traceback object and ExceptionInfo.""" self._excinfo = excinfo if isinstance(tb, TracebackType): def f(cur: TracebackType) -> Iterable[TracebackEntry]: - cur_ = cur # type: Optional[TracebackType] + cur_: Optional[TracebackType] = cur while cur_ is not None: yield TracebackEntry(cur_, excinfo=excinfo) cur_ = cur_.tb_next @@ -321,16 +339,16 @@ class Traceback(List[TracebackEntry]): path=None, lineno: Optional[int] = None, firstlineno: Optional[int] = None, - excludepath=None, + excludepath: Optional[py.path.local] = None, ) -> "Traceback": - """ return a Traceback instance wrapping part of this Traceback + """Return a Traceback instance wrapping part of this Traceback. - by providing any combination of path, lineno and firstlineno, the - first frame to start the to-be-returned traceback is determined + By providing any combination of path, lineno and firstlineno, the + first frame to start the to-be-returned traceback is determined. - this allows cutting the first part of a Traceback instance e.g. - for formatting reasons (removing some uninteresting bits that deal - with handling of the exception/traceback) + This allows cutting the first part of a Traceback instance e.g. + for formatting reasons (removing some uninteresting bits that deal + with handling of the exception/traceback). """ for x in self: code = x.frame.code @@ -350,15 +368,13 @@ class Traceback(List[TracebackEntry]): @overload def __getitem__(self, key: int) -> TracebackEntry: - raise NotImplementedError() + ... - @overload # noqa: F811 - def __getitem__(self, key: slice) -> "Traceback": # noqa: F811 - raise NotImplementedError() + @overload + def __getitem__(self, key: slice) -> "Traceback": + ... - def __getitem__( # noqa: F811 - self, key: Union[int, slice] - ) -> Union[TracebackEntry, "Traceback"]: + def __getitem__(self, key: Union[int, slice]) -> Union[TracebackEntry, "Traceback"]: if isinstance(key, slice): return self.__class__(super().__getitem__(key)) else: @@ -367,21 +383,19 @@ class Traceback(List[TracebackEntry]): def filter( self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden() ) -> "Traceback": - """ return a Traceback instance with certain items removed + """Return a Traceback instance with certain items removed - fn is a function that gets a single argument, a TracebackEntry - instance, and should return True when the item should be added - to the Traceback, False when not + fn is a function that gets a single argument, a TracebackEntry + instance, and should return True when the item should be added + to the Traceback, False when not. - by default this removes all the TracebackEntries which are hidden - (see ishidden() above) + By default this removes all the TracebackEntries which are hidden + (see ishidden() above). """ return Traceback(filter(fn, self), self._excinfo) def getcrashentry(self) -> TracebackEntry: - """ return last non-hidden traceback entry that lead - to the exception of a traceback. - """ + """Return last non-hidden traceback entry that lead to the exception of a traceback.""" for i in range(-1, -len(self) - 1, -1): entry = self[i] if not entry.ishidden(): @@ -389,10 +403,9 @@ class Traceback(List[TracebackEntry]): return self[-1] def recursionindex(self) -> Optional[int]: - """ return the index of the frame/TracebackEntry where recursion - originates if appropriate, None if no recursion occurred - """ - cache = {} # type: Dict[Tuple[Any, int, int], List[Dict[str, Any]]] + """Return the index of the frame/TracebackEntry where recursion originates if + appropriate, None if no recursion occurred.""" + cache: Dict[Tuple[Any, int, int], List[Dict[str, Any]]] = {} for i, entry in enumerate(self): # id for the code.raw is needed to work around # the strange metaprogramming in the decorator lib from pypi @@ -405,12 +418,10 @@ class Traceback(List[TracebackEntry]): f = entry.frame loc = f.f_locals for otherloc in values: - if f.is_true( - f.eval( - co_equal, - __recursioncache_locals_1=loc, - __recursioncache_locals_2=otherloc, - ) + if f.eval( + co_equal, + __recursioncache_locals_1=loc, + __recursioncache_locals_2=otherloc, ): return i values.append(entry.frame.f_locals) @@ -422,37 +433,36 @@ co_equal = compile( ) -_E = TypeVar("_E", bound=BaseException) +_E = TypeVar("_E", bound=BaseException, covariant=True) +@final @attr.s(repr=False) class ExceptionInfo(Generic[_E]): - """ wraps sys.exc_info() objects and offers - help for navigating the traceback. - """ + """Wraps sys.exc_info() objects and offers help for navigating the traceback.""" _assert_start_repr = "AssertionError('assert " - _excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]]) + _excinfo = attr.ib(type=Optional[Tuple[Type["_E"], "_E", TracebackType]]) _striptext = attr.ib(type=str, default="") _traceback = attr.ib(type=Optional[Traceback], default=None) @classmethod def from_exc_info( cls, - exc_info: Tuple["Type[_E]", "_E", TracebackType], + exc_info: Tuple[Type[_E], _E, TracebackType], exprinfo: Optional[str] = None, ) -> "ExceptionInfo[_E]": - """returns an ExceptionInfo for an existing exc_info tuple. + """Return an ExceptionInfo for an existing exc_info tuple. .. warning:: Experimental API - - :param exprinfo: a text string helping to determine if we should - strip ``AssertionError`` from the output, defaults - to the exception message/``__str__()`` + :param exprinfo: + A text string helping to determine if we should strip + ``AssertionError`` from the output. Defaults to the exception + message/``__str__()``. """ _striptext = "" if exprinfo is None and isinstance(exc_info[1], AssertionError): @@ -468,16 +478,16 @@ class ExceptionInfo(Generic[_E]): def from_current( cls, exprinfo: Optional[str] = None ) -> "ExceptionInfo[BaseException]": - """returns an ExceptionInfo matching the current traceback + """Return an ExceptionInfo matching the current traceback. .. warning:: Experimental API - - :param exprinfo: a text string helping to determine if we should - strip ``AssertionError`` from the output, defaults - to the exception message/``__str__()`` + :param exprinfo: + A text string helping to determine if we should strip + ``AssertionError`` from the output. Defaults to the exception + message/``__str__()``. """ tup = sys.exc_info() assert tup[0] is not None, "no current exception" @@ -488,18 +498,17 @@ class ExceptionInfo(Generic[_E]): @classmethod def for_later(cls) -> "ExceptionInfo[_E]": - """return an unfilled ExceptionInfo - """ + """Return an unfilled ExceptionInfo.""" return cls(None) - def fill_unfilled(self, exc_info: Tuple["Type[_E]", _E, TracebackType]) -> None: - """fill an unfilled ExceptionInfo created with for_later()""" + def fill_unfilled(self, exc_info: Tuple[Type[_E], _E, TracebackType]) -> None: + """Fill an unfilled ExceptionInfo created with ``for_later()``.""" assert self._excinfo is None, "ExceptionInfo was already filled" self._excinfo = exc_info @property - def type(self) -> "Type[_E]": - """the exception class""" + def type(self) -> Type[_E]: + """The exception class.""" assert ( self._excinfo is not None ), ".type can only be used after the context manager exits" @@ -507,7 +516,7 @@ class ExceptionInfo(Generic[_E]): @property def value(self) -> _E: - """the exception value""" + """The exception value.""" assert ( self._excinfo is not None ), ".value can only be used after the context manager exits" @@ -515,7 +524,7 @@ class ExceptionInfo(Generic[_E]): @property def tb(self) -> TracebackType: - """the exception raw traceback""" + """The exception raw traceback.""" assert ( self._excinfo is not None ), ".tb can only be used after the context manager exits" @@ -523,7 +532,7 @@ class ExceptionInfo(Generic[_E]): @property def typename(self) -> str: - """the type name of the exception""" + """The type name of the exception.""" assert ( self._excinfo is not None ), ".typename can only be used after the context manager exits" @@ -531,7 +540,7 @@ class ExceptionInfo(Generic[_E]): @property def traceback(self) -> Traceback: - """the traceback""" + """The traceback.""" if self._traceback is None: self._traceback = Traceback(self.tb, excinfo=ref(self)) return self._traceback @@ -548,12 +557,12 @@ class ExceptionInfo(Generic[_E]): ) def exconly(self, tryshort: bool = False) -> str: - """ return the exception as a string + """Return the exception as a string. - when 'tryshort' resolves to True, and the exception is a - _pytest._code._AssertionError, only the actual exception part of - the exception representation is returned (so 'AssertionError: ' is - removed from the beginning) + When 'tryshort' resolves to True, and the exception is a + _pytest._code._AssertionError, only the actual exception part of + the exception representation is returned (so 'AssertionError: ' is + removed from the beginning). """ lines = format_exception_only(self.type, self.value) text = "".join(lines) @@ -564,9 +573,12 @@ class ExceptionInfo(Generic[_E]): return text def errisinstance( - self, exc: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]] + self, exc: Union[Type[BaseException], Tuple[Type[BaseException], ...]] ) -> bool: - """ return True if the exception is an instance of exc """ + """Return True if the exception is an instance of exc. + + Consider using ``isinstance(excinfo.value, exc)`` instead. + """ return isinstance(self.value, exc) def _getreprcrash(self) -> "ReprFileLocation": @@ -585,14 +597,14 @@ class ExceptionInfo(Generic[_E]): truncate_locals: bool = True, chain: bool = True, ) -> Union["ReprExceptionInfo", "ExceptionChainRepr"]: - """ - Return str()able representation of this exception info. + """Return str()able representation of this exception info. :param bool showlocals: Show locals per traceback entry. Ignored if ``style=="native"``. - :param str style: long|short|no|native traceback style + :param str style: + long|short|no|native|value traceback style. :param bool abspath: If paths should be changed to absolute or left unchanged. @@ -607,7 +619,8 @@ class ExceptionInfo(Generic[_E]): :param bool truncate_locals: With ``showlocals==True``, make sure locals can be safely represented as strings. - :param bool chain: if chained exceptions in Python 3 should be shown. + :param bool chain: + If chained exceptions in Python 3 should be shown. .. versionchanged:: 3.9 @@ -634,24 +647,24 @@ class ExceptionInfo(Generic[_E]): ) return fmt.repr_excinfo(self) - def match(self, regexp: "Union[str, Pattern]") -> "Literal[True]": - """ - Check whether the regular expression `regexp` matches the string + def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]": + """Check whether the regular expression `regexp` matches the string representation of the exception using :func:`python:re.search`. - If it matches `True` is returned. - If it doesn't match an `AssertionError` is raised. + + If it matches `True` is returned, otherwise an `AssertionError` is raised. """ __tracebackhide__ = True - assert re.search( - regexp, str(self.value) - ), "Pattern {!r} does not match {!r}".format(regexp, str(self.value)) + msg = "Regex pattern {!r} does not match {!r}." + if regexp == str(self.value): + msg += " Did you mean to `re.escape()` the regex?" + assert re.search(regexp, str(self.value)), msg.format(regexp, str(self.value)) # Return True to allow for "assert excinfo.match()". return True @attr.s class FormattedExcinfo: - """ presenting information about failing Functions and Generators. """ + """Presenting information about failing Functions and Generators.""" # for traceback entries flow_marker = ">" @@ -667,17 +680,17 @@ class FormattedExcinfo: astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False) def _getindent(self, source: "Source") -> int: - # figure out indent for given source + # Figure out indent for the given source. try: s = str(source.getstatement(len(source) - 1)) except KeyboardInterrupt: raise - except: # noqa + except BaseException: try: s = str(source[-1]) except KeyboardInterrupt: raise - except: # noqa + except BaseException: return 0 return 4 + (len(s) - len(s.lstrip())) @@ -697,17 +710,15 @@ class FormattedExcinfo: def get_source( self, - source: "Source", + source: Optional["Source"], line_index: int = -1, - excinfo: Optional[ExceptionInfo] = None, + excinfo: Optional[ExceptionInfo[BaseException]] = None, short: bool = False, ) -> List[str]: - """ return formatted and marked up source lines. """ - import _pytest._code - + """Return formatted and marked up source lines.""" lines = [] if source is None or line_index >= len(source.lines): - source = _pytest._code.Source("???") + source = Source("???") line_index = 0 if line_index < 0: line_index += len(source) @@ -726,11 +737,14 @@ class FormattedExcinfo: return lines def get_exconly( - self, excinfo: ExceptionInfo, indent: int = 4, markall: bool = False + self, + excinfo: ExceptionInfo[BaseException], + indent: int = 4, + markall: bool = False, ) -> List[str]: lines = [] indentstr = " " * indent - # get the real exception information out + # Get the real exception information out. exlines = excinfo.exconly(tryshort=True).split("\n") failindent = self.fail_marker + indentstr[1:] for line in exlines: @@ -739,7 +753,7 @@ class FormattedExcinfo: failindent = indentstr return lines - def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]: + def repr_locals(self, locals: Mapping[str, object]) -> Optional["ReprLocals"]: if self.showlocals: lines = [] keys = [loc for loc in locals if loc[0] != "@"] @@ -756,9 +770,8 @@ class FormattedExcinfo: str_repr = saferepr(value) else: str_repr = safeformat(value) - # if len(str_repr) < 70 or not isinstance(value, - # (list, tuple, dict)): - lines.append("{:<10} = {}".format(name, str_repr)) + # if len(str_repr) < 70 or not isinstance(value, (list, tuple, dict)): + lines.append(f"{name:<10} = {str_repr}") # else: # self._line("%-10s =\\" % (name,)) # # XXX @@ -767,20 +780,19 @@ class FormattedExcinfo: return None def repr_traceback_entry( - self, entry: TracebackEntry, excinfo: Optional[ExceptionInfo] = None + self, + entry: TracebackEntry, + excinfo: Optional[ExceptionInfo[BaseException]] = None, ) -> "ReprEntry": - import _pytest._code - - source = self._getentrysource(entry) - if source is None: - source = _pytest._code.Source("???") - line_index = 0 - else: - line_index = entry.lineno - entry.getfirstlinesource() - - lines = [] # type: List[str] + lines: List[str] = [] style = entry._repr_style if entry._repr_style is not None else self.style if style in ("short", "long"): + source = self._getentrysource(entry) + if source is None: + source = Source("???") + line_index = 0 + else: + line_index = entry.lineno - entry.getfirstlinesource() short = style == "short" reprargs = self.repr_args(entry) if not short else None s = self.get_source(source, line_index, excinfo, short=short) @@ -793,9 +805,14 @@ class FormattedExcinfo: reprfileloc = ReprFileLocation(path, entry.lineno + 1, message) localsrepr = self.repr_locals(entry.locals) return ReprEntry(lines, reprargs, localsrepr, reprfileloc, style) - if excinfo: - lines.extend(self.get_exconly(excinfo, indent=4)) - return ReprEntry(lines, None, None, None, style) + elif style == "value": + if excinfo: + lines.extend(str(excinfo.value).split("\n")) + return ReprEntry(lines, None, None, None, style) + else: + if excinfo: + lines.extend(self.get_exconly(excinfo, indent=4)) + return ReprEntry(lines, None, None, None, style) def _makepath(self, path): if not self.abspath: @@ -807,18 +824,23 @@ class FormattedExcinfo: path = np return path - def repr_traceback(self, excinfo: ExceptionInfo) -> "ReprTraceback": + def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> "ReprTraceback": traceback = excinfo.traceback if self.tbfilter: traceback = traceback.filter() - if excinfo.errisinstance(RecursionError): + if isinstance(excinfo.value, RecursionError): traceback, extraline = self._truncate_recursive_traceback(traceback) else: extraline = None last = traceback[-1] entries = [] + if self.style == "value": + reprentry = self.repr_traceback_entry(last, excinfo) + entries.append(reprentry) + return ReprTraceback(entries, None, style=self.style) + for index, entry in enumerate(traceback): einfo = (last == entry) and excinfo or None reprentry = self.repr_traceback_entry(entry, einfo) @@ -828,22 +850,23 @@ class FormattedExcinfo: def _truncate_recursive_traceback( self, traceback: Traceback ) -> Tuple[Traceback, Optional[str]]: - """ - Truncate the given recursive traceback trying to find the starting point - of the recursion. + """Truncate the given recursive traceback trying to find the starting + point of the recursion. - The detection is done by going through each traceback entry and finding the - point in which the locals of the frame are equal to the locals of a previous frame (see ``recursionindex()``. + The detection is done by going through each traceback entry and + finding the point in which the locals of the frame are equal to the + locals of a previous frame (see ``recursionindex()``). - Handle the situation where the recursion process might raise an exception (for example - comparing numpy arrays using equality raises a TypeError), in which case we do our best to - warn the user of the error and show a limited traceback. + Handle the situation where the recursion process might raise an + exception (for example comparing numpy arrays using equality raises a + TypeError), in which case we do our best to warn the user of the + error and show a limited traceback. """ try: recursionindex = traceback.recursionindex() except Exception as e: max_frames = 10 - extraline = ( + extraline: Optional[str] = ( "!!! Recursion error detected, but an error occurred locating the origin of recursion.\n" " The following exception happened when comparing locals in the stack frame:\n" " {exc_type}: {exc_msg}\n" @@ -853,7 +876,7 @@ class FormattedExcinfo: exc_msg=str(e), max_frames=max_frames, total=len(traceback), - ) # type: Optional[str] + ) # Type ignored because adding two instaces of a List subtype # currently incorrectly has type List instead of the subtype. traceback = traceback[:max_frames] + traceback[-max_frames:] # type: ignore @@ -866,22 +889,26 @@ class FormattedExcinfo: return traceback, extraline - def repr_excinfo(self, excinfo: ExceptionInfo) -> "ExceptionChainRepr": - repr_chain = ( - [] - ) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]] - e = excinfo.value - excinfo_ = excinfo # type: Optional[ExceptionInfo] + def repr_excinfo( + self, excinfo: ExceptionInfo[BaseException] + ) -> "ExceptionChainRepr": + repr_chain: List[ + Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]] + ] = [] + e: Optional[BaseException] = excinfo.value + excinfo_: Optional[ExceptionInfo[BaseException]] = excinfo descr = None - seen = set() # type: Set[int] + seen: Set[int] = set() while e is not None and id(e) not in seen: seen.add(id(e)) if excinfo_: reprtraceback = self.repr_traceback(excinfo_) - reprcrash = excinfo_._getreprcrash() # type: Optional[ReprFileLocation] + reprcrash: Optional[ReprFileLocation] = ( + excinfo_._getreprcrash() if self.style != "value" else None + ) else: - # fallback to native repr if the exception doesn't have a traceback: - # ExceptionInfo objects require a full traceback to work + # Fallback to native repr if the exception doesn't have a traceback: + # ExceptionInfo objects require a full traceback to work. reprtraceback = ReprTracebackNative( traceback.format_exception(type(e), e, None) ) @@ -912,7 +939,7 @@ class FormattedExcinfo: return ExceptionChainRepr(repr_chain) [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore [email protected](eq=False) class TerminalRepr: def __str__(self) -> str: # FYI this is called from pytest-xdist's serialization of exception @@ -929,10 +956,15 @@ class TerminalRepr: raise NotImplementedError() [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore +# This class is abstract -- only subclasses are instantiated. [email protected](eq=False) class ExceptionRepr(TerminalRepr): - def __attrs_post_init__(self): - self.sections = [] # type: List[Tuple[str, str, str]] + # Provided by subclasses. + reprcrash: Optional["ReprFileLocation"] + reprtraceback: "ReprTraceback" + + def __attrs_post_init__(self) -> None: + self.sections: List[Tuple[str, str, str]] = [] def addsection(self, name: str, content: str, sep: str = "-") -> None: self.sections.append((name, content, sep)) @@ -943,7 +975,7 @@ class ExceptionRepr(TerminalRepr): tw.line(content) [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore [email protected](eq=False) class ExceptionChainRepr(ExceptionRepr): chain = attr.ib( type=Sequence[ @@ -951,10 +983,10 @@ class ExceptionChainRepr(ExceptionRepr): ] ) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: super().__attrs_post_init__() # reprcrash and reprtraceback of the outermost (the newest) exception - # in the chain + # in the chain. self.reprtraceback = self.chain[-1][0] self.reprcrash = self.chain[-1][1] @@ -967,7 +999,7 @@ class ExceptionChainRepr(ExceptionRepr): super().toterminal(tw) [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore [email protected](eq=False) class ReprExceptionInfo(ExceptionRepr): reprtraceback = attr.ib(type="ReprTraceback") reprcrash = attr.ib(type="ReprFileLocation") @@ -977,7 +1009,7 @@ class ReprExceptionInfo(ExceptionRepr): super().toterminal(tw) [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore [email protected](eq=False) class ReprTraceback(TerminalRepr): reprentries = attr.ib(type=Sequence[Union["ReprEntry", "ReprEntryNative"]]) extraline = attr.ib(type=Optional[str]) @@ -986,7 +1018,7 @@ class ReprTraceback(TerminalRepr): entrysep = "_ " def toterminal(self, tw: TerminalWriter) -> None: - # the entries might have different styles + # The entries might have different styles. for i, entry in enumerate(self.reprentries): if entry.style == "long": tw.line("") @@ -1011,16 +1043,16 @@ class ReprTracebackNative(ReprTraceback): self.extraline = None [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore [email protected](eq=False) class ReprEntryNative(TerminalRepr): lines = attr.ib(type=Sequence[str]) - style = "native" # type: _TracebackStyle + style: "_TracebackStyle" = "native" def toterminal(self, tw: TerminalWriter) -> None: tw.write("".join(self.lines)) [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore [email protected](eq=False) class ReprEntry(TerminalRepr): lines = attr.ib(type=Sequence[str]) reprfuncargs = attr.ib(type=Optional["ReprFuncArgs"]) @@ -1029,7 +1061,7 @@ class ReprEntry(TerminalRepr): style = attr.ib(type="_TracebackStyle") def _write_entry_lines(self, tw: TerminalWriter) -> None: - """Writes the source code portions of a list of traceback entries with syntax highlighting. + """Write the source code portions of a list of traceback entries with syntax highlighting. Usually entries are lines like these: @@ -1042,28 +1074,34 @@ class ReprEntry(TerminalRepr): character, as doing so might break line continuations. """ - indent_size = 4 - - def is_fail(line): - return line.startswith("{} ".format(FormattedExcinfo.fail_marker)) - if not self.lines: return # separate indents and source lines that are not failures: we want to # highlight the code but not the indentation, which may contain markers # such as "> assert 0" - indents = [] - source_lines = [] - for line in self.lines: - if not is_fail(line): - indents.append(line[:indent_size]) - source_lines.append(line[indent_size:]) + fail_marker = f"{FormattedExcinfo.fail_marker} " + indent_size = len(fail_marker) + indents: List[str] = [] + source_lines: List[str] = [] + failure_lines: List[str] = [] + for index, line in enumerate(self.lines): + is_failure_line = line.startswith(fail_marker) + if is_failure_line: + # from this point on all lines are considered part of the failure + failure_lines.extend(self.lines[index:]) + break + else: + if self.style == "value": + source_lines.append(line) + else: + indents.append(line[:indent_size]) + source_lines.append(line[indent_size:]) tw._write_source(source_lines, indents) # failure lines are always completely red and bold - for line in (x for x in self.lines if is_fail(x)): + for line in failure_lines: tw.line(line, bold=True, red=True) def toterminal(self, tw: TerminalWriter) -> None: @@ -1094,24 +1132,24 @@ class ReprEntry(TerminalRepr): ) [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore [email protected](eq=False) class ReprFileLocation(TerminalRepr): path = attr.ib(type=str, converter=str) lineno = attr.ib(type=int) message = attr.ib(type=str) def toterminal(self, tw: TerminalWriter) -> None: - # filename and lineno output for each entry, - # using an output format that most editors understand + # Filename and lineno output for each entry, using an output format + # that most editors understand. msg = self.message i = msg.find("\n") if i != -1: msg = msg[:i] tw.write(self.path, bold=True, red=True) - tw.line(":{}: {}".format(self.lineno, msg)) + tw.line(f":{self.lineno}: {msg}") [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore [email protected](eq=False) class ReprLocals(TerminalRepr): lines = attr.ib(type=Sequence[str]) @@ -1120,7 +1158,7 @@ class ReprLocals(TerminalRepr): tw.line(indent + line) [email protected](**{ATTRS_EQ_FIELD: False}) # type: ignore [email protected](eq=False) class ReprFuncArgs(TerminalRepr): args = attr.ib(type=Sequence[Tuple[str, object]]) @@ -1128,7 +1166,7 @@ class ReprFuncArgs(TerminalRepr): if self.args: linesofar = "" for name, value in self.args: - ns = "{} = {}".format(name, value) + ns = f"{name} = {value}" if len(ns) + len(linesofar) + 2 > tw.fullwidth: if linesofar: tw.line(linesofar) @@ -1143,49 +1181,79 @@ class ReprFuncArgs(TerminalRepr): tw.line("") -def getrawcode(obj, trycall: bool = True): - """ return code object for given function. """ +def getfslineno(obj: object) -> Tuple[Union[str, py.path.local], int]: + """Return source location (path, lineno) for the given object. + + If the source cannot be determined return ("", -1). + + The line number is 0-based. + """ + # xxx let decorators etc specify a sane ordering + # NOTE: this used to be done in _pytest.compat.getfslineno, initially added + # in 6ec13a2b9. It ("place_as") appears to be something very custom. + obj = get_real_func(obj) + if hasattr(obj, "place_as"): + obj = obj.place_as # type: ignore[attr-defined] + try: - return obj.__code__ - except AttributeError: - obj = getattr(obj, "f_code", obj) - obj = getattr(obj, "__code__", obj) - if trycall and not hasattr(obj, "co_firstlineno"): - if hasattr(obj, "__call__") and not inspect.isclass(obj): - x = getrawcode(obj.__call__, trycall=False) - if hasattr(x, "co_firstlineno"): - return x - return obj - - -# relative paths that we use to filter traceback entries from appearing to the user; -# see filter_traceback + code = Code.from_function(obj) + except TypeError: + try: + fn = inspect.getsourcefile(obj) or inspect.getfile(obj) # type: ignore[arg-type] + except TypeError: + return "", -1 + + fspath = fn and py.path.local(fn) or "" + lineno = -1 + if fspath: + try: + _, lineno = findsource(obj) + except OSError: + pass + return fspath, lineno + + return code.path, code.firstlineno + + +# Relative paths that we use to filter traceback entries from appearing to the user; +# see filter_traceback. # note: if we need to add more paths than what we have now we should probably use a list -# for better maintenance +# for better maintenance. -_PLUGGY_DIR = py.path.local(pluggy.__file__.rstrip("oc")) +_PLUGGY_DIR = Path(pluggy.__file__.rstrip("oc")) # pluggy is either a package or a single module depending on the version -if _PLUGGY_DIR.basename == "__init__.py": - _PLUGGY_DIR = _PLUGGY_DIR.dirpath() -_PYTEST_DIR = py.path.local(_pytest.__file__).dirpath() -_PY_DIR = py.path.local(py.__file__).dirpath() +if _PLUGGY_DIR.name == "__init__.py": + _PLUGGY_DIR = _PLUGGY_DIR.parent +_PYTEST_DIR = Path(_pytest.__file__).parent +_PY_DIR = Path(py.__file__).parent def filter_traceback(entry: TracebackEntry) -> bool: - """Return True if a TracebackEntry instance should be removed from tracebacks: + """Return True if a TracebackEntry instance should be included in tracebacks. + + We hide traceback entries of: + * dynamically generated code (no code to show up for it); * internal traceback from pytest or its internal libraries, py and pluggy. """ # entry.path might sometimes return a str object when the entry - # points to dynamically generated code - # see https://bitbucket.org/pytest-dev/py/issues/71 + # points to dynamically generated code. + # See https://bitbucket.org/pytest-dev/py/issues/71. raw_filename = entry.frame.code.raw.co_filename is_generated = "<" in raw_filename and ">" in raw_filename if is_generated: return False + # entry.path might point to a non-existing file, in which case it will - # also return a str object. see #1133 - p = py.path.local(entry.path) - return ( - not p.relto(_PLUGGY_DIR) and not p.relto(_PYTEST_DIR) and not p.relto(_PY_DIR) - ) + # also return a str object. See #1133. + p = Path(entry.path) + + parents = p.parents + if _PLUGGY_DIR in parents: + return False + if _PYTEST_DIR in parents: + return False + if _PY_DIR in parents: + return False + + return True diff --git a/contrib/python/pytest/py3/_pytest/_code/source.py b/contrib/python/pytest/py3/_pytest/_code/source.py index 28c11e5d5e3..6f54057c0a9 100644 --- a/contrib/python/pytest/py3/_pytest/_code/source.py +++ b/contrib/python/pytest/py3/_pytest/_code/source.py @@ -1,76 +1,59 @@ import ast import inspect -import linecache -import sys import textwrap import tokenize +import types import warnings from bisect import bisect_right -from types import CodeType -from types import FrameType -from typing import Any +from typing import Iterable from typing import Iterator from typing import List from typing import Optional -from typing import Sequence +from typing import overload from typing import Tuple from typing import Union -import py - -from _pytest.compat import get_real_func -from _pytest.compat import overload -from _pytest.compat import TYPE_CHECKING - -if TYPE_CHECKING: - from typing_extensions import Literal - class Source: - """ an immutable object holding a source code fragment, - possibly deindenting it. + """An immutable object holding a source code fragment. + + When using Source(...), the source lines are deindented. """ - _compilecounter = 0 - - def __init__(self, *parts, **kwargs) -> None: - self.lines = lines = [] # type: List[str] - de = kwargs.get("deindent", True) - for part in parts: - if not part: - partlines = [] # type: List[str] - elif isinstance(part, Source): - partlines = part.lines - elif isinstance(part, (tuple, list)): - partlines = [x.rstrip("\n") for x in part] - elif isinstance(part, str): - partlines = part.split("\n") - else: - partlines = getsource(part, deindent=de).lines - if de: - partlines = deindent(partlines) - lines.extend(partlines) - - def __eq__(self, other): - try: - return self.lines == other.lines - except AttributeError: - if isinstance(other, str): - return str(self) == other - return False + def __init__(self, obj: object = None) -> None: + if not obj: + self.lines: List[str] = [] + elif isinstance(obj, Source): + self.lines = obj.lines + elif isinstance(obj, (tuple, list)): + self.lines = deindent(x.rstrip("\n") for x in obj) + elif isinstance(obj, str): + self.lines = deindent(obj.split("\n")) + else: + try: + rawcode = getrawcode(obj) + src = inspect.getsource(rawcode) + except TypeError: + src = inspect.getsource(obj) # type: ignore[arg-type] + self.lines = deindent(src.split("\n")) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Source): + return NotImplemented + return self.lines == other.lines # Ignore type because of https://github.com/python/mypy/issues/4266. __hash__ = None # type: ignore @overload def __getitem__(self, key: int) -> str: - raise NotImplementedError() + ... - @overload # noqa: F811 - def __getitem__(self, key: slice) -> "Source": # noqa: F811 - raise NotImplementedError() + @overload + def __getitem__(self, key: slice) -> "Source": + ... - def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811 + def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: if isinstance(key, int): return self.lines[key] else: @@ -87,9 +70,7 @@ class Source: return len(self.lines) def strip(self) -> "Source": - """ return new source object with trailing - and leading blank lines removed. - """ + """Return new Source object with trailing and leading blank lines removed.""" start, end = 0, len(self) while start < end and not self.lines[start].strip(): start += 1 @@ -99,220 +80,36 @@ class Source: source.lines[:] = self.lines[start:end] return source - def putaround( - self, before: str = "", after: str = "", indent: str = " " * 4 - ) -> "Source": - """ return a copy of the source object with - 'before' and 'after' wrapped around it. - """ - beforesource = Source(before) - aftersource = Source(after) - newsource = Source() - lines = [(indent + line) for line in self.lines] - newsource.lines = beforesource.lines + lines + aftersource.lines - return newsource - def indent(self, indent: str = " " * 4) -> "Source": - """ return a copy of the source object with - all lines indented by the given indent-string. - """ + """Return a copy of the source object with all lines indented by the + given indent-string.""" newsource = Source() newsource.lines = [(indent + line) for line in self.lines] return newsource def getstatement(self, lineno: int) -> "Source": - """ return Source statement which contains the - given linenumber (counted from 0). - """ + """Return Source statement which contains the given linenumber + (counted from 0).""" start, end = self.getstatementrange(lineno) return self[start:end] def getstatementrange(self, lineno: int) -> Tuple[int, int]: - """ return (start, end) tuple which spans the minimal - statement region which containing the given lineno. - """ + """Return (start, end) tuple which spans the minimal statement region + which containing the given lineno.""" if not (0 <= lineno < len(self)): raise IndexError("lineno out of range") ast, start, end = getstatementrange_ast(lineno, self) return start, end def deindent(self) -> "Source": - """return a new source object deindented.""" + """Return a new Source object deindented.""" newsource = Source() newsource.lines[:] = deindent(self.lines) return newsource - def isparseable(self, deindent: bool = True) -> bool: - """ return True if source is parseable, heuristically - deindenting it by default. - """ - if deindent: - source = str(self.deindent()) - else: - source = str(self) - try: - ast.parse(source) - except (SyntaxError, ValueError, TypeError): - return False - else: - return True - def __str__(self) -> str: return "\n".join(self.lines) - @overload - def compile( - self, - filename: Optional[str] = ..., - mode: str = ..., - flag: "Literal[0]" = ..., - dont_inherit: int = ..., - _genframe: Optional[FrameType] = ..., - ) -> CodeType: - raise NotImplementedError() - - @overload # noqa: F811 - def compile( # noqa: F811 - self, - filename: Optional[str] = ..., - mode: str = ..., - flag: int = ..., - dont_inherit: int = ..., - _genframe: Optional[FrameType] = ..., - ) -> Union[CodeType, ast.AST]: - raise NotImplementedError() - - def compile( # noqa: F811 - self, - filename: Optional[str] = None, - mode: str = "exec", - flag: int = 0, - dont_inherit: int = 0, - _genframe: Optional[FrameType] = None, - ) -> Union[CodeType, ast.AST]: - """ return compiled code object. if filename is None - invent an artificial filename which displays - the source/line position of the caller frame. - """ - if not filename or py.path.local(filename).check(file=0): - if _genframe is None: - _genframe = sys._getframe(1) # the caller - fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno - base = "<%d-codegen " % self._compilecounter - self.__class__._compilecounter += 1 - if not filename: - filename = base + "%s:%d>" % (fn, lineno) - else: - filename = base + "%r %s:%d>" % (filename, fn, lineno) - source = "\n".join(self.lines) + "\n" - try: - co = compile(source, filename, mode, flag) - except SyntaxError as ex: - # re-represent syntax errors from parsing python strings - msglines = self.lines[: ex.lineno] - if ex.offset: - msglines.append(" " * ex.offset + "^") - msglines.append("(code was compiled probably from here: %s)" % filename) - newex = SyntaxError("\n".join(msglines)) - newex.offset = ex.offset - newex.lineno = ex.lineno - newex.text = ex.text - raise newex - else: - if flag & ast.PyCF_ONLY_AST: - assert isinstance(co, ast.AST) - return co - assert isinstance(co, CodeType) - lines = [(x + "\n") for x in self.lines] - # Type ignored because linecache.cache is private. - linecache.cache[filename] = (1, None, lines, filename) # type: ignore - return co - - -# -# public API shortcut functions -# - - -@overload -def compile_( - source: Union[str, bytes, ast.mod, ast.AST], - filename: Optional[str] = ..., - mode: str = ..., - flags: "Literal[0]" = ..., - dont_inherit: int = ..., -) -> CodeType: - raise NotImplementedError() - - -@overload # noqa: F811 -def compile_( # noqa: F811 - source: Union[str, bytes, ast.mod, ast.AST], - filename: Optional[str] = ..., - mode: str = ..., - flags: int = ..., - dont_inherit: int = ..., -) -> Union[CodeType, ast.AST]: - raise NotImplementedError() - - -def compile_( # noqa: F811 - source: Union[str, bytes, ast.mod, ast.AST], - filename: Optional[str] = None, - mode: str = "exec", - flags: int = 0, - dont_inherit: int = 0, -) -> Union[CodeType, ast.AST]: - """ compile the given source to a raw code object, - and maintain an internal cache which allows later - retrieval of the source code for the code object - and any recursively created code objects. - """ - if isinstance(source, ast.AST): - # XXX should Source support having AST? - assert filename is not None - co = compile(source, filename, mode, flags, dont_inherit) - assert isinstance(co, (CodeType, ast.AST)) - return co - _genframe = sys._getframe(1) # the caller - s = Source(source) - return s.compile(filename, mode, flags, _genframe=_genframe) - - -def getfslineno(obj: Any) -> Tuple[Union[str, py.path.local], int]: - """ Return source location (path, lineno) for the given object. - If the source cannot be determined return ("", -1). - - The line number is 0-based. - """ - from .code import Code - - # xxx let decorators etc specify a sane ordering - # NOTE: this used to be done in _pytest.compat.getfslineno, initially added - # in 6ec13a2b9. It ("place_as") appears to be something very custom. - obj = get_real_func(obj) - if hasattr(obj, "place_as"): - obj = obj.place_as - - try: - code = Code(obj) - except TypeError: - try: - fn = inspect.getsourcefile(obj) or inspect.getfile(obj) - except TypeError: - return "", -1 - - fspath = fn and py.path.local(fn) or "" - lineno = -1 - if fspath: - try: - _, lineno = findsource(obj) - except IOError: - pass - return fspath, lineno - else: - return code.path, code.firstlineno - # # helper functions @@ -329,35 +126,34 @@ def findsource(obj) -> Tuple[Optional[Source], int]: return source, lineno -def getsource(obj, **kwargs) -> Source: - from .code import getrawcode - - obj = getrawcode(obj) +def getrawcode(obj: object, trycall: bool = True) -> types.CodeType: + """Return code object for given function.""" try: - strsrc = inspect.getsource(obj) - except IndentationError: - strsrc = '"Buggy python version consider upgrading, cannot get source"' - assert isinstance(strsrc, str) - return Source(strsrc, **kwargs) + return obj.__code__ # type: ignore[attr-defined,no-any-return] + except AttributeError: + pass + if trycall: + call = getattr(obj, "__call__", None) + if call and not isinstance(obj, type): + return getrawcode(call, trycall=False) + raise TypeError(f"could not get code object for {obj!r}") -def deindent(lines: Sequence[str]) -> List[str]: +def deindent(lines: Iterable[str]) -> List[str]: return textwrap.dedent("\n".join(lines)).splitlines() def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: - import ast - - # flatten all statements and except handlers into one lineno-list - # AST's line numbers start indexing at 1 - values = [] # type: List[int] + # Flatten all statements and except handlers into one lineno-list. + # AST's line numbers start indexing at 1. + values: List[int] = [] for x in ast.walk(node): if isinstance(x, (ast.stmt, ast.ExceptHandler)): values.append(x.lineno - 1) for name in ("finalbody", "orelse"): - val = getattr(x, name, None) # type: Optional[List[ast.stmt]] + val: Optional[List[ast.stmt]] = getattr(x, name, None) if val: - # treat the finally/orelse part as its own statement + # Treat the finally/orelse part as its own statement. values.append(val[0].lineno - 1 - 1) values.sort() insert_index = bisect_right(values, lineno) @@ -378,13 +174,13 @@ def getstatementrange_ast( if astnode is None: content = str(source) # See #4260: - # don't produce duplicate warnings when compiling source to find ast + # Don't produce duplicate warnings when compiling source to find AST. with warnings.catch_warnings(): warnings.simplefilter("ignore") astnode = ast.parse(content, "source", "exec") start, end = get_statement_startend2(lineno, astnode) - # we need to correct the end: + # We need to correct the end: # - ast-parsing strips comments # - there might be empty lines # - we might have lesser indented code blocks at the end @@ -392,10 +188,10 @@ def getstatementrange_ast( end = len(source.lines) if end > start + 1: - # make sure we don't span differently indented code blocks - # by using the BlockFinder helper used which inspect.getsource() uses itself + # Make sure we don't span differently indented code blocks + # by using the BlockFinder helper used which inspect.getsource() uses itself. block_finder = inspect.BlockFinder() - # if we start with an indented line, put blockfinder to "started" mode + # If we start with an indented line, put blockfinder to "started" mode. block_finder.started = source.lines[start][0].isspace() it = ((x + "\n") for x in source.lines[start:end]) try: @@ -406,7 +202,7 @@ def getstatementrange_ast( except Exception: pass - # the end might still point to a comment or empty line, correct it + # The end might still point to a comment or empty line, correct it. while end: line = source.lines[end - 1].lstrip() if line.startswith("#") or not line: |
