aboutsummaryrefslogtreecommitdiffstats
path: root/contrib
diff options
context:
space:
mode:
authorarcadia-devtools <arcadia-devtools@yandex-team.ru>2022-02-09 12:00:52 +0300
committerDaniil Cherednik <dcherednik@yandex-team.ru>2022-02-10 15:58:17 +0300
commit8e1413fed79d1e8036e65228af6c93399ccf5502 (patch)
tree502c9df7b2614d20541c7a2d39d390e9a51877cc /contrib
parent6b813c17d56d1d05f92c61ddc347d0e4d358fe85 (diff)
downloadydb-8e1413fed79d1e8036e65228af6c93399ccf5502.tar.gz
intermediate changes
ref:614ed510ddd3cdf86a8c5dbf19afd113397e0172
Diffstat (limited to 'contrib')
-rw-r--r--contrib/python/iniconfig/.dist-info/METADATA78
-rw-r--r--contrib/python/iniconfig/.dist-info/top_level.txt1
-rw-r--r--contrib/python/iniconfig/LICENSE19
-rw-r--r--contrib/python/iniconfig/README.txt51
-rw-r--r--contrib/python/iniconfig/iniconfig/__init__.py173
-rw-r--r--contrib/python/iniconfig/iniconfig/__init__.pyi31
-rw-r--r--contrib/python/iniconfig/iniconfig/py.typed (renamed from contrib/python/more-itertools/py3/more_itertools/py.typed)0
-rw-r--r--contrib/python/iniconfig/patches/01-arcadia.patch26
-rw-r--r--contrib/python/iniconfig/ya.make26
-rw-r--r--contrib/python/more-itertools/py2/.dist-info/METADATA460
-rw-r--r--contrib/python/more-itertools/py2/.dist-info/top_level.txt1
-rw-r--r--contrib/python/more-itertools/py2/LICENSE19
-rw-r--r--contrib/python/more-itertools/py2/README.rst154
-rw-r--r--contrib/python/more-itertools/py2/more_itertools/__init__.py2
-rw-r--r--contrib/python/more-itertools/py2/more_itertools/more.py2333
-rw-r--r--contrib/python/more-itertools/py2/more_itertools/recipes.py577
-rw-r--r--contrib/python/more-itertools/py2/more_itertools/tests/test_more.py2313
-rw-r--r--contrib/python/more-itertools/py2/more_itertools/tests/test_recipes.py616
-rw-r--r--contrib/python/more-itertools/py2/patches/01-fix-tests.patch18
-rw-r--r--contrib/python/more-itertools/py2/tests/ya.make18
-rw-r--r--contrib/python/more-itertools/py2/ya.make34
-rw-r--r--contrib/python/more-itertools/py3/.dist-info/METADATA521
-rw-r--r--contrib/python/more-itertools/py3/.dist-info/top_level.txt1
-rw-r--r--contrib/python/more-itertools/py3/LICENSE19
-rw-r--r--contrib/python/more-itertools/py3/README.rst200
-rw-r--r--contrib/python/more-itertools/py3/more_itertools/__init__.py4
-rw-r--r--contrib/python/more-itertools/py3/more_itertools/__init__.pyi2
-rw-r--r--contrib/python/more-itertools/py3/more_itertools/more.py4317
-rw-r--r--contrib/python/more-itertools/py3/more_itertools/more.pyi664
-rw-r--r--contrib/python/more-itertools/py3/more_itertools/recipes.py698
-rw-r--r--contrib/python/more-itertools/py3/more_itertools/recipes.pyi112
-rw-r--r--contrib/python/more-itertools/py3/patches/01-fix-tests.patch17
-rw-r--r--contrib/python/more-itertools/py3/tests/test_more.py5033
-rw-r--r--contrib/python/more-itertools/py3/tests/test_recipes.py765
-rw-r--r--contrib/python/more-itertools/py3/tests/ya.make16
-rw-r--r--contrib/python/more-itertools/py3/ya.make34
-rw-r--r--contrib/python/more-itertools/ya.make20
-rw-r--r--contrib/python/pytest/py3/.dist-info/METADATA78
-rw-r--r--contrib/python/pytest/py3/.dist-info/entry_points.txt4
-rw-r--r--contrib/python/pytest/py3/AUTHORS36
-rw-r--r--contrib/python/pytest/py3/README.rst42
-rw-r--r--contrib/python/pytest/py3/_pytest/_argcomplete.py52
-rw-r--r--contrib/python/pytest/py3/_pytest/_code/__init__.py32
-rw-r--r--contrib/python/pytest/py3/_pytest/_code/code.py656
-rw-r--r--contrib/python/pytest/py3/_pytest/_code/source.py328
-rw-r--r--contrib/python/pytest/py3/_pytest/_io/__init__.py43
-rw-r--r--contrib/python/pytest/py3/_pytest/_io/saferepr.py60
-rw-r--r--contrib/python/pytest/py3/_pytest/_io/terminalwriter.py210
-rw-r--r--contrib/python/pytest/py3/_pytest/_io/wcwidth.py55
-rw-r--r--contrib/python/pytest/py3/_pytest/_version.py3
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/__init__.py70
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/rewrite.py399
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/truncate.py41
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/util.py122
-rw-r--r--contrib/python/pytest/py3/_pytest/cacheprovider.py317
-rw-r--r--contrib/python/pytest/py3/_pytest/capture.py1318
-rw-r--r--contrib/python/pytest/py3/_pytest/compat.py252
-rw-r--r--contrib/python/pytest/py3/_pytest/config/__init__.py1078
-rw-r--r--contrib/python/pytest/py3/_pytest/config/argparsing.py156
-rw-r--r--contrib/python/pytest/py3/_pytest/config/exceptions.py10
-rw-r--r--contrib/python/pytest/py3/_pytest/config/findpaths.py220
-rw-r--r--contrib/python/pytest/py3/_pytest/debugging.py111
-rw-r--r--contrib/python/pytest/py3/_pytest/deprecated.py85
-rw-r--r--contrib/python/pytest/py3/_pytest/doctest.py256
-rw-r--r--contrib/python/pytest/py3/_pytest/faulthandler.py41
-rw-r--r--contrib/python/pytest/py3/_pytest/fixtures.py1296
-rw-r--r--contrib/python/pytest/py3/_pytest/freeze_support.py45
-rw-r--r--contrib/python/pytest/py3/_pytest/helpconfig.py83
-rw-r--r--contrib/python/pytest/py3/_pytest/hookspec.py715
-rw-r--r--contrib/python/pytest/py3/_pytest/junitxml.py446
-rw-r--r--contrib/python/pytest/py3/_pytest/logging.py704
-rw-r--r--contrib/python/pytest/py3/_pytest/main.py737
-rw-r--r--contrib/python/pytest/py3/_pytest/mark/__init__.py162
-rw-r--r--contrib/python/pytest/py3/_pytest/mark/evaluate.py132
-rw-r--r--contrib/python/pytest/py3/_pytest/mark/expression.py221
-rw-r--r--contrib/python/pytest/py3/_pytest/mark/legacy.py116
-rw-r--r--contrib/python/pytest/py3/_pytest/mark/structures.py397
-rw-r--r--contrib/python/pytest/py3/_pytest/monkeypatch.py223
-rw-r--r--contrib/python/pytest/py3/_pytest/nodes.py471
-rw-r--r--contrib/python/pytest/py3/_pytest/nose.py15
-rw-r--r--contrib/python/pytest/py3/_pytest/outcomes.py107
-rw-r--r--contrib/python/pytest/py3/_pytest/pastebin.py59
-rw-r--r--contrib/python/pytest/py3/_pytest/pathlib.py425
-rw-r--r--contrib/python/pytest/py3/_pytest/py.typed (renamed from contrib/python/more-itertools/py2/more_itertools/tests/__init__.py)0
-rw-r--r--contrib/python/pytest/py3/_pytest/pytester.py1237
-rw-r--r--contrib/python/pytest/py3/_pytest/pytester_assertions.py66
-rw-r--r--contrib/python/pytest/py3/_pytest/python.py862
-rw-r--r--contrib/python/pytest/py3/_pytest/python_api.py331
-rw-r--r--contrib/python/pytest/py3/_pytest/recwarn.py134
-rw-r--r--contrib/python/pytest/py3/_pytest/reports.py307
-rw-r--r--contrib/python/pytest/py3/_pytest/resultlog.py102
-rw-r--r--contrib/python/pytest/py3/_pytest/runner.py276
-rw-r--r--contrib/python/pytest/py3/_pytest/setuponly.py36
-rw-r--r--contrib/python/pytest/py3/_pytest/setupplan.py18
-rw-r--r--contrib/python/pytest/py3/_pytest/skipping.py309
-rw-r--r--contrib/python/pytest/py3/_pytest/stepwise.py101
-rw-r--r--contrib/python/pytest/py3/_pytest/store.py12
-rw-r--r--contrib/python/pytest/py3/_pytest/terminal.py760
-rw-r--r--contrib/python/pytest/py3/_pytest/threadexception.py90
-rw-r--r--contrib/python/pytest/py3/_pytest/timing.py12
-rw-r--r--contrib/python/pytest/py3/_pytest/tmpdir.py203
-rw-r--r--contrib/python/pytest/py3/_pytest/unittest.py243
-rw-r--r--contrib/python/pytest/py3/_pytest/unraisableexception.py93
-rw-r--r--contrib/python/pytest/py3/_pytest/warning_types.py105
-rw-r--r--contrib/python/pytest/py3/_pytest/warnings.py140
-rw-r--r--contrib/python/pytest/py3/patches/03-limit-id.patch2
-rw-r--r--contrib/python/pytest/py3/patches/04-support-cyrillic-id.patch2
-rw-r--r--contrib/python/pytest/py3/patches/05-support-readline.patch16
-rw-r--r--contrib/python/pytest/py3/patches/06-support-ya-markers.patch2
-rw-r--r--contrib/python/pytest/py3/patches/07-disable-translate-non-printable.patch2
-rw-r--r--contrib/python/pytest/py3/pytest/__init__.py42
-rw-r--r--contrib/python/pytest/py3/pytest/__main__.py6
-rw-r--r--contrib/python/pytest/py3/pytest/collect.py39
-rw-r--r--contrib/python/pytest/py3/pytest/py.typed (renamed from contrib/python/more-itertools/py3/tests/__init__.py)0
-rw-r--r--contrib/python/pytest/py3/ya.make19
-rw-r--r--contrib/python/toml/.dist-info/METADATA255
-rw-r--r--contrib/python/toml/.dist-info/top_level.txt1
-rw-r--r--contrib/python/toml/LICENSE27
-rw-r--r--contrib/python/toml/README.rst224
-rw-r--r--contrib/python/toml/toml/__init__.py25
-rw-r--r--contrib/python/toml/toml/__init__.pyi15
-rw-r--r--contrib/python/toml/toml/decoder.py1057
-rw-r--r--contrib/python/toml/toml/decoder.pyi52
-rw-r--r--contrib/python/toml/toml/encoder.py304
-rw-r--r--contrib/python/toml/toml/encoder.pyi34
-rw-r--r--contrib/python/toml/toml/ordered.py15
-rw-r--r--contrib/python/toml/toml/ordered.pyi7
-rw-r--r--contrib/python/toml/toml/tz.py24
-rw-r--r--contrib/python/toml/toml/tz.pyi9
-rw-r--r--contrib/python/toml/ya.make31
-rw-r--r--contrib/python/ya.make2
131 files changed, 13334 insertions, 26117 deletions
diff --git a/contrib/python/iniconfig/.dist-info/METADATA b/contrib/python/iniconfig/.dist-info/METADATA
new file mode 100644
index 0000000000..c078a7532f
--- /dev/null
+++ b/contrib/python/iniconfig/.dist-info/METADATA
@@ -0,0 +1,78 @@
+Metadata-Version: 2.1
+Name: iniconfig
+Version: 1.1.1
+Summary: iniconfig: brain-dead simple config-ini parsing
+Home-page: http://github.com/RonnyPfannschmidt/iniconfig
+Author: Ronny Pfannschmidt, Holger Krekel
+Author-email: opensource@ronnypfannschmidt.de, holger.krekel@gmail.com
+License: MIT License
+Platform: unix
+Platform: linux
+Platform: osx
+Platform: cygwin
+Platform: win32
+Classifier: Development Status :: 4 - Beta
+Classifier: Intended Audience :: Developers
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Operating System :: POSIX
+Classifier: Operating System :: Microsoft :: Windows
+Classifier: Operating System :: MacOS :: MacOS X
+Classifier: Topic :: Software Development :: Libraries
+Classifier: Topic :: Utilities
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 2
+Classifier: Programming Language :: Python :: 3
+
+iniconfig: brain-dead simple parsing of ini files
+=======================================================
+
+iniconfig is a small and simple INI-file parser module
+having a unique set of features:
+
+* tested against Python2.4 across to Python3.2, Jython, PyPy
+* maintains order of sections and entries
+* supports multi-line values with or without line-continuations
+* supports "#" comments everywhere
+* raises errors with proper line-numbers
+* no bells and whistles like automatic substitutions
+* iniconfig raises an Error if two sections have the same name.
+
+If you encounter issues or have feature wishes please report them to:
+
+ http://github.com/RonnyPfannschmidt/iniconfig/issues
+
+Basic Example
+===================================
+
+If you have an ini file like this::
+
+ # content of example.ini
+ [section1] # comment
+ name1=value1 # comment
+ name1b=value1,value2 # comment
+
+ [section2]
+ name2=
+ line1
+ line2
+
+then you can do::
+
+ >>> import iniconfig
+ >>> ini = iniconfig.IniConfig("example.ini")
+ >>> ini['section1']['name1'] # raises KeyError if not exists
+ 'value1'
+ >>> ini.get('section1', 'name1b', [], lambda x: x.split(","))
+ ['value1', 'value2']
+ >>> ini.get('section1', 'notexist', [], lambda x: x.split(","))
+ []
+ >>> [x.name for x in list(ini)]
+ ['section1', 'section2']
+ >>> list(list(ini)[0].items())
+ [('name1', 'value1'), ('name1b', 'value1,value2')]
+ >>> 'section1' in ini
+ True
+ >>> 'inexistendsection' in ini
+ False
+
+
diff --git a/contrib/python/iniconfig/.dist-info/top_level.txt b/contrib/python/iniconfig/.dist-info/top_level.txt
new file mode 100644
index 0000000000..9dda53692d
--- /dev/null
+++ b/contrib/python/iniconfig/.dist-info/top_level.txt
@@ -0,0 +1 @@
+iniconfig
diff --git a/contrib/python/iniconfig/LICENSE b/contrib/python/iniconfig/LICENSE
new file mode 100644
index 0000000000..31ecdfb1db
--- /dev/null
+++ b/contrib/python/iniconfig/LICENSE
@@ -0,0 +1,19 @@
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+
diff --git a/contrib/python/iniconfig/README.txt b/contrib/python/iniconfig/README.txt
new file mode 100644
index 0000000000..6bbad9a8d9
--- /dev/null
+++ b/contrib/python/iniconfig/README.txt
@@ -0,0 +1,51 @@
+iniconfig: brain-dead simple parsing of ini files
+=======================================================
+
+iniconfig is a small and simple INI-file parser module
+having a unique set of features:
+
+* tested against Python2.4 across to Python3.2, Jython, PyPy
+* maintains order of sections and entries
+* supports multi-line values with or without line-continuations
+* supports "#" comments everywhere
+* raises errors with proper line-numbers
+* no bells and whistles like automatic substitutions
+* iniconfig raises an Error if two sections have the same name.
+
+If you encounter issues or have feature wishes please report them to:
+
+ http://github.com/RonnyPfannschmidt/iniconfig/issues
+
+Basic Example
+===================================
+
+If you have an ini file like this::
+
+ # content of example.ini
+ [section1] # comment
+ name1=value1 # comment
+ name1b=value1,value2 # comment
+
+ [section2]
+ name2=
+ line1
+ line2
+
+then you can do::
+
+ >>> import iniconfig
+ >>> ini = iniconfig.IniConfig("example.ini")
+ >>> ini['section1']['name1'] # raises KeyError if not exists
+ 'value1'
+ >>> ini.get('section1', 'name1b', [], lambda x: x.split(","))
+ ['value1', 'value2']
+ >>> ini.get('section1', 'notexist', [], lambda x: x.split(","))
+ []
+ >>> [x.name for x in list(ini)]
+ ['section1', 'section2']
+ >>> list(list(ini)[0].items())
+ [('name1', 'value1'), ('name1b', 'value1,value2')]
+ >>> 'section1' in ini
+ True
+ >>> 'inexistendsection' in ini
+ False
diff --git a/contrib/python/iniconfig/iniconfig/__init__.py b/contrib/python/iniconfig/iniconfig/__init__.py
new file mode 100644
index 0000000000..3209831362
--- /dev/null
+++ b/contrib/python/iniconfig/iniconfig/__init__.py
@@ -0,0 +1,173 @@
+""" brain-dead simple parser for ini-style files.
+(C) Ronny Pfannschmidt, Holger Krekel -- MIT licensed
+"""
+import os
+__all__ = ['IniConfig', 'ParseError']
+
+COMMENTCHARS = "#;"
+
+
+class ParseError(Exception):
+ def __init__(self, path, lineno, msg):
+ Exception.__init__(self, path, lineno, msg)
+ self.path = path
+ self.lineno = lineno
+ self.msg = msg
+
+ def __str__(self):
+ return "%s:%s: %s" % (self.path, self.lineno+1, self.msg)
+
+
+class SectionWrapper(object):
+ def __init__(self, config, name):
+ self.config = config
+ self.name = name
+
+ def lineof(self, name):
+ return self.config.lineof(self.name, name)
+
+ def get(self, key, default=None, convert=str):
+ return self.config.get(self.name, key,
+ convert=convert, default=default)
+
+ def __getitem__(self, key):
+ return self.config.sections[self.name][key]
+
+ def __iter__(self):
+ section = self.config.sections.get(self.name, [])
+
+ def lineof(key):
+ return self.config.lineof(self.name, key)
+ for name in sorted(section, key=lineof):
+ yield name
+
+ def items(self):
+ for name in self:
+ yield name, self[name]
+
+
+class IniConfig(object):
+ def __init__(self, path, data=None):
+ self.path = str(path) # convenience
+ if data is None:
+ if os.path.basename(self.path).startswith('pkg:'):
+ import io, pkgutil
+
+ _, package, resource = self.path.split(':')
+ content = pkgutil.get_data(package, resource)
+ f = io.StringIO(content.decode('utf-8'))
+ else:
+ f = open(self.path)
+ try:
+ tokens = self._parse(iter(f))
+ finally:
+ f.close()
+ else:
+ tokens = self._parse(data.splitlines(True))
+
+ self._sources = {}
+ self.sections = {}
+
+ for lineno, section, name, value in tokens:
+ if section is None:
+ self._raise(lineno, 'no section header defined')
+ self._sources[section, name] = lineno
+ if name is None:
+ if section in self.sections:
+ self._raise(lineno, 'duplicate section %r' % (section, ))
+ self.sections[section] = {}
+ else:
+ if name in self.sections[section]:
+ self._raise(lineno, 'duplicate name %r' % (name, ))
+ self.sections[section][name] = value
+
+ def _raise(self, lineno, msg):
+ raise ParseError(self.path, lineno, msg)
+
+ def _parse(self, line_iter):
+ result = []
+ section = None
+ for lineno, line in enumerate(line_iter):
+ name, data = self._parseline(line, lineno)
+ # new value
+ if name is not None and data is not None:
+ result.append((lineno, section, name, data))
+ # new section
+ elif name is not None and data is None:
+ if not name:
+ self._raise(lineno, 'empty section name')
+ section = name
+ result.append((lineno, section, None, None))
+ # continuation
+ elif name is None and data is not None:
+ if not result:
+ self._raise(lineno, 'unexpected value continuation')
+ last = result.pop()
+ last_name, last_data = last[-2:]
+ if last_name is None:
+ self._raise(lineno, 'unexpected value continuation')
+
+ if last_data:
+ data = '%s\n%s' % (last_data, data)
+ result.append(last[:-1] + (data,))
+ return result
+
+ def _parseline(self, line, lineno):
+ # blank lines
+ if iscommentline(line):
+ line = ""
+ else:
+ line = line.rstrip()
+ if not line:
+ return None, None
+ # section
+ if line[0] == '[':
+ realline = line
+ for c in COMMENTCHARS:
+ line = line.split(c)[0].rstrip()
+ if line[-1] == "]":
+ return line[1:-1], None
+ return None, realline.strip()
+ # value
+ elif not line[0].isspace():
+ try:
+ name, value = line.split('=', 1)
+ if ":" in name:
+ raise ValueError()
+ except ValueError:
+ try:
+ name, value = line.split(":", 1)
+ except ValueError:
+ self._raise(lineno, 'unexpected line: %r' % line)
+ return name.strip(), value.strip()
+ # continuation
+ else:
+ return None, line.strip()
+
+ def lineof(self, section, name=None):
+ lineno = self._sources.get((section, name))
+ if lineno is not None:
+ return lineno + 1
+
+ def get(self, section, name, default=None, convert=str):
+ try:
+ return convert(self.sections[section][name])
+ except KeyError:
+ return default
+
+ def __getitem__(self, name):
+ if name not in self.sections:
+ raise KeyError(name)
+ return SectionWrapper(self, name)
+
+ def __iter__(self):
+ for name in sorted(self.sections, key=self.lineof):
+ yield SectionWrapper(self, name)
+
+ def __contains__(self, arg):
+ return arg in self.sections
+
+
+def iscommentline(line):
+ c = line.lstrip()[:1]
+ return c in COMMENTCHARS
diff --git a/contrib/python/iniconfig/iniconfig/__init__.pyi b/contrib/python/iniconfig/iniconfig/__init__.pyi
new file mode 100644
index 0000000000..b6284bec3f
--- /dev/null
+++ b/contrib/python/iniconfig/iniconfig/__init__.pyi
@@ -0,0 +1,31 @@
+from typing import Callable, Iterator, Mapping, Optional, Tuple, TypeVar, Union
+from typing_extensions import Final
+
+_D = TypeVar('_D')
+_T = TypeVar('_T')
+
+class ParseError(Exception):
+ # Private __init__.
+ path: Final[str]
+ lineno: Final[int]
+ msg: Final[str]
+
+class SectionWrapper:
+ # Private __init__.
+ config: Final[IniConfig]
+ name: Final[str]
+ def __getitem__(self, key: str) -> str: ...
+ def __iter__(self) -> Iterator[str]: ...
+ def get(self, key: str, default: _D = ..., convert: Callable[[str], _T] = ...) -> Union[_T, _D]: ...
+ def items(self) -> Iterator[Tuple[str, str]]: ...
+ def lineof(self, name: str) -> Optional[int]: ...
+
+class IniConfig:
+ path: Final[str]
+ sections: Final[Mapping[str, Mapping[str, str]]]
+ def __init__(self, path: str, data: Optional[str] = None): ...
+ def __contains__(self, arg: str) -> bool: ...
+ def __getitem__(self, name: str) -> SectionWrapper: ...
+ def __iter__(self) -> Iterator[SectionWrapper]: ...
+ def get(self, section: str, name: str, default: _D = ..., convert: Callable[[str], _T] = ...) -> Union[_T, _D]: ...
+ def lineof(self, section: str, name: Optional[str] = ...) -> Optional[int]: ...
diff --git a/contrib/python/more-itertools/py3/more_itertools/py.typed b/contrib/python/iniconfig/iniconfig/py.typed
index e69de29bb2..e69de29bb2 100644
--- a/contrib/python/more-itertools/py3/more_itertools/py.typed
+++ b/contrib/python/iniconfig/iniconfig/py.typed
diff --git a/contrib/python/iniconfig/patches/01-arcadia.patch b/contrib/python/iniconfig/patches/01-arcadia.patch
new file mode 100644
index 0000000000..16d9cd88a4
--- /dev/null
+++ b/contrib/python/iniconfig/patches/01-arcadia.patch
@@ -0,0 +1,26 @@
+--- contrib/python/iniconfig/iniconfig/__init__.py (index)
++++ contrib/python/iniconfig/iniconfig/__init__.py (working tree)
+@@ -1,6 +1,7 @@
+ """ brain-dead simple parser for ini-style files.
+ (C) Ronny Pfannschmidt, Holger Krekel -- MIT licensed
+ """
++import os
+ __all__ = ['IniConfig', 'ParseError']
+
+ COMMENTCHARS = "#;"
+@@ -49,7 +50,14 @@ class IniConfig(object):
+ def __init__(self, path, data=None):
+ self.path = str(path) # convenience
+ if data is None:
+- f = open(self.path)
++ if os.path.basename(self.path).startswith('pkg:'):
++ import io, pkgutil
++
++ _, package, resource = self.path.split(':')
++ content = pkgutil.get_data(package, resource)
++ f = io.StringIO(content.decode('utf-8'))
++ else:
++ f = open(self.path)
+ try:
+ tokens = self._parse(iter(f))
+ finally:
diff --git a/contrib/python/iniconfig/ya.make b/contrib/python/iniconfig/ya.make
new file mode 100644
index 0000000000..9121ccd0ab
--- /dev/null
+++ b/contrib/python/iniconfig/ya.make
@@ -0,0 +1,26 @@
+# Generated by devtools/yamaker (pypi).
+
+PY3_LIBRARY()
+
+OWNER(g:python-contrib)
+
+VERSION(1.1.1)
+
+LICENSE(MIT)
+
+NO_LINT()
+
+PY_SRCS(
+ TOP_LEVEL
+ iniconfig/__init__.py
+ iniconfig/__init__.pyi
+)
+
+RESOURCE_FILES(
+ PREFIX contrib/python/iniconfig/
+ .dist-info/METADATA
+ .dist-info/top_level.txt
+ iniconfig/py.typed
+)
+
+END()
diff --git a/contrib/python/more-itertools/py2/.dist-info/METADATA b/contrib/python/more-itertools/py2/.dist-info/METADATA
deleted file mode 100644
index e712d08090..0000000000
--- a/contrib/python/more-itertools/py2/.dist-info/METADATA
+++ /dev/null
@@ -1,460 +0,0 @@
-Metadata-Version: 2.1
-Name: more-itertools
-Version: 5.0.0
-Summary: More routines for operating on iterables, beyond itertools
-Home-page: https://github.com/erikrose/more-itertools
-Author: Erik Rose
-Author-email: erikrose@grinchcentral.com
-License: MIT
-Keywords: itertools,iterator,iteration,filter,peek,peekable,collate,chunk,chunked
-Platform: UNKNOWN
-Classifier: Development Status :: 5 - Production/Stable
-Classifier: Intended Audience :: Developers
-Classifier: Natural Language :: English
-Classifier: License :: OSI Approved :: MIT License
-Classifier: Programming Language :: Python :: 2
-Classifier: Programming Language :: Python :: 2.7
-Classifier: Programming Language :: Python :: 3
-Classifier: Programming Language :: Python :: 3.4
-Classifier: Programming Language :: Python :: 3.5
-Classifier: Programming Language :: Python :: 3.6
-Classifier: Programming Language :: Python :: 3.7
-Classifier: Topic :: Software Development :: Libraries
-Requires-Dist: six (<2.0.0,>=1.0.0)
-
-==============
-More Itertools
-==============
-
-.. image:: https://coveralls.io/repos/github/erikrose/more-itertools/badge.svg?branch=master
- :target: https://coveralls.io/github/erikrose/more-itertools?branch=master
-
-Python's ``itertools`` library is a gem - you can compose elegant solutions
-for a variety of problems with the functions it provides. In ``more-itertools``
-we collect additional building blocks, recipes, and routines for working with
-Python iterables.
-
-----
-
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Grouping | `chunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked>`_, |
-| | `sliced <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliced>`_, |
-| | `distribute <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute>`_, |
-| | `divide <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.divide>`_, |
-| | `split_at <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_at>`_, |
-| | `split_before <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_before>`_, |
-| | `split_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_after>`_, |
-| | `bucket <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.bucket>`_, |
-| | `grouper <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.grouper>`_, |
-| | `partition <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partition>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Lookahead and lookback | `spy <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.spy>`_, |
-| | `peekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.peekable>`_, |
-| | `seekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.seekable>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Windowing | `windowed <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed>`_, |
-| | `stagger <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.stagger>`_, |
-| | `pairwise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pairwise>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Augmenting | `count_cycle <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.count_cycle>`_, |
-| | `intersperse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse>`_, |
-| | `padded <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.padded>`_, |
-| | `adjacent <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.adjacent>`_, |
-| | `groupby_transform <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.groupby_transform>`_, |
-| | `padnone <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.padnone>`_, |
-| | `ncycles <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ncycles>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Combining | `collapse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.collapse>`_, |
-| | `sort_together <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sort_together>`_, |
-| | `interleave <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave>`_, |
-| | `interleave_longest <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_longest>`_, |
-| | `collate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.collate>`_, |
-| | `zip_offset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_offset>`_, |
-| | `dotproduct <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.dotproduct>`_, |
-| | `flatten <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.flatten>`_, |
-| | `roundrobin <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.roundrobin>`_, |
-| | `prepend <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.prepend>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Summarizing | `ilen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ilen>`_, |
-| | `first <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first>`_, |
-| | `last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.last>`_, |
-| | `one <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.one>`_, |
-| | `unique_to_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_to_each>`_, |
-| | `locate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.locate>`_, |
-| | `rlocate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rlocate>`_, |
-| | `consecutive_groups <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consecutive_groups>`_, |
-| | `exactly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.exactly_n>`_, |
-| | `run_length <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.run_length>`_, |
-| | `map_reduce <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_reduce>`_, |
-| | `all_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_equal>`_, |
-| | `first_true <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first_true>`_, |
-| | `nth <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth>`_, |
-| | `quantify <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.quantify>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Selecting | `islice_extended <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.islice_extended>`_, |
-| | `strip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strip>`_, |
-| | `lstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.lstrip>`_, |
-| | `rstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rstrip>`_, |
-| | `take <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.take>`_, |
-| | `tail <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tail>`_, |
-| | `unique_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertoo ls.unique_everseen>`_, |
-| | `unique_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_justseen>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Combinatorics | `distinct_permutations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_permutations>`_, |
-| | `circular_shifts <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.circular_shifts>`_, |
-| | `powerset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.powerset>`_, |
-| | `random_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_product>`_, |
-| | `random_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_permutation>`_, |
-| | `random_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination>`_, |
-| | `random_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination_with_replacement>`_, |
-| | `nth_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Wrapping | `always_iterable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_iterable>`_, |
-| | `consumer <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consumer>`_, |
-| | `with_iter <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.with_iter>`_, |
-| | `iter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_except>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Others | `replace <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.replace>`_, |
-| | `numeric_range <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.numeric_range>`_, |
-| | `always_reversible <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_reversible>`_, |
-| | `side_effect <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.side_effect>`_, |
-| | `iterate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iterate>`_, |
-| | `difference <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.difference>`_, |
-| | `make_decorator <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.make_decorator>`_, |
-| | `SequenceView <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.SequenceView>`_, |
-| | `consume <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consume>`_, |
-| | `accumulate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.accumulate>`_, |
-| | `tabulate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tabulate>`_, |
-| | `repeatfunc <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeatfunc>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-
-
-Getting started
-===============
-
-To get started, install the library with `pip <https://pip.pypa.io/en/stable/>`_:
-
-.. code-block:: shell
-
- pip install more-itertools
-
-The recipes from the `itertools docs <https://docs.python.org/3/library/itertools.html#itertools-recipes>`_
-are included in the top-level package:
-
-.. code-block:: python
-
- >>> from more_itertools import flatten
- >>> iterable = [(0, 1), (2, 3)]
- >>> list(flatten(iterable))
- [0, 1, 2, 3]
-
-Several new recipes are available as well:
-
-.. code-block:: python
-
- >>> from more_itertools import chunked
- >>> iterable = [0, 1, 2, 3, 4, 5, 6, 7, 8]
- >>> list(chunked(iterable, 3))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
-
- >>> from more_itertools import spy
- >>> iterable = (x * x for x in range(1, 6))
- >>> head, iterable = spy(iterable, n=3)
- >>> list(head)
- [1, 4, 9]
- >>> list(iterable)
- [1, 4, 9, 16, 25]
-
-
-
-For the full listing of functions, see the `API documentation <https://more-itertools.readthedocs.io/en/latest/api.html>`_.
-
-Development
-===========
-
-``more-itertools`` is maintained by `@erikrose <https://github.com/erikrose>`_
-and `@bbayles <https://github.com/bbayles>`_, with help from `many others <https://github.com/erikrose/more-itertools/graphs/contributors>`_.
-If you have a problem or suggestion, please file a bug or pull request in this
-repository. Thanks for contributing!
-
-
-Version History
-===============
-
-
-
-5.0.0
------
-
-* New itertools:
- * split_into (thanks to rovyko)
- * unzip (thanks to bmintz)
- * substrings (thanks to pylang)
-
-* Changes to existing itertools:
- * ilen was optimized a bit (thanks to MSeifert04, achampion, and bmintz)
- * first_true now returns ``None`` by default. This is the reason for the major version bump - see below. (thanks to sk and OJFord)
-
-* Other changes:
- * Some code for old Python versions was removed (thanks to hugovk)
- * Some documentation mistakes were corrected (thanks to belm0 and hugovk)
- * Tests now run properly on 32-bit versions of Python (thanks to Millak)
- * Newer versions of CPython and PyPy are now tested against
-
-The major version update is due to the change in the default return value of
-first_true. It's now ``None``.
-
-.. code-block:: python
-
- >>> from more_itertools import first_true
- >>> iterable = [0, '', False, [], ()] # All these are False
- >>> answer = first_true(iterable)
- >>> print(answer)
- None
-
-4.3.0
------
-
-* New itertools:
- * last (thanks to tmshn)
- * replace (thanks to pylang)
- * rlocate (thanks to jferard and pylang)
-
-* Improvements to existing itertools:
- * locate can now search for multiple items
-
-* Other changes:
- * The docs now include a nice table of tools (thanks MSeifert04)
-
-4.2.0
------
-
-* New itertools:
- * map_reduce (thanks to pylang)
- * prepend (from the `Python 3.7 docs <https://docs.python.org/3.7/library/itertools.html#itertools-recipes>`_)
-
-* Improvements to existing itertools:
- * bucket now complies with PEP 479 (thanks to irmen)
-
-* Other changes:
- * Python 3.7 is now supported (thanks to irmen)
- * Python 3.3 is no longer supported
- * The test suite no longer requires third-party modules to run
- * The API docs now include links to source code
-
-4.1.0
------
-
-* New itertools:
- * split_at (thanks to michael-celani)
- * circular_shifts (thanks to hiqua)
- * make_decorator - see the blog post `Yo, I heard you like decorators <https://sites.google.com/site/bbayles/index/decorator_factory>`_
- for a tour (thanks to pylang)
- * always_reversible (thanks to michael-celani)
- * nth_combination (from the `Python 3.7 docs <https://docs.python.org/3.7/library/itertools.html#itertools-recipes>`_)
-
-* Improvements to existing itertools:
- * seekable now has an ``elements`` method to return cached items.
- * The performance tradeoffs between roundrobin and
- interleave_longest are now documented (thanks michael-celani,
- pylang, and MSeifert04)
-
-4.0.1
------
-
-* No code changes - this release fixes how the docs display on PyPI.
-
-4.0.0
------
-
-* New itertools:
- * consecutive_groups (Based on the example in the `Python 2.4 docs <https://docs.python.org/release/2.4.4/lib/itertools-example.html>`_)
- * seekable (If you're looking for how to "reset" an iterator,
- you're in luck!)
- * exactly_n (thanks to michael-celani)
- * run_length.encode and run_length.decode
- * difference
-
-* Improvements to existing itertools:
- * The number of items between filler elements in intersperse can
- now be specified (thanks to pylang)
- * distinct_permutations and peekable got some minor
- adjustments (thanks to MSeifert04)
- * always_iterable now returns an iterator object. It also now
- allows different types to be considered iterable (thanks to jaraco)
- * bucket can now limit the keys it stores in memory
- * one now allows for custom exceptions (thanks to kalekundert)
-
-* Other changes:
- * A few typos were fixed (thanks to EdwardBetts)
- * All tests can now be run with ``python setup.py test``
-
-The major version update is due to the change in the return value of always_iterable.
-It now always returns iterator objects:
-
-.. code-block:: python
-
- >>> from more_itertools import always_iterable
- # Non-iterable objects are wrapped with iter(tuple(obj))
- >>> always_iterable(12345)
- <tuple_iterator object at 0x7fb24c9488d0>
- >>> list(always_iterable(12345))
- [12345]
- # Iterable objects are wrapped with iter()
- >>> always_iterable([1, 2, 3, 4, 5])
- <list_iterator object at 0x7fb24c948c50>
-
-3.2.0
------
-
-* New itertools:
- * lstrip, rstrip, and strip
- (thanks to MSeifert04 and pylang)
- * islice_extended
-* Improvements to existing itertools:
- * Some bugs with slicing peekable-wrapped iterables were fixed
-
-3.1.0
------
-
-* New itertools:
- * numeric_range (Thanks to BebeSparkelSparkel and MSeifert04)
- * count_cycle (Thanks to BebeSparkelSparkel)
- * locate (Thanks to pylang and MSeifert04)
-* Improvements to existing itertools:
- * A few itertools are now slightly faster due to some function
- optimizations. (Thanks to MSeifert04)
-* The docs have been substantially revised with installation notes,
- categories for library functions, links, and more. (Thanks to pylang)
-
-
-3.0.0
------
-
-* Removed itertools:
- * ``context`` has been removed due to a design flaw - see below for
- replacement options. (thanks to NeilGirdhar)
-* Improvements to existing itertools:
- * ``side_effect`` now supports ``before`` and ``after`` keyword
- arguments. (Thanks to yardsale8)
-* PyPy and PyPy3 are now supported.
-
-The major version change is due to the removal of the ``context`` function.
-Replace it with standard ``with`` statement context management:
-
-.. code-block:: python
-
- # Don't use context() anymore
- file_obj = StringIO()
- consume(print(x, file=f) for f in context(file_obj) for x in u'123')
-
- # Use a with statement instead
- file_obj = StringIO()
- with file_obj as f:
- consume(print(x, file=f) for x in u'123')
-
-2.6.0
------
-
-* New itertools:
- * ``adjacent`` and ``groupby_transform`` (Thanks to diazona)
- * ``always_iterable`` (Thanks to jaraco)
- * (Removed in 3.0.0) ``context`` (Thanks to yardsale8)
- * ``divide`` (Thanks to mozbhearsum)
-* Improvements to existing itertools:
- * ``ilen`` is now slightly faster. (Thanks to wbolster)
- * ``peekable`` can now prepend items to an iterable. (Thanks to diazona)
-
-2.5.0
------
-
-* New itertools:
- * ``distribute`` (Thanks to mozbhearsum and coady)
- * ``sort_together`` (Thanks to clintval)
- * ``stagger`` and ``zip_offset`` (Thanks to joshbode)
- * ``padded``
-* Improvements to existing itertools:
- * ``peekable`` now handles negative indexes and slices with negative
- components properly.
- * ``intersperse`` is now slightly faster. (Thanks to pylang)
- * ``windowed`` now accepts a ``step`` keyword argument.
- (Thanks to pylang)
-* Python 3.6 is now supported.
-
-2.4.1
------
-
-* Move docs 100% to readthedocs.io.
-
-2.4
------
-
-* New itertools:
- * ``accumulate``, ``all_equal``, ``first_true``, ``partition``, and
- ``tail`` from the itertools documentation.
- * ``bucket`` (Thanks to Rosuav and cvrebert)
- * ``collapse`` (Thanks to abarnet)
- * ``interleave`` and ``interleave_longest`` (Thanks to abarnet)
- * ``side_effect`` (Thanks to nvie)
- * ``sliced`` (Thanks to j4mie and coady)
- * ``split_before`` and ``split_after`` (Thanks to astronouth7303)
- * ``spy`` (Thanks to themiurgo and mathieulongtin)
-* Improvements to existing itertools:
- * ``chunked`` is now simpler and more friendly to garbage collection.
- (Contributed by coady, with thanks to piskvorky)
- * ``collate`` now delegates to ``heapq.merge`` when possible.
- (Thanks to kmike and julianpistorius)
- * ``peekable``-wrapped iterables are now indexable and sliceable.
- Iterating through ``peekable``-wrapped iterables is also faster.
- * ``one`` and ``unique_to_each`` have been simplified.
- (Thanks to coady)
-
-
-2.3
------
-
-* Added ``one`` from ``jaraco.util.itertools``. (Thanks, jaraco!)
-* Added ``distinct_permutations`` and ``unique_to_each``. (Contributed by
- bbayles)
-* Added ``windowed``. (Contributed by bbayles, with thanks to buchanae,
- jaraco, and abarnert)
-* Simplified the implementation of ``chunked``. (Thanks, nvie!)
-* Python 3.5 is now supported. Python 2.6 is no longer supported.
-* Python 3 is now supported directly; there is no 2to3 step.
-
-2.2
------
-
-* Added ``iterate`` and ``with_iter``. (Thanks, abarnert!)
-
-2.1
------
-
-* Added (tested!) implementations of the recipes from the itertools
- documentation. (Thanks, Chris Lonnen!)
-* Added ``ilen``. (Thanks for the inspiration, Matt Basta!)
-
-2.0
------
-
-* ``chunked`` now returns lists rather than tuples. After all, they're
- homogeneous. This slightly backward-incompatible change is the reason for
- the major version bump.
-* Added ``@consumer``.
-* Improved test machinery.
-
-1.1
------
-
-* Added ``first`` function.
-* Added Python 3 support.
-* Added a default arg to ``peekable.peek()``.
-* Noted how to easily test whether a peekable iterator is exhausted.
-* Rewrote documentation.
-
-1.0
------
-
-* Initial release, with ``collate``, ``peekable``, and ``chunked``. Could
- really use better docs.
-
diff --git a/contrib/python/more-itertools/py2/.dist-info/top_level.txt b/contrib/python/more-itertools/py2/.dist-info/top_level.txt
deleted file mode 100644
index a5035befb3..0000000000
--- a/contrib/python/more-itertools/py2/.dist-info/top_level.txt
+++ /dev/null
@@ -1 +0,0 @@
-more_itertools
diff --git a/contrib/python/more-itertools/py2/LICENSE b/contrib/python/more-itertools/py2/LICENSE
deleted file mode 100644
index 0a523bece3..0000000000
--- a/contrib/python/more-itertools/py2/LICENSE
+++ /dev/null
@@ -1,19 +0,0 @@
-Copyright (c) 2012 Erik Rose
-
-Permission is hereby granted, free of charge, to any person obtaining a copy of
-this software and associated documentation files (the "Software"), to deal in
-the Software without restriction, including without limitation the rights to
-use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
-of the Software, and to permit persons to whom the Software is furnished to do
-so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/contrib/python/more-itertools/py2/README.rst b/contrib/python/more-itertools/py2/README.rst
deleted file mode 100644
index d918eb684f..0000000000
--- a/contrib/python/more-itertools/py2/README.rst
+++ /dev/null
@@ -1,154 +0,0 @@
-==============
-More Itertools
-==============
-
-.. image:: https://coveralls.io/repos/github/erikrose/more-itertools/badge.svg?branch=master
- :target: https://coveralls.io/github/erikrose/more-itertools?branch=master
-
-Python's ``itertools`` library is a gem - you can compose elegant solutions
-for a variety of problems with the functions it provides. In ``more-itertools``
-we collect additional building blocks, recipes, and routines for working with
-Python iterables.
-
-----
-
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Grouping | `chunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked>`_, |
-| | `sliced <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliced>`_, |
-| | `distribute <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute>`_, |
-| | `divide <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.divide>`_, |
-| | `split_at <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_at>`_, |
-| | `split_before <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_before>`_, |
-| | `split_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_after>`_, |
-| | `bucket <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.bucket>`_, |
-| | `grouper <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.grouper>`_, |
-| | `partition <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partition>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Lookahead and lookback | `spy <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.spy>`_, |
-| | `peekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.peekable>`_, |
-| | `seekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.seekable>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Windowing | `windowed <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed>`_, |
-| | `stagger <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.stagger>`_, |
-| | `pairwise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pairwise>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Augmenting | `count_cycle <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.count_cycle>`_, |
-| | `intersperse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse>`_, |
-| | `padded <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.padded>`_, |
-| | `adjacent <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.adjacent>`_, |
-| | `groupby_transform <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.groupby_transform>`_, |
-| | `padnone <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.padnone>`_, |
-| | `ncycles <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ncycles>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Combining | `collapse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.collapse>`_, |
-| | `sort_together <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sort_together>`_, |
-| | `interleave <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave>`_, |
-| | `interleave_longest <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_longest>`_, |
-| | `collate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.collate>`_, |
-| | `zip_offset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_offset>`_, |
-| | `dotproduct <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.dotproduct>`_, |
-| | `flatten <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.flatten>`_, |
-| | `roundrobin <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.roundrobin>`_, |
-| | `prepend <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.prepend>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Summarizing | `ilen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ilen>`_, |
-| | `first <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first>`_, |
-| | `last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.last>`_, |
-| | `one <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.one>`_, |
-| | `unique_to_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_to_each>`_, |
-| | `locate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.locate>`_, |
-| | `rlocate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rlocate>`_, |
-| | `consecutive_groups <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consecutive_groups>`_, |
-| | `exactly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.exactly_n>`_, |
-| | `run_length <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.run_length>`_, |
-| | `map_reduce <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_reduce>`_, |
-| | `all_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_equal>`_, |
-| | `first_true <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first_true>`_, |
-| | `nth <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth>`_, |
-| | `quantify <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.quantify>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Selecting | `islice_extended <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.islice_extended>`_, |
-| | `strip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strip>`_, |
-| | `lstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.lstrip>`_, |
-| | `rstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rstrip>`_, |
-| | `take <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.take>`_, |
-| | `tail <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tail>`_, |
-| | `unique_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertoo ls.unique_everseen>`_, |
-| | `unique_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_justseen>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Combinatorics | `distinct_permutations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_permutations>`_, |
-| | `circular_shifts <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.circular_shifts>`_, |
-| | `powerset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.powerset>`_, |
-| | `random_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_product>`_, |
-| | `random_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_permutation>`_, |
-| | `random_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination>`_, |
-| | `random_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination_with_replacement>`_, |
-| | `nth_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Wrapping | `always_iterable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_iterable>`_, |
-| | `consumer <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consumer>`_, |
-| | `with_iter <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.with_iter>`_, |
-| | `iter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_except>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Others | `replace <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.replace>`_, |
-| | `numeric_range <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.numeric_range>`_, |
-| | `always_reversible <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_reversible>`_, |
-| | `side_effect <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.side_effect>`_, |
-| | `iterate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iterate>`_, |
-| | `difference <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.difference>`_, |
-| | `make_decorator <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.make_decorator>`_, |
-| | `SequenceView <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.SequenceView>`_, |
-| | `consume <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consume>`_, |
-| | `accumulate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.accumulate>`_, |
-| | `tabulate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tabulate>`_, |
-| | `repeatfunc <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeatfunc>`_ |
-+------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-
-
-Getting started
-===============
-
-To get started, install the library with `pip <https://pip.pypa.io/en/stable/>`_:
-
-.. code-block:: shell
-
- pip install more-itertools
-
-The recipes from the `itertools docs <https://docs.python.org/3/library/itertools.html#itertools-recipes>`_
-are included in the top-level package:
-
-.. code-block:: python
-
- >>> from more_itertools import flatten
- >>> iterable = [(0, 1), (2, 3)]
- >>> list(flatten(iterable))
- [0, 1, 2, 3]
-
-Several new recipes are available as well:
-
-.. code-block:: python
-
- >>> from more_itertools import chunked
- >>> iterable = [0, 1, 2, 3, 4, 5, 6, 7, 8]
- >>> list(chunked(iterable, 3))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
-
- >>> from more_itertools import spy
- >>> iterable = (x * x for x in range(1, 6))
- >>> head, iterable = spy(iterable, n=3)
- >>> list(head)
- [1, 4, 9]
- >>> list(iterable)
- [1, 4, 9, 16, 25]
-
-
-
-For the full listing of functions, see the `API documentation <https://more-itertools.readthedocs.io/en/latest/api.html>`_.
-
-Development
-===========
-
-``more-itertools`` is maintained by `@erikrose <https://github.com/erikrose>`_
-and `@bbayles <https://github.com/bbayles>`_, with help from `many others <https://github.com/erikrose/more-itertools/graphs/contributors>`_.
-If you have a problem or suggestion, please file a bug or pull request in this
-repository. Thanks for contributing!
diff --git a/contrib/python/more-itertools/py2/more_itertools/__init__.py b/contrib/python/more-itertools/py2/more_itertools/__init__.py
deleted file mode 100644
index bba462c3db..0000000000
--- a/contrib/python/more-itertools/py2/more_itertools/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from more_itertools.more import * # noqa
-from more_itertools.recipes import * # noqa
diff --git a/contrib/python/more-itertools/py2/more_itertools/more.py b/contrib/python/more-itertools/py2/more_itertools/more.py
deleted file mode 100644
index bd32a26130..0000000000
--- a/contrib/python/more-itertools/py2/more_itertools/more.py
+++ /dev/null
@@ -1,2333 +0,0 @@
-from __future__ import print_function
-
-from collections import Counter, defaultdict, deque
-from functools import partial, wraps
-from heapq import merge
-from itertools import (
- chain,
- compress,
- count,
- cycle,
- dropwhile,
- groupby,
- islice,
- repeat,
- starmap,
- takewhile,
- tee
-)
-from operator import itemgetter, lt, gt, sub
-from sys import maxsize, version_info
-try:
- from collections.abc import Sequence
-except ImportError:
- from collections import Sequence
-
-from six import binary_type, string_types, text_type
-from six.moves import filter, map, range, zip, zip_longest
-
-from .recipes import consume, flatten, take
-
-__all__ = [
- 'adjacent',
- 'always_iterable',
- 'always_reversible',
- 'bucket',
- 'chunked',
- 'circular_shifts',
- 'collapse',
- 'collate',
- 'consecutive_groups',
- 'consumer',
- 'count_cycle',
- 'difference',
- 'distinct_permutations',
- 'distribute',
- 'divide',
- 'exactly_n',
- 'first',
- 'groupby_transform',
- 'ilen',
- 'interleave_longest',
- 'interleave',
- 'intersperse',
- 'islice_extended',
- 'iterate',
- 'last',
- 'locate',
- 'lstrip',
- 'make_decorator',
- 'map_reduce',
- 'numeric_range',
- 'one',
- 'padded',
- 'peekable',
- 'replace',
- 'rlocate',
- 'rstrip',
- 'run_length',
- 'seekable',
- 'SequenceView',
- 'side_effect',
- 'sliced',
- 'sort_together',
- 'split_at',
- 'split_after',
- 'split_before',
- 'split_into',
- 'spy',
- 'stagger',
- 'strip',
- 'substrings',
- 'unique_to_each',
- 'unzip',
- 'windowed',
- 'with_iter',
- 'zip_offset',
-]
-
-_marker = object()
-
-
-def chunked(iterable, n):
- """Break *iterable* into lists of length *n*:
-
- >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
- [[1, 2, 3], [4, 5, 6]]
-
- If the length of *iterable* is not evenly divisible by *n*, the last
- returned list will be shorter:
-
- >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
- [[1, 2, 3], [4, 5, 6], [7, 8]]
-
- To use a fill-in value instead, see the :func:`grouper` recipe.
-
- :func:`chunked` is useful for splitting up a computation on a large number
- of keys into batches, to be pickled and sent off to worker processes. One
- example is operations on rows in MySQL, which does not implement
- server-side cursors properly and would otherwise load the entire dataset
- into RAM on the client.
-
- """
- return iter(partial(take, n, iter(iterable)), [])
-
-
-def first(iterable, default=_marker):
- """Return the first item of *iterable*, or *default* if *iterable* is
- empty.
-
- >>> first([0, 1, 2, 3])
- 0
- >>> first([], 'some default')
- 'some default'
-
- If *default* is not provided and there are no items in the iterable,
- raise ``ValueError``.
-
- :func:`first` is useful when you have a generator of expensive-to-retrieve
- values and want any arbitrary one. It is marginally shorter than
- ``next(iter(iterable), default)``.
-
- """
- try:
- return next(iter(iterable))
- except StopIteration:
- # I'm on the edge about raising ValueError instead of StopIteration. At
- # the moment, ValueError wins, because the caller could conceivably
- # want to do something different with flow control when I raise the
- # exception, and it's weird to explicitly catch StopIteration.
- if default is _marker:
- raise ValueError('first() was called on an empty iterable, and no '
- 'default value was provided.')
- return default
-
-
-def last(iterable, default=_marker):
- """Return the last item of *iterable*, or *default* if *iterable* is
- empty.
-
- >>> last([0, 1, 2, 3])
- 3
- >>> last([], 'some default')
- 'some default'
-
- If *default* is not provided and there are no items in the iterable,
- raise ``ValueError``.
- """
- try:
- try:
- # Try to access the last item directly
- return iterable[-1]
- except (TypeError, AttributeError, KeyError):
- # If not slice-able, iterate entirely using length-1 deque
- return deque(iterable, maxlen=1)[0]
- except IndexError: # If the iterable was empty
- if default is _marker:
- raise ValueError('last() was called on an empty iterable, and no '
- 'default value was provided.')
- return default
-
-
-class peekable(object):
- """Wrap an iterator to allow lookahead and prepending elements.
-
- Call :meth:`peek` on the result to get the value that will be returned
- by :func:`next`. This won't advance the iterator:
-
- >>> p = peekable(['a', 'b'])
- >>> p.peek()
- 'a'
- >>> next(p)
- 'a'
-
- Pass :meth:`peek` a default value to return that instead of raising
- ``StopIteration`` when the iterator is exhausted.
-
- >>> p = peekable([])
- >>> p.peek('hi')
- 'hi'
-
- peekables also offer a :meth:`prepend` method, which "inserts" items
- at the head of the iterable:
-
- >>> p = peekable([1, 2, 3])
- >>> p.prepend(10, 11, 12)
- >>> next(p)
- 10
- >>> p.peek()
- 11
- >>> list(p)
- [11, 12, 1, 2, 3]
-
- peekables can be indexed. Index 0 is the item that will be returned by
- :func:`next`, index 1 is the item after that, and so on:
- The values up to the given index will be cached.
-
- >>> p = peekable(['a', 'b', 'c', 'd'])
- >>> p[0]
- 'a'
- >>> p[1]
- 'b'
- >>> next(p)
- 'a'
-
- Negative indexes are supported, but be aware that they will cache the
- remaining items in the source iterator, which may require significant
- storage.
-
- To check whether a peekable is exhausted, check its truth value:
-
- >>> p = peekable(['a', 'b'])
- >>> if p: # peekable has items
- ... list(p)
- ['a', 'b']
- >>> if not p: # peekable is exhaused
- ... list(p)
- []
-
- """
- def __init__(self, iterable):
- self._it = iter(iterable)
- self._cache = deque()
-
- def __iter__(self):
- return self
-
- def __bool__(self):
- try:
- self.peek()
- except StopIteration:
- return False
- return True
-
- def __nonzero__(self):
- # For Python 2 compatibility
- return self.__bool__()
-
- def peek(self, default=_marker):
- """Return the item that will be next returned from ``next()``.
-
- Return ``default`` if there are no items left. If ``default`` is not
- provided, raise ``StopIteration``.
-
- """
- if not self._cache:
- try:
- self._cache.append(next(self._it))
- except StopIteration:
- if default is _marker:
- raise
- return default
- return self._cache[0]
-
- def prepend(self, *items):
- """Stack up items to be the next ones returned from ``next()`` or
- ``self.peek()``. The items will be returned in
- first in, first out order::
-
- >>> p = peekable([1, 2, 3])
- >>> p.prepend(10, 11, 12)
- >>> next(p)
- 10
- >>> list(p)
- [11, 12, 1, 2, 3]
-
- It is possible, by prepending items, to "resurrect" a peekable that
- previously raised ``StopIteration``.
-
- >>> p = peekable([])
- >>> next(p)
- Traceback (most recent call last):
- ...
- StopIteration
- >>> p.prepend(1)
- >>> next(p)
- 1
- >>> next(p)
- Traceback (most recent call last):
- ...
- StopIteration
-
- """
- self._cache.extendleft(reversed(items))
-
- def __next__(self):
- if self._cache:
- return self._cache.popleft()
-
- return next(self._it)
-
- next = __next__ # For Python 2 compatibility
-
- def _get_slice(self, index):
- # Normalize the slice's arguments
- step = 1 if (index.step is None) else index.step
- if step > 0:
- start = 0 if (index.start is None) else index.start
- stop = maxsize if (index.stop is None) else index.stop
- elif step < 0:
- start = -1 if (index.start is None) else index.start
- stop = (-maxsize - 1) if (index.stop is None) else index.stop
- else:
- raise ValueError('slice step cannot be zero')
-
- # If either the start or stop index is negative, we'll need to cache
- # the rest of the iterable in order to slice from the right side.
- if (start < 0) or (stop < 0):
- self._cache.extend(self._it)
- # Otherwise we'll need to find the rightmost index and cache to that
- # point.
- else:
- n = min(max(start, stop) + 1, maxsize)
- cache_len = len(self._cache)
- if n >= cache_len:
- self._cache.extend(islice(self._it, n - cache_len))
-
- return list(self._cache)[index]
-
- def __getitem__(self, index):
- if isinstance(index, slice):
- return self._get_slice(index)
-
- cache_len = len(self._cache)
- if index < 0:
- self._cache.extend(self._it)
- elif index >= cache_len:
- self._cache.extend(islice(self._it, index + 1 - cache_len))
-
- return self._cache[index]
-
-
-def _collate(*iterables, **kwargs):
- """Helper for ``collate()``, called when the user is using the ``reverse``
- or ``key`` keyword arguments on Python versions below 3.5.
-
- """
- key = kwargs.pop('key', lambda a: a)
- reverse = kwargs.pop('reverse', False)
-
- min_or_max = partial(max if reverse else min, key=itemgetter(0))
- peekables = [peekable(it) for it in iterables]
- peekables = [p for p in peekables if p] # Kill empties.
- while peekables:
- _, p = min_or_max((key(p.peek()), p) for p in peekables)
- yield next(p)
- peekables = [x for x in peekables if x]
-
-
-def collate(*iterables, **kwargs):
- """Return a sorted merge of the items from each of several already-sorted
- *iterables*.
-
- >>> list(collate('ACDZ', 'AZ', 'JKL'))
- ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z']
-
- Works lazily, keeping only the next value from each iterable in memory. Use
- :func:`collate` to, for example, perform a n-way mergesort of items that
- don't fit in memory.
-
- If a *key* function is specified, the iterables will be sorted according
- to its result:
-
- >>> key = lambda s: int(s) # Sort by numeric value, not by string
- >>> list(collate(['1', '10'], ['2', '11'], key=key))
- ['1', '2', '10', '11']
-
-
- If the *iterables* are sorted in descending order, set *reverse* to
- ``True``:
-
- >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True))
- [5, 4, 3, 2, 1, 0]
-
- If the elements of the passed-in iterables are out of order, you might get
- unexpected results.
-
- On Python 2.7, this function delegates to :func:`heapq.merge` if neither
- of the keyword arguments are specified. On Python 3.5+, this function
- is an alias for :func:`heapq.merge`.
-
- """
- if not kwargs:
- return merge(*iterables)
-
- return _collate(*iterables, **kwargs)
-
-
-# If using Python version 3.5 or greater, heapq.merge() will be faster than
-# collate - use that instead.
-if version_info >= (3, 5, 0):
- _collate_docstring = collate.__doc__
- collate = partial(merge)
- collate.__doc__ = _collate_docstring
-
-
-def consumer(func):
- """Decorator that automatically advances a PEP-342-style "reverse iterator"
- to its first yield point so you don't have to call ``next()`` on it
- manually.
-
- >>> @consumer
- ... def tally():
- ... i = 0
- ... while True:
- ... print('Thing number %s is %s.' % (i, (yield)))
- ... i += 1
- ...
- >>> t = tally()
- >>> t.send('red')
- Thing number 0 is red.
- >>> t.send('fish')
- Thing number 1 is fish.
-
- Without the decorator, you would have to call ``next(t)`` before
- ``t.send()`` could be used.
-
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
- gen = func(*args, **kwargs)
- next(gen)
- return gen
- return wrapper
-
-
-def ilen(iterable):
- """Return the number of items in *iterable*.
-
- >>> ilen(x for x in range(1000000) if x % 3 == 0)
- 333334
-
- This consumes the iterable, so handle with care.
-
- """
- # This approach was selected because benchmarks showed it's likely the
- # fastest of the known implementations at the time of writing.
- # See GitHub tracker: #236, #230.
- counter = count()
- deque(zip(iterable, counter), maxlen=0)
- return next(counter)
-
-
-def iterate(func, start):
- """Return ``start``, ``func(start)``, ``func(func(start))``, ...
-
- >>> from itertools import islice
- >>> list(islice(iterate(lambda x: 2*x, 1), 10))
- [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
-
- """
- while True:
- yield start
- start = func(start)
-
-
-def with_iter(context_manager):
- """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
-
- For example, this will close the file when the iterator is exhausted::
-
- upper_lines = (line.upper() for line in with_iter(open('foo')))
-
- Any context manager which returns an iterable is a candidate for
- ``with_iter``.
-
- """
- with context_manager as iterable:
- for item in iterable:
- yield item
-
-
-def one(iterable, too_short=None, too_long=None):
- """Return the first item from *iterable*, which is expected to contain only
- that item. Raise an exception if *iterable* is empty or has more than one
- item.
-
- :func:`one` is useful for ensuring that an iterable contains only one item.
- For example, it can be used to retrieve the result of a database query
- that is expected to return a single row.
-
- If *iterable* is empty, ``ValueError`` will be raised. You may specify a
- different exception with the *too_short* keyword:
-
- >>> it = []
- >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError: too many items in iterable (expected 1)'
- >>> too_short = IndexError('too few items')
- >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- IndexError: too few items
-
- Similarly, if *iterable* contains more than one item, ``ValueError`` will
- be raised. You may specify a different exception with the *too_long*
- keyword:
-
- >>> it = ['too', 'many']
- >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError: too many items in iterable (expected 1)'
- >>> too_long = RuntimeError
- >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- RuntimeError
-
- Note that :func:`one` attempts to advance *iterable* twice to ensure there
- is only one item. If there is more than one, both items will be discarded.
- See :func:`spy` or :func:`peekable` to check iterable contents less
- destructively.
-
- """
- it = iter(iterable)
-
- try:
- value = next(it)
- except StopIteration:
- raise too_short or ValueError('too few items in iterable (expected 1)')
-
- try:
- next(it)
- except StopIteration:
- pass
- else:
- raise too_long or ValueError('too many items in iterable (expected 1)')
-
- return value
-
-
-def distinct_permutations(iterable):
- """Yield successive distinct permutations of the elements in *iterable*.
-
- >>> sorted(distinct_permutations([1, 0, 1]))
- [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
-
- Equivalent to ``set(permutations(iterable))``, except duplicates are not
- generated and thrown away. For larger input sequences this is much more
- efficient.
-
- Duplicate permutations arise when there are duplicated elements in the
- input iterable. The number of items returned is
- `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
- items input, and each `x_i` is the count of a distinct item in the input
- sequence.
-
- """
- def perm_unique_helper(item_counts, perm, i):
- """Internal helper function
-
- :arg item_counts: Stores the unique items in ``iterable`` and how many
- times they are repeated
- :arg perm: The permutation that is being built for output
- :arg i: The index of the permutation being modified
-
- The output permutations are built up recursively; the distinct items
- are placed until their repetitions are exhausted.
- """
- if i < 0:
- yield tuple(perm)
- else:
- for item in item_counts:
- if item_counts[item] <= 0:
- continue
- perm[i] = item
- item_counts[item] -= 1
- for x in perm_unique_helper(item_counts, perm, i - 1):
- yield x
- item_counts[item] += 1
-
- item_counts = Counter(iterable)
- length = sum(item_counts.values())
-
- return perm_unique_helper(item_counts, [None] * length, length - 1)
-
-
-def intersperse(e, iterable, n=1):
- """Intersperse filler element *e* among the items in *iterable*, leaving
- *n* items between each filler element.
-
- >>> list(intersperse('!', [1, 2, 3, 4, 5]))
- [1, '!', 2, '!', 3, '!', 4, '!', 5]
-
- >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
- [1, 2, None, 3, 4, None, 5]
-
- """
- if n == 0:
- raise ValueError('n must be > 0')
- elif n == 1:
- # interleave(repeat(e), iterable) -> e, x_0, e, e, x_1, e, x_2...
- # islice(..., 1, None) -> x_0, e, e, x_1, e, x_2...
- return islice(interleave(repeat(e), iterable), 1, None)
- else:
- # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
- # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
- # flatten(...) -> x_0, x_1, e, x_2, x_3...
- filler = repeat([e])
- chunks = chunked(iterable, n)
- return flatten(islice(interleave(filler, chunks), 1, None))
-
-
-def unique_to_each(*iterables):
- """Return the elements from each of the input iterables that aren't in the
- other input iterables.
-
- For example, suppose you have a set of packages, each with a set of
- dependencies::
-
- {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
-
- If you remove one package, which dependencies can also be removed?
-
- If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
- associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
- ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
-
- >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
- [['A'], ['C'], ['D']]
-
- If there are duplicates in one input iterable that aren't in the others
- they will be duplicated in the output. Input order is preserved::
-
- >>> unique_to_each("mississippi", "missouri")
- [['p', 'p'], ['o', 'u', 'r']]
-
- It is assumed that the elements of each iterable are hashable.
-
- """
- pool = [list(it) for it in iterables]
- counts = Counter(chain.from_iterable(map(set, pool)))
- uniques = {element for element in counts if counts[element] == 1}
- return [list(filter(uniques.__contains__, it)) for it in pool]
-
-
-def windowed(seq, n, fillvalue=None, step=1):
- """Return a sliding window of width *n* over the given iterable.
-
- >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
- >>> list(all_windows)
- [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
-
- When the window is larger than the iterable, *fillvalue* is used in place
- of missing values::
-
- >>> list(windowed([1, 2, 3], 4))
- [(1, 2, 3, None)]
-
- Each window will advance in increments of *step*:
-
- >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
- [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
-
- """
- if n < 0:
- raise ValueError('n must be >= 0')
- if n == 0:
- yield tuple()
- return
- if step < 1:
- raise ValueError('step must be >= 1')
-
- it = iter(seq)
- window = deque([], n)
- append = window.append
-
- # Initial deque fill
- for _ in range(n):
- append(next(it, fillvalue))
- yield tuple(window)
-
- # Appending new items to the right causes old items to fall off the left
- i = 0
- for item in it:
- append(item)
- i = (i + 1) % step
- if i % step == 0:
- yield tuple(window)
-
- # If there are items from the iterable in the window, pad with the given
- # value and emit them.
- if (i % step) and (step - i < n):
- for _ in range(step - i):
- append(fillvalue)
- yield tuple(window)
-
-
-def substrings(iterable, join_func=None):
- """Yield all of the substrings of *iterable*.
-
- >>> [''.join(s) for s in substrings('more')]
- ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
-
- Note that non-string iterables can also be subdivided.
-
- >>> list(substrings([0, 1, 2]))
- [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
-
- """
- # The length-1 substrings
- seq = []
- for item in iter(iterable):
- seq.append(item)
- yield (item,)
- seq = tuple(seq)
- item_count = len(seq)
-
- # And the rest
- for n in range(2, item_count + 1):
- for i in range(item_count - n + 1):
- yield seq[i:i + n]
-
-
-class bucket(object):
- """Wrap *iterable* and return an object that buckets it iterable into
- child iterables based on a *key* function.
-
- >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
- >>> s = bucket(iterable, key=lambda x: x[0])
- >>> a_iterable = s['a']
- >>> next(a_iterable)
- 'a1'
- >>> next(a_iterable)
- 'a2'
- >>> list(s['b'])
- ['b1', 'b2', 'b3']
-
- The original iterable will be advanced and its items will be cached until
- they are used by the child iterables. This may require significant storage.
-
- By default, attempting to select a bucket to which no items belong will
- exhaust the iterable and cache all values.
- If you specify a *validator* function, selected buckets will instead be
- checked against it.
-
- >>> from itertools import count
- >>> it = count(1, 2) # Infinite sequence of odd numbers
- >>> key = lambda x: x % 10 # Bucket by last digit
- >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
- >>> s = bucket(it, key=key, validator=validator)
- >>> 2 in s
- False
- >>> list(s[2])
- []
-
- """
- def __init__(self, iterable, key, validator=None):
- self._it = iter(iterable)
- self._key = key
- self._cache = defaultdict(deque)
- self._validator = validator or (lambda x: True)
-
- def __contains__(self, value):
- if not self._validator(value):
- return False
-
- try:
- item = next(self[value])
- except StopIteration:
- return False
- else:
- self._cache[value].appendleft(item)
-
- return True
-
- def _get_values(self, value):
- """
- Helper to yield items from the parent iterator that match *value*.
- Items that don't match are stored in the local cache as they
- are encountered.
- """
- while True:
- # If we've cached some items that match the target value, emit
- # the first one and evict it from the cache.
- if self._cache[value]:
- yield self._cache[value].popleft()
- # Otherwise we need to advance the parent iterator to search for
- # a matching item, caching the rest.
- else:
- while True:
- try:
- item = next(self._it)
- except StopIteration:
- return
- item_value = self._key(item)
- if item_value == value:
- yield item
- break
- elif self._validator(item_value):
- self._cache[item_value].append(item)
-
- def __getitem__(self, value):
- if not self._validator(value):
- return iter(())
-
- return self._get_values(value)
-
-
-def spy(iterable, n=1):
- """Return a 2-tuple with a list containing the first *n* elements of
- *iterable*, and an iterator with the same items as *iterable*.
- This allows you to "look ahead" at the items in the iterable without
- advancing it.
-
- There is one item in the list by default:
-
- >>> iterable = 'abcdefg'
- >>> head, iterable = spy(iterable)
- >>> head
- ['a']
- >>> list(iterable)
- ['a', 'b', 'c', 'd', 'e', 'f', 'g']
-
- You may use unpacking to retrieve items instead of lists:
-
- >>> (head,), iterable = spy('abcdefg')
- >>> head
- 'a'
- >>> (first, second), iterable = spy('abcdefg', 2)
- >>> first
- 'a'
- >>> second
- 'b'
-
- The number of items requested can be larger than the number of items in
- the iterable:
-
- >>> iterable = [1, 2, 3, 4, 5]
- >>> head, iterable = spy(iterable, 10)
- >>> head
- [1, 2, 3, 4, 5]
- >>> list(iterable)
- [1, 2, 3, 4, 5]
-
- """
- it = iter(iterable)
- head = take(n, it)
-
- return head, chain(head, it)
-
-
-def interleave(*iterables):
- """Return a new iterable yielding from each iterable in turn,
- until the shortest is exhausted.
-
- >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
- [1, 4, 6, 2, 5, 7]
-
- For a version that doesn't terminate after the shortest iterable is
- exhausted, see :func:`interleave_longest`.
-
- """
- return chain.from_iterable(zip(*iterables))
-
-
-def interleave_longest(*iterables):
- """Return a new iterable yielding from each iterable in turn,
- skipping any that are exhausted.
-
- >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
- [1, 4, 6, 2, 5, 7, 3, 8]
-
- This function produces the same output as :func:`roundrobin`, but may
- perform better for some inputs (in particular when the number of iterables
- is large).
-
- """
- i = chain.from_iterable(zip_longest(*iterables, fillvalue=_marker))
- return (x for x in i if x is not _marker)
-
-
-def collapse(iterable, base_type=None, levels=None):
- """Flatten an iterable with multiple levels of nesting (e.g., a list of
- lists of tuples) into non-iterable types.
-
- >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
- >>> list(collapse(iterable))
- [1, 2, 3, 4, 5, 6]
-
- String types are not considered iterable and will not be collapsed.
- To avoid collapsing other types, specify *base_type*:
-
- >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
- >>> list(collapse(iterable, base_type=tuple))
- ['ab', ('cd', 'ef'), 'gh', 'ij']
-
- Specify *levels* to stop flattening after a certain level:
-
- >>> iterable = [('a', ['b']), ('c', ['d'])]
- >>> list(collapse(iterable)) # Fully flattened
- ['a', 'b', 'c', 'd']
- >>> list(collapse(iterable, levels=1)) # Only one level flattened
- ['a', ['b'], 'c', ['d']]
-
- """
- def walk(node, level):
- if (
- ((levels is not None) and (level > levels)) or
- isinstance(node, string_types) or
- ((base_type is not None) and isinstance(node, base_type))
- ):
- yield node
- return
-
- try:
- tree = iter(node)
- except TypeError:
- yield node
- return
- else:
- for child in tree:
- for x in walk(child, level + 1):
- yield x
-
- for x in walk(iterable, 0):
- yield x
-
-
-def side_effect(func, iterable, chunk_size=None, before=None, after=None):
- """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
- of items) before yielding the item.
-
- `func` must be a function that takes a single argument. Its return value
- will be discarded.
-
- *before* and *after* are optional functions that take no arguments. They
- will be executed before iteration starts and after it ends, respectively.
-
- `side_effect` can be used for logging, updating progress bars, or anything
- that is not functionally "pure."
-
- Emitting a status message:
-
- >>> from more_itertools import consume
- >>> func = lambda item: print('Received {}'.format(item))
- >>> consume(side_effect(func, range(2)))
- Received 0
- Received 1
-
- Operating on chunks of items:
-
- >>> pair_sums = []
- >>> func = lambda chunk: pair_sums.append(sum(chunk))
- >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
- [0, 1, 2, 3, 4, 5]
- >>> list(pair_sums)
- [1, 5, 9]
-
- Writing to a file-like object:
-
- >>> from io import StringIO
- >>> from more_itertools import consume
- >>> f = StringIO()
- >>> func = lambda x: print(x, file=f)
- >>> before = lambda: print(u'HEADER', file=f)
- >>> after = f.close
- >>> it = [u'a', u'b', u'c']
- >>> consume(side_effect(func, it, before=before, after=after))
- >>> f.closed
- True
-
- """
- try:
- if before is not None:
- before()
-
- if chunk_size is None:
- for item in iterable:
- func(item)
- yield item
- else:
- for chunk in chunked(iterable, chunk_size):
- func(chunk)
- for item in chunk:
- yield item
- finally:
- if after is not None:
- after()
-
-
-def sliced(seq, n):
- """Yield slices of length *n* from the sequence *seq*.
-
- >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
- [(1, 2, 3), (4, 5, 6)]
-
- If the length of the sequence is not divisible by the requested slice
- length, the last slice will be shorter.
-
- >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
- [(1, 2, 3), (4, 5, 6), (7, 8)]
-
- This function will only work for iterables that support slicing.
- For non-sliceable iterables, see :func:`chunked`.
-
- """
- return takewhile(bool, (seq[i: i + n] for i in count(0, n)))
-
-
-def split_at(iterable, pred):
- """Yield lists of items from *iterable*, where each list is delimited by
- an item where callable *pred* returns ``True``. The lists do not include
- the delimiting items.
-
- >>> list(split_at('abcdcba', lambda x: x == 'b'))
- [['a'], ['c', 'd', 'c'], ['a']]
-
- >>> list(split_at(range(10), lambda n: n % 2 == 1))
- [[0], [2], [4], [6], [8], []]
- """
- buf = []
- for item in iterable:
- if pred(item):
- yield buf
- buf = []
- else:
- buf.append(item)
- yield buf
-
-
-def split_before(iterable, pred):
- """Yield lists of items from *iterable*, where each list starts with an
- item where callable *pred* returns ``True``:
-
- >>> list(split_before('OneTwo', lambda s: s.isupper()))
- [['O', 'n', 'e'], ['T', 'w', 'o']]
-
- >>> list(split_before(range(10), lambda n: n % 3 == 0))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
-
- """
- buf = []
- for item in iterable:
- if pred(item) and buf:
- yield buf
- buf = []
- buf.append(item)
- yield buf
-
-
-def split_after(iterable, pred):
- """Yield lists of items from *iterable*, where each list ends with an
- item where callable *pred* returns ``True``:
-
- >>> list(split_after('one1two2', lambda s: s.isdigit()))
- [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
-
- >>> list(split_after(range(10), lambda n: n % 3 == 0))
- [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
-
- """
- buf = []
- for item in iterable:
- buf.append(item)
- if pred(item) and buf:
- yield buf
- buf = []
- if buf:
- yield buf
-
-
-def split_into(iterable, sizes):
- """Yield a list of sequential items from *iterable* of length 'n' for each
- integer 'n' in *sizes*.
-
- >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
- [[1], [2, 3], [4, 5, 6]]
-
- If the sum of *sizes* is smaller than the length of *iterable*, then the
- remaining items of *iterable* will not be returned.
-
- >>> list(split_into([1,2,3,4,5,6], [2,3]))
- [[1, 2], [3, 4, 5]]
-
- If the sum of *sizes* is larger than the length of *iterable*, fewer items
- will be returned in the iteration that overruns *iterable* and further
- lists will be empty:
-
- >>> list(split_into([1,2,3,4], [1,2,3,4]))
- [[1], [2, 3], [4], []]
-
- When a ``None`` object is encountered in *sizes*, the returned list will
- contain items up to the end of *iterable* the same way that itertools.slice
- does:
-
- >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
- [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
-
- :func:`split_into` can be useful for grouping a series of items where the
- sizes of the groups are not uniform. An example would be where in a row
- from a table, multiple columns represent elements of the same feature
- (e.g. a point represented by x,y,z) but, the format is not the same for
- all columns.
- """
- # convert the iterable argument into an iterator so its contents can
- # be consumed by islice in case it is a generator
- it = iter(iterable)
-
- for size in sizes:
- if size is None:
- yield list(it)
- return
- else:
- yield list(islice(it, size))
-
-
-def padded(iterable, fillvalue=None, n=None, next_multiple=False):
- """Yield the elements from *iterable*, followed by *fillvalue*, such that
- at least *n* items are emitted.
-
- >>> list(padded([1, 2, 3], '?', 5))
- [1, 2, 3, '?', '?']
-
- If *next_multiple* is ``True``, *fillvalue* will be emitted until the
- number of items emitted is a multiple of *n*::
-
- >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
- [1, 2, 3, 4, None, None]
-
- If *n* is ``None``, *fillvalue* will be emitted indefinitely.
-
- """
- it = iter(iterable)
- if n is None:
- for item in chain(it, repeat(fillvalue)):
- yield item
- elif n < 1:
- raise ValueError('n must be at least 1')
- else:
- item_count = 0
- for item in it:
- yield item
- item_count += 1
-
- remaining = (n - item_count) % n if next_multiple else n - item_count
- for _ in range(remaining):
- yield fillvalue
-
-
-def distribute(n, iterable):
- """Distribute the items from *iterable* among *n* smaller iterables.
-
- >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
- >>> list(group_1)
- [1, 3, 5]
- >>> list(group_2)
- [2, 4, 6]
-
- If the length of *iterable* is not evenly divisible by *n*, then the
- length of the returned iterables will not be identical:
-
- >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
- >>> [list(c) for c in children]
- [[1, 4, 7], [2, 5], [3, 6]]
-
- If the length of *iterable* is smaller than *n*, then the last returned
- iterables will be empty:
-
- >>> children = distribute(5, [1, 2, 3])
- >>> [list(c) for c in children]
- [[1], [2], [3], [], []]
-
- This function uses :func:`itertools.tee` and may require significant
- storage. If you need the order items in the smaller iterables to match the
- original iterable, see :func:`divide`.
-
- """
- if n < 1:
- raise ValueError('n must be at least 1')
-
- children = tee(iterable, n)
- return [islice(it, index, None, n) for index, it in enumerate(children)]
-
-
-def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
- """Yield tuples whose elements are offset from *iterable*.
- The amount by which the `i`-th item in each tuple is offset is given by
- the `i`-th item in *offsets*.
-
- >>> list(stagger([0, 1, 2, 3]))
- [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
- >>> list(stagger(range(8), offsets=(0, 2, 4)))
- [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
-
- By default, the sequence will end when the final element of a tuple is the
- last item in the iterable. To continue until the first element of a tuple
- is the last item in the iterable, set *longest* to ``True``::
-
- >>> list(stagger([0, 1, 2, 3], longest=True))
- [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
-
- By default, ``None`` will be used to replace offsets beyond the end of the
- sequence. Specify *fillvalue* to use some other value.
-
- """
- children = tee(iterable, len(offsets))
-
- return zip_offset(
- *children, offsets=offsets, longest=longest, fillvalue=fillvalue
- )
-
-
-def zip_offset(*iterables, **kwargs):
- """``zip`` the input *iterables* together, but offset the `i`-th iterable
- by the `i`-th item in *offsets*.
-
- >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
- [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
-
- This can be used as a lightweight alternative to SciPy or pandas to analyze
- data sets in which some series have a lead or lag relationship.
-
- By default, the sequence will end when the shortest iterable is exhausted.
- To continue until the longest iterable is exhausted, set *longest* to
- ``True``.
-
- >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
- [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
-
- By default, ``None`` will be used to replace offsets beyond the end of the
- sequence. Specify *fillvalue* to use some other value.
-
- """
- offsets = kwargs['offsets']
- longest = kwargs.get('longest', False)
- fillvalue = kwargs.get('fillvalue', None)
-
- if len(iterables) != len(offsets):
- raise ValueError("Number of iterables and offsets didn't match")
-
- staggered = []
- for it, n in zip(iterables, offsets):
- if n < 0:
- staggered.append(chain(repeat(fillvalue, -n), it))
- elif n > 0:
- staggered.append(islice(it, n, None))
- else:
- staggered.append(it)
-
- if longest:
- return zip_longest(*staggered, fillvalue=fillvalue)
-
- return zip(*staggered)
-
-
-def sort_together(iterables, key_list=(0,), reverse=False):
- """Return the input iterables sorted together, with *key_list* as the
- priority for sorting. All iterables are trimmed to the length of the
- shortest one.
-
- This can be used like the sorting function in a spreadsheet. If each
- iterable represents a column of data, the key list determines which
- columns are used for sorting.
-
- By default, all iterables are sorted using the ``0``-th iterable::
-
- >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
- >>> sort_together(iterables)
- [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
-
- Set a different key list to sort according to another iterable.
- Specifying multiple keys dictates how ties are broken::
-
- >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
- >>> sort_together(iterables, key_list=(1, 2))
- [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
-
- Set *reverse* to ``True`` to sort in descending order.
-
- >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
- [(3, 2, 1), ('a', 'b', 'c')]
-
- """
- return list(zip(*sorted(zip(*iterables),
- key=itemgetter(*key_list),
- reverse=reverse)))
-
-
-def unzip(iterable):
- """The inverse of :func:`zip`, this function disaggregates the elements
- of the zipped *iterable*.
-
- The ``i``-th iterable contains the ``i``-th element from each element
- of the zipped iterable. The first element is used to to determine the
- length of the remaining elements.
-
- >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
- >>> letters, numbers = unzip(iterable)
- >>> list(letters)
- ['a', 'b', 'c', 'd']
- >>> list(numbers)
- [1, 2, 3, 4]
-
- This is similar to using ``zip(*iterable)``, but it avoids reading
- *iterable* into memory. Note, however, that this function uses
- :func:`itertools.tee` and thus may require significant storage.
-
- """
- head, iterable = spy(iter(iterable))
- if not head:
- # empty iterable, e.g. zip([], [], [])
- return ()
- # spy returns a one-length iterable as head
- head = head[0]
- iterables = tee(iterable, len(head))
-
- def itemgetter(i):
- def getter(obj):
- try:
- return obj[i]
- except IndexError:
- # basically if we have an iterable like
- # iter([(1, 2, 3), (4, 5), (6,)])
- # the second unzipped iterable would fail at the third tuple
- # since it would try to access tup[1]
- # same with the third unzipped iterable and the second tuple
- # to support these "improperly zipped" iterables,
- # we create a custom itemgetter
- # which just stops the unzipped iterables
- # at first length mismatch
- raise StopIteration
- return getter
-
- return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables))
-
-
-def divide(n, iterable):
- """Divide the elements from *iterable* into *n* parts, maintaining
- order.
-
- >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
- >>> list(group_1)
- [1, 2, 3]
- >>> list(group_2)
- [4, 5, 6]
-
- If the length of *iterable* is not evenly divisible by *n*, then the
- length of the returned iterables will not be identical:
-
- >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
- >>> [list(c) for c in children]
- [[1, 2, 3], [4, 5], [6, 7]]
-
- If the length of the iterable is smaller than n, then the last returned
- iterables will be empty:
-
- >>> children = divide(5, [1, 2, 3])
- >>> [list(c) for c in children]
- [[1], [2], [3], [], []]
-
- This function will exhaust the iterable before returning and may require
- significant storage. If order is not important, see :func:`distribute`,
- which does not first pull the iterable into memory.
-
- """
- if n < 1:
- raise ValueError('n must be at least 1')
-
- seq = tuple(iterable)
- q, r = divmod(len(seq), n)
-
- ret = []
- for i in range(n):
- start = (i * q) + (i if i < r else r)
- stop = ((i + 1) * q) + (i + 1 if i + 1 < r else r)
- ret.append(iter(seq[start:stop]))
-
- return ret
-
-
-def always_iterable(obj, base_type=(text_type, binary_type)):
- """If *obj* is iterable, return an iterator over its items::
-
- >>> obj = (1, 2, 3)
- >>> list(always_iterable(obj))
- [1, 2, 3]
-
- If *obj* is not iterable, return a one-item iterable containing *obj*::
-
- >>> obj = 1
- >>> list(always_iterable(obj))
- [1]
-
- If *obj* is ``None``, return an empty iterable:
-
- >>> obj = None
- >>> list(always_iterable(None))
- []
-
- By default, binary and text strings are not considered iterable::
-
- >>> obj = 'foo'
- >>> list(always_iterable(obj))
- ['foo']
-
- If *base_type* is set, objects for which ``isinstance(obj, base_type)``
- returns ``True`` won't be considered iterable.
-
- >>> obj = {'a': 1}
- >>> list(always_iterable(obj)) # Iterate over the dict's keys
- ['a']
- >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
- [{'a': 1}]
-
- Set *base_type* to ``None`` to avoid any special handling and treat objects
- Python considers iterable as iterable:
-
- >>> obj = 'foo'
- >>> list(always_iterable(obj, base_type=None))
- ['f', 'o', 'o']
- """
- if obj is None:
- return iter(())
-
- if (base_type is not None) and isinstance(obj, base_type):
- return iter((obj,))
-
- try:
- return iter(obj)
- except TypeError:
- return iter((obj,))
-
-
-def adjacent(predicate, iterable, distance=1):
- """Return an iterable over `(bool, item)` tuples where the `item` is
- drawn from *iterable* and the `bool` indicates whether
- that item satisfies the *predicate* or is adjacent to an item that does.
-
- For example, to find whether items are adjacent to a ``3``::
-
- >>> list(adjacent(lambda x: x == 3, range(6)))
- [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
-
- Set *distance* to change what counts as adjacent. For example, to find
- whether items are two places away from a ``3``:
-
- >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
- [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
-
- This is useful for contextualizing the results of a search function.
- For example, a code comparison tool might want to identify lines that
- have changed, but also surrounding lines to give the viewer of the diff
- context.
-
- The predicate function will only be called once for each item in the
- iterable.
-
- See also :func:`groupby_transform`, which can be used with this function
- to group ranges of items with the same `bool` value.
-
- """
- # Allow distance=0 mainly for testing that it reproduces results with map()
- if distance < 0:
- raise ValueError('distance must be at least 0')
-
- i1, i2 = tee(iterable)
- padding = [False] * distance
- selected = chain(padding, map(predicate, i1), padding)
- adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
- return zip(adjacent_to_selected, i2)
-
-
-def groupby_transform(iterable, keyfunc=None, valuefunc=None):
- """An extension of :func:`itertools.groupby` that transforms the values of
- *iterable* after grouping them.
- *keyfunc* is a function used to compute a grouping key for each item.
- *valuefunc* is a function for transforming the items after grouping.
-
- >>> iterable = 'AaaABbBCcA'
- >>> keyfunc = lambda x: x.upper()
- >>> valuefunc = lambda x: x.lower()
- >>> grouper = groupby_transform(iterable, keyfunc, valuefunc)
- >>> [(k, ''.join(g)) for k, g in grouper]
- [('A', 'aaaa'), ('B', 'bbb'), ('C', 'cc'), ('A', 'a')]
-
- *keyfunc* and *valuefunc* default to identity functions if they are not
- specified.
-
- :func:`groupby_transform` is useful when grouping elements of an iterable
- using a separate iterable as the key. To do this, :func:`zip` the iterables
- and pass a *keyfunc* that extracts the first element and a *valuefunc*
- that extracts the second element::
-
- >>> from operator import itemgetter
- >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
- >>> values = 'abcdefghi'
- >>> iterable = zip(keys, values)
- >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
- >>> [(k, ''.join(g)) for k, g in grouper]
- [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
-
- Note that the order of items in the iterable is significant.
- Only adjacent items are grouped together, so if you don't want any
- duplicate groups, you should sort the iterable by the key function.
-
- """
- valuefunc = (lambda x: x) if valuefunc is None else valuefunc
- return ((k, map(valuefunc, g)) for k, g in groupby(iterable, keyfunc))
-
-
-def numeric_range(*args):
- """An extension of the built-in ``range()`` function whose arguments can
- be any orderable numeric type.
-
- With only *stop* specified, *start* defaults to ``0`` and *step*
- defaults to ``1``. The output items will match the type of *stop*:
-
- >>> list(numeric_range(3.5))
- [0.0, 1.0, 2.0, 3.0]
-
- With only *start* and *stop* specified, *step* defaults to ``1``. The
- output items will match the type of *start*:
-
- >>> from decimal import Decimal
- >>> start = Decimal('2.1')
- >>> stop = Decimal('5.1')
- >>> list(numeric_range(start, stop))
- [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
-
- With *start*, *stop*, and *step* specified the output items will match
- the type of ``start + step``:
-
- >>> from fractions import Fraction
- >>> start = Fraction(1, 2) # Start at 1/2
- >>> stop = Fraction(5, 2) # End at 5/2
- >>> step = Fraction(1, 2) # Count by 1/2
- >>> list(numeric_range(start, stop, step))
- [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
-
- If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
-
- >>> list(numeric_range(3, -1, -1.0))
- [3.0, 2.0, 1.0, 0.0]
-
- Be aware of the limitations of floating point numbers; the representation
- of the yielded numbers may be surprising.
-
- """
- argc = len(args)
- if argc == 1:
- stop, = args
- start = type(stop)(0)
- step = 1
- elif argc == 2:
- start, stop = args
- step = 1
- elif argc == 3:
- start, stop, step = args
- else:
- err_msg = 'numeric_range takes at most 3 arguments, got {}'
- raise TypeError(err_msg.format(argc))
-
- values = (start + (step * n) for n in count())
- if step > 0:
- return takewhile(partial(gt, stop), values)
- elif step < 0:
- return takewhile(partial(lt, stop), values)
- else:
- raise ValueError('numeric_range arg 3 must not be zero')
-
-
-def count_cycle(iterable, n=None):
- """Cycle through the items from *iterable* up to *n* times, yielding
- the number of completed cycles along with each item. If *n* is omitted the
- process repeats indefinitely.
-
- >>> list(count_cycle('AB', 3))
- [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
-
- """
- iterable = tuple(iterable)
- if not iterable:
- return iter(())
- counter = count() if n is None else range(n)
- return ((i, item) for i in counter for item in iterable)
-
-
-def locate(iterable, pred=bool, window_size=None):
- """Yield the index of each item in *iterable* for which *pred* returns
- ``True``.
-
- *pred* defaults to :func:`bool`, which will select truthy items:
-
- >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
- [1, 2, 4]
-
- Set *pred* to a custom function to, e.g., find the indexes for a particular
- item.
-
- >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
- [1, 3]
-
- If *window_size* is given, then the *pred* function will be called with
- that many items. This enables searching for sub-sequences:
-
- >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
- >>> pred = lambda *args: args == (1, 2, 3)
- >>> list(locate(iterable, pred=pred, window_size=3))
- [1, 5, 9]
-
- Use with :func:`seekable` to find indexes and then retrieve the associated
- items:
-
- >>> from itertools import count
- >>> from more_itertools import seekable
- >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
- >>> it = seekable(source)
- >>> pred = lambda x: x > 100
- >>> indexes = locate(it, pred=pred)
- >>> i = next(indexes)
- >>> it.seek(i)
- >>> next(it)
- 106
-
- """
- if window_size is None:
- return compress(count(), map(pred, iterable))
-
- if window_size < 1:
- raise ValueError('window size must be at least 1')
-
- it = windowed(iterable, window_size, fillvalue=_marker)
- return compress(count(), starmap(pred, it))
-
-
-def lstrip(iterable, pred):
- """Yield the items from *iterable*, but strip any from the beginning
- for which *pred* returns ``True``.
-
- For example, to remove a set of items from the start of an iterable:
-
- >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
- >>> pred = lambda x: x in {None, False, ''}
- >>> list(lstrip(iterable, pred))
- [1, 2, None, 3, False, None]
-
- This function is analogous to to :func:`str.lstrip`, and is essentially
- an wrapper for :func:`itertools.dropwhile`.
-
- """
- return dropwhile(pred, iterable)
-
-
-def rstrip(iterable, pred):
- """Yield the items from *iterable*, but strip any from the end
- for which *pred* returns ``True``.
-
- For example, to remove a set of items from the end of an iterable:
-
- >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
- >>> pred = lambda x: x in {None, False, ''}
- >>> list(rstrip(iterable, pred))
- [None, False, None, 1, 2, None, 3]
-
- This function is analogous to :func:`str.rstrip`.
-
- """
- cache = []
- cache_append = cache.append
- for x in iterable:
- if pred(x):
- cache_append(x)
- else:
- for y in cache:
- yield y
- del cache[:]
- yield x
-
-
-def strip(iterable, pred):
- """Yield the items from *iterable*, but strip any from the
- beginning and end for which *pred* returns ``True``.
-
- For example, to remove a set of items from both ends of an iterable:
-
- >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
- >>> pred = lambda x: x in {None, False, ''}
- >>> list(strip(iterable, pred))
- [1, 2, None, 3]
-
- This function is analogous to :func:`str.strip`.
-
- """
- return rstrip(lstrip(iterable, pred), pred)
-
-
-def islice_extended(iterable, *args):
- """An extension of :func:`itertools.islice` that supports negative values
- for *stop*, *start*, and *step*.
-
- >>> iterable = iter('abcdefgh')
- >>> list(islice_extended(iterable, -4, -1))
- ['e', 'f', 'g']
-
- Slices with negative values require some caching of *iterable*, but this
- function takes care to minimize the amount of memory required.
-
- For example, you can use a negative step with an infinite iterator:
-
- >>> from itertools import count
- >>> list(islice_extended(count(), 110, 99, -2))
- [110, 108, 106, 104, 102, 100]
-
- """
- s = slice(*args)
- start = s.start
- stop = s.stop
- if s.step == 0:
- raise ValueError('step argument must be a non-zero integer or None.')
- step = s.step or 1
-
- it = iter(iterable)
-
- if step > 0:
- start = 0 if (start is None) else start
-
- if (start < 0):
- # Consume all but the last -start items
- cache = deque(enumerate(it, 1), maxlen=-start)
- len_iter = cache[-1][0] if cache else 0
-
- # Adjust start to be positive
- i = max(len_iter + start, 0)
-
- # Adjust stop to be positive
- if stop is None:
- j = len_iter
- elif stop >= 0:
- j = min(stop, len_iter)
- else:
- j = max(len_iter + stop, 0)
-
- # Slice the cache
- n = j - i
- if n <= 0:
- return
-
- for index, item in islice(cache, 0, n, step):
- yield item
- elif (stop is not None) and (stop < 0):
- # Advance to the start position
- next(islice(it, start, start), None)
-
- # When stop is negative, we have to carry -stop items while
- # iterating
- cache = deque(islice(it, -stop), maxlen=-stop)
-
- for index, item in enumerate(it):
- cached_item = cache.popleft()
- if index % step == 0:
- yield cached_item
- cache.append(item)
- else:
- # When both start and stop are positive we have the normal case
- for item in islice(it, start, stop, step):
- yield item
- else:
- start = -1 if (start is None) else start
-
- if (stop is not None) and (stop < 0):
- # Consume all but the last items
- n = -stop - 1
- cache = deque(enumerate(it, 1), maxlen=n)
- len_iter = cache[-1][0] if cache else 0
-
- # If start and stop are both negative they are comparable and
- # we can just slice. Otherwise we can adjust start to be negative
- # and then slice.
- if start < 0:
- i, j = start, stop
- else:
- i, j = min(start - len_iter, -1), None
-
- for index, item in list(cache)[i:j:step]:
- yield item
- else:
- # Advance to the stop position
- if stop is not None:
- m = stop + 1
- next(islice(it, m, m), None)
-
- # stop is positive, so if start is negative they are not comparable
- # and we need the rest of the items.
- if start < 0:
- i = start
- n = None
- # stop is None and start is positive, so we just need items up to
- # the start index.
- elif stop is None:
- i = None
- n = start + 1
- # Both stop and start are positive, so they are comparable.
- else:
- i = None
- n = start - stop
- if n <= 0:
- return
-
- cache = list(islice(it, n))
-
- for item in cache[i::step]:
- yield item
-
-
-def always_reversible(iterable):
- """An extension of :func:`reversed` that supports all iterables, not
- just those which implement the ``Reversible`` or ``Sequence`` protocols.
-
- >>> print(*always_reversible(x for x in range(3)))
- 2 1 0
-
- If the iterable is already reversible, this function returns the
- result of :func:`reversed()`. If the iterable is not reversible,
- this function will cache the remaining items in the iterable and
- yield them in reverse order, which may require significant storage.
- """
- try:
- return reversed(iterable)
- except TypeError:
- return reversed(list(iterable))
-
-
-def consecutive_groups(iterable, ordering=lambda x: x):
- """Yield groups of consecutive items using :func:`itertools.groupby`.
- The *ordering* function determines whether two items are adjacent by
- returning their position.
-
- By default, the ordering function is the identity function. This is
- suitable for finding runs of numbers:
-
- >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
- >>> for group in consecutive_groups(iterable):
- ... print(list(group))
- [1]
- [10, 11, 12]
- [20]
- [30, 31, 32, 33]
- [40]
-
- For finding runs of adjacent letters, try using the :meth:`index` method
- of a string of letters:
-
- >>> from string import ascii_lowercase
- >>> iterable = 'abcdfgilmnop'
- >>> ordering = ascii_lowercase.index
- >>> for group in consecutive_groups(iterable, ordering):
- ... print(list(group))
- ['a', 'b', 'c', 'd']
- ['f', 'g']
- ['i']
- ['l', 'm', 'n', 'o', 'p']
-
- """
- for k, g in groupby(
- enumerate(iterable), key=lambda x: x[0] - ordering(x[1])
- ):
- yield map(itemgetter(1), g)
-
-
-def difference(iterable, func=sub):
- """By default, compute the first difference of *iterable* using
- :func:`operator.sub`.
-
- >>> iterable = [0, 1, 3, 6, 10]
- >>> list(difference(iterable))
- [0, 1, 2, 3, 4]
-
- This is the opposite of :func:`accumulate`'s default behavior:
-
- >>> from more_itertools import accumulate
- >>> iterable = [0, 1, 2, 3, 4]
- >>> list(accumulate(iterable))
- [0, 1, 3, 6, 10]
- >>> list(difference(accumulate(iterable)))
- [0, 1, 2, 3, 4]
-
- By default *func* is :func:`operator.sub`, but other functions can be
- specified. They will be applied as follows::
-
- A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
-
- For example, to do progressive division:
-
- >>> iterable = [1, 2, 6, 24, 120] # Factorial sequence
- >>> func = lambda x, y: x // y
- >>> list(difference(iterable, func))
- [1, 2, 3, 4, 5]
-
- """
- a, b = tee(iterable)
- try:
- item = next(b)
- except StopIteration:
- return iter([])
- return chain([item], map(lambda x: func(x[1], x[0]), zip(a, b)))
-
-
-class SequenceView(Sequence):
- """Return a read-only view of the sequence object *target*.
-
- :class:`SequenceView` objects are analogous to Python's built-in
- "dictionary view" types. They provide a dynamic view of a sequence's items,
- meaning that when the sequence updates, so does the view.
-
- >>> seq = ['0', '1', '2']
- >>> view = SequenceView(seq)
- >>> view
- SequenceView(['0', '1', '2'])
- >>> seq.append('3')
- >>> view
- SequenceView(['0', '1', '2', '3'])
-
- Sequence views support indexing, slicing, and length queries. They act
- like the underlying sequence, except they don't allow assignment:
-
- >>> view[1]
- '1'
- >>> view[1:-1]
- ['1', '2']
- >>> len(view)
- 4
-
- Sequence views are useful as an alternative to copying, as they don't
- require (much) extra storage.
-
- """
- def __init__(self, target):
- if not isinstance(target, Sequence):
- raise TypeError
- self._target = target
-
- def __getitem__(self, index):
- return self._target[index]
-
- def __len__(self):
- return len(self._target)
-
- def __repr__(self):
- return '{}({})'.format(self.__class__.__name__, repr(self._target))
-
-
-class seekable(object):
- """Wrap an iterator to allow for seeking backward and forward. This
- progressively caches the items in the source iterable so they can be
- re-visited.
-
- Call :meth:`seek` with an index to seek to that position in the source
- iterable.
-
- To "reset" an iterator, seek to ``0``:
-
- >>> from itertools import count
- >>> it = seekable((str(n) for n in count()))
- >>> next(it), next(it), next(it)
- ('0', '1', '2')
- >>> it.seek(0)
- >>> next(it), next(it), next(it)
- ('0', '1', '2')
- >>> next(it)
- '3'
-
- You can also seek forward:
-
- >>> it = seekable((str(n) for n in range(20)))
- >>> it.seek(10)
- >>> next(it)
- '10'
- >>> it.seek(20) # Seeking past the end of the source isn't a problem
- >>> list(it)
- []
- >>> it.seek(0) # Resetting works even after hitting the end
- >>> next(it), next(it), next(it)
- ('0', '1', '2')
-
- The cache grows as the source iterable progresses, so beware of wrapping
- very large or infinite iterables.
-
- You may view the contents of the cache with the :meth:`elements` method.
- That returns a :class:`SequenceView`, a view that updates automatically:
-
- >>> it = seekable((str(n) for n in range(10)))
- >>> next(it), next(it), next(it)
- ('0', '1', '2')
- >>> elements = it.elements()
- >>> elements
- SequenceView(['0', '1', '2'])
- >>> next(it)
- '3'
- >>> elements
- SequenceView(['0', '1', '2', '3'])
-
- """
-
- def __init__(self, iterable):
- self._source = iter(iterable)
- self._cache = []
- self._index = None
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if self._index is not None:
- try:
- item = self._cache[self._index]
- except IndexError:
- self._index = None
- else:
- self._index += 1
- return item
-
- item = next(self._source)
- self._cache.append(item)
- return item
-
- next = __next__
-
- def elements(self):
- return SequenceView(self._cache)
-
- def seek(self, index):
- self._index = index
- remainder = index - len(self._cache)
- if remainder > 0:
- consume(self, remainder)
-
-
-class run_length(object):
- """
- :func:`run_length.encode` compresses an iterable with run-length encoding.
- It yields groups of repeated items with the count of how many times they
- were repeated:
-
- >>> uncompressed = 'abbcccdddd'
- >>> list(run_length.encode(uncompressed))
- [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
-
- :func:`run_length.decode` decompresses an iterable that was previously
- compressed with run-length encoding. It yields the items of the
- decompressed iterable:
-
- >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
- >>> list(run_length.decode(compressed))
- ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
-
- """
-
- @staticmethod
- def encode(iterable):
- return ((k, ilen(g)) for k, g in groupby(iterable))
-
- @staticmethod
- def decode(iterable):
- return chain.from_iterable(repeat(k, n) for k, n in iterable)
-
-
-def exactly_n(iterable, n, predicate=bool):
- """Return ``True`` if exactly ``n`` items in the iterable are ``True``
- according to the *predicate* function.
-
- >>> exactly_n([True, True, False], 2)
- True
- >>> exactly_n([True, True, False], 1)
- False
- >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
- True
-
- The iterable will be advanced until ``n + 1`` truthy items are encountered,
- so avoid calling it on infinite iterables.
-
- """
- return len(take(n + 1, filter(predicate, iterable))) == n
-
-
-def circular_shifts(iterable):
- """Return a list of circular shifts of *iterable*.
-
- >>> circular_shifts(range(4))
- [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
- """
- lst = list(iterable)
- return take(len(lst), windowed(cycle(lst), len(lst)))
-
-
-def make_decorator(wrapping_func, result_index=0):
- """Return a decorator version of *wrapping_func*, which is a function that
- modifies an iterable. *result_index* is the position in that function's
- signature where the iterable goes.
-
- This lets you use itertools on the "production end," i.e. at function
- definition. This can augment what the function returns without changing the
- function's code.
-
- For example, to produce a decorator version of :func:`chunked`:
-
- >>> from more_itertools import chunked
- >>> chunker = make_decorator(chunked, result_index=0)
- >>> @chunker(3)
- ... def iter_range(n):
- ... return iter(range(n))
- ...
- >>> list(iter_range(9))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
-
- To only allow truthy items to be returned:
-
- >>> truth_serum = make_decorator(filter, result_index=1)
- >>> @truth_serum(bool)
- ... def boolean_test():
- ... return [0, 1, '', ' ', False, True]
- ...
- >>> list(boolean_test())
- [1, ' ', True]
-
- The :func:`peekable` and :func:`seekable` wrappers make for practical
- decorators:
-
- >>> from more_itertools import peekable
- >>> peekable_function = make_decorator(peekable)
- >>> @peekable_function()
- ... def str_range(*args):
- ... return (str(x) for x in range(*args))
- ...
- >>> it = str_range(1, 20, 2)
- >>> next(it), next(it), next(it)
- ('1', '3', '5')
- >>> it.peek()
- '7'
- >>> next(it)
- '7'
-
- """
- # See https://sites.google.com/site/bbayles/index/decorator_factory for
- # notes on how this works.
- def decorator(*wrapping_args, **wrapping_kwargs):
- def outer_wrapper(f):
- def inner_wrapper(*args, **kwargs):
- result = f(*args, **kwargs)
- wrapping_args_ = list(wrapping_args)
- wrapping_args_.insert(result_index, result)
- return wrapping_func(*wrapping_args_, **wrapping_kwargs)
-
- return inner_wrapper
-
- return outer_wrapper
-
- return decorator
-
-
-def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
- """Return a dictionary that maps the items in *iterable* to categories
- defined by *keyfunc*, transforms them with *valuefunc*, and
- then summarizes them by category with *reducefunc*.
-
- *valuefunc* defaults to the identity function if it is unspecified.
- If *reducefunc* is unspecified, no summarization takes place:
-
- >>> keyfunc = lambda x: x.upper()
- >>> result = map_reduce('abbccc', keyfunc)
- >>> sorted(result.items())
- [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
-
- Specifying *valuefunc* transforms the categorized items:
-
- >>> keyfunc = lambda x: x.upper()
- >>> valuefunc = lambda x: 1
- >>> result = map_reduce('abbccc', keyfunc, valuefunc)
- >>> sorted(result.items())
- [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
-
- Specifying *reducefunc* summarizes the categorized items:
-
- >>> keyfunc = lambda x: x.upper()
- >>> valuefunc = lambda x: 1
- >>> reducefunc = sum
- >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
- >>> sorted(result.items())
- [('A', 1), ('B', 2), ('C', 3)]
-
- You may want to filter the input iterable before applying the map/reduce
- procedure:
-
- >>> all_items = range(30)
- >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
- >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
- >>> categories = map_reduce(items, keyfunc=keyfunc)
- >>> sorted(categories.items())
- [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
- >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
- >>> sorted(summaries.items())
- [(0, 90), (1, 75)]
-
- Note that all items in the iterable are gathered into a list before the
- summarization step, which may require significant storage.
-
- The returned object is a :obj:`collections.defaultdict` with the
- ``default_factory`` set to ``None``, such that it behaves like a normal
- dictionary.
-
- """
- valuefunc = (lambda x: x) if (valuefunc is None) else valuefunc
-
- ret = defaultdict(list)
- for item in iterable:
- key = keyfunc(item)
- value = valuefunc(item)
- ret[key].append(value)
-
- if reducefunc is not None:
- for key, value_list in ret.items():
- ret[key] = reducefunc(value_list)
-
- ret.default_factory = None
- return ret
-
-
-def rlocate(iterable, pred=bool, window_size=None):
- """Yield the index of each item in *iterable* for which *pred* returns
- ``True``, starting from the right and moving left.
-
- *pred* defaults to :func:`bool`, which will select truthy items:
-
- >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
- [4, 2, 1]
-
- Set *pred* to a custom function to, e.g., find the indexes for a particular
- item:
-
- >>> iterable = iter('abcb')
- >>> pred = lambda x: x == 'b'
- >>> list(rlocate(iterable, pred))
- [3, 1]
-
- If *window_size* is given, then the *pred* function will be called with
- that many items. This enables searching for sub-sequences:
-
- >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
- >>> pred = lambda *args: args == (1, 2, 3)
- >>> list(rlocate(iterable, pred=pred, window_size=3))
- [9, 5, 1]
-
- Beware, this function won't return anything for infinite iterables.
- If *iterable* is reversible, ``rlocate`` will reverse it and search from
- the right. Otherwise, it will search from the left and return the results
- in reverse order.
-
- See :func:`locate` to for other example applications.
-
- """
- if window_size is None:
- try:
- len_iter = len(iterable)
- return (
- len_iter - i - 1 for i in locate(reversed(iterable), pred)
- )
- except TypeError:
- pass
-
- return reversed(list(locate(iterable, pred, window_size)))
-
-
-def replace(iterable, pred, substitutes, count=None, window_size=1):
- """Yield the items from *iterable*, replacing the items for which *pred*
- returns ``True`` with the items from the iterable *substitutes*.
-
- >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
- >>> pred = lambda x: x == 0
- >>> substitutes = (2, 3)
- >>> list(replace(iterable, pred, substitutes))
- [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
-
- If *count* is given, the number of replacements will be limited:
-
- >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
- >>> pred = lambda x: x == 0
- >>> substitutes = [None]
- >>> list(replace(iterable, pred, substitutes, count=2))
- [1, 1, None, 1, 1, None, 1, 1, 0]
-
- Use *window_size* to control the number of items passed as arguments to
- *pred*. This allows for locating and replacing subsequences.
-
- >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
- >>> window_size = 3
- >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
- >>> substitutes = [3, 4] # Splice in these items
- >>> list(replace(iterable, pred, substitutes, window_size=window_size))
- [3, 4, 5, 3, 4, 5]
-
- """
- if window_size < 1:
- raise ValueError('window_size must be at least 1')
-
- # Save the substitutes iterable, since it's used more than once
- substitutes = tuple(substitutes)
-
- # Add padding such that the number of windows matches the length of the
- # iterable
- it = chain(iterable, [_marker] * (window_size - 1))
- windows = windowed(it, window_size)
-
- n = 0
- for w in windows:
- # If the current window matches our predicate (and we haven't hit
- # our maximum number of replacements), splice in the substitutes
- # and then consume the following windows that overlap with this one.
- # For example, if the iterable is (0, 1, 2, 3, 4...)
- # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
- # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
- if pred(*w):
- if (count is None) or (n < count):
- n += 1
- for s in substitutes:
- yield s
- consume(windows, window_size - 1)
- continue
-
- # If there was no match (or we've reached the replacement limit),
- # yield the first item from the window.
- if w and (w[0] is not _marker):
- yield w[0]
diff --git a/contrib/python/more-itertools/py2/more_itertools/recipes.py b/contrib/python/more-itertools/py2/more_itertools/recipes.py
deleted file mode 100644
index 3b455d4eb8..0000000000
--- a/contrib/python/more-itertools/py2/more_itertools/recipes.py
+++ /dev/null
@@ -1,577 +0,0 @@
-"""Imported from the recipes section of the itertools documentation.
-
-All functions taken from the recipes section of the itertools library docs
-[1]_.
-Some backward-compatible usability improvements have been made.
-
-.. [1] http://docs.python.org/library/itertools.html#recipes
-
-"""
-from collections import deque
-from itertools import (
- chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee
-)
-import operator
-from random import randrange, sample, choice
-
-from six import PY2
-from six.moves import filter, filterfalse, map, range, zip, zip_longest
-
-__all__ = [
- 'accumulate',
- 'all_equal',
- 'consume',
- 'dotproduct',
- 'first_true',
- 'flatten',
- 'grouper',
- 'iter_except',
- 'ncycles',
- 'nth',
- 'nth_combination',
- 'padnone',
- 'pairwise',
- 'partition',
- 'powerset',
- 'prepend',
- 'quantify',
- 'random_combination_with_replacement',
- 'random_combination',
- 'random_permutation',
- 'random_product',
- 'repeatfunc',
- 'roundrobin',
- 'tabulate',
- 'tail',
- 'take',
- 'unique_everseen',
- 'unique_justseen',
-]
-
-
-def accumulate(iterable, func=operator.add):
- """
- Return an iterator whose items are the accumulated results of a function
- (specified by the optional *func* argument) that takes two arguments.
- By default, returns accumulated sums with :func:`operator.add`.
-
- >>> list(accumulate([1, 2, 3, 4, 5])) # Running sum
- [1, 3, 6, 10, 15]
- >>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product
- [1, 2, 6]
- >>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum
- [0, 1, 1, 2, 3, 3]
-
- This function is available in the ``itertools`` module for Python 3.2 and
- greater.
-
- """
- it = iter(iterable)
- try:
- total = next(it)
- except StopIteration:
- return
- else:
- yield total
-
- for element in it:
- total = func(total, element)
- yield total
-
-
-def take(n, iterable):
- """Return first *n* items of the iterable as a list.
-
- >>> take(3, range(10))
- [0, 1, 2]
- >>> take(5, range(3))
- [0, 1, 2]
-
- Effectively a short replacement for ``next`` based iterator consumption
- when you want more than one item, but less than the whole iterator.
-
- """
- return list(islice(iterable, n))
-
-
-def tabulate(function, start=0):
- """Return an iterator over the results of ``func(start)``,
- ``func(start + 1)``, ``func(start + 2)``...
-
- *func* should be a function that accepts one integer argument.
-
- If *start* is not specified it defaults to 0. It will be incremented each
- time the iterator is advanced.
-
- >>> square = lambda x: x ** 2
- >>> iterator = tabulate(square, -3)
- >>> take(4, iterator)
- [9, 4, 1, 0]
-
- """
- return map(function, count(start))
-
-
-def tail(n, iterable):
- """Return an iterator over the last *n* items of *iterable*.
-
- >>> t = tail(3, 'ABCDEFG')
- >>> list(t)
- ['E', 'F', 'G']
-
- """
- return iter(deque(iterable, maxlen=n))
-
-
-def consume(iterator, n=None):
- """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
- entirely.
-
- Efficiently exhausts an iterator without returning values. Defaults to
- consuming the whole iterator, but an optional second argument may be
- provided to limit consumption.
-
- >>> i = (x for x in range(10))
- >>> next(i)
- 0
- >>> consume(i, 3)
- >>> next(i)
- 4
- >>> consume(i)
- >>> next(i)
- Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- StopIteration
-
- If the iterator has fewer items remaining than the provided limit, the
- whole iterator will be consumed.
-
- >>> i = (x for x in range(3))
- >>> consume(i, 5)
- >>> next(i)
- Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- StopIteration
-
- """
- # Use functions that consume iterators at C speed.
- if n is None:
- # feed the entire iterator into a zero-length deque
- deque(iterator, maxlen=0)
- else:
- # advance to the empty slice starting at position n
- next(islice(iterator, n, n), None)
-
-
-def nth(iterable, n, default=None):
- """Returns the nth item or a default value.
-
- >>> l = range(10)
- >>> nth(l, 3)
- 3
- >>> nth(l, 20, "zebra")
- 'zebra'
-
- """
- return next(islice(iterable, n, None), default)
-
-
-def all_equal(iterable):
- """
- Returns ``True`` if all the elements are equal to each other.
-
- >>> all_equal('aaaa')
- True
- >>> all_equal('aaab')
- False
-
- """
- g = groupby(iterable)
- return next(g, True) and not next(g, False)
-
-
-def quantify(iterable, pred=bool):
- """Return the how many times the predicate is true.
-
- >>> quantify([True, False, True])
- 2
-
- """
- return sum(map(pred, iterable))
-
-
-def padnone(iterable):
- """Returns the sequence of elements and then returns ``None`` indefinitely.
-
- >>> take(5, padnone(range(3)))
- [0, 1, 2, None, None]
-
- Useful for emulating the behavior of the built-in :func:`map` function.
-
- See also :func:`padded`.
-
- """
- return chain(iterable, repeat(None))
-
-
-def ncycles(iterable, n):
- """Returns the sequence elements *n* times
-
- >>> list(ncycles(["a", "b"], 3))
- ['a', 'b', 'a', 'b', 'a', 'b']
-
- """
- return chain.from_iterable(repeat(tuple(iterable), n))
-
-
-def dotproduct(vec1, vec2):
- """Returns the dot product of the two iterables.
-
- >>> dotproduct([10, 10], [20, 20])
- 400
-
- """
- return sum(map(operator.mul, vec1, vec2))
-
-
-def flatten(listOfLists):
- """Return an iterator flattening one level of nesting in a list of lists.
-
- >>> list(flatten([[0, 1], [2, 3]]))
- [0, 1, 2, 3]
-
- See also :func:`collapse`, which can flatten multiple levels of nesting.
-
- """
- return chain.from_iterable(listOfLists)
-
-
-def repeatfunc(func, times=None, *args):
- """Call *func* with *args* repeatedly, returning an iterable over the
- results.
-
- If *times* is specified, the iterable will terminate after that many
- repetitions:
-
- >>> from operator import add
- >>> times = 4
- >>> args = 3, 5
- >>> list(repeatfunc(add, times, *args))
- [8, 8, 8, 8]
-
- If *times* is ``None`` the iterable will not terminate:
-
- >>> from random import randrange
- >>> times = None
- >>> args = 1, 11
- >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
- [2, 4, 8, 1, 8, 4]
-
- """
- if times is None:
- return starmap(func, repeat(args))
- return starmap(func, repeat(args, times))
-
-
-def pairwise(iterable):
- """Returns an iterator of paired items, overlapping, from the original
-
- >>> take(4, pairwise(count()))
- [(0, 1), (1, 2), (2, 3), (3, 4)]
-
- """
- a, b = tee(iterable)
- next(b, None)
- return zip(a, b)
-
-
-def grouper(n, iterable, fillvalue=None):
- """Collect data into fixed-length chunks or blocks.
-
- >>> list(grouper(3, 'ABCDEFG', 'x'))
- [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
-
- """
- args = [iter(iterable)] * n
- return zip_longest(fillvalue=fillvalue, *args)
-
-
-def roundrobin(*iterables):
- """Yields an item from each iterable, alternating between them.
-
- >>> list(roundrobin('ABC', 'D', 'EF'))
- ['A', 'D', 'E', 'B', 'F', 'C']
-
- This function produces the same output as :func:`interleave_longest`, but
- may perform better for some inputs (in particular when the number of
- iterables is small).
-
- """
- # Recipe credited to George Sakkis
- pending = len(iterables)
- if PY2:
- nexts = cycle(iter(it).next for it in iterables)
- else:
- nexts = cycle(iter(it).__next__ for it in iterables)
- while pending:
- try:
- for next in nexts:
- yield next()
- except StopIteration:
- pending -= 1
- nexts = cycle(islice(nexts, pending))
-
-
-def partition(pred, iterable):
- """
- Returns a 2-tuple of iterables derived from the input iterable.
- The first yields the items that have ``pred(item) == False``.
- The second yields the items that have ``pred(item) == True``.
-
- >>> is_odd = lambda x: x % 2 != 0
- >>> iterable = range(10)
- >>> even_items, odd_items = partition(is_odd, iterable)
- >>> list(even_items), list(odd_items)
- ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
-
- """
- # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
- t1, t2 = tee(iterable)
- return filterfalse(pred, t1), filter(pred, t2)
-
-
-def powerset(iterable):
- """Yields all possible subsets of the iterable.
-
- >>> list(powerset([1, 2, 3]))
- [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
-
- :func:`powerset` will operate on iterables that aren't :class:`set`
- instances, so repeated elements in the input will produce repeated elements
- in the output. Use :func:`unique_everseen` on the input to avoid generating
- duplicates:
-
- >>> seq = [1, 1, 0]
- >>> list(powerset(seq))
- [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
- >>> from more_itertools import unique_everseen
- >>> list(powerset(unique_everseen(seq)))
- [(), (1,), (0,), (1, 0)]
-
- """
- s = list(iterable)
- return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
-
-
-def unique_everseen(iterable, key=None):
- """
- Yield unique elements, preserving order.
-
- >>> list(unique_everseen('AAAABBBCCDAABBB'))
- ['A', 'B', 'C', 'D']
- >>> list(unique_everseen('ABBCcAD', str.lower))
- ['A', 'B', 'C', 'D']
-
- Sequences with a mix of hashable and unhashable items can be used.
- The function will be slower (i.e., `O(n^2)`) for unhashable items.
-
- """
- seenset = set()
- seenset_add = seenset.add
- seenlist = []
- seenlist_add = seenlist.append
- if key is None:
- for element in iterable:
- try:
- if element not in seenset:
- seenset_add(element)
- yield element
- except TypeError:
- if element not in seenlist:
- seenlist_add(element)
- yield element
- else:
- for element in iterable:
- k = key(element)
- try:
- if k not in seenset:
- seenset_add(k)
- yield element
- except TypeError:
- if k not in seenlist:
- seenlist_add(k)
- yield element
-
-
-def unique_justseen(iterable, key=None):
- """Yields elements in order, ignoring serial duplicates
-
- >>> list(unique_justseen('AAAABBBCCDAABBB'))
- ['A', 'B', 'C', 'D', 'A', 'B']
- >>> list(unique_justseen('ABBCcAD', str.lower))
- ['A', 'B', 'C', 'A', 'D']
-
- """
- return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
-
-
-def iter_except(func, exception, first=None):
- """Yields results from a function repeatedly until an exception is raised.
-
- Converts a call-until-exception interface to an iterator interface.
- Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
- to end the loop.
-
- >>> l = [0, 1, 2]
- >>> list(iter_except(l.pop, IndexError))
- [2, 1, 0]
-
- """
- try:
- if first is not None:
- yield first()
- while 1:
- yield func()
- except exception:
- pass
-
-
-def first_true(iterable, default=None, pred=None):
- """
- Returns the first true value in the iterable.
-
- If no true value is found, returns *default*
-
- If *pred* is not None, returns the first item for which
- ``pred(item) == True`` .
-
- >>> first_true(range(10))
- 1
- >>> first_true(range(10), pred=lambda x: x > 5)
- 6
- >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
- 'missing'
-
- """
- return next(filter(pred, iterable), default)
-
-
-def random_product(*args, **kwds):
- """Draw an item at random from each of the input iterables.
-
- >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
- ('c', 3, 'Z')
-
- If *repeat* is provided as a keyword argument, that many items will be
- drawn from each iterable.
-
- >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
- ('a', 2, 'd', 3)
-
- This equivalent to taking a random selection from
- ``itertools.product(*args, **kwarg)``.
-
- """
- pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1)
- return tuple(choice(pool) for pool in pools)
-
-
-def random_permutation(iterable, r=None):
- """Return a random *r* length permutation of the elements in *iterable*.
-
- If *r* is not specified or is ``None``, then *r* defaults to the length of
- *iterable*.
-
- >>> random_permutation(range(5)) # doctest:+SKIP
- (3, 4, 0, 1, 2)
-
- This equivalent to taking a random selection from
- ``itertools.permutations(iterable, r)``.
-
- """
- pool = tuple(iterable)
- r = len(pool) if r is None else r
- return tuple(sample(pool, r))
-
-
-def random_combination(iterable, r):
- """Return a random *r* length subsequence of the elements in *iterable*.
-
- >>> random_combination(range(5), 3) # doctest:+SKIP
- (2, 3, 4)
-
- This equivalent to taking a random selection from
- ``itertools.combinations(iterable, r)``.
-
- """
- pool = tuple(iterable)
- n = len(pool)
- indices = sorted(sample(range(n), r))
- return tuple(pool[i] for i in indices)
-
-
-def random_combination_with_replacement(iterable, r):
- """Return a random *r* length subsequence of elements in *iterable*,
- allowing individual elements to be repeated.
-
- >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
- (0, 0, 1, 2, 2)
-
- This equivalent to taking a random selection from
- ``itertools.combinations_with_replacement(iterable, r)``.
-
- """
- pool = tuple(iterable)
- n = len(pool)
- indices = sorted(randrange(n) for i in range(r))
- return tuple(pool[i] for i in indices)
-
-
-def nth_combination(iterable, r, index):
- """Equivalent to ``list(combinations(iterable, r))[index]``.
-
- The subsequences of *iterable* that are of length *r* can be ordered
- lexicographically. :func:`nth_combination` computes the subsequence at
- sort position *index* directly, without computing the previous
- subsequences.
-
- """
- pool = tuple(iterable)
- n = len(pool)
- if (r < 0) or (r > n):
- raise ValueError
-
- c = 1
- k = min(r, n - r)
- for i in range(1, k + 1):
- c = c * (n - k + i) // i
-
- if index < 0:
- index += c
-
- if (index < 0) or (index >= c):
- raise IndexError
-
- result = []
- while r:
- c, n, r = c * r // n, n - 1, r - 1
- while index >= c:
- index -= c
- c, n = c * (n - r) // n, n - 1
- result.append(pool[-1 - n])
-
- return tuple(result)
-
-
-def prepend(value, iterator):
- """Yield *value*, followed by the elements in *iterator*.
-
- >>> value = '0'
- >>> iterator = ['1', '2', '3']
- >>> list(prepend(value, iterator))
- ['0', '1', '2', '3']
-
- To prepend multiple values, see :func:`itertools.chain`.
-
- """
- return chain([value], iterator)
diff --git a/contrib/python/more-itertools/py2/more_itertools/tests/test_more.py b/contrib/python/more-itertools/py2/more_itertools/tests/test_more.py
deleted file mode 100644
index 5f7e13df41..0000000000
--- a/contrib/python/more-itertools/py2/more_itertools/tests/test_more.py
+++ /dev/null
@@ -1,2313 +0,0 @@
-from __future__ import division, print_function, unicode_literals
-
-from collections import OrderedDict
-from decimal import Decimal
-from doctest import DocTestSuite
-from fractions import Fraction
-from functools import partial, reduce
-from heapq import merge
-from io import StringIO
-from itertools import (
- chain,
- count,
- groupby,
- islice,
- permutations,
- product,
- repeat,
-)
-from operator import add, mul, itemgetter
-from unittest import TestCase
-
-from six.moves import filter, map, range, zip
-
-import more_itertools as mi
-
-
-def load_tests(loader, tests, ignore):
- # Add the doctests
- tests.addTests(DocTestSuite('more_itertools.more'))
- return tests
-
-
-class CollateTests(TestCase):
- """Unit tests for ``collate()``"""
- # Also accidentally tests peekable, though that could use its own tests
-
- def test_default(self):
- """Test with the default `key` function."""
- iterables = [range(4), range(7), range(3, 6)]
- self.assertEqual(
- sorted(reduce(list.__add__, [list(it) for it in iterables])),
- list(mi.collate(*iterables))
- )
-
- def test_key(self):
- """Test using a custom `key` function."""
- iterables = [range(5, 0, -1), range(4, 0, -1)]
- actual = sorted(
- reduce(list.__add__, [list(it) for it in iterables]), reverse=True
- )
- expected = list(mi.collate(*iterables, key=lambda x: -x))
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- """Be nice if passed an empty list of iterables."""
- self.assertEqual([], list(mi.collate()))
-
- def test_one(self):
- """Work when only 1 iterable is passed."""
- self.assertEqual([0, 1], list(mi.collate(range(2))))
-
- def test_reverse(self):
- """Test the `reverse` kwarg."""
- iterables = [range(4, 0, -1), range(7, 0, -1), range(3, 6, -1)]
-
- actual = sorted(
- reduce(list.__add__, [list(it) for it in iterables]), reverse=True
- )
- expected = list(mi.collate(*iterables, reverse=True))
- self.assertEqual(actual, expected)
-
- def test_alias(self):
- self.assertNotEqual(merge.__doc__, mi.collate.__doc__)
- self.assertNotEqual(partial.__doc__, mi.collate.__doc__)
-
-
-class ChunkedTests(TestCase):
- """Tests for ``chunked()``"""
-
- def test_even(self):
- """Test when ``n`` divides evenly into the length of the iterable."""
- self.assertEqual(
- list(mi.chunked('ABCDEF', 3)), [['A', 'B', 'C'], ['D', 'E', 'F']]
- )
-
- def test_odd(self):
- """Test when ``n`` does not divide evenly into the length of the
- iterable.
-
- """
- self.assertEqual(
- list(mi.chunked('ABCDE', 3)), [['A', 'B', 'C'], ['D', 'E']]
- )
-
-
-class FirstTests(TestCase):
- """Tests for ``first()``"""
-
- def test_many(self):
- """Test that it works on many-item iterables."""
- # Also try it on a generator expression to make sure it works on
- # whatever those return, across Python versions.
- self.assertEqual(mi.first(x for x in range(4)), 0)
-
- def test_one(self):
- """Test that it doesn't raise StopIteration prematurely."""
- self.assertEqual(mi.first([3]), 3)
-
- def test_empty_stop_iteration(self):
- """It should raise StopIteration for empty iterables."""
- self.assertRaises(ValueError, lambda: mi.first([]))
-
- def test_default(self):
- """It should return the provided default arg for empty iterables."""
- self.assertEqual(mi.first([], 'boo'), 'boo')
-
-
-class IterOnlyRange:
- """User-defined iterable class which only support __iter__.
-
- It is not specified to inherit ``object``, so indexing on a instance will
- raise an ``AttributeError`` rather than ``TypeError`` in Python 2.
-
- >>> r = IterOnlyRange(5)
- >>> r[0] # doctest: +SKIP
- AttributeError: IterOnlyRange instance has no attribute '__getitem__'
-
- Note: In Python 3, ``TypeError`` will be raised because ``object`` is
- inherited implicitly by default.
-
- >>> r[0] # doctest: +SKIP
- TypeError: 'IterOnlyRange' object does not support indexing
- """
- def __init__(self, n):
- """Set the length of the range."""
- self.n = n
-
- def __iter__(self):
- """Works same as range()."""
- return iter(range(self.n))
-
-
-class LastTests(TestCase):
- """Tests for ``last()``"""
-
- def test_many_nonsliceable(self):
- """Test that it works on many-item non-slice-able iterables."""
- # Also try it on a generator expression to make sure it works on
- # whatever those return, across Python versions.
- self.assertEqual(mi.last(x for x in range(4)), 3)
-
- def test_one_nonsliceable(self):
- """Test that it doesn't raise StopIteration prematurely."""
- self.assertEqual(mi.last(x for x in range(1)), 0)
-
- def test_empty_stop_iteration_nonsliceable(self):
- """It should raise ValueError for empty non-slice-able iterables."""
- self.assertRaises(ValueError, lambda: mi.last(x for x in range(0)))
-
- def test_default_nonsliceable(self):
- """It should return the provided default arg for empty non-slice-able
- iterables.
- """
- self.assertEqual(mi.last((x for x in range(0)), 'boo'), 'boo')
-
- def test_many_sliceable(self):
- """Test that it works on many-item slice-able iterables."""
- self.assertEqual(mi.last([0, 1, 2, 3]), 3)
-
- def test_one_sliceable(self):
- """Test that it doesn't raise StopIteration prematurely."""
- self.assertEqual(mi.last([3]), 3)
-
- def test_empty_stop_iteration_sliceable(self):
- """It should raise ValueError for empty slice-able iterables."""
- self.assertRaises(ValueError, lambda: mi.last([]))
-
- def test_default_sliceable(self):
- """It should return the provided default arg for empty slice-able
- iterables.
- """
- self.assertEqual(mi.last([], 'boo'), 'boo')
-
- def test_dict(self):
- """last(dic) and last(dic.keys()) should return same result."""
- dic = {'a': 1, 'b': 2, 'c': 3}
- self.assertEqual(mi.last(dic), mi.last(dic.keys()))
-
- def test_ordereddict(self):
- """last(dic) should return the last key."""
- od = OrderedDict()
- od['a'] = 1
- od['b'] = 2
- od['c'] = 3
- self.assertEqual(mi.last(od), 'c')
-
- def test_customrange(self):
- """It should work on custom class where [] raises AttributeError."""
- self.assertEqual(mi.last(IterOnlyRange(5)), 4)
-
-
-class PeekableTests(TestCase):
- """Tests for ``peekable()`` behavor not incidentally covered by testing
- ``collate()``
-
- """
- def test_peek_default(self):
- """Make sure passing a default into ``peek()`` works."""
- p = mi.peekable([])
- self.assertEqual(p.peek(7), 7)
-
- def test_truthiness(self):
- """Make sure a ``peekable`` tests true iff there are items remaining in
- the iterable.
-
- """
- p = mi.peekable([])
- self.assertFalse(p)
-
- p = mi.peekable(range(3))
- self.assertTrue(p)
-
- def test_simple_peeking(self):
- """Make sure ``next`` and ``peek`` advance and don't advance the
- iterator, respectively.
-
- """
- p = mi.peekable(range(10))
- self.assertEqual(next(p), 0)
- self.assertEqual(p.peek(), 1)
- self.assertEqual(next(p), 1)
-
- def test_indexing(self):
- """
- Indexing into the peekable shouldn't advance the iterator.
- """
- p = mi.peekable('abcdefghijkl')
-
- # The 0th index is what ``next()`` will return
- self.assertEqual(p[0], 'a')
- self.assertEqual(next(p), 'a')
-
- # Indexing further into the peekable shouldn't advance the itertor
- self.assertEqual(p[2], 'd')
- self.assertEqual(next(p), 'b')
-
- # The 0th index moves up with the iterator; the last index follows
- self.assertEqual(p[0], 'c')
- self.assertEqual(p[9], 'l')
-
- self.assertEqual(next(p), 'c')
- self.assertEqual(p[8], 'l')
-
- # Negative indexing should work too
- self.assertEqual(p[-2], 'k')
- self.assertEqual(p[-9], 'd')
- self.assertRaises(IndexError, lambda: p[-10])
-
- def test_slicing(self):
- """Slicing the peekable shouldn't advance the iterator."""
- seq = list('abcdefghijkl')
- p = mi.peekable(seq)
-
- # Slicing the peekable should just be like slicing a re-iterable
- self.assertEqual(p[1:4], seq[1:4])
-
- # Advancing the iterator moves the slices up also
- self.assertEqual(next(p), 'a')
- self.assertEqual(p[1:4], seq[1:][1:4])
-
- # Implicit starts and stop should work
- self.assertEqual(p[:5], seq[1:][:5])
- self.assertEqual(p[:], seq[1:][:])
-
- # Indexing past the end should work
- self.assertEqual(p[:100], seq[1:][:100])
-
- # Steps should work, including negative
- self.assertEqual(p[::2], seq[1:][::2])
- self.assertEqual(p[::-1], seq[1:][::-1])
-
- def test_slicing_reset(self):
- """Test slicing on a fresh iterable each time"""
- iterable = ['0', '1', '2', '3', '4', '5']
- indexes = list(range(-4, len(iterable) + 4)) + [None]
- steps = [1, 2, 3, 4, -1, -2, -3, 4]
- for slice_args in product(indexes, indexes, steps):
- it = iter(iterable)
- p = mi.peekable(it)
- next(p)
- index = slice(*slice_args)
- actual = p[index]
- expected = iterable[1:][index]
- self.assertEqual(actual, expected, slice_args)
-
- def test_slicing_error(self):
- iterable = '01234567'
- p = mi.peekable(iter(iterable))
-
- # Prime the cache
- p.peek()
- old_cache = list(p._cache)
-
- # Illegal slice
- with self.assertRaises(ValueError):
- p[1:-1:0]
-
- # Neither the cache nor the iteration should be affected
- self.assertEqual(old_cache, list(p._cache))
- self.assertEqual(list(p), list(iterable))
-
- def test_passthrough(self):
- """Iterating a peekable without using ``peek()`` or ``prepend()``
- should just give the underlying iterable's elements (a trivial test but
- useful to set a baseline in case something goes wrong)"""
- expected = [1, 2, 3, 4, 5]
- actual = list(mi.peekable(expected))
- self.assertEqual(actual, expected)
-
- # prepend() behavior tests
-
- def test_prepend(self):
- """Tests intersperesed ``prepend()`` and ``next()`` calls"""
- it = mi.peekable(range(2))
- actual = []
-
- # Test prepend() before next()
- it.prepend(10)
- actual += [next(it), next(it)]
-
- # Test prepend() between next()s
- it.prepend(11)
- actual += [next(it), next(it)]
-
- # Test prepend() after source iterable is consumed
- it.prepend(12)
- actual += [next(it)]
-
- expected = [10, 0, 11, 1, 12]
- self.assertEqual(actual, expected)
-
- def test_multi_prepend(self):
- """Tests prepending multiple items and getting them in proper order"""
- it = mi.peekable(range(5))
- actual = [next(it), next(it)]
- it.prepend(10, 11, 12)
- it.prepend(20, 21)
- actual += list(it)
- expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4]
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- """Tests prepending in front of an empty iterable"""
- it = mi.peekable([])
- it.prepend(10)
- actual = list(it)
- expected = [10]
- self.assertEqual(actual, expected)
-
- def test_prepend_truthiness(self):
- """Tests that ``__bool__()`` or ``__nonzero__()`` works properly
- with ``prepend()``"""
- it = mi.peekable(range(5))
- self.assertTrue(it)
- actual = list(it)
- self.assertFalse(it)
- it.prepend(10)
- self.assertTrue(it)
- actual += [next(it)]
- self.assertFalse(it)
- expected = [0, 1, 2, 3, 4, 10]
- self.assertEqual(actual, expected)
-
- def test_multi_prepend_peek(self):
- """Tests prepending multiple elements and getting them in reverse order
- while peeking"""
- it = mi.peekable(range(5))
- actual = [next(it), next(it)]
- self.assertEqual(it.peek(), 2)
- it.prepend(10, 11, 12)
- self.assertEqual(it.peek(), 10)
- it.prepend(20, 21)
- self.assertEqual(it.peek(), 20)
- actual += list(it)
- self.assertFalse(it)
- expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4]
- self.assertEqual(actual, expected)
-
- def test_prepend_after_stop(self):
- """Test resuming iteration after a previous exhaustion"""
- it = mi.peekable(range(3))
- self.assertEqual(list(it), [0, 1, 2])
- self.assertRaises(StopIteration, lambda: next(it))
- it.prepend(10)
- self.assertEqual(next(it), 10)
- self.assertRaises(StopIteration, lambda: next(it))
-
- def test_prepend_slicing(self):
- """Tests interaction between prepending and slicing"""
- seq = list(range(20))
- p = mi.peekable(seq)
-
- p.prepend(30, 40, 50)
- pseq = [30, 40, 50] + seq # pseq for prepended_seq
-
- # adapt the specific tests from test_slicing
- self.assertEqual(p[0], 30)
- self.assertEqual(p[1:8], pseq[1:8])
- self.assertEqual(p[1:], pseq[1:])
- self.assertEqual(p[:5], pseq[:5])
- self.assertEqual(p[:], pseq[:])
- self.assertEqual(p[:100], pseq[:100])
- self.assertEqual(p[::2], pseq[::2])
- self.assertEqual(p[::-1], pseq[::-1])
-
- def test_prepend_indexing(self):
- """Tests interaction between prepending and indexing"""
- seq = list(range(20))
- p = mi.peekable(seq)
-
- p.prepend(30, 40, 50)
-
- self.assertEqual(p[0], 30)
- self.assertEqual(next(p), 30)
- self.assertEqual(p[2], 0)
- self.assertEqual(next(p), 40)
- self.assertEqual(p[0], 50)
- self.assertEqual(p[9], 8)
- self.assertEqual(next(p), 50)
- self.assertEqual(p[8], 8)
- self.assertEqual(p[-2], 18)
- self.assertEqual(p[-9], 11)
- self.assertRaises(IndexError, lambda: p[-21])
-
- def test_prepend_iterable(self):
- """Tests prepending from an iterable"""
- it = mi.peekable(range(5))
- # Don't directly use the range() object to avoid any range-specific
- # optimizations
- it.prepend(*(x for x in range(5)))
- actual = list(it)
- expected = list(chain(range(5), range(5)))
- self.assertEqual(actual, expected)
-
- def test_prepend_many(self):
- """Tests that prepending a huge number of elements works"""
- it = mi.peekable(range(5))
- # Don't directly use the range() object to avoid any range-specific
- # optimizations
- it.prepend(*(x for x in range(20000)))
- actual = list(it)
- expected = list(chain(range(20000), range(5)))
- self.assertEqual(actual, expected)
-
- def test_prepend_reversed(self):
- """Tests prepending from a reversed iterable"""
- it = mi.peekable(range(3))
- it.prepend(*reversed((10, 11, 12)))
- actual = list(it)
- expected = [12, 11, 10, 0, 1, 2]
- self.assertEqual(actual, expected)
-
-
-class ConsumerTests(TestCase):
- """Tests for ``consumer()``"""
-
- def test_consumer(self):
- @mi.consumer
- def eater():
- while True:
- x = yield # noqa
-
- e = eater()
- e.send('hi') # without @consumer, would raise TypeError
-
-
-class DistinctPermutationsTests(TestCase):
- def test_distinct_permutations(self):
- """Make sure the output for ``distinct_permutations()`` is the same as
- set(permutations(it)).
-
- """
- iterable = ['z', 'a', 'a', 'q', 'q', 'q', 'y']
- test_output = sorted(mi.distinct_permutations(iterable))
- ref_output = sorted(set(permutations(iterable)))
- self.assertEqual(test_output, ref_output)
-
- def test_other_iterables(self):
- """Make sure ``distinct_permutations()`` accepts a different type of
- iterables.
-
- """
- # a generator
- iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y'])
- test_output = sorted(mi.distinct_permutations(iterable))
- # "reload" it
- iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y'])
- ref_output = sorted(set(permutations(iterable)))
- self.assertEqual(test_output, ref_output)
-
- # an iterator
- iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y'])
- test_output = sorted(mi.distinct_permutations(iterable))
- # "reload" it
- iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y'])
- ref_output = sorted(set(permutations(iterable)))
- self.assertEqual(test_output, ref_output)
-
-
-class IlenTests(TestCase):
- def test_ilen(self):
- """Sanity-checks for ``ilen()``."""
- # Non-empty
- self.assertEqual(
- mi.ilen(filter(lambda x: x % 10 == 0, range(101))), 11
- )
-
- # Empty
- self.assertEqual(mi.ilen((x for x in range(0))), 0)
-
- # Iterable with __len__
- self.assertEqual(mi.ilen(list(range(6))), 6)
-
-
-class WithIterTests(TestCase):
- def test_with_iter(self):
- s = StringIO('One fish\nTwo fish')
- initial_words = [line.split()[0] for line in mi.with_iter(s)]
-
- # Iterable's items should be faithfully represented
- self.assertEqual(initial_words, ['One', 'Two'])
- # The file object should be closed
- self.assertTrue(s.closed)
-
-
-class OneTests(TestCase):
- def test_basic(self):
- it = iter(['item'])
- self.assertEqual(mi.one(it), 'item')
-
- def test_too_short(self):
- it = iter([])
- self.assertRaises(ValueError, lambda: mi.one(it))
- self.assertRaises(IndexError, lambda: mi.one(it, too_short=IndexError))
-
- def test_too_long(self):
- it = count()
- self.assertRaises(ValueError, lambda: mi.one(it)) # burn 0 and 1
- self.assertEqual(next(it), 2)
- self.assertRaises(
- OverflowError, lambda: mi.one(it, too_long=OverflowError)
- )
-
-
-class IntersperseTest(TestCase):
- """ Tests for intersperse() """
-
- def test_even(self):
- iterable = (x for x in '01')
- self.assertEqual(
- list(mi.intersperse(None, iterable)), ['0', None, '1']
- )
-
- def test_odd(self):
- iterable = (x for x in '012')
- self.assertEqual(
- list(mi.intersperse(None, iterable)), ['0', None, '1', None, '2']
- )
-
- def test_nested(self):
- element = ('a', 'b')
- iterable = (x for x in '012')
- actual = list(mi.intersperse(element, iterable))
- expected = ['0', ('a', 'b'), '1', ('a', 'b'), '2']
- self.assertEqual(actual, expected)
-
- def test_not_iterable(self):
- self.assertRaises(TypeError, lambda: mi.intersperse('x', 1))
-
- def test_n(self):
- for n, element, expected in [
- (1, '_', ['0', '_', '1', '_', '2', '_', '3', '_', '4', '_', '5']),
- (2, '_', ['0', '1', '_', '2', '3', '_', '4', '5']),
- (3, '_', ['0', '1', '2', '_', '3', '4', '5']),
- (4, '_', ['0', '1', '2', '3', '_', '4', '5']),
- (5, '_', ['0', '1', '2', '3', '4', '_', '5']),
- (6, '_', ['0', '1', '2', '3', '4', '5']),
- (7, '_', ['0', '1', '2', '3', '4', '5']),
- (3, ['a', 'b'], ['0', '1', '2', ['a', 'b'], '3', '4', '5']),
- ]:
- iterable = (x for x in '012345')
- actual = list(mi.intersperse(element, iterable, n=n))
- self.assertEqual(actual, expected)
-
- def test_n_zero(self):
- self.assertRaises(
- ValueError, lambda: list(mi.intersperse('x', '012', n=0))
- )
-
-
-class UniqueToEachTests(TestCase):
- """Tests for ``unique_to_each()``"""
-
- def test_all_unique(self):
- """When all the input iterables are unique the output should match
- the input."""
- iterables = [[1, 2], [3, 4, 5], [6, 7, 8]]
- self.assertEqual(mi.unique_to_each(*iterables), iterables)
-
- def test_duplicates(self):
- """When there are duplicates in any of the input iterables that aren't
- in the rest, those duplicates should be emitted."""
- iterables = ["mississippi", "missouri"]
- self.assertEqual(
- mi.unique_to_each(*iterables), [['p', 'p'], ['o', 'u', 'r']]
- )
-
- def test_mixed(self):
- """When the input iterables contain different types the function should
- still behave properly"""
- iterables = ['x', (i for i in range(3)), [1, 2, 3], tuple()]
- self.assertEqual(mi.unique_to_each(*iterables), [['x'], [0], [3], []])
-
-
-class WindowedTests(TestCase):
- """Tests for ``windowed()``"""
-
- def test_basic(self):
- actual = list(mi.windowed([1, 2, 3, 4, 5], 3))
- expected = [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
- self.assertEqual(actual, expected)
-
- def test_large_size(self):
- """
- When the window size is larger than the iterable, and no fill value is
- given,``None`` should be filled in.
- """
- actual = list(mi.windowed([1, 2, 3, 4, 5], 6))
- expected = [(1, 2, 3, 4, 5, None)]
- self.assertEqual(actual, expected)
-
- def test_fillvalue(self):
- """
- When sizes don't match evenly, the given fill value should be used.
- """
- iterable = [1, 2, 3, 4, 5]
-
- for n, kwargs, expected in [
- (6, {}, [(1, 2, 3, 4, 5, '!')]), # n > len(iterable)
- (3, {'step': 3}, [(1, 2, 3), (4, 5, '!')]), # using ``step``
- ]:
- actual = list(mi.windowed(iterable, n, fillvalue='!', **kwargs))
- self.assertEqual(actual, expected)
-
- def test_zero(self):
- """When the window size is zero, an empty tuple should be emitted."""
- actual = list(mi.windowed([1, 2, 3, 4, 5], 0))
- expected = [tuple()]
- self.assertEqual(actual, expected)
-
- def test_negative(self):
- """When the window size is negative, ValueError should be raised."""
- with self.assertRaises(ValueError):
- list(mi.windowed([1, 2, 3, 4, 5], -1))
-
- def test_step(self):
- """The window should advance by the number of steps provided"""
- iterable = [1, 2, 3, 4, 5, 6, 7]
- for n, step, expected in [
- (3, 2, [(1, 2, 3), (3, 4, 5), (5, 6, 7)]), # n > step
- (3, 3, [(1, 2, 3), (4, 5, 6), (7, None, None)]), # n == step
- (3, 4, [(1, 2, 3), (5, 6, 7)]), # line up nicely
- (3, 5, [(1, 2, 3), (6, 7, None)]), # off by one
- (3, 6, [(1, 2, 3), (7, None, None)]), # off by two
- (3, 7, [(1, 2, 3)]), # step past the end
- (7, 8, [(1, 2, 3, 4, 5, 6, 7)]), # step > len(iterable)
- ]:
- actual = list(mi.windowed(iterable, n, step=step))
- self.assertEqual(actual, expected)
-
- # Step must be greater than or equal to 1
- with self.assertRaises(ValueError):
- list(mi.windowed(iterable, 3, step=0))
-
-
-class SubstringsTests(TestCase):
- def test_basic(self):
- iterable = (x for x in range(4))
- actual = list(mi.substrings(iterable))
- expected = [
- (0,),
- (1,),
- (2,),
- (3,),
- (0, 1),
- (1, 2),
- (2, 3),
- (0, 1, 2),
- (1, 2, 3),
- (0, 1, 2, 3),
- ]
- self.assertEqual(actual, expected)
-
- def test_strings(self):
- iterable = 'abc'
- actual = list(mi.substrings(iterable))
- expected = [
- ('a',), ('b',), ('c',), ('a', 'b'), ('b', 'c'), ('a', 'b', 'c')
- ]
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- iterable = iter([])
- actual = list(mi.substrings(iterable))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_order(self):
- iterable = [2, 0, 1]
- actual = list(mi.substrings(iterable))
- expected = [(2,), (0,), (1,), (2, 0), (0, 1), (2, 0, 1)]
- self.assertEqual(actual, expected)
-
-
-class BucketTests(TestCase):
- """Tests for ``bucket()``"""
-
- def test_basic(self):
- iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33]
- D = mi.bucket(iterable, key=lambda x: 10 * (x // 10))
-
- # In-order access
- self.assertEqual(list(D[10]), [10, 11, 12])
-
- # Out of order access
- self.assertEqual(list(D[30]), [30, 31, 33])
- self.assertEqual(list(D[20]), [20, 21, 22, 23])
-
- self.assertEqual(list(D[40]), []) # Nothing in here!
-
- def test_in(self):
- iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33]
- D = mi.bucket(iterable, key=lambda x: 10 * (x // 10))
-
- self.assertIn(10, D)
- self.assertNotIn(40, D)
- self.assertIn(20, D)
- self.assertNotIn(21, D)
-
- # Checking in-ness shouldn't advance the iterator
- self.assertEqual(next(D[10]), 10)
-
- def test_validator(self):
- iterable = count(0)
- key = lambda x: int(str(x)[0]) # First digit of each number
- validator = lambda x: 0 < x < 10 # No leading zeros
- D = mi.bucket(iterable, key, validator=validator)
- self.assertEqual(mi.take(3, D[1]), [1, 10, 11])
- self.assertNotIn(0, D) # Non-valid entries don't return True
- self.assertNotIn(0, D._cache) # Don't store non-valid entries
- self.assertEqual(list(D[0]), [])
-
-
-class SpyTests(TestCase):
- """Tests for ``spy()``"""
-
- def test_basic(self):
- original_iterable = iter('abcdefg')
- head, new_iterable = mi.spy(original_iterable)
- self.assertEqual(head, ['a'])
- self.assertEqual(
- list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g']
- )
-
- def test_unpacking(self):
- original_iterable = iter('abcdefg')
- (first, second, third), new_iterable = mi.spy(original_iterable, 3)
- self.assertEqual(first, 'a')
- self.assertEqual(second, 'b')
- self.assertEqual(third, 'c')
- self.assertEqual(
- list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g']
- )
-
- def test_too_many(self):
- original_iterable = iter('abc')
- head, new_iterable = mi.spy(original_iterable, 4)
- self.assertEqual(head, ['a', 'b', 'c'])
- self.assertEqual(list(new_iterable), ['a', 'b', 'c'])
-
- def test_zero(self):
- original_iterable = iter('abc')
- head, new_iterable = mi.spy(original_iterable, 0)
- self.assertEqual(head, [])
- self.assertEqual(list(new_iterable), ['a', 'b', 'c'])
-
-
-class InterleaveTests(TestCase):
- def test_even(self):
- actual = list(mi.interleave([1, 4, 7], [2, 5, 8], [3, 6, 9]))
- expected = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_short(self):
- actual = list(mi.interleave([1, 4], [2, 5, 7], [3, 6, 8]))
- expected = [1, 2, 3, 4, 5, 6]
- self.assertEqual(actual, expected)
-
- def test_mixed_types(self):
- it_list = ['a', 'b', 'c', 'd']
- it_str = '12345'
- it_inf = count()
- actual = list(mi.interleave(it_list, it_str, it_inf))
- expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', 3]
- self.assertEqual(actual, expected)
-
-
-class InterleaveLongestTests(TestCase):
- def test_even(self):
- actual = list(mi.interleave_longest([1, 4, 7], [2, 5, 8], [3, 6, 9]))
- expected = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_short(self):
- actual = list(mi.interleave_longest([1, 4], [2, 5, 7], [3, 6, 8]))
- expected = [1, 2, 3, 4, 5, 6, 7, 8]
- self.assertEqual(actual, expected)
-
- def test_mixed_types(self):
- it_list = ['a', 'b', 'c', 'd']
- it_str = '12345'
- it_gen = (x for x in range(3))
- actual = list(mi.interleave_longest(it_list, it_str, it_gen))
- expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', '5']
- self.assertEqual(actual, expected)
-
-
-class TestCollapse(TestCase):
- """Tests for ``collapse()``"""
-
- def test_collapse(self):
- l = [[1], 2, [[3], 4], [[[5]]]]
- self.assertEqual(list(mi.collapse(l)), [1, 2, 3, 4, 5])
-
- def test_collapse_to_string(self):
- l = [["s1"], "s2", [["s3"], "s4"], [[["s5"]]]]
- self.assertEqual(list(mi.collapse(l)), ["s1", "s2", "s3", "s4", "s5"])
-
- def test_collapse_flatten(self):
- l = [[1], [2], [[3], 4], [[[5]]]]
- self.assertEqual(list(mi.collapse(l, levels=1)), list(mi.flatten(l)))
-
- def test_collapse_to_level(self):
- l = [[1], 2, [[3], 4], [[[5]]]]
- self.assertEqual(list(mi.collapse(l, levels=2)), [1, 2, 3, 4, [5]])
- self.assertEqual(
- list(mi.collapse(mi.collapse(l, levels=1), levels=1)),
- list(mi.collapse(l, levels=2))
- )
-
- def test_collapse_to_list(self):
- l = (1, [2], (3, [4, (5,)], 'ab'))
- actual = list(mi.collapse(l, base_type=list))
- expected = [1, [2], 3, [4, (5,)], 'ab']
- self.assertEqual(actual, expected)
-
-
-class SideEffectTests(TestCase):
- """Tests for ``side_effect()``"""
-
- def test_individual(self):
- # The function increments the counter for each call
- counter = [0]
-
- def func(arg):
- counter[0] += 1
-
- result = list(mi.side_effect(func, range(10)))
- self.assertEqual(result, list(range(10)))
- self.assertEqual(counter[0], 10)
-
- def test_chunked(self):
- # The function increments the counter for each call
- counter = [0]
-
- def func(arg):
- counter[0] += 1
-
- result = list(mi.side_effect(func, range(10), 2))
- self.assertEqual(result, list(range(10)))
- self.assertEqual(counter[0], 5)
-
- def test_before_after(self):
- f = StringIO()
- collector = []
-
- def func(item):
- print(item, file=f)
- collector.append(f.getvalue())
-
- def it():
- yield 'a'
- yield 'b'
- raise RuntimeError('kaboom')
-
- before = lambda: print('HEADER', file=f)
- after = f.close
-
- try:
- mi.consume(mi.side_effect(func, it(), before=before, after=after))
- except RuntimeError:
- pass
-
- # The iterable should have been written to the file
- self.assertEqual(collector, ['HEADER\na\n', 'HEADER\na\nb\n'])
-
- # The file should be closed even though something bad happened
- self.assertTrue(f.closed)
-
- def test_before_fails(self):
- f = StringIO()
- func = lambda x: print(x, file=f)
-
- def before():
- raise RuntimeError('ouch')
-
- try:
- mi.consume(
- mi.side_effect(func, 'abc', before=before, after=f.close)
- )
- except RuntimeError:
- pass
-
- # The file should be closed even though something bad happened in the
- # before function
- self.assertTrue(f.closed)
-
-
-class SlicedTests(TestCase):
- """Tests for ``sliced()``"""
-
- def test_even(self):
- """Test when the length of the sequence is divisible by *n*"""
- seq = 'ABCDEFGHI'
- self.assertEqual(list(mi.sliced(seq, 3)), ['ABC', 'DEF', 'GHI'])
-
- def test_odd(self):
- """Test when the length of the sequence is not divisible by *n*"""
- seq = 'ABCDEFGHI'
- self.assertEqual(list(mi.sliced(seq, 4)), ['ABCD', 'EFGH', 'I'])
-
- def test_not_sliceable(self):
- seq = (x for x in 'ABCDEFGHI')
-
- with self.assertRaises(TypeError):
- list(mi.sliced(seq, 3))
-
-
-class SplitAtTests(TestCase):
- """Tests for ``split()``"""
-
- def comp_with_str_split(self, str_to_split, delim):
- pred = lambda c: c == delim
- actual = list(map(''.join, mi.split_at(str_to_split, pred)))
- expected = str_to_split.split(delim)
- self.assertEqual(actual, expected)
-
- def test_seperators(self):
- test_strs = ['', 'abcba', 'aaabbbcccddd', 'e']
- for s, delim in product(test_strs, 'abcd'):
- self.comp_with_str_split(s, delim)
-
-
-class SplitBeforeTest(TestCase):
- """Tests for ``split_before()``"""
-
- def test_starts_with_sep(self):
- actual = list(mi.split_before('xooxoo', lambda c: c == 'x'))
- expected = [['x', 'o', 'o'], ['x', 'o', 'o']]
- self.assertEqual(actual, expected)
-
- def test_ends_with_sep(self):
- actual = list(mi.split_before('ooxoox', lambda c: c == 'x'))
- expected = [['o', 'o'], ['x', 'o', 'o'], ['x']]
- self.assertEqual(actual, expected)
-
- def test_no_sep(self):
- actual = list(mi.split_before('ooo', lambda c: c == 'x'))
- expected = [['o', 'o', 'o']]
- self.assertEqual(actual, expected)
-
-
-class SplitAfterTest(TestCase):
- """Tests for ``split_after()``"""
-
- def test_starts_with_sep(self):
- actual = list(mi.split_after('xooxoo', lambda c: c == 'x'))
- expected = [['x'], ['o', 'o', 'x'], ['o', 'o']]
- self.assertEqual(actual, expected)
-
- def test_ends_with_sep(self):
- actual = list(mi.split_after('ooxoox', lambda c: c == 'x'))
- expected = [['o', 'o', 'x'], ['o', 'o', 'x']]
- self.assertEqual(actual, expected)
-
- def test_no_sep(self):
- actual = list(mi.split_after('ooo', lambda c: c == 'x'))
- expected = [['o', 'o', 'o']]
- self.assertEqual(actual, expected)
-
-
-class SplitIntoTests(TestCase):
- """Tests for ``split_into()``"""
-
- def test_iterable_just_right(self):
- """Size of ``iterable`` equals the sum of ``sizes``."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, 4]
- expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_too_small(self):
- """Size of ``iterable`` is smaller than sum of ``sizes``. Last return
- list is shorter as a result."""
- iterable = [1, 2, 3, 4, 5, 6, 7]
- sizes = [2, 3, 4]
- expected = [[1, 2], [3, 4, 5], [6, 7]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_too_small_extra(self):
- """Size of ``iterable`` is smaller than sum of ``sizes``. Second last
- return list is shorter and last return list is empty as a result."""
- iterable = [1, 2, 3, 4, 5, 6, 7]
- sizes = [2, 3, 4, 5]
- expected = [[1, 2], [3, 4, 5], [6, 7], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_too_large(self):
- """Size of ``iterable`` is larger than sum of ``sizes``. Not all
- items of iterable are returned."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, 2]
- expected = [[1, 2], [3, 4, 5], [6, 7]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_using_none_with_leftover(self):
- """Last item of ``sizes`` is None when items still remain in
- ``iterable``. Last list returned stretches to fit all remaining items
- of ``iterable``."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, None]
- expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_using_none_without_leftover(self):
- """Last item of ``sizes`` is None when no items remain in
- ``iterable``. Last list returned is empty."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, 4, None]
- expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_using_none_mid_sizes(self):
- """None is present in ``sizes`` but is not the last item. Last list
- returned stretches to fit all remaining items of ``iterable`` but
- all items in ``sizes`` after None are ignored."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, None, 4]
- expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_empty(self):
- """``iterable`` argument is empty but ``sizes`` is not. An empty
- list is returned for each item in ``sizes``."""
- iterable = []
- sizes = [2, 4, 2]
- expected = [[], [], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_empty_using_none(self):
- """``iterable`` argument is empty but ``sizes`` is not. An empty
- list is returned for each item in ``sizes`` that is not after a
- None item."""
- iterable = []
- sizes = [2, 4, None, 2]
- expected = [[], [], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_sizes_empty(self):
- """``sizes`` argument is empty but ``iterable`` is not. An empty
- generator is returned."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = []
- expected = []
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_both_empty(self):
- """Both ``sizes`` and ``iterable`` arguments are empty. An empty
- generator is returned."""
- iterable = []
- sizes = []
- expected = []
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_bool_in_sizes(self):
- """A bool object is present in ``sizes`` is treated as a 1 or 0 for
- ``True`` or ``False`` due to bool being an instance of int."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [3, True, 2, False]
- expected = [[1, 2, 3], [4], [5, 6], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_invalid_in_sizes(self):
- """A ValueError is raised if an object in ``sizes`` is neither ``None``
- or an integer."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [1, [], 3]
- with self.assertRaises(ValueError):
- list(mi.split_into(iterable, sizes))
-
- def test_invalid_in_sizes_after_none(self):
- """A item in ``sizes`` that is invalid will not raise a TypeError if it
- comes after a ``None`` item."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [3, 4, None, []]
- expected = [[1, 2, 3], [4, 5, 6, 7], [8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_generator_iterable_integrity(self):
- """Check that if ``iterable`` is an iterator, it is consumed only by as
- many items as the sum of ``sizes``."""
- iterable = (i for i in range(10))
- sizes = [2, 3]
-
- expected = [[0, 1], [2, 3, 4]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- iterable_expected = [5, 6, 7, 8, 9]
- iterable_actual = list(iterable)
- self.assertEqual(iterable_actual, iterable_expected)
-
- def test_generator_sizes_integrity(self):
- """Check that if ``sizes`` is an iterator, it is consumed only until a
- ``None`` item is reached"""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = (i for i in [1, 2, None, 3, 4])
-
- expected = [[1], [2, 3], [4, 5, 6, 7, 8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- sizes_expected = [3, 4]
- sizes_actual = list(sizes)
- self.assertEqual(sizes_actual, sizes_expected)
-
-
-class PaddedTest(TestCase):
- """Tests for ``padded()``"""
-
- def test_no_n(self):
- seq = [1, 2, 3]
-
- # No fillvalue
- self.assertEqual(mi.take(5, mi.padded(seq)), [1, 2, 3, None, None])
-
- # With fillvalue
- self.assertEqual(
- mi.take(5, mi.padded(seq, fillvalue='')), [1, 2, 3, '', '']
- )
-
- def test_invalid_n(self):
- self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=-1)))
- self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=0)))
-
- def test_valid_n(self):
- seq = [1, 2, 3, 4, 5]
-
- # No need for padding: len(seq) <= n
- self.assertEqual(list(mi.padded(seq, n=4)), [1, 2, 3, 4, 5])
- self.assertEqual(list(mi.padded(seq, n=5)), [1, 2, 3, 4, 5])
-
- # No fillvalue
- self.assertEqual(
- list(mi.padded(seq, n=7)), [1, 2, 3, 4, 5, None, None]
- )
-
- # With fillvalue
- self.assertEqual(
- list(mi.padded(seq, fillvalue='', n=7)), [1, 2, 3, 4, 5, '', '']
- )
-
- def test_next_multiple(self):
- seq = [1, 2, 3, 4, 5, 6]
-
- # No need for padding: len(seq) % n == 0
- self.assertEqual(
- list(mi.padded(seq, n=3, next_multiple=True)), [1, 2, 3, 4, 5, 6]
- )
-
- # Padding needed: len(seq) < n
- self.assertEqual(
- list(mi.padded(seq, n=8, next_multiple=True)),
- [1, 2, 3, 4, 5, 6, None, None]
- )
-
- # No padding needed: len(seq) == n
- self.assertEqual(
- list(mi.padded(seq, n=6, next_multiple=True)), [1, 2, 3, 4, 5, 6]
- )
-
- # Padding needed: len(seq) > n
- self.assertEqual(
- list(mi.padded(seq, n=4, next_multiple=True)),
- [1, 2, 3, 4, 5, 6, None, None]
- )
-
- # With fillvalue
- self.assertEqual(
- list(mi.padded(seq, fillvalue='', n=4, next_multiple=True)),
- [1, 2, 3, 4, 5, 6, '', '']
- )
-
-
-class DistributeTest(TestCase):
- """Tests for distribute()"""
-
- def test_invalid_n(self):
- self.assertRaises(ValueError, lambda: mi.distribute(-1, [1, 2, 3]))
- self.assertRaises(ValueError, lambda: mi.distribute(0, [1, 2, 3]))
-
- def test_basic(self):
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-
- for n, expected in [
- (1, [iterable]),
- (2, [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]),
- (3, [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]]),
- (10, [[n] for n in range(1, 10 + 1)]),
- ]:
- self.assertEqual(
- [list(x) for x in mi.distribute(n, iterable)], expected
- )
-
- def test_large_n(self):
- iterable = [1, 2, 3, 4]
- self.assertEqual(
- [list(x) for x in mi.distribute(6, iterable)],
- [[1], [2], [3], [4], [], []]
- )
-
-
-class StaggerTest(TestCase):
- """Tests for ``stagger()``"""
-
- def test_default(self):
- iterable = [0, 1, 2, 3]
- actual = list(mi.stagger(iterable))
- expected = [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
- self.assertEqual(actual, expected)
-
- def test_offsets(self):
- iterable = [0, 1, 2, 3]
- for offsets, expected in [
- ((-2, 0, 2), [('', 0, 2), ('', 1, 3)]),
- ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3)]),
- ((1, 2), [(1, 2), (2, 3)]),
- ]:
- all_groups = mi.stagger(iterable, offsets=offsets, fillvalue='')
- self.assertEqual(list(all_groups), expected)
-
- def test_longest(self):
- iterable = [0, 1, 2, 3]
- for offsets, expected in [
- (
- (-1, 0, 1),
- [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, ''), (3, '', '')]
- ),
- ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3), (3, '')]),
- ((1, 2), [(1, 2), (2, 3), (3, '')]),
- ]:
- all_groups = mi.stagger(
- iterable, offsets=offsets, fillvalue='', longest=True
- )
- self.assertEqual(list(all_groups), expected)
-
-
-class ZipOffsetTest(TestCase):
- """Tests for ``zip_offset()``"""
-
- def test_shortest(self):
- a_1 = [0, 1, 2, 3]
- a_2 = [0, 1, 2, 3, 4, 5]
- a_3 = [0, 1, 2, 3, 4, 5, 6, 7]
- actual = list(
- mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), fillvalue='')
- )
- expected = [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)]
- self.assertEqual(actual, expected)
-
- def test_longest(self):
- a_1 = [0, 1, 2, 3]
- a_2 = [0, 1, 2, 3, 4, 5]
- a_3 = [0, 1, 2, 3, 4, 5, 6, 7]
- actual = list(
- mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), longest=True)
- )
- expected = [
- (None, 0, 1),
- (0, 1, 2),
- (1, 2, 3),
- (2, 3, 4),
- (3, 4, 5),
- (None, 5, 6),
- (None, None, 7),
- ]
- self.assertEqual(actual, expected)
-
- def test_mismatch(self):
- iterables = [0, 1, 2], [2, 3, 4]
- offsets = (-1, 0, 1)
- self.assertRaises(
- ValueError,
- lambda: list(mi.zip_offset(*iterables, offsets=offsets))
- )
-
-
-class UnzipTests(TestCase):
- """Tests for unzip()"""
-
- def test_empty_iterable(self):
- self.assertEqual(list(mi.unzip([])), [])
- # in reality zip([], [], []) is equivalent to iter([])
- # but it doesn't hurt to test both
- self.assertEqual(list(mi.unzip(zip([], [], []))), [])
-
- def test_length_one_iterable(self):
- xs, ys, zs = mi.unzip(zip([1], [2], [3]))
- self.assertEqual(list(xs), [1])
- self.assertEqual(list(ys), [2])
- self.assertEqual(list(zs), [3])
-
- def test_normal_case(self):
- xs, ys, zs = range(10), range(1, 11), range(2, 12)
- zipped = zip(xs, ys, zs)
- xs, ys, zs = mi.unzip(zipped)
- self.assertEqual(list(xs), list(range(10)))
- self.assertEqual(list(ys), list(range(1, 11)))
- self.assertEqual(list(zs), list(range(2, 12)))
-
- def test_improperly_zipped(self):
- zipped = iter([(1, 2, 3), (4, 5), (6,)])
- xs, ys, zs = mi.unzip(zipped)
- self.assertEqual(list(xs), [1, 4, 6])
- self.assertEqual(list(ys), [2, 5])
- self.assertEqual(list(zs), [3])
-
- def test_increasingly_zipped(self):
- zipped = iter([(1, 2), (3, 4, 5), (6, 7, 8, 9)])
- unzipped = mi.unzip(zipped)
- # from the docstring:
- # len(first tuple) is the number of iterables zipped
- self.assertEqual(len(unzipped), 2)
- xs, ys = unzipped
- self.assertEqual(list(xs), [1, 3, 6])
- self.assertEqual(list(ys), [2, 4, 7])
-
-
-class SortTogetherTest(TestCase):
- """Tests for sort_together()"""
-
- def test_key_list(self):
- """tests `key_list` including default, iterables include duplicates"""
- iterables = [
- ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'],
- ['May', 'Aug.', 'May', 'June', 'July', 'July'],
- [97, 20, 100, 70, 100, 20]
- ]
-
- self.assertEqual(
- mi.sort_together(iterables),
- [
- ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'),
- ('June', 'July', 'July', 'May', 'Aug.', 'May'),
- (70, 100, 20, 97, 20, 100)
- ]
- )
-
- self.assertEqual(
- mi.sort_together(iterables, key_list=(0, 1)),
- [
- ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'),
- ('July', 'July', 'June', 'Aug.', 'May', 'May'),
- (100, 20, 70, 20, 97, 100)
- ]
- )
-
- self.assertEqual(
- mi.sort_together(iterables, key_list=(0, 1, 2)),
- [
- ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'),
- ('July', 'July', 'June', 'Aug.', 'May', 'May'),
- (20, 100, 70, 20, 97, 100)
- ]
- )
-
- self.assertEqual(
- mi.sort_together(iterables, key_list=(2,)),
- [
- ('GA', 'CT', 'CT', 'GA', 'GA', 'CT'),
- ('Aug.', 'July', 'June', 'May', 'May', 'July'),
- (20, 20, 70, 97, 100, 100)
- ]
- )
-
- def test_invalid_key_list(self):
- """tests `key_list` for indexes not available in `iterables`"""
- iterables = [
- ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'],
- ['May', 'Aug.', 'May', 'June', 'July', 'July'],
- [97, 20, 100, 70, 100, 20]
- ]
-
- self.assertRaises(
- IndexError, lambda: mi.sort_together(iterables, key_list=(5,))
- )
-
- def test_reverse(self):
- """tests `reverse` to ensure a reverse sort for `key_list` iterables"""
- iterables = [
- ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'],
- ['May', 'Aug.', 'May', 'June', 'July', 'July'],
- [97, 20, 100, 70, 100, 20]
- ]
-
- self.assertEqual(
- mi.sort_together(iterables, key_list=(0, 1, 2), reverse=True),
- [('GA', 'GA', 'GA', 'CT', 'CT', 'CT'),
- ('May', 'May', 'Aug.', 'June', 'July', 'July'),
- (100, 97, 20, 70, 100, 20)]
- )
-
- def test_uneven_iterables(self):
- """tests trimming of iterables to the shortest length before sorting"""
- iterables = [['GA', 'GA', 'GA', 'CT', 'CT', 'CT', 'MA'],
- ['May', 'Aug.', 'May', 'June', 'July', 'July'],
- [97, 20, 100, 70, 100, 20, 0]]
-
- self.assertEqual(
- mi.sort_together(iterables),
- [
- ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'),
- ('June', 'July', 'July', 'May', 'Aug.', 'May'),
- (70, 100, 20, 97, 20, 100)
- ]
- )
-
-
-class DivideTest(TestCase):
- """Tests for divide()"""
-
- def test_invalid_n(self):
- self.assertRaises(ValueError, lambda: mi.divide(-1, [1, 2, 3]))
- self.assertRaises(ValueError, lambda: mi.divide(0, [1, 2, 3]))
-
- def test_basic(self):
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-
- for n, expected in [
- (1, [iterable]),
- (2, [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]),
- (3, [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]),
- (10, [[n] for n in range(1, 10 + 1)]),
- ]:
- self.assertEqual(
- [list(x) for x in mi.divide(n, iterable)], expected
- )
-
- def test_large_n(self):
- iterable = [1, 2, 3, 4]
- self.assertEqual(
- [list(x) for x in mi.divide(6, iterable)],
- [[1], [2], [3], [4], [], []]
- )
-
-
-class TestAlwaysIterable(TestCase):
- """Tests for always_iterable()"""
- def test_single(self):
- self.assertEqual(list(mi.always_iterable(1)), [1])
-
- def test_strings(self):
- for obj in ['foo', b'bar', 'baz']:
- actual = list(mi.always_iterable(obj))
- expected = [obj]
- self.assertEqual(actual, expected)
-
- def test_base_type(self):
- dict_obj = {'a': 1, 'b': 2}
- str_obj = '123'
-
- # Default: dicts are iterable like they normally are
- default_actual = list(mi.always_iterable(dict_obj))
- default_expected = list(dict_obj)
- self.assertEqual(default_actual, default_expected)
-
- # Unitary types set: dicts are not iterable
- custom_actual = list(mi.always_iterable(dict_obj, base_type=dict))
- custom_expected = [dict_obj]
- self.assertEqual(custom_actual, custom_expected)
-
- # With unitary types set, strings are iterable
- str_actual = list(mi.always_iterable(str_obj, base_type=None))
- str_expected = list(str_obj)
- self.assertEqual(str_actual, str_expected)
-
- def test_iterables(self):
- self.assertEqual(list(mi.always_iterable([0, 1])), [0, 1])
- self.assertEqual(
- list(mi.always_iterable([0, 1], base_type=list)), [[0, 1]]
- )
- self.assertEqual(
- list(mi.always_iterable(iter('foo'))), ['f', 'o', 'o']
- )
- self.assertEqual(list(mi.always_iterable([])), [])
-
- def test_none(self):
- self.assertEqual(list(mi.always_iterable(None)), [])
-
- def test_generator(self):
- def _gen():
- yield 0
- yield 1
-
- self.assertEqual(list(mi.always_iterable(_gen())), [0, 1])
-
-
-class AdjacentTests(TestCase):
- def test_typical(self):
- actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10)))
- expected = [(True, 0), (True, 1), (False, 2), (False, 3), (True, 4),
- (True, 5), (True, 6), (False, 7), (False, 8), (False, 9)]
- self.assertEqual(actual, expected)
-
- def test_empty_iterable(self):
- actual = list(mi.adjacent(lambda x: x % 5 == 0, []))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_length_one(self):
- actual = list(mi.adjacent(lambda x: x % 5 == 0, [0]))
- expected = [(True, 0)]
- self.assertEqual(actual, expected)
-
- actual = list(mi.adjacent(lambda x: x % 5 == 0, [1]))
- expected = [(False, 1)]
- self.assertEqual(actual, expected)
-
- def test_consecutive_true(self):
- """Test that when the predicate matches multiple consecutive elements
- it doesn't repeat elements in the output"""
- actual = list(mi.adjacent(lambda x: x % 5 < 2, range(10)))
- expected = [(True, 0), (True, 1), (True, 2), (False, 3), (True, 4),
- (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)]
- self.assertEqual(actual, expected)
-
- def test_distance(self):
- actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=2))
- expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4),
- (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)]
- self.assertEqual(actual, expected)
-
- actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=3))
- expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4),
- (True, 5), (True, 6), (True, 7), (True, 8), (False, 9)]
- self.assertEqual(actual, expected)
-
- def test_large_distance(self):
- """Test distance larger than the length of the iterable"""
- iterable = range(10)
- actual = list(mi.adjacent(lambda x: x % 5 == 4, iterable, distance=20))
- expected = list(zip(repeat(True), iterable))
- self.assertEqual(actual, expected)
-
- actual = list(mi.adjacent(lambda x: False, iterable, distance=20))
- expected = list(zip(repeat(False), iterable))
- self.assertEqual(actual, expected)
-
- def test_zero_distance(self):
- """Test that adjacent() reduces to zip+map when distance is 0"""
- iterable = range(1000)
- predicate = lambda x: x % 4 == 2
- actual = mi.adjacent(predicate, iterable, 0)
- expected = zip(map(predicate, iterable), iterable)
- self.assertTrue(all(a == e for a, e in zip(actual, expected)))
-
- def test_negative_distance(self):
- """Test that adjacent() raises an error with negative distance"""
- pred = lambda x: x
- self.assertRaises(
- ValueError, lambda: mi.adjacent(pred, range(1000), -1)
- )
- self.assertRaises(
- ValueError, lambda: mi.adjacent(pred, range(10), -10)
- )
-
- def test_grouping(self):
- """Test interaction of adjacent() with groupby_transform()"""
- iterable = mi.adjacent(lambda x: x % 5 == 0, range(10))
- grouper = mi.groupby_transform(iterable, itemgetter(0), itemgetter(1))
- actual = [(k, list(g)) for k, g in grouper]
- expected = [
- (True, [0, 1]),
- (False, [2, 3]),
- (True, [4, 5, 6]),
- (False, [7, 8, 9]),
- ]
- self.assertEqual(actual, expected)
-
- def test_call_once(self):
- """Test that the predicate is only called once per item."""
- already_seen = set()
- iterable = range(10)
-
- def predicate(item):
- self.assertNotIn(item, already_seen)
- already_seen.add(item)
- return True
-
- actual = list(mi.adjacent(predicate, iterable))
- expected = [(True, x) for x in iterable]
- self.assertEqual(actual, expected)
-
-
-class GroupByTransformTests(TestCase):
- def assertAllGroupsEqual(self, groupby1, groupby2):
- """Compare two groupby objects for equality, both keys and groups."""
- for a, b in zip(groupby1, groupby2):
- key1, group1 = a
- key2, group2 = b
- self.assertEqual(key1, key2)
- self.assertListEqual(list(group1), list(group2))
- self.assertRaises(StopIteration, lambda: next(groupby1))
- self.assertRaises(StopIteration, lambda: next(groupby2))
-
- def test_default_funcs(self):
- """Test that groupby_transform() with default args mimics groupby()"""
- iterable = [(x // 5, x) for x in range(1000)]
- actual = mi.groupby_transform(iterable)
- expected = groupby(iterable)
- self.assertAllGroupsEqual(actual, expected)
-
- def test_valuefunc(self):
- iterable = [(int(x / 5), int(x / 3), x) for x in range(10)]
-
- # Test the standard usage of grouping one iterable using another's keys
- grouper = mi.groupby_transform(
- iterable, keyfunc=itemgetter(0), valuefunc=itemgetter(-1)
- )
- actual = [(k, list(g)) for k, g in grouper]
- expected = [(0, [0, 1, 2, 3, 4]), (1, [5, 6, 7, 8, 9])]
- self.assertEqual(actual, expected)
-
- grouper = mi.groupby_transform(
- iterable, keyfunc=itemgetter(1), valuefunc=itemgetter(-1)
- )
- actual = [(k, list(g)) for k, g in grouper]
- expected = [(0, [0, 1, 2]), (1, [3, 4, 5]), (2, [6, 7, 8]), (3, [9])]
- self.assertEqual(actual, expected)
-
- # and now for something a little different
- d = dict(zip(range(10), 'abcdefghij'))
- grouper = mi.groupby_transform(
- range(10), keyfunc=lambda x: x // 5, valuefunc=d.get
- )
- actual = [(k, ''.join(g)) for k, g in grouper]
- expected = [(0, 'abcde'), (1, 'fghij')]
- self.assertEqual(actual, expected)
-
- def test_no_valuefunc(self):
- iterable = range(1000)
-
- def key(x):
- return x // 5
-
- actual = mi.groupby_transform(iterable, key, valuefunc=None)
- expected = groupby(iterable, key)
- self.assertAllGroupsEqual(actual, expected)
-
- actual = mi.groupby_transform(iterable, key) # default valuefunc
- expected = groupby(iterable, key)
- self.assertAllGroupsEqual(actual, expected)
-
-
-class NumericRangeTests(TestCase):
- def test_basic(self):
- for args, expected in [
- ((4,), [0, 1, 2, 3]),
- ((4.0,), [0.0, 1.0, 2.0, 3.0]),
- ((1.0, 4), [1.0, 2.0, 3.0]),
- ((1, 4.0), [1, 2, 3]),
- ((1.0, 5), [1.0, 2.0, 3.0, 4.0]),
- ((0, 20, 5), [0, 5, 10, 15]),
- ((0, 20, 5.0), [0.0, 5.0, 10.0, 15.0]),
- ((0, 10, 3), [0, 3, 6, 9]),
- ((0, 10, 3.0), [0.0, 3.0, 6.0, 9.0]),
- ((0, -5, -1), [0, -1, -2, -3, -4]),
- ((0.0, -5, -1), [0.0, -1.0, -2.0, -3.0, -4.0]),
- ((1, 2, Fraction(1, 2)), [Fraction(1, 1), Fraction(3, 2)]),
- ((0,), []),
- ((0.0,), []),
- ((1, 0), []),
- ((1.0, 0.0), []),
- ((Fraction(2, 1),), [Fraction(0, 1), Fraction(1, 1)]),
- ((Decimal('2.0'),), [Decimal('0.0'), Decimal('1.0')]),
- ]:
- actual = list(mi.numeric_range(*args))
- self.assertEqual(actual, expected)
- self.assertTrue(
- all(type(a) == type(e) for a, e in zip(actual, expected))
- )
-
- def test_arg_count(self):
- self.assertRaises(TypeError, lambda: list(mi.numeric_range()))
- self.assertRaises(
- TypeError, lambda: list(mi.numeric_range(0, 1, 2, 3))
- )
-
- def test_zero_step(self):
- self.assertRaises(
- ValueError, lambda: list(mi.numeric_range(1, 2, 0))
- )
-
-
-class CountCycleTests(TestCase):
- def test_basic(self):
- expected = [
- (0, 'a'), (0, 'b'), (0, 'c'),
- (1, 'a'), (1, 'b'), (1, 'c'),
- (2, 'a'), (2, 'b'), (2, 'c'),
- ]
- for actual in [
- mi.take(9, mi.count_cycle('abc')), # n=None
- list(mi.count_cycle('abc', 3)), # n=3
- ]:
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- self.assertEqual(list(mi.count_cycle('')), [])
- self.assertEqual(list(mi.count_cycle('', 2)), [])
-
- def test_negative(self):
- self.assertEqual(list(mi.count_cycle('abc', -3)), [])
-
-
-class LocateTests(TestCase):
- def test_default_pred(self):
- iterable = [0, 1, 1, 0, 1, 0, 0]
- actual = list(mi.locate(iterable))
- expected = [1, 2, 4]
- self.assertEqual(actual, expected)
-
- def test_no_matches(self):
- iterable = [0, 0, 0]
- actual = list(mi.locate(iterable))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_custom_pred(self):
- iterable = ['0', 1, 1, '0', 1, '0', '0']
- pred = lambda x: x == '0'
- actual = list(mi.locate(iterable, pred))
- expected = [0, 3, 5, 6]
- self.assertEqual(actual, expected)
-
- def test_window_size(self):
- iterable = ['0', 1, 1, '0', 1, '0', '0']
- pred = lambda *args: args == ('0', 1)
- actual = list(mi.locate(iterable, pred, window_size=2))
- expected = [0, 3]
- self.assertEqual(actual, expected)
-
- def test_window_size_large(self):
- iterable = [1, 2, 3, 4]
- pred = lambda a, b, c, d, e: True
- actual = list(mi.locate(iterable, pred, window_size=5))
- expected = [0]
- self.assertEqual(actual, expected)
-
- def test_window_size_zero(self):
- iterable = [1, 2, 3, 4]
- pred = lambda: True
- with self.assertRaises(ValueError):
- list(mi.locate(iterable, pred, window_size=0))
-
-
-class StripFunctionTests(TestCase):
- def test_hashable(self):
- iterable = list('www.example.com')
- pred = lambda x: x in set('cmowz.')
-
- self.assertEqual(list(mi.lstrip(iterable, pred)), list('example.com'))
- self.assertEqual(list(mi.rstrip(iterable, pred)), list('www.example'))
- self.assertEqual(list(mi.strip(iterable, pred)), list('example'))
-
- def test_not_hashable(self):
- iterable = [
- list('http://'), list('www'), list('.example'), list('.com')
- ]
- pred = lambda x: x in [list('http://'), list('www'), list('.com')]
-
- self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[2:])
- self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:3])
- self.assertEqual(list(mi.strip(iterable, pred)), iterable[2: 3])
-
- def test_math(self):
- iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]
- pred = lambda x: x <= 2
-
- self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[3:])
- self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:-3])
- self.assertEqual(list(mi.strip(iterable, pred)), iterable[3:-3])
-
-
-class IsliceExtendedTests(TestCase):
- def test_all(self):
- iterable = ['0', '1', '2', '3', '4', '5']
- indexes = list(range(-4, len(iterable) + 4)) + [None]
- steps = [1, 2, 3, 4, -1, -2, -3, 4]
- for slice_args in product(indexes, indexes, steps):
- try:
- actual = list(mi.islice_extended(iterable, *slice_args))
- except Exception as e:
- self.fail((slice_args, e))
-
- expected = iterable[slice(*slice_args)]
- self.assertEqual(actual, expected, slice_args)
-
- def test_zero_step(self):
- with self.assertRaises(ValueError):
- list(mi.islice_extended([1, 2, 3], 0, 1, 0))
-
-
-class ConsecutiveGroupsTest(TestCase):
- def test_numbers(self):
- iterable = [-10, -8, -7, -6, 1, 2, 4, 5, -1, 7]
- actual = [list(g) for g in mi.consecutive_groups(iterable)]
- expected = [[-10], [-8, -7, -6], [1, 2], [4, 5], [-1], [7]]
- self.assertEqual(actual, expected)
-
- def test_custom_ordering(self):
- iterable = ['1', '10', '11', '20', '21', '22', '30', '31']
- ordering = lambda x: int(x)
- actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)]
- expected = [['1'], ['10', '11'], ['20', '21', '22'], ['30', '31']]
- self.assertEqual(actual, expected)
-
- def test_exotic_ordering(self):
- iterable = [
- ('a', 'b', 'c', 'd'),
- ('a', 'c', 'b', 'd'),
- ('a', 'c', 'd', 'b'),
- ('a', 'd', 'b', 'c'),
- ('d', 'b', 'c', 'a'),
- ('d', 'c', 'a', 'b'),
- ]
- ordering = list(permutations('abcd')).index
- actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)]
- expected = [
- [('a', 'b', 'c', 'd')],
- [('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b'), ('a', 'd', 'b', 'c')],
- [('d', 'b', 'c', 'a'), ('d', 'c', 'a', 'b')],
- ]
- self.assertEqual(actual, expected)
-
-
-class DifferenceTest(TestCase):
- def test_normal(self):
- iterable = [10, 20, 30, 40, 50]
- actual = list(mi.difference(iterable))
- expected = [10, 10, 10, 10, 10]
- self.assertEqual(actual, expected)
-
- def test_custom(self):
- iterable = [10, 20, 30, 40, 50]
- actual = list(mi.difference(iterable, add))
- expected = [10, 30, 50, 70, 90]
- self.assertEqual(actual, expected)
-
- def test_roundtrip(self):
- original = list(range(100))
- accumulated = mi.accumulate(original)
- actual = list(mi.difference(accumulated))
- self.assertEqual(actual, original)
-
- def test_one(self):
- self.assertEqual(list(mi.difference([0])), [0])
-
- def test_empty(self):
- self.assertEqual(list(mi.difference([])), [])
-
-
-class SeekableTest(TestCase):
- def test_exhaustion_reset(self):
- iterable = [str(n) for n in range(10)]
-
- s = mi.seekable(iterable)
- self.assertEqual(list(s), iterable) # Normal iteration
- self.assertEqual(list(s), []) # Iterable is exhausted
-
- s.seek(0)
- self.assertEqual(list(s), iterable) # Back in action
-
- def test_partial_reset(self):
- iterable = [str(n) for n in range(10)]
-
- s = mi.seekable(iterable)
- self.assertEqual(mi.take(5, s), iterable[:5]) # Normal iteration
-
- s.seek(1)
- self.assertEqual(list(s), iterable[1:]) # Get the rest of the iterable
-
- def test_forward(self):
- iterable = [str(n) for n in range(10)]
-
- s = mi.seekable(iterable)
- self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration
-
- s.seek(3) # Skip over index 2
- self.assertEqual(list(s), iterable[3:]) # Result is similar to slicing
-
- s.seek(0) # Back to 0
- self.assertEqual(list(s), iterable) # No difference in result
-
- def test_past_end(self):
- iterable = [str(n) for n in range(10)]
-
- s = mi.seekable(iterable)
- self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration
-
- s.seek(20)
- self.assertEqual(list(s), []) # Iterable is exhausted
-
- s.seek(0) # Back to 0
- self.assertEqual(list(s), iterable) # No difference in result
-
- def test_elements(self):
- iterable = map(str, count())
-
- s = mi.seekable(iterable)
- mi.take(10, s)
-
- elements = s.elements()
- self.assertEqual(
- [elements[i] for i in range(10)], [str(n) for n in range(10)]
- )
- self.assertEqual(len(elements), 10)
-
- mi.take(10, s)
- self.assertEqual(list(elements), [str(n) for n in range(20)])
-
-
-class SequenceViewTests(TestCase):
- def test_init(self):
- view = mi.SequenceView((1, 2, 3))
- self.assertEqual(repr(view), "SequenceView((1, 2, 3))")
- self.assertRaises(TypeError, lambda: mi.SequenceView({}))
-
- def test_update(self):
- seq = [1, 2, 3]
- view = mi.SequenceView(seq)
- self.assertEqual(len(view), 3)
- self.assertEqual(repr(view), "SequenceView([1, 2, 3])")
-
- seq.pop()
- self.assertEqual(len(view), 2)
- self.assertEqual(repr(view), "SequenceView([1, 2])")
-
- def test_indexing(self):
- seq = ('a', 'b', 'c', 'd', 'e', 'f')
- view = mi.SequenceView(seq)
- for i in range(-len(seq), len(seq)):
- self.assertEqual(view[i], seq[i])
-
- def test_slicing(self):
- seq = ('a', 'b', 'c', 'd', 'e', 'f')
- view = mi.SequenceView(seq)
- n = len(seq)
- indexes = list(range(-n - 1, n + 1)) + [None]
- steps = list(range(-n, n + 1))
- steps.remove(0)
- for slice_args in product(indexes, indexes, steps):
- i = slice(*slice_args)
- self.assertEqual(view[i], seq[i])
-
- def test_abc_methods(self):
- # collections.Sequence should provide all of this functionality
- seq = ('a', 'b', 'c', 'd', 'e', 'f', 'f')
- view = mi.SequenceView(seq)
-
- # __contains__
- self.assertIn('b', view)
- self.assertNotIn('g', view)
-
- # __iter__
- self.assertEqual(list(iter(view)), list(seq))
-
- # __reversed__
- self.assertEqual(list(reversed(view)), list(reversed(seq)))
-
- # index
- self.assertEqual(view.index('b'), 1)
-
- # count
- self.assertEqual(seq.count('f'), 2)
-
-
-class RunLengthTest(TestCase):
- def test_encode(self):
- iterable = (int(str(n)[0]) for n in count(800))
- actual = mi.take(4, mi.run_length.encode(iterable))
- expected = [(8, 100), (9, 100), (1, 1000), (2, 1000)]
- self.assertEqual(actual, expected)
-
- def test_decode(self):
- iterable = [('d', 4), ('c', 3), ('b', 2), ('a', 1)]
- actual = ''.join(mi.run_length.decode(iterable))
- expected = 'ddddcccbba'
- self.assertEqual(actual, expected)
-
-
-class ExactlyNTests(TestCase):
- """Tests for ``exactly_n()``"""
-
- def test_true(self):
- """Iterable has ``n`` ``True`` elements"""
- self.assertTrue(mi.exactly_n([True, False, True], 2))
- self.assertTrue(mi.exactly_n([1, 1, 1, 0], 3))
- self.assertTrue(mi.exactly_n([False, False], 0))
- self.assertTrue(mi.exactly_n(range(100), 10, lambda x: x < 10))
-
- def test_false(self):
- """Iterable does not have ``n`` ``True`` elements"""
- self.assertFalse(mi.exactly_n([True, False, False], 2))
- self.assertFalse(mi.exactly_n([True, True, False], 1))
- self.assertFalse(mi.exactly_n([False], 1))
- self.assertFalse(mi.exactly_n([True], -1))
- self.assertFalse(mi.exactly_n(repeat(True), 100))
-
- def test_empty(self):
- """Return ``True`` if the iterable is empty and ``n`` is 0"""
- self.assertTrue(mi.exactly_n([], 0))
- self.assertFalse(mi.exactly_n([], 1))
-
-
-class AlwaysReversibleTests(TestCase):
- """Tests for ``always_reversible()``"""
-
- def test_regular_reversed(self):
- self.assertEqual(list(reversed(range(10))),
- list(mi.always_reversible(range(10))))
- self.assertEqual(list(reversed([1, 2, 3])),
- list(mi.always_reversible([1, 2, 3])))
- self.assertEqual(reversed([1, 2, 3]).__class__,
- mi.always_reversible([1, 2, 3]).__class__)
-
- def test_nonseq_reversed(self):
- # Create a non-reversible generator from a sequence
- with self.assertRaises(TypeError):
- reversed(x for x in range(10))
-
- self.assertEqual(list(reversed(range(10))),
- list(mi.always_reversible(x for x in range(10))))
- self.assertEqual(list(reversed([1, 2, 3])),
- list(mi.always_reversible(x for x in [1, 2, 3])))
- self.assertNotEqual(reversed((1, 2)).__class__,
- mi.always_reversible(x for x in (1, 2)).__class__)
-
-
-class CircularShiftsTests(TestCase):
- def test_empty(self):
- # empty iterable -> empty list
- self.assertEqual(list(mi.circular_shifts([])), [])
-
- def test_simple_circular_shifts(self):
- # test the a simple iterator case
- self.assertEqual(
- mi.circular_shifts(range(4)),
- [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
- )
-
- def test_duplicates(self):
- # test non-distinct entries
- self.assertEqual(
- mi.circular_shifts([0, 1, 0, 1]),
- [(0, 1, 0, 1), (1, 0, 1, 0), (0, 1, 0, 1), (1, 0, 1, 0)]
- )
-
-
-class MakeDecoratorTests(TestCase):
- def test_basic(self):
- slicer = mi.make_decorator(islice)
-
- @slicer(1, 10, 2)
- def user_function(arg_1, arg_2, kwarg_1=None):
- self.assertEqual(arg_1, 'arg_1')
- self.assertEqual(arg_2, 'arg_2')
- self.assertEqual(kwarg_1, 'kwarg_1')
- return map(str, count())
-
- it = user_function('arg_1', 'arg_2', kwarg_1='kwarg_1')
- actual = list(it)
- expected = ['1', '3', '5', '7', '9']
- self.assertEqual(actual, expected)
-
- def test_result_index(self):
- def stringify(*args, **kwargs):
- self.assertEqual(args[0], 'arg_0')
- iterable = args[1]
- self.assertEqual(args[2], 'arg_2')
- self.assertEqual(kwargs['kwarg_1'], 'kwarg_1')
- return map(str, iterable)
-
- stringifier = mi.make_decorator(stringify, result_index=1)
-
- @stringifier('arg_0', 'arg_2', kwarg_1='kwarg_1')
- def user_function(n):
- return count(n)
-
- it = user_function(1)
- actual = mi.take(5, it)
- expected = ['1', '2', '3', '4', '5']
- self.assertEqual(actual, expected)
-
- def test_wrap_class(self):
- seeker = mi.make_decorator(mi.seekable)
-
- @seeker()
- def user_function(n):
- return map(str, range(n))
-
- it = user_function(5)
- self.assertEqual(list(it), ['0', '1', '2', '3', '4'])
-
- it.seek(0)
- self.assertEqual(list(it), ['0', '1', '2', '3', '4'])
-
-
-class MapReduceTests(TestCase):
- def test_default(self):
- iterable = (str(x) for x in range(5))
- keyfunc = lambda x: int(x) // 2
- actual = sorted(mi.map_reduce(iterable, keyfunc).items())
- expected = [(0, ['0', '1']), (1, ['2', '3']), (2, ['4'])]
- self.assertEqual(actual, expected)
-
- def test_valuefunc(self):
- iterable = (str(x) for x in range(5))
- keyfunc = lambda x: int(x) // 2
- valuefunc = int
- actual = sorted(mi.map_reduce(iterable, keyfunc, valuefunc).items())
- expected = [(0, [0, 1]), (1, [2, 3]), (2, [4])]
- self.assertEqual(actual, expected)
-
- def test_reducefunc(self):
- iterable = (str(x) for x in range(5))
- keyfunc = lambda x: int(x) // 2
- valuefunc = int
- reducefunc = lambda value_list: reduce(mul, value_list, 1)
- actual = sorted(
- mi.map_reduce(iterable, keyfunc, valuefunc, reducefunc).items()
- )
- expected = [(0, 0), (1, 6), (2, 4)]
- self.assertEqual(actual, expected)
-
- def test_ret(self):
- d = mi.map_reduce([1, 0, 2, 0, 1, 0], bool)
- self.assertEqual(d, {False: [0, 0, 0], True: [1, 2, 1]})
- self.assertRaises(KeyError, lambda: d[None].append(1))
-
-
-class RlocateTests(TestCase):
- def test_default_pred(self):
- iterable = [0, 1, 1, 0, 1, 0, 0]
- for it in (iterable[:], iter(iterable)):
- actual = list(mi.rlocate(it))
- expected = [4, 2, 1]
- self.assertEqual(actual, expected)
-
- def test_no_matches(self):
- iterable = [0, 0, 0]
- for it in (iterable[:], iter(iterable)):
- actual = list(mi.rlocate(it))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_custom_pred(self):
- iterable = ['0', 1, 1, '0', 1, '0', '0']
- pred = lambda x: x == '0'
- for it in (iterable[:], iter(iterable)):
- actual = list(mi.rlocate(it, pred))
- expected = [6, 5, 3, 0]
- self.assertEqual(actual, expected)
-
- def test_efficient_reversal(self):
- iterable = range(9 ** 9) # Is efficiently reversible
- target = 9 ** 9 - 2
- pred = lambda x: x == target # Find-able from the right
- actual = next(mi.rlocate(iterable, pred))
- self.assertEqual(actual, target)
-
- def test_window_size(self):
- iterable = ['0', 1, 1, '0', 1, '0', '0']
- pred = lambda *args: args == ('0', 1)
- for it in (iterable, iter(iterable)):
- actual = list(mi.rlocate(it, pred, window_size=2))
- expected = [3, 0]
- self.assertEqual(actual, expected)
-
- def test_window_size_large(self):
- iterable = [1, 2, 3, 4]
- pred = lambda a, b, c, d, e: True
- for it in (iterable, iter(iterable)):
- actual = list(mi.rlocate(iterable, pred, window_size=5))
- expected = [0]
- self.assertEqual(actual, expected)
-
- def test_window_size_zero(self):
- iterable = [1, 2, 3, 4]
- pred = lambda: True
- for it in (iterable, iter(iterable)):
- with self.assertRaises(ValueError):
- list(mi.locate(iterable, pred, window_size=0))
-
-
-class ReplaceTests(TestCase):
- def test_basic(self):
- iterable = range(10)
- pred = lambda x: x % 2 == 0
- substitutes = []
- actual = list(mi.replace(iterable, pred, substitutes))
- expected = [1, 3, 5, 7, 9]
- self.assertEqual(actual, expected)
-
- def test_count(self):
- iterable = range(10)
- pred = lambda x: x % 2 == 0
- substitutes = []
- actual = list(mi.replace(iterable, pred, substitutes, count=4))
- expected = [1, 3, 5, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_window_size(self):
- iterable = range(10)
- pred = lambda *args: args == (0, 1, 2)
- substitutes = []
- actual = list(mi.replace(iterable, pred, substitutes, window_size=3))
- expected = [3, 4, 5, 6, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_window_size_end(self):
- iterable = range(10)
- pred = lambda *args: args == (7, 8, 9)
- substitutes = []
- actual = list(mi.replace(iterable, pred, substitutes, window_size=3))
- expected = [0, 1, 2, 3, 4, 5, 6]
- self.assertEqual(actual, expected)
-
- def test_window_size_count(self):
- iterable = range(10)
- pred = lambda *args: (args == (0, 1, 2)) or (args == (7, 8, 9))
- substitutes = []
- actual = list(
- mi.replace(iterable, pred, substitutes, count=1, window_size=3)
- )
- expected = [3, 4, 5, 6, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_window_size_large(self):
- iterable = range(4)
- pred = lambda a, b, c, d, e: True
- substitutes = [5, 6, 7]
- actual = list(mi.replace(iterable, pred, substitutes, window_size=5))
- expected = [5, 6, 7]
- self.assertEqual(actual, expected)
-
- def test_window_size_zero(self):
- iterable = range(10)
- pred = lambda *args: True
- substitutes = []
- with self.assertRaises(ValueError):
- list(mi.replace(iterable, pred, substitutes, window_size=0))
-
- def test_iterable_substitutes(self):
- iterable = range(5)
- pred = lambda x: x % 2 == 0
- substitutes = iter('__')
- actual = list(mi.replace(iterable, pred, substitutes))
- expected = ['_', '_', 1, '_', '_', 3, '_', '_']
- self.assertEqual(actual, expected)
diff --git a/contrib/python/more-itertools/py2/more_itertools/tests/test_recipes.py b/contrib/python/more-itertools/py2/more_itertools/tests/test_recipes.py
deleted file mode 100644
index b3cfb62f46..0000000000
--- a/contrib/python/more-itertools/py2/more_itertools/tests/test_recipes.py
+++ /dev/null
@@ -1,616 +0,0 @@
-from doctest import DocTestSuite
-from unittest import TestCase
-
-from itertools import combinations
-from six.moves import range
-
-import more_itertools as mi
-
-
-def load_tests(loader, tests, ignore):
- # Add the doctests
- tests.addTests(DocTestSuite('more_itertools.recipes'))
- return tests
-
-
-class AccumulateTests(TestCase):
- """Tests for ``accumulate()``"""
-
- def test_empty(self):
- """Test that an empty input returns an empty output"""
- self.assertEqual(list(mi.accumulate([])), [])
-
- def test_default(self):
- """Test accumulate with the default function (addition)"""
- self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6])
-
- def test_bogus_function(self):
- """Test accumulate with an invalid function"""
- with self.assertRaises(TypeError):
- list(mi.accumulate([1, 2, 3], func=lambda x: x))
-
- def test_custom_function(self):
- """Test accumulate with a custom function"""
- self.assertEqual(
- list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3]
- )
-
-
-class TakeTests(TestCase):
- """Tests for ``take()``"""
-
- def test_simple_take(self):
- """Test basic usage"""
- t = mi.take(5, range(10))
- self.assertEqual(t, [0, 1, 2, 3, 4])
-
- def test_null_take(self):
- """Check the null case"""
- t = mi.take(0, range(10))
- self.assertEqual(t, [])
-
- def test_negative_take(self):
- """Make sure taking negative items results in a ValueError"""
- self.assertRaises(ValueError, lambda: mi.take(-3, range(10)))
-
- def test_take_too_much(self):
- """Taking more than an iterator has remaining should return what the
- iterator has remaining.
-
- """
- t = mi.take(10, range(5))
- self.assertEqual(t, [0, 1, 2, 3, 4])
-
-
-class TabulateTests(TestCase):
- """Tests for ``tabulate()``"""
-
- def test_simple_tabulate(self):
- """Test the happy path"""
- t = mi.tabulate(lambda x: x)
- f = tuple([next(t) for _ in range(3)])
- self.assertEqual(f, (0, 1, 2))
-
- def test_count(self):
- """Ensure tabulate accepts specific count"""
- t = mi.tabulate(lambda x: 2 * x, -1)
- f = (next(t), next(t), next(t))
- self.assertEqual(f, (-2, 0, 2))
-
-
-class TailTests(TestCase):
- """Tests for ``tail()``"""
-
- def test_greater(self):
- """Length of iterable is greater than requested tail"""
- self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G'])
-
- def test_equal(self):
- """Length of iterable is equal to the requested tail"""
- self.assertEqual(
- list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
- )
-
- def test_less(self):
- """Length of iterable is less than requested tail"""
- self.assertEqual(
- list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
- )
-
-
-class ConsumeTests(TestCase):
- """Tests for ``consume()``"""
-
- def test_sanity(self):
- """Test basic functionality"""
- r = (x for x in range(10))
- mi.consume(r, 3)
- self.assertEqual(3, next(r))
-
- def test_null_consume(self):
- """Check the null case"""
- r = (x for x in range(10))
- mi.consume(r, 0)
- self.assertEqual(0, next(r))
-
- def test_negative_consume(self):
- """Check that negative consumsion throws an error"""
- r = (x for x in range(10))
- self.assertRaises(ValueError, lambda: mi.consume(r, -1))
-
- def test_total_consume(self):
- """Check that iterator is totally consumed by default"""
- r = (x for x in range(10))
- mi.consume(r)
- self.assertRaises(StopIteration, lambda: next(r))
-
-
-class NthTests(TestCase):
- """Tests for ``nth()``"""
-
- def test_basic(self):
- """Make sure the nth item is returned"""
- l = range(10)
- for i, v in enumerate(l):
- self.assertEqual(mi.nth(l, i), v)
-
- def test_default(self):
- """Ensure a default value is returned when nth item not found"""
- l = range(3)
- self.assertEqual(mi.nth(l, 100, "zebra"), "zebra")
-
- def test_negative_item_raises(self):
- """Ensure asking for a negative item raises an exception"""
- self.assertRaises(ValueError, lambda: mi.nth(range(10), -3))
-
-
-class AllEqualTests(TestCase):
- """Tests for ``all_equal()``"""
-
- def test_true(self):
- """Everything is equal"""
- self.assertTrue(mi.all_equal('aaaaaa'))
- self.assertTrue(mi.all_equal([0, 0, 0, 0]))
-
- def test_false(self):
- """Not everything is equal"""
- self.assertFalse(mi.all_equal('aaaaab'))
- self.assertFalse(mi.all_equal([0, 0, 0, 1]))
-
- def test_tricky(self):
- """Not everything is identical, but everything is equal"""
- items = [1, complex(1, 0), 1.0]
- self.assertTrue(mi.all_equal(items))
-
- def test_empty(self):
- """Return True if the iterable is empty"""
- self.assertTrue(mi.all_equal(''))
- self.assertTrue(mi.all_equal([]))
-
- def test_one(self):
- """Return True if the iterable is singular"""
- self.assertTrue(mi.all_equal('0'))
- self.assertTrue(mi.all_equal([0]))
-
-
-class QuantifyTests(TestCase):
- """Tests for ``quantify()``"""
-
- def test_happy_path(self):
- """Make sure True count is returned"""
- q = [True, False, True]
- self.assertEqual(mi.quantify(q), 2)
-
- def test_custom_predicate(self):
- """Ensure non-default predicates return as expected"""
- q = range(10)
- self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5)
-
-
-class PadnoneTests(TestCase):
- """Tests for ``padnone()``"""
-
- def test_happy_path(self):
- """wrapper iterator should return None indefinitely"""
- r = range(2)
- p = mi.padnone(r)
- self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)])
-
-
-class NcyclesTests(TestCase):
- """Tests for ``nyclces()``"""
-
- def test_happy_path(self):
- """cycle a sequence three times"""
- r = ["a", "b", "c"]
- n = mi.ncycles(r, 3)
- self.assertEqual(
- ["a", "b", "c", "a", "b", "c", "a", "b", "c"],
- list(n)
- )
-
- def test_null_case(self):
- """asking for 0 cycles should return an empty iterator"""
- n = mi.ncycles(range(100), 0)
- self.assertRaises(StopIteration, lambda: next(n))
-
- def test_pathalogical_case(self):
- """asking for negative cycles should return an empty iterator"""
- n = mi.ncycles(range(100), -10)
- self.assertRaises(StopIteration, lambda: next(n))
-
-
-class DotproductTests(TestCase):
- """Tests for ``dotproduct()``'"""
-
- def test_happy_path(self):
- """simple dotproduct example"""
- self.assertEqual(400, mi.dotproduct([10, 10], [20, 20]))
-
-
-class FlattenTests(TestCase):
- """Tests for ``flatten()``"""
-
- def test_basic_usage(self):
- """ensure list of lists is flattened one level"""
- f = [[0, 1, 2], [3, 4, 5]]
- self.assertEqual(list(range(6)), list(mi.flatten(f)))
-
- def test_single_level(self):
- """ensure list of lists is flattened only one level"""
- f = [[0, [1, 2]], [[3, 4], 5]]
- self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f)))
-
-
-class RepeatfuncTests(TestCase):
- """Tests for ``repeatfunc()``"""
-
- def test_simple_repeat(self):
- """test simple repeated functions"""
- r = mi.repeatfunc(lambda: 5)
- self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)])
-
- def test_finite_repeat(self):
- """ensure limited repeat when times is provided"""
- r = mi.repeatfunc(lambda: 5, times=5)
- self.assertEqual([5, 5, 5, 5, 5], list(r))
-
- def test_added_arguments(self):
- """ensure arguments are applied to the function"""
- r = mi.repeatfunc(lambda x: x, 2, 3)
- self.assertEqual([3, 3], list(r))
-
- def test_null_times(self):
- """repeat 0 should return an empty iterator"""
- r = mi.repeatfunc(range, 0, 3)
- self.assertRaises(StopIteration, lambda: next(r))
-
-
-class PairwiseTests(TestCase):
- """Tests for ``pairwise()``"""
-
- def test_base_case(self):
- """ensure an iterable will return pairwise"""
- p = mi.pairwise([1, 2, 3])
- self.assertEqual([(1, 2), (2, 3)], list(p))
-
- def test_short_case(self):
- """ensure an empty iterator if there's not enough values to pair"""
- p = mi.pairwise("a")
- self.assertRaises(StopIteration, lambda: next(p))
-
-
-class GrouperTests(TestCase):
- """Tests for ``grouper()``"""
-
- def test_even(self):
- """Test when group size divides evenly into the length of
- the iterable.
-
- """
- self.assertEqual(
- list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')]
- )
-
- def test_odd(self):
- """Test when group size does not divide evenly into the length of the
- iterable.
-
- """
- self.assertEqual(
- list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)]
- )
-
- def test_fill_value(self):
- """Test that the fill value is used to pad the final group"""
- self.assertEqual(
- list(mi.grouper(3, 'ABCDE', 'x')),
- [('A', 'B', 'C'), ('D', 'E', 'x')]
- )
-
-
-class RoundrobinTests(TestCase):
- """Tests for ``roundrobin()``"""
-
- def test_even_groups(self):
- """Ensure ordered output from evenly populated iterables"""
- self.assertEqual(
- list(mi.roundrobin('ABC', [1, 2, 3], range(3))),
- ['A', 1, 0, 'B', 2, 1, 'C', 3, 2]
- )
-
- def test_uneven_groups(self):
- """Ensure ordered output from unevenly populated iterables"""
- self.assertEqual(
- list(mi.roundrobin('ABCD', [1, 2], range(0))),
- ['A', 1, 'B', 2, 'C', 'D']
- )
-
-
-class PartitionTests(TestCase):
- """Tests for ``partition()``"""
-
- def test_bool(self):
- """Test when pred() returns a boolean"""
- lesser, greater = mi.partition(lambda x: x > 5, range(10))
- self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5])
- self.assertEqual(list(greater), [6, 7, 8, 9])
-
- def test_arbitrary(self):
- """Test when pred() returns an integer"""
- divisibles, remainders = mi.partition(lambda x: x % 3, range(10))
- self.assertEqual(list(divisibles), [0, 3, 6, 9])
- self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8])
-
-
-class PowersetTests(TestCase):
- """Tests for ``powerset()``"""
-
- def test_combinatorics(self):
- """Ensure a proper enumeration"""
- p = mi.powerset([1, 2, 3])
- self.assertEqual(
- list(p),
- [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
- )
-
-
-class UniqueEverseenTests(TestCase):
- """Tests for ``unique_everseen()``"""
-
- def test_everseen(self):
- """ensure duplicate elements are ignored"""
- u = mi.unique_everseen('AAAABBBBCCDAABBB')
- self.assertEqual(
- ['A', 'B', 'C', 'D'],
- list(u)
- )
-
- def test_custom_key(self):
- """ensure the custom key comparison works"""
- u = mi.unique_everseen('aAbACCc', key=str.lower)
- self.assertEqual(list('abC'), list(u))
-
- def test_unhashable(self):
- """ensure things work for unhashable items"""
- iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
- u = mi.unique_everseen(iterable)
- self.assertEqual(list(u), ['a', [1, 2, 3]])
-
- def test_unhashable_key(self):
- """ensure things work for unhashable items with a custom key"""
- iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
- u = mi.unique_everseen(iterable, key=lambda x: x)
- self.assertEqual(list(u), ['a', [1, 2, 3]])
-
-
-class UniqueJustseenTests(TestCase):
- """Tests for ``unique_justseen()``"""
-
- def test_justseen(self):
- """ensure only last item is remembered"""
- u = mi.unique_justseen('AAAABBBCCDABB')
- self.assertEqual(list('ABCDAB'), list(u))
-
- def test_custom_key(self):
- """ensure the custom key comparison works"""
- u = mi.unique_justseen('AABCcAD', str.lower)
- self.assertEqual(list('ABCAD'), list(u))
-
-
-class IterExceptTests(TestCase):
- """Tests for ``iter_except()``"""
-
- def test_exact_exception(self):
- """ensure the exact specified exception is caught"""
- l = [1, 2, 3]
- i = mi.iter_except(l.pop, IndexError)
- self.assertEqual(list(i), [3, 2, 1])
-
- def test_generic_exception(self):
- """ensure the generic exception can be caught"""
- l = [1, 2]
- i = mi.iter_except(l.pop, Exception)
- self.assertEqual(list(i), [2, 1])
-
- def test_uncaught_exception_is_raised(self):
- """ensure a non-specified exception is raised"""
- l = [1, 2, 3]
- i = mi.iter_except(l.pop, KeyError)
- self.assertRaises(IndexError, lambda: list(i))
-
- def test_first(self):
- """ensure first is run before the function"""
- l = [1, 2, 3]
- f = lambda: 25
- i = mi.iter_except(l.pop, IndexError, f)
- self.assertEqual(list(i), [25, 3, 2, 1])
-
-
-class FirstTrueTests(TestCase):
- """Tests for ``first_true()``"""
-
- def test_something_true(self):
- """Test with no keywords"""
- self.assertEqual(mi.first_true(range(10)), 1)
-
- def test_nothing_true(self):
- """Test default return value."""
- self.assertIsNone(mi.first_true([0, 0, 0]))
-
- def test_default(self):
- """Test with a default keyword"""
- self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!')
-
- def test_pred(self):
- """Test with a custom predicate"""
- self.assertEqual(
- mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6
- )
-
-
-class RandomProductTests(TestCase):
- """Tests for ``random_product()``
-
- Since random.choice() has different results with the same seed across
- python versions 2.x and 3.x, these tests use highly probably events to
- create predictable outcomes across platforms.
- """
-
- def test_simple_lists(self):
- """Ensure that one item is chosen from each list in each pair.
- Also ensure that each item from each list eventually appears in
- the chosen combinations.
-
- Odds are roughly 1 in 7.1 * 10e16 that one item from either list will
- not be chosen after 100 samplings of one item from each list. Just to
- be safe, better use a known random seed, too.
-
- """
- nums = [1, 2, 3]
- lets = ['a', 'b', 'c']
- n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)])
- n, m = set(n), set(m)
- self.assertEqual(n, set(nums))
- self.assertEqual(m, set(lets))
- self.assertEqual(len(n), len(nums))
- self.assertEqual(len(m), len(lets))
-
- def test_list_with_repeat(self):
- """ensure multiple items are chosen, and that they appear to be chosen
- from one list then the next, in proper order.
-
- """
- nums = [1, 2, 3]
- lets = ['a', 'b', 'c']
- r = list(mi.random_product(nums, lets, repeat=100))
- self.assertEqual(2 * 100, len(r))
- n, m = set(r[::2]), set(r[1::2])
- self.assertEqual(n, set(nums))
- self.assertEqual(m, set(lets))
- self.assertEqual(len(n), len(nums))
- self.assertEqual(len(m), len(lets))
-
-
-class RandomPermutationTests(TestCase):
- """Tests for ``random_permutation()``"""
-
- def test_full_permutation(self):
- """ensure every item from the iterable is returned in a new ordering
-
- 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so
- we fix a seed value just to be sure.
-
- """
- i = range(15)
- r = mi.random_permutation(i)
- self.assertEqual(set(i), set(r))
- if i == r:
- raise AssertionError("Values were not permuted")
-
- def test_partial_permutation(self):
- """ensure all returned items are from the iterable, that the returned
- permutation is of the desired length, and that all items eventually
- get returned.
-
- Sampling 100 permutations of length 5 from a set of 15 leaves a
- (2/3)^100 chance that an item will not be chosen. Multiplied by 15
- items, there is a 1 in 2.6e16 chance that at least 1 item will not
- show up in the resulting output. Using a random seed will fix that.
-
- """
- items = range(15)
- item_set = set(items)
- all_items = set()
- for _ in range(100):
- permutation = mi.random_permutation(items, 5)
- self.assertEqual(len(permutation), 5)
- permutation_set = set(permutation)
- self.assertLessEqual(permutation_set, item_set)
- all_items |= permutation_set
- self.assertEqual(all_items, item_set)
-
-
-class RandomCombinationTests(TestCase):
- """Tests for ``random_combination()``"""
-
- def test_pseudorandomness(self):
- """ensure different subsets of the iterable get returned over many
- samplings of random combinations"""
- items = range(15)
- all_items = set()
- for _ in range(50):
- combination = mi.random_combination(items, 5)
- all_items |= set(combination)
- self.assertEqual(all_items, set(items))
-
- def test_no_replacement(self):
- """ensure that elements are sampled without replacement"""
- items = range(15)
- for _ in range(50):
- combination = mi.random_combination(items, len(items))
- self.assertEqual(len(combination), len(set(combination)))
- self.assertRaises(
- ValueError, lambda: mi.random_combination(items, len(items) + 1)
- )
-
-
-class RandomCombinationWithReplacementTests(TestCase):
- """Tests for ``random_combination_with_replacement()``"""
-
- def test_replacement(self):
- """ensure that elements are sampled with replacement"""
- items = range(5)
- combo = mi.random_combination_with_replacement(items, len(items) * 2)
- self.assertEqual(2 * len(items), len(combo))
- if len(set(combo)) == len(combo):
- raise AssertionError("Combination contained no duplicates")
-
- def test_pseudorandomness(self):
- """ensure different subsets of the iterable get returned over many
- samplings of random combinations"""
- items = range(15)
- all_items = set()
- for _ in range(50):
- combination = mi.random_combination_with_replacement(items, 5)
- all_items |= set(combination)
- self.assertEqual(all_items, set(items))
-
-
-class NthCombinationTests(TestCase):
- def test_basic(self):
- iterable = 'abcdefg'
- r = 4
- for index, expected in enumerate(combinations(iterable, r)):
- actual = mi.nth_combination(iterable, r, index)
- self.assertEqual(actual, expected)
-
- def test_long(self):
- actual = mi.nth_combination(range(180), 4, 2000000)
- expected = (2, 12, 35, 126)
- self.assertEqual(actual, expected)
-
- def test_invalid_r(self):
- for r in (-1, 3):
- with self.assertRaises(ValueError):
- mi.nth_combination([], r, 0)
-
- def test_invalid_index(self):
- with self.assertRaises(IndexError):
- mi.nth_combination('abcdefg', 3, -36)
-
-
-class PrependTests(TestCase):
- def test_basic(self):
- value = 'a'
- iterator = iter('bcdefg')
- actual = list(mi.prepend(value, iterator))
- expected = list('abcdefg')
- self.assertEqual(actual, expected)
-
- def test_multiple(self):
- value = 'ab'
- iterator = iter('cdefg')
- actual = tuple(mi.prepend(value, iterator))
- expected = ('ab',) + tuple('cdefg')
- self.assertEqual(actual, expected)
diff --git a/contrib/python/more-itertools/py2/patches/01-fix-tests.patch b/contrib/python/more-itertools/py2/patches/01-fix-tests.patch
deleted file mode 100644
index 85602736df..0000000000
--- a/contrib/python/more-itertools/py2/patches/01-fix-tests.patch
+++ /dev/null
@@ -1,18 +0,0 @@
---- contrib/python/more-itertools/py2/more_itertools/tests/test_more.py (index)
-+++ contrib/python/more-itertools/py2/more_itertools/tests/test_more.py (working tree)
-@@ -122,13 +122,13 @@ class IterOnlyRange:
- raise an ``AttributeError`` rather than ``TypeError`` in Python 2.
-
- >>> r = IterOnlyRange(5)
-- >>> r[0]
-+ >>> r[0] # doctest: +SKIP
- AttributeError: IterOnlyRange instance has no attribute '__getitem__'
-
- Note: In Python 3, ``TypeError`` will be raised because ``object`` is
- inherited implicitly by default.
-
-- >>> r[0]
-+ >>> r[0] # doctest: +SKIP
- TypeError: 'IterOnlyRange' object does not support indexing
- """
- def __init__(self, n):
diff --git a/contrib/python/more-itertools/py2/tests/ya.make b/contrib/python/more-itertools/py2/tests/ya.make
deleted file mode 100644
index 8aecf61cc6..0000000000
--- a/contrib/python/more-itertools/py2/tests/ya.make
+++ /dev/null
@@ -1,18 +0,0 @@
-PY2TEST()
-
-OWNER(g:python-contrib)
-
-PEERDIR(
- contrib/python/more-itertools
-)
-
-SRCDIR(contrib/python/more-itertools/py2/more_itertools/tests)
-
-TEST_SRCS(
- test_more.py
- test_recipes.py
-)
-
-NO_LINT()
-
-END()
diff --git a/contrib/python/more-itertools/py2/ya.make b/contrib/python/more-itertools/py2/ya.make
deleted file mode 100644
index 0a914e4f42..0000000000
--- a/contrib/python/more-itertools/py2/ya.make
+++ /dev/null
@@ -1,34 +0,0 @@
-# Generated by devtools/yamaker (pypi).
-
-PY2_LIBRARY()
-
-OWNER(g:python-contrib)
-
-VERSION(5.0.0)
-
-LICENSE(MIT)
-
-PEERDIR(
- contrib/python/six
-)
-
-NO_LINT()
-
-PY_SRCS(
- TOP_LEVEL
- more_itertools/__init__.py
- more_itertools/more.py
- more_itertools/recipes.py
-)
-
-RESOURCE_FILES(
- PREFIX contrib/python/more-itertools/py2/
- .dist-info/METADATA
- .dist-info/top_level.txt
-)
-
-END()
-
-RECURSE_FOR_TESTS(
- tests
-)
diff --git a/contrib/python/more-itertools/py3/.dist-info/METADATA b/contrib/python/more-itertools/py3/.dist-info/METADATA
deleted file mode 100644
index 9efacdd745..0000000000
--- a/contrib/python/more-itertools/py3/.dist-info/METADATA
+++ /dev/null
@@ -1,521 +0,0 @@
-Metadata-Version: 2.1
-Name: more-itertools
-Version: 8.12.0
-Summary: More routines for operating on iterables, beyond itertools
-Home-page: https://github.com/more-itertools/more-itertools
-Author: Erik Rose
-Author-email: erikrose@grinchcentral.com
-License: MIT
-Keywords: itertools,iterator,iteration,filter,peek,peekable,collate,chunk,chunked
-Platform: UNKNOWN
-Classifier: Development Status :: 5 - Production/Stable
-Classifier: Intended Audience :: Developers
-Classifier: Natural Language :: English
-Classifier: License :: OSI Approved :: MIT License
-Classifier: Programming Language :: Python :: 3
-Classifier: Programming Language :: Python :: 3.6
-Classifier: Programming Language :: Python :: 3.7
-Classifier: Programming Language :: Python :: 3.8
-Classifier: Programming Language :: Python :: 3.9
-Classifier: Programming Language :: Python :: 3 :: Only
-Classifier: Programming Language :: Python :: Implementation :: CPython
-Classifier: Programming Language :: Python :: Implementation :: PyPy
-Classifier: Topic :: Software Development :: Libraries
-Requires-Python: >=3.5
-Description-Content-Type: text/x-rst
-License-File: LICENSE
-
-==============
-More Itertools
-==============
-
-.. image:: https://readthedocs.org/projects/more-itertools/badge/?version=latest
- :target: https://more-itertools.readthedocs.io/en/stable/
-
-Python's ``itertools`` library is a gem - you can compose elegant solutions
-for a variety of problems with the functions it provides. In ``more-itertools``
-we collect additional building blocks, recipes, and routines for working with
-Python iterables.
-
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Grouping | `chunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked>`_, |
-| | `ichunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ichunked>`_, |
-| | `sliced <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliced>`_, |
-| | `distribute <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute>`_, |
-| | `divide <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.divide>`_, |
-| | `split_at <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_at>`_, |
-| | `split_before <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_before>`_, |
-| | `split_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_after>`_, |
-| | `split_into <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_into>`_, |
-| | `split_when <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_when>`_, |
-| | `bucket <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.bucket>`_, |
-| | `unzip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unzip>`_, |
-| | `grouper <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.grouper>`_, |
-| | `partition <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partition>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Lookahead and lookback | `spy <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.spy>`_, |
-| | `peekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.peekable>`_, |
-| | `seekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.seekable>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Windowing | `windowed <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed>`_, |
-| | `substrings <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings>`_, |
-| | `substrings_indexes <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings_indexes>`_, |
-| | `stagger <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.stagger>`_, |
-| | `windowed_complete <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed_complete>`_, |
-| | `pairwise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pairwise>`_, |
-| | `triplewise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.triplewise>`_, |
-| | `sliding_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliding_window>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Augmenting | `count_cycle <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.count_cycle>`_, |
-| | `intersperse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse>`_, |
-| | `padded <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.padded>`_, |
-| | `mark_ends <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.mark_ends>`_, |
-| | `repeat_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeat_last>`_, |
-| | `adjacent <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.adjacent>`_, |
-| | `groupby_transform <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.groupby_transform>`_, |
-| | `pad_none <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pad_none>`_, |
-| | `ncycles <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ncycles>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Combining | `collapse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.collapse>`_, |
-| | `sort_together <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sort_together>`_, |
-| | `interleave <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave>`_, |
-| | `interleave_longest <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_longest>`_, |
-| | `interleave_evenly <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_evenly>`_, |
-| | `zip_offset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_offset>`_, |
-| | `zip_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_equal>`_, |
-| | `zip_broadcast <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_broadcast>`_, |
-| | `dotproduct <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.dotproduct>`_, |
-| | `convolve <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.convolve>`_, |
-| | `flatten <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.flatten>`_, |
-| | `roundrobin <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.roundrobin>`_, |
-| | `prepend <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.prepend>`_, |
-| | `value_chain <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.value_chain>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Summarizing | `ilen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ilen>`_, |
-| | `unique_to_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_to_each>`_, |
-| | `sample <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sample>`_, |
-| | `consecutive_groups <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consecutive_groups>`_, |
-| | `run_length <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.run_length>`_, |
-| | `map_reduce <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_reduce>`_, |
-| | `exactly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.exactly_n>`_, |
-| | `is_sorted <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.is_sorted>`_, |
-| | `all_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_equal>`_, |
-| | `all_unique <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_unique>`_, |
-| | `minmax <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.minmax>`_, |
-| | `first_true <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first_true>`_, |
-| | `quantify <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.quantify>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Selecting | `islice_extended <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.islice_extended>`_, |
-| | `first <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first>`_, |
-| | `last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.last>`_, |
-| | `one <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.one>`_, |
-| | `only <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.only>`_, |
-| | `strictly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strictly_n>`_, |
-| | `strip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strip>`_, |
-| | `lstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.lstrip>`_, |
-| | `rstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rstrip>`_, |
-| | `filter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.filter_except>`_, |
-| | `map_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_except>`_, |
-| | `nth_or_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_or_last>`_, |
-| | `unique_in_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_in_window>`_, |
-| | `before_and_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.before_and_after>`_, |
-| | `nth <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth>`_, |
-| | `take <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.take>`_, |
-| | `tail <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tail>`_, |
-| | `unique_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertoo ls.unique_everseen>`_, |
-| | `unique_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_justseen>`_, |
-| | `duplicates_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_everseen>`_, |
-| | `duplicates_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_justseen>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Combinatorics | `distinct_permutations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_permutations>`_, |
-| | `distinct_combinations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_combinations>`_, |
-| | `circular_shifts <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.circular_shifts>`_, |
-| | `partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partitions>`_, |
-| | `set_partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.set_partitions>`_, |
-| | `product_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.product_index>`_, |
-| | `combination_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.combination_index>`_, |
-| | `permutation_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.permutation_index>`_, |
-| | `powerset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.powerset>`_, |
-| | `random_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_product>`_, |
-| | `random_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_permutation>`_, |
-| | `random_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination>`_, |
-| | `random_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination_with_replacement>`_, |
-| | `nth_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_product>`_, |
-| | `nth_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_permutation>`_, |
-| | `nth_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Wrapping | `always_iterable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_iterable>`_, |
-| | `always_reversible <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_reversible>`_, |
-| | `countable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.countable>`_, |
-| | `consumer <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consumer>`_, |
-| | `with_iter <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.with_iter>`_, |
-| | `iter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_except>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Others | `locate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.locate>`_, |
-| | `rlocate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rlocate>`_, |
-| | `replace <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.replace>`_, |
-| | `numeric_range <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.numeric_range>`_, |
-| | `side_effect <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.side_effect>`_, |
-| | `iterate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iterate>`_, |
-| | `difference <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.difference>`_, |
-| | `make_decorator <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.make_decorator>`_, |
-| | `SequenceView <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.SequenceView>`_, |
-| | `time_limited <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.time_limited>`_, |
-| | `consume <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consume>`_, |
-| | `tabulate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tabulate>`_, |
-| | `repeatfunc <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeatfunc>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-
-
-Getting started
-===============
-
-To get started, install the library with `pip <https://pip.pypa.io/en/stable/>`_:
-
-.. code-block:: shell
-
- pip install more-itertools
-
-The recipes from the `itertools docs <https://docs.python.org/3/library/itertools.html#itertools-recipes>`_
-are included in the top-level package:
-
-.. code-block:: python
-
- >>> from more_itertools import flatten
- >>> iterable = [(0, 1), (2, 3)]
- >>> list(flatten(iterable))
- [0, 1, 2, 3]
-
-Several new recipes are available as well:
-
-.. code-block:: python
-
- >>> from more_itertools import chunked
- >>> iterable = [0, 1, 2, 3, 4, 5, 6, 7, 8]
- >>> list(chunked(iterable, 3))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
-
- >>> from more_itertools import spy
- >>> iterable = (x * x for x in range(1, 6))
- >>> head, iterable = spy(iterable, n=3)
- >>> list(head)
- [1, 4, 9]
- >>> list(iterable)
- [1, 4, 9, 16, 25]
-
-
-
-For the full listing of functions, see the `API documentation <https://more-itertools.readthedocs.io/en/stable/api.html>`_.
-
-
-Links elsewhere
-===============
-
-Blog posts about ``more-itertools``:
-
-* `Yo, I heard you like decorators <https://www.bbayles.com/index/decorator_factory>`__
-* `Tour of Python Itertools <https://martinheinz.dev/blog/16>`__ (`Alternate <https://dev.to/martinheinz/tour-of-python-itertools-4122>`__)
-* `Real-World Python More Itertools <https://www.gidware.com/real-world-more-itertools/>`_
-
-
-Development
-===========
-
-``more-itertools`` is maintained by `@erikrose <https://github.com/erikrose>`_
-and `@bbayles <https://github.com/bbayles>`_, with help from `many others <https://github.com/more-itertools/more-itertools/graphs/contributors>`_.
-If you have a problem or suggestion, please file a bug or pull request in this
-repository. Thanks for contributing!
-
-
-Version History
-===============
-
-
- :noindex:
-
-8.12.0
-------
-
-* Bug fixes
- * Some documentation issues were fixed (thanks to Masynchin, spookylukey, astrojuanlu, and stephengmatthews)
- * Python 3.5 support was temporarily restored (thanks to mattbonnell)
-
-8.11.0
-------
-
-* New functions
- * The before_and_after, sliding_window, and triplewise recipes from the Python 3.10 docs were added
- * duplicates_everseen and duplicates_justseen (thanks to OrBin and DavidPratt512)
- * minmax (thanks to Ricocotam, MSeifert04, and ruancomelli)
- * strictly_n (thanks to hwalinga and NotWearingPants)
- * unique_in_window
-
-* Changes to existing functions
- * groupby_transform had its type stub improved (thanks to mjk4 and ruancomelli)
- * is_sorted now accepts a ``strict`` parameter (thanks to Dutcho and ruancomelli)
- * zip_broadcast was updated to fix a bug (thanks to kalekundert)
-
-8.10.0
-------
-
-* Changes to existing functions
- * The type stub for iter_except was improved (thanks to MarcinKonowalczyk)
-
-* Other changes:
- * Type stubs now ship with the source release (thanks to saaketp)
- * The Sphinx docs were improved (thanks to MarcinKonowalczyk)
-
-8.9.0
------
-
-* New functions
- * interleave_evenly (thanks to mbugert)
- * repeat_each (thanks to FinalSh4re)
- * chunked_even (thanks to valtron)
- * map_if (thanks to sassbalint)
- * zip_broadcast (thanks to kalekundert)
-
-* Changes to existing functions
- * The type stub for chunked was improved (thanks to PhilMacKay)
- * The type stubs for zip_equal and `zip_offset` were improved (thanks to maffoo)
- * Building Sphinx docs locally was improved (thanks to MarcinKonowalczyk)
-
-8.8.0
------
-
-* New functions
- * countable (thanks to krzysieq)
-
-* Changes to existing functions
- * split_before was updated to handle empy collections (thanks to TiunovNN)
- * unique_everseen got a performance boost (thanks to Numerlor)
- * The type hint for value_chain was corrected (thanks to vr2262)
-
-8.7.0
------
-
-* New functions
- * convolve (from the Python itertools docs)
- * product_index, combination_index, and permutation_index (thanks to N8Brooks)
- * value_chain (thanks to jenstroeger)
-
-* Changes to existing functions
- * distinct_combinations now uses a non-recursive algorithm (thanks to knutdrand)
- * pad_none is now the preferred name for padnone, though the latter remains available.
- * pairwise will now use the Python standard library implementation on Python 3.10+
- * sort_together now accepts a ``key`` argument (thanks to brianmaissy)
- * seekable now has a ``peek`` method, and can indicate whether the iterator it's wrapping is exhausted (thanks to gsakkis)
- * time_limited can now indicate whether its iterator has expired (thanks to roysmith)
- * The implementation of unique_everseen was improved (thanks to plammens)
-
-* Other changes:
- * Various documentation updates (thanks to cthoyt, Evantm, and cyphase)
-
-8.6.0
------
-
-* New itertools
- * all_unique (thanks to brianmaissy)
- * nth_product and nth_permutation (thanks to N8Brooks)
-
-* Changes to existing itertools
- * chunked and sliced now accept a ``strict`` parameter (thanks to shlomif and jtwool)
-
-* Other changes
- * Python 3.5 has reached its end of life and is no longer supported.
- * Python 3.9 is officially supported.
- * Various documentation fixes (thanks to timgates42)
-
-8.5.0
------
-
-* New itertools
- * windowed_complete (thanks to MarcinKonowalczyk)
-
-* Changes to existing itertools:
- * The is_sorted implementation was improved (thanks to cool-RR)
- * The groupby_transform now accepts a ``reducefunc`` parameter.
- * The last implementation was improved (thanks to brianmaissy)
-
-* Other changes
- * Various documentation fixes (thanks to craigrosie, samuelstjean, PiCT0)
- * The tests for distinct_combinations were improved (thanks to Minabsapi)
- * Automated tests now run on GitHub Actions. All commits now check:
- * That unit tests pass
- * That the examples in docstrings work
- * That test coverage remains high (using `coverage`)
- * For linting errors (using `flake8`)
- * For consistent style (using `black`)
- * That the type stubs work (using `mypy`)
- * That the docs build correctly (using `sphinx`)
- * That packages build correctly (using `twine`)
-
-8.4.0
------
-
-* New itertools
- * mark_ends (thanks to kalekundert)
- * is_sorted
-
-* Changes to existing itertools:
- * islice_extended can now be used with real slices (thanks to cool-RR)
- * The implementations for filter_except and map_except were improved (thanks to SergBobrovsky)
-
-* Other changes
- * Automated tests now enforce code style (using `black <https://github.com/psf/black>`__)
- * The various signatures of islice_extended and numeric_range now appear in the docs (thanks to dsfulf)
- * The test configuration for mypy was updated (thanks to blueyed)
-
-
-8.3.0
------
-
-* New itertools
- * zip_equal (thanks to frankier and alexmojaki)
-
-* Changes to existing itertools:
- * split_at, split_before, split_after, and split_when all got a ``maxsplit`` paramter (thanks to jferard and ilai-deutel)
- * split_at now accepts a ``keep_separator`` parameter (thanks to jferard)
- * distinct_permutations can now generate ``r``-length permutations (thanks to SergBobrovsky and ilai-deutel)
- * The windowed implementation was improved (thanks to SergBobrovsky)
- * The spy implementation was improved (thanks to has2k1)
-
-* Other changes
- * Type stubs are now tested with ``stubtest`` (thanks to ilai-deutel)
- * Tests now run with ``python -m unittest`` instead of ``python setup.py test`` (thanks to jdufresne)
-
-8.2.0
------
-
-* Bug fixes
- * The .pyi files for typing were updated. (thanks to blueyed and ilai-deutel)
-
-* Changes to existing itertools:
- * numeric_range now behaves more like the built-in range. (thanks to jferard)
- * bucket now allows for enumerating keys. (thanks to alexchandel)
- * sliced now should now work for numpy arrays. (thanks to sswingle)
- * seekable now has a ``maxlen`` parameter.
-
-8.1.0
------
-
-* Bug fixes
- * partition works with ``pred=None`` again. (thanks to MSeifert04)
-
-* New itertools
- * sample (thanks to tommyod)
- * nth_or_last (thanks to d-ryzhikov)
-
-* Changes to existing itertools:
- * The implementation for divide was improved. (thanks to jferard)
-
-8.0.2
------
-
-* Bug fixes
- * The type stub files are now part of the wheel distribution (thanks to keisheiled)
-
-8.0.1
------
-
-* Bug fixes
- * The type stub files now work for functions imported from the
- root package (thanks to keisheiled)
-
-8.0.0
------
-
-* New itertools and other additions
- * This library now ships type hints for use with mypy.
- (thanks to ilai-deutel for the implementation, and to gabbard and fmagin for assistance)
- * split_when (thanks to jferard)
- * repeat_last (thanks to d-ryzhikov)
-
-* Changes to existing itertools:
- * The implementation for set_partitions was improved. (thanks to jferard)
- * partition was optimized for expensive predicates. (thanks to stevecj)
- * unique_everseen and groupby_transform were re-factored. (thanks to SergBobrovsky)
- * The implementation for difference was improved. (thanks to Jabbey92)
-
-* Other changes
- * Python 3.4 has reached its end of life and is no longer supported.
- * Python 3.8 is officially supported. (thanks to jdufresne)
- * The ``collate`` function has been deprecated.
- It raises a ``DeprecationWarning`` if used, and will be removed in a future release.
- * one and only now provide more informative error messages. (thanks to gabbard)
- * Unit tests were moved outside of the main package (thanks to jdufresne)
- * Various documentation fixes (thanks to kriomant, gabbard, jdufresne)
-
-
-7.2.0
------
-
-* New itertools
- * distinct_combinations
- * set_partitions (thanks to kbarrett)
- * filter_except
- * map_except
-
-7.1.0
------
-
-* New itertools
- * ichunked (thanks davebelais and youtux)
- * only (thanks jaraco)
-
-* Changes to existing itertools:
- * numeric_range now supports ranges specified by
- ``datetime.datetime`` and ``datetime.timedelta`` objects (thanks to MSeifert04 for tests).
- * difference now supports an *initial* keyword argument.
-
-
-* Other changes
- * Various documentation fixes (thanks raimon49, pylang)
-
-7.0.0
------
-
-* New itertools:
- * time_limited
- * partitions (thanks to rominf and Saluev)
- * substrings_indexes (thanks to rominf)
-
-* Changes to existing itertools:
- * collapse now treats ``bytes`` objects the same as ``str`` objects. (thanks to Sweenpet)
-
-The major version update is due to the change in the default behavior of
-collapse. It now treats ``bytes`` objects the same as ``str`` objects.
-This aligns its behavior with always_iterable.
-
-.. code-block:: python
-
- >>> from more_itertools import collapse
- >>> iterable = [[1, 2], b'345', [6]]
- >>> print(list(collapse(iterable)))
- [1, 2, b'345', 6]
-
-6.0.0
------
-
-* Major changes:
- * Python 2.7 is no longer supported. The 5.0.0 release will be the last
- version targeting Python 2.7.
- * All future releases will target the active versions of Python 3.
- As of 2019, those are Python 3.4 and above.
- * The ``six`` library is no longer a dependency.
- * The accumulate function is no longer part of this library. You
- may import a better version from the standard ``itertools`` module.
-
-* Changes to existing itertools:
- * The order of the parameters in grouper have changed to match
- the latest recipe in the itertools documentation. Use of the old order
- will be supported in this release, but emit a ``DeprecationWarning``.
- The legacy behavior will be dropped in a future release. (thanks to jaraco)
- * distinct_permutations was improved (thanks to jferard - see also `permutations with unique values <https://stackoverflow.com/questions/6284396/permutations-with-unique-values>`_ at StackOverflow.)
- * An unused parameter was removed from substrings. (thanks to pylang)
-
-* Other changes:
- * The docs for unique_everseen were improved. (thanks to jferard and MSeifert04)
- * Several Python 2-isms were removed. (thanks to jaraco, MSeifert04, and hugovk)
-
-
diff --git a/contrib/python/more-itertools/py3/.dist-info/top_level.txt b/contrib/python/more-itertools/py3/.dist-info/top_level.txt
deleted file mode 100644
index a5035befb3..0000000000
--- a/contrib/python/more-itertools/py3/.dist-info/top_level.txt
+++ /dev/null
@@ -1 +0,0 @@
-more_itertools
diff --git a/contrib/python/more-itertools/py3/LICENSE b/contrib/python/more-itertools/py3/LICENSE
deleted file mode 100644
index 0a523bece3..0000000000
--- a/contrib/python/more-itertools/py3/LICENSE
+++ /dev/null
@@ -1,19 +0,0 @@
-Copyright (c) 2012 Erik Rose
-
-Permission is hereby granted, free of charge, to any person obtaining a copy of
-this software and associated documentation files (the "Software"), to deal in
-the Software without restriction, including without limitation the rights to
-use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
-of the Software, and to permit persons to whom the Software is furnished to do
-so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/contrib/python/more-itertools/py3/README.rst b/contrib/python/more-itertools/py3/README.rst
deleted file mode 100644
index 4df22091a4..0000000000
--- a/contrib/python/more-itertools/py3/README.rst
+++ /dev/null
@@ -1,200 +0,0 @@
-==============
-More Itertools
-==============
-
-.. image:: https://readthedocs.org/projects/more-itertools/badge/?version=latest
- :target: https://more-itertools.readthedocs.io/en/stable/
-
-Python's ``itertools`` library is a gem - you can compose elegant solutions
-for a variety of problems with the functions it provides. In ``more-itertools``
-we collect additional building blocks, recipes, and routines for working with
-Python iterables.
-
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Grouping | `chunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked>`_, |
-| | `ichunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ichunked>`_, |
-| | `sliced <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliced>`_, |
-| | `distribute <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute>`_, |
-| | `divide <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.divide>`_, |
-| | `split_at <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_at>`_, |
-| | `split_before <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_before>`_, |
-| | `split_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_after>`_, |
-| | `split_into <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_into>`_, |
-| | `split_when <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_when>`_, |
-| | `bucket <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.bucket>`_, |
-| | `unzip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unzip>`_, |
-| | `grouper <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.grouper>`_, |
-| | `partition <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partition>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Lookahead and lookback | `spy <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.spy>`_, |
-| | `peekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.peekable>`_, |
-| | `seekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.seekable>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Windowing | `windowed <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed>`_, |
-| | `substrings <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings>`_, |
-| | `substrings_indexes <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings_indexes>`_, |
-| | `stagger <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.stagger>`_, |
-| | `windowed_complete <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed_complete>`_, |
-| | `pairwise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pairwise>`_, |
-| | `triplewise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.triplewise>`_, |
-| | `sliding_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliding_window>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Augmenting | `count_cycle <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.count_cycle>`_, |
-| | `intersperse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse>`_, |
-| | `padded <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.padded>`_, |
-| | `mark_ends <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.mark_ends>`_, |
-| | `repeat_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeat_last>`_, |
-| | `adjacent <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.adjacent>`_, |
-| | `groupby_transform <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.groupby_transform>`_, |
-| | `pad_none <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pad_none>`_, |
-| | `ncycles <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ncycles>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Combining | `collapse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.collapse>`_, |
-| | `sort_together <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sort_together>`_, |
-| | `interleave <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave>`_, |
-| | `interleave_longest <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_longest>`_, |
-| | `interleave_evenly <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_evenly>`_, |
-| | `zip_offset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_offset>`_, |
-| | `zip_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_equal>`_, |
-| | `zip_broadcast <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_broadcast>`_, |
-| | `dotproduct <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.dotproduct>`_, |
-| | `convolve <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.convolve>`_, |
-| | `flatten <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.flatten>`_, |
-| | `roundrobin <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.roundrobin>`_, |
-| | `prepend <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.prepend>`_, |
-| | `value_chain <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.value_chain>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Summarizing | `ilen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ilen>`_, |
-| | `unique_to_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_to_each>`_, |
-| | `sample <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sample>`_, |
-| | `consecutive_groups <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consecutive_groups>`_, |
-| | `run_length <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.run_length>`_, |
-| | `map_reduce <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_reduce>`_, |
-| | `exactly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.exactly_n>`_, |
-| | `is_sorted <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.is_sorted>`_, |
-| | `all_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_equal>`_, |
-| | `all_unique <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_unique>`_, |
-| | `minmax <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.minmax>`_, |
-| | `first_true <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first_true>`_, |
-| | `quantify <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.quantify>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Selecting | `islice_extended <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.islice_extended>`_, |
-| | `first <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first>`_, |
-| | `last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.last>`_, |
-| | `one <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.one>`_, |
-| | `only <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.only>`_, |
-| | `strictly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strictly_n>`_, |
-| | `strip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strip>`_, |
-| | `lstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.lstrip>`_, |
-| | `rstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rstrip>`_, |
-| | `filter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.filter_except>`_, |
-| | `map_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_except>`_, |
-| | `nth_or_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_or_last>`_, |
-| | `unique_in_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_in_window>`_, |
-| | `before_and_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.before_and_after>`_, |
-| | `nth <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth>`_, |
-| | `take <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.take>`_, |
-| | `tail <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tail>`_, |
-| | `unique_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertoo ls.unique_everseen>`_, |
-| | `unique_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_justseen>`_, |
-| | `duplicates_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_everseen>`_, |
-| | `duplicates_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_justseen>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Combinatorics | `distinct_permutations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_permutations>`_, |
-| | `distinct_combinations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_combinations>`_, |
-| | `circular_shifts <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.circular_shifts>`_, |
-| | `partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partitions>`_, |
-| | `set_partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.set_partitions>`_, |
-| | `product_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.product_index>`_, |
-| | `combination_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.combination_index>`_, |
-| | `permutation_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.permutation_index>`_, |
-| | `powerset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.powerset>`_, |
-| | `random_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_product>`_, |
-| | `random_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_permutation>`_, |
-| | `random_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination>`_, |
-| | `random_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination_with_replacement>`_, |
-| | `nth_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_product>`_, |
-| | `nth_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_permutation>`_, |
-| | `nth_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Wrapping | `always_iterable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_iterable>`_, |
-| | `always_reversible <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_reversible>`_, |
-| | `countable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.countable>`_, |
-| | `consumer <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consumer>`_, |
-| | `with_iter <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.with_iter>`_, |
-| | `iter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_except>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| Others | `locate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.locate>`_, |
-| | `rlocate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rlocate>`_, |
-| | `replace <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.replace>`_, |
-| | `numeric_range <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.numeric_range>`_, |
-| | `side_effect <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.side_effect>`_, |
-| | `iterate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iterate>`_, |
-| | `difference <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.difference>`_, |
-| | `make_decorator <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.make_decorator>`_, |
-| | `SequenceView <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.SequenceView>`_, |
-| | `time_limited <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.time_limited>`_, |
-| | `consume <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consume>`_, |
-| | `tabulate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tabulate>`_, |
-| | `repeatfunc <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeatfunc>`_ |
-+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-
-
-Getting started
-===============
-
-To get started, install the library with `pip <https://pip.pypa.io/en/stable/>`_:
-
-.. code-block:: shell
-
- pip install more-itertools
-
-The recipes from the `itertools docs <https://docs.python.org/3/library/itertools.html#itertools-recipes>`_
-are included in the top-level package:
-
-.. code-block:: python
-
- >>> from more_itertools import flatten
- >>> iterable = [(0, 1), (2, 3)]
- >>> list(flatten(iterable))
- [0, 1, 2, 3]
-
-Several new recipes are available as well:
-
-.. code-block:: python
-
- >>> from more_itertools import chunked
- >>> iterable = [0, 1, 2, 3, 4, 5, 6, 7, 8]
- >>> list(chunked(iterable, 3))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
-
- >>> from more_itertools import spy
- >>> iterable = (x * x for x in range(1, 6))
- >>> head, iterable = spy(iterable, n=3)
- >>> list(head)
- [1, 4, 9]
- >>> list(iterable)
- [1, 4, 9, 16, 25]
-
-
-
-For the full listing of functions, see the `API documentation <https://more-itertools.readthedocs.io/en/stable/api.html>`_.
-
-
-Links elsewhere
-===============
-
-Blog posts about ``more-itertools``:
-
-* `Yo, I heard you like decorators <https://www.bbayles.com/index/decorator_factory>`__
-* `Tour of Python Itertools <https://martinheinz.dev/blog/16>`__ (`Alternate <https://dev.to/martinheinz/tour-of-python-itertools-4122>`__)
-* `Real-World Python More Itertools <https://www.gidware.com/real-world-more-itertools/>`_
-
-
-Development
-===========
-
-``more-itertools`` is maintained by `@erikrose <https://github.com/erikrose>`_
-and `@bbayles <https://github.com/bbayles>`_, with help from `many others <https://github.com/more-itertools/more-itertools/graphs/contributors>`_.
-If you have a problem or suggestion, please file a bug or pull request in this
-repository. Thanks for contributing!
diff --git a/contrib/python/more-itertools/py3/more_itertools/__init__.py b/contrib/python/more-itertools/py3/more_itertools/__init__.py
deleted file mode 100644
index ea38bef1f6..0000000000
--- a/contrib/python/more-itertools/py3/more_itertools/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .more import * # noqa
-from .recipes import * # noqa
-
-__version__ = '8.12.0'
diff --git a/contrib/python/more-itertools/py3/more_itertools/__init__.pyi b/contrib/python/more-itertools/py3/more_itertools/__init__.pyi
deleted file mode 100644
index 96f6e36c7f..0000000000
--- a/contrib/python/more-itertools/py3/more_itertools/__init__.pyi
+++ /dev/null
@@ -1,2 +0,0 @@
-from .more import *
-from .recipes import *
diff --git a/contrib/python/more-itertools/py3/more_itertools/more.py b/contrib/python/more-itertools/py3/more_itertools/more.py
deleted file mode 100644
index 630af973f2..0000000000
--- a/contrib/python/more-itertools/py3/more_itertools/more.py
+++ /dev/null
@@ -1,4317 +0,0 @@
-import warnings
-
-from collections import Counter, defaultdict, deque, abc
-from collections.abc import Sequence
-from concurrent.futures import ThreadPoolExecutor
-from functools import partial, reduce, wraps
-from heapq import merge, heapify, heapreplace, heappop
-from itertools import (
- chain,
- compress,
- count,
- cycle,
- dropwhile,
- groupby,
- islice,
- repeat,
- starmap,
- takewhile,
- tee,
- zip_longest,
-)
-from math import exp, factorial, floor, log
-from queue import Empty, Queue
-from random import random, randrange, uniform
-from operator import itemgetter, mul, sub, gt, lt, ge, le
-from sys import hexversion, maxsize
-from time import monotonic
-
-from .recipes import (
- consume,
- flatten,
- pairwise,
- powerset,
- take,
- unique_everseen,
-)
-
-__all__ = [
- 'AbortThread',
- 'SequenceView',
- 'UnequalIterablesError',
- 'adjacent',
- 'all_unique',
- 'always_iterable',
- 'always_reversible',
- 'bucket',
- 'callback_iter',
- 'chunked',
- 'chunked_even',
- 'circular_shifts',
- 'collapse',
- 'collate',
- 'combination_index',
- 'consecutive_groups',
- 'consumer',
- 'count_cycle',
- 'countable',
- 'difference',
- 'distinct_combinations',
- 'distinct_permutations',
- 'distribute',
- 'divide',
- 'duplicates_everseen',
- 'duplicates_justseen',
- 'exactly_n',
- 'filter_except',
- 'first',
- 'groupby_transform',
- 'ichunked',
- 'ilen',
- 'interleave',
- 'interleave_evenly',
- 'interleave_longest',
- 'intersperse',
- 'is_sorted',
- 'islice_extended',
- 'iterate',
- 'last',
- 'locate',
- 'lstrip',
- 'make_decorator',
- 'map_except',
- 'map_if',
- 'map_reduce',
- 'mark_ends',
- 'minmax',
- 'nth_or_last',
- 'nth_permutation',
- 'nth_product',
- 'numeric_range',
- 'one',
- 'only',
- 'padded',
- 'partitions',
- 'peekable',
- 'permutation_index',
- 'product_index',
- 'raise_',
- 'repeat_each',
- 'repeat_last',
- 'replace',
- 'rlocate',
- 'rstrip',
- 'run_length',
- 'sample',
- 'seekable',
- 'set_partitions',
- 'side_effect',
- 'sliced',
- 'sort_together',
- 'split_after',
- 'split_at',
- 'split_before',
- 'split_into',
- 'split_when',
- 'spy',
- 'stagger',
- 'strip',
- 'strictly_n',
- 'substrings',
- 'substrings_indexes',
- 'time_limited',
- 'unique_in_window',
- 'unique_to_each',
- 'unzip',
- 'value_chain',
- 'windowed',
- 'windowed_complete',
- 'with_iter',
- 'zip_broadcast',
- 'zip_equal',
- 'zip_offset',
-]
-
-
-_marker = object()
-
-
-def chunked(iterable, n, strict=False):
- """Break *iterable* into lists of length *n*:
-
- >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
- [[1, 2, 3], [4, 5, 6]]
-
- By the default, the last yielded list will have fewer than *n* elements
- if the length of *iterable* is not divisible by *n*:
-
- >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
- [[1, 2, 3], [4, 5, 6], [7, 8]]
-
- To use a fill-in value instead, see the :func:`grouper` recipe.
-
- If the length of *iterable* is not divisible by *n* and *strict* is
- ``True``, then ``ValueError`` will be raised before the last
- list is yielded.
-
- """
- iterator = iter(partial(take, n, iter(iterable)), [])
- if strict:
- if n is None:
- raise ValueError('n must not be None when using strict mode.')
-
- def ret():
- for chunk in iterator:
- if len(chunk) != n:
- raise ValueError('iterable is not divisible by n.')
- yield chunk
-
- return iter(ret())
- else:
- return iterator
-
-
-def first(iterable, default=_marker):
- """Return the first item of *iterable*, or *default* if *iterable* is
- empty.
-
- >>> first([0, 1, 2, 3])
- 0
- >>> first([], 'some default')
- 'some default'
-
- If *default* is not provided and there are no items in the iterable,
- raise ``ValueError``.
-
- :func:`first` is useful when you have a generator of expensive-to-retrieve
- values and want any arbitrary one. It is marginally shorter than
- ``next(iter(iterable), default)``.
-
- """
- try:
- return next(iter(iterable))
- except StopIteration as e:
- if default is _marker:
- raise ValueError(
- 'first() was called on an empty iterable, and no '
- 'default value was provided.'
- ) from e
- return default
-
-
-def last(iterable, default=_marker):
- """Return the last item of *iterable*, or *default* if *iterable* is
- empty.
-
- >>> last([0, 1, 2, 3])
- 3
- >>> last([], 'some default')
- 'some default'
-
- If *default* is not provided and there are no items in the iterable,
- raise ``ValueError``.
- """
- try:
- if isinstance(iterable, Sequence):
- return iterable[-1]
- # Work around https://bugs.python.org/issue38525
- elif hasattr(iterable, '__reversed__') and (hexversion != 0x030800F0):
- return next(reversed(iterable))
- else:
- return deque(iterable, maxlen=1)[-1]
- except (IndexError, TypeError, StopIteration):
- if default is _marker:
- raise ValueError(
- 'last() was called on an empty iterable, and no default was '
- 'provided.'
- )
- return default
-
-
-def nth_or_last(iterable, n, default=_marker):
- """Return the nth or the last item of *iterable*,
- or *default* if *iterable* is empty.
-
- >>> nth_or_last([0, 1, 2, 3], 2)
- 2
- >>> nth_or_last([0, 1], 2)
- 1
- >>> nth_or_last([], 0, 'some default')
- 'some default'
-
- If *default* is not provided and there are no items in the iterable,
- raise ``ValueError``.
- """
- return last(islice(iterable, n + 1), default=default)
-
-
-class peekable:
- """Wrap an iterator to allow lookahead and prepending elements.
-
- Call :meth:`peek` on the result to get the value that will be returned
- by :func:`next`. This won't advance the iterator:
-
- >>> p = peekable(['a', 'b'])
- >>> p.peek()
- 'a'
- >>> next(p)
- 'a'
-
- Pass :meth:`peek` a default value to return that instead of raising
- ``StopIteration`` when the iterator is exhausted.
-
- >>> p = peekable([])
- >>> p.peek('hi')
- 'hi'
-
- peekables also offer a :meth:`prepend` method, which "inserts" items
- at the head of the iterable:
-
- >>> p = peekable([1, 2, 3])
- >>> p.prepend(10, 11, 12)
- >>> next(p)
- 10
- >>> p.peek()
- 11
- >>> list(p)
- [11, 12, 1, 2, 3]
-
- peekables can be indexed. Index 0 is the item that will be returned by
- :func:`next`, index 1 is the item after that, and so on:
- The values up to the given index will be cached.
-
- >>> p = peekable(['a', 'b', 'c', 'd'])
- >>> p[0]
- 'a'
- >>> p[1]
- 'b'
- >>> next(p)
- 'a'
-
- Negative indexes are supported, but be aware that they will cache the
- remaining items in the source iterator, which may require significant
- storage.
-
- To check whether a peekable is exhausted, check its truth value:
-
- >>> p = peekable(['a', 'b'])
- >>> if p: # peekable has items
- ... list(p)
- ['a', 'b']
- >>> if not p: # peekable is exhausted
- ... list(p)
- []
-
- """
-
- def __init__(self, iterable):
- self._it = iter(iterable)
- self._cache = deque()
-
- def __iter__(self):
- return self
-
- def __bool__(self):
- try:
- self.peek()
- except StopIteration:
- return False
- return True
-
- def peek(self, default=_marker):
- """Return the item that will be next returned from ``next()``.
-
- Return ``default`` if there are no items left. If ``default`` is not
- provided, raise ``StopIteration``.
-
- """
- if not self._cache:
- try:
- self._cache.append(next(self._it))
- except StopIteration:
- if default is _marker:
- raise
- return default
- return self._cache[0]
-
- def prepend(self, *items):
- """Stack up items to be the next ones returned from ``next()`` or
- ``self.peek()``. The items will be returned in
- first in, first out order::
-
- >>> p = peekable([1, 2, 3])
- >>> p.prepend(10, 11, 12)
- >>> next(p)
- 10
- >>> list(p)
- [11, 12, 1, 2, 3]
-
- It is possible, by prepending items, to "resurrect" a peekable that
- previously raised ``StopIteration``.
-
- >>> p = peekable([])
- >>> next(p)
- Traceback (most recent call last):
- ...
- StopIteration
- >>> p.prepend(1)
- >>> next(p)
- 1
- >>> next(p)
- Traceback (most recent call last):
- ...
- StopIteration
-
- """
- self._cache.extendleft(reversed(items))
-
- def __next__(self):
- if self._cache:
- return self._cache.popleft()
-
- return next(self._it)
-
- def _get_slice(self, index):
- # Normalize the slice's arguments
- step = 1 if (index.step is None) else index.step
- if step > 0:
- start = 0 if (index.start is None) else index.start
- stop = maxsize if (index.stop is None) else index.stop
- elif step < 0:
- start = -1 if (index.start is None) else index.start
- stop = (-maxsize - 1) if (index.stop is None) else index.stop
- else:
- raise ValueError('slice step cannot be zero')
-
- # If either the start or stop index is negative, we'll need to cache
- # the rest of the iterable in order to slice from the right side.
- if (start < 0) or (stop < 0):
- self._cache.extend(self._it)
- # Otherwise we'll need to find the rightmost index and cache to that
- # point.
- else:
- n = min(max(start, stop) + 1, maxsize)
- cache_len = len(self._cache)
- if n >= cache_len:
- self._cache.extend(islice(self._it, n - cache_len))
-
- return list(self._cache)[index]
-
- def __getitem__(self, index):
- if isinstance(index, slice):
- return self._get_slice(index)
-
- cache_len = len(self._cache)
- if index < 0:
- self._cache.extend(self._it)
- elif index >= cache_len:
- self._cache.extend(islice(self._it, index + 1 - cache_len))
-
- return self._cache[index]
-
-
-def collate(*iterables, **kwargs):
- """Return a sorted merge of the items from each of several already-sorted
- *iterables*.
-
- >>> list(collate('ACDZ', 'AZ', 'JKL'))
- ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z']
-
- Works lazily, keeping only the next value from each iterable in memory. Use
- :func:`collate` to, for example, perform a n-way mergesort of items that
- don't fit in memory.
-
- If a *key* function is specified, the iterables will be sorted according
- to its result:
-
- >>> key = lambda s: int(s) # Sort by numeric value, not by string
- >>> list(collate(['1', '10'], ['2', '11'], key=key))
- ['1', '2', '10', '11']
-
-
- If the *iterables* are sorted in descending order, set *reverse* to
- ``True``:
-
- >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True))
- [5, 4, 3, 2, 1, 0]
-
- If the elements of the passed-in iterables are out of order, you might get
- unexpected results.
-
- On Python 3.5+, this function is an alias for :func:`heapq.merge`.
-
- """
- warnings.warn(
- "collate is no longer part of more_itertools, use heapq.merge",
- DeprecationWarning,
- )
- return merge(*iterables, **kwargs)
-
-
-def consumer(func):
- """Decorator that automatically advances a PEP-342-style "reverse iterator"
- to its first yield point so you don't have to call ``next()`` on it
- manually.
-
- >>> @consumer
- ... def tally():
- ... i = 0
- ... while True:
- ... print('Thing number %s is %s.' % (i, (yield)))
- ... i += 1
- ...
- >>> t = tally()
- >>> t.send('red')
- Thing number 0 is red.
- >>> t.send('fish')
- Thing number 1 is fish.
-
- Without the decorator, you would have to call ``next(t)`` before
- ``t.send()`` could be used.
-
- """
-
- @wraps(func)
- def wrapper(*args, **kwargs):
- gen = func(*args, **kwargs)
- next(gen)
- return gen
-
- return wrapper
-
-
-def ilen(iterable):
- """Return the number of items in *iterable*.
-
- >>> ilen(x for x in range(1000000) if x % 3 == 0)
- 333334
-
- This consumes the iterable, so handle with care.
-
- """
- # This approach was selected because benchmarks showed it's likely the
- # fastest of the known implementations at the time of writing.
- # See GitHub tracker: #236, #230.
- counter = count()
- deque(zip(iterable, counter), maxlen=0)
- return next(counter)
-
-
-def iterate(func, start):
- """Return ``start``, ``func(start)``, ``func(func(start))``, ...
-
- >>> from itertools import islice
- >>> list(islice(iterate(lambda x: 2*x, 1), 10))
- [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
-
- """
- while True:
- yield start
- start = func(start)
-
-
-def with_iter(context_manager):
- """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
-
- For example, this will close the file when the iterator is exhausted::
-
- upper_lines = (line.upper() for line in with_iter(open('foo')))
-
- Any context manager which returns an iterable is a candidate for
- ``with_iter``.
-
- """
- with context_manager as iterable:
- yield from iterable
-
-
-def one(iterable, too_short=None, too_long=None):
- """Return the first item from *iterable*, which is expected to contain only
- that item. Raise an exception if *iterable* is empty or has more than one
- item.
-
- :func:`one` is useful for ensuring that an iterable contains only one item.
- For example, it can be used to retrieve the result of a database query
- that is expected to return a single row.
-
- If *iterable* is empty, ``ValueError`` will be raised. You may specify a
- different exception with the *too_short* keyword:
-
- >>> it = []
- >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError: too many items in iterable (expected 1)'
- >>> too_short = IndexError('too few items')
- >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- IndexError: too few items
-
- Similarly, if *iterable* contains more than one item, ``ValueError`` will
- be raised. You may specify a different exception with the *too_long*
- keyword:
-
- >>> it = ['too', 'many']
- >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError: Expected exactly one item in iterable, but got 'too',
- 'many', and perhaps more.
- >>> too_long = RuntimeError
- >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- RuntimeError
-
- Note that :func:`one` attempts to advance *iterable* twice to ensure there
- is only one item. See :func:`spy` or :func:`peekable` to check iterable
- contents less destructively.
-
- """
- it = iter(iterable)
-
- try:
- first_value = next(it)
- except StopIteration as e:
- raise (
- too_short or ValueError('too few items in iterable (expected 1)')
- ) from e
-
- try:
- second_value = next(it)
- except StopIteration:
- pass
- else:
- msg = (
- 'Expected exactly one item in iterable, but got {!r}, {!r}, '
- 'and perhaps more.'.format(first_value, second_value)
- )
- raise too_long or ValueError(msg)
-
- return first_value
-
-
-def raise_(exception, *args):
- raise exception(*args)
-
-
-def strictly_n(iterable, n, too_short=None, too_long=None):
- """Validate that *iterable* has exactly *n* items and return them if
- it does. If it has fewer than *n* items, call function *too_short*
- with those items. If it has more than *n* items, call function
- *too_long* with the first ``n + 1`` items.
-
- >>> iterable = ['a', 'b', 'c', 'd']
- >>> n = 4
- >>> list(strictly_n(iterable, n))
- ['a', 'b', 'c', 'd']
-
- By default, *too_short* and *too_long* are functions that raise
- ``ValueError``.
-
- >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError: too few items in iterable (got 2)
-
- >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError: too many items in iterable (got at least 3)
-
- You can instead supply functions that do something else.
- *too_short* will be called with the number of items in *iterable*.
- *too_long* will be called with `n + 1`.
-
- >>> def too_short(item_count):
- ... raise RuntimeError
- >>> it = strictly_n('abcd', 6, too_short=too_short)
- >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- RuntimeError
-
- >>> def too_long(item_count):
- ... print('The boss is going to hear about this')
- >>> it = strictly_n('abcdef', 4, too_long=too_long)
- >>> list(it)
- The boss is going to hear about this
- ['a', 'b', 'c', 'd']
-
- """
- if too_short is None:
- too_short = lambda item_count: raise_(
- ValueError,
- 'Too few items in iterable (got {})'.format(item_count),
- )
-
- if too_long is None:
- too_long = lambda item_count: raise_(
- ValueError,
- 'Too many items in iterable (got at least {})'.format(item_count),
- )
-
- it = iter(iterable)
- for i in range(n):
- try:
- item = next(it)
- except StopIteration:
- too_short(i)
- return
- else:
- yield item
-
- try:
- next(it)
- except StopIteration:
- pass
- else:
- too_long(n + 1)
-
-
-def distinct_permutations(iterable, r=None):
- """Yield successive distinct permutations of the elements in *iterable*.
-
- >>> sorted(distinct_permutations([1, 0, 1]))
- [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
-
- Equivalent to ``set(permutations(iterable))``, except duplicates are not
- generated and thrown away. For larger input sequences this is much more
- efficient.
-
- Duplicate permutations arise when there are duplicated elements in the
- input iterable. The number of items returned is
- `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
- items input, and each `x_i` is the count of a distinct item in the input
- sequence.
-
- If *r* is given, only the *r*-length permutations are yielded.
-
- >>> sorted(distinct_permutations([1, 0, 1], r=2))
- [(0, 1), (1, 0), (1, 1)]
- >>> sorted(distinct_permutations(range(3), r=2))
- [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
-
- """
- # Algorithm: https://w.wiki/Qai
- def _full(A):
- while True:
- # Yield the permutation we have
- yield tuple(A)
-
- # Find the largest index i such that A[i] < A[i + 1]
- for i in range(size - 2, -1, -1):
- if A[i] < A[i + 1]:
- break
- # If no such index exists, this permutation is the last one
- else:
- return
-
- # Find the largest index j greater than j such that A[i] < A[j]
- for j in range(size - 1, i, -1):
- if A[i] < A[j]:
- break
-
- # Swap the value of A[i] with that of A[j], then reverse the
- # sequence from A[i + 1] to form the new permutation
- A[i], A[j] = A[j], A[i]
- A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
-
- # Algorithm: modified from the above
- def _partial(A, r):
- # Split A into the first r items and the last r items
- head, tail = A[:r], A[r:]
- right_head_indexes = range(r - 1, -1, -1)
- left_tail_indexes = range(len(tail))
-
- while True:
- # Yield the permutation we have
- yield tuple(head)
-
- # Starting from the right, find the first index of the head with
- # value smaller than the maximum value of the tail - call it i.
- pivot = tail[-1]
- for i in right_head_indexes:
- if head[i] < pivot:
- break
- pivot = head[i]
- else:
- return
-
- # Starting from the left, find the first value of the tail
- # with a value greater than head[i] and swap.
- for j in left_tail_indexes:
- if tail[j] > head[i]:
- head[i], tail[j] = tail[j], head[i]
- break
- # If we didn't find one, start from the right and find the first
- # index of the head with a value greater than head[i] and swap.
- else:
- for j in right_head_indexes:
- if head[j] > head[i]:
- head[i], head[j] = head[j], head[i]
- break
-
- # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
- tail += head[: i - r : -1] # head[i + 1:][::-1]
- i += 1
- head[i:], tail[:] = tail[: r - i], tail[r - i :]
-
- items = sorted(iterable)
-
- size = len(items)
- if r is None:
- r = size
-
- if 0 < r <= size:
- return _full(items) if (r == size) else _partial(items, r)
-
- return iter(() if r else ((),))
-
-
-def intersperse(e, iterable, n=1):
- """Intersperse filler element *e* among the items in *iterable*, leaving
- *n* items between each filler element.
-
- >>> list(intersperse('!', [1, 2, 3, 4, 5]))
- [1, '!', 2, '!', 3, '!', 4, '!', 5]
-
- >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
- [1, 2, None, 3, 4, None, 5]
-
- """
- if n == 0:
- raise ValueError('n must be > 0')
- elif n == 1:
- # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
- # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
- return islice(interleave(repeat(e), iterable), 1, None)
- else:
- # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
- # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
- # flatten(...) -> x_0, x_1, e, x_2, x_3...
- filler = repeat([e])
- chunks = chunked(iterable, n)
- return flatten(islice(interleave(filler, chunks), 1, None))
-
-
-def unique_to_each(*iterables):
- """Return the elements from each of the input iterables that aren't in the
- other input iterables.
-
- For example, suppose you have a set of packages, each with a set of
- dependencies::
-
- {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
-
- If you remove one package, which dependencies can also be removed?
-
- If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
- associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
- ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
-
- >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
- [['A'], ['C'], ['D']]
-
- If there are duplicates in one input iterable that aren't in the others
- they will be duplicated in the output. Input order is preserved::
-
- >>> unique_to_each("mississippi", "missouri")
- [['p', 'p'], ['o', 'u', 'r']]
-
- It is assumed that the elements of each iterable are hashable.
-
- """
- pool = [list(it) for it in iterables]
- counts = Counter(chain.from_iterable(map(set, pool)))
- uniques = {element for element in counts if counts[element] == 1}
- return [list(filter(uniques.__contains__, it)) for it in pool]
-
-
-def windowed(seq, n, fillvalue=None, step=1):
- """Return a sliding window of width *n* over the given iterable.
-
- >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
- >>> list(all_windows)
- [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
-
- When the window is larger than the iterable, *fillvalue* is used in place
- of missing values:
-
- >>> list(windowed([1, 2, 3], 4))
- [(1, 2, 3, None)]
-
- Each window will advance in increments of *step*:
-
- >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
- [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
-
- To slide into the iterable's items, use :func:`chain` to add filler items
- to the left:
-
- >>> iterable = [1, 2, 3, 4]
- >>> n = 3
- >>> padding = [None] * (n - 1)
- >>> list(windowed(chain(padding, iterable), 3))
- [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
- """
- if n < 0:
- raise ValueError('n must be >= 0')
- if n == 0:
- yield tuple()
- return
- if step < 1:
- raise ValueError('step must be >= 1')
-
- window = deque(maxlen=n)
- i = n
- for _ in map(window.append, seq):
- i -= 1
- if not i:
- i = step
- yield tuple(window)
-
- size = len(window)
- if size < n:
- yield tuple(chain(window, repeat(fillvalue, n - size)))
- elif 0 < i < min(step, n):
- window += (fillvalue,) * i
- yield tuple(window)
-
-
-def substrings(iterable):
- """Yield all of the substrings of *iterable*.
-
- >>> [''.join(s) for s in substrings('more')]
- ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
-
- Note that non-string iterables can also be subdivided.
-
- >>> list(substrings([0, 1, 2]))
- [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
-
- """
- # The length-1 substrings
- seq = []
- for item in iter(iterable):
- seq.append(item)
- yield (item,)
- seq = tuple(seq)
- item_count = len(seq)
-
- # And the rest
- for n in range(2, item_count + 1):
- for i in range(item_count - n + 1):
- yield seq[i : i + n]
-
-
-def substrings_indexes(seq, reverse=False):
- """Yield all substrings and their positions in *seq*
-
- The items yielded will be a tuple of the form ``(substr, i, j)``, where
- ``substr == seq[i:j]``.
-
- This function only works for iterables that support slicing, such as
- ``str`` objects.
-
- >>> for item in substrings_indexes('more'):
- ... print(item)
- ('m', 0, 1)
- ('o', 1, 2)
- ('r', 2, 3)
- ('e', 3, 4)
- ('mo', 0, 2)
- ('or', 1, 3)
- ('re', 2, 4)
- ('mor', 0, 3)
- ('ore', 1, 4)
- ('more', 0, 4)
-
- Set *reverse* to ``True`` to yield the same items in the opposite order.
-
-
- """
- r = range(1, len(seq) + 1)
- if reverse:
- r = reversed(r)
- return (
- (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
- )
-
-
-class bucket:
- """Wrap *iterable* and return an object that buckets it iterable into
- child iterables based on a *key* function.
-
- >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
- >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
- >>> sorted(list(s)) # Get the keys
- ['a', 'b', 'c']
- >>> a_iterable = s['a']
- >>> next(a_iterable)
- 'a1'
- >>> next(a_iterable)
- 'a2'
- >>> list(s['b'])
- ['b1', 'b2', 'b3']
-
- The original iterable will be advanced and its items will be cached until
- they are used by the child iterables. This may require significant storage.
-
- By default, attempting to select a bucket to which no items belong will
- exhaust the iterable and cache all values.
- If you specify a *validator* function, selected buckets will instead be
- checked against it.
-
- >>> from itertools import count
- >>> it = count(1, 2) # Infinite sequence of odd numbers
- >>> key = lambda x: x % 10 # Bucket by last digit
- >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
- >>> s = bucket(it, key=key, validator=validator)
- >>> 2 in s
- False
- >>> list(s[2])
- []
-
- """
-
- def __init__(self, iterable, key, validator=None):
- self._it = iter(iterable)
- self._key = key
- self._cache = defaultdict(deque)
- self._validator = validator or (lambda x: True)
-
- def __contains__(self, value):
- if not self._validator(value):
- return False
-
- try:
- item = next(self[value])
- except StopIteration:
- return False
- else:
- self._cache[value].appendleft(item)
-
- return True
-
- def _get_values(self, value):
- """
- Helper to yield items from the parent iterator that match *value*.
- Items that don't match are stored in the local cache as they
- are encountered.
- """
- while True:
- # If we've cached some items that match the target value, emit
- # the first one and evict it from the cache.
- if self._cache[value]:
- yield self._cache[value].popleft()
- # Otherwise we need to advance the parent iterator to search for
- # a matching item, caching the rest.
- else:
- while True:
- try:
- item = next(self._it)
- except StopIteration:
- return
- item_value = self._key(item)
- if item_value == value:
- yield item
- break
- elif self._validator(item_value):
- self._cache[item_value].append(item)
-
- def __iter__(self):
- for item in self._it:
- item_value = self._key(item)
- if self._validator(item_value):
- self._cache[item_value].append(item)
-
- yield from self._cache.keys()
-
- def __getitem__(self, value):
- if not self._validator(value):
- return iter(())
-
- return self._get_values(value)
-
-
-def spy(iterable, n=1):
- """Return a 2-tuple with a list containing the first *n* elements of
- *iterable*, and an iterator with the same items as *iterable*.
- This allows you to "look ahead" at the items in the iterable without
- advancing it.
-
- There is one item in the list by default:
-
- >>> iterable = 'abcdefg'
- >>> head, iterable = spy(iterable)
- >>> head
- ['a']
- >>> list(iterable)
- ['a', 'b', 'c', 'd', 'e', 'f', 'g']
-
- You may use unpacking to retrieve items instead of lists:
-
- >>> (head,), iterable = spy('abcdefg')
- >>> head
- 'a'
- >>> (first, second), iterable = spy('abcdefg', 2)
- >>> first
- 'a'
- >>> second
- 'b'
-
- The number of items requested can be larger than the number of items in
- the iterable:
-
- >>> iterable = [1, 2, 3, 4, 5]
- >>> head, iterable = spy(iterable, 10)
- >>> head
- [1, 2, 3, 4, 5]
- >>> list(iterable)
- [1, 2, 3, 4, 5]
-
- """
- it = iter(iterable)
- head = take(n, it)
-
- return head.copy(), chain(head, it)
-
-
-def interleave(*iterables):
- """Return a new iterable yielding from each iterable in turn,
- until the shortest is exhausted.
-
- >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
- [1, 4, 6, 2, 5, 7]
-
- For a version that doesn't terminate after the shortest iterable is
- exhausted, see :func:`interleave_longest`.
-
- """
- return chain.from_iterable(zip(*iterables))
-
-
-def interleave_longest(*iterables):
- """Return a new iterable yielding from each iterable in turn,
- skipping any that are exhausted.
-
- >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
- [1, 4, 6, 2, 5, 7, 3, 8]
-
- This function produces the same output as :func:`roundrobin`, but may
- perform better for some inputs (in particular when the number of iterables
- is large).
-
- """
- i = chain.from_iterable(zip_longest(*iterables, fillvalue=_marker))
- return (x for x in i if x is not _marker)
-
-
-def interleave_evenly(iterables, lengths=None):
- """
- Interleave multiple iterables so that their elements are evenly distributed
- throughout the output sequence.
-
- >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
- >>> list(interleave_evenly(iterables))
- [1, 2, 'a', 3, 4, 'b', 5]
-
- >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
- >>> list(interleave_evenly(iterables))
- [1, 6, 4, 2, 7, 3, 8, 5]
-
- This function requires iterables of known length. Iterables without
- ``__len__()`` can be used by manually specifying lengths with *lengths*:
-
- >>> from itertools import combinations, repeat
- >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
- >>> lengths = [4 * (4 - 1) // 2, 3]
- >>> list(interleave_evenly(iterables, lengths=lengths))
- [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
-
- Based on Bresenham's algorithm.
- """
- if lengths is None:
- try:
- lengths = [len(it) for it in iterables]
- except TypeError:
- raise ValueError(
- 'Iterable lengths could not be determined automatically. '
- 'Specify them with the lengths keyword.'
- )
- elif len(iterables) != len(lengths):
- raise ValueError('Mismatching number of iterables and lengths.')
-
- dims = len(lengths)
-
- # sort iterables by length, descending
- lengths_permute = sorted(
- range(dims), key=lambda i: lengths[i], reverse=True
- )
- lengths_desc = [lengths[i] for i in lengths_permute]
- iters_desc = [iter(iterables[i]) for i in lengths_permute]
-
- # the longest iterable is the primary one (Bresenham: the longest
- # distance along an axis)
- delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
- iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
- errors = [delta_primary // dims] * len(deltas_secondary)
-
- to_yield = sum(lengths)
- while to_yield:
- yield next(iter_primary)
- to_yield -= 1
- # update errors for each secondary iterable
- errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
-
- # those iterables for which the error is negative are yielded
- # ("diagonal step" in Bresenham)
- for i, e in enumerate(errors):
- if e < 0:
- yield next(iters_secondary[i])
- to_yield -= 1
- errors[i] += delta_primary
-
-
-def collapse(iterable, base_type=None, levels=None):
- """Flatten an iterable with multiple levels of nesting (e.g., a list of
- lists of tuples) into non-iterable types.
-
- >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
- >>> list(collapse(iterable))
- [1, 2, 3, 4, 5, 6]
-
- Binary and text strings are not considered iterable and
- will not be collapsed.
-
- To avoid collapsing other types, specify *base_type*:
-
- >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
- >>> list(collapse(iterable, base_type=tuple))
- ['ab', ('cd', 'ef'), 'gh', 'ij']
-
- Specify *levels* to stop flattening after a certain level:
-
- >>> iterable = [('a', ['b']), ('c', ['d'])]
- >>> list(collapse(iterable)) # Fully flattened
- ['a', 'b', 'c', 'd']
- >>> list(collapse(iterable, levels=1)) # Only one level flattened
- ['a', ['b'], 'c', ['d']]
-
- """
-
- def walk(node, level):
- if (
- ((levels is not None) and (level > levels))
- or isinstance(node, (str, bytes))
- or ((base_type is not None) and isinstance(node, base_type))
- ):
- yield node
- return
-
- try:
- tree = iter(node)
- except TypeError:
- yield node
- return
- else:
- for child in tree:
- yield from walk(child, level + 1)
-
- yield from walk(iterable, 0)
-
-
-def side_effect(func, iterable, chunk_size=None, before=None, after=None):
- """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
- of items) before yielding the item.
-
- `func` must be a function that takes a single argument. Its return value
- will be discarded.
-
- *before* and *after* are optional functions that take no arguments. They
- will be executed before iteration starts and after it ends, respectively.
-
- `side_effect` can be used for logging, updating progress bars, or anything
- that is not functionally "pure."
-
- Emitting a status message:
-
- >>> from more_itertools import consume
- >>> func = lambda item: print('Received {}'.format(item))
- >>> consume(side_effect(func, range(2)))
- Received 0
- Received 1
-
- Operating on chunks of items:
-
- >>> pair_sums = []
- >>> func = lambda chunk: pair_sums.append(sum(chunk))
- >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
- [0, 1, 2, 3, 4, 5]
- >>> list(pair_sums)
- [1, 5, 9]
-
- Writing to a file-like object:
-
- >>> from io import StringIO
- >>> from more_itertools import consume
- >>> f = StringIO()
- >>> func = lambda x: print(x, file=f)
- >>> before = lambda: print(u'HEADER', file=f)
- >>> after = f.close
- >>> it = [u'a', u'b', u'c']
- >>> consume(side_effect(func, it, before=before, after=after))
- >>> f.closed
- True
-
- """
- try:
- if before is not None:
- before()
-
- if chunk_size is None:
- for item in iterable:
- func(item)
- yield item
- else:
- for chunk in chunked(iterable, chunk_size):
- func(chunk)
- yield from chunk
- finally:
- if after is not None:
- after()
-
-
-def sliced(seq, n, strict=False):
- """Yield slices of length *n* from the sequence *seq*.
-
- >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
- [(1, 2, 3), (4, 5, 6)]
-
- By the default, the last yielded slice will have fewer than *n* elements
- if the length of *seq* is not divisible by *n*:
-
- >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
- [(1, 2, 3), (4, 5, 6), (7, 8)]
-
- If the length of *seq* is not divisible by *n* and *strict* is
- ``True``, then ``ValueError`` will be raised before the last
- slice is yielded.
-
- This function will only work for iterables that support slicing.
- For non-sliceable iterables, see :func:`chunked`.
-
- """
- iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
- if strict:
-
- def ret():
- for _slice in iterator:
- if len(_slice) != n:
- raise ValueError("seq is not divisible by n.")
- yield _slice
-
- return iter(ret())
- else:
- return iterator
-
-
-def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
- """Yield lists of items from *iterable*, where each list is delimited by
- an item where callable *pred* returns ``True``.
-
- >>> list(split_at('abcdcba', lambda x: x == 'b'))
- [['a'], ['c', 'd', 'c'], ['a']]
-
- >>> list(split_at(range(10), lambda n: n % 2 == 1))
- [[0], [2], [4], [6], [8], []]
-
- At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
- then there is no limit on the number of splits:
-
- >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
- [[0], [2], [4, 5, 6, 7, 8, 9]]
-
- By default, the delimiting items are not included in the output.
- The include them, set *keep_separator* to ``True``.
-
- >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
- [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
-
- """
- if maxsplit == 0:
- yield list(iterable)
- return
-
- buf = []
- it = iter(iterable)
- for item in it:
- if pred(item):
- yield buf
- if keep_separator:
- yield [item]
- if maxsplit == 1:
- yield list(it)
- return
- buf = []
- maxsplit -= 1
- else:
- buf.append(item)
- yield buf
-
-
-def split_before(iterable, pred, maxsplit=-1):
- """Yield lists of items from *iterable*, where each list ends just before
- an item for which callable *pred* returns ``True``:
-
- >>> list(split_before('OneTwo', lambda s: s.isupper()))
- [['O', 'n', 'e'], ['T', 'w', 'o']]
-
- >>> list(split_before(range(10), lambda n: n % 3 == 0))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
-
- At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
- then there is no limit on the number of splits:
-
- >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
- """
- if maxsplit == 0:
- yield list(iterable)
- return
-
- buf = []
- it = iter(iterable)
- for item in it:
- if pred(item) and buf:
- yield buf
- if maxsplit == 1:
- yield [item] + list(it)
- return
- buf = []
- maxsplit -= 1
- buf.append(item)
- if buf:
- yield buf
-
-
-def split_after(iterable, pred, maxsplit=-1):
- """Yield lists of items from *iterable*, where each list ends with an
- item where callable *pred* returns ``True``:
-
- >>> list(split_after('one1two2', lambda s: s.isdigit()))
- [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
-
- >>> list(split_after(range(10), lambda n: n % 3 == 0))
- [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
-
- At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
- then there is no limit on the number of splits:
-
- >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
- [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
-
- """
- if maxsplit == 0:
- yield list(iterable)
- return
-
- buf = []
- it = iter(iterable)
- for item in it:
- buf.append(item)
- if pred(item) and buf:
- yield buf
- if maxsplit == 1:
- yield list(it)
- return
- buf = []
- maxsplit -= 1
- if buf:
- yield buf
-
-
-def split_when(iterable, pred, maxsplit=-1):
- """Split *iterable* into pieces based on the output of *pred*.
- *pred* should be a function that takes successive pairs of items and
- returns ``True`` if the iterable should be split in between them.
-
- For example, to find runs of increasing numbers, split the iterable when
- element ``i`` is larger than element ``i + 1``:
-
- >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
- [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
-
- At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
- then there is no limit on the number of splits:
-
- >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
- ... lambda x, y: x > y, maxsplit=2))
- [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
-
- """
- if maxsplit == 0:
- yield list(iterable)
- return
-
- it = iter(iterable)
- try:
- cur_item = next(it)
- except StopIteration:
- return
-
- buf = [cur_item]
- for next_item in it:
- if pred(cur_item, next_item):
- yield buf
- if maxsplit == 1:
- yield [next_item] + list(it)
- return
- buf = []
- maxsplit -= 1
-
- buf.append(next_item)
- cur_item = next_item
-
- yield buf
-
-
-def split_into(iterable, sizes):
- """Yield a list of sequential items from *iterable* of length 'n' for each
- integer 'n' in *sizes*.
-
- >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
- [[1], [2, 3], [4, 5, 6]]
-
- If the sum of *sizes* is smaller than the length of *iterable*, then the
- remaining items of *iterable* will not be returned.
-
- >>> list(split_into([1,2,3,4,5,6], [2,3]))
- [[1, 2], [3, 4, 5]]
-
- If the sum of *sizes* is larger than the length of *iterable*, fewer items
- will be returned in the iteration that overruns *iterable* and further
- lists will be empty:
-
- >>> list(split_into([1,2,3,4], [1,2,3,4]))
- [[1], [2, 3], [4], []]
-
- When a ``None`` object is encountered in *sizes*, the returned list will
- contain items up to the end of *iterable* the same way that itertools.slice
- does:
-
- >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
- [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
-
- :func:`split_into` can be useful for grouping a series of items where the
- sizes of the groups are not uniform. An example would be where in a row
- from a table, multiple columns represent elements of the same feature
- (e.g. a point represented by x,y,z) but, the format is not the same for
- all columns.
- """
- # convert the iterable argument into an iterator so its contents can
- # be consumed by islice in case it is a generator
- it = iter(iterable)
-
- for size in sizes:
- if size is None:
- yield list(it)
- return
- else:
- yield list(islice(it, size))
-
-
-def padded(iterable, fillvalue=None, n=None, next_multiple=False):
- """Yield the elements from *iterable*, followed by *fillvalue*, such that
- at least *n* items are emitted.
-
- >>> list(padded([1, 2, 3], '?', 5))
- [1, 2, 3, '?', '?']
-
- If *next_multiple* is ``True``, *fillvalue* will be emitted until the
- number of items emitted is a multiple of *n*::
-
- >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
- [1, 2, 3, 4, None, None]
-
- If *n* is ``None``, *fillvalue* will be emitted indefinitely.
-
- """
- it = iter(iterable)
- if n is None:
- yield from chain(it, repeat(fillvalue))
- elif n < 1:
- raise ValueError('n must be at least 1')
- else:
- item_count = 0
- for item in it:
- yield item
- item_count += 1
-
- remaining = (n - item_count) % n if next_multiple else n - item_count
- for _ in range(remaining):
- yield fillvalue
-
-
-def repeat_each(iterable, n=2):
- """Repeat each element in *iterable* *n* times.
-
- >>> list(repeat_each('ABC', 3))
- ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
- """
- return chain.from_iterable(map(repeat, iterable, repeat(n)))
-
-
-def repeat_last(iterable, default=None):
- """After the *iterable* is exhausted, keep yielding its last element.
-
- >>> list(islice(repeat_last(range(3)), 5))
- [0, 1, 2, 2, 2]
-
- If the iterable is empty, yield *default* forever::
-
- >>> list(islice(repeat_last(range(0), 42), 5))
- [42, 42, 42, 42, 42]
-
- """
- item = _marker
- for item in iterable:
- yield item
- final = default if item is _marker else item
- yield from repeat(final)
-
-
-def distribute(n, iterable):
- """Distribute the items from *iterable* among *n* smaller iterables.
-
- >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
- >>> list(group_1)
- [1, 3, 5]
- >>> list(group_2)
- [2, 4, 6]
-
- If the length of *iterable* is not evenly divisible by *n*, then the
- length of the returned iterables will not be identical:
-
- >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
- >>> [list(c) for c in children]
- [[1, 4, 7], [2, 5], [3, 6]]
-
- If the length of *iterable* is smaller than *n*, then the last returned
- iterables will be empty:
-
- >>> children = distribute(5, [1, 2, 3])
- >>> [list(c) for c in children]
- [[1], [2], [3], [], []]
-
- This function uses :func:`itertools.tee` and may require significant
- storage. If you need the order items in the smaller iterables to match the
- original iterable, see :func:`divide`.
-
- """
- if n < 1:
- raise ValueError('n must be at least 1')
-
- children = tee(iterable, n)
- return [islice(it, index, None, n) for index, it in enumerate(children)]
-
-
-def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
- """Yield tuples whose elements are offset from *iterable*.
- The amount by which the `i`-th item in each tuple is offset is given by
- the `i`-th item in *offsets*.
-
- >>> list(stagger([0, 1, 2, 3]))
- [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
- >>> list(stagger(range(8), offsets=(0, 2, 4)))
- [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
-
- By default, the sequence will end when the final element of a tuple is the
- last item in the iterable. To continue until the first element of a tuple
- is the last item in the iterable, set *longest* to ``True``::
-
- >>> list(stagger([0, 1, 2, 3], longest=True))
- [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
-
- By default, ``None`` will be used to replace offsets beyond the end of the
- sequence. Specify *fillvalue* to use some other value.
-
- """
- children = tee(iterable, len(offsets))
-
- return zip_offset(
- *children, offsets=offsets, longest=longest, fillvalue=fillvalue
- )
-
-
-class UnequalIterablesError(ValueError):
- def __init__(self, details=None):
- msg = 'Iterables have different lengths'
- if details is not None:
- msg += (': index 0 has length {}; index {} has length {}').format(
- *details
- )
-
- super().__init__(msg)
-
-
-def _zip_equal_generator(iterables):
- for combo in zip_longest(*iterables, fillvalue=_marker):
- for val in combo:
- if val is _marker:
- raise UnequalIterablesError()
- yield combo
-
-
-def _zip_equal(*iterables):
- # Check whether the iterables are all the same size.
- try:
- first_size = len(iterables[0])
- for i, it in enumerate(iterables[1:], 1):
- size = len(it)
- if size != first_size:
- break
- else:
- # If we didn't break out, we can use the built-in zip.
- return zip(*iterables)
-
- # If we did break out, there was a mismatch.
- raise UnequalIterablesError(details=(first_size, i, size))
- # If any one of the iterables didn't have a length, start reading
- # them until one runs out.
- except TypeError:
- return _zip_equal_generator(iterables)
-
-
-def zip_equal(*iterables):
- """``zip`` the input *iterables* together, but raise
- ``UnequalIterablesError`` if they aren't all the same length.
-
- >>> it_1 = range(3)
- >>> it_2 = iter('abc')
- >>> list(zip_equal(it_1, it_2))
- [(0, 'a'), (1, 'b'), (2, 'c')]
-
- >>> it_1 = range(3)
- >>> it_2 = iter('abcd')
- >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- more_itertools.more.UnequalIterablesError: Iterables have different
- lengths
-
- """
- if hexversion >= 0x30A00A6:
- warnings.warn(
- (
- 'zip_equal will be removed in a future version of '
- 'more-itertools. Use the builtin zip function with '
- 'strict=True instead.'
- ),
- DeprecationWarning,
- )
-
- return _zip_equal(*iterables)
-
-
-def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
- """``zip`` the input *iterables* together, but offset the `i`-th iterable
- by the `i`-th item in *offsets*.
-
- >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
- [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
-
- This can be used as a lightweight alternative to SciPy or pandas to analyze
- data sets in which some series have a lead or lag relationship.
-
- By default, the sequence will end when the shortest iterable is exhausted.
- To continue until the longest iterable is exhausted, set *longest* to
- ``True``.
-
- >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
- [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
-
- By default, ``None`` will be used to replace offsets beyond the end of the
- sequence. Specify *fillvalue* to use some other value.
-
- """
- if len(iterables) != len(offsets):
- raise ValueError("Number of iterables and offsets didn't match")
-
- staggered = []
- for it, n in zip(iterables, offsets):
- if n < 0:
- staggered.append(chain(repeat(fillvalue, -n), it))
- elif n > 0:
- staggered.append(islice(it, n, None))
- else:
- staggered.append(it)
-
- if longest:
- return zip_longest(*staggered, fillvalue=fillvalue)
-
- return zip(*staggered)
-
-
-def sort_together(iterables, key_list=(0,), key=None, reverse=False):
- """Return the input iterables sorted together, with *key_list* as the
- priority for sorting. All iterables are trimmed to the length of the
- shortest one.
-
- This can be used like the sorting function in a spreadsheet. If each
- iterable represents a column of data, the key list determines which
- columns are used for sorting.
-
- By default, all iterables are sorted using the ``0``-th iterable::
-
- >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
- >>> sort_together(iterables)
- [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
-
- Set a different key list to sort according to another iterable.
- Specifying multiple keys dictates how ties are broken::
-
- >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
- >>> sort_together(iterables, key_list=(1, 2))
- [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
-
- To sort by a function of the elements of the iterable, pass a *key*
- function. Its arguments are the elements of the iterables corresponding to
- the key list::
-
- >>> names = ('a', 'b', 'c')
- >>> lengths = (1, 2, 3)
- >>> widths = (5, 2, 1)
- >>> def area(length, width):
- ... return length * width
- >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
- [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
-
- Set *reverse* to ``True`` to sort in descending order.
-
- >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
- [(3, 2, 1), ('a', 'b', 'c')]
-
- """
- if key is None:
- # if there is no key function, the key argument to sorted is an
- # itemgetter
- key_argument = itemgetter(*key_list)
- else:
- # if there is a key function, call it with the items at the offsets
- # specified by the key function as arguments
- key_list = list(key_list)
- if len(key_list) == 1:
- # if key_list contains a single item, pass the item at that offset
- # as the only argument to the key function
- key_offset = key_list[0]
- key_argument = lambda zipped_items: key(zipped_items[key_offset])
- else:
- # if key_list contains multiple items, use itemgetter to return a
- # tuple of items, which we pass as *args to the key function
- get_key_items = itemgetter(*key_list)
- key_argument = lambda zipped_items: key(
- *get_key_items(zipped_items)
- )
-
- return list(
- zip(*sorted(zip(*iterables), key=key_argument, reverse=reverse))
- )
-
-
-def unzip(iterable):
- """The inverse of :func:`zip`, this function disaggregates the elements
- of the zipped *iterable*.
-
- The ``i``-th iterable contains the ``i``-th element from each element
- of the zipped iterable. The first element is used to to determine the
- length of the remaining elements.
-
- >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
- >>> letters, numbers = unzip(iterable)
- >>> list(letters)
- ['a', 'b', 'c', 'd']
- >>> list(numbers)
- [1, 2, 3, 4]
-
- This is similar to using ``zip(*iterable)``, but it avoids reading
- *iterable* into memory. Note, however, that this function uses
- :func:`itertools.tee` and thus may require significant storage.
-
- """
- head, iterable = spy(iter(iterable))
- if not head:
- # empty iterable, e.g. zip([], [], [])
- return ()
- # spy returns a one-length iterable as head
- head = head[0]
- iterables = tee(iterable, len(head))
-
- def itemgetter(i):
- def getter(obj):
- try:
- return obj[i]
- except IndexError:
- # basically if we have an iterable like
- # iter([(1, 2, 3), (4, 5), (6,)])
- # the second unzipped iterable would fail at the third tuple
- # since it would try to access tup[1]
- # same with the third unzipped iterable and the second tuple
- # to support these "improperly zipped" iterables,
- # we create a custom itemgetter
- # which just stops the unzipped iterables
- # at first length mismatch
- raise StopIteration
-
- return getter
-
- return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables))
-
-
-def divide(n, iterable):
- """Divide the elements from *iterable* into *n* parts, maintaining
- order.
-
- >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
- >>> list(group_1)
- [1, 2, 3]
- >>> list(group_2)
- [4, 5, 6]
-
- If the length of *iterable* is not evenly divisible by *n*, then the
- length of the returned iterables will not be identical:
-
- >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
- >>> [list(c) for c in children]
- [[1, 2, 3], [4, 5], [6, 7]]
-
- If the length of the iterable is smaller than n, then the last returned
- iterables will be empty:
-
- >>> children = divide(5, [1, 2, 3])
- >>> [list(c) for c in children]
- [[1], [2], [3], [], []]
-
- This function will exhaust the iterable before returning and may require
- significant storage. If order is not important, see :func:`distribute`,
- which does not first pull the iterable into memory.
-
- """
- if n < 1:
- raise ValueError('n must be at least 1')
-
- try:
- iterable[:0]
- except TypeError:
- seq = tuple(iterable)
- else:
- seq = iterable
-
- q, r = divmod(len(seq), n)
-
- ret = []
- stop = 0
- for i in range(1, n + 1):
- start = stop
- stop += q + 1 if i <= r else q
- ret.append(iter(seq[start:stop]))
-
- return ret
-
-
-def always_iterable(obj, base_type=(str, bytes)):
- """If *obj* is iterable, return an iterator over its items::
-
- >>> obj = (1, 2, 3)
- >>> list(always_iterable(obj))
- [1, 2, 3]
-
- If *obj* is not iterable, return a one-item iterable containing *obj*::
-
- >>> obj = 1
- >>> list(always_iterable(obj))
- [1]
-
- If *obj* is ``None``, return an empty iterable:
-
- >>> obj = None
- >>> list(always_iterable(None))
- []
-
- By default, binary and text strings are not considered iterable::
-
- >>> obj = 'foo'
- >>> list(always_iterable(obj))
- ['foo']
-
- If *base_type* is set, objects for which ``isinstance(obj, base_type)``
- returns ``True`` won't be considered iterable.
-
- >>> obj = {'a': 1}
- >>> list(always_iterable(obj)) # Iterate over the dict's keys
- ['a']
- >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
- [{'a': 1}]
-
- Set *base_type* to ``None`` to avoid any special handling and treat objects
- Python considers iterable as iterable:
-
- >>> obj = 'foo'
- >>> list(always_iterable(obj, base_type=None))
- ['f', 'o', 'o']
- """
- if obj is None:
- return iter(())
-
- if (base_type is not None) and isinstance(obj, base_type):
- return iter((obj,))
-
- try:
- return iter(obj)
- except TypeError:
- return iter((obj,))
-
-
-def adjacent(predicate, iterable, distance=1):
- """Return an iterable over `(bool, item)` tuples where the `item` is
- drawn from *iterable* and the `bool` indicates whether
- that item satisfies the *predicate* or is adjacent to an item that does.
-
- For example, to find whether items are adjacent to a ``3``::
-
- >>> list(adjacent(lambda x: x == 3, range(6)))
- [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
-
- Set *distance* to change what counts as adjacent. For example, to find
- whether items are two places away from a ``3``:
-
- >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
- [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
-
- This is useful for contextualizing the results of a search function.
- For example, a code comparison tool might want to identify lines that
- have changed, but also surrounding lines to give the viewer of the diff
- context.
-
- The predicate function will only be called once for each item in the
- iterable.
-
- See also :func:`groupby_transform`, which can be used with this function
- to group ranges of items with the same `bool` value.
-
- """
- # Allow distance=0 mainly for testing that it reproduces results with map()
- if distance < 0:
- raise ValueError('distance must be at least 0')
-
- i1, i2 = tee(iterable)
- padding = [False] * distance
- selected = chain(padding, map(predicate, i1), padding)
- adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
- return zip(adjacent_to_selected, i2)
-
-
-def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
- """An extension of :func:`itertools.groupby` that can apply transformations
- to the grouped data.
-
- * *keyfunc* is a function computing a key value for each item in *iterable*
- * *valuefunc* is a function that transforms the individual items from
- *iterable* after grouping
- * *reducefunc* is a function that transforms each group of items
-
- >>> iterable = 'aAAbBBcCC'
- >>> keyfunc = lambda k: k.upper()
- >>> valuefunc = lambda v: v.lower()
- >>> reducefunc = lambda g: ''.join(g)
- >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
- [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
-
- Each optional argument defaults to an identity function if not specified.
-
- :func:`groupby_transform` is useful when grouping elements of an iterable
- using a separate iterable as the key. To do this, :func:`zip` the iterables
- and pass a *keyfunc* that extracts the first element and a *valuefunc*
- that extracts the second element::
-
- >>> from operator import itemgetter
- >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
- >>> values = 'abcdefghi'
- >>> iterable = zip(keys, values)
- >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
- >>> [(k, ''.join(g)) for k, g in grouper]
- [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
-
- Note that the order of items in the iterable is significant.
- Only adjacent items are grouped together, so if you don't want any
- duplicate groups, you should sort the iterable by the key function.
-
- """
- ret = groupby(iterable, keyfunc)
- if valuefunc:
- ret = ((k, map(valuefunc, g)) for k, g in ret)
- if reducefunc:
- ret = ((k, reducefunc(g)) for k, g in ret)
-
- return ret
-
-
-class numeric_range(abc.Sequence, abc.Hashable):
- """An extension of the built-in ``range()`` function whose arguments can
- be any orderable numeric type.
-
- With only *stop* specified, *start* defaults to ``0`` and *step*
- defaults to ``1``. The output items will match the type of *stop*:
-
- >>> list(numeric_range(3.5))
- [0.0, 1.0, 2.0, 3.0]
-
- With only *start* and *stop* specified, *step* defaults to ``1``. The
- output items will match the type of *start*:
-
- >>> from decimal import Decimal
- >>> start = Decimal('2.1')
- >>> stop = Decimal('5.1')
- >>> list(numeric_range(start, stop))
- [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
-
- With *start*, *stop*, and *step* specified the output items will match
- the type of ``start + step``:
-
- >>> from fractions import Fraction
- >>> start = Fraction(1, 2) # Start at 1/2
- >>> stop = Fraction(5, 2) # End at 5/2
- >>> step = Fraction(1, 2) # Count by 1/2
- >>> list(numeric_range(start, stop, step))
- [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
-
- If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
-
- >>> list(numeric_range(3, -1, -1.0))
- [3.0, 2.0, 1.0, 0.0]
-
- Be aware of the limitations of floating point numbers; the representation
- of the yielded numbers may be surprising.
-
- ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
- is a ``datetime.timedelta`` object:
-
- >>> import datetime
- >>> start = datetime.datetime(2019, 1, 1)
- >>> stop = datetime.datetime(2019, 1, 3)
- >>> step = datetime.timedelta(days=1)
- >>> items = iter(numeric_range(start, stop, step))
- >>> next(items)
- datetime.datetime(2019, 1, 1, 0, 0)
- >>> next(items)
- datetime.datetime(2019, 1, 2, 0, 0)
-
- """
-
- _EMPTY_HASH = hash(range(0, 0))
-
- def __init__(self, *args):
- argc = len(args)
- if argc == 1:
- (self._stop,) = args
- self._start = type(self._stop)(0)
- self._step = type(self._stop - self._start)(1)
- elif argc == 2:
- self._start, self._stop = args
- self._step = type(self._stop - self._start)(1)
- elif argc == 3:
- self._start, self._stop, self._step = args
- elif argc == 0:
- raise TypeError(
- 'numeric_range expected at least '
- '1 argument, got {}'.format(argc)
- )
- else:
- raise TypeError(
- 'numeric_range expected at most '
- '3 arguments, got {}'.format(argc)
- )
-
- self._zero = type(self._step)(0)
- if self._step == self._zero:
- raise ValueError('numeric_range() arg 3 must not be zero')
- self._growing = self._step > self._zero
- self._init_len()
-
- def __bool__(self):
- if self._growing:
- return self._start < self._stop
- else:
- return self._start > self._stop
-
- def __contains__(self, elem):
- if self._growing:
- if self._start <= elem < self._stop:
- return (elem - self._start) % self._step == self._zero
- else:
- if self._start >= elem > self._stop:
- return (self._start - elem) % (-self._step) == self._zero
-
- return False
-
- def __eq__(self, other):
- if isinstance(other, numeric_range):
- empty_self = not bool(self)
- empty_other = not bool(other)
- if empty_self or empty_other:
- return empty_self and empty_other # True if both empty
- else:
- return (
- self._start == other._start
- and self._step == other._step
- and self._get_by_index(-1) == other._get_by_index(-1)
- )
- else:
- return False
-
- def __getitem__(self, key):
- if isinstance(key, int):
- return self._get_by_index(key)
- elif isinstance(key, slice):
- step = self._step if key.step is None else key.step * self._step
-
- if key.start is None or key.start <= -self._len:
- start = self._start
- elif key.start >= self._len:
- start = self._stop
- else: # -self._len < key.start < self._len
- start = self._get_by_index(key.start)
-
- if key.stop is None or key.stop >= self._len:
- stop = self._stop
- elif key.stop <= -self._len:
- stop = self._start
- else: # -self._len < key.stop < self._len
- stop = self._get_by_index(key.stop)
-
- return numeric_range(start, stop, step)
- else:
- raise TypeError(
- 'numeric range indices must be '
- 'integers or slices, not {}'.format(type(key).__name__)
- )
-
- def __hash__(self):
- if self:
- return hash((self._start, self._get_by_index(-1), self._step))
- else:
- return self._EMPTY_HASH
-
- def __iter__(self):
- values = (self._start + (n * self._step) for n in count())
- if self._growing:
- return takewhile(partial(gt, self._stop), values)
- else:
- return takewhile(partial(lt, self._stop), values)
-
- def __len__(self):
- return self._len
-
- def _init_len(self):
- if self._growing:
- start = self._start
- stop = self._stop
- step = self._step
- else:
- start = self._stop
- stop = self._start
- step = -self._step
- distance = stop - start
- if distance <= self._zero:
- self._len = 0
- else: # distance > 0 and step > 0: regular euclidean division
- q, r = divmod(distance, step)
- self._len = int(q) + int(r != self._zero)
-
- def __reduce__(self):
- return numeric_range, (self._start, self._stop, self._step)
-
- def __repr__(self):
- if self._step == 1:
- return "numeric_range({}, {})".format(
- repr(self._start), repr(self._stop)
- )
- else:
- return "numeric_range({}, {}, {})".format(
- repr(self._start), repr(self._stop), repr(self._step)
- )
-
- def __reversed__(self):
- return iter(
- numeric_range(
- self._get_by_index(-1), self._start - self._step, -self._step
- )
- )
-
- def count(self, value):
- return int(value in self)
-
- def index(self, value):
- if self._growing:
- if self._start <= value < self._stop:
- q, r = divmod(value - self._start, self._step)
- if r == self._zero:
- return int(q)
- else:
- if self._start >= value > self._stop:
- q, r = divmod(self._start - value, -self._step)
- if r == self._zero:
- return int(q)
-
- raise ValueError("{} is not in numeric range".format(value))
-
- def _get_by_index(self, i):
- if i < 0:
- i += self._len
- if i < 0 or i >= self._len:
- raise IndexError("numeric range object index out of range")
- return self._start + i * self._step
-
-
-def count_cycle(iterable, n=None):
- """Cycle through the items from *iterable* up to *n* times, yielding
- the number of completed cycles along with each item. If *n* is omitted the
- process repeats indefinitely.
-
- >>> list(count_cycle('AB', 3))
- [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
-
- """
- iterable = tuple(iterable)
- if not iterable:
- return iter(())
- counter = count() if n is None else range(n)
- return ((i, item) for i in counter for item in iterable)
-
-
-def mark_ends(iterable):
- """Yield 3-tuples of the form ``(is_first, is_last, item)``.
-
- >>> list(mark_ends('ABC'))
- [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
-
- Use this when looping over an iterable to take special action on its first
- and/or last items:
-
- >>> iterable = ['Header', 100, 200, 'Footer']
- >>> total = 0
- >>> for is_first, is_last, item in mark_ends(iterable):
- ... if is_first:
- ... continue # Skip the header
- ... if is_last:
- ... continue # Skip the footer
- ... total += item
- >>> print(total)
- 300
- """
- it = iter(iterable)
-
- try:
- b = next(it)
- except StopIteration:
- return
-
- try:
- for i in count():
- a = b
- b = next(it)
- yield i == 0, False, a
-
- except StopIteration:
- yield i == 0, True, a
-
-
-def locate(iterable, pred=bool, window_size=None):
- """Yield the index of each item in *iterable* for which *pred* returns
- ``True``.
-
- *pred* defaults to :func:`bool`, which will select truthy items:
-
- >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
- [1, 2, 4]
-
- Set *pred* to a custom function to, e.g., find the indexes for a particular
- item.
-
- >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
- [1, 3]
-
- If *window_size* is given, then the *pred* function will be called with
- that many items. This enables searching for sub-sequences:
-
- >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
- >>> pred = lambda *args: args == (1, 2, 3)
- >>> list(locate(iterable, pred=pred, window_size=3))
- [1, 5, 9]
-
- Use with :func:`seekable` to find indexes and then retrieve the associated
- items:
-
- >>> from itertools import count
- >>> from more_itertools import seekable
- >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
- >>> it = seekable(source)
- >>> pred = lambda x: x > 100
- >>> indexes = locate(it, pred=pred)
- >>> i = next(indexes)
- >>> it.seek(i)
- >>> next(it)
- 106
-
- """
- if window_size is None:
- return compress(count(), map(pred, iterable))
-
- if window_size < 1:
- raise ValueError('window size must be at least 1')
-
- it = windowed(iterable, window_size, fillvalue=_marker)
- return compress(count(), starmap(pred, it))
-
-
-def lstrip(iterable, pred):
- """Yield the items from *iterable*, but strip any from the beginning
- for which *pred* returns ``True``.
-
- For example, to remove a set of items from the start of an iterable:
-
- >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
- >>> pred = lambda x: x in {None, False, ''}
- >>> list(lstrip(iterable, pred))
- [1, 2, None, 3, False, None]
-
- This function is analogous to to :func:`str.lstrip`, and is essentially
- an wrapper for :func:`itertools.dropwhile`.
-
- """
- return dropwhile(pred, iterable)
-
-
-def rstrip(iterable, pred):
- """Yield the items from *iterable*, but strip any from the end
- for which *pred* returns ``True``.
-
- For example, to remove a set of items from the end of an iterable:
-
- >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
- >>> pred = lambda x: x in {None, False, ''}
- >>> list(rstrip(iterable, pred))
- [None, False, None, 1, 2, None, 3]
-
- This function is analogous to :func:`str.rstrip`.
-
- """
- cache = []
- cache_append = cache.append
- cache_clear = cache.clear
- for x in iterable:
- if pred(x):
- cache_append(x)
- else:
- yield from cache
- cache_clear()
- yield x
-
-
-def strip(iterable, pred):
- """Yield the items from *iterable*, but strip any from the
- beginning and end for which *pred* returns ``True``.
-
- For example, to remove a set of items from both ends of an iterable:
-
- >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
- >>> pred = lambda x: x in {None, False, ''}
- >>> list(strip(iterable, pred))
- [1, 2, None, 3]
-
- This function is analogous to :func:`str.strip`.
-
- """
- return rstrip(lstrip(iterable, pred), pred)
-
-
-class islice_extended:
- """An extension of :func:`itertools.islice` that supports negative values
- for *stop*, *start*, and *step*.
-
- >>> iterable = iter('abcdefgh')
- >>> list(islice_extended(iterable, -4, -1))
- ['e', 'f', 'g']
-
- Slices with negative values require some caching of *iterable*, but this
- function takes care to minimize the amount of memory required.
-
- For example, you can use a negative step with an infinite iterator:
-
- >>> from itertools import count
- >>> list(islice_extended(count(), 110, 99, -2))
- [110, 108, 106, 104, 102, 100]
-
- You can also use slice notation directly:
-
- >>> iterable = map(str, count())
- >>> it = islice_extended(iterable)[10:20:2]
- >>> list(it)
- ['10', '12', '14', '16', '18']
-
- """
-
- def __init__(self, iterable, *args):
- it = iter(iterable)
- if args:
- self._iterable = _islice_helper(it, slice(*args))
- else:
- self._iterable = it
-
- def __iter__(self):
- return self
-
- def __next__(self):
- return next(self._iterable)
-
- def __getitem__(self, key):
- if isinstance(key, slice):
- return islice_extended(_islice_helper(self._iterable, key))
-
- raise TypeError('islice_extended.__getitem__ argument must be a slice')
-
-
-def _islice_helper(it, s):
- start = s.start
- stop = s.stop
- if s.step == 0:
- raise ValueError('step argument must be a non-zero integer or None.')
- step = s.step or 1
-
- if step > 0:
- start = 0 if (start is None) else start
-
- if start < 0:
- # Consume all but the last -start items
- cache = deque(enumerate(it, 1), maxlen=-start)
- len_iter = cache[-1][0] if cache else 0
-
- # Adjust start to be positive
- i = max(len_iter + start, 0)
-
- # Adjust stop to be positive
- if stop is None:
- j = len_iter
- elif stop >= 0:
- j = min(stop, len_iter)
- else:
- j = max(len_iter + stop, 0)
-
- # Slice the cache
- n = j - i
- if n <= 0:
- return
-
- for index, item in islice(cache, 0, n, step):
- yield item
- elif (stop is not None) and (stop < 0):
- # Advance to the start position
- next(islice(it, start, start), None)
-
- # When stop is negative, we have to carry -stop items while
- # iterating
- cache = deque(islice(it, -stop), maxlen=-stop)
-
- for index, item in enumerate(it):
- cached_item = cache.popleft()
- if index % step == 0:
- yield cached_item
- cache.append(item)
- else:
- # When both start and stop are positive we have the normal case
- yield from islice(it, start, stop, step)
- else:
- start = -1 if (start is None) else start
-
- if (stop is not None) and (stop < 0):
- # Consume all but the last items
- n = -stop - 1
- cache = deque(enumerate(it, 1), maxlen=n)
- len_iter = cache[-1][0] if cache else 0
-
- # If start and stop are both negative they are comparable and
- # we can just slice. Otherwise we can adjust start to be negative
- # and then slice.
- if start < 0:
- i, j = start, stop
- else:
- i, j = min(start - len_iter, -1), None
-
- for index, item in list(cache)[i:j:step]:
- yield item
- else:
- # Advance to the stop position
- if stop is not None:
- m = stop + 1
- next(islice(it, m, m), None)
-
- # stop is positive, so if start is negative they are not comparable
- # and we need the rest of the items.
- if start < 0:
- i = start
- n = None
- # stop is None and start is positive, so we just need items up to
- # the start index.
- elif stop is None:
- i = None
- n = start + 1
- # Both stop and start are positive, so they are comparable.
- else:
- i = None
- n = start - stop
- if n <= 0:
- return
-
- cache = list(islice(it, n))
-
- yield from cache[i::step]
-
-
-def always_reversible(iterable):
- """An extension of :func:`reversed` that supports all iterables, not
- just those which implement the ``Reversible`` or ``Sequence`` protocols.
-
- >>> print(*always_reversible(x for x in range(3)))
- 2 1 0
-
- If the iterable is already reversible, this function returns the
- result of :func:`reversed()`. If the iterable is not reversible,
- this function will cache the remaining items in the iterable and
- yield them in reverse order, which may require significant storage.
- """
- try:
- return reversed(iterable)
- except TypeError:
- return reversed(list(iterable))
-
-
-def consecutive_groups(iterable, ordering=lambda x: x):
- """Yield groups of consecutive items using :func:`itertools.groupby`.
- The *ordering* function determines whether two items are adjacent by
- returning their position.
-
- By default, the ordering function is the identity function. This is
- suitable for finding runs of numbers:
-
- >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
- >>> for group in consecutive_groups(iterable):
- ... print(list(group))
- [1]
- [10, 11, 12]
- [20]
- [30, 31, 32, 33]
- [40]
-
- For finding runs of adjacent letters, try using the :meth:`index` method
- of a string of letters:
-
- >>> from string import ascii_lowercase
- >>> iterable = 'abcdfgilmnop'
- >>> ordering = ascii_lowercase.index
- >>> for group in consecutive_groups(iterable, ordering):
- ... print(list(group))
- ['a', 'b', 'c', 'd']
- ['f', 'g']
- ['i']
- ['l', 'm', 'n', 'o', 'p']
-
- Each group of consecutive items is an iterator that shares it source with
- *iterable*. When an an output group is advanced, the previous group is
- no longer available unless its elements are copied (e.g., into a ``list``).
-
- >>> iterable = [1, 2, 11, 12, 21, 22]
- >>> saved_groups = []
- >>> for group in consecutive_groups(iterable):
- ... saved_groups.append(list(group)) # Copy group elements
- >>> saved_groups
- [[1, 2], [11, 12], [21, 22]]
-
- """
- for k, g in groupby(
- enumerate(iterable), key=lambda x: x[0] - ordering(x[1])
- ):
- yield map(itemgetter(1), g)
-
-
-def difference(iterable, func=sub, *, initial=None):
- """This function is the inverse of :func:`itertools.accumulate`. By default
- it will compute the first difference of *iterable* using
- :func:`operator.sub`:
-
- >>> from itertools import accumulate
- >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
- >>> list(difference(iterable))
- [0, 1, 2, 3, 4]
-
- *func* defaults to :func:`operator.sub`, but other functions can be
- specified. They will be applied as follows::
-
- A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
-
- For example, to do progressive division:
-
- >>> iterable = [1, 2, 6, 24, 120]
- >>> func = lambda x, y: x // y
- >>> list(difference(iterable, func))
- [1, 2, 3, 4, 5]
-
- If the *initial* keyword is set, the first element will be skipped when
- computing successive differences.
-
- >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
- >>> list(difference(it, initial=10))
- [1, 2, 3]
-
- """
- a, b = tee(iterable)
- try:
- first = [next(b)]
- except StopIteration:
- return iter([])
-
- if initial is not None:
- first = []
-
- return chain(first, starmap(func, zip(b, a)))
-
-
-class SequenceView(Sequence):
- """Return a read-only view of the sequence object *target*.
-
- :class:`SequenceView` objects are analogous to Python's built-in
- "dictionary view" types. They provide a dynamic view of a sequence's items,
- meaning that when the sequence updates, so does the view.
-
- >>> seq = ['0', '1', '2']
- >>> view = SequenceView(seq)
- >>> view
- SequenceView(['0', '1', '2'])
- >>> seq.append('3')
- >>> view
- SequenceView(['0', '1', '2', '3'])
-
- Sequence views support indexing, slicing, and length queries. They act
- like the underlying sequence, except they don't allow assignment:
-
- >>> view[1]
- '1'
- >>> view[1:-1]
- ['1', '2']
- >>> len(view)
- 4
-
- Sequence views are useful as an alternative to copying, as they don't
- require (much) extra storage.
-
- """
-
- def __init__(self, target):
- if not isinstance(target, Sequence):
- raise TypeError
- self._target = target
-
- def __getitem__(self, index):
- return self._target[index]
-
- def __len__(self):
- return len(self._target)
-
- def __repr__(self):
- return '{}({})'.format(self.__class__.__name__, repr(self._target))
-
-
-class seekable:
- """Wrap an iterator to allow for seeking backward and forward. This
- progressively caches the items in the source iterable so they can be
- re-visited.
-
- Call :meth:`seek` with an index to seek to that position in the source
- iterable.
-
- To "reset" an iterator, seek to ``0``:
-
- >>> from itertools import count
- >>> it = seekable((str(n) for n in count()))
- >>> next(it), next(it), next(it)
- ('0', '1', '2')
- >>> it.seek(0)
- >>> next(it), next(it), next(it)
- ('0', '1', '2')
- >>> next(it)
- '3'
-
- You can also seek forward:
-
- >>> it = seekable((str(n) for n in range(20)))
- >>> it.seek(10)
- >>> next(it)
- '10'
- >>> it.seek(20) # Seeking past the end of the source isn't a problem
- >>> list(it)
- []
- >>> it.seek(0) # Resetting works even after hitting the end
- >>> next(it), next(it), next(it)
- ('0', '1', '2')
-
- Call :meth:`peek` to look ahead one item without advancing the iterator:
-
- >>> it = seekable('1234')
- >>> it.peek()
- '1'
- >>> list(it)
- ['1', '2', '3', '4']
- >>> it.peek(default='empty')
- 'empty'
-
- Before the iterator is at its end, calling :func:`bool` on it will return
- ``True``. After it will return ``False``:
-
- >>> it = seekable('5678')
- >>> bool(it)
- True
- >>> list(it)
- ['5', '6', '7', '8']
- >>> bool(it)
- False
-
- You may view the contents of the cache with the :meth:`elements` method.
- That returns a :class:`SequenceView`, a view that updates automatically:
-
- >>> it = seekable((str(n) for n in range(10)))
- >>> next(it), next(it), next(it)
- ('0', '1', '2')
- >>> elements = it.elements()
- >>> elements
- SequenceView(['0', '1', '2'])
- >>> next(it)
- '3'
- >>> elements
- SequenceView(['0', '1', '2', '3'])
-
- By default, the cache grows as the source iterable progresses, so beware of
- wrapping very large or infinite iterables. Supply *maxlen* to limit the
- size of the cache (this of course limits how far back you can seek).
-
- >>> from itertools import count
- >>> it = seekable((str(n) for n in count()), maxlen=2)
- >>> next(it), next(it), next(it), next(it)
- ('0', '1', '2', '3')
- >>> list(it.elements())
- ['2', '3']
- >>> it.seek(0)
- >>> next(it), next(it), next(it), next(it)
- ('2', '3', '4', '5')
- >>> next(it)
- '6'
-
- """
-
- def __init__(self, iterable, maxlen=None):
- self._source = iter(iterable)
- if maxlen is None:
- self._cache = []
- else:
- self._cache = deque([], maxlen)
- self._index = None
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if self._index is not None:
- try:
- item = self._cache[self._index]
- except IndexError:
- self._index = None
- else:
- self._index += 1
- return item
-
- item = next(self._source)
- self._cache.append(item)
- return item
-
- def __bool__(self):
- try:
- self.peek()
- except StopIteration:
- return False
- return True
-
- def peek(self, default=_marker):
- try:
- peeked = next(self)
- except StopIteration:
- if default is _marker:
- raise
- return default
- if self._index is None:
- self._index = len(self._cache)
- self._index -= 1
- return peeked
-
- def elements(self):
- return SequenceView(self._cache)
-
- def seek(self, index):
- self._index = index
- remainder = index - len(self._cache)
- if remainder > 0:
- consume(self, remainder)
-
-
-class run_length:
- """
- :func:`run_length.encode` compresses an iterable with run-length encoding.
- It yields groups of repeated items with the count of how many times they
- were repeated:
-
- >>> uncompressed = 'abbcccdddd'
- >>> list(run_length.encode(uncompressed))
- [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
-
- :func:`run_length.decode` decompresses an iterable that was previously
- compressed with run-length encoding. It yields the items of the
- decompressed iterable:
-
- >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
- >>> list(run_length.decode(compressed))
- ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
-
- """
-
- @staticmethod
- def encode(iterable):
- return ((k, ilen(g)) for k, g in groupby(iterable))
-
- @staticmethod
- def decode(iterable):
- return chain.from_iterable(repeat(k, n) for k, n in iterable)
-
-
-def exactly_n(iterable, n, predicate=bool):
- """Return ``True`` if exactly ``n`` items in the iterable are ``True``
- according to the *predicate* function.
-
- >>> exactly_n([True, True, False], 2)
- True
- >>> exactly_n([True, True, False], 1)
- False
- >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
- True
-
- The iterable will be advanced until ``n + 1`` truthy items are encountered,
- so avoid calling it on infinite iterables.
-
- """
- return len(take(n + 1, filter(predicate, iterable))) == n
-
-
-def circular_shifts(iterable):
- """Return a list of circular shifts of *iterable*.
-
- >>> circular_shifts(range(4))
- [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
- """
- lst = list(iterable)
- return take(len(lst), windowed(cycle(lst), len(lst)))
-
-
-def make_decorator(wrapping_func, result_index=0):
- """Return a decorator version of *wrapping_func*, which is a function that
- modifies an iterable. *result_index* is the position in that function's
- signature where the iterable goes.
-
- This lets you use itertools on the "production end," i.e. at function
- definition. This can augment what the function returns without changing the
- function's code.
-
- For example, to produce a decorator version of :func:`chunked`:
-
- >>> from more_itertools import chunked
- >>> chunker = make_decorator(chunked, result_index=0)
- >>> @chunker(3)
- ... def iter_range(n):
- ... return iter(range(n))
- ...
- >>> list(iter_range(9))
- [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
-
- To only allow truthy items to be returned:
-
- >>> truth_serum = make_decorator(filter, result_index=1)
- >>> @truth_serum(bool)
- ... def boolean_test():
- ... return [0, 1, '', ' ', False, True]
- ...
- >>> list(boolean_test())
- [1, ' ', True]
-
- The :func:`peekable` and :func:`seekable` wrappers make for practical
- decorators:
-
- >>> from more_itertools import peekable
- >>> peekable_function = make_decorator(peekable)
- >>> @peekable_function()
- ... def str_range(*args):
- ... return (str(x) for x in range(*args))
- ...
- >>> it = str_range(1, 20, 2)
- >>> next(it), next(it), next(it)
- ('1', '3', '5')
- >>> it.peek()
- '7'
- >>> next(it)
- '7'
-
- """
- # See https://sites.google.com/site/bbayles/index/decorator_factory for
- # notes on how this works.
- def decorator(*wrapping_args, **wrapping_kwargs):
- def outer_wrapper(f):
- def inner_wrapper(*args, **kwargs):
- result = f(*args, **kwargs)
- wrapping_args_ = list(wrapping_args)
- wrapping_args_.insert(result_index, result)
- return wrapping_func(*wrapping_args_, **wrapping_kwargs)
-
- return inner_wrapper
-
- return outer_wrapper
-
- return decorator
-
-
-def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
- """Return a dictionary that maps the items in *iterable* to categories
- defined by *keyfunc*, transforms them with *valuefunc*, and
- then summarizes them by category with *reducefunc*.
-
- *valuefunc* defaults to the identity function if it is unspecified.
- If *reducefunc* is unspecified, no summarization takes place:
-
- >>> keyfunc = lambda x: x.upper()
- >>> result = map_reduce('abbccc', keyfunc)
- >>> sorted(result.items())
- [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
-
- Specifying *valuefunc* transforms the categorized items:
-
- >>> keyfunc = lambda x: x.upper()
- >>> valuefunc = lambda x: 1
- >>> result = map_reduce('abbccc', keyfunc, valuefunc)
- >>> sorted(result.items())
- [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
-
- Specifying *reducefunc* summarizes the categorized items:
-
- >>> keyfunc = lambda x: x.upper()
- >>> valuefunc = lambda x: 1
- >>> reducefunc = sum
- >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
- >>> sorted(result.items())
- [('A', 1), ('B', 2), ('C', 3)]
-
- You may want to filter the input iterable before applying the map/reduce
- procedure:
-
- >>> all_items = range(30)
- >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
- >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
- >>> categories = map_reduce(items, keyfunc=keyfunc)
- >>> sorted(categories.items())
- [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
- >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
- >>> sorted(summaries.items())
- [(0, 90), (1, 75)]
-
- Note that all items in the iterable are gathered into a list before the
- summarization step, which may require significant storage.
-
- The returned object is a :obj:`collections.defaultdict` with the
- ``default_factory`` set to ``None``, such that it behaves like a normal
- dictionary.
-
- """
- valuefunc = (lambda x: x) if (valuefunc is None) else valuefunc
-
- ret = defaultdict(list)
- for item in iterable:
- key = keyfunc(item)
- value = valuefunc(item)
- ret[key].append(value)
-
- if reducefunc is not None:
- for key, value_list in ret.items():
- ret[key] = reducefunc(value_list)
-
- ret.default_factory = None
- return ret
-
-
-def rlocate(iterable, pred=bool, window_size=None):
- """Yield the index of each item in *iterable* for which *pred* returns
- ``True``, starting from the right and moving left.
-
- *pred* defaults to :func:`bool`, which will select truthy items:
-
- >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
- [4, 2, 1]
-
- Set *pred* to a custom function to, e.g., find the indexes for a particular
- item:
-
- >>> iterable = iter('abcb')
- >>> pred = lambda x: x == 'b'
- >>> list(rlocate(iterable, pred))
- [3, 1]
-
- If *window_size* is given, then the *pred* function will be called with
- that many items. This enables searching for sub-sequences:
-
- >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
- >>> pred = lambda *args: args == (1, 2, 3)
- >>> list(rlocate(iterable, pred=pred, window_size=3))
- [9, 5, 1]
-
- Beware, this function won't return anything for infinite iterables.
- If *iterable* is reversible, ``rlocate`` will reverse it and search from
- the right. Otherwise, it will search from the left and return the results
- in reverse order.
-
- See :func:`locate` to for other example applications.
-
- """
- if window_size is None:
- try:
- len_iter = len(iterable)
- return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
- except TypeError:
- pass
-
- return reversed(list(locate(iterable, pred, window_size)))
-
-
-def replace(iterable, pred, substitutes, count=None, window_size=1):
- """Yield the items from *iterable*, replacing the items for which *pred*
- returns ``True`` with the items from the iterable *substitutes*.
-
- >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
- >>> pred = lambda x: x == 0
- >>> substitutes = (2, 3)
- >>> list(replace(iterable, pred, substitutes))
- [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
-
- If *count* is given, the number of replacements will be limited:
-
- >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
- >>> pred = lambda x: x == 0
- >>> substitutes = [None]
- >>> list(replace(iterable, pred, substitutes, count=2))
- [1, 1, None, 1, 1, None, 1, 1, 0]
-
- Use *window_size* to control the number of items passed as arguments to
- *pred*. This allows for locating and replacing subsequences.
-
- >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
- >>> window_size = 3
- >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
- >>> substitutes = [3, 4] # Splice in these items
- >>> list(replace(iterable, pred, substitutes, window_size=window_size))
- [3, 4, 5, 3, 4, 5]
-
- """
- if window_size < 1:
- raise ValueError('window_size must be at least 1')
-
- # Save the substitutes iterable, since it's used more than once
- substitutes = tuple(substitutes)
-
- # Add padding such that the number of windows matches the length of the
- # iterable
- it = chain(iterable, [_marker] * (window_size - 1))
- windows = windowed(it, window_size)
-
- n = 0
- for w in windows:
- # If the current window matches our predicate (and we haven't hit
- # our maximum number of replacements), splice in the substitutes
- # and then consume the following windows that overlap with this one.
- # For example, if the iterable is (0, 1, 2, 3, 4...)
- # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
- # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
- if pred(*w):
- if (count is None) or (n < count):
- n += 1
- yield from substitutes
- consume(windows, window_size - 1)
- continue
-
- # If there was no match (or we've reached the replacement limit),
- # yield the first item from the window.
- if w and (w[0] is not _marker):
- yield w[0]
-
-
-def partitions(iterable):
- """Yield all possible order-preserving partitions of *iterable*.
-
- >>> iterable = 'abc'
- >>> for part in partitions(iterable):
- ... print([''.join(p) for p in part])
- ['abc']
- ['a', 'bc']
- ['ab', 'c']
- ['a', 'b', 'c']
-
- This is unrelated to :func:`partition`.
-
- """
- sequence = list(iterable)
- n = len(sequence)
- for i in powerset(range(1, n)):
- yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
-
-
-def set_partitions(iterable, k=None):
- """
- Yield the set partitions of *iterable* into *k* parts. Set partitions are
- not order-preserving.
-
- >>> iterable = 'abc'
- >>> for part in set_partitions(iterable, 2):
- ... print([''.join(p) for p in part])
- ['a', 'bc']
- ['ab', 'c']
- ['b', 'ac']
-
-
- If *k* is not given, every set partition is generated.
-
- >>> iterable = 'abc'
- >>> for part in set_partitions(iterable):
- ... print([''.join(p) for p in part])
- ['abc']
- ['a', 'bc']
- ['ab', 'c']
- ['b', 'ac']
- ['a', 'b', 'c']
-
- """
- L = list(iterable)
- n = len(L)
- if k is not None:
- if k < 1:
- raise ValueError(
- "Can't partition in a negative or zero number of groups"
- )
- elif k > n:
- return
-
- def set_partitions_helper(L, k):
- n = len(L)
- if k == 1:
- yield [L]
- elif n == k:
- yield [[s] for s in L]
- else:
- e, *M = L
- for p in set_partitions_helper(M, k - 1):
- yield [[e], *p]
- for p in set_partitions_helper(M, k):
- for i in range(len(p)):
- yield p[:i] + [[e] + p[i]] + p[i + 1 :]
-
- if k is None:
- for k in range(1, n + 1):
- yield from set_partitions_helper(L, k)
- else:
- yield from set_partitions_helper(L, k)
-
-
-class time_limited:
- """
- Yield items from *iterable* until *limit_seconds* have passed.
- If the time limit expires before all items have been yielded, the
- ``timed_out`` parameter will be set to ``True``.
-
- >>> from time import sleep
- >>> def generator():
- ... yield 1
- ... yield 2
- ... sleep(0.2)
- ... yield 3
- >>> iterable = time_limited(0.1, generator())
- >>> list(iterable)
- [1, 2]
- >>> iterable.timed_out
- True
-
- Note that the time is checked before each item is yielded, and iteration
- stops if the time elapsed is greater than *limit_seconds*. If your time
- limit is 1 second, but it takes 2 seconds to generate the first item from
- the iterable, the function will run for 2 seconds and not yield anything.
-
- """
-
- def __init__(self, limit_seconds, iterable):
- if limit_seconds < 0:
- raise ValueError('limit_seconds must be positive')
- self.limit_seconds = limit_seconds
- self._iterable = iter(iterable)
- self._start_time = monotonic()
- self.timed_out = False
-
- def __iter__(self):
- return self
-
- def __next__(self):
- item = next(self._iterable)
- if monotonic() - self._start_time > self.limit_seconds:
- self.timed_out = True
- raise StopIteration
-
- return item
-
-
-def only(iterable, default=None, too_long=None):
- """If *iterable* has only one item, return it.
- If it has zero items, return *default*.
- If it has more than one item, raise the exception given by *too_long*,
- which is ``ValueError`` by default.
-
- >>> only([], default='missing')
- 'missing'
- >>> only([1])
- 1
- >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError: Expected exactly one item in iterable, but got 1, 2,
- and perhaps more.'
- >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- TypeError
-
- Note that :func:`only` attempts to advance *iterable* twice to ensure there
- is only one item. See :func:`spy` or :func:`peekable` to check
- iterable contents less destructively.
- """
- it = iter(iterable)
- first_value = next(it, default)
-
- try:
- second_value = next(it)
- except StopIteration:
- pass
- else:
- msg = (
- 'Expected exactly one item in iterable, but got {!r}, {!r}, '
- 'and perhaps more.'.format(first_value, second_value)
- )
- raise too_long or ValueError(msg)
-
- return first_value
-
-
-def ichunked(iterable, n):
- """Break *iterable* into sub-iterables with *n* elements each.
- :func:`ichunked` is like :func:`chunked`, but it yields iterables
- instead of lists.
-
- If the sub-iterables are read in order, the elements of *iterable*
- won't be stored in memory.
- If they are read out of order, :func:`itertools.tee` is used to cache
- elements as necessary.
-
- >>> from itertools import count
- >>> all_chunks = ichunked(count(), 4)
- >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
- >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
- [4, 5, 6, 7]
- >>> list(c_1)
- [0, 1, 2, 3]
- >>> list(c_3)
- [8, 9, 10, 11]
-
- """
- source = iter(iterable)
-
- while True:
- # Check to see whether we're at the end of the source iterable
- item = next(source, _marker)
- if item is _marker:
- return
-
- # Clone the source and yield an n-length slice
- source, it = tee(chain([item], source))
- yield islice(it, n)
-
- # Advance the source iterable
- consume(source, n)
-
-
-def distinct_combinations(iterable, r):
- """Yield the distinct combinations of *r* items taken from *iterable*.
-
- >>> list(distinct_combinations([0, 0, 1], 2))
- [(0, 0), (0, 1)]
-
- Equivalent to ``set(combinations(iterable))``, except duplicates are not
- generated and thrown away. For larger input sequences this is much more
- efficient.
-
- """
- if r < 0:
- raise ValueError('r must be non-negative')
- elif r == 0:
- yield ()
- return
- pool = tuple(iterable)
- generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
- current_combo = [None] * r
- level = 0
- while generators:
- try:
- cur_idx, p = next(generators[-1])
- except StopIteration:
- generators.pop()
- level -= 1
- continue
- current_combo[level] = p
- if level + 1 == r:
- yield tuple(current_combo)
- else:
- generators.append(
- unique_everseen(
- enumerate(pool[cur_idx + 1 :], cur_idx + 1),
- key=itemgetter(1),
- )
- )
- level += 1
-
-
-def filter_except(validator, iterable, *exceptions):
- """Yield the items from *iterable* for which the *validator* function does
- not raise one of the specified *exceptions*.
-
- *validator* is called for each item in *iterable*.
- It should be a function that accepts one argument and raises an exception
- if that item is not valid.
-
- >>> iterable = ['1', '2', 'three', '4', None]
- >>> list(filter_except(int, iterable, ValueError, TypeError))
- ['1', '2', '4']
-
- If an exception other than one given by *exceptions* is raised by
- *validator*, it is raised like normal.
- """
- for item in iterable:
- try:
- validator(item)
- except exceptions:
- pass
- else:
- yield item
-
-
-def map_except(function, iterable, *exceptions):
- """Transform each item from *iterable* with *function* and yield the
- result, unless *function* raises one of the specified *exceptions*.
-
- *function* is called to transform each item in *iterable*.
- It should accept one argument.
-
- >>> iterable = ['1', '2', 'three', '4', None]
- >>> list(map_except(int, iterable, ValueError, TypeError))
- [1, 2, 4]
-
- If an exception other than one given by *exceptions* is raised by
- *function*, it is raised like normal.
- """
- for item in iterable:
- try:
- yield function(item)
- except exceptions:
- pass
-
-
-def map_if(iterable, pred, func, func_else=lambda x: x):
- """Evaluate each item from *iterable* using *pred*. If the result is
- equivalent to ``True``, transform the item with *func* and yield it.
- Otherwise, transform the item with *func_else* and yield it.
-
- *pred*, *func*, and *func_else* should each be functions that accept
- one argument. By default, *func_else* is the identity function.
-
- >>> from math import sqrt
- >>> iterable = list(range(-5, 5))
- >>> iterable
- [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
- >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
- [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
- >>> list(map_if(iterable, lambda x: x >= 0,
- ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
- [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
- """
- for item in iterable:
- yield func(item) if pred(item) else func_else(item)
-
-
-def _sample_unweighted(iterable, k):
- # Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li:
- # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
-
- # Fill up the reservoir (collection of samples) with the first `k` samples
- reservoir = take(k, iterable)
-
- # Generate random number that's the largest in a sample of k U(0,1) numbers
- # Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic
- W = exp(log(random()) / k)
-
- # The number of elements to skip before changing the reservoir is a random
- # number with a geometric distribution. Sample it using random() and logs.
- next_index = k + floor(log(random()) / log(1 - W))
-
- for index, element in enumerate(iterable, k):
-
- if index == next_index:
- reservoir[randrange(k)] = element
- # The new W is the largest in a sample of k U(0, `old_W`) numbers
- W *= exp(log(random()) / k)
- next_index += floor(log(random()) / log(1 - W)) + 1
-
- return reservoir
-
-
-def _sample_weighted(iterable, k, weights):
- # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
- # "Weighted random sampling with a reservoir".
-
- # Log-transform for numerical stability for weights that are small/large
- weight_keys = (log(random()) / weight for weight in weights)
-
- # Fill up the reservoir (collection of samples) with the first `k`
- # weight-keys and elements, then heapify the list.
- reservoir = take(k, zip(weight_keys, iterable))
- heapify(reservoir)
-
- # The number of jumps before changing the reservoir is a random variable
- # with an exponential distribution. Sample it using random() and logs.
- smallest_weight_key, _ = reservoir[0]
- weights_to_skip = log(random()) / smallest_weight_key
-
- for weight, element in zip(weights, iterable):
- if weight >= weights_to_skip:
- # The notation here is consistent with the paper, but we store
- # the weight-keys in log-space for better numerical stability.
- smallest_weight_key, _ = reservoir[0]
- t_w = exp(weight * smallest_weight_key)
- r_2 = uniform(t_w, 1) # generate U(t_w, 1)
- weight_key = log(r_2) / weight
- heapreplace(reservoir, (weight_key, element))
- smallest_weight_key, _ = reservoir[0]
- weights_to_skip = log(random()) / smallest_weight_key
- else:
- weights_to_skip -= weight
-
- # Equivalent to [element for weight_key, element in sorted(reservoir)]
- return [heappop(reservoir)[1] for _ in range(k)]
-
-
-def sample(iterable, k, weights=None):
- """Return a *k*-length list of elements chosen (without replacement)
- from the *iterable*. Like :func:`random.sample`, but works on iterables
- of unknown length.
-
- >>> iterable = range(100)
- >>> sample(iterable, 5) # doctest: +SKIP
- [81, 60, 96, 16, 4]
-
- An iterable with *weights* may also be given:
-
- >>> iterable = range(100)
- >>> weights = (i * i + 1 for i in range(100))
- >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
- [79, 67, 74, 66, 78]
-
- The algorithm can also be used to generate weighted random permutations.
- The relative weight of each item determines the probability that it
- appears late in the permutation.
-
- >>> data = "abcdefgh"
- >>> weights = range(1, len(data) + 1)
- >>> sample(data, k=len(data), weights=weights) # doctest: +SKIP
- ['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f']
- """
- if k == 0:
- return []
-
- iterable = iter(iterable)
- if weights is None:
- return _sample_unweighted(iterable, k)
- else:
- weights = iter(weights)
- return _sample_weighted(iterable, k, weights)
-
-
-def is_sorted(iterable, key=None, reverse=False, strict=False):
- """Returns ``True`` if the items of iterable are in sorted order, and
- ``False`` otherwise. *key* and *reverse* have the same meaning that they do
- in the built-in :func:`sorted` function.
-
- >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
- True
- >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
- False
-
- If *strict*, tests for strict sorting, that is, returns ``False`` if equal
- elements are found:
-
- >>> is_sorted([1, 2, 2])
- True
- >>> is_sorted([1, 2, 2], strict=True)
- False
-
- The function returns ``False`` after encountering the first out-of-order
- item. If there are no out-of-order items, the iterable is exhausted.
- """
-
- compare = (le if reverse else ge) if strict else (lt if reverse else gt)
- it = iterable if key is None else map(key, iterable)
- return not any(starmap(compare, pairwise(it)))
-
-
-class AbortThread(BaseException):
- pass
-
-
-class callback_iter:
- """Convert a function that uses callbacks to an iterator.
-
- Let *func* be a function that takes a `callback` keyword argument.
- For example:
-
- >>> def func(callback=None):
- ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
- ... if callback:
- ... callback(i, c)
- ... return 4
-
-
- Use ``with callback_iter(func)`` to get an iterator over the parameters
- that are delivered to the callback.
-
- >>> with callback_iter(func) as it:
- ... for args, kwargs in it:
- ... print(args)
- (1, 'a')
- (2, 'b')
- (3, 'c')
-
- The function will be called in a background thread. The ``done`` property
- indicates whether it has completed execution.
-
- >>> it.done
- True
-
- If it completes successfully, its return value will be available
- in the ``result`` property.
-
- >>> it.result
- 4
-
- Notes:
-
- * If the function uses some keyword argument besides ``callback``, supply
- *callback_kwd*.
- * If it finished executing, but raised an exception, accessing the
- ``result`` property will raise the same exception.
- * If it hasn't finished executing, accessing the ``result``
- property from within the ``with`` block will raise ``RuntimeError``.
- * If it hasn't finished executing, accessing the ``result`` property from
- outside the ``with`` block will raise a
- ``more_itertools.AbortThread`` exception.
- * Provide *wait_seconds* to adjust how frequently the it is polled for
- output.
-
- """
-
- def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
- self._func = func
- self._callback_kwd = callback_kwd
- self._aborted = False
- self._future = None
- self._wait_seconds = wait_seconds
- self._executor = ThreadPoolExecutor(max_workers=1)
- self._iterator = self._reader()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self._aborted = True
- self._executor.shutdown()
-
- def __iter__(self):
- return self
-
- def __next__(self):
- return next(self._iterator)
-
- @property
- def done(self):
- if self._future is None:
- return False
- return self._future.done()
-
- @property
- def result(self):
- if not self.done:
- raise RuntimeError('Function has not yet completed')
-
- return self._future.result()
-
- def _reader(self):
- q = Queue()
-
- def callback(*args, **kwargs):
- if self._aborted:
- raise AbortThread('canceled by user')
-
- q.put((args, kwargs))
-
- self._future = self._executor.submit(
- self._func, **{self._callback_kwd: callback}
- )
-
- while True:
- try:
- item = q.get(timeout=self._wait_seconds)
- except Empty:
- pass
- else:
- q.task_done()
- yield item
-
- if self._future.done():
- break
-
- remaining = []
- while True:
- try:
- item = q.get_nowait()
- except Empty:
- break
- else:
- q.task_done()
- remaining.append(item)
- q.join()
- yield from remaining
-
-
-def windowed_complete(iterable, n):
- """
- Yield ``(beginning, middle, end)`` tuples, where:
-
- * Each ``middle`` has *n* items from *iterable*
- * Each ``beginning`` has the items before the ones in ``middle``
- * Each ``end`` has the items after the ones in ``middle``
-
- >>> iterable = range(7)
- >>> n = 3
- >>> for beginning, middle, end in windowed_complete(iterable, n):
- ... print(beginning, middle, end)
- () (0, 1, 2) (3, 4, 5, 6)
- (0,) (1, 2, 3) (4, 5, 6)
- (0, 1) (2, 3, 4) (5, 6)
- (0, 1, 2) (3, 4, 5) (6,)
- (0, 1, 2, 3) (4, 5, 6) ()
-
- Note that *n* must be at least 0 and most equal to the length of
- *iterable*.
-
- This function will exhaust the iterable and may require significant
- storage.
- """
- if n < 0:
- raise ValueError('n must be >= 0')
-
- seq = tuple(iterable)
- size = len(seq)
-
- if n > size:
- raise ValueError('n must be <= len(seq)')
-
- for i in range(size - n + 1):
- beginning = seq[:i]
- middle = seq[i : i + n]
- end = seq[i + n :]
- yield beginning, middle, end
-
-
-def all_unique(iterable, key=None):
- """
- Returns ``True`` if all the elements of *iterable* are unique (no two
- elements are equal).
-
- >>> all_unique('ABCB')
- False
-
- If a *key* function is specified, it will be used to make comparisons.
-
- >>> all_unique('ABCb')
- True
- >>> all_unique('ABCb', str.lower)
- False
-
- The function returns as soon as the first non-unique element is
- encountered. Iterables with a mix of hashable and unhashable items can
- be used, but the function will be slower for unhashable items.
- """
- seenset = set()
- seenset_add = seenset.add
- seenlist = []
- seenlist_add = seenlist.append
- for element in map(key, iterable) if key else iterable:
- try:
- if element in seenset:
- return False
- seenset_add(element)
- except TypeError:
- if element in seenlist:
- return False
- seenlist_add(element)
- return True
-
-
-def nth_product(index, *args):
- """Equivalent to ``list(product(*args))[index]``.
-
- The products of *args* can be ordered lexicographically.
- :func:`nth_product` computes the product at sort position *index* without
- computing the previous products.
-
- >>> nth_product(8, range(2), range(2), range(2), range(2))
- (1, 0, 0, 0)
-
- ``IndexError`` will be raised if the given *index* is invalid.
- """
- pools = list(map(tuple, reversed(args)))
- ns = list(map(len, pools))
-
- c = reduce(mul, ns)
-
- if index < 0:
- index += c
-
- if not 0 <= index < c:
- raise IndexError
-
- result = []
- for pool, n in zip(pools, ns):
- result.append(pool[index % n])
- index //= n
-
- return tuple(reversed(result))
-
-
-def nth_permutation(iterable, r, index):
- """Equivalent to ``list(permutations(iterable, r))[index]```
-
- The subsequences of *iterable* that are of length *r* where order is
- important can be ordered lexicographically. :func:`nth_permutation`
- computes the subsequence at sort position *index* directly, without
- computing the previous subsequences.
-
- >>> nth_permutation('ghijk', 2, 5)
- ('h', 'i')
-
- ``ValueError`` will be raised If *r* is negative or greater than the length
- of *iterable*.
- ``IndexError`` will be raised if the given *index* is invalid.
- """
- pool = list(iterable)
- n = len(pool)
-
- if r is None or r == n:
- r, c = n, factorial(n)
- elif not 0 <= r < n:
- raise ValueError
- else:
- c = factorial(n) // factorial(n - r)
-
- if index < 0:
- index += c
-
- if not 0 <= index < c:
- raise IndexError
-
- if c == 0:
- return tuple()
-
- result = [0] * r
- q = index * factorial(n) // c if r < n else index
- for d in range(1, n + 1):
- q, i = divmod(q, d)
- if 0 <= n - d < r:
- result[n - d] = i
- if q == 0:
- break
-
- return tuple(map(pool.pop, result))
-
-
-def value_chain(*args):
- """Yield all arguments passed to the function in the same order in which
- they were passed. If an argument itself is iterable then iterate over its
- values.
-
- >>> list(value_chain(1, 2, 3, [4, 5, 6]))
- [1, 2, 3, 4, 5, 6]
-
- Binary and text strings are not considered iterable and are emitted
- as-is:
-
- >>> list(value_chain('12', '34', ['56', '78']))
- ['12', '34', '56', '78']
-
-
- Multiple levels of nesting are not flattened.
-
- """
- for value in args:
- if isinstance(value, (str, bytes)):
- yield value
- continue
- try:
- yield from value
- except TypeError:
- yield value
-
-
-def product_index(element, *args):
- """Equivalent to ``list(product(*args)).index(element)``
-
- The products of *args* can be ordered lexicographically.
- :func:`product_index` computes the first index of *element* without
- computing the previous products.
-
- >>> product_index([8, 2], range(10), range(5))
- 42
-
- ``ValueError`` will be raised if the given *element* isn't in the product
- of *args*.
- """
- index = 0
-
- for x, pool in zip_longest(element, args, fillvalue=_marker):
- if x is _marker or pool is _marker:
- raise ValueError('element is not a product of args')
-
- pool = tuple(pool)
- index = index * len(pool) + pool.index(x)
-
- return index
-
-
-def combination_index(element, iterable):
- """Equivalent to ``list(combinations(iterable, r)).index(element)``
-
- The subsequences of *iterable* that are of length *r* can be ordered
- lexicographically. :func:`combination_index` computes the index of the
- first *element*, without computing the previous combinations.
-
- >>> combination_index('adf', 'abcdefg')
- 10
-
- ``ValueError`` will be raised if the given *element* isn't one of the
- combinations of *iterable*.
- """
- element = enumerate(element)
- k, y = next(element, (None, None))
- if k is None:
- return 0
-
- indexes = []
- pool = enumerate(iterable)
- for n, x in pool:
- if x == y:
- indexes.append(n)
- tmp, y = next(element, (None, None))
- if tmp is None:
- break
- else:
- k = tmp
- else:
- raise ValueError('element is not a combination of iterable')
-
- n, _ = last(pool, default=(n, None))
-
- # Python versiosn below 3.8 don't have math.comb
- index = 1
- for i, j in enumerate(reversed(indexes), start=1):
- j = n - j
- if i <= j:
- index += factorial(j) // (factorial(i) * factorial(j - i))
-
- return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index
-
-
-def permutation_index(element, iterable):
- """Equivalent to ``list(permutations(iterable, r)).index(element)```
-
- The subsequences of *iterable* that are of length *r* where order is
- important can be ordered lexicographically. :func:`permutation_index`
- computes the index of the first *element* directly, without computing
- the previous permutations.
-
- >>> permutation_index([1, 3, 2], range(5))
- 19
-
- ``ValueError`` will be raised if the given *element* isn't one of the
- permutations of *iterable*.
- """
- index = 0
- pool = list(iterable)
- for i, x in zip(range(len(pool), -1, -1), element):
- r = pool.index(x)
- index = index * i + r
- del pool[r]
-
- return index
-
-
-class countable:
- """Wrap *iterable* and keep a count of how many items have been consumed.
-
- The ``items_seen`` attribute starts at ``0`` and increments as the iterable
- is consumed:
-
- >>> iterable = map(str, range(10))
- >>> it = countable(iterable)
- >>> it.items_seen
- 0
- >>> next(it), next(it)
- ('0', '1')
- >>> list(it)
- ['2', '3', '4', '5', '6', '7', '8', '9']
- >>> it.items_seen
- 10
- """
-
- def __init__(self, iterable):
- self._it = iter(iterable)
- self.items_seen = 0
-
- def __iter__(self):
- return self
-
- def __next__(self):
- item = next(self._it)
- self.items_seen += 1
-
- return item
-
-
-def chunked_even(iterable, n):
- """Break *iterable* into lists of approximately length *n*.
- Items are distributed such the lengths of the lists differ by at most
- 1 item.
-
- >>> iterable = [1, 2, 3, 4, 5, 6, 7]
- >>> n = 3
- >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
- [[1, 2, 3], [4, 5], [6, 7]]
- >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
- [[1, 2, 3], [4, 5, 6], [7]]
-
- """
-
- len_method = getattr(iterable, '__len__', None)
-
- if len_method is None:
- return _chunked_even_online(iterable, n)
- else:
- return _chunked_even_finite(iterable, len_method(), n)
-
-
-def _chunked_even_online(iterable, n):
- buffer = []
- maxbuf = n + (n - 2) * (n - 1)
- for x in iterable:
- buffer.append(x)
- if len(buffer) == maxbuf:
- yield buffer[:n]
- buffer = buffer[n:]
- yield from _chunked_even_finite(buffer, len(buffer), n)
-
-
-def _chunked_even_finite(iterable, N, n):
- if N < 1:
- return
-
- # Lists are either size `full_size <= n` or `partial_size = full_size - 1`
- q, r = divmod(N, n)
- num_lists = q + (1 if r > 0 else 0)
- q, r = divmod(N, num_lists)
- full_size = q + (1 if r > 0 else 0)
- partial_size = full_size - 1
- num_full = N - partial_size * num_lists
- num_partial = num_lists - num_full
-
- buffer = []
- iterator = iter(iterable)
-
- # Yield num_full lists of full_size
- for x in iterator:
- buffer.append(x)
- if len(buffer) == full_size:
- yield buffer
- buffer = []
- num_full -= 1
- if num_full <= 0:
- break
-
- # Yield num_partial lists of partial_size
- for x in iterator:
- buffer.append(x)
- if len(buffer) == partial_size:
- yield buffer
- buffer = []
- num_partial -= 1
-
-
-def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
- """A version of :func:`zip` that "broadcasts" any scalar
- (i.e., non-iterable) items into output tuples.
-
- >>> iterable_1 = [1, 2, 3]
- >>> iterable_2 = ['a', 'b', 'c']
- >>> scalar = '_'
- >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
- [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
-
- The *scalar_types* keyword argument determines what types are considered
- scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
- treat strings and byte strings as iterable:
-
- >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
- [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
-
- If the *strict* keyword argument is ``True``, then
- ``UnequalIterablesError`` will be raised if any of the iterables have
- different lengthss.
- """
-
- def is_scalar(obj):
- if scalar_types and isinstance(obj, scalar_types):
- return True
- try:
- iter(obj)
- except TypeError:
- return True
- else:
- return False
-
- size = len(objects)
- if not size:
- return
-
- iterables, iterable_positions = [], []
- scalars, scalar_positions = [], []
- for i, obj in enumerate(objects):
- if is_scalar(obj):
- scalars.append(obj)
- scalar_positions.append(i)
- else:
- iterables.append(iter(obj))
- iterable_positions.append(i)
-
- if len(scalars) == size:
- yield tuple(objects)
- return
-
- zipper = _zip_equal if strict else zip
- for item in zipper(*iterables):
- new_item = [None] * size
-
- for i, elem in zip(iterable_positions, item):
- new_item[i] = elem
-
- for i, elem in zip(scalar_positions, scalars):
- new_item[i] = elem
-
- yield tuple(new_item)
-
-
-def unique_in_window(iterable, n, key=None):
- """Yield the items from *iterable* that haven't been seen recently.
- *n* is the size of the lookback window.
-
- >>> iterable = [0, 1, 0, 2, 3, 0]
- >>> n = 3
- >>> list(unique_in_window(iterable, n))
- [0, 1, 2, 3, 0]
-
- The *key* function, if provided, will be used to determine uniqueness:
-
- >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
- ['a', 'b', 'c', 'd', 'a']
-
- The items in *iterable* must be hashable.
-
- """
- if n <= 0:
- raise ValueError('n must be greater than 0')
-
- window = deque(maxlen=n)
- uniques = set()
- use_key = key is not None
-
- for item in iterable:
- k = key(item) if use_key else item
- if k in uniques:
- continue
-
- if len(uniques) == n:
- uniques.discard(window[0])
-
- uniques.add(k)
- window.append(k)
-
- yield item
-
-
-def duplicates_everseen(iterable, key=None):
- """Yield duplicate elements after their first appearance.
-
- >>> list(duplicates_everseen('mississippi'))
- ['s', 'i', 's', 's', 'i', 'p', 'i']
- >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
- ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
-
- This function is analagous to :func:`unique_everseen` and is subject to
- the same performance considerations.
-
- """
- seen_set = set()
- seen_list = []
- use_key = key is not None
-
- for element in iterable:
- k = key(element) if use_key else element
- try:
- if k not in seen_set:
- seen_set.add(k)
- else:
- yield element
- except TypeError:
- if k not in seen_list:
- seen_list.append(k)
- else:
- yield element
-
-
-def duplicates_justseen(iterable, key=None):
- """Yields serially-duplicate elements after their first appearance.
-
- >>> list(duplicates_justseen('mississippi'))
- ['s', 's', 'p']
- >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
- ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
-
- This function is analagous to :func:`unique_justseen`.
-
- """
- return flatten(
- map(
- lambda group_tuple: islice_extended(group_tuple[1])[1:],
- groupby(iterable, key),
- )
- )
-
-
-def minmax(iterable_or_value, *others, key=None, default=_marker):
- """Returns both the smallest and largest items in an iterable
- or the largest of two or more arguments.
-
- >>> minmax([3, 1, 5])
- (1, 5)
-
- >>> minmax(4, 2, 6)
- (2, 6)
-
- If a *key* function is provided, it will be used to transform the input
- items for comparison.
-
- >>> minmax([5, 30], key=str) # '30' sorts before '5'
- (30, 5)
-
- If a *default* value is provided, it will be returned if there are no
- input items.
-
- >>> minmax([], default=(0, 0))
- (0, 0)
-
- Otherwise ``ValueError`` is raised.
-
- This function is based on the
- `recipe <http://code.activestate.com/recipes/577916/>`__ by
- Raymond Hettinger and takes care to minimize the number of comparisons
- performed.
- """
- iterable = (iterable_or_value, *others) if others else iterable_or_value
-
- it = iter(iterable)
-
- try:
- lo = hi = next(it)
- except StopIteration as e:
- if default is _marker:
- raise ValueError(
- '`minmax()` argument is an empty iterable. '
- 'Provide a `default` value to suppress this error.'
- ) from e
- return default
-
- # Different branches depending on the presence of key. This saves a lot
- # of unimportant copies which would slow the "key=None" branch
- # significantly down.
- if key is None:
- for x, y in zip_longest(it, it, fillvalue=lo):
- if y < x:
- x, y = y, x
- if x < lo:
- lo = x
- if hi < y:
- hi = y
-
- else:
- lo_key = hi_key = key(lo)
-
- for x, y in zip_longest(it, it, fillvalue=lo):
-
- x_key, y_key = key(x), key(y)
-
- if y_key < x_key:
- x, y, x_key, y_key = y, x, y_key, x_key
- if x_key < lo_key:
- lo, lo_key = x, x_key
- if hi_key < y_key:
- hi, hi_key = y, y_key
-
- return lo, hi
diff --git a/contrib/python/more-itertools/py3/more_itertools/more.pyi b/contrib/python/more-itertools/py3/more_itertools/more.pyi
deleted file mode 100644
index fe7d4bdd7a..0000000000
--- a/contrib/python/more-itertools/py3/more_itertools/more.pyi
+++ /dev/null
@@ -1,664 +0,0 @@
-"""Stubs for more_itertools.more"""
-
-from typing import (
- Any,
- Callable,
- Container,
- Dict,
- Generic,
- Hashable,
- Iterable,
- Iterator,
- List,
- Optional,
- Reversible,
- Sequence,
- Sized,
- Tuple,
- Union,
- TypeVar,
- type_check_only,
-)
-from types import TracebackType
-from typing_extensions import ContextManager, Protocol, Type, overload
-
-# Type and type variable definitions
-_T = TypeVar('_T')
-_T1 = TypeVar('_T1')
-_T2 = TypeVar('_T2')
-_U = TypeVar('_U')
-_V = TypeVar('_V')
-_W = TypeVar('_W')
-_T_co = TypeVar('_T_co', covariant=True)
-_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]])
-_Raisable = Union[BaseException, 'Type[BaseException]']
-
-@type_check_only
-class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ...
-
-@type_check_only
-class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ...
-
-def chunked(
- iterable: Iterable[_T], n: Optional[int], strict: bool = ...
-) -> Iterator[List[_T]]: ...
-@overload
-def first(iterable: Iterable[_T]) -> _T: ...
-@overload
-def first(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ...
-@overload
-def last(iterable: Iterable[_T]) -> _T: ...
-@overload
-def last(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ...
-@overload
-def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ...
-@overload
-def nth_or_last(
- iterable: Iterable[_T], n: int, default: _U
-) -> Union[_T, _U]: ...
-
-class peekable(Generic[_T], Iterator[_T]):
- def __init__(self, iterable: Iterable[_T]) -> None: ...
- def __iter__(self) -> peekable[_T]: ...
- def __bool__(self) -> bool: ...
- @overload
- def peek(self) -> _T: ...
- @overload
- def peek(self, default: _U) -> Union[_T, _U]: ...
- def prepend(self, *items: _T) -> None: ...
- def __next__(self) -> _T: ...
- @overload
- def __getitem__(self, index: int) -> _T: ...
- @overload
- def __getitem__(self, index: slice) -> List[_T]: ...
-
-def collate(*iterables: Iterable[_T], **kwargs: Any) -> Iterable[_T]: ...
-def consumer(func: _GenFn) -> _GenFn: ...
-def ilen(iterable: Iterable[object]) -> int: ...
-def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
-def with_iter(
- context_manager: ContextManager[Iterable[_T]],
-) -> Iterator[_T]: ...
-def one(
- iterable: Iterable[_T],
- too_short: Optional[_Raisable] = ...,
- too_long: Optional[_Raisable] = ...,
-) -> _T: ...
-def raise_(exception: _Raisable, *args: Any) -> None: ...
-def strictly_n(
- iterable: Iterable[_T],
- n: int,
- too_short: Optional[_GenFn] = ...,
- too_long: Optional[_GenFn] = ...,
-) -> List[_T]: ...
-def distinct_permutations(
- iterable: Iterable[_T], r: Optional[int] = ...
-) -> Iterator[Tuple[_T, ...]]: ...
-def intersperse(
- e: _U, iterable: Iterable[_T], n: int = ...
-) -> Iterator[Union[_T, _U]]: ...
-def unique_to_each(*iterables: Iterable[_T]) -> List[List[_T]]: ...
-@overload
-def windowed(
- seq: Iterable[_T], n: int, *, step: int = ...
-) -> Iterator[Tuple[Optional[_T], ...]]: ...
-@overload
-def windowed(
- seq: Iterable[_T], n: int, fillvalue: _U, step: int = ...
-) -> Iterator[Tuple[Union[_T, _U], ...]]: ...
-def substrings(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ...
-def substrings_indexes(
- seq: Sequence[_T], reverse: bool = ...
-) -> Iterator[Tuple[Sequence[_T], int, int]]: ...
-
-class bucket(Generic[_T, _U], Container[_U]):
- def __init__(
- self,
- iterable: Iterable[_T],
- key: Callable[[_T], _U],
- validator: Optional[Callable[[object], object]] = ...,
- ) -> None: ...
- def __contains__(self, value: object) -> bool: ...
- def __iter__(self) -> Iterator[_U]: ...
- def __getitem__(self, value: object) -> Iterator[_T]: ...
-
-def spy(
- iterable: Iterable[_T], n: int = ...
-) -> Tuple[List[_T], Iterator[_T]]: ...
-def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ...
-def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ...
-def interleave_evenly(
- iterables: List[Iterable[_T]], lengths: Optional[List[int]] = ...
-) -> Iterator[_T]: ...
-def collapse(
- iterable: Iterable[Any],
- base_type: Optional[type] = ...,
- levels: Optional[int] = ...,
-) -> Iterator[Any]: ...
-@overload
-def side_effect(
- func: Callable[[_T], object],
- iterable: Iterable[_T],
- chunk_size: None = ...,
- before: Optional[Callable[[], object]] = ...,
- after: Optional[Callable[[], object]] = ...,
-) -> Iterator[_T]: ...
-@overload
-def side_effect(
- func: Callable[[List[_T]], object],
- iterable: Iterable[_T],
- chunk_size: int,
- before: Optional[Callable[[], object]] = ...,
- after: Optional[Callable[[], object]] = ...,
-) -> Iterator[_T]: ...
-def sliced(
- seq: Sequence[_T], n: int, strict: bool = ...
-) -> Iterator[Sequence[_T]]: ...
-def split_at(
- iterable: Iterable[_T],
- pred: Callable[[_T], object],
- maxsplit: int = ...,
- keep_separator: bool = ...,
-) -> Iterator[List[_T]]: ...
-def split_before(
- iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
-) -> Iterator[List[_T]]: ...
-def split_after(
- iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
-) -> Iterator[List[_T]]: ...
-def split_when(
- iterable: Iterable[_T],
- pred: Callable[[_T, _T], object],
- maxsplit: int = ...,
-) -> Iterator[List[_T]]: ...
-def split_into(
- iterable: Iterable[_T], sizes: Iterable[Optional[int]]
-) -> Iterator[List[_T]]: ...
-@overload
-def padded(
- iterable: Iterable[_T],
- *,
- n: Optional[int] = ...,
- next_multiple: bool = ...
-) -> Iterator[Optional[_T]]: ...
-@overload
-def padded(
- iterable: Iterable[_T],
- fillvalue: _U,
- n: Optional[int] = ...,
- next_multiple: bool = ...,
-) -> Iterator[Union[_T, _U]]: ...
-@overload
-def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ...
-@overload
-def repeat_last(
- iterable: Iterable[_T], default: _U
-) -> Iterator[Union[_T, _U]]: ...
-def distribute(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ...
-@overload
-def stagger(
- iterable: Iterable[_T],
- offsets: _SizedIterable[int] = ...,
- longest: bool = ...,
-) -> Iterator[Tuple[Optional[_T], ...]]: ...
-@overload
-def stagger(
- iterable: Iterable[_T],
- offsets: _SizedIterable[int] = ...,
- longest: bool = ...,
- fillvalue: _U = ...,
-) -> Iterator[Tuple[Union[_T, _U], ...]]: ...
-
-class UnequalIterablesError(ValueError):
- def __init__(
- self, details: Optional[Tuple[int, int, int]] = ...
- ) -> None: ...
-
-@overload
-def zip_equal(__iter1: Iterable[_T1]) -> Iterator[Tuple[_T1]]: ...
-@overload
-def zip_equal(
- __iter1: Iterable[_T1], __iter2: Iterable[_T2]
-) -> Iterator[Tuple[_T1, _T2]]: ...
-@overload
-def zip_equal(
- __iter1: Iterable[_T],
- __iter2: Iterable[_T],
- __iter3: Iterable[_T],
- *iterables: Iterable[_T]
-) -> Iterator[Tuple[_T, ...]]: ...
-@overload
-def zip_offset(
- __iter1: Iterable[_T1],
- *,
- offsets: _SizedIterable[int],
- longest: bool = ...,
- fillvalue: None = None
-) -> Iterator[Tuple[Optional[_T1]]]: ...
-@overload
-def zip_offset(
- __iter1: Iterable[_T1],
- __iter2: Iterable[_T2],
- *,
- offsets: _SizedIterable[int],
- longest: bool = ...,
- fillvalue: None = None
-) -> Iterator[Tuple[Optional[_T1], Optional[_T2]]]: ...
-@overload
-def zip_offset(
- __iter1: Iterable[_T],
- __iter2: Iterable[_T],
- __iter3: Iterable[_T],
- *iterables: Iterable[_T],
- offsets: _SizedIterable[int],
- longest: bool = ...,
- fillvalue: None = None
-) -> Iterator[Tuple[Optional[_T], ...]]: ...
-@overload
-def zip_offset(
- __iter1: Iterable[_T1],
- *,
- offsets: _SizedIterable[int],
- longest: bool = ...,
- fillvalue: _U,
-) -> Iterator[Tuple[Union[_T1, _U]]]: ...
-@overload
-def zip_offset(
- __iter1: Iterable[_T1],
- __iter2: Iterable[_T2],
- *,
- offsets: _SizedIterable[int],
- longest: bool = ...,
- fillvalue: _U,
-) -> Iterator[Tuple[Union[_T1, _U], Union[_T2, _U]]]: ...
-@overload
-def zip_offset(
- __iter1: Iterable[_T],
- __iter2: Iterable[_T],
- __iter3: Iterable[_T],
- *iterables: Iterable[_T],
- offsets: _SizedIterable[int],
- longest: bool = ...,
- fillvalue: _U,
-) -> Iterator[Tuple[Union[_T, _U], ...]]: ...
-def sort_together(
- iterables: Iterable[Iterable[_T]],
- key_list: Iterable[int] = ...,
- key: Optional[Callable[..., Any]] = ...,
- reverse: bool = ...,
-) -> List[Tuple[_T, ...]]: ...
-def unzip(iterable: Iterable[Sequence[_T]]) -> Tuple[Iterator[_T], ...]: ...
-def divide(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ...
-def always_iterable(
- obj: object,
- base_type: Union[
- type, Tuple[Union[type, Tuple[Any, ...]], ...], None
- ] = ...,
-) -> Iterator[Any]: ...
-def adjacent(
- predicate: Callable[[_T], bool],
- iterable: Iterable[_T],
- distance: int = ...,
-) -> Iterator[Tuple[bool, _T]]: ...
-@overload
-def groupby_transform(
- iterable: Iterable[_T],
- keyfunc: None = None,
- valuefunc: None = None,
- reducefunc: None = None,
-) -> Iterator[Tuple[_T, Iterator[_T]]]: ...
-@overload
-def groupby_transform(
- iterable: Iterable[_T],
- keyfunc: Callable[[_T], _U],
- valuefunc: None,
- reducefunc: None,
-) -> Iterator[Tuple[_U, Iterator[_T]]]: ...
-@overload
-def groupby_transform(
- iterable: Iterable[_T],
- keyfunc: None,
- valuefunc: Callable[[_T], _V],
- reducefunc: None,
-) -> Iterable[Tuple[_T, Iterable[_V]]]: ...
-@overload
-def groupby_transform(
- iterable: Iterable[_T],
- keyfunc: Callable[[_T], _U],
- valuefunc: Callable[[_T], _V],
- reducefunc: None,
-) -> Iterable[Tuple[_U, Iterator[_V]]]: ...
-@overload
-def groupby_transform(
- iterable: Iterable[_T],
- keyfunc: None,
- valuefunc: None,
- reducefunc: Callable[[Iterator[_T]], _W],
-) -> Iterable[Tuple[_T, _W]]: ...
-@overload
-def groupby_transform(
- iterable: Iterable[_T],
- keyfunc: Callable[[_T], _U],
- valuefunc: None,
- reducefunc: Callable[[Iterator[_T]], _W],
-) -> Iterable[Tuple[_U, _W]]: ...
-@overload
-def groupby_transform(
- iterable: Iterable[_T],
- keyfunc: None,
- valuefunc: Callable[[_T], _V],
- reducefunc: Callable[[Iterable[_V]], _W],
-) -> Iterable[Tuple[_T, _W]]: ...
-@overload
-def groupby_transform(
- iterable: Iterable[_T],
- keyfunc: Callable[[_T], _U],
- valuefunc: Callable[[_T], _V],
- reducefunc: Callable[[Iterable[_V]], _W],
-) -> Iterable[Tuple[_U, _W]]: ...
-
-class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]):
- @overload
- def __init__(self, __stop: _T) -> None: ...
- @overload
- def __init__(self, __start: _T, __stop: _T) -> None: ...
- @overload
- def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ...
- def __bool__(self) -> bool: ...
- def __contains__(self, elem: object) -> bool: ...
- def __eq__(self, other: object) -> bool: ...
- @overload
- def __getitem__(self, key: int) -> _T: ...
- @overload
- def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ...
- def __hash__(self) -> int: ...
- def __iter__(self) -> Iterator[_T]: ...
- def __len__(self) -> int: ...
- def __reduce__(
- self,
- ) -> Tuple[Type[numeric_range[_T, _U]], Tuple[_T, _T, _U]]: ...
- def __repr__(self) -> str: ...
- def __reversed__(self) -> Iterator[_T]: ...
- def count(self, value: _T) -> int: ...
- def index(self, value: _T) -> int: ... # type: ignore
-
-def count_cycle(
- iterable: Iterable[_T], n: Optional[int] = ...
-) -> Iterable[Tuple[int, _T]]: ...
-def mark_ends(
- iterable: Iterable[_T],
-) -> Iterable[Tuple[bool, bool, _T]]: ...
-def locate(
- iterable: Iterable[object],
- pred: Callable[..., Any] = ...,
- window_size: Optional[int] = ...,
-) -> Iterator[int]: ...
-def lstrip(
- iterable: Iterable[_T], pred: Callable[[_T], object]
-) -> Iterator[_T]: ...
-def rstrip(
- iterable: Iterable[_T], pred: Callable[[_T], object]
-) -> Iterator[_T]: ...
-def strip(
- iterable: Iterable[_T], pred: Callable[[_T], object]
-) -> Iterator[_T]: ...
-
-class islice_extended(Generic[_T], Iterator[_T]):
- def __init__(
- self, iterable: Iterable[_T], *args: Optional[int]
- ) -> None: ...
- def __iter__(self) -> islice_extended[_T]: ...
- def __next__(self) -> _T: ...
- def __getitem__(self, index: slice) -> islice_extended[_T]: ...
-
-def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ...
-def consecutive_groups(
- iterable: Iterable[_T], ordering: Callable[[_T], int] = ...
-) -> Iterator[Iterator[_T]]: ...
-@overload
-def difference(
- iterable: Iterable[_T],
- func: Callable[[_T, _T], _U] = ...,
- *,
- initial: None = ...
-) -> Iterator[Union[_T, _U]]: ...
-@overload
-def difference(
- iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U
-) -> Iterator[_U]: ...
-
-class SequenceView(Generic[_T], Sequence[_T]):
- def __init__(self, target: Sequence[_T]) -> None: ...
- @overload
- def __getitem__(self, index: int) -> _T: ...
- @overload
- def __getitem__(self, index: slice) -> Sequence[_T]: ...
- def __len__(self) -> int: ...
-
-class seekable(Generic[_T], Iterator[_T]):
- def __init__(
- self, iterable: Iterable[_T], maxlen: Optional[int] = ...
- ) -> None: ...
- def __iter__(self) -> seekable[_T]: ...
- def __next__(self) -> _T: ...
- def __bool__(self) -> bool: ...
- @overload
- def peek(self) -> _T: ...
- @overload
- def peek(self, default: _U) -> Union[_T, _U]: ...
- def elements(self) -> SequenceView[_T]: ...
- def seek(self, index: int) -> None: ...
-
-class run_length:
- @staticmethod
- def encode(iterable: Iterable[_T]) -> Iterator[Tuple[_T, int]]: ...
- @staticmethod
- def decode(iterable: Iterable[Tuple[_T, int]]) -> Iterator[_T]: ...
-
-def exactly_n(
- iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ...
-) -> bool: ...
-def circular_shifts(iterable: Iterable[_T]) -> List[Tuple[_T, ...]]: ...
-def make_decorator(
- wrapping_func: Callable[..., _U], result_index: int = ...
-) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ...
-@overload
-def map_reduce(
- iterable: Iterable[_T],
- keyfunc: Callable[[_T], _U],
- valuefunc: None = ...,
- reducefunc: None = ...,
-) -> Dict[_U, List[_T]]: ...
-@overload
-def map_reduce(
- iterable: Iterable[_T],
- keyfunc: Callable[[_T], _U],
- valuefunc: Callable[[_T], _V],
- reducefunc: None = ...,
-) -> Dict[_U, List[_V]]: ...
-@overload
-def map_reduce(
- iterable: Iterable[_T],
- keyfunc: Callable[[_T], _U],
- valuefunc: None = ...,
- reducefunc: Callable[[List[_T]], _W] = ...,
-) -> Dict[_U, _W]: ...
-@overload
-def map_reduce(
- iterable: Iterable[_T],
- keyfunc: Callable[[_T], _U],
- valuefunc: Callable[[_T], _V],
- reducefunc: Callable[[List[_V]], _W],
-) -> Dict[_U, _W]: ...
-def rlocate(
- iterable: Iterable[_T],
- pred: Callable[..., object] = ...,
- window_size: Optional[int] = ...,
-) -> Iterator[int]: ...
-def replace(
- iterable: Iterable[_T],
- pred: Callable[..., object],
- substitutes: Iterable[_U],
- count: Optional[int] = ...,
- window_size: int = ...,
-) -> Iterator[Union[_T, _U]]: ...
-def partitions(iterable: Iterable[_T]) -> Iterator[List[List[_T]]]: ...
-def set_partitions(
- iterable: Iterable[_T], k: Optional[int] = ...
-) -> Iterator[List[List[_T]]]: ...
-
-class time_limited(Generic[_T], Iterator[_T]):
- def __init__(
- self, limit_seconds: float, iterable: Iterable[_T]
- ) -> None: ...
- def __iter__(self) -> islice_extended[_T]: ...
- def __next__(self) -> _T: ...
-
-@overload
-def only(
- iterable: Iterable[_T], *, too_long: Optional[_Raisable] = ...
-) -> Optional[_T]: ...
-@overload
-def only(
- iterable: Iterable[_T], default: _U, too_long: Optional[_Raisable] = ...
-) -> Union[_T, _U]: ...
-def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ...
-def distinct_combinations(
- iterable: Iterable[_T], r: int
-) -> Iterator[Tuple[_T, ...]]: ...
-def filter_except(
- validator: Callable[[Any], object],
- iterable: Iterable[_T],
- *exceptions: Type[BaseException]
-) -> Iterator[_T]: ...
-def map_except(
- function: Callable[[Any], _U],
- iterable: Iterable[_T],
- *exceptions: Type[BaseException]
-) -> Iterator[_U]: ...
-def map_if(
- iterable: Iterable[Any],
- pred: Callable[[Any], bool],
- func: Callable[[Any], Any],
- func_else: Optional[Callable[[Any], Any]] = ...,
-) -> Iterator[Any]: ...
-def sample(
- iterable: Iterable[_T],
- k: int,
- weights: Optional[Iterable[float]] = ...,
-) -> List[_T]: ...
-def is_sorted(
- iterable: Iterable[_T],
- key: Optional[Callable[[_T], _U]] = ...,
- reverse: bool = False,
- strict: bool = False,
-) -> bool: ...
-
-class AbortThread(BaseException):
- pass
-
-class callback_iter(Generic[_T], Iterator[_T]):
- def __init__(
- self,
- func: Callable[..., Any],
- callback_kwd: str = ...,
- wait_seconds: float = ...,
- ) -> None: ...
- def __enter__(self) -> callback_iter[_T]: ...
- def __exit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_value: Optional[BaseException],
- traceback: Optional[TracebackType],
- ) -> Optional[bool]: ...
- def __iter__(self) -> callback_iter[_T]: ...
- def __next__(self) -> _T: ...
- def _reader(self) -> Iterator[_T]: ...
- @property
- def done(self) -> bool: ...
- @property
- def result(self) -> Any: ...
-
-def windowed_complete(
- iterable: Iterable[_T], n: int
-) -> Iterator[Tuple[_T, ...]]: ...
-def all_unique(
- iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ...
-) -> bool: ...
-def nth_product(index: int, *args: Iterable[_T]) -> Tuple[_T, ...]: ...
-def nth_permutation(
- iterable: Iterable[_T], r: int, index: int
-) -> Tuple[_T, ...]: ...
-def value_chain(*args: Union[_T, Iterable[_T]]) -> Iterable[_T]: ...
-def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ...
-def combination_index(
- element: Iterable[_T], iterable: Iterable[_T]
-) -> int: ...
-def permutation_index(
- element: Iterable[_T], iterable: Iterable[_T]
-) -> int: ...
-def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ...
-
-class countable(Generic[_T], Iterator[_T]):
- def __init__(self, iterable: Iterable[_T]) -> None: ...
- def __iter__(self) -> countable[_T]: ...
- def __next__(self) -> _T: ...
-
-def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[List[_T]]: ...
-def zip_broadcast(
- *objects: Union[_T, Iterable[_T]],
- scalar_types: Union[
- type, Tuple[Union[type, Tuple[Any, ...]], ...], None
- ] = ...,
- strict: bool = ...
-) -> Iterable[Tuple[_T, ...]]: ...
-def unique_in_window(
- iterable: Iterable[_T], n: int, key: Optional[Callable[[_T], _U]] = ...
-) -> Iterator[_T]: ...
-def duplicates_everseen(
- iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ...
-) -> Iterator[_T]: ...
-def duplicates_justseen(
- iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ...
-) -> Iterator[_T]: ...
-
-class _SupportsLessThan(Protocol):
- def __lt__(self, __other: Any) -> bool: ...
-
-_SupportsLessThanT = TypeVar("_SupportsLessThanT", bound=_SupportsLessThan)
-
-@overload
-def minmax(
- iterable_or_value: Iterable[_SupportsLessThanT], *, key: None = None
-) -> Tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
-@overload
-def minmax(
- iterable_or_value: Iterable[_T], *, key: Callable[[_T], _SupportsLessThan]
-) -> Tuple[_T, _T]: ...
-@overload
-def minmax(
- iterable_or_value: Iterable[_SupportsLessThanT],
- *,
- key: None = None,
- default: _U
-) -> Union[_U, Tuple[_SupportsLessThanT, _SupportsLessThanT]]: ...
-@overload
-def minmax(
- iterable_or_value: Iterable[_T],
- *,
- key: Callable[[_T], _SupportsLessThan],
- default: _U,
-) -> Union[_U, Tuple[_T, _T]]: ...
-@overload
-def minmax(
- iterable_or_value: _SupportsLessThanT,
- __other: _SupportsLessThanT,
- *others: _SupportsLessThanT
-) -> Tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
-@overload
-def minmax(
- iterable_or_value: _T,
- __other: _T,
- *others: _T,
- key: Callable[[_T], _SupportsLessThan]
-) -> Tuple[_T, _T]: ...
diff --git a/contrib/python/more-itertools/py3/more_itertools/recipes.py b/contrib/python/more-itertools/py3/more_itertools/recipes.py
deleted file mode 100644
index a2596423a4..0000000000
--- a/contrib/python/more-itertools/py3/more_itertools/recipes.py
+++ /dev/null
@@ -1,698 +0,0 @@
-"""Imported from the recipes section of the itertools documentation.
-
-All functions taken from the recipes section of the itertools library docs
-[1]_.
-Some backward-compatible usability improvements have been made.
-
-.. [1] http://docs.python.org/library/itertools.html#recipes
-
-"""
-import warnings
-from collections import deque
-from itertools import (
- chain,
- combinations,
- count,
- cycle,
- groupby,
- islice,
- repeat,
- starmap,
- tee,
- zip_longest,
-)
-import operator
-from random import randrange, sample, choice
-
-__all__ = [
- 'all_equal',
- 'before_and_after',
- 'consume',
- 'convolve',
- 'dotproduct',
- 'first_true',
- 'flatten',
- 'grouper',
- 'iter_except',
- 'ncycles',
- 'nth',
- 'nth_combination',
- 'padnone',
- 'pad_none',
- 'pairwise',
- 'partition',
- 'powerset',
- 'prepend',
- 'quantify',
- 'random_combination_with_replacement',
- 'random_combination',
- 'random_permutation',
- 'random_product',
- 'repeatfunc',
- 'roundrobin',
- 'sliding_window',
- 'tabulate',
- 'tail',
- 'take',
- 'triplewise',
- 'unique_everseen',
- 'unique_justseen',
-]
-
-
-def take(n, iterable):
- """Return first *n* items of the iterable as a list.
-
- >>> take(3, range(10))
- [0, 1, 2]
-
- If there are fewer than *n* items in the iterable, all of them are
- returned.
-
- >>> take(10, range(3))
- [0, 1, 2]
-
- """
- return list(islice(iterable, n))
-
-
-def tabulate(function, start=0):
- """Return an iterator over the results of ``func(start)``,
- ``func(start + 1)``, ``func(start + 2)``...
-
- *func* should be a function that accepts one integer argument.
-
- If *start* is not specified it defaults to 0. It will be incremented each
- time the iterator is advanced.
-
- >>> square = lambda x: x ** 2
- >>> iterator = tabulate(square, -3)
- >>> take(4, iterator)
- [9, 4, 1, 0]
-
- """
- return map(function, count(start))
-
-
-def tail(n, iterable):
- """Return an iterator over the last *n* items of *iterable*.
-
- >>> t = tail(3, 'ABCDEFG')
- >>> list(t)
- ['E', 'F', 'G']
-
- """
- return iter(deque(iterable, maxlen=n))
-
-
-def consume(iterator, n=None):
- """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
- entirely.
-
- Efficiently exhausts an iterator without returning values. Defaults to
- consuming the whole iterator, but an optional second argument may be
- provided to limit consumption.
-
- >>> i = (x for x in range(10))
- >>> next(i)
- 0
- >>> consume(i, 3)
- >>> next(i)
- 4
- >>> consume(i)
- >>> next(i)
- Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- StopIteration
-
- If the iterator has fewer items remaining than the provided limit, the
- whole iterator will be consumed.
-
- >>> i = (x for x in range(3))
- >>> consume(i, 5)
- >>> next(i)
- Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- StopIteration
-
- """
- # Use functions that consume iterators at C speed.
- if n is None:
- # feed the entire iterator into a zero-length deque
- deque(iterator, maxlen=0)
- else:
- # advance to the empty slice starting at position n
- next(islice(iterator, n, n), None)
-
-
-def nth(iterable, n, default=None):
- """Returns the nth item or a default value.
-
- >>> l = range(10)
- >>> nth(l, 3)
- 3
- >>> nth(l, 20, "zebra")
- 'zebra'
-
- """
- return next(islice(iterable, n, None), default)
-
-
-def all_equal(iterable):
- """
- Returns ``True`` if all the elements are equal to each other.
-
- >>> all_equal('aaaa')
- True
- >>> all_equal('aaab')
- False
-
- """
- g = groupby(iterable)
- return next(g, True) and not next(g, False)
-
-
-def quantify(iterable, pred=bool):
- """Return the how many times the predicate is true.
-
- >>> quantify([True, False, True])
- 2
-
- """
- return sum(map(pred, iterable))
-
-
-def pad_none(iterable):
- """Returns the sequence of elements and then returns ``None`` indefinitely.
-
- >>> take(5, pad_none(range(3)))
- [0, 1, 2, None, None]
-
- Useful for emulating the behavior of the built-in :func:`map` function.
-
- See also :func:`padded`.
-
- """
- return chain(iterable, repeat(None))
-
-
-padnone = pad_none
-
-
-def ncycles(iterable, n):
- """Returns the sequence elements *n* times
-
- >>> list(ncycles(["a", "b"], 3))
- ['a', 'b', 'a', 'b', 'a', 'b']
-
- """
- return chain.from_iterable(repeat(tuple(iterable), n))
-
-
-def dotproduct(vec1, vec2):
- """Returns the dot product of the two iterables.
-
- >>> dotproduct([10, 10], [20, 20])
- 400
-
- """
- return sum(map(operator.mul, vec1, vec2))
-
-
-def flatten(listOfLists):
- """Return an iterator flattening one level of nesting in a list of lists.
-
- >>> list(flatten([[0, 1], [2, 3]]))
- [0, 1, 2, 3]
-
- See also :func:`collapse`, which can flatten multiple levels of nesting.
-
- """
- return chain.from_iterable(listOfLists)
-
-
-def repeatfunc(func, times=None, *args):
- """Call *func* with *args* repeatedly, returning an iterable over the
- results.
-
- If *times* is specified, the iterable will terminate after that many
- repetitions:
-
- >>> from operator import add
- >>> times = 4
- >>> args = 3, 5
- >>> list(repeatfunc(add, times, *args))
- [8, 8, 8, 8]
-
- If *times* is ``None`` the iterable will not terminate:
-
- >>> from random import randrange
- >>> times = None
- >>> args = 1, 11
- >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
- [2, 4, 8, 1, 8, 4]
-
- """
- if times is None:
- return starmap(func, repeat(args))
- return starmap(func, repeat(args, times))
-
-
-def _pairwise(iterable):
- """Returns an iterator of paired items, overlapping, from the original
-
- >>> take(4, pairwise(count()))
- [(0, 1), (1, 2), (2, 3), (3, 4)]
-
- On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
-
- """
- a, b = tee(iterable)
- next(b, None)
- yield from zip(a, b)
-
-
-try:
- from itertools import pairwise as itertools_pairwise
-except ImportError:
- pairwise = _pairwise
-else:
-
- def pairwise(iterable):
- yield from itertools_pairwise(iterable)
-
- pairwise.__doc__ = _pairwise.__doc__
-
-
-def grouper(iterable, n, fillvalue=None):
- """Collect data into fixed-length chunks or blocks.
-
- >>> list(grouper('ABCDEFG', 3, 'x'))
- [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
-
- """
- if isinstance(iterable, int):
- warnings.warn(
- "grouper expects iterable as first parameter", DeprecationWarning
- )
- n, iterable = iterable, n
- args = [iter(iterable)] * n
- return zip_longest(fillvalue=fillvalue, *args)
-
-
-def roundrobin(*iterables):
- """Yields an item from each iterable, alternating between them.
-
- >>> list(roundrobin('ABC', 'D', 'EF'))
- ['A', 'D', 'E', 'B', 'F', 'C']
-
- This function produces the same output as :func:`interleave_longest`, but
- may perform better for some inputs (in particular when the number of
- iterables is small).
-
- """
- # Recipe credited to George Sakkis
- pending = len(iterables)
- nexts = cycle(iter(it).__next__ for it in iterables)
- while pending:
- try:
- for next in nexts:
- yield next()
- except StopIteration:
- pending -= 1
- nexts = cycle(islice(nexts, pending))
-
-
-def partition(pred, iterable):
- """
- Returns a 2-tuple of iterables derived from the input iterable.
- The first yields the items that have ``pred(item) == False``.
- The second yields the items that have ``pred(item) == True``.
-
- >>> is_odd = lambda x: x % 2 != 0
- >>> iterable = range(10)
- >>> even_items, odd_items = partition(is_odd, iterable)
- >>> list(even_items), list(odd_items)
- ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
-
- If *pred* is None, :func:`bool` is used.
-
- >>> iterable = [0, 1, False, True, '', ' ']
- >>> false_items, true_items = partition(None, iterable)
- >>> list(false_items), list(true_items)
- ([0, False, ''], [1, True, ' '])
-
- """
- if pred is None:
- pred = bool
-
- evaluations = ((pred(x), x) for x in iterable)
- t1, t2 = tee(evaluations)
- return (
- (x for (cond, x) in t1 if not cond),
- (x for (cond, x) in t2 if cond),
- )
-
-
-def powerset(iterable):
- """Yields all possible subsets of the iterable.
-
- >>> list(powerset([1, 2, 3]))
- [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
-
- :func:`powerset` will operate on iterables that aren't :class:`set`
- instances, so repeated elements in the input will produce repeated elements
- in the output. Use :func:`unique_everseen` on the input to avoid generating
- duplicates:
-
- >>> seq = [1, 1, 0]
- >>> list(powerset(seq))
- [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
- >>> from more_itertools import unique_everseen
- >>> list(powerset(unique_everseen(seq)))
- [(), (1,), (0,), (1, 0)]
-
- """
- s = list(iterable)
- return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
-
-
-def unique_everseen(iterable, key=None):
- """
- Yield unique elements, preserving order.
-
- >>> list(unique_everseen('AAAABBBCCDAABBB'))
- ['A', 'B', 'C', 'D']
- >>> list(unique_everseen('ABBCcAD', str.lower))
- ['A', 'B', 'C', 'D']
-
- Sequences with a mix of hashable and unhashable items can be used.
- The function will be slower (i.e., `O(n^2)`) for unhashable items.
-
- Remember that ``list`` objects are unhashable - you can use the *key*
- parameter to transform the list to a tuple (which is hashable) to
- avoid a slowdown.
-
- >>> iterable = ([1, 2], [2, 3], [1, 2])
- >>> list(unique_everseen(iterable)) # Slow
- [[1, 2], [2, 3]]
- >>> list(unique_everseen(iterable, key=tuple)) # Faster
- [[1, 2], [2, 3]]
-
- Similary, you may want to convert unhashable ``set`` objects with
- ``key=frozenset``. For ``dict`` objects,
- ``key=lambda x: frozenset(x.items())`` can be used.
-
- """
- seenset = set()
- seenset_add = seenset.add
- seenlist = []
- seenlist_add = seenlist.append
- use_key = key is not None
-
- for element in iterable:
- k = key(element) if use_key else element
- try:
- if k not in seenset:
- seenset_add(k)
- yield element
- except TypeError:
- if k not in seenlist:
- seenlist_add(k)
- yield element
-
-
-def unique_justseen(iterable, key=None):
- """Yields elements in order, ignoring serial duplicates
-
- >>> list(unique_justseen('AAAABBBCCDAABBB'))
- ['A', 'B', 'C', 'D', 'A', 'B']
- >>> list(unique_justseen('ABBCcAD', str.lower))
- ['A', 'B', 'C', 'A', 'D']
-
- """
- return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
-
-
-def iter_except(func, exception, first=None):
- """Yields results from a function repeatedly until an exception is raised.
-
- Converts a call-until-exception interface to an iterator interface.
- Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
- to end the loop.
-
- >>> l = [0, 1, 2]
- >>> list(iter_except(l.pop, IndexError))
- [2, 1, 0]
-
- Multiple exceptions can be specified as a stopping condition:
-
- >>> l = [1, 2, 3, '...', 4, 5, 6]
- >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
- [7, 6, 5]
- >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
- [4, 3, 2]
- >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
- []
-
- """
- try:
- if first is not None:
- yield first()
- while 1:
- yield func()
- except exception:
- pass
-
-
-def first_true(iterable, default=None, pred=None):
- """
- Returns the first true value in the iterable.
-
- If no true value is found, returns *default*
-
- If *pred* is not None, returns the first item for which
- ``pred(item) == True`` .
-
- >>> first_true(range(10))
- 1
- >>> first_true(range(10), pred=lambda x: x > 5)
- 6
- >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
- 'missing'
-
- """
- return next(filter(pred, iterable), default)
-
-
-def random_product(*args, repeat=1):
- """Draw an item at random from each of the input iterables.
-
- >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
- ('c', 3, 'Z')
-
- If *repeat* is provided as a keyword argument, that many items will be
- drawn from each iterable.
-
- >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
- ('a', 2, 'd', 3)
-
- This equivalent to taking a random selection from
- ``itertools.product(*args, **kwarg)``.
-
- """
- pools = [tuple(pool) for pool in args] * repeat
- return tuple(choice(pool) for pool in pools)
-
-
-def random_permutation(iterable, r=None):
- """Return a random *r* length permutation of the elements in *iterable*.
-
- If *r* is not specified or is ``None``, then *r* defaults to the length of
- *iterable*.
-
- >>> random_permutation(range(5)) # doctest:+SKIP
- (3, 4, 0, 1, 2)
-
- This equivalent to taking a random selection from
- ``itertools.permutations(iterable, r)``.
-
- """
- pool = tuple(iterable)
- r = len(pool) if r is None else r
- return tuple(sample(pool, r))
-
-
-def random_combination(iterable, r):
- """Return a random *r* length subsequence of the elements in *iterable*.
-
- >>> random_combination(range(5), 3) # doctest:+SKIP
- (2, 3, 4)
-
- This equivalent to taking a random selection from
- ``itertools.combinations(iterable, r)``.
-
- """
- pool = tuple(iterable)
- n = len(pool)
- indices = sorted(sample(range(n), r))
- return tuple(pool[i] for i in indices)
-
-
-def random_combination_with_replacement(iterable, r):
- """Return a random *r* length subsequence of elements in *iterable*,
- allowing individual elements to be repeated.
-
- >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
- (0, 0, 1, 2, 2)
-
- This equivalent to taking a random selection from
- ``itertools.combinations_with_replacement(iterable, r)``.
-
- """
- pool = tuple(iterable)
- n = len(pool)
- indices = sorted(randrange(n) for i in range(r))
- return tuple(pool[i] for i in indices)
-
-
-def nth_combination(iterable, r, index):
- """Equivalent to ``list(combinations(iterable, r))[index]``.
-
- The subsequences of *iterable* that are of length *r* can be ordered
- lexicographically. :func:`nth_combination` computes the subsequence at
- sort position *index* directly, without computing the previous
- subsequences.
-
- >>> nth_combination(range(5), 3, 5)
- (0, 3, 4)
-
- ``ValueError`` will be raised If *r* is negative or greater than the length
- of *iterable*.
- ``IndexError`` will be raised if the given *index* is invalid.
- """
- pool = tuple(iterable)
- n = len(pool)
- if (r < 0) or (r > n):
- raise ValueError
-
- c = 1
- k = min(r, n - r)
- for i in range(1, k + 1):
- c = c * (n - k + i) // i
-
- if index < 0:
- index += c
-
- if (index < 0) or (index >= c):
- raise IndexError
-
- result = []
- while r:
- c, n, r = c * r // n, n - 1, r - 1
- while index >= c:
- index -= c
- c, n = c * (n - r) // n, n - 1
- result.append(pool[-1 - n])
-
- return tuple(result)
-
-
-def prepend(value, iterator):
- """Yield *value*, followed by the elements in *iterator*.
-
- >>> value = '0'
- >>> iterator = ['1', '2', '3']
- >>> list(prepend(value, iterator))
- ['0', '1', '2', '3']
-
- To prepend multiple values, see :func:`itertools.chain`
- or :func:`value_chain`.
-
- """
- return chain([value], iterator)
-
-
-def convolve(signal, kernel):
- """Convolve the iterable *signal* with the iterable *kernel*.
-
- >>> signal = (1, 2, 3, 4, 5)
- >>> kernel = [3, 2, 1]
- >>> list(convolve(signal, kernel))
- [3, 8, 14, 20, 26, 14, 5]
-
- Note: the input arguments are not interchangeable, as the *kernel*
- is immediately consumed and stored.
-
- """
- kernel = tuple(kernel)[::-1]
- n = len(kernel)
- window = deque([0], maxlen=n) * n
- for x in chain(signal, repeat(0, n - 1)):
- window.append(x)
- yield sum(map(operator.mul, kernel, window))
-
-
-def before_and_after(predicate, it):
- """A variant of :func:`takewhile` that allows complete access to the
- remainder of the iterator.
-
- >>> it = iter('ABCdEfGhI')
- >>> all_upper, remainder = before_and_after(str.isupper, it)
- >>> ''.join(all_upper)
- 'ABC'
- >>> ''.join(remainder) # takewhile() would lose the 'd'
- 'dEfGhI'
-
- Note that the first iterator must be fully consumed before the second
- iterator can generate valid results.
- """
- it = iter(it)
- transition = []
-
- def true_iterator():
- for elem in it:
- if predicate(elem):
- yield elem
- else:
- transition.append(elem)
- return
-
- def remainder_iterator():
- yield from transition
- yield from it
-
- return true_iterator(), remainder_iterator()
-
-
-def triplewise(iterable):
- """Return overlapping triplets from *iterable*.
-
- >>> list(triplewise('ABCDE'))
- [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
-
- """
- for (a, _), (b, c) in pairwise(pairwise(iterable)):
- yield a, b, c
-
-
-def sliding_window(iterable, n):
- """Return a sliding window of width *n* over *iterable*.
-
- >>> list(sliding_window(range(6), 4))
- [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
-
- If *iterable* has fewer than *n* items, then nothing is yielded:
-
- >>> list(sliding_window(range(3), 4))
- []
-
- For a variant with more features, see :func:`windowed`.
- """
- it = iter(iterable)
- window = deque(islice(it, n), maxlen=n)
- if len(window) == n:
- yield tuple(window)
- for x in it:
- window.append(x)
- yield tuple(window)
diff --git a/contrib/python/more-itertools/py3/more_itertools/recipes.pyi b/contrib/python/more-itertools/py3/more_itertools/recipes.pyi
deleted file mode 100644
index 4648a41b5e..0000000000
--- a/contrib/python/more-itertools/py3/more_itertools/recipes.pyi
+++ /dev/null
@@ -1,112 +0,0 @@
-"""Stubs for more_itertools.recipes"""
-from typing import (
- Any,
- Callable,
- Iterable,
- Iterator,
- List,
- Optional,
- Tuple,
- TypeVar,
- Union,
-)
-from typing_extensions import overload, Type
-
-# Type and type variable definitions
-_T = TypeVar('_T')
-_U = TypeVar('_U')
-
-def take(n: int, iterable: Iterable[_T]) -> List[_T]: ...
-def tabulate(
- function: Callable[[int], _T], start: int = ...
-) -> Iterator[_T]: ...
-def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ...
-def consume(iterator: Iterable[object], n: Optional[int] = ...) -> None: ...
-@overload
-def nth(iterable: Iterable[_T], n: int) -> Optional[_T]: ...
-@overload
-def nth(iterable: Iterable[_T], n: int, default: _U) -> Union[_T, _U]: ...
-def all_equal(iterable: Iterable[object]) -> bool: ...
-def quantify(
- iterable: Iterable[_T], pred: Callable[[_T], bool] = ...
-) -> int: ...
-def pad_none(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ...
-def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ...
-def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ...
-def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ...
-def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ...
-def repeatfunc(
- func: Callable[..., _U], times: Optional[int] = ..., *args: Any
-) -> Iterator[_U]: ...
-def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: ...
-@overload
-def grouper(
- iterable: Iterable[_T], n: int
-) -> Iterator[Tuple[Optional[_T], ...]]: ...
-@overload
-def grouper(
- iterable: Iterable[_T], n: int, fillvalue: _U
-) -> Iterator[Tuple[Union[_T, _U], ...]]: ...
-@overload
-def grouper( # Deprecated interface
- iterable: int, n: Iterable[_T]
-) -> Iterator[Tuple[Optional[_T], ...]]: ...
-@overload
-def grouper( # Deprecated interface
- iterable: int, n: Iterable[_T], fillvalue: _U
-) -> Iterator[Tuple[Union[_T, _U], ...]]: ...
-def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ...
-def partition(
- pred: Optional[Callable[[_T], object]], iterable: Iterable[_T]
-) -> Tuple[Iterator[_T], Iterator[_T]]: ...
-def powerset(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ...
-def unique_everseen(
- iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ...
-) -> Iterator[_T]: ...
-def unique_justseen(
- iterable: Iterable[_T], key: Optional[Callable[[_T], object]] = ...
-) -> Iterator[_T]: ...
-@overload
-def iter_except(
- func: Callable[[], _T],
- exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
- first: None = ...,
-) -> Iterator[_T]: ...
-@overload
-def iter_except(
- func: Callable[[], _T],
- exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
- first: Callable[[], _U],
-) -> Iterator[Union[_T, _U]]: ...
-@overload
-def first_true(
- iterable: Iterable[_T], *, pred: Optional[Callable[[_T], object]] = ...
-) -> Optional[_T]: ...
-@overload
-def first_true(
- iterable: Iterable[_T],
- default: _U,
- pred: Optional[Callable[[_T], object]] = ...,
-) -> Union[_T, _U]: ...
-def random_product(
- *args: Iterable[_T], repeat: int = ...
-) -> Tuple[_T, ...]: ...
-def random_permutation(
- iterable: Iterable[_T], r: Optional[int] = ...
-) -> Tuple[_T, ...]: ...
-def random_combination(iterable: Iterable[_T], r: int) -> Tuple[_T, ...]: ...
-def random_combination_with_replacement(
- iterable: Iterable[_T], r: int
-) -> Tuple[_T, ...]: ...
-def nth_combination(
- iterable: Iterable[_T], r: int, index: int
-) -> Tuple[_T, ...]: ...
-def prepend(value: _T, iterator: Iterable[_U]) -> Iterator[Union[_T, _U]]: ...
-def convolve(signal: Iterable[_T], kernel: Iterable[_T]) -> Iterator[_T]: ...
-def before_and_after(
- predicate: Callable[[_T], bool], it: Iterable[_T]
-) -> Tuple[Iterator[_T], Iterator[_T]]: ...
-def triplewise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T, _T]]: ...
-def sliding_window(
- iterable: Iterable[_T], n: int
-) -> Iterator[Tuple[_T, ...]]: ...
diff --git a/contrib/python/more-itertools/py3/patches/01-fix-tests.patch b/contrib/python/more-itertools/py3/patches/01-fix-tests.patch
deleted file mode 100644
index 497d4d8da4..0000000000
--- a/contrib/python/more-itertools/py3/patches/01-fix-tests.patch
+++ /dev/null
@@ -1,17 +0,0 @@
---- contrib/python/more-itertools/py3/tests/test_more.py (index)
-+++ contrib/python/more-itertools/py3/tests/test_more.py (working tree)
-@@ -177,13 +177,13 @@ class IterOnlyRange:
- """User-defined iterable class which only support __iter__.
-
- >>> r = IterOnlyRange(5)
-- >>> r[0]
-+ >>> r[0] # doctest: +SKIP
- AttributeError: IterOnlyRange instance has no attribute '__getitem__'
-
- Note: In Python 3, ``TypeError`` will be raised because ``object`` is
- inherited implicitly by default.
-
-- >>> r[0]
-+ >>> r[0] # doctest: +SKIP
- TypeError: 'IterOnlyRange' object does not support indexing
- """
diff --git a/contrib/python/more-itertools/py3/tests/test_more.py b/contrib/python/more-itertools/py3/tests/test_more.py
deleted file mode 100644
index 9a15025899..0000000000
--- a/contrib/python/more-itertools/py3/tests/test_more.py
+++ /dev/null
@@ -1,5033 +0,0 @@
-import warnings
-
-from collections import Counter, abc
-from collections.abc import Set
-from datetime import datetime, timedelta
-from decimal import Decimal
-from doctest import DocTestSuite
-from fractions import Fraction
-from functools import partial, reduce
-from heapq import merge
-from io import StringIO
-from itertools import (
- accumulate,
- chain,
- combinations,
- count,
- cycle,
- groupby,
- islice,
- permutations,
- product,
- repeat,
-)
-from operator import add, mul, itemgetter
-from pickle import loads, dumps
-from random import seed, Random
-from statistics import mean
-from string import ascii_letters
-from sys import version_info
-from time import sleep
-from traceback import format_exc
-from unittest import skipIf, TestCase
-
-import more_itertools as mi
-
-
-def load_tests(loader, tests, ignore):
- # Add the doctests
- tests.addTests(DocTestSuite('more_itertools.more'))
- return tests
-
-
-class CollateTests(TestCase):
- """Unit tests for ``collate()``"""
-
- # Also accidentally tests peekable, though that could use its own tests
-
- def test_default(self):
- """Test with the default `key` function."""
- iterables = [range(4), range(7), range(3, 6)]
- self.assertEqual(
- sorted(reduce(list.__add__, [list(it) for it in iterables])),
- list(mi.collate(*iterables)),
- )
-
- def test_key(self):
- """Test using a custom `key` function."""
- iterables = [range(5, 0, -1), range(4, 0, -1)]
- actual = sorted(
- reduce(list.__add__, [list(it) for it in iterables]), reverse=True
- )
- expected = list(mi.collate(*iterables, key=lambda x: -x))
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- """Be nice if passed an empty list of iterables."""
- self.assertEqual([], list(mi.collate()))
-
- def test_one(self):
- """Work when only 1 iterable is passed."""
- self.assertEqual([0, 1], list(mi.collate(range(2))))
-
- def test_reverse(self):
- """Test the `reverse` kwarg."""
- iterables = [range(4, 0, -1), range(7, 0, -1), range(3, 6, -1)]
-
- actual = sorted(
- reduce(list.__add__, [list(it) for it in iterables]), reverse=True
- )
- expected = list(mi.collate(*iterables, reverse=True))
- self.assertEqual(actual, expected)
-
- def test_alias(self):
- self.assertNotEqual(merge.__doc__, mi.collate.__doc__)
- self.assertNotEqual(partial.__doc__, mi.collate.__doc__)
-
-
-class ChunkedTests(TestCase):
- """Tests for ``chunked()``"""
-
- def test_even(self):
- """Test when ``n`` divides evenly into the length of the iterable."""
- self.assertEqual(
- list(mi.chunked('ABCDEF', 3)), [['A', 'B', 'C'], ['D', 'E', 'F']]
- )
-
- def test_odd(self):
- """Test when ``n`` does not divide evenly into the length of the
- iterable.
-
- """
- self.assertEqual(
- list(mi.chunked('ABCDE', 3)), [['A', 'B', 'C'], ['D', 'E']]
- )
-
- def test_none(self):
- """Test when ``n`` has the value ``None``."""
- self.assertEqual(
- list(mi.chunked('ABCDE', None)), [['A', 'B', 'C', 'D', 'E']]
- )
-
- def test_strict_false(self):
- """Test when ``n`` does not divide evenly into the length of the
- iterable and strict is false.
-
- """
- self.assertEqual(
- list(mi.chunked('ABCDE', 3, strict=False)),
- [['A', 'B', 'C'], ['D', 'E']],
- )
-
- def test_strict_being_true(self):
- """Test when ``n`` does not divide evenly into the length of the
- iterable and strict is True (raising an exception).
-
- """
-
- def f():
- return list(mi.chunked('ABCDE', 3, strict=True))
-
- self.assertRaisesRegex(ValueError, "iterable is not divisible by n", f)
- self.assertEqual(
- list(mi.chunked('ABCDEF', 3, strict=True)),
- [['A', 'B', 'C'], ['D', 'E', 'F']],
- )
-
- def test_strict_being_true_with_size_none(self):
- """Test when ``n`` has value ``None`` and the keyword strict is True
- (raising an exception).
-
- """
-
- def f():
- return list(mi.chunked('ABCDE', None, strict=True))
-
- self.assertRaisesRegex(
- ValueError, "n must not be None when using strict mode.", f
- )
-
-
-class FirstTests(TestCase):
- def test_many(self):
- # Also try it on a generator expression to make sure it works on
- # whatever those return, across Python versions.
- self.assertEqual(mi.first(x for x in range(4)), 0)
-
- def test_one(self):
- self.assertEqual(mi.first([3]), 3)
-
- def test_empty_stop_iteration(self):
- try:
- mi.first([])
- except ValueError:
- formatted_exc = format_exc()
- self.assertIn('StopIteration', formatted_exc)
- self.assertIn(
- 'The above exception was the direct cause', formatted_exc
- )
- else:
- self.fail()
-
- def test_default(self):
- self.assertEqual(mi.first([], 'boo'), 'boo')
-
-
-class IterOnlyRange:
- """User-defined iterable class which only support __iter__.
-
- >>> r = IterOnlyRange(5)
- >>> r[0] # doctest: +SKIP
- AttributeError: IterOnlyRange instance has no attribute '__getitem__'
-
- Note: In Python 3, ``TypeError`` will be raised because ``object`` is
- inherited implicitly by default.
-
- >>> r[0] # doctest: +SKIP
- TypeError: 'IterOnlyRange' object does not support indexing
- """
-
- def __init__(self, n):
- """Set the length of the range."""
- self.n = n
-
- def __iter__(self):
- """Works same as range()."""
- return iter(range(self.n))
-
-
-class LastTests(TestCase):
- def test_basic(self):
- cases = [
- (range(4), 3),
- (iter(range(4)), 3),
- (range(1), 0),
- (iter(range(1)), 0),
- (IterOnlyRange(5), 4),
- ({n: str(n) for n in range(5)}, 4),
- ]
- # Versions below 3.6.0 don't have ordered dicts
- if version_info >= (3, 6, 0):
- cases.append(({0: '0', -1: '-1', 2: '-2'}, 2))
-
- for iterable, expected in cases:
- with self.subTest(iterable=iterable):
- self.assertEqual(mi.last(iterable), expected)
-
- def test_default(self):
- for iterable, default, expected in [
- (range(1), None, 0),
- ([], None, None),
- ({}, None, None),
- (iter([]), None, None),
- ]:
- with self.subTest(args=(iterable, default)):
- self.assertEqual(mi.last(iterable, default=default), expected)
-
- def test_empty(self):
- for iterable in ([], iter(range(0))):
- with self.subTest(iterable=iterable):
- with self.assertRaises(ValueError):
- mi.last(iterable)
-
-
-class NthOrLastTests(TestCase):
- """Tests for ``nth_or_last()``"""
-
- def test_basic(self):
- self.assertEqual(mi.nth_or_last(range(3), 1), 1)
- self.assertEqual(mi.nth_or_last(range(3), 3), 2)
-
- def test_default_value(self):
- default = 42
- self.assertEqual(mi.nth_or_last(range(0), 3, default), default)
-
- def test_empty_iterable_no_default(self):
- self.assertRaises(ValueError, lambda: mi.nth_or_last(range(0), 0))
-
-
-class PeekableMixinTests:
- """Common tests for ``peekable()`` and ``seekable()`` behavior"""
-
- cls = None
-
- def test_passthrough(self):
- """Iterating a peekable without using ``peek()`` or ``prepend()``
- should just give the underlying iterable's elements (a trivial test but
- useful to set a baseline in case something goes wrong)"""
- expected = [1, 2, 3, 4, 5]
- actual = list(self.cls(expected))
- self.assertEqual(actual, expected)
-
- def test_peek_default(self):
- """Make sure passing a default into ``peek()`` works."""
- p = self.cls([])
- self.assertEqual(p.peek(7), 7)
-
- def test_truthiness(self):
- """Make sure a ``peekable`` tests true iff there are items remaining in
- the iterable.
-
- """
- p = self.cls([])
- self.assertFalse(p)
-
- p = self.cls(range(3))
- self.assertTrue(p)
-
- def test_simple_peeking(self):
- """Make sure ``next`` and ``peek`` advance and don't advance the
- iterator, respectively.
-
- """
- p = self.cls(range(10))
- self.assertEqual(next(p), 0)
- self.assertEqual(p.peek(), 1)
- self.assertEqual(p.peek(), 1)
- self.assertEqual(next(p), 1)
-
-
-class PeekableTests(PeekableMixinTests, TestCase):
- """Tests for ``peekable()`` behavior not incidentally covered by testing
- ``collate()``
-
- """
-
- cls = mi.peekable
-
- def test_indexing(self):
- """
- Indexing into the peekable shouldn't advance the iterator.
- """
- p = mi.peekable('abcdefghijkl')
-
- # The 0th index is what ``next()`` will return
- self.assertEqual(p[0], 'a')
- self.assertEqual(next(p), 'a')
-
- # Indexing further into the peekable shouldn't advance the itertor
- self.assertEqual(p[2], 'd')
- self.assertEqual(next(p), 'b')
-
- # The 0th index moves up with the iterator; the last index follows
- self.assertEqual(p[0], 'c')
- self.assertEqual(p[9], 'l')
-
- self.assertEqual(next(p), 'c')
- self.assertEqual(p[8], 'l')
-
- # Negative indexing should work too
- self.assertEqual(p[-2], 'k')
- self.assertEqual(p[-9], 'd')
- self.assertRaises(IndexError, lambda: p[-10])
-
- def test_slicing(self):
- """Slicing the peekable shouldn't advance the iterator."""
- seq = list('abcdefghijkl')
- p = mi.peekable(seq)
-
- # Slicing the peekable should just be like slicing a re-iterable
- self.assertEqual(p[1:4], seq[1:4])
-
- # Advancing the iterator moves the slices up also
- self.assertEqual(next(p), 'a')
- self.assertEqual(p[1:4], seq[1:][1:4])
-
- # Implicit starts and stop should work
- self.assertEqual(p[:5], seq[1:][:5])
- self.assertEqual(p[:], seq[1:][:])
-
- # Indexing past the end should work
- self.assertEqual(p[:100], seq[1:][:100])
-
- # Steps should work, including negative
- self.assertEqual(p[::2], seq[1:][::2])
- self.assertEqual(p[::-1], seq[1:][::-1])
-
- def test_slicing_reset(self):
- """Test slicing on a fresh iterable each time"""
- iterable = ['0', '1', '2', '3', '4', '5']
- indexes = list(range(-4, len(iterable) + 4)) + [None]
- steps = [1, 2, 3, 4, -1, -2, -3, 4]
- for slice_args in product(indexes, indexes, steps):
- it = iter(iterable)
- p = mi.peekable(it)
- next(p)
- index = slice(*slice_args)
- actual = p[index]
- expected = iterable[1:][index]
- self.assertEqual(actual, expected, slice_args)
-
- def test_slicing_error(self):
- iterable = '01234567'
- p = mi.peekable(iter(iterable))
-
- # Prime the cache
- p.peek()
- old_cache = list(p._cache)
-
- # Illegal slice
- with self.assertRaises(ValueError):
- p[1:-1:0]
-
- # Neither the cache nor the iteration should be affected
- self.assertEqual(old_cache, list(p._cache))
- self.assertEqual(list(p), list(iterable))
-
- # prepend() behavior tests
-
- def test_prepend(self):
- """Tests intersperesed ``prepend()`` and ``next()`` calls"""
- it = mi.peekable(range(2))
- actual = []
-
- # Test prepend() before next()
- it.prepend(10)
- actual += [next(it), next(it)]
-
- # Test prepend() between next()s
- it.prepend(11)
- actual += [next(it), next(it)]
-
- # Test prepend() after source iterable is consumed
- it.prepend(12)
- actual += [next(it)]
-
- expected = [10, 0, 11, 1, 12]
- self.assertEqual(actual, expected)
-
- def test_multi_prepend(self):
- """Tests prepending multiple items and getting them in proper order"""
- it = mi.peekable(range(5))
- actual = [next(it), next(it)]
- it.prepend(10, 11, 12)
- it.prepend(20, 21)
- actual += list(it)
- expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4]
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- """Tests prepending in front of an empty iterable"""
- it = mi.peekable([])
- it.prepend(10)
- actual = list(it)
- expected = [10]
- self.assertEqual(actual, expected)
-
- def test_prepend_truthiness(self):
- """Tests that ``__bool__()`` or ``__nonzero__()`` works properly
- with ``prepend()``"""
- it = mi.peekable(range(5))
- self.assertTrue(it)
- actual = list(it)
- self.assertFalse(it)
- it.prepend(10)
- self.assertTrue(it)
- actual += [next(it)]
- self.assertFalse(it)
- expected = [0, 1, 2, 3, 4, 10]
- self.assertEqual(actual, expected)
-
- def test_multi_prepend_peek(self):
- """Tests prepending multiple elements and getting them in reverse order
- while peeking"""
- it = mi.peekable(range(5))
- actual = [next(it), next(it)]
- self.assertEqual(it.peek(), 2)
- it.prepend(10, 11, 12)
- self.assertEqual(it.peek(), 10)
- it.prepend(20, 21)
- self.assertEqual(it.peek(), 20)
- actual += list(it)
- self.assertFalse(it)
- expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4]
- self.assertEqual(actual, expected)
-
- def test_prepend_after_stop(self):
- """Test resuming iteration after a previous exhaustion"""
- it = mi.peekable(range(3))
- self.assertEqual(list(it), [0, 1, 2])
- self.assertRaises(StopIteration, lambda: next(it))
- it.prepend(10)
- self.assertEqual(next(it), 10)
- self.assertRaises(StopIteration, lambda: next(it))
-
- def test_prepend_slicing(self):
- """Tests interaction between prepending and slicing"""
- seq = list(range(20))
- p = mi.peekable(seq)
-
- p.prepend(30, 40, 50)
- pseq = [30, 40, 50] + seq # pseq for prepended_seq
-
- # adapt the specific tests from test_slicing
- self.assertEqual(p[0], 30)
- self.assertEqual(p[1:8], pseq[1:8])
- self.assertEqual(p[1:], pseq[1:])
- self.assertEqual(p[:5], pseq[:5])
- self.assertEqual(p[:], pseq[:])
- self.assertEqual(p[:100], pseq[:100])
- self.assertEqual(p[::2], pseq[::2])
- self.assertEqual(p[::-1], pseq[::-1])
-
- def test_prepend_indexing(self):
- """Tests interaction between prepending and indexing"""
- seq = list(range(20))
- p = mi.peekable(seq)
-
- p.prepend(30, 40, 50)
-
- self.assertEqual(p[0], 30)
- self.assertEqual(next(p), 30)
- self.assertEqual(p[2], 0)
- self.assertEqual(next(p), 40)
- self.assertEqual(p[0], 50)
- self.assertEqual(p[9], 8)
- self.assertEqual(next(p), 50)
- self.assertEqual(p[8], 8)
- self.assertEqual(p[-2], 18)
- self.assertEqual(p[-9], 11)
- self.assertRaises(IndexError, lambda: p[-21])
-
- def test_prepend_iterable(self):
- """Tests prepending from an iterable"""
- it = mi.peekable(range(5))
- # Don't directly use the range() object to avoid any range-specific
- # optimizations
- it.prepend(*(x for x in range(5)))
- actual = list(it)
- expected = list(chain(range(5), range(5)))
- self.assertEqual(actual, expected)
-
- def test_prepend_many(self):
- """Tests that prepending a huge number of elements works"""
- it = mi.peekable(range(5))
- # Don't directly use the range() object to avoid any range-specific
- # optimizations
- it.prepend(*(x for x in range(20000)))
- actual = list(it)
- expected = list(chain(range(20000), range(5)))
- self.assertEqual(actual, expected)
-
- def test_prepend_reversed(self):
- """Tests prepending from a reversed iterable"""
- it = mi.peekable(range(3))
- it.prepend(*reversed((10, 11, 12)))
- actual = list(it)
- expected = [12, 11, 10, 0, 1, 2]
- self.assertEqual(actual, expected)
-
-
-class ConsumerTests(TestCase):
- """Tests for ``consumer()``"""
-
- def test_consumer(self):
- @mi.consumer
- def eater():
- while True:
- x = yield # noqa
-
- e = eater()
- e.send('hi') # without @consumer, would raise TypeError
-
-
-class DistinctPermutationsTests(TestCase):
- def test_distinct_permutations(self):
- """Make sure the output for ``distinct_permutations()`` is the same as
- set(permutations(it)).
-
- """
- iterable = ['z', 'a', 'a', 'q', 'q', 'q', 'y']
- test_output = sorted(mi.distinct_permutations(iterable))
- ref_output = sorted(set(permutations(iterable)))
- self.assertEqual(test_output, ref_output)
-
- def test_other_iterables(self):
- """Make sure ``distinct_permutations()`` accepts a different type of
- iterables.
-
- """
- # a generator
- iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y'])
- test_output = sorted(mi.distinct_permutations(iterable))
- # "reload" it
- iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y'])
- ref_output = sorted(set(permutations(iterable)))
- self.assertEqual(test_output, ref_output)
-
- # an iterator
- iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y'])
- test_output = sorted(mi.distinct_permutations(iterable))
- # "reload" it
- iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y'])
- ref_output = sorted(set(permutations(iterable)))
- self.assertEqual(test_output, ref_output)
-
- def test_r(self):
- for iterable, r in (
- ('mississippi', 0),
- ('mississippi', 1),
- ('mississippi', 6),
- ('mississippi', 7),
- ('mississippi', 12),
- ([0, 1, 1, 0], 0),
- ([0, 1, 1, 0], 1),
- ([0, 1, 1, 0], 2),
- ([0, 1, 1, 0], 3),
- ([0, 1, 1, 0], 4),
- (['a'], 0),
- (['a'], 1),
- (['a'], 5),
- ([], 0),
- ([], 1),
- ([], 4),
- ):
- with self.subTest(iterable=iterable, r=r):
- expected = sorted(set(permutations(iterable, r)))
- actual = sorted(mi.distinct_permutations(iter(iterable), r))
- self.assertEqual(actual, expected)
-
-
-class IlenTests(TestCase):
- def test_ilen(self):
- """Sanity-checks for ``ilen()``."""
- # Non-empty
- self.assertEqual(
- mi.ilen(filter(lambda x: x % 10 == 0, range(101))), 11
- )
-
- # Empty
- self.assertEqual(mi.ilen(x for x in range(0)), 0)
-
- # Iterable with __len__
- self.assertEqual(mi.ilen(list(range(6))), 6)
-
-
-class MinMaxTests(TestCase):
- def test_basic(self):
- for iterable, expected in (
- # easy case
- ([0, 1, 2, 3], (0, 3)),
- # min and max are not in the extremes + we have `int`s and `float`s
- ([3, 5.5, -1, 2], (-1, 5.5)),
- # unordered collection
- ({3, 5.5, -1, 2}, (-1, 5.5)),
- # with repetitions
- ([3, 5.5, float('-Inf'), 5.5], (float('-Inf'), 5.5)),
- # other collections
- ('banana', ('a', 'n')),
- ({0: 1, 2: 100, 1: 10}, (0, 2)),
- (range(3, 14), (3, 13)),
- ):
- with self.subTest(iterable=iterable, expected=expected):
- # check for expected results
- self.assertTupleEqual(mi.minmax(iterable), expected)
- # check for equality with built-in `min` and `max`
- self.assertTupleEqual(
- mi.minmax(iterable), (min(iterable), max(iterable))
- )
-
- def test_unpacked(self):
- self.assertTupleEqual(mi.minmax(2, 3, 1), (1, 3))
- self.assertTupleEqual(mi.minmax(12, 3, 4, key=str), (12, 4))
-
- def test_iterables(self):
- self.assertTupleEqual(mi.minmax(x for x in [0, 1, 2, 3]), (0, 3))
- self.assertTupleEqual(
- mi.minmax(map(str, [3, 5.5, 'a', 2])), ('2', 'a')
- )
- self.assertTupleEqual(
- mi.minmax(filter(None, [0, 3, '', None, 10])), (3, 10)
- )
-
- def test_key(self):
- self.assertTupleEqual(
- mi.minmax({(), (1, 4, 2), 'abcde', range(4)}, key=len),
- ((), 'abcde'),
- )
- self.assertTupleEqual(
- mi.minmax((x for x in [10, 3, 25]), key=str), (10, 3)
- )
-
- def test_default(self):
- with self.assertRaises(ValueError):
- mi.minmax([])
-
- self.assertIs(mi.minmax([], default=None), None)
- self.assertListEqual(mi.minmax([], default=[1, 'a']), [1, 'a'])
-
-
-class WithIterTests(TestCase):
- def test_with_iter(self):
- s = StringIO('One fish\nTwo fish')
- initial_words = [line.split()[0] for line in mi.with_iter(s)]
-
- # Iterable's items should be faithfully represented
- self.assertEqual(initial_words, ['One', 'Two'])
- # The file object should be closed
- self.assertTrue(s.closed)
-
-
-class OneTests(TestCase):
- def test_basic(self):
- it = iter(['item'])
- self.assertEqual(mi.one(it), 'item')
-
- def test_too_short(self):
- it = iter([])
- for too_short, exc_type in [
- (None, ValueError),
- (IndexError, IndexError),
- ]:
- with self.subTest(too_short=too_short):
- try:
- mi.one(it, too_short=too_short)
- except exc_type:
- formatted_exc = format_exc()
- self.assertIn('StopIteration', formatted_exc)
- self.assertIn(
- 'The above exception was the direct cause',
- formatted_exc,
- )
- else:
- self.fail()
-
- def test_too_long(self):
- it = count()
- self.assertRaises(ValueError, lambda: mi.one(it)) # burn 0 and 1
- self.assertEqual(next(it), 2)
- self.assertRaises(
- OverflowError, lambda: mi.one(it, too_long=OverflowError)
- )
-
- def test_too_long_default_message(self):
- it = count()
- self.assertRaisesRegex(
- ValueError,
- "Expected exactly one item in "
- "iterable, but got 0, 1, and "
- "perhaps more.",
- lambda: mi.one(it),
- )
-
-
-class IntersperseTest(TestCase):
- """Tests for intersperse()"""
-
- def test_even(self):
- iterable = (x for x in '01')
- self.assertEqual(
- list(mi.intersperse(None, iterable)), ['0', None, '1']
- )
-
- def test_odd(self):
- iterable = (x for x in '012')
- self.assertEqual(
- list(mi.intersperse(None, iterable)), ['0', None, '1', None, '2']
- )
-
- def test_nested(self):
- element = ('a', 'b')
- iterable = (x for x in '012')
- actual = list(mi.intersperse(element, iterable))
- expected = ['0', ('a', 'b'), '1', ('a', 'b'), '2']
- self.assertEqual(actual, expected)
-
- def test_not_iterable(self):
- self.assertRaises(TypeError, lambda: mi.intersperse('x', 1))
-
- def test_n(self):
- for n, element, expected in [
- (1, '_', ['0', '_', '1', '_', '2', '_', '3', '_', '4', '_', '5']),
- (2, '_', ['0', '1', '_', '2', '3', '_', '4', '5']),
- (3, '_', ['0', '1', '2', '_', '3', '4', '5']),
- (4, '_', ['0', '1', '2', '3', '_', '4', '5']),
- (5, '_', ['0', '1', '2', '3', '4', '_', '5']),
- (6, '_', ['0', '1', '2', '3', '4', '5']),
- (7, '_', ['0', '1', '2', '3', '4', '5']),
- (3, ['a', 'b'], ['0', '1', '2', ['a', 'b'], '3', '4', '5']),
- ]:
- iterable = (x for x in '012345')
- actual = list(mi.intersperse(element, iterable, n=n))
- self.assertEqual(actual, expected)
-
- def test_n_zero(self):
- self.assertRaises(
- ValueError, lambda: list(mi.intersperse('x', '012', n=0))
- )
-
-
-class UniqueToEachTests(TestCase):
- """Tests for ``unique_to_each()``"""
-
- def test_all_unique(self):
- """When all the input iterables are unique the output should match
- the input."""
- iterables = [[1, 2], [3, 4, 5], [6, 7, 8]]
- self.assertEqual(mi.unique_to_each(*iterables), iterables)
-
- def test_duplicates(self):
- """When there are duplicates in any of the input iterables that aren't
- in the rest, those duplicates should be emitted."""
- iterables = ["mississippi", "missouri"]
- self.assertEqual(
- mi.unique_to_each(*iterables), [['p', 'p'], ['o', 'u', 'r']]
- )
-
- def test_mixed(self):
- """When the input iterables contain different types the function should
- still behave properly"""
- iterables = ['x', (i for i in range(3)), [1, 2, 3], tuple()]
- self.assertEqual(mi.unique_to_each(*iterables), [['x'], [0], [3], []])
-
-
-class WindowedTests(TestCase):
- """Tests for ``windowed()``"""
-
- def test_basic(self):
- actual = list(mi.windowed([1, 2, 3, 4, 5], 3))
- expected = [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
- self.assertEqual(actual, expected)
-
- def test_large_size(self):
- """
- When the window size is larger than the iterable, and no fill value is
- given,``None`` should be filled in.
- """
- actual = list(mi.windowed([1, 2, 3, 4, 5], 6))
- expected = [(1, 2, 3, 4, 5, None)]
- self.assertEqual(actual, expected)
-
- def test_fillvalue(self):
- """
- When sizes don't match evenly, the given fill value should be used.
- """
- iterable = [1, 2, 3, 4, 5]
-
- for n, kwargs, expected in [
- (6, {}, [(1, 2, 3, 4, 5, '!')]), # n > len(iterable)
- (3, {'step': 3}, [(1, 2, 3), (4, 5, '!')]), # using ``step``
- ]:
- actual = list(mi.windowed(iterable, n, fillvalue='!', **kwargs))
- self.assertEqual(actual, expected)
-
- def test_zero(self):
- """When the window size is zero, an empty tuple should be emitted."""
- actual = list(mi.windowed([1, 2, 3, 4, 5], 0))
- expected = [tuple()]
- self.assertEqual(actual, expected)
-
- def test_negative(self):
- """When the window size is negative, ValueError should be raised."""
- with self.assertRaises(ValueError):
- list(mi.windowed([1, 2, 3, 4, 5], -1))
-
- def test_step(self):
- """The window should advance by the number of steps provided"""
- iterable = [1, 2, 3, 4, 5, 6, 7]
- for n, step, expected in [
- (3, 2, [(1, 2, 3), (3, 4, 5), (5, 6, 7)]), # n > step
- (3, 3, [(1, 2, 3), (4, 5, 6), (7, None, None)]), # n == step
- (3, 4, [(1, 2, 3), (5, 6, 7)]), # line up nicely
- (3, 5, [(1, 2, 3), (6, 7, None)]), # off by one
- (3, 6, [(1, 2, 3), (7, None, None)]), # off by two
- (3, 7, [(1, 2, 3)]), # step past the end
- (7, 8, [(1, 2, 3, 4, 5, 6, 7)]), # step > len(iterable)
- ]:
- actual = list(mi.windowed(iterable, n, step=step))
- self.assertEqual(actual, expected)
-
- # Step must be greater than or equal to 1
- with self.assertRaises(ValueError):
- list(mi.windowed(iterable, 3, step=0))
-
-
-class SubstringsTests(TestCase):
- def test_basic(self):
- iterable = (x for x in range(4))
- actual = list(mi.substrings(iterable))
- expected = [
- (0,),
- (1,),
- (2,),
- (3,),
- (0, 1),
- (1, 2),
- (2, 3),
- (0, 1, 2),
- (1, 2, 3),
- (0, 1, 2, 3),
- ]
- self.assertEqual(actual, expected)
-
- def test_strings(self):
- iterable = 'abc'
- actual = list(mi.substrings(iterable))
- expected = [
- ('a',),
- ('b',),
- ('c',),
- ('a', 'b'),
- ('b', 'c'),
- ('a', 'b', 'c'),
- ]
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- iterable = iter([])
- actual = list(mi.substrings(iterable))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_order(self):
- iterable = [2, 0, 1]
- actual = list(mi.substrings(iterable))
- expected = [(2,), (0,), (1,), (2, 0), (0, 1), (2, 0, 1)]
- self.assertEqual(actual, expected)
-
-
-class SubstringsIndexesTests(TestCase):
- def test_basic(self):
- sequence = [x for x in range(4)]
- actual = list(mi.substrings_indexes(sequence))
- expected = [
- ([0], 0, 1),
- ([1], 1, 2),
- ([2], 2, 3),
- ([3], 3, 4),
- ([0, 1], 0, 2),
- ([1, 2], 1, 3),
- ([2, 3], 2, 4),
- ([0, 1, 2], 0, 3),
- ([1, 2, 3], 1, 4),
- ([0, 1, 2, 3], 0, 4),
- ]
- self.assertEqual(actual, expected)
-
- def test_strings(self):
- sequence = 'abc'
- actual = list(mi.substrings_indexes(sequence))
- expected = [
- ('a', 0, 1),
- ('b', 1, 2),
- ('c', 2, 3),
- ('ab', 0, 2),
- ('bc', 1, 3),
- ('abc', 0, 3),
- ]
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- sequence = []
- actual = list(mi.substrings_indexes(sequence))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_order(self):
- sequence = [2, 0, 1]
- actual = list(mi.substrings_indexes(sequence))
- expected = [
- ([2], 0, 1),
- ([0], 1, 2),
- ([1], 2, 3),
- ([2, 0], 0, 2),
- ([0, 1], 1, 3),
- ([2, 0, 1], 0, 3),
- ]
- self.assertEqual(actual, expected)
-
- def test_reverse(self):
- sequence = [2, 0, 1]
- actual = list(mi.substrings_indexes(sequence, reverse=True))
- expected = [
- ([2, 0, 1], 0, 3),
- ([2, 0], 0, 2),
- ([0, 1], 1, 3),
- ([2], 0, 1),
- ([0], 1, 2),
- ([1], 2, 3),
- ]
- self.assertEqual(actual, expected)
-
-
-class BucketTests(TestCase):
- def test_basic(self):
- iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33]
- D = mi.bucket(iterable, key=lambda x: 10 * (x // 10))
-
- # In-order access
- self.assertEqual(list(D[10]), [10, 11, 12])
-
- # Out of order access
- self.assertEqual(list(D[30]), [30, 31, 33])
- self.assertEqual(list(D[20]), [20, 21, 22, 23])
-
- self.assertEqual(list(D[40]), []) # Nothing in here!
-
- def test_in(self):
- iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33]
- D = mi.bucket(iterable, key=lambda x: 10 * (x // 10))
-
- self.assertIn(10, D)
- self.assertNotIn(40, D)
- self.assertIn(20, D)
- self.assertNotIn(21, D)
-
- # Checking in-ness shouldn't advance the iterator
- self.assertEqual(next(D[10]), 10)
-
- def test_validator(self):
- iterable = count(0)
- key = lambda x: int(str(x)[0]) # First digit of each number
- validator = lambda x: 0 < x < 10 # No leading zeros
- D = mi.bucket(iterable, key, validator=validator)
- self.assertEqual(mi.take(3, D[1]), [1, 10, 11])
- self.assertNotIn(0, D) # Non-valid entries don't return True
- self.assertNotIn(0, D._cache) # Don't store non-valid entries
- self.assertEqual(list(D[0]), [])
-
- def test_list(self):
- iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33]
- D = mi.bucket(iterable, key=lambda x: 10 * (x // 10))
- self.assertEqual(list(D[10]), [10, 11, 12])
- self.assertEqual(list(D[20]), [20, 21, 22, 23])
- self.assertEqual(list(D[30]), [30, 31, 33])
- self.assertEqual(set(D), {10, 20, 30})
-
- def test_list_validator(self):
- iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33]
- key = lambda x: 10 * (x // 10)
- validator = lambda x: x != 20
- D = mi.bucket(iterable, key, validator=validator)
- self.assertEqual(set(D), {10, 30})
- self.assertEqual(list(D[10]), [10, 11, 12])
- self.assertEqual(list(D[20]), [])
- self.assertEqual(list(D[30]), [30, 31, 33])
-
-
-class SpyTests(TestCase):
- """Tests for ``spy()``"""
-
- def test_basic(self):
- original_iterable = iter('abcdefg')
- head, new_iterable = mi.spy(original_iterable)
- self.assertEqual(head, ['a'])
- self.assertEqual(
- list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g']
- )
-
- def test_unpacking(self):
- original_iterable = iter('abcdefg')
- (first, second, third), new_iterable = mi.spy(original_iterable, 3)
- self.assertEqual(first, 'a')
- self.assertEqual(second, 'b')
- self.assertEqual(third, 'c')
- self.assertEqual(
- list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g']
- )
-
- def test_too_many(self):
- original_iterable = iter('abc')
- head, new_iterable = mi.spy(original_iterable, 4)
- self.assertEqual(head, ['a', 'b', 'c'])
- self.assertEqual(list(new_iterable), ['a', 'b', 'c'])
-
- def test_zero(self):
- original_iterable = iter('abc')
- head, new_iterable = mi.spy(original_iterable, 0)
- self.assertEqual(head, [])
- self.assertEqual(list(new_iterable), ['a', 'b', 'c'])
-
- def test_immutable(self):
- original_iterable = iter('abcdefg')
- head, new_iterable = mi.spy(original_iterable, 3)
- head[0] = 'A'
- self.assertEqual(head, ['A', 'b', 'c'])
- self.assertEqual(
- list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g']
- )
-
-
-class InterleaveTests(TestCase):
- def test_even(self):
- actual = list(mi.interleave([1, 4, 7], [2, 5, 8], [3, 6, 9]))
- expected = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_short(self):
- actual = list(mi.interleave([1, 4], [2, 5, 7], [3, 6, 8]))
- expected = [1, 2, 3, 4, 5, 6]
- self.assertEqual(actual, expected)
-
- def test_mixed_types(self):
- it_list = ['a', 'b', 'c', 'd']
- it_str = '12345'
- it_inf = count()
- actual = list(mi.interleave(it_list, it_str, it_inf))
- expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', 3]
- self.assertEqual(actual, expected)
-
-
-class InterleaveLongestTests(TestCase):
- def test_even(self):
- actual = list(mi.interleave_longest([1, 4, 7], [2, 5, 8], [3, 6, 9]))
- expected = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_short(self):
- actual = list(mi.interleave_longest([1, 4], [2, 5, 7], [3, 6, 8]))
- expected = [1, 2, 3, 4, 5, 6, 7, 8]
- self.assertEqual(actual, expected)
-
- def test_mixed_types(self):
- it_list = ['a', 'b', 'c', 'd']
- it_str = '12345'
- it_gen = (x for x in range(3))
- actual = list(mi.interleave_longest(it_list, it_str, it_gen))
- expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', '5']
- self.assertEqual(actual, expected)
-
-
-class InterleaveEvenlyTests(TestCase):
- def test_equal_lengths(self):
- # when lengths are equal, the relative order shouldn't change
- a = [1, 2, 3]
- b = [5, 6, 7]
- actual = list(mi.interleave_evenly([a, b]))
- expected = [1, 5, 2, 6, 3, 7]
- self.assertEqual(actual, expected)
-
- def test_proportional(self):
- # easy case where the iterables have proportional length
- a = [1, 2, 3, 4]
- b = [5, 6]
- actual = list(mi.interleave_evenly([a, b]))
- expected = [1, 2, 5, 3, 4, 6]
- self.assertEqual(actual, expected)
-
- # swapping a and b should yield the same result
- actual_swapped = list(mi.interleave_evenly([b, a]))
- self.assertEqual(actual_swapped, expected)
-
- def test_not_proportional(self):
- a = [1, 2, 3, 4, 5, 6, 7]
- b = [8, 9, 10]
- expected = [1, 2, 8, 3, 4, 9, 5, 6, 10, 7]
- actual = list(mi.interleave_evenly([a, b]))
- self.assertEqual(actual, expected)
-
- def test_degenerate_one(self):
- a = [0, 1, 2, 3, 4]
- b = [5]
- expected = [0, 1, 2, 5, 3, 4]
- actual = list(mi.interleave_evenly([a, b]))
- self.assertEqual(actual, expected)
-
- def test_degenerate_empty(self):
- a = [1, 2, 3]
- b = []
- expected = [1, 2, 3]
- actual = list(mi.interleave_evenly([a, b]))
- self.assertEqual(actual, expected)
-
- def test_three_iters(self):
- a = ["a1", "a2", "a3", "a4", "a5"]
- b = ["b1", "b2", "b3"]
- c = ["c1"]
- actual = list(mi.interleave_evenly([a, b, c]))
- expected = ["a1", "b1", "a2", "c1", "a3", "b2", "a4", "b3", "a5"]
- self.assertEqual(actual, expected)
-
- def test_many_iters(self):
- # smoke test with many iterables: create iterables with a random
- # number of elements starting with a character ("a0", "a1", ...)
- rng = Random(0)
- iterables = []
- for ch in ascii_letters:
- length = rng.randint(0, 100)
- iterable = [f"{ch}{i}" for i in range(length)]
- iterables.append(iterable)
-
- interleaved = list(mi.interleave_evenly(iterables))
-
- # for each iterable, check that the result contains all its items
- for iterable, ch_expect in zip(iterables, ascii_letters):
- interleaved_actual = [
- e for e in interleaved if e.startswith(ch_expect)
- ]
- assert len(set(interleaved_actual)) == len(iterable)
-
- def test_manual_lengths(self):
- a = combinations(range(4), 2)
- len_a = 4 * (4 - 1) // 2 # == 6
- b = combinations(range(4), 3)
- len_b = 4
-
- expected = [
- (0, 1),
- (0, 1, 2),
- (0, 2),
- (0, 3),
- (0, 1, 3),
- (1, 2),
- (0, 2, 3),
- (1, 3),
- (2, 3),
- (1, 2, 3),
- ]
- actual = list(mi.interleave_evenly([a, b], lengths=[len_a, len_b]))
- self.assertEqual(expected, actual)
-
- def test_no_length_raises(self):
- # combinations doesn't have __len__, should trigger ValueError
- iterables = [range(5), combinations(range(5), 2)]
- with self.assertRaises(ValueError):
- list(mi.interleave_evenly(iterables))
-
- def test_argument_mismatch_raises(self):
- # pass mismatching number of iterables and lengths
- iterables = [range(3)]
- lengths = [3, 4]
- with self.assertRaises(ValueError):
- list(mi.interleave_evenly(iterables, lengths=lengths))
-
-
-class TestCollapse(TestCase):
- """Tests for ``collapse()``"""
-
- def test_collapse(self):
- l = [[1], 2, [[3], 4], [[[5]]]]
- self.assertEqual(list(mi.collapse(l)), [1, 2, 3, 4, 5])
-
- def test_collapse_to_string(self):
- l = [["s1"], "s2", [["s3"], "s4"], [[["s5"]]]]
- self.assertEqual(list(mi.collapse(l)), ["s1", "s2", "s3", "s4", "s5"])
-
- def test_collapse_to_bytes(self):
- l = [[b"s1"], b"s2", [[b"s3"], b"s4"], [[[b"s5"]]]]
- self.assertEqual(
- list(mi.collapse(l)), [b"s1", b"s2", b"s3", b"s4", b"s5"]
- )
-
- def test_collapse_flatten(self):
- l = [[1], [2], [[3], 4], [[[5]]]]
- self.assertEqual(list(mi.collapse(l, levels=1)), list(mi.flatten(l)))
-
- def test_collapse_to_level(self):
- l = [[1], 2, [[3], 4], [[[5]]]]
- self.assertEqual(list(mi.collapse(l, levels=2)), [1, 2, 3, 4, [5]])
- self.assertEqual(
- list(mi.collapse(mi.collapse(l, levels=1), levels=1)),
- list(mi.collapse(l, levels=2)),
- )
-
- def test_collapse_to_list(self):
- l = (1, [2], (3, [4, (5,)], 'ab'))
- actual = list(mi.collapse(l, base_type=list))
- expected = [1, [2], 3, [4, (5,)], 'ab']
- self.assertEqual(actual, expected)
-
-
-class SideEffectTests(TestCase):
- """Tests for ``side_effect()``"""
-
- def test_individual(self):
- # The function increments the counter for each call
- counter = [0]
-
- def func(arg):
- counter[0] += 1
-
- result = list(mi.side_effect(func, range(10)))
- self.assertEqual(result, list(range(10)))
- self.assertEqual(counter[0], 10)
-
- def test_chunked(self):
- # The function increments the counter for each call
- counter = [0]
-
- def func(arg):
- counter[0] += 1
-
- result = list(mi.side_effect(func, range(10), 2))
- self.assertEqual(result, list(range(10)))
- self.assertEqual(counter[0], 5)
-
- def test_before_after(self):
- f = StringIO()
- collector = []
-
- def func(item):
- print(item, file=f)
- collector.append(f.getvalue())
-
- def it():
- yield 'a'
- yield 'b'
- raise RuntimeError('kaboom')
-
- before = lambda: print('HEADER', file=f)
- after = f.close
-
- try:
- mi.consume(mi.side_effect(func, it(), before=before, after=after))
- except RuntimeError:
- pass
-
- # The iterable should have been written to the file
- self.assertEqual(collector, ['HEADER\na\n', 'HEADER\na\nb\n'])
-
- # The file should be closed even though something bad happened
- self.assertTrue(f.closed)
-
- def test_before_fails(self):
- f = StringIO()
- func = lambda x: print(x, file=f)
-
- def before():
- raise RuntimeError('ouch')
-
- try:
- mi.consume(
- mi.side_effect(func, 'abc', before=before, after=f.close)
- )
- except RuntimeError:
- pass
-
- # The file should be closed even though something bad happened in the
- # before function
- self.assertTrue(f.closed)
-
-
-class SlicedTests(TestCase):
- """Tests for ``sliced()``"""
-
- def test_even(self):
- """Test when the length of the sequence is divisible by *n*"""
- seq = 'ABCDEFGHI'
- self.assertEqual(list(mi.sliced(seq, 3)), ['ABC', 'DEF', 'GHI'])
-
- def test_odd(self):
- """Test when the length of the sequence is not divisible by *n*"""
- seq = 'ABCDEFGHI'
- self.assertEqual(list(mi.sliced(seq, 4)), ['ABCD', 'EFGH', 'I'])
-
- def test_not_sliceable(self):
- seq = (x for x in 'ABCDEFGHI')
-
- with self.assertRaises(TypeError):
- list(mi.sliced(seq, 3))
-
- def test_odd_and_strict(self):
- seq = [x for x in 'ABCDEFGHI']
-
- with self.assertRaises(ValueError):
- list(mi.sliced(seq, 4, strict=True))
-
- def test_numpy_like_array(self):
- # Numpy arrays don't behave like Python lists - calling bool()
- # on them doesn't return False for empty lists and True for non-empty
- # ones. Emulate that behavior.
- class FalseList(list):
- def __getitem__(self, key):
- ret = super().__getitem__(key)
- if isinstance(key, slice):
- return FalseList(ret)
-
- return ret
-
- def __bool__(self):
- return False
-
- seq = FalseList(range(9))
- actual = list(mi.sliced(seq, 3))
- expected = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- self.assertEqual(actual, expected)
-
-
-class SplitAtTests(TestCase):
- def test_basic(self):
- for iterable, separator in [
- ('a,bb,ccc,dddd', ','),
- (',a,bb,ccc,dddd', ','),
- ('a,bb,ccc,dddd,', ','),
- ('a,bb,ccc,,dddd', ','),
- ('', ','),
- (',', ','),
- ('a,bb,ccc,dddd', ';'),
- ]:
- with self.subTest(iterable=iterable, separator=separator):
- it = iter(iterable)
- pred = lambda x: x == separator
- actual = [''.join(x) for x in mi.split_at(it, pred)]
- expected = iterable.split(separator)
- self.assertEqual(actual, expected)
-
- def test_maxsplit(self):
- iterable = 'a,bb,ccc,dddd'
- separator = ','
- pred = lambda x: x == separator
-
- for maxsplit in range(-1, 4):
- with self.subTest(maxsplit=maxsplit):
- it = iter(iterable)
- result = mi.split_at(it, pred, maxsplit=maxsplit)
- actual = [''.join(x) for x in result]
- expected = iterable.split(separator, maxsplit)
- self.assertEqual(actual, expected)
-
- def test_keep_separator(self):
- separator = ','
- pred = lambda x: x == separator
-
- for iterable, expected in [
- ('a,bb,ccc', ['a', ',', 'bb', ',', 'ccc']),
- (',a,bb,ccc', ['', ',', 'a', ',', 'bb', ',', 'ccc']),
- ('a,bb,ccc,', ['a', ',', 'bb', ',', 'ccc', ',', '']),
- ]:
- with self.subTest(iterable=iterable):
- it = iter(iterable)
- result = mi.split_at(it, pred, keep_separator=True)
- actual = [''.join(x) for x in result]
- self.assertEqual(actual, expected)
-
- def test_combination(self):
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
- pred = lambda x: x % 3 == 0
- actual = list(
- mi.split_at(iterable, pred, maxsplit=2, keep_separator=True)
- )
- expected = [[1, 2], [3], [4, 5], [6], [7, 8, 9, 10]]
- self.assertEqual(actual, expected)
-
-
-class SplitBeforeTest(TestCase):
- """Tests for ``split_before()``"""
-
- def test_starts_with_sep(self):
- actual = list(mi.split_before('xooxoo', lambda c: c == 'x'))
- expected = [['x', 'o', 'o'], ['x', 'o', 'o']]
- self.assertEqual(actual, expected)
-
- def test_ends_with_sep(self):
- actual = list(mi.split_before('ooxoox', lambda c: c == 'x'))
- expected = [['o', 'o'], ['x', 'o', 'o'], ['x']]
- self.assertEqual(actual, expected)
-
- def test_no_sep(self):
- actual = list(mi.split_before('ooo', lambda c: c == 'x'))
- expected = [['o', 'o', 'o']]
- self.assertEqual(actual, expected)
-
- def test_empty_collection(self):
- actual = list(mi.split_before([], lambda c: bool(c)))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_max_split(self):
- for args, expected in [
- (
- ('a,b,c,d', lambda c: c == ',', -1),
- [['a'], [',', 'b'], [',', 'c'], [',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c == ',', 0),
- [['a', ',', 'b', ',', 'c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c == ',', 1),
- [['a'], [',', 'b', ',', 'c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c == ',', 2),
- [['a'], [',', 'b'], [',', 'c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c == ',', 10),
- [['a'], [',', 'b'], [',', 'c'], [',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c == '@', 2),
- [['a', ',', 'b', ',', 'c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c != ',', 2),
- [['a', ','], ['b', ','], ['c', ',', 'd']],
- ),
- ]:
- actual = list(mi.split_before(*args))
- self.assertEqual(actual, expected)
-
-
-class SplitAfterTest(TestCase):
- """Tests for ``split_after()``"""
-
- def test_starts_with_sep(self):
- actual = list(mi.split_after('xooxoo', lambda c: c == 'x'))
- expected = [['x'], ['o', 'o', 'x'], ['o', 'o']]
- self.assertEqual(actual, expected)
-
- def test_ends_with_sep(self):
- actual = list(mi.split_after('ooxoox', lambda c: c == 'x'))
- expected = [['o', 'o', 'x'], ['o', 'o', 'x']]
- self.assertEqual(actual, expected)
-
- def test_no_sep(self):
- actual = list(mi.split_after('ooo', lambda c: c == 'x'))
- expected = [['o', 'o', 'o']]
- self.assertEqual(actual, expected)
-
- def test_max_split(self):
- for args, expected in [
- (
- ('a,b,c,d', lambda c: c == ',', -1),
- [['a', ','], ['b', ','], ['c', ','], ['d']],
- ),
- (
- ('a,b,c,d', lambda c: c == ',', 0),
- [['a', ',', 'b', ',', 'c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c == ',', 1),
- [['a', ','], ['b', ',', 'c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c == ',', 2),
- [['a', ','], ['b', ','], ['c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c == ',', 10),
- [['a', ','], ['b', ','], ['c', ','], ['d']],
- ),
- (
- ('a,b,c,d', lambda c: c == '@', 2),
- [['a', ',', 'b', ',', 'c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda c: c != ',', 2),
- [['a'], [',', 'b'], [',', 'c', ',', 'd']],
- ),
- ]:
- actual = list(mi.split_after(*args))
- self.assertEqual(actual, expected)
-
-
-class SplitWhenTests(TestCase):
- """Tests for ``split_when()``"""
-
- @staticmethod
- def _split_when_before(iterable, pred):
- return mi.split_when(iterable, lambda _, c: pred(c))
-
- @staticmethod
- def _split_when_after(iterable, pred):
- return mi.split_when(iterable, lambda c, _: pred(c))
-
- # split_before emulation
- def test_before_emulation_starts_with_sep(self):
- actual = list(self._split_when_before('xooxoo', lambda c: c == 'x'))
- expected = [['x', 'o', 'o'], ['x', 'o', 'o']]
- self.assertEqual(actual, expected)
-
- def test_before_emulation_ends_with_sep(self):
- actual = list(self._split_when_before('ooxoox', lambda c: c == 'x'))
- expected = [['o', 'o'], ['x', 'o', 'o'], ['x']]
- self.assertEqual(actual, expected)
-
- def test_before_emulation_no_sep(self):
- actual = list(self._split_when_before('ooo', lambda c: c == 'x'))
- expected = [['o', 'o', 'o']]
- self.assertEqual(actual, expected)
-
- # split_after emulation
- def test_after_emulation_starts_with_sep(self):
- actual = list(self._split_when_after('xooxoo', lambda c: c == 'x'))
- expected = [['x'], ['o', 'o', 'x'], ['o', 'o']]
- self.assertEqual(actual, expected)
-
- def test_after_emulation_ends_with_sep(self):
- actual = list(self._split_when_after('ooxoox', lambda c: c == 'x'))
- expected = [['o', 'o', 'x'], ['o', 'o', 'x']]
- self.assertEqual(actual, expected)
-
- def test_after_emulation_no_sep(self):
- actual = list(self._split_when_after('ooo', lambda c: c == 'x'))
- expected = [['o', 'o', 'o']]
- self.assertEqual(actual, expected)
-
- # edge cases
- def test_empty_iterable(self):
- actual = list(mi.split_when('', lambda a, b: a != b))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_one_element(self):
- actual = list(mi.split_when('o', lambda a, b: a == b))
- expected = [['o']]
- self.assertEqual(actual, expected)
-
- def test_one_element_is_second_item(self):
- actual = list(self._split_when_before('x', lambda c: c == 'x'))
- expected = [['x']]
- self.assertEqual(actual, expected)
-
- def test_one_element_is_first_item(self):
- actual = list(self._split_when_after('x', lambda c: c == 'x'))
- expected = [['x']]
- self.assertEqual(actual, expected)
-
- def test_max_split(self):
- for args, expected in [
- (
- ('a,b,c,d', lambda a, _: a == ',', -1),
- [['a', ','], ['b', ','], ['c', ','], ['d']],
- ),
- (
- ('a,b,c,d', lambda a, _: a == ',', 0),
- [['a', ',', 'b', ',', 'c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda _, b: b == ',', 1),
- [['a'], [',', 'b', ',', 'c', ',', 'd']],
- ),
- (
- ('a,b,c,d', lambda a, _: a == ',', 2),
- [['a', ','], ['b', ','], ['c', ',', 'd']],
- ),
- (
- ('0124376', lambda a, b: a > b, -1),
- [['0', '1', '2', '4'], ['3', '7'], ['6']],
- ),
- (
- ('0124376', lambda a, b: a > b, 0),
- [['0', '1', '2', '4', '3', '7', '6']],
- ),
- (
- ('0124376', lambda a, b: a > b, 1),
- [['0', '1', '2', '4'], ['3', '7', '6']],
- ),
- (
- ('0124376', lambda a, b: a > b, 2),
- [['0', '1', '2', '4'], ['3', '7'], ['6']],
- ),
- ]:
- actual = list(mi.split_when(*args))
- self.assertEqual(actual, expected, str(args))
-
-
-class SplitIntoTests(TestCase):
- """Tests for ``split_into()``"""
-
- def test_iterable_just_right(self):
- """Size of ``iterable`` equals the sum of ``sizes``."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, 4]
- expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_too_small(self):
- """Size of ``iterable`` is smaller than sum of ``sizes``. Last return
- list is shorter as a result."""
- iterable = [1, 2, 3, 4, 5, 6, 7]
- sizes = [2, 3, 4]
- expected = [[1, 2], [3, 4, 5], [6, 7]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_too_small_extra(self):
- """Size of ``iterable`` is smaller than sum of ``sizes``. Second last
- return list is shorter and last return list is empty as a result."""
- iterable = [1, 2, 3, 4, 5, 6, 7]
- sizes = [2, 3, 4, 5]
- expected = [[1, 2], [3, 4, 5], [6, 7], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_too_large(self):
- """Size of ``iterable`` is larger than sum of ``sizes``. Not all
- items of iterable are returned."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, 2]
- expected = [[1, 2], [3, 4, 5], [6, 7]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_using_none_with_leftover(self):
- """Last item of ``sizes`` is None when items still remain in
- ``iterable``. Last list returned stretches to fit all remaining items
- of ``iterable``."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, None]
- expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_using_none_without_leftover(self):
- """Last item of ``sizes`` is None when no items remain in
- ``iterable``. Last list returned is empty."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, 4, None]
- expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_using_none_mid_sizes(self):
- """None is present in ``sizes`` but is not the last item. Last list
- returned stretches to fit all remaining items of ``iterable`` but
- all items in ``sizes`` after None are ignored."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [2, 3, None, 4]
- expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_empty(self):
- """``iterable`` argument is empty but ``sizes`` is not. An empty
- list is returned for each item in ``sizes``."""
- iterable = []
- sizes = [2, 4, 2]
- expected = [[], [], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_iterable_empty_using_none(self):
- """``iterable`` argument is empty but ``sizes`` is not. An empty
- list is returned for each item in ``sizes`` that is not after a
- None item."""
- iterable = []
- sizes = [2, 4, None, 2]
- expected = [[], [], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_sizes_empty(self):
- """``sizes`` argument is empty but ``iterable`` is not. An empty
- generator is returned."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = []
- expected = []
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_both_empty(self):
- """Both ``sizes`` and ``iterable`` arguments are empty. An empty
- generator is returned."""
- iterable = []
- sizes = []
- expected = []
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_bool_in_sizes(self):
- """A bool object is present in ``sizes`` is treated as a 1 or 0 for
- ``True`` or ``False`` due to bool being an instance of int."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [3, True, 2, False]
- expected = [[1, 2, 3], [4], [5, 6], []]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_invalid_in_sizes(self):
- """A ValueError is raised if an object in ``sizes`` is neither ``None``
- or an integer."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [1, [], 3]
- with self.assertRaises(ValueError):
- list(mi.split_into(iterable, sizes))
-
- def test_invalid_in_sizes_after_none(self):
- """A item in ``sizes`` that is invalid will not raise a TypeError if it
- comes after a ``None`` item."""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = [3, 4, None, []]
- expected = [[1, 2, 3], [4, 5, 6, 7], [8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- def test_generator_iterable_integrity(self):
- """Check that if ``iterable`` is an iterator, it is consumed only by as
- many items as the sum of ``sizes``."""
- iterable = (i for i in range(10))
- sizes = [2, 3]
-
- expected = [[0, 1], [2, 3, 4]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- iterable_expected = [5, 6, 7, 8, 9]
- iterable_actual = list(iterable)
- self.assertEqual(iterable_actual, iterable_expected)
-
- def test_generator_sizes_integrity(self):
- """Check that if ``sizes`` is an iterator, it is consumed only until a
- ``None`` item is reached"""
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9]
- sizes = (i for i in [1, 2, None, 3, 4])
-
- expected = [[1], [2, 3], [4, 5, 6, 7, 8, 9]]
- actual = list(mi.split_into(iterable, sizes))
- self.assertEqual(actual, expected)
-
- sizes_expected = [3, 4]
- sizes_actual = list(sizes)
- self.assertEqual(sizes_actual, sizes_expected)
-
-
-class PaddedTest(TestCase):
- """Tests for ``padded()``"""
-
- def test_no_n(self):
- seq = [1, 2, 3]
-
- # No fillvalue
- self.assertEqual(mi.take(5, mi.padded(seq)), [1, 2, 3, None, None])
-
- # With fillvalue
- self.assertEqual(
- mi.take(5, mi.padded(seq, fillvalue='')), [1, 2, 3, '', '']
- )
-
- def test_invalid_n(self):
- self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=-1)))
- self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=0)))
-
- def test_valid_n(self):
- seq = [1, 2, 3, 4, 5]
-
- # No need for padding: len(seq) <= n
- self.assertEqual(list(mi.padded(seq, n=4)), [1, 2, 3, 4, 5])
- self.assertEqual(list(mi.padded(seq, n=5)), [1, 2, 3, 4, 5])
-
- # No fillvalue
- self.assertEqual(
- list(mi.padded(seq, n=7)), [1, 2, 3, 4, 5, None, None]
- )
-
- # With fillvalue
- self.assertEqual(
- list(mi.padded(seq, fillvalue='', n=7)), [1, 2, 3, 4, 5, '', '']
- )
-
- def test_next_multiple(self):
- seq = [1, 2, 3, 4, 5, 6]
-
- # No need for padding: len(seq) % n == 0
- self.assertEqual(
- list(mi.padded(seq, n=3, next_multiple=True)), [1, 2, 3, 4, 5, 6]
- )
-
- # Padding needed: len(seq) < n
- self.assertEqual(
- list(mi.padded(seq, n=8, next_multiple=True)),
- [1, 2, 3, 4, 5, 6, None, None],
- )
-
- # No padding needed: len(seq) == n
- self.assertEqual(
- list(mi.padded(seq, n=6, next_multiple=True)), [1, 2, 3, 4, 5, 6]
- )
-
- # Padding needed: len(seq) > n
- self.assertEqual(
- list(mi.padded(seq, n=4, next_multiple=True)),
- [1, 2, 3, 4, 5, 6, None, None],
- )
-
- # With fillvalue
- self.assertEqual(
- list(mi.padded(seq, fillvalue='', n=4, next_multiple=True)),
- [1, 2, 3, 4, 5, 6, '', ''],
- )
-
-
-class RepeatEachTests(TestCase):
- """Tests for repeat_each()"""
-
- def test_default(self):
- actual = list(mi.repeat_each('ABC'))
- expected = ['A', 'A', 'B', 'B', 'C', 'C']
- self.assertEqual(actual, expected)
-
- def test_basic(self):
- actual = list(mi.repeat_each('ABC', 3))
- expected = ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- actual = list(mi.repeat_each(''))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_no_repeat(self):
- actual = list(mi.repeat_each('ABC', 0))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_negative_repeat(self):
- actual = list(mi.repeat_each('ABC', -1))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_infinite_input(self):
- repeater = mi.repeat_each(cycle('AB'))
- actual = mi.take(6, repeater)
- expected = ['A', 'A', 'B', 'B', 'A', 'A']
- self.assertEqual(actual, expected)
-
-
-class RepeatLastTests(TestCase):
- def test_empty_iterable(self):
- slice_length = 3
- iterable = iter([])
- actual = mi.take(slice_length, mi.repeat_last(iterable))
- expected = [None] * slice_length
- self.assertEqual(actual, expected)
-
- def test_default_value(self):
- slice_length = 3
- iterable = iter([])
- default = '3'
- actual = mi.take(slice_length, mi.repeat_last(iterable, default))
- expected = ['3'] * slice_length
- self.assertEqual(actual, expected)
-
- def test_basic(self):
- slice_length = 10
- iterable = (str(x) for x in range(5))
- actual = mi.take(slice_length, mi.repeat_last(iterable))
- expected = ['0', '1', '2', '3', '4', '4', '4', '4', '4', '4']
- self.assertEqual(actual, expected)
-
-
-class DistributeTest(TestCase):
- """Tests for distribute()"""
-
- def test_invalid_n(self):
- self.assertRaises(ValueError, lambda: mi.distribute(-1, [1, 2, 3]))
- self.assertRaises(ValueError, lambda: mi.distribute(0, [1, 2, 3]))
-
- def test_basic(self):
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-
- for n, expected in [
- (1, [iterable]),
- (2, [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]),
- (3, [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]]),
- (10, [[n] for n in range(1, 10 + 1)]),
- ]:
- self.assertEqual(
- [list(x) for x in mi.distribute(n, iterable)], expected
- )
-
- def test_large_n(self):
- iterable = [1, 2, 3, 4]
- self.assertEqual(
- [list(x) for x in mi.distribute(6, iterable)],
- [[1], [2], [3], [4], [], []],
- )
-
-
-class StaggerTest(TestCase):
- """Tests for ``stagger()``"""
-
- def test_default(self):
- iterable = [0, 1, 2, 3]
- actual = list(mi.stagger(iterable))
- expected = [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
- self.assertEqual(actual, expected)
-
- def test_offsets(self):
- iterable = [0, 1, 2, 3]
- for offsets, expected in [
- ((-2, 0, 2), [('', 0, 2), ('', 1, 3)]),
- ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3)]),
- ((1, 2), [(1, 2), (2, 3)]),
- ]:
- all_groups = mi.stagger(iterable, offsets=offsets, fillvalue='')
- self.assertEqual(list(all_groups), expected)
-
- def test_longest(self):
- iterable = [0, 1, 2, 3]
- for offsets, expected in [
- (
- (-1, 0, 1),
- [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, ''), (3, '', '')],
- ),
- ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3), (3, '')]),
- ((1, 2), [(1, 2), (2, 3), (3, '')]),
- ]:
- all_groups = mi.stagger(
- iterable, offsets=offsets, fillvalue='', longest=True
- )
- self.assertEqual(list(all_groups), expected)
-
-
-class ZipEqualTest(TestCase):
- @skipIf(version_info[:2] < (3, 10), 'zip_equal deprecated for 3.10+')
- def test_deprecation(self):
- with warnings.catch_warnings(record=True) as caught:
- warnings.simplefilter('always')
- self.assertEqual(
- list(mi.zip_equal([1, 2], [3, 4])), [(1, 3), (2, 4)]
- )
-
- (warning,) = caught
- assert warning.category == DeprecationWarning
-
- def test_equal(self):
- lists = [0, 1, 2], [2, 3, 4]
-
- for iterables in [lists, map(iter, lists)]:
- actual = list(mi.zip_equal(*iterables))
- expected = [(0, 2), (1, 3), (2, 4)]
- self.assertEqual(actual, expected)
-
- def test_unequal_lists(self):
- two_items = [0, 1]
- three_items = [2, 3, 4]
- four_items = [5, 6, 7, 8]
-
- # the mismatch is at index 1
- try:
- list(mi.zip_equal(two_items, three_items, four_items))
- except mi.UnequalIterablesError as e:
- self.assertEqual(
- e.args[0],
- (
- 'Iterables have different lengths: '
- 'index 0 has length 2; index 1 has length 3'
- ),
- )
-
- # the mismatch is at index 2
- try:
- list(mi.zip_equal(two_items, two_items, four_items, four_items))
- except mi.UnequalIterablesError as e:
- self.assertEqual(
- e.args[0],
- (
- 'Iterables have different lengths: '
- 'index 0 has length 2; index 2 has length 4'
- ),
- )
-
- # One without length: delegate to _zip_equal_generator
- try:
- list(mi.zip_equal(two_items, iter(two_items), three_items))
- except mi.UnequalIterablesError as e:
- self.assertEqual(e.args[0], 'Iterables have different lengths')
-
-
-class ZipOffsetTest(TestCase):
- """Tests for ``zip_offset()``"""
-
- def test_shortest(self):
- a_1 = [0, 1, 2, 3]
- a_2 = [0, 1, 2, 3, 4, 5]
- a_3 = [0, 1, 2, 3, 4, 5, 6, 7]
- actual = list(
- mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), fillvalue='')
- )
- expected = [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)]
- self.assertEqual(actual, expected)
-
- def test_longest(self):
- a_1 = [0, 1, 2, 3]
- a_2 = [0, 1, 2, 3, 4, 5]
- a_3 = [0, 1, 2, 3, 4, 5, 6, 7]
- actual = list(
- mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), longest=True)
- )
- expected = [
- (None, 0, 1),
- (0, 1, 2),
- (1, 2, 3),
- (2, 3, 4),
- (3, 4, 5),
- (None, 5, 6),
- (None, None, 7),
- ]
- self.assertEqual(actual, expected)
-
- def test_mismatch(self):
- iterables = [0, 1, 2], [2, 3, 4]
- offsets = (-1, 0, 1)
- self.assertRaises(
- ValueError,
- lambda: list(mi.zip_offset(*iterables, offsets=offsets)),
- )
-
-
-class UnzipTests(TestCase):
- """Tests for unzip()"""
-
- def test_empty_iterable(self):
- self.assertEqual(list(mi.unzip([])), [])
- # in reality zip([], [], []) is equivalent to iter([])
- # but it doesn't hurt to test both
- self.assertEqual(list(mi.unzip(zip([], [], []))), [])
-
- def test_length_one_iterable(self):
- xs, ys, zs = mi.unzip(zip([1], [2], [3]))
- self.assertEqual(list(xs), [1])
- self.assertEqual(list(ys), [2])
- self.assertEqual(list(zs), [3])
-
- def test_normal_case(self):
- xs, ys, zs = range(10), range(1, 11), range(2, 12)
- zipped = zip(xs, ys, zs)
- xs, ys, zs = mi.unzip(zipped)
- self.assertEqual(list(xs), list(range(10)))
- self.assertEqual(list(ys), list(range(1, 11)))
- self.assertEqual(list(zs), list(range(2, 12)))
-
- def test_improperly_zipped(self):
- zipped = iter([(1, 2, 3), (4, 5), (6,)])
- xs, ys, zs = mi.unzip(zipped)
- self.assertEqual(list(xs), [1, 4, 6])
- self.assertEqual(list(ys), [2, 5])
- self.assertEqual(list(zs), [3])
-
- def test_increasingly_zipped(self):
- zipped = iter([(1, 2), (3, 4, 5), (6, 7, 8, 9)])
- unzipped = mi.unzip(zipped)
- # from the docstring:
- # len(first tuple) is the number of iterables zipped
- self.assertEqual(len(unzipped), 2)
- xs, ys = unzipped
- self.assertEqual(list(xs), [1, 3, 6])
- self.assertEqual(list(ys), [2, 4, 7])
-
-
-class SortTogetherTest(TestCase):
- """Tests for sort_together()"""
-
- def test_key_list(self):
- """tests `key_list` including default, iterables include duplicates"""
- iterables = [
- ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'],
- ['May', 'Aug.', 'May', 'June', 'July', 'July'],
- [97, 20, 100, 70, 100, 20],
- ]
-
- self.assertEqual(
- mi.sort_together(iterables),
- [
- ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'),
- ('June', 'July', 'July', 'May', 'Aug.', 'May'),
- (70, 100, 20, 97, 20, 100),
- ],
- )
-
- self.assertEqual(
- mi.sort_together(iterables, key_list=(0, 1)),
- [
- ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'),
- ('July', 'July', 'June', 'Aug.', 'May', 'May'),
- (100, 20, 70, 20, 97, 100),
- ],
- )
-
- self.assertEqual(
- mi.sort_together(iterables, key_list=(0, 1, 2)),
- [
- ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'),
- ('July', 'July', 'June', 'Aug.', 'May', 'May'),
- (20, 100, 70, 20, 97, 100),
- ],
- )
-
- self.assertEqual(
- mi.sort_together(iterables, key_list=(2,)),
- [
- ('GA', 'CT', 'CT', 'GA', 'GA', 'CT'),
- ('Aug.', 'July', 'June', 'May', 'May', 'July'),
- (20, 20, 70, 97, 100, 100),
- ],
- )
-
- def test_invalid_key_list(self):
- """tests `key_list` for indexes not available in `iterables`"""
- iterables = [
- ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'],
- ['May', 'Aug.', 'May', 'June', 'July', 'July'],
- [97, 20, 100, 70, 100, 20],
- ]
-
- self.assertRaises(
- IndexError, lambda: mi.sort_together(iterables, key_list=(5,))
- )
-
- def test_key_function(self):
- """tests `key` function, including interaction with `key_list`"""
- iterables = [
- ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'],
- ['May', 'Aug.', 'May', 'June', 'July', 'July'],
- [97, 20, 100, 70, 100, 20],
- ]
- self.assertEqual(
- mi.sort_together(iterables, key=lambda x: x),
- [
- ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'),
- ('June', 'July', 'July', 'May', 'Aug.', 'May'),
- (70, 100, 20, 97, 20, 100),
- ],
- )
- self.assertEqual(
- mi.sort_together(iterables, key=lambda x: x[::-1]),
- [
- ('GA', 'GA', 'GA', 'CT', 'CT', 'CT'),
- ('May', 'Aug.', 'May', 'June', 'July', 'July'),
- (97, 20, 100, 70, 100, 20),
- ],
- )
- self.assertEqual(
- mi.sort_together(
- iterables,
- key_list=(0, 2),
- key=lambda state, number: number
- if state == 'CT'
- else 2 * number,
- ),
- [
- ('CT', 'GA', 'CT', 'CT', 'GA', 'GA'),
- ('July', 'Aug.', 'June', 'July', 'May', 'May'),
- (20, 20, 70, 100, 97, 100),
- ],
- )
-
- def test_reverse(self):
- """tests `reverse` to ensure a reverse sort for `key_list` iterables"""
- iterables = [
- ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'],
- ['May', 'Aug.', 'May', 'June', 'July', 'July'],
- [97, 20, 100, 70, 100, 20],
- ]
-
- self.assertEqual(
- mi.sort_together(iterables, key_list=(0, 1, 2), reverse=True),
- [
- ('GA', 'GA', 'GA', 'CT', 'CT', 'CT'),
- ('May', 'May', 'Aug.', 'June', 'July', 'July'),
- (100, 97, 20, 70, 100, 20),
- ],
- )
-
- def test_uneven_iterables(self):
- """tests trimming of iterables to the shortest length before sorting"""
- iterables = [
- ['GA', 'GA', 'GA', 'CT', 'CT', 'CT', 'MA'],
- ['May', 'Aug.', 'May', 'June', 'July', 'July'],
- [97, 20, 100, 70, 100, 20, 0],
- ]
-
- self.assertEqual(
- mi.sort_together(iterables),
- [
- ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'),
- ('June', 'July', 'July', 'May', 'Aug.', 'May'),
- (70, 100, 20, 97, 20, 100),
- ],
- )
-
-
-class DivideTest(TestCase):
- """Tests for divide()"""
-
- def test_invalid_n(self):
- self.assertRaises(ValueError, lambda: mi.divide(-1, [1, 2, 3]))
- self.assertRaises(ValueError, lambda: mi.divide(0, [1, 2, 3]))
-
- def test_basic(self):
- iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-
- for n, expected in [
- (1, [iterable]),
- (2, [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]),
- (3, [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]),
- (10, [[n] for n in range(1, 10 + 1)]),
- ]:
- self.assertEqual(
- [list(x) for x in mi.divide(n, iterable)], expected
- )
-
- def test_large_n(self):
- self.assertEqual(
- [list(x) for x in mi.divide(6, iter(range(1, 4 + 1)))],
- [[1], [2], [3], [4], [], []],
- )
-
-
-class TestAlwaysIterable(TestCase):
- """Tests for always_iterable()"""
-
- def test_single(self):
- self.assertEqual(list(mi.always_iterable(1)), [1])
-
- def test_strings(self):
- for obj in ['foo', b'bar', 'baz']:
- actual = list(mi.always_iterable(obj))
- expected = [obj]
- self.assertEqual(actual, expected)
-
- def test_base_type(self):
- dict_obj = {'a': 1, 'b': 2}
- str_obj = '123'
-
- # Default: dicts are iterable like they normally are
- default_actual = list(mi.always_iterable(dict_obj))
- default_expected = list(dict_obj)
- self.assertEqual(default_actual, default_expected)
-
- # Unitary types set: dicts are not iterable
- custom_actual = list(mi.always_iterable(dict_obj, base_type=dict))
- custom_expected = [dict_obj]
- self.assertEqual(custom_actual, custom_expected)
-
- # With unitary types set, strings are iterable
- str_actual = list(mi.always_iterable(str_obj, base_type=None))
- str_expected = list(str_obj)
- self.assertEqual(str_actual, str_expected)
-
- # base_type handles nested tuple (via isinstance).
- base_type = ((dict,),)
- custom_actual = list(mi.always_iterable(dict_obj, base_type=base_type))
- custom_expected = [dict_obj]
- self.assertEqual(custom_actual, custom_expected)
-
- def test_iterables(self):
- self.assertEqual(list(mi.always_iterable([0, 1])), [0, 1])
- self.assertEqual(
- list(mi.always_iterable([0, 1], base_type=list)), [[0, 1]]
- )
- self.assertEqual(
- list(mi.always_iterable(iter('foo'))), ['f', 'o', 'o']
- )
- self.assertEqual(list(mi.always_iterable([])), [])
-
- def test_none(self):
- self.assertEqual(list(mi.always_iterable(None)), [])
-
- def test_generator(self):
- def _gen():
- yield 0
- yield 1
-
- self.assertEqual(list(mi.always_iterable(_gen())), [0, 1])
-
-
-class AdjacentTests(TestCase):
- def test_typical(self):
- actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10)))
- expected = [
- (True, 0),
- (True, 1),
- (False, 2),
- (False, 3),
- (True, 4),
- (True, 5),
- (True, 6),
- (False, 7),
- (False, 8),
- (False, 9),
- ]
- self.assertEqual(actual, expected)
-
- def test_empty_iterable(self):
- actual = list(mi.adjacent(lambda x: x % 5 == 0, []))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_length_one(self):
- actual = list(mi.adjacent(lambda x: x % 5 == 0, [0]))
- expected = [(True, 0)]
- self.assertEqual(actual, expected)
-
- actual = list(mi.adjacent(lambda x: x % 5 == 0, [1]))
- expected = [(False, 1)]
- self.assertEqual(actual, expected)
-
- def test_consecutive_true(self):
- """Test that when the predicate matches multiple consecutive elements
- it doesn't repeat elements in the output"""
- actual = list(mi.adjacent(lambda x: x % 5 < 2, range(10)))
- expected = [
- (True, 0),
- (True, 1),
- (True, 2),
- (False, 3),
- (True, 4),
- (True, 5),
- (True, 6),
- (True, 7),
- (False, 8),
- (False, 9),
- ]
- self.assertEqual(actual, expected)
-
- def test_distance(self):
- actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=2))
- expected = [
- (True, 0),
- (True, 1),
- (True, 2),
- (True, 3),
- (True, 4),
- (True, 5),
- (True, 6),
- (True, 7),
- (False, 8),
- (False, 9),
- ]
- self.assertEqual(actual, expected)
-
- actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=3))
- expected = [
- (True, 0),
- (True, 1),
- (True, 2),
- (True, 3),
- (True, 4),
- (True, 5),
- (True, 6),
- (True, 7),
- (True, 8),
- (False, 9),
- ]
- self.assertEqual(actual, expected)
-
- def test_large_distance(self):
- """Test distance larger than the length of the iterable"""
- iterable = range(10)
- actual = list(mi.adjacent(lambda x: x % 5 == 4, iterable, distance=20))
- expected = list(zip(repeat(True), iterable))
- self.assertEqual(actual, expected)
-
- actual = list(mi.adjacent(lambda x: False, iterable, distance=20))
- expected = list(zip(repeat(False), iterable))
- self.assertEqual(actual, expected)
-
- def test_zero_distance(self):
- """Test that adjacent() reduces to zip+map when distance is 0"""
- iterable = range(1000)
- predicate = lambda x: x % 4 == 2
- actual = mi.adjacent(predicate, iterable, 0)
- expected = zip(map(predicate, iterable), iterable)
- self.assertTrue(all(a == e for a, e in zip(actual, expected)))
-
- def test_negative_distance(self):
- """Test that adjacent() raises an error with negative distance"""
- pred = lambda x: x
- self.assertRaises(
- ValueError, lambda: mi.adjacent(pred, range(1000), -1)
- )
- self.assertRaises(
- ValueError, lambda: mi.adjacent(pred, range(10), -10)
- )
-
- def test_grouping(self):
- """Test interaction of adjacent() with groupby_transform()"""
- iterable = mi.adjacent(lambda x: x % 5 == 0, range(10))
- grouper = mi.groupby_transform(iterable, itemgetter(0), itemgetter(1))
- actual = [(k, list(g)) for k, g in grouper]
- expected = [
- (True, [0, 1]),
- (False, [2, 3]),
- (True, [4, 5, 6]),
- (False, [7, 8, 9]),
- ]
- self.assertEqual(actual, expected)
-
- def test_call_once(self):
- """Test that the predicate is only called once per item."""
- already_seen = set()
- iterable = range(10)
-
- def predicate(item):
- self.assertNotIn(item, already_seen)
- already_seen.add(item)
- return True
-
- actual = list(mi.adjacent(predicate, iterable))
- expected = [(True, x) for x in iterable]
- self.assertEqual(actual, expected)
-
-
-class GroupByTransformTests(TestCase):
- def assertAllGroupsEqual(self, groupby1, groupby2):
- for a, b in zip(groupby1, groupby2):
- key1, group1 = a
- key2, group2 = b
- self.assertEqual(key1, key2)
- self.assertListEqual(list(group1), list(group2))
- self.assertRaises(StopIteration, lambda: next(groupby1))
- self.assertRaises(StopIteration, lambda: next(groupby2))
-
- def test_default_funcs(self):
- iterable = [(x // 5, x) for x in range(1000)]
- actual = mi.groupby_transform(iterable)
- expected = groupby(iterable)
- self.assertAllGroupsEqual(actual, expected)
-
- def test_valuefunc(self):
- iterable = [(int(x / 5), int(x / 3), x) for x in range(10)]
-
- # Test the standard usage of grouping one iterable using another's keys
- grouper = mi.groupby_transform(
- iterable, keyfunc=itemgetter(0), valuefunc=itemgetter(-1)
- )
- actual = [(k, list(g)) for k, g in grouper]
- expected = [(0, [0, 1, 2, 3, 4]), (1, [5, 6, 7, 8, 9])]
- self.assertEqual(actual, expected)
-
- grouper = mi.groupby_transform(
- iterable, keyfunc=itemgetter(1), valuefunc=itemgetter(-1)
- )
- actual = [(k, list(g)) for k, g in grouper]
- expected = [(0, [0, 1, 2]), (1, [3, 4, 5]), (2, [6, 7, 8]), (3, [9])]
- self.assertEqual(actual, expected)
-
- # and now for something a little different
- d = dict(zip(range(10), 'abcdefghij'))
- grouper = mi.groupby_transform(
- range(10), keyfunc=lambda x: x // 5, valuefunc=d.get
- )
- actual = [(k, ''.join(g)) for k, g in grouper]
- expected = [(0, 'abcde'), (1, 'fghij')]
- self.assertEqual(actual, expected)
-
- def test_no_valuefunc(self):
- iterable = range(1000)
-
- def key(x):
- return x // 5
-
- actual = mi.groupby_transform(iterable, key, valuefunc=None)
- expected = groupby(iterable, key)
- self.assertAllGroupsEqual(actual, expected)
-
- actual = mi.groupby_transform(iterable, key) # default valuefunc
- expected = groupby(iterable, key)
- self.assertAllGroupsEqual(actual, expected)
-
- def test_reducefunc(self):
- iterable = range(50)
- keyfunc = lambda k: 10 * (k // 10)
- valuefunc = lambda v: v + 1
- reducefunc = sum
- actual = list(
- mi.groupby_transform(
- iterable,
- keyfunc=keyfunc,
- valuefunc=valuefunc,
- reducefunc=reducefunc,
- )
- )
- expected = [(0, 55), (10, 155), (20, 255), (30, 355), (40, 455)]
- self.assertEqual(actual, expected)
-
-
-class NumericRangeTests(TestCase):
- def test_basic(self):
- for args, expected in [
- ((4,), [0, 1, 2, 3]),
- ((4.0,), [0.0, 1.0, 2.0, 3.0]),
- ((1.0, 4), [1.0, 2.0, 3.0]),
- ((1, 4.0), [1.0, 2.0, 3.0]),
- ((1.0, 5), [1.0, 2.0, 3.0, 4.0]),
- ((0, 20, 5), [0, 5, 10, 15]),
- ((0, 20, 5.0), [0.0, 5.0, 10.0, 15.0]),
- ((0, 10, 3), [0, 3, 6, 9]),
- ((0, 10, 3.0), [0.0, 3.0, 6.0, 9.0]),
- ((0, -5, -1), [0, -1, -2, -3, -4]),
- ((0.0, -5, -1), [0.0, -1.0, -2.0, -3.0, -4.0]),
- ((1, 2, Fraction(1, 2)), [Fraction(1, 1), Fraction(3, 2)]),
- ((0,), []),
- ((0.0,), []),
- ((1, 0), []),
- ((1.0, 0.0), []),
- ((0.1, 0.30000000000000001, 0.2), [0.1]), # IEE 754 !
- (
- (
- Decimal("0.1"),
- Decimal("0.30000000000000001"),
- Decimal("0.2"),
- ),
- [Decimal("0.1"), Decimal("0.3")],
- ), # okay with Decimal
- (
- (
- Fraction(1, 10),
- Fraction(30000000000000001, 100000000000000000),
- Fraction(2, 10),
- ),
- [Fraction(1, 10), Fraction(3, 10)],
- ), # okay with Fraction
- ((Fraction(2, 1),), [Fraction(0, 1), Fraction(1, 1)]),
- ((Decimal('2.0'),), [Decimal('0.0'), Decimal('1.0')]),
- (
- (
- datetime(2019, 3, 29, 12, 34, 56),
- datetime(2019, 3, 29, 12, 37, 55),
- timedelta(minutes=1),
- ),
- [
- datetime(2019, 3, 29, 12, 34, 56),
- datetime(2019, 3, 29, 12, 35, 56),
- datetime(2019, 3, 29, 12, 36, 56),
- ],
- ),
- ]:
- actual = list(mi.numeric_range(*args))
- self.assertEqual(expected, actual)
- self.assertTrue(
- all(type(a) == type(e) for a, e in zip(actual, expected))
- )
-
- def test_arg_count(self):
- for args, message in [
- ((), 'numeric_range expected at least 1 argument, got 0'),
- (
- (0, 1, 2, 3),
- 'numeric_range expected at most 3 arguments, got 4',
- ),
- ]:
- with self.assertRaisesRegex(TypeError, message):
- mi.numeric_range(*args)
-
- def test_zero_step(self):
- for args in [
- (1, 2, 0),
- (
- datetime(2019, 3, 29, 12, 34, 56),
- datetime(2019, 3, 29, 12, 37, 55),
- timedelta(minutes=0),
- ),
- (1.0, 2.0, 0.0),
- (Decimal("1.0"), Decimal("2.0"), Decimal("0.0")),
- (Fraction(2, 2), Fraction(4, 2), Fraction(0, 2)),
- ]:
- with self.assertRaises(ValueError):
- list(mi.numeric_range(*args))
-
- def test_bool(self):
- for args, expected in [
- ((1.0, 3.0, 1.5), True),
- ((1.0, 2.0, 1.5), True),
- ((1.0, 1.0, 1.5), False),
- ((1.0, 0.0, 1.5), False),
- ((3.0, 1.0, -1.5), True),
- ((2.0, 1.0, -1.5), True),
- ((1.0, 1.0, -1.5), False),
- ((0.0, 1.0, -1.5), False),
- ((Decimal("1.0"), Decimal("2.0"), Decimal("1.5")), True),
- ((Decimal("1.0"), Decimal("0.0"), Decimal("1.5")), False),
- ((Fraction(2, 2), Fraction(4, 2), Fraction(3, 2)), True),
- ((Fraction(2, 2), Fraction(0, 2), Fraction(3, 2)), False),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=1),
- ),
- True,
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 28),
- timedelta(hours=1),
- ),
- False,
- ),
- ]:
- self.assertEqual(expected, bool(mi.numeric_range(*args)))
-
- def test_contains(self):
- for args, expected_in, expected_not_in in [
- ((10,), range(10), (0.5,)),
- ((1.0, 9.9, 1.5), (1.0, 2.5, 4.0, 5.5, 7.0, 8.5), (0.9,)),
- ((9.0, 1.0, -1.5), (1.5, 3.0, 4.5, 6.0, 7.5, 9.0), (0.0, 0.9)),
- (
- (Decimal("1.0"), Decimal("9.9"), Decimal("1.5")),
- (
- Decimal("1.0"),
- Decimal("2.5"),
- Decimal("4.0"),
- Decimal("5.5"),
- Decimal("7.0"),
- Decimal("8.5"),
- ),
- (Decimal("0.9"),),
- ),
- (
- (Fraction(0, 1), Fraction(5, 1), Fraction(1, 2)),
- (Fraction(0, 1), Fraction(1, 2), Fraction(9, 2)),
- (Fraction(10, 2),),
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=1),
- ),
- (datetime(2019, 3, 29, 15),),
- (datetime(2019, 3, 29, 15, 30),),
- ),
- ]:
- r = mi.numeric_range(*args)
- for v in expected_in:
- self.assertTrue(v in r)
- self.assertFalse(v not in r)
-
- for v in expected_not_in:
- self.assertFalse(v in r)
- self.assertTrue(v not in r)
-
- def test_eq(self):
- for args1, args2 in [
- ((0, 5, 2), (0, 6, 2)),
- ((1.0, 9.9, 1.5), (1.0, 8.6, 1.5)),
- ((8.5, 0.0, -1.5), (8.5, 0.7, -1.5)),
- ((7.0, 0.0, 1.0), (17.0, 7.0, 0.5)),
- (
- (Decimal("1.0"), Decimal("9.9"), Decimal("1.5")),
- (Decimal("1.0"), Decimal("8.6"), Decimal("1.5")),
- ),
- (
- (Fraction(1, 1), Fraction(10, 1), Fraction(3, 2)),
- (Fraction(1, 1), Fraction(9, 1), Fraction(3, 2)),
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30, 1),
- timedelta(hours=10),
- ),
- ),
- ]:
- self.assertEqual(
- mi.numeric_range(*args1), mi.numeric_range(*args2)
- )
-
- for args1, args2 in [
- ((0, 5, 2), (0, 7, 2)),
- ((1.0, 9.9, 1.5), (1.2, 9.9, 1.5)),
- ((1.0, 9.9, 1.5), (1.0, 10.3, 1.5)),
- ((1.0, 9.9, 1.5), (1.0, 9.9, 1.4)),
- ((8.5, 0.0, -1.5), (8.4, 0.0, -1.5)),
- ((8.5, 0.0, -1.5), (8.5, -0.7, -1.5)),
- ((8.5, 0.0, -1.5), (8.5, 0.0, -1.4)),
- ((0.0, 7.0, 1.0), (7.0, 0.0, 1.0)),
- (
- (Decimal("1.0"), Decimal("10.0"), Decimal("1.5")),
- (Decimal("1.0"), Decimal("10.5"), Decimal("1.5")),
- ),
- (
- (Fraction(1, 1), Fraction(10, 1), Fraction(3, 2)),
- (Fraction(1, 1), Fraction(21, 2), Fraction(3, 2)),
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30, 15),
- timedelta(hours=10),
- ),
- ),
- ]:
- self.assertNotEqual(
- mi.numeric_range(*args1), mi.numeric_range(*args2)
- )
-
- self.assertNotEqual(mi.numeric_range(7.0), 1)
- self.assertNotEqual(mi.numeric_range(7.0), "abc")
-
- def test_get_item_by_index(self):
- for args, index, expected in [
- ((1, 6), 2, 3),
- ((1.0, 6.0, 1.5), 0, 1.0),
- ((1.0, 6.0, 1.5), 1, 2.5),
- ((1.0, 6.0, 1.5), 2, 4.0),
- ((1.0, 6.0, 1.5), 3, 5.5),
- ((1.0, 6.0, 1.5), -1, 5.5),
- ((1.0, 6.0, 1.5), -2, 4.0),
- (
- (Decimal("1.0"), Decimal("9.0"), Decimal("1.5")),
- -1,
- Decimal("8.5"),
- ),
- (
- (Fraction(1, 1), Fraction(10, 1), Fraction(3, 2)),
- 2,
- Fraction(4, 1),
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- 1,
- datetime(2019, 3, 29, 10),
- ),
- ]:
- self.assertEqual(expected, mi.numeric_range(*args)[index])
-
- for args, index in [
- ((1.0, 6.0, 1.5), 4),
- ((1.0, 6.0, 1.5), -5),
- ((6.0, 1.0, 1.5), 0),
- ((6.0, 1.0, 1.5), -1),
- ((Decimal("1.0"), Decimal("9.0"), Decimal("-1.5")), -1),
- ((Fraction(1, 1), Fraction(2, 1), Fraction(3, 2)), 2),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- 8,
- ),
- ]:
- with self.assertRaises(IndexError):
- mi.numeric_range(*args)[index]
-
- def test_get_item_by_slice(self):
- for args, sl, expected_args in [
- ((1.0, 9.0, 1.5), slice(None, None, None), (1.0, 9.0, 1.5)),
- ((1.0, 9.0, 1.5), slice(None, 1, None), (1.0, 2.5, 1.5)),
- ((1.0, 9.0, 1.5), slice(None, None, 2), (1.0, 9.0, 3.0)),
- ((1.0, 9.0, 1.5), slice(None, 2, None), (1.0, 4.0, 1.5)),
- ((1.0, 9.0, 1.5), slice(1, 2, None), (2.5, 4.0, 1.5)),
- ((1.0, 9.0, 1.5), slice(1, -1, None), (2.5, 8.5, 1.5)),
- ((1.0, 9.0, 1.5), slice(10, None, 3), (9.0, 9.0, 4.5)),
- ((1.0, 9.0, 1.5), slice(-10, None, 3), (1.0, 9.0, 4.5)),
- ((1.0, 9.0, 1.5), slice(None, -10, 3), (1.0, 1.0, 4.5)),
- ((1.0, 9.0, 1.5), slice(None, 10, 3), (1.0, 9.0, 4.5)),
- (
- (Decimal("1.0"), Decimal("9.0"), Decimal("1.5")),
- slice(1, -1, None),
- (Decimal("2.5"), Decimal("8.5"), Decimal("1.5")),
- ),
- (
- (Fraction(1, 1), Fraction(5, 1), Fraction(3, 2)),
- slice(1, -1, None),
- (Fraction(5, 2), Fraction(4, 1), Fraction(3, 2)),
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- slice(1, -1, None),
- (
- datetime(2019, 3, 29, 10),
- datetime(2019, 3, 29, 20),
- timedelta(hours=10),
- ),
- ),
- ]:
- self.assertEqual(
- mi.numeric_range(*expected_args), mi.numeric_range(*args)[sl]
- )
-
- def test_hash(self):
- for args, expected in [
- ((1.0, 6.0, 1.5), hash((1.0, 5.5, 1.5))),
- ((1.0, 7.0, 1.5), hash((1.0, 5.5, 1.5))),
- ((1.0, 7.5, 1.5), hash((1.0, 7.0, 1.5))),
- ((1.0, 1.5, 1.5), hash((1.0, 1.0, 1.5))),
- ((1.5, 1.0, 1.5), hash(range(0, 0))),
- ((1.5, 1.5, 1.5), hash(range(0, 0))),
- (
- (Decimal("1.0"), Decimal("9.0"), Decimal("1.5")),
- hash((Decimal("1.0"), Decimal("8.5"), Decimal("1.5"))),
- ),
- (
- (Fraction(1, 1), Fraction(5, 1), Fraction(3, 2)),
- hash((Fraction(1, 1), Fraction(4, 1), Fraction(3, 2))),
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- hash(
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 29, 20),
- timedelta(hours=10),
- )
- ),
- ),
- ]:
- self.assertEqual(expected, hash(mi.numeric_range(*args)))
-
- def test_iter_twice(self):
- r1 = mi.numeric_range(1.0, 9.9, 1.5)
- r2 = mi.numeric_range(8.5, 0.0, -1.5)
- self.assertEqual([1.0, 2.5, 4.0, 5.5, 7.0, 8.5], list(r1))
- self.assertEqual([1.0, 2.5, 4.0, 5.5, 7.0, 8.5], list(r1))
- self.assertEqual([8.5, 7.0, 5.5, 4.0, 2.5, 1.0], list(r2))
- self.assertEqual([8.5, 7.0, 5.5, 4.0, 2.5, 1.0], list(r2))
-
- def test_len(self):
- for args, expected in [
- ((1.0, 7.0, 1.5), 4),
- ((1.0, 7.01, 1.5), 5),
- ((7.0, 1.0, -1.5), 4),
- ((7.01, 1.0, -1.5), 5),
- ((0.1, 0.30000000000000001, 0.2), 1), # IEE 754 !
- (
- (
- Decimal("0.1"),
- Decimal("0.30000000000000001"),
- Decimal("0.2"),
- ),
- 2,
- ), # works with Decimal
- ((Decimal("1.0"), Decimal("9.0"), Decimal("1.5")), 6),
- ((Fraction(1, 1), Fraction(5, 1), Fraction(3, 2)), 3),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- 3,
- ),
- ]:
- self.assertEqual(expected, len(mi.numeric_range(*args)))
-
- def test_repr(self):
- for args, *expected in [
- ((7.0,), "numeric_range(0.0, 7.0)"),
- ((1.0, 7.0), "numeric_range(1.0, 7.0)"),
- ((7.0, 1.0, -1.5), "numeric_range(7.0, 1.0, -1.5)"),
- (
- (Decimal("1.0"), Decimal("9.0"), Decimal("1.5")),
- (
- "numeric_range(Decimal('1.0'), Decimal('9.0'), "
- "Decimal('1.5'))"
- ),
- ),
- (
- (Fraction(7, 7), Fraction(10, 2), Fraction(3, 2)),
- (
- "numeric_range(Fraction(1, 1), Fraction(5, 1), "
- "Fraction(3, 2))"
- ),
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- "numeric_range(datetime.datetime(2019, 3, 29, 0, 0), "
- "datetime.datetime(2019, 3, 30, 0, 0), "
- "datetime.timedelta(seconds=36000))",
- "numeric_range(datetime.datetime(2019, 3, 29, 0, 0), "
- "datetime.datetime(2019, 3, 30, 0, 0), "
- "datetime.timedelta(0, 36000))",
- ),
- ]:
- with self.subTest(args=args):
- self.assertIn(repr(mi.numeric_range(*args)), expected)
-
- def test_reversed(self):
- for args, expected in [
- ((7.0,), [6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0]),
- ((1.0, 7.0), [6.0, 5.0, 4.0, 3.0, 2.0, 1.0]),
- ((7.0, 1.0, -1.5), [2.5, 4.0, 5.5, 7.0]),
- ((7.0, 0.9, -1.5), [1.0, 2.5, 4.0, 5.5, 7.0]),
- (
- (Decimal("1.0"), Decimal("5.0"), Decimal("1.5")),
- [Decimal('4.0'), Decimal('2.5'), Decimal('1.0')],
- ),
- (
- (Fraction(1, 1), Fraction(5, 1), Fraction(3, 2)),
- [Fraction(4, 1), Fraction(5, 2), Fraction(1, 1)],
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- [
- datetime(2019, 3, 29, 20),
- datetime(2019, 3, 29, 10),
- datetime(2019, 3, 29),
- ],
- ),
- ]:
- self.assertEqual(expected, list(reversed(mi.numeric_range(*args))))
-
- def test_count(self):
- for args, v, c in [
- ((7.0,), 0.0, 1),
- ((7.0,), 0.5, 0),
- ((7.0,), 6.0, 1),
- ((7.0,), 7.0, 0),
- ((7.0,), 10.0, 0),
- (
- (Decimal("1.0"), Decimal("5.0"), Decimal("1.5")),
- Decimal('4.0'),
- 1,
- ),
- (
- (Fraction(1, 1), Fraction(5, 1), Fraction(3, 2)),
- Fraction(5, 2),
- 1,
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- datetime(2019, 3, 29, 20),
- 1,
- ),
- ]:
- self.assertEqual(c, mi.numeric_range(*args).count(v))
-
- def test_index(self):
- for args, v, i in [
- ((7.0,), 0.0, 0),
- ((7.0,), 6.0, 6),
- ((7.0, 0.0, -1.0), 7.0, 0),
- ((7.0, 0.0, -1.0), 1.0, 6),
- (
- (Decimal("1.0"), Decimal("5.0"), Decimal("1.5")),
- Decimal('4.0'),
- 2,
- ),
- (
- (Fraction(1, 1), Fraction(5, 1), Fraction(3, 2)),
- Fraction(5, 2),
- 1,
- ),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- datetime(2019, 3, 29, 20),
- 2,
- ),
- ]:
- self.assertEqual(i, mi.numeric_range(*args).index(v))
-
- for args, v in [
- ((0.7,), 0.5),
- ((0.7,), 7.0),
- ((0.7,), 10.0),
- ((7.0, 0.0, -1.0), 0.5),
- ((7.0, 0.0, -1.0), 0.0),
- ((7.0, 0.0, -1.0), 10.0),
- ((7.0, 0.0), 5.0),
- ((Decimal("1.0"), Decimal("5.0"), Decimal("1.5")), Decimal('4.5')),
- ((Fraction(1, 1), Fraction(5, 1), Fraction(3, 2)), Fraction(5, 3)),
- (
- (
- datetime(2019, 3, 29),
- datetime(2019, 3, 30),
- timedelta(hours=10),
- ),
- datetime(2019, 3, 30),
- ),
- ]:
- with self.assertRaises(ValueError):
- mi.numeric_range(*args).index(v)
-
- def test_parent_classes(self):
- r = mi.numeric_range(7.0)
- self.assertTrue(isinstance(r, abc.Iterable))
- self.assertFalse(isinstance(r, abc.Iterator))
- self.assertTrue(isinstance(r, abc.Sequence))
- self.assertTrue(isinstance(r, abc.Hashable))
-
- def test_bad_key(self):
- r = mi.numeric_range(7.0)
- for arg, message in [
- ('a', 'numeric range indices must be integers or slices, not str'),
- (
- (),
- 'numeric range indices must be integers or slices, not tuple',
- ),
- ]:
- with self.assertRaisesRegex(TypeError, message):
- r[arg]
-
- def test_pickle(self):
- for args in [
- (7.0,),
- (5.0, 7.0),
- (5.0, 7.0, 3.0),
- (7.0, 5.0),
- (7.0, 5.0, 4.0),
- (7.0, 5.0, -1.0),
- (Decimal("1.0"), Decimal("5.0"), Decimal("1.5")),
- (Fraction(1, 1), Fraction(5, 1), Fraction(3, 2)),
- (datetime(2019, 3, 29), datetime(2019, 3, 30)),
- ]:
- r = mi.numeric_range(*args)
- self.assertTrue(dumps(r)) # assert not empty
- self.assertEqual(r, loads(dumps(r)))
-
-
-class CountCycleTests(TestCase):
- def test_basic(self):
- expected = [
- (0, 'a'),
- (0, 'b'),
- (0, 'c'),
- (1, 'a'),
- (1, 'b'),
- (1, 'c'),
- (2, 'a'),
- (2, 'b'),
- (2, 'c'),
- ]
- for actual in [
- mi.take(9, mi.count_cycle('abc')), # n=None
- list(mi.count_cycle('abc', 3)), # n=3
- ]:
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- self.assertEqual(list(mi.count_cycle('')), [])
- self.assertEqual(list(mi.count_cycle('', 2)), [])
-
- def test_negative(self):
- self.assertEqual(list(mi.count_cycle('abc', -3)), [])
-
-
-class MarkEndsTests(TestCase):
- def test_basic(self):
- for size, expected in [
- (0, []),
- (1, [(True, True, '0')]),
- (2, [(True, False, '0'), (False, True, '1')]),
- (3, [(True, False, '0'), (False, False, '1'), (False, True, '2')]),
- (
- 4,
- [
- (True, False, '0'),
- (False, False, '1'),
- (False, False, '2'),
- (False, True, '3'),
- ],
- ),
- ]:
- with self.subTest(size=size):
- iterable = map(str, range(size))
- actual = list(mi.mark_ends(iterable))
- self.assertEqual(actual, expected)
-
-
-class LocateTests(TestCase):
- def test_default_pred(self):
- iterable = [0, 1, 1, 0, 1, 0, 0]
- actual = list(mi.locate(iterable))
- expected = [1, 2, 4]
- self.assertEqual(actual, expected)
-
- def test_no_matches(self):
- iterable = [0, 0, 0]
- actual = list(mi.locate(iterable))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_custom_pred(self):
- iterable = ['0', 1, 1, '0', 1, '0', '0']
- pred = lambda x: x == '0'
- actual = list(mi.locate(iterable, pred))
- expected = [0, 3, 5, 6]
- self.assertEqual(actual, expected)
-
- def test_window_size(self):
- iterable = ['0', 1, 1, '0', 1, '0', '0']
- pred = lambda *args: args == ('0', 1)
- actual = list(mi.locate(iterable, pred, window_size=2))
- expected = [0, 3]
- self.assertEqual(actual, expected)
-
- def test_window_size_large(self):
- iterable = [1, 2, 3, 4]
- pred = lambda a, b, c, d, e: True
- actual = list(mi.locate(iterable, pred, window_size=5))
- expected = [0]
- self.assertEqual(actual, expected)
-
- def test_window_size_zero(self):
- iterable = [1, 2, 3, 4]
- pred = lambda: True
- with self.assertRaises(ValueError):
- list(mi.locate(iterable, pred, window_size=0))
-
-
-class StripFunctionTests(TestCase):
- def test_hashable(self):
- iterable = list('www.example.com')
- pred = lambda x: x in set('cmowz.')
-
- self.assertEqual(list(mi.lstrip(iterable, pred)), list('example.com'))
- self.assertEqual(list(mi.rstrip(iterable, pred)), list('www.example'))
- self.assertEqual(list(mi.strip(iterable, pred)), list('example'))
-
- def test_not_hashable(self):
- iterable = [
- list('http://'),
- list('www'),
- list('.example'),
- list('.com'),
- ]
- pred = lambda x: x in [list('http://'), list('www'), list('.com')]
-
- self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[2:])
- self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:3])
- self.assertEqual(list(mi.strip(iterable, pred)), iterable[2:3])
-
- def test_math(self):
- iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]
- pred = lambda x: x <= 2
-
- self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[3:])
- self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:-3])
- self.assertEqual(list(mi.strip(iterable, pred)), iterable[3:-3])
-
-
-class IsliceExtendedTests(TestCase):
- def test_all(self):
- iterable = ['0', '1', '2', '3', '4', '5']
- indexes = list(range(-4, len(iterable) + 4)) + [None]
- steps = [1, 2, 3, 4, -1, -2, -3, 4]
- for slice_args in product(indexes, indexes, steps):
- with self.subTest(slice_args=slice_args):
- actual = list(mi.islice_extended(iterable, *slice_args))
- expected = iterable[slice(*slice_args)]
- self.assertEqual(actual, expected, slice_args)
-
- def test_zero_step(self):
- with self.assertRaises(ValueError):
- list(mi.islice_extended([1, 2, 3], 0, 1, 0))
-
- def test_slicing(self):
- iterable = map(str, count())
- first_slice = mi.islice_extended(iterable)[10:]
- second_slice = mi.islice_extended(first_slice)[:10]
- third_slice = mi.islice_extended(second_slice)[::2]
- self.assertEqual(list(third_slice), ['10', '12', '14', '16', '18'])
-
- def test_slicing_extensive(self):
- iterable = range(10)
- options = (None, 1, 2, 7, -1)
- for start, stop, step in product(options, options, options):
- with self.subTest(slice_args=(start, stop, step)):
- sliced_tuple_0 = tuple(
- mi.islice_extended(iterable)[start:stop:step]
- )
- sliced_tuple_1 = tuple(
- mi.islice_extended(iterable, start, stop, step)
- )
- sliced_range = tuple(iterable[start:stop:step])
- self.assertEqual(sliced_tuple_0, sliced_range)
- self.assertEqual(sliced_tuple_1, sliced_range)
-
- def test_invalid_slice(self):
- with self.assertRaises(TypeError):
- mi.islice_extended(count())[13]
-
-
-class ConsecutiveGroupsTest(TestCase):
- def test_numbers(self):
- iterable = [-10, -8, -7, -6, 1, 2, 4, 5, -1, 7]
- actual = [list(g) for g in mi.consecutive_groups(iterable)]
- expected = [[-10], [-8, -7, -6], [1, 2], [4, 5], [-1], [7]]
- self.assertEqual(actual, expected)
-
- def test_custom_ordering(self):
- iterable = ['1', '10', '11', '20', '21', '22', '30', '31']
- ordering = lambda x: int(x)
- actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)]
- expected = [['1'], ['10', '11'], ['20', '21', '22'], ['30', '31']]
- self.assertEqual(actual, expected)
-
- def test_exotic_ordering(self):
- iterable = [
- ('a', 'b', 'c', 'd'),
- ('a', 'c', 'b', 'd'),
- ('a', 'c', 'd', 'b'),
- ('a', 'd', 'b', 'c'),
- ('d', 'b', 'c', 'a'),
- ('d', 'c', 'a', 'b'),
- ]
- ordering = list(permutations('abcd')).index
- actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)]
- expected = [
- [('a', 'b', 'c', 'd')],
- [('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b'), ('a', 'd', 'b', 'c')],
- [('d', 'b', 'c', 'a'), ('d', 'c', 'a', 'b')],
- ]
- self.assertEqual(actual, expected)
-
-
-class DifferenceTest(TestCase):
- def test_normal(self):
- iterable = [10, 20, 30, 40, 50]
- actual = list(mi.difference(iterable))
- expected = [10, 10, 10, 10, 10]
- self.assertEqual(actual, expected)
-
- def test_custom(self):
- iterable = [10, 20, 30, 40, 50]
- actual = list(mi.difference(iterable, add))
- expected = [10, 30, 50, 70, 90]
- self.assertEqual(actual, expected)
-
- def test_roundtrip(self):
- original = list(range(100))
- accumulated = accumulate(original)
- actual = list(mi.difference(accumulated))
- self.assertEqual(actual, original)
-
- def test_one(self):
- self.assertEqual(list(mi.difference([0])), [0])
-
- def test_empty(self):
- self.assertEqual(list(mi.difference([])), [])
-
- @skipIf(version_info[:2] < (3, 8), 'accumulate with initial needs 3.8+')
- def test_initial(self):
- original = list(range(100))
- accumulated = accumulate(original, initial=100)
- actual = list(mi.difference(accumulated, initial=100))
- self.assertEqual(actual, original)
-
-
-class SeekableTest(PeekableMixinTests, TestCase):
- cls = mi.seekable
-
- def test_exhaustion_reset(self):
- iterable = [str(n) for n in range(10)]
-
- s = mi.seekable(iterable)
- self.assertEqual(list(s), iterable) # Normal iteration
- self.assertEqual(list(s), []) # Iterable is exhausted
-
- s.seek(0)
- self.assertEqual(list(s), iterable) # Back in action
-
- def test_partial_reset(self):
- iterable = [str(n) for n in range(10)]
-
- s = mi.seekable(iterable)
- self.assertEqual(mi.take(5, s), iterable[:5]) # Normal iteration
-
- s.seek(1)
- self.assertEqual(list(s), iterable[1:]) # Get the rest of the iterable
-
- def test_forward(self):
- iterable = [str(n) for n in range(10)]
-
- s = mi.seekable(iterable)
- self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration
-
- s.seek(3) # Skip over index 2
- self.assertEqual(list(s), iterable[3:]) # Result is similar to slicing
-
- s.seek(0) # Back to 0
- self.assertEqual(list(s), iterable) # No difference in result
-
- def test_past_end(self):
- iterable = [str(n) for n in range(10)]
-
- s = mi.seekable(iterable)
- self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration
-
- s.seek(20)
- self.assertEqual(list(s), []) # Iterable is exhausted
-
- s.seek(0) # Back to 0
- self.assertEqual(list(s), iterable) # No difference in result
-
- def test_elements(self):
- iterable = map(str, count())
-
- s = mi.seekable(iterable)
- mi.take(10, s)
-
- elements = s.elements()
- self.assertEqual(
- [elements[i] for i in range(10)], [str(n) for n in range(10)]
- )
- self.assertEqual(len(elements), 10)
-
- mi.take(10, s)
- self.assertEqual(list(elements), [str(n) for n in range(20)])
-
- def test_maxlen(self):
- iterable = map(str, count())
-
- s = mi.seekable(iterable, maxlen=4)
- self.assertEqual(mi.take(10, s), [str(n) for n in range(10)])
- self.assertEqual(list(s.elements()), ['6', '7', '8', '9'])
-
- s.seek(0)
- self.assertEqual(mi.take(14, s), [str(n) for n in range(6, 20)])
- self.assertEqual(list(s.elements()), ['16', '17', '18', '19'])
-
- def test_maxlen_zero(self):
- iterable = [str(x) for x in range(5)]
- s = mi.seekable(iterable, maxlen=0)
- self.assertEqual(list(s), iterable)
- self.assertEqual(list(s.elements()), [])
-
-
-class SequenceViewTests(TestCase):
- def test_init(self):
- view = mi.SequenceView((1, 2, 3))
- self.assertEqual(repr(view), "SequenceView((1, 2, 3))")
- self.assertRaises(TypeError, lambda: mi.SequenceView({}))
-
- def test_update(self):
- seq = [1, 2, 3]
- view = mi.SequenceView(seq)
- self.assertEqual(len(view), 3)
- self.assertEqual(repr(view), "SequenceView([1, 2, 3])")
-
- seq.pop()
- self.assertEqual(len(view), 2)
- self.assertEqual(repr(view), "SequenceView([1, 2])")
-
- def test_indexing(self):
- seq = ('a', 'b', 'c', 'd', 'e', 'f')
- view = mi.SequenceView(seq)
- for i in range(-len(seq), len(seq)):
- self.assertEqual(view[i], seq[i])
-
- def test_slicing(self):
- seq = ('a', 'b', 'c', 'd', 'e', 'f')
- view = mi.SequenceView(seq)
- n = len(seq)
- indexes = list(range(-n - 1, n + 1)) + [None]
- steps = list(range(-n, n + 1))
- steps.remove(0)
- for slice_args in product(indexes, indexes, steps):
- i = slice(*slice_args)
- self.assertEqual(view[i], seq[i])
-
- def test_abc_methods(self):
- # collections.Sequence should provide all of this functionality
- seq = ('a', 'b', 'c', 'd', 'e', 'f', 'f')
- view = mi.SequenceView(seq)
-
- # __contains__
- self.assertIn('b', view)
- self.assertNotIn('g', view)
-
- # __iter__
- self.assertEqual(list(iter(view)), list(seq))
-
- # __reversed__
- self.assertEqual(list(reversed(view)), list(reversed(seq)))
-
- # index
- self.assertEqual(view.index('b'), 1)
-
- # count
- self.assertEqual(seq.count('f'), 2)
-
-
-class RunLengthTest(TestCase):
- def test_encode(self):
- iterable = (int(str(n)[0]) for n in count(800))
- actual = mi.take(4, mi.run_length.encode(iterable))
- expected = [(8, 100), (9, 100), (1, 1000), (2, 1000)]
- self.assertEqual(actual, expected)
-
- def test_decode(self):
- iterable = [('d', 4), ('c', 3), ('b', 2), ('a', 1)]
- actual = ''.join(mi.run_length.decode(iterable))
- expected = 'ddddcccbba'
- self.assertEqual(actual, expected)
-
-
-class ExactlyNTests(TestCase):
- """Tests for ``exactly_n()``"""
-
- def test_true(self):
- """Iterable has ``n`` ``True`` elements"""
- self.assertTrue(mi.exactly_n([True, False, True], 2))
- self.assertTrue(mi.exactly_n([1, 1, 1, 0], 3))
- self.assertTrue(mi.exactly_n([False, False], 0))
- self.assertTrue(mi.exactly_n(range(100), 10, lambda x: x < 10))
-
- def test_false(self):
- """Iterable does not have ``n`` ``True`` elements"""
- self.assertFalse(mi.exactly_n([True, False, False], 2))
- self.assertFalse(mi.exactly_n([True, True, False], 1))
- self.assertFalse(mi.exactly_n([False], 1))
- self.assertFalse(mi.exactly_n([True], -1))
- self.assertFalse(mi.exactly_n(repeat(True), 100))
-
- def test_empty(self):
- """Return ``True`` if the iterable is empty and ``n`` is 0"""
- self.assertTrue(mi.exactly_n([], 0))
- self.assertFalse(mi.exactly_n([], 1))
-
-
-class AlwaysReversibleTests(TestCase):
- """Tests for ``always_reversible()``"""
-
- def test_regular_reversed(self):
- self.assertEqual(
- list(reversed(range(10))), list(mi.always_reversible(range(10)))
- )
- self.assertEqual(
- list(reversed([1, 2, 3])), list(mi.always_reversible([1, 2, 3]))
- )
- self.assertEqual(
- reversed([1, 2, 3]).__class__,
- mi.always_reversible([1, 2, 3]).__class__,
- )
-
- def test_nonseq_reversed(self):
- # Create a non-reversible generator from a sequence
- with self.assertRaises(TypeError):
- reversed(x for x in range(10))
-
- self.assertEqual(
- list(reversed(range(10))),
- list(mi.always_reversible(x for x in range(10))),
- )
- self.assertEqual(
- list(reversed([1, 2, 3])),
- list(mi.always_reversible(x for x in [1, 2, 3])),
- )
- self.assertNotEqual(
- reversed((1, 2)).__class__,
- mi.always_reversible(x for x in (1, 2)).__class__,
- )
-
-
-class CircularShiftsTests(TestCase):
- def test_empty(self):
- # empty iterable -> empty list
- self.assertEqual(list(mi.circular_shifts([])), [])
-
- def test_simple_circular_shifts(self):
- # test the a simple iterator case
- self.assertEqual(
- mi.circular_shifts(range(4)),
- [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)],
- )
-
- def test_duplicates(self):
- # test non-distinct entries
- self.assertEqual(
- mi.circular_shifts([0, 1, 0, 1]),
- [(0, 1, 0, 1), (1, 0, 1, 0), (0, 1, 0, 1), (1, 0, 1, 0)],
- )
-
-
-class MakeDecoratorTests(TestCase):
- def test_basic(self):
- slicer = mi.make_decorator(islice)
-
- @slicer(1, 10, 2)
- def user_function(arg_1, arg_2, kwarg_1=None):
- self.assertEqual(arg_1, 'arg_1')
- self.assertEqual(arg_2, 'arg_2')
- self.assertEqual(kwarg_1, 'kwarg_1')
- return map(str, count())
-
- it = user_function('arg_1', 'arg_2', kwarg_1='kwarg_1')
- actual = list(it)
- expected = ['1', '3', '5', '7', '9']
- self.assertEqual(actual, expected)
-
- def test_result_index(self):
- def stringify(*args, **kwargs):
- self.assertEqual(args[0], 'arg_0')
- iterable = args[1]
- self.assertEqual(args[2], 'arg_2')
- self.assertEqual(kwargs['kwarg_1'], 'kwarg_1')
- return map(str, iterable)
-
- stringifier = mi.make_decorator(stringify, result_index=1)
-
- @stringifier('arg_0', 'arg_2', kwarg_1='kwarg_1')
- def user_function(n):
- return count(n)
-
- it = user_function(1)
- actual = mi.take(5, it)
- expected = ['1', '2', '3', '4', '5']
- self.assertEqual(actual, expected)
-
- def test_wrap_class(self):
- seeker = mi.make_decorator(mi.seekable)
-
- @seeker()
- def user_function(n):
- return map(str, range(n))
-
- it = user_function(5)
- self.assertEqual(list(it), ['0', '1', '2', '3', '4'])
-
- it.seek(0)
- self.assertEqual(list(it), ['0', '1', '2', '3', '4'])
-
-
-class MapReduceTests(TestCase):
- def test_default(self):
- iterable = (str(x) for x in range(5))
- keyfunc = lambda x: int(x) // 2
- actual = sorted(mi.map_reduce(iterable, keyfunc).items())
- expected = [(0, ['0', '1']), (1, ['2', '3']), (2, ['4'])]
- self.assertEqual(actual, expected)
-
- def test_valuefunc(self):
- iterable = (str(x) for x in range(5))
- keyfunc = lambda x: int(x) // 2
- valuefunc = int
- actual = sorted(mi.map_reduce(iterable, keyfunc, valuefunc).items())
- expected = [(0, [0, 1]), (1, [2, 3]), (2, [4])]
- self.assertEqual(actual, expected)
-
- def test_reducefunc(self):
- iterable = (str(x) for x in range(5))
- keyfunc = lambda x: int(x) // 2
- valuefunc = int
- reducefunc = lambda value_list: reduce(mul, value_list, 1)
- actual = sorted(
- mi.map_reduce(iterable, keyfunc, valuefunc, reducefunc).items()
- )
- expected = [(0, 0), (1, 6), (2, 4)]
- self.assertEqual(actual, expected)
-
- def test_ret(self):
- d = mi.map_reduce([1, 0, 2, 0, 1, 0], bool)
- self.assertEqual(d, {False: [0, 0, 0], True: [1, 2, 1]})
- self.assertRaises(KeyError, lambda: d[None].append(1))
-
-
-class RlocateTests(TestCase):
- def test_default_pred(self):
- iterable = [0, 1, 1, 0, 1, 0, 0]
- for it in (iterable[:], iter(iterable)):
- actual = list(mi.rlocate(it))
- expected = [4, 2, 1]
- self.assertEqual(actual, expected)
-
- def test_no_matches(self):
- iterable = [0, 0, 0]
- for it in (iterable[:], iter(iterable)):
- actual = list(mi.rlocate(it))
- expected = []
- self.assertEqual(actual, expected)
-
- def test_custom_pred(self):
- iterable = ['0', 1, 1, '0', 1, '0', '0']
- pred = lambda x: x == '0'
- for it in (iterable[:], iter(iterable)):
- actual = list(mi.rlocate(it, pred))
- expected = [6, 5, 3, 0]
- self.assertEqual(actual, expected)
-
- def test_efficient_reversal(self):
- iterable = range(9 ** 9) # Is efficiently reversible
- target = 9 ** 9 - 2
- pred = lambda x: x == target # Find-able from the right
- actual = next(mi.rlocate(iterable, pred))
- self.assertEqual(actual, target)
-
- def test_window_size(self):
- iterable = ['0', 1, 1, '0', 1, '0', '0']
- pred = lambda *args: args == ('0', 1)
- for it in (iterable, iter(iterable)):
- actual = list(mi.rlocate(it, pred, window_size=2))
- expected = [3, 0]
- self.assertEqual(actual, expected)
-
- def test_window_size_large(self):
- iterable = [1, 2, 3, 4]
- pred = lambda a, b, c, d, e: True
- for it in (iterable, iter(iterable)):
- actual = list(mi.rlocate(iterable, pred, window_size=5))
- expected = [0]
- self.assertEqual(actual, expected)
-
- def test_window_size_zero(self):
- iterable = [1, 2, 3, 4]
- pred = lambda: True
- for it in (iterable, iter(iterable)):
- with self.assertRaises(ValueError):
- list(mi.locate(iterable, pred, window_size=0))
-
-
-class ReplaceTests(TestCase):
- def test_basic(self):
- iterable = range(10)
- pred = lambda x: x % 2 == 0
- substitutes = []
- actual = list(mi.replace(iterable, pred, substitutes))
- expected = [1, 3, 5, 7, 9]
- self.assertEqual(actual, expected)
-
- def test_count(self):
- iterable = range(10)
- pred = lambda x: x % 2 == 0
- substitutes = []
- actual = list(mi.replace(iterable, pred, substitutes, count=4))
- expected = [1, 3, 5, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_window_size(self):
- iterable = range(10)
- pred = lambda *args: args == (0, 1, 2)
- substitutes = []
- actual = list(mi.replace(iterable, pred, substitutes, window_size=3))
- expected = [3, 4, 5, 6, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_window_size_end(self):
- iterable = range(10)
- pred = lambda *args: args == (7, 8, 9)
- substitutes = []
- actual = list(mi.replace(iterable, pred, substitutes, window_size=3))
- expected = [0, 1, 2, 3, 4, 5, 6]
- self.assertEqual(actual, expected)
-
- def test_window_size_count(self):
- iterable = range(10)
- pred = lambda *args: (args == (0, 1, 2)) or (args == (7, 8, 9))
- substitutes = []
- actual = list(
- mi.replace(iterable, pred, substitutes, count=1, window_size=3)
- )
- expected = [3, 4, 5, 6, 7, 8, 9]
- self.assertEqual(actual, expected)
-
- def test_window_size_large(self):
- iterable = range(4)
- pred = lambda a, b, c, d, e: True
- substitutes = [5, 6, 7]
- actual = list(mi.replace(iterable, pred, substitutes, window_size=5))
- expected = [5, 6, 7]
- self.assertEqual(actual, expected)
-
- def test_window_size_zero(self):
- iterable = range(10)
- pred = lambda *args: True
- substitutes = []
- with self.assertRaises(ValueError):
- list(mi.replace(iterable, pred, substitutes, window_size=0))
-
- def test_iterable_substitutes(self):
- iterable = range(5)
- pred = lambda x: x % 2 == 0
- substitutes = iter('__')
- actual = list(mi.replace(iterable, pred, substitutes))
- expected = ['_', '_', 1, '_', '_', 3, '_', '_']
- self.assertEqual(actual, expected)
-
-
-class PartitionsTest(TestCase):
- def test_types(self):
- for iterable in ['abcd', ['a', 'b', 'c', 'd'], ('a', 'b', 'c', 'd')]:
- with self.subTest(iterable=iterable):
- actual = list(mi.partitions(iterable))
- expected = [
- [['a', 'b', 'c', 'd']],
- [['a'], ['b', 'c', 'd']],
- [['a', 'b'], ['c', 'd']],
- [['a', 'b', 'c'], ['d']],
- [['a'], ['b'], ['c', 'd']],
- [['a'], ['b', 'c'], ['d']],
- [['a', 'b'], ['c'], ['d']],
- [['a'], ['b'], ['c'], ['d']],
- ]
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- iterable = []
- actual = list(mi.partitions(iterable))
- expected = [[[]]]
- self.assertEqual(actual, expected)
-
- def test_order(self):
- iterable = iter([3, 2, 1])
- actual = list(mi.partitions(iterable))
- expected = [[[3, 2, 1]], [[3], [2, 1]], [[3, 2], [1]], [[3], [2], [1]]]
- self.assertEqual(actual, expected)
-
- def test_duplicates(self):
- iterable = [1, 1, 1]
- actual = list(mi.partitions(iterable))
- expected = [[[1, 1, 1]], [[1], [1, 1]], [[1, 1], [1]], [[1], [1], [1]]]
- self.assertEqual(actual, expected)
-
-
-class _FrozenMultiset(Set):
- """
- A helper class, useful to compare two lists without reference to the order
- of elements.
-
- FrozenMultiset represents a hashable set that allows duplicate elements.
- """
-
- def __init__(self, iterable):
- self._collection = frozenset(Counter(iterable).items())
-
- def __contains__(self, y):
- """
- >>> (0, 1) in _FrozenMultiset([(0, 1), (2,), (0, 1)])
- True
- """
- return any(y == x for x, _ in self._collection)
-
- def __iter__(self):
- """
- >>> sorted(_FrozenMultiset([(0, 1), (2,), (0, 1)]))
- [(0, 1), (0, 1), (2,)]
- """
- return (x for x, c in self._collection for _ in range(c))
-
- def __len__(self):
- """
- >>> len(_FrozenMultiset([(0, 1), (2,), (0, 1)]))
- 3
- """
- return sum(c for x, c in self._collection)
-
- def has_duplicates(self):
- """
- >>> _FrozenMultiset([(0, 1), (2,), (0, 1)]).has_duplicates()
- True
- """
- return any(c != 1 for _, c in self._collection)
-
- def __hash__(self):
- return hash(self._collection)
-
- def __repr__(self):
- return "FrozenSet([{}]".format(", ".join(repr(x) for x in iter(self)))
-
-
-class SetPartitionsTests(TestCase):
- @staticmethod
- def _normalize_partition(p):
- """
- Return a normalized, hashable, version of a partition using
- _FrozenMultiset
- """
- return _FrozenMultiset(_FrozenMultiset(g) for g in p)
-
- @staticmethod
- def _normalize_partitions(ps):
- """
- Return a normalized set of all normalized partitions using
- _FrozenMultiset
- """
- return _FrozenMultiset(
- SetPartitionsTests._normalize_partition(p) for p in ps
- )
-
- def test_repeated(self):
- it = 'aaa'
- actual = mi.set_partitions(it, 2)
- expected = [['a', 'aa'], ['a', 'aa'], ['a', 'aa']]
- self.assertEqual(
- self._normalize_partitions(expected),
- self._normalize_partitions(actual),
- )
-
- def test_each_correct(self):
- a = set(range(6))
- for p in mi.set_partitions(a):
- total = {e for g in p for e in g}
- self.assertEqual(a, total)
-
- def test_duplicates(self):
- a = set(range(6))
- for p in mi.set_partitions(a):
- self.assertFalse(self._normalize_partition(p).has_duplicates())
-
- def test_found_all(self):
- """small example, hand-checked"""
- expected = [
- [[0], [1], [2, 3, 4]],
- [[0], [1, 2], [3, 4]],
- [[0], [2], [1, 3, 4]],
- [[0], [3], [1, 2, 4]],
- [[0], [4], [1, 2, 3]],
- [[0], [1, 3], [2, 4]],
- [[0], [1, 4], [2, 3]],
- [[1], [2], [0, 3, 4]],
- [[1], [3], [0, 2, 4]],
- [[1], [4], [0, 2, 3]],
- [[1], [0, 2], [3, 4]],
- [[1], [0, 3], [2, 4]],
- [[1], [0, 4], [2, 3]],
- [[2], [3], [0, 1, 4]],
- [[2], [4], [0, 1, 3]],
- [[2], [0, 1], [3, 4]],
- [[2], [0, 3], [1, 4]],
- [[2], [0, 4], [1, 3]],
- [[3], [4], [0, 1, 2]],
- [[3], [0, 1], [2, 4]],
- [[3], [0, 2], [1, 4]],
- [[3], [0, 4], [1, 2]],
- [[4], [0, 1], [2, 3]],
- [[4], [0, 2], [1, 3]],
- [[4], [0, 3], [1, 2]],
- ]
- actual = mi.set_partitions(range(5), 3)
- self.assertEqual(
- self._normalize_partitions(expected),
- self._normalize_partitions(actual),
- )
-
- def test_stirling_numbers(self):
- """Check against https://en.wikipedia.org/wiki/
- Stirling_numbers_of_the_second_kind#Table_of_values"""
- cardinality_by_k_by_n = [
- [1],
- [1, 1],
- [1, 3, 1],
- [1, 7, 6, 1],
- [1, 15, 25, 10, 1],
- [1, 31, 90, 65, 15, 1],
- ]
- for n, cardinality_by_k in enumerate(cardinality_by_k_by_n, 1):
- for k, cardinality in enumerate(cardinality_by_k, 1):
- self.assertEqual(
- cardinality, len(list(mi.set_partitions(range(n), k)))
- )
-
- def test_no_group(self):
- def helper():
- list(mi.set_partitions(range(4), -1))
-
- self.assertRaises(ValueError, helper)
-
- def test_to_many_groups(self):
- self.assertEqual([], list(mi.set_partitions(range(4), 5)))
-
-
-class TimeLimitedTests(TestCase):
- def test_basic(self):
- def generator():
- yield 1
- yield 2
- sleep(0.2)
- yield 3
-
- iterable = mi.time_limited(0.1, generator())
- actual = list(iterable)
- expected = [1, 2]
- self.assertEqual(actual, expected)
- self.assertTrue(iterable.timed_out)
-
- def test_complete(self):
- iterable = mi.time_limited(2, iter(range(10)))
- actual = list(iterable)
- expected = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
- self.assertEqual(actual, expected)
- self.assertFalse(iterable.timed_out)
-
- def test_zero_limit(self):
- iterable = mi.time_limited(0, count())
- actual = list(iterable)
- expected = []
- self.assertEqual(actual, expected)
- self.assertTrue(iterable.timed_out)
-
- def test_invalid_limit(self):
- with self.assertRaises(ValueError):
- list(mi.time_limited(-0.1, count()))
-
-
-class OnlyTests(TestCase):
- def test_defaults(self):
- self.assertEqual(mi.only([]), None)
- self.assertEqual(mi.only([1]), 1)
- self.assertRaises(ValueError, lambda: mi.only([1, 2]))
-
- def test_custom_value(self):
- self.assertEqual(mi.only([], default='!'), '!')
- self.assertEqual(mi.only([1], default='!'), 1)
- self.assertRaises(ValueError, lambda: mi.only([1, 2], default='!'))
-
- def test_custom_exception(self):
- self.assertEqual(mi.only([], too_long=RuntimeError), None)
- self.assertEqual(mi.only([1], too_long=RuntimeError), 1)
- self.assertRaises(
- RuntimeError, lambda: mi.only([1, 2], too_long=RuntimeError)
- )
-
- def test_default_exception_message(self):
- self.assertRaisesRegex(
- ValueError,
- "Expected exactly one item in iterable, "
- "but got 'foo', 'bar', and perhaps more",
- lambda: mi.only(['foo', 'bar', 'baz']),
- )
-
-
-class IchunkedTests(TestCase):
- def test_even(self):
- iterable = (str(x) for x in range(10))
- actual = [''.join(c) for c in mi.ichunked(iterable, 5)]
- expected = ['01234', '56789']
- self.assertEqual(actual, expected)
-
- def test_odd(self):
- iterable = (str(x) for x in range(10))
- actual = [''.join(c) for c in mi.ichunked(iterable, 4)]
- expected = ['0123', '4567', '89']
- self.assertEqual(actual, expected)
-
- def test_zero(self):
- iterable = []
- actual = [list(c) for c in mi.ichunked(iterable, 0)]
- expected = []
- self.assertEqual(actual, expected)
-
- def test_negative(self):
- iterable = count()
- with self.assertRaises(ValueError):
- [list(c) for c in mi.ichunked(iterable, -1)]
-
- def test_out_of_order(self):
- iterable = map(str, count())
- it = mi.ichunked(iterable, 4)
- chunk_1 = next(it)
- chunk_2 = next(it)
- self.assertEqual(''.join(chunk_2), '4567')
- self.assertEqual(''.join(chunk_1), '0123')
-
- def test_laziness(self):
- def gen():
- yield 0
- raise RuntimeError
- yield from count(1)
-
- it = mi.ichunked(gen(), 4)
- chunk = next(it)
- self.assertEqual(next(chunk), 0)
- self.assertRaises(RuntimeError, next, it)
-
-
-class DistinctCombinationsTests(TestCase):
- def test_basic(self):
- for iterable in [
- (1, 2, 2, 3, 3, 3), # In order
- range(6), # All distinct
- 'abbccc', # Not numbers
- 'cccbba', # Backward
- 'mississippi', # No particular order
- ]:
- for r in range(len(iterable)):
- with self.subTest(iterable=iterable, r=r):
- actual = list(mi.distinct_combinations(iterable, r))
- expected = list(
- mi.unique_everseen(combinations(iterable, r))
- )
- self.assertEqual(actual, expected)
-
- def test_negative(self):
- with self.assertRaises(ValueError):
- list(mi.distinct_combinations([], -1))
-
- def test_empty(self):
- self.assertEqual(list(mi.distinct_combinations([], 2)), [])
-
-
-class FilterExceptTests(TestCase):
- def test_no_exceptions_pass(self):
- iterable = '0123'
- actual = list(mi.filter_except(int, iterable))
- expected = ['0', '1', '2', '3']
- self.assertEqual(actual, expected)
-
- def test_no_exceptions_raise(self):
- iterable = ['0', '1', 'two', '3']
- with self.assertRaises(ValueError):
- list(mi.filter_except(int, iterable))
-
- def test_raise(self):
- iterable = ['0', '1' '2', 'three', None]
- with self.assertRaises(TypeError):
- list(mi.filter_except(int, iterable, ValueError))
-
- def test_false(self):
- # Even if the validator returns false, we pass through
- validator = lambda x: False
- iterable = ['0', '1', '2', 'three', None]
- actual = list(mi.filter_except(validator, iterable, Exception))
- expected = ['0', '1', '2', 'three', None]
- self.assertEqual(actual, expected)
-
- def test_multiple(self):
- iterable = ['0', '1', '2', 'three', None, '4']
- actual = list(mi.filter_except(int, iterable, ValueError, TypeError))
- expected = ['0', '1', '2', '4']
- self.assertEqual(actual, expected)
-
-
-class MapExceptTests(TestCase):
- def test_no_exceptions_pass(self):
- iterable = '0123'
- actual = list(mi.map_except(int, iterable))
- expected = [0, 1, 2, 3]
- self.assertEqual(actual, expected)
-
- def test_no_exceptions_raise(self):
- iterable = ['0', '1', 'two', '3']
- with self.assertRaises(ValueError):
- list(mi.map_except(int, iterable))
-
- def test_raise(self):
- iterable = ['0', '1' '2', 'three', None]
- with self.assertRaises(TypeError):
- list(mi.map_except(int, iterable, ValueError))
-
- def test_multiple(self):
- iterable = ['0', '1', '2', 'three', None, '4']
- actual = list(mi.map_except(int, iterable, ValueError, TypeError))
- expected = [0, 1, 2, 4]
- self.assertEqual(actual, expected)
-
-
-class MapIfTests(TestCase):
- def test_without_func_else(self):
- iterable = list(range(-5, 5))
- actual = list(mi.map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
- expected = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
- self.assertEqual(actual, expected)
-
- def test_with_func_else(self):
- iterable = list(range(-5, 5))
- actual = list(
- mi.map_if(
- iterable, lambda x: x >= 0, lambda x: 'notneg', lambda x: 'neg'
- )
- )
- expected = ['neg'] * 5 + ['notneg'] * 5
- self.assertEqual(actual, expected)
-
- def test_empty(self):
- actual = list(mi.map_if([], lambda x: len(x) > 5, lambda x: None))
- expected = []
- self.assertEqual(actual, expected)
-
-
-class SampleTests(TestCase):
- def test_unit_case(self):
- """Test against a fixed case by seeding the random module."""
- # Beware that this test really just verifies random.random() behavior.
- # If the algorithm is changed (e.g. to a more naive implementation)
- # this test will fail, but the algorithm might be correct.
- # Also, this test can pass and the algorithm can be completely wrong.
- data = "abcdef"
- weights = list(range(1, len(data) + 1))
- seed(123)
- actual = mi.sample(data, k=2, weights=weights)
- expected = ['f', 'e']
- self.assertEqual(actual, expected)
-
- def test_length(self):
- """Check that *k* elements are sampled."""
- data = [1, 2, 3, 4, 5]
- for k in [0, 3, 5, 7]:
- sampled = mi.sample(data, k=k)
- actual = len(sampled)
- expected = min(k, len(data))
- self.assertEqual(actual, expected)
-
- def test_samling_entire_iterable(self):
- """If k=len(iterable), the sample contains the original elements."""
- data = ["a", 2, "a", 4, (1, 2, 3)]
- actual = set(mi.sample(data, k=len(data)))
- expected = set(data)
- self.assertEqual(actual, expected)
-
- def test_scale_invariance_of_weights(self):
- """The probabilit of chosing element a_i is w_i / sum(weights).
- Scaling weights should not change the probability or outcome."""
- data = "abcdef"
-
- weights = list(range(1, len(data) + 1))
- seed(123)
- first_sample = mi.sample(data, k=2, weights=weights)
-
- # Scale the weights and sample again
- weights_scaled = [w / 1e10 for w in weights]
- seed(123)
- second_sample = mi.sample(data, k=2, weights=weights_scaled)
-
- self.assertEqual(first_sample, second_sample)
-
- def test_invariance_under_permutations_unweighted(self):
- """The order of the data should not matter. This is a stochastic test,
- but it will fail in less than 1 / 10_000 cases."""
-
- # Create a data set and a reversed data set
- data = list(range(100))
- data_rev = list(reversed(data))
-
- # Sample each data set 10 times
- data_means = [mean(mi.sample(data, k=50)) for _ in range(10)]
- data_rev_means = [mean(mi.sample(data_rev, k=50)) for _ in range(10)]
-
- # The difference in the means should be low, i.e. little bias
- difference_in_means = abs(mean(data_means) - mean(data_rev_means))
-
- # The observed largest difference in 10,000 simulations was 5.09599
- self.assertTrue(difference_in_means < 5.1)
-
- def test_invariance_under_permutations_weighted(self):
- """The order of the data should not matter. This is a stochastic test,
- but it will fail in less than 1 / 10_000 cases."""
-
- # Create a data set and a reversed data set
- data = list(range(1, 101))
- data_rev = list(reversed(data))
-
- # Sample each data set 10 times
- data_means = [
- mean(mi.sample(data, k=50, weights=data)) for _ in range(10)
- ]
- data_rev_means = [
- mean(mi.sample(data_rev, k=50, weights=data_rev))
- for _ in range(10)
- ]
-
- # The difference in the means should be low, i.e. little bias
- difference_in_means = abs(mean(data_means) - mean(data_rev_means))
-
- # The observed largest difference in 10,000 simulations was 4.337999
- self.assertTrue(difference_in_means < 4.4)
-
-
-class IsSortedTests(TestCase):
- def test_basic(self):
- for iterable, kwargs, expected in [
- ([], {}, True),
- ([1], {}, True),
- ([1, 2, 3], {}, True),
- ([1, 1, 2, 3], {}, True),
- ([1, 10, 2, 3], {}, False),
- (['1', '10', '2', '3'], {}, True),
- (['1', '10', '2', '3'], {'key': int}, False),
- ([1, 2, 3], {'reverse': True}, False),
- ([1, 1, 2, 3], {'reverse': True}, False),
- ([1, 10, 2, 3], {'reverse': True}, False),
- (['3', '2', '10', '1'], {'reverse': True}, True),
- (['3', '2', '10', '1'], {'key': int, 'reverse': True}, False),
- # strict
- ([], {'strict': True}, True),
- ([1], {'strict': True}, True),
- ([1, 1], {'strict': True}, False),
- ([1, 2, 3], {'strict': True}, True),
- ([1, 1, 2, 3], {'strict': True}, False),
- ([1, 10, 2, 3], {'strict': True}, False),
- (['1', '10', '2', '3'], {'strict': True}, True),
- (['1', '10', '2', '3', '3'], {'strict': True}, False),
- (['1', '10', '2', '3'], {'strict': True, 'key': int}, False),
- ([1, 2, 3], {'strict': True, 'reverse': True}, False),
- ([1, 1, 2, 3], {'strict': True, 'reverse': True}, False),
- ([1, 10, 2, 3], {'strict': True, 'reverse': True}, False),
- (['3', '2', '10', '1'], {'strict': True, 'reverse': True}, True),
- (
- ['3', '2', '10', '10', '1'],
- {'strict': True, 'reverse': True},
- False,
- ),
- (
- ['3', '2', '10', '1'],
- {'strict': True, 'key': int, 'reverse': True},
- False,
- ),
- # We'll do the same weird thing as Python here
- (['nan', 0, 'nan', 0], {'key': float}, True),
- ([0, 'nan', 0, 'nan'], {'key': float}, True),
- (['nan', 0, 'nan', 0], {'key': float, 'reverse': True}, True),
- ([0, 'nan', 0, 'nan'], {'key': float, 'reverse': True}, True),
- ([0, 'nan', 0, 'nan'], {'strict': True, 'key': float}, True),
- (
- ['nan', 0, 'nan', 0],
- {'strict': True, 'key': float, 'reverse': True},
- True,
- ),
- ]:
- key = kwargs.get('key', None)
- reverse = kwargs.get('reverse', False)
- strict = kwargs.get('strict', False)
-
- with self.subTest(
- iterable=iterable, key=key, reverse=reverse, strict=strict
- ):
- mi_result = mi.is_sorted(
- iter(iterable), key=key, reverse=reverse, strict=strict
- )
-
- sorted_iterable = sorted(iterable, key=key, reverse=reverse)
- if strict:
- sorted_iterable = list(mi.unique_justseen(sorted_iterable))
-
- py_result = iterable == sorted_iterable
-
- self.assertEqual(mi_result, expected)
- self.assertEqual(mi_result, py_result)
-
-
-class CallbackIterTests(TestCase):
- def _target(self, cb=None, exc=None, wait=0):
- total = 0
- for i, c in enumerate('abc', 1):
- total += i
- if wait:
- sleep(wait)
- if cb:
- cb(i, c, intermediate_total=total)
- if exc:
- raise exc('error in target')
-
- return total
-
- def test_basic(self):
- func = lambda callback=None: self._target(cb=callback, wait=0.02)
- with mi.callback_iter(func, wait_seconds=0.01) as it:
- # Execution doesn't start until we begin iterating
- self.assertFalse(it.done)
-
- # Consume everything
- self.assertEqual(
- list(it),
- [
- ((1, 'a'), {'intermediate_total': 1}),
- ((2, 'b'), {'intermediate_total': 3}),
- ((3, 'c'), {'intermediate_total': 6}),
- ],
- )
-
- # After consuming everything the future is done and the
- # result is available.
- self.assertTrue(it.done)
- self.assertEqual(it.result, 6)
-
- # This examines the internal state of the ThreadPoolExecutor. This
- # isn't documented, so may break in future Python versions.
- self.assertTrue(it._executor._shutdown)
-
- def test_callback_kwd(self):
- with mi.callback_iter(self._target, callback_kwd='cb') as it:
- self.assertEqual(
- list(it),
- [
- ((1, 'a'), {'intermediate_total': 1}),
- ((2, 'b'), {'intermediate_total': 3}),
- ((3, 'c'), {'intermediate_total': 6}),
- ],
- )
-
- def test_partial_consumption(self):
- func = lambda callback=None: self._target(cb=callback)
- with mi.callback_iter(func) as it:
- self.assertEqual(next(it), ((1, 'a'), {'intermediate_total': 1}))
-
- self.assertTrue(it._executor._shutdown)
-
- def test_abort(self):
- func = lambda callback=None: self._target(cb=callback, wait=0.1)
- with mi.callback_iter(func) as it:
- self.assertEqual(next(it), ((1, 'a'), {'intermediate_total': 1}))
-
- with self.assertRaises(mi.AbortThread):
- it.result
-
- def test_no_result(self):
- func = lambda callback=None: self._target(cb=callback)
- with mi.callback_iter(func) as it:
- with self.assertRaises(RuntimeError):
- it.result
-
- def test_exception(self):
- func = lambda callback=None: self._target(cb=callback, exc=ValueError)
- with mi.callback_iter(func) as it:
- self.assertEqual(
- next(it),
- ((1, 'a'), {'intermediate_total': 1}),
- )
-
- with self.assertRaises(ValueError):
- it.result
-
-
-class WindowedCompleteTests(TestCase):
- """Tests for ``windowed_complete()``"""
-
- def test_basic(self):
- actual = list(mi.windowed_complete([1, 2, 3, 4, 5], 3))
- expected = [
- ((), (1, 2, 3), (4, 5)),
- ((1,), (2, 3, 4), (5,)),
- ((1, 2), (3, 4, 5), ()),
- ]
- self.assertEqual(actual, expected)
-
- def test_zero_length(self):
- actual = list(mi.windowed_complete([1, 2, 3], 0))
- expected = [
- ((), (), (1, 2, 3)),
- ((1,), (), (2, 3)),
- ((1, 2), (), (3,)),
- ((1, 2, 3), (), ()),
- ]
- self.assertEqual(actual, expected)
-
- def test_wrong_length(self):
- seq = [1, 2, 3, 4, 5]
- for n in (-10, -1, len(seq) + 1, len(seq) + 10):
- with self.subTest(n=n):
- with self.assertRaises(ValueError):
- list(mi.windowed_complete(seq, n))
-
- def test_every_partition(self):
- every_partition = lambda seq: chain(
- *map(partial(mi.windowed_complete, seq), range(len(seq)))
- )
-
- seq = 'ABC'
- actual = list(every_partition(seq))
- expected = [
- ((), (), ('A', 'B', 'C')),
- (('A',), (), ('B', 'C')),
- (('A', 'B'), (), ('C',)),
- (('A', 'B', 'C'), (), ()),
- ((), ('A',), ('B', 'C')),
- (('A',), ('B',), ('C',)),
- (('A', 'B'), ('C',), ()),
- ((), ('A', 'B'), ('C',)),
- (('A',), ('B', 'C'), ()),
- ]
- self.assertEqual(actual, expected)
-
-
-class AllUniqueTests(TestCase):
- def test_basic(self):
- for iterable, expected in [
- ([], True),
- ([1, 2, 3], True),
- ([1, 1], False),
- ([1, 2, 3, 1], False),
- ([1, 2, 3, '1'], True),
- ]:
- with self.subTest(args=(iterable,)):
- self.assertEqual(mi.all_unique(iterable), expected)
-
- def test_non_hashable(self):
- self.assertEqual(mi.all_unique([[1, 2], [3, 4]]), True)
- self.assertEqual(mi.all_unique([[1, 2], [3, 4], [1, 2]]), False)
-
- def test_partially_hashable(self):
- self.assertEqual(mi.all_unique([[1, 2], [3, 4], (5, 6)]), True)
- self.assertEqual(
- mi.all_unique([[1, 2], [3, 4], (5, 6), [1, 2]]), False
- )
- self.assertEqual(
- mi.all_unique([[1, 2], [3, 4], (5, 6), (5, 6)]), False
- )
-
- def test_key(self):
- iterable = ['A', 'B', 'C', 'b']
- self.assertEqual(mi.all_unique(iterable, lambda x: x), True)
- self.assertEqual(mi.all_unique(iterable, str.lower), False)
-
- def test_infinite(self):
- self.assertEqual(mi.all_unique(mi.prepend(3, count())), False)
-
-
-class NthProductTests(TestCase):
- def test_basic(self):
- iterables = ['ab', 'cdef', 'ghi']
- for index, expected in enumerate(product(*iterables)):
- actual = mi.nth_product(index, *iterables)
- self.assertEqual(actual, expected)
-
- def test_long(self):
- actual = mi.nth_product(1337, range(101), range(22), range(53))
- expected = (1, 3, 12)
- self.assertEqual(actual, expected)
-
- def test_negative(self):
- iterables = ['abc', 'de', 'fghi']
- for index, expected in enumerate(product(*iterables)):
- actual = mi.nth_product(index - 24, *iterables)
- self.assertEqual(actual, expected)
-
- def test_invalid_index(self):
- with self.assertRaises(IndexError):
- mi.nth_product(24, 'ab', 'cde', 'fghi')
-
-
-class ValueChainTests(TestCase):
- def test_empty(self):
- actual = list(mi.value_chain())
- expected = []
- self.assertEqual(actual, expected)
-
- def test_simple(self):
- actual = list(mi.value_chain(1, 2.71828, False, 'foo'))
- expected = [1, 2.71828, False, 'foo']
- self.assertEqual(actual, expected)
-
- def test_more(self):
- actual = list(mi.value_chain(b'bar', [1, 2, 3], 4, {'key': 1}))
- expected = [b'bar', 1, 2, 3, 4, 'key']
- self.assertEqual(actual, expected)
-
- def test_empty_lists(self):
- actual = list(mi.value_chain(1, 2, [], [3, 4]))
- expected = [1, 2, 3, 4]
- self.assertEqual(actual, expected)
-
- def test_complex(self):
- obj = object()
- actual = list(
- mi.value_chain(
- (1, (2, (3,))),
- ['foo', ['bar', ['baz']], 'tic'],
- {'key': {'foo': 1}},
- obj,
- )
- )
- expected = [1, (2, (3,)), 'foo', ['bar', ['baz']], 'tic', 'key', obj]
- self.assertEqual(actual, expected)
-
-
-class ProductIndexTests(TestCase):
- def test_basic(self):
- iterables = ['ab', 'cdef', 'ghi']
- first_index = {}
- for index, element in enumerate(product(*iterables)):
- actual = mi.product_index(element, *iterables)
- expected = first_index.setdefault(element, index)
- self.assertEqual(actual, expected)
-
- def test_multiplicity(self):
- iterables = ['ab', 'bab', 'cab']
- first_index = {}
- for index, element in enumerate(product(*iterables)):
- actual = mi.product_index(element, *iterables)
- expected = first_index.setdefault(element, index)
- self.assertEqual(actual, expected)
-
- def test_long(self):
- actual = mi.product_index((1, 3, 12), range(101), range(22), range(53))
- expected = 1337
- self.assertEqual(actual, expected)
-
- def test_invalid_empty(self):
- with self.assertRaises(ValueError):
- mi.product_index('', 'ab', 'cde', 'fghi')
-
- def test_invalid_small(self):
- with self.assertRaises(ValueError):
- mi.product_index('ac', 'ab', 'cde', 'fghi')
-
- def test_invalid_large(self):
- with self.assertRaises(ValueError):
- mi.product_index('achi', 'ab', 'cde', 'fghi')
-
- def test_invalid_match(self):
- with self.assertRaises(ValueError):
- mi.product_index('axf', 'ab', 'cde', 'fghi')
-
-
-class CombinationIndexTests(TestCase):
- def test_r_less_than_n(self):
- iterable = 'abcdefg'
- r = 4
- first_index = {}
- for index, element in enumerate(combinations(iterable, r)):
- actual = mi.combination_index(element, iterable)
- expected = first_index.setdefault(element, index)
- self.assertEqual(actual, expected)
-
- def test_r_equal_to_n(self):
- iterable = 'abcd'
- r = len(iterable)
- first_index = {}
- for index, element in enumerate(combinations(iterable, r=r)):
- actual = mi.combination_index(element, iterable)
- expected = first_index.setdefault(element, index)
- self.assertEqual(actual, expected)
-
- def test_multiplicity(self):
- iterable = 'abacba'
- r = 3
- first_index = {}
- for index, element in enumerate(combinations(iterable, r)):
- actual = mi.combination_index(element, iterable)
- expected = first_index.setdefault(element, index)
- self.assertEqual(actual, expected)
-
- def test_null(self):
- actual = mi.combination_index(tuple(), [])
- expected = 0
- self.assertEqual(actual, expected)
-
- def test_long(self):
- actual = mi.combination_index((2, 12, 35, 126), range(180))
- expected = 2000000
- self.assertEqual(actual, expected)
-
- def test_invalid_order(self):
- with self.assertRaises(ValueError):
- mi.combination_index(tuple('acb'), 'abcde')
-
- def test_invalid_large(self):
- with self.assertRaises(ValueError):
- mi.combination_index(tuple('abcdefg'), 'abcdef')
-
- def test_invalid_match(self):
- with self.assertRaises(ValueError):
- mi.combination_index(tuple('axe'), 'abcde')
-
-
-class PermutationIndexTests(TestCase):
- def test_r_less_than_n(self):
- iterable = 'abcdefg'
- r = 4
- first_index = {}
- for index, element in enumerate(permutations(iterable, r)):
- actual = mi.permutation_index(element, iterable)
- expected = first_index.setdefault(element, index)
- self.assertEqual(actual, expected)
-
- def test_r_equal_to_n(self):
- iterable = 'abcd'
- first_index = {}
- for index, element in enumerate(permutations(iterable)):
- actual = mi.permutation_index(element, iterable)
- expected = first_index.setdefault(element, index)
- self.assertEqual(actual, expected)
-
- def test_multiplicity(self):
- iterable = 'abacba'
- r = 3
- first_index = {}
- for index, element in enumerate(permutations(iterable, r)):
- actual = mi.permutation_index(element, iterable)
- expected = first_index.setdefault(element, index)
- self.assertEqual(actual, expected)
-
- def test_null(self):
- actual = mi.permutation_index(tuple(), [])
- expected = 0
- self.assertEqual(actual, expected)
-
- def test_long(self):
- actual = mi.permutation_index((2, 12, 35, 126), range(180))
- expected = 11631678
- self.assertEqual(actual, expected)
-
- def test_invalid_large(self):
- with self.assertRaises(ValueError):
- mi.permutation_index(tuple('abcdefg'), 'abcdef')
-
- def test_invalid_match(self):
- with self.assertRaises(ValueError):
- mi.permutation_index(tuple('axe'), 'abcde')
-
-
-class CountableTests(TestCase):
- def test_empty(self):
- iterable = []
- it = mi.countable(iterable)
- self.assertEqual(it.items_seen, 0)
- self.assertEqual(list(it), [])
-
- def test_basic(self):
- iterable = '0123456789'
- it = mi.countable(iterable)
- self.assertEqual(it.items_seen, 0)
- self.assertEqual(next(it), '0')
- self.assertEqual(it.items_seen, 1)
- self.assertEqual(''.join(it), '123456789')
- self.assertEqual(it.items_seen, 10)
-
-
-class ChunkedEvenTests(TestCase):
- """Tests for ``chunked_even()``"""
-
- def test_0(self):
- self._test_finite('', 3, [])
-
- def test_1(self):
- self._test_finite('A', 1, [['A']])
-
- def test_4(self):
- self._test_finite('ABCD', 3, [['A', 'B'], ['C', 'D']])
-
- def test_5(self):
- self._test_finite('ABCDE', 3, [['A', 'B', 'C'], ['D', 'E']])
-
- def test_6(self):
- self._test_finite('ABCDEF', 3, [['A', 'B', 'C'], ['D', 'E', 'F']])
-
- def test_7(self):
- self._test_finite(
- 'ABCDEFG', 3, [['A', 'B', 'C'], ['D', 'E'], ['F', 'G']]
- )
-
- def _test_finite(self, seq, n, expected):
- # Check with and without `len()`
- self.assertEqual(list(mi.chunked_even(seq, n)), expected)
- self.assertEqual(list(mi.chunked_even(iter(seq), n)), expected)
-
- def test_infinite(self):
- for n in range(1, 5):
- k = 0
-
- def count_with_assert():
- for i in count():
- # Look-ahead should be less than n^2
- self.assertLessEqual(i, n * k + n * n)
- yield i
-
- ls = mi.chunked_even(count_with_assert(), n)
- while k < 2:
- self.assertEqual(next(ls), list(range(k * n, (k + 1) * n)))
- k += 1
-
- def test_evenness(self):
- for N in range(1, 50):
- for n in range(1, N + 2):
- lengths = []
- items = []
- for l in mi.chunked_even(range(N), n):
- L = len(l)
- self.assertLessEqual(L, n)
- self.assertGreaterEqual(L, 1)
- lengths.append(L)
- items.extend(l)
- self.assertEqual(items, list(range(N)))
- self.assertLessEqual(max(lengths) - min(lengths), 1)
-
-
-class ZipBroadcastTests(TestCase):
- def test_basic(self):
- for objects, expected in [
- # All scalar
- ([1, 2], [(1, 2)]),
- # Scalar, iterable
- ([1, [2]], [(1, 2)]),
- # Iterable, scalar
- ([[1], 2], [(1, 2)]),
- # Mixed length
- ([1, [2, 3]], [(1, 2), (1, 3)]),
- # All iterable
- ([[1, 2], [3, 4]], [(1, 3), (2, 4)]),
- # Infinite
- ([count(), 1, [2]], [(0, 1, 2)]),
- ([count(), 1, [2, 3]], [(0, 1, 2), (1, 1, 3)]),
- ]:
- with self.subTest(expected=expected):
- actual = list(mi.zip_broadcast(*objects))
- self.assertEqual(actual, expected)
-
- def test_scalar_types(self):
- # Default: str and bytes are treated as scalar
- self.assertEqual(
- list(mi.zip_broadcast('ab', [1, 2, 3])),
- [('ab', 1), ('ab', 2), ('ab', 3)],
- )
- self.assertEqual(
- list(mi.zip_broadcast(b'ab', [1, 2, 3])),
- [(b'ab', 1), (b'ab', 2), (b'ab', 3)],
- )
- # scalar_types=None allows str and bytes to be treated as iterable
- self.assertEqual(
- list(mi.zip_broadcast('abc', [1, 2, 3], scalar_types=None)),
- [('a', 1), ('b', 2), ('c', 3)],
- )
- # Use a custom type
- self.assertEqual(
- list(mi.zip_broadcast({'a': 'b'}, [1, 2, 3], scalar_types=dict)),
- [({'a': 'b'}, 1), ({'a': 'b'}, 2), ({'a': 'b'}, 3)],
- )
-
- def test_strict(self):
- for objects, zipped in [
- ([[], [1]], []),
- ([[1], []], []),
- ([[1], [2, 3]], [(1, 2)]),
- ([[1, 2], [3]], [(1, 3)]),
- ([[1, 2], [3], [4]], [(1, 3, 4)]),
- ([[1], [2, 3], [4]], [(1, 2, 4)]),
- ([[1], [2], [3, 4]], [(1, 2, 3)]),
- ([[1], [2, 3], [4, 5]], [(1, 2, 4)]),
- ([[1, 2], [3], [4, 5]], [(1, 3, 4)]),
- ([[1, 2], [3, 4], [5]], [(1, 3, 5)]),
- (['a', [1, 2], [3, 4, 5]], [('a', 1, 3), ('a', 2, 4)]),
- ]:
- # Truncate by default
- with self.subTest(objects=objects, strict=False, zipped=zipped):
- self.assertEqual(list(mi.zip_broadcast(*objects)), zipped)
-
- # Raise an exception for strict=True
- with self.subTest(objects=objects, strict=True):
- with self.assertRaises(ValueError):
- list(mi.zip_broadcast(*objects, strict=True))
-
- def test_empty(self):
- self.assertEqual(list(mi.zip_broadcast()), [])
-
-
-class UniqueInWindowTests(TestCase):
- def test_invalid_n(self):
- with self.assertRaises(ValueError):
- list(mi.unique_in_window([], 0))
-
- def test_basic(self):
- for iterable, n, expected in [
- (range(9), 10, list(range(9))),
- (range(20), 10, list(range(20))),
- ([1, 2, 3, 4, 4, 4], 1, [1, 2, 3, 4]),
- ([1, 2, 3, 4, 4, 4], 2, [1, 2, 3, 4]),
- ([1, 2, 3, 4, 4, 4], 3, [1, 2, 3, 4]),
- ([1, 2, 3, 4, 4, 4], 4, [1, 2, 3, 4]),
- ([1, 2, 3, 4, 4, 4], 5, [1, 2, 3, 4]),
- ]:
- with self.subTest(expected=expected):
- actual = list(mi.unique_in_window(iterable, n))
- self.assertEqual(actual, expected)
-
- def test_key(self):
- iterable = [0, 1, 3, 4, 5, 6, 7, 8, 9]
- n = 3
- key = lambda x: x // 3
- actual = list(mi.unique_in_window(iterable, n, key=key))
- expected = [0, 3, 6, 9]
- self.assertEqual(actual, expected)
-
-
-class StrictlyNTests(TestCase):
- def test_basic(self):
- iterable = ['a', 'b', 'c', 'd']
- n = 4
- actual = list(mi.strictly_n(iter(iterable), n))
- expected = iterable
- self.assertEqual(actual, expected)
-
- def test_too_short_default(self):
- iterable = ['a', 'b', 'c', 'd']
- n = 5
- with self.assertRaises(ValueError) as exc:
- list(mi.strictly_n(iter(iterable), n))
-
- self.assertEqual(
- 'Too few items in iterable (got 4)', exc.exception.args[0]
- )
-
- def test_too_long_default(self):
- iterable = ['a', 'b', 'c', 'd']
- n = 3
- with self.assertRaises(ValueError) as cm:
- list(mi.strictly_n(iter(iterable), n))
-
- self.assertEqual(
- 'Too many items in iterable (got at least 4)',
- cm.exception.args[0],
- )
-
- def test_too_short_custom(self):
- call_count = 0
-
- def too_short(item_count):
- nonlocal call_count
- call_count += 1
-
- iterable = ['a', 'b', 'c', 'd']
- n = 6
- actual = []
- for item in mi.strictly_n(iter(iterable), n, too_short=too_short):
- actual.append(item)
- expected = ['a', 'b', 'c', 'd']
- self.assertEqual(actual, expected)
- self.assertEqual(call_count, 1)
-
- def test_too_long_custom(self):
- import logging
-
- iterable = ['a', 'b', 'c', 'd']
- n = 2
- too_long = lambda item_count: logging.warning(
- 'Picked the first %s items', n
- )
-
- with self.assertLogs(level='WARNING') as cm:
- actual = list(mi.strictly_n(iter(iterable), n, too_long=too_long))
-
- self.assertEqual(actual, ['a', 'b'])
- self.assertIn('Picked the first 2 items', cm.output[0])
-
-
-class DuplicatesEverSeenTests(TestCase):
- def test_basic(self):
- for iterable, expected in [
- ([], []),
- ([1, 2, 3], []),
- ([1, 1], [1]),
- ([1, 2, 1, 2], [1, 2]),
- ([1, 2, 3, '1'], []),
- ]:
- with self.subTest(args=(iterable,)):
- self.assertEqual(
- list(mi.duplicates_everseen(iterable)), expected
- )
-
- def test_non_hashable(self):
- self.assertEqual(list(mi.duplicates_everseen([[1, 2], [3, 4]])), [])
- self.assertEqual(
- list(mi.duplicates_everseen([[1, 2], [3, 4], [1, 2]])), [[1, 2]]
- )
-
- def test_partially_hashable(self):
- self.assertEqual(
- list(mi.duplicates_everseen([[1, 2], [3, 4], (5, 6)])), []
- )
- self.assertEqual(
- list(mi.duplicates_everseen([[1, 2], [3, 4], (5, 6), [1, 2]])),
- [[1, 2]],
- )
- self.assertEqual(
- list(mi.duplicates_everseen([[1, 2], [3, 4], (5, 6), (5, 6)])),
- [(5, 6)],
- )
-
- def test_key_hashable(self):
- iterable = 'HEheHEhe'
- self.assertEqual(list(mi.duplicates_everseen(iterable)), list('HEhe'))
- self.assertEqual(
- list(mi.duplicates_everseen(iterable, str.lower)),
- list('heHEhe'),
- )
-
- def test_key_non_hashable(self):
- iterable = [[1, 2], [3, 0], [5, -2], [5, 6]]
- self.assertEqual(
- list(mi.duplicates_everseen(iterable, lambda x: x)), []
- )
- self.assertEqual(
- list(mi.duplicates_everseen(iterable, sum)), [[3, 0], [5, -2]]
- )
-
- def test_key_partially_hashable(self):
- iterable = [[1, 2], (1, 2), [1, 2], [5, 6]]
- self.assertEqual(
- list(mi.duplicates_everseen(iterable, lambda x: x)), [[1, 2]]
- )
- self.assertEqual(
- list(mi.duplicates_everseen(iterable, list)), [(1, 2), [1, 2]]
- )
-
-
-class DuplicatesJustSeenTests(TestCase):
- def test_basic(self):
- for iterable, expected in [
- ([], []),
- ([1, 2, 3, 3, 2, 2], [3, 2]),
- ([1, 1], [1]),
- ([1, 2, 1, 2], []),
- ([1, 2, 3, '1'], []),
- ]:
- with self.subTest(args=(iterable,)):
- self.assertEqual(
- list(mi.duplicates_justseen(iterable)), expected
- )
-
- def test_non_hashable(self):
- self.assertEqual(list(mi.duplicates_justseen([[1, 2], [3, 4]])), [])
- self.assertEqual(
- list(
- mi.duplicates_justseen(
- [[1, 2], [3, 4], [3, 4], [3, 4], [1, 2]]
- )
- ),
- [[3, 4], [3, 4]],
- )
-
- def test_partially_hashable(self):
- self.assertEqual(
- list(mi.duplicates_justseen([[1, 2], [3, 4], (5, 6)])), []
- )
- self.assertEqual(
- list(
- mi.duplicates_justseen(
- [[1, 2], [3, 4], (5, 6), [1, 2], [1, 2]]
- )
- ),
- [[1, 2]],
- )
- self.assertEqual(
- list(
- mi.duplicates_justseen(
- [[1, 2], [3, 4], (5, 6), (5, 6), (5, 6)]
- )
- ),
- [(5, 6), (5, 6)],
- )
-
- def test_key_hashable(self):
- iterable = 'HEheHHHhEheeEe'
- self.assertEqual(list(mi.duplicates_justseen(iterable)), list('HHe'))
- self.assertEqual(
- list(mi.duplicates_justseen(iterable, str.lower)),
- list('HHheEe'),
- )
-
- def test_key_non_hashable(self):
- iterable = [[1, 2], [3, 0], [5, -2], [5, 6], [1, 2]]
- self.assertEqual(
- list(mi.duplicates_justseen(iterable, lambda x: x)), []
- )
- self.assertEqual(
- list(mi.duplicates_justseen(iterable, sum)), [[3, 0], [5, -2]]
- )
-
- def test_key_partially_hashable(self):
- iterable = [[1, 2], (1, 2), [1, 2], [5, 6], [1, 2]]
- self.assertEqual(
- list(mi.duplicates_justseen(iterable, lambda x: x)), []
- )
- self.assertEqual(
- list(mi.duplicates_justseen(iterable, list)), [(1, 2), [1, 2]]
- )
-
- def test_nested(self):
- iterable = [[[1, 2], [1, 2]], [5, 6], [5, 6]]
- self.assertEqual(list(mi.duplicates_justseen(iterable)), [[5, 6]])
diff --git a/contrib/python/more-itertools/py3/tests/test_recipes.py b/contrib/python/more-itertools/py3/tests/test_recipes.py
deleted file mode 100644
index be40995749..0000000000
--- a/contrib/python/more-itertools/py3/tests/test_recipes.py
+++ /dev/null
@@ -1,765 +0,0 @@
-import warnings
-
-from doctest import DocTestSuite
-from itertools import combinations, count, permutations
-from math import factorial
-from unittest import TestCase
-
-import more_itertools as mi
-
-
-def load_tests(loader, tests, ignore):
- # Add the doctests
- tests.addTests(DocTestSuite('more_itertools.recipes'))
- return tests
-
-
-class TakeTests(TestCase):
- """Tests for ``take()``"""
-
- def test_simple_take(self):
- """Test basic usage"""
- t = mi.take(5, range(10))
- self.assertEqual(t, [0, 1, 2, 3, 4])
-
- def test_null_take(self):
- """Check the null case"""
- t = mi.take(0, range(10))
- self.assertEqual(t, [])
-
- def test_negative_take(self):
- """Make sure taking negative items results in a ValueError"""
- self.assertRaises(ValueError, lambda: mi.take(-3, range(10)))
-
- def test_take_too_much(self):
- """Taking more than an iterator has remaining should return what the
- iterator has remaining.
-
- """
- t = mi.take(10, range(5))
- self.assertEqual(t, [0, 1, 2, 3, 4])
-
-
-class TabulateTests(TestCase):
- """Tests for ``tabulate()``"""
-
- def test_simple_tabulate(self):
- """Test the happy path"""
- t = mi.tabulate(lambda x: x)
- f = tuple([next(t) for _ in range(3)])
- self.assertEqual(f, (0, 1, 2))
-
- def test_count(self):
- """Ensure tabulate accepts specific count"""
- t = mi.tabulate(lambda x: 2 * x, -1)
- f = (next(t), next(t), next(t))
- self.assertEqual(f, (-2, 0, 2))
-
-
-class TailTests(TestCase):
- """Tests for ``tail()``"""
-
- def test_greater(self):
- """Length of iterable is greater than requested tail"""
- self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G'])
-
- def test_equal(self):
- """Length of iterable is equal to the requested tail"""
- self.assertEqual(
- list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
- )
-
- def test_less(self):
- """Length of iterable is less than requested tail"""
- self.assertEqual(
- list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
- )
-
-
-class ConsumeTests(TestCase):
- """Tests for ``consume()``"""
-
- def test_sanity(self):
- """Test basic functionality"""
- r = (x for x in range(10))
- mi.consume(r, 3)
- self.assertEqual(3, next(r))
-
- def test_null_consume(self):
- """Check the null case"""
- r = (x for x in range(10))
- mi.consume(r, 0)
- self.assertEqual(0, next(r))
-
- def test_negative_consume(self):
- """Check that negative consumsion throws an error"""
- r = (x for x in range(10))
- self.assertRaises(ValueError, lambda: mi.consume(r, -1))
-
- def test_total_consume(self):
- """Check that iterator is totally consumed by default"""
- r = (x for x in range(10))
- mi.consume(r)
- self.assertRaises(StopIteration, lambda: next(r))
-
-
-class NthTests(TestCase):
- """Tests for ``nth()``"""
-
- def test_basic(self):
- """Make sure the nth item is returned"""
- l = range(10)
- for i, v in enumerate(l):
- self.assertEqual(mi.nth(l, i), v)
-
- def test_default(self):
- """Ensure a default value is returned when nth item not found"""
- l = range(3)
- self.assertEqual(mi.nth(l, 100, "zebra"), "zebra")
-
- def test_negative_item_raises(self):
- """Ensure asking for a negative item raises an exception"""
- self.assertRaises(ValueError, lambda: mi.nth(range(10), -3))
-
-
-class AllEqualTests(TestCase):
- """Tests for ``all_equal()``"""
-
- def test_true(self):
- """Everything is equal"""
- self.assertTrue(mi.all_equal('aaaaaa'))
- self.assertTrue(mi.all_equal([0, 0, 0, 0]))
-
- def test_false(self):
- """Not everything is equal"""
- self.assertFalse(mi.all_equal('aaaaab'))
- self.assertFalse(mi.all_equal([0, 0, 0, 1]))
-
- def test_tricky(self):
- """Not everything is identical, but everything is equal"""
- items = [1, complex(1, 0), 1.0]
- self.assertTrue(mi.all_equal(items))
-
- def test_empty(self):
- """Return True if the iterable is empty"""
- self.assertTrue(mi.all_equal(''))
- self.assertTrue(mi.all_equal([]))
-
- def test_one(self):
- """Return True if the iterable is singular"""
- self.assertTrue(mi.all_equal('0'))
- self.assertTrue(mi.all_equal([0]))
-
-
-class QuantifyTests(TestCase):
- """Tests for ``quantify()``"""
-
- def test_happy_path(self):
- """Make sure True count is returned"""
- q = [True, False, True]
- self.assertEqual(mi.quantify(q), 2)
-
- def test_custom_predicate(self):
- """Ensure non-default predicates return as expected"""
- q = range(10)
- self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5)
-
-
-class PadnoneTests(TestCase):
- def test_basic(self):
- iterable = range(2)
- for func in (mi.pad_none, mi.padnone):
- with self.subTest(func=func):
- p = func(iterable)
- self.assertEqual(
- [0, 1, None, None], [next(p) for _ in range(4)]
- )
-
-
-class NcyclesTests(TestCase):
- """Tests for ``nyclces()``"""
-
- def test_happy_path(self):
- """cycle a sequence three times"""
- r = ["a", "b", "c"]
- n = mi.ncycles(r, 3)
- self.assertEqual(
- ["a", "b", "c", "a", "b", "c", "a", "b", "c"], list(n)
- )
-
- def test_null_case(self):
- """asking for 0 cycles should return an empty iterator"""
- n = mi.ncycles(range(100), 0)
- self.assertRaises(StopIteration, lambda: next(n))
-
- def test_pathalogical_case(self):
- """asking for negative cycles should return an empty iterator"""
- n = mi.ncycles(range(100), -10)
- self.assertRaises(StopIteration, lambda: next(n))
-
-
-class DotproductTests(TestCase):
- """Tests for ``dotproduct()``'"""
-
- def test_happy_path(self):
- """simple dotproduct example"""
- self.assertEqual(400, mi.dotproduct([10, 10], [20, 20]))
-
-
-class FlattenTests(TestCase):
- """Tests for ``flatten()``"""
-
- def test_basic_usage(self):
- """ensure list of lists is flattened one level"""
- f = [[0, 1, 2], [3, 4, 5]]
- self.assertEqual(list(range(6)), list(mi.flatten(f)))
-
- def test_single_level(self):
- """ensure list of lists is flattened only one level"""
- f = [[0, [1, 2]], [[3, 4], 5]]
- self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f)))
-
-
-class RepeatfuncTests(TestCase):
- """Tests for ``repeatfunc()``"""
-
- def test_simple_repeat(self):
- """test simple repeated functions"""
- r = mi.repeatfunc(lambda: 5)
- self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)])
-
- def test_finite_repeat(self):
- """ensure limited repeat when times is provided"""
- r = mi.repeatfunc(lambda: 5, times=5)
- self.assertEqual([5, 5, 5, 5, 5], list(r))
-
- def test_added_arguments(self):
- """ensure arguments are applied to the function"""
- r = mi.repeatfunc(lambda x: x, 2, 3)
- self.assertEqual([3, 3], list(r))
-
- def test_null_times(self):
- """repeat 0 should return an empty iterator"""
- r = mi.repeatfunc(range, 0, 3)
- self.assertRaises(StopIteration, lambda: next(r))
-
-
-class PairwiseTests(TestCase):
- """Tests for ``pairwise()``"""
-
- def test_base_case(self):
- """ensure an iterable will return pairwise"""
- p = mi.pairwise([1, 2, 3])
- self.assertEqual([(1, 2), (2, 3)], list(p))
-
- def test_short_case(self):
- """ensure an empty iterator if there's not enough values to pair"""
- p = mi.pairwise("a")
- self.assertRaises(StopIteration, lambda: next(p))
-
-
-class GrouperTests(TestCase):
- """Tests for ``grouper()``"""
-
- def test_even(self):
- """Test when group size divides evenly into the length of
- the iterable.
-
- """
- self.assertEqual(
- list(mi.grouper('ABCDEF', 3)), [('A', 'B', 'C'), ('D', 'E', 'F')]
- )
-
- def test_odd(self):
- """Test when group size does not divide evenly into the length of the
- iterable.
-
- """
- self.assertEqual(
- list(mi.grouper('ABCDE', 3)), [('A', 'B', 'C'), ('D', 'E', None)]
- )
-
- def test_fill_value(self):
- """Test that the fill value is used to pad the final group"""
- self.assertEqual(
- list(mi.grouper('ABCDE', 3, 'x')),
- [('A', 'B', 'C'), ('D', 'E', 'x')],
- )
-
- def test_legacy_order(self):
- """Historically, grouper expected the n as the first parameter"""
- with warnings.catch_warnings(record=True) as caught:
- warnings.simplefilter('always')
- self.assertEqual(
- list(mi.grouper(3, 'ABCDEF')),
- [('A', 'B', 'C'), ('D', 'E', 'F')],
- )
-
- (warning,) = caught
- assert warning.category == DeprecationWarning
-
-
-class RoundrobinTests(TestCase):
- """Tests for ``roundrobin()``"""
-
- def test_even_groups(self):
- """Ensure ordered output from evenly populated iterables"""
- self.assertEqual(
- list(mi.roundrobin('ABC', [1, 2, 3], range(3))),
- ['A', 1, 0, 'B', 2, 1, 'C', 3, 2],
- )
-
- def test_uneven_groups(self):
- """Ensure ordered output from unevenly populated iterables"""
- self.assertEqual(
- list(mi.roundrobin('ABCD', [1, 2], range(0))),
- ['A', 1, 'B', 2, 'C', 'D'],
- )
-
-
-class PartitionTests(TestCase):
- """Tests for ``partition()``"""
-
- def test_bool(self):
- lesser, greater = mi.partition(lambda x: x > 5, range(10))
- self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5])
- self.assertEqual(list(greater), [6, 7, 8, 9])
-
- def test_arbitrary(self):
- divisibles, remainders = mi.partition(lambda x: x % 3, range(10))
- self.assertEqual(list(divisibles), [0, 3, 6, 9])
- self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8])
-
- def test_pred_is_none(self):
- falses, trues = mi.partition(None, range(3))
- self.assertEqual(list(falses), [0])
- self.assertEqual(list(trues), [1, 2])
-
-
-class PowersetTests(TestCase):
- """Tests for ``powerset()``"""
-
- def test_combinatorics(self):
- """Ensure a proper enumeration"""
- p = mi.powerset([1, 2, 3])
- self.assertEqual(
- list(p), [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
- )
-
-
-class UniqueEverseenTests(TestCase):
- """Tests for ``unique_everseen()``"""
-
- def test_everseen(self):
- """ensure duplicate elements are ignored"""
- u = mi.unique_everseen('AAAABBBBCCDAABBB')
- self.assertEqual(['A', 'B', 'C', 'D'], list(u))
-
- def test_custom_key(self):
- """ensure the custom key comparison works"""
- u = mi.unique_everseen('aAbACCc', key=str.lower)
- self.assertEqual(list('abC'), list(u))
-
- def test_unhashable(self):
- """ensure things work for unhashable items"""
- iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
- u = mi.unique_everseen(iterable)
- self.assertEqual(list(u), ['a', [1, 2, 3]])
-
- def test_unhashable_key(self):
- """ensure things work for unhashable items with a custom key"""
- iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
- u = mi.unique_everseen(iterable, key=lambda x: x)
- self.assertEqual(list(u), ['a', [1, 2, 3]])
-
-
-class UniqueJustseenTests(TestCase):
- """Tests for ``unique_justseen()``"""
-
- def test_justseen(self):
- """ensure only last item is remembered"""
- u = mi.unique_justseen('AAAABBBCCDABB')
- self.assertEqual(list('ABCDAB'), list(u))
-
- def test_custom_key(self):
- """ensure the custom key comparison works"""
- u = mi.unique_justseen('AABCcAD', str.lower)
- self.assertEqual(list('ABCAD'), list(u))
-
-
-class IterExceptTests(TestCase):
- """Tests for ``iter_except()``"""
-
- def test_exact_exception(self):
- """ensure the exact specified exception is caught"""
- l = [1, 2, 3]
- i = mi.iter_except(l.pop, IndexError)
- self.assertEqual(list(i), [3, 2, 1])
-
- def test_generic_exception(self):
- """ensure the generic exception can be caught"""
- l = [1, 2]
- i = mi.iter_except(l.pop, Exception)
- self.assertEqual(list(i), [2, 1])
-
- def test_uncaught_exception_is_raised(self):
- """ensure a non-specified exception is raised"""
- l = [1, 2, 3]
- i = mi.iter_except(l.pop, KeyError)
- self.assertRaises(IndexError, lambda: list(i))
-
- def test_first(self):
- """ensure first is run before the function"""
- l = [1, 2, 3]
- f = lambda: 25
- i = mi.iter_except(l.pop, IndexError, f)
- self.assertEqual(list(i), [25, 3, 2, 1])
-
- def test_multiple(self):
- """ensure can catch multiple exceptions"""
-
- class Fiz(Exception):
- pass
-
- class Buzz(Exception):
- pass
-
- i = 0
-
- def fizbuzz():
- nonlocal i
- i += 1
- if i % 3 == 0:
- raise Fiz
- if i % 5 == 0:
- raise Buzz
- return i
-
- expected = ([1, 2], [4], [], [7, 8], [])
- for x in expected:
- self.assertEqual(list(mi.iter_except(fizbuzz, (Fiz, Buzz))), x)
-
-
-class FirstTrueTests(TestCase):
- """Tests for ``first_true()``"""
-
- def test_something_true(self):
- """Test with no keywords"""
- self.assertEqual(mi.first_true(range(10)), 1)
-
- def test_nothing_true(self):
- """Test default return value."""
- self.assertIsNone(mi.first_true([0, 0, 0]))
-
- def test_default(self):
- """Test with a default keyword"""
- self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!')
-
- def test_pred(self):
- """Test with a custom predicate"""
- self.assertEqual(
- mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6
- )
-
-
-class RandomProductTests(TestCase):
- """Tests for ``random_product()``
-
- Since random.choice() has different results with the same seed across
- python versions 2.x and 3.x, these tests use highly probably events to
- create predictable outcomes across platforms.
- """
-
- def test_simple_lists(self):
- """Ensure that one item is chosen from each list in each pair.
- Also ensure that each item from each list eventually appears in
- the chosen combinations.
-
- Odds are roughly 1 in 7.1 * 10e16 that one item from either list will
- not be chosen after 100 samplings of one item from each list. Just to
- be safe, better use a known random seed, too.
-
- """
- nums = [1, 2, 3]
- lets = ['a', 'b', 'c']
- n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)])
- n, m = set(n), set(m)
- self.assertEqual(n, set(nums))
- self.assertEqual(m, set(lets))
- self.assertEqual(len(n), len(nums))
- self.assertEqual(len(m), len(lets))
-
- def test_list_with_repeat(self):
- """ensure multiple items are chosen, and that they appear to be chosen
- from one list then the next, in proper order.
-
- """
- nums = [1, 2, 3]
- lets = ['a', 'b', 'c']
- r = list(mi.random_product(nums, lets, repeat=100))
- self.assertEqual(2 * 100, len(r))
- n, m = set(r[::2]), set(r[1::2])
- self.assertEqual(n, set(nums))
- self.assertEqual(m, set(lets))
- self.assertEqual(len(n), len(nums))
- self.assertEqual(len(m), len(lets))
-
-
-class RandomPermutationTests(TestCase):
- """Tests for ``random_permutation()``"""
-
- def test_full_permutation(self):
- """ensure every item from the iterable is returned in a new ordering
-
- 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so
- we fix a seed value just to be sure.
-
- """
- i = range(15)
- r = mi.random_permutation(i)
- self.assertEqual(set(i), set(r))
- if i == r:
- raise AssertionError("Values were not permuted")
-
- def test_partial_permutation(self):
- """ensure all returned items are from the iterable, that the returned
- permutation is of the desired length, and that all items eventually
- get returned.
-
- Sampling 100 permutations of length 5 from a set of 15 leaves a
- (2/3)^100 chance that an item will not be chosen. Multiplied by 15
- items, there is a 1 in 2.6e16 chance that at least 1 item will not
- show up in the resulting output. Using a random seed will fix that.
-
- """
- items = range(15)
- item_set = set(items)
- all_items = set()
- for _ in range(100):
- permutation = mi.random_permutation(items, 5)
- self.assertEqual(len(permutation), 5)
- permutation_set = set(permutation)
- self.assertLessEqual(permutation_set, item_set)
- all_items |= permutation_set
- self.assertEqual(all_items, item_set)
-
-
-class RandomCombinationTests(TestCase):
- """Tests for ``random_combination()``"""
-
- def test_pseudorandomness(self):
- """ensure different subsets of the iterable get returned over many
- samplings of random combinations"""
- items = range(15)
- all_items = set()
- for _ in range(50):
- combination = mi.random_combination(items, 5)
- all_items |= set(combination)
- self.assertEqual(all_items, set(items))
-
- def test_no_replacement(self):
- """ensure that elements are sampled without replacement"""
- items = range(15)
- for _ in range(50):
- combination = mi.random_combination(items, len(items))
- self.assertEqual(len(combination), len(set(combination)))
- self.assertRaises(
- ValueError, lambda: mi.random_combination(items, len(items) + 1)
- )
-
-
-class RandomCombinationWithReplacementTests(TestCase):
- """Tests for ``random_combination_with_replacement()``"""
-
- def test_replacement(self):
- """ensure that elements are sampled with replacement"""
- items = range(5)
- combo = mi.random_combination_with_replacement(items, len(items) * 2)
- self.assertEqual(2 * len(items), len(combo))
- if len(set(combo)) == len(combo):
- raise AssertionError("Combination contained no duplicates")
-
- def test_pseudorandomness(self):
- """ensure different subsets of the iterable get returned over many
- samplings of random combinations"""
- items = range(15)
- all_items = set()
- for _ in range(50):
- combination = mi.random_combination_with_replacement(items, 5)
- all_items |= set(combination)
- self.assertEqual(all_items, set(items))
-
-
-class NthCombinationTests(TestCase):
- def test_basic(self):
- iterable = 'abcdefg'
- r = 4
- for index, expected in enumerate(combinations(iterable, r)):
- actual = mi.nth_combination(iterable, r, index)
- self.assertEqual(actual, expected)
-
- def test_long(self):
- actual = mi.nth_combination(range(180), 4, 2000000)
- expected = (2, 12, 35, 126)
- self.assertEqual(actual, expected)
-
- def test_invalid_r(self):
- for r in (-1, 3):
- with self.assertRaises(ValueError):
- mi.nth_combination([], r, 0)
-
- def test_invalid_index(self):
- with self.assertRaises(IndexError):
- mi.nth_combination('abcdefg', 3, -36)
-
-
-class NthPermutationTests(TestCase):
- def test_r_less_than_n(self):
- iterable = 'abcde'
- r = 4
- for index, expected in enumerate(permutations(iterable, r)):
- actual = mi.nth_permutation(iterable, r, index)
- self.assertEqual(actual, expected)
-
- def test_r_equal_to_n(self):
- iterable = 'abcde'
- for index, expected in enumerate(permutations(iterable)):
- actual = mi.nth_permutation(iterable, None, index)
- self.assertEqual(actual, expected)
-
- def test_long(self):
- iterable = tuple(range(180))
- r = 4
- index = 1000000
- actual = mi.nth_permutation(iterable, r, index)
- expected = mi.nth(permutations(iterable, r), index)
- self.assertEqual(actual, expected)
-
- def test_null(self):
- actual = mi.nth_permutation([], 0, 0)
- expected = tuple()
- self.assertEqual(actual, expected)
-
- def test_negative_index(self):
- iterable = 'abcde'
- r = 4
- n = factorial(len(iterable)) // factorial(len(iterable) - r)
- for index, expected in enumerate(permutations(iterable, r)):
- actual = mi.nth_permutation(iterable, r, index - n)
- self.assertEqual(actual, expected)
-
- def test_invalid_index(self):
- iterable = 'abcde'
- r = 4
- n = factorial(len(iterable)) // factorial(len(iterable) - r)
- for index in [-1 - n, n + 1]:
- with self.assertRaises(IndexError):
- mi.nth_combination(iterable, r, index)
-
- def test_invalid_r(self):
- iterable = 'abcde'
- r = 4
- n = factorial(len(iterable)) // factorial(len(iterable) - r)
- for r in [-1, n + 1]:
- with self.assertRaises(ValueError):
- mi.nth_combination(iterable, r, 0)
-
-
-class PrependTests(TestCase):
- def test_basic(self):
- value = 'a'
- iterator = iter('bcdefg')
- actual = list(mi.prepend(value, iterator))
- expected = list('abcdefg')
- self.assertEqual(actual, expected)
-
- def test_multiple(self):
- value = 'ab'
- iterator = iter('cdefg')
- actual = tuple(mi.prepend(value, iterator))
- expected = ('ab',) + tuple('cdefg')
- self.assertEqual(actual, expected)
-
-
-class Convolvetests(TestCase):
- def test_moving_average(self):
- signal = iter([10, 20, 30, 40, 50])
- kernel = [0.5, 0.5]
- actual = list(mi.convolve(signal, kernel))
- expected = [
- (10 + 0) / 2,
- (20 + 10) / 2,
- (30 + 20) / 2,
- (40 + 30) / 2,
- (50 + 40) / 2,
- (0 + 50) / 2,
- ]
- self.assertEqual(actual, expected)
-
- def test_derivative(self):
- signal = iter([10, 20, 30, 40, 50])
- kernel = [1, -1]
- actual = list(mi.convolve(signal, kernel))
- expected = [10 - 0, 20 - 10, 30 - 20, 40 - 30, 50 - 40, 0 - 50]
- self.assertEqual(actual, expected)
-
- def test_infinite_signal(self):
- signal = count()
- kernel = [1, -1]
- actual = mi.take(5, mi.convolve(signal, kernel))
- expected = [0, 1, 1, 1, 1]
- self.assertEqual(actual, expected)
-
-
-class BeforeAndAfterTests(TestCase):
- def test_empty(self):
- before, after = mi.before_and_after(bool, [])
- self.assertEqual(list(before), [])
- self.assertEqual(list(after), [])
-
- def test_never_true(self):
- before, after = mi.before_and_after(bool, [0, False, None, ''])
- self.assertEqual(list(before), [])
- self.assertEqual(list(after), [0, False, None, ''])
-
- def test_never_false(self):
- before, after = mi.before_and_after(bool, [1, True, Ellipsis, ' '])
- self.assertEqual(list(before), [1, True, Ellipsis, ' '])
- self.assertEqual(list(after), [])
-
- def test_some_true(self):
- before, after = mi.before_and_after(bool, [1, True, 0, False])
- self.assertEqual(list(before), [1, True])
- self.assertEqual(list(after), [0, False])
-
-
-class TriplewiseTests(TestCase):
- def test_basic(self):
- for iterable, expected in [
- ([0], []),
- ([0, 1], []),
- ([0, 1, 2], [(0, 1, 2)]),
- ([0, 1, 2, 3], [(0, 1, 2), (1, 2, 3)]),
- ([0, 1, 2, 3, 4], [(0, 1, 2), (1, 2, 3), (2, 3, 4)]),
- ]:
- with self.subTest(expected=expected):
- actual = list(mi.triplewise(iterable))
- self.assertEqual(actual, expected)
-
-
-class SlidingWindowTests(TestCase):
- def test_basic(self):
- for iterable, n, expected in [
- ([], 0, [()]),
- ([], 1, []),
- ([0], 1, [(0,)]),
- ([0, 1], 1, [(0,), (1,)]),
- ([0, 1, 2], 2, [(0, 1), (1, 2)]),
- ([0, 1, 2], 3, [(0, 1, 2)]),
- ([0, 1, 2], 4, []),
- ([0, 1, 2, 3], 4, [(0, 1, 2, 3)]),
- ([0, 1, 2, 3, 4], 4, [(0, 1, 2, 3), (1, 2, 3, 4)]),
- ]:
- with self.subTest(expected=expected):
- actual = list(mi.sliding_window(iterable, n))
- self.assertEqual(actual, expected)
diff --git a/contrib/python/more-itertools/py3/tests/ya.make b/contrib/python/more-itertools/py3/tests/ya.make
deleted file mode 100644
index 8d3caffc22..0000000000
--- a/contrib/python/more-itertools/py3/tests/ya.make
+++ /dev/null
@@ -1,16 +0,0 @@
-PY3TEST()
-
-OWNER(g:python-contrib)
-
-PEERDIR(
- contrib/python/more-itertools
-)
-
-TEST_SRCS(
- test_more.py
- test_recipes.py
-)
-
-NO_LINT()
-
-END()
diff --git a/contrib/python/more-itertools/py3/ya.make b/contrib/python/more-itertools/py3/ya.make
deleted file mode 100644
index 3573378d83..0000000000
--- a/contrib/python/more-itertools/py3/ya.make
+++ /dev/null
@@ -1,34 +0,0 @@
-# Generated by devtools/yamaker (pypi).
-
-PY3_LIBRARY()
-
-OWNER(g:python-contrib)
-
-VERSION(8.12.0)
-
-LICENSE(MIT)
-
-NO_LINT()
-
-PY_SRCS(
- TOP_LEVEL
- more_itertools/__init__.py
- more_itertools/__init__.pyi
- more_itertools/more.py
- more_itertools/more.pyi
- more_itertools/recipes.py
- more_itertools/recipes.pyi
-)
-
-RESOURCE_FILES(
- PREFIX contrib/python/more-itertools/py3/
- .dist-info/METADATA
- .dist-info/top_level.txt
- more_itertools/py.typed
-)
-
-END()
-
-RECURSE_FOR_TESTS(
- tests
-)
diff --git a/contrib/python/more-itertools/ya.make b/contrib/python/more-itertools/ya.make
deleted file mode 100644
index 2caa580ba5..0000000000
--- a/contrib/python/more-itertools/ya.make
+++ /dev/null
@@ -1,20 +0,0 @@
-PY23_LIBRARY()
-
-LICENSE(Service-Py23-Proxy)
-
-OWNER(g:python-contrib)
-
-IF (PYTHON2)
- PEERDIR(contrib/python/more-itertools/py2)
-ELSE()
- PEERDIR(contrib/python/more-itertools/py3)
-ENDIF()
-
-NO_LINT()
-
-END()
-
-RECURSE(
- py2
- py3
-)
diff --git a/contrib/python/pytest/py3/.dist-info/METADATA b/contrib/python/pytest/py3/.dist-info/METADATA
index 2354d6e80c..ee9a695542 100644
--- a/contrib/python/pytest/py3/.dist-info/METADATA
+++ b/contrib/python/pytest/py3/.dist-info/METADATA
@@ -1,10 +1,12 @@
Metadata-Version: 2.1
Name: pytest
-Version: 5.4.3
+Version: 6.2.5
Summary: pytest: simple powerful testing with Python
Home-page: https://docs.pytest.org/en/latest/
Author: Holger Krekel, Bruno Oliveira, Ronny Pfannschmidt, Floris Bruynooghe, Brianna Laugher, Florian Bruhin and others
-License: MIT license
+License: MIT
+Project-URL: Changelog, https://docs.pytest.org/en/stable/changelog.html
+Project-URL: Twitter, https://twitter.com/pytestdotorg
Project-URL: Source, https://github.com/pytest-dev/pytest
Project-URL: Tracker, https://github.com/pytest-dev/pytest/issues
Keywords: test,unittest
@@ -16,31 +18,31 @@ Platform: win32
Classifier: Development Status :: 6 - Mature
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
-Classifier: Operating System :: POSIX
-Classifier: Operating System :: Microsoft :: Windows
Classifier: Operating System :: MacOS :: MacOS X
-Classifier: Topic :: Software Development :: Testing
-Classifier: Topic :: Software Development :: Libraries
-Classifier: Topic :: Utilities
+Classifier: Operating System :: Microsoft :: Windows
+Classifier: Operating System :: POSIX
+Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
-Classifier: Programming Language :: Python :: 3.5
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
-Requires-Python: >=3.5
-Requires-Dist: py (>=1.5.0)
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Topic :: Software Development :: Libraries
+Classifier: Topic :: Software Development :: Testing
+Classifier: Topic :: Utilities
+Requires-Python: >=3.6
+Description-Content-Type: text/x-rst
+License-File: LICENSE
+Requires-Dist: attrs (>=19.2.0)
+Requires-Dist: iniconfig
Requires-Dist: packaging
-Requires-Dist: attrs (>=17.4.0)
-Requires-Dist: more-itertools (>=4.0.0)
-Requires-Dist: pluggy (<1.0,>=0.12)
-Requires-Dist: wcwidth
-Requires-Dist: pathlib2 (>=2.2.0) ; python_version < "3.6"
+Requires-Dist: pluggy (<2.0,>=0.12)
+Requires-Dist: py (>=1.8.2)
+Requires-Dist: toml
Requires-Dist: importlib-metadata (>=0.12) ; python_version < "3.8"
Requires-Dist: atomicwrites (>=1.0) ; sys_platform == "win32"
Requires-Dist: colorama ; sys_platform == "win32"
-Provides-Extra: checkqa-mypy
-Requires-Dist: mypy (==v0.761) ; extra == 'checkqa-mypy'
Provides-Extra: testing
Requires-Dist: argcomplete ; extra == 'testing'
Requires-Dist: hypothesis (>=3.56) ; extra == 'testing'
@@ -49,8 +51,8 @@ Requires-Dist: nose ; extra == 'testing'
Requires-Dist: requests ; extra == 'testing'
Requires-Dist: xmlschema ; extra == 'testing'
-.. image:: https://docs.pytest.org/en/latest/_static/pytest1.png
- :target: https://docs.pytest.org/en/latest/
+.. image:: https://github.com/pytest-dev/pytest/raw/main/doc/en/img/pytest_logo_curves.svg
+ :target: https://docs.pytest.org/en/stable/
:align: center
:alt: pytest
@@ -66,15 +68,19 @@ Requires-Dist: xmlschema ; extra == 'testing'
.. image:: https://img.shields.io/pypi/pyversions/pytest.svg
:target: https://pypi.org/project/pytest/
-.. image:: https://codecov.io/gh/pytest-dev/pytest/branch/master/graph/badge.svg
+.. image:: https://codecov.io/gh/pytest-dev/pytest/branch/main/graph/badge.svg
:target: https://codecov.io/gh/pytest-dev/pytest
:alt: Code coverage Status
.. image:: https://travis-ci.org/pytest-dev/pytest.svg?branch=master
:target: https://travis-ci.org/pytest-dev/pytest
-.. image:: https://dev.azure.com/pytest-dev/pytest/_apis/build/status/pytest-CI?branchName=master
- :target: https://dev.azure.com/pytest-dev/pytest
+.. image:: https://github.com/pytest-dev/pytest/workflows/main/badge.svg
+ :target: https://github.com/pytest-dev/pytest/actions?query=workflow%3Amain
+
+.. image:: https://results.pre-commit.ci/badge/github/pytest-dev/pytest/main.svg
+ :target: https://results.pre-commit.ci/latest/github/pytest-dev/pytest/master
+ :alt: pre-commit.ci status
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black
@@ -122,33 +128,33 @@ To execute it::
========================== 1 failed in 0.04 seconds ===========================
-Due to ``pytest``'s detailed assertion introspection, only plain ``assert`` statements are used. See `getting-started <https://docs.pytest.org/en/latest/getting-started.html#our-first-test-run>`_ for more examples.
+Due to ``pytest``'s detailed assertion introspection, only plain ``assert`` statements are used. See `getting-started <https://docs.pytest.org/en/stable/getting-started.html#our-first-test-run>`_ for more examples.
Features
--------
-- Detailed info on failing `assert statements <https://docs.pytest.org/en/latest/assert.html>`_ (no need to remember ``self.assert*`` names);
+- Detailed info on failing `assert statements <https://docs.pytest.org/en/stable/assert.html>`_ (no need to remember ``self.assert*`` names)
- `Auto-discovery
- <https://docs.pytest.org/en/latest/goodpractices.html#python-test-discovery>`_
- of test modules and functions;
+ <https://docs.pytest.org/en/stable/goodpractices.html#python-test-discovery>`_
+ of test modules and functions
-- `Modular fixtures <https://docs.pytest.org/en/latest/fixture.html>`_ for
- managing small or parametrized long-lived test resources;
+- `Modular fixtures <https://docs.pytest.org/en/stable/fixture.html>`_ for
+ managing small or parametrized long-lived test resources
-- Can run `unittest <https://docs.pytest.org/en/latest/unittest.html>`_ (or trial),
- `nose <https://docs.pytest.org/en/latest/nose.html>`_ test suites out of the box;
+- Can run `unittest <https://docs.pytest.org/en/stable/unittest.html>`_ (or trial),
+ `nose <https://docs.pytest.org/en/stable/nose.html>`_ test suites out of the box
-- Python 3.5+ and PyPy3;
+- Python 3.6+ and PyPy3
-- Rich plugin architecture, with over 315+ `external plugins <http://plugincompat.herokuapp.com>`_ and thriving community;
+- Rich plugin architecture, with over 850+ `external plugins <http://plugincompat.herokuapp.com>`_ and thriving community
Documentation
-------------
-For full documentation, including installation, tutorials and PDF documents, please see https://docs.pytest.org/en/latest/.
+For full documentation, including installation, tutorials and PDF documents, please see https://docs.pytest.org/en/stable/.
Bugs/Requests
@@ -160,7 +166,7 @@ Please use the `GitHub issue tracker <https://github.com/pytest-dev/pytest/issue
Changelog
---------
-Consult the `Changelog <https://docs.pytest.org/en/latest/changelog.html>`__ page for fixes and enhancements of each version.
+Consult the `Changelog <https://docs.pytest.org/en/stable/changelog.html>`__ page for fixes and enhancements of each version.
Support pytest
@@ -200,10 +206,10 @@ Tidelift will coordinate the fix and disclosure.
License
-------
-Copyright Holger Krekel and others, 2004-2020.
+Copyright Holger Krekel and others, 2004-2021.
Distributed under the terms of the `MIT`_ license, pytest is free and open source software.
-.. _`MIT`: https://github.com/pytest-dev/pytest/blob/master/LICENSE
+.. _`MIT`: https://github.com/pytest-dev/pytest/blob/main/LICENSE
diff --git a/contrib/python/pytest/py3/.dist-info/entry_points.txt b/contrib/python/pytest/py3/.dist-info/entry_points.txt
index d8e4fd298f..0267c75b77 100644
--- a/contrib/python/pytest/py3/.dist-info/entry_points.txt
+++ b/contrib/python/pytest/py3/.dist-info/entry_points.txt
@@ -1,4 +1,4 @@
[console_scripts]
-py.test = pytest:main
-pytest = pytest:main
+py.test = pytest:console_main
+pytest = pytest:console_main
diff --git a/contrib/python/pytest/py3/AUTHORS b/contrib/python/pytest/py3/AUTHORS
index 3de12aa12f..2c690c5d28 100644
--- a/contrib/python/pytest/py3/AUTHORS
+++ b/contrib/python/pytest/py3/AUTHORS
@@ -21,6 +21,7 @@ Anders Hovmöller
Andras Mitzki
Andras Tim
Andrea Cimatoribus
+Andreas Motl
Andreas Zeidler
Andrey Paramonov
Andrzej Klajnert
@@ -32,6 +33,7 @@ Anthony Sottile
Anton Lodder
Antony Lee
Arel Cordero
+Ariel Pillemer
Armin Rigo
Aron Coyle
Aron Curzon
@@ -55,14 +57,17 @@ Charles Cloud
Charles Machalow
Charnjit SiNGH (CCSJ)
Chris Lamb
+Chris NeJame
Christian Boelsen
Christian Fetzer
Christian Neumüller
Christian Theunert
Christian Tismer
+Christine Mecklenborg
Christoph Buelter
Christopher Dignam
Christopher Gilling
+Claire Cecil
Claudio Madotto
CrazyMerlyn
Cyrus Maden
@@ -80,11 +85,13 @@ David Paul Röthlisberger
David Szotten
David Vierra
Daw-Ran Liou
+Debi Mishra
Denis Kirisov
Dhiren Serai
Diego Russo
Dmitry Dygalo
Dmitry Pribysh
+Dominic Mortlock
Duncan Betts
Edison Gustavo Muenz
Edoardo Batini
@@ -94,17 +101,22 @@ Elizaveta Shashkova
Endre Galaczi
Eric Hunsberger
Eric Siegerman
+Erik Aronesty
Erik M. Bray
Evan Kepner
Fabien Zarifian
Fabio Zadrozny
+Felix Nieuwenhuizen
Feng Ma
Florian Bruhin
+Florian Dahlitz
Floris Bruynooghe
Gabriel Reis
+Garvit Shubham
Gene Wood
George Kussumoto
Georgy Dyuldin
+Gleb Nikonorov
Graham Horler
Greg Price
Gregory Lee
@@ -123,6 +135,7 @@ Ilya Konstantinov
Ionuț Turturică
Iwan Briquemont
Jaap Broekhuizen
+Jakob van Santen
Jakub Mitoraj
Jan Balster
Janne Vanhala
@@ -145,9 +158,13 @@ Joshua Bronson
Jurko Gospodnetić
Justyna Janczyszyn
Kale Kundert
+Kamran Ahmad
Karl O. Pinc
+Karthikeyan Singaravelan
Katarzyna Jachim
+Katarzyna Król
Katerina Koukiou
+Keri Volans
Kevin Cox
Kevin J. Foley
Kodi B. Arfer
@@ -157,6 +174,7 @@ Kyle Altendorf
Lawrence Mitchell
Lee Kamentsky
Lev Maximov
+Lewis Cowles
Llandy Riveron Del Risco
Loic Esteve
Lukas Bednar
@@ -183,7 +201,9 @@ Matt Duck
Matt Williams
Matthias Hafner
Maxim Filipenko
+Maximilian Cosmo Sitter
mbyt
+Mickey Pashov
Michael Aquilina
Michael Birtwell
Michael Droettboom
@@ -213,15 +233,22 @@ Ondřej Súkup
Oscar Benjamin
Patrick Hayes
Pauli Virtanen
+Pavel Karateev
Paweł Adamczak
Pedro Algarvio
+Petter Strandmark
Philipp Loose
Pieter Mulder
Piotr Banaszkiewicz
+Piotr Helm
+Prakhar Gurunani
+Prashant Anand
+Prashant Sharma
Pulkit Goyal
Punyashloka Biswal
Quentin Pradet
Ralf Schmitt
+Ram Rachum
Ralph Giles
Ran Benita
Raphael Castaneda
@@ -235,16 +262,21 @@ Romain Dorgueil
Roman Bolshakov
Ronny Pfannschmidt
Ross Lawley
+Ruaridh Williamson
Russel Winder
Ryan Wooden
Samuel Dion-Girardeau
Samuel Searles-Bryant
Samuele Pedroni
+Sanket Duthade
Sankt Petersbug
Segev Finer
Serhii Mozghovyi
Seth Junot
+Shantanu Jain
+Shubham Adep
Simon Gomizelj
+Simon Kerr
Skylar Downes
Srinivas Reddy Thatiparthy
Stefan Farmbauer
@@ -254,8 +286,10 @@ Stefano Taschini
Steffen Allner
Stephan Obermann
Sven-Hendrik Haase
+Sylvain Marié
Tadek Teleżyński
Takafumi Arakaki
+Tanvi Mehta
Tarcisio Fischer
Tareq Alayan
Ted Xiao
@@ -278,6 +312,7 @@ Vidar T. Fauske
Virgil Dupras
Vitaly Lashmanov
Vlad Dragos
+Vlad Radziuk
Vladyslav Rachek
Volodymyr Piskun
Wei Lin
@@ -291,3 +326,4 @@ Xuecong Liao
Yoav Caspi
Zac Hatfield-Dodds
Zoltán Máté
+Zsolt Cserna
diff --git a/contrib/python/pytest/py3/README.rst b/contrib/python/pytest/py3/README.rst
index 864467ea21..a6ba517c66 100644
--- a/contrib/python/pytest/py3/README.rst
+++ b/contrib/python/pytest/py3/README.rst
@@ -1,5 +1,5 @@
-.. image:: https://docs.pytest.org/en/latest/_static/pytest1.png
- :target: https://docs.pytest.org/en/latest/
+.. image:: https://github.com/pytest-dev/pytest/raw/main/doc/en/img/pytest_logo_curves.svg
+ :target: https://docs.pytest.org/en/stable/
:align: center
:alt: pytest
@@ -15,15 +15,19 @@
.. image:: https://img.shields.io/pypi/pyversions/pytest.svg
:target: https://pypi.org/project/pytest/
-.. image:: https://codecov.io/gh/pytest-dev/pytest/branch/master/graph/badge.svg
+.. image:: https://codecov.io/gh/pytest-dev/pytest/branch/main/graph/badge.svg
:target: https://codecov.io/gh/pytest-dev/pytest
:alt: Code coverage Status
.. image:: https://travis-ci.org/pytest-dev/pytest.svg?branch=master
:target: https://travis-ci.org/pytest-dev/pytest
-.. image:: https://dev.azure.com/pytest-dev/pytest/_apis/build/status/pytest-CI?branchName=master
- :target: https://dev.azure.com/pytest-dev/pytest
+.. image:: https://github.com/pytest-dev/pytest/workflows/main/badge.svg
+ :target: https://github.com/pytest-dev/pytest/actions?query=workflow%3Amain
+
+.. image:: https://results.pre-commit.ci/badge/github/pytest-dev/pytest/main.svg
+ :target: https://results.pre-commit.ci/latest/github/pytest-dev/pytest/master
+ :alt: pre-commit.ci status
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black
@@ -71,33 +75,33 @@ To execute it::
========================== 1 failed in 0.04 seconds ===========================
-Due to ``pytest``'s detailed assertion introspection, only plain ``assert`` statements are used. See `getting-started <https://docs.pytest.org/en/latest/getting-started.html#our-first-test-run>`_ for more examples.
+Due to ``pytest``'s detailed assertion introspection, only plain ``assert`` statements are used. See `getting-started <https://docs.pytest.org/en/stable/getting-started.html#our-first-test-run>`_ for more examples.
Features
--------
-- Detailed info on failing `assert statements <https://docs.pytest.org/en/latest/assert.html>`_ (no need to remember ``self.assert*`` names);
+- Detailed info on failing `assert statements <https://docs.pytest.org/en/stable/assert.html>`_ (no need to remember ``self.assert*`` names)
- `Auto-discovery
- <https://docs.pytest.org/en/latest/goodpractices.html#python-test-discovery>`_
- of test modules and functions;
+ <https://docs.pytest.org/en/stable/goodpractices.html#python-test-discovery>`_
+ of test modules and functions
-- `Modular fixtures <https://docs.pytest.org/en/latest/fixture.html>`_ for
- managing small or parametrized long-lived test resources;
+- `Modular fixtures <https://docs.pytest.org/en/stable/fixture.html>`_ for
+ managing small or parametrized long-lived test resources
-- Can run `unittest <https://docs.pytest.org/en/latest/unittest.html>`_ (or trial),
- `nose <https://docs.pytest.org/en/latest/nose.html>`_ test suites out of the box;
+- Can run `unittest <https://docs.pytest.org/en/stable/unittest.html>`_ (or trial),
+ `nose <https://docs.pytest.org/en/stable/nose.html>`_ test suites out of the box
-- Python 3.5+ and PyPy3;
+- Python 3.6+ and PyPy3
-- Rich plugin architecture, with over 315+ `external plugins <http://plugincompat.herokuapp.com>`_ and thriving community;
+- Rich plugin architecture, with over 850+ `external plugins <http://plugincompat.herokuapp.com>`_ and thriving community
Documentation
-------------
-For full documentation, including installation, tutorials and PDF documents, please see https://docs.pytest.org/en/latest/.
+For full documentation, including installation, tutorials and PDF documents, please see https://docs.pytest.org/en/stable/.
Bugs/Requests
@@ -109,7 +113,7 @@ Please use the `GitHub issue tracker <https://github.com/pytest-dev/pytest/issue
Changelog
---------
-Consult the `Changelog <https://docs.pytest.org/en/latest/changelog.html>`__ page for fixes and enhancements of each version.
+Consult the `Changelog <https://docs.pytest.org/en/stable/changelog.html>`__ page for fixes and enhancements of each version.
Support pytest
@@ -149,8 +153,8 @@ Tidelift will coordinate the fix and disclosure.
License
-------
-Copyright Holger Krekel and others, 2004-2020.
+Copyright Holger Krekel and others, 2004-2021.
Distributed under the terms of the `MIT`_ license, pytest is free and open source software.
-.. _`MIT`: https://github.com/pytest-dev/pytest/blob/master/LICENSE
+.. _`MIT`: https://github.com/pytest-dev/pytest/blob/main/LICENSE
diff --git a/contrib/python/pytest/py3/_pytest/_argcomplete.py b/contrib/python/pytest/py3/_pytest/_argcomplete.py
index 7ca216ecf9..41d9d9407c 100644
--- a/contrib/python/pytest/py3/_pytest/_argcomplete.py
+++ b/contrib/python/pytest/py3/_pytest/_argcomplete.py
@@ -1,7 +1,8 @@
-"""allow bash-completion for argparse with argcomplete if installed
-needs argcomplete>=0.5.6 for python 3.2/3.3 (older versions fail
+"""Allow bash-completion for argparse with argcomplete if installed.
+
+Needs argcomplete>=0.5.6 for python 3.2/3.3 (older versions fail
to find the magic string, so _ARGCOMPLETE env. var is never set, and
-this does not need special code.
+this does not need special code).
Function try_argcomplete(parser) should be called directly before
the call to ArgumentParser.parse_args().
@@ -10,8 +11,7 @@ The filescompleter is what you normally would use on the positional
arguments specification, in order to get "dirname/" after "dirn<TAB>"
instead of the default "dirname ":
- optparser.add_argument(Config._file_or_dir, nargs='*'
- ).completer=filescompleter
+ optparser.add_argument(Config._file_or_dir, nargs='*').completer=filescompleter
Other, application specific, completers should go in the file
doing the add_argument calls as they need to be specified as .completer
@@ -20,35 +20,43 @@ attribute points to will not be used).
SPEEDUP
=======
+
The generic argcomplete script for bash-completion
-(/etc/bash_completion.d/python-argcomplete.sh )
+(/etc/bash_completion.d/python-argcomplete.sh)
uses a python program to determine startup script generated by pip.
You can speed up completion somewhat by changing this script to include
# PYTHON_ARGCOMPLETE_OK
-so the the python-argcomplete-check-easy-install-script does not
+so the python-argcomplete-check-easy-install-script does not
need to be called to find the entry point of the code and see if that is
-marked with PYTHON_ARGCOMPLETE_OK
+marked with PYTHON_ARGCOMPLETE_OK.
INSTALL/DEBUGGING
=================
+
To include this support in another application that has setup.py generated
scripts:
-- add the line:
+
+- Add the line:
# PYTHON_ARGCOMPLETE_OK
- near the top of the main python entry point
-- include in the file calling parse_args():
+ near the top of the main python entry point.
+
+- Include in the file calling parse_args():
from _argcomplete import try_argcomplete, filescompleter
- , call try_argcomplete just before parse_args(), and optionally add
- filescompleter to the positional arguments' add_argument()
+ Call try_argcomplete just before parse_args(), and optionally add
+ filescompleter to the positional arguments' add_argument().
+
If things do not work right away:
-- switch on argcomplete debugging with (also helpful when doing custom
+
+- Switch on argcomplete debugging with (also helpful when doing custom
completers):
export _ARC_DEBUG=1
-- run:
+
+- Run:
python-argcomplete-check-easy-install-script $(which appname)
echo $?
- will echo 0 if the magic line has been found, 1 if not
-- sometimes it helps to find early on errors using:
+ will echo 0 if the magic line has been found, 1 if not.
+
+- Sometimes it helps to find early on errors using:
_ARGCOMPLETE=1 _ARC_DEBUG=1 appname
which should throw a KeyError: 'COMPLINE' (which is properly set by the
global argcomplete script).
@@ -63,13 +71,13 @@ from typing import Optional
class FastFilesCompleter:
- "Fast file completer class"
+ """Fast file completer class."""
def __init__(self, directories: bool = True) -> None:
self.directories = directories
def __call__(self, prefix: str, **kwargs: Any) -> List[str]:
- """only called on non option completions"""
+ # Only called on non option completions.
if os.path.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.path.sep)
else:
@@ -77,7 +85,7 @@ class FastFilesCompleter:
completion = []
globbed = []
if "*" not in prefix and "?" not in prefix:
- # we are on unix, otherwise no bash
+ # We are on unix, otherwise no bash.
if not prefix or prefix[-1] == os.path.sep:
globbed.extend(glob(prefix + ".*"))
prefix += "*"
@@ -85,7 +93,7 @@ class FastFilesCompleter:
for x in sorted(globbed):
if os.path.isdir(x):
x += "/"
- # append stripping the prefix (like bash, not like compgen)
+ # Append stripping the prefix (like bash, not like compgen).
completion.append(x[prefix_dir:])
return completion
@@ -95,7 +103,7 @@ if os.environ.get("_ARGCOMPLETE"):
import argcomplete.completers
except ImportError:
sys.exit(-1)
- filescompleter = FastFilesCompleter() # type: Optional[FastFilesCompleter]
+ filescompleter: Optional[FastFilesCompleter] = FastFilesCompleter()
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
argcomplete.autocomplete(parser, always_complete_options=False)
diff --git a/contrib/python/pytest/py3/_pytest/_code/__init__.py b/contrib/python/pytest/py3/_pytest/_code/__init__.py
index 370e41dc9f..511d0dde66 100644
--- a/contrib/python/pytest/py3/_pytest/_code/__init__.py
+++ b/contrib/python/pytest/py3/_pytest/_code/__init__.py
@@ -1,10 +1,22 @@
-""" python inspection/code generation API """
-from .code import Code # noqa
-from .code import ExceptionInfo # noqa
-from .code import filter_traceback # noqa
-from .code import Frame # noqa
-from .code import getrawcode # noqa
-from .code import Traceback # noqa
-from .source import compile_ as compile # noqa
-from .source import getfslineno # noqa
-from .source import Source # noqa
+"""Python inspection/code generation API."""
+from .code import Code
+from .code import ExceptionInfo
+from .code import filter_traceback
+from .code import Frame
+from .code import getfslineno
+from .code import Traceback
+from .code import TracebackEntry
+from .source import getrawcode
+from .source import Source
+
+__all__ = [
+ "Code",
+ "ExceptionInfo",
+ "filter_traceback",
+ "Frame",
+ "getfslineno",
+ "getrawcode",
+ "Traceback",
+ "TracebackEntry",
+ "Source",
+]
diff --git a/contrib/python/pytest/py3/_pytest/_code/code.py b/contrib/python/pytest/py3/_pytest/_code/code.py
index 965074c924..423069330a 100644
--- a/contrib/python/pytest/py3/_pytest/_code/code.py
+++ b/contrib/python/pytest/py3/_pytest/_code/code.py
@@ -5,6 +5,7 @@ import traceback
from inspect import CO_VARARGS
from inspect import CO_VARKEYWORDS
from io import StringIO
+from pathlib import Path
from traceback import format_exception_only
from types import CodeType
from types import FrameType
@@ -15,11 +16,15 @@ from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
+from typing import Mapping
from typing import Optional
+from typing import overload
from typing import Pattern
from typing import Sequence
from typing import Set
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from weakref import ref
@@ -29,35 +34,34 @@ import pluggy
import py
import _pytest
+from _pytest._code.source import findsource
+from _pytest._code.source import getrawcode
+from _pytest._code.source import getstatementrange_ast
+from _pytest._code.source import Source
from _pytest._io import TerminalWriter
from _pytest._io.saferepr import safeformat
from _pytest._io.saferepr import saferepr
-from _pytest.compat import ATTRS_EQ_FIELD
-from _pytest.compat import overload
-from _pytest.compat import TYPE_CHECKING
+from _pytest.compat import final
+from _pytest.compat import get_real_func
if TYPE_CHECKING:
- from typing import Type
from typing_extensions import Literal
- from weakref import ReferenceType # noqa: F401
+ from weakref import ReferenceType
- from _pytest._code import Source
-
- _TracebackStyle = Literal["long", "short", "line", "no", "native"]
+ _TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"]
class Code:
- """ wrapper around Python code objects """
-
- def __init__(self, rawcode) -> None:
- if not hasattr(rawcode, "co_filename"):
- rawcode = getrawcode(rawcode)
- if not isinstance(rawcode, CodeType):
- raise TypeError("not a code object: {!r}".format(rawcode))
- self.filename = rawcode.co_filename
- self.firstlineno = rawcode.co_firstlineno - 1
- self.name = rawcode.co_name
- self.raw = rawcode
+ """Wrapper around Python code objects."""
+
+ __slots__ = ("raw",)
+
+ def __init__(self, obj: CodeType) -> None:
+ self.raw = obj
+
+ @classmethod
+ def from_function(cls, obj: object) -> "Code":
+ return cls(getrawcode(obj))
def __eq__(self, other):
return self.raw == other.raw
@@ -65,14 +69,18 @@ class Code:
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
- def __ne__(self, other):
- return not self == other
+ @property
+ def firstlineno(self) -> int:
+ return self.raw.co_firstlineno - 1
+
+ @property
+ def name(self) -> str:
+ return self.raw.co_name
@property
def path(self) -> Union[py.path.local, str]:
- """ return a path object pointing to source code (or a str in case
- of OSError / non-existing file).
- """
+ """Return a path object pointing to source code, or an ``str`` in
+ case of ``OSError`` / non-existing file."""
if not self.raw.co_filename:
return ""
try:
@@ -88,28 +96,22 @@ class Code:
@property
def fullsource(self) -> Optional["Source"]:
- """ return a _pytest._code.Source object for the full source file of the code
- """
- from _pytest._code import source
-
- full, _ = source.findsource(self.raw)
+ """Return a _pytest._code.Source object for the full source file of the code."""
+ full, _ = findsource(self.raw)
return full
def source(self) -> "Source":
- """ return a _pytest._code.Source object for the code object's source only
- """
+ """Return a _pytest._code.Source object for the code object's source only."""
# return source only for that part of code
- import _pytest._code
-
- return _pytest._code.Source(self.raw)
+ return Source(self.raw)
def getargs(self, var: bool = False) -> Tuple[str, ...]:
- """ return a tuple with the argument names for the code object
+ """Return a tuple with the argument names for the code object.
- if 'var' is set True also return the names of the variable and
- keyword arguments when present
+ If 'var' is set True also return the names of the variable and
+ keyword arguments when present.
"""
- # handfull shortcut for getting args
+ # Handy shortcut for getting args.
raw = self.raw
argcount = raw.co_argcount
if var:
@@ -122,55 +124,54 @@ class Frame:
"""Wrapper around a Python frame holding f_locals and f_globals
in which expressions can be evaluated."""
+ __slots__ = ("raw",)
+
def __init__(self, frame: FrameType) -> None:
- self.lineno = frame.f_lineno - 1
- self.f_globals = frame.f_globals
- self.f_locals = frame.f_locals
self.raw = frame
- self.code = Code(frame.f_code)
@property
- def statement(self) -> "Source":
- """ statement this frame is at """
- import _pytest._code
+ def lineno(self) -> int:
+ return self.raw.f_lineno - 1
+
+ @property
+ def f_globals(self) -> Dict[str, Any]:
+ return self.raw.f_globals
+
+ @property
+ def f_locals(self) -> Dict[str, Any]:
+ return self.raw.f_locals
+
+ @property
+ def code(self) -> Code:
+ return Code(self.raw.f_code)
+ @property
+ def statement(self) -> "Source":
+ """Statement this frame is at."""
if self.code.fullsource is None:
- return _pytest._code.Source("")
+ return Source("")
return self.code.fullsource.getstatement(self.lineno)
def eval(self, code, **vars):
- """ evaluate 'code' in the frame
+ """Evaluate 'code' in the frame.
- 'vars' are optional additional local variables
+ 'vars' are optional additional local variables.
- returns the result of the evaluation
+ Returns the result of the evaluation.
"""
f_locals = self.f_locals.copy()
f_locals.update(vars)
return eval(code, self.f_globals, f_locals)
- def exec_(self, code, **vars) -> None:
- """ exec 'code' in the frame
-
- 'vars' are optional; additional local variables
- """
- f_locals = self.f_locals.copy()
- f_locals.update(vars)
- exec(code, self.f_globals, f_locals)
-
def repr(self, object: object) -> str:
- """ return a 'safe' (non-recursive, one-line) string repr for 'object'
- """
+ """Return a 'safe' (non-recursive, one-line) string repr for 'object'."""
return saferepr(object)
- def is_true(self, object):
- return object
-
def getargs(self, var: bool = False):
- """ return a list of tuples (name, value) for all arguments
+ """Return a list of tuples (name, value) for all arguments.
- if 'var' is set True also include the variable and keyword
- arguments when present
+ If 'var' is set True, also include the variable and keyword arguments
+ when present.
"""
retval = []
for arg in self.code.getargs(var):
@@ -182,15 +183,22 @@ class Frame:
class TracebackEntry:
- """ a single entry in a traceback """
+ """A single entry in a Traceback."""
- _repr_style = None # type: Optional[Literal["short", "long"]]
- exprinfo = None
+ __slots__ = ("_rawentry", "_excinfo", "_repr_style")
- def __init__(self, rawentry: TracebackType, excinfo=None) -> None:
- self._excinfo = excinfo
+ def __init__(
+ self,
+ rawentry: TracebackType,
+ excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None,
+ ) -> None:
self._rawentry = rawentry
- self.lineno = rawentry.tb_lineno - 1
+ self._excinfo = excinfo
+ self._repr_style: Optional['Literal["short", "long"]'] = None
+
+ @property
+ def lineno(self) -> int:
+ return self._rawentry.tb_lineno - 1
def set_repr_style(self, mode: "Literal['short', 'long']") -> None:
assert mode in ("short", "long")
@@ -209,30 +217,28 @@ class TracebackEntry:
@property
def statement(self) -> "Source":
- """ _pytest._code.Source object for the current statement """
+ """_pytest._code.Source object for the current statement."""
source = self.frame.code.fullsource
assert source is not None
return source.getstatement(self.lineno)
@property
- def path(self):
- """ path to the source code """
+ def path(self) -> Union[py.path.local, str]:
+ """Path to the source code."""
return self.frame.code.path
@property
def locals(self) -> Dict[str, Any]:
- """ locals of underlying frame """
+ """Locals of underlying frame."""
return self.frame.f_locals
def getfirstlinesource(self) -> int:
return self.frame.code.firstlineno
def getsource(self, astcache=None) -> Optional["Source"]:
- """ return failing source code. """
+ """Return failing source code."""
# we use the passed in astcache to not reparse asttrees
# within exception info printing
- from _pytest._code.source import getstatementrange_ast
-
source = self.frame.code.fullsource
if source is None:
return None
@@ -255,59 +261,71 @@ class TracebackEntry:
source = property(getsource)
- def ishidden(self):
- """ return True if the current frame has a var __tracebackhide__
- resolving to True.
+ def ishidden(self) -> bool:
+ """Return True if the current frame has a var __tracebackhide__
+ resolving to True.
- If __tracebackhide__ is a callable, it gets called with the
- ExceptionInfo instance and can decide whether to hide the traceback.
+ If __tracebackhide__ is a callable, it gets called with the
+ ExceptionInfo instance and can decide whether to hide the traceback.
- mostly for internal use
+ Mostly for internal use.
"""
- f = self.frame
- tbh = f.f_locals.get(
- "__tracebackhide__", f.f_globals.get("__tracebackhide__", False)
+ tbh: Union[bool, Callable[[Optional[ExceptionInfo[BaseException]]], bool]] = (
+ False
)
+ for maybe_ns_dct in (self.frame.f_locals, self.frame.f_globals):
+ # in normal cases, f_locals and f_globals are dictionaries
+ # however via `exec(...)` / `eval(...)` they can be other types
+ # (even incorrect types!).
+ # as such, we suppress all exceptions while accessing __tracebackhide__
+ try:
+ tbh = maybe_ns_dct["__tracebackhide__"]
+ except Exception:
+ pass
+ else:
+ break
if tbh and callable(tbh):
return tbh(None if self._excinfo is None else self._excinfo())
return tbh
def __str__(self) -> str:
- try:
- fn = str(self.path)
- except py.error.Error:
- fn = "???"
name = self.frame.code.name
try:
line = str(self.statement).lstrip()
except KeyboardInterrupt:
raise
- except: # noqa
+ except BaseException:
line = "???"
- return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line)
+ # This output does not quite match Python's repr for traceback entries,
+ # but changing it to do so would break certain plugins. See
+ # https://github.com/pytest-dev/pytest/pull/7535/ for details.
+ return " File %r:%d in %s\n %s\n" % (
+ str(self.path),
+ self.lineno + 1,
+ name,
+ line,
+ )
@property
def name(self) -> str:
- """ co_name of underlying code """
+ """co_name of underlying code."""
return self.frame.code.raw.co_name
class Traceback(List[TracebackEntry]):
- """ Traceback objects encapsulate and offer higher level
- access to Traceback entries.
- """
+ """Traceback objects encapsulate and offer higher level access to Traceback entries."""
def __init__(
self,
tb: Union[TracebackType, Iterable[TracebackEntry]],
- excinfo: Optional["ReferenceType[ExceptionInfo]"] = None,
+ excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None,
) -> None:
- """ initialize from given python traceback object and ExceptionInfo """
+ """Initialize from given python traceback object and ExceptionInfo."""
self._excinfo = excinfo
if isinstance(tb, TracebackType):
def f(cur: TracebackType) -> Iterable[TracebackEntry]:
- cur_ = cur # type: Optional[TracebackType]
+ cur_: Optional[TracebackType] = cur
while cur_ is not None:
yield TracebackEntry(cur_, excinfo=excinfo)
cur_ = cur_.tb_next
@@ -321,16 +339,16 @@ class Traceback(List[TracebackEntry]):
path=None,
lineno: Optional[int] = None,
firstlineno: Optional[int] = None,
- excludepath=None,
+ excludepath: Optional[py.path.local] = None,
) -> "Traceback":
- """ return a Traceback instance wrapping part of this Traceback
+ """Return a Traceback instance wrapping part of this Traceback.
- by providing any combination of path, lineno and firstlineno, the
- first frame to start the to-be-returned traceback is determined
+ By providing any combination of path, lineno and firstlineno, the
+ first frame to start the to-be-returned traceback is determined.
- this allows cutting the first part of a Traceback instance e.g.
- for formatting reasons (removing some uninteresting bits that deal
- with handling of the exception/traceback)
+ This allows cutting the first part of a Traceback instance e.g.
+ for formatting reasons (removing some uninteresting bits that deal
+ with handling of the exception/traceback).
"""
for x in self:
code = x.frame.code
@@ -350,15 +368,13 @@ class Traceback(List[TracebackEntry]):
@overload
def __getitem__(self, key: int) -> TracebackEntry:
- raise NotImplementedError()
+ ...
- @overload # noqa: F811
- def __getitem__(self, key: slice) -> "Traceback": # noqa: F811
- raise NotImplementedError()
+ @overload
+ def __getitem__(self, key: slice) -> "Traceback":
+ ...
- def __getitem__( # noqa: F811
- self, key: Union[int, slice]
- ) -> Union[TracebackEntry, "Traceback"]:
+ def __getitem__(self, key: Union[int, slice]) -> Union[TracebackEntry, "Traceback"]:
if isinstance(key, slice):
return self.__class__(super().__getitem__(key))
else:
@@ -367,21 +383,19 @@ class Traceback(List[TracebackEntry]):
def filter(
self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden()
) -> "Traceback":
- """ return a Traceback instance with certain items removed
+ """Return a Traceback instance with certain items removed
- fn is a function that gets a single argument, a TracebackEntry
- instance, and should return True when the item should be added
- to the Traceback, False when not
+ fn is a function that gets a single argument, a TracebackEntry
+ instance, and should return True when the item should be added
+ to the Traceback, False when not.
- by default this removes all the TracebackEntries which are hidden
- (see ishidden() above)
+ By default this removes all the TracebackEntries which are hidden
+ (see ishidden() above).
"""
return Traceback(filter(fn, self), self._excinfo)
def getcrashentry(self) -> TracebackEntry:
- """ return last non-hidden traceback entry that lead
- to the exception of a traceback.
- """
+ """Return last non-hidden traceback entry that lead to the exception of a traceback."""
for i in range(-1, -len(self) - 1, -1):
entry = self[i]
if not entry.ishidden():
@@ -389,10 +403,9 @@ class Traceback(List[TracebackEntry]):
return self[-1]
def recursionindex(self) -> Optional[int]:
- """ return the index of the frame/TracebackEntry where recursion
- originates if appropriate, None if no recursion occurred
- """
- cache = {} # type: Dict[Tuple[Any, int, int], List[Dict[str, Any]]]
+ """Return the index of the frame/TracebackEntry where recursion originates if
+ appropriate, None if no recursion occurred."""
+ cache: Dict[Tuple[Any, int, int], List[Dict[str, Any]]] = {}
for i, entry in enumerate(self):
# id for the code.raw is needed to work around
# the strange metaprogramming in the decorator lib from pypi
@@ -405,12 +418,10 @@ class Traceback(List[TracebackEntry]):
f = entry.frame
loc = f.f_locals
for otherloc in values:
- if f.is_true(
- f.eval(
- co_equal,
- __recursioncache_locals_1=loc,
- __recursioncache_locals_2=otherloc,
- )
+ if f.eval(
+ co_equal,
+ __recursioncache_locals_1=loc,
+ __recursioncache_locals_2=otherloc,
):
return i
values.append(entry.frame.f_locals)
@@ -422,37 +433,36 @@ co_equal = compile(
)
-_E = TypeVar("_E", bound=BaseException)
+_E = TypeVar("_E", bound=BaseException, covariant=True)
+@final
@attr.s(repr=False)
class ExceptionInfo(Generic[_E]):
- """ wraps sys.exc_info() objects and offers
- help for navigating the traceback.
- """
+ """Wraps sys.exc_info() objects and offers help for navigating the traceback."""
_assert_start_repr = "AssertionError('assert "
- _excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
+ _excinfo = attr.ib(type=Optional[Tuple[Type["_E"], "_E", TracebackType]])
_striptext = attr.ib(type=str, default="")
_traceback = attr.ib(type=Optional[Traceback], default=None)
@classmethod
def from_exc_info(
cls,
- exc_info: Tuple["Type[_E]", "_E", TracebackType],
+ exc_info: Tuple[Type[_E], _E, TracebackType],
exprinfo: Optional[str] = None,
) -> "ExceptionInfo[_E]":
- """returns an ExceptionInfo for an existing exc_info tuple.
+ """Return an ExceptionInfo for an existing exc_info tuple.
.. warning::
Experimental API
-
- :param exprinfo: a text string helping to determine if we should
- strip ``AssertionError`` from the output, defaults
- to the exception message/``__str__()``
+ :param exprinfo:
+ A text string helping to determine if we should strip
+ ``AssertionError`` from the output. Defaults to the exception
+ message/``__str__()``.
"""
_striptext = ""
if exprinfo is None and isinstance(exc_info[1], AssertionError):
@@ -468,16 +478,16 @@ class ExceptionInfo(Generic[_E]):
def from_current(
cls, exprinfo: Optional[str] = None
) -> "ExceptionInfo[BaseException]":
- """returns an ExceptionInfo matching the current traceback
+ """Return an ExceptionInfo matching the current traceback.
.. warning::
Experimental API
-
- :param exprinfo: a text string helping to determine if we should
- strip ``AssertionError`` from the output, defaults
- to the exception message/``__str__()``
+ :param exprinfo:
+ A text string helping to determine if we should strip
+ ``AssertionError`` from the output. Defaults to the exception
+ message/``__str__()``.
"""
tup = sys.exc_info()
assert tup[0] is not None, "no current exception"
@@ -488,18 +498,17 @@ class ExceptionInfo(Generic[_E]):
@classmethod
def for_later(cls) -> "ExceptionInfo[_E]":
- """return an unfilled ExceptionInfo
- """
+ """Return an unfilled ExceptionInfo."""
return cls(None)
- def fill_unfilled(self, exc_info: Tuple["Type[_E]", _E, TracebackType]) -> None:
- """fill an unfilled ExceptionInfo created with for_later()"""
+ def fill_unfilled(self, exc_info: Tuple[Type[_E], _E, TracebackType]) -> None:
+ """Fill an unfilled ExceptionInfo created with ``for_later()``."""
assert self._excinfo is None, "ExceptionInfo was already filled"
self._excinfo = exc_info
@property
- def type(self) -> "Type[_E]":
- """the exception class"""
+ def type(self) -> Type[_E]:
+ """The exception class."""
assert (
self._excinfo is not None
), ".type can only be used after the context manager exits"
@@ -507,7 +516,7 @@ class ExceptionInfo(Generic[_E]):
@property
def value(self) -> _E:
- """the exception value"""
+ """The exception value."""
assert (
self._excinfo is not None
), ".value can only be used after the context manager exits"
@@ -515,7 +524,7 @@ class ExceptionInfo(Generic[_E]):
@property
def tb(self) -> TracebackType:
- """the exception raw traceback"""
+ """The exception raw traceback."""
assert (
self._excinfo is not None
), ".tb can only be used after the context manager exits"
@@ -523,7 +532,7 @@ class ExceptionInfo(Generic[_E]):
@property
def typename(self) -> str:
- """the type name of the exception"""
+ """The type name of the exception."""
assert (
self._excinfo is not None
), ".typename can only be used after the context manager exits"
@@ -531,7 +540,7 @@ class ExceptionInfo(Generic[_E]):
@property
def traceback(self) -> Traceback:
- """the traceback"""
+ """The traceback."""
if self._traceback is None:
self._traceback = Traceback(self.tb, excinfo=ref(self))
return self._traceback
@@ -548,12 +557,12 @@ class ExceptionInfo(Generic[_E]):
)
def exconly(self, tryshort: bool = False) -> str:
- """ return the exception as a string
+ """Return the exception as a string.
- when 'tryshort' resolves to True, and the exception is a
- _pytest._code._AssertionError, only the actual exception part of
- the exception representation is returned (so 'AssertionError: ' is
- removed from the beginning)
+ When 'tryshort' resolves to True, and the exception is a
+ _pytest._code._AssertionError, only the actual exception part of
+ the exception representation is returned (so 'AssertionError: ' is
+ removed from the beginning).
"""
lines = format_exception_only(self.type, self.value)
text = "".join(lines)
@@ -564,9 +573,12 @@ class ExceptionInfo(Generic[_E]):
return text
def errisinstance(
- self, exc: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]]
+ self, exc: Union[Type[BaseException], Tuple[Type[BaseException], ...]]
) -> bool:
- """ return True if the exception is an instance of exc """
+ """Return True if the exception is an instance of exc.
+
+ Consider using ``isinstance(excinfo.value, exc)`` instead.
+ """
return isinstance(self.value, exc)
def _getreprcrash(self) -> "ReprFileLocation":
@@ -585,14 +597,14 @@ class ExceptionInfo(Generic[_E]):
truncate_locals: bool = True,
chain: bool = True,
) -> Union["ReprExceptionInfo", "ExceptionChainRepr"]:
- """
- Return str()able representation of this exception info.
+ """Return str()able representation of this exception info.
:param bool showlocals:
Show locals per traceback entry.
Ignored if ``style=="native"``.
- :param str style: long|short|no|native traceback style
+ :param str style:
+ long|short|no|native|value traceback style.
:param bool abspath:
If paths should be changed to absolute or left unchanged.
@@ -607,7 +619,8 @@ class ExceptionInfo(Generic[_E]):
:param bool truncate_locals:
With ``showlocals==True``, make sure locals can be safely represented as strings.
- :param bool chain: if chained exceptions in Python 3 should be shown.
+ :param bool chain:
+ If chained exceptions in Python 3 should be shown.
.. versionchanged:: 3.9
@@ -634,24 +647,24 @@ class ExceptionInfo(Generic[_E]):
)
return fmt.repr_excinfo(self)
- def match(self, regexp: "Union[str, Pattern]") -> "Literal[True]":
- """
- Check whether the regular expression `regexp` matches the string
+ def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]":
+ """Check whether the regular expression `regexp` matches the string
representation of the exception using :func:`python:re.search`.
- If it matches `True` is returned.
- If it doesn't match an `AssertionError` is raised.
+
+ If it matches `True` is returned, otherwise an `AssertionError` is raised.
"""
__tracebackhide__ = True
- assert re.search(
- regexp, str(self.value)
- ), "Pattern {!r} does not match {!r}".format(regexp, str(self.value))
+ msg = "Regex pattern {!r} does not match {!r}."
+ if regexp == str(self.value):
+ msg += " Did you mean to `re.escape()` the regex?"
+ assert re.search(regexp, str(self.value)), msg.format(regexp, str(self.value))
# Return True to allow for "assert excinfo.match()".
return True
@attr.s
class FormattedExcinfo:
- """ presenting information about failing Functions and Generators. """
+ """Presenting information about failing Functions and Generators."""
# for traceback entries
flow_marker = ">"
@@ -667,17 +680,17 @@ class FormattedExcinfo:
astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False)
def _getindent(self, source: "Source") -> int:
- # figure out indent for given source
+ # Figure out indent for the given source.
try:
s = str(source.getstatement(len(source) - 1))
except KeyboardInterrupt:
raise
- except: # noqa
+ except BaseException:
try:
s = str(source[-1])
except KeyboardInterrupt:
raise
- except: # noqa
+ except BaseException:
return 0
return 4 + (len(s) - len(s.lstrip()))
@@ -697,17 +710,15 @@ class FormattedExcinfo:
def get_source(
self,
- source: "Source",
+ source: Optional["Source"],
line_index: int = -1,
- excinfo: Optional[ExceptionInfo] = None,
+ excinfo: Optional[ExceptionInfo[BaseException]] = None,
short: bool = False,
) -> List[str]:
- """ return formatted and marked up source lines. """
- import _pytest._code
-
+ """Return formatted and marked up source lines."""
lines = []
if source is None or line_index >= len(source.lines):
- source = _pytest._code.Source("???")
+ source = Source("???")
line_index = 0
if line_index < 0:
line_index += len(source)
@@ -726,11 +737,14 @@ class FormattedExcinfo:
return lines
def get_exconly(
- self, excinfo: ExceptionInfo, indent: int = 4, markall: bool = False
+ self,
+ excinfo: ExceptionInfo[BaseException],
+ indent: int = 4,
+ markall: bool = False,
) -> List[str]:
lines = []
indentstr = " " * indent
- # get the real exception information out
+ # Get the real exception information out.
exlines = excinfo.exconly(tryshort=True).split("\n")
failindent = self.fail_marker + indentstr[1:]
for line in exlines:
@@ -739,7 +753,7 @@ class FormattedExcinfo:
failindent = indentstr
return lines
- def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]:
+ def repr_locals(self, locals: Mapping[str, object]) -> Optional["ReprLocals"]:
if self.showlocals:
lines = []
keys = [loc for loc in locals if loc[0] != "@"]
@@ -756,9 +770,8 @@ class FormattedExcinfo:
str_repr = saferepr(value)
else:
str_repr = safeformat(value)
- # if len(str_repr) < 70 or not isinstance(value,
- # (list, tuple, dict)):
- lines.append("{:<10} = {}".format(name, str_repr))
+ # if len(str_repr) < 70 or not isinstance(value, (list, tuple, dict)):
+ lines.append(f"{name:<10} = {str_repr}")
# else:
# self._line("%-10s =\\" % (name,))
# # XXX
@@ -767,20 +780,19 @@ class FormattedExcinfo:
return None
def repr_traceback_entry(
- self, entry: TracebackEntry, excinfo: Optional[ExceptionInfo] = None
+ self,
+ entry: TracebackEntry,
+ excinfo: Optional[ExceptionInfo[BaseException]] = None,
) -> "ReprEntry":
- import _pytest._code
-
- source = self._getentrysource(entry)
- if source is None:
- source = _pytest._code.Source("???")
- line_index = 0
- else:
- line_index = entry.lineno - entry.getfirstlinesource()
-
- lines = [] # type: List[str]
+ lines: List[str] = []
style = entry._repr_style if entry._repr_style is not None else self.style
if style in ("short", "long"):
+ source = self._getentrysource(entry)
+ if source is None:
+ source = Source("???")
+ line_index = 0
+ else:
+ line_index = entry.lineno - entry.getfirstlinesource()
short = style == "short"
reprargs = self.repr_args(entry) if not short else None
s = self.get_source(source, line_index, excinfo, short=short)
@@ -793,9 +805,14 @@ class FormattedExcinfo:
reprfileloc = ReprFileLocation(path, entry.lineno + 1, message)
localsrepr = self.repr_locals(entry.locals)
return ReprEntry(lines, reprargs, localsrepr, reprfileloc, style)
- if excinfo:
- lines.extend(self.get_exconly(excinfo, indent=4))
- return ReprEntry(lines, None, None, None, style)
+ elif style == "value":
+ if excinfo:
+ lines.extend(str(excinfo.value).split("\n"))
+ return ReprEntry(lines, None, None, None, style)
+ else:
+ if excinfo:
+ lines.extend(self.get_exconly(excinfo, indent=4))
+ return ReprEntry(lines, None, None, None, style)
def _makepath(self, path):
if not self.abspath:
@@ -807,18 +824,23 @@ class FormattedExcinfo:
path = np
return path
- def repr_traceback(self, excinfo: ExceptionInfo) -> "ReprTraceback":
+ def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> "ReprTraceback":
traceback = excinfo.traceback
if self.tbfilter:
traceback = traceback.filter()
- if excinfo.errisinstance(RecursionError):
+ if isinstance(excinfo.value, RecursionError):
traceback, extraline = self._truncate_recursive_traceback(traceback)
else:
extraline = None
last = traceback[-1]
entries = []
+ if self.style == "value":
+ reprentry = self.repr_traceback_entry(last, excinfo)
+ entries.append(reprentry)
+ return ReprTraceback(entries, None, style=self.style)
+
for index, entry in enumerate(traceback):
einfo = (last == entry) and excinfo or None
reprentry = self.repr_traceback_entry(entry, einfo)
@@ -828,22 +850,23 @@ class FormattedExcinfo:
def _truncate_recursive_traceback(
self, traceback: Traceback
) -> Tuple[Traceback, Optional[str]]:
- """
- Truncate the given recursive traceback trying to find the starting point
- of the recursion.
+ """Truncate the given recursive traceback trying to find the starting
+ point of the recursion.
- The detection is done by going through each traceback entry and finding the
- point in which the locals of the frame are equal to the locals of a previous frame (see ``recursionindex()``.
+ The detection is done by going through each traceback entry and
+ finding the point in which the locals of the frame are equal to the
+ locals of a previous frame (see ``recursionindex()``).
- Handle the situation where the recursion process might raise an exception (for example
- comparing numpy arrays using equality raises a TypeError), in which case we do our best to
- warn the user of the error and show a limited traceback.
+ Handle the situation where the recursion process might raise an
+ exception (for example comparing numpy arrays using equality raises a
+ TypeError), in which case we do our best to warn the user of the
+ error and show a limited traceback.
"""
try:
recursionindex = traceback.recursionindex()
except Exception as e:
max_frames = 10
- extraline = (
+ extraline: Optional[str] = (
"!!! Recursion error detected, but an error occurred locating the origin of recursion.\n"
" The following exception happened when comparing locals in the stack frame:\n"
" {exc_type}: {exc_msg}\n"
@@ -853,7 +876,7 @@ class FormattedExcinfo:
exc_msg=str(e),
max_frames=max_frames,
total=len(traceback),
- ) # type: Optional[str]
+ )
# Type ignored because adding two instaces of a List subtype
# currently incorrectly has type List instead of the subtype.
traceback = traceback[:max_frames] + traceback[-max_frames:] # type: ignore
@@ -866,22 +889,26 @@ class FormattedExcinfo:
return traceback, extraline
- def repr_excinfo(self, excinfo: ExceptionInfo) -> "ExceptionChainRepr":
- repr_chain = (
- []
- ) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]]
- e = excinfo.value
- excinfo_ = excinfo # type: Optional[ExceptionInfo]
+ def repr_excinfo(
+ self, excinfo: ExceptionInfo[BaseException]
+ ) -> "ExceptionChainRepr":
+ repr_chain: List[
+ Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]
+ ] = []
+ e: Optional[BaseException] = excinfo.value
+ excinfo_: Optional[ExceptionInfo[BaseException]] = excinfo
descr = None
- seen = set() # type: Set[int]
+ seen: Set[int] = set()
while e is not None and id(e) not in seen:
seen.add(id(e))
if excinfo_:
reprtraceback = self.repr_traceback(excinfo_)
- reprcrash = excinfo_._getreprcrash() # type: Optional[ReprFileLocation]
+ reprcrash: Optional[ReprFileLocation] = (
+ excinfo_._getreprcrash() if self.style != "value" else None
+ )
else:
- # fallback to native repr if the exception doesn't have a traceback:
- # ExceptionInfo objects require a full traceback to work
+ # Fallback to native repr if the exception doesn't have a traceback:
+ # ExceptionInfo objects require a full traceback to work.
reprtraceback = ReprTracebackNative(
traceback.format_exception(type(e), e, None)
)
@@ -912,7 +939,7 @@ class FormattedExcinfo:
return ExceptionChainRepr(repr_chain)
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+@attr.s(eq=False)
class TerminalRepr:
def __str__(self) -> str:
# FYI this is called from pytest-xdist's serialization of exception
@@ -929,10 +956,15 @@ class TerminalRepr:
raise NotImplementedError()
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+# This class is abstract -- only subclasses are instantiated.
+@attr.s(eq=False)
class ExceptionRepr(TerminalRepr):
- def __attrs_post_init__(self):
- self.sections = [] # type: List[Tuple[str, str, str]]
+ # Provided by subclasses.
+ reprcrash: Optional["ReprFileLocation"]
+ reprtraceback: "ReprTraceback"
+
+ def __attrs_post_init__(self) -> None:
+ self.sections: List[Tuple[str, str, str]] = []
def addsection(self, name: str, content: str, sep: str = "-") -> None:
self.sections.append((name, content, sep))
@@ -943,7 +975,7 @@ class ExceptionRepr(TerminalRepr):
tw.line(content)
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+@attr.s(eq=False)
class ExceptionChainRepr(ExceptionRepr):
chain = attr.ib(
type=Sequence[
@@ -951,10 +983,10 @@ class ExceptionChainRepr(ExceptionRepr):
]
)
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()
# reprcrash and reprtraceback of the outermost (the newest) exception
- # in the chain
+ # in the chain.
self.reprtraceback = self.chain[-1][0]
self.reprcrash = self.chain[-1][1]
@@ -967,7 +999,7 @@ class ExceptionChainRepr(ExceptionRepr):
super().toterminal(tw)
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+@attr.s(eq=False)
class ReprExceptionInfo(ExceptionRepr):
reprtraceback = attr.ib(type="ReprTraceback")
reprcrash = attr.ib(type="ReprFileLocation")
@@ -977,7 +1009,7 @@ class ReprExceptionInfo(ExceptionRepr):
super().toterminal(tw)
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+@attr.s(eq=False)
class ReprTraceback(TerminalRepr):
reprentries = attr.ib(type=Sequence[Union["ReprEntry", "ReprEntryNative"]])
extraline = attr.ib(type=Optional[str])
@@ -986,7 +1018,7 @@ class ReprTraceback(TerminalRepr):
entrysep = "_ "
def toterminal(self, tw: TerminalWriter) -> None:
- # the entries might have different styles
+ # The entries might have different styles.
for i, entry in enumerate(self.reprentries):
if entry.style == "long":
tw.line("")
@@ -1011,16 +1043,16 @@ class ReprTracebackNative(ReprTraceback):
self.extraline = None
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+@attr.s(eq=False)
class ReprEntryNative(TerminalRepr):
lines = attr.ib(type=Sequence[str])
- style = "native" # type: _TracebackStyle
+ style: "_TracebackStyle" = "native"
def toterminal(self, tw: TerminalWriter) -> None:
tw.write("".join(self.lines))
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+@attr.s(eq=False)
class ReprEntry(TerminalRepr):
lines = attr.ib(type=Sequence[str])
reprfuncargs = attr.ib(type=Optional["ReprFuncArgs"])
@@ -1029,7 +1061,7 @@ class ReprEntry(TerminalRepr):
style = attr.ib(type="_TracebackStyle")
def _write_entry_lines(self, tw: TerminalWriter) -> None:
- """Writes the source code portions of a list of traceback entries with syntax highlighting.
+ """Write the source code portions of a list of traceback entries with syntax highlighting.
Usually entries are lines like these:
@@ -1042,28 +1074,34 @@ class ReprEntry(TerminalRepr):
character, as doing so might break line continuations.
"""
- indent_size = 4
-
- def is_fail(line):
- return line.startswith("{} ".format(FormattedExcinfo.fail_marker))
-
if not self.lines:
return
# separate indents and source lines that are not failures: we want to
# highlight the code but not the indentation, which may contain markers
# such as "> assert 0"
- indents = []
- source_lines = []
- for line in self.lines:
- if not is_fail(line):
- indents.append(line[:indent_size])
- source_lines.append(line[indent_size:])
+ fail_marker = f"{FormattedExcinfo.fail_marker} "
+ indent_size = len(fail_marker)
+ indents: List[str] = []
+ source_lines: List[str] = []
+ failure_lines: List[str] = []
+ for index, line in enumerate(self.lines):
+ is_failure_line = line.startswith(fail_marker)
+ if is_failure_line:
+ # from this point on all lines are considered part of the failure
+ failure_lines.extend(self.lines[index:])
+ break
+ else:
+ if self.style == "value":
+ source_lines.append(line)
+ else:
+ indents.append(line[:indent_size])
+ source_lines.append(line[indent_size:])
tw._write_source(source_lines, indents)
# failure lines are always completely red and bold
- for line in (x for x in self.lines if is_fail(x)):
+ for line in failure_lines:
tw.line(line, bold=True, red=True)
def toterminal(self, tw: TerminalWriter) -> None:
@@ -1094,24 +1132,24 @@ class ReprEntry(TerminalRepr):
)
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+@attr.s(eq=False)
class ReprFileLocation(TerminalRepr):
path = attr.ib(type=str, converter=str)
lineno = attr.ib(type=int)
message = attr.ib(type=str)
def toterminal(self, tw: TerminalWriter) -> None:
- # filename and lineno output for each entry,
- # using an output format that most editors understand
+ # Filename and lineno output for each entry, using an output format
+ # that most editors understand.
msg = self.message
i = msg.find("\n")
if i != -1:
msg = msg[:i]
tw.write(self.path, bold=True, red=True)
- tw.line(":{}: {}".format(self.lineno, msg))
+ tw.line(f":{self.lineno}: {msg}")
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+@attr.s(eq=False)
class ReprLocals(TerminalRepr):
lines = attr.ib(type=Sequence[str])
@@ -1120,7 +1158,7 @@ class ReprLocals(TerminalRepr):
tw.line(indent + line)
-@attr.s(**{ATTRS_EQ_FIELD: False}) # type: ignore
+@attr.s(eq=False)
class ReprFuncArgs(TerminalRepr):
args = attr.ib(type=Sequence[Tuple[str, object]])
@@ -1128,7 +1166,7 @@ class ReprFuncArgs(TerminalRepr):
if self.args:
linesofar = ""
for name, value in self.args:
- ns = "{} = {}".format(name, value)
+ ns = f"{name} = {value}"
if len(ns) + len(linesofar) + 2 > tw.fullwidth:
if linesofar:
tw.line(linesofar)
@@ -1143,49 +1181,79 @@ class ReprFuncArgs(TerminalRepr):
tw.line("")
-def getrawcode(obj, trycall: bool = True):
- """ return code object for given function. """
+def getfslineno(obj: object) -> Tuple[Union[str, py.path.local], int]:
+ """Return source location (path, lineno) for the given object.
+
+ If the source cannot be determined return ("", -1).
+
+ The line number is 0-based.
+ """
+ # xxx let decorators etc specify a sane ordering
+ # NOTE: this used to be done in _pytest.compat.getfslineno, initially added
+ # in 6ec13a2b9. It ("place_as") appears to be something very custom.
+ obj = get_real_func(obj)
+ if hasattr(obj, "place_as"):
+ obj = obj.place_as # type: ignore[attr-defined]
+
try:
- return obj.__code__
- except AttributeError:
- obj = getattr(obj, "f_code", obj)
- obj = getattr(obj, "__code__", obj)
- if trycall and not hasattr(obj, "co_firstlineno"):
- if hasattr(obj, "__call__") and not inspect.isclass(obj):
- x = getrawcode(obj.__call__, trycall=False)
- if hasattr(x, "co_firstlineno"):
- return x
- return obj
-
-
-# relative paths that we use to filter traceback entries from appearing to the user;
-# see filter_traceback
+ code = Code.from_function(obj)
+ except TypeError:
+ try:
+ fn = inspect.getsourcefile(obj) or inspect.getfile(obj) # type: ignore[arg-type]
+ except TypeError:
+ return "", -1
+
+ fspath = fn and py.path.local(fn) or ""
+ lineno = -1
+ if fspath:
+ try:
+ _, lineno = findsource(obj)
+ except OSError:
+ pass
+ return fspath, lineno
+
+ return code.path, code.firstlineno
+
+
+# Relative paths that we use to filter traceback entries from appearing to the user;
+# see filter_traceback.
# note: if we need to add more paths than what we have now we should probably use a list
-# for better maintenance
+# for better maintenance.
-_PLUGGY_DIR = py.path.local(pluggy.__file__.rstrip("oc"))
+_PLUGGY_DIR = Path(pluggy.__file__.rstrip("oc"))
# pluggy is either a package or a single module depending on the version
-if _PLUGGY_DIR.basename == "__init__.py":
- _PLUGGY_DIR = _PLUGGY_DIR.dirpath()
-_PYTEST_DIR = py.path.local(_pytest.__file__).dirpath()
-_PY_DIR = py.path.local(py.__file__).dirpath()
+if _PLUGGY_DIR.name == "__init__.py":
+ _PLUGGY_DIR = _PLUGGY_DIR.parent
+_PYTEST_DIR = Path(_pytest.__file__).parent
+_PY_DIR = Path(py.__file__).parent
def filter_traceback(entry: TracebackEntry) -> bool:
- """Return True if a TracebackEntry instance should be removed from tracebacks:
+ """Return True if a TracebackEntry instance should be included in tracebacks.
+
+ We hide traceback entries of:
+
* dynamically generated code (no code to show up for it);
* internal traceback from pytest or its internal libraries, py and pluggy.
"""
# entry.path might sometimes return a str object when the entry
- # points to dynamically generated code
- # see https://bitbucket.org/pytest-dev/py/issues/71
+ # points to dynamically generated code.
+ # See https://bitbucket.org/pytest-dev/py/issues/71.
raw_filename = entry.frame.code.raw.co_filename
is_generated = "<" in raw_filename and ">" in raw_filename
if is_generated:
return False
+
# entry.path might point to a non-existing file, in which case it will
- # also return a str object. see #1133
- p = py.path.local(entry.path)
- return (
- not p.relto(_PLUGGY_DIR) and not p.relto(_PYTEST_DIR) and not p.relto(_PY_DIR)
- )
+ # also return a str object. See #1133.
+ p = Path(entry.path)
+
+ parents = p.parents
+ if _PLUGGY_DIR in parents:
+ return False
+ if _PYTEST_DIR in parents:
+ return False
+ if _PY_DIR in parents:
+ return False
+
+ return True
diff --git a/contrib/python/pytest/py3/_pytest/_code/source.py b/contrib/python/pytest/py3/_pytest/_code/source.py
index 28c11e5d5e..6f54057c0a 100644
--- a/contrib/python/pytest/py3/_pytest/_code/source.py
+++ b/contrib/python/pytest/py3/_pytest/_code/source.py
@@ -1,76 +1,59 @@
import ast
import inspect
-import linecache
-import sys
import textwrap
import tokenize
+import types
import warnings
from bisect import bisect_right
-from types import CodeType
-from types import FrameType
-from typing import Any
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
-from typing import Sequence
+from typing import overload
from typing import Tuple
from typing import Union
-import py
-
-from _pytest.compat import get_real_func
-from _pytest.compat import overload
-from _pytest.compat import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from typing_extensions import Literal
-
class Source:
- """ an immutable object holding a source code fragment,
- possibly deindenting it.
+ """An immutable object holding a source code fragment.
+
+ When using Source(...), the source lines are deindented.
"""
- _compilecounter = 0
-
- def __init__(self, *parts, **kwargs) -> None:
- self.lines = lines = [] # type: List[str]
- de = kwargs.get("deindent", True)
- for part in parts:
- if not part:
- partlines = [] # type: List[str]
- elif isinstance(part, Source):
- partlines = part.lines
- elif isinstance(part, (tuple, list)):
- partlines = [x.rstrip("\n") for x in part]
- elif isinstance(part, str):
- partlines = part.split("\n")
- else:
- partlines = getsource(part, deindent=de).lines
- if de:
- partlines = deindent(partlines)
- lines.extend(partlines)
-
- def __eq__(self, other):
- try:
- return self.lines == other.lines
- except AttributeError:
- if isinstance(other, str):
- return str(self) == other
- return False
+ def __init__(self, obj: object = None) -> None:
+ if not obj:
+ self.lines: List[str] = []
+ elif isinstance(obj, Source):
+ self.lines = obj.lines
+ elif isinstance(obj, (tuple, list)):
+ self.lines = deindent(x.rstrip("\n") for x in obj)
+ elif isinstance(obj, str):
+ self.lines = deindent(obj.split("\n"))
+ else:
+ try:
+ rawcode = getrawcode(obj)
+ src = inspect.getsource(rawcode)
+ except TypeError:
+ src = inspect.getsource(obj) # type: ignore[arg-type]
+ self.lines = deindent(src.split("\n"))
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Source):
+ return NotImplemented
+ return self.lines == other.lines
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@overload
def __getitem__(self, key: int) -> str:
- raise NotImplementedError()
+ ...
- @overload # noqa: F811
- def __getitem__(self, key: slice) -> "Source": # noqa: F811
- raise NotImplementedError()
+ @overload
+ def __getitem__(self, key: slice) -> "Source":
+ ...
- def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811
+ def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]:
if isinstance(key, int):
return self.lines[key]
else:
@@ -87,9 +70,7 @@ class Source:
return len(self.lines)
def strip(self) -> "Source":
- """ return new source object with trailing
- and leading blank lines removed.
- """
+ """Return new Source object with trailing and leading blank lines removed."""
start, end = 0, len(self)
while start < end and not self.lines[start].strip():
start += 1
@@ -99,220 +80,36 @@ class Source:
source.lines[:] = self.lines[start:end]
return source
- def putaround(
- self, before: str = "", after: str = "", indent: str = " " * 4
- ) -> "Source":
- """ return a copy of the source object with
- 'before' and 'after' wrapped around it.
- """
- beforesource = Source(before)
- aftersource = Source(after)
- newsource = Source()
- lines = [(indent + line) for line in self.lines]
- newsource.lines = beforesource.lines + lines + aftersource.lines
- return newsource
-
def indent(self, indent: str = " " * 4) -> "Source":
- """ return a copy of the source object with
- all lines indented by the given indent-string.
- """
+ """Return a copy of the source object with all lines indented by the
+ given indent-string."""
newsource = Source()
newsource.lines = [(indent + line) for line in self.lines]
return newsource
def getstatement(self, lineno: int) -> "Source":
- """ return Source statement which contains the
- given linenumber (counted from 0).
- """
+ """Return Source statement which contains the given linenumber
+ (counted from 0)."""
start, end = self.getstatementrange(lineno)
return self[start:end]
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
- """ return (start, end) tuple which spans the minimal
- statement region which containing the given lineno.
- """
+ """Return (start, end) tuple which spans the minimal statement region
+ which containing the given lineno."""
if not (0 <= lineno < len(self)):
raise IndexError("lineno out of range")
ast, start, end = getstatementrange_ast(lineno, self)
return start, end
def deindent(self) -> "Source":
- """return a new source object deindented."""
+ """Return a new Source object deindented."""
newsource = Source()
newsource.lines[:] = deindent(self.lines)
return newsource
- def isparseable(self, deindent: bool = True) -> bool:
- """ return True if source is parseable, heuristically
- deindenting it by default.
- """
- if deindent:
- source = str(self.deindent())
- else:
- source = str(self)
- try:
- ast.parse(source)
- except (SyntaxError, ValueError, TypeError):
- return False
- else:
- return True
-
def __str__(self) -> str:
return "\n".join(self.lines)
- @overload
- def compile(
- self,
- filename: Optional[str] = ...,
- mode: str = ...,
- flag: "Literal[0]" = ...,
- dont_inherit: int = ...,
- _genframe: Optional[FrameType] = ...,
- ) -> CodeType:
- raise NotImplementedError()
-
- @overload # noqa: F811
- def compile( # noqa: F811
- self,
- filename: Optional[str] = ...,
- mode: str = ...,
- flag: int = ...,
- dont_inherit: int = ...,
- _genframe: Optional[FrameType] = ...,
- ) -> Union[CodeType, ast.AST]:
- raise NotImplementedError()
-
- def compile( # noqa: F811
- self,
- filename: Optional[str] = None,
- mode: str = "exec",
- flag: int = 0,
- dont_inherit: int = 0,
- _genframe: Optional[FrameType] = None,
- ) -> Union[CodeType, ast.AST]:
- """ return compiled code object. if filename is None
- invent an artificial filename which displays
- the source/line position of the caller frame.
- """
- if not filename or py.path.local(filename).check(file=0):
- if _genframe is None:
- _genframe = sys._getframe(1) # the caller
- fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno
- base = "<%d-codegen " % self._compilecounter
- self.__class__._compilecounter += 1
- if not filename:
- filename = base + "%s:%d>" % (fn, lineno)
- else:
- filename = base + "%r %s:%d>" % (filename, fn, lineno)
- source = "\n".join(self.lines) + "\n"
- try:
- co = compile(source, filename, mode, flag)
- except SyntaxError as ex:
- # re-represent syntax errors from parsing python strings
- msglines = self.lines[: ex.lineno]
- if ex.offset:
- msglines.append(" " * ex.offset + "^")
- msglines.append("(code was compiled probably from here: %s)" % filename)
- newex = SyntaxError("\n".join(msglines))
- newex.offset = ex.offset
- newex.lineno = ex.lineno
- newex.text = ex.text
- raise newex
- else:
- if flag & ast.PyCF_ONLY_AST:
- assert isinstance(co, ast.AST)
- return co
- assert isinstance(co, CodeType)
- lines = [(x + "\n") for x in self.lines]
- # Type ignored because linecache.cache is private.
- linecache.cache[filename] = (1, None, lines, filename) # type: ignore
- return co
-
-
-#
-# public API shortcut functions
-#
-
-
-@overload
-def compile_(
- source: Union[str, bytes, ast.mod, ast.AST],
- filename: Optional[str] = ...,
- mode: str = ...,
- flags: "Literal[0]" = ...,
- dont_inherit: int = ...,
-) -> CodeType:
- raise NotImplementedError()
-
-
-@overload # noqa: F811
-def compile_( # noqa: F811
- source: Union[str, bytes, ast.mod, ast.AST],
- filename: Optional[str] = ...,
- mode: str = ...,
- flags: int = ...,
- dont_inherit: int = ...,
-) -> Union[CodeType, ast.AST]:
- raise NotImplementedError()
-
-
-def compile_( # noqa: F811
- source: Union[str, bytes, ast.mod, ast.AST],
- filename: Optional[str] = None,
- mode: str = "exec",
- flags: int = 0,
- dont_inherit: int = 0,
-) -> Union[CodeType, ast.AST]:
- """ compile the given source to a raw code object,
- and maintain an internal cache which allows later
- retrieval of the source code for the code object
- and any recursively created code objects.
- """
- if isinstance(source, ast.AST):
- # XXX should Source support having AST?
- assert filename is not None
- co = compile(source, filename, mode, flags, dont_inherit)
- assert isinstance(co, (CodeType, ast.AST))
- return co
- _genframe = sys._getframe(1) # the caller
- s = Source(source)
- return s.compile(filename, mode, flags, _genframe=_genframe)
-
-
-def getfslineno(obj: Any) -> Tuple[Union[str, py.path.local], int]:
- """ Return source location (path, lineno) for the given object.
- If the source cannot be determined return ("", -1).
-
- The line number is 0-based.
- """
- from .code import Code
-
- # xxx let decorators etc specify a sane ordering
- # NOTE: this used to be done in _pytest.compat.getfslineno, initially added
- # in 6ec13a2b9. It ("place_as") appears to be something very custom.
- obj = get_real_func(obj)
- if hasattr(obj, "place_as"):
- obj = obj.place_as
-
- try:
- code = Code(obj)
- except TypeError:
- try:
- fn = inspect.getsourcefile(obj) or inspect.getfile(obj)
- except TypeError:
- return "", -1
-
- fspath = fn and py.path.local(fn) or ""
- lineno = -1
- if fspath:
- try:
- _, lineno = findsource(obj)
- except IOError:
- pass
- return fspath, lineno
- else:
- return code.path, code.firstlineno
-
#
# helper functions
@@ -329,35 +126,34 @@ def findsource(obj) -> Tuple[Optional[Source], int]:
return source, lineno
-def getsource(obj, **kwargs) -> Source:
- from .code import getrawcode
-
- obj = getrawcode(obj)
+def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
+ """Return code object for given function."""
try:
- strsrc = inspect.getsource(obj)
- except IndentationError:
- strsrc = '"Buggy python version consider upgrading, cannot get source"'
- assert isinstance(strsrc, str)
- return Source(strsrc, **kwargs)
+ return obj.__code__ # type: ignore[attr-defined,no-any-return]
+ except AttributeError:
+ pass
+ if trycall:
+ call = getattr(obj, "__call__", None)
+ if call and not isinstance(obj, type):
+ return getrawcode(call, trycall=False)
+ raise TypeError(f"could not get code object for {obj!r}")
-def deindent(lines: Sequence[str]) -> List[str]:
+def deindent(lines: Iterable[str]) -> List[str]:
return textwrap.dedent("\n".join(lines)).splitlines()
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
- import ast
-
- # flatten all statements and except handlers into one lineno-list
- # AST's line numbers start indexing at 1
- values = [] # type: List[int]
+ # Flatten all statements and except handlers into one lineno-list.
+ # AST's line numbers start indexing at 1.
+ values: List[int] = []
for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
values.append(x.lineno - 1)
for name in ("finalbody", "orelse"):
- val = getattr(x, name, None) # type: Optional[List[ast.stmt]]
+ val: Optional[List[ast.stmt]] = getattr(x, name, None)
if val:
- # treat the finally/orelse part as its own statement
+ # Treat the finally/orelse part as its own statement.
values.append(val[0].lineno - 1 - 1)
values.sort()
insert_index = bisect_right(values, lineno)
@@ -378,13 +174,13 @@ def getstatementrange_ast(
if astnode is None:
content = str(source)
# See #4260:
- # don't produce duplicate warnings when compiling source to find ast
+ # Don't produce duplicate warnings when compiling source to find AST.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
astnode = ast.parse(content, "source", "exec")
start, end = get_statement_startend2(lineno, astnode)
- # we need to correct the end:
+ # We need to correct the end:
# - ast-parsing strips comments
# - there might be empty lines
# - we might have lesser indented code blocks at the end
@@ -392,10 +188,10 @@ def getstatementrange_ast(
end = len(source.lines)
if end > start + 1:
- # make sure we don't span differently indented code blocks
- # by using the BlockFinder helper used which inspect.getsource() uses itself
+ # Make sure we don't span differently indented code blocks
+ # by using the BlockFinder helper used which inspect.getsource() uses itself.
block_finder = inspect.BlockFinder()
- # if we start with an indented line, put blockfinder to "started" mode
+ # If we start with an indented line, put blockfinder to "started" mode.
block_finder.started = source.lines[start][0].isspace()
it = ((x + "\n") for x in source.lines[start:end])
try:
@@ -406,7 +202,7 @@ def getstatementrange_ast(
except Exception:
pass
- # the end might still point to a comment or empty line, correct it
+ # The end might still point to a comment or empty line, correct it.
while end:
line = source.lines[end - 1].lstrip()
if line.startswith("#") or not line:
diff --git a/contrib/python/pytest/py3/_pytest/_io/__init__.py b/contrib/python/pytest/py3/_pytest/_io/__init__.py
index f56579806c..db001e918c 100644
--- a/contrib/python/pytest/py3/_pytest/_io/__init__.py
+++ b/contrib/python/pytest/py3/_pytest/_io/__init__.py
@@ -1,39 +1,8 @@
-from typing import List
-from typing import Sequence
+from .terminalwriter import get_terminal_width
+from .terminalwriter import TerminalWriter
-from py.io import TerminalWriter as BaseTerminalWriter # noqa: F401
-
-class TerminalWriter(BaseTerminalWriter):
- def _write_source(self, lines: List[str], indents: Sequence[str] = ()) -> None:
- """Write lines of source code possibly highlighted.
-
- Keeping this private for now because the API is clunky. We should discuss how
- to evolve the terminal writer so we can have more precise color support, for example
- being able to write part of a line in one color and the rest in another, and so on.
- """
- if indents and len(indents) != len(lines):
- raise ValueError(
- "indents size ({}) should have same size as lines ({})".format(
- len(indents), len(lines)
- )
- )
- if not indents:
- indents = [""] * len(lines)
- source = "\n".join(lines)
- new_lines = self._highlight(source).splitlines()
- for indent, new_line in zip(indents, new_lines):
- self.line(indent + new_line)
-
- def _highlight(self, source):
- """Highlight the given source code according to the "code_highlight" option"""
- if not self.hasmarkup:
- return source
- try:
- from pygments.formatters.terminal import TerminalFormatter
- from pygments.lexers.python import PythonLexer
- from pygments import highlight
- except ImportError:
- return source
- else:
- return highlight(source, PythonLexer(), TerminalFormatter(bg="dark"))
+__all__ = [
+ "TerminalWriter",
+ "get_terminal_width",
+]
diff --git a/contrib/python/pytest/py3/_pytest/_io/saferepr.py b/contrib/python/pytest/py3/_pytest/_io/saferepr.py
index 47a00de606..5eb1e08890 100644
--- a/contrib/python/pytest/py3/_pytest/_io/saferepr.py
+++ b/contrib/python/pytest/py3/_pytest/_io/saferepr.py
@@ -1,9 +1,12 @@
import pprint
import reprlib
from typing import Any
+from typing import Dict
+from typing import IO
+from typing import Optional
-def _try_repr_or_str(obj):
+def _try_repr_or_str(obj: object) -> str:
try:
return repr(obj)
except (KeyboardInterrupt, SystemExit):
@@ -12,7 +15,7 @@ def _try_repr_or_str(obj):
return '{}("{}")'.format(type(obj).__name__, obj)
-def _format_repr_exception(exc: BaseException, obj: Any) -> str:
+def _format_repr_exception(exc: BaseException, obj: object) -> str:
try:
exc_info = _try_repr_or_str(exc)
except (KeyboardInterrupt, SystemExit):
@@ -33,16 +36,15 @@ def _ellipsize(s: str, maxsize: int) -> str:
class SafeRepr(reprlib.Repr):
- """subclass of repr.Repr that limits the resulting size of repr()
- and includes information on exceptions raised during the call.
- """
+ """repr.Repr that limits the resulting size of repr() and includes
+ information on exceptions raised during the call."""
def __init__(self, maxsize: int) -> None:
super().__init__()
self.maxstring = maxsize
self.maxsize = maxsize
- def repr(self, x: Any) -> str:
+ def repr(self, x: object) -> str:
try:
s = super().repr(x)
except (KeyboardInterrupt, SystemExit):
@@ -51,7 +53,7 @@ class SafeRepr(reprlib.Repr):
s = _format_repr_exception(exc, x)
return _ellipsize(s, self.maxsize)
- def repr_instance(self, x: Any, level: int) -> str:
+ def repr_instance(self, x: object, level: int) -> str:
try:
s = repr(x)
except (KeyboardInterrupt, SystemExit):
@@ -61,8 +63,9 @@ class SafeRepr(reprlib.Repr):
return _ellipsize(s, self.maxsize)
-def safeformat(obj: Any) -> str:
- """return a pretty printed string for the given object.
+def safeformat(obj: object) -> str:
+ """Return a pretty printed string for the given object.
+
Failing __repr__ functions of user instances will be represented
with a short exception info.
"""
@@ -72,12 +75,15 @@ def safeformat(obj: Any) -> str:
return _format_repr_exception(exc, obj)
-def saferepr(obj: Any, maxsize: int = 240) -> str:
- """return a size-limited safe repr-string for the given object.
+def saferepr(obj: object, maxsize: int = 240) -> str:
+ """Return a size-limited safe repr-string for the given object.
+
Failing __repr__ functions of user instances will be represented
with a short exception info and 'saferepr' generally takes
- care to never raise exceptions itself. This function is a wrapper
- around the Repr/reprlib functionality of the standard 2.6 lib.
+ care to never raise exceptions itself.
+
+ This function is a wrapper around the Repr/reprlib functionality of the
+ standard 2.6 lib.
"""
return SafeRepr(maxsize).repr(obj)
@@ -85,19 +91,39 @@ def saferepr(obj: Any, maxsize: int = 240) -> str:
class AlwaysDispatchingPrettyPrinter(pprint.PrettyPrinter):
"""PrettyPrinter that always dispatches (regardless of width)."""
- def _format(self, object, stream, indent, allowance, context, level):
- p = self._dispatch.get(type(object).__repr__, None)
+ def _format(
+ self,
+ object: object,
+ stream: IO[str],
+ indent: int,
+ allowance: int,
+ context: Dict[int, Any],
+ level: int,
+ ) -> None:
+ # Type ignored because _dispatch is private.
+ p = self._dispatch.get(type(object).__repr__, None) # type: ignore[attr-defined]
objid = id(object)
if objid in context or p is None:
- return super()._format(object, stream, indent, allowance, context, level)
+ # Type ignored because _format is private.
+ super()._format( # type: ignore[misc]
+ object, stream, indent, allowance, context, level,
+ )
+ return
context[objid] = 1
p(self, object, stream, indent, allowance, context, level + 1)
del context[objid]
-def _pformat_dispatch(object, indent=1, width=80, depth=None, *, compact=False):
+def _pformat_dispatch(
+ object: object,
+ indent: int = 1,
+ width: int = 80,
+ depth: Optional[int] = None,
+ *,
+ compact: bool = False,
+) -> str:
return AlwaysDispatchingPrettyPrinter(
indent=indent, width=width, depth=depth, compact=compact
).pformat(object)
diff --git a/contrib/python/pytest/py3/_pytest/_io/terminalwriter.py b/contrib/python/pytest/py3/_pytest/_io/terminalwriter.py
new file mode 100644
index 0000000000..8edf4cd75f
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/_io/terminalwriter.py
@@ -0,0 +1,210 @@
+"""Helper functions for writing to terminals and files."""
+import os
+import shutil
+import sys
+from typing import Optional
+from typing import Sequence
+from typing import TextIO
+
+from .wcwidth import wcswidth
+from _pytest.compat import final
+
+
+# This code was initially copied from py 1.8.1, file _io/terminalwriter.py.
+
+
+def get_terminal_width() -> int:
+ width, _ = shutil.get_terminal_size(fallback=(80, 24))
+
+ # The Windows get_terminal_size may be bogus, let's sanify a bit.
+ if width < 40:
+ width = 80
+
+ return width
+
+
+def should_do_markup(file: TextIO) -> bool:
+ if os.environ.get("PY_COLORS") == "1":
+ return True
+ if os.environ.get("PY_COLORS") == "0":
+ return False
+ if "NO_COLOR" in os.environ:
+ return False
+ if "FORCE_COLOR" in os.environ:
+ return True
+ return (
+ hasattr(file, "isatty") and file.isatty() and os.environ.get("TERM") != "dumb"
+ )
+
+
+@final
+class TerminalWriter:
+ _esctable = dict(
+ black=30,
+ red=31,
+ green=32,
+ yellow=33,
+ blue=34,
+ purple=35,
+ cyan=36,
+ white=37,
+ Black=40,
+ Red=41,
+ Green=42,
+ Yellow=43,
+ Blue=44,
+ Purple=45,
+ Cyan=46,
+ White=47,
+ bold=1,
+ light=2,
+ blink=5,
+ invert=7,
+ )
+
+ def __init__(self, file: Optional[TextIO] = None) -> None:
+ if file is None:
+ file = sys.stdout
+ if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32":
+ try:
+ import colorama
+ except ImportError:
+ pass
+ else:
+ file = colorama.AnsiToWin32(file).stream
+ assert file is not None
+ self._file = file
+ self.hasmarkup = should_do_markup(file)
+ self._current_line = ""
+ self._terminal_width: Optional[int] = None
+ self.code_highlight = True
+
+ @property
+ def fullwidth(self) -> int:
+ if self._terminal_width is not None:
+ return self._terminal_width
+ return get_terminal_width()
+
+ @fullwidth.setter
+ def fullwidth(self, value: int) -> None:
+ self._terminal_width = value
+
+ @property
+ def width_of_current_line(self) -> int:
+ """Return an estimate of the width so far in the current line."""
+ return wcswidth(self._current_line)
+
+ def markup(self, text: str, **markup: bool) -> str:
+ for name in markup:
+ if name not in self._esctable:
+ raise ValueError(f"unknown markup: {name!r}")
+ if self.hasmarkup:
+ esc = [self._esctable[name] for name, on in markup.items() if on]
+ if esc:
+ text = "".join("\x1b[%sm" % cod for cod in esc) + text + "\x1b[0m"
+ return text
+
+ def sep(
+ self,
+ sepchar: str,
+ title: Optional[str] = None,
+ fullwidth: Optional[int] = None,
+ **markup: bool,
+ ) -> None:
+ if fullwidth is None:
+ fullwidth = self.fullwidth
+ # The goal is to have the line be as long as possible
+ # under the condition that len(line) <= fullwidth.
+ if sys.platform == "win32":
+ # If we print in the last column on windows we are on a
+ # new line but there is no way to verify/neutralize this
+ # (we may not know the exact line width).
+ # So let's be defensive to avoid empty lines in the output.
+ fullwidth -= 1
+ if title is not None:
+ # we want 2 + 2*len(fill) + len(title) <= fullwidth
+ # i.e. 2 + 2*len(sepchar)*N + len(title) <= fullwidth
+ # 2*len(sepchar)*N <= fullwidth - len(title) - 2
+ # N <= (fullwidth - len(title) - 2) // (2*len(sepchar))
+ N = max((fullwidth - len(title) - 2) // (2 * len(sepchar)), 1)
+ fill = sepchar * N
+ line = f"{fill} {title} {fill}"
+ else:
+ # we want len(sepchar)*N <= fullwidth
+ # i.e. N <= fullwidth // len(sepchar)
+ line = sepchar * (fullwidth // len(sepchar))
+ # In some situations there is room for an extra sepchar at the right,
+ # in particular if we consider that with a sepchar like "_ " the
+ # trailing space is not important at the end of the line.
+ if len(line) + len(sepchar.rstrip()) <= fullwidth:
+ line += sepchar.rstrip()
+
+ self.line(line, **markup)
+
+ def write(self, msg: str, *, flush: bool = False, **markup: bool) -> None:
+ if msg:
+ current_line = msg.rsplit("\n", 1)[-1]
+ if "\n" in msg:
+ self._current_line = current_line
+ else:
+ self._current_line += current_line
+
+ msg = self.markup(msg, **markup)
+
+ try:
+ self._file.write(msg)
+ except UnicodeEncodeError:
+ # Some environments don't support printing general Unicode
+ # strings, due to misconfiguration or otherwise; in that case,
+ # print the string escaped to ASCII.
+ # When the Unicode situation improves we should consider
+ # letting the error propagate instead of masking it (see #7475
+ # for one brief attempt).
+ msg = msg.encode("unicode-escape").decode("ascii")
+ self._file.write(msg)
+
+ if flush:
+ self.flush()
+
+ def line(self, s: str = "", **markup: bool) -> None:
+ self.write(s, **markup)
+ self.write("\n")
+
+ def flush(self) -> None:
+ self._file.flush()
+
+ def _write_source(self, lines: Sequence[str], indents: Sequence[str] = ()) -> None:
+ """Write lines of source code possibly highlighted.
+
+ Keeping this private for now because the API is clunky. We should discuss how
+ to evolve the terminal writer so we can have more precise color support, for example
+ being able to write part of a line in one color and the rest in another, and so on.
+ """
+ if indents and len(indents) != len(lines):
+ raise ValueError(
+ "indents size ({}) should have same size as lines ({})".format(
+ len(indents), len(lines)
+ )
+ )
+ if not indents:
+ indents = [""] * len(lines)
+ source = "\n".join(lines)
+ new_lines = self._highlight(source).splitlines()
+ for indent, new_line in zip(indents, new_lines):
+ self.line(indent + new_line)
+
+ def _highlight(self, source: str) -> str:
+ """Highlight the given source code if we have markup support."""
+ if not self.hasmarkup or not self.code_highlight:
+ return source
+ try:
+ from pygments.formatters.terminal import TerminalFormatter
+ from pygments.lexers.python import PythonLexer
+ from pygments import highlight
+ except ImportError:
+ return source
+ else:
+ highlighted: str = highlight(
+ source, PythonLexer(), TerminalFormatter(bg="dark")
+ )
+ return highlighted
diff --git a/contrib/python/pytest/py3/_pytest/_io/wcwidth.py b/contrib/python/pytest/py3/_pytest/_io/wcwidth.py
new file mode 100644
index 0000000000..e5c7bf4d86
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/_io/wcwidth.py
@@ -0,0 +1,55 @@
+import unicodedata
+from functools import lru_cache
+
+
+@lru_cache(100)
+def wcwidth(c: str) -> int:
+ """Determine how many columns are needed to display a character in a terminal.
+
+ Returns -1 if the character is not printable.
+ Returns 0, 1 or 2 for other characters.
+ """
+ o = ord(c)
+
+ # ASCII fast path.
+ if 0x20 <= o < 0x07F:
+ return 1
+
+ # Some Cf/Zp/Zl characters which should be zero-width.
+ if (
+ o == 0x0000
+ or 0x200B <= o <= 0x200F
+ or 0x2028 <= o <= 0x202E
+ or 0x2060 <= o <= 0x2063
+ ):
+ return 0
+
+ category = unicodedata.category(c)
+
+ # Control characters.
+ if category == "Cc":
+ return -1
+
+ # Combining characters with zero width.
+ if category in ("Me", "Mn"):
+ return 0
+
+ # Full/Wide east asian characters.
+ if unicodedata.east_asian_width(c) in ("F", "W"):
+ return 2
+
+ return 1
+
+
+def wcswidth(s: str) -> int:
+ """Determine how many columns are needed to display a string in a terminal.
+
+ Returns -1 if the string contains non-printable characters.
+ """
+ width = 0
+ for c in unicodedata.normalize("NFC", s):
+ wc = wcwidth(c)
+ if wc < 0:
+ return -1
+ width += wc
+ return width
diff --git a/contrib/python/pytest/py3/_pytest/_version.py b/contrib/python/pytest/py3/_pytest/_version.py
index cfc10f60ae..83518587e4 100644
--- a/contrib/python/pytest/py3/_pytest/_version.py
+++ b/contrib/python/pytest/py3/_pytest/_version.py
@@ -1,4 +1,5 @@
# coding: utf-8
# file generated by setuptools_scm
# don't change, don't track in version control
-version = '5.4.3'
+version = '6.2.5'
+version_tuple = (6, 2, 5)
diff --git a/contrib/python/pytest/py3/_pytest/assertion/__init__.py b/contrib/python/pytest/py3/_pytest/assertion/__init__.py
index ee7fa6a3af..a18cf198df 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/__init__.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/__init__.py
@@ -1,24 +1,25 @@
-"""
-support for presenting detailed information in failing assertions.
-"""
+"""Support for presenting detailed information in failing assertions."""
import sys
from typing import Any
+from typing import Generator
from typing import List
from typing import Optional
+from typing import TYPE_CHECKING
from _pytest.assertion import rewrite
from _pytest.assertion import truncate
from _pytest.assertion import util
from _pytest.assertion.rewrite import assertstate_key
-from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import hookimpl
+from _pytest.config.argparsing import Parser
+from _pytest.nodes import Item
if TYPE_CHECKING:
from _pytest.main import Session
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--assert",
@@ -27,11 +28,12 @@ def pytest_addoption(parser):
choices=("rewrite", "plain"),
default="rewrite",
metavar="MODE",
- help="""Control assertion debugging tools. 'plain'
- performs no assertion debugging. 'rewrite'
- (the default) rewrites assert statements in
- test modules on import to provide assert
- expression information.""",
+ help=(
+ "Control assertion debugging tools.\n"
+ "'plain' performs no assertion debugging.\n"
+ "'rewrite' (the default) rewrites assert statements in test modules"
+ " on import to provide assert expression information."
+ ),
)
parser.addini(
"enable_assertion_pass_hook",
@@ -42,7 +44,7 @@ def pytest_addoption(parser):
)
-def register_assert_rewrite(*names) -> None:
+def register_assert_rewrite(*names: str) -> None:
"""Register one or more module names to be rewritten on import.
This function will make sure that this module or all modules inside
@@ -51,11 +53,11 @@ def register_assert_rewrite(*names) -> None:
actually imported, usually in your __init__.py if you are a plugin
using a package.
- :raise TypeError: if the given module names are not strings.
+ :raises TypeError: If the given module names are not strings.
"""
for name in names:
if not isinstance(name, str):
- msg = "expected module names as *args, got {0} instead"
+ msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
raise TypeError(msg.format(repr(names)))
for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook):
@@ -71,27 +73,27 @@ def register_assert_rewrite(*names) -> None:
class DummyRewriteHook:
"""A no-op import hook for when rewriting is disabled."""
- def mark_rewrite(self, *names):
+ def mark_rewrite(self, *names: str) -> None:
pass
class AssertionState:
"""State for the assertion plugin."""
- def __init__(self, config, mode):
+ def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
- self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
+ self.hook: Optional[rewrite.AssertionRewritingHook] = None
-def install_importhook(config):
+def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails."""
config._store[assertstate_key] = AssertionState(config, "rewrite")
config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
config._store[assertstate_key].trace("installed rewrite import hook")
- def undo():
+ def undo() -> None:
hook = config._store[assertstate_key].hook
if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook)
@@ -101,9 +103,9 @@ def install_importhook(config):
def pytest_collection(session: "Session") -> None:
- # this hook is only called when test modules are collected
+ # This hook is only called when test modules are collected
# so for example not in the master process of pytest-xdist
- # (which does not collect test modules)
+ # (which does not collect test modules).
assertstate = session.config._store.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
@@ -111,18 +113,18 @@ def pytest_collection(session: "Session") -> None:
@hookimpl(tryfirst=True, hookwrapper=True)
-def pytest_runtest_protocol(item):
- """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
+def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
+ """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
- The newinterpret and rewrite modules will use util._reprcompare if
- it exists to use custom reporting via the
- pytest_assertrepr_compare hook. This sets up this custom
+ The rewrite module will use util._reprcompare if it exists to use custom
+ reporting via the pytest_assertrepr_compare hook. This sets up this custom
comparison for the test.
"""
- def callbinrepr(op, left, right):
- # type: (str, object, object) -> Optional[str]
- """Call the pytest_assertrepr_compare hook and prepare the result
+ ihook = item.ihook
+
+ def callbinrepr(op, left: object, right: object) -> Optional[str]:
+ """Call the pytest_assertrepr_compare hook and prepare the result.
This uses the first result from the hook and then ensures the
following:
@@ -136,7 +138,7 @@ def pytest_runtest_protocol(item):
The result can be formatted by util.format_explanation() for
pretty printing.
"""
- hook_result = item.ihook.pytest_assertrepr_compare(
+ hook_result = ihook.pytest_assertrepr_compare(
config=item.config, op=op, left=left, right=right
)
for new_expl in hook_result:
@@ -152,12 +154,10 @@ def pytest_runtest_protocol(item):
saved_assert_hooks = util._reprcompare, util._assertion_pass
util._reprcompare = callbinrepr
- if item.ihook.pytest_assertion_pass.get_hookimpls():
+ if ihook.pytest_assertion_pass.get_hookimpls():
- def call_assertion_pass_hook(lineno, orig, expl):
- item.ihook.pytest_assertion_pass(
- item=item, lineno=lineno, orig=orig, expl=expl
- )
+ def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
+ ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
util._assertion_pass = call_assertion_pass_hook
@@ -166,7 +166,7 @@ def pytest_runtest_protocol(item):
util._reprcompare, util._assertion_pass = saved_assert_hooks
-def pytest_sessionfinish(session):
+def pytest_sessionfinish(session: "Session") -> None:
assertstate = session.config._store.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
index f84127dcaf..37ff076aab 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
@@ -1,4 +1,4 @@
-"""Rewrite assertion AST to produce nice error messages"""
+"""Rewrite assertion AST to produce nice error messages."""
import ast
import errno
import functools
@@ -13,11 +13,21 @@ import struct
import sys
import tokenize
import types
+from pathlib import Path
+from pathlib import PurePath
+from typing import Callable
from typing import Dict
+from typing import IO
+from typing import Iterable
from typing import List
from typing import Optional
+from typing import Sequence
from typing import Set
from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
+import py
from _pytest._io.saferepr import saferepr
from _pytest._version import version
@@ -25,22 +35,20 @@ from _pytest.assertion import util
from _pytest.assertion.util import ( # noqa: F401
format_explanation as _format_explanation,
)
-from _pytest.compat import fspath
-from _pytest.compat import TYPE_CHECKING
+from _pytest.config import Config
+from _pytest.main import Session
from _pytest.pathlib import fnmatch_ex
-from _pytest.pathlib import Path
-from _pytest.pathlib import PurePath
from _pytest.store import StoreKey
if TYPE_CHECKING:
- from _pytest.assertion import AssertionState # noqa: F401
+ from _pytest.assertion import AssertionState
assertstate_key = StoreKey["AssertionState"]()
# pytest caches rewritten pycs in pycache dirs
-PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version)
+PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
PYC_EXT = ".py" + (__debug__ and "c" or "o")
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
@@ -48,30 +56,35 @@ PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""PEP302/PEP451 import hook which rewrites asserts."""
- def __init__(self, config):
+ def __init__(self, config: Config) -> None:
self.config = config
try:
self.fnpats = config.getini("python_files")
except ValueError:
self.fnpats = ["test_*.py", "*_test.py"]
- self.session = None
- self._rewritten_names = set() # type: Set[str]
- self._must_rewrite = set() # type: Set[str]
+ self.session: Optional[Session] = None
+ self._rewritten_names: Set[str] = set()
+ self._must_rewrite: Set[str] = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False
self._basenames_to_check_rewrite = {"conftest"}
- self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
+ self._marked_for_rewrite_cache: Dict[str, bool] = {}
self._session_paths_checked = False
- def set_session(self, session):
+ def set_session(self, session: Optional[Session]) -> None:
self.session = session
self._session_paths_checked = False
# Indirection so we can mock calls to find_spec originated from the hook during testing
_find_spec = importlib.machinery.PathFinder.find_spec
- def find_spec(self, name, path=None, target=None):
+ def find_spec(
+ self,
+ name: str,
+ path: Optional[Sequence[Union[str, bytes]]] = None,
+ target: Optional[types.ModuleType] = None,
+ ) -> Optional[importlib.machinery.ModuleSpec]:
if self._writing_pyc:
return None
state = self.config._store[assertstate_key]
@@ -79,13 +92,14 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return None
state.trace("find_module called for: %s" % name)
- spec = self._find_spec(name, path)
+ # Type ignored because mypy is confused about the `self` binding here.
+ spec = self._find_spec(name, path) # type: ignore
if (
# the import machinery could not find a file to import
spec is None
# this is a namespace package (without `__init__.py`)
# there's nothing to rewrite there
- # python3.5 - python3.6: `namespace`
+ # python3.6: `namespace`
# python3.7+: `None`
or spec.origin == "namespace"
or spec.origin is None
@@ -108,10 +122,14 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
submodule_search_locations=spec.submodule_search_locations,
)
- def create_module(self, spec):
+ def create_module(
+ self, spec: importlib.machinery.ModuleSpec
+ ) -> Optional[types.ModuleType]:
return None # default behaviour is fine
- def exec_module(self, module):
+ def exec_module(self, module: types.ModuleType) -> None:
+ assert module.__spec__ is not None
+ assert module.__spec__.origin is not None
fn = Path(module.__spec__.origin)
state = self.config._store[assertstate_key]
@@ -131,7 +149,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
ok = try_makedirs(cache_dir)
if not ok:
write = False
- state.trace("read only directory: {}".format(cache_dir))
+ state.trace(f"read only directory: {cache_dir}")
cache_name = fn.name[:-3] + PYC_TAIL
pyc = cache_dir / cache_name
@@ -139,7 +157,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
# to check for a cached pyc. This may not be optimal...
co = _read_pyc(fn, pyc, state.trace)
if co is None:
- state.trace("rewriting {!r}".format(fn))
+ state.trace(f"rewriting {fn!r}")
source_stat, co = _rewrite_test(fn, self.config)
if write:
self._writing_pyc = True
@@ -148,11 +166,11 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
finally:
self._writing_pyc = False
else:
- state.trace("found cached rewritten pyc for {}".format(fn))
+ state.trace(f"found cached rewritten pyc for {fn}")
exec(co, module.__dict__)
- def _early_rewrite_bailout(self, name, state):
- """This is a fast way to get out of rewriting modules.
+ def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
+ """A fast way to get out of rewriting modules.
Profiling has shown that the call to PathFinder.find_spec (inside of
the find_spec from this class) is a major slowdown, so, this method
@@ -161,10 +179,10 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
"""
if self.session is not None and not self._session_paths_checked:
self._session_paths_checked = True
- for path in self.session._initialpaths:
+ for initial_path in self.session._initialpaths:
# Make something as c:/projects/my_project/path.py ->
# ['c:', 'projects', 'my_project', 'path.py']
- parts = str(path).split(os.path.sep)
+ parts = str(initial_path).split(os.path.sep)
# add 'path' to basenames to be checked.
self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
@@ -187,20 +205,18 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
if self._is_marked_for_rewrite(name, state):
return False
- state.trace("early skip of rewriting module: {}".format(name))
+ state.trace(f"early skip of rewriting module: {name}")
return True
- def _should_rewrite(self, name, fn, state):
+ def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
# always rewrite conftest files
if os.path.basename(fn) == "conftest.py":
- state.trace("rewriting conftest file: {!r}".format(fn))
+ state.trace(f"rewriting conftest file: {fn!r}")
return True
if self.session is not None:
- if self.session.isinitpath(fn):
- state.trace(
- "matched test file (was specified on cmdline): {!r}".format(fn)
- )
+ if self.session.isinitpath(py.path.local(fn)):
+ state.trace(f"matched test file (was specified on cmdline): {fn!r}")
return True
# modules not passed explicitly on the command line are only
@@ -208,20 +224,18 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
fn_path = PurePath(fn)
for pat in self.fnpats:
if fnmatch_ex(pat, fn_path):
- state.trace("matched test file {!r}".format(fn))
+ state.trace(f"matched test file {fn!r}")
return True
return self._is_marked_for_rewrite(name, state)
- def _is_marked_for_rewrite(self, name: str, state):
+ def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
try:
return self._marked_for_rewrite_cache[name]
except KeyError:
for marked in self._must_rewrite:
if name == marked or name.startswith(marked + "."):
- state.trace(
- "matched marked file {!r} (from {!r})".format(name, marked)
- )
+ state.trace(f"matched marked file {name!r} (from {marked!r})")
self._marked_for_rewrite_cache[name] = True
return True
@@ -246,33 +260,37 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear()
- def _warn_already_imported(self, name):
+ def _warn_already_imported(self, name: str) -> None:
from _pytest.warning_types import PytestAssertRewriteWarning
- from _pytest.warnings import _issue_warning_captured
- _issue_warning_captured(
+ self.config.issue_config_time_warning(
PytestAssertRewriteWarning(
"Module already imported so cannot be rewritten: %s" % name
),
- self.config.hook,
stacklevel=5,
)
- def get_data(self, pathname):
+ def get_data(self, pathname: Union[str, bytes]) -> bytes:
"""Optional PEP302 get_data API."""
with open(pathname, "rb") as f:
return f.read()
-def _write_pyc_fp(fp, source_stat, co):
+def _write_pyc_fp(
+ fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
+) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
- # import. However, there's little reason deviate.
+ # import. However, there's little reason to deviate.
fp.write(importlib.util.MAGIC_NUMBER)
+ # https://www.python.org/dev/peps/pep-0552/
+ if sys.version_info >= (3, 7):
+ flags = b"\x00\x00\x00\x00"
+ fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
- # "<LL" stands for 2 unsigned longs, little-ending
+ # "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
fp.write(marshal.dumps(co))
@@ -280,12 +298,17 @@ def _write_pyc_fp(fp, source_stat, co):
if sys.platform == "win32":
from atomicwrites import atomic_write
- def _write_pyc(state, co, source_stat, pyc):
+ def _write_pyc(
+ state: "AssertionState",
+ co: types.CodeType,
+ source_stat: os.stat_result,
+ pyc: Path,
+ ) -> bool:
try:
- with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
+ with atomic_write(os.fspath(pyc), mode="wb", overwrite=True) as fp:
_write_pyc_fp(fp, source_stat, co)
- except EnvironmentError as e:
- state.trace("error writing pyc file at {}: errno={}".format(pyc, e.errno))
+ except OSError as e:
+ state.trace(f"error writing pyc file at {pyc}: {e}")
# we ignore any failure to write the cache file
# there are many reasons, permission-denied, pycache dir being a
# file etc.
@@ -295,21 +318,24 @@ if sys.platform == "win32":
else:
- def _write_pyc(state, co, source_stat, pyc):
- proc_pyc = "{}.{}".format(pyc, os.getpid())
+ def _write_pyc(
+ state: "AssertionState",
+ co: types.CodeType,
+ source_stat: os.stat_result,
+ pyc: Path,
+ ) -> bool:
+ proc_pyc = f"{pyc}.{os.getpid()}"
try:
fp = open(proc_pyc, "wb")
- except EnvironmentError as e:
- state.trace(
- "error writing pyc file at {}: errno={}".format(proc_pyc, e.errno)
- )
+ except OSError as e:
+ state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False
try:
_write_pyc_fp(fp, source_stat, co)
- os.rename(proc_pyc, fspath(pyc))
- except BaseException as e:
- state.trace("error writing pyc file at {}: errno={}".format(pyc, e.errno))
+ os.rename(proc_pyc, os.fspath(pyc))
+ except OSError as e:
+ state.trace(f"error writing pyc file at {pyc}: {e}")
# we ignore any failure to write the cache file
# there are many reasons, permission-denied, pycache dir being a
# file etc.
@@ -319,48 +345,62 @@ else:
return True
-def _rewrite_test(fn, config):
- """read and rewrite *fn* and return the code object."""
- fn = fspath(fn)
- stat = os.stat(fn)
- with open(fn, "rb") as f:
+def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
+ """Read and rewrite *fn* and return the code object."""
+ fn_ = os.fspath(fn)
+ stat = os.stat(fn_)
+ with open(fn_, "rb") as f:
source = f.read()
- tree = ast.parse(source, filename=fn)
- rewrite_asserts(tree, source, fn, config)
- co = compile(tree, fn, "exec", dont_inherit=True)
+ tree = ast.parse(source, filename=fn_)
+ rewrite_asserts(tree, source, fn_, config)
+ co = compile(tree, fn_, "exec", dont_inherit=True)
return stat, co
-def _read_pyc(source, pyc, trace=lambda x: None):
+def _read_pyc(
+ source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
+) -> Optional[types.CodeType]:
"""Possibly read a pytest pyc containing rewritten code.
Return rewritten code if successful or None if not.
"""
try:
- fp = open(fspath(pyc), "rb")
- except IOError:
+ fp = open(os.fspath(pyc), "rb")
+ except OSError:
return None
with fp:
+ # https://www.python.org/dev/peps/pep-0552/
+ has_flags = sys.version_info >= (3, 7)
try:
- stat_result = os.stat(fspath(source))
+ stat_result = os.stat(os.fspath(source))
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
- data = fp.read(12)
- except EnvironmentError as e:
- trace("_read_pyc({}): EnvironmentError {}".format(source, e))
+ data = fp.read(16 if has_flags else 12)
+ except OSError as e:
+ trace(f"_read_pyc({source}): OSError {e}")
return None
# Check for invalid or out of date pyc file.
- if (
- len(data) != 12
- or data[:4] != importlib.util.MAGIC_NUMBER
- or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
- ):
- trace("_read_pyc(%s): invalid or out of date pyc" % source)
+ if len(data) != (16 if has_flags else 12):
+ trace("_read_pyc(%s): invalid pyc (too short)" % source)
+ return None
+ if data[:4] != importlib.util.MAGIC_NUMBER:
+ trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
+ return None
+ if has_flags and data[4:8] != b"\x00\x00\x00\x00":
+ trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
+ return None
+ mtime_data = data[8 if has_flags else 4 : 12 if has_flags else 8]
+ if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
+ trace("_read_pyc(%s): out of date" % source)
+ return None
+ size_data = data[12 if has_flags else 8 : 16 if has_flags else 12]
+ if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
+ trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
try:
co = marshal.load(fp)
except Exception as e:
- trace("_read_pyc({}): marshal.load error {}".format(source, e))
+ trace(f"_read_pyc({source}): marshal.load error {e}")
return None
if not isinstance(co, types.CodeType):
trace("_read_pyc(%s): not a code object" % source)
@@ -368,13 +408,18 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co
-def rewrite_asserts(mod, source, module_path=None, config=None):
+def rewrite_asserts(
+ mod: ast.Module,
+ source: bytes,
+ module_path: Optional[str] = None,
+ config: Optional[Config] = None,
+) -> None:
"""Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config, source).run(mod)
-def _saferepr(obj):
- """Get a safe repr of an object for assertion error messages.
+def _saferepr(obj: object) -> str:
+ r"""Get a safe repr of an object for assertion error messages.
The assertion formatting (util.format_explanation()) requires
newlines to be escaped since they are a special character for it.
@@ -382,18 +427,16 @@ def _saferepr(obj):
custom repr it is possible to contain one of the special escape
sequences, especially '\n{' and '\n}' are likely to be present in
JSON reprs.
-
"""
return saferepr(obj).replace("\n", "\\n")
-def _format_assertmsg(obj):
- """Format the custom assertion message given.
+def _format_assertmsg(obj: object) -> str:
+ r"""Format the custom assertion message given.
For strings this simply replaces newlines with '\n~' so that
util.format_explanation() will preserve them instead of escaping
newlines. For other objects saferepr() is used first.
-
"""
# reprlib appears to have a bug which means that if a string
# contains a newline it gets escaped, however if an object has a
@@ -410,7 +453,7 @@ def _format_assertmsg(obj):
return obj
-def _should_repr_global_name(obj):
+def _should_repr_global_name(obj: object) -> bool:
if callable(obj):
return False
@@ -420,16 +463,17 @@ def _should_repr_global_name(obj):
return True
-def _format_boolop(explanations, is_or):
+def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
- if isinstance(explanation, str):
- return explanation.replace("%", "%%")
- else:
- return explanation.replace(b"%", b"%%")
+ return explanation.replace("%", "%%")
-def _call_reprcompare(ops, results, expls, each_obj):
- # type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
+def _call_reprcompare(
+ ops: Sequence[str],
+ results: Sequence[bool],
+ expls: Sequence[str],
+ each_obj: Sequence[object],
+) -> str:
for i, res, expl in zip(range(len(ops)), results, expls):
try:
done = not res
@@ -444,16 +488,14 @@ def _call_reprcompare(ops, results, expls, each_obj):
return expl
-def _call_assertion_pass(lineno, orig, expl):
- # type: (int, str, str) -> None
+def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
if util._assertion_pass is not None:
util._assertion_pass(lineno, orig, expl)
-def _check_if_assertion_pass_impl():
- # type: () -> bool
- """Checks if any plugins implement the pytest_assertion_pass hook
- in order not to generate explanation unecessarily (might be expensive)"""
+def _check_if_assertion_pass_impl() -> bool:
+ """Check if any plugins implement the pytest_assertion_pass hook
+ in order not to generate explanation unecessarily (might be expensive)."""
return True if util._assertion_pass else False
@@ -502,13 +544,13 @@ def set_location(node, lineno, col_offset):
def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
- """Returns a mapping from {lineno: "assertion test expression"}"""
- ret = {} # type: Dict[int, str]
+ """Return a mapping from {lineno: "assertion test expression"}."""
+ ret: Dict[int, str] = {}
depth = 0
- lines = [] # type: List[str]
- assert_lineno = None # type: Optional[int]
- seen_lines = set() # type: Set[int]
+ lines: List[str] = []
+ assert_lineno: Optional[int] = None
+ seen_lines: Set[int] = set()
def _write_and_reset() -> None:
nonlocal depth, lines, assert_lineno, seen_lines
@@ -606,10 +648,11 @@ class AssertionRewriter(ast.NodeVisitor):
This state is reset on every new assert statement visited and used
by the other visitors.
-
"""
- def __init__(self, module_path, config, source):
+ def __init__(
+ self, module_path: Optional[str], config: Optional[Config], source: bytes
+ ) -> None:
super().__init__()
self.module_path = module_path
self.config = config
@@ -622,7 +665,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.source = source
@functools.lru_cache(maxsize=1)
- def _assert_expr_to_lineno(self):
+ def _assert_expr_to_lineno(self) -> Dict[int, str]:
return _get_assertion_exprs(self.source)
def run(self, mod: ast.Module) -> None:
@@ -630,12 +673,9 @@ class AssertionRewriter(ast.NodeVisitor):
if not mod.body:
# Nothing to do.
return
- # Insert some special imports at the top of the module but after any
- # docstrings and __future__ imports.
- aliases = [
- ast.alias("builtins", "@py_builtins"),
- ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
- ]
+
+ # We'll insert some special imports at the top of the module, but after any
+ # docstrings and __future__ imports, so first figure out where that is.
doc = getattr(mod, "docstring", None)
expect_docstring = doc is None
if doc is not None and self.is_rewrite_disabled(doc):
@@ -653,26 +693,48 @@ class AssertionRewriter(ast.NodeVisitor):
return
expect_docstring = False
elif (
- not isinstance(item, ast.ImportFrom)
- or item.level > 0
- or item.module != "__future__"
+ isinstance(item, ast.ImportFrom)
+ and item.level == 0
+ and item.module == "__future__"
):
- lineno = item.lineno
+ pass
+ else:
break
pos += 1
+ # Special case: for a decorated function, set the lineno to that of the
+ # first decorator, not the `def`. Issue #4984.
+ if isinstance(item, ast.FunctionDef) and item.decorator_list:
+ lineno = item.decorator_list[0].lineno
else:
lineno = item.lineno
+ # Now actually insert the special imports.
+ if sys.version_info >= (3, 10):
+ aliases = [
+ ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
+ ast.alias(
+ "_pytest.assertion.rewrite",
+ "@pytest_ar",
+ lineno=lineno,
+ col_offset=0,
+ ),
+ ]
+ else:
+ aliases = [
+ ast.alias("builtins", "@py_builtins"),
+ ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
+ ]
imports = [
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
]
mod.body[pos:pos] = imports
+
# Collect asserts.
- nodes = [mod] # type: List[ast.AST]
+ nodes: List[ast.AST] = [mod]
while nodes:
node = nodes.pop()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
- new = [] # type: List
+ new: List[ast.AST] = []
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
@@ -691,51 +753,50 @@ class AssertionRewriter(ast.NodeVisitor):
nodes.append(field)
@staticmethod
- def is_rewrite_disabled(docstring):
+ def is_rewrite_disabled(docstring: str) -> bool:
return "PYTEST_DONT_REWRITE" in docstring
- def variable(self):
+ def variable(self) -> str:
"""Get a new variable."""
# Use a character invalid in python identifiers to avoid clashing.
name = "@py_assert" + str(next(self.variable_counter))
self.variables.append(name)
return name
- def assign(self, expr):
+ def assign(self, expr: ast.expr) -> ast.Name:
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load())
- def display(self, expr):
+ def display(self, expr: ast.expr) -> ast.expr:
"""Call saferepr on the expression."""
return self.helper("_saferepr", expr)
- def helper(self, name, *args):
+ def helper(self, name: str, *args: ast.expr) -> ast.expr:
"""Call a helper in this module."""
py_name = ast.Name("@pytest_ar", ast.Load())
attr = ast.Attribute(py_name, name, ast.Load())
return ast.Call(attr, list(args), [])
- def builtin(self, name):
+ def builtin(self, name: str) -> ast.Attribute:
"""Return the builtin called *name*."""
builtin_name = ast.Name("@py_builtins", ast.Load())
return ast.Attribute(builtin_name, name, ast.Load())
- def explanation_param(self, expr):
+ def explanation_param(self, expr: ast.expr) -> str:
"""Return a new named %-formatting placeholder for expr.
This creates a %-formatting placeholder for expr in the
current formatting context, e.g. ``%(py0)s``. The placeholder
and expr are placed in the current format context so that it
can be used on the next call to .pop_format_context().
-
"""
specifier = "py" + str(next(self.variable_counter))
self.explanation_specifiers[specifier] = expr
return "%(" + specifier + ")s"
- def push_format_context(self):
+ def push_format_context(self) -> None:
"""Create a new formatting context.
The format context is used for when an explanation wants to
@@ -744,19 +805,17 @@ class AssertionRewriter(ast.NodeVisitor):
.explanation_param(). Finally .pop_format_context() is used
to format a string of %-formatted values as added by
.explanation_param().
-
"""
- self.explanation_specifiers = {} # type: Dict[str, ast.expr]
+ self.explanation_specifiers: Dict[str, ast.expr] = {}
self.stack.append(self.explanation_specifiers)
- def pop_format_context(self, expl_expr):
+ def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
"""Format the %-formatted string with current format context.
- The expl_expr should be an ast.Str instance constructed from
+ The expl_expr should be an str ast.expr instance constructed from
the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string.
-
"""
current = self.stack.pop()
if self.stack:
@@ -770,43 +829,44 @@ class AssertionRewriter(ast.NodeVisitor):
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())
- def generic_visit(self, node):
+ def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
"""Handle expressions we don't have custom code for."""
assert isinstance(node, ast.expr)
res = self.assign(node)
return res, self.explanation_param(self.display(res))
- def visit_Assert(self, assert_):
+ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
"""Return the AST statements to replace the ast.Assert instance.
This rewrites the test of an assertion to provide
intermediate values and replace it with an if statement which
raises an assertion error with a detailed explanation in case
the expression is false.
-
"""
if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
from _pytest.warning_types import PytestAssertRewriteWarning
import warnings
+ # TODO: This assert should not be needed.
+ assert self.module_path is not None
warnings.warn_explicit(
PytestAssertRewriteWarning(
"assertion is always true, perhaps remove parentheses?"
),
category=None,
- filename=fspath(self.module_path),
+ filename=os.fspath(self.module_path),
lineno=assert_.lineno,
)
- self.statements = [] # type: List[ast.stmt]
- self.variables = [] # type: List[str]
+ self.statements: List[ast.stmt] = []
+ self.variables: List[str] = []
self.variable_counter = itertools.count()
if self.enable_assertion_pass_hook:
- self.format_variables = [] # type: List[str]
+ self.format_variables: List[str] = []
- self.stack = [] # type: List[Dict[str, ast.expr]]
- self.expl_stmts = [] # type: List[ast.stmt]
+ self.stack: List[Dict[str, ast.expr]] = []
+ self.expl_stmts: List[ast.stmt] = []
self.push_format_context()
# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
@@ -891,7 +951,7 @@ class AssertionRewriter(ast.NodeVisitor):
set_location(stmt, assert_.lineno, assert_.col_offset)
return self.statements
- def visit_Name(self, name):
+ def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
@@ -901,7 +961,7 @@ class AssertionRewriter(ast.NodeVisitor):
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
return name, self.explanation_param(expr)
- def visit_BoolOp(self, boolop):
+ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
res_var = self.variable()
expl_list = self.assign(ast.List([], ast.Load()))
app = ast.Attribute(expl_list, "append", ast.Load())
@@ -913,7 +973,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Process each operand, short-circuiting if needed.
for i, v in enumerate(boolop.values):
if i:
- fail_inner = [] # type: List[ast.stmt]
+ fail_inner: List[ast.stmt] = []
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
@@ -924,10 +984,10 @@ class AssertionRewriter(ast.NodeVisitor):
call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call))
if i < levels:
- cond = res # type: ast.expr
+ cond: ast.expr = res
if is_or:
cond = ast.UnaryOp(ast.Not(), cond)
- inner = [] # type: List[ast.stmt]
+ inner: List[ast.stmt] = []
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
self.statements = save
@@ -936,24 +996,21 @@ class AssertionRewriter(ast.NodeVisitor):
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
- def visit_UnaryOp(self, unary):
+ def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res))
return res, pattern % (operand_expl,)
- def visit_BinOp(self, binop):
+ def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right)
- explanation = "({} {} {})".format(left_expl, symbol, right_expl)
+ explanation = f"({left_expl} {symbol} {right_expl})"
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
return res, explanation
- def visit_Call(self, call):
- """
- visit `ast.Call` nodes
- """
+ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
@@ -974,16 +1031,16 @@ class AssertionRewriter(ast.NodeVisitor):
new_call = ast.Call(new_func, new_args, new_kwargs)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
- outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
+ outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
return res, outer_expl
- def visit_Starred(self, starred):
- # From Python 3.5, a Starred node can appear in a function call
+ def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
+ # A Starred node can appear in a function call.
res, expl = self.visit(starred.value)
new_starred = ast.Starred(res, starred.ctx)
return new_starred, "*" + expl
- def visit_Attribute(self, attr):
+ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
value, value_expl = self.visit(attr.value)
@@ -993,11 +1050,11 @@ class AssertionRewriter(ast.NodeVisitor):
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl
- def visit_Compare(self, comp: ast.Compare):
+ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
- left_expl = "({})".format(left_expl)
+ left_expl = f"({left_expl})"
res_variables = [self.variable() for i in range(len(comp.ops))]
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
@@ -1008,11 +1065,11 @@ class AssertionRewriter(ast.NodeVisitor):
for i, op, next_operand in it:
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
- next_expl = "({})".format(next_expl)
+ next_expl = f"({next_expl})"
results.append(next_res)
sym = BINOP_MAP[op.__class__]
syms.append(ast.Str(sym))
- expl = "{} {} {}".format(left_expl, sym, next_expl)
+ expl = f"{left_expl} {sym} {next_expl}"
expls.append(ast.Str(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
self.statements.append(ast.Assign([store_names[i]], res_expr))
@@ -1026,17 +1083,19 @@ class AssertionRewriter(ast.NodeVisitor):
ast.Tuple(results, ast.Load()),
)
if len(comp.ops) > 1:
- res = ast.BoolOp(ast.And(), load_names) # type: ast.expr
+ res: ast.expr = ast.BoolOp(ast.And(), load_names)
else:
res = load_names[0]
return res, self.explanation_param(self.pop_format_context(expl_call))
-def try_makedirs(cache_dir) -> bool:
- """Attempts to create the given directory and sub-directories exist, returns True if
- successful or it already exists"""
+def try_makedirs(cache_dir: Path) -> bool:
+ """Attempt to create the given directory and sub-directories exist.
+
+ Returns True if successful or if it already exists.
+ """
try:
- os.makedirs(fspath(cache_dir), exist_ok=True)
+ os.makedirs(os.fspath(cache_dir), exist_ok=True)
except (FileNotFoundError, NotADirectoryError, FileExistsError):
# One of the path components was not a directory:
# - we're in a zip file
@@ -1053,7 +1112,7 @@ def try_makedirs(cache_dir) -> bool:
def get_cache_dir(file_path: Path) -> Path:
- """Returns the cache directory to write .pyc files for the given .py file path"""
+ """Return the cache directory to write .pyc files for the given .py file path."""
if sys.version_info >= (3, 8) and sys.pycache_prefix:
# given:
# prefix = '/tmp/pycs'
diff --git a/contrib/python/pytest/py3/_pytest/assertion/truncate.py b/contrib/python/pytest/py3/_pytest/assertion/truncate.py
index d97b05b441..5ba9ddca75 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/truncate.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/truncate.py
@@ -1,42 +1,47 @@
-"""
-Utilities for truncating assertion output.
+"""Utilities for truncating assertion output.
Current default behaviour is to truncate assertion explanations at
~8 terminal lines, unless running in "-vv" mode or running on CI.
"""
import os
+from typing import List
+from typing import Optional
+
+from _pytest.nodes import Item
+
DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = 8 * 80
USAGE_MSG = "use '-vv' to show"
-def truncate_if_required(explanation, item, max_length=None):
- """
- Truncate this assertion explanation if the given test item is eligible.
- """
+def truncate_if_required(
+ explanation: List[str], item: Item, max_length: Optional[int] = None
+) -> List[str]:
+ """Truncate this assertion explanation if the given test item is eligible."""
if _should_truncate_item(item):
return _truncate_explanation(explanation)
return explanation
-def _should_truncate_item(item):
- """
- Whether or not this test item is eligible for truncation.
- """
+def _should_truncate_item(item: Item) -> bool:
+ """Whether or not this test item is eligible for truncation."""
verbose = item.config.option.verbose
return verbose < 2 and not _running_on_ci()
-def _running_on_ci():
+def _running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars)
-def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
- """
- Truncate given list of strings that makes up the assertion explanation.
+def _truncate_explanation(
+ input_lines: List[str],
+ max_lines: Optional[int] = None,
+ max_chars: Optional[int] = None,
+) -> List[str]:
+ """Truncate given list of strings that makes up the assertion explanation.
Truncates to either 8 lines, or 640 characters - whichever the input reaches
first. The remaining lines will be replaced by a usage message.
@@ -65,15 +70,15 @@ def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
truncated_line_count += 1 # Account for the part-truncated final line
msg = "...Full output truncated"
if truncated_line_count == 1:
- msg += " ({} line hidden)".format(truncated_line_count)
+ msg += f" ({truncated_line_count} line hidden)"
else:
- msg += " ({} lines hidden)".format(truncated_line_count)
- msg += ", {}".format(USAGE_MSG)
+ msg += f" ({truncated_line_count} lines hidden)"
+ msg += f", {USAGE_MSG}"
truncated_explanation.extend(["", str(msg)])
return truncated_explanation
-def _truncate_by_char_count(input_lines, max_chars):
+def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
# Check if truncation required
if len("".join(input_lines)) <= max_chars:
return input_lines
diff --git a/contrib/python/pytest/py3/_pytest/assertion/util.py b/contrib/python/pytest/py3/_pytest/assertion/util.py
index 7d525aa4c4..da1ffd15e3 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/util.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/util.py
@@ -1,4 +1,4 @@
-"""Utilities for assertion debugging"""
+"""Utilities for assertion debugging."""
import collections.abc
import pprint
from typing import AbstractSet
@@ -9,28 +9,26 @@ from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
-from typing import Tuple
import _pytest._code
from _pytest import outcomes
from _pytest._io.saferepr import _pformat_dispatch
from _pytest._io.saferepr import safeformat
from _pytest._io.saferepr import saferepr
-from _pytest.compat import ATTRS_EQ_FIELD
# The _reprcompare attribute on the util module is used by the new assertion
# interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the
# DebugInterpreter.
-_reprcompare = None # type: Optional[Callable[[str, object, object], Optional[str]]]
+_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
-_assertion_pass = None # type: Optional[Callable[[int, str, str], None]]
+_assertion_pass: Optional[Callable[[int, str, str], None]] = None
def format_explanation(explanation: str) -> str:
- """This formats an explanation
+ r"""Format an explanation.
Normally all embedded newlines are escaped, however there are
three exceptions: \n{, \n} and \n~. The first two are intended
@@ -45,7 +43,7 @@ def format_explanation(explanation: str) -> str:
def _split_explanation(explanation: str) -> List[str]:
- """Return a list of individual lines in the explanation
+ r"""Return a list of individual lines in the explanation.
This will return a list of lines split on '\n{', '\n}' and '\n~'.
Any other newlines will be escaped and appear in the line as the
@@ -62,11 +60,11 @@ def _split_explanation(explanation: str) -> List[str]:
def _format_lines(lines: Sequence[str]) -> List[str]:
- """Format the individual lines
+ """Format the individual lines.
- This will replace the '{', '}' and '~' characters of our mini
- formatting language with the proper 'where ...', 'and ...' and ' +
- ...' text, taking care of indentation along the way.
+ This will replace the '{', '}' and '~' characters of our mini formatting
+ language with the proper 'where ...', 'and ...' and ' + ...' text, taking
+ care of indentation along the way.
Return a list of formatted lines.
"""
@@ -112,6 +110,10 @@ def isset(x: Any) -> bool:
return isinstance(x, (set, frozenset))
+def isnamedtuple(obj: Any) -> bool:
+ return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
+
+
def isdatacls(obj: Any) -> bool:
return getattr(obj, "__dataclass_fields__", None) is not None
@@ -129,7 +131,7 @@ def isiterable(obj: Any) -> bool:
def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]:
- """Return specialised explanations for some operators/operands"""
+ """Return specialised explanations for some operators/operands."""
verbose = config.getoption("verbose")
if verbose > 1:
left_repr = safeformat(left)
@@ -143,31 +145,12 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[
left_repr = saferepr(left, maxsize=maxsize)
right_repr = saferepr(right, maxsize=maxsize)
- summary = "{} {} {}".format(left_repr, op, right_repr)
+ summary = f"{left_repr} {op} {right_repr}"
explanation = None
try:
if op == "==":
- if istext(left) and istext(right):
- explanation = _diff_text(left, right, verbose)
- else:
- if issequence(left) and issequence(right):
- explanation = _compare_eq_sequence(left, right, verbose)
- elif isset(left) and isset(right):
- explanation = _compare_eq_set(left, right, verbose)
- elif isdict(left) and isdict(right):
- explanation = _compare_eq_dict(left, right, verbose)
- elif type(left) == type(right) and (isdatacls(left) or isattrs(left)):
- type_fn = (isdatacls, isattrs)
- explanation = _compare_eq_cls(left, right, verbose, type_fn)
- elif verbose > 0:
- explanation = _compare_eq_verbose(left, right)
- if isiterable(left) and isiterable(right):
- expl = _compare_eq_iterable(left, right, verbose)
- if explanation is not None:
- explanation.extend(expl)
- else:
- explanation = expl
+ explanation = _compare_eq_any(left, right, verbose)
elif op == "not in":
if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose)
@@ -187,6 +170,33 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[
return [summary] + explanation
+def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
+ explanation = []
+ if istext(left) and istext(right):
+ explanation = _diff_text(left, right, verbose)
+ else:
+ if type(left) == type(right) and (
+ isdatacls(left) or isattrs(left) or isnamedtuple(left)
+ ):
+ # Note: unlike dataclasses/attrs, namedtuples compare only the
+ # field values, not the type or field names. But this branch
+ # intentionally only handles the same-type case, which was often
+ # used in older code bases before dataclasses/attrs were available.
+ explanation = _compare_eq_cls(left, right, verbose)
+ elif issequence(left) and issequence(right):
+ explanation = _compare_eq_sequence(left, right, verbose)
+ elif isset(left) and isset(right):
+ explanation = _compare_eq_set(left, right, verbose)
+ elif isdict(left) and isdict(right):
+ explanation = _compare_eq_dict(left, right, verbose)
+ elif verbose > 0:
+ explanation = _compare_eq_verbose(left, right)
+ if isiterable(left) and isiterable(right):
+ expl = _compare_eq_iterable(left, right, verbose)
+ explanation.extend(expl)
+ return explanation
+
+
def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
"""Return the explanation for the diff between text.
@@ -195,7 +205,7 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
"""
from difflib import ndiff
- explanation = [] # type: List[str]
+ explanation: List[str] = []
if verbose < 1:
i = 0 # just in case left or right has zero length
@@ -240,7 +250,7 @@ def _compare_eq_verbose(left: Any, right: Any) -> List[str]:
left_lines = repr(left).splitlines(keepends)
right_lines = repr(right).splitlines(keepends)
- explanation = [] # type: List[str]
+ explanation: List[str] = []
explanation += ["+" + line for line in left_lines]
explanation += ["-" + line for line in right_lines]
@@ -294,7 +304,7 @@ def _compare_eq_sequence(
left: Sequence[Any], right: Sequence[Any], verbose: int = 0
) -> List[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
- explanation = [] # type: List[str]
+ explanation: List[str] = []
len_left = len(left)
len_right = len(right)
for i in range(min(len_left, len_right)):
@@ -314,9 +324,7 @@ def _compare_eq_sequence(
left_value = left[i]
right_value = right[i]
- explanation += [
- "At index {} diff: {!r} != {!r}".format(i, left_value, right_value)
- ]
+ explanation += [f"At index {i} diff: {left_value!r} != {right_value!r}"]
break
if comparing_bytes:
@@ -336,9 +344,7 @@ def _compare_eq_sequence(
extra = saferepr(right[len_left])
if len_diff == 1:
- explanation += [
- "{} contains one more item: {}".format(dir_with_more, extra)
- ]
+ explanation += [f"{dir_with_more} contains one more item: {extra}"]
else:
explanation += [
"%s contains %d more items, first extra item: %s"
@@ -367,7 +373,7 @@ def _compare_eq_set(
def _compare_eq_dict(
left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
) -> List[str]:
- explanation = [] # type: List[str]
+ explanation: List[str] = []
set_left = set(left)
set_right = set(right)
common = set_left.intersection(set_right)
@@ -405,22 +411,19 @@ def _compare_eq_dict(
return explanation
-def _compare_eq_cls(
- left: Any,
- right: Any,
- verbose: int,
- type_fns: Tuple[Callable[[Any], bool], Callable[[Any], bool]],
-) -> List[str]:
- isdatacls, isattrs = type_fns
+def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
if isdatacls(left):
all_fields = left.__dataclass_fields__
fields_to_check = [field for field, info in all_fields.items() if info.compare]
elif isattrs(left):
all_fields = left.__attrs_attrs__
- fields_to_check = [
- field.name for field in all_fields if getattr(field, ATTRS_EQ_FIELD)
- ]
+ fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
+ elif isnamedtuple(left):
+ fields_to_check = left._fields
+ else:
+ assert False
+ indent = " "
same = []
diff = []
for field in fields_to_check:
@@ -430,6 +433,8 @@ def _compare_eq_cls(
diff.append(field)
explanation = []
+ if same or diff:
+ explanation += [""]
if same and verbose < 2:
explanation.append("Omitting %s identical items, use -vv to show" % len(same))
elif same:
@@ -437,9 +442,18 @@ def _compare_eq_cls(
explanation += pprint.pformat(same).splitlines()
if diff:
explanation += ["Differing attributes:"]
+ explanation += pprint.pformat(diff).splitlines()
for field in diff:
+ field_left = getattr(left, field)
+ field_right = getattr(right, field)
+ explanation += [
+ "",
+ "Drill down into differing attribute %s:" % field,
+ ("%s%s: %r != %r") % (indent, field, field_left, field_right),
+ ]
explanation += [
- ("%s: %r != %r") % (field, getattr(left, field), getattr(right, field))
+ indent + line
+ for line in _compare_eq_any(field_left, field_right, verbose)
]
return explanation
diff --git a/contrib/python/pytest/py3/_pytest/cacheprovider.py b/contrib/python/pytest/py3/_pytest/cacheprovider.py
index a0f486089f..03acd03109 100644
--- a/contrib/python/pytest/py3/_pytest/cacheprovider.py
+++ b/contrib/python/pytest/py3/_pytest/cacheprovider.py
@@ -1,31 +1,38 @@
-"""
-merged implementation of the cache provider
-
-the name cache was not chosen to ensure pluggy automatically
-ignores the external pytest-cache
-"""
+"""Implementation of the cache provider."""
+# This plugin was not named "cache" to avoid conflicts with the external
+# pytest-cache version.
import json
import os
-from collections import OrderedDict
+from pathlib import Path
from typing import Dict
from typing import Generator
+from typing import Iterable
from typing import List
from typing import Optional
from typing import Set
+from typing import Union
import attr
import py
-import pytest
-from .pathlib import Path
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
from _pytest import nodes
from _pytest._io import TerminalWriter
+from _pytest.compat import final
from _pytest.config import Config
+from _pytest.config import ExitCode
+from _pytest.config import hookimpl
+from _pytest.config.argparsing import Parser
+from _pytest.deprecated import check_ispytest
+from _pytest.fixtures import fixture
+from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.python import Module
+from _pytest.python import Package
+from _pytest.reports import TestReport
+
README_CONTENT = """\
# pytest cache directory #
@@ -35,7 +42,7 @@ which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
**Do not** commit this to version control.
-See [the docs](https://docs.pytest.org/en/latest/cache.html) for more information.
+See [the docs](https://docs.pytest.org/en/stable/cache.html) for more information.
"""
CACHEDIR_TAG_CONTENT = b"""\
@@ -46,10 +53,11 @@ Signature: 8a477f597d28d172789f06886806bc55
"""
-@attr.s
+@final
+@attr.s(init=False)
class Cache:
- _cachedir = attr.ib(repr=False)
- _config = attr.ib(repr=False)
+ _cachedir = attr.ib(type=Path, repr=False)
+ _config = attr.ib(type=Config, repr=False)
# sub-directory under cache-dir for directories created by "makedir"
_CACHE_PREFIX_DIRS = "d"
@@ -57,26 +65,52 @@ class Cache:
# sub-directory under cache-dir for values created by "set"
_CACHE_PREFIX_VALUES = "v"
+ def __init__(
+ self, cachedir: Path, config: Config, *, _ispytest: bool = False
+ ) -> None:
+ check_ispytest(_ispytest)
+ self._cachedir = cachedir
+ self._config = config
+
@classmethod
- def for_config(cls, config):
- cachedir = cls.cache_dir_from_config(config)
+ def for_config(cls, config: Config, *, _ispytest: bool = False) -> "Cache":
+ """Create the Cache instance for a Config.
+
+ :meta private:
+ """
+ check_ispytest(_ispytest)
+ cachedir = cls.cache_dir_from_config(config, _ispytest=True)
if config.getoption("cacheclear") and cachedir.is_dir():
- cls.clear_cache(cachedir)
- return cls(cachedir, config)
+ cls.clear_cache(cachedir, _ispytest=True)
+ return cls(cachedir, config, _ispytest=True)
@classmethod
- def clear_cache(cls, cachedir: Path):
- """Clears the sub-directories used to hold cached directories and values."""
+ def clear_cache(cls, cachedir: Path, _ispytest: bool = False) -> None:
+ """Clear the sub-directories used to hold cached directories and values.
+
+ :meta private:
+ """
+ check_ispytest(_ispytest)
for prefix in (cls._CACHE_PREFIX_DIRS, cls._CACHE_PREFIX_VALUES):
d = cachedir / prefix
if d.is_dir():
rm_rf(d)
@staticmethod
- def cache_dir_from_config(config):
- return resolve_from_str(config.getini("cache_dir"), config.rootdir)
+ def cache_dir_from_config(config: Config, *, _ispytest: bool = False) -> Path:
+ """Get the path to the cache directory for a Config.
+
+ :meta private:
+ """
+ check_ispytest(_ispytest)
+ return resolve_from_str(config.getini("cache_dir"), config.rootpath)
+
+ def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None:
+ """Issue a cache warning.
- def warn(self, fmt, **args):
+ :meta private:
+ """
+ check_ispytest(_ispytest)
import warnings
from _pytest.warning_types import PytestCacheWarning
@@ -86,52 +120,56 @@ class Cache:
stacklevel=3,
)
- def makedir(self, name):
- """ return a directory path object with the given name. If the
- directory does not yet exist, it will be created. You can use it
- to manage files likes e. g. store/retrieve database
- dumps across test sessions.
+ def makedir(self, name: str) -> py.path.local:
+ """Return a directory path object with the given name.
+
+ If the directory does not yet exist, it will be created. You can use
+ it to manage files to e.g. store/retrieve database dumps across test
+ sessions.
- :param name: must be a string not containing a ``/`` separator.
- Make sure the name contains your plugin or application
- identifiers to prevent clashes with other cache users.
+ :param name:
+ Must be a string not containing a ``/`` separator.
+ Make sure the name contains your plugin or application
+ identifiers to prevent clashes with other cache users.
"""
- name = Path(name)
- if len(name.parts) > 1:
+ path = Path(name)
+ if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators")
- res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, name)
+ res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
res.mkdir(exist_ok=True, parents=True)
return py.path.local(res)
- def _getvaluepath(self, key):
+ def _getvaluepath(self, key: str) -> Path:
return self._cachedir.joinpath(self._CACHE_PREFIX_VALUES, Path(key))
- def get(self, key, default):
- """ return cached value for the given key. If no value
- was yet cached or the value cannot be read, the specified
- default is returned.
+ def get(self, key: str, default):
+ """Return the cached value for the given key.
- :param key: must be a ``/`` separated value. Usually the first
- name is the name of your plugin or your application.
- :param default: must be provided in case of a cache-miss or
- invalid cache values.
+ If no value was yet cached or the value cannot be read, the specified
+ default is returned.
+ :param key:
+ Must be a ``/`` separated value. Usually the first
+ name is the name of your plugin or your application.
+ :param default:
+ The value to return in case of a cache-miss or invalid cache value.
"""
path = self._getvaluepath(key)
try:
with path.open("r") as f:
return json.load(f)
- except (ValueError, IOError, OSError):
+ except (ValueError, OSError):
return default
- def set(self, key, value):
- """ save value for the given key.
+ def set(self, key: str, value: object) -> None:
+ """Save value for the given key.
- :param key: must be a ``/`` separated value. Usually the first
- name is the name of your plugin or your application.
- :param value: must be of any combination of basic
- python types, including nested types
- like e. g. lists of dictionaries.
+ :param key:
+ Must be a ``/`` separated value. Usually the first
+ name is the name of your plugin or your application.
+ :param value:
+ Must be of any combination of basic python types,
+ including nested types like lists of dictionaries.
"""
path = self._getvaluepath(key)
try:
@@ -140,21 +178,21 @@ class Cache:
else:
cache_dir_exists_already = self._cachedir.exists()
path.parent.mkdir(exist_ok=True, parents=True)
- except (IOError, OSError):
- self.warn("could not create cache path {path}", path=path)
+ except OSError:
+ self.warn("could not create cache path {path}", path=path, _ispytest=True)
return
if not cache_dir_exists_already:
self._ensure_supporting_files()
data = json.dumps(value, indent=2, sort_keys=True)
try:
f = path.open("w")
- except (IOError, OSError):
- self.warn("cache could not write path {path}", path=path)
+ except OSError:
+ self.warn("cache could not write path {path}", path=path, _ispytest=True)
else:
with f:
f.write(data)
- def _ensure_supporting_files(self):
+ def _ensure_supporting_files(self) -> None:
"""Create supporting files in the cache dir that are not really part of the cache."""
readme_path = self._cachedir / "README.md"
readme_path.write_text(README_CONTENT)
@@ -168,52 +206,65 @@ class Cache:
class LFPluginCollWrapper:
- def __init__(self, lfplugin: "LFPlugin"):
+ def __init__(self, lfplugin: "LFPlugin") -> None:
self.lfplugin = lfplugin
self._collected_at_least_one_failure = False
- @pytest.hookimpl(hookwrapper=True)
- def pytest_make_collect_report(self, collector) -> Generator:
+ @hookimpl(hookwrapper=True)
+ def pytest_make_collect_report(self, collector: nodes.Collector):
if isinstance(collector, Session):
out = yield
- res = out.get_result() # type: CollectReport
+ res: CollectReport = out.get_result()
# Sort any lf-paths to the beginning.
lf_paths = self.lfplugin._last_failed_paths
res.result = sorted(
res.result, key=lambda x: 0 if Path(str(x.fspath)) in lf_paths else 1,
)
- out.force_result(res)
return
elif isinstance(collector, Module):
if Path(str(collector.fspath)) in self.lfplugin._last_failed_paths:
out = yield
res = out.get_result()
-
- filtered_result = [
- x for x in res.result if x.nodeid in self.lfplugin.lastfailed
+ result = res.result
+ lastfailed = self.lfplugin.lastfailed
+
+ # Only filter with known failures.
+ if not self._collected_at_least_one_failure:
+ if not any(x.nodeid in lastfailed for x in result):
+ return
+ self.lfplugin.config.pluginmanager.register(
+ LFPluginCollSkipfiles(self.lfplugin), "lfplugin-collskip"
+ )
+ self._collected_at_least_one_failure = True
+
+ session = collector.session
+ result[:] = [
+ x
+ for x in result
+ if x.nodeid in lastfailed
+ # Include any passed arguments (not trivial to filter).
+ or session.isinitpath(x.fspath)
+ # Keep all sub-collectors.
+ or isinstance(x, nodes.Collector)
]
- if filtered_result:
- res.result = filtered_result
- out.force_result(res)
-
- if not self._collected_at_least_one_failure:
- self.lfplugin.config.pluginmanager.register(
- LFPluginCollSkipfiles(self.lfplugin), "lfplugin-collskip"
- )
- self._collected_at_least_one_failure = True
- return res
+ return
yield
class LFPluginCollSkipfiles:
- def __init__(self, lfplugin: "LFPlugin"):
+ def __init__(self, lfplugin: "LFPlugin") -> None:
self.lfplugin = lfplugin
- @pytest.hookimpl
- def pytest_make_collect_report(self, collector) -> Optional[CollectReport]:
- if isinstance(collector, Module):
+ @hookimpl
+ def pytest_make_collect_report(
+ self, collector: nodes.Collector
+ ) -> Optional[CollectReport]:
+ # Packages are Modules, but _last_failed_paths only contains
+ # test-bearing paths and doesn't try to include the paths of their
+ # packages, so don't filter them.
+ if isinstance(collector, Module) and not isinstance(collector, Package):
if Path(str(collector.fspath)) not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1
@@ -224,18 +275,16 @@ class LFPluginCollSkipfiles:
class LFPlugin:
- """ Plugin which implements the --lf (run last-failing) option """
+ """Plugin which implements the --lf (run last-failing) option."""
def __init__(self, config: Config) -> None:
self.config = config
active_keys = "lf", "failedfirst"
self.active = any(config.getoption(key) for key in active_keys)
assert config.cache
- self.lastfailed = config.cache.get(
- "cache/lastfailed", {}
- ) # type: Dict[str, bool]
- self._previously_failed_count = None
- self._report_status = None
+ self.lastfailed: Dict[str, bool] = config.cache.get("cache/lastfailed", {})
+ self._previously_failed_count: Optional[int] = None
+ self._report_status: Optional[str] = None
self._skipped_files = 0 # count skipped files during collection due to --lf
if config.getoption("lf"):
@@ -245,22 +294,23 @@ class LFPlugin:
)
def get_last_failed_paths(self) -> Set[Path]:
- """Returns a set with all Paths()s of the previously failed nodeids."""
- rootpath = Path(str(self.config.rootdir))
+ """Return a set with all Paths()s of the previously failed nodeids."""
+ rootpath = self.config.rootpath
result = {rootpath / nodeid.split("::")[0] for nodeid in self.lastfailed}
return {x for x in result if x.exists()}
- def pytest_report_collectionfinish(self):
+ def pytest_report_collectionfinish(self) -> Optional[str]:
if self.active and self.config.getoption("verbose") >= 0:
return "run-last-failure: %s" % self._report_status
+ return None
- def pytest_runtest_logreport(self, report):
+ def pytest_runtest_logreport(self, report: TestReport) -> None:
if (report.when == "call" and report.passed) or report.skipped:
self.lastfailed.pop(report.nodeid, None)
elif report.failed:
self.lastfailed[report.nodeid] = True
- def pytest_collectreport(self, report):
+ def pytest_collectreport(self, report: CollectReport) -> None:
passed = report.outcome in ("passed", "skipped")
if passed:
if report.nodeid in self.lastfailed:
@@ -269,7 +319,12 @@ class LFPlugin:
else:
self.lastfailed[report.nodeid] = True
- def pytest_collection_modifyitems(self, session, config, items):
+ @hookimpl(hookwrapper=True, tryfirst=True)
+ def pytest_collection_modifyitems(
+ self, config: Config, items: List[nodes.Item]
+ ) -> Generator[None, None, None]:
+ yield
+
if not self.active:
return
@@ -316,30 +371,35 @@ class LFPlugin:
else:
self._report_status += "not deselecting items."
- def pytest_sessionfinish(self, session):
+ def pytest_sessionfinish(self, session: Session) -> None:
config = self.config
- if config.getoption("cacheshow") or hasattr(config, "slaveinput"):
+ if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
+ assert config.cache is not None
saved_lastfailed = config.cache.get("cache/lastfailed", {})
if saved_lastfailed != self.lastfailed:
config.cache.set("cache/lastfailed", self.lastfailed)
class NFPlugin:
- """ Plugin which implements the --nf (run new-first) option """
+ """Plugin which implements the --nf (run new-first) option."""
- def __init__(self, config):
+ def __init__(self, config: Config) -> None:
self.config = config
self.active = config.option.newfirst
- self.cached_nodeids = config.cache.get("cache/nodeids", [])
+ assert config.cache is not None
+ self.cached_nodeids = set(config.cache.get("cache/nodeids", []))
+ @hookimpl(hookwrapper=True, tryfirst=True)
def pytest_collection_modifyitems(
- self, session: Session, config: Config, items: List[nodes.Item]
- ) -> None:
- new_items = OrderedDict() # type: OrderedDict[str, nodes.Item]
+ self, items: List[nodes.Item]
+ ) -> Generator[None, None, None]:
+ yield
+
if self.active:
- other_items = OrderedDict() # type: OrderedDict[str, nodes.Item]
+ new_items: Dict[str, nodes.Item] = {}
+ other_items: Dict[str, nodes.Item] = {}
for item in items:
if item.nodeid not in self.cached_nodeids:
new_items[item.nodeid] = item
@@ -349,24 +409,26 @@ class NFPlugin:
items[:] = self._get_increasing_order(
new_items.values()
) + self._get_increasing_order(other_items.values())
+ self.cached_nodeids.update(new_items)
else:
- for item in items:
- if item.nodeid not in self.cached_nodeids:
- new_items[item.nodeid] = item
- self.cached_nodeids.extend(new_items)
+ self.cached_nodeids.update(item.nodeid for item in items)
- def _get_increasing_order(self, items):
- return sorted(items, key=lambda item: item.fspath.mtime(), reverse=True)
+ def _get_increasing_order(self, items: Iterable[nodes.Item]) -> List[nodes.Item]:
+ return sorted(items, key=lambda item: item.fspath.mtime(), reverse=True) # type: ignore[no-any-return]
- def pytest_sessionfinish(self, session):
+ def pytest_sessionfinish(self) -> None:
config = self.config
- if config.getoption("cacheshow") or hasattr(config, "slaveinput"):
+ if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
- config.cache.set("cache/nodeids", self.cached_nodeids)
+ if config.getoption("collectonly"):
+ return
+
+ assert config.cache is not None
+ config.cache.set("cache/nodeids", sorted(self.cached_nodeids))
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--lf",
@@ -381,9 +443,9 @@ def pytest_addoption(parser):
"--failed-first",
action="store_true",
dest="failedfirst",
- help="run all tests but run the last failures first. "
+ help="run all tests, but run the last failures first.\n"
"This may re-order tests and thus lead to "
- "repeated fixture setup/teardown",
+ "repeated fixture setup/teardown.",
)
group.addoption(
"--nf",
@@ -424,53 +486,58 @@ def pytest_addoption(parser):
)
-def pytest_cmdline_main(config):
+def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.cacheshow:
from _pytest.main import wrap_session
return wrap_session(config, cacheshow)
+ return None
-@pytest.hookimpl(tryfirst=True)
+@hookimpl(tryfirst=True)
def pytest_configure(config: Config) -> None:
- config.cache = Cache.for_config(config)
+ config.cache = Cache.for_config(config, _ispytest=True)
config.pluginmanager.register(LFPlugin(config), "lfplugin")
config.pluginmanager.register(NFPlugin(config), "nfplugin")
-@pytest.fixture
-def cache(request):
- """
- Return a cache object that can persist state between testing sessions.
+@fixture
+def cache(request: FixtureRequest) -> Cache:
+ """Return a cache object that can persist state between testing sessions.
cache.get(key, default)
cache.set(key, value)
- Keys must be a ``/`` separated value, where the first part is usually the
+ Keys must be ``/`` separated strings, where the first part is usually the
name of your plugin or application to avoid clashes with other cache users.
Values can be any object handled by the json stdlib module.
"""
+ assert request.config.cache is not None
return request.config.cache
-def pytest_report_header(config):
+def pytest_report_header(config: Config) -> Optional[str]:
"""Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
+ assert config.cache is not None
cachedir = config.cache._cachedir
# TODO: evaluate generating upward relative paths
# starting with .., ../.. if sensible
try:
- displaypath = cachedir.relative_to(config.rootdir)
+ displaypath = cachedir.relative_to(config.rootpath)
except ValueError:
displaypath = cachedir
- return "cachedir: {}".format(displaypath)
+ return f"cachedir: {displaypath}"
+ return None
-def cacheshow(config, session):
+def cacheshow(config: Config, session: Session) -> int:
from pprint import pformat
+ assert config.cache is not None
+
tw = TerminalWriter()
tw.line("cachedir: " + str(config.cache._cachedir))
if not config.cache._cachedir.is_dir():
@@ -486,7 +553,7 @@ def cacheshow(config, session):
vdir = basedir / Cache._CACHE_PREFIX_VALUES
tw.sep("-", "cache values for %r" % glob)
for valpath in sorted(x for x in vdir.rglob(glob) if x.is_file()):
- key = valpath.relative_to(vdir)
+ key = str(valpath.relative_to(vdir))
val = config.cache.get(key, dummy)
if val is dummy:
tw.line("%s contains unreadable content, will be ignored" % key)
@@ -503,6 +570,6 @@ def cacheshow(config, session):
# if p.check(dir=1):
# print("%s/" % p.relto(basedir))
if p.is_file():
- key = p.relative_to(basedir)
- tw.line("{} is a file of length {:d}".format(key, p.stat().st_size))
+ key = str(p.relative_to(basedir))
+ tw.line(f"{key} is a file of length {p.stat().st_size:d}")
return 0
diff --git a/contrib/python/pytest/py3/_pytest/capture.py b/contrib/python/pytest/py3/_pytest/capture.py
index 673bb07a9b..086302658c 100644
--- a/contrib/python/pytest/py3/_pytest/capture.py
+++ b/contrib/python/pytest/py3/_pytest/capture.py
@@ -1,40 +1,45 @@
-"""
-per-test stdout/stderr capturing mechanism.
-
-"""
-import collections
+"""Per-test stdout/stderr capturing mechanism."""
import contextlib
+import functools
import io
import os
import sys
from io import UnsupportedOperation
from tempfile import TemporaryFile
-from typing import BinaryIO
+from typing import Any
+from typing import AnyStr
from typing import Generator
-from typing import Iterable
+from typing import Generic
+from typing import Iterator
from typing import Optional
+from typing import TextIO
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
-import pytest
-from _pytest.compat import CaptureAndPassthroughIO
-from _pytest.compat import CaptureIO
-from _pytest.compat import TYPE_CHECKING
+from _pytest.compat import final
from _pytest.config import Config
-from _pytest.fixtures import FixtureRequest
+from _pytest.config import hookimpl
+from _pytest.config.argparsing import Parser
+from _pytest.deprecated import check_ispytest
+from _pytest.fixtures import fixture
+from _pytest.fixtures import SubRequest
+from _pytest.nodes import Collector
+from _pytest.nodes import File
+from _pytest.nodes import Item
if TYPE_CHECKING:
from typing_extensions import Literal
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"]
-patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"}
-
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group._addoption(
"--capture",
action="store",
- default="fd" if hasattr(os, "dup") else "sys",
+ default="fd",
metavar="method",
choices=["fd", "sys", "no", "tee-sys"],
help="per-test capturing method: one of fd|sys|no|tee-sys.",
@@ -48,7 +53,102 @@ def pytest_addoption(parser):
)
-@pytest.hookimpl(hookwrapper=True)
+def _colorama_workaround() -> None:
+ """Ensure colorama is imported so that it attaches to the correct stdio
+ handles on Windows.
+
+ colorama uses the terminal on import time. So if something does the
+ first import of colorama while I/O capture is active, colorama will
+ fail in various ways.
+ """
+ if sys.platform.startswith("win32"):
+ try:
+ import colorama # noqa: F401
+ except ImportError:
+ pass
+
+
+def _readline_workaround() -> None:
+ """Ensure readline is imported so that it attaches to the correct stdio
+ handles on Windows.
+
+ Pdb uses readline support where available--when not running from the Python
+ prompt, the readline module is not imported until running the pdb REPL. If
+ running pytest with the --pdb option this means the readline module is not
+ imported until after I/O capture has been started.
+
+ This is a problem for pyreadline, which is often used to implement readline
+ support on Windows, as it does not attach to the correct handles for stdout
+ and/or stdin if they have been redirected by the FDCapture mechanism. This
+ workaround ensures that readline is imported before I/O capture is setup so
+ that it can attach to the actual stdin/out for the console.
+
+ See https://github.com/pytest-dev/pytest/pull/1281.
+ """
+ if sys.platform.startswith("win32"):
+ try:
+ import readline # noqa: F401
+ except ImportError:
+ pass
+
+
+def _py36_windowsconsoleio_workaround(stream: TextIO) -> None:
+ """Workaround for Windows Unicode console handling on Python>=3.6.
+
+ Python 3.6 implemented Unicode console handling for Windows. This works
+ by reading/writing to the raw console handle using
+ ``{Read,Write}ConsoleW``.
+
+ The problem is that we are going to ``dup2`` over the stdio file
+ descriptors when doing ``FDCapture`` and this will ``CloseHandle`` the
+ handles used by Python to write to the console. Though there is still some
+ weirdness and the console handle seems to only be closed randomly and not
+ on the first call to ``CloseHandle``, or maybe it gets reopened with the
+ same handle value when we suspend capturing.
+
+ The workaround in this case will reopen stdio with a different fd which
+ also means a different handle by replicating the logic in
+ "Py_lifecycle.c:initstdio/create_stdio".
+
+ :param stream:
+ In practice ``sys.stdout`` or ``sys.stderr``, but given
+ here as parameter for unittesting purposes.
+
+ See https://github.com/pytest-dev/py/issues/103.
+ """
+ if not sys.platform.startswith("win32") or hasattr(sys, "pypy_version_info"):
+ return
+
+ # Bail out if ``stream`` doesn't seem like a proper ``io`` stream (#2666).
+ if not hasattr(stream, "buffer"): # type: ignore[unreachable]
+ return
+
+ buffered = hasattr(stream.buffer, "raw")
+ raw_stdout = stream.buffer.raw if buffered else stream.buffer # type: ignore[attr-defined]
+
+ if not isinstance(raw_stdout, io._WindowsConsoleIO): # type: ignore[attr-defined]
+ return
+
+ def _reopen_stdio(f, mode):
+ if not buffered and mode[0] == "w":
+ buffering = 0
+ else:
+ buffering = -1
+
+ return io.TextIOWrapper(
+ open(os.dup(f.fileno()), mode, buffering), # type: ignore[arg-type]
+ f.encoding,
+ f.errors,
+ f.newlines,
+ f.line_buffering,
+ )
+
+ sys.stdin = _reopen_stdio(sys.stdin, "rb")
+ sys.stdout = _reopen_stdio(sys.stdout, "wb")
+ sys.stderr = _reopen_stdio(sys.stderr, "wb")
+
+
+@hookimpl(hookwrapper=True)
def pytest_load_initial_conftests(early_config: Config):
ns = early_config.known_args_namespace
if ns.capture == "fd":
@@ -59,10 +159,10 @@ def pytest_load_initial_conftests(early_config: Config):
capman = CaptureManager(ns.capture)
pluginmanager.register(capman, "capturemanager")
- # make sure that capturemanager is properly reset at final shutdown
+ # Make sure that capturemanager is properly reset at final shutdown.
early_config.add_cleanup(capman.stop_global_capturing)
- # finally trigger conftest loading but while capturing (issue93)
+ # Finally trigger conftest loading but while capturing (issue #93).
capman.start_global_capturing()
outcome = yield
capman.suspend_global_capture()
@@ -72,395 +172,394 @@ def pytest_load_initial_conftests(early_config: Config):
sys.stderr.write(err)
-def _get_multicapture(method: "_CaptureMethod") -> "MultiCapture":
- if method == "fd":
- return MultiCapture(out=True, err=True, Capture=FDCapture)
- elif method == "sys":
- return MultiCapture(out=True, err=True, Capture=SysCapture)
- elif method == "no":
- return MultiCapture(out=False, err=False, in_=False)
- elif method == "tee-sys":
- return MultiCapture(out=True, err=True, in_=False, Capture=TeeSysCapture)
- raise ValueError("unknown capturing method: {!r}".format(method))
+# IO Helpers.
-class CaptureManager:
- """
- Capture plugin, manages that the appropriate capture method is enabled/disabled during collection and each
- test phase (setup, call, teardown). After each of those points, the captured output is obtained and
- attached to the collection/runtest report.
+class EncodedFile(io.TextIOWrapper):
+ __slots__ = ()
- There are two levels of capture:
- * global: which is enabled by default and can be suppressed by the ``-s`` option. This is always enabled/disabled
- during collection and each test phase.
- * fixture: when a test function or one of its fixture depend on the ``capsys`` or ``capfd`` fixtures. In this
- case special handling is needed to ensure the fixtures take precedence over the global capture.
- """
+ @property
+ def name(self) -> str:
+ # Ensure that file.name is a string. Workaround for a Python bug
+ # fixed in >=3.7.4: https://bugs.python.org/issue36015
+ return repr(self.buffer)
- def __init__(self, method: "_CaptureMethod") -> None:
- self._method = method
- self._global_capturing = None
- self._capture_fixture = None # type: Optional[CaptureFixture]
+ @property
+ def mode(self) -> str:
+ # TextIOWrapper doesn't expose a mode, but at least some of our
+ # tests check it.
+ return self.buffer.mode.replace("b", "")
- def __repr__(self):
- return "<CaptureManager _method={!r} _global_capturing={!r} _capture_fixture={!r}>".format(
- self._method, self._global_capturing, self._capture_fixture
+
+class CaptureIO(io.TextIOWrapper):
+ def __init__(self) -> None:
+ super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
+
+ def getvalue(self) -> str:
+ assert isinstance(self.buffer, io.BytesIO)
+ return self.buffer.getvalue().decode("UTF-8")
+
+
+class TeeCaptureIO(CaptureIO):
+ def __init__(self, other: TextIO) -> None:
+ self._other = other
+ super().__init__()
+
+ def write(self, s: str) -> int:
+ super().write(s)
+ return self._other.write(s)
+
+
+class DontReadFromInput:
+ encoding = None
+
+ def read(self, *args):
+ raise OSError(
+ "pytest: reading from stdin while output is captured! Consider using `-s`."
)
- def is_capturing(self):
- if self.is_globally_capturing():
- return "global"
- if self._capture_fixture:
- return "fixture %s" % self._capture_fixture.request.fixturename
- return False
+ readline = read
+ readlines = read
+ __next__ = read
- # Global capturing control
+ def __iter__(self):
+ return self
- def is_globally_capturing(self):
- return self._method != "no"
+ def fileno(self) -> int:
+ raise UnsupportedOperation("redirected stdin is pseudofile, has no fileno()")
- def start_global_capturing(self):
- assert self._global_capturing is None
- self._global_capturing = _get_multicapture(self._method)
- self._global_capturing.start_capturing()
+ def isatty(self) -> bool:
+ return False
- def stop_global_capturing(self):
- if self._global_capturing is not None:
- self._global_capturing.pop_outerr_to_orig()
- self._global_capturing.stop_capturing()
- self._global_capturing = None
+ def close(self) -> None:
+ pass
- def resume_global_capture(self):
- # During teardown of the python process, and on rare occasions, capture
- # attributes can be `None` while trying to resume global capture.
- if self._global_capturing is not None:
- self._global_capturing.resume_capturing()
+ @property
+ def buffer(self):
+ return self
- def suspend_global_capture(self, in_=False):
- cap = getattr(self, "_global_capturing", None)
- if cap is not None:
- cap.suspend_capturing(in_=in_)
- def suspend(self, in_=False):
- # Need to undo local capsys-et-al if it exists before disabling global capture.
- self.suspend_fixture()
- self.suspend_global_capture(in_)
+# Capture classes.
- def resume(self):
- self.resume_global_capture()
- self.resume_fixture()
- def read_global_capture(self):
- return self._global_capturing.readouterr()
+patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"}
- # Fixture Control (it's just forwarding, think about removing this later)
- @contextlib.contextmanager
- def _capturing_for_request(
- self, request: FixtureRequest
- ) -> Generator["CaptureFixture", None, None]:
- """
- Context manager that creates a ``CaptureFixture`` instance for the
- given ``request``, ensuring there is only a single one being requested
- at the same time.
+class NoCapture:
+ EMPTY_BUFFER = None
+ __init__ = start = done = suspend = resume = lambda *args: None
- This is used as a helper with ``capsys``, ``capfd`` etc.
- """
- if self._capture_fixture:
- other_name = next(
- k
- for k, v in map_fixname_class.items()
- if v is self._capture_fixture.captureclass
- )
- raise request.raiseerror(
- "cannot use {} and {} at the same time".format(
- request.fixturename, other_name
- )
- )
- capture_class = map_fixname_class[request.fixturename]
- self._capture_fixture = CaptureFixture(capture_class, request)
- self.activate_fixture()
- yield self._capture_fixture
- self._capture_fixture.close()
- self._capture_fixture = None
- def activate_fixture(self):
- """If the current item is using ``capsys`` or ``capfd``, activate them so they take precedence over
- the global capture.
- """
- if self._capture_fixture:
- self._capture_fixture._start()
+class SysCaptureBinary:
- def deactivate_fixture(self):
- """Deactivates the ``capsys`` or ``capfd`` fixture of this item, if any."""
- if self._capture_fixture:
- self._capture_fixture.close()
+ EMPTY_BUFFER = b""
- def suspend_fixture(self):
- if self._capture_fixture:
- self._capture_fixture._suspend()
+ def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None:
+ name = patchsysdict[fd]
+ self._old = getattr(sys, name)
+ self.name = name
+ if tmpfile is None:
+ if name == "stdin":
+ tmpfile = DontReadFromInput()
+ else:
+ tmpfile = CaptureIO() if not tee else TeeCaptureIO(self._old)
+ self.tmpfile = tmpfile
+ self._state = "initialized"
- def resume_fixture(self):
- if self._capture_fixture:
- self._capture_fixture._resume()
+ def repr(self, class_name: str) -> str:
+ return "<{} {} _old={} _state={!r} tmpfile={!r}>".format(
+ class_name,
+ self.name,
+ hasattr(self, "_old") and repr(self._old) or "<UNSET>",
+ self._state,
+ self.tmpfile,
+ )
- # Helper context managers
+ def __repr__(self) -> str:
+ return "<{} {} _old={} _state={!r} tmpfile={!r}>".format(
+ self.__class__.__name__,
+ self.name,
+ hasattr(self, "_old") and repr(self._old) or "<UNSET>",
+ self._state,
+ self.tmpfile,
+ )
- @contextlib.contextmanager
- def global_and_fixture_disabled(self):
- """Context manager to temporarily disable global and current fixture capturing."""
- self.suspend()
- try:
- yield
- finally:
- self.resume()
+ def _assert_state(self, op: str, states: Tuple[str, ...]) -> None:
+ assert (
+ self._state in states
+ ), "cannot {} in state {!r}: expected one of {}".format(
+ op, self._state, ", ".join(states)
+ )
- @contextlib.contextmanager
- def item_capture(self, when, item):
- self.resume_global_capture()
- self.activate_fixture()
- try:
- yield
- finally:
- self.deactivate_fixture()
- self.suspend_global_capture(in_=False)
+ def start(self) -> None:
+ self._assert_state("start", ("initialized",))
+ setattr(sys, self.name, self.tmpfile)
+ self._state = "started"
- out, err = self.read_global_capture()
- item.add_report_section(when, "stdout", out)
- item.add_report_section(when, "stderr", err)
+ def snap(self):
+ self._assert_state("snap", ("started", "suspended"))
+ self.tmpfile.seek(0)
+ res = self.tmpfile.buffer.read()
+ self.tmpfile.seek(0)
+ self.tmpfile.truncate()
+ return res
- # Hooks
+ def done(self) -> None:
+ self._assert_state("done", ("initialized", "started", "suspended", "done"))
+ if self._state == "done":
+ return
+ setattr(sys, self.name, self._old)
+ del self._old
+ self.tmpfile.close()
+ self._state = "done"
- @pytest.hookimpl(hookwrapper=True)
- def pytest_make_collect_report(self, collector):
- if isinstance(collector, pytest.File):
- self.resume_global_capture()
- outcome = yield
- self.suspend_global_capture()
- out, err = self.read_global_capture()
- rep = outcome.get_result()
- if out:
- rep.sections.append(("Captured stdout", out))
- if err:
- rep.sections.append(("Captured stderr", err))
- else:
- yield
+ def suspend(self) -> None:
+ self._assert_state("suspend", ("started", "suspended"))
+ setattr(sys, self.name, self._old)
+ self._state = "suspended"
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_setup(self, item):
- with self.item_capture("setup", item):
- yield
+ def resume(self) -> None:
+ self._assert_state("resume", ("started", "suspended"))
+ if self._state == "started":
+ return
+ setattr(sys, self.name, self.tmpfile)
+ self._state = "started"
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_call(self, item):
- with self.item_capture("call", item):
- yield
+ def writeorg(self, data) -> None:
+ self._assert_state("writeorg", ("started", "suspended"))
+ self._old.flush()
+ self._old.buffer.write(data)
+ self._old.buffer.flush()
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_teardown(self, item):
- with self.item_capture("teardown", item):
- yield
- @pytest.hookimpl(tryfirst=True)
- def pytest_keyboard_interrupt(self, excinfo):
- self.stop_global_capturing()
+class SysCapture(SysCaptureBinary):
+ EMPTY_BUFFER = "" # type: ignore[assignment]
- @pytest.hookimpl(tryfirst=True)
- def pytest_internalerror(self, excinfo):
- self.stop_global_capturing()
+ def snap(self):
+ res = self.tmpfile.getvalue()
+ self.tmpfile.seek(0)
+ self.tmpfile.truncate()
+ return res
+ def writeorg(self, data):
+ self._assert_state("writeorg", ("started", "suspended"))
+ self._old.write(data)
+ self._old.flush()
-@pytest.fixture
-def capsys(request):
- """Enable text capturing of writes to ``sys.stdout`` and ``sys.stderr``.
- The captured output is made available via ``capsys.readouterr()`` method
- calls, which return a ``(out, err)`` namedtuple.
- ``out`` and ``err`` will be ``text`` objects.
+class FDCaptureBinary:
+ """Capture IO to/from a given OS-level file descriptor.
+
+ snap() produces `bytes`.
"""
- capman = request.config.pluginmanager.getplugin("capturemanager")
- with capman._capturing_for_request(request) as fixture:
- yield fixture
+ EMPTY_BUFFER = b""
-@pytest.fixture
-def capsysbinary(request):
- """Enable bytes capturing of writes to ``sys.stdout`` and ``sys.stderr``.
+ def __init__(self, targetfd: int) -> None:
+ self.targetfd = targetfd
- The captured output is made available via ``capsysbinary.readouterr()``
- method calls, which return a ``(out, err)`` namedtuple.
- ``out`` and ``err`` will be ``bytes`` objects.
- """
- capman = request.config.pluginmanager.getplugin("capturemanager")
- with capman._capturing_for_request(request) as fixture:
- yield fixture
+ try:
+ os.fstat(targetfd)
+ except OSError:
+ # FD capturing is conceptually simple -- create a temporary file,
+ # redirect the FD to it, redirect back when done. But when the
+ # target FD is invalid it throws a wrench into this loveley scheme.
+ #
+ # Tests themselves shouldn't care if the FD is valid, FD capturing
+ # should work regardless of external circumstances. So falling back
+ # to just sys capturing is not a good option.
+ #
+ # Further complications are the need to support suspend() and the
+ # possibility of FD reuse (e.g. the tmpfile getting the very same
+ # target FD). The following approach is robust, I believe.
+ self.targetfd_invalid: Optional[int] = os.open(os.devnull, os.O_RDWR)
+ os.dup2(self.targetfd_invalid, targetfd)
+ else:
+ self.targetfd_invalid = None
+ self.targetfd_save = os.dup(targetfd)
+ if targetfd == 0:
+ self.tmpfile = open(os.devnull)
+ self.syscapture = SysCapture(targetfd)
+ else:
+ self.tmpfile = EncodedFile(
+ TemporaryFile(buffering=0),
+ encoding="utf-8",
+ errors="replace",
+ newline="",
+ write_through=True,
+ )
+ if targetfd in patchsysdict:
+ self.syscapture = SysCapture(targetfd, self.tmpfile)
+ else:
+ self.syscapture = NoCapture()
-@pytest.fixture
-def capfd(request):
- """Enable text capturing of writes to file descriptors ``1`` and ``2``.
+ self._state = "initialized"
- The captured output is made available via ``capfd.readouterr()`` method
- calls, which return a ``(out, err)`` namedtuple.
- ``out`` and ``err`` will be ``text`` objects.
- """
- if not hasattr(os, "dup"):
- pytest.skip(
- "capfd fixture needs os.dup function which is not available in this system"
+ def __repr__(self) -> str:
+ return "<{} {} oldfd={} _state={!r} tmpfile={!r}>".format(
+ self.__class__.__name__,
+ self.targetfd,
+ self.targetfd_save,
+ self._state,
+ self.tmpfile,
)
- capman = request.config.pluginmanager.getplugin("capturemanager")
- with capman._capturing_for_request(request) as fixture:
- yield fixture
+ def _assert_state(self, op: str, states: Tuple[str, ...]) -> None:
+ assert (
+ self._state in states
+ ), "cannot {} in state {!r}: expected one of {}".format(
+ op, self._state, ", ".join(states)
+ )
-@pytest.fixture
-def capfdbinary(request):
- """Enable bytes capturing of writes to file descriptors ``1`` and ``2``.
+ def start(self) -> None:
+ """Start capturing on targetfd using memorized tmpfile."""
+ self._assert_state("start", ("initialized",))
+ os.dup2(self.tmpfile.fileno(), self.targetfd)
+ self.syscapture.start()
+ self._state = "started"
- The captured output is made available via ``capfd.readouterr()`` method
- calls, which return a ``(out, err)`` namedtuple.
- ``out`` and ``err`` will be ``byte`` objects.
- """
- if not hasattr(os, "dup"):
- pytest.skip(
- "capfdbinary fixture needs os.dup function which is not available in this system"
- )
- capman = request.config.pluginmanager.getplugin("capturemanager")
- with capman._capturing_for_request(request) as fixture:
- yield fixture
+ def snap(self):
+ self._assert_state("snap", ("started", "suspended"))
+ self.tmpfile.seek(0)
+ res = self.tmpfile.buffer.read()
+ self.tmpfile.seek(0)
+ self.tmpfile.truncate()
+ return res
+ def done(self) -> None:
+ """Stop capturing, restore streams, return original capture file,
+ seeked to position zero."""
+ self._assert_state("done", ("initialized", "started", "suspended", "done"))
+ if self._state == "done":
+ return
+ os.dup2(self.targetfd_save, self.targetfd)
+ os.close(self.targetfd_save)
+ if self.targetfd_invalid is not None:
+ if self.targetfd_invalid != self.targetfd:
+ os.close(self.targetfd)
+ os.close(self.targetfd_invalid)
+ self.syscapture.done()
+ self.tmpfile.close()
+ self._state = "done"
-class CaptureFixture:
- """
- Object returned by :py:func:`capsys`, :py:func:`capsysbinary`, :py:func:`capfd` and :py:func:`capfdbinary`
- fixtures.
+ def suspend(self) -> None:
+ self._assert_state("suspend", ("started", "suspended"))
+ if self._state == "suspended":
+ return
+ self.syscapture.suspend()
+ os.dup2(self.targetfd_save, self.targetfd)
+ self._state = "suspended"
+
+ def resume(self) -> None:
+ self._assert_state("resume", ("started", "suspended"))
+ if self._state == "started":
+ return
+ self.syscapture.resume()
+ os.dup2(self.tmpfile.fileno(), self.targetfd)
+ self._state = "started"
+
+ def writeorg(self, data):
+ """Write to original file descriptor."""
+ self._assert_state("writeorg", ("started", "suspended"))
+ os.write(self.targetfd_save, data)
+
+
+class FDCapture(FDCaptureBinary):
+ """Capture IO to/from a given OS-level file descriptor.
+
+ snap() produces text.
"""
- def __init__(self, captureclass, request):
- self.captureclass = captureclass
- self.request = request
- self._capture = None
- self._captured_out = self.captureclass.EMPTY_BUFFER
- self._captured_err = self.captureclass.EMPTY_BUFFER
+ # Ignore type because it doesn't match the type in the superclass (bytes).
+ EMPTY_BUFFER = "" # type: ignore
- def _start(self):
- if self._capture is None:
- self._capture = MultiCapture(
- out=True, err=True, in_=False, Capture=self.captureclass
- )
- self._capture.start_capturing()
+ def snap(self):
+ self._assert_state("snap", ("started", "suspended"))
+ self.tmpfile.seek(0)
+ res = self.tmpfile.read()
+ self.tmpfile.seek(0)
+ self.tmpfile.truncate()
+ return res
- def close(self):
- if self._capture is not None:
- out, err = self._capture.pop_outerr_to_orig()
- self._captured_out += out
- self._captured_err += err
- self._capture.stop_capturing()
- self._capture = None
+ def writeorg(self, data):
+ """Write to original file descriptor."""
+ super().writeorg(data.encode("utf-8")) # XXX use encoding of original stream
- def readouterr(self):
- """Read and return the captured output so far, resetting the internal buffer.
- :return: captured content as a namedtuple with ``out`` and ``err`` string attributes
- """
- captured_out, captured_err = self._captured_out, self._captured_err
- if self._capture is not None:
- out, err = self._capture.readouterr()
- captured_out += out
- captured_err += err
- self._captured_out = self.captureclass.EMPTY_BUFFER
- self._captured_err = self.captureclass.EMPTY_BUFFER
- return CaptureResult(captured_out, captured_err)
+# MultiCapture
- def _suspend(self):
- """Suspends this fixture's own capturing temporarily."""
- if self._capture is not None:
- self._capture.suspend_capturing()
- def _resume(self):
- """Resumes this fixture's own capturing temporarily."""
- if self._capture is not None:
- self._capture.resume_capturing()
+# This class was a namedtuple, but due to mypy limitation[0] it could not be
+# made generic, so was replaced by a regular class which tries to emulate the
+# pertinent parts of a namedtuple. If the mypy limitation is ever lifted, can
+# make it a namedtuple again.
+# [0]: https://github.com/python/mypy/issues/685
+@final
+@functools.total_ordering
+class CaptureResult(Generic[AnyStr]):
+ """The result of :method:`CaptureFixture.readouterr`."""
- @contextlib.contextmanager
- def disabled(self):
- """Temporarily disables capture while inside the 'with' block."""
- capmanager = self.request.config.pluginmanager.getplugin("capturemanager")
- with capmanager.global_and_fixture_disabled():
- yield
+ __slots__ = ("out", "err")
+ def __init__(self, out: AnyStr, err: AnyStr) -> None:
+ self.out: AnyStr = out
+ self.err: AnyStr = err
-def safe_text_dupfile(f, mode, default_encoding="UTF8"):
- """ return an open text file object that's a duplicate of f on the
- FD-level if possible.
- """
- encoding = getattr(f, "encoding", None)
- try:
- fd = f.fileno()
- except Exception:
- if "b" not in getattr(f, "mode", "") and hasattr(f, "encoding"):
- # we seem to have a text stream, let's just use it
- return f
- else:
- newfd = os.dup(fd)
- if "b" not in mode:
- mode += "b"
- f = os.fdopen(newfd, mode, 0) # no buffering
- return EncodedFile(f, encoding or default_encoding)
-
-
-class EncodedFile:
- errors = "strict" # possibly needed by py3 code (issue555)
-
- def __init__(self, buffer: BinaryIO, encoding: str) -> None:
- self.buffer = buffer
- self.encoding = encoding
+ def __len__(self) -> int:
+ return 2
- def write(self, s: str) -> int:
- if not isinstance(s, str):
- raise TypeError(
- "write() argument must be str, not {}".format(type(s).__name__)
- )
- return self.buffer.write(s.encode(self.encoding, "replace"))
+ def __iter__(self) -> Iterator[AnyStr]:
+ return iter((self.out, self.err))
- def writelines(self, lines: Iterable[str]) -> None:
- self.buffer.writelines(x.encode(self.encoding, "replace") for x in lines)
+ def __getitem__(self, item: int) -> AnyStr:
+ return tuple(self)[item]
- @property
- def name(self) -> str:
- """Ensure that file.name is a string."""
- return repr(self.buffer)
+ def _replace(
+ self, *, out: Optional[AnyStr] = None, err: Optional[AnyStr] = None
+ ) -> "CaptureResult[AnyStr]":
+ return CaptureResult(
+ out=self.out if out is None else out, err=self.err if err is None else err
+ )
- @property
- def mode(self) -> str:
- return self.buffer.mode.replace("b", "")
+ def count(self, value: AnyStr) -> int:
+ return tuple(self).count(value)
+
+ def index(self, value) -> int:
+ return tuple(self).index(value)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, (CaptureResult, tuple)):
+ return NotImplemented
+ return tuple(self) == tuple(other)
- def __getattr__(self, name):
- return getattr(object.__getattribute__(self, "buffer"), name)
+ def __hash__(self) -> int:
+ return hash(tuple(self))
+ def __lt__(self, other: object) -> bool:
+ if not isinstance(other, (CaptureResult, tuple)):
+ return NotImplemented
+ return tuple(self) < tuple(other)
-CaptureResult = collections.namedtuple("CaptureResult", ["out", "err"])
+ def __repr__(self) -> str:
+ return f"CaptureResult(out={self.out!r}, err={self.err!r})"
-class MultiCapture:
- out = err = in_ = None
+class MultiCapture(Generic[AnyStr]):
_state = None
_in_suspended = False
- def __init__(self, out=True, err=True, in_=True, Capture=None):
- if in_:
- self.in_ = Capture(0)
- if out:
- self.out = Capture(1)
- if err:
- self.err = Capture(2)
+ def __init__(self, in_, out, err) -> None:
+ self.in_ = in_
+ self.out = out
+ self.err = err
- def __repr__(self):
+ def __repr__(self) -> str:
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
self.out, self.err, self.in_, self._state, self._in_suspended,
)
- def start_capturing(self):
+ def start_capturing(self) -> None:
self._state = "started"
if self.in_:
self.in_.start()
@@ -469,8 +568,8 @@ class MultiCapture:
if self.err:
self.err.start()
- def pop_outerr_to_orig(self):
- """ pop current snapshot out/err capture and flush to orig streams. """
+ def pop_outerr_to_orig(self) -> Tuple[AnyStr, AnyStr]:
+ """Pop current snapshot out/err capture and flush to orig streams."""
out, err = self.readouterr()
if out:
self.out.writeorg(out)
@@ -478,7 +577,7 @@ class MultiCapture:
self.err.writeorg(err)
return out, err
- def suspend_capturing(self, in_=False):
+ def suspend_capturing(self, in_: bool = False) -> None:
self._state = "suspended"
if self.out:
self.out.suspend()
@@ -488,8 +587,8 @@ class MultiCapture:
self.in_.suspend()
self._in_suspended = True
- def resume_capturing(self):
- self._state = "resumed"
+ def resume_capturing(self) -> None:
+ self._state = "started"
if self.out:
self.out.resume()
if self.err:
@@ -498,8 +597,8 @@ class MultiCapture:
self.in_.resume()
self._in_suspended = False
- def stop_capturing(self):
- """ stop capturing and reset capturing streams """
+ def stop_capturing(self) -> None:
+ """Stop capturing and reset capturing streams."""
if self._state == "stopped":
raise ValueError("was already stopped")
self._state = "stopped"
@@ -510,7 +609,11 @@ class MultiCapture:
if self.in_:
self.in_.done()
- def readouterr(self) -> CaptureResult:
+ def is_started(self) -> bool:
+ """Whether actively capturing -- not suspended or stopped."""
+ return self._state == "started"
+
+ def readouterr(self) -> CaptureResult[AnyStr]:
if self.out:
out = self.out.snap()
else:
@@ -522,332 +625,343 @@ class MultiCapture:
return CaptureResult(out, err)
-class NoCapture:
- EMPTY_BUFFER = None
- __init__ = start = done = suspend = resume = lambda *args: None
+def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]:
+ if method == "fd":
+ return MultiCapture(in_=FDCapture(0), out=FDCapture(1), err=FDCapture(2))
+ elif method == "sys":
+ return MultiCapture(in_=SysCapture(0), out=SysCapture(1), err=SysCapture(2))
+ elif method == "no":
+ return MultiCapture(in_=None, out=None, err=None)
+ elif method == "tee-sys":
+ return MultiCapture(
+ in_=None, out=SysCapture(1, tee=True), err=SysCapture(2, tee=True)
+ )
+ raise ValueError(f"unknown capturing method: {method!r}")
-class FDCaptureBinary:
- """Capture IO to/from a given os-level filedescriptor.
+# CaptureManager and CaptureFixture
- snap() produces `bytes`
- """
- EMPTY_BUFFER = b""
- _state = None
+class CaptureManager:
+ """The capture plugin.
- def __init__(self, targetfd, tmpfile=None):
- self.targetfd = targetfd
- try:
- self.targetfd_save = os.dup(self.targetfd)
- except OSError:
- self.start = lambda: None
- self.done = lambda: None
- else:
- self.start = self._start
- self.done = self._done
- if targetfd == 0:
- assert not tmpfile, "cannot set tmpfile with stdin"
- tmpfile = open(os.devnull, "r")
- self.syscapture = SysCapture(targetfd)
- else:
- if tmpfile is None:
- f = TemporaryFile()
- with f:
- tmpfile = safe_text_dupfile(f, mode="wb+")
- if targetfd in patchsysdict:
- self.syscapture = SysCapture(targetfd, tmpfile)
- else:
- self.syscapture = NoCapture()
- self.tmpfile = tmpfile
- self.tmpfile_fd = tmpfile.fileno()
-
- def __repr__(self):
- return "<{} {} oldfd={} _state={!r} tmpfile={}>".format(
- self.__class__.__name__,
- self.targetfd,
- getattr(self, "targetfd_save", "<UNSET>"),
- self._state,
- hasattr(self, "tmpfile") and repr(self.tmpfile) or "<UNSET>",
- )
+ Manages that the appropriate capture method is enabled/disabled during
+ collection and each test phase (setup, call, teardown). After each of
+ those points, the captured output is obtained and attached to the
+ collection/runtest report.
- def _start(self):
- """ Start capturing on targetfd using memorized tmpfile. """
- try:
- os.fstat(self.targetfd_save)
- except (AttributeError, OSError):
- raise ValueError("saved filedescriptor not valid anymore")
- os.dup2(self.tmpfile_fd, self.targetfd)
- self.syscapture.start()
- self._state = "started"
+ There are two levels of capture:
- def snap(self):
- self.tmpfile.seek(0)
- res = self.tmpfile.read()
- self.tmpfile.seek(0)
- self.tmpfile.truncate()
- return res
+ * global: enabled by default and can be suppressed by the ``-s``
+ option. This is always enabled/disabled during collection and each test
+ phase.
- def _done(self):
- """ stop capturing, restore streams, return original capture file,
- seeked to position zero. """
- targetfd_save = self.__dict__.pop("targetfd_save")
- os.dup2(targetfd_save, self.targetfd)
- os.close(targetfd_save)
- self.syscapture.done()
- self.tmpfile.close()
- self._state = "done"
-
- def suspend(self):
- self.syscapture.suspend()
- os.dup2(self.targetfd_save, self.targetfd)
- self._state = "suspended"
+ * fixture: when a test function or one of its fixture depend on the
+ ``capsys`` or ``capfd`` fixtures. In this case special handling is
+ needed to ensure the fixtures take precedence over the global capture.
+ """
- def resume(self):
- self.syscapture.resume()
- os.dup2(self.tmpfile_fd, self.targetfd)
- self._state = "resumed"
+ def __init__(self, method: "_CaptureMethod") -> None:
+ self._method = method
+ self._global_capturing: Optional[MultiCapture[str]] = None
+ self._capture_fixture: Optional[CaptureFixture[Any]] = None
- def writeorg(self, data):
- """ write to original file descriptor. """
- os.write(self.targetfd_save, data)
+ def __repr__(self) -> str:
+ return "<CaptureManager _method={!r} _global_capturing={!r} _capture_fixture={!r}>".format(
+ self._method, self._global_capturing, self._capture_fixture
+ )
+ def is_capturing(self) -> Union[str, bool]:
+ if self.is_globally_capturing():
+ return "global"
+ if self._capture_fixture:
+ return "fixture %s" % self._capture_fixture.request.fixturename
+ return False
-class FDCapture(FDCaptureBinary):
- """Capture IO to/from a given os-level filedescriptor.
+ # Global capturing control
- snap() produces text
- """
+ def is_globally_capturing(self) -> bool:
+ return self._method != "no"
- # Ignore type because it doesn't match the type in the superclass (bytes).
- EMPTY_BUFFER = str() # type: ignore
+ def start_global_capturing(self) -> None:
+ assert self._global_capturing is None
+ self._global_capturing = _get_multicapture(self._method)
+ self._global_capturing.start_capturing()
- def snap(self):
- res = super().snap()
- enc = getattr(self.tmpfile, "encoding", None)
- if enc and isinstance(res, bytes):
- res = str(res, enc, "replace")
- return res
+ def stop_global_capturing(self) -> None:
+ if self._global_capturing is not None:
+ self._global_capturing.pop_outerr_to_orig()
+ self._global_capturing.stop_capturing()
+ self._global_capturing = None
- def writeorg(self, data):
- """ write to original file descriptor. """
- data = data.encode("utf-8") # XXX use encoding of original stream
- os.write(self.targetfd_save, data)
+ def resume_global_capture(self) -> None:
+ # During teardown of the python process, and on rare occasions, capture
+ # attributes can be `None` while trying to resume global capture.
+ if self._global_capturing is not None:
+ self._global_capturing.resume_capturing()
+ def suspend_global_capture(self, in_: bool = False) -> None:
+ if self._global_capturing is not None:
+ self._global_capturing.suspend_capturing(in_=in_)
-class SysCaptureBinary:
+ def suspend(self, in_: bool = False) -> None:
+ # Need to undo local capsys-et-al if it exists before disabling global capture.
+ self.suspend_fixture()
+ self.suspend_global_capture(in_)
- EMPTY_BUFFER = b""
- _state = None
+ def resume(self) -> None:
+ self.resume_global_capture()
+ self.resume_fixture()
- def __init__(self, fd, tmpfile=None):
- name = patchsysdict[fd]
- self._old = getattr(sys, name)
- self.name = name
- if tmpfile is None:
- if name == "stdin":
- tmpfile = DontReadFromInput()
- else:
- tmpfile = CaptureIO()
- self.tmpfile = tmpfile
+ def read_global_capture(self) -> CaptureResult[str]:
+ assert self._global_capturing is not None
+ return self._global_capturing.readouterr()
- def __repr__(self):
- return "<{} {} _old={} _state={!r} tmpfile={!r}>".format(
- self.__class__.__name__,
- self.name,
- hasattr(self, "_old") and repr(self._old) or "<UNSET>",
- self._state,
- self.tmpfile,
- )
+ # Fixture Control
- def start(self):
- setattr(sys, self.name, self.tmpfile)
- self._state = "started"
+ def set_fixture(self, capture_fixture: "CaptureFixture[Any]") -> None:
+ if self._capture_fixture:
+ current_fixture = self._capture_fixture.request.fixturename
+ requested_fixture = capture_fixture.request.fixturename
+ capture_fixture.request.raiseerror(
+ "cannot use {} and {} at the same time".format(
+ requested_fixture, current_fixture
+ )
+ )
+ self._capture_fixture = capture_fixture
- def snap(self):
- res = self.tmpfile.buffer.getvalue()
- self.tmpfile.seek(0)
- self.tmpfile.truncate()
- return res
+ def unset_fixture(self) -> None:
+ self._capture_fixture = None
- def done(self):
- setattr(sys, self.name, self._old)
- del self._old
- self.tmpfile.close()
- self._state = "done"
+ def activate_fixture(self) -> None:
+ """If the current item is using ``capsys`` or ``capfd``, activate
+ them so they take precedence over the global capture."""
+ if self._capture_fixture:
+ self._capture_fixture._start()
- def suspend(self):
- setattr(sys, self.name, self._old)
- self._state = "suspended"
+ def deactivate_fixture(self) -> None:
+ """Deactivate the ``capsys`` or ``capfd`` fixture of this item, if any."""
+ if self._capture_fixture:
+ self._capture_fixture.close()
- def resume(self):
- setattr(sys, self.name, self.tmpfile)
- self._state = "resumed"
+ def suspend_fixture(self) -> None:
+ if self._capture_fixture:
+ self._capture_fixture._suspend()
- def writeorg(self, data):
- self._old.flush()
- self._old.buffer.write(data)
- self._old.buffer.flush()
+ def resume_fixture(self) -> None:
+ if self._capture_fixture:
+ self._capture_fixture._resume()
+ # Helper context managers
-class SysCapture(SysCaptureBinary):
- EMPTY_BUFFER = str() # type: ignore[assignment] # noqa: F821
+ @contextlib.contextmanager
+ def global_and_fixture_disabled(self) -> Generator[None, None, None]:
+ """Context manager to temporarily disable global and current fixture capturing."""
+ do_fixture = self._capture_fixture and self._capture_fixture._is_started()
+ if do_fixture:
+ self.suspend_fixture()
+ do_global = self._global_capturing and self._global_capturing.is_started()
+ if do_global:
+ self.suspend_global_capture()
+ try:
+ yield
+ finally:
+ if do_global:
+ self.resume_global_capture()
+ if do_fixture:
+ self.resume_fixture()
- def snap(self):
- res = self.tmpfile.getvalue()
- self.tmpfile.seek(0)
- self.tmpfile.truncate()
- return res
+ @contextlib.contextmanager
+ def item_capture(self, when: str, item: Item) -> Generator[None, None, None]:
+ self.resume_global_capture()
+ self.activate_fixture()
+ try:
+ yield
+ finally:
+ self.deactivate_fixture()
+ self.suspend_global_capture(in_=False)
- def writeorg(self, data):
- self._old.write(data)
- self._old.flush()
+ out, err = self.read_global_capture()
+ item.add_report_section(when, "stdout", out)
+ item.add_report_section(when, "stderr", err)
+ # Hooks
-class TeeSysCapture(SysCapture):
- def __init__(self, fd, tmpfile=None):
- name = patchsysdict[fd]
- self._old = getattr(sys, name)
- self.name = name
- if tmpfile is None:
- if name == "stdin":
- tmpfile = DontReadFromInput()
- else:
- tmpfile = CaptureAndPassthroughIO(self._old)
- self.tmpfile = tmpfile
+ @hookimpl(hookwrapper=True)
+ def pytest_make_collect_report(self, collector: Collector):
+ if isinstance(collector, File):
+ self.resume_global_capture()
+ outcome = yield
+ self.suspend_global_capture()
+ out, err = self.read_global_capture()
+ rep = outcome.get_result()
+ if out:
+ rep.sections.append(("Captured stdout", out))
+ if err:
+ rep.sections.append(("Captured stderr", err))
+ else:
+ yield
+ @hookimpl(hookwrapper=True)
+ def pytest_runtest_setup(self, item: Item) -> Generator[None, None, None]:
+ with self.item_capture("setup", item):
+ yield
-map_fixname_class = {
- "capfd": FDCapture,
- "capfdbinary": FDCaptureBinary,
- "capsys": SysCapture,
- "capsysbinary": SysCaptureBinary,
-}
+ @hookimpl(hookwrapper=True)
+ def pytest_runtest_call(self, item: Item) -> Generator[None, None, None]:
+ with self.item_capture("call", item):
+ yield
+ @hookimpl(hookwrapper=True)
+ def pytest_runtest_teardown(self, item: Item) -> Generator[None, None, None]:
+ with self.item_capture("teardown", item):
+ yield
-class DontReadFromInput:
- encoding = None
+ @hookimpl(tryfirst=True)
+ def pytest_keyboard_interrupt(self) -> None:
+ self.stop_global_capturing()
- def read(self, *args):
- raise IOError(
- "pytest: reading from stdin while output is captured! Consider using `-s`."
- )
+ @hookimpl(tryfirst=True)
+ def pytest_internalerror(self) -> None:
+ self.stop_global_capturing()
- readline = read
- readlines = read
- __next__ = read
- def __iter__(self):
- return self
+class CaptureFixture(Generic[AnyStr]):
+ """Object returned by the :fixture:`capsys`, :fixture:`capsysbinary`,
+ :fixture:`capfd` and :fixture:`capfdbinary` fixtures."""
- def fileno(self):
- raise UnsupportedOperation("redirected stdin is pseudofile, has no fileno()")
+ def __init__(
+ self, captureclass, request: SubRequest, *, _ispytest: bool = False
+ ) -> None:
+ check_ispytest(_ispytest)
+ self.captureclass = captureclass
+ self.request = request
+ self._capture: Optional[MultiCapture[AnyStr]] = None
+ self._captured_out = self.captureclass.EMPTY_BUFFER
+ self._captured_err = self.captureclass.EMPTY_BUFFER
- def isatty(self):
- return False
+ def _start(self) -> None:
+ if self._capture is None:
+ self._capture = MultiCapture(
+ in_=None, out=self.captureclass(1), err=self.captureclass(2),
+ )
+ self._capture.start_capturing()
- def close(self):
- pass
+ def close(self) -> None:
+ if self._capture is not None:
+ out, err = self._capture.pop_outerr_to_orig()
+ self._captured_out += out
+ self._captured_err += err
+ self._capture.stop_capturing()
+ self._capture = None
- @property
- def buffer(self):
- return self
+ def readouterr(self) -> CaptureResult[AnyStr]:
+ """Read and return the captured output so far, resetting the internal
+ buffer.
+ :returns:
+ The captured content as a namedtuple with ``out`` and ``err``
+ string attributes.
+ """
+ captured_out, captured_err = self._captured_out, self._captured_err
+ if self._capture is not None:
+ out, err = self._capture.readouterr()
+ captured_out += out
+ captured_err += err
+ self._captured_out = self.captureclass.EMPTY_BUFFER
+ self._captured_err = self.captureclass.EMPTY_BUFFER
+ return CaptureResult(captured_out, captured_err)
-def _colorama_workaround():
- """
- Ensure colorama is imported so that it attaches to the correct stdio
- handles on Windows.
+ def _suspend(self) -> None:
+ """Suspend this fixture's own capturing temporarily."""
+ if self._capture is not None:
+ self._capture.suspend_capturing()
- colorama uses the terminal on import time. So if something does the
- first import of colorama while I/O capture is active, colorama will
- fail in various ways.
- """
- if sys.platform.startswith("win32"):
- try:
- import colorama # noqa: F401
- except ImportError:
- pass
+ def _resume(self) -> None:
+ """Resume this fixture's own capturing temporarily."""
+ if self._capture is not None:
+ self._capture.resume_capturing()
+ def _is_started(self) -> bool:
+ """Whether actively capturing -- not disabled or closed."""
+ if self._capture is not None:
+ return self._capture.is_started()
+ return False
-def _readline_workaround():
- """
- Ensure readline is imported so that it attaches to the correct stdio
- handles on Windows.
+ @contextlib.contextmanager
+ def disabled(self) -> Generator[None, None, None]:
+ """Temporarily disable capturing while inside the ``with`` block."""
+ capmanager = self.request.config.pluginmanager.getplugin("capturemanager")
+ with capmanager.global_and_fixture_disabled():
+ yield
- Pdb uses readline support where available--when not running from the Python
- prompt, the readline module is not imported until running the pdb REPL. If
- running pytest with the --pdb option this means the readline module is not
- imported until after I/O capture has been started.
- This is a problem for pyreadline, which is often used to implement readline
- support on Windows, as it does not attach to the correct handles for stdout
- and/or stdin if they have been redirected by the FDCapture mechanism. This
- workaround ensures that readline is imported before I/O capture is setup so
- that it can attach to the actual stdin/out for the console.
+# The fixtures.
- See https://github.com/pytest-dev/pytest/pull/1281
- """
- if sys.platform.startswith("win32"):
- try:
- import readline # noqa: F401
- except ImportError:
- pass
+@fixture
+def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
+ """Enable text capturing of writes to ``sys.stdout`` and ``sys.stderr``.
-def _py36_windowsconsoleio_workaround(stream):
+ The captured output is made available via ``capsys.readouterr()`` method
+ calls, which return a ``(out, err)`` namedtuple.
+ ``out`` and ``err`` will be ``text`` objects.
"""
- Python 3.6 implemented unicode console handling for Windows. This works
- by reading/writing to the raw console handle using
- ``{Read,Write}ConsoleW``.
-
- The problem is that we are going to ``dup2`` over the stdio file
- descriptors when doing ``FDCapture`` and this will ``CloseHandle`` the
- handles used by Python to write to the console. Though there is still some
- weirdness and the console handle seems to only be closed randomly and not
- on the first call to ``CloseHandle``, or maybe it gets reopened with the
- same handle value when we suspend capturing.
+ capman = request.config.pluginmanager.getplugin("capturemanager")
+ capture_fixture = CaptureFixture[str](SysCapture, request, _ispytest=True)
+ capman.set_fixture(capture_fixture)
+ capture_fixture._start()
+ yield capture_fixture
+ capture_fixture.close()
+ capman.unset_fixture()
- The workaround in this case will reopen stdio with a different fd which
- also means a different handle by replicating the logic in
- "Py_lifecycle.c:initstdio/create_stdio".
- :param stream: in practice ``sys.stdout`` or ``sys.stderr``, but given
- here as parameter for unittesting purposes.
+@fixture
+def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, None]:
+ """Enable bytes capturing of writes to ``sys.stdout`` and ``sys.stderr``.
- See https://github.com/pytest-dev/py/issues/103
+ The captured output is made available via ``capsysbinary.readouterr()``
+ method calls, which return a ``(out, err)`` namedtuple.
+ ``out`` and ``err`` will be ``bytes`` objects.
"""
- if (
- not sys.platform.startswith("win32")
- or sys.version_info[:2] < (3, 6)
- or hasattr(sys, "pypy_version_info")
- ):
- return
+ capman = request.config.pluginmanager.getplugin("capturemanager")
+ capture_fixture = CaptureFixture[bytes](SysCaptureBinary, request, _ispytest=True)
+ capman.set_fixture(capture_fixture)
+ capture_fixture._start()
+ yield capture_fixture
+ capture_fixture.close()
+ capman.unset_fixture()
- # bail out if ``stream`` doesn't seem like a proper ``io`` stream (#2666)
- if not hasattr(stream, "buffer"):
- return
- buffered = hasattr(stream.buffer, "raw")
- raw_stdout = stream.buffer.raw if buffered else stream.buffer
+@fixture
+def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
+ """Enable text capturing of writes to file descriptors ``1`` and ``2``.
- if not isinstance(raw_stdout, io._WindowsConsoleIO):
- return
+ The captured output is made available via ``capfd.readouterr()`` method
+ calls, which return a ``(out, err)`` namedtuple.
+ ``out`` and ``err`` will be ``text`` objects.
+ """
+ capman = request.config.pluginmanager.getplugin("capturemanager")
+ capture_fixture = CaptureFixture[str](FDCapture, request, _ispytest=True)
+ capman.set_fixture(capture_fixture)
+ capture_fixture._start()
+ yield capture_fixture
+ capture_fixture.close()
+ capman.unset_fixture()
- def _reopen_stdio(f, mode):
- if not buffered and mode[0] == "w":
- buffering = 0
- else:
- buffering = -1
- return io.TextIOWrapper(
- open(os.dup(f.fileno()), mode, buffering),
- f.encoding,
- f.errors,
- f.newlines,
- f.line_buffering,
- )
+@fixture
+def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, None]:
+ """Enable bytes capturing of writes to file descriptors ``1`` and ``2``.
- sys.stdin = _reopen_stdio(sys.stdin, "rb")
- sys.stdout = _reopen_stdio(sys.stdout, "wb")
- sys.stderr = _reopen_stdio(sys.stderr, "wb")
+ The captured output is made available via ``capfd.readouterr()`` method
+ calls, which return a ``(out, err)`` namedtuple.
+ ``out`` and ``err`` will be ``byte`` objects.
+ """
+ capman = request.config.pluginmanager.getplugin("capturemanager")
+ capture_fixture = CaptureFixture[bytes](FDCaptureBinary, request, _ispytest=True)
+ capman.set_fixture(capture_fixture)
+ capture_fixture._start()
+ yield capture_fixture
+ capture_fixture.close()
+ capman.unset_fixture()
diff --git a/contrib/python/pytest/py3/_pytest/compat.py b/contrib/python/pytest/py3/_pytest/compat.py
index 4e11bcab76..c23cc962ce 100644
--- a/contrib/python/pytest/py3/_pytest/compat.py
+++ b/contrib/python/pytest/py3/_pytest/compat.py
@@ -1,52 +1,43 @@
-"""
-python version compatibility code
-"""
+"""Python version compatibility code."""
+import enum
import functools
import inspect
-import io
-import os
import re
import sys
from contextlib import contextmanager
from inspect import Parameter
from inspect import signature
+from pathlib import Path
from typing import Any
from typing import Callable
from typing import Generic
-from typing import IO
from typing import Optional
-from typing import overload
from typing import Tuple
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
-import py
-from _pytest._io.saferepr import saferepr
from _pytest.outcomes import fail
from _pytest.outcomes import TEST_OUTCOME
-if sys.version_info < (3, 5, 2):
- TYPE_CHECKING = False # type: bool
-else:
- from typing import TYPE_CHECKING
-
-
if TYPE_CHECKING:
- from typing import Type # noqa: F401 (used in type string)
+ from typing import NoReturn
+ from typing_extensions import Final
_T = TypeVar("_T")
_S = TypeVar("_S")
-NOTSET = object()
-
-MODULE_NOT_FOUND_ERROR = (
- "ModuleNotFoundError" if sys.version_info[:2] >= (3, 6) else "ImportError"
-)
-
+# fmt: off
+# Singleton type for NOTSET, as described in:
+# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
+class NotSetType(enum.Enum):
+ token = 0
+NOTSET: "Final" = NotSetType.token # noqa: E305
+# fmt: on
if sys.version_info >= (3, 8):
from importlib import metadata as importlib_metadata
@@ -62,27 +53,13 @@ def _format_args(func: Callable[..., Any]) -> str:
REGEX_TYPE = type(re.compile(""))
-if sys.version_info < (3, 6):
-
- def fspath(p):
- """os.fspath replacement, useful to point out when we should replace it by the
- real function once we drop py35.
- """
- return str(p)
-
-
-else:
- fspath = os.fspath
-
-
def is_generator(func: object) -> bool:
genfunc = inspect.isgeneratorfunction(func)
return genfunc and not iscoroutinefunction(func)
def iscoroutinefunction(func: object) -> bool:
- """
- Return True if func is a coroutine function (a function defined with async
+ """Return True if func is a coroutine function (a function defined with async
def syntax, and doesn't contain yield), or a function decorated with
@asyncio.coroutine.
@@ -94,25 +71,27 @@ def iscoroutinefunction(func: object) -> bool:
def is_async_function(func: object) -> bool:
- """Return True if the given function seems to be an async function or async generator"""
- return iscoroutinefunction(func) or (
- sys.version_info >= (3, 6) and inspect.isasyncgenfunction(func)
- )
+ """Return True if the given function seems to be an async function or
+ an async generator."""
+ return iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
-def getlocation(function, curdir=None) -> str:
+def getlocation(function, curdir: Optional[str] = None) -> str:
function = get_real_func(function)
- fn = py.path.local(inspect.getfile(function))
+ fn = Path(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
if curdir is not None:
- relfn = fn.relto(curdir)
- if relfn:
+ try:
+ relfn = fn.relative_to(curdir)
+ except ValueError:
+ pass
+ else:
return "%s:%d" % (relfn, lineno + 1)
return "%s:%d" % (fn, lineno + 1)
def num_mock_patch_args(function) -> int:
- """ return number of arguments used up by mock arguments (if any) """
+ """Return number of arguments used up by mock arguments (if any)."""
patchings = getattr(function, "patchings", None)
if not patchings:
return 0
@@ -135,15 +114,15 @@ def getfuncargnames(
*,
name: str = "",
is_method: bool = False,
- cls: Optional[type] = None
+ cls: Optional[type] = None,
) -> Tuple[str, ...]:
- """Returns the names of a function's mandatory arguments.
+ """Return the names of a function's mandatory arguments.
- This should return the names of all function arguments that:
- * Aren't bound to an instance or type as in instance or class methods.
- * Don't have default values.
- * Aren't bound with functools.partial.
- * Aren't replaced with mocks.
+ Should return the names of all function arguments that:
+ * Aren't bound to an instance or type as in instance or class methods.
+ * Don't have default values.
+ * Aren't bound with functools.partial.
+ * Aren't replaced with mocks.
The is_method and cls arguments indicate that the function should
be treated as a bound method even though it's not unless, only in
@@ -164,8 +143,7 @@ def getfuncargnames(
parameters = signature(function).parameters
except (ValueError, TypeError) as e:
fail(
- "Could not determine arguments of {!r}: {}".format(function, e),
- pytrace=False,
+ f"Could not determine arguments of {function!r}: {e}", pytrace=False,
)
arg_names = tuple(
@@ -201,12 +179,13 @@ if sys.version_info < (3, 7):
else:
- from contextlib import nullcontext # noqa
+ from contextlib import nullcontext as nullcontext # noqa: F401
def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
- # Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
- # to get the arguments which were excluded from its result because they had default values
+ # Note: this code intentionally mirrors the code at the beginning of
+ # getfuncargnames, to get the arguments which were excluded from its result
+ # because they had default values.
return tuple(
p.name
for p in signature(function).parameters.values()
@@ -216,7 +195,7 @@ def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
_non_printable_ascii_translate_table = {
- i: "\\x{:02x}".format(i) for i in range(128) if i not in range(32, 127)
+ i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127)
}
_non_printable_ascii_translate_table.update(
{ord("\t"): "\\t", ord("\r"): "\\r", ord("\n"): "\\n"}
@@ -235,22 +214,21 @@ def _bytes_to_ascii(val: bytes) -> str:
def ascii_escaped(val: Union[bytes, str]) -> str:
- """If val is pure ascii, returns it as a str(). Otherwise, escapes
+ r"""If val is pure ASCII, return it as an str, otherwise, escape
bytes objects into a sequence of escaped bytes:
- b'\xc3\xb4\xc5\xd6' -> '\\xc3\\xb4\\xc5\\xd6'
+ b'\xc3\xb4\xc5\xd6' -> r'\xc3\xb4\xc5\xd6'
and escapes unicode objects into a sequence of escaped unicode
ids, e.g.:
- '4\\nV\\U00043efa\\x0eMXWB\\x1e\\u3028\\u15fd\\xcd\\U0007d944'
+ r'4\nV\U00043efa\x0eMXWB\x1e\u3028\u15fd\xcd\U0007d944'
- note:
- the obvious "v.decode('unicode-escape')" will return
- valid utf-8 unicode if it finds them in bytes, but we
+ Note:
+ The obvious "v.decode('unicode-escape')" will return
+ valid UTF-8 unicode if it finds them in bytes, but we
want to return escaped bytes for any byte, even if they match
- a utf-8 string.
-
+ a UTF-8 string.
"""
if isinstance(val, bytes):
ret = _bytes_to_ascii(val)
@@ -263,18 +241,17 @@ def ascii_escaped(val: Union[bytes, str]) -> str:
class _PytestWrapper:
"""Dummy wrapper around a function object for internal use only.
- Used to correctly unwrap the underlying function object
- when we are creating fixtures, because we wrap the function object ourselves with a decorator
- to issue warnings when the fixture function is called directly.
+ Used to correctly unwrap the underlying function object when we are
+ creating fixtures, because we wrap the function object ourselves with a
+ decorator to issue warnings when the fixture function is called directly.
"""
obj = attr.ib()
def get_real_func(obj):
- """ gets the real function object of the (possibly) wrapped object by
- functools.wraps or functools.partial.
- """
+ """Get the real function object of the (possibly) wrapped object by
+ functools.wraps or functools.partial."""
start_obj = obj
for i in range(100):
# __pytest_wrapped__ is set by @pytest.fixture when wrapping the fixture function
@@ -289,6 +266,8 @@ def get_real_func(obj):
break
obj = new_obj
else:
+ from _pytest._io.saferepr import saferepr
+
raise ValueError(
("could not find real function of {start}\nstopped at {current}").format(
start=saferepr(start_obj), current=saferepr(obj)
@@ -300,10 +279,9 @@ def get_real_func(obj):
def get_real_method(obj, holder):
- """
- Attempts to obtain the real function object that might be wrapping ``obj``, while at the same time
- returning a bound method to ``holder`` if the original object was a bound method.
- """
+ """Attempt to obtain the real function object that might be wrapping
+ ``obj``, while at the same time returning a bound method to ``holder`` if
+ the original object was a bound method."""
try:
is_method = hasattr(obj, "__func__")
obj = get_real_func(obj)
@@ -322,12 +300,13 @@ def getimfunc(func):
def safe_getattr(object: Any, name: str, default: Any) -> Any:
- """ Like getattr but return default upon any Exception or any OutcomeException.
+ """Like getattr but return default upon any Exception or any OutcomeException.
Attribute access can potentially fail for 'evil' Python objects.
See issue #214.
- It catches OutcomeException because of #2490 (issue #580), new outcomes are derived from BaseException
- instead of Exception (for more details check #2707)
+ It catches OutcomeException because of #2490 (issue #580), new outcomes
+ are derived from BaseException instead of Exception (for more details
+ check #2707).
"""
try:
return getattr(object, name, default)
@@ -343,64 +322,24 @@ def safe_isclass(obj: object) -> bool:
return False
-COLLECT_FAKEMODULE_ATTRIBUTES = (
- "Collector",
- "Module",
- "Function",
- "Instance",
- "Session",
- "Item",
- "Class",
- "File",
- "_fillfuncargs",
-)
-
-
-def _setup_collect_fakemodule() -> None:
- from types import ModuleType
- import pytest
-
- # Types ignored because the module is created dynamically.
- pytest.collect = ModuleType("pytest.collect") # type: ignore
- pytest.collect.__all__ = [] # type: ignore # used for setns
- for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES:
- setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore
-
-
-class CaptureIO(io.TextIOWrapper):
- def __init__(self) -> None:
- super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
-
- def getvalue(self) -> str:
- assert isinstance(self.buffer, io.BytesIO)
- return self.buffer.getvalue().decode("UTF-8")
-
-
-class CaptureAndPassthroughIO(CaptureIO):
- def __init__(self, other: IO) -> None:
- self._other = other
- super().__init__()
-
- def write(self, s) -> int:
- super().write(s)
- return self._other.write(s)
-
-
-if sys.version_info < (3, 5, 2):
+if TYPE_CHECKING:
+ if sys.version_info >= (3, 8):
+ from typing import final as final
+ else:
+ from typing_extensions import final as final
+elif sys.version_info >= (3, 8):
+ from typing import final as final
+else:
- def overload(f): # noqa: F811
+ def final(f):
return f
-if getattr(attr, "__version_info__", ()) >= (19, 2):
- ATTRS_EQ_FIELD = "eq"
-else:
- ATTRS_EQ_FIELD = "cmp"
-
-
if sys.version_info >= (3, 8):
- from functools import cached_property
+ from functools import cached_property as cached_property
else:
+ from typing import overload
+ from typing import Type
class cached_property(Generic[_S, _T]):
__slots__ = ("func", "__doc__")
@@ -411,18 +350,51 @@ else:
@overload
def __get__(
- self, instance: None, owner: Optional["Type[_S]"] = ...
+ self, instance: None, owner: Optional[Type[_S]] = ...
) -> "cached_property[_S, _T]":
- raise NotImplementedError()
+ ...
- @overload # noqa: F811
- def __get__( # noqa: F811
- self, instance: _S, owner: Optional["Type[_S]"] = ...
- ) -> _T:
- raise NotImplementedError()
+ @overload
+ def __get__(self, instance: _S, owner: Optional[Type[_S]] = ...) -> _T:
+ ...
- def __get__(self, instance, owner=None): # noqa: F811
+ def __get__(self, instance, owner=None):
if instance is None:
return self
value = instance.__dict__[self.func.__name__] = self.func(instance)
return value
+
+
+# Perform exhaustiveness checking.
+#
+# Consider this example:
+#
+# MyUnion = Union[int, str]
+#
+# def handle(x: MyUnion) -> int {
+# if isinstance(x, int):
+# return 1
+# elif isinstance(x, str):
+# return 2
+# else:
+# raise Exception('unreachable')
+#
+# Now suppose we add a new variant:
+#
+# MyUnion = Union[int, str, bytes]
+#
+# After doing this, we must remember ourselves to go and update the handle
+# function to handle the new variant.
+#
+# With `assert_never` we can do better:
+#
+# // raise Exception('unreachable')
+# return assert_never(x)
+#
+# Now, if we forget to handle the new variant, the type-checker will emit a
+# compile-time error, instead of the runtime error we would have gotten
+# previously.
+#
+# This also work for Enums (if you use `is` to compare) and Literals.
+def assert_never(value: "NoReturn") -> "NoReturn":
+ assert False, "Unhandled value: {} ({})".format(value, type(value).__name__)
diff --git a/contrib/python/pytest/py3/_pytest/config/__init__.py b/contrib/python/pytest/py3/_pytest/config/__init__.py
index e21b9f1e2b..bd9e2883f9 100644
--- a/contrib/python/pytest/py3/_pytest/config/__init__.py
+++ b/contrib/python/pytest/py3/_pytest/config/__init__.py
@@ -1,104 +1,142 @@
-""" command line options, ini-file and conftest.py processing. """
+"""Command line options, ini-file and conftest.py processing."""
import argparse
+import collections.abc
+import contextlib
import copy
import enum
import inspect
import os
+import re
import shlex
import sys
import types
import warnings
from functools import lru_cache
+from pathlib import Path
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Generator
+from typing import IO
+from typing import Iterable
+from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
+from typing import TextIO
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import Union
import attr
import py
-from packaging.version import Version
from pluggy import HookimplMarker
from pluggy import HookspecMarker
from pluggy import PluginManager
import _pytest._code
import _pytest.deprecated
-import _pytest.hookspec # the extension point definitions
-from .exceptions import PrintHelp
-from .exceptions import UsageError
+import _pytest.hookspec
+from .exceptions import PrintHelp as PrintHelp
+from .exceptions import UsageError as UsageError
from .findpaths import determine_setup
-from .findpaths import exists
from _pytest._code import ExceptionInfo
from _pytest._code import filter_traceback
from _pytest._io import TerminalWriter
+from _pytest.compat import final
from _pytest.compat import importlib_metadata
-from _pytest.compat import TYPE_CHECKING
from _pytest.outcomes import fail
from _pytest.outcomes import Skipped
-from _pytest.pathlib import Path
+from _pytest.pathlib import bestrelpath
+from _pytest.pathlib import import_path
+from _pytest.pathlib import ImportMode
from _pytest.store import Store
from _pytest.warning_types import PytestConfigWarning
if TYPE_CHECKING:
- from typing import Type
+ from _pytest._code.code import _TracebackStyle
+ from _pytest.terminal import TerminalReporter
from .argparsing import Argument
_PluggyPlugin = object
"""A type to represent plugin objects.
+
Plugins can be any namespace, so we can't narrow it down much, but we use an
alias to make the intent clear.
-Ideally this type would be provided by pluggy itself."""
+
+Ideally this type would be provided by pluggy itself.
+"""
hookimpl = HookimplMarker("pytest")
hookspec = HookspecMarker("pytest")
+@final
class ExitCode(enum.IntEnum):
- """
- .. versionadded:: 5.0
-
- Encodes the valid exit codes by pytest.
+ """Encodes the valid exit codes by pytest.
Currently users and plugins may supply other exit codes as well.
+
+ .. versionadded:: 5.0
"""
- #: tests passed
+ #: Tests passed.
OK = 0
- #: tests failed
+ #: Tests failed.
TESTS_FAILED = 1
- #: pytest was interrupted
+ #: pytest was interrupted.
INTERRUPTED = 2
- #: an internal error got in the way
+ #: An internal error got in the way.
INTERNAL_ERROR = 3
- #: pytest was misused
+ #: pytest was misused.
USAGE_ERROR = 4
- #: pytest couldn't find tests
+ #: pytest couldn't find tests.
NO_TESTS_COLLECTED = 5
class ConftestImportFailure(Exception):
- def __init__(self, path, excinfo):
- Exception.__init__(self, path, excinfo)
+ def __init__(
+ self,
+ path: py.path.local,
+ excinfo: Tuple[Type[Exception], Exception, TracebackType],
+ ) -> None:
+ super().__init__(path, excinfo)
self.path = path
- self.excinfo = excinfo # type: Tuple[Type[Exception], Exception, TracebackType]
+ self.excinfo = excinfo
+ def __str__(self) -> str:
+ return "{}: {} (from {})".format(
+ self.excinfo[0].__name__, self.excinfo[1], self.path
+ )
-def main(args=None, plugins=None) -> Union[int, ExitCode]:
- """ return exit code, after performing an in-process test run.
- :arg args: list of command line arguments.
+def filter_traceback_for_conftest_import_failure(
+ entry: _pytest._code.TracebackEntry,
+) -> bool:
+ """Filter tracebacks entries which point to pytest internals or importlib.
- :arg plugins: list of plugin objects to be auto-registered during
- initialization.
+ Make a special case for importlib because we use it to import test modules and conftest files
+ in _pytest.pathlib.import_path.
+ """
+ return filter_traceback(entry) and "importlib" not in str(entry.path).split(os.sep)
+
+
+def main(
+ args: Optional[Union[List[str], py.path.local]] = None,
+ plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
+) -> Union[int, ExitCode]:
+ """Perform an in-process test run.
+
+ :param args: List of command line arguments.
+ :param plugins: List of plugin objects to be auto-registered during initialization.
+
+ :returns: An exit code.
"""
try:
try:
@@ -106,10 +144,10 @@ def main(args=None, plugins=None) -> Union[int, ExitCode]:
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
- tw.line(
- "ImportError while loading conftest '{e.path}'.".format(e=e), red=True
+ tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
+ exc_info.traceback = exc_info.traceback.filter(
+ filter_traceback_for_conftest_import_failure
)
- exc_info.traceback = exc_info.traceback.filter(filter_traceback)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
@@ -121,9 +159,9 @@ def main(args=None, plugins=None) -> Union[int, ExitCode]:
return ExitCode.USAGE_ERROR
else:
try:
- ret = config.hook.pytest_cmdline_main(
+ ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
- ) # type: Union[ExitCode, int]
+ )
try:
return ExitCode(ret)
except ValueError:
@@ -133,33 +171,51 @@ def main(args=None, plugins=None) -> Union[int, ExitCode]:
except UsageError as e:
tw = TerminalWriter(sys.stderr)
for msg in e.args:
- tw.line("ERROR: {}\n".format(msg), red=True)
+ tw.line(f"ERROR: {msg}\n", red=True)
return ExitCode.USAGE_ERROR
+def console_main() -> int:
+ """The CLI entry point of pytest.
+
+ This function is not meant for programmable use; use `main()` instead.
+ """
+ # https://docs.python.org/3/library/signal.html#note-on-sigpipe
+ try:
+ code = main()
+ sys.stdout.flush()
+ return code
+ except BrokenPipeError:
+ # Python flushes standard streams on exit; redirect remaining output
+ # to devnull to avoid another BrokenPipeError at shutdown
+ devnull = os.open(os.devnull, os.O_WRONLY)
+ os.dup2(devnull, sys.stdout.fileno())
+ return 1 # Python exits with error code 1 on EPIPE
+
+
class cmdline: # compatibility namespace
main = staticmethod(main)
-def filename_arg(path, optname):
- """ Argparse type validator for filename arguments.
+def filename_arg(path: str, optname: str) -> str:
+ """Argparse type validator for filename arguments.
- :path: path of filename
- :optname: name of the option
+ :path: Path of filename.
+ :optname: Name of the option.
"""
if os.path.isdir(path):
- raise UsageError("{} must be a filename, given: {}".format(optname, path))
+ raise UsageError(f"{optname} must be a filename, given: {path}")
return path
-def directory_arg(path, optname):
+def directory_arg(path: str, optname: str) -> str:
"""Argparse type validator for directory arguments.
- :path: path of directory
- :optname: name of the option
+ :path: Path of directory.
+ :optname: Name of the option.
"""
if not os.path.isdir(path):
- raise UsageError("{} must be a directory, given: {}".format(optname, path))
+ raise UsageError(f"{optname} must be a directory, given: {path}")
return path
@@ -186,7 +242,6 @@ default_plugins = essential_plugins + (
"nose",
"assertion",
"junitxml",
- "resultlog",
"doctest",
"cacheprovider",
"freeze_support",
@@ -196,20 +251,25 @@ default_plugins = essential_plugins + (
"warnings",
"logging",
"reports",
+ *(["unraisableexception", "threadexception"] if sys.version_info >= (3, 8) else []),
"faulthandler",
)
builtin_plugins = set(default_plugins)
builtin_plugins.add("pytester")
+builtin_plugins.add("pytester_assertions")
-def get_config(args=None, plugins=None):
+def get_config(
+ args: Optional[List[str]] = None,
+ plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
+) -> "Config":
# subsequent calls to main will create a fresh instance
pluginmanager = PytestPluginManager()
config = Config(
pluginmanager,
invocation_params=Config.InvocationParams(
- args=args or (), plugins=plugins, dir=Path().resolve()
+ args=args or (), plugins=plugins, dir=Path.cwd(),
),
)
@@ -219,12 +279,12 @@ def get_config(args=None, plugins=None):
for spec in default_plugins:
pluginmanager.import_plugin(spec)
+
return config
-def get_plugin_manager():
- """
- Obtain a new instance of the
+def get_plugin_manager() -> "PytestPluginManager":
+ """Obtain a new instance of the
:py:class:`_pytest.config.PytestPluginManager`, with default plugins
already loaded.
@@ -235,8 +295,9 @@ def get_plugin_manager():
def _prepareconfig(
- args: Optional[Union[py.path.local, List[str]]] = None, plugins=None
-):
+ args: Optional[Union[py.path.local, List[str]]] = None,
+ plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
+) -> "Config":
if args is None:
args = sys.argv[1:]
elif isinstance(args, py.path.local):
@@ -254,61 +315,55 @@ def _prepareconfig(
pluginmanager.consider_pluginarg(plugin)
else:
pluginmanager.register(plugin)
- return pluginmanager.hook.pytest_cmdline_parse(
+ config = pluginmanager.hook.pytest_cmdline_parse(
pluginmanager=pluginmanager, args=args
)
+ return config
except BaseException:
config._ensure_unconfigure()
raise
-def _fail_on_non_top_pytest_plugins(conftestpath, confcutdir):
- msg = (
- "Defining 'pytest_plugins' in a non-top-level conftest is no longer supported:\n"
- "It affects the entire test suite instead of just below the conftest as expected.\n"
- " {}\n"
- "Please move it to a top level conftest file at the rootdir:\n"
- " {}\n"
- "For more information, visit:\n"
- " https://docs.pytest.org/en/latest/deprecations.html#pytest-plugins-in-non-top-level-conftest-files"
- )
- fail(msg.format(conftestpath, confcutdir), pytrace=False)
-
-
+@final
class PytestPluginManager(PluginManager):
- """
- Overwrites :py:class:`pluggy.PluginManager <pluggy.PluginManager>` to add pytest-specific
- functionality:
+ """A :py:class:`pluggy.PluginManager <pluggy.PluginManager>` with
+ additional pytest-specific functionality:
- * loading plugins from the command line, ``PYTEST_PLUGINS`` env variable and
- ``pytest_plugins`` global variables found in plugins being loaded;
- * ``conftest.py`` loading during start-up;
+ * Loading plugins from the command line, ``PYTEST_PLUGINS`` env variable and
+ ``pytest_plugins`` global variables found in plugins being loaded.
+ * ``conftest.py`` loading during start-up.
"""
- def __init__(self):
+ def __init__(self) -> None:
import _pytest.assertion
super().__init__("pytest")
# The objects are module objects, only used generically.
- self._conftest_plugins = set() # type: Set[object]
-
- # state related to local conftest plugins
- # Maps a py.path.local to a list of module objects.
- self._dirpath2confmods = {} # type: Dict[Any, List[object]]
- # Maps a py.path.local to a module object.
- self._conftestpath2mod = {} # type: Dict[Any, object]
- self._confcutdir = None
+ self._conftest_plugins: Set[types.ModuleType] = set()
+
+ # State related to local conftest plugins.
+ self._dirpath2confmods: Dict[py.path.local, List[types.ModuleType]] = {}
+ self._conftestpath2mod: Dict[Path, types.ModuleType] = {}
+ self._confcutdir: Optional[py.path.local] = None
self._noconftest = False
- # Set of py.path.local's.
- self._duplicatepaths = set() # type: Set[Any]
+ self._duplicatepaths: Set[py.path.local] = set()
+
+ # plugins that were explicitly skipped with pytest.skip
+ # list of (module name, skip reason)
+ # previously we would issue a warning when a plugin was skipped, but
+ # since we refactored warnings as first citizens of Config, they are
+ # just stored here to be used later.
+ self.skipped_plugins: List[Tuple[str, str]] = []
self.add_hookspecs(_pytest.hookspec)
self.register(self)
if os.environ.get("PYTEST_DEBUG"):
- err = sys.stderr
- encoding = getattr(err, "encoding", "utf8")
+ err: IO[str] = sys.stderr
+ encoding: str = getattr(err, "encoding", "utf8")
try:
- err = py.io.dupfile(err, encoding=encoding)
+ err = open(
+ os.dup(err.fileno()), mode=err.mode, buffering=1, encoding=encoding,
+ )
except Exception:
pass
self.trace.root.setwriter(err.write)
@@ -316,27 +371,27 @@ class PytestPluginManager(PluginManager):
# Config._consider_importhook will set a real object if required.
self.rewrite_hook = _pytest.assertion.DummyRewriteHook()
- # Used to know when we are importing conftests after the pytest_configure stage
+ # Used to know when we are importing conftests after the pytest_configure stage.
self._configured = False
- def parse_hookimpl_opts(self, plugin, name):
- # pytest hooks are always prefixed with pytest_
+ def parse_hookimpl_opts(self, plugin: _PluggyPlugin, name: str):
+ # pytest hooks are always prefixed with "pytest_",
# so we avoid accessing possibly non-readable attributes
- # (see issue #1073)
+ # (see issue #1073).
if not name.startswith("pytest_"):
return
- # ignore names which can not be hooks
+ # Ignore names which can not be hooks.
if name == "pytest_plugins":
return
method = getattr(plugin, name)
opts = super().parse_hookimpl_opts(plugin, name)
- # consider only actual functions for hooks (#3775)
+ # Consider only actual functions for hooks (#3775).
if not inspect.isroutine(method):
return
- # collect unmarked hooks as long as they have the `pytest_' prefix
+ # Collect unmarked hooks as long as they have the `pytest_' prefix.
if opts is None and name.startswith("pytest_"):
opts = {}
if opts is not None:
@@ -348,7 +403,7 @@ class PytestPluginManager(PluginManager):
opts.setdefault(name, hasattr(method, name) or name in known_marks)
return opts
- def parse_hookspec_opts(self, module_or_class, name):
+ def parse_hookspec_opts(self, module_or_class, name: str):
opts = super().parse_hookspec_opts(module_or_class, name)
if opts is None:
method = getattr(module_or_class, name)
@@ -365,7 +420,9 @@ class PytestPluginManager(PluginManager):
}
return opts
- def register(self, plugin, name=None):
+ def register(
+ self, plugin: _PluggyPlugin, name: Optional[str] = None
+ ) -> Optional[str]:
if name in _pytest.deprecated.DEPRECATED_EXTERNAL_PLUGINS:
warnings.warn(
PytestConfigWarning(
@@ -375,8 +432,8 @@ class PytestPluginManager(PluginManager):
)
)
)
- return
- ret = super().register(plugin, name)
+ return None
+ ret: Optional[str] = super().register(plugin, name)
if ret:
self.hook.pytest_plugin_registered.call_historic(
kwargs=dict(plugin=plugin, manager=self)
@@ -386,17 +443,19 @@ class PytestPluginManager(PluginManager):
self.consider_module(plugin)
return ret
- def getplugin(self, name):
- # support deprecated naming because plugins (xdist e.g.) use it
- return self.get_plugin(name)
+ def getplugin(self, name: str):
+ # Support deprecated naming because plugins (xdist e.g.) use it.
+ plugin: Optional[_PluggyPlugin] = self.get_plugin(name)
+ return plugin
- def hasplugin(self, name):
- """Return True if the plugin with the given name is registered."""
+ def hasplugin(self, name: str) -> bool:
+ """Return whether a plugin with the given name is registered."""
return bool(self.get_plugin(name))
- def pytest_configure(self, config):
+ def pytest_configure(self, config: "Config") -> None:
+ """:meta private:"""
# XXX now that the pluginmanager exposes hookimpl(tryfirst...)
- # we should remove tryfirst/trylast as markers
+ # we should remove tryfirst/trylast as markers.
config.addinivalue_line(
"markers",
"tryfirst: mark a hook implementation function such that the "
@@ -410,15 +469,15 @@ class PytestPluginManager(PluginManager):
self._configured = True
#
- # internal API for local conftest plugin handling
+ # Internal API for local conftest plugin handling.
#
- def _set_initial_conftests(self, namespace):
- """ load initial conftest files given a preparsed "namespace".
- As conftest files may add their own command line options
- which have arguments ('--my-opt somepath') we might get some
- false positives. All builtin and 3rd party plugins will have
- been loaded, however, so common options will not confuse our logic
- here.
+ def _set_initial_conftests(self, namespace: argparse.Namespace) -> None:
+ """Load initial conftest files given a preparsed "namespace".
+
+ As conftest files may add their own command line options which have
+ arguments ('--my-opt somepath') we might get some false positives.
+ All builtin and 3rd party plugins will have been loaded, however, so
+ common options will not confuse our logic here.
"""
current = py.path.local()
self._confcutdir = (
@@ -430,29 +489,33 @@ class PytestPluginManager(PluginManager):
self._using_pyargs = namespace.pyargs
testpaths = namespace.file_or_dir
foundanchor = False
- for path in testpaths:
- path = str(path)
+ for testpath in testpaths:
+ path = str(testpath)
# remove node-id syntax
i = path.find("::")
if i != -1:
path = path[:i]
anchor = current.join(path, abs=1)
- if exists(anchor): # we found some file object
- self._try_load_conftest(anchor)
+ if anchor.exists(): # we found some file object
+ self._try_load_conftest(anchor, namespace.importmode)
foundanchor = True
if not foundanchor:
- self._try_load_conftest(current)
+ self._try_load_conftest(current, namespace.importmode)
- def _try_load_conftest(self, anchor):
- self._getconftestmodules(anchor)
+ def _try_load_conftest(
+ self, anchor: py.path.local, importmode: Union[str, ImportMode]
+ ) -> None:
+ self._getconftestmodules(anchor, importmode)
# let's also consider test* subdirs
if anchor.check(dir=1):
for x in anchor.listdir("test*"):
if x.check(dir=1):
- self._getconftestmodules(x)
+ self._getconftestmodules(x, importmode)
@lru_cache(maxsize=128)
- def _getconftestmodules(self, path):
+ def _getconftestmodules(
+ self, path: py.path.local, importmode: Union[str, ImportMode],
+ ) -> List[types.ModuleType]:
if self._noconftest:
return []
@@ -461,22 +524,24 @@ class PytestPluginManager(PluginManager):
else:
directory = path
- # XXX these days we may rather want to use config.rootdir
+ # XXX these days we may rather want to use config.rootpath
# and allow users to opt into looking into the rootdir parent
- # directories instead of requiring to specify confcutdir
+ # directories instead of requiring to specify confcutdir.
clist = []
- for parent in directory.realpath().parts():
+ for parent in directory.parts():
if self._confcutdir and self._confcutdir.relto(parent):
continue
conftestpath = parent.join("conftest.py")
if conftestpath.isfile():
- mod = self._importconftest(conftestpath)
+ mod = self._importconftest(conftestpath, importmode)
clist.append(mod)
self._dirpath2confmods[directory] = clist
return clist
- def _rget_with_confmod(self, name, path):
- modules = self._getconftestmodules(path)
+ def _rget_with_confmod(
+ self, name: str, path: py.path.local, importmode: Union[str, ImportMode],
+ ) -> Tuple[types.ModuleType, Any]:
+ modules = self._getconftestmodules(path, importmode)
for mod in reversed(modules):
try:
return mod, getattr(mod, name)
@@ -484,48 +549,71 @@ class PytestPluginManager(PluginManager):
continue
raise KeyError(name)
- def _importconftest(self, conftestpath):
- # Use a resolved Path object as key to avoid loading the same conftest twice
- # with build systems that create build directories containing
+ def _importconftest(
+ self, conftestpath: py.path.local, importmode: Union[str, ImportMode],
+ ) -> types.ModuleType:
+ # Use a resolved Path object as key to avoid loading the same conftest
+ # twice with build systems that create build directories containing
# symlinks to actual files.
# Using Path().resolve() is better than py.path.realpath because
# it resolves to the correct path/drive in case-insensitive file systems (#5792)
key = Path(str(conftestpath)).resolve()
- try:
+
+ with contextlib.suppress(KeyError):
return self._conftestpath2mod[key]
- except KeyError:
- pkgpath = conftestpath.pypkgpath()
- if pkgpath is None:
- _ensure_removed_sysmodule(conftestpath.purebasename)
- try:
- mod = conftestpath.pyimport()
- if (
- hasattr(mod, "pytest_plugins")
- and self._configured
- and not self._using_pyargs
- ):
- _fail_on_non_top_pytest_plugins(conftestpath, self._confcutdir)
- except Exception:
- raise ConftestImportFailure(conftestpath, sys.exc_info())
-
- self._conftest_plugins.add(mod)
- self._conftestpath2mod[key] = mod
- dirpath = conftestpath.dirpath()
- if dirpath in self._dirpath2confmods:
- for path, mods in self._dirpath2confmods.items():
- if path and path.relto(dirpath) or path == dirpath:
- assert mod not in mods
- mods.append(mod)
- self.trace("loading conftestmodule {!r}".format(mod))
- self.consider_conftest(mod)
- return mod
+
+ pkgpath = conftestpath.pypkgpath()
+ if pkgpath is None:
+ _ensure_removed_sysmodule(conftestpath.purebasename)
+
+ try:
+ mod = import_path(conftestpath, mode=importmode)
+ except Exception as e:
+ assert e.__traceback__ is not None
+ exc_info = (type(e), e, e.__traceback__)
+ raise ConftestImportFailure(conftestpath, exc_info) from e
+
+ self._check_non_top_pytest_plugins(mod, conftestpath)
+
+ self._conftest_plugins.add(mod)
+ self._conftestpath2mod[key] = mod
+ dirpath = conftestpath.dirpath()
+ if dirpath in self._dirpath2confmods:
+ for path, mods in self._dirpath2confmods.items():
+ if path and path.relto(dirpath) or path == dirpath:
+ assert mod not in mods
+ mods.append(mod)
+ self.trace(f"loading conftestmodule {mod!r}")
+ self.consider_conftest(mod)
+ return mod
+
+ def _check_non_top_pytest_plugins(
+ self, mod: types.ModuleType, conftestpath: py.path.local,
+ ) -> None:
+ if (
+ hasattr(mod, "pytest_plugins")
+ and self._configured
+ and not self._using_pyargs
+ ):
+ msg = (
+ "Defining 'pytest_plugins' in a non-top-level conftest is no longer supported:\n"
+ "It affects the entire test suite instead of just below the conftest as expected.\n"
+ " {}\n"
+ "Please move it to a top level conftest file at the rootdir:\n"
+ " {}\n"
+ "For more information, visit:\n"
+ " https://docs.pytest.org/en/stable/deprecations.html#pytest-plugins-in-non-top-level-conftest-files"
+ )
+ fail(msg.format(conftestpath, self._confcutdir), pytrace=False)
#
# API for bootstrapping plugin loading
#
#
- def consider_preparse(self, args, *, exclude_only=False):
+ def consider_preparse(
+ self, args: Sequence[str], *, exclude_only: bool = False
+ ) -> None:
i = 0
n = len(args)
while i < n:
@@ -546,13 +634,13 @@ class PytestPluginManager(PluginManager):
continue
self.consider_pluginarg(parg)
- def consider_pluginarg(self, arg):
+ def consider_pluginarg(self, arg: str) -> None:
if arg.startswith("no:"):
name = arg[3:]
if name in essential_plugins:
raise UsageError("plugin %s cannot be disabled" % name)
- # PR #4304 : remove stepwise if cacheprovider is blocked
+ # PR #4304: remove stepwise if cacheprovider is blocked.
if name == "cacheprovider":
self.set_blocked("stepwise")
self.set_blocked("pytest_stepwise")
@@ -571,33 +659,35 @@ class PytestPluginManager(PluginManager):
del self._name2plugin["pytest_" + name]
self.import_plugin(arg, consider_entry_points=True)
- def consider_conftest(self, conftestmodule):
+ def consider_conftest(self, conftestmodule: types.ModuleType) -> None:
self.register(conftestmodule, name=conftestmodule.__file__)
- def consider_env(self):
+ def consider_env(self) -> None:
self._import_plugin_specs(os.environ.get("PYTEST_PLUGINS"))
- def consider_module(self, mod):
+ def consider_module(self, mod: types.ModuleType) -> None:
self._import_plugin_specs(getattr(mod, "pytest_plugins", []))
- def _import_plugin_specs(self, spec):
+ def _import_plugin_specs(
+ self, spec: Union[None, types.ModuleType, str, Sequence[str]]
+ ) -> None:
plugins = _get_plugin_specs_as_list(spec)
for import_spec in plugins:
self.import_plugin(import_spec)
- def import_plugin(self, modname, consider_entry_points=False):
- """
- Imports a plugin with ``modname``. If ``consider_entry_points`` is True, entry point
- names are also considered to find a plugin.
+ def import_plugin(self, modname: str, consider_entry_points: bool = False) -> None:
+ """Import a plugin with ``modname``.
+
+ If ``consider_entry_points`` is True, entry point names are also
+ considered to find a plugin.
"""
- # most often modname refers to builtin modules, e.g. "pytester",
+ # Most often modname refers to builtin modules, e.g. "pytester",
# "terminal" or "capture". Those plugins are registered under their
# basename for historic purposes but must be imported with the
# _pytest prefix.
assert isinstance(modname, str), (
"module name as text required, got %r" % modname
)
- modname = str(modname)
if self.is_blocked(modname) or self.get_plugin(modname) is not None:
return
@@ -614,42 +704,38 @@ class PytestPluginManager(PluginManager):
except ImportError as e:
raise ImportError(
'Error importing plugin "{}": {}'.format(modname, str(e.args[0]))
- ).with_traceback(e.__traceback__)
+ ).with_traceback(e.__traceback__) from e
except Skipped as e:
- from _pytest.warnings import _issue_warning_captured
-
- _issue_warning_captured(
- PytestConfigWarning("skipped plugin {!r}: {}".format(modname, e.msg)),
- self.hook,
- stacklevel=2,
- )
+ self.skipped_plugins.append((modname, e.msg or ""))
else:
mod = sys.modules[importspec]
self.register(mod, modname)
-def _get_plugin_specs_as_list(specs):
- """
- Parses a list of "plugin specs" and returns a list of plugin names.
-
- Plugin specs can be given as a list of strings separated by "," or already as a list/tuple in
- which case it is returned as a list. Specs can also be `None` in which case an
- empty list is returned.
- """
- if specs is not None and not isinstance(specs, types.ModuleType):
- if isinstance(specs, str):
- specs = specs.split(",") if specs else []
- if not isinstance(specs, (list, tuple)):
- raise UsageError(
- "Plugin specs must be a ','-separated string or a "
- "list/tuple of strings for plugin names. Given: %r" % specs
- )
+def _get_plugin_specs_as_list(
+ specs: Union[None, types.ModuleType, str, Sequence[str]]
+) -> List[str]:
+ """Parse a plugins specification into a list of plugin names."""
+ # None means empty.
+ if specs is None:
+ return []
+ # Workaround for #3899 - a submodule which happens to be called "pytest_plugins".
+ if isinstance(specs, types.ModuleType):
+ return []
+ # Comma-separated list.
+ if isinstance(specs, str):
+ return specs.split(",") if specs else []
+ # Direct specification.
+ if isinstance(specs, collections.abc.Sequence):
return list(specs)
- return []
+ raise UsageError(
+ "Plugins may be specified as a sequence or a ','-separated string of plugin names. Got: %r"
+ % specs
+ )
-def _ensure_removed_sysmodule(modname):
+def _ensure_removed_sysmodule(modname: str) -> None:
try:
del sys.modules[modname]
except KeyError:
@@ -664,11 +750,12 @@ class Notset:
notset = Notset()
-def _iter_rewritable_modules(package_files):
- """
- Given an iterable of file names in a source distribution, return the "names" that should
- be marked for assertion rewrite (for example the package "pytest_mock/__init__.py" should
- be added as "pytest_mock" in the assertion rewrite mechanism.
+def _iter_rewritable_modules(package_files: Iterable[str]) -> Iterator[str]:
+ """Given an iterable of file names in a source distribution, return the "names" that should
+ be marked for assertion rewrite.
+
+ For example the package "pytest_mock/__init__.py" should be added as "pytest_mock" in
+ the assertion rewrite mechanism.
This function has to deal with dist-info based distributions and egg based distributions
(which are still very much in use for "editable" installs).
@@ -712,11 +799,11 @@ def _iter_rewritable_modules(package_files):
yield package_name
if not seen_some:
- # at this point we did not find any packages or modules suitable for assertion
+ # At this point we did not find any packages or modules suitable for assertion
# rewriting, so we try again by stripping the first path component (to account for
- # "src" based source trees for example)
- # this approach lets us have the common case continue to be fast, as egg-distributions
- # are rarer
+ # "src" based source trees for example).
+ # This approach lets us have the common case continue to be fast, as egg-distributions
+ # are rarer.
new_package_files = []
for fn in package_files:
parts = fn.split("/")
@@ -727,29 +814,27 @@ def _iter_rewritable_modules(package_files):
yield from _iter_rewritable_modules(new_package_files)
-class Config:
- """
- Access to configuration values, pluginmanager and plugin hooks.
+def _args_converter(args: Iterable[str]) -> Tuple[str, ...]:
+ return tuple(args)
- :ivar PytestPluginManager pluginmanager: the plugin manager handles plugin registration and hook invocation.
- :ivar argparse.Namespace option: access to command line option as attributes.
+@final
+class Config:
+ """Access to configuration values, pluginmanager and plugin hooks.
- :ivar InvocationParams invocation_params:
+ :param PytestPluginManager pluginmanager:
- Object containing the parameters regarding the ``pytest.main``
+ :param InvocationParams invocation_params:
+ Object containing parameters regarding the :func:`pytest.main`
invocation.
-
- Contains the following read-only attributes:
-
- * ``args``: tuple of command-line arguments as passed to ``pytest.main()``.
- * ``plugins``: list of extra plugins, might be None.
- * ``dir``: directory where ``pytest.main()`` was invoked from.
"""
+ @final
@attr.s(frozen=True)
class InvocationParams:
- """Holds parameters passed during ``pytest.main()``
+ """Holds parameters passed during :func:`pytest.main`.
+
+ The object attributes are read-only.
.. versionadded:: 5.1
@@ -761,33 +846,64 @@ class Config:
Plugins accessing ``InvocationParams`` must be aware of that.
"""
- args = attr.ib(converter=tuple)
- plugins = attr.ib()
+ args = attr.ib(type=Tuple[str, ...], converter=_args_converter)
+ """The command-line arguments as passed to :func:`pytest.main`.
+
+ :type: Tuple[str, ...]
+ """
+ plugins = attr.ib(type=Optional[Sequence[Union[str, _PluggyPlugin]]])
+ """Extra plugins, might be `None`.
+
+ :type: Optional[Sequence[Union[str, plugin]]]
+ """
dir = attr.ib(type=Path)
+ """The directory from which :func:`pytest.main` was invoked.
- def __init__(self, pluginmanager, *, invocation_params=None) -> None:
+ :type: pathlib.Path
+ """
+
+ def __init__(
+ self,
+ pluginmanager: PytestPluginManager,
+ *,
+ invocation_params: Optional[InvocationParams] = None,
+ ) -> None:
from .argparsing import Parser, FILE_OR_DIR
if invocation_params is None:
invocation_params = self.InvocationParams(
- args=(), plugins=None, dir=Path().resolve()
+ args=(), plugins=None, dir=Path.cwd()
)
self.option = argparse.Namespace()
+ """Access to command line option as attributes.
+
+ :type: argparse.Namespace
+ """
+
self.invocation_params = invocation_params
+ """The parameters with which pytest was invoked.
+
+ :type: InvocationParams
+ """
_a = FILE_OR_DIR
self._parser = Parser(
- usage="%(prog)s [options] [{}] [{}] [...]".format(_a, _a),
+ usage=f"%(prog)s [options] [{_a}] [{_a}] [...]",
processopt=self._processopt,
)
self.pluginmanager = pluginmanager
+ """The plugin manager handles plugin registration and hook invocation.
+
+ :type: PytestPluginManager
+ """
+
self.trace = self.pluginmanager.trace.root.get("config")
self.hook = self.pluginmanager.hook
- self._inicache = {} # type: Dict[str, Any]
- self._override_ini = () # type: Sequence[str]
- self._opt2dest = {} # type: Dict[str, str]
- self._cleanup = [] # type: List[Callable[[], None]]
+ self._inicache: Dict[str, Any] = {}
+ self._override_ini: Sequence[str] = ()
+ self._opt2dest: Dict[str, str] = {}
+ self._cleanup: List[Callable[[], None]] = []
# A place where plugins can store information on the config for their
# own use. Currently only intended for internal plugins.
self._store = Store()
@@ -800,26 +916,72 @@ class Config:
if TYPE_CHECKING:
from _pytest.cacheprovider import Cache
- self.cache = None # type: Optional[Cache]
+ self.cache: Optional[Cache] = None
@property
- def invocation_dir(self):
- """Backward compatibility"""
+ def invocation_dir(self) -> py.path.local:
+ """The directory from which pytest was invoked.
+
+ Prefer to use :attr:`invocation_params.dir <InvocationParams.dir>`,
+ which is a :class:`pathlib.Path`.
+
+ :type: py.path.local
+ """
return py.path.local(str(self.invocation_params.dir))
- def add_cleanup(self, func):
- """ Add a function to be called when the config object gets out of
+ @property
+ def rootpath(self) -> Path:
+ """The path to the :ref:`rootdir <rootdir>`.
+
+ :type: pathlib.Path
+
+ .. versionadded:: 6.1
+ """
+ return self._rootpath
+
+ @property
+ def rootdir(self) -> py.path.local:
+ """The path to the :ref:`rootdir <rootdir>`.
+
+ Prefer to use :attr:`rootpath`, which is a :class:`pathlib.Path`.
+
+ :type: py.path.local
+ """
+ return py.path.local(str(self.rootpath))
+
+ @property
+ def inipath(self) -> Optional[Path]:
+ """The path to the :ref:`configfile <configfiles>`.
+
+ :type: Optional[pathlib.Path]
+
+ .. versionadded:: 6.1
+ """
+ return self._inipath
+
+ @property
+ def inifile(self) -> Optional[py.path.local]:
+ """The path to the :ref:`configfile <configfiles>`.
+
+ Prefer to use :attr:`inipath`, which is a :class:`pathlib.Path`.
+
+ :type: Optional[py.path.local]
+ """
+ return py.path.local(str(self.inipath)) if self.inipath else None
+
+ def add_cleanup(self, func: Callable[[], None]) -> None:
+ """Add a function to be called when the config object gets out of
use (usually coninciding with pytest_unconfigure)."""
self._cleanup.append(func)
- def _do_configure(self):
+ def _do_configure(self) -> None:
assert not self._configured
self._configured = True
with warnings.catch_warnings():
warnings.simplefilter("default")
self.hook.pytest_configure.call_historic(kwargs=dict(config=self))
- def _ensure_unconfigure(self):
+ def _ensure_unconfigure(self) -> None:
if self._configured:
self._configured = False
self.hook.pytest_unconfigure(config=self)
@@ -828,10 +990,15 @@ class Config:
fin = self._cleanup.pop()
fin()
- def get_terminal_writer(self):
- return self.pluginmanager.get_plugin("terminalreporter")._tw
+ def get_terminal_writer(self) -> TerminalWriter:
+ terminalreporter: TerminalReporter = self.pluginmanager.get_plugin(
+ "terminalreporter"
+ )
+ return terminalreporter._tw
- def pytest_cmdline_parse(self, pluginmanager, args):
+ def pytest_cmdline_parse(
+ self, pluginmanager: PytestPluginManager, args: List[str]
+ ) -> "Config":
try:
self.parse(args)
except UsageError:
@@ -855,9 +1022,13 @@ class Config:
return self
- def notify_exception(self, excinfo, option=None):
+ def notify_exception(
+ self,
+ excinfo: ExceptionInfo[BaseException],
+ option: Optional[argparse.Namespace] = None,
+ ) -> None:
if option and getattr(option, "fulltrace", False):
- style = "long"
+ style: _TracebackStyle = "long"
else:
style = "native"
excrepr = excinfo.getrepr(
@@ -869,16 +1040,16 @@ class Config:
sys.stderr.write("INTERNALERROR> %s\n" % line)
sys.stderr.flush()
- def cwd_relative_nodeid(self, nodeid):
- # nodeid's are relative to the rootpath, compute relative to cwd
- if self.invocation_dir != self.rootdir:
- fullpath = self.rootdir.join(nodeid)
- nodeid = self.invocation_dir.bestrelpath(fullpath)
+ def cwd_relative_nodeid(self, nodeid: str) -> str:
+ # nodeid's are relative to the rootpath, compute relative to cwd.
+ if self.invocation_params.dir != self.rootpath:
+ fullpath = self.rootpath / nodeid
+ nodeid = bestrelpath(self.invocation_params.dir, fullpath)
return nodeid
@classmethod
- def fromdictargs(cls, option_dict, args):
- """ constructor usable for subprocesses. """
+ def fromdictargs(cls, option_dict, args) -> "Config":
+ """Constructor usable for subprocesses."""
config = get_config(args)
config.option.__dict__.update(option_dict)
config.parse(args, addopts=False)
@@ -895,24 +1066,32 @@ class Config:
setattr(self.option, opt.dest, opt.default)
@hookimpl(trylast=True)
- def pytest_load_initial_conftests(self, early_config):
+ def pytest_load_initial_conftests(self, early_config: "Config") -> None:
self.pluginmanager._set_initial_conftests(early_config.known_args_namespace)
def _initini(self, args: Sequence[str]) -> None:
ns, unknown_args = self._parser.parse_known_and_unknown_args(
args, namespace=copy.copy(self.option)
)
- r = determine_setup(
+ rootpath, inipath, inicfg = determine_setup(
ns.inifilename,
ns.file_or_dir + unknown_args,
rootdir_cmd_arg=ns.rootdir or None,
config=self,
)
- self.rootdir, self.inifile, self.inicfg = r
- self._parser.extra_info["rootdir"] = self.rootdir
- self._parser.extra_info["inifile"] = self.inifile
+ self._rootpath = rootpath
+ self._inipath = inipath
+ self.inicfg = inicfg
+ self._parser.extra_info["rootdir"] = str(self.rootpath)
+ self._parser.extra_info["inifile"] = str(self.inipath)
self._parser.addini("addopts", "extra command line options", "args")
self._parser.addini("minversion", "minimally required pytest version")
+ self._parser.addini(
+ "required_plugins",
+ "plugins that must be present for pytest to run",
+ type="args",
+ default=[],
+ )
self._override_ini = ns.override_ini or ()
def _consider_importhook(self, args: Sequence[str]) -> None:
@@ -933,14 +1112,12 @@ class Config:
mode = "plain"
else:
self._mark_plugins_for_rewrite(hook)
- _warn_about_missing_assertion(mode)
+ self._warn_about_missing_assertion(mode)
- def _mark_plugins_for_rewrite(self, hook):
- """
- Given an importhook, mark for rewrite any top-level
+ def _mark_plugins_for_rewrite(self, hook) -> None:
+ """Given an importhook, mark for rewrite any top-level
modules or packages in the distribution package for
- all pytest plugins.
- """
+ all pytest plugins."""
self.pluginmanager.rewrite_hook = hook
if os.environ.get("PYTEST_DISABLE_PLUGIN_AUTOLOAD"):
@@ -983,6 +1160,9 @@ class Config:
self._validate_args(self.getini("addopts"), "via addopts config") + args
)
+ self.known_args_namespace = self._parser.parse_known_args(
+ args, namespace=copy.copy(self.option)
+ )
self._checkversion()
self._consider_importhook(args)
self.pluginmanager.consider_preparse(args, exclude_only=False)
@@ -991,50 +1171,109 @@ class Config:
# plugins are going to be loaded.
self.pluginmanager.load_setuptools_entrypoints("pytest11")
self.pluginmanager.consider_env()
- self.known_args_namespace = ns = self._parser.parse_known_args(
- args, namespace=copy.copy(self.option)
+
+ self.known_args_namespace = self._parser.parse_known_args(
+ args, namespace=copy.copy(self.known_args_namespace)
)
- if self.known_args_namespace.confcutdir is None and self.inifile:
- confcutdir = py.path.local(self.inifile).dirname
+
+ self._validate_plugins()
+ self._warn_about_skipped_plugins()
+
+ if self.known_args_namespace.strict:
+ self.issue_config_time_warning(
+ _pytest.deprecated.STRICT_OPTION, stacklevel=2
+ )
+
+ if self.known_args_namespace.confcutdir is None and self.inipath is not None:
+ confcutdir = str(self.inipath.parent)
self.known_args_namespace.confcutdir = confcutdir
try:
self.hook.pytest_load_initial_conftests(
early_config=self, args=args, parser=self._parser
)
except ConftestImportFailure as e:
- if ns.help or ns.version:
+ if self.known_args_namespace.help or self.known_args_namespace.version:
# we don't want to prevent --help/--version to work
# so just let is pass and print a warning at the end
- from _pytest.warnings import _issue_warning_captured
-
- _issue_warning_captured(
- PytestConfigWarning(
- "could not load initial conftests: {}".format(e.path)
- ),
- self.hook,
+ self.issue_config_time_warning(
+ PytestConfigWarning(f"could not load initial conftests: {e.path}"),
stacklevel=2,
)
else:
raise
- def _checkversion(self):
+ @hookimpl(hookwrapper=True)
+ def pytest_collection(self) -> Generator[None, None, None]:
+ """Validate invalid ini keys after collection is done so we take in account
+ options added by late-loading conftest files."""
+ yield
+ self._validate_config_options()
+
+ def _checkversion(self) -> None:
import pytest
minver = self.inicfg.get("minversion", None)
if minver:
+ # Imported lazily to improve start-up time.
+ from packaging.version import Version
+
+ if not isinstance(minver, str):
+ raise pytest.UsageError(
+ "%s: 'minversion' must be a single value" % self.inipath
+ )
+
if Version(minver) > Version(pytest.__version__):
raise pytest.UsageError(
- "%s:%d: requires pytest-%s, actual pytest-%s'"
- % (
- self.inicfg.config.path,
- self.inicfg.lineof("minversion"),
- minver,
- pytest.__version__,
- )
+ "%s: 'minversion' requires pytest-%s, actual pytest-%s'"
+ % (self.inipath, minver, pytest.__version__,)
)
+ def _validate_config_options(self) -> None:
+ for key in sorted(self._get_unknown_ini_keys()):
+ self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")
+
+ def _validate_plugins(self) -> None:
+ required_plugins = sorted(self.getini("required_plugins"))
+ if not required_plugins:
+ return
+
+ # Imported lazily to improve start-up time.
+ from packaging.version import Version
+ from packaging.requirements import InvalidRequirement, Requirement
+
+ plugin_info = self.pluginmanager.list_plugin_distinfo()
+ plugin_dist_info = {dist.project_name: dist.version for _, dist in plugin_info}
+
+ missing_plugins = []
+ for required_plugin in required_plugins:
+ try:
+ spec = Requirement(required_plugin)
+ except InvalidRequirement:
+ missing_plugins.append(required_plugin)
+ continue
+
+ if spec.name not in plugin_dist_info:
+ missing_plugins.append(required_plugin)
+ elif Version(plugin_dist_info[spec.name]) not in spec.specifier:
+ missing_plugins.append(required_plugin)
+
+ if missing_plugins:
+ raise UsageError(
+ "Missing required plugins: {}".format(", ".join(missing_plugins)),
+ )
+
+ def _warn_or_fail_if_strict(self, message: str) -> None:
+ if self.known_args_namespace.strict_config:
+ raise UsageError(message)
+
+ self.issue_config_time_warning(PytestConfigWarning(message), stacklevel=3)
+
+ def _get_unknown_ini_keys(self) -> List[str]:
+ parser_inicfg = self._parser._inidict
+ return [name for name in self.inicfg if name not in parser_inicfg]
+
def parse(self, args: List[str], addopts: bool = True) -> None:
- # parse given cmdline arguments into this config object.
+ # Parse given cmdline arguments into this config object.
assert not hasattr(
self, "args"
), "can only parse cmdline args at most once per Config object"
@@ -1050,40 +1289,85 @@ class Config:
args, self.option, namespace=self.option
)
if not args:
- if self.invocation_dir == self.rootdir:
+ if self.invocation_params.dir == self.rootpath:
args = self.getini("testpaths")
if not args:
- args = [str(self.invocation_dir)]
+ args = [str(self.invocation_params.dir)]
self.args = args
except PrintHelp:
pass
- def addinivalue_line(self, name, line):
- """ add a line to an ini-file option. The option must have been
- declared but might not yet be set in which case the line becomes the
- the first line in its value. """
+ def issue_config_time_warning(self, warning: Warning, stacklevel: int) -> None:
+ """Issue and handle a warning during the "configure" stage.
+
+ During ``pytest_configure`` we can't capture warnings using the ``catch_warnings_for_item``
+ function because it is not possible to have hookwrappers around ``pytest_configure``.
+
+ This function is mainly intended for plugins that need to issue warnings during
+ ``pytest_configure`` (or similar stages).
+
+ :param warning: The warning instance.
+ :param stacklevel: stacklevel forwarded to warnings.warn.
+ """
+ if self.pluginmanager.is_blocked("warnings"):
+ return
+
+ cmdline_filters = self.known_args_namespace.pythonwarnings or []
+ config_filters = self.getini("filterwarnings")
+
+ with warnings.catch_warnings(record=True) as records:
+ warnings.simplefilter("always", type(warning))
+ apply_warning_filters(config_filters, cmdline_filters)
+ warnings.warn(warning, stacklevel=stacklevel)
+
+ if records:
+ frame = sys._getframe(stacklevel - 1)
+ location = frame.f_code.co_filename, frame.f_lineno, frame.f_code.co_name
+ self.hook.pytest_warning_captured.call_historic(
+ kwargs=dict(
+ warning_message=records[0],
+ when="config",
+ item=None,
+ location=location,
+ )
+ )
+ self.hook.pytest_warning_recorded.call_historic(
+ kwargs=dict(
+ warning_message=records[0],
+ when="config",
+ nodeid="",
+ location=location,
+ )
+ )
+
+ def addinivalue_line(self, name: str, line: str) -> None:
+ """Add a line to an ini-file option. The option must have been
+ declared but might not yet be set in which case the line becomes
+ the first line in its value."""
x = self.getini(name)
assert isinstance(x, list)
x.append(line) # modifies the cached list inline
def getini(self, name: str):
- """ return configuration value from an :ref:`ini file <inifiles>`. If the
- specified name hasn't been registered through a prior
+ """Return configuration value from an :ref:`ini file <configfiles>`.
+
+ If the specified name hasn't been registered through a prior
:py:func:`parser.addini <_pytest.config.argparsing.Parser.addini>`
- call (usually from a plugin), a ValueError is raised. """
+ call (usually from a plugin), a ValueError is raised.
+ """
try:
return self._inicache[name]
except KeyError:
self._inicache[name] = val = self._getini(name)
return val
- def _getini(self, name: str) -> Any:
+ def _getini(self, name: str):
try:
description, type, default = self._parser._inidict[name]
- except KeyError:
- raise ValueError("unknown configuration value: {!r}".format(name))
- value = self._get_override_ini_value(name)
- if value is None:
+ except KeyError as e:
+ raise ValueError(f"unknown configuration value: {name!r}") from e
+ override_value = self._get_override_ini_value(name)
+ if override_value is None:
try:
value = self.inicfg[name]
except KeyError:
@@ -1092,62 +1376,86 @@ class Config:
if type is None:
return ""
return []
+ else:
+ value = override_value
+ # Coerce the values based on types.
+ #
+ # Note: some coercions are only required if we are reading from .ini files, because
+ # the file format doesn't contain type information, but when reading from toml we will
+ # get either str or list of str values (see _parse_ini_config_from_pyproject_toml).
+ # For example:
+ #
+ # ini:
+ # a_line_list = "tests acceptance"
+ # in this case, we need to split the string to obtain a list of strings.
+ #
+ # toml:
+ # a_line_list = ["tests", "acceptance"]
+ # in this case, we already have a list ready to use.
+ #
if type == "pathlist":
- dp = py.path.local(self.inicfg.config.path).dirpath()
- values = []
- for relpath in shlex.split(value):
- values.append(dp.join(relpath, abs=True))
- return values
+ # TODO: This assert is probably not valid in all cases.
+ assert self.inipath is not None
+ dp = self.inipath.parent
+ input_values = shlex.split(value) if isinstance(value, str) else value
+ return [py.path.local(str(dp / x)) for x in input_values]
elif type == "args":
- return shlex.split(value)
+ return shlex.split(value) if isinstance(value, str) else value
elif type == "linelist":
- return [t for t in map(lambda x: x.strip(), value.split("\n")) if t]
+ if isinstance(value, str):
+ return [t for t in map(lambda x: x.strip(), value.split("\n")) if t]
+ else:
+ return value
elif type == "bool":
- return bool(_strtobool(value.strip()))
+ return _strtobool(str(value).strip())
else:
- assert type is None
+ assert type in [None, "string"]
return value
- def _getconftest_pathlist(self, name, path):
+ def _getconftest_pathlist(
+ self, name: str, path: py.path.local
+ ) -> Optional[List[py.path.local]]:
try:
- mod, relroots = self.pluginmanager._rget_with_confmod(name, path)
+ mod, relroots = self.pluginmanager._rget_with_confmod(
+ name, path, self.getoption("importmode")
+ )
except KeyError:
return None
modpath = py.path.local(mod.__file__).dirpath()
- values = []
+ values: List[py.path.local] = []
for relroot in relroots:
if not isinstance(relroot, py.path.local):
- relroot = relroot.replace("/", py.path.local.sep)
+ relroot = relroot.replace("/", os.sep)
relroot = modpath.join(relroot, abs=True)
values.append(relroot)
return values
def _get_override_ini_value(self, name: str) -> Optional[str]:
value = None
- # override_ini is a list of "ini=value" options
- # always use the last item if multiple values are set for same ini-name,
- # e.g. -o foo=bar1 -o foo=bar2 will set foo to bar2
+ # override_ini is a list of "ini=value" options.
+ # Always use the last item if multiple values are set for same ini-name,
+ # e.g. -o foo=bar1 -o foo=bar2 will set foo to bar2.
for ini_config in self._override_ini:
try:
key, user_ini_value = ini_config.split("=", 1)
- except ValueError:
+ except ValueError as e:
raise UsageError(
"-o/--override-ini expects option=value style (got: {!r}).".format(
ini_config
)
- )
+ ) from e
else:
if key == name:
value = user_ini_value
return value
def getoption(self, name: str, default=notset, skip: bool = False):
- """ return command line option value.
+ """Return command line option value.
- :arg name: name of the option. You may also specify
+ :param name: Name of the option. You may also specify
the literal ``--OPT`` option instead of the "dest" option name.
- :arg default: default value if no option of that name exists.
- :arg skip: if True raise pytest.skip if option does not exists
+ :param default: Default value if no option of that name exists.
+ :param skip: If True, raise pytest.skip if option does not exists
or has a None value.
"""
name = self._opt2dest.get(name, name)
@@ -1156,77 +1464,143 @@ class Config:
if val is None and skip:
raise AttributeError(name)
return val
- except AttributeError:
+ except AttributeError as e:
if default is not notset:
return default
if skip:
import pytest
- pytest.skip("no {!r} option found".format(name))
- raise ValueError("no option named {!r}".format(name))
+ pytest.skip(f"no {name!r} option found")
+ raise ValueError(f"no option named {name!r}") from e
- def getvalue(self, name, path=None):
- """ (deprecated, use getoption()) """
+ def getvalue(self, name: str, path=None):
+ """Deprecated, use getoption() instead."""
return self.getoption(name)
- def getvalueorskip(self, name, path=None):
- """ (deprecated, use getoption(skip=True)) """
+ def getvalueorskip(self, name: str, path=None):
+ """Deprecated, use getoption(skip=True) instead."""
return self.getoption(name, skip=True)
+ def _warn_about_missing_assertion(self, mode: str) -> None:
+ if not _assertion_supported():
+ if mode == "plain":
+ warning_text = (
+ "ASSERTIONS ARE NOT EXECUTED"
+ " and FAILING TESTS WILL PASS. Are you"
+ " using python -O?"
+ )
+ else:
+ warning_text = (
+ "assertions not in test modules or"
+ " plugins will be ignored"
+ " because assert statements are not executed "
+ "by the underlying Python interpreter "
+ "(are you using python -O?)\n"
+ )
+ self.issue_config_time_warning(
+ PytestConfigWarning(warning_text), stacklevel=3,
+ )
+
+ def _warn_about_skipped_plugins(self) -> None:
+ for module_name, msg in self.pluginmanager.skipped_plugins:
+ self.issue_config_time_warning(
+ PytestConfigWarning(f"skipped plugin {module_name!r}: {msg}"),
+ stacklevel=2,
+ )
+
-def _assertion_supported():
+def _assertion_supported() -> bool:
try:
assert False
except AssertionError:
return True
else:
- return False
-
+ return False # type: ignore[unreachable]
-def _warn_about_missing_assertion(mode):
- if not _assertion_supported():
- if mode == "plain":
- sys.stderr.write(
- "WARNING: ASSERTIONS ARE NOT EXECUTED"
- " and FAILING TESTS WILL PASS. Are you"
- " using python -O?"
- )
- else:
- sys.stderr.write(
- "WARNING: assertions not in test modules or"
- " plugins will be ignored"
- " because assert statements are not executed "
- "by the underlying Python interpreter "
- "(are you using python -O?)\n"
- )
-
-def create_terminal_writer(config: Config, *args, **kwargs) -> TerminalWriter:
+def create_terminal_writer(
+ config: Config, file: Optional[TextIO] = None
+) -> TerminalWriter:
"""Create a TerminalWriter instance configured according to the options
- in the config object. Every code which requires a TerminalWriter object
- and has access to a config object should use this function.
+ in the config object.
+
+ Every code which requires a TerminalWriter object and has access to a
+ config object should use this function.
"""
- tw = TerminalWriter(*args, **kwargs)
+ tw = TerminalWriter(file=file)
+
if config.option.color == "yes":
tw.hasmarkup = True
- if config.option.color == "no":
+ elif config.option.color == "no":
tw.hasmarkup = False
+
+ if config.option.code_highlight == "yes":
+ tw.code_highlight = True
+ elif config.option.code_highlight == "no":
+ tw.code_highlight = False
+
return tw
-def _strtobool(val):
- """Convert a string representation of truth to true (1) or false (0).
+def _strtobool(val: str) -> bool:
+ """Convert a string representation of truth to True or False.
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
- .. note:: copied from distutils.util
+ .. note:: Copied from distutils.util.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
- return 1
+ return True
elif val in ("n", "no", "f", "false", "off", "0"):
- return 0
+ return False
else:
- raise ValueError("invalid truth value {!r}".format(val))
+ raise ValueError(f"invalid truth value {val!r}")
+
+
+@lru_cache(maxsize=50)
+def parse_warning_filter(
+ arg: str, *, escape: bool
+) -> Tuple[str, str, Type[Warning], str, int]:
+ """Parse a warnings filter string.
+
+ This is copied from warnings._setoption, but does not apply the filter,
+ only parses it, and makes the escaping optional.
+ """
+ parts = arg.split(":")
+ if len(parts) > 5:
+ raise warnings._OptionError(f"too many fields (max 5): {arg!r}")
+ while len(parts) < 5:
+ parts.append("")
+ action_, message, category_, module, lineno_ = [s.strip() for s in parts]
+ action: str = warnings._getaction(action_) # type: ignore[attr-defined]
+ category: Type[Warning] = warnings._getcategory(category_) # type: ignore[attr-defined]
+ if message and escape:
+ message = re.escape(message)
+ if module and escape:
+ module = re.escape(module) + r"\Z"
+ if lineno_:
+ try:
+ lineno = int(lineno_)
+ if lineno < 0:
+ raise ValueError
+ except (ValueError, OverflowError) as e:
+ raise warnings._OptionError(f"invalid lineno {lineno_!r}") from e
+ else:
+ lineno = 0
+ return action, message, category, module, lineno
+
+
+def apply_warning_filters(
+ config_filters: Iterable[str], cmdline_filters: Iterable[str]
+) -> None:
+ """Applies pytest-configured filters to the warnings module"""
+ # Filters should have this precedence: cmdline options, config.
+ # Filters should be applied in the inverse order of precedence.
+ for arg in config_filters:
+ warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
+
+ for arg in cmdline_filters:
+ warnings.filterwarnings(*parse_warning_filter(arg, escape=True))
diff --git a/contrib/python/pytest/py3/_pytest/config/argparsing.py b/contrib/python/pytest/py3/_pytest/config/argparsing.py
index 140e04e972..9a48196552 100644
--- a/contrib/python/pytest/py3/_pytest/config/argparsing.py
+++ b/contrib/python/pytest/py3/_pytest/config/argparsing.py
@@ -11,28 +11,31 @@ from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
+from typing import TYPE_CHECKING
from typing import Union
import py
-from _pytest.compat import TYPE_CHECKING
+import _pytest._io
+from _pytest.compat import final
from _pytest.config.exceptions import UsageError
if TYPE_CHECKING:
from typing import NoReturn
- from typing_extensions import Literal # noqa: F401
+ from typing_extensions import Literal
FILE_OR_DIR = "file_or_dir"
+@final
class Parser:
- """ Parser for command line arguments and ini-file values.
+ """Parser for command line arguments and ini-file values.
- :ivar extra_info: dict of generic param -> value to display in case
+ :ivar extra_info: Dict of generic param -> value to display in case
there's an error processing the command line arguments.
"""
- prog = None # type: Optional[str]
+ prog: Optional[str] = None
def __init__(
self,
@@ -40,12 +43,12 @@ class Parser:
processopt: Optional[Callable[["Argument"], None]] = None,
) -> None:
self._anonymous = OptionGroup("custom options", parser=self)
- self._groups = [] # type: List[OptionGroup]
+ self._groups: List[OptionGroup] = []
self._processopt = processopt
self._usage = usage
- self._inidict = {} # type: Dict[str, Tuple[str, Optional[str], Any]]
- self._ininames = [] # type: List[str]
- self.extra_info = {} # type: Dict[str, Any]
+ self._inidict: Dict[str, Tuple[str, Optional[str], Any]] = {}
+ self._ininames: List[str] = []
+ self.extra_info: Dict[str, Any] = {}
def processoption(self, option: "Argument") -> None:
if self._processopt:
@@ -55,11 +58,11 @@ class Parser:
def getgroup(
self, name: str, description: str = "", after: Optional[str] = None
) -> "OptionGroup":
- """ get (or create) a named option Group.
+ """Get (or create) a named option Group.
- :name: name of the option group.
- :description: long description for --help output.
- :after: name of other group, used for ordering --help output.
+ :name: Name of the option group.
+ :description: Long description for --help output.
+ :after: Name of another group, used for ordering --help output.
The returned group object has an ``addoption`` method with the same
signature as :py:func:`parser.addoption
@@ -78,15 +81,14 @@ class Parser:
return group
def addoption(self, *opts: str, **attrs: Any) -> None:
- """ register a command line option.
+ """Register a command line option.
- :opts: option names, can be short or long options.
- :attrs: same attributes which the ``add_argument()`` function of the
- `argparse library
- <https://docs.python.org/library/argparse.html>`_
+ :opts: Option names, can be short or long options.
+ :attrs: Same attributes which the ``add_argument()`` function of the
+ `argparse library <https://docs.python.org/library/argparse.html>`_
accepts.
- After command line parsing options are available on the pytest config
+ After command line parsing, options are available on the pytest config
object via ``config.option.NAME`` where ``NAME`` is usually set
by passing a ``dest`` attribute, for example
``addoption("--long", dest="NAME", ...)``.
@@ -140,9 +142,7 @@ class Parser:
args: Sequence[Union[str, py.path.local]],
namespace: Optional[argparse.Namespace] = None,
) -> argparse.Namespace:
- """parses and returns a namespace object with known arguments at this
- point.
- """
+ """Parse and return a namespace object with known arguments at this point."""
return self.parse_known_and_unknown_args(args, namespace=namespace)[0]
def parse_known_and_unknown_args(
@@ -150,9 +150,8 @@ class Parser:
args: Sequence[Union[str, py.path.local]],
namespace: Optional[argparse.Namespace] = None,
) -> Tuple[argparse.Namespace, List[str]]:
- """parses and returns a namespace object with known arguments, and
- the remaining arguments unknown at this point.
- """
+ """Parse and return a namespace object with known arguments, and
+ the remaining arguments unknown at this point."""
optparser = self._getparser()
strargs = [str(x) if isinstance(x, py.path.local) else x for x in args]
return optparser.parse_known_args(strargs, namespace=namespace)
@@ -161,29 +160,30 @@ class Parser:
self,
name: str,
help: str,
- type: Optional["Literal['pathlist', 'args', 'linelist', 'bool']"] = None,
+ type: Optional[
+ "Literal['string', 'pathlist', 'args', 'linelist', 'bool']"
+ ] = None,
default=None,
) -> None:
- """ register an ini-file option.
+ """Register an ini-file option.
- :name: name of the ini-variable
- :type: type of the variable, can be ``pathlist``, ``args``, ``linelist``
- or ``bool``.
- :default: default value if no ini-file option exists but is queried.
+ :name: Name of the ini-variable.
+ :type: Type of the variable, can be ``string``, ``pathlist``, ``args``,
+ ``linelist`` or ``bool``. Defaults to ``string`` if ``None`` or
+ not passed.
+ :default: Default value if no ini-file option exists but is queried.
The value of ini-variables can be retrieved via a call to
:py:func:`config.getini(name) <_pytest.config.Config.getini>`.
"""
- assert type in (None, "pathlist", "args", "linelist", "bool")
+ assert type in (None, "string", "pathlist", "args", "linelist", "bool")
self._inidict[name] = (help, type, default)
self._ininames.append(name)
class ArgumentError(Exception):
- """
- Raised if an Argument instance is created with invalid or
- inconsistent arguments.
- """
+ """Raised if an Argument instance is created with invalid or
+ inconsistent arguments."""
def __init__(self, msg: str, option: Union["Argument", str]) -> None:
self.msg = msg
@@ -191,26 +191,27 @@ class ArgumentError(Exception):
def __str__(self) -> str:
if self.option_id:
- return "option {}: {}".format(self.option_id, self.msg)
+ return f"option {self.option_id}: {self.msg}"
else:
return self.msg
class Argument:
- """class that mimics the necessary behaviour of optparse.Option
+ """Class that mimics the necessary behaviour of optparse.Option.
+
+ It's currently a least effort implementation and ignoring choices
+ and integer prefixes.
- it's currently a least effort implementation
- and ignoring choices and integer prefixes
https://docs.python.org/3/library/optparse.html#optparse-standard-option-types
"""
_typ_map = {"int": int, "string": str, "float": float, "complex": complex}
def __init__(self, *names: str, **attrs: Any) -> None:
- """store parms in private vars for use in add_argument"""
+ """Store parms in private vars for use in add_argument."""
self._attrs = attrs
- self._short_opts = [] # type: List[str]
- self._long_opts = [] # type: List[str]
+ self._short_opts: List[str] = []
+ self._long_opts: List[str] = []
if "%default" in (attrs.get("help") or ""):
warnings.warn(
'pytest now uses argparse. "%default" should be'
@@ -223,7 +224,7 @@ class Argument:
except KeyError:
pass
else:
- # this might raise a keyerror as well, don't want to catch that
+ # This might raise a keyerror as well, don't want to catch that.
if isinstance(typ, str):
if typ == "choice":
warnings.warn(
@@ -246,17 +247,17 @@ class Argument:
stacklevel=4,
)
attrs["type"] = Argument._typ_map[typ]
- # used in test_parseopt -> test_parse_defaultgetter
+ # Used in test_parseopt -> test_parse_defaultgetter.
self.type = attrs["type"]
else:
self.type = typ
try:
- # attribute existence is tested in Config._processopt
+ # Attribute existence is tested in Config._processopt.
self.default = attrs["default"]
except KeyError:
pass
self._set_opt_strings(names)
- dest = attrs.get("dest") # type: Optional[str]
+ dest: Optional[str] = attrs.get("dest")
if dest:
self.dest = dest
elif self._long_opts:
@@ -264,15 +265,15 @@ class Argument:
else:
try:
self.dest = self._short_opts[0][1:]
- except IndexError:
+ except IndexError as e:
self.dest = "???" # Needed for the error repr.
- raise ArgumentError("need a long or short option", self)
+ raise ArgumentError("need a long or short option", self) from e
def names(self) -> List[str]:
return self._short_opts + self._long_opts
def attrs(self) -> Mapping[str, Any]:
- # update any attributes set by processopt
+ # Update any attributes set by processopt.
attrs = "default dest help".split()
attrs.append(self.dest)
for attr in attrs:
@@ -288,9 +289,10 @@ class Argument:
return self._attrs
def _set_opt_strings(self, opts: Sequence[str]) -> None:
- """directly from optparse
+ """Directly from optparse.
- might not be necessary as this is passed to argparse later on"""
+ Might not be necessary as this is passed to argparse later on.
+ """
for opt in opts:
if len(opt) < 2:
raise ArgumentError(
@@ -316,7 +318,7 @@ class Argument:
self._long_opts.append(opt)
def __repr__(self) -> str:
- args = [] # type: List[str]
+ args: List[str] = []
if self._short_opts:
args += ["_short_opts: " + repr(self._short_opts)]
if self._long_opts:
@@ -335,16 +337,16 @@ class OptionGroup:
) -> None:
self.name = name
self.description = description
- self.options = [] # type: List[Argument]
+ self.options: List[Argument] = []
self.parser = parser
def addoption(self, *optnames: str, **attrs: Any) -> None:
- """ add an option to this group.
+ """Add an option to this group.
- if a shortened version of a long option is specified it will
+ If a shortened version of a long option is specified, it will
be suppressed in the help. addoption('--twowords', '--two-words')
results in help showing '--two-words' only, but --twowords gets
- accepted **and** the automatic destination is in args.twowords
+ accepted **and** the automatic destination is in args.twowords.
"""
conflict = set(optnames).intersection(
name for opt in self.options for name in opt.names()
@@ -385,16 +387,16 @@ class MyOptionParser(argparse.ArgumentParser):
allow_abbrev=False,
)
# extra_info is a dict of (param -> value) to display if there's
- # an usage error to provide more contextual information to the user
+ # an usage error to provide more contextual information to the user.
self.extra_info = extra_info if extra_info else {}
def error(self, message: str) -> "NoReturn":
"""Transform argparse error message into UsageError."""
- msg = "{}: error: {}".format(self.prog, message)
+ msg = f"{self.prog}: error: {message}"
if hasattr(self._parser, "_config_source_hint"):
# Type ignored because the attribute is set dynamically.
- msg = "{} ({})".format(msg, self._parser._config_source_hint) # type: ignore
+ msg = f"{msg} ({self._parser._config_source_hint})" # type: ignore
raise UsageError(self.format_usage() + msg)
@@ -404,14 +406,14 @@ class MyOptionParser(argparse.ArgumentParser):
args: Optional[Sequence[str]] = None,
namespace: Optional[argparse.Namespace] = None,
) -> argparse.Namespace:
- """allow splitting of positional arguments"""
+ """Allow splitting of positional arguments."""
parsed, unrecognized = self.parse_known_args(args, namespace)
if unrecognized:
for arg in unrecognized:
if arg and arg[0] == "-":
lines = ["unrecognized arguments: %s" % (" ".join(unrecognized))]
for k, v in sorted(self.extra_info.items()):
- lines.append(" {}: {}".format(k, v))
+ lines.append(f" {k}: {v}")
self.error("\n".join(lines))
getattr(parsed, FILE_OR_DIR).extend(unrecognized)
return parsed
@@ -456,26 +458,24 @@ class MyOptionParser(argparse.ArgumentParser):
class DropShorterLongHelpFormatter(argparse.HelpFormatter):
- """shorten help for long options that differ only in extra hyphens
+ """Shorten help for long options that differ only in extra hyphens.
- - collapse **long** options that are the same except for extra hyphens
- - shortcut if there are only two options and one of them is a short one
- - cache result on action object as this is called at least 2 times
+ - Collapse **long** options that are the same except for extra hyphens.
+ - Shortcut if there are only two options and one of them is a short one.
+ - Cache result on the action object as this is called at least 2 times.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
- """Use more accurate terminal width via pylib."""
+ # Use more accurate terminal width.
if "width" not in kwargs:
- kwargs["width"] = py.io.get_terminal_width()
+ kwargs["width"] = _pytest._io.get_terminal_width()
super().__init__(*args, **kwargs)
def _format_action_invocation(self, action: argparse.Action) -> str:
orgstr = argparse.HelpFormatter._format_action_invocation(self, action)
if orgstr and orgstr[0] != "-": # only optional arguments
return orgstr
- res = getattr(
- action, "_formatted_action_invocation", None
- ) # type: Optional[str]
+ res: Optional[str] = getattr(action, "_formatted_action_invocation", None)
if res:
return res
options = orgstr.split(", ")
@@ -484,7 +484,7 @@ class DropShorterLongHelpFormatter(argparse.HelpFormatter):
action._formatted_action_invocation = orgstr # type: ignore
return orgstr
return_list = []
- short_long = {} # type: Dict[str, str]
+ short_long: Dict[str, str] = {}
for option in options:
if len(option) == 2 or option[2] == " ":
continue
@@ -508,3 +508,15 @@ class DropShorterLongHelpFormatter(argparse.HelpFormatter):
formatted_action_invocation = ", ".join(return_list)
action._formatted_action_invocation = formatted_action_invocation # type: ignore
return formatted_action_invocation
+
+ def _split_lines(self, text, width):
+ """Wrap lines after splitting on original newlines.
+
+ This allows to have explicit line breaks in the help text.
+ """
+ import textwrap
+
+ lines = []
+ for line in text.splitlines():
+ lines.extend(textwrap.wrap(line.strip(), width))
+ return lines
diff --git a/contrib/python/pytest/py3/_pytest/config/exceptions.py b/contrib/python/pytest/py3/_pytest/config/exceptions.py
index 19fe5cb08e..4f1320e758 100644
--- a/contrib/python/pytest/py3/_pytest/config/exceptions.py
+++ b/contrib/python/pytest/py3/_pytest/config/exceptions.py
@@ -1,9 +1,11 @@
+from _pytest.compat import final
+
+
+@final
class UsageError(Exception):
- """ error in pytest usage or invocation"""
+ """Error in pytest usage or invocation."""
class PrintHelp(Exception):
- """Raised when pytest should print it's help to skip the rest of the
+ """Raised when pytest should print its help to skip the rest of the
argument parsing and validation."""
-
- pass
diff --git a/contrib/python/pytest/py3/_pytest/config/findpaths.py b/contrib/python/pytest/py3/_pytest/config/findpaths.py
index fb84160c1f..2edf54536b 100644
--- a/contrib/python/pytest/py3/_pytest/config/findpaths.py
+++ b/contrib/python/pytest/py3/_pytest/config/findpaths.py
@@ -1,111 +1,165 @@
import os
-from typing import Any
+from pathlib import Path
+from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
+from typing import Sequence
from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
-import py
+import iniconfig
from .exceptions import UsageError
-from _pytest.compat import TYPE_CHECKING
from _pytest.outcomes import fail
+from _pytest.pathlib import absolutepath
+from _pytest.pathlib import commonpath
if TYPE_CHECKING:
- from . import Config # noqa: F401
+ from . import Config
-def exists(path, ignore=EnvironmentError):
+def _parse_ini_config(path: Path) -> iniconfig.IniConfig:
+ """Parse the given generic '.ini' file using legacy IniConfig parser, returning
+ the parsed object.
+
+ Raise UsageError if the file cannot be parsed.
+ """
try:
- return path.check()
- except ignore:
- return False
+ return iniconfig.IniConfig(str(path))
+ except iniconfig.ParseError as exc:
+ raise UsageError(str(exc)) from exc
-def getcfg(args, config=None):
- """
- Search the list of arguments for a valid ini-file for pytest,
- and return a tuple of (rootdir, inifile, cfg-dict).
+def load_config_dict_from_file(
+ filepath: Path,
+) -> Optional[Dict[str, Union[str, List[str]]]]:
+ """Load pytest configuration from the given file path, if supported.
- note: config is optional and used only to issue warnings explicitly (#2891).
+ Return None if the file does not contain valid pytest configuration.
"""
- inibasenames = ["pytest.ini", "tox.ini", "setup.cfg"]
+
+ # Configuration from ini files are obtained from the [pytest] section, if present.
+ if filepath.suffix == ".ini":
+ iniconfig = _parse_ini_config(filepath)
+
+ if "pytest" in iniconfig:
+ return dict(iniconfig["pytest"].items())
+ else:
+ # "pytest.ini" files are always the source of configuration, even if empty.
+ if filepath.name == "pytest.ini":
+ return {}
+
+ # '.cfg' files are considered if they contain a "[tool:pytest]" section.
+ elif filepath.suffix == ".cfg":
+ iniconfig = _parse_ini_config(filepath)
+
+ if "tool:pytest" in iniconfig.sections:
+ return dict(iniconfig["tool:pytest"].items())
+ elif "pytest" in iniconfig.sections:
+ # If a setup.cfg contains a "[pytest]" section, we raise a failure to indicate users that
+ # plain "[pytest]" sections in setup.cfg files is no longer supported (#3086).
+ fail(CFG_PYTEST_SECTION.format(filename="setup.cfg"), pytrace=False)
+
+ # '.toml' files are considered if they contain a [tool.pytest.ini_options] table.
+ elif filepath.suffix == ".toml":
+ import toml
+
+ config = toml.load(str(filepath))
+
+ result = config.get("tool", {}).get("pytest", {}).get("ini_options", None)
+ if result is not None:
+ # TOML supports richer data types than ini files (strings, arrays, floats, ints, etc),
+ # however we need to convert all scalar values to str for compatibility with the rest
+ # of the configuration system, which expects strings only.
+ def make_scalar(v: object) -> Union[str, List[str]]:
+ return v if isinstance(v, list) else str(v)
+
+ return {k: make_scalar(v) for k, v in result.items()}
+
+ return None
+
+
+def locate_config(
+ args: Iterable[Path],
+) -> Tuple[
+ Optional[Path], Optional[Path], Dict[str, Union[str, List[str]]],
+]:
+ """Search in the list of arguments for a valid ini-file for pytest,
+ and return a tuple of (rootdir, inifile, cfg-dict)."""
+ config_names = [
+ "pytest.ini",
+ "pyproject.toml",
+ "tox.ini",
+ "setup.cfg",
+ ]
args = [x for x in args if not str(x).startswith("-")]
if not args:
- args = [py.path.local()]
+ args = [Path.cwd()]
for arg in args:
- arg = py.path.local(arg)
- for base in arg.parts(reverse=True):
- for inibasename in inibasenames:
- p = base.join(inibasename)
- if exists(p):
- try:
- iniconfig = py.iniconfig.IniConfig(p)
- except py.iniconfig.ParseError as exc:
- raise UsageError(str(exc))
-
- if (
- inibasename == "setup.cfg"
- and "tool:pytest" in iniconfig.sections
- ):
- return base, p, iniconfig["tool:pytest"]
- elif "pytest" in iniconfig.sections:
- if inibasename == "setup.cfg" and config is not None:
-
- fail(
- CFG_PYTEST_SECTION.format(filename=inibasename),
- pytrace=False,
- )
- return base, p, iniconfig["pytest"]
- elif inibasename == "pytest.ini":
- # allowed to be empty
- return base, p, {}
- return None, None, None
-
-
-def get_common_ancestor(paths: Iterable[py.path.local]) -> py.path.local:
- common_ancestor = None
+ argpath = absolutepath(arg)
+ for base in (argpath, *argpath.parents):
+ for config_name in config_names:
+ p = base / config_name
+ if p.is_file():
+ ini_config = load_config_dict_from_file(p)
+ if ini_config is not None:
+ return base, p, ini_config
+ return None, None, {}
+
+
+def get_common_ancestor(paths: Iterable[Path]) -> Path:
+ common_ancestor: Optional[Path] = None
for path in paths:
if not path.exists():
continue
if common_ancestor is None:
common_ancestor = path
else:
- if path.relto(common_ancestor) or path == common_ancestor:
+ if common_ancestor in path.parents or path == common_ancestor:
continue
- elif common_ancestor.relto(path):
+ elif path in common_ancestor.parents:
common_ancestor = path
else:
- shared = path.common(common_ancestor)
+ shared = commonpath(path, common_ancestor)
if shared is not None:
common_ancestor = shared
if common_ancestor is None:
- common_ancestor = py.path.local()
- elif common_ancestor.isfile():
- common_ancestor = common_ancestor.dirpath()
+ common_ancestor = Path.cwd()
+ elif common_ancestor.is_file():
+ common_ancestor = common_ancestor.parent
return common_ancestor
-def get_dirs_from_args(args):
- def is_option(x):
- return str(x).startswith("-")
+def get_dirs_from_args(args: Iterable[str]) -> List[Path]:
+ def is_option(x: str) -> bool:
+ return x.startswith("-")
- def get_file_part_from_node_id(x):
- return str(x).split("::")[0]
+ def get_file_part_from_node_id(x: str) -> str:
+ return x.split("::")[0]
- def get_dir_from_path(path):
- if path.isdir():
+ def get_dir_from_path(path: Path) -> Path:
+ if path.is_dir():
return path
- return py.path.local(path.dirname)
+ return path.parent
+
+ def safe_exists(path: Path) -> bool:
+ # This can throw on paths that contain characters unrepresentable at the OS level,
+ # or with invalid syntax on Windows (https://bugs.python.org/issue35306)
+ try:
+ return path.exists()
+ except OSError:
+ return False
# These look like paths but may not exist
possible_paths = (
- py.path.local(get_file_part_from_node_id(arg))
+ absolutepath(get_file_part_from_node_id(arg))
for arg in args
if not is_option(arg)
)
- return [get_dir_from_path(path) for path in possible_paths if path.exists()]
+ return [get_dir_from_path(path) for path in possible_paths if safe_exists(path)]
CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead."
@@ -113,55 +167,45 @@ CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supporte
def determine_setup(
inifile: Optional[str],
- args: List[str],
+ args: Sequence[str],
rootdir_cmd_arg: Optional[str] = None,
config: Optional["Config"] = None,
-) -> Tuple[py.path.local, Optional[str], Any]:
+) -> Tuple[Path, Optional[Path], Dict[str, Union[str, List[str]]]]:
+ rootdir = None
dirs = get_dirs_from_args(args)
if inifile:
- iniconfig = py.iniconfig.IniConfig(inifile)
- is_cfg_file = str(inifile).endswith(".cfg")
- sections = ["tool:pytest", "pytest"] if is_cfg_file else ["pytest"]
- for section in sections:
- try:
- inicfg = iniconfig[
- section
- ] # type: Optional[py.iniconfig._SectionWrapper]
- if is_cfg_file and section == "pytest" and config is not None:
- fail(
- CFG_PYTEST_SECTION.format(filename=str(inifile)), pytrace=False
- )
- break
- except KeyError:
- inicfg = None
+ inipath_ = absolutepath(inifile)
+ inipath: Optional[Path] = inipath_
+ inicfg = load_config_dict_from_file(inipath_) or {}
if rootdir_cmd_arg is None:
rootdir = get_common_ancestor(dirs)
else:
ancestor = get_common_ancestor(dirs)
- rootdir, inifile, inicfg = getcfg([ancestor], config=config)
+ rootdir, inipath, inicfg = locate_config([ancestor])
if rootdir is None and rootdir_cmd_arg is None:
- for possible_rootdir in ancestor.parts(reverse=True):
- if possible_rootdir.join("setup.py").exists():
+ for possible_rootdir in (ancestor, *ancestor.parents):
+ if (possible_rootdir / "setup.py").is_file():
rootdir = possible_rootdir
break
else:
if dirs != [ancestor]:
- rootdir, inifile, inicfg = getcfg(dirs, config=config)
+ rootdir, inipath, inicfg = locate_config(dirs)
if rootdir is None:
if config is not None:
- cwd = config.invocation_dir
+ cwd = config.invocation_params.dir
else:
- cwd = py.path.local()
+ cwd = Path.cwd()
rootdir = get_common_ancestor([cwd, ancestor])
is_fs_root = os.path.splitdrive(str(rootdir))[1] == "/"
if is_fs_root:
rootdir = ancestor
if rootdir_cmd_arg:
- rootdir = py.path.local(os.path.expandvars(rootdir_cmd_arg))
- if not rootdir.isdir():
+ rootdir = absolutepath(os.path.expandvars(rootdir_cmd_arg))
+ if not rootdir.is_dir():
raise UsageError(
"Directory '{}' not found. Check your '--rootdir' option.".format(
rootdir
)
)
- return rootdir, inifile, inicfg or {}
+ assert rootdir is not None
+ return rootdir, inipath, inicfg or {}
diff --git a/contrib/python/pytest/py3/_pytest/debugging.py b/contrib/python/pytest/py3/_pytest/debugging.py
index 07f212fa5e..b52840006b 100644
--- a/contrib/python/pytest/py3/_pytest/debugging.py
+++ b/contrib/python/pytest/py3/_pytest/debugging.py
@@ -1,13 +1,33 @@
-""" interactive debugging with PDB, the Python Debugger. """
+"""Interactive debugging with PDB, the Python Debugger."""
import argparse
import functools
import os
import sys
+import types
+from typing import Any
+from typing import Callable
+from typing import Generator
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from _pytest import outcomes
+from _pytest._code import ExceptionInfo
+from _pytest.config import Config
from _pytest.config import ConftestImportFailure
from _pytest.config import hookimpl
+from _pytest.config import PytestPluginManager
+from _pytest.config.argparsing import Parser
from _pytest.config.exceptions import UsageError
+from _pytest.nodes import Node
+from _pytest.reports import BaseReport
+
+if TYPE_CHECKING:
+ from _pytest.capture import CaptureManager
+ from _pytest.runner import CallInfo
def import_readline():
@@ -46,18 +66,18 @@ def tty():
sys.path = old_sys_path
-def _validate_usepdb_cls(value):
+def _validate_usepdb_cls(value: str) -> Tuple[str, str]:
"""Validate syntax of --pdbcls option."""
try:
modname, classname = value.split(":")
- except ValueError:
+ except ValueError as e:
raise argparse.ArgumentTypeError(
- "{!r} is not in the format 'modname:classname'".format(value)
- )
+ f"{value!r} is not in the format 'modname:classname'"
+ ) from e
return (modname, classname)
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group._addoption(
"--pdb",
@@ -81,7 +101,7 @@ def pytest_addoption(parser):
)
-def pytest_configure(config):
+def pytest_configure(config: Config) -> None:
import pdb
if config.getvalue("trace"):
@@ -98,7 +118,7 @@ def pytest_configure(config):
# NOTE: not using pytest_unconfigure, since it might get called although
# pytest_configure was not (if another plugin raises UsageError).
- def fin():
+ def fin() -> None:
(
pdb.set_trace,
pytestPDB._pluginmanager,
@@ -109,22 +129,24 @@ def pytest_configure(config):
class pytestPDB:
- """ Pseudo PDB that defers to the real pdb. """
+ """Pseudo PDB that defers to the real pdb."""
- _pluginmanager = None
- _config = None
- _saved = [] # type: list
+ _pluginmanager: Optional[PytestPluginManager] = None
+ _config: Optional[Config] = None
+ _saved: List[
+ Tuple[Callable[..., None], Optional[PytestPluginManager], Optional[Config]]
+ ] = []
_recursive_debug = 0
- _wrapped_pdb_cls = None
+ _wrapped_pdb_cls: Optional[Tuple[Type[Any], Type[Any]]] = None
@classmethod
- def _is_capturing(cls, capman):
+ def _is_capturing(cls, capman: Optional["CaptureManager"]) -> Union[str, bool]:
if capman:
return capman.is_capturing()
return False
@classmethod
- def _import_pdb_cls(cls, capman):
+ def _import_pdb_cls(cls, capman: Optional["CaptureManager"]):
if not cls._config:
import pdb
@@ -151,8 +173,8 @@ class pytestPDB:
except Exception as exc:
value = ":".join((modname, classname))
raise UsageError(
- "--pdbcls: could not import {!r}: {}".format(value, exc)
- )
+ f"--pdbcls: could not import {value!r}: {exc}"
+ ) from exc
else:
import pdb
@@ -163,10 +185,12 @@ class pytestPDB:
return wrapped_cls
@classmethod
- def _get_pdb_wrapper_class(cls, pdb_cls, capman):
+ def _get_pdb_wrapper_class(cls, pdb_cls, capman: Optional["CaptureManager"]):
import _pytest.config
- class PytestPdbWrapper(pdb_cls):
+ # Type ignored because mypy doesn't support "dynamic"
+ # inheritance like this.
+ class PytestPdbWrapper(pdb_cls): # type: ignore[valid-type,misc]
_pytest_capman = capman
_continued = False
@@ -179,6 +203,7 @@ class pytestPDB:
def do_continue(self, arg):
ret = super().do_continue(arg)
if cls._recursive_debug == 0:
+ assert cls._config is not None
tw = _pytest.config.create_terminal_writer(cls._config)
tw.line()
@@ -193,9 +218,11 @@ class pytestPDB:
"PDB continue (IO-capturing resumed for %s)"
% capturing,
)
+ assert capman is not None
capman.resume()
else:
tw.sep(">", "PDB continue")
+ assert cls._pluginmanager is not None
cls._pluginmanager.hook.pytest_leave_pdb(config=cls._config, pdb=self)
self._continued = True
return ret
@@ -246,13 +273,13 @@ class pytestPDB:
@classmethod
def _init_pdb(cls, method, *args, **kwargs):
- """ Initialize PDB debugging, dropping any IO capturing. """
+ """Initialize PDB debugging, dropping any IO capturing."""
import _pytest.config
- if cls._pluginmanager is not None:
- capman = cls._pluginmanager.getplugin("capturemanager")
+ if cls._pluginmanager is None:
+ capman: Optional[CaptureManager] = None
else:
- capman = None
+ capman = cls._pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend(in_=True)
@@ -268,7 +295,7 @@ class pytestPDB:
else:
capturing = cls._is_capturing(capman)
if capturing == "global":
- tw.sep(">", "PDB {} (IO-capturing turned off)".format(method))
+ tw.sep(">", f"PDB {method} (IO-capturing turned off)")
elif capturing:
tw.sep(
">",
@@ -276,7 +303,7 @@ class pytestPDB:
% (method, capturing),
)
else:
- tw.sep(">", "PDB {}".format(method))
+ tw.sep(">", f"PDB {method}")
_pdb = cls._import_pdb_cls(capman)(**kwargs)
@@ -285,7 +312,7 @@ class pytestPDB:
return _pdb
@classmethod
- def set_trace(cls, *args, **kwargs):
+ def set_trace(cls, *args, **kwargs) -> None:
"""Invoke debugging via ``Pdb.set_trace``, dropping any IO capturing."""
tty()
frame = sys._getframe().f_back
@@ -294,7 +321,9 @@ class pytestPDB:
class PdbInvoke:
- def pytest_exception_interact(self, node, call, report):
+ def pytest_exception_interact(
+ self, node: Node, call: "CallInfo[Any]", report: BaseReport
+ ) -> None:
capman = node.config.pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend_global_capture(in_=True)
@@ -302,31 +331,32 @@ class PdbInvoke:
sys.stdout.write(out)
sys.stdout.write(err)
tty()
+ assert call.excinfo is not None
_enter_pdb(node, call.excinfo, report)
- def pytest_internalerror(self, excrepr, excinfo):
+ def pytest_internalerror(self, excinfo: ExceptionInfo[BaseException]) -> None:
tb = _postmortem_traceback(excinfo)
post_mortem(tb)
class PdbTrace:
@hookimpl(hookwrapper=True)
- def pytest_pyfunc_call(self, pyfuncitem):
+ def pytest_pyfunc_call(self, pyfuncitem) -> Generator[None, None, None]:
wrap_pytest_function_for_tracing(pyfuncitem)
yield
def wrap_pytest_function_for_tracing(pyfuncitem):
- """Changes the python function object of the given Function item by a wrapper which actually
- enters pdb before calling the python function itself, effectively leaving the user
- in the pdb prompt in the first statement of the function.
- """
+ """Change the Python function object of the given Function item by a
+ wrapper which actually enters pdb before calling the python function
+ itself, effectively leaving the user in the pdb prompt in the first
+ statement of the function."""
_pdb = pytestPDB._init_pdb("runcall")
testfunction = pyfuncitem.obj
# we can't just return `partial(pdb.runcall, testfunction)` because (on
# python < 3.7.4) runcall's first param is `func`, which means we'd get
- # an exception if one of the kwargs to testfunction was called `func`
+ # an exception if one of the kwargs to testfunction was called `func`.
@functools.wraps(testfunction)
def wrapper(*args, **kwargs):
func = functools.partial(testfunction, *args, **kwargs)
@@ -337,12 +367,14 @@ def wrap_pytest_function_for_tracing(pyfuncitem):
def maybe_wrap_pytest_function_for_tracing(pyfuncitem):
"""Wrap the given pytestfunct item for tracing support if --trace was given in
- the command line"""
+ the command line."""
if pyfuncitem.config.getvalue("trace"):
wrap_pytest_function_for_tracing(pyfuncitem)
-def _enter_pdb(node, excinfo, rep):
+def _enter_pdb(
+ node: Node, excinfo: ExceptionInfo[BaseException], rep: BaseReport
+) -> BaseReport:
# XXX we re-use the TerminalReporter's terminalwriter
# because this seems to avoid some encoding related troubles
# for not completely clear reasons.
@@ -366,12 +398,12 @@ def _enter_pdb(node, excinfo, rep):
rep.toterminal(tw)
tw.sep(">", "entering PDB")
tb = _postmortem_traceback(excinfo)
- rep._pdbshown = True
+ rep._pdbshown = True # type: ignore[attr-defined]
post_mortem(tb)
return rep
-def _postmortem_traceback(excinfo):
+def _postmortem_traceback(excinfo: ExceptionInfo[BaseException]) -> types.TracebackType:
from doctest import UnexpectedException
if isinstance(excinfo.value, UnexpectedException):
@@ -383,10 +415,11 @@ def _postmortem_traceback(excinfo):
# Use the underlying exception instead:
return excinfo.value.excinfo[2]
else:
+ assert excinfo._excinfo is not None
return excinfo._excinfo[2]
-def post_mortem(t):
+def post_mortem(t: types.TracebackType) -> None:
p = pytestPDB._init_pdb("post_mortem")
p.reset()
p.interaction(None, t)
diff --git a/contrib/python/pytest/py3/_pytest/deprecated.py b/contrib/python/pytest/py3/_pytest/deprecated.py
index b11093910a..19b31d6653 100644
--- a/contrib/python/pytest/py3/_pytest/deprecated.py
+++ b/contrib/python/pytest/py3/_pytest/deprecated.py
@@ -1,13 +1,15 @@
-"""
-This module contains deprecation messages and bits of code used elsewhere in the codebase
-that is planned to be removed in the next pytest release.
+"""Deprecation messages and bits of code used elsewhere in the codebase that
+is planned to be removed in the next pytest release.
Keeping it in a central location makes it easy to track what is deprecated and should
be removed when the time comes.
-All constants defined in this module should be either PytestWarning instances or UnformattedWarning
+All constants defined in this module should be either instances of
+:class:`PytestWarning`, or :class:`UnformattedWarning`
in case of warnings which need to format their messages.
"""
+from warnings import warn
+
from _pytest.warning_types import PytestDeprecationWarning
from _pytest.warning_types import UnformattedWarning
@@ -19,44 +21,67 @@ DEPRECATED_EXTERNAL_PLUGINS = {
"pytest_faulthandler",
}
-FUNCARGNAMES = PytestDeprecationWarning(
- "The `funcargnames` attribute was an alias for `fixturenames`, "
- "since pytest 2.3 - use the newer attribute instead."
-)
-RESULT_LOG = PytestDeprecationWarning(
- "--result-log is deprecated, please try the new pytest-reportlog plugin.\n"
- "See https://docs.pytest.org/en/latest/deprecations.html#result-log-result-log for more information."
+FILLFUNCARGS = UnformattedWarning(
+ PytestDeprecationWarning,
+ "{name} is deprecated, use "
+ "function._request._fillfixtures() instead if you cannot avoid reaching into internals.",
)
-FIXTURE_POSITIONAL_ARGUMENTS = PytestDeprecationWarning(
- "Passing arguments to pytest.fixture() as positional arguments is deprecated - pass them "
- "as a keyword argument instead."
+PYTEST_COLLECT_MODULE = UnformattedWarning(
+ PytestDeprecationWarning,
+ "pytest.collect.{name} was moved to pytest.{name}\n"
+ "Please update to the new name.",
)
-NODE_USE_FROM_PARENT = UnformattedWarning(
- PytestDeprecationWarning,
- "direct construction of {name} has been deprecated, please use {name}.from_parent",
+YIELD_FIXTURE = PytestDeprecationWarning(
+ "@pytest.yield_fixture is deprecated.\n"
+ "Use @pytest.fixture instead; they are the same."
)
-JUNIT_XML_DEFAULT_FAMILY = PytestDeprecationWarning(
- "The 'junit_family' default value will change to 'xunit2' in pytest 6.0.\n"
- "Add 'junit_family=xunit1' to your pytest.ini file to keep the current format "
- "in future versions of pytest and silence this warning."
+MINUS_K_DASH = PytestDeprecationWarning(
+ "The `-k '-expr'` syntax to -k is deprecated.\nUse `-k 'not expr'` instead."
)
-NO_PRINT_LOGS = PytestDeprecationWarning(
- "--no-print-logs is deprecated and scheduled for removal in pytest 6.0.\n"
- "Please use --show-capture instead."
+MINUS_K_COLON = PytestDeprecationWarning(
+ "The `-k 'expr:'` syntax to -k is deprecated.\n"
+ "Please open an issue if you use this and want a replacement."
)
-COLLECT_DIRECTORY_HOOK = PytestDeprecationWarning(
- "The pytest_collect_directory hook is not working.\n"
- "Please use collect_ignore in conftests or pytest_collection_modifyitems."
+WARNING_CAPTURED_HOOK = PytestDeprecationWarning(
+ "The pytest_warning_captured is deprecated and will be removed in a future release.\n"
+ "Please use pytest_warning_recorded instead."
)
+FSCOLLECTOR_GETHOOKPROXY_ISINITPATH = PytestDeprecationWarning(
+ "The gethookproxy() and isinitpath() methods of FSCollector and Package are deprecated; "
+ "use self.session.gethookproxy() and self.session.isinitpath() instead. "
+)
-TERMINALWRITER_WRITER = PytestDeprecationWarning(
- "The TerminalReporter.writer attribute is deprecated, use TerminalReporter._tw instead at your own risk.\n"
- "See https://docs.pytest.org/en/latest/deprecations.html#terminalreporter-writer for more information."
+STRICT_OPTION = PytestDeprecationWarning(
+ "The --strict option is deprecated, use --strict-markers instead."
)
+
+PRIVATE = PytestDeprecationWarning("A private pytest class or function was used.")
+
+
+# You want to make some `__init__` or function "private".
+#
+# def my_private_function(some, args):
+# ...
+#
+# Do this:
+#
+# def my_private_function(some, args, *, _ispytest: bool = False):
+# check_ispytest(_ispytest)
+# ...
+#
+# Change all internal/allowed calls to
+#
+# my_private_function(some, args, _ispytest=True)
+#
+# All other calls will get the default _ispytest=False and trigger
+# the warning (possibly error in the future).
+def check_ispytest(ispytest: bool) -> None:
+ if not ispytest:
+ warn(PRIVATE, stacklevel=3)
diff --git a/contrib/python/pytest/py3/_pytest/doctest.py b/contrib/python/pytest/py3/_pytest/doctest.py
index e1dd9691cc..64e8f0e0ee 100644
--- a/contrib/python/pytest/py3/_pytest/doctest.py
+++ b/contrib/python/pytest/py3/_pytest/doctest.py
@@ -1,16 +1,24 @@
-""" discover and run doctests in modules and test files."""
+"""Discover and run doctests in modules and test files."""
import bdb
import inspect
import platform
import sys
import traceback
+import types
import warnings
from contextlib import contextmanager
+from typing import Any
+from typing import Callable
from typing import Dict
+from typing import Generator
+from typing import Iterable
from typing import List
from typing import Optional
+from typing import Pattern
from typing import Sequence
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import Union
import py.path
@@ -22,15 +30,17 @@ from _pytest._code.code import ReprFileLocation
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import safe_getattr
-from _pytest.compat import TYPE_CHECKING
+from _pytest.config import Config
+from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
+from _pytest.nodes import Collector
from _pytest.outcomes import OutcomeException
+from _pytest.pathlib import import_path
from _pytest.python_api import approx
from _pytest.warning_types import PytestWarning
if TYPE_CHECKING:
import doctest
- from typing import Type
DOCTEST_REPORT_CHOICE_NONE = "none"
DOCTEST_REPORT_CHOICE_CDIFF = "cdiff"
@@ -49,10 +59,10 @@ DOCTEST_REPORT_CHOICES = (
# Lazy definition of runner class
RUNNER_CLASS = None
# Lazy definition of output checker class
-CHECKER_CLASS = None # type: Optional[Type[doctest.OutputChecker]]
+CHECKER_CLASS: Optional[Type["doctest.OutputChecker"]] = None
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
parser.addini(
"doctest_optionflags",
"option flags for doctests",
@@ -102,19 +112,24 @@ def pytest_addoption(parser):
)
-def pytest_unconfigure():
+def pytest_unconfigure() -> None:
global RUNNER_CLASS
RUNNER_CLASS = None
-def pytest_collect_file(path: py.path.local, parent):
+def pytest_collect_file(
+ path: py.path.local, parent: Collector,
+) -> Optional[Union["DoctestModule", "DoctestTextfile"]]:
config = parent.config
if path.ext == ".py":
if config.option.doctestmodules and not _is_setup_py(path):
- return DoctestModule.from_parent(parent, fspath=path)
+ mod: DoctestModule = DoctestModule.from_parent(parent, fspath=path)
+ return mod
elif _is_doctest(config, path, parent):
- return DoctestTextfile.from_parent(parent, fspath=path)
+ txt: DoctestTextfile = DoctestTextfile.from_parent(parent, fspath=path)
+ return txt
+ return None
def _is_setup_py(path: py.path.local) -> bool:
@@ -124,7 +139,7 @@ def _is_setup_py(path: py.path.local) -> bool:
return b"setuptools" in contents or b"distutils" in contents
-def _is_doctest(config, path, parent):
+def _is_doctest(config: Config, path: py.path.local, parent) -> bool:
if path.ext in (".txt", ".rst") and parent.session.isinitpath(path):
return True
globs = config.getoption("doctestglob") or ["test*.txt"]
@@ -137,7 +152,7 @@ def _is_doctest(config, path, parent):
class ReprFailDoctest(TerminalRepr):
def __init__(
self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]]
- ):
+ ) -> None:
self.reprlocation_lines = reprlocation_lines
def toterminal(self, tw: TerminalWriter) -> None:
@@ -148,36 +163,49 @@ class ReprFailDoctest(TerminalRepr):
class MultipleDoctestFailures(Exception):
- def __init__(self, failures):
+ def __init__(self, failures: Sequence["doctest.DocTestFailure"]) -> None:
super().__init__()
self.failures = failures
-def _init_runner_class() -> "Type[doctest.DocTestRunner]":
+def _init_runner_class() -> Type["doctest.DocTestRunner"]:
import doctest
class PytestDoctestRunner(doctest.DebugRunner):
- """
- Runner to collect failures. Note that the out variable in this case is
- a list instead of a stdout-like object
+ """Runner to collect failures.
+
+ Note that the out variable in this case is a list instead of a
+ stdout-like object.
"""
def __init__(
- self, checker=None, verbose=None, optionflags=0, continue_on_failure=True
- ):
+ self,
+ checker: Optional["doctest.OutputChecker"] = None,
+ verbose: Optional[bool] = None,
+ optionflags: int = 0,
+ continue_on_failure: bool = True,
+ ) -> None:
doctest.DebugRunner.__init__(
self, checker=checker, verbose=verbose, optionflags=optionflags
)
self.continue_on_failure = continue_on_failure
- def report_failure(self, out, test, example, got):
+ def report_failure(
+ self, out, test: "doctest.DocTest", example: "doctest.Example", got: str,
+ ) -> None:
failure = doctest.DocTestFailure(test, example, got)
if self.continue_on_failure:
out.append(failure)
else:
raise failure
- def report_unexpected_exception(self, out, test, example, exc_info):
+ def report_unexpected_exception(
+ self,
+ out,
+ test: "doctest.DocTest",
+ example: "doctest.Example",
+ exc_info: Tuple[Type[BaseException], BaseException, types.TracebackType],
+ ) -> None:
if isinstance(exc_info[1], OutcomeException):
raise exc_info[1]
if isinstance(exc_info[1], bdb.BdbQuit):
@@ -212,24 +240,33 @@ def _get_runner(
class DoctestItem(pytest.Item):
- def __init__(self, name, parent, runner=None, dtest=None):
+ def __init__(
+ self,
+ name: str,
+ parent: "Union[DoctestTextfile, DoctestModule]",
+ runner: Optional["doctest.DocTestRunner"] = None,
+ dtest: Optional["doctest.DocTest"] = None,
+ ) -> None:
super().__init__(name, parent)
self.runner = runner
self.dtest = dtest
self.obj = None
- self.fixture_request = None
+ self.fixture_request: Optional[FixtureRequest] = None
@classmethod
def from_parent( # type: ignore
- cls, parent: "Union[DoctestTextfile, DoctestModule]", *, name, runner, dtest
+ cls,
+ parent: "Union[DoctestTextfile, DoctestModule]",
+ *,
+ name: str,
+ runner: "doctest.DocTestRunner",
+ dtest: "doctest.DocTest",
):
# incompatible signature due to to imposed limits on sublcass
- """
- the public named constructor
- """
+ """The public named constructor."""
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)
- def setup(self):
+ def setup(self) -> None:
if self.dtest is not None:
self.fixture_request = _setup_fixtures(self)
globs = dict(getfixture=self.fixture_request.getfixturevalue)
@@ -240,17 +277,19 @@ class DoctestItem(pytest.Item):
self.dtest.globs.update(globs)
def runtest(self) -> None:
+ assert self.dtest is not None
+ assert self.runner is not None
_check_all_skipped(self.dtest)
self._disable_output_capturing_for_darwin()
- failures = [] # type: List[doctest.DocTestFailure]
- self.runner.run(self.dtest, out=failures)
+ failures: List["doctest.DocTestFailure"] = []
+ # Type ignored because we change the type of `out` from what
+ # doctest expects.
+ self.runner.run(self.dtest, out=failures) # type: ignore[arg-type]
if failures:
raise MultipleDoctestFailures(failures)
- def _disable_output_capturing_for_darwin(self):
- """
- Disable output capturing. Otherwise, stdout is lost to doctest (#985)
- """
+ def _disable_output_capturing_for_darwin(self) -> None:
+ """Disable output capturing. Otherwise, stdout is lost to doctest (#985)."""
if platform.system() != "Darwin":
return
capman = self.config.pluginmanager.getplugin("capturemanager")
@@ -260,15 +299,20 @@ class DoctestItem(pytest.Item):
sys.stdout.write(out)
sys.stderr.write(err)
- def repr_failure(self, excinfo):
+ # TODO: Type ignored -- breaks Liskov Substitution.
+ def repr_failure( # type: ignore[override]
+ self, excinfo: ExceptionInfo[BaseException],
+ ) -> Union[str, TerminalRepr]:
import doctest
- failures = (
- None
- ) # type: Optional[List[Union[doctest.DocTestFailure, doctest.UnexpectedException]]]
- if excinfo.errisinstance((doctest.DocTestFailure, doctest.UnexpectedException)):
+ failures: Optional[
+ Sequence[Union[doctest.DocTestFailure, doctest.UnexpectedException]]
+ ] = (None)
+ if isinstance(
+ excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException)
+ ):
failures = [excinfo.value]
- elif excinfo.errisinstance(MultipleDoctestFailures):
+ elif isinstance(excinfo.value, MultipleDoctestFailures):
failures = excinfo.value.failures
if failures is not None:
@@ -282,7 +326,8 @@ class DoctestItem(pytest.Item):
else:
lineno = test.lineno + example.lineno + 1
message = type(failure).__name__
- reprlocation = ReprFileLocation(filename, lineno, message)
+ # TODO: ReprFileLocation doesn't expect a None lineno.
+ reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type]
checker = _get_checker()
report_choice = _get_report_choice(
self.config.getoption("doctestreport")
@@ -304,7 +349,7 @@ class DoctestItem(pytest.Item):
]
indent = ">>>"
for line in example.source.splitlines():
- lines.append("??? {} {}".format(indent, line))
+ lines.append(f"??? {indent} {line}")
indent = "..."
if isinstance(failure, doctest.DocTestFailure):
lines += checker.output_difference(
@@ -322,7 +367,8 @@ class DoctestItem(pytest.Item):
else:
return super().repr_failure(excinfo)
- def reportinfo(self) -> Tuple[py.path.local, int, str]:
+ def reportinfo(self):
+ assert self.dtest is not None
return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name
@@ -355,7 +401,7 @@ def _get_continue_on_failure(config):
continue_on_failure = config.getvalue("doctest_continue_on_failure")
if continue_on_failure:
# We need to turn off this if we use pdb since we should stop at
- # the first failure
+ # the first failure.
if config.getvalue("usepdb"):
continue_on_failure = False
return continue_on_failure
@@ -364,11 +410,11 @@ def _get_continue_on_failure(config):
class DoctestTextfile(pytest.Module):
obj = None
- def collect(self):
+ def collect(self) -> Iterable[DoctestItem]:
import doctest
- # inspired by doctest.testfile; ideally we would use it directly,
- # but it doesn't support passing a custom checker
+ # Inspired by doctest.testfile; ideally we would use it directly,
+ # but it doesn't support passing a custom checker.
encoding = self.config.getini("doctest_encoding")
text = self.fspath.read_text(encoding)
filename = str(self.fspath)
@@ -392,10 +438,9 @@ class DoctestTextfile(pytest.Module):
)
-def _check_all_skipped(test):
- """raises pytest.skip() if all examples in the given DocTest have the SKIP
- option set.
- """
+def _check_all_skipped(test: "doctest.DocTest") -> None:
+ """Raise pytest.skip() if all examples in the given DocTest have the SKIP
+ option set."""
import doctest
all_skipped = all(x.options.get(doctest.SKIP, False) for x in test.examples)
@@ -403,10 +448,9 @@ def _check_all_skipped(test):
pytest.skip("all tests skipped by +SKIP option")
-def _is_mocked(obj):
- """
- returns if a object is possibly a mock object by checking the existence of a highly improbable attribute
- """
+def _is_mocked(obj: object) -> bool:
+ """Return if an object is possibly a mock object by checking the
+ existence of a highly improbable attribute."""
return (
safe_getattr(obj, "pytest_mock_example_attribute_that_shouldnt_exist", None)
is not None
@@ -414,23 +458,24 @@ def _is_mocked(obj):
@contextmanager
-def _patch_unwrap_mock_aware():
- """
- contextmanager which replaces ``inspect.unwrap`` with a version
- that's aware of mock objects and doesn't recurse on them
- """
+def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
+ """Context manager which replaces ``inspect.unwrap`` with a version
+ that's aware of mock objects and doesn't recurse into them."""
real_unwrap = inspect.unwrap
- def _mock_aware_unwrap(obj, stop=None):
+ def _mock_aware_unwrap(
+ func: Callable[..., Any], *, stop: Optional[Callable[[Any], Any]] = None
+ ) -> Any:
try:
if stop is None or stop is _is_mocked:
- return real_unwrap(obj, stop=_is_mocked)
- return real_unwrap(obj, stop=lambda obj: _is_mocked(obj) or stop(obj))
+ return real_unwrap(func, stop=_is_mocked)
+ _stop = stop
+ return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func))
except Exception as e:
warnings.warn(
"Got %r when unwrapping %r. This is usually caused "
"by a violation of Python's object protocol; see e.g. "
- "https://github.com/pytest-dev/pytest/issues/5080" % (e, obj),
+ "https://github.com/pytest-dev/pytest/issues/5080" % (e, func),
PytestWarning,
)
raise
@@ -443,26 +488,28 @@ def _patch_unwrap_mock_aware():
class DoctestModule(pytest.Module):
- def collect(self):
+ def collect(self) -> Iterable[DoctestItem]:
import doctest
class MockAwareDocTestFinder(doctest.DocTestFinder):
- """
- a hackish doctest finder that overrides stdlib internals to fix a stdlib bug
+ """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug.
https://github.com/pytest-dev/pytest/issues/3456
https://bugs.python.org/issue25532
"""
def _find_lineno(self, obj, source_lines):
- """
- Doctest code does not take into account `@property`, this is a hackish way to fix it.
+ """Doctest code does not take into account `@property`, this
+ is a hackish way to fix it.
https://bugs.python.org/issue17446
"""
if isinstance(obj, property):
obj = getattr(obj, "fget", obj)
- return doctest.DocTestFinder._find_lineno(self, obj, source_lines)
+ # Type ignored because this is a private function.
+ return doctest.DocTestFinder._find_lineno( # type: ignore
+ self, obj, source_lines,
+ )
def _find(
self, tests, obj, name, module, source_lines, globs, seen
@@ -477,16 +524,18 @@ class DoctestModule(pytest.Module):
)
if self.fspath.basename == "conftest.py":
- module = self.config.pluginmanager._importconftest(self.fspath)
+ module = self.config.pluginmanager._importconftest(
+ self.fspath, self.config.getoption("importmode")
+ )
else:
try:
- module = self.fspath.pyimport()
+ module = import_path(self.fspath)
except ImportError:
if self.config.getvalue("doctest_ignore_import_errors"):
pytest.skip("unable to import module %r" % self.fspath)
else:
raise
- # uses internal doctest module parsing mechanism
+ # Uses internal doctest module parsing mechanism.
finder = MockAwareDocTestFinder()
optionflags = get_optionflags(self)
runner = _get_runner(
@@ -503,34 +552,30 @@ class DoctestModule(pytest.Module):
)
-def _setup_fixtures(doctest_item):
- """
- Used by DoctestTextfile and DoctestItem to setup fixture information.
- """
+def _setup_fixtures(doctest_item: DoctestItem) -> FixtureRequest:
+ """Used by DoctestTextfile and DoctestItem to setup fixture information."""
- def func():
+ def func() -> None:
pass
- doctest_item.funcargs = {}
+ doctest_item.funcargs = {} # type: ignore[attr-defined]
fm = doctest_item.session._fixturemanager
- doctest_item._fixtureinfo = fm.getfixtureinfo(
+ doctest_item._fixtureinfo = fm.getfixtureinfo( # type: ignore[attr-defined]
node=doctest_item, func=func, cls=None, funcargs=False
)
- fixture_request = FixtureRequest(doctest_item)
+ fixture_request = FixtureRequest(doctest_item, _ispytest=True)
fixture_request._fillfixtures()
return fixture_request
-def _init_checker_class() -> "Type[doctest.OutputChecker]":
+def _init_checker_class() -> Type["doctest.OutputChecker"]:
import doctest
import re
class LiteralsOutputChecker(doctest.OutputChecker):
- """
- Based on doctest_nose_plugin.py from the nltk project
- (https://github.com/nltk/nltk) and on the "numtest" doctest extension
- by Sebastien Boisgerault (https://github.com/boisgera/numtest).
- """
+ # Based on doctest_nose_plugin.py from the nltk project
+ # (https://github.com/nltk/nltk) and on the "numtest" doctest extension
+ # by Sebastien Boisgerault (https://github.com/boisgera/numtest).
_unicode_literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE)
_bytes_literal_re = re.compile(r"(\W|^)[bB]([rR]?[\'\"])", re.UNICODE)
@@ -557,7 +602,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
re.VERBOSE,
)
- def check_output(self, want, got, optionflags):
+ def check_output(self, want: str, got: str, optionflags: int) -> bool:
if doctest.OutputChecker.check_output(self, want, got, optionflags):
return True
@@ -568,7 +613,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
if not allow_unicode and not allow_bytes and not allow_number:
return False
- def remove_prefixes(regex, txt):
+ def remove_prefixes(regex: Pattern[str], txt: str) -> str:
return re.sub(regex, r"\1\2", txt)
if allow_unicode:
@@ -584,15 +629,15 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
return doctest.OutputChecker.check_output(self, want, got, optionflags)
- def _remove_unwanted_precision(self, want, got):
+ def _remove_unwanted_precision(self, want: str, got: str) -> str:
wants = list(self._number_re.finditer(want))
gots = list(self._number_re.finditer(got))
if len(wants) != len(gots):
return got
offset = 0
for w, g in zip(wants, gots):
- fraction = w.group("fraction")
- exponent = w.group("exponent1")
+ fraction: Optional[str] = w.group("fraction")
+ exponent: Optional[str] = w.group("exponent1")
if exponent is None:
exponent = w.group("exponent2")
if fraction is None:
@@ -615,8 +660,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
def _get_checker() -> "doctest.OutputChecker":
- """
- Returns a doctest.OutputChecker subclass that supports some
+ """Return a doctest.OutputChecker subclass that supports some
additional options:
* ALLOW_UNICODE and ALLOW_BYTES options to ignore u'' and b''
@@ -636,36 +680,31 @@ def _get_checker() -> "doctest.OutputChecker":
def _get_allow_unicode_flag() -> int:
- """
- Registers and returns the ALLOW_UNICODE flag.
- """
+ """Register and return the ALLOW_UNICODE flag."""
import doctest
return doctest.register_optionflag("ALLOW_UNICODE")
def _get_allow_bytes_flag() -> int:
- """
- Registers and returns the ALLOW_BYTES flag.
- """
+ """Register and return the ALLOW_BYTES flag."""
import doctest
return doctest.register_optionflag("ALLOW_BYTES")
def _get_number_flag() -> int:
- """
- Registers and returns the NUMBER flag.
- """
+ """Register and return the NUMBER flag."""
import doctest
return doctest.register_optionflag("NUMBER")
def _get_report_choice(key: str) -> int:
- """
- This function returns the actual `doctest` module flag value, we want to do it as late as possible to avoid
- importing `doctest` and all its dependencies when parsing options, as it adds overhead and breaks tests.
+ """Return the actual `doctest` module flag value.
+
+ We want to do it as late as possible to avoid importing `doctest` and all
+ its dependencies when parsing options, as it adds overhead and breaks tests.
"""
import doctest
@@ -679,8 +718,7 @@ def _get_report_choice(key: str) -> int:
@pytest.fixture(scope="session")
-def doctest_namespace():
- """
- Fixture that returns a :py:class:`dict` that will be injected into the namespace of doctests.
- """
+def doctest_namespace() -> Dict[str, Any]:
+ """Fixture that returns a :py:class:`dict` that will be injected into the
+ namespace of doctests."""
return dict()
diff --git a/contrib/python/pytest/py3/_pytest/faulthandler.py b/contrib/python/pytest/py3/_pytest/faulthandler.py
index 8d723c206c..ff673b5b16 100644
--- a/contrib/python/pytest/py3/_pytest/faulthandler.py
+++ b/contrib/python/pytest/py3/_pytest/faulthandler.py
@@ -1,25 +1,28 @@
import io
import os
import sys
+from typing import Generator
from typing import TextIO
import pytest
+from _pytest.config import Config
+from _pytest.config.argparsing import Parser
+from _pytest.nodes import Item
from _pytest.store import StoreKey
fault_handler_stderr_key = StoreKey[TextIO]()
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
help = (
"Dump the traceback of all threads if a test takes "
- "more than TIMEOUT seconds to finish.\n"
- "Not available on Windows."
+ "more than TIMEOUT seconds to finish."
)
parser.addini("faulthandler_timeout", help, default=0.0)
-def pytest_configure(config):
+def pytest_configure(config: Config) -> None:
import faulthandler
if not faulthandler.is_enabled():
@@ -27,18 +30,15 @@ def pytest_configure(config):
# of enabling faulthandler before each test executes.
config.pluginmanager.register(FaultHandlerHooks(), "faulthandler-hooks")
else:
- from _pytest.warnings import _issue_warning_captured
-
# Do not handle dumping to stderr if faulthandler is already enabled, so warn
# users that the option is being ignored.
timeout = FaultHandlerHooks.get_timeout_config_value(config)
if timeout > 0:
- _issue_warning_captured(
+ config.issue_config_time_warning(
pytest.PytestConfigWarning(
"faulthandler module enabled before pytest configuration step, "
"'faulthandler_timeout' option ignored"
),
- config.hook,
stacklevel=2,
)
@@ -47,14 +47,14 @@ class FaultHandlerHooks:
"""Implements hooks that will actually install fault handler before tests execute,
as well as correctly handle pdb and internal errors."""
- def pytest_configure(self, config):
+ def pytest_configure(self, config: Config) -> None:
import faulthandler
stderr_fd_copy = os.dup(self._get_stderr_fileno())
config._store[fault_handler_stderr_key] = open(stderr_fd_copy, "w")
faulthandler.enable(file=config._store[fault_handler_stderr_key])
- def pytest_unconfigure(self, config):
+ def pytest_unconfigure(self, config: Config) -> None:
import faulthandler
faulthandler.disable()
@@ -69,7 +69,12 @@ class FaultHandlerHooks:
@staticmethod
def _get_stderr_fileno():
try:
- return sys.stderr.fileno()
+ fileno = sys.stderr.fileno()
+ # The Twisted Logger will return an invalid file descriptor since it is not backed
+ # by an FD. So, let's also forward this to the same code path as with pytest-xdist.
+ if fileno == -1:
+ raise AttributeError()
+ return fileno
except (AttributeError, io.UnsupportedOperation):
# pytest-xdist monkeypatches sys.stderr with an object that is not an actual file.
# https://docs.python.org/3/library/faulthandler.html#issue-with-file-descriptors
@@ -80,8 +85,8 @@ class FaultHandlerHooks:
def get_timeout_config_value(config):
return float(config.getini("faulthandler_timeout") or 0.0)
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_protocol(self, item):
+ @pytest.hookimpl(hookwrapper=True, trylast=True)
+ def pytest_runtest_protocol(self, item: Item) -> Generator[None, None, None]:
timeout = self.get_timeout_config_value(item.config)
stderr = item.config._store[fault_handler_stderr_key]
if timeout > 0 and stderr is not None:
@@ -96,18 +101,16 @@ class FaultHandlerHooks:
yield
@pytest.hookimpl(tryfirst=True)
- def pytest_enter_pdb(self):
- """Cancel any traceback dumping due to timeout before entering pdb.
- """
+ def pytest_enter_pdb(self) -> None:
+ """Cancel any traceback dumping due to timeout before entering pdb."""
import faulthandler
faulthandler.cancel_dump_traceback_later()
@pytest.hookimpl(tryfirst=True)
- def pytest_exception_interact(self):
+ def pytest_exception_interact(self) -> None:
"""Cancel any traceback dumping due to an interactive exception being
- raised.
- """
+ raised."""
import faulthandler
faulthandler.cancel_dump_traceback_later()
diff --git a/contrib/python/pytest/py3/_pytest/fixtures.py b/contrib/python/pytest/py3/_pytest/fixtures.py
index 22964770d2..273bcafd39 100644
--- a/contrib/python/pytest/py3/_pytest/fixtures.py
+++ b/contrib/python/pytest/py3/_pytest/fixtures.py
@@ -1,25 +1,43 @@
import functools
import inspect
-import itertools
+import os
import sys
import warnings
from collections import defaultdict
from collections import deque
-from collections import OrderedDict
+from types import TracebackType
+from typing import Any
+from typing import Callable
+from typing import cast
from typing import Dict
+from typing import Generator
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
import attr
import py
import _pytest
+from _pytest import nodes
+from _pytest._code import getfslineno
from _pytest._code.code import FormattedExcinfo
from _pytest._code.code import TerminalRepr
-from _pytest._code.source import getfslineno
from _pytest._io import TerminalWriter
from _pytest.compat import _format_args
from _pytest.compat import _PytestWrapper
+from _pytest.compat import assert_never
+from _pytest.compat import final
from _pytest.compat import get_real_func
from _pytest.compat import get_real_method
from _pytest.compat import getfuncargnames
@@ -28,69 +46,71 @@ from _pytest.compat import getlocation
from _pytest.compat import is_generator
from _pytest.compat import NOTSET
from _pytest.compat import safe_getattr
-from _pytest.compat import TYPE_CHECKING
-from _pytest.deprecated import FIXTURE_POSITIONAL_ARGUMENTS
-from _pytest.deprecated import FUNCARGNAMES
+from _pytest.config import _PluggyPlugin
+from _pytest.config import Config
+from _pytest.config.argparsing import Parser
+from _pytest.deprecated import check_ispytest
+from _pytest.deprecated import FILLFUNCARGS
+from _pytest.deprecated import YIELD_FIXTURE
+from _pytest.mark import Mark
from _pytest.mark import ParameterSet
+from _pytest.mark.structures import MarkDecorator
from _pytest.outcomes import fail
from _pytest.outcomes import TEST_OUTCOME
+from _pytest.pathlib import absolutepath
+from _pytest.store import StoreKey
if TYPE_CHECKING:
- from typing import Type
+ from typing import Deque
+ from typing import NoReturn
+ from typing_extensions import Literal
- from _pytest import nodes
from _pytest.main import Session
+ from _pytest.python import CallSpec2
+ from _pytest.python import Function
+ from _pytest.python import Metafunc
+
+ _Scope = Literal["session", "package", "module", "class", "function"]
+
+
+# The value of the fixture -- return/yield of the fixture function (type variable).
+_FixtureValue = TypeVar("_FixtureValue")
+# The type of the fixture function (type variable).
+_FixtureFunction = TypeVar("_FixtureFunction", bound=Callable[..., object])
+# The type of a fixture function (type alias generic in fixture value).
+_FixtureFunc = Union[
+ Callable[..., _FixtureValue], Callable[..., Generator[_FixtureValue, None, None]]
+]
+# The type of FixtureDef.cached_result (type alias generic in fixture value).
+_FixtureCachedResult = Union[
+ Tuple[
+ # The result.
+ _FixtureValue,
+ # Cache key.
+ object,
+ None,
+ ],
+ Tuple[
+ None,
+ # Cache key.
+ object,
+ # Exc info if raised.
+ Tuple[Type[BaseException], BaseException, TracebackType],
+ ],
+]
@attr.s(frozen=True)
-class PseudoFixtureDef:
- cached_result = attr.ib()
- scope = attr.ib()
+class PseudoFixtureDef(Generic[_FixtureValue]):
+ cached_result = attr.ib(type="_FixtureCachedResult[_FixtureValue]")
+ scope = attr.ib(type="_Scope")
-def pytest_sessionstart(session: "Session"):
- import _pytest.python
- import _pytest.nodes
-
- scopename2class.update(
- {
- "package": _pytest.python.Package,
- "class": _pytest.python.Class,
- "module": _pytest.python.Module,
- "function": _pytest.nodes.Item,
- "session": _pytest.main.Session,
- }
- )
+def pytest_sessionstart(session: "Session") -> None:
session._fixturemanager = FixtureManager(session)
-scopename2class = {} # type: Dict[str, Type[nodes.Node]]
-
-scope2props = dict(session=()) # type: Dict[str, Tuple[str, ...]]
-scope2props["package"] = ("fspath",)
-scope2props["module"] = ("fspath", "module")
-scope2props["class"] = scope2props["module"] + ("cls",)
-scope2props["instance"] = scope2props["class"] + ("instance",)
-scope2props["function"] = scope2props["instance"] + ("function", "keywords")
-
-
-def scopeproperty(name=None, doc=None):
- def decoratescope(func):
- scopename = name or func.__name__
-
- def provide(self):
- if func.__name__ in scope2props[self.scope]:
- return func(self)
- raise AttributeError(
- "{} not available in {}-scoped context".format(scopename, self.scope)
- )
-
- return property(provide, None, None, func.__doc__)
-
- return decoratescope
-
-
-def get_scope_package(node, fixturedef):
+def get_scope_package(node, fixturedef: "FixtureDef[object]"):
import pytest
cls = pytest.Package
@@ -105,25 +125,44 @@ def get_scope_package(node, fixturedef):
return current
-def get_scope_node(node, scope):
- cls = scopename2class.get(scope)
- if cls is None:
- raise ValueError("unknown scope")
- return node.getparent(cls)
+def get_scope_node(
+ node: nodes.Node, scope: "_Scope"
+) -> Optional[Union[nodes.Item, nodes.Collector]]:
+ import _pytest.python
+
+ if scope == "function":
+ return node.getparent(nodes.Item)
+ elif scope == "class":
+ return node.getparent(_pytest.python.Class)
+ elif scope == "module":
+ return node.getparent(_pytest.python.Module)
+ elif scope == "package":
+ return node.getparent(_pytest.python.Package)
+ elif scope == "session":
+ return node.getparent(_pytest.main.Session)
+ else:
+ assert_never(scope)
+
+
+# Used for storing artificial fixturedefs for direct parametrization.
+name2pseudofixturedef_key = StoreKey[Dict[str, "FixtureDef[Any]"]]()
-def add_funcarg_pseudo_fixture_def(collector, metafunc, fixturemanager):
- # this function will transform all collected calls to a functions
+def add_funcarg_pseudo_fixture_def(
+ collector: nodes.Collector, metafunc: "Metafunc", fixturemanager: "FixtureManager"
+) -> None:
+ # This function will transform all collected calls to functions
# if they use direct funcargs (i.e. direct parametrization)
# because we want later test execution to be able to rely on
# an existing FixtureDef structure for all arguments.
# XXX we can probably avoid this algorithm if we modify CallSpec2
# to directly care for creating the fixturedefs within its methods.
if not metafunc._calls[0].funcargs:
- return # this function call does not have direct parametrization
- # collect funcargs of all callspecs into a list of values
- arg2params = {}
- arg2scope = {}
+ # This function call does not have direct parametrization.
+ return
+ # Collect funcargs of all callspecs into a list of values.
+ arg2params: Dict[str, List[object]] = {}
+ arg2scope: Dict[str, _Scope] = {}
for callspec in metafunc._calls:
for argname, argvalue in callspec.funcargs.items():
assert argname not in callspec.params
@@ -136,11 +175,11 @@ def add_funcarg_pseudo_fixture_def(collector, metafunc, fixturemanager):
arg2scope[argname] = scopes[scopenum]
callspec.funcargs.clear()
- # register artificial FixtureDef's so that later at test execution
+ # Register artificial FixtureDef's so that later at test execution
# time we can rely on a proper FixtureDef to exist for fixture setup.
arg2fixturedefs = metafunc._arg2fixturedefs
for argname, valuelist in arg2params.items():
- # if we have a scope that is higher than function we need
+ # If we have a scope that is higher than function, we need
# to make sure we only ever create an according fixturedef on
# a per-scope basis. We thus store and cache the fixturedef on the
# node related to the scope.
@@ -150,46 +189,61 @@ def add_funcarg_pseudo_fixture_def(collector, metafunc, fixturemanager):
node = get_scope_node(collector, scope)
if node is None:
assert scope == "class" and isinstance(collector, _pytest.python.Module)
- # use module-level collector for class-scope (for now)
+ # Use module-level collector for class-scope (for now).
node = collector
- if node and argname in node._name2pseudofixturedef:
- arg2fixturedefs[argname] = [node._name2pseudofixturedef[argname]]
+ if node is None:
+ name2pseudofixturedef = None
+ else:
+ default: Dict[str, FixtureDef[Any]] = {}
+ name2pseudofixturedef = node._store.setdefault(
+ name2pseudofixturedef_key, default
+ )
+ if name2pseudofixturedef is not None and argname in name2pseudofixturedef:
+ arg2fixturedefs[argname] = [name2pseudofixturedef[argname]]
else:
fixturedef = FixtureDef(
- fixturemanager,
- "",
- argname,
- get_direct_param_fixture_func,
- arg2scope[argname],
- valuelist,
- False,
- False,
+ fixturemanager=fixturemanager,
+ baseid="",
+ argname=argname,
+ func=get_direct_param_fixture_func,
+ scope=arg2scope[argname],
+ params=valuelist,
+ unittest=False,
+ ids=None,
)
arg2fixturedefs[argname] = [fixturedef]
- if node is not None:
- node._name2pseudofixturedef[argname] = fixturedef
+ if name2pseudofixturedef is not None:
+ name2pseudofixturedef[argname] = fixturedef
-def getfixturemarker(obj):
- """ return fixturemarker or None if it doesn't exist or raised
+def getfixturemarker(obj: object) -> Optional["FixtureFunctionMarker"]:
+ """Return fixturemarker or None if it doesn't exist or raised
exceptions."""
try:
- return getattr(obj, "_pytestfixturefunction", None)
+ fixturemarker: Optional[FixtureFunctionMarker] = getattr(
+ obj, "_pytestfixturefunction", None
+ )
except TEST_OUTCOME:
# some objects raise errors like request (from flask import request)
# we don't expect them to be fixture functions
return None
+ return fixturemarker
+
+# Parametrized fixture key, helper alias for code below.
+_Key = Tuple[object, ...]
-def get_parametrized_fixture_keys(item, scopenum):
- """ return list of keys for all parametrized arguments which match
+
+def get_parametrized_fixture_keys(item: nodes.Item, scopenum: int) -> Iterator[_Key]:
+ """Return list of keys for all parametrized arguments which match
the specified scope. """
assert scopenum < scopenum_function # function
try:
- cs = item.callspec
+ callspec = item.callspec # type: ignore[attr-defined]
except AttributeError:
pass
else:
+ cs: CallSpec2 = callspec
# cs.indices.items() is random order of argnames. Need to
# sort this so that different calls to
# get_parametrized_fixture_keys will be deterministic.
@@ -197,67 +251,80 @@ def get_parametrized_fixture_keys(item, scopenum):
if cs._arg2scopenum[argname] != scopenum:
continue
if scopenum == 0: # session
- key = (argname, param_index)
+ key: _Key = (argname, param_index)
elif scopenum == 1: # package
key = (argname, param_index, item.fspath.dirpath())
elif scopenum == 2: # module
key = (argname, param_index, item.fspath)
elif scopenum == 3: # class
- key = (argname, param_index, item.fspath, item.cls)
+ item_cls = item.cls # type: ignore[attr-defined]
+ key = (argname, param_index, item.fspath, item_cls)
yield key
-# algorithm for sorting on a per-parametrized resource setup basis
-# it is called for scopenum==0 (session) first and performs sorting
+# Algorithm for sorting on a per-parametrized resource setup basis.
+# It is called for scopenum==0 (session) first and performs sorting
# down to the lower scopes such as to minimize number of "high scope"
-# setups and teardowns
+# setups and teardowns.
-def reorder_items(items):
- argkeys_cache = {}
- items_by_argkey = {}
+def reorder_items(items: Sequence[nodes.Item]) -> List[nodes.Item]:
+ argkeys_cache: Dict[int, Dict[nodes.Item, Dict[_Key, None]]] = {}
+ items_by_argkey: Dict[int, Dict[_Key, Deque[nodes.Item]]] = {}
for scopenum in range(0, scopenum_function):
- argkeys_cache[scopenum] = d = {}
- items_by_argkey[scopenum] = item_d = defaultdict(deque)
+ d: Dict[nodes.Item, Dict[_Key, None]] = {}
+ argkeys_cache[scopenum] = d
+ item_d: Dict[_Key, Deque[nodes.Item]] = defaultdict(deque)
+ items_by_argkey[scopenum] = item_d
for item in items:
- keys = OrderedDict.fromkeys(get_parametrized_fixture_keys(item, scopenum))
+ keys = dict.fromkeys(get_parametrized_fixture_keys(item, scopenum), None)
if keys:
d[item] = keys
for key in keys:
item_d[key].append(item)
- items = OrderedDict.fromkeys(items)
- return list(reorder_items_atscope(items, argkeys_cache, items_by_argkey, 0))
+ items_dict = dict.fromkeys(items, None)
+ return list(reorder_items_atscope(items_dict, argkeys_cache, items_by_argkey, 0))
-def fix_cache_order(item, argkeys_cache, items_by_argkey):
+def fix_cache_order(
+ item: nodes.Item,
+ argkeys_cache: Dict[int, Dict[nodes.Item, Dict[_Key, None]]],
+ items_by_argkey: Dict[int, Dict[_Key, "Deque[nodes.Item]"]],
+) -> None:
for scopenum in range(0, scopenum_function):
for key in argkeys_cache[scopenum].get(item, []):
items_by_argkey[scopenum][key].appendleft(item)
-def reorder_items_atscope(items, argkeys_cache, items_by_argkey, scopenum):
+def reorder_items_atscope(
+ items: Dict[nodes.Item, None],
+ argkeys_cache: Dict[int, Dict[nodes.Item, Dict[_Key, None]]],
+ items_by_argkey: Dict[int, Dict[_Key, "Deque[nodes.Item]"]],
+ scopenum: int,
+) -> Dict[nodes.Item, None]:
if scopenum >= scopenum_function or len(items) < 3:
return items
- ignore = set()
+ ignore: Set[Optional[_Key]] = set()
items_deque = deque(items)
- items_done = OrderedDict()
+ items_done: Dict[nodes.Item, None] = {}
scoped_items_by_argkey = items_by_argkey[scopenum]
scoped_argkeys_cache = argkeys_cache[scopenum]
while items_deque:
- no_argkey_group = OrderedDict()
+ no_argkey_group: Dict[nodes.Item, None] = {}
slicing_argkey = None
while items_deque:
item = items_deque.popleft()
if item in items_done or item in no_argkey_group:
continue
- argkeys = OrderedDict.fromkeys(
- k for k in scoped_argkeys_cache.get(item, []) if k not in ignore
+ argkeys = dict.fromkeys(
+ (k for k in scoped_argkeys_cache.get(item, []) if k not in ignore), None
)
if not argkeys:
no_argkey_group[item] = None
else:
slicing_argkey, _ = argkeys.popitem()
- # we don't have to remove relevant items from later in the deque because they'll just be ignored
+ # We don't have to remove relevant items from later in the
+ # deque because they'll just be ignored.
matching_items = [
i for i in scoped_items_by_argkey[slicing_argkey] if i in items
]
@@ -275,8 +342,22 @@ def reorder_items_atscope(items, argkeys_cache, items_by_argkey, scopenum):
return items_done
-def fillfixtures(function):
- """ fill missing funcargs for a test function. """
+def _fillfuncargs(function: "Function") -> None:
+ """Fill missing fixtures for a test function, old public API (deprecated)."""
+ warnings.warn(FILLFUNCARGS.format(name="pytest._fillfuncargs()"), stacklevel=2)
+ _fill_fixtures_impl(function)
+
+
+def fillfixtures(function: "Function") -> None:
+ """Fill missing fixtures for a test function (deprecated)."""
+ warnings.warn(
+ FILLFUNCARGS.format(name="_pytest.fixtures.fillfixtures()"), stacklevel=2
+ )
+ _fill_fixtures_impl(function)
+
+
+def _fill_fixtures_impl(function: "Function") -> None:
+ """Internal implementation to fill fixtures on the given function object."""
try:
request = function._request
except AttributeError:
@@ -284,11 +365,12 @@ def fillfixtures(function):
# with the oejskit plugin. It uses classes with funcargs
# and we thus have to work a bit to allow this.
fm = function.session._fixturemanager
+ assert function.parent is not None
fi = fm.getfixtureinfo(function.parent, function.obj, None)
function._fixtureinfo = fi
- request = function._request = FixtureRequest(function)
+ request = function._request = FixtureRequest(function, _ispytest=True)
request._fillfixtures()
- # prune out funcargs for jstests
+ # Prune out funcargs for jstests.
newfuncargs = {}
for name in fi.argnames:
newfuncargs[name] = function.funcargs[name]
@@ -303,17 +385,17 @@ def get_direct_param_fixture_func(request):
@attr.s(slots=True)
class FuncFixtureInfo:
- # original function argument names
- argnames = attr.ib(type=tuple)
- # argnames that function immediately requires. These include argnames +
+ # Original function argument names.
+ argnames = attr.ib(type=Tuple[str, ...])
+ # Argnames that function immediately requires. These include argnames +
# fixture names specified via usefixtures and via autouse=True in fixture
# definitions.
- initialnames = attr.ib(type=tuple)
- names_closure = attr.ib() # List[str]
- name2fixturedefs = attr.ib() # List[str, List[FixtureDef]]
+ initialnames = attr.ib(type=Tuple[str, ...])
+ names_closure = attr.ib(type=List[str])
+ name2fixturedefs = attr.ib(type=Dict[str, Sequence["FixtureDef[Any]"]])
- def prune_dependency_tree(self):
- """Recompute names_closure from initialnames and name2fixturedefs
+ def prune_dependency_tree(self) -> None:
+ """Recompute names_closure from initialnames and name2fixturedefs.
Can only reduce names_closure, which means that the new closure will
always be a subset of the old one. The order is preserved.
@@ -323,11 +405,11 @@ class FuncFixtureInfo:
tree. In this way the dependency tree can get pruned, and the closure
of argnames may get reduced.
"""
- closure = set()
+ closure: Set[str] = set()
working_set = set(self.initialnames)
while working_set:
argname = working_set.pop()
- # argname may be smth not included in the original names_closure,
+ # Argname may be smth not included in the original names_closure,
# in which case we ignore it. This currently happens with pseudo
# FixtureDefs which wrap 'get_direct_param_fixture_func(request)'.
# So they introduce the new dependency 'request' which might have
@@ -341,53 +423,51 @@ class FuncFixtureInfo:
class FixtureRequest:
- """ A request for a fixture from a test or fixture function.
+ """A request for a fixture from a test or fixture function.
- A request object gives access to the requesting test context
- and has an optional ``param`` attribute in case
- the fixture is parametrized indirectly.
+ A request object gives access to the requesting test context and has
+ an optional ``param`` attribute in case the fixture is parametrized
+ indirectly.
"""
- def __init__(self, pyfuncitem):
+ def __init__(self, pyfuncitem, *, _ispytest: bool = False) -> None:
+ check_ispytest(_ispytest)
self._pyfuncitem = pyfuncitem
- #: fixture for which this request is being performed
- self.fixturename = None
- #: Scope string, one of "function", "class", "module", "session"
- self.scope = "function"
- self._fixture_defs = {} # type: Dict[str, FixtureDef]
- fixtureinfo = pyfuncitem._fixtureinfo
+ #: Fixture for which this request is being performed.
+ self.fixturename: Optional[str] = None
+ #: Scope string, one of "function", "class", "module", "session".
+ self.scope: _Scope = "function"
+ self._fixture_defs: Dict[str, FixtureDef[Any]] = {}
+ fixtureinfo: FuncFixtureInfo = pyfuncitem._fixtureinfo
self._arg2fixturedefs = fixtureinfo.name2fixturedefs.copy()
- self._arg2index = {}
- self._fixturemanager = pyfuncitem.session._fixturemanager
+ self._arg2index: Dict[str, int] = {}
+ self._fixturemanager: FixtureManager = (pyfuncitem.session._fixturemanager)
@property
- def fixturenames(self):
- """names of all active fixtures in this request"""
+ def fixturenames(self) -> List[str]:
+ """Names of all active fixtures in this request."""
result = list(self._pyfuncitem._fixtureinfo.names_closure)
result.extend(set(self._fixture_defs).difference(result))
return result
@property
- def funcargnames(self):
- """ alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
- warnings.warn(FUNCARGNAMES, stacklevel=2)
- return self.fixturenames
-
- @property
def node(self):
- """ underlying collection node (depends on current request scope)"""
+ """Underlying collection node (depends on current request scope)."""
return self._getscopeitem(self.scope)
- def _getnextfixturedef(self, argname):
+ def _getnextfixturedef(self, argname: str) -> "FixtureDef[Any]":
fixturedefs = self._arg2fixturedefs.get(argname, None)
if fixturedefs is None:
- # we arrive here because of a dynamic call to
+ # We arrive here because of a dynamic call to
# getfixturevalue(argname) usage which was naturally
- # not known at parsing/collection time
+ # not known at parsing/collection time.
+ assert self._pyfuncitem.parent is not None
parentid = self._pyfuncitem.parent.nodeid
fixturedefs = self._fixturemanager.getfixturedefs(argname, parentid)
- self._arg2fixturedefs[argname] = fixturedefs
- # fixturedefs list is immutable so we maintain a decreasing index
+ # TODO: Fix this type ignore. Either add assert or adjust types.
+ # Can this be None here?
+ self._arg2fixturedefs[argname] = fixturedefs # type: ignore[assignment]
+ # fixturedefs list is immutable so we maintain a decreasing index.
index = self._arg2index.get(argname, 0) - 1
if fixturedefs is None or (-index > len(fixturedefs)):
raise FixtureLookupError(argname, self)
@@ -395,98 +475,116 @@ class FixtureRequest:
return fixturedefs[index]
@property
- def config(self):
- """ the pytest config object associated with this request. """
- return self._pyfuncitem.config
+ def config(self) -> Config:
+ """The pytest config object associated with this request."""
+ return self._pyfuncitem.config # type: ignore[no-any-return]
- @scopeproperty()
+ @property
def function(self):
- """ test function object if the request has a per-function scope. """
+ """Test function object if the request has a per-function scope."""
+ if self.scope != "function":
+ raise AttributeError(
+ f"function not available in {self.scope}-scoped context"
+ )
return self._pyfuncitem.obj
- @scopeproperty("class")
+ @property
def cls(self):
- """ class (can be None) where the test function was collected. """
+ """Class (can be None) where the test function was collected."""
+ if self.scope not in ("class", "function"):
+ raise AttributeError(f"cls not available in {self.scope}-scoped context")
clscol = self._pyfuncitem.getparent(_pytest.python.Class)
if clscol:
return clscol.obj
@property
def instance(self):
- """ instance (can be None) on which test function was collected. """
- # unittest support hack, see _pytest.unittest.TestCaseFunction
+ """Instance (can be None) on which test function was collected."""
+ # unittest support hack, see _pytest.unittest.TestCaseFunction.
try:
return self._pyfuncitem._testcase
except AttributeError:
function = getattr(self, "function", None)
return getattr(function, "__self__", None)
- @scopeproperty()
+ @property
def module(self):
- """ python module object where the test function was collected. """
+ """Python module object where the test function was collected."""
+ if self.scope not in ("function", "class", "module"):
+ raise AttributeError(f"module not available in {self.scope}-scoped context")
return self._pyfuncitem.getparent(_pytest.python.Module).obj
- @scopeproperty()
+ @property
def fspath(self) -> py.path.local:
- """ the file system path of the test module which collected this test. """
+ """The file system path of the test module which collected this test."""
+ if self.scope not in ("function", "class", "module", "package"):
+ raise AttributeError(f"module not available in {self.scope}-scoped context")
# TODO: Remove ignore once _pyfuncitem is properly typed.
return self._pyfuncitem.fspath # type: ignore
@property
def keywords(self):
- """ keywords/markers dictionary for the underlying node. """
+ """Keywords/markers dictionary for the underlying node."""
return self.node.keywords
@property
- def session(self):
- """ pytest session object. """
- return self._pyfuncitem.session
-
- def addfinalizer(self, finalizer):
- """ add finalizer/teardown function to be called after the
- last test within the requesting test context finished
- execution. """
- # XXX usually this method is shadowed by fixturedef specific ones
+ def session(self) -> "Session":
+ """Pytest session object."""
+ return self._pyfuncitem.session # type: ignore[no-any-return]
+
+ def addfinalizer(self, finalizer: Callable[[], object]) -> None:
+ """Add finalizer/teardown function to be called after the last test
+ within the requesting test context finished execution."""
+ # XXX usually this method is shadowed by fixturedef specific ones.
self._addfinalizer(finalizer, scope=self.scope)
- def _addfinalizer(self, finalizer, scope):
+ def _addfinalizer(self, finalizer: Callable[[], object], scope) -> None:
colitem = self._getscopeitem(scope)
self._pyfuncitem.session._setupstate.addfinalizer(
finalizer=finalizer, colitem=colitem
)
- def applymarker(self, marker):
- """ Apply a marker to a single test function invocation.
+ def applymarker(self, marker: Union[str, MarkDecorator]) -> None:
+ """Apply a marker to a single test function invocation.
+
This method is useful if you don't want to have a keyword/marker
on all function invocations.
- :arg marker: a :py:class:`_pytest.mark.MarkDecorator` object
- created by a call to ``pytest.mark.NAME(...)``.
+ :param marker:
+ A :py:class:`_pytest.mark.MarkDecorator` object created by a call
+ to ``pytest.mark.NAME(...)``.
"""
self.node.add_marker(marker)
- def raiseerror(self, msg):
- """ raise a FixtureLookupError with the given message. """
+ def raiseerror(self, msg: Optional[str]) -> "NoReturn":
+ """Raise a FixtureLookupError with the given message."""
raise self._fixturemanager.FixtureLookupError(None, self, msg)
- def _fillfixtures(self):
+ def _fillfixtures(self) -> None:
item = self._pyfuncitem
fixturenames = getattr(item, "fixturenames", self.fixturenames)
for argname in fixturenames:
if argname not in item.funcargs:
item.funcargs[argname] = self.getfixturevalue(argname)
- def getfixturevalue(self, argname):
- """ Dynamically run a named fixture function.
+ def getfixturevalue(self, argname: str) -> Any:
+ """Dynamically run a named fixture function.
Declaring fixtures via function argument is recommended where possible.
But if you can only decide whether to use another fixture at test
setup time, you may use this function to retrieve it inside a fixture
or test function body.
+
+ :raises pytest.FixtureLookupError:
+ If the given fixture could not be found.
"""
- return self._get_active_fixturedef(argname).cached_result[0]
+ fixturedef = self._get_active_fixturedef(argname)
+ assert fixturedef.cached_result is not None
+ return fixturedef.cached_result[0]
- def _get_active_fixturedef(self, argname):
+ def _get_active_fixturedef(
+ self, argname: str
+ ) -> Union["FixtureDef[object]", PseudoFixtureDef[object]]:
try:
return self._fixture_defs[argname]
except KeyError:
@@ -495,31 +593,34 @@ class FixtureRequest:
except FixtureLookupError:
if argname == "request":
cached_result = (self, [0], None)
- scope = "function"
+ scope: _Scope = "function"
return PseudoFixtureDef(cached_result, scope)
raise
- # remove indent to prevent the python3 exception
- # from leaking into the call
+ # Remove indent to prevent the python3 exception
+ # from leaking into the call.
self._compute_fixture_value(fixturedef)
self._fixture_defs[argname] = fixturedef
return fixturedef
- def _get_fixturestack(self):
+ def _get_fixturestack(self) -> List["FixtureDef[Any]"]:
current = self
- values = []
+ values: List[FixtureDef[Any]] = []
while 1:
fixturedef = getattr(current, "_fixturedef", None)
if fixturedef is None:
values.reverse()
return values
values.append(fixturedef)
+ assert isinstance(current, SubRequest)
current = current._parent_request
- def _compute_fixture_value(self, fixturedef: "FixtureDef") -> None:
- """
- Creates a SubRequest based on "self" and calls the execute method of the given fixturedef object. This will
- force the FixtureDef object to throw away any previous results and compute a new fixture value, which
- will be stored into the FixtureDef object itself.
+ def _compute_fixture_value(self, fixturedef: "FixtureDef[object]") -> None:
+ """Create a SubRequest based on "self" and call the execute method
+ of the given FixtureDef object.
+
+ This will force the FixtureDef object to throw away any previous
+ results and compute a new fixture value, which will be stored into
+ the FixtureDef object itself.
"""
# prepare a subrequest object before calling fixture function
# (latter managed by fixturedef)
@@ -569,33 +670,39 @@ class FixtureRequest:
fail(msg, pytrace=False)
else:
param_index = funcitem.callspec.indices[argname]
- # if a parametrize invocation set a scope it will override
- # the static scope defined with the fixture function
+ # If a parametrize invocation set a scope it will override
+ # the static scope defined with the fixture function.
paramscopenum = funcitem.callspec._arg2scopenum.get(argname)
if paramscopenum is not None:
scope = scopes[paramscopenum]
- subrequest = SubRequest(self, scope, param, param_index, fixturedef)
+ subrequest = SubRequest(
+ self, scope, param, param_index, fixturedef, _ispytest=True
+ )
- # check if a higher-level scoped fixture accesses a lower level one
+ # Check if a higher-level scoped fixture accesses a lower level one.
subrequest._check_scope(argname, self.scope, scope)
try:
- # call the fixture function
+ # Call the fixture function.
fixturedef.execute(request=subrequest)
finally:
self._schedule_finalizers(fixturedef, subrequest)
- def _schedule_finalizers(self, fixturedef, subrequest):
- # if fixture function failed it might have registered finalizers
+ def _schedule_finalizers(
+ self, fixturedef: "FixtureDef[object]", subrequest: "SubRequest"
+ ) -> None:
+ # If fixture function failed it might have registered finalizers.
self.session._setupstate.addfinalizer(
functools.partial(fixturedef.finish, request=subrequest), subrequest.node
)
- def _check_scope(self, argname, invoking_scope, requested_scope):
+ def _check_scope(
+ self, argname: str, invoking_scope: "_Scope", requested_scope: "_Scope",
+ ) -> None:
if argname == "request":
return
if scopemismatch(invoking_scope, requested_scope):
- # try to report something helpful
+ # Try to report something helpful.
lines = self._factorytraceback()
fail(
"ScopeMismatch: You tried to access the %r scoped "
@@ -605,7 +712,7 @@ class FixtureRequest:
pytrace=False,
)
- def _factorytraceback(self):
+ def _factorytraceback(self) -> List[str]:
lines = []
for fixturedef in self._get_fixturestack():
factory = fixturedef.func
@@ -615,31 +722,43 @@ class FixtureRequest:
lines.append("%s:%d: def %s%s" % (p, lineno + 1, factory.__name__, args))
return lines
- def _getscopeitem(self, scope):
+ def _getscopeitem(self, scope: "_Scope") -> Union[nodes.Item, nodes.Collector]:
if scope == "function":
- # this might also be a non-function Item despite its attribute name
- return self._pyfuncitem
- if scope == "package":
- node = get_scope_package(self._pyfuncitem, self._fixturedef)
+ # This might also be a non-function Item despite its attribute name.
+ node: Optional[Union[nodes.Item, nodes.Collector]] = self._pyfuncitem
+ elif scope == "package":
+ # FIXME: _fixturedef is not defined on FixtureRequest (this class),
+ # but on FixtureRequest (a subclass).
+ node = get_scope_package(self._pyfuncitem, self._fixturedef) # type: ignore[attr-defined]
else:
node = get_scope_node(self._pyfuncitem, scope)
if node is None and scope == "class":
- # fallback to function item itself
+ # Fallback to function item itself.
node = self._pyfuncitem
assert node, 'Could not obtain a node for scope "{}" for function {!r}'.format(
scope, self._pyfuncitem
)
return node
- def __repr__(self):
+ def __repr__(self) -> str:
return "<FixtureRequest for %r>" % (self.node)
+@final
class SubRequest(FixtureRequest):
- """ a sub request for handling getting a fixture from a
- test function/fixture. """
+ """A sub request for handling getting a fixture from a test function/fixture."""
- def __init__(self, request, scope, param, param_index, fixturedef):
+ def __init__(
+ self,
+ request: "FixtureRequest",
+ scope: "_Scope",
+ param,
+ param_index: int,
+ fixturedef: "FixtureDef[object]",
+ *,
+ _ispytest: bool = False,
+ ) -> None:
+ check_ispytest(_ispytest)
self._parent_request = request
self.fixturename = fixturedef.argname
if param is not NOTSET:
@@ -653,16 +772,20 @@ class SubRequest(FixtureRequest):
self._arg2index = request._arg2index
self._fixturemanager = request._fixturemanager
- def __repr__(self):
- return "<SubRequest {!r} for {!r}>".format(self.fixturename, self._pyfuncitem)
+ def __repr__(self) -> str:
+ return f"<SubRequest {self.fixturename!r} for {self._pyfuncitem!r}>"
- def addfinalizer(self, finalizer):
+ def addfinalizer(self, finalizer: Callable[[], object]) -> None:
+ """Add finalizer/teardown function to be called after the last test
+ within the requesting test context finished execution."""
self._fixturedef.addfinalizer(finalizer)
- def _schedule_finalizers(self, fixturedef, subrequest):
- # if the executing fixturedef was not explicitly requested in the argument list (via
+ def _schedule_finalizers(
+ self, fixturedef: "FixtureDef[object]", subrequest: "SubRequest"
+ ) -> None:
+ # If the executing fixturedef was not explicitly requested in the argument list (via
# getfixturevalue inside the fixture call) then ensure this fixture def will be finished
- # first
+ # first.
if fixturedef.argname not in self.fixturenames:
fixturedef.addfinalizer(
functools.partial(self._fixturedef.finish, request=self)
@@ -670,53 +793,56 @@ class SubRequest(FixtureRequest):
super()._schedule_finalizers(fixturedef, subrequest)
-scopes = "session package module class function".split()
+scopes: List["_Scope"] = ["session", "package", "module", "class", "function"]
scopenum_function = scopes.index("function")
-def scopemismatch(currentscope, newscope):
+def scopemismatch(currentscope: "_Scope", newscope: "_Scope") -> bool:
return scopes.index(newscope) > scopes.index(currentscope)
-def scope2index(scope, descr, where=None):
+def scope2index(scope: str, descr: str, where: Optional[str] = None) -> int:
"""Look up the index of ``scope`` and raise a descriptive value error
- if not defined.
- """
+ if not defined."""
+ strscopes: Sequence[str] = scopes
try:
- return scopes.index(scope)
+ return strscopes.index(scope)
except ValueError:
fail(
"{} {}got an unexpected scope value '{}'".format(
- descr, "from {} ".format(where) if where else "", scope
+ descr, f"from {where} " if where else "", scope
),
pytrace=False,
)
+@final
class FixtureLookupError(LookupError):
- """ could not return a requested Fixture (missing or invalid). """
+ """Could not return a requested fixture (missing or invalid)."""
- def __init__(self, argname, request, msg=None):
+ def __init__(
+ self, argname: Optional[str], request: FixtureRequest, msg: Optional[str] = None
+ ) -> None:
self.argname = argname
self.request = request
self.fixturestack = request._get_fixturestack()
self.msg = msg
def formatrepr(self) -> "FixtureLookupErrorRepr":
- tblines = [] # type: List[str]
+ tblines: List[str] = []
addline = tblines.append
stack = [self.request._pyfuncitem.obj]
stack.extend(map(lambda x: x.func, self.fixturestack))
msg = self.msg
if msg is not None:
- # the last fixture raise an error, let's present
- # it at the requesting side
+ # The last fixture raise an error, let's present
+ # it at the requesting side.
stack = stack[:-1]
for function in stack:
fspath, lineno = getfslineno(function)
try:
lines, _ = inspect.getsourcelines(get_real_func(function))
- except (IOError, IndexError, TypeError):
+ except (OSError, IndexError, TypeError):
error_msg = "file %s, line %s: source code not available"
addline(error_msg % (fspath, lineno + 1))
else:
@@ -740,7 +866,7 @@ class FixtureLookupError(LookupError):
self.argname
)
else:
- msg = "fixture '{}' not found".format(self.argname)
+ msg = f"fixture '{self.argname}' not found"
msg += "\n available fixtures: {}".format(", ".join(sorted(available)))
msg += "\n use 'pytest --fixtures [testpath]' for help on them."
@@ -748,7 +874,14 @@ class FixtureLookupError(LookupError):
class FixtureLookupErrorRepr(TerminalRepr):
- def __init__(self, filename, firstlineno, tblines, errorstring, argname):
+ def __init__(
+ self,
+ filename: Union[str, py.path.local],
+ firstlineno: int,
+ tblines: Sequence[str],
+ errorstring: str,
+ argname: Optional[str],
+ ) -> None:
self.tblines = tblines
self.errorstring = errorstring
self.filename = filename
@@ -767,55 +900,67 @@ class FixtureLookupErrorRepr(TerminalRepr):
)
for line in lines[1:]:
tw.line(
- "{} {}".format(FormattedExcinfo.flow_marker, line.strip()),
- red=True,
+ f"{FormattedExcinfo.flow_marker} {line.strip()}", red=True,
)
tw.line()
tw.line("%s:%d" % (self.filename, self.firstlineno + 1))
-def fail_fixturefunc(fixturefunc, msg):
+def fail_fixturefunc(fixturefunc, msg: str) -> "NoReturn":
fs, lineno = getfslineno(fixturefunc)
location = "{}:{}".format(fs, lineno + 1)
source = _pytest._code.Source(fixturefunc)
fail(msg + ":\n\n" + str(source.indent()) + "\n" + location, pytrace=False)
-def call_fixture_func(fixturefunc, request, kwargs):
- yieldctx = is_generator(fixturefunc)
- if yieldctx:
- it = fixturefunc(**kwargs)
- res = next(it)
- finalizer = functools.partial(_teardown_yield_fixture, fixturefunc, it)
+def call_fixture_func(
+ fixturefunc: "_FixtureFunc[_FixtureValue]", request: FixtureRequest, kwargs
+) -> _FixtureValue:
+ if is_generator(fixturefunc):
+ fixturefunc = cast(
+ Callable[..., Generator[_FixtureValue, None, None]], fixturefunc
+ )
+ generator = fixturefunc(**kwargs)
+ try:
+ fixture_result = next(generator)
+ except StopIteration:
+ raise ValueError(f"{request.fixturename} did not yield a value") from None
+ finalizer = functools.partial(_teardown_yield_fixture, fixturefunc, generator)
request.addfinalizer(finalizer)
else:
- res = fixturefunc(**kwargs)
- return res
+ fixturefunc = cast(Callable[..., _FixtureValue], fixturefunc)
+ fixture_result = fixturefunc(**kwargs)
+ return fixture_result
-def _teardown_yield_fixture(fixturefunc, it):
- """Executes the teardown of a fixture function by advancing the iterator after the
- yield and ensure the iteration ends (if not it means there is more than one yield in the function)"""
+def _teardown_yield_fixture(fixturefunc, it) -> None:
+ """Execute the teardown of a fixture function by advancing the iterator
+ after the yield and ensure the iteration ends (if not it means there is
+ more than one yield in the function)."""
try:
next(it)
except StopIteration:
pass
else:
- fail_fixturefunc(
- fixturefunc, "yield_fixture function has more than one 'yield'"
- )
+ fail_fixturefunc(fixturefunc, "fixture function has more than one 'yield'")
-def _eval_scope_callable(scope_callable, fixture_name, config):
+def _eval_scope_callable(
+ scope_callable: "Callable[[str, Config], _Scope]",
+ fixture_name: str,
+ config: Config,
+) -> "_Scope":
try:
- result = scope_callable(fixture_name=fixture_name, config=config)
- except Exception:
+ # Type ignored because there is no typing mechanism to specify
+ # keyword arguments, currently.
+ result = scope_callable(fixture_name=fixture_name, config=config) # type: ignore[call-arg]
+ except Exception as e:
raise TypeError(
"Error evaluating {} while defining fixture '{}'.\n"
"Expected a function with the signature (*, fixture_name, config)".format(
scope_callable, fixture_name
)
- )
+ ) from e
if not isinstance(result, str):
fail(
"Expected {} to return a 'str' while defining fixture '{}', but it returned:\n"
@@ -825,44 +970,55 @@ def _eval_scope_callable(scope_callable, fixture_name, config):
return result
-class FixtureDef:
- """ A container for a factory definition. """
+@final
+class FixtureDef(Generic[_FixtureValue]):
+ """A container for a factory definition."""
def __init__(
self,
- fixturemanager,
- baseid,
- argname,
- func,
- scope,
- params,
- unittest=False,
- ids=None,
- ):
+ fixturemanager: "FixtureManager",
+ baseid: Optional[str],
+ argname: str,
+ func: "_FixtureFunc[_FixtureValue]",
+ scope: "Union[_Scope, Callable[[str, Config], _Scope]]",
+ params: Optional[Sequence[object]],
+ unittest: bool = False,
+ ids: Optional[
+ Union[
+ Tuple[Union[None, str, float, int, bool], ...],
+ Callable[[Any], Optional[object]],
+ ]
+ ] = None,
+ ) -> None:
self._fixturemanager = fixturemanager
self.baseid = baseid or ""
self.has_location = baseid is not None
self.func = func
self.argname = argname
if callable(scope):
- scope = _eval_scope_callable(scope, argname, fixturemanager.config)
- self.scope = scope
+ scope_ = _eval_scope_callable(scope, argname, fixturemanager.config)
+ else:
+ scope_ = scope
self.scopenum = scope2index(
- scope or "function",
- descr="Fixture '{}'".format(func.__name__),
+ # TODO: Check if the `or` here is really necessary.
+ scope_ or "function", # type: ignore[unreachable]
+ descr=f"Fixture '{func.__name__}'",
where=baseid,
)
- self.params = params
- self.argnames = getfuncargnames(func, name=argname, is_method=unittest)
+ self.scope = scope_
+ self.params: Optional[Sequence[object]] = params
+ self.argnames: Tuple[str, ...] = getfuncargnames(
+ func, name=argname, is_method=unittest
+ )
self.unittest = unittest
self.ids = ids
- self.cached_result = None
- self._finalizers = []
+ self.cached_result: Optional[_FixtureCachedResult[_FixtureValue]] = None
+ self._finalizers: List[Callable[[], object]] = []
- def addfinalizer(self, finalizer):
+ def addfinalizer(self, finalizer: Callable[[], object]) -> None:
self._finalizers.append(finalizer)
- def finish(self, request):
+ def finish(self, request: SubRequest) -> None:
exc = None
try:
while self._finalizers:
@@ -879,77 +1035,83 @@ class FixtureDef:
finally:
hook = self._fixturemanager.session.gethookproxy(request.node.fspath)
hook.pytest_fixture_post_finalizer(fixturedef=self, request=request)
- # even if finalization fails, we invalidate
- # the cached fixture value and remove
- # all finalizers because they may be bound methods which will
- # keep instances alive
+ # Even if finalization fails, we invalidate the cached fixture
+ # value and remove all finalizers because they may be bound methods
+ # which will keep instances alive.
self.cached_result = None
self._finalizers = []
- def execute(self, request):
- # get required arguments and register our own finish()
- # with their finalization
+ def execute(self, request: SubRequest) -> _FixtureValue:
+ # Get required arguments and register our own finish()
+ # with their finalization.
for argname in self.argnames:
fixturedef = request._get_active_fixturedef(argname)
if argname != "request":
+ # PseudoFixtureDef is only for "request".
+ assert isinstance(fixturedef, FixtureDef)
fixturedef.addfinalizer(functools.partial(self.finish, request=request))
my_cache_key = self.cache_key(request)
if self.cached_result is not None:
- result, cache_key, err = self.cached_result
# note: comparison with `==` can fail (or be expensive) for e.g.
- # numpy arrays (#6497)
+ # numpy arrays (#6497).
+ cache_key = self.cached_result[1]
if my_cache_key is cache_key:
- if err is not None:
- _, val, tb = err
+ if self.cached_result[2] is not None:
+ _, val, tb = self.cached_result[2]
raise val.with_traceback(tb)
else:
+ result = self.cached_result[0]
return result
- # we have a previous but differently parametrized fixture instance
- # so we need to tear it down before creating a new one
+ # We have a previous but differently parametrized fixture instance
+ # so we need to tear it down before creating a new one.
self.finish(request)
assert self.cached_result is None
hook = self._fixturemanager.session.gethookproxy(request.node.fspath)
- return hook.pytest_fixture_setup(fixturedef=self, request=request)
+ result = hook.pytest_fixture_setup(fixturedef=self, request=request)
+ return result
- def cache_key(self, request):
+ def cache_key(self, request: SubRequest) -> object:
return request.param_index if not hasattr(request, "param") else request.param
- def __repr__(self):
+ def __repr__(self) -> str:
return "<FixtureDef argname={!r} scope={!r} baseid={!r}>".format(
self.argname, self.scope, self.baseid
)
-def resolve_fixture_function(fixturedef, request):
- """Gets the actual callable that can be called to obtain the fixture value, dealing with unittest-specific
- instances and bound methods.
- """
+def resolve_fixture_function(
+ fixturedef: FixtureDef[_FixtureValue], request: FixtureRequest
+) -> "_FixtureFunc[_FixtureValue]":
+ """Get the actual callable that can be called to obtain the fixture
+ value, dealing with unittest-specific instances and bound methods."""
fixturefunc = fixturedef.func
if fixturedef.unittest:
if request.instance is not None:
- # bind the unbound method to the TestCase instance
- fixturefunc = fixturedef.func.__get__(request.instance)
+ # Bind the unbound method to the TestCase instance.
+ fixturefunc = fixturedef.func.__get__(request.instance) # type: ignore[union-attr]
else:
- # the fixture function needs to be bound to the actual
+ # The fixture function needs to be bound to the actual
# request.instance so that code working with "fixturedef" behaves
# as expected.
if request.instance is not None:
- # handle the case where fixture is defined not in a test class, but some other class
- # (for example a plugin class with a fixture), see #2270
+ # Handle the case where fixture is defined not in a test class, but some other class
+ # (for example a plugin class with a fixture), see #2270.
if hasattr(fixturefunc, "__self__") and not isinstance(
- request.instance, fixturefunc.__self__.__class__
+ request.instance, fixturefunc.__self__.__class__ # type: ignore[union-attr]
):
return fixturefunc
fixturefunc = getimfunc(fixturedef.func)
if fixturefunc != fixturedef.func:
- fixturefunc = fixturefunc.__get__(request.instance)
+ fixturefunc = fixturefunc.__get__(request.instance) # type: ignore[union-attr]
return fixturefunc
-def pytest_fixture_setup(fixturedef, request):
- """ Execution of fixture setup. """
+def pytest_fixture_setup(
+ fixturedef: FixtureDef[_FixtureValue], request: SubRequest
+) -> _FixtureValue:
+ """Execution of fixture setup."""
kwargs = {}
for argname in fixturedef.argnames:
fixdef = request._get_active_fixturedef(argname)
@@ -963,52 +1125,80 @@ def pytest_fixture_setup(fixturedef, request):
try:
result = call_fixture_func(fixturefunc, request, kwargs)
except TEST_OUTCOME:
- fixturedef.cached_result = (None, my_cache_key, sys.exc_info())
+ exc_info = sys.exc_info()
+ assert exc_info[0] is not None
+ fixturedef.cached_result = (None, my_cache_key, exc_info)
raise
fixturedef.cached_result = (result, my_cache_key, None)
return result
-def _ensure_immutable_ids(ids):
+def _ensure_immutable_ids(
+ ids: Optional[
+ Union[
+ Iterable[Union[None, str, float, int, bool]],
+ Callable[[Any], Optional[object]],
+ ]
+ ],
+) -> Optional[
+ Union[
+ Tuple[Union[None, str, float, int, bool], ...],
+ Callable[[Any], Optional[object]],
+ ]
+]:
if ids is None:
- return
+ return None
if callable(ids):
return ids
return tuple(ids)
-def wrap_function_to_error_out_if_called_directly(function, fixture_marker):
+def _params_converter(
+ params: Optional[Iterable[object]],
+) -> Optional[Tuple[object, ...]]:
+ return tuple(params) if params is not None else None
+
+
+def wrap_function_to_error_out_if_called_directly(
+ function: _FixtureFunction, fixture_marker: "FixtureFunctionMarker",
+) -> _FixtureFunction:
"""Wrap the given fixture function so we can raise an error about it being called directly,
- instead of used as an argument in a test function.
- """
+ instead of used as an argument in a test function."""
message = (
'Fixture "{name}" called directly. Fixtures are not meant to be called directly,\n'
"but are created automatically when test functions request them as parameters.\n"
- "See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures, and\n"
- "https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly about how to update your code."
+ "See https://docs.pytest.org/en/stable/fixture.html for more information about fixtures, and\n"
+ "https://docs.pytest.org/en/stable/deprecations.html#calling-fixtures-directly about how to update your code."
).format(name=fixture_marker.name or function.__name__)
@functools.wraps(function)
def result(*args, **kwargs):
fail(message, pytrace=False)
- # keep reference to the original function in our own custom attribute so we don't unwrap
- # further than this point and lose useful wrappings like @mock.patch (#3774)
- result.__pytest_wrapped__ = _PytestWrapper(function)
+ # Keep reference to the original function in our own custom attribute so we don't unwrap
+ # further than this point and lose useful wrappings like @mock.patch (#3774).
+ result.__pytest_wrapped__ = _PytestWrapper(function) # type: ignore[attr-defined]
- return result
+ return cast(_FixtureFunction, result)
+@final
@attr.s(frozen=True)
class FixtureFunctionMarker:
- scope = attr.ib()
- params = attr.ib(converter=attr.converters.optional(tuple))
- autouse = attr.ib(default=False)
- # Ignore type because of https://github.com/python/mypy/issues/6172.
- ids = attr.ib(default=None, converter=_ensure_immutable_ids) # type: ignore
- name = attr.ib(default=None)
-
- def __call__(self, function):
+ scope = attr.ib(type="Union[_Scope, Callable[[str, Config], _Scope]]")
+ params = attr.ib(type=Optional[Tuple[object, ...]], converter=_params_converter)
+ autouse = attr.ib(type=bool, default=False)
+ ids = attr.ib(
+ type=Union[
+ Tuple[Union[None, str, float, int, bool], ...],
+ Callable[[Any], Optional[object]],
+ ],
+ default=None,
+ converter=_ensure_immutable_ids,
+ )
+ name = attr.ib(type=Optional[str], default=None)
+
+ def __call__(self, function: _FixtureFunction) -> _FixtureFunction:
if inspect.isclass(function):
raise ValueError("class fixtures not supported (maybe in the future)")
@@ -1028,153 +1218,140 @@ class FixtureFunctionMarker:
),
pytrace=False,
)
- function._pytestfixturefunction = self
- return function
-
-
-FIXTURE_ARGS_ORDER = ("scope", "params", "autouse", "ids", "name")
-
-
-def _parse_fixture_args(callable_or_scope, *args, **kwargs):
- arguments = {
- "scope": "function",
- "params": None,
- "autouse": False,
- "ids": None,
- "name": None,
- }
- kwargs = {
- key: value for key, value in kwargs.items() if arguments.get(key) != value
- }
-
- fixture_function = None
- if isinstance(callable_or_scope, str):
- args = list(args)
- args.insert(0, callable_or_scope)
- else:
- fixture_function = callable_or_scope
-
- positionals = set()
- for positional, argument_name in zip(args, FIXTURE_ARGS_ORDER):
- arguments[argument_name] = positional
- positionals.add(argument_name)
- duplicated_kwargs = {kwarg for kwarg in kwargs.keys() if kwarg in positionals}
- if duplicated_kwargs:
- raise TypeError(
- "The fixture arguments are defined as positional and keyword: {}. "
- "Use only keyword arguments.".format(", ".join(duplicated_kwargs))
- )
-
- if positionals:
- warnings.warn(FIXTURE_POSITIONAL_ARGUMENTS, stacklevel=2)
+ # Type ignored because https://github.com/python/mypy/issues/2087.
+ function._pytestfixturefunction = self # type: ignore[attr-defined]
+ return function
- arguments.update(kwargs)
- return fixture_function, arguments
+@overload
+def fixture(
+ fixture_function: _FixtureFunction,
+ *,
+ scope: "Union[_Scope, Callable[[str, Config], _Scope]]" = ...,
+ params: Optional[Iterable[object]] = ...,
+ autouse: bool = ...,
+ ids: Optional[
+ Union[
+ Iterable[Union[None, str, float, int, bool]],
+ Callable[[Any], Optional[object]],
+ ]
+ ] = ...,
+ name: Optional[str] = ...,
+) -> _FixtureFunction:
+ ...
+
+
+@overload
+def fixture(
+ fixture_function: None = ...,
+ *,
+ scope: "Union[_Scope, Callable[[str, Config], _Scope]]" = ...,
+ params: Optional[Iterable[object]] = ...,
+ autouse: bool = ...,
+ ids: Optional[
+ Union[
+ Iterable[Union[None, str, float, int, bool]],
+ Callable[[Any], Optional[object]],
+ ]
+ ] = ...,
+ name: Optional[str] = None,
+) -> FixtureFunctionMarker:
+ ...
def fixture(
- callable_or_scope=None,
- *args,
- scope="function",
- params=None,
- autouse=False,
- ids=None,
- name=None
-):
+ fixture_function: Optional[_FixtureFunction] = None,
+ *,
+ scope: "Union[_Scope, Callable[[str, Config], _Scope]]" = "function",
+ params: Optional[Iterable[object]] = None,
+ autouse: bool = False,
+ ids: Optional[
+ Union[
+ Iterable[Union[None, str, float, int, bool]],
+ Callable[[Any], Optional[object]],
+ ]
+ ] = None,
+ name: Optional[str] = None,
+) -> Union[FixtureFunctionMarker, _FixtureFunction]:
"""Decorator to mark a fixture factory function.
This decorator can be used, with or without parameters, to define a
fixture function.
The name of the fixture function can later be referenced to cause its
- invocation ahead of running tests: test
- modules or classes can use the ``pytest.mark.usefixtures(fixturename)``
- marker.
-
- Test functions can directly use fixture names as input
- arguments in which case the fixture instance returned from the fixture
- function will be injected.
-
- Fixtures can provide their values to test functions using ``return`` or ``yield``
- statements. When using ``yield`` the code block after the ``yield`` statement is executed
- as teardown code regardless of the test outcome, and must yield exactly once.
-
- :arg scope: the scope for which this fixture is shared, one of
- ``"function"`` (default), ``"class"``, ``"module"``,
- ``"package"`` or ``"session"`` (``"package"`` is considered **experimental**
- at this time).
-
- This parameter may also be a callable which receives ``(fixture_name, config)``
- as parameters, and must return a ``str`` with one of the values mentioned above.
-
- See :ref:`dynamic scope` in the docs for more information.
-
- :arg params: an optional list of parameters which will cause multiple
- invocations of the fixture function and all of the tests
- using it.
- The current parameter is available in ``request.param``.
-
- :arg autouse: if True, the fixture func is activated for all tests that
- can see it. If False (the default) then an explicit
- reference is needed to activate the fixture.
-
- :arg ids: list of string ids each corresponding to the params
- so that they are part of the test id. If no ids are provided
- they will be generated automatically from the params.
-
- :arg name: the name of the fixture. This defaults to the name of the
- decorated function. If a fixture is used in the same module in
- which it is defined, the function name of the fixture will be
- shadowed by the function arg that requests the fixture; one way
- to resolve this is to name the decorated function
- ``fixture_<fixturename>`` and then use
- ``@pytest.fixture(name='<fixturename>')``.
+ invocation ahead of running tests: test modules or classes can use the
+ ``pytest.mark.usefixtures(fixturename)`` marker.
+
+ Test functions can directly use fixture names as input arguments in which
+ case the fixture instance returned from the fixture function will be
+ injected.
+
+ Fixtures can provide their values to test functions using ``return`` or
+ ``yield`` statements. When using ``yield`` the code block after the
+ ``yield`` statement is executed as teardown code regardless of the test
+ outcome, and must yield exactly once.
+
+ :param scope:
+ The scope for which this fixture is shared; one of ``"function"``
+ (default), ``"class"``, ``"module"``, ``"package"`` or ``"session"``.
+
+ This parameter may also be a callable which receives ``(fixture_name, config)``
+ as parameters, and must return a ``str`` with one of the values mentioned above.
+
+ See :ref:`dynamic scope` in the docs for more information.
+
+ :param params:
+ An optional list of parameters which will cause multiple invocations
+ of the fixture function and all of the tests using it. The current
+ parameter is available in ``request.param``.
+
+ :param autouse:
+ If True, the fixture func is activated for all tests that can see it.
+ If False (the default), an explicit reference is needed to activate
+ the fixture.
+
+ :param ids:
+ List of string ids each corresponding to the params so that they are
+ part of the test id. If no ids are provided they will be generated
+ automatically from the params.
+
+ :param name:
+ The name of the fixture. This defaults to the name of the decorated
+ function. If a fixture is used in the same module in which it is
+ defined, the function name of the fixture will be shadowed by the
+ function arg that requests the fixture; one way to resolve this is to
+ name the decorated function ``fixture_<fixturename>`` and then use
+ ``@pytest.fixture(name='<fixturename>')``.
"""
- if params is not None:
- params = list(params)
-
- fixture_function, arguments = _parse_fixture_args(
- callable_or_scope,
- *args,
- scope=scope,
- params=params,
- autouse=autouse,
- ids=ids,
- name=name,
+ fixture_marker = FixtureFunctionMarker(
+ scope=scope, params=params, autouse=autouse, ids=ids, name=name,
)
- scope = arguments.get("scope")
- params = arguments.get("params")
- autouse = arguments.get("autouse")
- ids = arguments.get("ids")
- name = arguments.get("name")
-
- if fixture_function and params is None and autouse is False:
- # direct decoration
- return FixtureFunctionMarker(scope, params, autouse, name=name)(
- fixture_function
- )
- return FixtureFunctionMarker(scope, params, autouse, ids=ids, name=name)
+ # Direct decoration.
+ if fixture_function:
+ return fixture_marker(fixture_function)
+
+ return fixture_marker
def yield_fixture(
- callable_or_scope=None,
+ fixture_function=None,
*args,
scope="function",
params=None,
autouse=False,
ids=None,
- name=None
+ name=None,
):
- """ (return a) decorator to mark a yield-fixture factory function.
+ """(Return a) decorator to mark a yield-fixture factory function.
.. deprecated:: 3.0
Use :py:func:`pytest.fixture` directly instead.
"""
+ warnings.warn(YIELD_FIXTURE, stacklevel=2)
return fixture(
- callable_or_scope,
+ fixture_function,
*args,
scope=scope,
params=params,
@@ -1184,11 +1361,8 @@ def yield_fixture(
)
-defaultfuncargprefixmarker = fixture()
-
-
@fixture(scope="session")
-def pytestconfig(request):
+def pytestconfig(request: FixtureRequest) -> Config:
"""Session-scoped fixture that returns the :class:`_pytest.config.Config` object.
Example::
@@ -1201,7 +1375,7 @@ def pytestconfig(request):
return request.config
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
parser.addini(
"usefixtures",
type="args",
@@ -1211,8 +1385,7 @@ def pytest_addoption(parser):
class FixtureManager:
- """
- pytest fixtures definitions and information is stored and managed
+ """pytest fixture definitions and information is stored and managed
from this class.
During collection fm.parsefactories() is called multiple times to parse
@@ -1225,7 +1398,7 @@ class FixtureManager:
which themselves offer a fixturenames attribute.
The FuncFixtureInfo object holds information about fixtures and FixtureDefs
- relevant for a particular function. An initial list of fixtures is
+ relevant for a particular function. An initial list of fixtures is
assembled like this:
- ini-defined usefixtures
@@ -1235,7 +1408,7 @@ class FixtureManager:
Subsequently the funcfixtureinfo.fixturenames attribute is computed
as the closure of the fixtures needed to setup the initial fixtures,
- i. e. fixtures needed by fixture functions themselves are appended
+ i.e. fixtures needed by fixture functions themselves are appended
to the fixturenames list.
Upon the test-setup phases all fixturenames are instantiated, retrieved
@@ -1245,24 +1418,27 @@ class FixtureManager:
FixtureLookupError = FixtureLookupError
FixtureLookupErrorRepr = FixtureLookupErrorRepr
- def __init__(self, session):
+ def __init__(self, session: "Session") -> None:
self.session = session
- self.config = session.config
- self._arg2fixturedefs = {}
- self._holderobjseen = set()
- self._nodeid_and_autousenames = [("", self.config.getini("usefixtures"))]
+ self.config: Config = session.config
+ self._arg2fixturedefs: Dict[str, List[FixtureDef[Any]]] = {}
+ self._holderobjseen: Set[object] = set()
+ # A mapping from a nodeid to a list of autouse fixtures it defines.
+ self._nodeid_autousenames: Dict[str, List[str]] = {
+ "": self.config.getini("usefixtures"),
+ }
session.config.pluginmanager.register(self, "funcmanage")
- def _get_direct_parametrize_args(self, node):
- """This function returns all the direct parametrization
- arguments of a node, so we don't mistake them for fixtures
+ def _get_direct_parametrize_args(self, node: nodes.Node) -> List[str]:
+ """Return all direct parametrization arguments of a node, so we don't
+ mistake them for fixtures.
- Check https://github.com/pytest-dev/pytest/issues/5036
+ Check https://github.com/pytest-dev/pytest/issues/5036.
- This things are done later as well when dealing with parametrization
- so this could be improved
+ These things are done later as well when dealing with parametrization
+ so this could be improved.
"""
- parametrize_argnames = []
+ parametrize_argnames: List[str] = []
for marker in node.iter_markers(name="parametrize"):
if not marker.kwargs.get("indirect", False):
p_argnames, _ = ParameterSet._parse_parametrize_args(
@@ -1272,78 +1448,82 @@ class FixtureManager:
return parametrize_argnames
- def getfixtureinfo(self, node, func, cls, funcargs=True):
+ def getfixtureinfo(
+ self, node: nodes.Node, func, cls, funcargs: bool = True
+ ) -> FuncFixtureInfo:
if funcargs and not getattr(node, "nofuncargs", False):
argnames = getfuncargnames(func, name=node.name, cls=cls)
else:
argnames = ()
- usefixtures = itertools.chain.from_iterable(
- mark.args for mark in node.iter_markers(name="usefixtures")
+ usefixtures = tuple(
+ arg for mark in node.iter_markers(name="usefixtures") for arg in mark.args
)
- initialnames = tuple(usefixtures) + argnames
+ initialnames = usefixtures + argnames
fm = node.session._fixturemanager
initialnames, names_closure, arg2fixturedefs = fm.getfixtureclosure(
initialnames, node, ignore_args=self._get_direct_parametrize_args(node)
)
return FuncFixtureInfo(argnames, initialnames, names_closure, arg2fixturedefs)
- def pytest_plugin_registered(self, plugin):
+ def pytest_plugin_registered(self, plugin: _PluggyPlugin) -> None:
nodeid = None
try:
- p = py.path.local(plugin.__file__).realpath()
+ p = absolutepath(plugin.__file__) # type: ignore[attr-defined]
except AttributeError:
pass
else:
- from _pytest import nodes
-
- # construct the base nodeid which is later used to check
+ # Construct the base nodeid which is later used to check
# what fixtures are visible for particular tests (as denoted
- # by their test id)
- if p.basename.startswith("conftest.py"):
- nodeid = p.dirpath().relto(self.config.rootdir)
- if p.sep != nodes.SEP:
- nodeid = nodeid.replace(p.sep, nodes.SEP)
+ # by their test id).
+ if p.name.startswith("conftest.py"):
+ try:
+ nodeid = str(p.parent.relative_to(self.config.rootpath))
+ except ValueError:
+ nodeid = ""
+ if nodeid == ".":
+ nodeid = ""
+ if os.sep != nodes.SEP:
+ nodeid = nodeid.replace(os.sep, nodes.SEP)
self.parsefactories(plugin, nodeid)
- def _getautousenames(self, nodeid):
- """ return a tuple of fixture names to be used. """
- autousenames = []
- for baseid, basenames in self._nodeid_and_autousenames:
- if nodeid.startswith(baseid):
- if baseid:
- i = len(baseid)
- nextchar = nodeid[i : i + 1]
- if nextchar and nextchar not in ":/":
- continue
- autousenames.extend(basenames)
- return autousenames
-
- def getfixtureclosure(self, fixturenames, parentnode, ignore_args=()):
- # collect the closure of all fixtures , starting with the given
+ def _getautousenames(self, nodeid: str) -> Iterator[str]:
+ """Return the names of autouse fixtures applicable to nodeid."""
+ for parentnodeid in nodes.iterparentnodeids(nodeid):
+ basenames = self._nodeid_autousenames.get(parentnodeid)
+ if basenames:
+ yield from basenames
+
+ def getfixtureclosure(
+ self,
+ fixturenames: Tuple[str, ...],
+ parentnode: nodes.Node,
+ ignore_args: Sequence[str] = (),
+ ) -> Tuple[Tuple[str, ...], List[str], Dict[str, Sequence[FixtureDef[Any]]]]:
+ # Collect the closure of all fixtures, starting with the given
# fixturenames as the initial set. As we have to visit all
# factory definitions anyway, we also return an arg2fixturedefs
# mapping so that the caller can reuse it and does not have
# to re-discover fixturedefs again for each fixturename
- # (discovering matching fixtures for a given name/node is expensive)
+ # (discovering matching fixtures for a given name/node is expensive).
parentid = parentnode.nodeid
- fixturenames_closure = self._getautousenames(parentid)
+ fixturenames_closure = list(self._getautousenames(parentid))
- def merge(otherlist):
+ def merge(otherlist: Iterable[str]) -> None:
for arg in otherlist:
if arg not in fixturenames_closure:
fixturenames_closure.append(arg)
merge(fixturenames)
- # at this point, fixturenames_closure contains what we call "initialnames",
+ # At this point, fixturenames_closure contains what we call "initialnames",
# which is a set of fixturenames the function immediately requests. We
# need to return it as well, so save this.
initialnames = tuple(fixturenames_closure)
- arg2fixturedefs = {}
+ arg2fixturedefs: Dict[str, Sequence[FixtureDef[Any]]] = {}
lastlen = -1
while lastlen != len(fixturenames_closure):
lastlen = len(fixturenames_closure)
@@ -1357,7 +1537,7 @@ class FixtureManager:
arg2fixturedefs[argname] = fixturedefs
merge(fixturedefs[-1].argnames)
- def sort_by_scope(arg_name):
+ def sort_by_scope(arg_name: str) -> int:
try:
fixturedefs = arg2fixturedefs[arg_name]
except KeyError:
@@ -1368,41 +1548,58 @@ class FixtureManager:
fixturenames_closure.sort(key=sort_by_scope)
return initialnames, fixturenames_closure, arg2fixturedefs
- def pytest_generate_tests(self, metafunc):
+ def pytest_generate_tests(self, metafunc: "Metafunc") -> None:
+ """Generate new tests based on parametrized fixtures used by the given metafunc"""
+
+ def get_parametrize_mark_argnames(mark: Mark) -> Sequence[str]:
+ args, _ = ParameterSet._parse_parametrize_args(*mark.args, **mark.kwargs)
+ return args
+
for argname in metafunc.fixturenames:
- faclist = metafunc._arg2fixturedefs.get(argname)
- if faclist:
- fixturedef = faclist[-1]
+ # Get the FixtureDefs for the argname.
+ fixture_defs = metafunc._arg2fixturedefs.get(argname)
+ if not fixture_defs:
+ # Will raise FixtureLookupError at setup time if not parametrized somewhere
+ # else (e.g @pytest.mark.parametrize)
+ continue
+
+ # If the test itself parametrizes using this argname, give it
+ # precedence.
+ if any(
+ argname in get_parametrize_mark_argnames(mark)
+ for mark in metafunc.definition.iter_markers("parametrize")
+ ):
+ continue
+
+ # In the common case we only look at the fixture def with the
+ # closest scope (last in the list). But if the fixture overrides
+ # another fixture, while requesting the super fixture, keep going
+ # in case the super fixture is parametrized (#1953).
+ for fixturedef in reversed(fixture_defs):
+ # Fixture is parametrized, apply it and stop.
if fixturedef.params is not None:
- markers = list(metafunc.definition.iter_markers("parametrize"))
- for parametrize_mark in markers:
- if "argnames" in parametrize_mark.kwargs:
- argnames = parametrize_mark.kwargs["argnames"]
- else:
- argnames = parametrize_mark.args[0]
-
- if not isinstance(argnames, (tuple, list)):
- argnames = [
- x.strip() for x in argnames.split(",") if x.strip()
- ]
- if argname in argnames:
- break
- else:
- metafunc.parametrize(
- argname,
- fixturedef.params,
- indirect=True,
- scope=fixturedef.scope,
- ids=fixturedef.ids,
- )
- else:
- continue # will raise FixtureLookupError at setup time
+ metafunc.parametrize(
+ argname,
+ fixturedef.params,
+ indirect=True,
+ scope=fixturedef.scope,
+ ids=fixturedef.ids,
+ )
+ break
+
+ # Not requesting the overridden super fixture, stop.
+ if argname not in fixturedef.argnames:
+ break
- def pytest_collection_modifyitems(self, items):
- # separate parametrized setups
+ # Try next super fixture, if any.
+
+ def pytest_collection_modifyitems(self, items: List[nodes.Item]) -> None:
+ # Separate parametrized setups.
items[:] = reorder_items(items)
- def parsefactories(self, node_or_obj, nodeid=NOTSET, unittest=False):
+ def parsefactories(
+ self, node_or_obj, nodeid=NOTSET, unittest: bool = False
+ ) -> None:
if nodeid is not NOTSET:
holderobj = node_or_obj
else:
@@ -1419,25 +1616,26 @@ class FixtureManager:
obj = safe_getattr(holderobj, name, None)
marker = getfixturemarker(obj)
if not isinstance(marker, FixtureFunctionMarker):
- # magic globals with __getattr__ might have got us a wrong
- # fixture attribute
+ # Magic globals with __getattr__ might have got us a wrong
+ # fixture attribute.
continue
if marker.name:
name = marker.name
- # during fixture definition we wrap the original fixture function
- # to issue a warning if called directly, so here we unwrap it in order to not emit the warning
- # when pytest itself calls the fixture function
+ # During fixture definition we wrap the original fixture function
+ # to issue a warning if called directly, so here we unwrap it in
+ # order to not emit the warning when pytest itself calls the
+ # fixture function.
obj = get_real_method(obj, holderobj)
fixture_def = FixtureDef(
- self,
- nodeid,
- name,
- obj,
- marker.scope,
- marker.params,
+ fixturemanager=self,
+ baseid=nodeid,
+ argname=name,
+ func=obj,
+ scope=marker.scope,
+ params=marker.params,
unittest=unittest,
ids=marker.ids,
)
@@ -1456,15 +1654,16 @@ class FixtureManager:
autousenames.append(name)
if autousenames:
- self._nodeid_and_autousenames.append((nodeid or "", autousenames))
+ self._nodeid_autousenames.setdefault(nodeid or "", []).extend(autousenames)
- def getfixturedefs(self, argname, nodeid):
- """
- Gets a list of fixtures which are applicable to the given node id.
+ def getfixturedefs(
+ self, argname: str, nodeid: str
+ ) -> Optional[Sequence[FixtureDef[Any]]]:
+ """Get a list of fixtures which are applicable to the given node id.
- :param str argname: name of the fixture to search for
- :param str nodeid: full node id of the requesting test.
- :return: list[FixtureDef]
+ :param str argname: Name of the fixture to search for.
+ :param str nodeid: Full node id of the requesting test.
+ :rtype: Sequence[FixtureDef]
"""
try:
fixturedefs = self._arg2fixturedefs[argname]
@@ -1472,9 +1671,10 @@ class FixtureManager:
return None
return tuple(self._matchfactories(fixturedefs, nodeid))
- def _matchfactories(self, fixturedefs, nodeid):
- from _pytest import nodes
-
+ def _matchfactories(
+ self, fixturedefs: Iterable[FixtureDef[Any]], nodeid: str
+ ) -> Iterator[FixtureDef[Any]]:
+ parentnodeids = set(nodes.iterparentnodeids(nodeid))
for fixturedef in fixturedefs:
- if nodes.ischildnode(fixturedef.baseid, nodeid):
+ if fixturedef.baseid in parentnodeids:
yield fixturedef
diff --git a/contrib/python/pytest/py3/_pytest/freeze_support.py b/contrib/python/pytest/py3/_pytest/freeze_support.py
index f9d613a2b6..8b93ed5f7f 100644
--- a/contrib/python/pytest/py3/_pytest/freeze_support.py
+++ b/contrib/python/pytest/py3/_pytest/freeze_support.py
@@ -1,14 +1,14 @@
-"""
-Provides a function to report all internal modules for using freezing tools
-pytest
-"""
+"""Provides a function to report all internal modules for using freezing
+tools."""
+import types
+from typing import Iterator
+from typing import List
+from typing import Union
-def freeze_includes():
- """
- Returns a list of module names used by pytest that should be
- included by cx_freeze.
- """
+def freeze_includes() -> List[str]:
+ """Return a list of module names used by pytest that should be
+ included by cx_freeze."""
import py
import _pytest
@@ -17,25 +17,26 @@ def freeze_includes():
return result
-def _iter_all_modules(package, prefix=""):
- """
- Iterates over the names of all modules that can be found in the given
+def _iter_all_modules(
+ package: Union[str, types.ModuleType], prefix: str = "",
+) -> Iterator[str]:
+ """Iterate over the names of all modules that can be found in the given
package, recursively.
- Example:
- _iter_all_modules(_pytest) ->
- ['_pytest.assertion.newinterpret',
- '_pytest.capture',
- '_pytest.core',
- ...
- ]
+
+ >>> import _pytest
+ >>> list(_iter_all_modules(_pytest))
+ ['_pytest._argcomplete', '_pytest._code.code', ...]
"""
import os
import pkgutil
- if type(package) is not str:
- path, prefix = package.__path__[0], package.__name__ + "."
- else:
+ if isinstance(package, str):
path = package
+ else:
+ # Type ignored because typeshed doesn't define ModuleType.__path__
+ # (only defined on packages).
+ package_path = package.__path__ # type: ignore[attr-defined]
+ path, prefix = package_path[0], package.__name__ + "."
for _, name, is_package in pkgutil.iter_modules([path]):
if is_package:
for m in _iter_all_modules(os.path.join(path, name), prefix=name + "."):
diff --git a/contrib/python/pytest/py3/_pytest/helpconfig.py b/contrib/python/pytest/py3/_pytest/helpconfig.py
index ae37fdea45..4384d07b26 100644
--- a/contrib/python/pytest/py3/_pytest/helpconfig.py
+++ b/contrib/python/pytest/py3/_pytest/helpconfig.py
@@ -1,17 +1,24 @@
-""" version info, help messages, tracing configuration. """
+"""Version info, help messages, tracing configuration."""
import os
import sys
from argparse import Action
+from typing import List
+from typing import Optional
+from typing import Union
import py
import pytest
+from _pytest.config import Config
+from _pytest.config import ExitCode
from _pytest.config import PrintHelp
+from _pytest.config.argparsing import Parser
class HelpAction(Action):
- """This is an argparse Action that will raise an exception in
- order to skip the rest of the argument parsing when --help is passed.
+ """An argparse Action that will raise an exception in order to skip the
+ rest of the argument parsing when --help is passed.
+
This prevents argparse from quitting due to missing required arguments
when any are defined, for example by ``pytest_addoption``.
This is similar to the way that the builtin argparse --help option is
@@ -31,18 +38,21 @@ class HelpAction(Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, self.const)
- # We should only skip the rest of the parsing after preparse is done
+ # We should only skip the rest of the parsing after preparse is done.
if getattr(parser._parser, "after_preparse", False):
raise PrintHelp
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--version",
"-V",
- action="store_true",
- help="display pytest version and information about plugins.",
+ action="count",
+ default=0,
+ dest="version",
+ help="display pytest version and information about plugins."
+ "When given twice, also display information about plugins.",
)
group._addoption(
"-h",
@@ -57,7 +67,7 @@ def pytest_addoption(parser):
dest="plugins",
default=[],
metavar="name",
- help="early-load given plugin module name or entry point (multi-allowed). "
+ help="early-load given plugin module name or entry point (multi-allowed).\n"
"To avoid loading of plugins, use the `no:` prefix, e.g. "
"`no:doctest`.",
)
@@ -87,7 +97,7 @@ def pytest_addoption(parser):
@pytest.hookimpl(hookwrapper=True)
def pytest_cmdline_parse():
outcome = yield
- config = outcome.get_result()
+ config: Config = outcome.get_result()
if config.option.debug:
path = os.path.abspath("pytestdebug.log")
debugfile = open(path, "w")
@@ -106,7 +116,7 @@ def pytest_cmdline_parse():
undo_tracing = config.pluginmanager.enable_tracing()
sys.stderr.write("writing pytestdebug information to %s\n" % path)
- def unset_tracing():
+ def unset_tracing() -> None:
debugfile.close()
sys.stderr.write("wrote pytestdebug information to %s\n" % debugfile.name)
config.trace.root.setwriter(None)
@@ -115,20 +125,23 @@ def pytest_cmdline_parse():
config.add_cleanup(unset_tracing)
-def showversion(config):
- sys.stderr.write(
- "This is pytest version {}, imported from {}\n".format(
- pytest.__version__, pytest.__file__
+def showversion(config: Config) -> None:
+ if config.option.version > 1:
+ sys.stderr.write(
+ "This is pytest version {}, imported from {}\n".format(
+ pytest.__version__, pytest.__file__
+ )
)
- )
- plugininfo = getpluginversioninfo(config)
- if plugininfo:
- for line in plugininfo:
- sys.stderr.write(line + "\n")
+ plugininfo = getpluginversioninfo(config)
+ if plugininfo:
+ for line in plugininfo:
+ sys.stderr.write(line + "\n")
+ else:
+ sys.stderr.write(f"pytest {pytest.__version__}\n")
-def pytest_cmdline_main(config):
- if config.option.version:
+def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
+ if config.option.version > 0:
showversion(config)
return 0
elif config.option.help:
@@ -136,9 +149,10 @@ def pytest_cmdline_main(config):
showhelp(config)
config._ensure_unconfigure()
return 0
+ return None
-def showhelp(config):
+def showhelp(config: Config) -> None:
import textwrap
reporter = config.pluginmanager.get_plugin("terminalreporter")
@@ -157,7 +171,9 @@ def showhelp(config):
help, type, default = config._parser._inidict[name]
if type is None:
type = "string"
- spec = "{} ({}):".format(name, type)
+ if help is None:
+ raise TypeError(f"help argument cannot be None for {name}")
+ spec = f"{name} ({type}):"
tw.write(" %s" % spec)
spec_len = len(spec)
if spec_len > (indent_len - 3):
@@ -178,9 +194,10 @@ def showhelp(config):
tw.write(" " * (indent_len - spec_len - 2))
wrapped = textwrap.wrap(help, columns - indent_len, break_on_hyphens=False)
- tw.line(wrapped[0])
- for line in wrapped[1:]:
- tw.line(indent + line)
+ if wrapped:
+ tw.line(wrapped[0])
+ for line in wrapped[1:]:
+ tw.line(indent + line)
tw.line()
tw.line("environment variables:")
@@ -191,7 +208,7 @@ def showhelp(config):
("PYTEST_DEBUG", "set to enable debug tracing of pytest's internals"),
]
for name, help in vars:
- tw.line(" {:<24} {}".format(name, help))
+ tw.line(f" {name:<24} {help}")
tw.line()
tw.line()
@@ -211,24 +228,22 @@ def showhelp(config):
conftest_options = [("pytest_plugins", "list of plugin names to load")]
-def getpluginversioninfo(config):
+def getpluginversioninfo(config: Config) -> List[str]:
lines = []
plugininfo = config.pluginmanager.list_plugin_distinfo()
if plugininfo:
lines.append("setuptools registered plugins:")
for plugin, dist in plugininfo:
loc = getattr(plugin, "__file__", repr(plugin))
- content = "{}-{} at {}".format(dist.project_name, dist.version, loc)
+ content = f"{dist.project_name}-{dist.version} at {loc}"
lines.append(" " + content)
return lines
-def pytest_report_header(config):
+def pytest_report_header(config: Config) -> List[str]:
lines = []
if config.option.debug or config.option.traceconfig:
- lines.append(
- "using: pytest-{} pylib-{}".format(pytest.__version__, py.__version__)
- )
+ lines.append(f"using: pytest-{pytest.__version__} pylib-{py.__version__}")
verinfo = getpluginversioninfo(config)
if verinfo:
@@ -242,5 +257,5 @@ def pytest_report_header(config):
r = plugin.__file__
else:
r = repr(plugin)
- lines.append(" {:<20}: {}".format(name, r))
+ lines.append(f" {name:<20}: {r}")
return lines
diff --git a/contrib/python/pytest/py3/_pytest/hookspec.py b/contrib/python/pytest/py3/_pytest/hookspec.py
index 1e16d092d0..e499b742c7 100644
--- a/contrib/python/pytest/py3/_pytest/hookspec.py
+++ b/contrib/python/pytest/py3/_pytest/hookspec.py
@@ -1,14 +1,46 @@
-""" hook specifications for pytest plugins, invoked from main.py and builtin plugins. """
+"""Hook specifications for pytest plugins which are invoked by pytest itself
+and by builtin plugins."""
from typing import Any
+from typing import Dict
+from typing import List
+from typing import Mapping
from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+import py.path
from pluggy import HookspecMarker
-from .deprecated import COLLECT_DIRECTORY_HOOK
-from _pytest.compat import TYPE_CHECKING
+from _pytest.deprecated import WARNING_CAPTURED_HOOK
if TYPE_CHECKING:
+ import pdb
+ import warnings
+ from typing_extensions import Literal
+
+ from _pytest._code.code import ExceptionRepr
+ from _pytest.code import ExceptionInfo
+ from _pytest.config import Config
+ from _pytest.config import ExitCode
+ from _pytest.config import PytestPluginManager
+ from _pytest.config import _PluggyPlugin
+ from _pytest.config.argparsing import Parser
+ from _pytest.fixtures import FixtureDef
+ from _pytest.fixtures import SubRequest
from _pytest.main import Session
+ from _pytest.nodes import Collector
+ from _pytest.nodes import Item
+ from _pytest.outcomes import Exit
+ from _pytest.python import Function
+ from _pytest.python import Metafunc
+ from _pytest.python import Module
+ from _pytest.python import PyCollector
+ from _pytest.reports import CollectReport
+ from _pytest.reports import TestReport
+ from _pytest.runner import CallInfo
+ from _pytest.terminal import TerminalReporter
hookspec = HookspecMarker("pytest")
@@ -19,12 +51,11 @@ hookspec = HookspecMarker("pytest")
@hookspec(historic=True)
-def pytest_addhooks(pluginmanager):
- """called at plugin registration time to allow adding new hooks via a call to
+def pytest_addhooks(pluginmanager: "PytestPluginManager") -> None:
+ """Called at plugin registration time to allow adding new hooks via a call to
``pluginmanager.add_hookspecs(module_or_class, prefix)``.
-
- :param _pytest.config.PytestPluginManager pluginmanager: pytest plugin manager
+ :param _pytest.config.PytestPluginManager pluginmanager: pytest plugin manager.
.. note::
This hook is incompatible with ``hookwrapper=True``.
@@ -32,11 +63,13 @@ def pytest_addhooks(pluginmanager):
@hookspec(historic=True)
-def pytest_plugin_registered(plugin, manager):
- """ a new pytest plugin got registered.
+def pytest_plugin_registered(
+ plugin: "_PluggyPlugin", manager: "PytestPluginManager"
+) -> None:
+ """A new pytest plugin got registered.
- :param plugin: the plugin module or instance
- :param _pytest.config.PytestPluginManager manager: pytest plugin manager
+ :param plugin: The plugin module or instance.
+ :param _pytest.config.PytestPluginManager manager: pytest plugin manager.
.. note::
This hook is incompatible with ``hookwrapper=True``.
@@ -44,8 +77,8 @@ def pytest_plugin_registered(plugin, manager):
@hookspec(historic=True)
-def pytest_addoption(parser, pluginmanager):
- """register argparse-style options and ini-style config values,
+def pytest_addoption(parser: "Parser", pluginmanager: "PytestPluginManager") -> None:
+ """Register argparse-style options and ini-style config values,
called once at the beginning of a test run.
.. note::
@@ -54,15 +87,16 @@ def pytest_addoption(parser, pluginmanager):
files situated at the tests root directory due to how pytest
:ref:`discovers plugins during startup <pluginorder>`.
- :arg _pytest.config.argparsing.Parser parser: To add command line options, call
+ :param _pytest.config.argparsing.Parser parser:
+ To add command line options, call
:py:func:`parser.addoption(...) <_pytest.config.argparsing.Parser.addoption>`.
To add ini-file values call :py:func:`parser.addini(...)
<_pytest.config.argparsing.Parser.addini>`.
- :arg _pytest.config.PytestPluginManager pluginmanager: pytest plugin manager,
- which can be used to install :py:func:`hookspec`'s or :py:func:`hookimpl`'s
- and allow one plugin to call another plugin's hooks to change how
- command line options are added.
+ :param _pytest.config.PytestPluginManager pluginmanager:
+ pytest plugin manager, which can be used to install :py:func:`hookspec`'s
+ or :py:func:`hookimpl`'s and allow one plugin to call another plugin's hooks
+ to change how command line options are added.
Options can later be accessed through the
:py:class:`config <_pytest.config.Config>` object, respectively:
@@ -82,9 +116,8 @@ def pytest_addoption(parser, pluginmanager):
@hookspec(historic=True)
-def pytest_configure(config):
- """
- Allows plugins and conftest files to perform initial configuration.
+def pytest_configure(config: "Config") -> None:
+ """Allow plugins and conftest files to perform initial configuration.
This hook is called for every plugin and initial conftest file
after command line options have been parsed.
@@ -95,7 +128,7 @@ def pytest_configure(config):
.. note::
This hook is incompatible with ``hookwrapper=True``.
- :arg _pytest.config.Config config: pytest config object
+ :param _pytest.config.Config config: The pytest config object.
"""
@@ -106,21 +139,24 @@ def pytest_configure(config):
@hookspec(firstresult=True)
-def pytest_cmdline_parse(pluginmanager, args):
- """return initialized config object, parsing the specified args.
+def pytest_cmdline_parse(
+ pluginmanager: "PytestPluginManager", args: List[str]
+) -> Optional["Config"]:
+ """Return an initialized config object, parsing the specified args.
- Stops at first non-None result, see :ref:`firstresult`
+ Stops at first non-None result, see :ref:`firstresult`.
.. note::
- This hook will only be called for plugin classes passed to the ``plugins`` arg when using `pytest.main`_ to
- perform an in-process test run.
+ This hook will only be called for plugin classes passed to the
+ ``plugins`` arg when using `pytest.main`_ to perform an in-process
+ test run.
- :param _pytest.config.PytestPluginManager pluginmanager: pytest plugin manager
- :param list[str] args: list of arguments passed on the command line
+ :param _pytest.config.PytestPluginManager pluginmanager: Pytest plugin manager.
+ :param List[str] args: List of arguments passed on the command line.
"""
-def pytest_cmdline_preparse(config, args):
+def pytest_cmdline_preparse(config: "Config", args: List[str]) -> None:
"""(**Deprecated**) modify command line arguments before option parsing.
This hook is considered deprecated and will be removed in a future pytest version. Consider
@@ -129,35 +165,37 @@ def pytest_cmdline_preparse(config, args):
.. note::
This hook will not be called for ``conftest.py`` files, only for setuptools plugins.
- :param _pytest.config.Config config: pytest config object
- :param list[str] args: list of arguments passed on the command line
+ :param _pytest.config.Config config: The pytest config object.
+ :param List[str] args: Arguments passed on the command line.
"""
@hookspec(firstresult=True)
-def pytest_cmdline_main(config):
- """ called for performing the main command line action. The default
+def pytest_cmdline_main(config: "Config") -> Optional[Union["ExitCode", int]]:
+ """Called for performing the main command line action. The default
implementation will invoke the configure hooks and runtest_mainloop.
.. note::
This hook will not be called for ``conftest.py`` files, only for setuptools plugins.
- Stops at first non-None result, see :ref:`firstresult`
+ Stops at first non-None result, see :ref:`firstresult`.
- :param _pytest.config.Config config: pytest config object
+ :param _pytest.config.Config config: The pytest config object.
"""
-def pytest_load_initial_conftests(early_config, parser, args):
- """ implements the loading of initial conftest files ahead
+def pytest_load_initial_conftests(
+ early_config: "Config", parser: "Parser", args: List[str]
+) -> None:
+ """Called to implement the loading of initial conftest files ahead
of command line option parsing.
.. note::
This hook will not be called for ``conftest.py`` files, only for setuptools plugins.
- :param _pytest.config.Config early_config: pytest config object
- :param list[str] args: list of arguments passed on the command line
- :param _pytest.config.argparsing.Parser parser: to add command line options
+ :param _pytest.config.Config early_config: The pytest config object.
+ :param List[str] args: Arguments passed on the command line.
+ :param _pytest.config.argparsing.Parser parser: To add command line options.
"""
@@ -167,87 +205,114 @@ def pytest_load_initial_conftests(early_config, parser, args):
@hookspec(firstresult=True)
-def pytest_collection(session: "Session") -> Optional[Any]:
- """Perform the collection protocol for the given session.
+def pytest_collection(session: "Session") -> Optional[object]:
+ """Perform the collection phase for the given session.
Stops at first non-None result, see :ref:`firstresult`.
+ The return value is not used, but only stops further processing.
+
+ The default collection phase is this (see individual hooks for full details):
+
+ 1. Starting from ``session`` as the initial collector:
+
+ 1. ``pytest_collectstart(collector)``
+ 2. ``report = pytest_make_collect_report(collector)``
+ 3. ``pytest_exception_interact(collector, call, report)`` if an interactive exception occurred
+ 4. For each collected node:
+
+ 1. If an item, ``pytest_itemcollected(item)``
+ 2. If a collector, recurse into it.
+
+ 5. ``pytest_collectreport(report)``
+
+ 2. ``pytest_collection_modifyitems(session, config, items)``
+
+ 1. ``pytest_deselected(items)`` for any deselected items (may be called multiple times)
+
+ 3. ``pytest_collection_finish(session)``
+ 4. Set ``session.items`` to the list of collected items
+ 5. Set ``session.testscollected`` to the number of collected items
+
+ You can implement this hook to only perform some action before collection,
+ for example the terminal plugin uses it to start displaying the collection
+ counter (and returns `None`).
- :param _pytest.main.Session session: the pytest session object
+ :param pytest.Session session: The pytest session object.
"""
-def pytest_collection_modifyitems(session, config, items):
- """ called after collection has been performed, may filter or re-order
+def pytest_collection_modifyitems(
+ session: "Session", config: "Config", items: List["Item"]
+) -> None:
+ """Called after collection has been performed. May filter or re-order
the items in-place.
- :param _pytest.main.Session session: the pytest session object
- :param _pytest.config.Config config: pytest config object
- :param List[_pytest.nodes.Item] items: list of item objects
+ :param pytest.Session session: The pytest session object.
+ :param _pytest.config.Config config: The pytest config object.
+ :param List[pytest.Item] items: List of item objects.
"""
-def pytest_collection_finish(session):
- """ called after collection has been performed and modified.
+def pytest_collection_finish(session: "Session") -> None:
+ """Called after collection has been performed and modified.
- :param _pytest.main.Session session: the pytest session object
+ :param pytest.Session session: The pytest session object.
"""
@hookspec(firstresult=True)
-def pytest_ignore_collect(path, config):
- """ return True to prevent considering this path for collection.
+def pytest_ignore_collect(path: py.path.local, config: "Config") -> Optional[bool]:
+ """Return True to prevent considering this path for collection.
+
This hook is consulted for all files and directories prior to calling
more specific hooks.
- Stops at first non-None result, see :ref:`firstresult`
+ Stops at first non-None result, see :ref:`firstresult`.
- :param path: a :py:class:`py.path.local` - the path to analyze
- :param _pytest.config.Config config: pytest config object
+ :param py.path.local path: The path to analyze.
+ :param _pytest.config.Config config: The pytest config object.
"""
-@hookspec(firstresult=True, warn_on_impl=COLLECT_DIRECTORY_HOOK)
-def pytest_collect_directory(path, parent):
- """ called before traversing a directory for collection files.
+def pytest_collect_file(
+ path: py.path.local, parent: "Collector"
+) -> "Optional[Collector]":
+ """Create a Collector for the given path, or None if not relevant.
- Stops at first non-None result, see :ref:`firstresult`
+ The new node needs to have the specified ``parent`` as a parent.
- :param path: a :py:class:`py.path.local` - the path to analyze
- """
-
-
-def pytest_collect_file(path, parent):
- """ return collection Node or None for the given path. Any new node
- needs to have the specified ``parent`` as a parent.
-
- :param path: a :py:class:`py.path.local` - the path to collect
+ :param py.path.local path: The path to collect.
"""
# logging hooks for collection
-def pytest_collectstart(collector):
- """ collector starts collecting. """
+def pytest_collectstart(collector: "Collector") -> None:
+ """Collector starts collecting."""
-def pytest_itemcollected(item):
- """ we just collected a test item. """
+def pytest_itemcollected(item: "Item") -> None:
+ """We just collected a test item."""
-def pytest_collectreport(report):
- """ collector finished collecting. """
+def pytest_collectreport(report: "CollectReport") -> None:
+ """Collector finished collecting."""
-def pytest_deselected(items):
- """ called for test items deselected, e.g. by keyword. """
+def pytest_deselected(items: Sequence["Item"]) -> None:
+ """Called for deselected test items, e.g. by keyword.
+
+ May be called multiple times.
+ """
@hookspec(firstresult=True)
-def pytest_make_collect_report(collector):
- """ perform ``collector.collect()`` and return a CollectReport.
+def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectReport]":
+ """Perform ``collector.collect()`` and return a CollectReport.
- Stops at first non-None result, see :ref:`firstresult` """
+ Stops at first non-None result, see :ref:`firstresult`.
+ """
# -------------------------------------------------------------------------
@@ -256,165 +321,232 @@ def pytest_make_collect_report(collector):
@hookspec(firstresult=True)
-def pytest_pycollect_makemodule(path, parent):
- """ return a Module collector or None for the given path.
+def pytest_pycollect_makemodule(path: py.path.local, parent) -> Optional["Module"]:
+ """Return a Module collector or None for the given path.
+
This hook will be called for each matching test module path.
The pytest_collect_file hook needs to be used if you want to
create test modules for files that do not match as a test module.
- Stops at first non-None result, see :ref:`firstresult`
+ Stops at first non-None result, see :ref:`firstresult`.
- :param path: a :py:class:`py.path.local` - the path of module to collect
+ :param py.path.local path: The path of module to collect.
"""
@hookspec(firstresult=True)
-def pytest_pycollect_makeitem(collector, name, obj):
- """ return custom item/collector for a python object in a module, or None.
+def pytest_pycollect_makeitem(
+ collector: "PyCollector", name: str, obj: object
+) -> Union[None, "Item", "Collector", List[Union["Item", "Collector"]]]:
+ """Return a custom item/collector for a Python object in a module, or None.
- Stops at first non-None result, see :ref:`firstresult` """
+ Stops at first non-None result, see :ref:`firstresult`.
+ """
@hookspec(firstresult=True)
-def pytest_pyfunc_call(pyfuncitem):
- """ call underlying test function.
+def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
+ """Call underlying test function.
- Stops at first non-None result, see :ref:`firstresult` """
+ Stops at first non-None result, see :ref:`firstresult`.
+ """
-def pytest_generate_tests(metafunc):
- """ generate (multiple) parametrized calls to a test function."""
+def pytest_generate_tests(metafunc: "Metafunc") -> None:
+ """Generate (multiple) parametrized calls to a test function."""
@hookspec(firstresult=True)
-def pytest_make_parametrize_id(config, val, argname):
- """Return a user-friendly string representation of the given ``val`` that will be used
- by @pytest.mark.parametrize calls. Return None if the hook doesn't know about ``val``.
+def pytest_make_parametrize_id(
+ config: "Config", val: object, argname: str
+) -> Optional[str]:
+ """Return a user-friendly string representation of the given ``val``
+ that will be used by @pytest.mark.parametrize calls, or None if the hook
+ doesn't know about ``val``.
+
The parameter name is available as ``argname``, if required.
- Stops at first non-None result, see :ref:`firstresult`
+ Stops at first non-None result, see :ref:`firstresult`.
- :param _pytest.config.Config config: pytest config object
- :param val: the parametrized value
- :param str argname: the automatic parameter name produced by pytest
+ :param _pytest.config.Config config: The pytest config object.
+ :param val: The parametrized value.
+ :param str argname: The automatic parameter name produced by pytest.
"""
# -------------------------------------------------------------------------
-# generic runtest related hooks
+# runtest related hooks
# -------------------------------------------------------------------------
@hookspec(firstresult=True)
-def pytest_runtestloop(session):
- """ called for performing the main runtest loop
- (after collection finished).
+def pytest_runtestloop(session: "Session") -> Optional[object]:
+ """Perform the main runtest loop (after collection finished).
+
+ The default hook implementation performs the runtest protocol for all items
+ collected in the session (``session.items``), unless the collection failed
+ or the ``collectonly`` pytest option is set.
+
+ If at any point :py:func:`pytest.exit` is called, the loop is
+ terminated immediately.
- Stops at first non-None result, see :ref:`firstresult`
+ If at any point ``session.shouldfail`` or ``session.shouldstop`` are set, the
+ loop is terminated after the runtest protocol for the current item is finished.
- :param _pytest.main.Session session: the pytest session object
+ :param pytest.Session session: The pytest session object.
+
+ Stops at first non-None result, see :ref:`firstresult`.
+ The return value is not used, but only stops further processing.
"""
@hookspec(firstresult=True)
-def pytest_runtest_protocol(item, nextitem):
- """ implements the runtest_setup/call/teardown protocol for
- the given test item, including capturing exceptions and calling
- reporting hooks.
-
- :arg item: test item for which the runtest protocol is performed.
+def pytest_runtest_protocol(
+ item: "Item", nextitem: "Optional[Item]"
+) -> Optional[object]:
+ """Perform the runtest protocol for a single test item.
- :arg nextitem: the scheduled-to-be-next test item (or None if this
- is the end my friend). This argument is passed on to
- :py:func:`pytest_runtest_teardown`.
+ The default runtest protocol is this (see individual hooks for full details):
- :return boolean: True if no further hook implementations should be invoked.
+ - ``pytest_runtest_logstart(nodeid, location)``
+ - Setup phase:
+ - ``call = pytest_runtest_setup(item)`` (wrapped in ``CallInfo(when="setup")``)
+ - ``report = pytest_runtest_makereport(item, call)``
+ - ``pytest_runtest_logreport(report)``
+ - ``pytest_exception_interact(call, report)`` if an interactive exception occurred
- Stops at first non-None result, see :ref:`firstresult` """
+ - Call phase, if the the setup passed and the ``setuponly`` pytest option is not set:
+ - ``call = pytest_runtest_call(item)`` (wrapped in ``CallInfo(when="call")``)
+ - ``report = pytest_runtest_makereport(item, call)``
+ - ``pytest_runtest_logreport(report)``
+ - ``pytest_exception_interact(call, report)`` if an interactive exception occurred
+ - Teardown phase:
+ - ``call = pytest_runtest_teardown(item, nextitem)`` (wrapped in ``CallInfo(when="teardown")``)
+ - ``report = pytest_runtest_makereport(item, call)``
+ - ``pytest_runtest_logreport(report)``
+ - ``pytest_exception_interact(call, report)`` if an interactive exception occurred
-def pytest_runtest_logstart(nodeid, location):
- """ signal the start of running a single test item.
+ - ``pytest_runtest_logfinish(nodeid, location)``
- This hook will be called **before** :func:`pytest_runtest_setup`, :func:`pytest_runtest_call` and
- :func:`pytest_runtest_teardown` hooks.
+ :param item: Test item for which the runtest protocol is performed.
+ :param nextitem: The scheduled-to-be-next test item (or None if this is the end my friend).
- :param str nodeid: full id of the item
- :param location: a triple of ``(filename, linenum, testname)``
+ Stops at first non-None result, see :ref:`firstresult`.
+ The return value is not used, but only stops further processing.
"""
-def pytest_runtest_logfinish(nodeid, location):
- """ signal the complete finish of running a single test item.
+def pytest_runtest_logstart(
+ nodeid: str, location: Tuple[str, Optional[int], str]
+) -> None:
+ """Called at the start of running the runtest protocol for a single item.
- This hook will be called **after** :func:`pytest_runtest_setup`, :func:`pytest_runtest_call` and
- :func:`pytest_runtest_teardown` hooks.
+ See :func:`pytest_runtest_protocol` for a description of the runtest protocol.
- :param str nodeid: full id of the item
- :param location: a triple of ``(filename, linenum, testname)``
+ :param str nodeid: Full node ID of the item.
+ :param location: A tuple of ``(filename, lineno, testname)``.
"""
-def pytest_runtest_setup(item):
- """ called before ``pytest_runtest_call(item)``. """
+def pytest_runtest_logfinish(
+ nodeid: str, location: Tuple[str, Optional[int], str]
+) -> None:
+ """Called at the end of running the runtest protocol for a single item.
+ See :func:`pytest_runtest_protocol` for a description of the runtest protocol.
-def pytest_runtest_call(item):
- """ called to execute the test ``item``. """
+ :param str nodeid: Full node ID of the item.
+ :param location: A tuple of ``(filename, lineno, testname)``.
+ """
-def pytest_runtest_teardown(item, nextitem):
- """ called after ``pytest_runtest_call``.
+def pytest_runtest_setup(item: "Item") -> None:
+ """Called to perform the setup phase for a test item.
- :arg nextitem: the scheduled-to-be-next test item (None if no further
- test item is scheduled). This argument can be used to
- perform exact teardowns, i.e. calling just enough finalizers
- so that nextitem only needs to call setup-functions.
+ The default implementation runs ``setup()`` on ``item`` and all of its
+ parents (which haven't been setup yet). This includes obtaining the
+ values of fixtures required by the item (which haven't been obtained
+ yet).
"""
-@hookspec(firstresult=True)
-def pytest_runtest_makereport(item, call):
- """ return a :py:class:`_pytest.runner.TestReport` object
- for the given :py:class:`pytest.Item <_pytest.main.Item>` and
- :py:class:`_pytest.runner.CallInfo`.
+def pytest_runtest_call(item: "Item") -> None:
+ """Called to run the test for test item (the call phase).
- Stops at first non-None result, see :ref:`firstresult` """
+ The default implementation calls ``item.runtest()``.
+ """
-def pytest_runtest_logreport(report):
- """ process a test setup/call/teardown report relating to
- the respective phase of executing a test. """
+def pytest_runtest_teardown(item: "Item", nextitem: Optional["Item"]) -> None:
+ """Called to perform the teardown phase for a test item.
+ The default implementation runs the finalizers and calls ``teardown()``
+ on ``item`` and all of its parents (which need to be torn down). This
+ includes running the teardown phase of fixtures required by the item (if
+ they go out of scope).
-@hookspec(firstresult=True)
-def pytest_report_to_serializable(config, report):
- """
- Serializes the given report object into a data structure suitable for sending
- over the wire, e.g. converted to JSON.
+ :param nextitem:
+ The scheduled-to-be-next test item (None if no further test item is
+ scheduled). This argument can be used to perform exact teardowns,
+ i.e. calling just enough finalizers so that nextitem only needs to
+ call setup-functions.
"""
@hookspec(firstresult=True)
-def pytest_report_from_serializable(config, data):
+def pytest_runtest_makereport(
+ item: "Item", call: "CallInfo[None]"
+) -> Optional["TestReport"]:
+ """Called to create a :py:class:`_pytest.reports.TestReport` for each of
+ the setup, call and teardown runtest phases of a test item.
+
+ See :func:`pytest_runtest_protocol` for a description of the runtest protocol.
+
+ :param CallInfo[None] call: The ``CallInfo`` for the phase.
+
+ Stops at first non-None result, see :ref:`firstresult`.
"""
- Restores a report object previously serialized with pytest_report_to_serializable().
+
+
+def pytest_runtest_logreport(report: "TestReport") -> None:
+ """Process the :py:class:`_pytest.reports.TestReport` produced for each
+ of the setup, call and teardown runtest phases of an item.
+
+ See :func:`pytest_runtest_protocol` for a description of the runtest protocol.
"""
+@hookspec(firstresult=True)
+def pytest_report_to_serializable(
+ config: "Config", report: Union["CollectReport", "TestReport"],
+) -> Optional[Dict[str, Any]]:
+ """Serialize the given report object into a data structure suitable for
+ sending over the wire, e.g. converted to JSON."""
+
+
+@hookspec(firstresult=True)
+def pytest_report_from_serializable(
+ config: "Config", data: Dict[str, Any],
+) -> Optional[Union["CollectReport", "TestReport"]]:
+ """Restore a report object previously serialized with pytest_report_to_serializable()."""
+
+
# -------------------------------------------------------------------------
# Fixture related hooks
# -------------------------------------------------------------------------
@hookspec(firstresult=True)
-def pytest_fixture_setup(fixturedef, request):
- """ performs fixture setup execution.
+def pytest_fixture_setup(
+ fixturedef: "FixtureDef[Any]", request: "SubRequest"
+) -> Optional[object]:
+ """Perform fixture setup execution.
- :return: The return value of the call to the fixture function
+ :returns: The return value of the call to the fixture function.
- Stops at first non-None result, see :ref:`firstresult`
+ Stops at first non-None result, see :ref:`firstresult`.
.. note::
If the fixture function returns None, other implementations of
@@ -423,7 +555,9 @@ def pytest_fixture_setup(fixturedef, request):
"""
-def pytest_fixture_post_finalizer(fixturedef, request):
+def pytest_fixture_post_finalizer(
+ fixturedef: "FixtureDef[Any]", request: "SubRequest"
+) -> None:
"""Called after fixture teardown, but before the cache is cleared, so
the fixture result ``fixturedef.cached_result`` is still available (not
``None``)."""
@@ -434,26 +568,28 @@ def pytest_fixture_post_finalizer(fixturedef, request):
# -------------------------------------------------------------------------
-def pytest_sessionstart(session):
- """ called after the ``Session`` object has been created and before performing collection
+def pytest_sessionstart(session: "Session") -> None:
+ """Called after the ``Session`` object has been created and before performing collection
and entering the run test loop.
- :param _pytest.main.Session session: the pytest session object
+ :param pytest.Session session: The pytest session object.
"""
-def pytest_sessionfinish(session, exitstatus):
- """ called after whole test run finished, right before returning the exit status to the system.
+def pytest_sessionfinish(
+ session: "Session", exitstatus: Union[int, "ExitCode"],
+) -> None:
+ """Called after whole test run finished, right before returning the exit status to the system.
- :param _pytest.main.Session session: the pytest session object
- :param int exitstatus: the status which pytest will return to the system
+ :param pytest.Session session: The pytest session object.
+ :param int exitstatus: The status which pytest will return to the system.
"""
-def pytest_unconfigure(config):
- """ called before test process is exited.
+def pytest_unconfigure(config: "Config") -> None:
+ """Called before test process is exited.
- :param _pytest.config.Config config: pytest config object
+ :param _pytest.config.Config config: The pytest config object.
"""
@@ -462,26 +598,25 @@ def pytest_unconfigure(config):
# -------------------------------------------------------------------------
-def pytest_assertrepr_compare(config, op, left, right):
- """return explanation for comparisons in failing assert expressions.
+def pytest_assertrepr_compare(
+ config: "Config", op: str, left: object, right: object
+) -> Optional[List[str]]:
+ """Return explanation for comparisons in failing assert expressions.
Return None for no custom explanation, otherwise return a list
- of strings. The strings will be joined by newlines but any newlines
- *in* a string will be escaped. Note that all but the first line will
+ of strings. The strings will be joined by newlines but any newlines
+ *in* a string will be escaped. Note that all but the first line will
be indented slightly, the intention is for the first line to be a summary.
- :param _pytest.config.Config config: pytest config object
+ :param _pytest.config.Config config: The pytest config object.
"""
-def pytest_assertion_pass(item, lineno, orig, expl):
- """
- **(Experimental)**
+def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> None:
+ """**(Experimental)** Called whenever an assertion passes.
.. versionadded:: 5.0
- Hook called whenever an assertion *passes*.
-
Use this hook to do some processing after a passing assertion.
The original assertion information is available in the `orig` string
and the pytest introspected assertion information is available in the
@@ -498,30 +633,39 @@ def pytest_assertion_pass(item, lineno, orig, expl):
You need to **clean the .pyc** files in your project directory and interpreter libraries
when enabling this option, as assertions will require to be re-written.
- :param _pytest.nodes.Item item: pytest item object of current test
- :param int lineno: line number of the assert statement
- :param string orig: string with original assertion
- :param string expl: string with assert explanation
+ :param pytest.Item item: pytest item object of current test.
+ :param int lineno: Line number of the assert statement.
+ :param str orig: String with the original assertion.
+ :param str expl: String with the assert explanation.
.. note::
This hook is **experimental**, so its parameters or even the hook itself might
be changed/removed without warning in any future pytest release.
- If you find this hook useful, please share your feedback opening an issue.
+ If you find this hook useful, please share your feedback in an issue.
"""
# -------------------------------------------------------------------------
-# hooks for influencing reporting (invoked from _pytest_terminal)
+# Hooks for influencing reporting (invoked from _pytest_terminal).
# -------------------------------------------------------------------------
-def pytest_report_header(config, startdir):
- """ return a string or list of strings to be displayed as header info for terminal reporting.
+def pytest_report_header(
+ config: "Config", startdir: py.path.local
+) -> Union[str, List[str]]:
+ """Return a string or list of strings to be displayed as header info for terminal reporting.
- :param _pytest.config.Config config: pytest config object
- :param startdir: py.path object with the starting dir
+ :param _pytest.config.Config config: The pytest config object.
+ :param py.path.local startdir: The starting dir.
+
+ .. note::
+
+ Lines returned by a plugin are displayed before those of plugins which
+ ran before it.
+ If you want to have your line(s) displayed first, use
+ :ref:`trylast=True <plugin-hookorder>`.
.. note::
@@ -531,45 +675,85 @@ def pytest_report_header(config, startdir):
"""
-def pytest_report_collectionfinish(config, startdir, items):
- """
+def pytest_report_collectionfinish(
+ config: "Config", startdir: py.path.local, items: Sequence["Item"],
+) -> Union[str, List[str]]:
+ """Return a string or list of strings to be displayed after collection
+ has finished successfully.
+
+ These strings will be displayed after the standard "collected X items" message.
+
.. versionadded:: 3.2
- return a string or list of strings to be displayed after collection has finished successfully.
+ :param _pytest.config.Config config: The pytest config object.
+ :param py.path.local startdir: The starting dir.
+ :param items: List of pytest items that are going to be executed; this list should not be modified.
- This strings will be displayed after the standard "collected X items" message.
+ .. note::
- :param _pytest.config.Config config: pytest config object
- :param startdir: py.path object with the starting dir
- :param items: list of pytest items that are going to be executed; this list should not be modified.
+ Lines returned by a plugin are displayed before those of plugins which
+ ran before it.
+ If you want to have your line(s) displayed first, use
+ :ref:`trylast=True <plugin-hookorder>`.
"""
@hookspec(firstresult=True)
-def pytest_report_teststatus(report, config):
- """ return result-category, shortletter and verbose word for reporting.
+def pytest_report_teststatus(
+ report: Union["CollectReport", "TestReport"], config: "Config"
+) -> Tuple[
+ str, str, Union[str, Mapping[str, bool]],
+]:
+ """Return result-category, shortletter and verbose word for status
+ reporting.
- :param _pytest.config.Config config: pytest config object
+ The result-category is a category in which to count the result, for
+ example "passed", "skipped", "error" or the empty string.
- Stops at first non-None result, see :ref:`firstresult` """
+ The shortletter is shown as testing progresses, for example ".", "s",
+ "E" or the empty string.
+ The verbose word is shown as testing progresses in verbose mode, for
+ example "PASSED", "SKIPPED", "ERROR" or the empty string.
-def pytest_terminal_summary(terminalreporter, exitstatus, config):
+ pytest may style these implicitly according to the report outcome.
+ To provide explicit styling, return a tuple for the verbose word,
+ for example ``"rerun", "R", ("RERUN", {"yellow": True})``.
+
+ :param report: The report object whose status is to be returned.
+ :param _pytest.config.Config config: The pytest config object.
+
+ Stops at first non-None result, see :ref:`firstresult`.
+ """
+
+
+def pytest_terminal_summary(
+ terminalreporter: "TerminalReporter", exitstatus: "ExitCode", config: "Config",
+) -> None:
"""Add a section to terminal summary reporting.
- :param _pytest.terminal.TerminalReporter terminalreporter: the internal terminal reporter object
- :param int exitstatus: the exit status that will be reported back to the OS
- :param _pytest.config.Config config: pytest config object
+ :param _pytest.terminal.TerminalReporter terminalreporter: The internal terminal reporter object.
+ :param int exitstatus: The exit status that will be reported back to the OS.
+ :param _pytest.config.Config config: The pytest config object.
.. versionadded:: 4.2
The ``config`` parameter.
"""
-@hookspec(historic=True)
-def pytest_warning_captured(warning_message, when, item, location):
- """
- Process a warning captured by the internal pytest warnings plugin.
+@hookspec(historic=True, warn_on_impl=WARNING_CAPTURED_HOOK)
+def pytest_warning_captured(
+ warning_message: "warnings.WarningMessage",
+ when: "Literal['config', 'collect', 'runtest']",
+ item: Optional["Item"],
+ location: Optional[Tuple[str, int, str]],
+) -> None:
+ """(**Deprecated**) Process a warning captured by the internal pytest warnings plugin.
+
+ .. deprecated:: 6.0
+
+ This hook is considered deprecated and will be removed in a future pytest version.
+ Use :func:`pytest_warning_recorded` instead.
:param warnings.WarningMessage warning_message:
The captured warning. This is the same object produced by :py:func:`warnings.catch_warnings`, and contains
@@ -583,27 +767,66 @@ def pytest_warning_captured(warning_message, when, item, location):
* ``"runtest"``: during test execution.
:param pytest.Item|None item:
- **DEPRECATED**: This parameter is incompatible with ``pytest-xdist``, and will always receive ``None``
- in a future release.
-
The item being executed if ``when`` is ``"runtest"``, otherwise ``None``.
:param tuple location:
- Holds information about the execution context of the captured warning (filename, linenumber, function).
- ``function`` evaluates to <module> when the execution context is at the module level.
+ When available, holds information about the execution context of the captured
+ warning (filename, linenumber, function). ``function`` evaluates to <module>
+ when the execution context is at the module level.
+ """
+
+
+@hookspec(historic=True)
+def pytest_warning_recorded(
+ warning_message: "warnings.WarningMessage",
+ when: "Literal['config', 'collect', 'runtest']",
+ nodeid: str,
+ location: Optional[Tuple[str, int, str]],
+) -> None:
+ """Process a warning captured by the internal pytest warnings plugin.
+
+ :param warnings.WarningMessage warning_message:
+ The captured warning. This is the same object produced by :py:func:`warnings.catch_warnings`, and contains
+ the same attributes as the parameters of :py:func:`warnings.showwarning`.
+
+ :param str when:
+ Indicates when the warning was captured. Possible values:
+
+ * ``"config"``: during pytest configuration/initialization stage.
+ * ``"collect"``: during test collection.
+ * ``"runtest"``: during test execution.
+
+ :param str nodeid:
+ Full id of the item.
+
+ :param tuple|None location:
+ When available, holds information about the execution context of the captured
+ warning (filename, linenumber, function). ``function`` evaluates to <module>
+ when the execution context is at the module level.
+
+ .. versionadded:: 6.0
"""
# -------------------------------------------------------------------------
-# doctest hooks
+# Hooks for influencing skipping
# -------------------------------------------------------------------------
-@hookspec(firstresult=True)
-def pytest_doctest_prepare_content(content):
- """ return processed content for a given doctest
+def pytest_markeval_namespace(config: "Config") -> Dict[str, Any]:
+ """Called when constructing the globals dictionary used for
+ evaluating string conditions in xfail/skipif markers.
+
+ This is useful when the condition for a marker requires
+ objects that are expensive or impossible to obtain during
+ collection time, which is required by normal boolean
+ conditions.
+
+ .. versionadded:: 6.2
- Stops at first non-None result, see :ref:`firstresult` """
+ :param _pytest.config.Config config: The pytest config object.
+ :returns: A dictionary of additional globals to add.
+ """
# -------------------------------------------------------------------------
@@ -611,38 +834,58 @@ def pytest_doctest_prepare_content(content):
# -------------------------------------------------------------------------
-def pytest_internalerror(excrepr, excinfo):
- """ called for internal errors. """
+def pytest_internalerror(
+ excrepr: "ExceptionRepr", excinfo: "ExceptionInfo[BaseException]",
+) -> Optional[bool]:
+ """Called for internal errors.
+
+ Return True to suppress the fallback handling of printing an
+ INTERNALERROR message directly to sys.stderr.
+ """
-def pytest_keyboard_interrupt(excinfo):
- """ called for keyboard interrupt. """
+def pytest_keyboard_interrupt(
+ excinfo: "ExceptionInfo[Union[KeyboardInterrupt, Exit]]",
+) -> None:
+ """Called for keyboard interrupt."""
-def pytest_exception_interact(node, call, report):
- """called when an exception was raised which can potentially be
+def pytest_exception_interact(
+ node: Union["Item", "Collector"],
+ call: "CallInfo[Any]",
+ report: Union["CollectReport", "TestReport"],
+) -> None:
+ """Called when an exception was raised which can potentially be
interactively handled.
- This hook is only called if an exception was raised
- that is not an internal exception like ``skip.Exception``.
+ May be called during collection (see :py:func:`pytest_make_collect_report`),
+ in which case ``report`` is a :py:class:`_pytest.reports.CollectReport`.
+
+ May be called during runtest of an item (see :py:func:`pytest_runtest_protocol`),
+ in which case ``report`` is a :py:class:`_pytest.reports.TestReport`.
+
+ This hook is not called if the exception that was raised is an internal
+ exception like ``skip.Exception``.
"""
-def pytest_enter_pdb(config, pdb):
- """ called upon pdb.set_trace(), can be used by plugins to take special
- action just before the python debugger enters in interactive mode.
+def pytest_enter_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
+ """Called upon pdb.set_trace().
+
+ Can be used by plugins to take special action just before the python
+ debugger enters interactive mode.
- :param _pytest.config.Config config: pytest config object
- :param pdb.Pdb pdb: Pdb instance
+ :param _pytest.config.Config config: The pytest config object.
+ :param pdb.Pdb pdb: The Pdb instance.
"""
-def pytest_leave_pdb(config, pdb):
- """ called when leaving pdb (e.g. with continue after pdb.set_trace()).
+def pytest_leave_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
+ """Called when leaving pdb (e.g. with continue after pdb.set_trace()).
Can be used by plugins to take special action just after the python
debugger leaves interactive mode.
- :param _pytest.config.Config config: pytest config object
- :param pdb.Pdb pdb: Pdb instance
+ :param _pytest.config.Config config: The pytest config object.
+ :param pdb.Pdb pdb: The Pdb instance.
"""
diff --git a/contrib/python/pytest/py3/_pytest/junitxml.py b/contrib/python/pytest/py3/_pytest/junitxml.py
index 77e1843127..c4761cd3b8 100644
--- a/contrib/python/pytest/py3/_pytest/junitxml.py
+++ b/contrib/python/pytest/py3/_pytest/junitxml.py
@@ -1,71 +1,70 @@
-"""
- report test results in JUnit-XML format,
- for use with Jenkins and build integration servers.
-
+"""Report test results in JUnit-XML format, for use with Jenkins and build
+integration servers.
Based on initial code from Ross Lawley.
-Output conforms to https://github.com/jenkinsci/xunit-plugin/blob/master/
-src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd
+Output conforms to
+https://github.com/jenkinsci/xunit-plugin/blob/master/src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd
"""
import functools
import os
import platform
import re
-import sys
-import time
+import xml.etree.ElementTree as ET
from datetime import datetime
-
-import py
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Match
+from typing import Optional
+from typing import Tuple
+from typing import Union
import pytest
-from _pytest import deprecated
from _pytest import nodes
+from _pytest import timing
+from _pytest._code.code import ExceptionRepr
+from _pytest._code.code import ReprFileLocation
+from _pytest.config import Config
from _pytest.config import filename_arg
+from _pytest.config.argparsing import Parser
+from _pytest.fixtures import FixtureRequest
+from _pytest.reports import TestReport
from _pytest.store import StoreKey
-from _pytest.warnings import _issue_warning_captured
+from _pytest.terminal import TerminalReporter
xml_key = StoreKey["LogXML"]()
-class Junit(py.xml.Namespace):
- pass
-
-
-# We need to get the subset of the invalid unicode ranges according to
-# XML 1.0 which are valid in this python build. Hence we calculate
-# this dynamically instead of hardcoding it. The spec range of valid
-# chars is: Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD]
-# | [#x10000-#x10FFFF]
-_legal_chars = (0x09, 0x0A, 0x0D)
-_legal_ranges = ((0x20, 0x7E), (0x80, 0xD7FF), (0xE000, 0xFFFD), (0x10000, 0x10FFFF))
-_legal_xml_re = [
- "{}-{}".format(chr(low), chr(high))
- for (low, high) in _legal_ranges
- if low < sys.maxunicode
-]
-_legal_xml_re = [chr(x) for x in _legal_chars] + _legal_xml_re
-illegal_xml_re = re.compile("[^%s]" % "".join(_legal_xml_re))
-del _legal_chars
-del _legal_ranges
-del _legal_xml_re
-
-_py_ext_re = re.compile(r"\.py$")
+def bin_xml_escape(arg: object) -> str:
+ r"""Visually escape invalid XML characters.
+ For example, transforms
+ 'hello\aworld\b'
+ into
+ 'hello#x07world#x08'
+ Note that the #xABs are *not* XML escapes - missing the ampersand &#xAB.
+ The idea is to escape visually for the user rather than for XML itself.
+ """
-def bin_xml_escape(arg):
- def repl(matchobj):
+ def repl(matchobj: Match[str]) -> str:
i = ord(matchobj.group())
if i <= 0xFF:
return "#x%02X" % i
else:
return "#x%04X" % i
- return py.xml.raw(illegal_xml_re.sub(repl, py.xml.escape(arg)))
+ # The spec range of valid chars is:
+ # Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
+ # For an unknown(?) reason, we disallow #x7F (DEL) as well.
+ illegal_xml_re = (
+ "[^\u0009\u000A\u000D\u0020-\u007E\u0080-\uD7FF\uE000-\uFFFD\u10000-\u10FFFF]"
+ )
+ return re.sub(illegal_xml_re, repl, str(arg))
-def merge_family(left, right):
+def merge_family(left, right) -> None:
result = {}
for kl, vl in left.items():
for kr, vr in right.items():
@@ -79,68 +78,63 @@ families = {}
families["_base"] = {"testcase": ["classname", "name"]}
families["_base_legacy"] = {"testcase": ["file", "line", "url"]}
-# xUnit 1.x inherits legacy attributes
+# xUnit 1.x inherits legacy attributes.
families["xunit1"] = families["_base"].copy()
merge_family(families["xunit1"], families["_base_legacy"])
-# xUnit 2.x uses strict base attributes
+# xUnit 2.x uses strict base attributes.
families["xunit2"] = families["_base"]
class _NodeReporter:
- def __init__(self, nodeid, xml):
+ def __init__(self, nodeid: Union[str, TestReport], xml: "LogXML") -> None:
self.id = nodeid
self.xml = xml
self.add_stats = self.xml.add_stats
self.family = self.xml.family
self.duration = 0
- self.properties = []
- self.nodes = []
- self.testcase = None
- self.attrs = {}
+ self.properties: List[Tuple[str, str]] = []
+ self.nodes: List[ET.Element] = []
+ self.attrs: Dict[str, str] = {}
- def append(self, node):
- self.xml.add_stats(type(node).__name__)
+ def append(self, node: ET.Element) -> None:
+ self.xml.add_stats(node.tag)
self.nodes.append(node)
- def add_property(self, name, value):
+ def add_property(self, name: str, value: object) -> None:
self.properties.append((str(name), bin_xml_escape(value)))
- def add_attribute(self, name, value):
+ def add_attribute(self, name: str, value: object) -> None:
self.attrs[str(name)] = bin_xml_escape(value)
- def make_properties_node(self):
- """Return a Junit node containing custom properties, if any.
- """
+ def make_properties_node(self) -> Optional[ET.Element]:
+ """Return a Junit node containing custom properties, if any."""
if self.properties:
- return Junit.properties(
- [
- Junit.property(name=name, value=value)
- for name, value in self.properties
- ]
- )
- return ""
+ properties = ET.Element("properties")
+ for name, value in self.properties:
+ properties.append(ET.Element("property", name=name, value=value))
+ return properties
+ return None
- def record_testreport(self, testreport):
- assert not self.testcase
+ def record_testreport(self, testreport: TestReport) -> None:
names = mangle_test_address(testreport.nodeid)
existing_attrs = self.attrs
classnames = names[:-1]
if self.xml.prefix:
classnames.insert(0, self.xml.prefix)
- attrs = {
+ attrs: Dict[str, str] = {
"classname": ".".join(classnames),
"name": bin_xml_escape(names[-1]),
"file": testreport.location[0],
}
if testreport.location[1] is not None:
- attrs["line"] = testreport.location[1]
+ attrs["line"] = str(testreport.location[1])
if hasattr(testreport, "url"):
attrs["url"] = testreport.url
self.attrs = attrs
- self.attrs.update(existing_attrs) # restore any user-defined attributes
+ self.attrs.update(existing_attrs) # Restore any user-defined attributes.
- # Preserve legacy testcase behavior
+ # Preserve legacy testcase behavior.
if self.family == "xunit1":
return
@@ -152,19 +146,20 @@ class _NodeReporter:
temp_attrs[key] = self.attrs[key]
self.attrs = temp_attrs
- def to_xml(self):
- testcase = Junit.testcase(time="%.3f" % self.duration, **self.attrs)
- testcase.append(self.make_properties_node())
- for node in self.nodes:
- testcase.append(node)
+ def to_xml(self) -> ET.Element:
+ testcase = ET.Element("testcase", self.attrs, time="%.3f" % self.duration)
+ properties = self.make_properties_node()
+ if properties is not None:
+ testcase.append(properties)
+ testcase.extend(self.nodes)
return testcase
- def _add_simple(self, kind, message, data=None):
- data = bin_xml_escape(data)
- node = kind(data, message=message)
+ def _add_simple(self, tag: str, message: str, data: Optional[str] = None) -> None:
+ node = ET.Element(tag, message=message)
+ node.text = bin_xml_escape(data)
self.append(node)
- def write_captured_output(self, report):
+ def write_captured_output(self, report: TestReport) -> None:
if not self.xml.log_passing_tests and report.passed:
return
@@ -187,81 +182,89 @@ class _NodeReporter:
if content_all:
self._write_content(report, content_all, "system-out")
- def _prepare_content(self, content, header):
+ def _prepare_content(self, content: str, header: str) -> str:
return "\n".join([header.center(80, "-"), content, ""])
- def _write_content(self, report, content, jheader):
- tag = getattr(Junit, jheader)
- self.append(tag(bin_xml_escape(content)))
+ def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
+ tag = ET.Element(jheader)
+ tag.text = bin_xml_escape(content)
+ self.append(tag)
- def append_pass(self, report):
+ def append_pass(self, report: TestReport) -> None:
self.add_stats("passed")
- def append_failure(self, report):
+ def append_failure(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
if hasattr(report, "wasxfail"):
- self._add_simple(Junit.skipped, "xfail-marked test passes unexpectedly")
+ self._add_simple("skipped", "xfail-marked test passes unexpectedly")
else:
- if hasattr(report.longrepr, "reprcrash"):
- message = report.longrepr.reprcrash.message
- elif isinstance(report.longrepr, str):
- message = report.longrepr
+ assert report.longrepr is not None
+ reprcrash: Optional[ReprFileLocation] = getattr(
+ report.longrepr, "reprcrash", None
+ )
+ if reprcrash is not None:
+ message = reprcrash.message
else:
message = str(report.longrepr)
message = bin_xml_escape(message)
- fail = Junit.failure(message=message)
- fail.append(bin_xml_escape(report.longrepr))
- self.append(fail)
+ self._add_simple("failure", message, str(report.longrepr))
- def append_collect_error(self, report):
+ def append_collect_error(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
- self.append(
- Junit.error(bin_xml_escape(report.longrepr), message="collection failure")
- )
+ assert report.longrepr is not None
+ self._add_simple("error", "collection failure", str(report.longrepr))
- def append_collect_skipped(self, report):
- self._add_simple(Junit.skipped, "collection skipped", report.longrepr)
+ def append_collect_skipped(self, report: TestReport) -> None:
+ self._add_simple("skipped", "collection skipped", str(report.longrepr))
+
+ def append_error(self, report: TestReport) -> None:
+ assert report.longrepr is not None
+ reprcrash: Optional[ReprFileLocation] = getattr(
+ report.longrepr, "reprcrash", None
+ )
+ if reprcrash is not None:
+ reason = reprcrash.message
+ else:
+ reason = str(report.longrepr)
- def append_error(self, report):
if report.when == "teardown":
- msg = "test teardown failure"
+ msg = f'failed on teardown with "{reason}"'
else:
- msg = "test setup failure"
- self._add_simple(Junit.error, msg, report.longrepr)
+ msg = f'failed on setup with "{reason}"'
+ self._add_simple("error", msg, str(report.longrepr))
- def append_skipped(self, report):
+ def append_skipped(self, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
xfailreason = report.wasxfail
if xfailreason.startswith("reason: "):
xfailreason = xfailreason[8:]
- self.append(
- Junit.skipped(
- "", type="pytest.xfail", message=bin_xml_escape(xfailreason)
- )
- )
+ xfailreason = bin_xml_escape(xfailreason)
+ skipped = ET.Element("skipped", type="pytest.xfail", message=xfailreason)
+ self.append(skipped)
else:
+ assert isinstance(report.longrepr, tuple)
filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "):
skipreason = skipreason[9:]
- details = "{}:{}: {}".format(filename, lineno, skipreason)
+ details = f"{filename}:{lineno}: {skipreason}"
- self.append(
- Junit.skipped(
- bin_xml_escape(details),
- type="pytest.skip",
- message=bin_xml_escape(skipreason),
- )
- )
+ skipped = ET.Element("skipped", type="pytest.skip", message=skipreason)
+ skipped.text = bin_xml_escape(details)
+ self.append(skipped)
self.write_captured_output(report)
- def finalize(self):
- data = self.to_xml().unicode(indent=0)
+ def finalize(self) -> None:
+ data = self.to_xml()
self.__dict__.clear()
- self.to_xml = lambda: py.xml.raw(data)
+ # Type ignored becuase mypy doesn't like overriding a method.
+ # Also the return value doesn't match...
+ self.to_xml = lambda: data # type: ignore[assignment]
-def _warn_incompatibility_with_xunit2(request, fixture_name):
- """Emits a PytestWarning about the given fixture being incompatible with newer xunit revisions"""
+def _warn_incompatibility_with_xunit2(
+ request: FixtureRequest, fixture_name: str
+) -> None:
+ """Emit a PytestWarning about the given fixture being incompatible with newer xunit revisions."""
from _pytest.warning_types import PytestWarning
xml = request.config._store.get(xml_key, None)
@@ -276,12 +279,14 @@ def _warn_incompatibility_with_xunit2(request, fixture_name):
@pytest.fixture
-def record_property(request):
- """Add an extra properties the calling test.
+def record_property(request: FixtureRequest) -> Callable[[str, object], None]:
+ """Add extra properties to the calling test.
+
User properties become part of the test report and are available to the
configured reporters, like JUnit XML.
- The fixture is callable with ``(name, value)``, with value being automatically
- xml-encoded.
+
+ The fixture is callable with ``name, value``. The value is automatically
+ XML-encoded.
Example::
@@ -290,17 +295,18 @@ def record_property(request):
"""
_warn_incompatibility_with_xunit2(request, "record_property")
- def append_property(name, value):
+ def append_property(name: str, value: object) -> None:
request.node.user_properties.append((name, value))
return append_property
@pytest.fixture
-def record_xml_attribute(request):
+def record_xml_attribute(request: FixtureRequest) -> Callable[[str, object], None]:
"""Add extra xml attributes to the tag for the calling test.
- The fixture is callable with ``(name, value)``, with value being
- automatically xml-encoded
+
+ The fixture is callable with ``name, value``. The value is
+ automatically XML-encoded.
"""
from _pytest.warning_types import PytestExperimentalApiWarning
@@ -311,7 +317,7 @@ def record_xml_attribute(request):
_warn_incompatibility_with_xunit2(request, "record_xml_attribute")
# Declare noop
- def add_attr_noop(name, value):
+ def add_attr_noop(name: str, value: object) -> None:
pass
attr_func = add_attr_noop
@@ -324,20 +330,21 @@ def record_xml_attribute(request):
return attr_func
-def _check_record_param_type(param, v):
+def _check_record_param_type(param: str, v: str) -> None:
"""Used by record_testsuite_property to check that the given parameter name is of the proper
- type"""
+ type."""
__tracebackhide__ = True
if not isinstance(v, str):
- msg = "{param} parameter needs to be a string, but {g} given"
+ msg = "{param} parameter needs to be a string, but {g} given" # type: ignore[unreachable]
raise TypeError(msg.format(param=param, g=type(v).__name__))
@pytest.fixture(scope="session")
-def record_testsuite_property(request):
- """
- Records a new ``<property>`` tag as child of the root ``<testsuite>``. This is suitable to
- writing global information regarding the entire test suite, and is compatible with ``xunit2`` JUnit family.
+def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object], None]:
+ """Record a new ``<property>`` tag as child of the root ``<testsuite>``.
+
+ This is suitable to writing global information regarding the entire test
+ suite, and is compatible with ``xunit2`` JUnit family.
This is a ``session``-scoped fixture which is called with ``(name, value)``. Example:
@@ -348,12 +355,18 @@ def record_testsuite_property(request):
record_testsuite_property("STORAGE_TYPE", "CEPH")
``name`` must be a string, ``value`` will be converted to a string and properly xml-escaped.
+
+ .. warning::
+
+ Currently this fixture **does not work** with the
+ `pytest-xdist <https://github.com/pytest-dev/pytest-xdist>`__ plugin. See issue
+ `#7767 <https://github.com/pytest-dev/pytest/issues/7767>`__ for details.
"""
__tracebackhide__ = True
- def record_func(name, value):
- """noop function in case --junitxml was not passed in the command-line"""
+ def record_func(name: str, value: object) -> None:
+ """No-op function in case --junitxml was not passed in the command-line."""
__tracebackhide__ = True
_check_record_param_type("name", name)
@@ -363,7 +376,7 @@ def record_testsuite_property(request):
return record_func
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group.addoption(
"--junitxml",
@@ -404,18 +417,17 @@ def pytest_addoption(parser):
default="total",
) # choices=['total', 'call'])
parser.addini(
- "junit_family", "Emit XML for schema: one of legacy|xunit1|xunit2", default=None
+ "junit_family",
+ "Emit XML for schema: one of legacy|xunit1|xunit2",
+ default="xunit2",
)
-def pytest_configure(config):
+def pytest_configure(config: Config) -> None:
xmlpath = config.option.xmlpath
- # prevent opening xmllog on slave nodes (xdist)
- if xmlpath and not hasattr(config, "slaveinput"):
+ # Prevent opening xmllog on worker nodes (xdist).
+ if xmlpath and not hasattr(config, "workerinput"):
junit_family = config.getini("junit_family")
- if not junit_family:
- _issue_warning_captured(deprecated.JUNIT_XML_DEFAULT_FAMILY, config.hook, 2)
- junit_family = "xunit1"
config._store[xml_key] = LogXML(
xmlpath,
config.option.junitprefix,
@@ -428,24 +440,24 @@ def pytest_configure(config):
config.pluginmanager.register(config._store[xml_key])
-def pytest_unconfigure(config):
+def pytest_unconfigure(config: Config) -> None:
xml = config._store.get(xml_key, None)
if xml:
del config._store[xml_key]
config.pluginmanager.unregister(xml)
-def mangle_test_address(address):
+def mangle_test_address(address: str) -> List[str]:
path, possible_open_bracket, params = address.partition("[")
names = path.split("::")
try:
names.remove("()")
except ValueError:
pass
- # convert file path to dotted path
+ # Convert file path to dotted path.
names[0] = names[0].replace(nodes.SEP, ".")
- names[0] = _py_ext_re.sub("", names[0])
- # put any params back
+ names[0] = re.sub(r"\.py$", "", names[0])
+ # Put any params back.
names[-1] += possible_open_bracket + params
return names
@@ -454,13 +466,13 @@ class LogXML:
def __init__(
self,
logfile,
- prefix,
- suite_name="pytest",
- logging="no",
- report_duration="total",
+ prefix: Optional[str],
+ suite_name: str = "pytest",
+ logging: str = "no",
+ report_duration: str = "total",
family="xunit1",
- log_passing_tests=True,
- ):
+ log_passing_tests: bool = True,
+ ) -> None:
logfile = os.path.expanduser(os.path.expandvars(logfile))
self.logfile = os.path.normpath(os.path.abspath(logfile))
self.prefix = prefix
@@ -469,33 +481,37 @@ class LogXML:
self.log_passing_tests = log_passing_tests
self.report_duration = report_duration
self.family = family
- self.stats = dict.fromkeys(["error", "passed", "failure", "skipped"], 0)
- self.node_reporters = {} # nodeid -> _NodeReporter
- self.node_reporters_ordered = []
- self.global_properties = []
+ self.stats: Dict[str, int] = dict.fromkeys(
+ ["error", "passed", "failure", "skipped"], 0
+ )
+ self.node_reporters: Dict[
+ Tuple[Union[str, TestReport], object], _NodeReporter
+ ] = ({})
+ self.node_reporters_ordered: List[_NodeReporter] = []
+ self.global_properties: List[Tuple[str, str]] = []
# List of reports that failed on call but teardown is pending.
- self.open_reports = []
+ self.open_reports: List[TestReport] = []
self.cnt_double_fail_tests = 0
- # Replaces convenience family with real family
+ # Replaces convenience family with real family.
if self.family == "legacy":
self.family = "xunit1"
- def finalize(self, report):
+ def finalize(self, report: TestReport) -> None:
nodeid = getattr(report, "nodeid", report)
- # local hack to handle xdist report order
- slavenode = getattr(report, "node", None)
- reporter = self.node_reporters.pop((nodeid, slavenode))
+ # Local hack to handle xdist report order.
+ workernode = getattr(report, "node", None)
+ reporter = self.node_reporters.pop((nodeid, workernode))
if reporter is not None:
reporter.finalize()
- def node_reporter(self, report):
- nodeid = getattr(report, "nodeid", report)
- # local hack to handle xdist report order
- slavenode = getattr(report, "node", None)
+ def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporter:
+ nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
+ # Local hack to handle xdist report order.
+ workernode = getattr(report, "node", None)
- key = nodeid, slavenode
+ key = nodeid, workernode
if key in self.node_reporters:
# TODO: breaks for --dist=each
@@ -508,23 +524,23 @@ class LogXML:
return reporter
- def add_stats(self, key):
+ def add_stats(self, key: str) -> None:
if key in self.stats:
self.stats[key] += 1
- def _opentestcase(self, report):
+ def _opentestcase(self, report: TestReport) -> _NodeReporter:
reporter = self.node_reporter(report)
reporter.record_testreport(report)
return reporter
- def pytest_runtest_logreport(self, report):
- """handle a setup/call/teardown report, generating the appropriate
- xml tags as necessary.
+ def pytest_runtest_logreport(self, report: TestReport) -> None:
+ """Handle a setup/call/teardown report, generating the appropriate
+ XML tags as necessary.
- note: due to plugins like xdist, this hook may be called in interlaced
- order with reports from other nodes. for example:
+ Note: due to plugins like xdist, this hook may be called in interlaced
+ order with reports from other nodes. For example:
- usual call order:
+ Usual call order:
-> setup node1
-> call node1
-> teardown node1
@@ -532,7 +548,7 @@ class LogXML:
-> call node2
-> teardown node2
- possible call order in xdist:
+ Possible call order in xdist:
-> setup node1
-> call node1
-> setup node2
@@ -547,7 +563,7 @@ class LogXML:
reporter.append_pass(report)
elif report.failed:
if report.when == "teardown":
- # The following vars are needed when xdist plugin is used
+ # The following vars are needed when xdist plugin is used.
report_wid = getattr(report, "worker_id", None)
report_ii = getattr(report, "item_index", None)
close_report = next(
@@ -565,7 +581,7 @@ class LogXML:
if close_report:
# We need to open new testcase in case we have failure in
# call and error in teardown in order to follow junit
- # schema
+ # schema.
self.finalize(close_report)
self.cnt_double_fail_tests += 1
reporter = self._opentestcase(report)
@@ -585,7 +601,7 @@ class LogXML:
reporter.write_captured_output(report)
for propname, propvalue in report.user_properties:
- reporter.add_property(propname, propvalue)
+ reporter.add_property(propname, str(propvalue))
self.finalize(report)
report_wid = getattr(report, "worker_id", None)
@@ -605,15 +621,14 @@ class LogXML:
if close_report:
self.open_reports.remove(close_report)
- def update_testcase_duration(self, report):
- """accumulates total duration for nodeid from given report and updates
- the Junit.testcase with the new total if already created.
- """
+ def update_testcase_duration(self, report: TestReport) -> None:
+ """Accumulate total duration for nodeid from given report and update
+ the Junit.testcase with the new total if already created."""
if self.report_duration == "total" or report.when == self.report_duration:
reporter = self.node_reporter(report)
reporter.duration += getattr(report, "duration", 0.0)
- def pytest_collectreport(self, report):
+ def pytest_collectreport(self, report: TestReport) -> None:
if not report.passed:
reporter = self._opentestcase(report)
if report.failed:
@@ -621,20 +636,20 @@ class LogXML:
else:
reporter.append_collect_skipped(report)
- def pytest_internalerror(self, excrepr):
+ def pytest_internalerror(self, excrepr: ExceptionRepr) -> None:
reporter = self.node_reporter("internal")
reporter.attrs.update(classname="pytest", name="internal")
- reporter._add_simple(Junit.error, "internal error", excrepr)
+ reporter._add_simple("error", "internal error", str(excrepr))
- def pytest_sessionstart(self):
- self.suite_start_time = time.time()
+ def pytest_sessionstart(self) -> None:
+ self.suite_start_time = timing.time()
- def pytest_sessionfinish(self):
+ def pytest_sessionfinish(self) -> None:
dirname = os.path.dirname(os.path.abspath(self.logfile))
if not os.path.isdir(dirname):
os.makedirs(dirname)
logfile = open(self.logfile, "w", encoding="utf-8")
- suite_stop_time = time.time()
+ suite_stop_time = timing.time()
suite_time_delta = suite_stop_time - self.suite_start_time
numtests = (
@@ -646,37 +661,40 @@ class LogXML:
)
logfile.write('<?xml version="1.0" encoding="utf-8"?>')
- suite_node = Junit.testsuite(
- self._get_global_properties_node(),
- [x.to_xml() for x in self.node_reporters_ordered],
+ suite_node = ET.Element(
+ "testsuite",
name=self.suite_name,
- errors=self.stats["error"],
- failures=self.stats["failure"],
- skipped=self.stats["skipped"],
- tests=numtests,
+ errors=str(self.stats["error"]),
+ failures=str(self.stats["failure"]),
+ skipped=str(self.stats["skipped"]),
+ tests=str(numtests),
time="%.3f" % suite_time_delta,
timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(),
hostname=platform.node(),
)
- logfile.write(Junit.testsuites([suite_node]).unicode(indent=0))
+ global_properties = self._get_global_properties_node()
+ if global_properties is not None:
+ suite_node.append(global_properties)
+ for node_reporter in self.node_reporters_ordered:
+ suite_node.append(node_reporter.to_xml())
+ testsuites = ET.Element("testsuites")
+ testsuites.append(suite_node)
+ logfile.write(ET.tostring(testsuites, encoding="unicode"))
logfile.close()
- def pytest_terminal_summary(self, terminalreporter):
- terminalreporter.write_sep("-", "generated xml file: %s" % (self.logfile))
+ def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None:
+ terminalreporter.write_sep("-", f"generated xml file: {self.logfile}")
- def add_global_property(self, name, value):
+ def add_global_property(self, name: str, value: object) -> None:
__tracebackhide__ = True
_check_record_param_type("name", name)
self.global_properties.append((name, bin_xml_escape(value)))
- def _get_global_properties_node(self):
- """Return a Junit node containing custom properties, if any.
- """
+ def _get_global_properties_node(self) -> Optional[ET.Element]:
+ """Return a Junit node containing custom properties, if any."""
if self.global_properties:
- return Junit.properties(
- [
- Junit.property(name=name, value=value)
- for name, value in self.global_properties
- ]
- )
- return ""
+ properties = ET.Element("properties")
+ for name, value in self.global_properties:
+ properties.append(ET.Element("property", name=name, value=value))
+ return properties
+ return None
diff --git a/contrib/python/pytest/py3/_pytest/logging.py b/contrib/python/pytest/py3/_pytest/logging.py
index 5e60a23217..2e4847328a 100644
--- a/contrib/python/pytest/py3/_pytest/logging.py
+++ b/contrib/python/pytest/py3/_pytest/logging.py
@@ -1,38 +1,56 @@
-""" Access and control log capturing. """
+"""Access and control log capturing."""
import logging
+import os
import re
+import sys
from contextlib import contextmanager
from io import StringIO
+from pathlib import Path
from typing import AbstractSet
from typing import Dict
from typing import Generator
from typing import List
from typing import Mapping
from typing import Optional
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
-import pytest
from _pytest import nodes
+from _pytest._io import TerminalWriter
+from _pytest.capture import CaptureManager
+from _pytest.compat import final
from _pytest.compat import nullcontext
from _pytest.config import _strtobool
from _pytest.config import Config
from _pytest.config import create_terminal_writer
-from _pytest.pathlib import Path
+from _pytest.config import hookimpl
+from _pytest.config import UsageError
+from _pytest.config.argparsing import Parser
+from _pytest.deprecated import check_ispytest
+from _pytest.fixtures import fixture
+from _pytest.fixtures import FixtureRequest
+from _pytest.main import Session
+from _pytest.store import StoreKey
+from _pytest.terminal import TerminalReporter
+
DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s"
DEFAULT_LOG_DATE_FORMAT = "%H:%M:%S"
_ANSI_ESCAPE_SEQ = re.compile(r"\x1b\[[\d;]+m")
+caplog_handler_key = StoreKey["LogCaptureHandler"]()
+caplog_records_key = StoreKey[Dict[str, List[logging.LogRecord]]]()
-def _remove_ansi_escape_sequences(text):
+def _remove_ansi_escape_sequences(text: str) -> str:
return _ANSI_ESCAPE_SEQ.sub("", text)
class ColoredLevelFormatter(logging.Formatter):
- """
- Colorize the %(levelname)..s part of the log format passed to __init__.
- """
+ """A logging formatter which colorizes the %(levelname)..s part of the
+ log format passed to __init__."""
- LOGLEVEL_COLOROPTS = {
+ LOGLEVEL_COLOROPTS: Mapping[int, AbstractSet[str]] = {
logging.CRITICAL: {"red"},
logging.ERROR: {"red", "bold"},
logging.WARNING: {"yellow"},
@@ -40,13 +58,13 @@ class ColoredLevelFormatter(logging.Formatter):
logging.INFO: {"green"},
logging.DEBUG: {"purple"},
logging.NOTSET: set(),
- } # type: Mapping[int, AbstractSet[str]]
+ }
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*s)")
- def __init__(self, terminalwriter, *args, **kwargs) -> None:
+ def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._original_fmt = self._style._fmt
- self._level_to_fmt_mapping = {} # type: Dict[int, str]
+ self._level_to_fmt_mapping: Dict[int, str] = {}
assert self._fmt is not None
levelname_fmt_match = self.LEVELNAME_FMT_REGEX.search(self._fmt)
@@ -68,7 +86,7 @@ class ColoredLevelFormatter(logging.Formatter):
colorized_formatted_levelname, self._fmt
)
- def format(self, record):
+ def format(self, record: logging.LogRecord) -> str:
fmt = self._level_to_fmt_mapping.get(record.levelno, self._original_fmt)
self._style._fmt = fmt
return super().format(record)
@@ -81,19 +99,21 @@ class PercentStyleMultiline(logging.PercentStyle):
formats the message as if each line were logged separately.
"""
- def __init__(self, fmt, auto_indent):
+ def __init__(self, fmt: str, auto_indent: Union[int, str, bool, None]) -> None:
super().__init__(fmt)
self._auto_indent = self._get_auto_indent(auto_indent)
@staticmethod
- def _update_message(record_dict, message):
+ def _update_message(
+ record_dict: Dict[str, object], message: str
+ ) -> Dict[str, object]:
tmp = record_dict.copy()
tmp["message"] = message
return tmp
@staticmethod
- def _get_auto_indent(auto_indent_option) -> int:
- """Determines the current auto indentation setting
+ def _get_auto_indent(auto_indent_option: Union[int, str, bool, None]) -> int:
+ """Determine the current auto indentation setting.
Specify auto indent behavior (on/off/fixed) by passing in
extra={"auto_indent": [value]} to the call to logging.log() or
@@ -111,20 +131,29 @@ class PercentStyleMultiline(logging.PercentStyle):
Any other values for the option are invalid, and will silently be
converted to the default.
- :param any auto_indent_option: User specified option for indentation
- from command line, config or extra kwarg. Accepts int, bool or str.
- str option accepts the same range of values as boolean config options,
- as well as positive integers represented in str form.
+ :param None|bool|int|str auto_indent_option:
+ User specified option for indentation from command line, config
+ or extra kwarg. Accepts int, bool or str. str option accepts the
+ same range of values as boolean config options, as well as
+ positive integers represented in str form.
- :returns: indentation value, which can be
+ :returns:
+ Indentation value, which can be
-1 (automatically determine indentation) or
0 (auto-indent turned off) or
>0 (explicitly set indentation position).
"""
- if type(auto_indent_option) is int:
+ if auto_indent_option is None:
+ return 0
+ elif isinstance(auto_indent_option, bool):
+ if auto_indent_option:
+ return -1
+ else:
+ return 0
+ elif isinstance(auto_indent_option, int):
return int(auto_indent_option)
- elif type(auto_indent_option) is str:
+ elif isinstance(auto_indent_option, str):
try:
return int(auto_indent_option)
except ValueError:
@@ -134,17 +163,14 @@ class PercentStyleMultiline(logging.PercentStyle):
return -1
except ValueError:
return 0
- elif type(auto_indent_option) is bool:
- if auto_indent_option:
- return -1
return 0
- def format(self, record):
+ def format(self, record: logging.LogRecord) -> str:
if "\n" in record.message:
if hasattr(record, "auto_indent"):
- # passed in from the "extra={}" kwarg on the call to logging.log()
- auto_indent = self._get_auto_indent(record.auto_indent)
+ # Passed in from the "extra={}" kwarg on the call to logging.log().
+ auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined]
else:
auto_indent = self._auto_indent
@@ -157,14 +183,14 @@ class PercentStyleMultiline(logging.PercentStyle):
lines[0]
)
else:
- # optimizes logging by allowing a fixed indentation
+ # Optimizes logging by allowing a fixed indentation.
indentation = auto_indent
lines[0] = formatted
return ("\n" + " " * indentation).join(lines)
return self._fmt % record.__dict__
-def get_option_ini(config, *names):
+def get_option_ini(config: Config, *names: str):
for name in names:
ret = config.getoption(name) # 'default' arg won't work as expected
if ret is None:
@@ -173,7 +199,7 @@ def get_option_ini(config, *names):
return ret
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
"""Add options to control log capturing."""
group = parser.getgroup("logging")
@@ -184,15 +210,6 @@ def pytest_addoption(parser):
group.addoption(option, dest=dest, **kwargs)
add_option_ini(
- "--no-print-logs",
- dest="log_print",
- action="store_const",
- const=False,
- default=True,
- type="bool",
- help="disable printing caught logs on failed tests.",
- )
- add_option_ini(
"--log-level",
dest="log_level",
default=None,
@@ -268,109 +285,121 @@ def pytest_addoption(parser):
)
-@contextmanager
-def catching_logs(handler, formatter=None, level=None):
+_HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
+
+
+# Not using @contextmanager for performance reasons.
+class catching_logs:
"""Context manager that prepares the whole logging machinery properly."""
- root_logger = logging.getLogger()
-
- if formatter is not None:
- handler.setFormatter(formatter)
- if level is not None:
- handler.setLevel(level)
-
- # Adding the same handler twice would confuse logging system.
- # Just don't do that.
- add_new_handler = handler not in root_logger.handlers
-
- if add_new_handler:
- root_logger.addHandler(handler)
- if level is not None:
- orig_level = root_logger.level
- root_logger.setLevel(min(orig_level, level))
- try:
- yield handler
- finally:
- if level is not None:
- root_logger.setLevel(orig_level)
- if add_new_handler:
- root_logger.removeHandler(handler)
+
+ __slots__ = ("handler", "level", "orig_level")
+
+ def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None:
+ self.handler = handler
+ self.level = level
+
+ def __enter__(self):
+ root_logger = logging.getLogger()
+ if self.level is not None:
+ self.handler.setLevel(self.level)
+ root_logger.addHandler(self.handler)
+ if self.level is not None:
+ self.orig_level = root_logger.level
+ root_logger.setLevel(min(self.orig_level, self.level))
+ return self.handler
+
+ def __exit__(self, type, value, traceback):
+ root_logger = logging.getLogger()
+ if self.level is not None:
+ root_logger.setLevel(self.orig_level)
+ root_logger.removeHandler(self.handler)
class LogCaptureHandler(logging.StreamHandler):
"""A logging handler that stores log records and the log text."""
+ stream: StringIO
+
def __init__(self) -> None:
- """Creates a new log handler."""
- logging.StreamHandler.__init__(self, StringIO())
- self.records = [] # type: List[logging.LogRecord]
+ """Create a new log handler."""
+ super().__init__(StringIO())
+ self.records: List[logging.LogRecord] = []
def emit(self, record: logging.LogRecord) -> None:
"""Keep the log records in a list in addition to the log text."""
self.records.append(record)
- logging.StreamHandler.emit(self, record)
+ super().emit(record)
def reset(self) -> None:
self.records = []
self.stream = StringIO()
+ def handleError(self, record: logging.LogRecord) -> None:
+ if logging.raiseExceptions:
+ # Fail the test if the log message is bad (emit failed).
+ # The default behavior of logging is to print "Logging error"
+ # to stderr with the call stack and some extra details.
+ # pytest wants to make such mistakes visible during testing.
+ raise
+
+@final
class LogCaptureFixture:
"""Provides access and control of log capturing."""
- def __init__(self, item) -> None:
- """Creates a new funcarg."""
+ def __init__(self, item: nodes.Node, *, _ispytest: bool = False) -> None:
+ check_ispytest(_ispytest)
self._item = item
- # dict of log name -> log level
- self._initial_log_levels = {} # type: Dict[str, int]
+ self._initial_handler_level: Optional[int] = None
+ # Dict of log name -> log level.
+ self._initial_logger_levels: Dict[Optional[str], int] = {}
def _finalize(self) -> None:
- """Finalizes the fixture.
+ """Finalize the fixture.
This restores the log levels changed by :meth:`set_level`.
"""
- # restore log levels
- for logger_name, level in self._initial_log_levels.items():
+ # Restore log levels.
+ if self._initial_handler_level is not None:
+ self.handler.setLevel(self._initial_handler_level)
+ for logger_name, level in self._initial_logger_levels.items():
logger = logging.getLogger(logger_name)
logger.setLevel(level)
@property
def handler(self) -> LogCaptureHandler:
- """
+ """Get the logging handler used by the fixture.
+
:rtype: LogCaptureHandler
"""
- return self._item.catch_log_handler # type: ignore[no-any-return] # noqa: F723
+ return self._item._store[caplog_handler_key]
def get_records(self, when: str) -> List[logging.LogRecord]:
- """
- Get the logging records for one of the possible test phases.
+ """Get the logging records for one of the possible test phases.
:param str when:
Which test phase to obtain the records from. Valid values are: "setup", "call" and "teardown".
+ :returns: The list of captured records at the given stage.
:rtype: List[logging.LogRecord]
- :return: the list of captured records at the given stage
.. versionadded:: 3.4
"""
- handler = self._item.catch_log_handlers.get(when)
- if handler:
- return handler.records # type: ignore[no-any-return] # noqa: F723
- else:
- return []
+ return self._item._store[caplog_records_key].get(when, [])
@property
- def text(self):
- """Returns the formatted log text."""
+ def text(self) -> str:
+ """The formatted log text."""
return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
@property
- def records(self):
- """Returns the list of log records."""
+ def records(self) -> List[logging.LogRecord]:
+ """The list of log records."""
return self.handler.records
@property
- def record_tuples(self):
- """Returns a list of a stripped down version of log records intended
+ def record_tuples(self) -> List[Tuple[str, int, str]]:
+ """A list of a stripped down version of log records intended
for use in assertion comparison.
The format of the tuple is:
@@ -380,61 +409,71 @@ class LogCaptureFixture:
return [(r.name, r.levelno, r.getMessage()) for r in self.records]
@property
- def messages(self):
- """Returns a list of format-interpolated log messages.
+ def messages(self) -> List[str]:
+ """A list of format-interpolated log messages.
- Unlike 'records', which contains the format string and parameters for interpolation, log messages in this list
- are all interpolated.
- Unlike 'text', which contains the output from the handler, log messages in this list are unadorned with
- levels, timestamps, etc, making exact comparisons more reliable.
+ Unlike 'records', which contains the format string and parameters for
+ interpolation, log messages in this list are all interpolated.
- Note that traceback or stack info (from :func:`logging.exception` or the `exc_info` or `stack_info` arguments
- to the logging functions) is not included, as this is added by the formatter in the handler.
+ Unlike 'text', which contains the output from the handler, log
+ messages in this list are unadorned with levels, timestamps, etc,
+ making exact comparisons more reliable.
+
+ Note that traceback or stack info (from :func:`logging.exception` or
+ the `exc_info` or `stack_info` arguments to the logging functions) is
+ not included, as this is added by the formatter in the handler.
.. versionadded:: 3.7
"""
return [r.getMessage() for r in self.records]
- def clear(self):
+ def clear(self) -> None:
"""Reset the list of log records and the captured log text."""
self.handler.reset()
- def set_level(self, level, logger=None):
- """Sets the level for capturing of logs. The level will be restored to its previous value at the end of
- the test.
-
- :param int level: the logger to level.
- :param str logger: the logger to update the level. If not given, the root logger level is updated.
+ def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None:
+ """Set the level of a logger for the duration of a test.
.. versionchanged:: 3.4
- The levels of the loggers changed by this function will be restored to their initial values at the
- end of the test.
+ The levels of the loggers changed by this function will be
+ restored to their initial values at the end of the test.
+
+ :param int level: The level.
+ :param str logger: The logger to update. If not given, the root logger.
"""
- logger_name = logger
- logger = logging.getLogger(logger_name)
- # save the original log-level to restore it during teardown
- self._initial_log_levels.setdefault(logger_name, logger.level)
- logger.setLevel(level)
+ logger_obj = logging.getLogger(logger)
+ # Save the original log-level to restore it during teardown.
+ self._initial_logger_levels.setdefault(logger, logger_obj.level)
+ logger_obj.setLevel(level)
+ if self._initial_handler_level is None:
+ self._initial_handler_level = self.handler.level
+ self.handler.setLevel(level)
@contextmanager
- def at_level(self, level, logger=None):
- """Context manager that sets the level for capturing of logs. After the end of the 'with' statement the
- level is restored to its original value.
+ def at_level(
+ self, level: int, logger: Optional[str] = None
+ ) -> Generator[None, None, None]:
+ """Context manager that sets the level for capturing of logs. After
+ the end of the 'with' statement the level is restored to its original
+ value.
- :param int level: the logger to level.
- :param str logger: the logger to update the level. If not given, the root logger level is updated.
+ :param int level: The level.
+ :param str logger: The logger to update. If not given, the root logger.
"""
- logger = logging.getLogger(logger)
- orig_level = logger.level
- logger.setLevel(level)
+ logger_obj = logging.getLogger(logger)
+ orig_level = logger_obj.level
+ logger_obj.setLevel(level)
+ handler_orig_level = self.handler.level
+ self.handler.setLevel(level)
try:
yield
finally:
- logger.setLevel(orig_level)
+ logger_obj.setLevel(orig_level)
+ self.handler.setLevel(handler_orig_level)
-@pytest.fixture
-def caplog(request):
+@fixture
+def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]:
"""Access and control log capturing.
Captured logs are available through the following properties/methods::
@@ -445,7 +484,7 @@ def caplog(request):
* caplog.record_tuples -> list of (logger_name, level, message) tuples
* caplog.clear() -> clear captured records and formatted log output string
"""
- result = LogCaptureFixture(request.node)
+ result = LogCaptureFixture(request.node, _ispytest=True)
yield result
result._finalize()
@@ -464,84 +503,92 @@ def get_log_level_for_setting(config: Config, *setting_names: str) -> Optional[i
log_level = log_level.upper()
try:
return int(getattr(logging, log_level, log_level))
- except ValueError:
+ except ValueError as e:
# Python logging does not recognise this as a logging level
- raise pytest.UsageError(
+ raise UsageError(
"'{}' is not recognized as a logging level name for "
"'{}'. Please consider passing the "
"logging level num instead.".format(log_level, setting_name)
- )
+ ) from e
# run after terminalreporter/capturemanager are configured
-@pytest.hookimpl(trylast=True)
-def pytest_configure(config):
+@hookimpl(trylast=True)
+def pytest_configure(config: Config) -> None:
config.pluginmanager.register(LoggingPlugin(config), "logging-plugin")
class LoggingPlugin:
- """Attaches to the logging module and captures log messages for each test.
- """
+ """Attaches to the logging module and captures log messages for each test."""
def __init__(self, config: Config) -> None:
- """Creates a new plugin to capture log messages.
+ """Create a new plugin to capture log messages.
The formatter can be safely shared across all handlers so
create a single one for the entire test session here.
"""
self._config = config
- self.print_logs = get_option_ini(config, "log_print")
- if not self.print_logs:
- from _pytest.warnings import _issue_warning_captured
- from _pytest.deprecated import NO_PRINT_LOGS
-
- _issue_warning_captured(NO_PRINT_LOGS, self._config.hook, stacklevel=2)
-
+ # Report logging.
self.formatter = self._create_formatter(
get_option_ini(config, "log_format"),
get_option_ini(config, "log_date_format"),
get_option_ini(config, "log_auto_indent"),
)
self.log_level = get_log_level_for_setting(config, "log_level")
+ self.caplog_handler = LogCaptureHandler()
+ self.caplog_handler.setFormatter(self.formatter)
+ self.report_handler = LogCaptureHandler()
+ self.report_handler.setFormatter(self.formatter)
+ # File logging.
self.log_file_level = get_log_level_for_setting(config, "log_file_level")
- self.log_file_format = get_option_ini(config, "log_file_format", "log_format")
- self.log_file_date_format = get_option_ini(
+ log_file = get_option_ini(config, "log_file") or os.devnull
+ if log_file != os.devnull:
+ directory = os.path.dirname(os.path.abspath(log_file))
+ if not os.path.isdir(directory):
+ os.makedirs(directory)
+
+ self.log_file_handler = _FileHandler(log_file, mode="w", encoding="UTF-8")
+ log_file_format = get_option_ini(config, "log_file_format", "log_format")
+ log_file_date_format = get_option_ini(
config, "log_file_date_format", "log_date_format"
)
- self.log_file_formatter = logging.Formatter(
- self.log_file_format, datefmt=self.log_file_date_format
- )
- log_file = get_option_ini(config, "log_file")
- if log_file:
- self.log_file_handler = logging.FileHandler(
- log_file, mode="w", encoding="UTF-8"
- ) # type: Optional[logging.FileHandler]
- self.log_file_handler.setFormatter(self.log_file_formatter)
- else:
- self.log_file_handler = None
-
- self.log_cli_handler = None
-
- self.live_logs_context = lambda: nullcontext()
- # Note that the lambda for the live_logs_context is needed because
- # live_logs_context can otherwise not be entered multiple times due
- # to limitations of contextlib.contextmanager.
+ log_file_formatter = logging.Formatter(
+ log_file_format, datefmt=log_file_date_format
+ )
+ self.log_file_handler.setFormatter(log_file_formatter)
+ # CLI/live logging.
+ self.log_cli_level = get_log_level_for_setting(
+ config, "log_cli_level", "log_level"
+ )
if self._log_cli_enabled():
- self._setup_cli_logging()
+ terminal_reporter = config.pluginmanager.get_plugin("terminalreporter")
+ capture_manager = config.pluginmanager.get_plugin("capturemanager")
+ # if capturemanager plugin is disabled, live logging still works.
+ self.log_cli_handler: Union[
+ _LiveLoggingStreamHandler, _LiveLoggingNullHandler
+ ] = _LiveLoggingStreamHandler(terminal_reporter, capture_manager)
+ else:
+ self.log_cli_handler = _LiveLoggingNullHandler()
+ log_cli_formatter = self._create_formatter(
+ get_option_ini(config, "log_cli_format", "log_format"),
+ get_option_ini(config, "log_cli_date_format", "log_date_format"),
+ get_option_ini(config, "log_auto_indent"),
+ )
+ self.log_cli_handler.setFormatter(log_cli_formatter)
def _create_formatter(self, log_format, log_date_format, auto_indent):
- # color option doesn't exist if terminal plugin is disabled
+ # Color option doesn't exist if terminal plugin is disabled.
color = getattr(self._config.option, "color", "no")
if color != "no" and ColoredLevelFormatter.LEVELNAME_FMT_REGEX.search(
log_format
):
- formatter = ColoredLevelFormatter(
+ formatter: logging.Formatter = ColoredLevelFormatter(
create_terminal_writer(self._config), log_format, log_date_format
- ) # type: logging.Formatter
+ )
else:
formatter = logging.Formatter(log_format, log_date_format)
@@ -551,223 +598,192 @@ class LoggingPlugin:
return formatter
- def _setup_cli_logging(self):
- config = self._config
- terminal_reporter = config.pluginmanager.get_plugin("terminalreporter")
- if terminal_reporter is None:
- # terminal reporter is disabled e.g. by pytest-xdist.
- return
-
- capture_manager = config.pluginmanager.get_plugin("capturemanager")
- # if capturemanager plugin is disabled, live logging still works.
- log_cli_handler = _LiveLoggingStreamHandler(terminal_reporter, capture_manager)
-
- log_cli_formatter = self._create_formatter(
- get_option_ini(config, "log_cli_format", "log_format"),
- get_option_ini(config, "log_cli_date_format", "log_date_format"),
- get_option_ini(config, "log_auto_indent"),
- )
-
- log_cli_level = get_log_level_for_setting(config, "log_cli_level", "log_level")
- self.log_cli_handler = log_cli_handler
- self.live_logs_context = lambda: catching_logs(
- log_cli_handler, formatter=log_cli_formatter, level=log_cli_level
- )
+ def set_log_path(self, fname: str) -> None:
+ """Set the filename parameter for Logging.FileHandler().
- def set_log_path(self, fname):
- """Public method, which can set filename parameter for
- Logging.FileHandler(). Also creates parent directory if
- it does not exist.
+ Creates parent directory if it does not exist.
.. warning::
- Please considered as an experimental API.
+ This is an experimental API.
"""
- fname = Path(fname)
+ fpath = Path(fname)
- if not fname.is_absolute():
- fname = Path(self._config.rootdir, fname)
+ if not fpath.is_absolute():
+ fpath = self._config.rootpath / fpath
- if not fname.parent.exists():
- fname.parent.mkdir(exist_ok=True, parents=True)
+ if not fpath.parent.exists():
+ fpath.parent.mkdir(exist_ok=True, parents=True)
- self.log_file_handler = logging.FileHandler(
- str(fname), mode="w", encoding="UTF-8"
- )
- self.log_file_handler.setFormatter(self.log_file_formatter)
+ stream = fpath.open(mode="w", encoding="UTF-8")
+ if sys.version_info >= (3, 7):
+ old_stream = self.log_file_handler.setStream(stream)
+ else:
+ old_stream = self.log_file_handler.stream
+ self.log_file_handler.acquire()
+ try:
+ self.log_file_handler.flush()
+ self.log_file_handler.stream = stream
+ finally:
+ self.log_file_handler.release()
+ if old_stream:
+ old_stream.close()
def _log_cli_enabled(self):
- """Return True if log_cli should be considered enabled, either explicitly
- or because --log-cli-level was given in the command-line.
- """
- return self._config.getoption(
+ """Return whether live logging is enabled."""
+ enabled = self._config.getoption(
"--log-cli-level"
) is not None or self._config.getini("log_cli")
+ if not enabled:
+ return False
- @pytest.hookimpl(hookwrapper=True, tryfirst=True)
- def pytest_collection(self) -> Generator[None, None, None]:
- with self.live_logs_context():
- if self.log_cli_handler:
- self.log_cli_handler.set_when("collection")
+ terminal_reporter = self._config.pluginmanager.get_plugin("terminalreporter")
+ if terminal_reporter is None:
+ # terminal reporter is disabled e.g. by pytest-xdist.
+ return False
- if self.log_file_handler is not None:
- with catching_logs(self.log_file_handler, level=self.log_file_level):
- yield
- else:
+ return True
+
+ @hookimpl(hookwrapper=True, tryfirst=True)
+ def pytest_sessionstart(self) -> Generator[None, None, None]:
+ self.log_cli_handler.set_when("sessionstart")
+
+ with catching_logs(self.log_cli_handler, level=self.log_cli_level):
+ with catching_logs(self.log_file_handler, level=self.log_file_level):
yield
- @contextmanager
- def _runtest_for(self, item, when):
- with self._runtest_for_main(item, when):
- if self.log_file_handler is not None:
- with catching_logs(self.log_file_handler, level=self.log_file_level):
- yield
- else:
+ @hookimpl(hookwrapper=True, tryfirst=True)
+ def pytest_collection(self) -> Generator[None, None, None]:
+ self.log_cli_handler.set_when("collection")
+
+ with catching_logs(self.log_cli_handler, level=self.log_cli_level):
+ with catching_logs(self.log_file_handler, level=self.log_file_level):
yield
- @contextmanager
- def _runtest_for_main(
- self, item: nodes.Item, when: str
- ) -> Generator[None, None, None]:
- """Implements the internals of pytest_runtest_xxx() hook."""
- with catching_logs(
- LogCaptureHandler(), formatter=self.formatter, level=self.log_level
- ) as log_handler:
- if self.log_cli_handler:
- self.log_cli_handler.set_when(when)
-
- if item is None:
- yield # run the test
- return
-
- if not hasattr(item, "catch_log_handlers"):
- item.catch_log_handlers = {} # type: ignore[attr-defined] # noqa: F821
- item.catch_log_handlers[when] = log_handler # type: ignore[attr-defined] # noqa: F821
- item.catch_log_handler = log_handler # type: ignore[attr-defined] # noqa: F821
- try:
- yield # run test
- finally:
- if when == "teardown":
- del item.catch_log_handler # type: ignore[attr-defined] # noqa: F821
- del item.catch_log_handlers # type: ignore[attr-defined] # noqa: F821
-
- if self.print_logs:
- # Add a captured log section to the report.
- log = log_handler.stream.getvalue().strip()
- item.add_report_section(when, "log", log)
-
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_setup(self, item):
- with self._runtest_for(item, "setup"):
+ @hookimpl(hookwrapper=True)
+ def pytest_runtestloop(self, session: Session) -> Generator[None, None, None]:
+ if session.config.option.collectonly:
yield
+ return
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_call(self, item):
- with self._runtest_for(item, "call"):
- yield
+ if self._log_cli_enabled() and self._config.getoption("verbose") < 1:
+ # The verbose flag is needed to avoid messy test progress output.
+ self._config.option.verbose = 1
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_teardown(self, item):
- with self._runtest_for(item, "teardown"):
- yield
+ with catching_logs(self.log_cli_handler, level=self.log_cli_level):
+ with catching_logs(self.log_file_handler, level=self.log_file_level):
+ yield # Run all the tests.
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_logstart(self):
- if self.log_cli_handler:
- self.log_cli_handler.reset()
- with self._runtest_for(None, "start"):
- yield
+ @hookimpl
+ def pytest_runtest_logstart(self) -> None:
+ self.log_cli_handler.reset()
+ self.log_cli_handler.set_when("start")
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_logfinish(self):
- with self._runtest_for(None, "finish"):
- yield
+ @hookimpl
+ def pytest_runtest_logreport(self) -> None:
+ self.log_cli_handler.set_when("logreport")
+
+ def _runtest_for(self, item: nodes.Item, when: str) -> Generator[None, None, None]:
+ """Implement the internals of the pytest_runtest_xxx() hooks."""
+ with catching_logs(
+ self.caplog_handler, level=self.log_level,
+ ) as caplog_handler, catching_logs(
+ self.report_handler, level=self.log_level,
+ ) as report_handler:
+ caplog_handler.reset()
+ report_handler.reset()
+ item._store[caplog_records_key][when] = caplog_handler.records
+ item._store[caplog_handler_key] = caplog_handler
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_logreport(self):
- with self._runtest_for(None, "logreport"):
yield
- @pytest.hookimpl(hookwrapper=True, tryfirst=True)
- def pytest_sessionfinish(self):
- with self.live_logs_context():
- if self.log_cli_handler:
- self.log_cli_handler.set_when("sessionfinish")
- if self.log_file_handler is not None:
- try:
- with catching_logs(
- self.log_file_handler, level=self.log_file_level
- ):
- yield
- finally:
- # Close the FileHandler explicitly.
- # (logging.shutdown might have lost the weakref?!)
- self.log_file_handler.close()
- else:
- yield
+ log = report_handler.stream.getvalue().strip()
+ item.add_report_section(when, "log", log)
- @pytest.hookimpl(hookwrapper=True, tryfirst=True)
- def pytest_sessionstart(self):
- with self.live_logs_context():
- if self.log_cli_handler:
- self.log_cli_handler.set_when("sessionstart")
- if self.log_file_handler is not None:
- with catching_logs(self.log_file_handler, level=self.log_file_level):
- yield
- else:
+ @hookimpl(hookwrapper=True)
+ def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None, None, None]:
+ self.log_cli_handler.set_when("setup")
+
+ empty: Dict[str, List[logging.LogRecord]] = {}
+ item._store[caplog_records_key] = empty
+ yield from self._runtest_for(item, "setup")
+
+ @hookimpl(hookwrapper=True)
+ def pytest_runtest_call(self, item: nodes.Item) -> Generator[None, None, None]:
+ self.log_cli_handler.set_when("call")
+
+ yield from self._runtest_for(item, "call")
+
+ @hookimpl(hookwrapper=True)
+ def pytest_runtest_teardown(self, item: nodes.Item) -> Generator[None, None, None]:
+ self.log_cli_handler.set_when("teardown")
+
+ yield from self._runtest_for(item, "teardown")
+ del item._store[caplog_records_key]
+ del item._store[caplog_handler_key]
+
+ @hookimpl
+ def pytest_runtest_logfinish(self) -> None:
+ self.log_cli_handler.set_when("finish")
+
+ @hookimpl(hookwrapper=True, tryfirst=True)
+ def pytest_sessionfinish(self) -> Generator[None, None, None]:
+ self.log_cli_handler.set_when("sessionfinish")
+
+ with catching_logs(self.log_cli_handler, level=self.log_cli_level):
+ with catching_logs(self.log_file_handler, level=self.log_file_level):
yield
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtestloop(self, session):
- """Runs all collected test items."""
+ @hookimpl
+ def pytest_unconfigure(self) -> None:
+ # Close the FileHandler explicitly.
+ # (logging.shutdown might have lost the weakref?!)
+ self.log_file_handler.close()
- if session.config.option.collectonly:
- yield
- return
- if self._log_cli_enabled() and self._config.getoption("verbose") < 1:
- # setting verbose flag is needed to avoid messy test progress output
- self._config.option.verbose = 1
+class _FileHandler(logging.FileHandler):
+ """A logging FileHandler with pytest tweaks."""
- with self.live_logs_context():
- if self.log_file_handler is not None:
- with catching_logs(self.log_file_handler, level=self.log_file_level):
- yield # run all the tests
- else:
- yield # run all the tests
+ def handleError(self, record: logging.LogRecord) -> None:
+ # Handled by LogCaptureHandler.
+ pass
class _LiveLoggingStreamHandler(logging.StreamHandler):
- """
- Custom StreamHandler used by the live logging feature: it will write a newline before the first log message
- in each test.
+ """A logging StreamHandler used by the live logging feature: it will
+ write a newline before the first log message in each test.
- During live logging we must also explicitly disable stdout/stderr capturing otherwise it will get captured
- and won't appear in the terminal.
+ During live logging we must also explicitly disable stdout/stderr
+ capturing otherwise it will get captured and won't appear in the
+ terminal.
"""
- def __init__(self, terminal_reporter, capture_manager):
- """
- :param _pytest.terminal.TerminalReporter terminal_reporter:
- :param _pytest.capture.CaptureManager capture_manager:
- """
- logging.StreamHandler.__init__(self, stream=terminal_reporter)
+ # Officially stream needs to be a IO[str], but TerminalReporter
+ # isn't. So force it.
+ stream: TerminalReporter = None # type: ignore
+
+ def __init__(
+ self,
+ terminal_reporter: TerminalReporter,
+ capture_manager: Optional[CaptureManager],
+ ) -> None:
+ logging.StreamHandler.__init__(self, stream=terminal_reporter) # type: ignore[arg-type]
self.capture_manager = capture_manager
self.reset()
self.set_when(None)
self._test_outcome_written = False
- def reset(self):
- """Reset the handler; should be called before the start of each test"""
+ def reset(self) -> None:
+ """Reset the handler; should be called before the start of each test."""
self._first_record_emitted = False
- def set_when(self, when):
- """Prepares for the given test phase (setup/call/teardown)"""
+ def set_when(self, when: Optional[str]) -> None:
+ """Prepare for the given test phase (setup/call/teardown)."""
self._when = when
self._section_name_shown = False
if when == "start":
self._test_outcome_written = False
- def emit(self, record):
+ def emit(self, record: logging.LogRecord) -> None:
ctx_manager = (
self.capture_manager.global_and_fixture_disabled()
if self.capture_manager
@@ -784,4 +800,22 @@ class _LiveLoggingStreamHandler(logging.StreamHandler):
if not self._section_name_shown and self._when:
self.stream.section("live log " + self._when, sep="-", bold=True)
self._section_name_shown = True
- logging.StreamHandler.emit(self, record)
+ super().emit(record)
+
+ def handleError(self, record: logging.LogRecord) -> None:
+ # Handled by LogCaptureHandler.
+ pass
+
+
+class _LiveLoggingNullHandler(logging.NullHandler):
+ """A logging handler used when live logging is disabled."""
+
+ def reset(self) -> None:
+ pass
+
+ def set_when(self, when: str) -> None:
+ pass
+
+ def handleError(self, record: logging.LogRecord) -> None:
+ # Handled by LogCaptureHandler.
+ pass
diff --git a/contrib/python/pytest/py3/_pytest/main.py b/contrib/python/pytest/py3/_pytest/main.py
index 61eb7ca74c..41a33d4494 100644
--- a/contrib/python/pytest/py3/_pytest/main.py
+++ b/contrib/python/pytest/py3/_pytest/main.py
@@ -1,16 +1,23 @@
-""" core implementation of testing process: init, session, runtest loop. """
+"""Core implementation of the testing process: init, session, runtest loop."""
+import argparse
import fnmatch
import functools
import importlib
import os
import sys
+from pathlib import Path
from typing import Callable
from typing import Dict
from typing import FrozenSet
+from typing import Iterator
from typing import List
from typing import Optional
+from typing import overload
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import Union
import attr
@@ -18,32 +25,45 @@ import py
import _pytest._code
from _pytest import nodes
-from _pytest.compat import TYPE_CHECKING
+from _pytest.compat import final
from _pytest.config import Config
from _pytest.config import directory_arg
from _pytest.config import ExitCode
from _pytest.config import hookimpl
+from _pytest.config import PytestPluginManager
from _pytest.config import UsageError
+from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureManager
from _pytest.outcomes import exit
+from _pytest.pathlib import absolutepath
+from _pytest.pathlib import bestrelpath
+from _pytest.pathlib import visit
from _pytest.reports import CollectReport
+from _pytest.reports import TestReport
from _pytest.runner import collect_one_node
from _pytest.runner import SetupState
if TYPE_CHECKING:
- from typing import Type
from typing_extensions import Literal
- from _pytest.python import Package
-
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
parser.addini(
"norecursedirs",
"directory patterns to avoid for recursion",
type="args",
- default=[".*", "build", "dist", "CVS", "_darcs", "{arch}", "*.egg", "venv"],
+ default=[
+ "*.egg",
+ ".*",
+ "_darcs",
+ "build",
+ "CVS",
+ "dist",
+ "node_modules",
+ "venv",
+ "{arch}",
+ ],
)
parser.addini(
"testpaths",
@@ -61,6 +81,20 @@ def pytest_addoption(parser):
const=1,
help="exit instantly on first error or failed test.",
)
+ group = parser.getgroup("pytest-warnings")
+ group.addoption(
+ "-W",
+ "--pythonwarnings",
+ action="append",
+ help="set which warnings to report, see -W option of python itself.",
+ )
+ parser.addini(
+ "filterwarnings",
+ type="linelist",
+ help="Each line specifies a pattern for "
+ "warnings.filterwarnings. "
+ "Processed after -W/--pythonwarnings.",
+ )
group._addoption(
"--maxfail",
metavar="num",
@@ -71,12 +105,19 @@ def pytest_addoption(parser):
help="exit after first num failures or errors.",
)
group._addoption(
+ "--strict-config",
+ action="store_true",
+ help="any warnings encountered while parsing the `pytest` section of the configuration file raise errors.",
+ )
+ group._addoption(
"--strict-markers",
- "--strict",
action="store_true",
help="markers not registered in the `markers` section of the configuration file raise errors.",
)
group._addoption(
+ "--strict", action="store_true", help="(deprecated) alias to --strict-markers.",
+ )
+ group._addoption(
"-c",
metavar="file",
type=str,
@@ -161,12 +202,21 @@ def pytest_addoption(parser):
default=False,
help="Don't ignore tests in a local virtualenv directory",
)
+ group.addoption(
+ "--import-mode",
+ default="prepend",
+ choices=["prepend", "append", "importlib"],
+ dest="importmode",
+ help="prepend/append to sys.path when importing test modules and conftest files, "
+ "default is to prepend.",
+ )
group = parser.getgroup("debugconfig", "test session debugging and configuration")
group.addoption(
"--basetemp",
dest="basetemp",
default=None,
+ type=validate_basetemp,
metavar="dir",
help=(
"base temporary directory for this test run."
@@ -175,10 +225,38 @@ def pytest_addoption(parser):
)
+def validate_basetemp(path: str) -> str:
+ # GH 7119
+ msg = "basetemp must not be empty, the current working directory or any parent directory of it"
+
+ # empty path
+ if not path:
+ raise argparse.ArgumentTypeError(msg)
+
+ def is_ancestor(base: Path, query: Path) -> bool:
+ """Return whether query is an ancestor of base."""
+ if base == query:
+ return True
+ for parent in base.parents:
+ if parent == query:
+ return True
+ return False
+
+ # check if path is an ancestor of cwd
+ if is_ancestor(Path.cwd(), Path(path).absolute()):
+ raise argparse.ArgumentTypeError(msg)
+
+ # check symlinks for ancestors
+ if is_ancestor(Path.cwd().resolve(), Path(path).resolve()):
+ raise argparse.ArgumentTypeError(msg)
+
+ return path
+
+
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
- """Skeleton command line program"""
+ """Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
@@ -196,17 +274,15 @@ def wrap_session(
session.exitstatus = ExitCode.TESTS_FAILED
except (KeyboardInterrupt, exit.Exception):
excinfo = _pytest._code.ExceptionInfo.from_current()
- exitstatus = ExitCode.INTERRUPTED # type: Union[int, ExitCode]
+ exitstatus: Union[int, ExitCode] = ExitCode.INTERRUPTED
if isinstance(excinfo.value, exit.Exception):
if excinfo.value.returncode is not None:
exitstatus = excinfo.value.returncode
if initstate < 2:
- sys.stderr.write(
- "{}: {}\n".format(excinfo.typename, excinfo.value.msg)
- )
+ sys.stderr.write(f"{excinfo.typename}: {excinfo.value.msg}\n")
config.hook.pytest_keyboard_interrupt(excinfo=excinfo)
session.exitstatus = exitstatus
- except: # noqa
+ except BaseException:
session.exitstatus = ExitCode.INTERNAL_ERROR
excinfo = _pytest._code.ExceptionInfo.from_current()
try:
@@ -216,7 +292,7 @@ def wrap_session(
session.exitstatus = exc.returncode
sys.stderr.write("{}: {}\n".format(type(exc).__name__, exc))
else:
- if excinfo.errisinstance(SystemExit):
+ if isinstance(excinfo.value, SystemExit):
sys.stderr.write("mainloop: caught unexpected SystemExit!\n")
finally:
@@ -236,13 +312,13 @@ def wrap_session(
return session.exitstatus
-def pytest_cmdline_main(config):
+def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
return wrap_session(config, _main)
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
- """ default command line protocol for initialization, session,
- running tests and reporting. """
+ """Default command line protocol for initialization, session,
+ running tests and reporting."""
config.hook.pytest_collection(session=session)
config.hook.pytest_runtestloop(session=session)
@@ -253,11 +329,11 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
return None
-def pytest_collection(session):
- return session.perform_collect()
+def pytest_collection(session: "Session") -> None:
+ session.perform_collect()
-def pytest_runtestloop(session):
+def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
@@ -277,9 +353,9 @@ def pytest_runtestloop(session):
return True
-def _in_venv(path):
- """Attempts to detect if ``path`` is the root of a Virtual Environment by
- checking for the existence of the appropriate activate script"""
+def _in_venv(path: py.path.local) -> bool:
+ """Attempt to detect if ``path`` is the root of a Virtual Environment by
+ checking for the existence of the appropriate activate script."""
bindir = path.join("Scripts" if sys.platform.startswith("win") else "bin")
if not bindir.isdir():
return False
@@ -294,9 +370,7 @@ def _in_venv(path):
return any([fname.basename in activates for fname in bindir.listdir()])
-def pytest_ignore_collect(
- path: py.path.local, config: Config
-) -> "Optional[Literal[True]]":
+def pytest_ignore_collect(path: py.path.local, config: Config) -> Optional[bool]:
ignore_paths = config._getconftest_pathlist("collect_ignore", path=path.dirpath())
ignore_paths = ignore_paths or []
excludeopt = config.getoption("ignore")
@@ -323,7 +397,7 @@ def pytest_ignore_collect(
return None
-def pytest_collection_modifyitems(items, config):
+def pytest_collection_modifyitems(items: List[nodes.Item], config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or [])
if not deselect_prefixes:
return
@@ -341,76 +415,69 @@ def pytest_collection_modifyitems(items, config):
items[:] = remaining
-class NoMatch(Exception):
- """ raised if matching cannot locate a matching names. """
+class FSHookProxy:
+ def __init__(self, pm: PytestPluginManager, remove_mods) -> None:
+ self.pm = pm
+ self.remove_mods = remove_mods
+
+ def __getattr__(self, name: str):
+ x = self.pm.subset_hook_caller(name, remove_plugins=self.remove_mods)
+ self.__dict__[name] = x
+ return x
class Interrupted(KeyboardInterrupt):
- """ signals an interrupted test run. """
+ """Signals that the test run was interrupted."""
- __module__ = "builtins" # for py3
+ __module__ = "builtins" # For py3.
class Failed(Exception):
- """ signals a stop as failed test run. """
+ """Signals a stop as failed test run."""
@attr.s
-class _bestrelpath_cache(dict):
- path = attr.ib(type=py.path.local)
+class _bestrelpath_cache(Dict[Path, str]):
+ path = attr.ib(type=Path)
- def __missing__(self, path: py.path.local) -> str:
- r = self.path.bestrelpath(path) # type: str
+ def __missing__(self, path: Path) -> str:
+ r = bestrelpath(self.path, path)
self[path] = r
return r
+@final
class Session(nodes.FSCollector):
Interrupted = Interrupted
Failed = Failed
# Set on the session by runner.pytest_sessionstart.
- _setupstate = None # type: SetupState
+ _setupstate: SetupState
# Set on the session by fixtures.pytest_sessionstart.
- _fixturemanager = None # type: FixtureManager
- exitstatus = None # type: Union[int, ExitCode]
+ _fixturemanager: FixtureManager
+ exitstatus: Union[int, ExitCode]
def __init__(self, config: Config) -> None:
- nodes.FSCollector.__init__(
- self, config.rootdir, parent=None, config=config, session=self, nodeid=""
+ super().__init__(
+ config.rootdir, parent=None, config=config, session=self, nodeid=""
)
self.testsfailed = 0
self.testscollected = 0
- self.shouldstop = False
- self.shouldfail = False
+ self.shouldstop: Union[bool, str] = False
+ self.shouldfail: Union[bool, str] = False
self.trace = config.trace.root.get("collection")
self.startdir = config.invocation_dir
- self._initialpaths = frozenset() # type: FrozenSet[py.path.local]
-
- # Keep track of any collected nodes in here, so we don't duplicate fixtures
- self._collection_node_cache1 = (
- {}
- ) # type: Dict[py.path.local, Sequence[nodes.Collector]]
- self._collection_node_cache2 = (
- {}
- ) # type: Dict[Tuple[Type[nodes.Collector], py.path.local], nodes.Collector]
- self._collection_node_cache3 = (
- {}
- ) # type: Dict[Tuple[Type[nodes.Collector], str], CollectReport]
-
- # Dirnames of pkgs with dunder-init files.
- self._collection_pkg_roots = {} # type: Dict[py.path.local, Package]
+ self._initialpaths: FrozenSet[py.path.local] = frozenset()
- self._bestrelpathcache = _bestrelpath_cache(
- config.rootdir
- ) # type: Dict[py.path.local, str]
+ self._bestrelpathcache: Dict[Path, str] = _bestrelpath_cache(config.rootpath)
self.config.pluginmanager.register(self, name="session")
@classmethod
- def from_config(cls, config):
- return cls._create(config)
+ def from_config(cls, config: Config) -> "Session":
+ session: Session = cls._create(config)
+ return session
- def __repr__(self):
+ def __repr__(self) -> str:
return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % (
self.__class__.__name__,
self.name,
@@ -419,19 +486,21 @@ class Session(nodes.FSCollector):
self.testscollected,
)
- def _node_location_to_relpath(self, node_path: py.path.local) -> str:
- # bestrelpath is a quite slow function
+ def _node_location_to_relpath(self, node_path: Path) -> str:
+ # bestrelpath is a quite slow function.
return self._bestrelpathcache[node_path]
@hookimpl(tryfirst=True)
- def pytest_collectstart(self):
+ def pytest_collectstart(self) -> None:
if self.shouldfail:
raise self.Failed(self.shouldfail)
if self.shouldstop:
raise self.Interrupted(self.shouldstop)
@hookimpl(tryfirst=True)
- def pytest_runtest_logreport(self, report):
+ def pytest_runtest_logreport(
+ self, report: Union[TestReport, CollectReport]
+ ) -> None:
if report.failed and not hasattr(report, "wasxfail"):
self.testsfailed += 1
maxfail = self.config.getvalue("maxfail")
@@ -440,238 +509,296 @@ class Session(nodes.FSCollector):
pytest_collectreport = pytest_runtest_logreport
- def isinitpath(self, path):
+ def isinitpath(self, path: py.path.local) -> bool:
return path in self._initialpaths
def gethookproxy(self, fspath: py.path.local):
- return super()._gethookproxy(fspath)
+ # Check if we have the common case of running
+ # hooks with all conftest.py files.
+ pm = self.config.pluginmanager
+ my_conftestmodules = pm._getconftestmodules(
+ fspath, self.config.getoption("importmode")
+ )
+ remove_mods = pm._conftest_plugins.difference(my_conftestmodules)
+ if remove_mods:
+ # One or more conftests are not in use at this fspath.
+ proxy = FSHookProxy(pm, remove_mods)
+ else:
+ # All plugins are active for this fspath.
+ proxy = self.config.hook
+ return proxy
+
+ def _recurse(self, direntry: "os.DirEntry[str]") -> bool:
+ if direntry.name == "__pycache__":
+ return False
+ path = py.path.local(direntry.path)
+ ihook = self.gethookproxy(path.dirpath())
+ if ihook.pytest_ignore_collect(path=path, config=self.config):
+ return False
+ norecursepatterns = self.config.getini("norecursedirs")
+ if any(path.check(fnmatch=pat) for pat in norecursepatterns):
+ return False
+ return True
+
+ def _collectfile(
+ self, path: py.path.local, handle_dupes: bool = True
+ ) -> Sequence[nodes.Collector]:
+ assert (
+ path.isfile()
+ ), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format(
+ path, path.isdir(), path.exists(), path.islink()
+ )
+ ihook = self.gethookproxy(path)
+ if not self.isinitpath(path):
+ if ihook.pytest_ignore_collect(path=path, config=self.config):
+ return ()
+
+ if handle_dupes:
+ keepduplicates = self.config.getoption("keepduplicates")
+ if not keepduplicates:
+ duplicate_paths = self.config.pluginmanager._duplicatepaths
+ if path in duplicate_paths:
+ return ()
+ else:
+ duplicate_paths.add(path)
+
+ return ihook.pytest_collect_file(path=path, parent=self) # type: ignore[no-any-return]
+
+ @overload
+ def perform_collect(
+ self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ...
+ ) -> Sequence[nodes.Item]:
+ ...
+
+ @overload
+ def perform_collect(
+ self, args: Optional[Sequence[str]] = ..., genitems: bool = ...
+ ) -> Sequence[Union[nodes.Item, nodes.Collector]]:
+ ...
+
+ def perform_collect(
+ self, args: Optional[Sequence[str]] = None, genitems: bool = True
+ ) -> Sequence[Union[nodes.Item, nodes.Collector]]:
+ """Perform the collection phase for this session.
+
+ This is called by the default
+ :func:`pytest_collection <_pytest.hookspec.pytest_collection>` hook
+ implementation; see the documentation of this hook for more details.
+ For testing purposes, it may also be called directly on a fresh
+ ``Session``.
+
+ This function normally recursively expands any collectors collected
+ from the session to their items, and only items are returned. For
+ testing purposes, this may be suppressed by passing ``genitems=False``,
+ in which case the return value contains these collectors unexpanded,
+ and ``session.items`` is empty.
+ """
+ if args is None:
+ args = self.config.args
+
+ self.trace("perform_collect", self, args)
+ self.trace.root.indent += 1
+
+ self._notfound: List[Tuple[str, Sequence[nodes.Collector]]] = []
+ self._initial_parts: List[Tuple[py.path.local, List[str]]] = []
+ self.items: List[nodes.Item] = []
- def perform_collect(self, args=None, genitems=True):
hook = self.config.hook
+
+ items: Sequence[Union[nodes.Item, nodes.Collector]] = self.items
try:
- items = self._perform_collect(args, genitems)
+ initialpaths: List[py.path.local] = []
+ for arg in args:
+ fspath, parts = resolve_collection_argument(
+ self.config.invocation_params.dir,
+ arg,
+ as_pypath=self.config.option.pyargs,
+ )
+ self._initial_parts.append((fspath, parts))
+ initialpaths.append(fspath)
+ self._initialpaths = frozenset(initialpaths)
+ rep = collect_one_node(self)
+ self.ihook.pytest_collectreport(report=rep)
+ self.trace.root.indent -= 1
+ if self._notfound:
+ errors = []
+ for arg, cols in self._notfound:
+ line = f"(no name {arg!r} in any of {cols!r})"
+ errors.append(f"not found: {arg}\n{line}")
+ raise UsageError(*errors)
+ if not genitems:
+ items = rep.result
+ else:
+ if rep.passed:
+ for node in rep.result:
+ self.items.extend(self.genitems(node))
+
self.config.pluginmanager.check_pending()
hook.pytest_collection_modifyitems(
session=self, config=self.config, items=items
)
finally:
hook.pytest_collection_finish(session=self)
+
self.testscollected = len(items)
return items
- def _perform_collect(self, args, genitems):
- if args is None:
- args = self.config.args
- self.trace("perform_collect", self, args)
- self.trace.root.indent += 1
- self._notfound = []
- initialpaths = [] # type: List[py.path.local]
- self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]]
- self.items = items = []
- for arg in args:
- fspath, parts = self._parsearg(arg)
- self._initial_parts.append((fspath, parts))
- initialpaths.append(fspath)
- self._initialpaths = frozenset(initialpaths)
- rep = collect_one_node(self)
- self.ihook.pytest_collectreport(report=rep)
- self.trace.root.indent -= 1
- if self._notfound:
- errors = []
- for arg, exc in self._notfound:
- line = "(no name {!r} in any of {!r})".format(arg, exc.args[0])
- errors.append("not found: {}\n{}".format(arg, line))
- raise UsageError(*errors)
- if not genitems:
- return rep.result
- else:
- if rep.passed:
- for node in rep.result:
- self.items.extend(self.genitems(node))
- return items
+ def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]:
+ from _pytest.python import Package
- def collect(self):
- for fspath, parts in self._initial_parts:
- self.trace("processing argument", (fspath, parts))
- self.trace.root.indent += 1
- try:
- yield from self._collect(fspath, parts)
- except NoMatch as exc:
- report_arg = "::".join((str(fspath), *parts))
- # we are inside a make_report hook so
- # we cannot directly pass through the exception
- self._notfound.append((report_arg, exc))
+ # Keep track of any collected nodes in here, so we don't duplicate fixtures.
+ node_cache1: Dict[py.path.local, Sequence[nodes.Collector]] = {}
+ node_cache2: Dict[
+ Tuple[Type[nodes.Collector], py.path.local], nodes.Collector
+ ] = ({})
- self.trace.root.indent -= 1
- self._collection_node_cache1.clear()
- self._collection_node_cache2.clear()
- self._collection_node_cache3.clear()
- self._collection_pkg_roots.clear()
+ # Keep track of any collected collectors in matchnodes paths, so they
+ # are not collected more than once.
+ matchnodes_cache: Dict[Tuple[Type[nodes.Collector], str], CollectReport] = ({})
- def _collect(self, argpath, names):
- from _pytest.python import Package
+ # Dirnames of pkgs with dunder-init files.
+ pkg_roots: Dict[str, Package] = {}
+
+ for argpath, names in self._initial_parts:
+ self.trace("processing argument", (argpath, names))
+ self.trace.root.indent += 1
- # Start with a Session root, and delve to argpath item (dir or file)
- # and stack all Packages found on the way.
- # No point in finding packages when collecting doctests
- if not self.config.getoption("doctestmodules", False):
- pm = self.config.pluginmanager
- for parent in reversed(argpath.parts()):
- if pm._confcutdir and pm._confcutdir.relto(parent):
- break
-
- if parent.isdir():
- pkginit = parent.join("__init__.py")
- if pkginit.isfile():
- if pkginit not in self._collection_node_cache1:
+ # Start with a Session root, and delve to argpath item (dir or file)
+ # and stack all Packages found on the way.
+ # No point in finding packages when collecting doctests.
+ if not self.config.getoption("doctestmodules", False):
+ pm = self.config.pluginmanager
+ for parent in reversed(argpath.parts()):
+ if pm._confcutdir and pm._confcutdir.relto(parent):
+ break
+
+ if parent.isdir():
+ pkginit = parent.join("__init__.py")
+ if pkginit.isfile() and pkginit not in node_cache1:
col = self._collectfile(pkginit, handle_dupes=False)
if col:
if isinstance(col[0], Package):
- self._collection_pkg_roots[parent] = col[0]
- # always store a list in the cache, matchnodes expects it
- self._collection_node_cache1[col[0].fspath] = [col[0]]
-
- # If it's a directory argument, recurse and look for any Subpackages.
- # Let the Package collector deal with subnodes, don't collect here.
- if argpath.check(dir=1):
- assert not names, "invalid arg {!r}".format((argpath, names))
-
- seen_dirs = set()
- for path in argpath.visit(
- fil=self._visit_filter, rec=self._recurse, bf=True, sort=True
- ):
- dirpath = path.dirpath()
- if dirpath not in seen_dirs:
- # Collect packages first.
- seen_dirs.add(dirpath)
- pkginit = dirpath.join("__init__.py")
- if pkginit.exists():
- for x in self._collectfile(pkginit):
+ pkg_roots[str(parent)] = col[0]
+ node_cache1[col[0].fspath] = [col[0]]
+
+ # If it's a directory argument, recurse and look for any Subpackages.
+ # Let the Package collector deal with subnodes, don't collect here.
+ if argpath.check(dir=1):
+ assert not names, "invalid arg {!r}".format((argpath, names))
+
+ seen_dirs: Set[py.path.local] = set()
+ for direntry in visit(str(argpath), self._recurse):
+ if not direntry.is_file():
+ continue
+
+ path = py.path.local(direntry.path)
+ dirpath = path.dirpath()
+
+ if dirpath not in seen_dirs:
+ # Collect packages first.
+ seen_dirs.add(dirpath)
+ pkginit = dirpath.join("__init__.py")
+ if pkginit.exists():
+ for x in self._collectfile(pkginit):
+ yield x
+ if isinstance(x, Package):
+ pkg_roots[str(dirpath)] = x
+ if str(dirpath) in pkg_roots:
+ # Do not collect packages here.
+ continue
+
+ for x in self._collectfile(path):
+ key = (type(x), x.fspath)
+ if key in node_cache2:
+ yield node_cache2[key]
+ else:
+ node_cache2[key] = x
yield x
- if isinstance(x, Package):
- self._collection_pkg_roots[dirpath] = x
- if dirpath in self._collection_pkg_roots:
- # Do not collect packages here.
+ else:
+ assert argpath.check(file=1)
+
+ if argpath in node_cache1:
+ col = node_cache1[argpath]
+ else:
+ collect_root = pkg_roots.get(argpath.dirname, self)
+ col = collect_root._collectfile(argpath, handle_dupes=False)
+ if col:
+ node_cache1[argpath] = col
+
+ matching = []
+ work: List[
+ Tuple[Sequence[Union[nodes.Item, nodes.Collector]], Sequence[str]]
+ ] = [(col, names)]
+ while work:
+ self.trace("matchnodes", col, names)
+ self.trace.root.indent += 1
+
+ matchnodes, matchnames = work.pop()
+ for node in matchnodes:
+ if not matchnames:
+ matching.append(node)
+ continue
+ if not isinstance(node, nodes.Collector):
+ continue
+ key = (type(node), node.nodeid)
+ if key in matchnodes_cache:
+ rep = matchnodes_cache[key]
+ else:
+ rep = collect_one_node(node)
+ matchnodes_cache[key] = rep
+ if rep.passed:
+ submatchnodes = []
+ for r in rep.result:
+ # TODO: Remove parametrized workaround once collection structure contains
+ # parametrization.
+ if (
+ r.name == matchnames[0]
+ or r.name.split("[")[0] == matchnames[0]
+ ):
+ submatchnodes.append(r)
+ if submatchnodes:
+ work.append((submatchnodes, matchnames[1:]))
+ # XXX Accept IDs that don't have "()" for class instances.
+ elif len(rep.result) == 1 and rep.result[0].name == "()":
+ work.append((rep.result, matchnames))
+ else:
+ # Report collection failures here to avoid failing to run some test
+ # specified in the command line because the module could not be
+ # imported (#134).
+ node.ihook.pytest_collectreport(report=rep)
+
+ self.trace("matchnodes finished -> ", len(matching), "nodes")
+ self.trace.root.indent -= 1
+
+ if not matching:
+ report_arg = "::".join((str(argpath), *names))
+ self._notfound.append((report_arg, col))
continue
- for x in self._collectfile(path):
- key = (type(x), x.fspath)
- if key in self._collection_node_cache2:
- yield self._collection_node_cache2[key]
- else:
- self._collection_node_cache2[key] = x
- yield x
- else:
- assert argpath.check(file=1)
+ # If __init__.py was the only file requested, then the matched
+ # node will be the corresponding Package (by default), and the
+ # first yielded item will be the __init__ Module itself, so
+ # just use that. If this special case isn't taken, then all the
+ # files in the package will be yielded.
+ if argpath.basename == "__init__.py" and isinstance(
+ matching[0], Package
+ ):
+ try:
+ yield next(iter(matching[0].collect()))
+ except StopIteration:
+ # The package collects nothing with only an __init__.py
+ # file in it, which gets ignored by the default
+ # "python_files" option.
+ pass
+ continue
- if argpath in self._collection_node_cache1:
- col = self._collection_node_cache1[argpath]
- else:
- collect_root = self._collection_pkg_roots.get(argpath.dirname, self)
- col = collect_root._collectfile(argpath, handle_dupes=False)
- if col:
- self._collection_node_cache1[argpath] = col
- m = self.matchnodes(col, names)
- # If __init__.py was the only file requested, then the matched node will be
- # the corresponding Package, and the first yielded item will be the __init__
- # Module itself, so just use that. If this special case isn't taken, then all
- # the files in the package will be yielded.
- if argpath.basename == "__init__.py":
- try:
- yield next(m[0].collect())
- except StopIteration:
- # The package collects nothing with only an __init__.py
- # file in it, which gets ignored by the default
- # "python_files" option.
- pass
- return
- yield from m
-
- @staticmethod
- def _visit_filter(f):
- return f.check(file=1)
-
- def _tryconvertpyarg(self, x):
- """Convert a dotted module name to path."""
- try:
- spec = importlib.util.find_spec(x)
- # AttributeError: looks like package module, but actually filename
- # ImportError: module does not exist
- # ValueError: not a module name
- except (AttributeError, ImportError, ValueError):
- return x
- if spec is None or spec.origin in {None, "namespace"}:
- return x
- elif spec.submodule_search_locations:
- return os.path.dirname(spec.origin)
- else:
- return spec.origin
-
- def _parsearg(self, arg):
- """ return (fspath, names) tuple after checking the file exists. """
- strpath, *parts = str(arg).split("::")
- if self.config.option.pyargs:
- strpath = self._tryconvertpyarg(strpath)
- relpath = strpath.replace("/", os.sep)
- fspath = self.config.invocation_dir.join(relpath, abs=True)
- if not fspath.check():
- if self.config.option.pyargs:
- raise UsageError(
- "file or package not found: " + arg + " (missing __init__.py?)"
- )
- raise UsageError("file not found: " + arg)
- fspath = fspath.realpath()
- return (fspath, parts)
+ yield from matching
- def matchnodes(self, matching, names):
- self.trace("matchnodes", matching, names)
- self.trace.root.indent += 1
- nodes = self._matchnodes(matching, names)
- num = len(nodes)
- self.trace("matchnodes finished -> ", num, "nodes")
- self.trace.root.indent -= 1
- if num == 0:
- raise NoMatch(matching, names[:1])
- return nodes
-
- def _matchnodes(self, matching, names):
- if not matching or not names:
- return matching
- name = names[0]
- assert name
- nextnames = names[1:]
- resultnodes = []
- for node in matching:
- if isinstance(node, nodes.Item):
- if not names:
- resultnodes.append(node)
- continue
- assert isinstance(node, nodes.Collector)
- key = (type(node), node.nodeid)
- if key in self._collection_node_cache3:
- rep = self._collection_node_cache3[key]
- else:
- rep = collect_one_node(node)
- self._collection_node_cache3[key] = rep
- if rep.passed:
- has_matched = False
- for x in rep.result:
- # TODO: remove parametrized workaround once collection structure contains parametrization
- if x.name == name or x.name.split("[")[0] == name:
- resultnodes.extend(self.matchnodes([x], nextnames))
- has_matched = True
- # XXX accept IDs that don't have "()" for class instances
- if not has_matched and len(rep.result) == 1 and x.name == "()":
- nextnames.insert(0, name)
- resultnodes.extend(self.matchnodes([x], nextnames))
- else:
- # report collection failures here to avoid failing to run some test
- # specified in the command line because the module could not be
- # imported (#134)
- node.ihook.pytest_collectreport(report=rep)
- return resultnodes
+ self.trace.root.indent -= 1
- def genitems(self, node):
+ def genitems(
+ self, node: Union[nodes.Item, nodes.Collector]
+ ) -> Iterator[nodes.Item]:
self.trace("genitems", node)
if isinstance(node, nodes.Item):
node.ihook.pytest_itemcollected(item=node)
@@ -683,3 +810,67 @@ class Session(nodes.FSCollector):
for subnode in rep.result:
yield from self.genitems(subnode)
node.ihook.pytest_collectreport(report=rep)
+
+
+def search_pypath(module_name: str) -> str:
+ """Search sys.path for the given a dotted module name, and return its file system path."""
+ try:
+ spec = importlib.util.find_spec(module_name)
+ # AttributeError: looks like package module, but actually filename
+ # ImportError: module does not exist
+ # ValueError: not a module name
+ except (AttributeError, ImportError, ValueError):
+ return module_name
+ if spec is None or spec.origin is None or spec.origin == "namespace":
+ return module_name
+ elif spec.submodule_search_locations:
+ return os.path.dirname(spec.origin)
+ else:
+ return spec.origin
+
+
+def resolve_collection_argument(
+ invocation_path: Path, arg: str, *, as_pypath: bool = False
+) -> Tuple[py.path.local, List[str]]:
+ """Parse path arguments optionally containing selection parts and return (fspath, names).
+
+ Command-line arguments can point to files and/or directories, and optionally contain
+ parts for specific tests selection, for example:
+
+ "pkg/tests/test_foo.py::TestClass::test_foo"
+
+ This function ensures the path exists, and returns a tuple:
+
+ (py.path.path("/full/path/to/pkg/tests/test_foo.py"), ["TestClass", "test_foo"])
+
+ When as_pypath is True, expects that the command-line argument actually contains
+ module paths instead of file-system paths:
+
+ "pkg.tests.test_foo::TestClass::test_foo"
+
+ In which case we search sys.path for a matching module, and then return the *path* to the
+ found module.
+
+ If the path doesn't exist, raise UsageError.
+ If the path is a directory and selection parts are present, raise UsageError.
+ """
+ strpath, *parts = str(arg).split("::")
+ if as_pypath:
+ strpath = search_pypath(strpath)
+ fspath = invocation_path / strpath
+ fspath = absolutepath(fspath)
+ if not fspath.exists():
+ msg = (
+ "module or package not found: {arg} (missing __init__.py?)"
+ if as_pypath
+ else "file or directory not found: {arg}"
+ )
+ raise UsageError(msg.format(arg=arg))
+ if parts and fspath.is_dir():
+ msg = (
+ "package argument cannot contain :: selection parts: {arg}"
+ if as_pypath
+ else "directory argument cannot contain :: selection parts: {arg}"
+ )
+ raise UsageError(msg.format(arg=arg))
+ return py.path.local(str(fspath)), parts
diff --git a/contrib/python/pytest/py3/_pytest/mark/__init__.py b/contrib/python/pytest/py3/_pytest/mark/__init__.py
index dab0cf149f..329a11c4ae 100644
--- a/contrib/python/pytest/py3/_pytest/mark/__init__.py
+++ b/contrib/python/pytest/py3/_pytest/mark/__init__.py
@@ -1,8 +1,16 @@
-""" generic mechanism for marking and selecting python functions. """
+"""Generic mechanism for marking and selecting python functions."""
+import warnings
+from typing import AbstractSet
+from typing import Collection
+from typing import List
from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
-from .legacy import matchkeyword
-from .legacy import matchmark
+import attr
+
+from .expression import Expression
+from .expression import ParseError
from .structures import EMPTY_PARAMETERSET_OPTION
from .structures import get_empty_parameterset_mark
from .structures import Mark
@@ -11,37 +19,56 @@ from .structures import MarkDecorator
from .structures import MarkGenerator
from .structures import ParameterSet
from _pytest.config import Config
+from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
+from _pytest.config.argparsing import Parser
+from _pytest.deprecated import MINUS_K_COLON
+from _pytest.deprecated import MINUS_K_DASH
from _pytest.store import StoreKey
-__all__ = ["Mark", "MarkDecorator", "MarkGenerator", "get_empty_parameterset_mark"]
+if TYPE_CHECKING:
+ from _pytest.nodes import Item
+
+
+__all__ = [
+ "MARK_GEN",
+ "Mark",
+ "MarkDecorator",
+ "MarkGenerator",
+ "ParameterSet",
+ "get_empty_parameterset_mark",
+]
old_mark_config_key = StoreKey[Optional[Config]]()
-def param(*values, **kw):
+def param(
+ *values: object,
+ marks: Union[MarkDecorator, Collection[Union[MarkDecorator, Mark]]] = (),
+ id: Optional[str] = None,
+) -> ParameterSet:
"""Specify a parameter in `pytest.mark.parametrize`_ calls or
:ref:`parametrized fixtures <fixture-parametrize-marks>`.
.. code-block:: python
- @pytest.mark.parametrize("test_input,expected", [
- ("3+5", 8),
- pytest.param("6*9", 42, marks=pytest.mark.xfail),
- ])
+ @pytest.mark.parametrize(
+ "test_input,expected",
+ [("3+5", 8), pytest.param("6*9", 42, marks=pytest.mark.xfail),],
+ )
def test_eval(test_input, expected):
assert eval(test_input) == expected
- :param values: variable args of the values of the parameter set, in order.
- :keyword marks: a single mark or a list of marks to be applied to this parameter set.
- :keyword str id: the id to attribute to this parameter set.
+ :param values: Variable args of the values of the parameter set, in order.
+ :keyword marks: A single mark or a list of marks to be applied to this parameter set.
+ :keyword str id: The id to attribute to this parameter set.
"""
- return ParameterSet.param(*values, **kw)
+ return ParameterSet.param(*values, marks=marks, id=id)
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group._addoption(
"-k",
@@ -69,8 +96,8 @@ def pytest_addoption(parser):
dest="markexpr",
default="",
metavar="MARKEXPR",
- help="only run tests matching given mark expression. "
- "example: -m 'mark1 and not mark2'.",
+ help="only run tests matching given mark expression.\n"
+ "For example: -m 'mark1 and not mark2'.",
)
group.addoption(
@@ -84,7 +111,7 @@ def pytest_addoption(parser):
@hookimpl(tryfirst=True)
-def pytest_cmdline_main(config):
+def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
import _pytest.config
if config.option.markers:
@@ -100,23 +127,87 @@ def pytest_cmdline_main(config):
config._ensure_unconfigure()
return 0
+ return None
+
+
+@attr.s(slots=True)
+class KeywordMatcher:
+ """A matcher for keywords.
+
+ Given a list of names, matches any substring of one of these names. The
+ string inclusion check is case-insensitive.
+
+ Will match on the name of colitem, including the names of its parents.
+ Only matches names of items which are either a :class:`Class` or a
+ :class:`Function`.
+
+ Additionally, matches on names in the 'extra_keyword_matches' set of
+ any item, as well as names directly assigned to test functions.
+ """
+
+ _names = attr.ib(type=AbstractSet[str])
+
+ @classmethod
+ def from_item(cls, item: "Item") -> "KeywordMatcher":
+ mapped_names = set()
+
+ # Add the names of the current item and any parent items.
+ import pytest
+
+ for node in item.listchain():
+ if not isinstance(node, (pytest.Instance, pytest.Session)):
+ mapped_names.add(node.name)
+
+ # Add the names added as extra keywords to current or parent items.
+ mapped_names.update(item.listextrakeywords())
+
+ # Add the names attached to the current function through direct assignment.
+ function_obj = getattr(item, "function", None)
+ if function_obj:
+ mapped_names.update(function_obj.__dict__)
-def deselect_by_keyword(items, config):
+ # Add the markers to the keywords as we no longer handle them correctly.
+ mapped_names.update(mark.name for mark in item.iter_markers())
+
+ return cls(mapped_names)
+
+ def __call__(self, subname: str) -> bool:
+ subname = subname.lower()
+ names = (name.lower() for name in self._names)
+
+ for name in names:
+ if subname in name:
+ return True
+ return False
+
+
+def deselect_by_keyword(items: "List[Item]", config: Config) -> None:
keywordexpr = config.option.keyword.lstrip()
if not keywordexpr:
return
if keywordexpr.startswith("-"):
+ # To be removed in pytest 7.0.0.
+ warnings.warn(MINUS_K_DASH, stacklevel=2)
keywordexpr = "not " + keywordexpr[1:]
selectuntil = False
if keywordexpr[-1:] == ":":
+ # To be removed in pytest 7.0.0.
+ warnings.warn(MINUS_K_COLON, stacklevel=2)
selectuntil = True
keywordexpr = keywordexpr[:-1]
+ try:
+ expression = Expression.compile(keywordexpr)
+ except ParseError as e:
+ raise UsageError(
+ f"Wrong expression passed to '-k': {keywordexpr}: {e}"
+ ) from None
+
remaining = []
deselected = []
for colitem in items:
- if keywordexpr and not matchkeyword(colitem, keywordexpr):
+ if keywordexpr and not expression.evaluate(KeywordMatcher.from_item(colitem)):
deselected.append(colitem)
else:
if selectuntil:
@@ -128,15 +219,38 @@ def deselect_by_keyword(items, config):
items[:] = remaining
-def deselect_by_mark(items, config):
+@attr.s(slots=True)
+class MarkMatcher:
+ """A matcher for markers which are present.
+
+ Tries to match on any marker names, attached to the given colitem.
+ """
+
+ own_mark_names = attr.ib()
+
+ @classmethod
+ def from_item(cls, item) -> "MarkMatcher":
+ mark_names = {mark.name for mark in item.iter_markers()}
+ return cls(mark_names)
+
+ def __call__(self, name: str) -> bool:
+ return name in self.own_mark_names
+
+
+def deselect_by_mark(items: "List[Item]", config: Config) -> None:
matchexpr = config.option.markexpr
if not matchexpr:
return
+ try:
+ expression = Expression.compile(matchexpr)
+ except ParseError as e:
+ raise UsageError(f"Wrong expression passed to '-m': {matchexpr}: {e}") from None
+
remaining = []
deselected = []
for item in items:
- if matchmark(item, matchexpr):
+ if expression.evaluate(MarkMatcher.from_item(item)):
remaining.append(item)
else:
deselected.append(item)
@@ -146,12 +260,12 @@ def deselect_by_mark(items, config):
items[:] = remaining
-def pytest_collection_modifyitems(items, config):
+def pytest_collection_modifyitems(items: "List[Item]", config: Config) -> None:
deselect_by_keyword(items, config)
deselect_by_mark(items, config)
-def pytest_configure(config):
+def pytest_configure(config: Config) -> None:
config._store[old_mark_config_key] = MARK_GEN._config
MARK_GEN._config = config
@@ -164,5 +278,5 @@ def pytest_configure(config):
)
-def pytest_unconfigure(config):
+def pytest_unconfigure(config: Config) -> None:
MARK_GEN._config = config._store.get(old_mark_config_key, None)
diff --git a/contrib/python/pytest/py3/_pytest/mark/evaluate.py b/contrib/python/pytest/py3/_pytest/mark/evaluate.py
deleted file mode 100644
index 772baf31b6..0000000000
--- a/contrib/python/pytest/py3/_pytest/mark/evaluate.py
+++ /dev/null
@@ -1,132 +0,0 @@
-import os
-import platform
-import sys
-import traceback
-from typing import Any
-from typing import Dict
-
-from ..outcomes import fail
-from ..outcomes import TEST_OUTCOME
-from _pytest.config import Config
-from _pytest.store import StoreKey
-
-
-evalcache_key = StoreKey[Dict[str, Any]]()
-
-
-def cached_eval(config: Config, expr: str, d: Dict[str, object]) -> Any:
- default = {} # type: Dict[str, object]
- evalcache = config._store.setdefault(evalcache_key, default)
- try:
- return evalcache[expr]
- except KeyError:
- import _pytest._code
-
- exprcode = _pytest._code.compile(expr, mode="eval")
- evalcache[expr] = x = eval(exprcode, d)
- return x
-
-
-class MarkEvaluator:
- def __init__(self, item, name):
- self.item = item
- self._marks = None
- self._mark = None
- self._mark_name = name
-
- def __bool__(self):
- # don't cache here to prevent staleness
- return bool(self._get_marks())
-
- __nonzero__ = __bool__
-
- def wasvalid(self):
- return not hasattr(self, "exc")
-
- def _get_marks(self):
- return list(self.item.iter_markers(name=self._mark_name))
-
- def invalidraise(self, exc):
- raises = self.get("raises")
- if not raises:
- return
- return not isinstance(exc, raises)
-
- def istrue(self):
- try:
- return self._istrue()
- except TEST_OUTCOME:
- self.exc = sys.exc_info()
- if isinstance(self.exc[1], SyntaxError):
- # TODO: Investigate why SyntaxError.offset is Optional, and if it can be None here.
- assert self.exc[1].offset is not None
- msg = [" " * (self.exc[1].offset + 4) + "^"]
- msg.append("SyntaxError: invalid syntax")
- else:
- msg = traceback.format_exception_only(*self.exc[:2])
- fail(
- "Error evaluating %r expression\n"
- " %s\n"
- "%s" % (self._mark_name, self.expr, "\n".join(msg)),
- pytrace=False,
- )
-
- def _getglobals(self):
- d = {"os": os, "sys": sys, "platform": platform, "config": self.item.config}
- if hasattr(self.item, "obj"):
- d.update(self.item.obj.__globals__)
- return d
-
- def _istrue(self):
- if hasattr(self, "result"):
- return self.result
- self._marks = self._get_marks()
-
- if self._marks:
- self.result = False
- for mark in self._marks:
- self._mark = mark
- if "condition" in mark.kwargs:
- args = (mark.kwargs["condition"],)
- else:
- args = mark.args
-
- for expr in args:
- self.expr = expr
- if isinstance(expr, str):
- d = self._getglobals()
- result = cached_eval(self.item.config, expr, d)
- else:
- if "reason" not in mark.kwargs:
- # XXX better be checked at collection time
- msg = (
- "you need to specify reason=STRING "
- "when using booleans as conditions."
- )
- fail(msg)
- result = bool(expr)
- if result:
- self.result = True
- self.reason = mark.kwargs.get("reason", None)
- self.expr = expr
- return self.result
-
- if not args:
- self.result = True
- self.reason = mark.kwargs.get("reason", None)
- return self.result
- return False
-
- def get(self, attr, default=None):
- if self._mark is None:
- return default
- return self._mark.kwargs.get(attr, default)
-
- def getexplanation(self):
- expl = getattr(self, "reason", None) or self.get("reason", None)
- if not expl:
- if not hasattr(self, "expr"):
- return ""
- else:
- return "condition: " + str(self.expr)
- return expl
diff --git a/contrib/python/pytest/py3/_pytest/mark/expression.py b/contrib/python/pytest/py3/_pytest/mark/expression.py
new file mode 100644
index 0000000000..dc3991b10c
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/mark/expression.py
@@ -0,0 +1,221 @@
+r"""Evaluate match expressions, as used by `-k` and `-m`.
+
+The grammar is:
+
+expression: expr? EOF
+expr: and_expr ('or' and_expr)*
+and_expr: not_expr ('and' not_expr)*
+not_expr: 'not' not_expr | '(' expr ')' | ident
+ident: (\w|:|\+|-|\.|\[|\])+
+
+The semantics are:
+
+- Empty expression evaluates to False.
+- ident evaluates to True of False according to a provided matcher function.
+- or/and/not evaluate according to the usual boolean semantics.
+"""
+import ast
+import enum
+import re
+import types
+from typing import Callable
+from typing import Iterator
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import TYPE_CHECKING
+
+import attr
+
+if TYPE_CHECKING:
+ from typing import NoReturn
+
+
+__all__ = [
+ "Expression",
+ "ParseError",
+]
+
+
+class TokenType(enum.Enum):
+ LPAREN = "left parenthesis"
+ RPAREN = "right parenthesis"
+ OR = "or"
+ AND = "and"
+ NOT = "not"
+ IDENT = "identifier"
+ EOF = "end of input"
+
+
+@attr.s(frozen=True, slots=True)
+class Token:
+ type = attr.ib(type=TokenType)
+ value = attr.ib(type=str)
+ pos = attr.ib(type=int)
+
+
+class ParseError(Exception):
+ """The expression contains invalid syntax.
+
+ :param column: The column in the line where the error occurred (1-based).
+ :param message: A description of the error.
+ """
+
+ def __init__(self, column: int, message: str) -> None:
+ self.column = column
+ self.message = message
+
+ def __str__(self) -> str:
+ return f"at column {self.column}: {self.message}"
+
+
+class Scanner:
+ __slots__ = ("tokens", "current")
+
+ def __init__(self, input: str) -> None:
+ self.tokens = self.lex(input)
+ self.current = next(self.tokens)
+
+ def lex(self, input: str) -> Iterator[Token]:
+ pos = 0
+ while pos < len(input):
+ if input[pos] in (" ", "\t"):
+ pos += 1
+ elif input[pos] == "(":
+ yield Token(TokenType.LPAREN, "(", pos)
+ pos += 1
+ elif input[pos] == ")":
+ yield Token(TokenType.RPAREN, ")", pos)
+ pos += 1
+ else:
+ match = re.match(r"(:?\w|:|\+|-|\.|\[|\])+", input[pos:])
+ if match:
+ value = match.group(0)
+ if value == "or":
+ yield Token(TokenType.OR, value, pos)
+ elif value == "and":
+ yield Token(TokenType.AND, value, pos)
+ elif value == "not":
+ yield Token(TokenType.NOT, value, pos)
+ else:
+ yield Token(TokenType.IDENT, value, pos)
+ pos += len(value)
+ else:
+ raise ParseError(
+ pos + 1, 'unexpected character "{}"'.format(input[pos]),
+ )
+ yield Token(TokenType.EOF, "", pos)
+
+ def accept(self, type: TokenType, *, reject: bool = False) -> Optional[Token]:
+ if self.current.type is type:
+ token = self.current
+ if token.type is not TokenType.EOF:
+ self.current = next(self.tokens)
+ return token
+ if reject:
+ self.reject((type,))
+ return None
+
+ def reject(self, expected: Sequence[TokenType]) -> "NoReturn":
+ raise ParseError(
+ self.current.pos + 1,
+ "expected {}; got {}".format(
+ " OR ".join(type.value for type in expected), self.current.type.value,
+ ),
+ )
+
+
+# True, False and None are legal match expression identifiers,
+# but illegal as Python identifiers. To fix this, this prefix
+# is added to identifiers in the conversion to Python AST.
+IDENT_PREFIX = "$"
+
+
+def expression(s: Scanner) -> ast.Expression:
+ if s.accept(TokenType.EOF):
+ ret: ast.expr = ast.NameConstant(False)
+ else:
+ ret = expr(s)
+ s.accept(TokenType.EOF, reject=True)
+ return ast.fix_missing_locations(ast.Expression(ret))
+
+
+def expr(s: Scanner) -> ast.expr:
+ ret = and_expr(s)
+ while s.accept(TokenType.OR):
+ rhs = and_expr(s)
+ ret = ast.BoolOp(ast.Or(), [ret, rhs])
+ return ret
+
+
+def and_expr(s: Scanner) -> ast.expr:
+ ret = not_expr(s)
+ while s.accept(TokenType.AND):
+ rhs = not_expr(s)
+ ret = ast.BoolOp(ast.And(), [ret, rhs])
+ return ret
+
+
+def not_expr(s: Scanner) -> ast.expr:
+ if s.accept(TokenType.NOT):
+ return ast.UnaryOp(ast.Not(), not_expr(s))
+ if s.accept(TokenType.LPAREN):
+ ret = expr(s)
+ s.accept(TokenType.RPAREN, reject=True)
+ return ret
+ ident = s.accept(TokenType.IDENT)
+ if ident:
+ return ast.Name(IDENT_PREFIX + ident.value, ast.Load())
+ s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
+
+
+class MatcherAdapter(Mapping[str, bool]):
+ """Adapts a matcher function to a locals mapping as required by eval()."""
+
+ def __init__(self, matcher: Callable[[str], bool]) -> None:
+ self.matcher = matcher
+
+ def __getitem__(self, key: str) -> bool:
+ return self.matcher(key[len(IDENT_PREFIX) :])
+
+ def __iter__(self) -> Iterator[str]:
+ raise NotImplementedError()
+
+ def __len__(self) -> int:
+ raise NotImplementedError()
+
+
+class Expression:
+ """A compiled match expression as used by -k and -m.
+
+ The expression can be evaulated against different matchers.
+ """
+
+ __slots__ = ("code",)
+
+ def __init__(self, code: types.CodeType) -> None:
+ self.code = code
+
+ @classmethod
+ def compile(self, input: str) -> "Expression":
+ """Compile a match expression.
+
+ :param input: The input expression - one line.
+ """
+ astexpr = expression(Scanner(input))
+ code: types.CodeType = compile(
+ astexpr, filename="<pytest match expression>", mode="eval",
+ )
+ return Expression(code)
+
+ def evaluate(self, matcher: Callable[[str], bool]) -> bool:
+ """Evaluate the match expression.
+
+ :param matcher:
+ Given an identifier, should return whether it matches or not.
+ Should be prepared to handle arbitrary strings as input.
+
+ :returns: Whether the expression matches or not.
+ """
+ ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))
+ return ret
diff --git a/contrib/python/pytest/py3/_pytest/mark/legacy.py b/contrib/python/pytest/py3/_pytest/mark/legacy.py
deleted file mode 100644
index 3d7a194b61..0000000000
--- a/contrib/python/pytest/py3/_pytest/mark/legacy.py
+++ /dev/null
@@ -1,116 +0,0 @@
-"""
-this is a place where we put datastructures used by legacy apis
-we hope to remove
-"""
-import keyword
-from typing import Set
-
-import attr
-
-from _pytest.compat import TYPE_CHECKING
-from _pytest.config import UsageError
-
-if TYPE_CHECKING:
- from _pytest.nodes import Item # noqa: F401 (used in type string)
-
-
-@attr.s
-class MarkMapping:
- """Provides a local mapping for markers where item access
- resolves to True if the marker is present. """
-
- own_mark_names = attr.ib()
-
- @classmethod
- def from_item(cls, item):
- mark_names = {mark.name for mark in item.iter_markers()}
- return cls(mark_names)
-
- def __getitem__(self, name):
- return name in self.own_mark_names
-
-
-@attr.s
-class KeywordMapping:
- """Provides a local mapping for keywords.
- Given a list of names, map any substring of one of these names to True.
- """
-
- _names = attr.ib(type=Set[str])
-
- @classmethod
- def from_item(cls, item: "Item") -> "KeywordMapping":
- mapped_names = set()
-
- # Add the names of the current item and any parent items
- import pytest
-
- for item in item.listchain():
- if not isinstance(item, pytest.Instance):
- mapped_names.add(item.name)
-
- # Add the names added as extra keywords to current or parent items
- mapped_names.update(item.listextrakeywords())
-
- # Add the names attached to the current function through direct assignment
- function_obj = getattr(item, "function", None)
- if function_obj:
- mapped_names.update(function_obj.__dict__)
-
- # add the markers to the keywords as we no longer handle them correctly
- mapped_names.update(mark.name for mark in item.iter_markers())
-
- return cls(mapped_names)
-
- def __getitem__(self, subname: str) -> bool:
- """Return whether subname is included within stored names.
-
- The string inclusion check is case-insensitive.
-
- """
- subname = subname.lower()
- names = (name.lower() for name in self._names)
-
- for name in names:
- if subname in name:
- return True
- return False
-
-
-python_keywords_allowed_list = ["or", "and", "not"]
-
-
-def matchmark(colitem, markexpr):
- """Tries to match on any marker names, attached to the given colitem."""
- try:
- return eval(markexpr, {}, MarkMapping.from_item(colitem))
- except SyntaxError as e:
- raise SyntaxError(str(e) + "\nMarker expression must be valid Python!")
-
-
-def matchkeyword(colitem, keywordexpr):
- """Tries to match given keyword expression to given collector item.
-
- Will match on the name of colitem, including the names of its parents.
- Only matches names of items which are either a :class:`Class` or a
- :class:`Function`.
- Additionally, matches on names in the 'extra_keyword_matches' set of
- any item, as well as names directly assigned to test functions.
- """
- mapping = KeywordMapping.from_item(colitem)
- if " " not in keywordexpr:
- # special case to allow for simple "-k pass" and "-k 1.3"
- return mapping[keywordexpr]
- elif keywordexpr.startswith("not ") and " " not in keywordexpr[4:]:
- return not mapping[keywordexpr[4:]]
- for kwd in keywordexpr.split():
- if keyword.iskeyword(kwd) and kwd not in python_keywords_allowed_list:
- raise UsageError(
- "Python keyword '{}' not accepted in expressions passed to '-k'".format(
- kwd
- )
- )
- try:
- return eval(keywordexpr, {}, mapping)
- except SyntaxError:
- raise UsageError("Wrong expression passed to '-k': {}".format(keywordexpr))
diff --git a/contrib/python/pytest/py3/_pytest/mark/structures.py b/contrib/python/pytest/py3/_pytest/mark/structures.py
index 50ad81baa6..f5736a4c1c 100644
--- a/contrib/python/pytest/py3/_pytest/mark/structures.py
+++ b/contrib/python/pytest/py3/_pytest/mark/structures.py
@@ -1,39 +1,68 @@
+import collections.abc
import inspect
import warnings
-from collections import namedtuple
-from collections.abc import MutableMapping
+from typing import Any
+from typing import Callable
+from typing import Collection
from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import Mapping
+from typing import MutableMapping
+from typing import NamedTuple
from typing import Optional
+from typing import overload
+from typing import Sequence
from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import attr
-from .._code.source import getfslineno
+from .._code import getfslineno
from ..compat import ascii_escaped
+from ..compat import final
from ..compat import NOTSET
+from ..compat import NotSetType
+from _pytest.config import Config
from _pytest.outcomes import fail
from _pytest.warning_types import PytestUnknownMarkWarning
+if TYPE_CHECKING:
+ from ..nodes import Node
+
+
EMPTY_PARAMETERSET_OPTION = "empty_parameter_set_mark"
-def istestfunc(func):
+def istestfunc(func) -> bool:
return (
hasattr(func, "__call__")
and getattr(func, "__name__", "<lambda>") != "<lambda>"
)
-def get_empty_parameterset_mark(config, argnames, func):
+def get_empty_parameterset_mark(
+ config: Config, argnames: Sequence[str], func
+) -> "MarkDecorator":
from ..nodes import Collector
+ fs, lineno = getfslineno(func)
+ reason = "got empty parameter set %r, function %s at %s:%d" % (
+ argnames,
+ func.__name__,
+ fs,
+ lineno,
+ )
+
requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION)
if requested_mark in ("", None, "skip"):
- mark = MARK_GEN.skip
+ mark = MARK_GEN.skip(reason=reason)
elif requested_mark == "xfail":
- mark = MARK_GEN.xfail(run=False)
+ mark = MARK_GEN.xfail(reason=reason, run=False)
elif requested_mark == "fail_at_collect":
f_name = func.__name__
_, lineno = getfslineno(func)
@@ -42,23 +71,30 @@ def get_empty_parameterset_mark(config, argnames, func):
)
else:
raise LookupError(requested_mark)
- fs, lineno = getfslineno(func)
- reason = "got empty parameter set %r, function %s at %s:%d" % (
- argnames,
- func.__name__,
- fs,
- lineno,
- )
- return mark(reason=reason)
+ return mark
-class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
+class ParameterSet(
+ NamedTuple(
+ "ParameterSet",
+ [
+ ("values", Sequence[Union[object, NotSetType]]),
+ ("marks", Collection[Union["MarkDecorator", "Mark"]]),
+ ("id", Optional[str]),
+ ],
+ )
+):
@classmethod
- def param(cls, *values, marks=(), id=None):
+ def param(
+ cls,
+ *values: object,
+ marks: Union["MarkDecorator", Collection[Union["MarkDecorator", "Mark"]]] = (),
+ id: Optional[str] = None,
+ ) -> "ParameterSet":
if isinstance(marks, MarkDecorator):
marks = (marks,)
else:
- assert isinstance(marks, (tuple, list, set))
+ assert isinstance(marks, collections.abc.Collection)
if id is not None:
if not isinstance(id, str):
@@ -69,15 +105,20 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
return cls(values, marks, id)
@classmethod
- def extract_from(cls, parameterset, force_tuple=False):
- """
+ def extract_from(
+ cls,
+ parameterset: Union["ParameterSet", Sequence[object], object],
+ force_tuple: bool = False,
+ ) -> "ParameterSet":
+ """Extract from an object or objects.
+
:param parameterset:
- a legacy style parameterset that may or may not be a tuple,
- and may or may not be wrapped into a mess of mark objects
+ A legacy style parameterset that may or may not be a tuple,
+ and may or may not be wrapped into a mess of mark objects.
:param force_tuple:
- enforce tuple wrapping so single argument tuple values
- don't get decomposed and break tests
+ Enforce tuple wrapping so single argument tuple values
+ don't get decomposed and break tests.
"""
if isinstance(parameterset, cls):
@@ -85,10 +126,20 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
if force_tuple:
return cls.param(parameterset)
else:
- return cls(parameterset, marks=[], id=None)
+ # TODO: Refactor to fix this type-ignore. Currently the following
+ # passes type-checking but crashes:
+ #
+ # @pytest.mark.parametrize(('x', 'y'), [1, 2])
+ # def test_foo(x, y): pass
+ return cls(parameterset, marks=[], id=None) # type: ignore[arg-type]
@staticmethod
- def _parse_parametrize_args(argnames, argvalues, *args, **kwargs):
+ def _parse_parametrize_args(
+ argnames: Union[str, List[str], Tuple[str, ...]],
+ argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
+ *args,
+ **kwargs,
+ ) -> Tuple[Union[List[str], Tuple[str, ...]], bool]:
if not isinstance(argnames, (tuple, list)):
argnames = [x.strip() for x in argnames.split(",") if x.strip()]
force_tuple = len(argnames) == 1
@@ -97,19 +148,29 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
return argnames, force_tuple
@staticmethod
- def _parse_parametrize_parameters(argvalues, force_tuple):
+ def _parse_parametrize_parameters(
+ argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
+ force_tuple: bool,
+ ) -> List["ParameterSet"]:
return [
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
]
@classmethod
- def _for_parametrize(cls, argnames, argvalues, func, config, function_definition):
+ def _for_parametrize(
+ cls,
+ argnames: Union[str, List[str], Tuple[str, ...]],
+ argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
+ func,
+ config: Config,
+ nodeid: str,
+ ) -> Tuple[Union[List[str], Tuple[str, ...]], List["ParameterSet"]]:
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
del argvalues
if parameters:
- # check all parameter sets have the correct number of values
+ # Check all parameter sets have the correct number of values.
for param in parameters:
if len(param.values) != len(argnames):
msg = (
@@ -120,7 +181,7 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
)
fail(
msg.format(
- nodeid=function_definition.nodeid,
+ nodeid=nodeid,
values=param.values,
names=argnames,
names_len=len(argnames),
@@ -129,8 +190,8 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
pytrace=False,
)
else:
- # empty parameter set (likely computed at runtime): create a single
- # parameter set with NOTSET values, with the "empty parameter set" mark applied to it
+ # Empty parameter set (likely computed at runtime): create a single
+ # parameter set with NOTSET values, with the "empty parameter set" mark applied to it.
mark = get_empty_parameterset_mark(config, argnames, func)
parameters.append(
ParameterSet(values=(NOTSET,) * len(argnames), marks=[mark], id=None)
@@ -138,35 +199,39 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
return argnames, parameters
+@final
@attr.s(frozen=True)
class Mark:
- #: name of the mark
+ #: Name of the mark.
name = attr.ib(type=str)
- #: positional arguments of the mark decorator
- args = attr.ib() # List[object]
- #: keyword arguments of the mark decorator
- kwargs = attr.ib() # Dict[str, object]
+ #: Positional arguments of the mark decorator.
+ args = attr.ib(type=Tuple[Any, ...])
+ #: Keyword arguments of the mark decorator.
+ kwargs = attr.ib(type=Mapping[str, Any])
- #: source Mark for ids with parametrize Marks
+ #: Source Mark for ids with parametrize Marks.
_param_ids_from = attr.ib(type=Optional["Mark"], default=None, repr=False)
- #: resolved/generated ids with parametrize Marks
- _param_ids_generated = attr.ib(type=Optional[List[str]], default=None, repr=False)
+ #: Resolved/generated ids with parametrize Marks.
+ _param_ids_generated = attr.ib(
+ type=Optional[Sequence[str]], default=None, repr=False
+ )
- def _has_param_ids(self):
+ def _has_param_ids(self) -> bool:
return "ids" in self.kwargs or len(self.args) >= 4
def combined_with(self, other: "Mark") -> "Mark":
- """
- :param other: the mark to combine with
- :type other: Mark
- :rtype: Mark
+ """Return a new Mark which is a combination of this
+ Mark and another Mark.
+
+ Combines by appending args and merging kwargs.
- combines by appending args and merging the mappings
+ :param Mark other: The mark to combine with.
+ :rtype: Mark
"""
assert self.name == other.name
# Remember source of ids with parametrize Marks.
- param_ids_from = None # type: Optional[Mark]
+ param_ids_from: Optional[Mark] = None
if self.name == "parametrize":
if other._has_param_ids():
param_ids_from = other
@@ -181,13 +246,20 @@ class Mark:
)
+# A generic parameter designating an object to which a Mark may
+# be applied -- a test function (callable) or class.
+# Note: a lambda is not allowed, but this can't be represented.
+_Markable = TypeVar("_Markable", bound=Union[Callable[..., object], type])
+
+
@attr.s
class MarkDecorator:
- """ A decorator for test functions and test classes. When applied
- it will create :class:`Mark` objects which are often created like this::
+ """A decorator for applying a mark on test functions and classes.
+
+ MarkDecorators are created with ``pytest.mark``::
- mark1 = pytest.mark.NAME # simple MarkDecorator
- mark2 = pytest.mark.NAME(name1=value) # parametrized MarkDecorator
+ mark1 = pytest.mark.NAME # Simple MarkDecorator
+ mark2 = pytest.mark.NAME(name1=value) # Parametrized MarkDecorator
and can then be applied as decorators to test functions::
@@ -195,64 +267,75 @@ class MarkDecorator:
def test_function():
pass
- When a MarkDecorator instance is called it does the following:
+ When a MarkDecorator is called it does the following:
1. If called with a single class as its only positional argument and no
- additional keyword arguments, it attaches itself to the class so it
+ additional keyword arguments, it attaches the mark to the class so it
gets applied automatically to all test cases found in that class.
- 2. If called with a single function as its only positional argument and
- no additional keyword arguments, it attaches a MarkInfo object to the
- function, containing all the arguments already stored internally in
- the MarkDecorator.
- 3. When called in any other case, it performs a 'fake construction' call,
- i.e. it returns a new MarkDecorator instance with the original
- MarkDecorator's content updated with the arguments passed to this
- call.
-
- Note: The rules above prevent MarkDecorator objects from storing only a
- single function or class reference as their positional argument with no
- additional keyword or positional arguments.
+ 2. If called with a single function as its only positional argument and
+ no additional keyword arguments, it attaches the mark to the function,
+ containing all the arguments already stored internally in the
+ MarkDecorator.
+
+ 3. When called in any other case, it returns a new MarkDecorator instance
+ with the original MarkDecorator's content updated with the arguments
+ passed to this call.
+
+ Note: The rules above prevent MarkDecorators from storing only a single
+ function or class reference as their positional argument with no
+ additional keyword or positional arguments. You can work around this by
+ using `with_args()`.
"""
- mark = attr.ib(validator=attr.validators.instance_of(Mark))
+ mark = attr.ib(type=Mark, validator=attr.validators.instance_of(Mark))
@property
- def name(self):
- """alias for mark.name"""
+ def name(self) -> str:
+ """Alias for mark.name."""
return self.mark.name
@property
- def args(self):
- """alias for mark.args"""
+ def args(self) -> Tuple[Any, ...]:
+ """Alias for mark.args."""
return self.mark.args
@property
- def kwargs(self):
- """alias for mark.kwargs"""
+ def kwargs(self) -> Mapping[str, Any]:
+ """Alias for mark.kwargs."""
return self.mark.kwargs
@property
- def markname(self):
+ def markname(self) -> str:
return self.name # for backward-compat (2.4.1 had this attr)
- def __repr__(self):
- return "<MarkDecorator {!r}>".format(self.mark)
+ def __repr__(self) -> str:
+ return f"<MarkDecorator {self.mark!r}>"
- def with_args(self, *args, **kwargs):
- """ return a MarkDecorator with extra arguments added
+ def with_args(self, *args: object, **kwargs: object) -> "MarkDecorator":
+ """Return a MarkDecorator with extra arguments added.
- unlike call this can be used even if the sole argument is a callable/class
+ Unlike calling the MarkDecorator, with_args() can be used even
+ if the sole argument is a callable/class.
- :return: MarkDecorator
+ :rtype: MarkDecorator
"""
-
mark = Mark(self.name, args, kwargs)
return self.__class__(self.mark.combined_with(mark))
- def __call__(self, *args, **kwargs):
- """ if passed a single callable argument: decorate it with mark info.
- otherwise add *args/**kwargs in-place to mark information. """
+ # Type ignored because the overloads overlap with an incompatible
+ # return type. Not much we can do about that. Thankfully mypy picks
+ # the first match so it works out even if we break the rules.
+ @overload
+ def __call__(self, arg: _Markable) -> _Markable: # type: ignore[misc]
+ pass
+
+ @overload
+ def __call__(self, *args: object, **kwargs: object) -> "MarkDecorator":
+ pass
+
+ def __call__(self, *args: object, **kwargs: object):
+ """Call the MarkDecorator."""
if args and not kwargs:
func = args[0]
is_class = inspect.isclass(func)
@@ -262,10 +345,8 @@ class MarkDecorator:
return self.with_args(*args, **kwargs)
-def get_unpacked_marks(obj):
- """
- obtain the unpacked marks that are stored on an object
- """
+def get_unpacked_marks(obj) -> List[Mark]:
+ """Obtain the unpacked marks that are stored on an object."""
mark_list = getattr(obj, "pytestmark", [])
if not isinstance(mark_list, list):
mark_list = [mark_list]
@@ -273,10 +354,9 @@ def get_unpacked_marks(obj):
def normalize_mark_list(mark_list: Iterable[Union[Mark, MarkDecorator]]) -> List[Mark]:
- """
- normalizes marker decorating helpers to mark objects
+ """Normalize marker decorating helpers to mark objects.
- :type mark_list: List[Union[Mark, Markdecorator]]
+ :type List[Union[Mark, Markdecorator]] mark_list:
:rtype: List[Mark]
"""
extracted = [
@@ -284,34 +364,118 @@ def normalize_mark_list(mark_list: Iterable[Union[Mark, MarkDecorator]]) -> List
] # unpack MarkDecorator
for mark in extracted:
if not isinstance(mark, Mark):
- raise TypeError("got {!r} instead of Mark".format(mark))
+ raise TypeError(f"got {mark!r} instead of Mark")
return [x for x in extracted if isinstance(x, Mark)]
-def store_mark(obj, mark):
- """store a Mark on an object
- this is used to implement the Mark declarations/decorators correctly
+def store_mark(obj, mark: Mark) -> None:
+ """Store a Mark on an object.
+
+ This is used to implement the Mark declarations/decorators correctly.
"""
assert isinstance(mark, Mark), mark
- # always reassign name to avoid updating pytestmark
- # in a reference that was only borrowed
+ # Always reassign name to avoid updating pytestmark in a reference that
+ # was only borrowed.
obj.pytestmark = get_unpacked_marks(obj) + [mark]
+# Typing for builtin pytest marks. This is cheating; it gives builtin marks
+# special privilege, and breaks modularity. But practicality beats purity...
+if TYPE_CHECKING:
+ from _pytest.fixtures import _Scope
+
+ class _SkipMarkDecorator(MarkDecorator):
+ @overload # type: ignore[override,misc]
+ def __call__(self, arg: _Markable) -> _Markable:
+ ...
+
+ @overload
+ def __call__(self, reason: str = ...) -> "MarkDecorator":
+ ...
+
+ class _SkipifMarkDecorator(MarkDecorator):
+ def __call__( # type: ignore[override]
+ self,
+ condition: Union[str, bool] = ...,
+ *conditions: Union[str, bool],
+ reason: str = ...,
+ ) -> MarkDecorator:
+ ...
+
+ class _XfailMarkDecorator(MarkDecorator):
+ @overload # type: ignore[override,misc]
+ def __call__(self, arg: _Markable) -> _Markable:
+ ...
+
+ @overload
+ def __call__(
+ self,
+ condition: Union[str, bool] = ...,
+ *conditions: Union[str, bool],
+ reason: str = ...,
+ run: bool = ...,
+ raises: Union[Type[BaseException], Tuple[Type[BaseException], ...]] = ...,
+ strict: bool = ...,
+ ) -> MarkDecorator:
+ ...
+
+ class _ParametrizeMarkDecorator(MarkDecorator):
+ def __call__( # type: ignore[override]
+ self,
+ argnames: Union[str, List[str], Tuple[str, ...]],
+ argvalues: Iterable[Union[ParameterSet, Sequence[object], object]],
+ *,
+ indirect: Union[bool, Sequence[str]] = ...,
+ ids: Optional[
+ Union[
+ Iterable[Union[None, str, float, int, bool]],
+ Callable[[Any], Optional[object]],
+ ]
+ ] = ...,
+ scope: Optional[_Scope] = ...,
+ ) -> MarkDecorator:
+ ...
+
+ class _UsefixturesMarkDecorator(MarkDecorator):
+ def __call__( # type: ignore[override]
+ self, *fixtures: str
+ ) -> MarkDecorator:
+ ...
+
+ class _FilterwarningsMarkDecorator(MarkDecorator):
+ def __call__( # type: ignore[override]
+ self, *filters: str
+ ) -> MarkDecorator:
+ ...
+
+
+@final
class MarkGenerator:
- """ Factory for :class:`MarkDecorator` objects - exposed as
- a ``pytest.mark`` singleton instance. Example::
+ """Factory for :class:`MarkDecorator` objects - exposed as
+ a ``pytest.mark`` singleton instance.
+
+ Example::
import pytest
+
@pytest.mark.slowtest
def test_function():
pass
- will set a 'slowtest' :class:`MarkInfo` object
- on the ``test_function`` object. """
+ applies a 'slowtest' :class:`Mark` on ``test_function``.
+ """
+
+ _config: Optional[Config] = None
+ _markers: Set[str] = set()
- _config = None
- _markers = set() # type: Set[str]
+ # See TYPE_CHECKING above.
+ if TYPE_CHECKING:
+ skip: _SkipMarkDecorator
+ skipif: _SkipifMarkDecorator
+ xfail: _XfailMarkDecorator
+ parametrize: _ParametrizeMarkDecorator
+ usefixtures: _UsefixturesMarkDecorator
+ filterwarnings: _FilterwarningsMarkDecorator
def __getattr__(self, name: str) -> MarkDecorator:
if name[0] == "_":
@@ -335,21 +499,21 @@ class MarkGenerator:
# If the name is not in the set of known marks after updating,
# then it really is time to issue a warning or an error.
if name not in self._markers:
- if self._config.option.strict_markers:
+ if self._config.option.strict_markers or self._config.option.strict:
fail(
- "{!r} not found in `markers` configuration option".format(name),
+ f"{name!r} not found in `markers` configuration option",
pytrace=False,
)
# Raise a specific error for common misspellings of "parametrize".
if name in ["parameterize", "parametrise", "parameterise"]:
__tracebackhide__ = True
- fail("Unknown '{}' mark, did you mean 'parametrize'?".format(name))
+ fail(f"Unknown '{name}' mark, did you mean 'parametrize'?")
warnings.warn(
"Unknown pytest.mark.%s - is this a typo? You can register "
"custom marks to avoid this warning - for details, see "
- "https://docs.pytest.org/en/latest/mark.html" % name,
+ "https://docs.pytest.org/en/stable/mark.html" % name,
PytestUnknownMarkWarning,
2,
)
@@ -360,13 +524,14 @@ class MarkGenerator:
MARK_GEN = MarkGenerator()
-class NodeKeywords(MutableMapping):
- def __init__(self, node):
+@final
+class NodeKeywords(MutableMapping[str, Any]):
+ def __init__(self, node: "Node") -> None:
self.node = node
self.parent = node.parent
self._markers = {node.name: True}
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> Any:
try:
return self._markers[key]
except KeyError:
@@ -374,24 +539,24 @@ class NodeKeywords(MutableMapping):
raise
return self.parent.keywords[key]
- def __setitem__(self, key, value):
+ def __setitem__(self, key: str, value: Any) -> None:
self._markers[key] = value
- def __delitem__(self, key):
+ def __delitem__(self, key: str) -> None:
raise ValueError("cannot delete key in keywords dict")
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
seen = self._seen()
return iter(seen)
- def _seen(self):
+ def _seen(self) -> Set[str]:
seen = set(self._markers)
if self.parent is not None:
seen.update(self.parent.keywords)
return seen
- def __len__(self):
+ def __len__(self) -> int:
return len(self._seen())
- def __repr__(self):
- return "<NodeKeywords for node {}>".format(self.node)
+ def __repr__(self) -> str:
+ return f"<NodeKeywords for node {self.node}>"
diff --git a/contrib/python/pytest/py3/_pytest/monkeypatch.py b/contrib/python/pytest/py3/_pytest/monkeypatch.py
index ce1c0f6510..a052f693ac 100644
--- a/contrib/python/pytest/py3/_pytest/monkeypatch.py
+++ b/contrib/python/pytest/py3/_pytest/monkeypatch.py
@@ -1,22 +1,37 @@
-""" monkeypatching and mocking functionality. """
+"""Monkeypatching and mocking functionality."""
import os
import re
import sys
import warnings
from contextlib import contextmanager
+from pathlib import Path
+from typing import Any
from typing import Generator
-
-import pytest
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
+
+from _pytest.compat import final
from _pytest.fixtures import fixture
-from _pytest.pathlib import Path
+from _pytest.warning_types import PytestWarning
RE_IMPORT_ERROR_NAME = re.compile(r"^No module named (.*)$")
+K = TypeVar("K")
+V = TypeVar("V")
+
+
@fixture
-def monkeypatch():
- """The returned ``monkeypatch`` fixture provides these
- helper methods to modify objects, dictionaries or os.environ::
+def monkeypatch() -> Generator["MonkeyPatch", None, None]:
+ """A convenient fixture for monkey-patching.
+
+ The fixture provides these methods to modify objects, dictionaries or
+ os.environ::
monkeypatch.setattr(obj, name, value, raising=True)
monkeypatch.delattr(obj, name, raising=True)
@@ -27,18 +42,17 @@ def monkeypatch():
monkeypatch.syspath_prepend(path)
monkeypatch.chdir(path)
- All modifications will be undone after the requesting
- test function or fixture has finished. The ``raising``
- parameter determines if a KeyError or AttributeError
- will be raised if the set/deletion operation has no target.
+ All modifications will be undone after the requesting test function or
+ fixture has finished. The ``raising`` parameter determines if a KeyError
+ or AttributeError will be raised if the set/deletion operation has no target.
"""
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
-def resolve(name):
- # simplified from zope.dottedname
+def resolve(name: str) -> object:
+ # Simplified from zope.dottedname.
parts = name.split(".")
used = parts.pop(0)
@@ -51,38 +65,35 @@ def resolve(name):
pass
else:
continue
- # we use explicit un-nesting of the handling block in order
- # to avoid nested exceptions on python 3
+ # We use explicit un-nesting of the handling block in order
+ # to avoid nested exceptions.
try:
__import__(used)
except ImportError as ex:
- # str is used for py2 vs py3
expected = str(ex).split()[-1]
if expected == used:
raise
else:
- raise ImportError("import error in {}: {}".format(used, ex))
+ raise ImportError(f"import error in {used}: {ex}") from ex
found = annotated_getattr(found, part, used)
return found
-def annotated_getattr(obj, name, ann):
+def annotated_getattr(obj: object, name: str, ann: str) -> object:
try:
obj = getattr(obj, name)
- except AttributeError:
+ except AttributeError as e:
raise AttributeError(
"{!r} object at {} has no attribute {!r}".format(
type(obj).__name__, ann, name
)
- )
+ ) from e
return obj
-def derive_importpath(import_path, raising):
- if not isinstance(import_path, str) or "." not in import_path:
- raise TypeError(
- "must be absolute import path string, not {!r}".format(import_path)
- )
+def derive_importpath(import_path: str, raising: bool) -> Tuple[str, object]:
+ if not isinstance(import_path, str) or "." not in import_path: # type: ignore[unreachable]
+ raise TypeError(f"must be absolute import path string, not {import_path!r}")
module, attr = import_path.rsplit(".", 1)
target = resolve(module)
if raising:
@@ -91,32 +102,46 @@ def derive_importpath(import_path, raising):
class Notset:
- def __repr__(self):
+ def __repr__(self) -> str:
return "<notset>"
notset = Notset()
+@final
class MonkeyPatch:
- """ Object returned by the ``monkeypatch`` fixture keeping a record of setattr/item/env/syspath changes.
+ """Helper to conveniently monkeypatch attributes/items/environment
+ variables/syspath.
+
+ Returned by the :fixture:`monkeypatch` fixture.
+
+ :versionchanged:: 6.2
+ Can now also be used directly as `pytest.MonkeyPatch()`, for when
+ the fixture is not available. In this case, use
+ :meth:`with MonkeyPatch.context() as mp: <context>` or remember to call
+ :meth:`undo` explicitly.
"""
- def __init__(self):
- self._setattr = []
- self._setitem = []
- self._cwd = None
- self._savesyspath = None
+ def __init__(self) -> None:
+ self._setattr: List[Tuple[object, str, object]] = []
+ self._setitem: List[Tuple[MutableMapping[Any, Any], object, object]] = ([])
+ self._cwd: Optional[str] = None
+ self._savesyspath: Optional[List[str]] = None
+ @classmethod
@contextmanager
- def context(self) -> Generator["MonkeyPatch", None, None]:
- """
- Context manager that returns a new :class:`MonkeyPatch` object which
- undoes any patching done inside the ``with`` block upon exit:
+ def context(cls) -> Generator["MonkeyPatch", None, None]:
+ """Context manager that returns a new :class:`MonkeyPatch` object
+ which undoes any patching done inside the ``with`` block upon exit.
+
+ Example:
.. code-block:: python
import functools
+
+
def test_partial(monkeypatch):
with monkeypatch.context() as m:
m.setattr(functools, "partial", 3)
@@ -125,30 +150,46 @@ class MonkeyPatch:
such as mocking ``stdlib`` functions that might break pytest itself if mocked (for examples
of this see `#3290 <https://github.com/pytest-dev/pytest/issues/3290>`_.
"""
- m = MonkeyPatch()
+ m = cls()
try:
yield m
finally:
m.undo()
- def setattr(self, target, name, value=notset, raising=True):
- """ Set attribute value on target, memorizing the old value.
- By default raise AttributeError if the attribute did not exist.
+ @overload
+ def setattr(
+ self, target: str, name: object, value: Notset = ..., raising: bool = ...,
+ ) -> None:
+ ...
+
+ @overload
+ def setattr(
+ self, target: object, name: str, value: object, raising: bool = ...,
+ ) -> None:
+ ...
+
+ def setattr(
+ self,
+ target: Union[str, object],
+ name: Union[object, str],
+ value: object = notset,
+ raising: bool = True,
+ ) -> None:
+ """Set attribute value on target, memorizing the old value.
For convenience you can specify a string as ``target`` which
will be interpreted as a dotted import path, with the last part
- being the attribute name. Example:
+ being the attribute name. For example,
``monkeypatch.setattr("os.getcwd", lambda: "/")``
would set the ``getcwd`` function of the ``os`` module.
- The ``raising`` value determines if the setattr should fail
- if the attribute is not already present (defaults to True
- which means it will raise).
+ Raises AttributeError if the attribute does not exist, unless
+ ``raising`` is set to False.
"""
__tracebackhide__ = True
import inspect
- if value is notset:
+ if isinstance(value, Notset):
if not isinstance(target, str):
raise TypeError(
"use setattr(target, name, value) or "
@@ -157,10 +198,17 @@ class MonkeyPatch:
)
value = name
name, target = derive_importpath(target, raising)
+ else:
+ if not isinstance(name, str):
+ raise TypeError(
+ "use setattr(target, name, value) with name being a string or "
+ "setattr(target, value) with target being a dotted "
+ "import string"
+ )
oldval = getattr(target, name, notset)
if raising and oldval is notset:
- raise AttributeError("{!r} has no attribute {!r}".format(target, name))
+ raise AttributeError(f"{target!r} has no attribute {name!r}")
# avoid class descriptors like staticmethod/classmethod
if inspect.isclass(target):
@@ -168,21 +216,25 @@ class MonkeyPatch:
self._setattr.append((target, name, oldval))
setattr(target, name, value)
- def delattr(self, target, name=notset, raising=True):
- """ Delete attribute ``name`` from ``target``, by default raise
- AttributeError it the attribute did not previously exist.
+ def delattr(
+ self,
+ target: Union[object, str],
+ name: Union[str, Notset] = notset,
+ raising: bool = True,
+ ) -> None:
+ """Delete attribute ``name`` from ``target``.
If no ``name`` is specified and ``target`` is a string
it will be interpreted as a dotted import path with the
last part being the attribute name.
- If ``raising`` is set to False, no exception will be raised if the
- attribute is missing.
+ Raises AttributeError it the attribute does not exist, unless
+ ``raising`` is set to False.
"""
__tracebackhide__ = True
import inspect
- if name is notset:
+ if isinstance(name, Notset):
if not isinstance(target, str):
raise TypeError(
"use delattr(target, name) or "
@@ -202,16 +254,16 @@ class MonkeyPatch:
self._setattr.append((target, name, oldval))
delattr(target, name)
- def setitem(self, dic, name, value):
- """ Set dictionary entry ``name`` to value. """
+ def setitem(self, dic: MutableMapping[K, V], name: K, value: V) -> None:
+ """Set dictionary entry ``name`` to value."""
self._setitem.append((dic, name, dic.get(name, notset)))
dic[name] = value
- def delitem(self, dic, name, raising=True):
- """ Delete ``name`` from dict. Raise KeyError if it doesn't exist.
+ def delitem(self, dic: MutableMapping[K, V], name: K, raising: bool = True) -> None:
+ """Delete ``name`` from dict.
- If ``raising`` is set to False, no exception will be raised if the
- key is missing.
+ Raises ``KeyError`` if it doesn't exist, unless ``raising`` is set to
+ False.
"""
if name not in dic:
if raising:
@@ -220,13 +272,16 @@ class MonkeyPatch:
self._setitem.append((dic, name, dic.get(name, notset)))
del dic[name]
- def setenv(self, name, value, prepend=None):
- """ Set environment variable ``name`` to ``value``. If ``prepend``
- is a character, read the current environment variable value
- and prepend the ``value`` adjoined with the ``prepend`` character."""
+ def setenv(self, name: str, value: str, prepend: Optional[str] = None) -> None:
+ """Set environment variable ``name`` to ``value``.
+
+ If ``prepend`` is a character, read the current environment variable
+ value and prepend the ``value`` adjoined with the ``prepend``
+ character.
+ """
if not isinstance(value, str):
- warnings.warn(
- pytest.PytestWarning(
+ warnings.warn( # type: ignore[unreachable]
+ PytestWarning(
"Value of environment variable {name} type should be str, but got "
"{value!r} (type: {type}); converted to str implicitly".format(
name=name, value=value, type=type(value).__name__
@@ -239,17 +294,17 @@ class MonkeyPatch:
value = value + prepend + os.environ[name]
self.setitem(os.environ, name, value)
- def delenv(self, name, raising=True):
- """ Delete ``name`` from the environment. Raise KeyError if it does
- not exist.
+ def delenv(self, name: str, raising: bool = True) -> None:
+ """Delete ``name`` from the environment.
- If ``raising`` is set to False, no exception will be raised if the
- environment variable is missing.
+ Raises ``KeyError`` if it does not exist, unless ``raising`` is set to
+ False.
"""
- self.delitem(os.environ, name, raising=raising)
+ environ: MutableMapping[str, str] = os.environ
+ self.delitem(environ, name, raising=raising)
- def syspath_prepend(self, path):
- """ Prepend ``path`` to ``sys.path`` list of import locations. """
+ def syspath_prepend(self, path) -> None:
+ """Prepend ``path`` to ``sys.path`` list of import locations."""
from pkg_resources import fixup_namespace_packages
if self._savesyspath is None:
@@ -270,8 +325,9 @@ class MonkeyPatch:
invalidate_caches()
- def chdir(self, path):
- """ Change the current working directory to the specified path.
+ def chdir(self, path) -> None:
+ """Change the current working directory to the specified path.
+
Path can be a string or a py.path.local object.
"""
if self._cwd is None:
@@ -279,15 +335,16 @@ class MonkeyPatch:
if hasattr(path, "chdir"):
path.chdir()
elif isinstance(path, Path):
- # modern python uses the fspath protocol here LEGACY
+ # Modern python uses the fspath protocol here LEGACY
os.chdir(str(path))
else:
os.chdir(path)
- def undo(self):
- """ Undo previous changes. This call consumes the
- undo stack. Calling it a second time has no effect unless
- you do more monkeypatching after the undo call.
+ def undo(self) -> None:
+ """Undo previous changes.
+
+ This call consumes the undo stack. Calling it a second time has no
+ effect unless you do more monkeypatching after the undo call.
There is generally no need to call `undo()`, since it is
called automatically during tear-down.
@@ -304,14 +361,14 @@ class MonkeyPatch:
else:
delattr(obj, name)
self._setattr[:] = []
- for dictionary, name, value in reversed(self._setitem):
+ for dictionary, key, value in reversed(self._setitem):
if value is notset:
try:
- del dictionary[name]
+ del dictionary[key]
except KeyError:
- pass # was already deleted, so we have the desired state
+ pass # Was already deleted, so we have the desired state.
else:
- dictionary[name] = value
+ dictionary[key] = value
self._setitem[:] = []
if self._savesyspath is not None:
sys.path[:] = self._savesyspath
diff --git a/contrib/python/pytest/py3/_pytest/nodes.py b/contrib/python/pytest/py3/_pytest/nodes.py
index 6f22a8daaa..27434fb6a6 100644
--- a/contrib/python/pytest/py3/_pytest/nodes.py
+++ b/contrib/python/pytest/py3/_pytest/nodes.py
@@ -1,121 +1,143 @@
import os
import warnings
-from functools import lru_cache
-from typing import Any
-from typing import Dict
+from pathlib import Path
+from typing import Callable
+from typing import Iterable
+from typing import Iterator
from typing import List
from typing import Optional
+from typing import overload
from typing import Set
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import py
import _pytest._code
-from _pytest._code.code import ExceptionChainRepr
+from _pytest._code import getfslineno
from _pytest._code.code import ExceptionInfo
-from _pytest._code.code import ReprExceptionInfo
-from _pytest._code.source import getfslineno
+from _pytest._code.code import TerminalRepr
from _pytest.compat import cached_property
-from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import ConftestImportFailure
-from _pytest.config import PytestPluginManager
-from _pytest.deprecated import NODE_USE_FROM_PARENT
-from _pytest.fixtures import FixtureDef
-from _pytest.fixtures import FixtureLookupError
-from _pytest.fixtures import FixtureLookupErrorRepr
+from _pytest.deprecated import FSCOLLECTOR_GETHOOKPROXY_ISINITPATH
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import NodeKeywords
from _pytest.outcomes import fail
-from _pytest.pathlib import Path
+from _pytest.pathlib import absolutepath
from _pytest.store import Store
if TYPE_CHECKING:
# Imported here due to circular import.
- from _pytest.main import Session # noqa: F401
+ from _pytest.main import Session
+ from _pytest._code.code import _TracebackStyle
+
SEP = "/"
tracebackcutdir = py.path.local(_pytest.__file__).dirpath()
-@lru_cache(maxsize=None)
-def _splitnode(nodeid):
- """Split a nodeid into constituent 'parts'.
+def iterparentnodeids(nodeid: str) -> Iterator[str]:
+ """Return the parent node IDs of a given node ID, inclusive.
- Node IDs are strings, and can be things like:
- ''
- 'testing/code'
- 'testing/code/test_excinfo.py'
- 'testing/code/test_excinfo.py::TestFormattedExcinfo'
+ For the node ID
- Return values are lists e.g.
- []
- ['testing', 'code']
- ['testing', 'code', 'test_excinfo.py']
- ['testing', 'code', 'test_excinfo.py', 'TestFormattedExcinfo']
- """
- if nodeid == "":
- # If there is no root node at all, return an empty list so the caller's logic can remain sane
- return ()
- parts = nodeid.split(SEP)
- # Replace single last element 'test_foo.py::Bar' with multiple elements 'test_foo.py', 'Bar'
- parts[-1:] = parts[-1].split("::")
- # Convert parts into a tuple to avoid possible errors with caching of a mutable type
- return tuple(parts)
+ "testing/code/test_excinfo.py::TestFormattedExcinfo::test_repr_source"
+ the result would be
-def ischildnode(baseid, nodeid):
- """Return True if the nodeid is a child node of the baseid.
+ ""
+ "testing"
+ "testing/code"
+ "testing/code/test_excinfo.py"
+ "testing/code/test_excinfo.py::TestFormattedExcinfo"
+ "testing/code/test_excinfo.py::TestFormattedExcinfo::test_repr_source"
- E.g. 'foo/bar::Baz' is a child of 'foo', 'foo/bar' and 'foo/bar::Baz', but not of 'foo/blorp'
+ Note that :: parts are only considered at the last / component.
"""
- base_parts = _splitnode(baseid)
- node_parts = _splitnode(nodeid)
- if len(node_parts) < len(base_parts):
- return False
- return node_parts[: len(base_parts)] == base_parts
+ pos = 0
+ sep = SEP
+ yield ""
+ while True:
+ at = nodeid.find(sep, pos)
+ if at == -1 and sep == SEP:
+ sep = "::"
+ elif at == -1:
+ if nodeid:
+ yield nodeid
+ break
+ else:
+ if at:
+ yield nodeid[:at]
+ pos = at + len(sep)
+
+
+_NodeType = TypeVar("_NodeType", bound="Node")
class NodeMeta(type):
def __call__(self, *k, **kw):
- warnings.warn(NODE_USE_FROM_PARENT.format(name=self.__name__), stacklevel=2)
- return super().__call__(*k, **kw)
+ msg = (
+ "Direct construction of {name} has been deprecated, please use {name}.from_parent.\n"
+ "See "
+ "https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent"
+ " for more details."
+ ).format(name=self.__name__)
+ fail(msg, pytrace=False)
def _create(self, *k, **kw):
return super().__call__(*k, **kw)
class Node(metaclass=NodeMeta):
- """ base class for Collector and Item the test collection tree.
- Collector subclasses have children, Items are terminal nodes."""
+ """Base class for Collector and Item, the components of the test
+ collection tree.
+
+ Collector subclasses have children; Items are leaf nodes.
+ """
+
+ # Use __slots__ to make attribute access faster.
+ # Note that __dict__ is still available.
+ __slots__ = (
+ "name",
+ "parent",
+ "config",
+ "session",
+ "fspath",
+ "_nodeid",
+ "_store",
+ "__dict__",
+ )
def __init__(
self,
name: str,
- parent: Optional["Node"] = None,
+ parent: "Optional[Node]" = None,
config: Optional[Config] = None,
- session: Optional["Session"] = None,
+ session: "Optional[Session]" = None,
fspath: Optional[py.path.local] = None,
nodeid: Optional[str] = None,
) -> None:
- #: a unique name within the scope of the parent node
+ #: A unique name within the scope of the parent node.
self.name = name
- #: the parent collector node.
+ #: The parent collector node.
self.parent = parent
- #: the pytest config object
+ #: The pytest config object.
if config:
- self.config = config
+ self.config: Config = config
else:
if not parent:
raise TypeError("config or parent must be provided")
self.config = parent.config
- #: the session this node is part of
+ #: The pytest session this node is part of.
if session:
self.session = session
else:
@@ -123,20 +145,17 @@ class Node(metaclass=NodeMeta):
raise TypeError("session or parent must be provided")
self.session = parent.session
- #: filesystem path where this node was collected from (can be None)
+ #: Filesystem path where this node was collected from (can be None).
self.fspath = fspath or getattr(parent, "fspath", None)
- #: keywords/markers collected from all scopes
+ #: Keywords/markers collected from all scopes.
self.keywords = NodeKeywords(self)
- #: the marker objects belonging to this node
- self.own_markers = [] # type: List[Mark]
-
- #: allow adding of extra keywords to use for matching
- self.extra_keyword_matches = set() # type: Set[str]
+ #: The marker objects belonging to this node.
+ self.own_markers: List[Mark] = []
- # used for storing artificial fixturedefs for direct parametrization
- self._name2pseudofixturedef = {} # type: Dict[str, FixtureDef]
+ #: Allow adding of extra keywords to use for matching.
+ self.extra_keyword_matches: Set[str] = set()
if nodeid is not None:
assert "::()" not in nodeid
@@ -154,15 +173,15 @@ class Node(metaclass=NodeMeta):
@classmethod
def from_parent(cls, parent: "Node", **kw):
- """
- Public Constructor for Nodes
+ """Public constructor for Nodes.
This indirection got introduced in order to enable removing
the fragile logic from the node constructors.
- Subclasses can use ``super().from_parent(...)`` when overriding the construction
+ Subclasses can use ``super().from_parent(...)`` when overriding the
+ construction.
- :param parent: the parent node of this test Node
+ :param parent: The parent node of this Node.
"""
if "config" in kw:
raise TypeError("config is not a valid argument for from_parent")
@@ -172,64 +191,67 @@ class Node(metaclass=NodeMeta):
@property
def ihook(self):
- """ fspath sensitive hook proxy used to call pytest hooks"""
+ """fspath-sensitive hook proxy used to call pytest hooks."""
return self.session.gethookproxy(self.fspath)
- def __repr__(self):
+ def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None))
- def warn(self, warning):
- """Issue a warning for this item.
+ def warn(self, warning: Warning) -> None:
+ """Issue a warning for this Node.
- Warnings will be displayed after the test session, unless explicitly suppressed
+ Warnings will be displayed after the test session, unless explicitly suppressed.
- :param Warning warning: the warning instance to issue. Must be a subclass of PytestWarning.
+ :param Warning warning:
+ The warning instance to issue.
- :raise ValueError: if ``warning`` instance is not a subclass of PytestWarning.
+ :raises ValueError: If ``warning`` instance is not a subclass of Warning.
Example usage:
.. code-block:: python
node.warn(PytestWarning("some message"))
+ node.warn(UserWarning("some message"))
+ .. versionchanged:: 6.2
+ Any subclass of :class:`Warning` is now accepted, rather than only
+ :class:`PytestWarning <pytest.PytestWarning>` subclasses.
"""
- from _pytest.warning_types import PytestWarning
-
- if not isinstance(warning, PytestWarning):
+ # enforce type checks here to avoid getting a generic type error later otherwise.
+ if not isinstance(warning, Warning):
raise ValueError(
- "warning must be an instance of PytestWarning or subclass, got {!r}".format(
+ "warning must be an instance of Warning or subclass, got {!r}".format(
warning
)
)
path, lineno = get_fslocation_from_item(self)
+ assert lineno is not None
warnings.warn_explicit(
- warning,
- category=None,
- filename=str(path),
- lineno=lineno + 1 if lineno is not None else None,
+ warning, category=None, filename=str(path), lineno=lineno + 1,
)
- # methods for ordering nodes
+ # Methods for ordering nodes.
+
@property
- def nodeid(self):
- """ a ::-separated string denoting its collection tree address. """
+ def nodeid(self) -> str:
+ """A ::-separated string denoting its collection tree address."""
return self._nodeid
- def __hash__(self):
- return hash(self.nodeid)
+ def __hash__(self) -> int:
+ return hash(self._nodeid)
- def setup(self):
+ def setup(self) -> None:
pass
- def teardown(self):
+ def teardown(self) -> None:
pass
- def listchain(self):
- """ return list of all parent collectors up to self,
- starting from root of collection tree. """
+ def listchain(self) -> List["Node"]:
+ """Return list of all parent collectors up to self, starting from
+ the root of collection tree."""
chain = []
- item = self # type: Optional[Node]
+ item: Optional[Node] = self
while item is not None:
chain.append(item)
item = item.parent
@@ -239,12 +261,10 @@ class Node(metaclass=NodeMeta):
def add_marker(
self, marker: Union[str, MarkDecorator], append: bool = True
) -> None:
- """dynamically add a marker object to the node.
+ """Dynamically add a marker object to the node.
- :type marker: ``str`` or ``pytest.mark.*`` object
- :param marker:
- ``append=True`` whether to append the marker,
- if ``False`` insert at position ``0``.
+ :param append:
+ Whether to append the marker, or prepend it.
"""
from _pytest.mark import MARK_GEN
@@ -254,78 +274,93 @@ class Node(metaclass=NodeMeta):
marker_ = getattr(MARK_GEN, marker)
else:
raise ValueError("is not a string or pytest.mark.* Marker")
- self.keywords[marker_.name] = marker
+ self.keywords[marker_.name] = marker_
if append:
self.own_markers.append(marker_.mark)
else:
self.own_markers.insert(0, marker_.mark)
- def iter_markers(self, name=None):
- """
- :param name: if given, filter the results by the name attribute
+ def iter_markers(self, name: Optional[str] = None) -> Iterator[Mark]:
+ """Iterate over all markers of the node.
- iterate over all markers of the node
+ :param name: If given, filter the results by the name attribute.
"""
return (x[1] for x in self.iter_markers_with_node(name=name))
- def iter_markers_with_node(self, name=None):
- """
- :param name: if given, filter the results by the name attribute
+ def iter_markers_with_node(
+ self, name: Optional[str] = None
+ ) -> Iterator[Tuple["Node", Mark]]:
+ """Iterate over all markers of the node.
- iterate over all markers of the node
- returns sequence of tuples (node, mark)
+ :param name: If given, filter the results by the name attribute.
+ :returns: An iterator of (node, mark) tuples.
"""
for node in reversed(self.listchain()):
for mark in node.own_markers:
if name is None or getattr(mark, "name", None) == name:
yield node, mark
- def get_closest_marker(self, name, default=None):
- """return the first marker matching the name, from closest (for example function) to farther level (for example
- module level).
+ @overload
+ def get_closest_marker(self, name: str) -> Optional[Mark]:
+ ...
+
+ @overload
+ def get_closest_marker(self, name: str, default: Mark) -> Mark:
+ ...
+
+ def get_closest_marker(
+ self, name: str, default: Optional[Mark] = None
+ ) -> Optional[Mark]:
+ """Return the first marker matching the name, from closest (for
+ example function) to farther level (for example module level).
- :param default: fallback return value of no marker was found
- :param name: name to filter by
+ :param default: Fallback return value if no marker was found.
+ :param name: Name to filter by.
"""
return next(self.iter_markers(name=name), default)
- def listextrakeywords(self):
- """ Return a set of all extra keywords in self and any parents."""
- extra_keywords = set() # type: Set[str]
+ def listextrakeywords(self) -> Set[str]:
+ """Return a set of all extra keywords in self and any parents."""
+ extra_keywords: Set[str] = set()
for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches)
return extra_keywords
- def listnames(self):
+ def listnames(self) -> List[str]:
return [x.name for x in self.listchain()]
- def addfinalizer(self, fin):
- """ register a function to be called when this node is finalized.
+ def addfinalizer(self, fin: Callable[[], object]) -> None:
+ """Register a function to be called when this node is finalized.
This method can only be called when this node is active
in a setup chain, for example during self.setup().
"""
self.session._setupstate.addfinalizer(fin, self)
- def getparent(self, cls):
- """ get the next parent node (including ourself)
- which is an instance of the given class"""
- current = self # type: Optional[Node]
+ def getparent(self, cls: Type[_NodeType]) -> Optional[_NodeType]:
+ """Get the next parent node (including self) which is an instance of
+ the given class."""
+ current: Optional[Node] = self
while current and not isinstance(current, cls):
current = current.parent
+ assert current is None or isinstance(current, cls)
return current
- def _prunetraceback(self, excinfo):
+ def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
pass
def _repr_failure_py(
- self, excinfo: ExceptionInfo[BaseException], style=None,
- ) -> Union[str, ReprExceptionInfo, ExceptionChainRepr, FixtureLookupErrorRepr]:
+ self,
+ excinfo: ExceptionInfo[BaseException],
+ style: "Optional[_TracebackStyle]" = None,
+ ) -> TerminalRepr:
+ from _pytest.fixtures import FixtureLookupError
+
if isinstance(excinfo.value, ConftestImportFailure):
excinfo = ExceptionInfo(excinfo.value.excinfo)
if isinstance(excinfo.value, fail.Exception):
if not excinfo.value.pytrace:
- return str(excinfo.value)
+ style = "value"
if isinstance(excinfo.value, FixtureLookupError):
return excinfo.value.formatrepr()
if self.config.getoption("fulltrace", False):
@@ -356,7 +391,7 @@ class Node(metaclass=NodeMeta):
# It will be better to just always display paths relative to invocation_dir, but
# this requires a lot of plumbing (#6428).
try:
- abspath = Path(os.getcwd()) != Path(self.config.invocation_dir)
+ abspath = Path(os.getcwd()) != self.config.invocation_params.dir
except OSError:
abspath = True
@@ -370,49 +405,59 @@ class Node(metaclass=NodeMeta):
)
def repr_failure(
- self, excinfo, style=None
- ) -> Union[str, ReprExceptionInfo, ExceptionChainRepr, FixtureLookupErrorRepr]:
+ self,
+ excinfo: ExceptionInfo[BaseException],
+ style: "Optional[_TracebackStyle]" = None,
+ ) -> Union[str, TerminalRepr]:
+ """Return a representation of a collection or test failure.
+
+ :param excinfo: Exception information for the failure.
+ """
return self._repr_failure_py(excinfo, style)
def get_fslocation_from_item(
- item: "Item",
+ node: "Node",
) -> Tuple[Union[str, py.path.local], Optional[int]]:
- """Tries to extract the actual location from an item, depending on available attributes:
+ """Try to extract the actual location from a node, depending on available attributes:
- * "fslocation": a pair (path, lineno)
- * "obj": a Python object that the item wraps.
+ * "location": a pair (path, lineno)
+ * "obj": a Python object that the node wraps.
* "fspath": just a path
- :rtype: a tuple of (str|LocalPath, int) with filename and line number.
+ :rtype: A tuple of (str|py.path.local, int) with filename and line number.
"""
- try:
- return item.location[:2]
- except AttributeError:
- pass
- obj = getattr(item, "obj", None)
+ # See Item.location.
+ location: Optional[Tuple[str, Optional[int], str]] = getattr(node, "location", None)
+ if location is not None:
+ return location[:2]
+ obj = getattr(node, "obj", None)
if obj is not None:
return getfslineno(obj)
- return getattr(item, "fspath", "unknown location"), -1
+ return getattr(node, "fspath", "unknown location"), -1
class Collector(Node):
- """ Collector instances create children through collect()
- and thus iteratively build a tree.
- """
+ """Collector instances create children through collect() and thus
+ iteratively build a tree."""
class CollectError(Exception):
- """ an error during collection, contains a custom message. """
+ """An error during collection, contains a custom message."""
- def collect(self):
- """ returns a list of children (items and collectors)
- for this collection node.
- """
+ def collect(self) -> Iterable[Union["Item", "Collector"]]:
+ """Return a list of children (items and collectors) for this
+ collection node."""
raise NotImplementedError("abstract")
- def repr_failure(self, excinfo):
- """ represent a collection failure. """
- if excinfo.errisinstance(self.CollectError) and not self.config.getoption(
+ # TODO: This omits the style= parameter which breaks Liskov Substitution.
+ def repr_failure( # type: ignore[override]
+ self, excinfo: ExceptionInfo[BaseException]
+ ) -> Union[str, TerminalRepr]:
+ """Return a representation of a collection failure.
+
+ :param excinfo: Exception information for the failure.
+ """
+ if isinstance(excinfo.value, self.CollectError) and not self.config.getoption(
"fulltrace", False
):
exc = excinfo.value
@@ -426,7 +471,7 @@ class Collector(Node):
return self._repr_failure_py(excinfo, style=tbstyle)
- def _prunetraceback(self, excinfo):
+ def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
if hasattr(self, "fspath"):
traceback = excinfo.traceback
ntraceback = traceback.cut(path=self.fspath)
@@ -441,23 +486,14 @@ def _check_initialpaths_for_relpath(session, fspath):
return fspath.relto(initial_path)
-class FSHookProxy:
- def __init__(
- self, fspath: py.path.local, pm: PytestPluginManager, remove_mods
- ) -> None:
- self.fspath = fspath
- self.pm = pm
- self.remove_mods = remove_mods
-
- def __getattr__(self, name: str):
- x = self.pm.subset_hook_caller(name, remove_plugins=self.remove_mods)
- self.__dict__[name] = x
- return x
-
-
class FSCollector(Collector):
def __init__(
- self, fspath: py.path.local, parent=None, config=None, session=None, nodeid=None
+ self,
+ fspath: py.path.local,
+ parent=None,
+ config: Optional[Config] = None,
+ session: Optional["Session"] = None,
+ nodeid: Optional[str] = None,
) -> None:
name = fspath.basename
if parent is not None:
@@ -479,91 +515,56 @@ class FSCollector(Collector):
super().__init__(name, parent, config, session, nodeid=nodeid, fspath=fspath)
- self._norecursepatterns = self.config.getini("norecursedirs")
-
@classmethod
def from_parent(cls, parent, *, fspath, **kw):
- """
- The public constructor
- """
+ """The public constructor."""
return super().from_parent(parent=parent, fspath=fspath, **kw)
- def _gethookproxy(self, fspath: py.path.local):
- # check if we have the common case of running
- # hooks with all conftest.py files
- pm = self.config.pluginmanager
- my_conftestmodules = pm._getconftestmodules(fspath)
- remove_mods = pm._conftest_plugins.difference(my_conftestmodules)
- if remove_mods:
- # one or more conftests are not in use at this fspath
- proxy = FSHookProxy(fspath, pm, remove_mods)
- else:
- # all plugins are active for this fspath
- proxy = self.config.hook
- return proxy
-
- def _recurse(self, dirpath: py.path.local) -> bool:
- if dirpath.basename == "__pycache__":
- return False
- ihook = self._gethookproxy(dirpath.dirpath())
- if ihook.pytest_ignore_collect(path=dirpath, config=self.config):
- return False
- for pat in self._norecursepatterns:
- if dirpath.check(fnmatch=pat):
- return False
- ihook = self._gethookproxy(dirpath)
- ihook.pytest_collect_directory(path=dirpath, parent=self)
- return True
-
- def _collectfile(self, path, handle_dupes=True):
- assert (
- path.isfile()
- ), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format(
- path, path.isdir(), path.exists(), path.islink()
- )
- ihook = self.gethookproxy(path)
- if not self.isinitpath(path):
- if ihook.pytest_ignore_collect(path=path, config=self.config):
- return ()
-
- if handle_dupes:
- keepduplicates = self.config.getoption("keepduplicates")
- if not keepduplicates:
- duplicate_paths = self.config.pluginmanager._duplicatepaths
- if path in duplicate_paths:
- return ()
- else:
- duplicate_paths.add(path)
+ def gethookproxy(self, fspath: py.path.local):
+ warnings.warn(FSCOLLECTOR_GETHOOKPROXY_ISINITPATH, stacklevel=2)
+ return self.session.gethookproxy(fspath)
- return ihook.pytest_collect_file(path=path, parent=self)
+ def isinitpath(self, path: py.path.local) -> bool:
+ warnings.warn(FSCOLLECTOR_GETHOOKPROXY_ISINITPATH, stacklevel=2)
+ return self.session.isinitpath(path)
class File(FSCollector):
- """ base class for collecting tests from a file. """
+ """Base class for collecting tests from a file.
+
+ :ref:`non-python tests`.
+ """
class Item(Node):
- """ a basic test invocation item. Note that for a single function
- there might be multiple test invocation items.
+ """A basic test invocation item.
+
+ Note that for a single function there might be multiple test invocation items.
"""
nextitem = None
- def __init__(self, name, parent=None, config=None, session=None, nodeid=None):
+ def __init__(
+ self,
+ name,
+ parent=None,
+ config: Optional[Config] = None,
+ session: Optional["Session"] = None,
+ nodeid: Optional[str] = None,
+ ) -> None:
super().__init__(name, parent, config, session, nodeid=nodeid)
- self._report_sections = [] # type: List[Tuple[str, str, str]]
+ self._report_sections: List[Tuple[str, str, str]] = []
- #: user properties is a list of tuples (name, value) that holds user
- #: defined properties for this test.
- self.user_properties = [] # type: List[Tuple[str, Any]]
+ #: A list of tuples (name, value) that holds user defined properties
+ #: for this test.
+ self.user_properties: List[Tuple[str, object]] = []
def runtest(self) -> None:
raise NotImplementedError("runtest must be implemented by Item subclass")
def add_report_section(self, when: str, key: str, content: str) -> None:
- """
- Adds a new report section, similar to what's done internally to add stdout and
- stderr captured output::
+ """Add a new report section, similar to what's done internally to add
+ stdout and stderr captured output::
item.add_report_section("call", "stdout", "report section contents")
@@ -572,7 +573,6 @@ class Item(Node):
:param str key:
Name of the section, can be customized at will. Pytest uses ``"stdout"`` and
``"stderr"`` internally.
-
:param str content:
The full contents as a string.
"""
@@ -585,10 +585,7 @@ class Item(Node):
@cached_property
def location(self) -> Tuple[str, Optional[int], str]:
location = self.reportinfo()
- if isinstance(location[0], py.path.local):
- fspath = location[0]
- else:
- fspath = py.path.local(location[0])
+ fspath = absolutepath(str(location[0]))
relfspath = self.session._node_location_to_relpath(fspath)
assert type(location[2]) is str
return (relfspath, location[1], location[2])
diff --git a/contrib/python/pytest/py3/_pytest/nose.py b/contrib/python/pytest/py3/_pytest/nose.py
index d6f3c2b224..bb8f99772a 100644
--- a/contrib/python/pytest/py3/_pytest/nose.py
+++ b/contrib/python/pytest/py3/_pytest/nose.py
@@ -1,16 +1,17 @@
-""" run test suites written for nose. """
+"""Run testsuites written for nose."""
from _pytest import python
from _pytest import unittest
from _pytest.config import hookimpl
+from _pytest.nodes import Item
@hookimpl(trylast=True)
def pytest_runtest_setup(item):
if is_potential_nosetest(item):
if not call_optional(item.obj, "setup"):
- # call module level setup if there is no object level one
+ # Call module level setup if there is no object level one.
call_optional(item.parent.obj, "setup")
- # XXX this implies we only call teardown when setup worked
+ # XXX This implies we only call teardown when setup worked.
item.session._setupstate.addfinalizer((lambda: teardown_nose(item)), item)
@@ -20,9 +21,9 @@ def teardown_nose(item):
call_optional(item.parent.obj, "teardown")
-def is_potential_nosetest(item):
- # extra check needed since we do not do nose style setup/teardown
- # on direct unittest style classes
+def is_potential_nosetest(item: Item) -> bool:
+ # Extra check needed since we do not do nose style setup/teardown
+ # on direct unittest style classes.
return isinstance(item, python.Function) and not isinstance(
item, unittest.TestCaseFunction
)
@@ -33,6 +34,6 @@ def call_optional(obj, name):
isfixture = hasattr(method, "_pytestfixturefunction")
if method is not None and not isfixture and callable(method):
# If there's any problems allow the exception to raise rather than
- # silently ignoring them
+ # silently ignoring them.
method()
return True
diff --git a/contrib/python/pytest/py3/_pytest/outcomes.py b/contrib/python/pytest/py3/_pytest/outcomes.py
index bed73c94de..8f6203fd7f 100644
--- a/contrib/python/pytest/py3/_pytest/outcomes.py
+++ b/contrib/python/pytest/py3/_pytest/outcomes.py
@@ -1,21 +1,17 @@
-"""
-exception classes and constants handling test outcomes
-as well as functions creating them
-"""
+"""Exception classes and constants handling test outcomes as well as
+functions creating them."""
import sys
from typing import Any
from typing import Callable
from typing import cast
from typing import Optional
+from typing import Type
from typing import TypeVar
-from packaging.version import Version
-
-TYPE_CHECKING = False # avoid circular import through compat
+TYPE_CHECKING = False # Avoid circular import through compat.
if TYPE_CHECKING:
from typing import NoReturn
- from typing import Type # noqa: F401 (Used in string type annotation.)
from typing_extensions import Protocol
else:
# typing.Protocol is only available starting from Python 3.8. It is also
@@ -27,13 +23,12 @@ else:
class OutcomeException(BaseException):
- """ OutcomeException and its subclass instances indicate and
- contain info about test and collection outcomes.
- """
+ """OutcomeException and its subclass instances indicate and contain info
+ about test and collection outcomes."""
def __init__(self, msg: Optional[str] = None, pytrace: bool = True) -> None:
if msg is not None and not isinstance(msg, str):
- error_msg = (
+ error_msg = ( # type: ignore[unreachable]
"{} expected string as 'msg' parameter, got '{}' instead.\n"
"Perhaps you meant to use a mark?"
)
@@ -43,9 +38,9 @@ class OutcomeException(BaseException):
self.pytrace = pytrace
def __repr__(self) -> str:
- if self.msg:
+ if self.msg is not None:
return self.msg
- return "<{} instance>".format(self.__class__.__name__)
+ return f"<{self.__class__.__name__} instance>"
__str__ = __repr__
@@ -69,13 +64,13 @@ class Skipped(OutcomeException):
class Failed(OutcomeException):
- """ raised from an explicit call to pytest.fail() """
+ """Raised from an explicit call to pytest.fail()."""
__module__ = "builtins"
class Exit(Exception):
- """ raised for immediate program exits (no tracebacks/summaries)"""
+ """Raised for immediate program exits (no tracebacks/summaries)."""
def __init__(
self, msg: str = "unknown reason", returncode: Optional[int] = None
@@ -88,13 +83,13 @@ class Exit(Exception):
# Elaborate hack to work around https://github.com/python/mypy/issues/2087.
# Ideally would just be `exit.Exception = Exit` etc.
-_F = TypeVar("_F", bound=Callable)
-_ET = TypeVar("_ET", bound="Type[BaseException]")
+_F = TypeVar("_F", bound=Callable[..., object])
+_ET = TypeVar("_ET", bound=Type[BaseException])
class _WithException(Protocol[_F, _ET]):
- Exception = None # type: _ET
- __call__ = None # type: _F
+ Exception: _ET
+ __call__: _F
def _with_exception(exception_type: _ET) -> Callable[[_F], _WithException[_F, _ET]]:
@@ -106,16 +101,15 @@ def _with_exception(exception_type: _ET) -> Callable[[_F], _WithException[_F, _E
return decorate
-# exposed helper methods
+# Exposed helper methods.
@_with_exception(Exit)
def exit(msg: str, returncode: Optional[int] = None) -> "NoReturn":
- """
- Exit testing process.
+ """Exit testing process.
- :param str msg: message to display upon exit.
- :param int returncode: return code to be used when exiting pytest.
+ :param str msg: Message to display upon exit.
+ :param int returncode: Return code to be used when exiting pytest.
"""
__tracebackhide__ = True
raise Exit(msg, returncode)
@@ -123,20 +117,20 @@ def exit(msg: str, returncode: Optional[int] = None) -> "NoReturn":
@_with_exception(Skipped)
def skip(msg: str = "", *, allow_module_level: bool = False) -> "NoReturn":
- """
- Skip an executing test with the given message.
+ """Skip an executing test with the given message.
This function should be called only during testing (setup, call or teardown) or
during collection by using the ``allow_module_level`` flag. This function can
be called in doctests as well.
- :kwarg bool allow_module_level: allows this function to be called at
- module level, skipping the rest of the module. Default to False.
+ :param bool allow_module_level:
+ Allows this function to be called at module level, skipping the rest
+ of the module. Defaults to False.
.. note::
- It is better to use the :ref:`pytest.mark.skipif ref` marker when possible to declare a test to be
- skipped under certain conditions like mismatching platforms or
- dependencies.
+ It is better to use the :ref:`pytest.mark.skipif ref` marker when
+ possible to declare a test to be skipped under certain conditions
+ like mismatching platforms or dependencies.
Similarly, use the ``# doctest: +SKIP`` directive (see `doctest.SKIP
<https://docs.python.org/3/library/doctest.html#doctest.SKIP>`_)
to skip a doctest statically.
@@ -147,11 +141,12 @@ def skip(msg: str = "", *, allow_module_level: bool = False) -> "NoReturn":
@_with_exception(Failed)
def fail(msg: str = "", pytrace: bool = True) -> "NoReturn":
- """
- Explicitly fail an executing test with the given message.
+ """Explicitly fail an executing test with the given message.
- :param str msg: the message to show the user as reason for the failure.
- :param bool pytrace: if false the msg represents the full failure information and no
+ :param str msg:
+ The message to show the user as reason for the failure.
+ :param bool pytrace:
+ If False, msg represents the full failure information and no
python traceback will be reported.
"""
__tracebackhide__ = True
@@ -159,19 +154,19 @@ def fail(msg: str = "", pytrace: bool = True) -> "NoReturn":
class XFailed(Failed):
- """ raised from an explicit call to pytest.xfail() """
+ """Raised from an explicit call to pytest.xfail()."""
@_with_exception(XFailed)
def xfail(reason: str = "") -> "NoReturn":
- """
- Imperatively xfail an executing test or setup functions with the given reason.
+ """Imperatively xfail an executing test or setup function with the given reason.
This function should be called only during testing (setup, call or teardown).
.. note::
- It is better to use the :ref:`pytest.mark.xfail ref` marker when possible to declare a test to be
- xfailed under certain conditions like known bugs or missing features.
+ It is better to use the :ref:`pytest.mark.xfail ref` marker when
+ possible to declare a test to be xfailed under certain conditions
+ like known bugs or missing features.
"""
__tracebackhide__ = True
raise XFailed(reason)
@@ -180,17 +175,20 @@ def xfail(reason: str = "") -> "NoReturn":
def importorskip(
modname: str, minversion: Optional[str] = None, reason: Optional[str] = None
) -> Any:
- """Imports and returns the requested module ``modname``, or skip the
+ """Import and return the requested module ``modname``, or skip the
current test if the module cannot be imported.
- :param str modname: the name of the module to import
- :param str minversion: if given, the imported module's ``__version__``
- attribute must be at least this minimal version, otherwise the test is
- still skipped.
- :param str reason: if given, this reason is shown as the message when the
- module cannot be imported.
- :returns: The imported module. This should be assigned to its canonical
- name.
+ :param str modname:
+ The name of the module to import.
+ :param str minversion:
+ If given, the imported module's ``__version__`` attribute must be at
+ least this minimal version, otherwise the test is still skipped.
+ :param str reason:
+ If given, this reason is shown as the message when the module cannot
+ be imported.
+
+ :returns:
+ The imported module. This should be assigned to its canonical name.
Example::
@@ -202,21 +200,24 @@ def importorskip(
compile(modname, "", "eval") # to catch syntaxerrors
with warnings.catch_warnings():
- # make sure to ignore ImportWarnings that might happen because
+ # Make sure to ignore ImportWarnings that might happen because
# of existing directories with the same name we're trying to
- # import but without a __init__.py file
+ # import but without a __init__.py file.
warnings.simplefilter("ignore")
try:
__import__(modname)
except ImportError as exc:
if reason is None:
- reason = "could not import {!r}: {}".format(modname, exc)
+ reason = f"could not import {modname!r}: {exc}"
raise Skipped(reason, allow_module_level=True) from None
mod = sys.modules[modname]
if minversion is None:
return mod
verattr = getattr(mod, "__version__", None)
if minversion is not None:
+ # Imported lazily to improve start-up time.
+ from packaging.version import Version
+
if verattr is None or Version(verattr) < Version(minversion):
raise Skipped(
"module %r has __version__ %r, required is: %r"
diff --git a/contrib/python/pytest/py3/_pytest/pastebin.py b/contrib/python/pytest/py3/_pytest/pastebin.py
index 3f4a7502d5..131873c174 100644
--- a/contrib/python/pytest/py3/_pytest/pastebin.py
+++ b/contrib/python/pytest/py3/_pytest/pastebin.py
@@ -1,15 +1,21 @@
-""" submit failure or test session information to a pastebin service. """
+"""Submit failure or test session information to a pastebin service."""
import tempfile
+from io import StringIO
from typing import IO
+from typing import Union
import pytest
+from _pytest.config import Config
+from _pytest.config import create_terminal_writer
+from _pytest.config.argparsing import Parser
from _pytest.store import StoreKey
+from _pytest.terminal import TerminalReporter
pastebinfile_key = StoreKey[IO[bytes]]()
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group._addoption(
"--pastebin",
@@ -23,14 +29,14 @@ def pytest_addoption(parser):
@pytest.hookimpl(trylast=True)
-def pytest_configure(config):
+def pytest_configure(config: Config) -> None:
if config.option.pastebin == "all":
tr = config.pluginmanager.getplugin("terminalreporter")
- # if no terminal reporter plugin is present, nothing we can do here;
- # this can happen when this function executes in a slave node
- # when using pytest-xdist, for example
+ # If no terminal reporter plugin is present, nothing we can do here;
+ # this can happen when this function executes in a worker node
+ # when using pytest-xdist, for example.
if tr is not None:
- # pastebin file will be utf-8 encoded binary file
+ # pastebin file will be UTF-8 encoded binary file.
config._store[pastebinfile_key] = tempfile.TemporaryFile("w+b")
oldwrite = tr._tw.write
@@ -43,29 +49,28 @@ def pytest_configure(config):
tr._tw.write = tee_write
-def pytest_unconfigure(config):
+def pytest_unconfigure(config: Config) -> None:
if pastebinfile_key in config._store:
pastebinfile = config._store[pastebinfile_key]
- # get terminal contents and delete file
+ # Get terminal contents and delete file.
pastebinfile.seek(0)
sessionlog = pastebinfile.read()
pastebinfile.close()
del config._store[pastebinfile_key]
- # undo our patching in the terminal reporter
+ # Undo our patching in the terminal reporter.
tr = config.pluginmanager.getplugin("terminalreporter")
del tr._tw.__dict__["write"]
- # write summary
+ # Write summary.
tr.write_sep("=", "Sending information to Paste Service")
pastebinurl = create_new_paste(sessionlog)
tr.write_line("pastebin session-log: %s\n" % pastebinurl)
-def create_new_paste(contents):
- """
- Creates a new paste using bpaste.net service.
+def create_new_paste(contents: Union[str, bytes]) -> str:
+ """Create a new paste using the bpaste.net service.
- :contents: paste contents as utf-8 encoded bytes
- :returns: url to the pasted contents or error message
+ :contents: Paste contents string.
+ :returns: URL to the pasted contents, or an error message.
"""
import re
from urllib.request import urlopen
@@ -74,7 +79,7 @@ def create_new_paste(contents):
params = {"code": contents, "lexer": "text", "expiry": "1week"}
url = "https://bpaste.net"
try:
- response = (
+ response: str = (
urlopen(url, data=urlencode(params).encode("ascii")).read().decode("utf-8")
)
except OSError as exc_info: # urllib errors
@@ -86,24 +91,20 @@ def create_new_paste(contents):
return "bad response: invalid format ('" + response + "')"
-def pytest_terminal_summary(terminalreporter):
- import _pytest.config
-
+def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
if terminalreporter.config.option.pastebin != "failed":
return
- tr = terminalreporter
- if "failed" in tr.stats:
+ if "failed" in terminalreporter.stats:
terminalreporter.write_sep("=", "Sending information to Paste Service")
- for rep in terminalreporter.stats.get("failed"):
+ for rep in terminalreporter.stats["failed"]:
try:
msg = rep.longrepr.reprtraceback.reprentries[-1].reprfileloc
except AttributeError:
- msg = tr._getfailureheadline(rep)
- tw = _pytest.config.create_terminal_writer(
- terminalreporter.config, stringio=True
- )
+ msg = terminalreporter._getfailureheadline(rep)
+ file = StringIO()
+ tw = create_terminal_writer(terminalreporter.config, file)
rep.toterminal(tw)
- s = tw.stringio.getvalue()
+ s = file.getvalue()
assert len(s)
pastebinurl = create_new_paste(s)
- tr.write_line("{} --> {}".format(msg, pastebinurl))
+ terminalreporter.write_line(f"{msg} --> {pastebinurl}")
diff --git a/contrib/python/pytest/py3/_pytest/pathlib.py b/contrib/python/pytest/py3/_pytest/pathlib.py
index 2f04b02d7a..7d9269a185 100644
--- a/contrib/python/pytest/py3/_pytest/pathlib.py
+++ b/contrib/python/pytest/py3/_pytest/pathlib.py
@@ -1,69 +1,84 @@
import atexit
+import contextlib
import fnmatch
+import importlib.util
import itertools
import os
import shutil
import sys
import uuid
import warnings
+from enum import Enum
+from errno import EBADF
+from errno import ELOOP
+from errno import ENOENT
+from errno import ENOTDIR
from functools import partial
from os.path import expanduser
from os.path import expandvars
from os.path import isabs
from os.path import sep
+from pathlib import Path
+from pathlib import PurePath
from posixpath import sep as posix_sep
+from types import ModuleType
+from typing import Callable
from typing import Iterable
from typing import Iterator
+from typing import Optional
from typing import Set
from typing import TypeVar
from typing import Union
+import py
+
+from _pytest.compat import assert_never
+from _pytest.outcomes import skip
from _pytest.warning_types import PytestWarning
-if sys.version_info[:2] >= (3, 6):
- from pathlib import Path, PurePath
-else:
- from pathlib2 import Path, PurePath
+LOCK_TIMEOUT = 60 * 60 * 24 * 3
-__all__ = ["Path", "PurePath"]
+_AnyPurePath = TypeVar("_AnyPurePath", bound=PurePath)
-LOCK_TIMEOUT = 60 * 60 * 3
+# The following function, variables and comments were
+# copied from cpython 3.9 Lib/pathlib.py file.
+# EBADF - guard against macOS `stat` throwing EBADF
+_IGNORED_ERRORS = (ENOENT, ENOTDIR, EBADF, ELOOP)
-_AnyPurePath = TypeVar("_AnyPurePath", bound=PurePath)
+_IGNORED_WINERRORS = (
+ 21, # ERROR_NOT_READY - drive exists but is not accessible
+ 1921, # ERROR_CANT_RESOLVE_FILENAME - fix for broken symlink pointing to itself
+)
-def get_lock_path(path: _AnyPurePath) -> _AnyPurePath:
- return path.joinpath(".lock")
+def _ignore_error(exception):
+ return (
+ getattr(exception, "errno", None) in _IGNORED_ERRORS
+ or getattr(exception, "winerror", None) in _IGNORED_WINERRORS
+ )
-def ensure_reset_dir(path: Path) -> None:
- """
- ensures the given path is an empty directory
- """
- if path.exists():
- rm_rf(path)
- path.mkdir()
+def get_lock_path(path: _AnyPurePath) -> _AnyPurePath:
+ return path.joinpath(".lock")
def on_rm_rf_error(func, path: str, exc, *, start_path: Path) -> bool:
- """Handles known read-only errors during rmtree.
+ """Handle known read-only errors during rmtree.
The returned value is used only by our own tests.
"""
exctype, excvalue = exc[:2]
- # another process removed the file in the middle of the "rm_rf" (xdist for example)
- # more context: https://github.com/pytest-dev/pytest/issues/5974#issuecomment-543799018
+ # Another process removed the file in the middle of the "rm_rf" (xdist for example).
+ # More context: https://github.com/pytest-dev/pytest/issues/5974#issuecomment-543799018
if isinstance(excvalue, FileNotFoundError):
return False
if not isinstance(excvalue, PermissionError):
warnings.warn(
- PytestWarning(
- "(rm_rf) error removing {}\n{}: {}".format(path, exctype, excvalue)
- )
+ PytestWarning(f"(rm_rf) error removing {path}\n{exctype}: {excvalue}")
)
return False
@@ -91,7 +106,7 @@ def on_rm_rf_error(func, path: str, exc, *, start_path: Path) -> bool:
if p.is_file():
for parent in p.parents:
chmod_rw(str(parent))
- # stop when we reach the original path passed to rm_rf
+ # Stop when we reach the original path passed to rm_rf.
if parent == start_path:
break
chmod_rw(str(path))
@@ -119,7 +134,7 @@ def ensure_extended_length_path(path: Path) -> Path:
def get_extended_length_path_str(path: str) -> str:
- """Converts to extended length path as a str"""
+ """Convert a path to a Windows extended length path."""
long_path_prefix = "\\\\?\\"
unc_long_path_prefix = "\\\\?\\UNC\\"
if path.startswith((long_path_prefix, unc_long_path_prefix)):
@@ -132,15 +147,14 @@ def get_extended_length_path_str(path: str) -> str:
def rm_rf(path: Path) -> None:
"""Remove the path contents recursively, even if some elements
- are read-only.
- """
+ are read-only."""
path = ensure_extended_length_path(path)
onerror = partial(on_rm_rf_error, start_path=path)
shutil.rmtree(str(path), onerror=onerror)
def find_prefixed(root: Path, prefix: str) -> Iterator[Path]:
- """finds all elements in root that begin with the prefix, case insensitive"""
+ """Find all elements in root that begin with the prefix, case insensitive."""
l_prefix = prefix.lower()
for x in root.iterdir():
if x.name.lower().startswith(l_prefix):
@@ -148,10 +162,10 @@ def find_prefixed(root: Path, prefix: str) -> Iterator[Path]:
def extract_suffixes(iter: Iterable[PurePath], prefix: str) -> Iterator[str]:
- """
- :param iter: iterator over path names
- :param prefix: expected prefix of the path names
- :returns: the parts of the paths following the prefix
+ """Return the parts of the paths following the prefix.
+
+ :param iter: Iterator over path names.
+ :param prefix: Expected prefix of the path names.
"""
p_len = len(prefix)
for p in iter:
@@ -159,13 +173,12 @@ def extract_suffixes(iter: Iterable[PurePath], prefix: str) -> Iterator[str]:
def find_suffixes(root: Path, prefix: str) -> Iterator[str]:
- """combines find_prefixes and extract_suffixes
- """
+ """Combine find_prefixes and extract_suffixes."""
return extract_suffixes(find_prefixed(root, prefix), prefix)
def parse_num(maybe_num) -> int:
- """parses number path suffixes, returns -1 on error"""
+ """Parse number path suffixes, returns -1 on error."""
try:
return int(maybe_num)
except ValueError:
@@ -175,13 +188,13 @@ def parse_num(maybe_num) -> int:
def _force_symlink(
root: Path, target: Union[str, PurePath], link_to: Union[str, Path]
) -> None:
- """helper to create the current symlink
+ """Helper to create the current symlink.
- it's full of race conditions that are reasonably ok to ignore
- for the context of best effort linking to the latest test run
+ It's full of race conditions that are reasonably OK to ignore
+ for the context of best effort linking to the latest test run.
- the presumption being that in case of much parallelism
- the inaccuracy is going to be acceptable
+ The presumption being that in case of much parallelism
+ the inaccuracy is going to be acceptable.
"""
current_symlink = root.joinpath(target)
try:
@@ -194,46 +207,46 @@ def _force_symlink(
pass
-def make_numbered_dir(root: Path, prefix: str) -> Path:
- """create a directory with an increased number as suffix for the given prefix"""
+def make_numbered_dir(root: Path, prefix: str, mode: int = 0o700) -> Path:
+ """Create a directory with an increased number as suffix for the given prefix."""
for i in range(10):
# try up to 10 times to create the folder
max_existing = max(map(parse_num, find_suffixes(root, prefix)), default=-1)
new_number = max_existing + 1
- new_path = root.joinpath("{}{}".format(prefix, new_number))
+ new_path = root.joinpath(f"{prefix}{new_number}")
try:
- new_path.mkdir()
+ new_path.mkdir(mode=mode)
except Exception:
pass
else:
_force_symlink(root, prefix + "current", new_path)
return new_path
else:
- raise EnvironmentError(
+ raise OSError(
"could not create numbered dir with prefix "
"{prefix} in {root} after 10 tries".format(prefix=prefix, root=root)
)
def create_cleanup_lock(p: Path) -> Path:
- """crates a lock to prevent premature folder cleanup"""
+ """Create a lock to prevent premature folder cleanup."""
lock_path = get_lock_path(p)
try:
fd = os.open(str(lock_path), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
except FileExistsError as e:
- raise EnvironmentError("cannot create lockfile in {path}".format(path=p)) from e
+ raise OSError(f"cannot create lockfile in {p}") from e
else:
pid = os.getpid()
spid = str(pid).encode()
os.write(fd, spid)
os.close(fd)
if not lock_path.is_file():
- raise EnvironmentError("lock path got renamed after successful creation")
+ raise OSError("lock path got renamed after successful creation")
return lock_path
def register_cleanup_lock_removal(lock_path: Path, register=atexit.register):
- """registers a cleanup function for removing a lock, by default on atexit"""
+ """Register a cleanup function for removing a lock, by default on atexit."""
pid = os.getpid()
def cleanup_on_exit(lock_path: Path = lock_path, original_pid: int = pid) -> None:
@@ -243,66 +256,76 @@ def register_cleanup_lock_removal(lock_path: Path, register=atexit.register):
return
try:
lock_path.unlink()
- except (OSError, IOError):
+ except OSError:
pass
return register(cleanup_on_exit)
def maybe_delete_a_numbered_dir(path: Path) -> None:
- """removes a numbered directory if its lock can be obtained and it does not seem to be in use"""
+ """Remove a numbered directory if its lock can be obtained and it does
+ not seem to be in use."""
path = ensure_extended_length_path(path)
lock_path = None
try:
lock_path = create_cleanup_lock(path)
parent = path.parent
- garbage = parent.joinpath("garbage-{}".format(uuid.uuid4()))
+ garbage = parent.joinpath(f"garbage-{uuid.uuid4()}")
path.rename(garbage)
rm_rf(garbage)
- except (OSError, EnvironmentError):
+ except OSError:
# known races:
# * other process did a cleanup at the same time
# * deletable folder was found
# * process cwd (Windows)
return
finally:
- # if we created the lock, ensure we remove it even if we failed
- # to properly remove the numbered dir
+ # If we created the lock, ensure we remove it even if we failed
+ # to properly remove the numbered dir.
if lock_path is not None:
try:
lock_path.unlink()
- except (OSError, IOError):
+ except OSError:
pass
def ensure_deletable(path: Path, consider_lock_dead_if_created_before: float) -> bool:
- """checks if a lock exists and breaks it if its considered dead"""
+ """Check if `path` is deletable based on whether the lock file is expired."""
if path.is_symlink():
return False
lock = get_lock_path(path)
- if not lock.exists():
- return True
+ try:
+ if not lock.is_file():
+ return True
+ except OSError:
+ # we might not have access to the lock file at all, in this case assume
+ # we don't have access to the entire directory (#7491).
+ return False
try:
lock_time = lock.stat().st_mtime
except Exception:
return False
else:
if lock_time < consider_lock_dead_if_created_before:
- lock.unlink()
- return True
- else:
- return False
+ # We want to ignore any errors while trying to remove the lock such as:
+ # - PermissionDenied, like the file permissions have changed since the lock creation;
+ # - FileNotFoundError, in case another pytest process got here first;
+ # and any other cause of failure.
+ with contextlib.suppress(OSError):
+ lock.unlink()
+ return True
+ return False
def try_cleanup(path: Path, consider_lock_dead_if_created_before: float) -> None:
- """tries to cleanup a folder if we can ensure it's deletable"""
+ """Try to cleanup a folder if we can ensure it's deletable."""
if ensure_deletable(path, consider_lock_dead_if_created_before):
maybe_delete_a_numbered_dir(path)
def cleanup_candidates(root: Path, prefix: str, keep: int) -> Iterator[Path]:
- """lists candidates for numbered directories to be removed - follows py.path"""
+ """List candidates for numbered directories to be removed - follows py.path."""
max_existing = max(map(parse_num, find_suffixes(root, prefix)), default=-1)
max_delete = max_existing - keep
paths = find_prefixed(root, prefix)
@@ -316,7 +339,7 @@ def cleanup_candidates(root: Path, prefix: str, keep: int) -> Iterator[Path]:
def cleanup_numbered_dir(
root: Path, prefix: str, keep: int, consider_lock_dead_if_created_before: float
) -> None:
- """cleanup for lock driven numbered directories"""
+ """Cleanup for lock driven numbered directories."""
for path in cleanup_candidates(root, prefix, keep):
try_cleanup(path, consider_lock_dead_if_created_before)
for path in root.glob("garbage-*"):
@@ -324,53 +347,54 @@ def cleanup_numbered_dir(
def make_numbered_dir_with_cleanup(
- root: Path, prefix: str, keep: int, lock_timeout: float
+ root: Path, prefix: str, keep: int, lock_timeout: float, mode: int,
) -> Path:
- """creates a numbered dir with a cleanup lock and removes old ones"""
+ """Create a numbered dir with a cleanup lock and remove old ones."""
e = None
for i in range(10):
try:
- p = make_numbered_dir(root, prefix)
+ p = make_numbered_dir(root, prefix, mode)
lock_path = create_cleanup_lock(p)
register_cleanup_lock_removal(lock_path)
except Exception as exc:
e = exc
else:
consider_lock_dead_if_created_before = p.stat().st_mtime - lock_timeout
- cleanup_numbered_dir(
- root=root,
- prefix=prefix,
- keep=keep,
- consider_lock_dead_if_created_before=consider_lock_dead_if_created_before,
+ # Register a cleanup for program exit
+ atexit.register(
+ cleanup_numbered_dir,
+ root,
+ prefix,
+ keep,
+ consider_lock_dead_if_created_before,
)
return p
assert e is not None
raise e
-def resolve_from_str(input, root):
- assert not isinstance(input, Path), "would break on py2"
- root = Path(root)
+def resolve_from_str(input: str, rootpath: Path) -> Path:
input = expanduser(input)
input = expandvars(input)
if isabs(input):
return Path(input)
else:
- return root.joinpath(input)
+ return rootpath.joinpath(input)
def fnmatch_ex(pattern: str, path) -> bool:
- """FNMatcher port from py.path.common which works with PurePath() instances.
+ """A port of FNMatcher from py.path.common which works with PurePath() instances.
- The difference between this algorithm and PurePath.match() is that the latter matches "**" glob expressions
- for each part of the path, while this algorithm uses the whole path instead.
+ The difference between this algorithm and PurePath.match() is that the
+ latter matches "**" glob expressions for each part of the path, while
+ this algorithm uses the whole path instead.
For example:
- "tests/foo/bar/doc/test_foo.py" matches pattern "tests/**/doc/test*.py" with this algorithm, but not with
- PurePath.match().
+ "tests/foo/bar/doc/test_foo.py" matches pattern "tests/**/doc/test*.py"
+ with this algorithm, but not with PurePath.match().
- This algorithm was ported to keep backward-compatibility with existing settings which assume paths match according
- this logic.
+ This algorithm was ported to keep backward-compatibility with existing
+ settings which assume paths match according this logic.
References:
* https://bugs.python.org/issue29249
@@ -390,10 +414,241 @@ def fnmatch_ex(pattern: str, path) -> bool:
else:
name = str(path)
if path.is_absolute() and not os.path.isabs(pattern):
- pattern = "*{}{}".format(os.sep, pattern)
+ pattern = f"*{os.sep}{pattern}"
return fnmatch.fnmatch(name, pattern)
def parts(s: str) -> Set[str]:
parts = s.split(sep)
return {sep.join(parts[: i + 1]) or sep for i in range(len(parts))}
+
+
+def symlink_or_skip(src, dst, **kwargs):
+ """Make a symlink, or skip the test in case symlinks are not supported."""
+ try:
+ os.symlink(str(src), str(dst), **kwargs)
+ except OSError as e:
+ skip(f"symlinks not supported: {e}")
+
+
+class ImportMode(Enum):
+ """Possible values for `mode` parameter of `import_path`."""
+
+ prepend = "prepend"
+ append = "append"
+ importlib = "importlib"
+
+
+class ImportPathMismatchError(ImportError):
+ """Raised on import_path() if there is a mismatch of __file__'s.
+
+ This can happen when `import_path` is called multiple times with different filenames that has
+ the same basename but reside in packages
+ (for example "/tests1/test_foo.py" and "/tests2/test_foo.py").
+ """
+
+
+def import_path(
+ p: Union[str, py.path.local, Path],
+ *,
+ mode: Union[str, ImportMode] = ImportMode.prepend,
+) -> ModuleType:
+ """Import and return a module from the given path, which can be a file (a module) or
+ a directory (a package).
+
+ The import mechanism used is controlled by the `mode` parameter:
+
+ * `mode == ImportMode.prepend`: the directory containing the module (or package, taking
+ `__init__.py` files into account) will be put at the *start* of `sys.path` before
+ being imported with `__import__.
+
+ * `mode == ImportMode.append`: same as `prepend`, but the directory will be appended
+ to the end of `sys.path`, if not already in `sys.path`.
+
+ * `mode == ImportMode.importlib`: uses more fine control mechanisms provided by `importlib`
+ to import the module, which avoids having to use `__import__` and muck with `sys.path`
+ at all. It effectively allows having same-named test modules in different places.
+
+ :raises ImportPathMismatchError:
+ If after importing the given `path` and the module `__file__`
+ are different. Only raised in `prepend` and `append` modes.
+ """
+ mode = ImportMode(mode)
+
+ path = Path(str(p))
+
+ if not path.exists():
+ raise ImportError(path)
+
+ if mode is ImportMode.importlib:
+ module_name = path.stem
+
+ for meta_importer in sys.meta_path:
+ spec = meta_importer.find_spec(module_name, [str(path.parent)])
+ if spec is not None:
+ break
+ else:
+ spec = importlib.util.spec_from_file_location(module_name, str(path))
+
+ if spec is None:
+ raise ImportError(
+ "Can't find module {} at location {}".format(module_name, str(path))
+ )
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod) # type: ignore[union-attr]
+ return mod
+
+ pkg_path = resolve_package_path(path)
+ if pkg_path is not None:
+ pkg_root = pkg_path.parent
+ names = list(path.with_suffix("").relative_to(pkg_root).parts)
+ if names[-1] == "__init__":
+ names.pop()
+ module_name = ".".join(names)
+ else:
+ pkg_root = path.parent
+ module_name = path.stem
+
+ # Change sys.path permanently: restoring it at the end of this function would cause surprising
+ # problems because of delayed imports: for example, a conftest.py file imported by this function
+ # might have local imports, which would fail at runtime if we restored sys.path.
+ if mode is ImportMode.append:
+ if str(pkg_root) not in sys.path:
+ sys.path.append(str(pkg_root))
+ elif mode is ImportMode.prepend:
+ if str(pkg_root) != sys.path[0]:
+ sys.path.insert(0, str(pkg_root))
+ else:
+ assert_never(mode)
+
+ importlib.import_module(module_name)
+
+ mod = sys.modules[module_name]
+ if path.name == "__init__.py":
+ return mod
+
+ ignore = os.environ.get("PY_IGNORE_IMPORTMISMATCH", "")
+ if ignore != "1":
+ module_file = mod.__file__
+ if module_file.endswith((".pyc", ".pyo")):
+ module_file = module_file[:-1]
+ if module_file.endswith(os.path.sep + "__init__.py"):
+ module_file = module_file[: -(len(os.path.sep + "__init__.py"))]
+
+ try:
+ is_same = _is_same(str(path), module_file)
+ except FileNotFoundError:
+ is_same = False
+
+ if not is_same:
+ raise ImportPathMismatchError(module_name, module_file, path)
+
+ return mod
+
+
+# Implement a special _is_same function on Windows which returns True if the two filenames
+# compare equal, to circumvent os.path.samefile returning False for mounts in UNC (#7678).
+if sys.platform.startswith("win"):
+
+ def _is_same(f1: str, f2: str) -> bool:
+ return Path(f1) == Path(f2) or os.path.samefile(f1, f2)
+
+
+else:
+
+ def _is_same(f1: str, f2: str) -> bool:
+ return os.path.samefile(f1, f2)
+
+
+def resolve_package_path(path: Path) -> Optional[Path]:
+ """Return the Python package path by looking for the last
+ directory upwards which still contains an __init__.py.
+
+ Returns None if it can not be determined.
+ """
+ result = None
+ for parent in itertools.chain((path,), path.parents):
+ if parent.is_dir():
+ if not parent.joinpath("__init__.py").is_file():
+ break
+ if not parent.name.isidentifier():
+ break
+ result = parent
+ return result
+
+
+def visit(
+ path: str, recurse: Callable[["os.DirEntry[str]"], bool]
+) -> Iterator["os.DirEntry[str]"]:
+ """Walk a directory recursively, in breadth-first order.
+
+ Entries at each directory level are sorted.
+ """
+
+ # Skip entries with symlink loops and other brokenness, so the caller doesn't
+ # have to deal with it.
+ entries = []
+ for entry in os.scandir(path):
+ try:
+ entry.is_file()
+ except OSError as err:
+ if _ignore_error(err):
+ continue
+ raise
+ entries.append(entry)
+
+ entries.sort(key=lambda entry: entry.name)
+
+ yield from entries
+
+ for entry in entries:
+ if entry.is_dir() and recurse(entry):
+ yield from visit(entry.path, recurse)
+
+
+def absolutepath(path: Union[Path, str]) -> Path:
+ """Convert a path to an absolute path using os.path.abspath.
+
+ Prefer this over Path.resolve() (see #6523).
+ Prefer this over Path.absolute() (not public, doesn't normalize).
+ """
+ return Path(os.path.abspath(str(path)))
+
+
+def commonpath(path1: Path, path2: Path) -> Optional[Path]:
+ """Return the common part shared with the other path, or None if there is
+ no common part.
+
+ If one path is relative and one is absolute, returns None.
+ """
+ try:
+ return Path(os.path.commonpath((str(path1), str(path2))))
+ except ValueError:
+ return None
+
+
+def bestrelpath(directory: Path, dest: Path) -> str:
+ """Return a string which is a relative path from directory to dest such
+ that directory/bestrelpath == dest.
+
+ The paths must be either both absolute or both relative.
+
+ If no such path can be determined, returns dest.
+ """
+ if dest == directory:
+ return os.curdir
+ # Find the longest common directory.
+ base = commonpath(directory, dest)
+ # Can be the case on Windows for two absolute paths on different drives.
+ # Can be the case for two relative paths without common prefix.
+ # Can be the case for a relative path and an absolute path.
+ if not base:
+ return str(dest)
+ reldirectory = directory.relative_to(base)
+ reldest = dest.relative_to(base)
+ return os.path.join(
+ # Back from directory to base.
+ *([os.pardir] * len(reldirectory.parts)),
+ # Forward from base to dest.
+ *reldest.parts,
+ )
diff --git a/contrib/python/more-itertools/py2/more_itertools/tests/__init__.py b/contrib/python/pytest/py3/_pytest/py.typed
index e69de29bb2..e69de29bb2 100644
--- a/contrib/python/more-itertools/py2/more_itertools/tests/__init__.py
+++ b/contrib/python/pytest/py3/_pytest/py.typed
diff --git a/contrib/python/pytest/py3/_pytest/pytester.py b/contrib/python/pytest/py3/_pytest/pytester.py
index 9df3ed779d..31259d1bdc 100644
--- a/contrib/python/pytest/py3/_pytest/pytester.py
+++ b/contrib/python/pytest/py3/_pytest/pytester.py
@@ -1,57 +1,84 @@
-"""(disabled by default) support for testing pytest and pytest plugins."""
+"""(Disabled by default) support for testing pytest and pytest plugins.
+
+PYTEST_DONT_REWRITE
+"""
import collections.abc
+import contextlib
import gc
import importlib
import os
import platform
import re
+import shutil
import subprocess
import sys
-import time
import traceback
from fnmatch import fnmatch
from io import StringIO
+from pathlib import Path
+from typing import Any
from typing import Callable
from typing import Dict
+from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
+from typing import overload
from typing import Sequence
+from typing import TextIO
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import Union
from weakref import WeakKeyDictionary
+import attr
import py
+from iniconfig import IniConfig
+from iniconfig import SectionWrapper
-import pytest
+from _pytest import timing
from _pytest._code import Source
-from _pytest.capture import MultiCapture
-from _pytest.capture import SysCapture
-from _pytest.compat import TYPE_CHECKING
+from _pytest.capture import _get_multicapture
+from _pytest.compat import final
from _pytest.config import _PluggyPlugin
+from _pytest.config import Config
from _pytest.config import ExitCode
+from _pytest.config import hookimpl
+from _pytest.config import main
+from _pytest.config import PytestPluginManager
+from _pytest.config.argparsing import Parser
+from _pytest.deprecated import check_ispytest
+from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.monkeypatch import MonkeyPatch
from _pytest.nodes import Collector
from _pytest.nodes import Item
-from _pytest.pathlib import Path
-from _pytest.python import Module
+from _pytest.outcomes import fail
+from _pytest.outcomes import importorskip
+from _pytest.outcomes import skip
+from _pytest.pathlib import make_numbered_dir
+from _pytest.reports import CollectReport
from _pytest.reports import TestReport
-from _pytest.tmpdir import TempdirFactory
+from _pytest.tmpdir import TempPathFactory
+from _pytest.warning_types import PytestWarning
if TYPE_CHECKING:
- from typing import Type
+ from typing_extensions import Literal
import pexpect
+pytest_plugins = ["pytester_assertions"]
+
+
IGNORE_PAM = [ # filenames added when obtaining details about the current user
"/var/lib/sss/mc/passwd"
]
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
parser.addoption(
"--lsof",
action="store_true",
@@ -76,7 +103,7 @@ def pytest_addoption(parser):
)
-def pytest_configure(config):
+def pytest_configure(config: Config) -> None:
if config.getvalue("lsof"):
checker = LsofFdLeakChecker()
if checker.matching_platform():
@@ -90,21 +117,16 @@ def pytest_configure(config):
class LsofFdLeakChecker:
- def get_open_files(self):
- out = self._exec_lsof()
- open_files = self._parse_lsof_output(out)
- return open_files
-
- def _exec_lsof(self):
- pid = os.getpid()
- # py3: use subprocess.DEVNULL directly.
- with open(os.devnull, "wb") as devnull:
- return subprocess.check_output(
- ("lsof", "-Ffn0", "-p", str(pid)), stderr=devnull
- ).decode()
-
- def _parse_lsof_output(self, out):
- def isopen(line):
+ def get_open_files(self) -> List[Tuple[str, str]]:
+ out = subprocess.run(
+ ("lsof", "-Ffn0", "-p", str(os.getpid())),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ check=True,
+ universal_newlines=True,
+ ).stdout
+
+ def isopen(line: str) -> bool:
return line.startswith("f") and (
"deleted" not in line
and "mem" not in line
@@ -126,16 +148,16 @@ class LsofFdLeakChecker:
return open_files
- def matching_platform(self):
+ def matching_platform(self) -> bool:
try:
- subprocess.check_output(("lsof", "-v"))
+ subprocess.run(("lsof", "-v"), check=True)
except (OSError, subprocess.CalledProcessError):
return False
else:
return True
- @pytest.hookimpl(hookwrapper=True, tryfirst=True)
- def pytest_runtest_protocol(self, item):
+ @hookimpl(hookwrapper=True, tryfirst=True)
+ def pytest_runtest_protocol(self, item: Item) -> Generator[None, None, None]:
lines1 = self.get_open_files()
yield
if hasattr(sys, "pypy_version_info"):
@@ -145,61 +167,60 @@ class LsofFdLeakChecker:
new_fds = {t[0] for t in lines2} - {t[0] for t in lines1}
leaked_files = [t for t in lines2 if t[0] in new_fds]
if leaked_files:
- error = []
- error.append("***** %s FD leakage detected" % len(leaked_files))
- error.extend([str(f) for f in leaked_files])
- error.append("*** Before:")
- error.extend([str(f) for f in lines1])
- error.append("*** After:")
- error.extend([str(f) for f in lines2])
- error.append(error[0])
- error.append("*** function %s:%s: %s " % item.location)
- error.append("See issue #2366")
- item.warn(pytest.PytestWarning("\n".join(error)))
+ error = [
+ "***** %s FD leakage detected" % len(leaked_files),
+ *(str(f) for f in leaked_files),
+ "*** Before:",
+ *(str(f) for f in lines1),
+ "*** After:",
+ *(str(f) for f in lines2),
+ "***** %s FD leakage detected" % len(leaked_files),
+ "*** function %s:%s: %s " % item.location,
+ "See issue #2366",
+ ]
+ item.warn(PytestWarning("\n".join(error)))
# used at least by pytest-xdist plugin
-@pytest.fixture
+@fixture
def _pytest(request: FixtureRequest) -> "PytestArg":
"""Return a helper which offers a gethookrecorder(hook) method which
returns a HookRecorder instance which helps to make assertions about called
- hooks.
-
- """
+ hooks."""
return PytestArg(request)
class PytestArg:
def __init__(self, request: FixtureRequest) -> None:
- self.request = request
+ self._request = request
def gethookrecorder(self, hook) -> "HookRecorder":
hookrecorder = HookRecorder(hook._pm)
- self.request.addfinalizer(hookrecorder.finish_recording)
+ self._request.addfinalizer(hookrecorder.finish_recording)
return hookrecorder
-def get_public_names(values):
+def get_public_names(values: Iterable[str]) -> List[str]:
"""Only return names from iterator values without a leading underscore."""
return [x for x in values if x[0] != "_"]
class ParsedCall:
- def __init__(self, name, kwargs):
+ def __init__(self, name: str, kwargs) -> None:
self.__dict__.update(kwargs)
self._name = name
- def __repr__(self):
+ def __repr__(self) -> str:
d = self.__dict__.copy()
del d["_name"]
- return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
+ return f"<ParsedCall {self._name!r}(**{d!r})>"
if TYPE_CHECKING:
# The class has undetermined attributes, this tells mypy about it.
- def __getattr__(self, key):
- raise NotImplementedError()
+ def __getattr__(self, key: str):
+ ...
class HookRecorder:
@@ -207,12 +228,12 @@ class HookRecorder:
This wraps all the hook calls in the plugin manager, recording each call
before propagating the normal calls.
-
"""
- def __init__(self, pluginmanager) -> None:
+ def __init__(self, pluginmanager: PytestPluginManager) -> None:
self._pluginmanager = pluginmanager
- self.calls = [] # type: List[ParsedCall]
+ self.calls: List[ParsedCall] = []
+ self.ret: Optional[Union[int, ExitCode]] = None
def before(hook_name: str, hook_impls, kwargs) -> None:
self.calls.append(ParsedCall(hook_name, kwargs))
@@ -230,7 +251,7 @@ class HookRecorder:
names = names.split()
return [call for call in self.calls if call._name in names]
- def assert_contains(self, entries) -> None:
+ def assert_contains(self, entries: Sequence[Tuple[str, str]]) -> None:
__tracebackhide__ = True
i = 0
entries = list(entries)
@@ -249,7 +270,7 @@ class HookRecorder:
break
print("NONAMEMATCH", name, "with", call)
else:
- pytest.fail("could not find {!r} check {!r}".format(name, check))
+ fail(f"could not find {name!r} check {check!r}")
def popcall(self, name: str) -> ParsedCall:
__tracebackhide__ = True
@@ -257,9 +278,9 @@ class HookRecorder:
if call._name == name:
del self.calls[i]
return call
- lines = ["could not find call {!r}, in:".format(name)]
+ lines = [f"could not find call {name!r}, in:"]
lines.extend([" %s" % x for x in self.calls])
- pytest.fail("\n".join(lines))
+ fail("\n".join(lines))
def getcall(self, name: str) -> ParsedCall:
values = self.getcalls(name)
@@ -268,23 +289,47 @@ class HookRecorder:
# functionality for test reports
+ @overload
+ def getreports(
+ self, names: "Literal['pytest_collectreport']",
+ ) -> Sequence[CollectReport]:
+ ...
+
+ @overload
+ def getreports(
+ self, names: "Literal['pytest_runtest_logreport']",
+ ) -> Sequence[TestReport]:
+ ...
+
+ @overload
+ def getreports(
+ self,
+ names: Union[str, Iterable[str]] = (
+ "pytest_collectreport",
+ "pytest_runtest_logreport",
+ ),
+ ) -> Sequence[Union[CollectReport, TestReport]]:
+ ...
+
def getreports(
self,
- names: Union[
- str, Iterable[str]
- ] = "pytest_runtest_logreport pytest_collectreport",
- ) -> List[TestReport]:
+ names: Union[str, Iterable[str]] = (
+ "pytest_collectreport",
+ "pytest_runtest_logreport",
+ ),
+ ) -> Sequence[Union[CollectReport, TestReport]]:
return [x.report for x in self.getcalls(names)]
def matchreport(
self,
inamepart: str = "",
- names: Union[
- str, Iterable[str]
- ] = "pytest_runtest_logreport pytest_collectreport",
- when=None,
- ):
- """return a testreport whose dotted import path matches"""
+ names: Union[str, Iterable[str]] = (
+ "pytest_runtest_logreport",
+ "pytest_collectreport",
+ ),
+ when: Optional[str] = None,
+ ) -> Union[CollectReport, TestReport]:
+ """Return a testreport whose dotted import path matches."""
values = []
for rep in self.getreports(names=names):
if not when and rep.when != "call" and rep.passed:
@@ -307,31 +352,61 @@ class HookRecorder:
)
return values[0]
+ @overload
+ def getfailures(
+ self, names: "Literal['pytest_collectreport']",
+ ) -> Sequence[CollectReport]:
+ ...
+
+ @overload
+ def getfailures(
+ self, names: "Literal['pytest_runtest_logreport']",
+ ) -> Sequence[TestReport]:
+ ...
+
+ @overload
+ def getfailures(
+ self,
+ names: Union[str, Iterable[str]] = (
+ "pytest_collectreport",
+ "pytest_runtest_logreport",
+ ),
+ ) -> Sequence[Union[CollectReport, TestReport]]:
+ ...
+
def getfailures(
self,
- names: Union[
- str, Iterable[str]
- ] = "pytest_runtest_logreport pytest_collectreport",
- ) -> List[TestReport]:
+ names: Union[str, Iterable[str]] = (
+ "pytest_collectreport",
+ "pytest_runtest_logreport",
+ ),
+ ) -> Sequence[Union[CollectReport, TestReport]]:
return [rep for rep in self.getreports(names) if rep.failed]
- def getfailedcollections(self) -> List[TestReport]:
+ def getfailedcollections(self) -> Sequence[CollectReport]:
return self.getfailures("pytest_collectreport")
def listoutcomes(
self,
- ) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
+ ) -> Tuple[
+ Sequence[TestReport],
+ Sequence[Union[CollectReport, TestReport]],
+ Sequence[Union[CollectReport, TestReport]],
+ ]:
passed = []
skipped = []
failed = []
- for rep in self.getreports("pytest_collectreport pytest_runtest_logreport"):
+ for rep in self.getreports(
+ ("pytest_collectreport", "pytest_runtest_logreport")
+ ):
if rep.passed:
if rep.when == "call":
+ assert isinstance(rep, TestReport)
passed.append(rep)
elif rep.skipped:
skipped.append(rep)
else:
- assert rep.failed, "Unexpected outcome: {!r}".format(rep)
+ assert rep.failed, f"Unexpected outcome: {rep!r}"
failed.append(rep)
return passed, skipped, failed
@@ -340,38 +415,62 @@ class HookRecorder:
def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None:
__tracebackhide__ = True
+ from _pytest.pytester_assertions import assertoutcome
outcomes = self.listoutcomes()
- realpassed, realskipped, realfailed = outcomes
- obtained = {
- "passed": len(realpassed),
- "skipped": len(realskipped),
- "failed": len(realfailed),
- }
- expected = {"passed": passed, "skipped": skipped, "failed": failed}
- assert obtained == expected, outcomes
+ assertoutcome(
+ outcomes, passed=passed, skipped=skipped, failed=failed,
+ )
def clear(self) -> None:
self.calls[:] = []
-@pytest.fixture
+@fixture
def linecomp() -> "LineComp":
+ """A :class: `LineComp` instance for checking that an input linearly
+ contains a sequence of strings."""
return LineComp()
-@pytest.fixture(name="LineMatcher")
-def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]":
+@fixture(name="LineMatcher")
+def LineMatcher_fixture(request: FixtureRequest) -> Type["LineMatcher"]:
+ """A reference to the :class: `LineMatcher`.
+
+ This is instantiable with a list of lines (without their trailing newlines).
+ This is useful for testing large texts, such as the output of commands.
+ """
return LineMatcher
-@pytest.fixture
-def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir":
- return Testdir(request, tmpdir_factory)
+@fixture
+def pytester(request: FixtureRequest, tmp_path_factory: TempPathFactory) -> "Pytester":
+ """
+ Facilities to write tests/configuration files, execute pytest in isolation, and match
+ against expected output, perfect for black-box testing of pytest plugins.
+
+ It attempts to isolate the test run from external factors as much as possible, modifying
+ the current working directory to ``path`` and environment variables during initialization.
+
+ It is particularly useful for testing plugins. It is similar to the :fixture:`tmp_path`
+ fixture but provides methods which aid in testing pytest itself.
+ """
+ return Pytester(request, tmp_path_factory, _ispytest=True)
+
+
+@fixture
+def testdir(pytester: "Pytester") -> "Testdir":
+ """
+ Identical to :fixture:`pytester`, and provides an instance whose methods return
+ legacy ``py.path.local`` objects instead when applicable.
+
+ New code should avoid using :fixture:`testdir` in favor of :fixture:`pytester`.
+ """
+ return Testdir(pytester, _ispytest=True)
-@pytest.fixture
-def _sys_snapshot():
+@fixture
+def _sys_snapshot() -> Generator[None, None, None]:
snappaths = SysPathsSnapshot()
snapmods = SysModulesSnapshot()
yield
@@ -379,8 +478,8 @@ def _sys_snapshot():
snappaths.restore()
-@pytest.fixture
-def _config_for_test():
+@fixture
+def _config_for_test() -> Generator[Config, None, None]:
from _pytest.config import get_config
config = get_config()
@@ -388,26 +487,14 @@ def _config_for_test():
config._ensure_unconfigure() # cleanup, e.g. capman closing tmpfiles.
-# regex to match the session duration string in the summary: "74.34s"
+# Regex to match the session duration string in the summary: "74.34s".
rex_session_duration = re.compile(r"\d+\.\d\ds")
-# regex to match all the counts and phrases in the summary line: "34 passed, 111 skipped"
+# Regex to match all the counts and phrases in the summary line: "34 passed, 111 skipped".
rex_outcome = re.compile(r"(\d+) (\w+)")
class RunResult:
- """The result of running a command.
-
- Attributes:
-
- :ivar ret: the return value
- :ivar outlines: list of lines captured from stdout
- :ivar errlines: list of lines captured from stderr
- :ivar stdout: :py:class:`LineMatcher` of stdout, use ``stdout.str()`` to
- reconstruct stdout or the commonly used ``stdout.fnmatch_lines()``
- method
- :ivar stderr: :py:class:`LineMatcher` of stderr
- :ivar duration: duration in seconds
- """
+ """The result of running a command."""
def __init__(
self,
@@ -417,14 +504,24 @@ class RunResult:
duration: float,
) -> None:
try:
- self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode]
+ self.ret: Union[int, ExitCode] = ExitCode(ret)
+ """The return value."""
except ValueError:
self.ret = ret
self.outlines = outlines
+ """List of lines captured from stdout."""
self.errlines = errlines
+ """List of lines captured from stderr."""
self.stdout = LineMatcher(outlines)
+ """:class:`LineMatcher` of stdout.
+
+ Use e.g. :func:`str(stdout) <LineMatcher.__str__()>` to reconstruct stdout, or the commonly used
+ :func:`stdout.fnmatch_lines() <LineMatcher.fnmatch_lines()>` method.
+ """
self.stderr = LineMatcher(errlines)
+ """:class:`LineMatcher` of stderr."""
self.duration = duration
+ """Duration in seconds."""
def __repr__(self) -> str:
return (
@@ -433,54 +530,65 @@ class RunResult:
)
def parseoutcomes(self) -> Dict[str, int]:
- """Return a dictionary of outcomestring->num from parsing the terminal
+ """Return a dictionary of outcome noun -> count from parsing the terminal
output that the test process produced.
+ The returned nouns will always be in plural form::
+
+ ======= 1 failed, 1 passed, 1 warning, 1 error in 0.13s ====
+
+ Will return ``{"failed": 1, "passed": 1, "warnings": 1, "errors": 1}``.
+ """
+ return self.parse_summary_nouns(self.outlines)
+
+ @classmethod
+ def parse_summary_nouns(cls, lines) -> Dict[str, int]:
+ """Extract the nouns from a pytest terminal summary line.
+
+ It always returns the plural noun for consistency::
+
+ ======= 1 failed, 1 passed, 1 warning, 1 error in 0.13s ====
+
+ Will return ``{"failed": 1, "passed": 1, "warnings": 1, "errors": 1}``.
"""
- for line in reversed(self.outlines):
+ for line in reversed(lines):
if rex_session_duration.search(line):
outcomes = rex_outcome.findall(line)
ret = {noun: int(count) for (count, noun) in outcomes}
break
else:
raise ValueError("Pytest terminal summary report not found")
- if "errors" in ret:
- assert "error" not in ret
- ret["error"] = ret.pop("errors")
- return ret
+
+ to_plural = {
+ "warning": "warnings",
+ "error": "errors",
+ }
+ return {to_plural.get(k, k): v for k, v in ret.items()}
def assert_outcomes(
self,
passed: int = 0,
skipped: int = 0,
failed: int = 0,
- error: int = 0,
+ errors: int = 0,
xpassed: int = 0,
xfailed: int = 0,
) -> None:
"""Assert that the specified outcomes appear with the respective
- numbers (0 means it didn't occur) in the text output from a test run.
- """
+ numbers (0 means it didn't occur) in the text output from a test run."""
__tracebackhide__ = True
-
- d = self.parseoutcomes()
- obtained = {
- "passed": d.get("passed", 0),
- "skipped": d.get("skipped", 0),
- "failed": d.get("failed", 0),
- "error": d.get("error", 0),
- "xpassed": d.get("xpassed", 0),
- "xfailed": d.get("xfailed", 0),
- }
- expected = {
- "passed": passed,
- "skipped": skipped,
- "failed": failed,
- "error": error,
- "xpassed": xpassed,
- "xfailed": xfailed,
- }
- assert obtained == expected
+ from _pytest.pytester_assertions import assert_outcomes
+
+ outcomes = self.parseoutcomes()
+ assert_outcomes(
+ outcomes,
+ passed=passed,
+ skipped=skipped,
+ failed=failed,
+ errors=errors,
+ xpassed=xpassed,
+ xfailed=xfailed,
+ )
class CwdSnapshot:
@@ -492,7 +600,7 @@ class CwdSnapshot:
class SysModulesSnapshot:
- def __init__(self, preserve: Optional[Callable[[str], bool]] = None):
+ def __init__(self, preserve: Optional[Callable[[str], bool]] = None) -> None:
self.__preserve = preserve
self.__saved = dict(sys.modules)
@@ -513,22 +621,24 @@ class SysPathsSnapshot:
sys.path[:], sys.meta_path[:] = self.__saved
-class Testdir:
- """Temporary test directory with tools to test/run pytest itself.
+@final
+class Pytester:
+ """
+ Facilities to write tests/configuration files, execute pytest in isolation, and match
+ against expected output, perfect for black-box testing of pytest plugins.
- This is based on the ``tmpdir`` fixture but provides a number of methods
- which aid with testing pytest itself. Unless :py:meth:`chdir` is used all
- methods will use :py:attr:`tmpdir` as their current working directory.
+ It attempts to isolate the test run from external factors as much as possible, modifying
+ the current working directory to ``path`` and environment variables during initialization.
Attributes:
- :ivar tmpdir: The :py:class:`py.path.local` instance of the temporary directory.
+ :ivar Path path: temporary directory path used to create files/run tests from, etc.
- :ivar plugins: A list of plugins to use with :py:meth:`parseconfig` and
+ :ivar plugins:
+ A list of plugins to use with :py:meth:`parseconfig` and
:py:meth:`runpytest`. Initially this is an empty list but plugins can
be added to the list. The type of items to add to the list depends on
the method using them so refer to them for details.
-
"""
__test__ = False
@@ -538,85 +648,102 @@ class Testdir:
class TimeoutExpired(Exception):
pass
- def __init__(self, request: FixtureRequest, tmpdir_factory: TempdirFactory) -> None:
- self.request = request
- self._mod_collections = (
- WeakKeyDictionary()
- ) # type: WeakKeyDictionary[Module, List[Union[Item, Collector]]]
+ def __init__(
+ self,
+ request: FixtureRequest,
+ tmp_path_factory: TempPathFactory,
+ *,
+ _ispytest: bool = False,
+ ) -> None:
+ check_ispytest(_ispytest)
+ self._request = request
+ self._mod_collections: WeakKeyDictionary[
+ Collector, List[Union[Item, Collector]]
+ ] = (WeakKeyDictionary())
if request.function:
- name = request.function.__name__ # type: str
+ name: str = request.function.__name__
else:
name = request.node.name
self._name = name
- self.tmpdir = tmpdir_factory.mktemp(name, numbered=True)
- self.test_tmproot = tmpdir_factory.mktemp("tmp-" + name, numbered=True)
- self.plugins = [] # type: List[Union[str, _PluggyPlugin]]
+ self._path: Path = tmp_path_factory.mktemp(name, numbered=True)
+ self.plugins: List[Union[str, _PluggyPlugin]] = []
self._cwd_snapshot = CwdSnapshot()
self._sys_path_snapshot = SysPathsSnapshot()
self._sys_modules_snapshot = self.__take_sys_modules_snapshot()
self.chdir()
- self.request.addfinalizer(self.finalize)
- self._method = self.request.config.getoption("--runpytest")
+ self._request.addfinalizer(self._finalize)
+ self._method = self._request.config.getoption("--runpytest")
+ self._test_tmproot = tmp_path_factory.mktemp(f"tmp-{name}", numbered=True)
- mp = self.monkeypatch = MonkeyPatch()
- mp.setenv("PYTEST_DEBUG_TEMPROOT", str(self.test_tmproot))
+ self._monkeypatch = mp = MonkeyPatch()
+ mp.setenv("PYTEST_DEBUG_TEMPROOT", str(self._test_tmproot))
# Ensure no unexpected caching via tox.
mp.delenv("TOX_ENV_DIR", raising=False)
# Discard outer pytest options.
mp.delenv("PYTEST_ADDOPTS", raising=False)
# Ensure no user config is used.
- tmphome = str(self.tmpdir)
+ tmphome = str(self.path)
mp.setenv("HOME", tmphome)
mp.setenv("USERPROFILE", tmphome)
# Do not use colors for inner runs by default.
mp.setenv("PY_COLORS", "0")
- def __repr__(self):
- return "<Testdir {!r}>".format(self.tmpdir)
+ @property
+ def path(self) -> Path:
+ """Temporary directory where files are created and pytest is executed."""
+ return self._path
- def __str__(self):
- return str(self.tmpdir)
+ def __repr__(self) -> str:
+ return f"<Pytester {self.path!r}>"
- def finalize(self):
- """Clean up global state artifacts.
+ def _finalize(self) -> None:
+ """
+ Clean up global state artifacts.
Some methods modify the global interpreter state and this tries to
- clean this up. It does not remove the temporary directory however so
+ clean this up. It does not remove the temporary directory however so
it can be looked at after the test run has finished.
-
"""
self._sys_modules_snapshot.restore()
self._sys_path_snapshot.restore()
self._cwd_snapshot.restore()
- self.monkeypatch.undo()
+ self._monkeypatch.undo()
- def __take_sys_modules_snapshot(self):
- # some zope modules used by twisted-related tests keep internal state
+ def __take_sys_modules_snapshot(self) -> SysModulesSnapshot:
+ # Some zope modules used by twisted-related tests keep internal state
# and can't be deleted; we had some trouble in the past with
- # `zope.interface` for example
+ # `zope.interface` for example.
+ #
+ # Preserve readline due to https://bugs.python.org/issue41033.
+ # pexpect issues a SIGWINCH.
def preserve_module(name):
- return name.startswith("zope")
+ return name.startswith(("zope", "readline"))
return SysModulesSnapshot(preserve=preserve_module)
- def make_hook_recorder(self, pluginmanager):
+ def make_hook_recorder(self, pluginmanager: PytestPluginManager) -> HookRecorder:
"""Create a new :py:class:`HookRecorder` for a PluginManager."""
pluginmanager.reprec = reprec = HookRecorder(pluginmanager)
- self.request.addfinalizer(reprec.finish_recording)
+ self._request.addfinalizer(reprec.finish_recording)
return reprec
- def chdir(self):
+ def chdir(self) -> None:
"""Cd into the temporary directory.
This is done automatically upon instantiation.
-
"""
- self.tmpdir.chdir()
+ os.chdir(self.path)
- def _makefile(self, ext, lines, files, encoding="utf-8"):
+ def _makefile(
+ self,
+ ext: str,
+ lines: Sequence[Union[Any, bytes]],
+ files: Dict[str, str],
+ encoding: str = "utf-8",
+ ) -> Path:
items = list(files.items())
- def to_text(s):
+ def to_text(s: Union[Any, bytes]) -> str:
return s.decode(encoding) if isinstance(s, bytes) else str(s)
if lines:
@@ -626,144 +753,189 @@ class Testdir:
ret = None
for basename, value in items:
- p = self.tmpdir.join(basename).new(ext=ext)
- p.dirpath().ensure_dir()
- source = Source(value)
- source = "\n".join(to_text(line) for line in source.lines)
- p.write(source.strip().encode(encoding), "wb")
+ p = self.path.joinpath(basename).with_suffix(ext)
+ p.parent.mkdir(parents=True, exist_ok=True)
+ source_ = Source(value)
+ source = "\n".join(to_text(line) for line in source_.lines)
+ p.write_text(source.strip(), encoding=encoding)
if ret is None:
ret = p
+ assert ret is not None
return ret
- def makefile(self, ext, *args, **kwargs):
- r"""Create new file(s) in the testdir.
+ def makefile(self, ext: str, *args: str, **kwargs: str) -> Path:
+ r"""Create new file(s) in the test directory.
- :param str ext: The extension the file(s) should use, including the dot, e.g. `.py`.
- :param list[str] args: All args will be treated as strings and joined using newlines.
- The result will be written as contents to the file. The name of the
- file will be based on the test function requesting this fixture.
- :param kwargs: Each keyword is the name of a file, while the value of it will
- be written as contents of the file.
+ :param str ext:
+ The extension the file(s) should use, including the dot, e.g. `.py`.
+ :param args:
+ All args are treated as strings and joined using newlines.
+ The result is written as contents to the file. The name of the
+ file is based on the test function requesting this fixture.
+ :param kwargs:
+ Each keyword is the name of a file, while the value of it will
+ be written as contents of the file.
Examples:
.. code-block:: python
- testdir.makefile(".txt", "line1", "line2")
+ pytester.makefile(".txt", "line1", "line2")
- testdir.makefile(".ini", pytest="[pytest]\naddopts=-rs\n")
+ pytester.makefile(".ini", pytest="[pytest]\naddopts=-rs\n")
"""
return self._makefile(ext, args, kwargs)
- def makeconftest(self, source):
+ def makeconftest(self, source: str) -> Path:
"""Write a contest.py file with 'source' as contents."""
return self.makepyfile(conftest=source)
- def makeini(self, source):
+ def makeini(self, source: str) -> Path:
"""Write a tox.ini file with 'source' as contents."""
return self.makefile(".ini", tox=source)
- def getinicfg(self, source):
+ def getinicfg(self, source: str) -> SectionWrapper:
"""Return the pytest section from the tox.ini config file."""
p = self.makeini(source)
- return py.iniconfig.IniConfig(p)["pytest"]
+ return IniConfig(str(p))["pytest"]
+
+ def makepyprojecttoml(self, source: str) -> Path:
+ """Write a pyproject.toml file with 'source' as contents.
+
+ .. versionadded:: 6.0
+ """
+ return self.makefile(".toml", pyproject=source)
+
+ def makepyfile(self, *args, **kwargs) -> Path:
+ r"""Shortcut for .makefile() with a .py extension.
+
+ Defaults to the test name with a '.py' extension, e.g test_foobar.py, overwriting
+ existing files.
+
+ Examples:
+
+ .. code-block:: python
+
+ def test_something(pytester):
+ # Initial file is created test_something.py.
+ pytester.makepyfile("foobar")
+ # To create multiple files, pass kwargs accordingly.
+ pytester.makepyfile(custom="foobar")
+ # At this point, both 'test_something.py' & 'custom.py' exist in the test directory.
- def makepyfile(self, *args, **kwargs):
- """Shortcut for .makefile() with a .py extension."""
+ """
return self._makefile(".py", args, kwargs)
- def maketxtfile(self, *args, **kwargs):
- """Shortcut for .makefile() with a .txt extension."""
+ def maketxtfile(self, *args, **kwargs) -> Path:
+ r"""Shortcut for .makefile() with a .txt extension.
+
+ Defaults to the test name with a '.txt' extension, e.g test_foobar.txt, overwriting
+ existing files.
+
+ Examples:
+
+ .. code-block:: python
+
+ def test_something(pytester):
+ # Initial file is created test_something.txt.
+ pytester.maketxtfile("foobar")
+ # To create multiple files, pass kwargs accordingly.
+ pytester.maketxtfile(custom="foobar")
+ # At this point, both 'test_something.txt' & 'custom.txt' exist in the test directory.
+
+ """
return self._makefile(".txt", args, kwargs)
- def syspathinsert(self, path=None):
+ def syspathinsert(
+ self, path: Optional[Union[str, "os.PathLike[str]"]] = None
+ ) -> None:
"""Prepend a directory to sys.path, defaults to :py:attr:`tmpdir`.
This is undone automatically when this object dies at the end of each
test.
"""
if path is None:
- path = self.tmpdir
+ path = self.path
- self.monkeypatch.syspath_prepend(str(path))
+ self._monkeypatch.syspath_prepend(str(path))
- def mkdir(self, name):
+ def mkdir(self, name: str) -> Path:
"""Create a new (sub)directory."""
- return self.tmpdir.mkdir(name)
+ p = self.path / name
+ p.mkdir()
+ return p
- def mkpydir(self, name):
+ def mkpydir(self, name: str) -> Path:
"""Create a new python package.
This creates a (sub)directory with an empty ``__init__.py`` file so it
- gets recognised as a python package.
-
+ gets recognised as a Python package.
"""
- p = self.mkdir(name)
- p.ensure("__init__.py")
+ p = self.path / name
+ p.mkdir()
+ p.joinpath("__init__.py").touch()
return p
- def copy_example(self, name=None):
+ def copy_example(self, name: Optional[str] = None) -> Path:
"""Copy file from project's directory into the testdir.
:param str name: The name of the file to copy.
- :return: path to the copied directory (inside ``self.tmpdir``).
+ :return: path to the copied directory (inside ``self.path``).
"""
- import warnings
- from _pytest.warning_types import PYTESTER_COPY_EXAMPLE
-
- warnings.warn(PYTESTER_COPY_EXAMPLE, stacklevel=2)
- example_dir = self.request.config.getini("pytester_example_dir")
+ example_dir = self._request.config.getini("pytester_example_dir")
if example_dir is None:
raise ValueError("pytester_example_dir is unset, can't copy examples")
- example_dir = self.request.config.rootdir.join(example_dir)
+ example_dir = Path(str(self._request.config.rootdir)) / example_dir
- for extra_element in self.request.node.iter_markers("pytester_example_path"):
+ for extra_element in self._request.node.iter_markers("pytester_example_path"):
assert extra_element.args
- example_dir = example_dir.join(*extra_element.args)
+ example_dir = example_dir.joinpath(*extra_element.args)
if name is None:
func_name = self._name
maybe_dir = example_dir / func_name
maybe_file = example_dir / (func_name + ".py")
- if maybe_dir.isdir():
+ if maybe_dir.is_dir():
example_path = maybe_dir
- elif maybe_file.isfile():
+ elif maybe_file.is_file():
example_path = maybe_file
else:
raise LookupError(
- "{} cant be found as module or package in {}".format(
- func_name, example_dir.bestrelpath(self.request.config.rootdir)
- )
+ f"{func_name} can't be found as module or package in {example_dir}"
)
else:
- example_path = example_dir.join(name)
-
- if example_path.isdir() and not example_path.join("__init__.py").isfile():
- example_path.copy(self.tmpdir)
- return self.tmpdir
- elif example_path.isfile():
- result = self.tmpdir.join(example_path.basename)
- example_path.copy(result)
+ example_path = example_dir.joinpath(name)
+
+ if example_path.is_dir() and not example_path.joinpath("__init__.py").is_file():
+ # TODO: py.path.local.copy can copy files to existing directories,
+ # while with shutil.copytree the destination directory cannot exist,
+ # we will need to roll our own in order to drop py.path.local completely
+ py.path.local(example_path).copy(py.path.local(self.path))
+ return self.path
+ elif example_path.is_file():
+ result = self.path.joinpath(example_path.name)
+ shutil.copy(example_path, result)
return result
else:
raise LookupError(
- 'example "{}" is not found as a file or directory'.format(example_path)
+ f'example "{example_path}" is not found as a file or directory'
)
Session = Session
- def getnode(self, config, arg):
+ def getnode(
+ self, config: Config, arg: Union[str, "os.PathLike[str]"]
+ ) -> Optional[Union[Collector, Item]]:
"""Return the collection node of a file.
- :param config: :py:class:`_pytest.config.Config` instance, see
- :py:meth:`parseconfig` and :py:meth:`parseconfigure` to create the
- configuration
-
- :param arg: a :py:class:`py.path.local` instance of the file
-
+ :param _pytest.config.Config config:
+ A pytest config.
+ See :py:meth:`parseconfig` and :py:meth:`parseconfigure` for creating it.
+ :param py.path.local arg:
+ Path to the file.
"""
session = Session.from_config(config)
assert "::" not in str(arg)
@@ -773,15 +945,15 @@ class Testdir:
config.hook.pytest_sessionfinish(session=session, exitstatus=ExitCode.OK)
return res
- def getpathnode(self, path):
+ def getpathnode(self, path: Union[str, "os.PathLike[str]"]):
"""Return the collection node of a file.
This is like :py:meth:`getnode` but uses :py:meth:`parseconfigure` to
create the (configured) pytest Config instance.
- :param path: a :py:class:`py.path.local` instance of the file
-
+ :param py.path.local path: Path to the file.
"""
+ path = py.path.local(path)
config = self.parseconfigure(path)
session = Session.from_config(config)
x = session.fspath.bestrelpath(path)
@@ -790,66 +962,67 @@ class Testdir:
config.hook.pytest_sessionfinish(session=session, exitstatus=ExitCode.OK)
return res
- def genitems(self, colitems):
+ def genitems(self, colitems: Sequence[Union[Item, Collector]]) -> List[Item]:
"""Generate all test items from a collection node.
This recurses into the collection node and returns a list of all the
test items contained within.
-
"""
session = colitems[0].session
- result = []
+ result: List[Item] = []
for colitem in colitems:
result.extend(session.genitems(colitem))
return result
- def runitem(self, source):
+ def runitem(self, source: str) -> Any:
"""Run the "test_func" Item.
The calling test instance (class containing the test method) must
provide a ``.getrunner()`` method which should return a runner which
can run the test protocol for a single item, e.g.
:py:func:`_pytest.runner.runtestprotocol`.
-
"""
# used from runner functional tests
item = self.getitem(source)
# the test class where we are called from wants to provide the runner
- testclassinstance = self.request.instance
+ testclassinstance = self._request.instance
runner = testclassinstance.getrunner()
return runner(item)
- def inline_runsource(self, source, *cmdlineargs):
+ def inline_runsource(self, source: str, *cmdlineargs) -> HookRecorder:
"""Run a test module in process using ``pytest.main()``.
This run writes "source" into a temporary file and runs
``pytest.main()`` on it, returning a :py:class:`HookRecorder` instance
for the result.
- :param source: the source code of the test module
+ :param source: The source code of the test module.
- :param cmdlineargs: any extra command line arguments to use
-
- :return: :py:class:`HookRecorder` instance of the result
+ :param cmdlineargs: Any extra command line arguments to use.
+ :returns: :py:class:`HookRecorder` instance of the result.
"""
p = self.makepyfile(source)
values = list(cmdlineargs) + [p]
return self.inline_run(*values)
- def inline_genitems(self, *args):
+ def inline_genitems(self, *args) -> Tuple[List[Item], HookRecorder]:
"""Run ``pytest.main(['--collectonly'])`` in-process.
Runs the :py:func:`pytest.main` function to run all of pytest inside
the test process itself like :py:meth:`inline_run`, but returns a
tuple of the collected items and a :py:class:`HookRecorder` instance.
-
"""
rec = self.inline_run("--collect-only", *args)
items = [x.item for x in rec.getcalls("pytest_itemcollected")]
return items, rec
- def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False):
+ def inline_run(
+ self,
+ *args: Union[str, "os.PathLike[str]"],
+ plugins=(),
+ no_reraise_ctrlc: bool = False,
+ ) -> HookRecorder:
"""Run ``pytest.main()`` in-process, returning a HookRecorder.
Runs the :py:func:`pytest.main` function to run all of pytest inside
@@ -858,14 +1031,15 @@ class Testdir:
from that run than can be done by matching stdout/stderr from
:py:meth:`runpytest`.
- :param args: command line arguments to pass to :py:func:`pytest.main`
-
- :kwarg plugins: extra plugin instances the ``pytest.main()`` instance should use.
-
- :kwarg no_reraise_ctrlc: typically we reraise keyboard interrupts from the child run. If
+ :param args:
+ Command line arguments to pass to :py:func:`pytest.main`.
+ :param plugins:
+ Extra plugin instances the ``pytest.main()`` instance should use.
+ :param no_reraise_ctrlc:
+ Typically we reraise keyboard interrupts from the child run. If
True, the KeyboardInterrupt exception is captured.
- :return: a :py:class:`HookRecorder` instance
+ :returns: A :py:class:`HookRecorder` instance.
"""
# (maybe a cpython bug?) the importlib cache sometimes isn't updated
# properly between file creation and inline_run (especially if imports
@@ -891,11 +1065,11 @@ class Testdir:
rec = []
class Collect:
- def pytest_configure(x, config):
+ def pytest_configure(x, config: Config) -> None:
rec.append(self.make_hook_recorder(config.pluginmanager))
plugins.append(Collect())
- ret = pytest.main(list(args), plugins=plugins)
+ ret = main([str(x) for x in args], plugins=plugins)
if len(rec) == 1:
reprec = rec.pop()
else:
@@ -903,10 +1077,10 @@ class Testdir:
class reprec: # type: ignore
pass
- reprec.ret = ret
+ reprec.ret = ret # type: ignore
- # typically we reraise keyboard interrupts from the child run
- # because it's our user requesting interruption of the testing
+ # Typically we reraise keyboard interrupts from the child run
+ # because it's our user requesting interruption of the testing.
if ret == ExitCode.INTERRUPTED and not no_reraise_ctrlc:
calls = reprec.getcalls("pytest_keyboard_interrupt")
if calls and calls[-1].excinfo.type == KeyboardInterrupt:
@@ -916,16 +1090,17 @@ class Testdir:
for finalizer in finalizers:
finalizer()
- def runpytest_inprocess(self, *args, **kwargs) -> RunResult:
+ def runpytest_inprocess(
+ self, *args: Union[str, "os.PathLike[str]"], **kwargs: Any
+ ) -> RunResult:
"""Return result of running pytest in-process, providing a similar
- interface to what self.runpytest() provides.
- """
+ interface to what self.runpytest() provides."""
syspathinsert = kwargs.pop("syspathinsert", False)
if syspathinsert:
self.syspathinsert()
- now = time.time()
- capture = MultiCapture(Capture=SysCapture)
+ now = timing.time()
+ capture = _get_multicapture("sys")
capture.start_capturing()
try:
try:
@@ -952,34 +1127,37 @@ class Testdir:
sys.stdout.write(out)
sys.stderr.write(err)
+ assert reprec.ret is not None
res = RunResult(
- reprec.ret, out.splitlines(), err.splitlines(), time.time() - now
+ reprec.ret, out.splitlines(), err.splitlines(), timing.time() - now
)
res.reprec = reprec # type: ignore
return res
- def runpytest(self, *args, **kwargs) -> RunResult:
+ def runpytest(
+ self, *args: Union[str, "os.PathLike[str]"], **kwargs: Any
+ ) -> RunResult:
"""Run pytest inline or in a subprocess, depending on the command line
- option "--runpytest" and return a :py:class:`RunResult`.
-
- """
- args = self._ensure_basetemp(args)
+ option "--runpytest" and return a :py:class:`RunResult`."""
+ new_args = self._ensure_basetemp(args)
if self._method == "inprocess":
- return self.runpytest_inprocess(*args, **kwargs)
+ return self.runpytest_inprocess(*new_args, **kwargs)
elif self._method == "subprocess":
- return self.runpytest_subprocess(*args, **kwargs)
- raise RuntimeError("Unrecognized runpytest option: {}".format(self._method))
-
- def _ensure_basetemp(self, args):
- args = list(args)
- for x in args:
+ return self.runpytest_subprocess(*new_args, **kwargs)
+ raise RuntimeError(f"Unrecognized runpytest option: {self._method}")
+
+ def _ensure_basetemp(
+ self, args: Sequence[Union[str, "os.PathLike[str]"]]
+ ) -> List[Union[str, "os.PathLike[str]"]]:
+ new_args = list(args)
+ for x in new_args:
if str(x).startswith("--basetemp"):
break
else:
- args.append("--basetemp=%s" % self.tmpdir.dirpath("basetemp"))
- return args
+ new_args.append("--basetemp=%s" % self.path.parent.joinpath("basetemp"))
+ return new_args
- def parseconfig(self, *args):
+ def parseconfig(self, *args: Union[str, "os.PathLike[str]"]) -> Config:
"""Return a new pytest Config instance from given commandline args.
This invokes the pytest bootstrapping code in _pytest.config to create
@@ -989,41 +1167,40 @@ class Testdir:
If :py:attr:`plugins` has been populated they should be plugin modules
to be registered with the PluginManager.
-
"""
- args = self._ensure_basetemp(args)
-
import _pytest.config
- config = _pytest.config._prepareconfig(args, self.plugins)
+ new_args = self._ensure_basetemp(args)
+ new_args = [str(x) for x in new_args]
+
+ config = _pytest.config._prepareconfig(new_args, self.plugins) # type: ignore[arg-type]
# we don't know what the test will do with this half-setup config
# object and thus we make sure it gets unconfigured properly in any
# case (otherwise capturing could still be active, for example)
- self.request.addfinalizer(config._ensure_unconfigure)
+ self._request.addfinalizer(config._ensure_unconfigure)
return config
- def parseconfigure(self, *args):
+ def parseconfigure(self, *args: Union[str, "os.PathLike[str]"]) -> Config:
"""Return a new pytest configured Config instance.
- This returns a new :py:class:`_pytest.config.Config` instance like
+ Returns a new :py:class:`_pytest.config.Config` instance like
:py:meth:`parseconfig`, but also calls the pytest_configure hook.
"""
config = self.parseconfig(*args)
config._do_configure()
return config
- def getitem(self, source, funcname="test_func"):
+ def getitem(self, source: str, funcname: str = "test_func") -> Item:
"""Return the test item for a test function.
- This writes the source to a python file and runs pytest's collection on
+ Writes the source to a python file and runs pytest's collection on
the resulting module, returning the test item for the requested
function name.
- :param source: the module source
-
- :param funcname: the name of the test function for which to return a
- test item
-
+ :param source:
+ The module source.
+ :param funcname:
+ The name of the test function for which to return a test item.
"""
items = self.getitems(source)
for item in items:
@@ -1033,37 +1210,39 @@ class Testdir:
funcname, source, items
)
- def getitems(self, source):
+ def getitems(self, source: str) -> List[Item]:
"""Return all test items collected from the module.
- This writes the source to a python file and runs pytest's collection on
+ Writes the source to a Python file and runs pytest's collection on
the resulting module, returning all test items contained within.
-
"""
modcol = self.getmodulecol(source)
return self.genitems([modcol])
- def getmodulecol(self, source, configargs=(), withinit=False):
+ def getmodulecol(
+ self, source: Union[str, Path], configargs=(), *, withinit: bool = False
+ ):
"""Return the module collection node for ``source``.
- This writes ``source`` to a file using :py:meth:`makepyfile` and then
+ Writes ``source`` to a file using :py:meth:`makepyfile` and then
runs the pytest collection on it, returning the collection node for the
test module.
- :param source: the source code of the module to collect
-
- :param configargs: any extra arguments to pass to
- :py:meth:`parseconfigure`
+ :param source:
+ The source code of the module to collect.
- :param withinit: whether to also write an ``__init__.py`` file to the
- same directory to ensure it is a package
+ :param configargs:
+ Any extra arguments to pass to :py:meth:`parseconfigure`.
+ :param withinit:
+ Whether to also write an ``__init__.py`` file to the same
+ directory to ensure it is a package.
"""
if isinstance(source, Path):
- path = self.tmpdir.join(str(source))
+ path = self.path.joinpath(source)
assert not withinit, "not supported for paths"
else:
- kw = {self._name: Source(source).strip()}
+ kw = {self._name: str(source)}
path = self.makepyfile(**kw)
if withinit:
self.makepyfile(__init__="#")
@@ -1071,16 +1250,15 @@ class Testdir:
return self.getnode(config, path)
def collect_by_name(
- self, modcol: Module, name: str
+ self, modcol: Collector, name: str
) -> Optional[Union[Item, Collector]]:
"""Return the collection node for name from the module collection.
- This will search a module collection node for a collection node
- matching the given name.
+ Searchs a module collection node for a collection node matching the
+ given name.
- :param modcol: a module collection node; see :py:meth:`getmodulecol`
-
- :param name: the name of the node to return
+ :param modcol: A module collection node; see :py:meth:`getmodulecol`.
+ :param name: The name of the node to return.
"""
if modcol not in self._mod_collections:
self._mod_collections[modcol] = list(modcol.collect())
@@ -1092,18 +1270,17 @@ class Testdir:
def popen(
self,
cmdargs,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
+ stdout: Union[int, TextIO] = subprocess.PIPE,
+ stderr: Union[int, TextIO] = subprocess.PIPE,
stdin=CLOSE_STDIN,
- **kw
+ **kw,
):
"""Invoke subprocess.Popen.
- This calls subprocess.Popen making sure the current working directory
- is in the PYTHONPATH.
+ Calls subprocess.Popen making sure the current working directory is
+ in the PYTHONPATH.
You probably want to use :py:meth:`run` instead.
-
"""
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join(
@@ -1111,7 +1288,7 @@ class Testdir:
)
kw["env"] = env
- if stdin is Testdir.CLOSE_STDIN:
+ if stdin is self.CLOSE_STDIN:
kw["stdin"] = subprocess.PIPE
elif isinstance(stdin, bytes):
kw["stdin"] = subprocess.PIPE
@@ -1119,42 +1296,53 @@ class Testdir:
kw["stdin"] = stdin
popen = subprocess.Popen(cmdargs, stdout=stdout, stderr=stderr, **kw)
- if stdin is Testdir.CLOSE_STDIN:
+ if stdin is self.CLOSE_STDIN:
+ assert popen.stdin is not None
popen.stdin.close()
elif isinstance(stdin, bytes):
+ assert popen.stdin is not None
popen.stdin.write(stdin)
return popen
- def run(self, *cmdargs, timeout=None, stdin=CLOSE_STDIN) -> RunResult:
+ def run(
+ self,
+ *cmdargs: Union[str, "os.PathLike[str]"],
+ timeout: Optional[float] = None,
+ stdin=CLOSE_STDIN,
+ ) -> RunResult:
"""Run a command with arguments.
Run a process using subprocess.Popen saving the stdout and stderr.
- :param args: the sequence of arguments to pass to `subprocess.Popen()`
- :kwarg timeout: the period in seconds after which to timeout and raise
- :py:class:`Testdir.TimeoutExpired`
- :kwarg stdin: optional standard input. Bytes are being send, closing
+ :param cmdargs:
+ The sequence of arguments to pass to `subprocess.Popen()`, with path-like objects
+ being converted to ``str`` automatically.
+ :param timeout:
+ The period in seconds after which to timeout and raise
+ :py:class:`Pytester.TimeoutExpired`.
+ :param stdin:
+ Optional standard input. Bytes are being send, closing
the pipe, otherwise it is passed through to ``popen``.
Defaults to ``CLOSE_STDIN``, which translates to using a pipe
(``subprocess.PIPE``) that gets closed.
- Returns a :py:class:`RunResult`.
-
+ :rtype: RunResult
"""
__tracebackhide__ = True
+ # TODO: Remove type ignore in next mypy release.
+ # https://github.com/python/typeshed/pull/4582
cmdargs = tuple(
- str(arg) if isinstance(arg, py.path.local) else arg for arg in cmdargs
+ os.fspath(arg) if isinstance(arg, os.PathLike) else arg for arg in cmdargs # type: ignore[misc]
)
- p1 = self.tmpdir.join("stdout")
- p2 = self.tmpdir.join("stderr")
+ p1 = self.path.joinpath("stdout")
+ p2 = self.path.joinpath("stderr")
print("running:", *cmdargs)
- print(" in:", py.path.local())
- f1 = open(str(p1), "w", encoding="utf8")
- f2 = open(str(p2), "w", encoding="utf8")
- try:
- now = time.time()
+ print(" in:", Path.cwd())
+
+ with p1.open("w", encoding="utf8") as f1, p2.open("w", encoding="utf8") as f2:
+ now = timing.time()
popen = self.popen(
cmdargs,
stdin=stdin,
@@ -1162,10 +1350,10 @@ class Testdir:
stderr=f2,
close_fds=(sys.platform != "win32"),
)
- if isinstance(stdin, bytes):
+ if popen.stdin is not None:
popen.stdin.close()
- def handle_timeout():
+ def handle_timeout() -> None:
__tracebackhide__ = True
timeout_message = (
@@ -1184,48 +1372,43 @@ class Testdir:
ret = popen.wait(timeout)
except subprocess.TimeoutExpired:
handle_timeout()
- finally:
- f1.close()
- f2.close()
- f1 = open(str(p1), "r", encoding="utf8")
- f2 = open(str(p2), "r", encoding="utf8")
- try:
+
+ with p1.open(encoding="utf8") as f1, p2.open(encoding="utf8") as f2:
out = f1.read().splitlines()
err = f2.read().splitlines()
- finally:
- f1.close()
- f2.close()
+
self._dump_lines(out, sys.stdout)
self._dump_lines(err, sys.stderr)
- try:
+
+ with contextlib.suppress(ValueError):
ret = ExitCode(ret)
- except ValueError:
- pass
- return RunResult(ret, out, err, time.time() - now)
+ return RunResult(ret, out, err, timing.time() - now)
def _dump_lines(self, lines, fp):
try:
for line in lines:
print(line, file=fp)
except UnicodeEncodeError:
- print("couldn't print to {} because of encoding".format(fp))
+ print(f"couldn't print to {fp} because of encoding")
- def _getpytestargs(self):
+ def _getpytestargs(self) -> Tuple[str, ...]:
return sys.executable, "-mpytest"
def runpython(self, script) -> RunResult:
"""Run a python script using sys.executable as interpreter.
- Returns a :py:class:`RunResult`.
-
+ :rtype: RunResult
"""
return self.run(sys.executable, script)
def runpython_c(self, command):
- """Run python -c "command", return a :py:class:`RunResult`."""
+ """Run python -c "command".
+
+ :rtype: RunResult
+ """
return self.run(sys.executable, "-c", command)
- def runpytest_subprocess(self, *args, timeout=None) -> RunResult:
+ def runpytest_subprocess(self, *args, timeout: Optional[float] = None) -> RunResult:
"""Run pytest as a subprocess with given arguments.
Any plugins added to the :py:attr:`plugins` list will be added using the
@@ -1234,16 +1417,16 @@ class Testdir:
with "runpytest-" to not conflict with the normal numbered pytest
location for temporary files and directories.
- :param args: the sequence of arguments to pass to the pytest subprocess
- :param timeout: the period in seconds after which to timeout and raise
- :py:class:`Testdir.TimeoutExpired`
+ :param args:
+ The sequence of arguments to pass to the pytest subprocess.
+ :param timeout:
+ The period in seconds after which to timeout and raise
+ :py:class:`Pytester.TimeoutExpired`.
- Returns a :py:class:`RunResult`.
+ :rtype: RunResult
"""
__tracebackhide__ = True
- p = py.path.local.make_numbered_dir(
- prefix="runpytest-", keep=None, rootdir=self.tmpdir
- )
+ p = make_numbered_dir(root=self.path, prefix="runpytest-", mode=0o700)
args = ("--basetemp=%s" % p,) + args
plugins = [x for x in self.plugins if isinstance(x, str)]
if plugins:
@@ -1260,29 +1443,27 @@ class Testdir:
directory locations.
The pexpect child is returned.
-
"""
- basetemp = self.tmpdir.mkdir("temp-pexpect")
+ basetemp = self.path / "temp-pexpect"
+ basetemp.mkdir(mode=0o700)
invoke = " ".join(map(str, self._getpytestargs()))
- cmd = "{} --basetemp={} {}".format(invoke, basetemp, string)
+ cmd = f"{invoke} --basetemp={basetemp} {string}"
return self.spawn(cmd, expect_timeout=expect_timeout)
def spawn(self, cmd: str, expect_timeout: float = 10.0) -> "pexpect.spawn":
"""Run a command using pexpect.
The pexpect child is returned.
-
"""
- pexpect = pytest.importorskip("pexpect", "3.0")
+ pexpect = importorskip("pexpect", "3.0")
if hasattr(sys, "pypy_version_info") and "64" in platform.machine():
- pytest.skip("pypy-64 bit not supported")
+ skip("pypy-64 bit not supported")
if not hasattr(pexpect, "spawn"):
- pytest.skip("pexpect.spawn not available")
- logfile = self.tmpdir.join("spawn.out").open("wb")
+ skip("pexpect.spawn not available")
+ logfile = self.path.joinpath("spawn.out").open("wb")
- child = pexpect.spawn(cmd, logfile=logfile)
- self.request.addfinalizer(logfile.close)
- child.timeout = expect_timeout
+ child = pexpect.spawn(cmd, logfile=logfile, timeout=expect_timeout)
+ self._request.addfinalizer(logfile.close)
return child
@@ -1304,6 +1485,217 @@ class LineComp:
LineMatcher(lines1).fnmatch_lines(lines2)
+@final
+@attr.s(repr=False, str=False, init=False)
+class Testdir:
+ """
+ Similar to :class:`Pytester`, but this class works with legacy py.path.local objects instead.
+
+ All methods just forward to an internal :class:`Pytester` instance, converting results
+ to `py.path.local` objects as necessary.
+ """
+
+ __test__ = False
+
+ CLOSE_STDIN = Pytester.CLOSE_STDIN
+ TimeoutExpired = Pytester.TimeoutExpired
+ Session = Pytester.Session
+
+ def __init__(self, pytester: Pytester, *, _ispytest: bool = False) -> None:
+ check_ispytest(_ispytest)
+ self._pytester = pytester
+
+ @property
+ def tmpdir(self) -> py.path.local:
+ """Temporary directory where tests are executed."""
+ return py.path.local(self._pytester.path)
+
+ @property
+ def test_tmproot(self) -> py.path.local:
+ return py.path.local(self._pytester._test_tmproot)
+
+ @property
+ def request(self):
+ return self._pytester._request
+
+ @property
+ def plugins(self):
+ return self._pytester.plugins
+
+ @plugins.setter
+ def plugins(self, plugins):
+ self._pytester.plugins = plugins
+
+ @property
+ def monkeypatch(self) -> MonkeyPatch:
+ return self._pytester._monkeypatch
+
+ def make_hook_recorder(self, pluginmanager) -> HookRecorder:
+ """See :meth:`Pytester.make_hook_recorder`."""
+ return self._pytester.make_hook_recorder(pluginmanager)
+
+ def chdir(self) -> None:
+ """See :meth:`Pytester.chdir`."""
+ return self._pytester.chdir()
+
+ def finalize(self) -> None:
+ """See :meth:`Pytester._finalize`."""
+ return self._pytester._finalize()
+
+ def makefile(self, ext, *args, **kwargs) -> py.path.local:
+ """See :meth:`Pytester.makefile`."""
+ return py.path.local(str(self._pytester.makefile(ext, *args, **kwargs)))
+
+ def makeconftest(self, source) -> py.path.local:
+ """See :meth:`Pytester.makeconftest`."""
+ return py.path.local(str(self._pytester.makeconftest(source)))
+
+ def makeini(self, source) -> py.path.local:
+ """See :meth:`Pytester.makeini`."""
+ return py.path.local(str(self._pytester.makeini(source)))
+
+ def getinicfg(self, source: str) -> SectionWrapper:
+ """See :meth:`Pytester.getinicfg`."""
+ return self._pytester.getinicfg(source)
+
+ def makepyprojecttoml(self, source) -> py.path.local:
+ """See :meth:`Pytester.makepyprojecttoml`."""
+ return py.path.local(str(self._pytester.makepyprojecttoml(source)))
+
+ def makepyfile(self, *args, **kwargs) -> py.path.local:
+ """See :meth:`Pytester.makepyfile`."""
+ return py.path.local(str(self._pytester.makepyfile(*args, **kwargs)))
+
+ def maketxtfile(self, *args, **kwargs) -> py.path.local:
+ """See :meth:`Pytester.maketxtfile`."""
+ return py.path.local(str(self._pytester.maketxtfile(*args, **kwargs)))
+
+ def syspathinsert(self, path=None) -> None:
+ """See :meth:`Pytester.syspathinsert`."""
+ return self._pytester.syspathinsert(path)
+
+ def mkdir(self, name) -> py.path.local:
+ """See :meth:`Pytester.mkdir`."""
+ return py.path.local(str(self._pytester.mkdir(name)))
+
+ def mkpydir(self, name) -> py.path.local:
+ """See :meth:`Pytester.mkpydir`."""
+ return py.path.local(str(self._pytester.mkpydir(name)))
+
+ def copy_example(self, name=None) -> py.path.local:
+ """See :meth:`Pytester.copy_example`."""
+ return py.path.local(str(self._pytester.copy_example(name)))
+
+ def getnode(self, config: Config, arg) -> Optional[Union[Item, Collector]]:
+ """See :meth:`Pytester.getnode`."""
+ return self._pytester.getnode(config, arg)
+
+ def getpathnode(self, path):
+ """See :meth:`Pytester.getpathnode`."""
+ return self._pytester.getpathnode(path)
+
+ def genitems(self, colitems: List[Union[Item, Collector]]) -> List[Item]:
+ """See :meth:`Pytester.genitems`."""
+ return self._pytester.genitems(colitems)
+
+ def runitem(self, source):
+ """See :meth:`Pytester.runitem`."""
+ return self._pytester.runitem(source)
+
+ def inline_runsource(self, source, *cmdlineargs):
+ """See :meth:`Pytester.inline_runsource`."""
+ return self._pytester.inline_runsource(source, *cmdlineargs)
+
+ def inline_genitems(self, *args):
+ """See :meth:`Pytester.inline_genitems`."""
+ return self._pytester.inline_genitems(*args)
+
+ def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False):
+ """See :meth:`Pytester.inline_run`."""
+ return self._pytester.inline_run(
+ *args, plugins=plugins, no_reraise_ctrlc=no_reraise_ctrlc
+ )
+
+ def runpytest_inprocess(self, *args, **kwargs) -> RunResult:
+ """See :meth:`Pytester.runpytest_inprocess`."""
+ return self._pytester.runpytest_inprocess(*args, **kwargs)
+
+ def runpytest(self, *args, **kwargs) -> RunResult:
+ """See :meth:`Pytester.runpytest`."""
+ return self._pytester.runpytest(*args, **kwargs)
+
+ def parseconfig(self, *args) -> Config:
+ """See :meth:`Pytester.parseconfig`."""
+ return self._pytester.parseconfig(*args)
+
+ def parseconfigure(self, *args) -> Config:
+ """See :meth:`Pytester.parseconfigure`."""
+ return self._pytester.parseconfigure(*args)
+
+ def getitem(self, source, funcname="test_func"):
+ """See :meth:`Pytester.getitem`."""
+ return self._pytester.getitem(source, funcname)
+
+ def getitems(self, source):
+ """See :meth:`Pytester.getitems`."""
+ return self._pytester.getitems(source)
+
+ def getmodulecol(self, source, configargs=(), withinit=False):
+ """See :meth:`Pytester.getmodulecol`."""
+ return self._pytester.getmodulecol(
+ source, configargs=configargs, withinit=withinit
+ )
+
+ def collect_by_name(
+ self, modcol: Collector, name: str
+ ) -> Optional[Union[Item, Collector]]:
+ """See :meth:`Pytester.collect_by_name`."""
+ return self._pytester.collect_by_name(modcol, name)
+
+ def popen(
+ self,
+ cmdargs,
+ stdout: Union[int, TextIO] = subprocess.PIPE,
+ stderr: Union[int, TextIO] = subprocess.PIPE,
+ stdin=CLOSE_STDIN,
+ **kw,
+ ):
+ """See :meth:`Pytester.popen`."""
+ return self._pytester.popen(cmdargs, stdout, stderr, stdin, **kw)
+
+ def run(self, *cmdargs, timeout=None, stdin=CLOSE_STDIN) -> RunResult:
+ """See :meth:`Pytester.run`."""
+ return self._pytester.run(*cmdargs, timeout=timeout, stdin=stdin)
+
+ def runpython(self, script) -> RunResult:
+ """See :meth:`Pytester.runpython`."""
+ return self._pytester.runpython(script)
+
+ def runpython_c(self, command):
+ """See :meth:`Pytester.runpython_c`."""
+ return self._pytester.runpython_c(command)
+
+ def runpytest_subprocess(self, *args, timeout=None) -> RunResult:
+ """See :meth:`Pytester.runpytest_subprocess`."""
+ return self._pytester.runpytest_subprocess(*args, timeout=timeout)
+
+ def spawn_pytest(
+ self, string: str, expect_timeout: float = 10.0
+ ) -> "pexpect.spawn":
+ """See :meth:`Pytester.spawn_pytest`."""
+ return self._pytester.spawn_pytest(string, expect_timeout=expect_timeout)
+
+ def spawn(self, cmd: str, expect_timeout: float = 10.0) -> "pexpect.spawn":
+ """See :meth:`Pytester.spawn`."""
+ return self._pytester.spawn(cmd, expect_timeout=expect_timeout)
+
+ def __repr__(self) -> str:
+ return f"<Testdir {self.tmpdir!r}>"
+
+ def __str__(self) -> str:
+ return str(self.tmpdir)
+
+
class LineMatcher:
"""Flexible matching of text.
@@ -1316,7 +1708,15 @@ class LineMatcher:
def __init__(self, lines: List[str]) -> None:
self.lines = lines
- self._log_output = [] # type: List[str]
+ self._log_output: List[str] = []
+
+ def __str__(self) -> str:
+ """Return the entire original text.
+
+ .. versionadded:: 6.2
+ You can use :meth:`str` in older versions.
+ """
+ return "\n".join(self.lines)
def _getlines(self, lines2: Union[str, Sequence[str], Source]) -> Sequence[str]:
if isinstance(lines2, str):
@@ -1326,14 +1726,12 @@ class LineMatcher:
return lines2
def fnmatch_lines_random(self, lines2: Sequence[str]) -> None:
- """Check lines exist in the output in any order (using :func:`python:fnmatch.fnmatch`).
- """
+ """Check lines exist in the output in any order (using :func:`python:fnmatch.fnmatch`)."""
__tracebackhide__ = True
self._match_lines_random(lines2, fnmatch)
def re_match_lines_random(self, lines2: Sequence[str]) -> None:
- """Check lines exist in the output in any order (using :func:`python:re.match`).
- """
+ """Check lines exist in the output in any order (using :func:`python:re.match`)."""
__tracebackhide__ = True
self._match_lines_random(lines2, lambda name, pat: bool(re.match(pat, name)))
@@ -1378,8 +1776,8 @@ class LineMatcher:
wildcards. If they do not match a pytest.fail() is called. The
matches and non-matches are also shown as part of the error message.
- :param lines2: string patterns to match.
- :param consecutive: match lines consecutive?
+ :param lines2: String patterns to match.
+ :param consecutive: Match lines consecutively?
"""
__tracebackhide__ = True
self._match_lines(lines2, fnmatch, "fnmatch", consecutive=consecutive)
@@ -1411,24 +1809,27 @@ class LineMatcher:
match_func: Callable[[str, str], bool],
match_nickname: str,
*,
- consecutive: bool = False
+ consecutive: bool = False,
) -> None:
"""Underlying implementation of ``fnmatch_lines`` and ``re_match_lines``.
- :param list[str] lines2: list of string patterns to match. The actual
- format depends on ``match_func``
- :param match_func: a callable ``match_func(line, pattern)`` where line
- is the captured line from stdout/stderr and pattern is the matching
- pattern
- :param str match_nickname: the nickname for the match function that
- will be logged to stdout when a match occurs
- :param consecutive: match lines consecutively?
+ :param Sequence[str] lines2:
+ List of string patterns to match. The actual format depends on
+ ``match_func``.
+ :param match_func:
+ A callable ``match_func(line, pattern)`` where line is the
+ captured line from stdout/stderr and pattern is the matching
+ pattern.
+ :param str match_nickname:
+ The nickname for the match function that will be logged to stdout
+ when a match occurs.
+ :param consecutive:
+ Match lines consecutively?
"""
if not isinstance(lines2, collections.abc.Sequence):
raise TypeError("invalid type for lines2: {}".format(type(lines2).__name__))
lines2 = self._getlines(lines2)
lines1 = self.lines[:]
- nextline = None
extralines = []
__tracebackhide__ = True
wnick = len(match_nickname) + 1
@@ -1450,7 +1851,7 @@ class LineMatcher:
break
else:
if consecutive and started:
- msg = "no consecutive match: {!r}".format(line)
+ msg = f"no consecutive match: {line!r}"
self._log(msg)
self._log(
"{:>{width}}".format("with:", width=wnick), repr(nextline)
@@ -1464,7 +1865,7 @@ class LineMatcher:
self._log("{:>{width}}".format("and:", width=wnick), repr(nextline))
extralines.append(nextline)
else:
- msg = "remains unmatched: {!r}".format(line)
+ msg = f"remains unmatched: {line!r}"
self._log(msg)
self._fail(msg)
self._log_output = []
@@ -1472,7 +1873,7 @@ class LineMatcher:
def no_fnmatch_line(self, pat: str) -> None:
"""Ensure captured lines do not match the given pattern, using ``fnmatch.fnmatch``.
- :param str pat: the pattern to match lines.
+ :param str pat: The pattern to match lines.
"""
__tracebackhide__ = True
self._no_match_line(pat, fnmatch, "fnmatch")
@@ -1480,7 +1881,7 @@ class LineMatcher:
def no_re_match_line(self, pat: str) -> None:
"""Ensure captured lines do not match the given pattern, using ``re.match``.
- :param str pat: the regular expression to match lines.
+ :param str pat: The regular expression to match lines.
"""
__tracebackhide__ = True
self._no_match_line(
@@ -1490,16 +1891,16 @@ class LineMatcher:
def _no_match_line(
self, pat: str, match_func: Callable[[str, str], bool], match_nickname: str
) -> None:
- """Ensure captured lines does not have a the given pattern, using ``fnmatch.fnmatch``
+ """Ensure captured lines does not have a the given pattern, using ``fnmatch.fnmatch``.
- :param str pat: the pattern to match lines
+ :param str pat: The pattern to match lines.
"""
__tracebackhide__ = True
nomatch_printed = False
wnick = len(match_nickname) + 1
for line in self.lines:
if match_func(line, pat):
- msg = "{}: {!r}".format(match_nickname, pat)
+ msg = f"{match_nickname}: {pat!r}"
self._log(msg)
self._log("{:>{width}}".format("with:", width=wnick), repr(line))
self._fail(msg)
@@ -1514,8 +1915,8 @@ class LineMatcher:
__tracebackhide__ = True
log_text = self._log_text
self._log_output = []
- pytest.fail(log_text)
+ fail(log_text)
def str(self) -> str:
"""Return the entire original text."""
- return "\n".join(self.lines)
+ return str(self)
diff --git a/contrib/python/pytest/py3/_pytest/pytester_assertions.py b/contrib/python/pytest/py3/_pytest/pytester_assertions.py
new file mode 100644
index 0000000000..630c1d3331
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/pytester_assertions.py
@@ -0,0 +1,66 @@
+"""Helper plugin for pytester; should not be loaded on its own."""
+# This plugin contains assertions used by pytester. pytester cannot
+# contain them itself, since it is imported by the `pytest` module,
+# hence cannot be subject to assertion rewriting, which requires a
+# module to not be already imported.
+from typing import Dict
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+from _pytest.reports import CollectReport
+from _pytest.reports import TestReport
+
+
+def assertoutcome(
+ outcomes: Tuple[
+ Sequence[TestReport],
+ Sequence[Union[CollectReport, TestReport]],
+ Sequence[Union[CollectReport, TestReport]],
+ ],
+ passed: int = 0,
+ skipped: int = 0,
+ failed: int = 0,
+) -> None:
+ __tracebackhide__ = True
+
+ realpassed, realskipped, realfailed = outcomes
+ obtained = {
+ "passed": len(realpassed),
+ "skipped": len(realskipped),
+ "failed": len(realfailed),
+ }
+ expected = {"passed": passed, "skipped": skipped, "failed": failed}
+ assert obtained == expected, outcomes
+
+
+def assert_outcomes(
+ outcomes: Dict[str, int],
+ passed: int = 0,
+ skipped: int = 0,
+ failed: int = 0,
+ errors: int = 0,
+ xpassed: int = 0,
+ xfailed: int = 0,
+) -> None:
+ """Assert that the specified outcomes appear with the respective
+ numbers (0 means it didn't occur) in the text output from a test run."""
+ __tracebackhide__ = True
+
+ obtained = {
+ "passed": outcomes.get("passed", 0),
+ "skipped": outcomes.get("skipped", 0),
+ "failed": outcomes.get("failed", 0),
+ "errors": outcomes.get("errors", 0),
+ "xpassed": outcomes.get("xpassed", 0),
+ "xfailed": outcomes.get("xfailed", 0),
+ }
+ expected = {
+ "passed": passed,
+ "skipped": skipped,
+ "failed": failed,
+ "errors": errors,
+ "xpassed": xpassed,
+ "xfailed": xfailed,
+ }
+ assert obtained == expected
diff --git a/contrib/python/pytest/py3/_pytest/python.py b/contrib/python/pytest/py3/_pytest/python.py
index 0c1df99e36..f1a47d7d33 100644
--- a/contrib/python/pytest/py3/_pytest/python.py
+++ b/contrib/python/pytest/py3/_pytest/python.py
@@ -1,22 +1,29 @@
-""" Python test discovery, setup and run of test functions. """
+"""Python test discovery, setup and run of test functions."""
import enum
import fnmatch
import inspect
import itertools
import os
import sys
-import typing
+import types
import warnings
from collections import Counter
from collections import defaultdict
-from collections.abc import Sequence
from functools import partial
+from typing import Any
from typing import Callable
from typing import Dict
+from typing import Generator
from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import Mapping
from typing import Optional
+from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import Union
import py
@@ -25,11 +32,13 @@ import _pytest
from _pytest import fixtures
from _pytest import nodes
from _pytest._code import filter_traceback
+from _pytest._code import getfslineno
from _pytest._code.code import ExceptionInfo
-from _pytest._code.source import getfslineno
+from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest._io.saferepr import saferepr
from _pytest.compat import ascii_escaped
+from _pytest.compat import final
from _pytest.compat import get_default_arg_names
from _pytest.compat import get_real_func
from _pytest.compat import getimfunc
@@ -42,34 +51,33 @@ from _pytest.compat import safe_getattr
from _pytest.compat import safe_isclass
from _pytest.compat import STRING_TYPES
from _pytest.config import Config
+from _pytest.config import ExitCode
from _pytest.config import hookimpl
-from _pytest.deprecated import FUNCARGNAMES
+from _pytest.config.argparsing import Parser
+from _pytest.deprecated import FSCOLLECTOR_GETHOOKPROXY_ISINITPATH
from _pytest.fixtures import FuncFixtureInfo
+from _pytest.main import Session
from _pytest.mark import MARK_GEN
from _pytest.mark import ParameterSet
from _pytest.mark.structures import get_unpacked_marks
from _pytest.mark.structures import Mark
+from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import normalize_mark_list
from _pytest.outcomes import fail
from _pytest.outcomes import skip
+from _pytest.pathlib import import_path
+from _pytest.pathlib import ImportPathMismatchError
from _pytest.pathlib import parts
+from _pytest.pathlib import visit
from _pytest.warning_types import PytestCollectionWarning
from _pytest.warning_types import PytestUnhandledCoroutineWarning
+if TYPE_CHECKING:
+ from typing_extensions import Literal
+ from _pytest.fixtures import _Scope
-def pyobj_property(name):
- def get(self):
- node = self.getparent(getattr(__import__("pytest"), name))
- if node is not None:
- return node.obj
- doc = "python {} object this node was collected from (can be None).".format(
- name.lower()
- )
- return property(get, None, None, doc)
-
-
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--fixtures",
@@ -114,32 +122,24 @@ def pytest_addoption(parser):
"side effects(use at your own risk)",
)
- group.addoption(
- "--import-mode",
- default="prepend",
- choices=["prepend", "append"],
- dest="importmode",
- help="prepend/append to sys.path when importing test modules, "
- "default is to prepend.",
- )
-
-def pytest_cmdline_main(config):
+def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.showfixtures:
showfixtures(config)
return 0
if config.option.show_fixtures_per_test:
show_fixtures_per_test(config)
return 0
+ return None
def pytest_generate_tests(metafunc: "Metafunc") -> None:
for marker in metafunc.definition.iter_markers(name="parametrize"):
# TODO: Fix this type-ignore (overlapping kwargs).
- metafunc.parametrize(*marker.args, **marker.kwargs, _param_mark=marker) # type: ignore[misc] # noqa: F821
+ metafunc.parametrize(*marker.args, **marker.kwargs, _param_mark=marker) # type: ignore[misc]
-def pytest_configure(config):
+def pytest_configure(config: Config) -> None:
config.addinivalue_line(
"markers",
"parametrize(argnames, argvalues): call a test function multiple "
@@ -148,14 +148,14 @@ def pytest_configure(config):
"or a list of tuples of values if argnames specifies multiple names. "
"Example: @parametrize('arg1', [1,2]) would lead to two calls of the "
"decorated test function, one with arg1=1 and another with arg1=2."
- "see https://docs.pytest.org/en/latest/parametrize.html for more info "
+ "see https://docs.pytest.org/en/stable/parametrize.html for more info "
"and examples.",
)
config.addinivalue_line(
"markers",
"usefixtures(fixturename1, fixturename2, ...): mark tests as needing "
"all of the specified fixtures. see "
- "https://docs.pytest.org/en/latest/fixture.html#usefixtures ",
+ "https://docs.pytest.org/en/stable/fixture.html#usefixtures ",
)
@@ -164,16 +164,17 @@ def async_warn_and_skip(nodeid: str) -> None:
msg += (
"You need to install a suitable plugin for your async framework, for example:\n"
)
+ msg += " - anyio\n"
msg += " - pytest-asyncio\n"
- msg += " - pytest-trio\n"
msg += " - pytest-tornasync\n"
+ msg += " - pytest-trio\n"
msg += " - pytest-twisted"
warnings.warn(PytestUnhandledCoroutineWarning(msg.format(nodeid)))
skip(msg="async def function and no async plugin installed (see warnings)")
@hookimpl(trylast=True)
-def pytest_pyfunc_call(pyfuncitem: "Function"):
+def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
@@ -185,45 +186,47 @@ def pytest_pyfunc_call(pyfuncitem: "Function"):
return True
-def pytest_collect_file(path, parent):
+def pytest_collect_file(
+ path: py.path.local, parent: nodes.Collector
+) -> Optional["Module"]:
ext = path.ext
if ext == ".py":
if not parent.session.isinitpath(path):
if not path_matches_patterns(
path, parent.config.getini("python_files") + ["__init__.py"]
):
- return
+ return None
ihook = parent.session.gethookproxy(path)
- return ihook.pytest_pycollect_makemodule(path=path, parent=parent)
+ module: Module = ihook.pytest_pycollect_makemodule(path=path, parent=parent)
+ return module
+ return None
-def path_matches_patterns(path, patterns):
- """Returns True if the given py.path.local matches one of the patterns in the list of globs given"""
+def path_matches_patterns(path: py.path.local, patterns: Iterable[str]) -> bool:
+ """Return whether path matches any of the patterns in the list of globs given."""
return any(path.fnmatch(pattern) for pattern in patterns)
-def pytest_pycollect_makemodule(path, parent):
+def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Module":
if path.basename == "__init__.py":
- return Package.from_parent(parent, fspath=path)
- return Module.from_parent(parent, fspath=path)
+ pkg: Package = Package.from_parent(parent, fspath=path)
+ return pkg
+ mod: Module = Module.from_parent(parent, fspath=path)
+ return mod
-@hookimpl(hookwrapper=True)
-def pytest_pycollect_makeitem(collector, name, obj):
- outcome = yield
- res = outcome.get_result()
- if res is not None:
- return
- # nothing was collected elsewhere, let's do it here
+@hookimpl(trylast=True)
+def pytest_pycollect_makeitem(collector: "PyCollector", name: str, obj: object):
+ # Nothing was collected elsewhere, let's do it here.
if safe_isclass(obj):
if collector.istestclass(obj, name):
- outcome.force_result(Class.from_parent(collector, name=name, obj=obj))
+ return Class.from_parent(collector, name=name, obj=obj)
elif collector.istestfunction(obj, name):
- # mock seems to store unbound methods (issue473), normalize it
+ # mock seems to store unbound methods (issue473), normalize it.
obj = getattr(obj, "__func__", obj)
# We need to try and unwrap the function if it's a functools.partial
# or a functools.wrapped.
- # We mustn't if it's been wrapped with mock.patch (python 2 only)
+ # We mustn't if it's been wrapped with mock.patch (python 2 only).
if not (inspect.isfunction(obj) or inspect.isfunction(get_real_func(obj))):
filename, lineno = getfslineno(obj)
warnings.warn_explicit(
@@ -244,15 +247,42 @@ def pytest_pycollect_makeitem(collector, name, obj):
res.warn(PytestCollectionWarning(reason))
else:
res = list(collector._genfunctions(name, obj))
- outcome.force_result(res)
+ return res
class PyobjMixin:
- module = pyobj_property("Module")
- cls = pyobj_property("Class")
- instance = pyobj_property("Instance")
_ALLOW_MARKERS = True
+ # Function and attributes that the mixin needs (for type-checking only).
+ if TYPE_CHECKING:
+ name: str = ""
+ parent: Optional[nodes.Node] = None
+ own_markers: List[Mark] = []
+
+ def getparent(self, cls: Type[nodes._NodeType]) -> Optional[nodes._NodeType]:
+ ...
+
+ def listchain(self) -> List[nodes.Node]:
+ ...
+
+ @property
+ def module(self):
+ """Python module object this node was collected from (can be None)."""
+ node = self.getparent(Module)
+ return node.obj if node is not None else None
+
+ @property
+ def cls(self):
+ """Python class object this node was collected from (can be None)."""
+ node = self.getparent(Class)
+ return node.obj if node is not None else None
+
+ @property
+ def instance(self):
+ """Python instance object this node was collected from (can be None)."""
+ node = self.getparent(Instance)
+ return node.obj if node is not None else None
+
@property
def obj(self):
"""Underlying Python object."""
@@ -270,11 +300,14 @@ class PyobjMixin:
self._obj = value
def _getobj(self):
- """Gets the underlying Python object. May be overwritten by subclasses."""
- return getattr(self.parent.obj, self.name)
-
- def getmodpath(self, stopatmodule=True, includemodule=False):
- """ return python path relative to the containing module. """
+ """Get the underlying Python object. May be overwritten by subclasses."""
+ # TODO: Improve the type of `parent` such that assert/ignore aren't needed.
+ assert self.parent is not None
+ obj = self.parent.obj # type: ignore[attr-defined]
+ return getattr(obj, self.name)
+
+ def getmodpath(self, stopatmodule: bool = True, includemodule: bool = False) -> str:
+ """Return Python path relative to the containing module."""
chain = self.listchain()
chain.reverse()
parts = []
@@ -301,7 +334,7 @@ class PyobjMixin:
file_path = sys.modules[obj.__module__].__file__
if file_path.endswith(".pyc"):
file_path = file_path[:-1]
- fspath = file_path # type: Union[py.path.local, str]
+ fspath: Union[py.path.local, str] = file_path
lineno = compat_co_firstlineno
else:
fspath, lineno = getfslineno(obj)
@@ -310,26 +343,46 @@ class PyobjMixin:
return fspath, lineno, modpath
+# As an optimization, these builtin attribute names are pre-ignored when
+# iterating over an object during collection -- the pytest_pycollect_makeitem
+# hook is not called for them.
+# fmt: off
+class _EmptyClass: pass # noqa: E701
+IGNORED_ATTRIBUTES = frozenset.union( # noqa: E305
+ frozenset(),
+ # Module.
+ dir(types.ModuleType("empty_module")),
+ # Some extra module attributes the above doesn't catch.
+ {"__builtins__", "__file__", "__cached__"},
+ # Class.
+ dir(_EmptyClass),
+ # Instance.
+ dir(_EmptyClass()),
+)
+del _EmptyClass
+# fmt: on
+
+
class PyCollector(PyobjMixin, nodes.Collector):
- def funcnamefilter(self, name):
+ def funcnamefilter(self, name: str) -> bool:
return self._matches_prefix_or_glob_option("python_functions", name)
- def isnosetest(self, obj):
- """ Look for the __test__ attribute, which is applied by the
- @nose.tools.istest decorator
+ def isnosetest(self, obj: object) -> bool:
+ """Look for the __test__ attribute, which is applied by the
+ @nose.tools.istest decorator.
"""
# We explicitly check for "is True" here to not mistakenly treat
# classes with a custom __getattr__ returning something truthy (like a
# function) as test classes.
return safe_getattr(obj, "__test__", False) is True
- def classnamefilter(self, name):
+ def classnamefilter(self, name: str) -> bool:
return self._matches_prefix_or_glob_option("python_classes", name)
- def istestfunction(self, obj, name):
+ def istestfunction(self, obj: object, name: str) -> bool:
if self.funcnamefilter(name) or self.isnosetest(obj):
if isinstance(obj, staticmethod):
- # static methods need to be unwrapped
+ # staticmethods need to be unwrapped.
obj = safe_getattr(obj, "__func__", False)
return (
safe_getattr(obj, "__call__", False)
@@ -338,48 +391,54 @@ class PyCollector(PyobjMixin, nodes.Collector):
else:
return False
- def istestclass(self, obj, name):
+ def istestclass(self, obj: object, name: str) -> bool:
return self.classnamefilter(name) or self.isnosetest(obj)
- def _matches_prefix_or_glob_option(self, option_name, name):
- """
- checks if the given name matches the prefix or glob-pattern defined
- in ini configuration.
- """
+ def _matches_prefix_or_glob_option(self, option_name: str, name: str) -> bool:
+ """Check if the given name matches the prefix or glob-pattern defined
+ in ini configuration."""
for option in self.config.getini(option_name):
if name.startswith(option):
return True
- # check that name looks like a glob-string before calling fnmatch
+ # Check that name looks like a glob-string before calling fnmatch
# because this is called for every name in each collected module,
- # and fnmatch is somewhat expensive to call
+ # and fnmatch is somewhat expensive to call.
elif ("*" in option or "?" in option or "[" in option) and fnmatch.fnmatch(
name, option
):
return True
return False
- def collect(self):
+ def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
if not getattr(self.obj, "__test__", True):
return []
# NB. we avoid random getattrs and peek in the __dict__ instead
# (XXX originally introduced from a PyPy need, still true?)
dicts = [getattr(self.obj, "__dict__", {})]
- for basecls in inspect.getmro(self.obj.__class__):
+ for basecls in self.obj.__class__.__mro__:
dicts.append(basecls.__dict__)
- seen = {}
- values = []
+ seen: Set[str] = set()
+ values: List[Union[nodes.Item, nodes.Collector]] = []
+ ihook = self.ihook
for dic in dicts:
+ # Note: seems like the dict can change during iteration -
+ # be careful not to remove the list() without consideration.
for name, obj in list(dic.items()):
+ if name in IGNORED_ATTRIBUTES:
+ continue
if name in seen:
continue
- seen[name] = True
- res = self._makeitem(name, obj)
+ seen.add(name)
+ res = ihook.pytest_pycollect_makeitem(
+ collector=self, name=name, obj=obj
+ )
if res is None:
continue
- if not isinstance(res, list):
- res = [res]
- values.extend(res)
+ elif isinstance(res, list):
+ values.extend(res)
+ else:
+ values.append(res)
def sort_key(item):
fspath, lineno, _ = item.reportinfo()
@@ -388,12 +447,10 @@ class PyCollector(PyobjMixin, nodes.Collector):
values.sort(key=sort_key)
return values
- def _makeitem(self, name, obj):
- # assert self.ihook.fspath == self.fspath, self
- return self.ihook.pytest_pycollect_makeitem(collector=self, name=name, obj=obj)
-
- def _genfunctions(self, name, funcobj):
- module = self.getparent(Module).obj
+ def _genfunctions(self, name: str, funcobj) -> Iterator["Function"]:
+ modulecol = self.getparent(Module)
+ assert modulecol is not None
+ module = modulecol.obj
clscol = self.getparent(Class)
cls = clscol and clscol.obj or None
fm = self.session._fixturemanager
@@ -407,7 +464,7 @@ class PyCollector(PyobjMixin, nodes.Collector):
methods = []
if hasattr(module, "pytest_generate_tests"):
methods.append(module.pytest_generate_tests)
- if hasattr(cls, "pytest_generate_tests"):
+ if cls is not None and hasattr(cls, "pytest_generate_tests"):
methods.append(cls().pytest_generate_tests)
self.ihook.pytest_generate_tests.call_extra(methods, dict(metafunc=metafunc))
@@ -415,16 +472,16 @@ class PyCollector(PyobjMixin, nodes.Collector):
if not metafunc._calls:
yield Function.from_parent(self, name=name, fixtureinfo=fixtureinfo)
else:
- # add funcargs() as fixturedefs to fixtureinfo.arg2fixturedefs
+ # Add funcargs() as fixturedefs to fixtureinfo.arg2fixturedefs.
fixtures.add_funcarg_pseudo_fixture_def(self, metafunc, fm)
- # add_funcarg_pseudo_fixture_def may have shadowed some fixtures
+ # Add_funcarg_pseudo_fixture_def may have shadowed some fixtures
# with direct parametrization, so make sure we update what the
# function really needs.
fixtureinfo.prune_dependency_tree()
for callspec in metafunc._calls:
- subname = "{}[{}]".format(name, callspec.id)
+ subname = f"{name}[{callspec.id}]"
yield Function.from_parent(
self,
name=subname,
@@ -437,19 +494,19 @@ class PyCollector(PyobjMixin, nodes.Collector):
class Module(nodes.File, PyCollector):
- """ Collector for test classes and functions. """
+ """Collector for test classes and functions."""
def _getobj(self):
return self._importtestmodule()
- def collect(self):
+ def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
self._inject_setup_module_fixture()
self._inject_setup_function_fixture()
self.session._fixturemanager.parsefactories(self)
return super().collect()
- def _inject_setup_module_fixture(self):
- """Injects a hidden autouse, module scoped fixture into the collected module object
+ def _inject_setup_module_fixture(self) -> None:
+ """Inject a hidden autouse, module scoped fixture into the collected module object
that invokes setUpModule/tearDownModule if either or both are available.
Using a fixture to invoke this methods ensures we play nicely and unsurprisingly with
@@ -465,8 +522,13 @@ class Module(nodes.File, PyCollector):
if setup_module is None and teardown_module is None:
return
- @fixtures.fixture(autouse=True, scope="module")
- def xunit_setup_module_fixture(request):
+ @fixtures.fixture(
+ autouse=True,
+ scope="module",
+ # Use a unique name to speed up lookup.
+ name=f"xunit_setup_module_fixture_{self.obj.__name__}",
+ )
+ def xunit_setup_module_fixture(request) -> Generator[None, None, None]:
if setup_module is not None:
_call_with_optional_argument(setup_module, request.module)
yield
@@ -475,8 +537,8 @@ class Module(nodes.File, PyCollector):
self.obj.__pytest_setup_module = xunit_setup_module_fixture
- def _inject_setup_function_fixture(self):
- """Injects a hidden autouse, function scoped fixture into the collected module object
+ def _inject_setup_function_fixture(self) -> None:
+ """Inject a hidden autouse, function scoped fixture into the collected module object
that invokes setup_function/teardown_function if either or both are available.
Using a fixture to invoke this methods ensures we play nicely and unsurprisingly with
@@ -489,8 +551,13 @@ class Module(nodes.File, PyCollector):
if setup_function is None and teardown_function is None:
return
- @fixtures.fixture(autouse=True, scope="function")
- def xunit_setup_function_fixture(request):
+ @fixtures.fixture(
+ autouse=True,
+ scope="function",
+ # Use a unique name to speed up lookup.
+ name=f"xunit_setup_function_fixture_{self.obj.__name__}",
+ )
+ def xunit_setup_function_fixture(request) -> Generator[None, None, None]:
if request.instance is not None:
# in this case we are bound to an instance, so we need to let
# setup_method handle this
@@ -505,13 +572,15 @@ class Module(nodes.File, PyCollector):
self.obj.__pytest_setup_function = xunit_setup_function_fixture
def _importtestmodule(self):
- # we assume we are only called once per module
+ # We assume we are only called once per module.
importmode = self.config.getoption("--import-mode")
try:
- mod = self.fspath.pyimport(ensuresyspath=importmode)
- except SyntaxError:
- raise self.CollectError(ExceptionInfo.from_current().getrepr(style="short"))
- except self.fspath.ImportMismatchError as e:
+ mod = import_path(self.fspath, mode=importmode)
+ except SyntaxError as e:
+ raise self.CollectError(
+ ExceptionInfo.from_current().getrepr(style="short")
+ ) from e
+ except ImportPathMismatchError as e:
raise self.CollectError(
"import file mismatch:\n"
"imported module %r has this __file__ attribute:\n"
@@ -520,8 +589,8 @@ class Module(nodes.File, PyCollector):
" %s\n"
"HINT: remove __pycache__ / .pyc files and/or use a "
"unique basename for your test file modules" % e.args
- )
- except ImportError:
+ ) from e
+ except ImportError as e:
exc_info = ExceptionInfo.from_current()
if self.config.getoption("verbose") < 2:
exc_info.traceback = exc_info.traceback.filter(filter_traceback)
@@ -536,8 +605,8 @@ class Module(nodes.File, PyCollector):
"Hint: make sure your test modules/packages have valid Python names.\n"
"Traceback:\n"
"{traceback}".format(fspath=self.fspath, traceback=formatted_tb)
- )
- except _pytest.runner.Skipped as e:
+ ) from e
+ except skip.Exception as e:
if e.allow_module_level:
raise
raise self.CollectError(
@@ -545,7 +614,7 @@ class Module(nodes.File, PyCollector):
"To decorate a test function, use the @pytest.mark.skip "
"or @pytest.mark.skipif decorators instead, and to skip a "
"module use `pytestmark = pytest.mark.{skip,skipif}."
- )
+ ) from e
self.config.pluginmanager.consider_module(mod)
return mod
@@ -560,18 +629,17 @@ class Package(Module):
session=None,
nodeid=None,
) -> None:
- # NOTE: could be just the following, but kept as-is for compat.
+ # NOTE: Could be just the following, but kept as-is for compat.
# nodes.FSCollector.__init__(self, fspath, parent=parent)
session = parent.session
nodes.FSCollector.__init__(
self, fspath, parent=parent, config=config, session=session, nodeid=nodeid
)
+ self.name = os.path.basename(str(fspath.dirname))
- self.name = fspath.dirname
-
- def setup(self):
- # not using fixtures to call setup_module here because autouse fixtures
- # from packages are not called automatically (#4085)
+ def setup(self) -> None:
+ # Not using fixtures to call setup_module here because autouse fixtures
+ # from packages are not called automatically (#4085).
setup_module = _get_first_non_fixture_func(
self.obj, ("setUpModule", "setup_module")
)
@@ -586,45 +654,84 @@ class Package(Module):
self.addfinalizer(func)
def gethookproxy(self, fspath: py.path.local):
- return super()._gethookproxy(fspath)
+ warnings.warn(FSCOLLECTOR_GETHOOKPROXY_ISINITPATH, stacklevel=2)
+ return self.session.gethookproxy(fspath)
+
+ def isinitpath(self, path: py.path.local) -> bool:
+ warnings.warn(FSCOLLECTOR_GETHOOKPROXY_ISINITPATH, stacklevel=2)
+ return self.session.isinitpath(path)
+
+ def _recurse(self, direntry: "os.DirEntry[str]") -> bool:
+ if direntry.name == "__pycache__":
+ return False
+ path = py.path.local(direntry.path)
+ ihook = self.session.gethookproxy(path.dirpath())
+ if ihook.pytest_ignore_collect(path=path, config=self.config):
+ return False
+ norecursepatterns = self.config.getini("norecursedirs")
+ if any(path.check(fnmatch=pat) for pat in norecursepatterns):
+ return False
+ return True
+
+ def _collectfile(
+ self, path: py.path.local, handle_dupes: bool = True
+ ) -> Sequence[nodes.Collector]:
+ assert (
+ path.isfile()
+ ), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format(
+ path, path.isdir(), path.exists(), path.islink()
+ )
+ ihook = self.session.gethookproxy(path)
+ if not self.session.isinitpath(path):
+ if ihook.pytest_ignore_collect(path=path, config=self.config):
+ return ()
+
+ if handle_dupes:
+ keepduplicates = self.config.getoption("keepduplicates")
+ if not keepduplicates:
+ duplicate_paths = self.config.pluginmanager._duplicatepaths
+ if path in duplicate_paths:
+ return ()
+ else:
+ duplicate_paths.add(path)
- def isinitpath(self, path):
- return path in self.session._initialpaths
+ return ihook.pytest_collect_file(path=path, parent=self) # type: ignore[no-any-return]
- def collect(self):
+ def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
this_path = self.fspath.dirpath()
init_module = this_path.join("__init__.py")
if init_module.check(file=1) and path_matches_patterns(
init_module, self.config.getini("python_files")
):
yield Module.from_parent(self, fspath=init_module)
- pkg_prefixes = set()
- for path in this_path.visit(rec=self._recurse, bf=True, sort=True):
+ pkg_prefixes: Set[py.path.local] = set()
+ for direntry in visit(str(this_path), recurse=self._recurse):
+ path = py.path.local(direntry.path)
+
# We will visit our own __init__.py file, in which case we skip it.
- is_file = path.isfile()
- if is_file:
- if path.basename == "__init__.py" and path.dirpath() == this_path:
+ if direntry.is_file():
+ if direntry.name == "__init__.py" and path.dirpath() == this_path:
continue
- parts_ = parts(path.strpath)
+ parts_ = parts(direntry.path)
if any(
- pkg_prefix in parts_ and pkg_prefix.join("__init__.py") != path
+ str(pkg_prefix) in parts_ and pkg_prefix.join("__init__.py") != path
for pkg_prefix in pkg_prefixes
):
continue
- if is_file:
+ if direntry.is_file():
yield from self._collectfile(path)
- elif not path.isdir():
+ elif not direntry.is_dir():
# Broken symlink or invalid/missing file.
continue
elif path.join("__init__.py").check(file=1):
pkg_prefixes.add(path)
-def _call_with_optional_argument(func, arg):
+def _call_with_optional_argument(func, arg) -> None:
"""Call the given function with the given argument if func accepts one argument, otherwise
- calls func without arguments"""
+ calls func without arguments."""
arg_count = func.__code__.co_argcount
if inspect.ismethod(func):
arg_count -= 1
@@ -634,11 +741,9 @@ def _call_with_optional_argument(func, arg):
func()
-def _get_first_non_fixture_func(obj, names):
+def _get_first_non_fixture_func(obj: object, names: Iterable[str]):
"""Return the attribute from the given object to be used as a setup/teardown
- xunit-style function, but only if not marked as a fixture to
- avoid calling it twice.
- """
+ xunit-style function, but only if not marked as a fixture to avoid calling it twice."""
for name in names:
meth = getattr(obj, name, None)
if meth is not None and fixtures.getfixturemarker(meth) is None:
@@ -646,19 +751,18 @@ def _get_first_non_fixture_func(obj, names):
class Class(PyCollector):
- """ Collector for test methods. """
+ """Collector for test methods."""
@classmethod
def from_parent(cls, parent, *, name, obj=None):
- """
- The public constructor
- """
+ """The public constructor."""
return super().from_parent(name=name, parent=parent)
- def collect(self):
+ def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
if not safe_getattr(self.obj, "__test__", True):
return []
if hasinit(self.obj):
+ assert self.parent is not None
self.warn(
PytestCollectionWarning(
"cannot collect test class %r because it has a "
@@ -668,6 +772,7 @@ class Class(PyCollector):
)
return []
elif hasnew(self.obj):
+ assert self.parent is not None
self.warn(
PytestCollectionWarning(
"cannot collect test class %r because it has a "
@@ -682,8 +787,8 @@ class Class(PyCollector):
return [Instance.from_parent(self, name="()")]
- def _inject_setup_class_fixture(self):
- """Injects a hidden autouse, class scoped fixture into the collected class object
+ def _inject_setup_class_fixture(self) -> None:
+ """Inject a hidden autouse, class scoped fixture into the collected class object
that invokes setup_class/teardown_class if either or both are available.
Using a fixture to invoke this methods ensures we play nicely and unsurprisingly with
@@ -694,8 +799,13 @@ class Class(PyCollector):
if setup_class is None and teardown_class is None:
return
- @fixtures.fixture(autouse=True, scope="class")
- def xunit_setup_class_fixture(cls):
+ @fixtures.fixture(
+ autouse=True,
+ scope="class",
+ # Use a unique name to speed up lookup.
+ name=f"xunit_setup_class_fixture_{self.obj.__qualname__}",
+ )
+ def xunit_setup_class_fixture(cls) -> Generator[None, None, None]:
if setup_class is not None:
func = getimfunc(setup_class)
_call_with_optional_argument(func, self.obj)
@@ -706,8 +816,8 @@ class Class(PyCollector):
self.obj.__pytest_setup_class = xunit_setup_class_fixture
- def _inject_setup_method_fixture(self):
- """Injects a hidden autouse, function scoped fixture into the collected class object
+ def _inject_setup_method_fixture(self) -> None:
+ """Inject a hidden autouse, function scoped fixture into the collected class object
that invokes setup_method/teardown_method if either or both are available.
Using a fixture to invoke this methods ensures we play nicely and unsurprisingly with
@@ -718,8 +828,13 @@ class Class(PyCollector):
if setup_method is None and teardown_method is None:
return
- @fixtures.fixture(autouse=True, scope="function")
- def xunit_setup_method_fixture(self, request):
+ @fixtures.fixture(
+ autouse=True,
+ scope="function",
+ # Use a unique name to speed up lookup.
+ name=f"xunit_setup_method_fixture_{self.obj.__qualname__}",
+ )
+ def xunit_setup_method_fixture(self, request) -> Generator[None, None, None]:
method = request.function
if setup_method is not None:
func = getattr(self, "setup_method")
@@ -734,14 +849,17 @@ class Class(PyCollector):
class Instance(PyCollector):
_ALLOW_MARKERS = False # hack, destroy later
- # instances share the object with their parents in a way
+ # Instances share the object with their parents in a way
# that duplicates markers instances if not taken out
- # can be removed at node structure reorganization time
+ # can be removed at node structure reorganization time.
def _getobj(self):
- return self.parent.obj()
+ # TODO: Improve the type of `parent` such that assert/ignore aren't needed.
+ assert self.parent is not None
+ obj = self.parent.obj # type: ignore[attr-defined]
+ return obj()
- def collect(self):
+ def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
self.session._fixturemanager.parsefactories(self)
return super().collect()
@@ -750,29 +868,33 @@ class Instance(PyCollector):
return self.obj
-def hasinit(obj):
- init = getattr(obj, "__init__", None)
+def hasinit(obj: object) -> bool:
+ init: object = getattr(obj, "__init__", None)
if init:
return init != object.__init__
+ return False
-def hasnew(obj):
- new = getattr(obj, "__new__", None)
+def hasnew(obj: object) -> bool:
+ new: object = getattr(obj, "__new__", None)
if new:
return new != object.__new__
+ return False
+@final
class CallSpec2:
- def __init__(self, metafunc):
+ def __init__(self, metafunc: "Metafunc") -> None:
self.metafunc = metafunc
- self.funcargs = {}
- self._idlist = []
- self.params = {}
- self._arg2scopenum = {} # used for sorting parametrized resources
- self.marks = []
- self.indices = {}
-
- def copy(self):
+ self.funcargs: Dict[str, object] = {}
+ self._idlist: List[str] = []
+ self.params: Dict[str, object] = {}
+ # Used for sorting parametrized resources.
+ self._arg2scopenum: Dict[str, int] = {}
+ self.marks: List[Mark] = []
+ self.indices: Dict[str, int] = {}
+
+ def copy(self) -> "CallSpec2":
cs = CallSpec2(self.metafunc)
cs.funcargs.update(self.funcargs)
cs.params.update(self.params)
@@ -782,34 +904,49 @@ class CallSpec2:
cs._idlist = list(self._idlist)
return cs
- def _checkargnotcontained(self, arg):
+ def _checkargnotcontained(self, arg: str) -> None:
if arg in self.params or arg in self.funcargs:
- raise ValueError("duplicate {!r}".format(arg))
+ raise ValueError(f"duplicate {arg!r}")
- def getparam(self, name):
+ def getparam(self, name: str) -> object:
try:
return self.params[name]
- except KeyError:
- raise ValueError(name)
+ except KeyError as e:
+ raise ValueError(name) from e
@property
- def id(self):
+ def id(self) -> str:
return "-".join(map(str, self._idlist))
- def setmulti2(self, valtypes, argnames, valset, id, marks, scopenum, param_index):
+ def setmulti2(
+ self,
+ valtypes: Mapping[str, "Literal['params', 'funcargs']"],
+ argnames: Sequence[str],
+ valset: Iterable[object],
+ id: str,
+ marks: Iterable[Union[Mark, MarkDecorator]],
+ scopenum: int,
+ param_index: int,
+ ) -> None:
for arg, val in zip(argnames, valset):
self._checkargnotcontained(arg)
valtype_for_arg = valtypes[arg]
- getattr(self, valtype_for_arg)[arg] = val
+ if valtype_for_arg == "params":
+ self.params[arg] = val
+ elif valtype_for_arg == "funcargs":
+ self.funcargs[arg] = val
+ else: # pragma: no cover
+ assert False, f"Unhandled valtype for arg: {valtype_for_arg}"
self.indices[arg] = param_index
self._arg2scopenum[arg] = scopenum
self._idlist.append(id)
self.marks.extend(normalize_mark_list(marks))
+@final
class Metafunc:
- """
- Metafunc objects are passed to the :func:`pytest_generate_tests <_pytest.hookspec.pytest_generate_tests>` hook.
+ """Objects passed to the :func:`pytest_generate_tests <_pytest.hookspec.pytest_generate_tests>` hook.
+
They help to inspect a test function and to generate tests according to
test configuration or values specified in the class or module where a
test function is defined.
@@ -823,71 +960,71 @@ class Metafunc:
cls=None,
module=None,
) -> None:
+ #: Access to the underlying :class:`_pytest.python.FunctionDefinition`.
self.definition = definition
- #: access to the :class:`_pytest.config.Config` object for the test session
+ #: Access to the :class:`_pytest.config.Config` object for the test session.
self.config = config
- #: the module object where the test function is defined in.
+ #: The module object where the test function is defined in.
self.module = module
- #: underlying python test function
+ #: Underlying Python test function.
self.function = definition.obj
- #: set of fixture names required by the test function
+ #: Set of fixture names required by the test function.
self.fixturenames = fixtureinfo.names_closure
- #: class object where the test function is defined in or ``None``.
+ #: Class object where the test function is defined in or ``None``.
self.cls = cls
- self._calls = [] # type: List[CallSpec2]
+ self._calls: List[CallSpec2] = []
self._arg2fixturedefs = fixtureinfo.name2fixturedefs
- @property
- def funcargnames(self):
- """ alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
- warnings.warn(FUNCARGNAMES, stacklevel=2)
- return self.fixturenames
-
def parametrize(
self,
argnames: Union[str, List[str], Tuple[str, ...]],
- argvalues: Iterable[Union[ParameterSet, typing.Sequence[object], object]],
- indirect: Union[bool, typing.Sequence[str]] = False,
+ argvalues: Iterable[Union[ParameterSet, Sequence[object], object]],
+ indirect: Union[bool, Sequence[str]] = False,
ids: Optional[
Union[
Iterable[Union[None, str, float, int, bool]],
- Callable[[object], Optional[object]],
+ Callable[[Any], Optional[object]],
]
] = None,
- scope: "Optional[str]" = None,
+ scope: "Optional[_Scope]" = None,
*,
- _param_mark: Optional[Mark] = None
+ _param_mark: Optional[Mark] = None,
) -> None:
- """ Add new invocations to the underlying test function using the list
+ """Add new invocations to the underlying test function using the list
of argvalues for the given argnames. Parametrization is performed
during the collection phase. If you need to setup expensive resources
see about setting indirect to do it rather at test setup time.
- :arg argnames: a comma-separated string denoting one or more argument
- names, or a list/tuple of argument strings.
+ :param argnames:
+ A comma-separated string denoting one or more argument names, or
+ a list/tuple of argument strings.
- :arg argvalues: The list of argvalues determines how often a
- test is invoked with different argument values. If only one
- argname was specified argvalues is a list of values. If N
- argnames were specified, argvalues must be a list of N-tuples,
- where each tuple-element specifies a value for its respective
- argname.
+ :param argvalues:
+ The list of argvalues determines how often a test is invoked with
+ different argument values.
- :arg indirect: The list of argnames or boolean. A list of arguments'
- names (subset of argnames). If True the list contains all names from
- the argnames. Each argvalue corresponding to an argname in this list will
+ If only one argname was specified argvalues is a list of values.
+ If N argnames were specified, argvalues must be a list of
+ N-tuples, where each tuple-element specifies a value for its
+ respective argname.
+
+ :param indirect:
+ A list of arguments' names (subset of argnames) or a boolean.
+ If True the list contains all names from the argnames. Each
+ argvalue corresponding to an argname in this list will
be passed as request.param to its respective argname fixture
function so that it can perform more expensive setups during the
setup phase of a test rather than at collection time.
- :arg ids: sequence of (or generator for) ids for ``argvalues``,
- or a callable to return part of the id for each argvalue.
+ :param ids:
+ Sequence of (or generator for) ids for ``argvalues``,
+ or a callable to return part of the id for each argvalue.
With sequences (and generators like ``itertools.count()``) the
returned ids should be of type ``string``, ``int``, ``float``,
@@ -905,7 +1042,8 @@ class Metafunc:
If no ids are provided they will be generated automatically from
the argvalues.
- :arg scope: if specified it denotes the scope of the parameters.
+ :param scope:
+ If specified it denotes the scope of the parameters.
The scope is used for grouping tests by parameter instances.
It will also override any fixture-function defined scope, allowing
to set a dynamic scope using test context or configuration.
@@ -917,7 +1055,7 @@ class Metafunc:
argvalues,
self.function,
self.config,
- function_definition=self.definition,
+ nodeid=self.definition.nodeid,
)
del argvalues
@@ -940,19 +1078,21 @@ class Metafunc:
if generated_ids is not None:
ids = generated_ids
- ids = self._resolve_arg_ids(argnames, ids, parameters, item=self.definition)
+ ids = self._resolve_arg_ids(
+ argnames, ids, parameters, nodeid=self.definition.nodeid
+ )
# Store used (possibly generated) ids with parametrize Marks.
if _param_mark and _param_mark._param_ids_from and generated_ids is None:
object.__setattr__(_param_mark._param_ids_from, "_param_ids_generated", ids)
scopenum = scope2index(
- scope, descr="parametrize() call in {}".format(self.function.__name__)
+ scope, descr=f"parametrize() call in {self.function.__name__}"
)
- # create the new calls: if we are parametrize() multiple times (by applying the decorator
+ # Create the new calls: if we are parametrize() multiple times (by applying the decorator
# more than once) then we accumulate those calls generating the cartesian product
- # of all calls
+ # of all calls.
newcalls = []
for callspec in self._calls or [CallSpec2(self)]:
for param_index, (param_id, param_set) in enumerate(zip(ids, parameters)):
@@ -971,25 +1111,25 @@ class Metafunc:
def _resolve_arg_ids(
self,
- argnames: typing.Sequence[str],
+ argnames: Sequence[str],
ids: Optional[
Union[
Iterable[Union[None, str, float, int, bool]],
- Callable[[object], Optional[object]],
+ Callable[[Any], Optional[object]],
]
],
- parameters: typing.Sequence[ParameterSet],
- item,
+ parameters: Sequence[ParameterSet],
+ nodeid: str,
) -> List[str]:
- """Resolves the actual ids for the given argnames, based on the ``ids`` parameter given
+ """Resolve the actual ids for the given argnames, based on the ``ids`` parameter given
to ``parametrize``.
- :param List[str] argnames: list of argument names passed to ``parametrize()``.
- :param ids: the ids parameter of the parametrized call (see docs).
- :param List[ParameterSet] parameters: the list of parameter values, same size as ``argnames``.
- :param Item item: the item that generated this parametrized call.
+ :param List[str] argnames: List of argument names passed to ``parametrize()``.
+ :param ids: The ids parameter of the parametrized call (see docs).
+ :param List[ParameterSet] parameters: The list of parameter values, same size as ``argnames``.
+ :param str str: The nodeid of the item that generated this parametrized call.
:rtype: List[str]
- :return: the list of ids for each argname given
+ :returns: The list of ids for each argname given.
"""
if ids is None:
idfn = None
@@ -1000,21 +1140,21 @@ class Metafunc:
else:
idfn = None
ids_ = self._validate_ids(ids, parameters, self.function.__name__)
- return idmaker(argnames, parameters, idfn, ids_, self.config, item=item)
+ return idmaker(argnames, parameters, idfn, ids_, self.config, nodeid=nodeid)
def _validate_ids(
self,
ids: Iterable[Union[None, str, float, int, bool]],
- parameters: typing.Sequence[ParameterSet],
+ parameters: Sequence[ParameterSet],
func_name: str,
) -> List[Union[None, str]]:
try:
- num_ids = len(ids) # type: ignore[arg-type] # noqa: F821
+ num_ids = len(ids) # type: ignore[arg-type]
except TypeError:
try:
iter(ids)
- except TypeError:
- raise TypeError("ids must be a callable or an iterable")
+ except TypeError as e:
+ raise TypeError("ids must be a callable or an iterable") from e
num_ids = len(parameters)
# num_ids == 0 is a special case: https://github.com/pytest-dev/pytest/issues/1849
@@ -1029,7 +1169,10 @@ class Metafunc:
elif isinstance(id_value, (float, int, bool)):
new_ids.append(str(id_value))
else:
- msg = "In {}: ids must be list of string/float/int/bool, found: {} (type: {!r}) at index {}"
+ msg = ( # type: ignore[unreachable]
+ "In {}: ids must be list of string/float/int/bool, "
+ "found: {} (type: {!r}) at index {}"
+ )
fail(
msg.format(func_name, saferepr(id_value), type(id_value), idx),
pytrace=False,
@@ -1037,22 +1180,23 @@ class Metafunc:
return new_ids
def _resolve_arg_value_types(
- self,
- argnames: typing.Sequence[str],
- indirect: Union[bool, typing.Sequence[str]],
- ) -> Dict[str, str]:
- """Resolves if each parametrized argument must be considered a parameter to a fixture or a "funcarg"
- to the function, based on the ``indirect`` parameter of the parametrized() call.
-
- :param List[str] argnames: list of argument names passed to ``parametrize()``.
- :param indirect: same ``indirect`` parameter of ``parametrize()``.
+ self, argnames: Sequence[str], indirect: Union[bool, Sequence[str]],
+ ) -> Dict[str, "Literal['params', 'funcargs']"]:
+ """Resolve if each parametrized argument must be considered a
+ parameter to a fixture or a "funcarg" to the function, based on the
+ ``indirect`` parameter of the parametrized() call.
+
+ :param List[str] argnames: List of argument names passed to ``parametrize()``.
+ :param indirect: Same as the ``indirect`` parameter of ``parametrize()``.
:rtype: Dict[str, str]
A dict mapping each arg name to either:
* "params" if the argname should be the parameter of a fixture of the same name.
* "funcargs" if the argname should be a parameter to the parametrized test function.
"""
if isinstance(indirect, bool):
- valtypes = dict.fromkeys(argnames, "params" if indirect else "funcargs")
+ valtypes: Dict[str, Literal["params", "funcargs"]] = dict.fromkeys(
+ argnames, "params" if indirect else "funcargs"
+ )
elif isinstance(indirect, Sequence):
valtypes = dict.fromkeys(argnames, "funcargs")
for arg in indirect:
@@ -1074,16 +1218,13 @@ class Metafunc:
return valtypes
def _validate_if_using_arg_names(
- self,
- argnames: typing.Sequence[str],
- indirect: Union[bool, typing.Sequence[str]],
+ self, argnames: Sequence[str], indirect: Union[bool, Sequence[str]],
) -> None:
- """
- Check if all argnames are being used, by default values, or directly/indirectly.
+ """Check if all argnames are being used, by default values, or directly/indirectly.
- :param List[str] argnames: list of argument names passed to ``parametrize()``.
- :param indirect: same ``indirect`` parameter of ``parametrize()``.
- :raise ValueError: if validation fails.
+ :param List[str] argnames: List of argument names passed to ``parametrize()``.
+ :param indirect: Same as the ``indirect`` parameter of ``parametrize()``.
+ :raises ValueError: If validation fails.
"""
default_arg_names = set(get_default_arg_names(self.function))
func_name = self.function.__name__
@@ -1102,12 +1243,16 @@ class Metafunc:
else:
name = "fixture" if indirect else "argument"
fail(
- "In {}: function uses no {} '{}'".format(func_name, name, arg),
+ f"In {func_name}: function uses no {name} '{arg}'",
pytrace=False,
)
-def _find_parametrized_scope(argnames, arg2fixturedefs, indirect):
+def _find_parametrized_scope(
+ argnames: Sequence[str],
+ arg2fixturedefs: Mapping[str, Sequence[fixtures.FixtureDef[object]]],
+ indirect: Union[bool, Sequence[str]],
+) -> "fixtures._Scope":
"""Find the most appropriate scope for a parametrized call based on its arguments.
When there's at least one direct argument, always use "function" scope.
@@ -1117,9 +1262,7 @@ def _find_parametrized_scope(argnames, arg2fixturedefs, indirect):
Related to issue #1832, based on code posted by @Kingdread.
"""
- from _pytest.fixtures import scopes
-
- if isinstance(indirect, (list, tuple)):
+ if isinstance(indirect, Sequence):
all_arguments_are_fixtures = len(indirect) == len(argnames)
else:
all_arguments_are_fixtures = bool(indirect)
@@ -1132,8 +1275,8 @@ def _find_parametrized_scope(argnames, arg2fixturedefs, indirect):
if name in argnames
]
if used_scopes:
- # Takes the most narrow scope from used fixtures
- for scope in reversed(scopes):
+ # Takes the most narrow scope from used fixtures.
+ for scope in reversed(fixtures.scopes):
if scope in used_scopes:
return scope
@@ -1157,8 +1300,8 @@ def _idval(
val: object,
argname: str,
idx: int,
- idfn: Optional[Callable[[object], Optional[object]]],
- item,
+ idfn: Optional[Callable[[Any], Optional[object]]],
+ nodeid: Optional[str],
config: Optional[Config],
) -> str:
if idfn:
@@ -1167,13 +1310,14 @@ def _idval(
if generated_id is not None:
val = generated_id
except Exception as e:
- msg = "{}: error raised while trying to determine id of parameter '{}' at position {}"
- msg = msg.format(item.nodeid, argname, idx)
+ prefix = f"{nodeid}: " if nodeid is not None else ""
+ msg = "error raised while trying to determine id of parameter '{}' at position {}"
+ msg = prefix + msg.format(argname, idx)
raise ValueError(msg) from e
elif config:
- hook_id = config.hook.pytest_make_parametrize_id(
+ hook_id: Optional[str] = config.hook.pytest_make_parametrize_id(
config=config, val=val, argname=argname
- ) # type: Optional[str]
+ )
if hook_id:
return hook_id
@@ -1183,11 +1327,14 @@ def _idval(
return str(val)
elif isinstance(val, REGEX_TYPE):
return ascii_escaped(val.pattern)
+ elif val is NOTSET:
+ # Fallback to default. Note that NOTSET is an enum.Enum.
+ pass
elif isinstance(val, enum.Enum):
return str(val)
elif isinstance(getattr(val, "__name__", None), str):
- # name of a class, function, module, etc.
- name = getattr(val, "__name__") # type: str
+ # Name of a class, function, module, etc.
+ name: str = getattr(val, "__name__")
return name
return str(argname) + str(idx)
@@ -1223,17 +1370,17 @@ def _idvalset(
idx: int,
parameterset: ParameterSet,
argnames: Iterable[str],
- idfn: Optional[Callable[[object], Optional[object]]],
+ idfn: Optional[Callable[[Any], Optional[object]]],
ids: Optional[List[Union[None, str]]],
- item,
+ nodeid: Optional[str],
config: Optional[Config],
-):
+) -> str:
if parameterset.id is not None:
return parameterset.id
id = None if ids is None or idx >= len(ids) else ids[idx]
if id is None:
this_id = [
- _idval(val, argname, idx, idfn, item=item, config=config)
+ _idval(val, argname, idx, idfn, nodeid=nodeid, config=config)
for val, argname in zip(parameterset.values, argnames)
]
return "-".join(this_id)
@@ -1244,13 +1391,15 @@ def _idvalset(
def idmaker(
argnames: Iterable[str],
parametersets: Iterable[ParameterSet],
- idfn: Optional[Callable[[object], Optional[object]]] = None,
+ idfn: Optional[Callable[[Any], Optional[object]]] = None,
ids: Optional[List[Union[None, str]]] = None,
config: Optional[Config] = None,
- item=None,
+ nodeid: Optional[str] = None,
) -> List[str]:
resolved_ids = [
- _idvalset(valindex, parameterset, argnames, idfn, ids, config=config, item=item)
+ _idvalset(
+ valindex, parameterset, argnames, idfn, ids, config=config, nodeid=nodeid
+ )
for valindex, parameterset in enumerate(parametersets)
]
@@ -1258,13 +1407,13 @@ def idmaker(
unique_ids = set(resolved_ids)
if len(unique_ids) != len(resolved_ids):
- # Record the number of occurrences of each test ID
+ # Record the number of occurrences of each test ID.
test_id_counts = Counter(resolved_ids)
- # Map the test ID to its next suffix
- test_id_suffixes = defaultdict(int) # type: Dict[str, int]
+ # Map the test ID to its next suffix.
+ test_id_suffixes: Dict[str, int] = defaultdict(int)
- # Suffix non-unique IDs to make them unique
+ # Suffix non-unique IDs to make them unique.
for index, test_id in enumerate(resolved_ids):
if test_id_counts[test_id] > 1:
resolved_ids[index] = "{}{}".format(test_id, test_id_suffixes[test_id])
@@ -1279,7 +1428,7 @@ def show_fixtures_per_test(config):
return wrap_session(config, _show_fixtures_per_test)
-def _show_fixtures_per_test(config, session):
+def _show_fixtures_per_test(config: Config, session: Session) -> None:
import _pytest.config
session.perform_collect()
@@ -1288,16 +1437,16 @@ def _show_fixtures_per_test(config, session):
verbose = config.getvalue("verbose")
def get_best_relpath(func):
- loc = getlocation(func, curdir)
- return curdir.bestrelpath(loc)
+ loc = getlocation(func, str(curdir))
+ return curdir.bestrelpath(py.path.local(loc))
- def write_fixture(fixture_def):
+ def write_fixture(fixture_def: fixtures.FixtureDef[object]) -> None:
argname = fixture_def.argname
if verbose <= 0 and argname.startswith("_"):
return
if verbose > 0:
bestrel = get_best_relpath(fixture_def.func)
- funcargspec = "{} -- {}".format(argname, bestrel)
+ funcargspec = f"{argname} -- {bestrel}"
else:
funcargspec = argname
tw.line(funcargspec, green=True)
@@ -1307,37 +1456,35 @@ def _show_fixtures_per_test(config, session):
else:
tw.line(" no docstring available", red=True)
- def write_item(item):
- try:
- info = item._fixtureinfo
- except AttributeError:
- # doctests items have no _fixtureinfo attribute
- return
- if not info.name2fixturedefs:
- # this test item does not use any fixtures
+ def write_item(item: nodes.Item) -> None:
+ # Not all items have _fixtureinfo attribute.
+ info: Optional[FuncFixtureInfo] = getattr(item, "_fixtureinfo", None)
+ if info is None or not info.name2fixturedefs:
+ # This test item does not use any fixtures.
return
tw.line()
- tw.sep("-", "fixtures used by {}".format(item.name))
- tw.sep("-", "({})".format(get_best_relpath(item.function)))
- # dict key not used in loop but needed for sorting
+ tw.sep("-", f"fixtures used by {item.name}")
+ # TODO: Fix this type ignore.
+ tw.sep("-", "({})".format(get_best_relpath(item.function))) # type: ignore[attr-defined]
+ # dict key not used in loop but needed for sorting.
for _, fixturedefs in sorted(info.name2fixturedefs.items()):
assert fixturedefs is not None
if not fixturedefs:
continue
- # last item is expected to be the one used by the test item
+ # Last item is expected to be the one used by the test item.
write_fixture(fixturedefs[-1])
for session_item in session.items:
write_item(session_item)
-def showfixtures(config):
+def showfixtures(config: Config) -> Union[int, ExitCode]:
from _pytest.main import wrap_session
return wrap_session(config, _showfixtures_main)
-def _showfixtures_main(config, session):
+def _showfixtures_main(config: Config, session: Session) -> None:
import _pytest.config
session.perform_collect()
@@ -1348,14 +1495,14 @@ def _showfixtures_main(config, session):
fm = session._fixturemanager
available = []
- seen = set()
+ seen: Set[Tuple[str, str]] = set()
for argname, fixturedefs in fm._arg2fixturedefs.items():
assert fixturedefs is not None
if not fixturedefs:
continue
for fixturedef in fixturedefs:
- loc = getlocation(fixturedef.func, curdir)
+ loc = getlocation(fixturedef.func, str(curdir))
if (fixturedef.argname, loc) in seen:
continue
seen.add((fixturedef.argname, loc))
@@ -1363,7 +1510,7 @@ def _showfixtures_main(config, session):
(
len(fixturedef.baseid),
fixturedef.func.__module__,
- curdir.bestrelpath(loc),
+ curdir.bestrelpath(py.path.local(loc)),
fixturedef.argname,
fixturedef,
)
@@ -1375,7 +1522,7 @@ def _showfixtures_main(config, session):
if currentmodule != module:
if not module.startswith("_pytest."):
tw.line()
- tw.sep("-", "fixtures defined from {}".format(module))
+ tw.sep("-", f"fixtures defined from {module}")
currentmodule = module
if verbose <= 0 and argname[0] == "_":
continue
@@ -1385,46 +1532,80 @@ def _showfixtures_main(config, session):
if verbose > 0:
tw.write(" -- %s" % bestrel, yellow=True)
tw.write("\n")
- loc = getlocation(fixturedef.func, curdir)
+ loc = getlocation(fixturedef.func, str(curdir))
doc = inspect.getdoc(fixturedef.func)
if doc:
write_docstring(tw, doc)
else:
- tw.line(" {}: no docstring available".format(loc), red=True)
+ tw.line(f" {loc}: no docstring available", red=True)
tw.line()
def write_docstring(tw: TerminalWriter, doc: str, indent: str = " ") -> None:
for line in doc.split("\n"):
- tw.write(indent + line + "\n")
+ tw.line(indent + line)
class Function(PyobjMixin, nodes.Item):
- """ a Function Item is responsible for setting up and executing a
- Python test function.
+ """An Item responsible for setting up and executing a Python test function.
+
+ param name:
+ The full function name, including any decorations like those
+ added by parametrization (``my_func[my_param]``).
+ param parent:
+ The parent Node.
+ param config:
+ The pytest Config object.
+ param callspec:
+ If given, this is function has been parametrized and the callspec contains
+ meta information about the parametrization.
+ param callobj:
+ If given, the object which will be called when the Function is invoked,
+ otherwise the callobj will be obtained from ``parent`` using ``originalname``.
+ param keywords:
+ Keywords bound to the function object for "-k" matching.
+ param session:
+ The pytest Session object.
+ param fixtureinfo:
+ Fixture information already resolved at this fixture node..
+ param originalname:
+ The attribute name to use for accessing the underlying function object.
+ Defaults to ``name``. Set this if name is different from the original name,
+ for example when it contains decorations like those added by parametrization
+ (``my_func[my_param]``).
"""
- # disable since functions handle it themselves
+ # Disable since functions handle it themselves.
_ALLOW_MARKERS = False
def __init__(
self,
- name,
+ name: str,
parent,
- args=None,
- config=None,
+ config: Optional[Config] = None,
callspec: Optional[CallSpec2] = None,
callobj=NOTSET,
keywords=None,
- session=None,
+ session: Optional[Session] = None,
fixtureinfo: Optional[FuncFixtureInfo] = None,
- originalname=None,
+ originalname: Optional[str] = None,
) -> None:
super().__init__(name, parent, config=config, session=session)
- self._args = args
+
if callobj is not NOTSET:
self.obj = callobj
+ #: Original function name, without any decorations (for example
+ #: parametrization adds a ``"[...]"`` suffix to function names), used to access
+ #: the underlying function object from ``parent`` (in case ``callobj`` is not given
+ #: explicitly).
+ #:
+ #: .. versionadded:: 3.0
+ self.originalname = originalname or name
+
+ # Note: when FunctionDefinition is introduced, we should change ``originalname``
+ # to a readonly property that returns FunctionDefinition.name.
+
self.keywords.update(self.obj.__dict__)
self.own_markers.extend(get_unpacked_marks(self.obj))
if callspec:
@@ -1455,63 +1636,46 @@ class Function(PyobjMixin, nodes.Item):
fixtureinfo = self.session._fixturemanager.getfixtureinfo(
self, self.obj, self.cls, funcargs=True
)
- self._fixtureinfo = fixtureinfo # type: FuncFixtureInfo
+ self._fixtureinfo: FuncFixtureInfo = fixtureinfo
self.fixturenames = fixtureinfo.names_closure
self._initrequest()
- #: original function name, without any decorations (for example
- #: parametrization adds a ``"[...]"`` suffix to function names).
- #:
- #: .. versionadded:: 3.0
- self.originalname = originalname
-
@classmethod
def from_parent(cls, parent, **kw): # todo: determine sound type limitations
- """
- The public constructor
- """
+ """The public constructor."""
return super().from_parent(parent=parent, **kw)
- def _initrequest(self):
- self.funcargs = {}
- self._request = fixtures.FixtureRequest(self)
+ def _initrequest(self) -> None:
+ self.funcargs: Dict[str, object] = {}
+ self._request = fixtures.FixtureRequest(self, _ispytest=True)
@property
def function(self):
- "underlying python 'function' object"
+ """Underlying python 'function' object."""
return getimfunc(self.obj)
def _getobj(self):
- name = self.name
- i = name.find("[") # parametrization
- if i != -1:
- name = name[:i]
- return getattr(self.parent.obj, name)
+ assert self.parent is not None
+ return getattr(self.parent.obj, self.originalname) # type: ignore[attr-defined]
@property
def _pyfuncitem(self):
- "(compatonly) for code expecting pytest-2.2 style request objects"
+ """(compatonly) for code expecting pytest-2.2 style request objects."""
return self
- @property
- def funcargnames(self):
- """ alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
- warnings.warn(FUNCARGNAMES, stacklevel=2)
- return self.fixturenames
-
def runtest(self) -> None:
- """ execute the underlying test function. """
+ """Execute the underlying test function."""
self.ihook.pytest_pyfunc_call(pyfuncitem=self)
def setup(self) -> None:
if isinstance(self.parent, Instance):
self.parent.newinstance()
self.obj = self._getobj()
- fixtures.fillfixtures(self)
+ self._request._fillfixtures()
- def _prunetraceback(self, excinfo: ExceptionInfo) -> None:
+ def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
if hasattr(self, "_obj") and not self.config.getoption("fulltrace", False):
- code = _pytest._code.Code(get_real_func(self.obj))
+ code = _pytest._code.Code.from_function(get_real_func(self.obj))
path, firstlineno = code.path, code.firstlineno
traceback = excinfo.traceback
ntraceback = traceback.cut(path=path, firstlineno=firstlineno)
@@ -1524,14 +1688,16 @@ class Function(PyobjMixin, nodes.Item):
excinfo.traceback = ntraceback.filter()
# issue364: mark all but first and last frames to
- # only show a single-line message for each frame
+ # only show a single-line message for each frame.
if self.config.getoption("tbstyle", "auto") == "auto":
if len(excinfo.traceback) > 2:
for entry in excinfo.traceback[1:-1]:
entry.set_repr_style("short")
- def repr_failure(self, excinfo, outerr=None):
- assert outerr is None, "XXX outerr usage is deprecated"
+ # TODO: Type ignored -- breaks Liskov Substitution.
+ def repr_failure( # type: ignore[override]
+ self, excinfo: ExceptionInfo[BaseException],
+ ) -> Union[str, TerminalRepr]:
style = self.config.getoption("tbstyle", "auto")
if style == "auto":
style = "long"
@@ -1540,11 +1706,11 @@ class Function(PyobjMixin, nodes.Item):
class FunctionDefinition(Function):
"""
- internal hack until we get actual definition nodes instead of the
- crappy metafunc hack
+ This class is a step gap solution until we evolve to have actual function definition nodes
+ and manage to get rid of ``metafunc``.
"""
def runtest(self) -> None:
- raise RuntimeError("function definitions are not supposed to be used")
+ raise RuntimeError("function definitions are not supposed to be run as tests")
setup = runtest
diff --git a/contrib/python/pytest/py3/_pytest/python_api.py b/contrib/python/pytest/py3/_pytest/python_api.py
index df97181f4f..81ce4f8953 100644
--- a/contrib/python/pytest/py3/_pytest/python_api.py
+++ b/contrib/python/pytest/py3/_pytest/python_api.py
@@ -1,40 +1,36 @@
-import inspect
import math
import pprint
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sized
from decimal import Decimal
-from itertools import filterfalse
-from numbers import Number
+from numbers import Complex
from types import TracebackType
from typing import Any
from typing import Callable
from typing import cast
from typing import Generic
from typing import Optional
+from typing import overload
from typing import Pattern
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from more_itertools.more import always_iterable
+if TYPE_CHECKING:
+ from numpy import ndarray
+
import _pytest._code
-from _pytest.compat import overload
+from _pytest.compat import final
from _pytest.compat import STRING_TYPES
-from _pytest.compat import TYPE_CHECKING
from _pytest.outcomes import fail
-if TYPE_CHECKING:
- from typing import Type # noqa: F401 (used in type string)
-
-BASE_TYPE = (type, STRING_TYPES)
-
-
-def _non_numeric_type_error(value, at):
- at_str = " at {}".format(at) if at else ""
+def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
+ at_str = f" at {at}" if at else ""
return TypeError(
"cannot make approximate comparisons to non-numeric values: {!r} {}".format(
value, at_str
@@ -46,16 +42,14 @@ def _non_numeric_type_error(value, at):
class ApproxBase:
- """
- Provide shared utilities for making approximate comparisons between numbers
- or sequences of numbers.
- """
+ """Provide shared utilities for making approximate comparisons between
+ numbers or sequences of numbers."""
# Tell numpy to use our `__eq__` operator instead of its.
__array_ufunc__ = None
__array_priority__ = 100
- def __init__(self, expected, rel=None, abs=None, nan_ok=False):
+ def __init__(self, expected, rel=None, abs=None, nan_ok: bool = False) -> None:
__tracebackhide__ = True
self.expected = expected
self.abs = abs
@@ -63,10 +57,10 @@ class ApproxBase:
self.nan_ok = nan_ok
self._check_type()
- def __repr__(self):
+ def __repr__(self) -> str:
raise NotImplementedError
- def __eq__(self, actual):
+ def __eq__(self, actual) -> bool:
return all(
a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
)
@@ -74,23 +68,21 @@ class ApproxBase:
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
- def __ne__(self, actual):
+ def __ne__(self, actual) -> bool:
return not (actual == self)
- def _approx_scalar(self, x):
+ def _approx_scalar(self, x) -> "ApproxScalar":
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
def _yield_comparisons(self, actual):
- """
- Yield all the pairs of numbers to be compared. This is used to
- implement the `__eq__` method.
+ """Yield all the pairs of numbers to be compared.
+
+ This is used to implement the `__eq__` method.
"""
raise NotImplementedError
- def _check_type(self):
- """
- Raise a TypeError if the expected value is not a valid type.
- """
+ def _check_type(self) -> None:
+ """Raise a TypeError if the expected value is not a valid type."""
# This is only a concern if the expected value is a sequence. In every
# other case, the approx() function ensures that the expected value has
# a numeric type. For this reason, the default is to do nothing. The
@@ -107,24 +99,22 @@ def _recursive_list_map(f, x):
class ApproxNumpy(ApproxBase):
- """
- Perform approximate comparisons where the expected value is numpy array.
- """
+ """Perform approximate comparisons where the expected value is numpy array."""
- def __repr__(self):
+ def __repr__(self) -> str:
list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
- return "approx({!r})".format(list_scalars)
+ return f"approx({list_scalars!r})"
- def __eq__(self, actual):
+ def __eq__(self, actual) -> bool:
import numpy as np
- # self.expected is supposed to always be an array here
+ # self.expected is supposed to always be an array here.
if not np.isscalar(actual):
try:
actual = np.asarray(actual)
- except: # noqa
- raise TypeError("cannot compare '{}' to numpy.ndarray".format(actual))
+ except Exception as e:
+ raise TypeError(f"cannot compare '{actual}' to numpy.ndarray") from e
if not np.isscalar(actual) and actual.shape != self.expected.shape:
return False
@@ -147,18 +137,19 @@ class ApproxNumpy(ApproxBase):
class ApproxMapping(ApproxBase):
- """
- Perform approximate comparisons where the expected value is a mapping with
- numeric values (the keys can be anything).
- """
+ """Perform approximate comparisons where the expected value is a mapping
+ with numeric values (the keys can be anything)."""
- def __repr__(self):
+ def __repr__(self) -> str:
return "approx({!r})".format(
{k: self._approx_scalar(v) for k, v in self.expected.items()}
)
- def __eq__(self, actual):
- if set(actual.keys()) != set(self.expected.keys()):
+ def __eq__(self, actual) -> bool:
+ try:
+ if set(actual.keys()) != set(self.expected.keys()):
+ return False
+ except AttributeError:
return False
return ApproxBase.__eq__(self, actual)
@@ -167,23 +158,18 @@ class ApproxMapping(ApproxBase):
for k in self.expected.keys():
yield actual[k], self.expected[k]
- def _check_type(self):
+ def _check_type(self) -> None:
__tracebackhide__ = True
for key, value in self.expected.items():
if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
- elif not isinstance(value, Number):
- raise _non_numeric_type_error(self.expected, at="key={!r}".format(key))
class ApproxSequencelike(ApproxBase):
- """
- Perform approximate comparisons where the expected value is a sequence of
- numbers.
- """
+ """Perform approximate comparisons where the expected value is a sequence of numbers."""
- def __repr__(self):
+ def __repr__(self) -> str:
seq_type = type(self.expected)
if seq_type not in (tuple, list, set):
seq_type = list
@@ -191,77 +177,90 @@ class ApproxSequencelike(ApproxBase):
seq_type(self._approx_scalar(x) for x in self.expected)
)
- def __eq__(self, actual):
- if len(actual) != len(self.expected):
+ def __eq__(self, actual) -> bool:
+ try:
+ if len(actual) != len(self.expected):
+ return False
+ except TypeError:
return False
return ApproxBase.__eq__(self, actual)
def _yield_comparisons(self, actual):
return zip(actual, self.expected)
- def _check_type(self):
+ def _check_type(self) -> None:
__tracebackhide__ = True
for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
- elif not isinstance(x, Number):
- raise _non_numeric_type_error(
- self.expected, at="index {}".format(index)
- )
class ApproxScalar(ApproxBase):
- """
- Perform approximate comparisons where the expected value is a single number.
- """
+ """Perform approximate comparisons where the expected value is a single number."""
# Using Real should be better than this Union, but not possible yet:
# https://github.com/python/typeshed/pull/3108
- DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 # type: Union[float, Decimal]
- DEFAULT_RELATIVE_TOLERANCE = 1e-6 # type: Union[float, Decimal]
+ DEFAULT_ABSOLUTE_TOLERANCE: Union[float, Decimal] = 1e-12
+ DEFAULT_RELATIVE_TOLERANCE: Union[float, Decimal] = 1e-6
- def __repr__(self):
- """
- Return a string communicating both the expected value and the tolerance
- for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'.
+ def __repr__(self) -> str:
+ """Return a string communicating both the expected value and the
+ tolerance for the comparison being made.
+
+ For example, ``1.0 ± 1e-6``, ``(3+4j) ± 5e-6 ∠ ±180°``.
"""
- # Infinities aren't compared using tolerances, so don't show a
- # tolerance. Need to call abs to handle complex numbers, e.g. (inf + 1j)
- if math.isinf(abs(self.expected)):
+ # Don't show a tolerance for values that aren't compared using
+ # tolerances, i.e. non-numerics and infinities. Need to call abs to
+ # handle complex numbers, e.g. (inf + 1j).
+ if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
+ abs(self.expected) # type: ignore[arg-type]
+ ):
return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'.
try:
- vetted_tolerance = "{:.1e}".format(self.tolerance)
- if isinstance(self.expected, complex) and not math.isinf(self.tolerance):
+ vetted_tolerance = f"{self.tolerance:.1e}"
+ if (
+ isinstance(self.expected, Complex)
+ and self.expected.imag
+ and not math.isinf(self.tolerance)
+ ):
vetted_tolerance += " ∠ ±180°"
except ValueError:
vetted_tolerance = "???"
- return "{} ± {}".format(self.expected, vetted_tolerance)
+ return f"{self.expected} ± {vetted_tolerance}"
- def __eq__(self, actual):
- """
- Return true if the given value is equal to the expected value within
- the pre-specified tolerance.
- """
- if _is_numpy_array(actual):
+ def __eq__(self, actual) -> bool:
+ """Return whether the given value is equal to the expected value
+ within the pre-specified tolerance."""
+ asarray = _as_numpy_array(actual)
+ if asarray is not None:
# Call ``__eq__()`` manually to prevent infinite-recursion with
# numpy<1.13. See #3748.
- return all(self.__eq__(a) for a in actual.flat)
+ return all(self.__eq__(a) for a in asarray.flat)
# Short-circuit exact equality.
if actual == self.expected:
return True
+ # If either type is non-numeric, fall back to strict equality.
+ # NB: we need Complex, rather than just Number, to ensure that __abs__,
+ # __sub__, and __float__ are defined.
+ if not (
+ isinstance(self.expected, (Complex, Decimal))
+ and isinstance(actual, (Complex, Decimal))
+ ):
+ return False
+
# Allow the user to control whether NaNs are considered equal to each
# other or not. The abs() calls are for compatibility with complex
# numbers.
- if math.isnan(abs(self.expected)):
- return self.nan_ok and math.isnan(abs(actual))
+ if math.isnan(abs(self.expected)): # type: ignore[arg-type]
+ return self.nan_ok and math.isnan(abs(actual)) # type: ignore[arg-type]
# Infinity shouldn't be approximately equal to anything but itself, but
# if there's a relative tolerance, it will be infinite and infinity
@@ -269,21 +268,22 @@ class ApproxScalar(ApproxBase):
# case would have been short circuited above, so here we can just
# return false if the expected value is infinite. The abs() call is
# for compatibility with complex numbers.
- if math.isinf(abs(self.expected)):
+ if math.isinf(abs(self.expected)): # type: ignore[arg-type]
return False
# Return true if the two numbers are within the tolerance.
- return abs(self.expected - actual) <= self.tolerance
+ result: bool = abs(self.expected - actual) <= self.tolerance
+ return result
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@property
def tolerance(self):
- """
- Return the tolerance for the comparison. This could be either an
- absolute tolerance or a relative tolerance, depending on what the user
- specified or which would be larger.
+ """Return the tolerance for the comparison.
+
+ This could be either an absolute tolerance or a relative tolerance,
+ depending on what the user specified or which would be larger.
"""
def set_default(x, default):
@@ -295,7 +295,7 @@ class ApproxScalar(ApproxBase):
if absolute_tolerance < 0:
raise ValueError(
- "absolute tolerance can't be negative: {}".format(absolute_tolerance)
+ f"absolute tolerance can't be negative: {absolute_tolerance}"
)
if math.isnan(absolute_tolerance):
raise ValueError("absolute tolerance can't be NaN.")
@@ -317,7 +317,7 @@ class ApproxScalar(ApproxBase):
if relative_tolerance < 0:
raise ValueError(
- "relative tolerance can't be negative: {}".format(absolute_tolerance)
+ f"relative tolerance can't be negative: {absolute_tolerance}"
)
if math.isnan(relative_tolerance):
raise ValueError("relative tolerance can't be NaN.")
@@ -327,17 +327,14 @@ class ApproxScalar(ApproxBase):
class ApproxDecimal(ApproxScalar):
- """
- Perform approximate comparisons where the expected value is a decimal.
- """
+ """Perform approximate comparisons where the expected value is a Decimal."""
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
-def approx(expected, rel=None, abs=None, nan_ok=False):
- """
- Assert that two numbers (or two sets of numbers) are equal to each other
+def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
+ """Assert that two numbers (or two sets of numbers) are equal to each other
within some tolerance.
Due to the `intricacies of floating-point arithmetic`__, numbers that we
@@ -429,6 +426,18 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
>>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12)
True
+ You can also use ``approx`` to compare nonnumeric types, or dicts and
+ sequences containing nonnumeric types, in which case it falls back to
+ strict equality. This can be useful for comparing dicts and sequences that
+ can contain optional values::
+
+ >>> {"required": 1.0000005, "optional": None} == approx({"required": 1, "optional": None})
+ True
+ >>> [None, 1.0000005] == approx([None,1])
+ True
+ >>> ["foo", 1.0000005] == approx([None,1])
+ False
+
If you're thinking about using ``approx``, then you might want to know how
it compares to other good ways of comparing floating-point numbers. All of
these algorithms are based on relative and absolute tolerances and should
@@ -440,7 +449,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
both ``a`` and ``b``, this test is symmetric (i.e. neither ``a`` nor
``b`` is a "reference value"). You have to specify an absolute tolerance
if you want to compare to ``0.0`` because there is no tolerance by
- default. Only available in python>=3.5. `More information...`__
+ default. `More information...`__
__ https://docs.python.org/3/library/math.html#math.isclose
@@ -451,7 +460,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
think of ``b`` as the reference value. Support for comparing sequences
is provided by ``numpy.allclose``. `More information...`__
- __ http://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.isclose.html
+ __ https://numpy.org/doc/stable/reference/generated/numpy.isclose.html
- ``unittest.TestCase.assertAlmostEqual(a, b)``: True if ``a`` and ``b``
are within an absolute tolerance of ``1e-7``. No relative tolerance is
@@ -486,6 +495,14 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
follows a fixed behavior. `More information...`__
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__
+
+ .. versionchanged:: 3.7.1
+ ``approx`` raises ``TypeError`` when it encounters a dict value or
+ sequence element of nonnumeric type.
+
+ .. versionchanged:: 6.1.0
+ ``approx`` falls back to strict equality for nonnumeric types instead
+ of raising ``TypeError``.
"""
# Delegate the comparison to a class that knows how to deal with the type
@@ -506,36 +523,50 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
__tracebackhide__ = True
if isinstance(expected, Decimal):
- cls = ApproxDecimal
- elif isinstance(expected, Number):
- cls = ApproxScalar
+ cls: Type[ApproxBase] = ApproxDecimal
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif _is_numpy_array(expected):
+ expected = _as_numpy_array(expected)
cls = ApproxNumpy
elif (
isinstance(expected, Iterable)
and isinstance(expected, Sized)
- and not isinstance(expected, STRING_TYPES)
+ # Type ignored because the error is wrong -- not unreachable.
+ and not isinstance(expected, STRING_TYPES) # type: ignore[unreachable]
):
cls = ApproxSequencelike
else:
- raise _non_numeric_type_error(expected, at=None)
+ cls = ApproxScalar
return cls(expected, rel, abs, nan_ok)
-def _is_numpy_array(obj):
+def _is_numpy_array(obj: object) -> bool:
+ """
+ Return true if the given object is implicitly convertible to ndarray,
+ and numpy is already imported.
"""
- Return true if the given object is a numpy array. Make a special effort to
- avoid importing numpy unless it's really necessary.
+ return _as_numpy_array(obj) is not None
+
+
+def _as_numpy_array(obj: object) -> Optional["ndarray"]:
+ """
+ Return an ndarray if the given object is implicitly convertible to ndarray,
+ and numpy is already imported, otherwise None.
"""
import sys
- np = sys.modules.get("numpy")
+ np: Any = sys.modules.get("numpy")
if np is not None:
- return isinstance(obj, np.ndarray)
- return False
+ # avoid infinite recursion on numpy scalars, which have __array__
+ if np.isscalar(obj):
+ return None
+ elif isinstance(obj, np.ndarray):
+ return obj
+ elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"):
+ return np.asarray(obj)
+ return None
# builtin pytest.raises helper
@@ -545,33 +576,31 @@ _E = TypeVar("_E", bound=BaseException)
@overload
def raises(
- expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
+ expected_exception: Union[Type[_E], Tuple[Type[_E], ...]],
*,
- match: "Optional[Union[str, Pattern]]" = ...
+ match: Optional[Union[str, Pattern[str]]] = ...,
) -> "RaisesContext[_E]":
- ... # pragma: no cover
+ ...
-@overload # noqa: F811
-def raises( # noqa: F811
- expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
- func: Callable,
+@overload
+def raises(
+ expected_exception: Union[Type[_E], Tuple[Type[_E], ...]],
+ func: Callable[..., Any],
*args: Any,
- **kwargs: Any
+ **kwargs: Any,
) -> _pytest._code.ExceptionInfo[_E]:
- ... # pragma: no cover
+ ...
-def raises( # noqa: F811
- expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
- *args: Any,
- **kwargs: Any
+def raises(
+ expected_exception: Union[Type[_E], Tuple[Type[_E], ...]], *args: Any, **kwargs: Any
) -> Union["RaisesContext[_E]", _pytest._code.ExceptionInfo[_E]]:
- r"""
- Assert that a code block/function call raises ``expected_exception``
+ r"""Assert that a code block/function call raises ``expected_exception``
or raise a failure exception otherwise.
- :kwparam match: if specified, a string containing a regular expression,
+ :kwparam match:
+ If specified, a string containing a regular expression,
or a regular expression object, that is tested against the string
representation of the exception using ``re.search``. To match a literal
string that may contain `special characters`__, the pattern can
@@ -589,7 +618,8 @@ def raises( # noqa: F811
Use ``pytest.raises`` as a context manager, which will capture the exception of the given
type::
- >>> with raises(ZeroDivisionError):
+ >>> import pytest
+ >>> with pytest.raises(ZeroDivisionError):
... 1/0
If the code block does not raise the expected exception (``ZeroDivisionError`` in the example
@@ -598,16 +628,16 @@ def raises( # noqa: F811
You can also use the keyword argument ``match`` to assert that the
exception matches a text or regex::
- >>> with raises(ValueError, match='must be 0 or None'):
+ >>> with pytest.raises(ValueError, match='must be 0 or None'):
... raise ValueError("value must be 0 or None")
- >>> with raises(ValueError, match=r'must be \d+$'):
+ >>> with pytest.raises(ValueError, match=r'must be \d+$'):
... raise ValueError("value must be 42")
The context manager produces an :class:`ExceptionInfo` object which can be used to inspect the
details of the captured exception::
- >>> with raises(ValueError) as exc_info:
+ >>> with pytest.raises(ValueError) as exc_info:
... raise ValueError("value must be 42")
>>> assert exc_info.type is ValueError
>>> assert exc_info.value.args[0] == "value must be 42"
@@ -621,7 +651,7 @@ def raises( # noqa: F811
not be executed. For example::
>>> value = 15
- >>> with raises(ValueError) as exc_info:
+ >>> with pytest.raises(ValueError) as exc_info:
... if value > 10:
... raise ValueError("value must be <= 10")
... assert exc_info.type is ValueError # this will not execute
@@ -629,7 +659,7 @@ def raises( # noqa: F811
Instead, the following approach must be taken (note the difference in
scope)::
- >>> with raises(ValueError) as exc_info:
+ >>> with pytest.raises(ValueError) as exc_info:
... if value > 10:
... raise ValueError("value must be <= 10")
...
@@ -677,16 +707,21 @@ def raises( # noqa: F811
documentation for :ref:`the try statement <python:try>`.
"""
__tracebackhide__ = True
- for exc in filterfalse(
- inspect.isclass, always_iterable(expected_exception, BASE_TYPE) # type: ignore[arg-type] # noqa: F821
- ):
- msg = "exceptions must be derived from BaseException, not %s"
- raise TypeError(msg % type(exc))
- message = "DID NOT RAISE {}".format(expected_exception)
+ if isinstance(expected_exception, type):
+ excepted_exceptions: Tuple[Type[_E], ...] = (expected_exception,)
+ else:
+ excepted_exceptions = expected_exception
+ for exc in excepted_exceptions:
+ if not isinstance(exc, type) or not issubclass(exc, BaseException): # type: ignore[unreachable]
+ msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
+ not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
+ raise TypeError(msg.format(not_a))
+
+ message = f"DID NOT RAISE {expected_exception}"
if not args:
- match = kwargs.pop("match", None)
+ match: Optional[Union[str, Pattern[str]]] = kwargs.pop("match", None)
if kwargs:
msg = "Unexpected keyword arguments passed to pytest.raises: "
msg += ", ".join(sorted(kwargs))
@@ -710,20 +745,22 @@ def raises( # noqa: F811
fail(message)
+# This doesn't work with mypy for now. Use fail.Exception instead.
raises.Exception = fail.Exception # type: ignore
+@final
class RaisesContext(Generic[_E]):
def __init__(
self,
- expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
+ expected_exception: Union[Type[_E], Tuple[Type[_E], ...]],
message: str,
- match_expr: Optional[Union[str, "Pattern"]] = None,
+ match_expr: Optional[Union[str, Pattern[str]]] = None,
) -> None:
self.expected_exception = expected_exception
self.message = message
self.match_expr = match_expr
- self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
+ self.excinfo: Optional[_pytest._code.ExceptionInfo[_E]] = None
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later()
@@ -731,7 +768,7 @@ class RaisesContext(Generic[_E]):
def __exit__(
self,
- exc_type: Optional["Type[BaseException]"],
+ exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
@@ -742,9 +779,7 @@ class RaisesContext(Generic[_E]):
if not issubclass(exc_type, self.expected_exception):
return False
# Cast to narrow the exception type now that it's verified.
- exc_info = cast(
- Tuple["Type[_E]", _E, TracebackType], (exc_type, exc_val, exc_tb)
- )
+ exc_info = cast(Tuple[Type[_E], _E, TracebackType], (exc_type, exc_val, exc_tb))
self.excinfo.fill_unfilled(exc_info)
if self.match_expr is not None:
self.excinfo.match(self.match_expr)
diff --git a/contrib/python/pytest/py3/_pytest/recwarn.py b/contrib/python/pytest/py3/_pytest/recwarn.py
index c57c94b1cb..d872d9da40 100644
--- a/contrib/python/pytest/py3/_pytest/recwarn.py
+++ b/contrib/python/pytest/py3/_pytest/recwarn.py
@@ -1,53 +1,79 @@
-""" recording warnings during test function execution. """
+"""Record warnings during test function execution."""
import re
import warnings
from types import TracebackType
from typing import Any
from typing import Callable
+from typing import Generator
from typing import Iterator
from typing import List
from typing import Optional
+from typing import overload
from typing import Pattern
from typing import Tuple
+from typing import Type
+from typing import TypeVar
from typing import Union
-from _pytest.compat import overload
-from _pytest.compat import TYPE_CHECKING
-from _pytest.fixtures import yield_fixture
+from _pytest.compat import final
+from _pytest.deprecated import check_ispytest
+from _pytest.fixtures import fixture
from _pytest.outcomes import fail
-if TYPE_CHECKING:
- from typing import Type
+T = TypeVar("T")
-@yield_fixture
-def recwarn():
+
+@fixture
+def recwarn() -> Generator["WarningsRecorder", None, None]:
"""Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions.
See http://docs.python.org/library/warnings.html for information
on warning categories.
"""
- wrec = WarningsRecorder()
+ wrec = WarningsRecorder(_ispytest=True)
with wrec:
warnings.simplefilter("default")
yield wrec
-def deprecated_call(func=None, *args, **kwargs):
- """context manager that can be used to ensure a block of code triggers a
- ``DeprecationWarning`` or ``PendingDeprecationWarning``::
+@overload
+def deprecated_call(
+ *, match: Optional[Union[str, Pattern[str]]] = ...
+) -> "WarningsRecorder":
+ ...
+
+
+@overload
+def deprecated_call(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
+ ...
+
+
+def deprecated_call(
+ func: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any
+) -> Union["WarningsRecorder", Any]:
+ """Assert that code produces a ``DeprecationWarning`` or ``PendingDeprecationWarning``.
+
+ This function can be used as a context manager::
>>> import warnings
>>> def api_call_v2():
... warnings.warn('use v3 of this api', DeprecationWarning)
... return 200
- >>> with deprecated_call():
+ >>> import pytest
+ >>> with pytest.deprecated_call():
... assert api_call_v2() == 200
- ``deprecated_call`` can also be used by passing a function and ``*args`` and ``*kwargs``,
- in which case it will ensure calling ``func(*args, **kwargs)`` produces one of the warnings
- types above.
+ It can also be used by passing a function and ``*args`` and ``**kwargs``,
+ in which case it will ensure calling ``func(*args, **kwargs)`` produces one of
+ the warnings types above. The return value is the return value of the function.
+
+ In the context manager form you may use the keyword argument ``match`` to assert
+ that the warning matches a text or regex.
+
+ The context manager produces a list of :class:`warnings.WarningMessage` objects,
+ one for each warning raised.
"""
__tracebackhide__ = True
if func is not None:
@@ -57,29 +83,28 @@ def deprecated_call(func=None, *args, **kwargs):
@overload
def warns(
- expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
+ expected_warning: Optional[Union[Type[Warning], Tuple[Type[Warning], ...]]],
*,
- match: "Optional[Union[str, Pattern]]" = ...
+ match: Optional[Union[str, Pattern[str]]] = ...,
) -> "WarningsChecker":
- raise NotImplementedError()
+ ...
-@overload # noqa: F811
-def warns( # noqa: F811
- expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
- func: Callable,
+@overload
+def warns(
+ expected_warning: Optional[Union[Type[Warning], Tuple[Type[Warning], ...]]],
+ func: Callable[..., T],
*args: Any,
- match: Optional[Union[str, "Pattern"]] = ...,
- **kwargs: Any
-) -> Union[Any]:
- raise NotImplementedError()
+ **kwargs: Any,
+) -> T:
+ ...
-def warns( # noqa: F811
- expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
+def warns(
+ expected_warning: Optional[Union[Type[Warning], Tuple[Type[Warning], ...]]],
*args: Any,
- match: Optional[Union[str, "Pattern"]] = None,
- **kwargs: Any
+ match: Optional[Union[str, Pattern[str]]] = None,
+ **kwargs: Any,
) -> Union["WarningsChecker", Any]:
r"""Assert that code raises a particular class of warning.
@@ -91,21 +116,22 @@ def warns( # noqa: F811
one for each warning raised.
This function can be used as a context manager, or any of the other ways
- ``pytest.raises`` can be used::
+ :func:`pytest.raises` can be used::
- >>> with warns(RuntimeWarning):
+ >>> import pytest
+ >>> with pytest.warns(RuntimeWarning):
... warnings.warn("my warning", RuntimeWarning)
In the context manager form you may use the keyword argument ``match`` to assert
- that the exception matches a text or regex::
+ that the warning matches a text or regex::
- >>> with warns(UserWarning, match='must be 0 or None'):
+ >>> with pytest.warns(UserWarning, match='must be 0 or None'):
... warnings.warn("value must be 0 or None", UserWarning)
- >>> with warns(UserWarning, match=r'must be \d+$'):
+ >>> with pytest.warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("value must be 42", UserWarning)
- >>> with warns(UserWarning, match=r'must be \d+$'):
+ >>> with pytest.warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("this is not here", UserWarning)
Traceback (most recent call last):
...
@@ -119,14 +145,14 @@ def warns( # noqa: F811
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
- return WarningsChecker(expected_warning, match_expr=match)
+ return WarningsChecker(expected_warning, match_expr=match, _ispytest=True)
else:
func = args[0]
if not callable(func):
raise TypeError(
"{!r} object (type: {}) must be callable".format(func, type(func))
)
- with WarningsChecker(expected_warning):
+ with WarningsChecker(expected_warning, _ispytest=True):
return func(*args[1:], **kwargs)
@@ -136,21 +162,23 @@ class WarningsRecorder(warnings.catch_warnings):
Adapted from `warnings.catch_warnings`.
"""
- def __init__(self):
- super().__init__(record=True)
+ def __init__(self, *, _ispytest: bool = False) -> None:
+ check_ispytest(_ispytest)
+ # Type ignored due to the way typeshed handles warnings.catch_warnings.
+ super().__init__(record=True) # type: ignore[call-arg]
self._entered = False
- self._list = [] # type: List[warnings._Record]
+ self._list: List[warnings.WarningMessage] = []
@property
- def list(self) -> List["warnings._Record"]:
+ def list(self) -> List["warnings.WarningMessage"]:
"""The list of recorded warnings."""
return self._list
- def __getitem__(self, i: int) -> "warnings._Record":
+ def __getitem__(self, i: int) -> "warnings.WarningMessage":
"""Get a recorded warning by index."""
return self._list[i]
- def __iter__(self) -> Iterator["warnings._Record"]:
+ def __iter__(self) -> Iterator["warnings.WarningMessage"]:
"""Iterate through the recorded warnings."""
return iter(self._list)
@@ -158,7 +186,7 @@ class WarningsRecorder(warnings.catch_warnings):
"""The number of recorded warnings."""
return len(self._list)
- def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
+ def pop(self, cls: Type[Warning] = Warning) -> "warnings.WarningMessage":
"""Pop the first recorded warning, raise exception if not exists."""
for i, w in enumerate(self._list):
if issubclass(w.category, cls):
@@ -185,7 +213,7 @@ class WarningsRecorder(warnings.catch_warnings):
def __exit__(
self,
- exc_type: Optional["Type[BaseException]"],
+ exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
@@ -200,15 +228,19 @@ class WarningsRecorder(warnings.catch_warnings):
self._entered = False
+@final
class WarningsChecker(WarningsRecorder):
def __init__(
self,
expected_warning: Optional[
- Union["Type[Warning]", Tuple["Type[Warning]", ...]]
+ Union[Type[Warning], Tuple[Type[Warning], ...]]
] = None,
- match_expr: Optional[Union[str, "Pattern"]] = None,
+ match_expr: Optional[Union[str, Pattern[str]]] = None,
+ *,
+ _ispytest: bool = False,
) -> None:
- super().__init__()
+ check_ispytest(_ispytest)
+ super().__init__(_ispytest=True)
msg = "exceptions must be derived from Warning, not %s"
if expected_warning is None:
@@ -228,7 +260,7 @@ class WarningsChecker(WarningsRecorder):
def __exit__(
self,
- exc_type: Optional["Type[BaseException]"],
+ exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
diff --git a/contrib/python/pytest/py3/_pytest/reports.py b/contrib/python/pytest/py3/_pytest/reports.py
index 4fa465ea71..58f12517c5 100644
--- a/contrib/python/pytest/py3/_pytest/reports.py
+++ b/contrib/python/pytest/py3/_pytest/reports.py
@@ -1,9 +1,17 @@
from io import StringIO
+from pathlib import Path
from pprint import pprint
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import attr
@@ -11,6 +19,7 @@ import py
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
+from _pytest._code.code import ExceptionRepr
from _pytest._code.code import ReprEntry
from _pytest._code.code import ReprEntryNative
from _pytest._code.code import ReprExceptionInfo
@@ -20,30 +29,42 @@ from _pytest._code.code import ReprLocals
from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
-from _pytest.compat import TYPE_CHECKING
-from _pytest.nodes import Node
+from _pytest.compat import final
+from _pytest.config import Config
+from _pytest.nodes import Collector
+from _pytest.nodes import Item
from _pytest.outcomes import skip
-from _pytest.pathlib import Path
+if TYPE_CHECKING:
+ from typing import NoReturn
+ from typing_extensions import Literal
-def getslaveinfoline(node):
+ from _pytest.runner import CallInfo
+
+
+def getworkerinfoline(node):
try:
- return node._slaveinfocache
+ return node._workerinfocache
except AttributeError:
- d = node.slaveinfo
+ d = node.workerinfo
ver = "%s.%s.%s" % d["version_info"][:3]
- node._slaveinfocache = s = "[{}] {} -- Python {} {}".format(
+ node._workerinfocache = s = "[{}] {} -- Python {} {}".format(
d["id"], d["sysplatform"], ver, d["executable"]
)
return s
+_R = TypeVar("_R", bound="BaseReport")
+
+
class BaseReport:
- when = None # type: Optional[str]
- location = None # type: Optional[Tuple[str, Optional[int], str]]
- longrepr = None
- sections = [] # type: List[Tuple[str, str]]
- nodeid = None # type: str
+ when: Optional[str]
+ location: Optional[Tuple[str, Optional[int], str]]
+ longrepr: Union[
+ None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr
+ ]
+ sections: List[Tuple[str, str]]
+ nodeid: str
def __init__(self, **kw: Any) -> None:
self.__dict__.update(kw)
@@ -51,46 +72,48 @@ class BaseReport:
if TYPE_CHECKING:
# Can have arbitrary fields given to __init__().
def __getattr__(self, key: str) -> Any:
- raise NotImplementedError()
+ ...
- def toterminal(self, out) -> None:
+ def toterminal(self, out: TerminalWriter) -> None:
if hasattr(self, "node"):
- out.line(getslaveinfoline(self.node)) # type: ignore
+ out.line(getworkerinfoline(self.node))
longrepr = self.longrepr
if longrepr is None:
return
if hasattr(longrepr, "toterminal"):
- longrepr.toterminal(out)
+ longrepr_terminal = cast(TerminalRepr, longrepr)
+ longrepr_terminal.toterminal(out)
else:
try:
- out.line(longrepr)
+ s = str(longrepr)
except UnicodeEncodeError:
- out.line("<unprintable longrepr>")
+ s = "<unprintable longrepr>"
+ out.line(s)
- def get_sections(self, prefix):
+ def get_sections(self, prefix: str) -> Iterator[Tuple[str, str]]:
for name, content in self.sections:
if name.startswith(prefix):
yield prefix, content
@property
- def longreprtext(self):
- """
- Read-only property that returns the full string representation
- of ``longrepr``.
+ def longreprtext(self) -> str:
+ """Read-only property that returns the full string representation of
+ ``longrepr``.
.. versionadded:: 3.0
"""
- tw = TerminalWriter(stringio=True)
+ file = StringIO()
+ tw = TerminalWriter(file)
tw.hasmarkup = False
self.toterminal(tw)
- exc = tw.stringio.getvalue()
+ exc = file.getvalue()
return exc.strip()
@property
- def caplog(self):
- """Return captured log lines, if log capturing is enabled
+ def caplog(self) -> str:
+ """Return captured log lines, if log capturing is enabled.
.. versionadded:: 3.5
"""
@@ -99,8 +122,8 @@ class BaseReport:
)
@property
- def capstdout(self):
- """Return captured text from stdout, if capturing is enabled
+ def capstdout(self) -> str:
+ """Return captured text from stdout, if capturing is enabled.
.. versionadded:: 3.0
"""
@@ -109,8 +132,8 @@ class BaseReport:
)
@property
- def capstderr(self):
- """Return captured text from stderr, if capturing is enabled
+ def capstderr(self) -> str:
+ """Return captured text from stderr, if capturing is enabled.
.. versionadded:: 3.0
"""
@@ -127,12 +150,9 @@ class BaseReport:
return self.nodeid.split("::")[0]
@property
- def count_towards_summary(self):
- """
- **Experimental**
-
- Returns True if this report should be counted towards the totals shown at the end of the
- test session: "1 passed, 1 failure, etc".
+ def count_towards_summary(self) -> bool:
+ """**Experimental** Whether this report should be counted towards the
+ totals shown at the end of the test session: "1 passed, 1 failure, etc".
.. note::
@@ -142,12 +162,10 @@ class BaseReport:
return True
@property
- def head_line(self):
- """
- **Experimental**
-
- Returns the head line shown with longrepr output for this report, more commonly during
- traceback representation during failures::
+ def head_line(self) -> Optional[str]:
+ """**Experimental** The head line shown with longrepr output for this
+ report, more commonly during traceback representation during
+ failures::
________ Test.foo ________
@@ -162,31 +180,31 @@ class BaseReport:
if self.location is not None:
fspath, lineno, domain = self.location
return domain
+ return None
- def _get_verbose_word(self, config):
+ def _get_verbose_word(self, config: Config):
_category, _short, verbose = config.hook.pytest_report_teststatus(
report=self, config=config
)
return verbose
- def _to_json(self):
- """
- This was originally the serialize_report() function from xdist (ca03269).
+ def _to_json(self) -> Dict[str, Any]:
+ """Return the contents of this report as a dict of builtin entries,
+ suitable for serialization.
- Returns the contents of this report as a dict of builtin entries, suitable for
- serialization.
+ This was originally the serialize_report() function from xdist (ca03269).
Experimental method.
"""
return _report_to_json(self)
@classmethod
- def _from_json(cls, reportdict):
- """
- This was originally the serialize_report() function from xdist (ca03269).
+ def _from_json(cls: Type[_R], reportdict: Dict[str, object]) -> _R:
+ """Create either a TestReport or CollectReport, depending on the calling class.
- Factory method that returns either a TestReport or CollectReport, depending on the calling
- class. It's the callers responsibility to know which class to pass here.
+ It is the callers responsibility to know which class to pass here.
+
+ This was originally the serialize_report() function from xdist (ca03269).
Experimental method.
"""
@@ -194,7 +212,9 @@ class BaseReport:
return cls(**kwargs)
-def _report_unserialization_failure(type_name, report_class, reportdict):
+def _report_unserialization_failure(
+ type_name: str, report_class: Type[BaseReport], reportdict
+) -> "NoReturn":
url = "https://github.com/pytest-dev/pytest/issues"
stream = StringIO()
pprint("-" * 100, stream=stream)
@@ -206,85 +226,93 @@ def _report_unserialization_failure(type_name, report_class, reportdict):
raise RuntimeError(stream.getvalue())
+@final
class TestReport(BaseReport):
- """ Basic test report object (also used for setup and teardown calls if
- they fail).
- """
+ """Basic test report object (also used for setup and teardown calls if
+ they fail)."""
__test__ = False
def __init__(
self,
- nodeid,
+ nodeid: str,
location: Tuple[str, Optional[int], str],
keywords,
- outcome,
- longrepr,
- when,
- sections=(),
- duration=0,
- user_properties=None,
- **extra
+ outcome: "Literal['passed', 'failed', 'skipped']",
+ longrepr: Union[
+ None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr
+ ],
+ when: "Literal['setup', 'call', 'teardown']",
+ sections: Iterable[Tuple[str, str]] = (),
+ duration: float = 0,
+ user_properties: Optional[Iterable[Tuple[str, object]]] = None,
+ **extra,
) -> None:
- #: normalized collection node id
+ #: Normalized collection nodeid.
self.nodeid = nodeid
- #: a (filesystempath, lineno, domaininfo) tuple indicating the
+ #: A (filesystempath, lineno, domaininfo) tuple indicating the
#: actual location of a test item - it might be different from the
#: collected one e.g. if a method is inherited from a different module.
- self.location = location # type: Tuple[str, Optional[int], str]
+ self.location: Tuple[str, Optional[int], str] = location
- #: a name -> value dictionary containing all keywords and
+ #: A name -> value dictionary containing all keywords and
#: markers associated with a test invocation.
self.keywords = keywords
- #: test outcome, always one of "passed", "failed", "skipped".
+ #: Test outcome, always one of "passed", "failed", "skipped".
self.outcome = outcome
#: None or a failure representation.
self.longrepr = longrepr
- #: one of 'setup', 'call', 'teardown' to indicate runtest phase.
+ #: One of 'setup', 'call', 'teardown' to indicate runtest phase.
self.when = when
- #: user properties is a list of tuples (name, value) that holds user
- #: defined properties of the test
+ #: User properties is a list of tuples (name, value) that holds user
+ #: defined properties of the test.
self.user_properties = list(user_properties or [])
- #: list of pairs ``(str, str)`` of extra information which needs to
+ #: List of pairs ``(str, str)`` of extra information which needs to
#: marshallable. Used by pytest to add captured text
#: from ``stdout`` and ``stderr``, but may be used by other plugins
#: to add arbitrary information to reports.
self.sections = list(sections)
- #: time it took to run just the test
+ #: Time it took to run just the test.
self.duration = duration
self.__dict__.update(extra)
- def __repr__(self):
+ def __repr__(self) -> str:
return "<{} {!r} when={!r} outcome={!r}>".format(
self.__class__.__name__, self.nodeid, self.when, self.outcome
)
@classmethod
- def from_item_and_call(cls, item, call) -> "TestReport":
- """
- Factory method to create and fill a TestReport with standard item and call info.
- """
+ def from_item_and_call(cls, item: Item, call: "CallInfo[None]") -> "TestReport":
+ """Create and fill a TestReport with standard item and call info."""
when = call.when
- duration = call.stop - call.start
+ # Remove "collect" from the Literal type -- only for collection calls.
+ assert when != "collect"
+ duration = call.duration
keywords = {x: 1 for x in item.keywords}
excinfo = call.excinfo
sections = []
if not call.excinfo:
- outcome = "passed"
- longrepr = None
+ outcome: Literal["passed", "failed", "skipped"] = "passed"
+ longrepr: Union[
+ None,
+ ExceptionInfo[BaseException],
+ Tuple[str, int, str],
+ str,
+ TerminalRepr,
+ ] = (None)
else:
if not isinstance(excinfo, ExceptionInfo):
outcome = "failed"
longrepr = excinfo
- elif excinfo.errisinstance(skip.Exception):
+ elif isinstance(excinfo.value, skip.Exception):
outcome = "skipped"
r = excinfo._getreprcrash()
longrepr = (str(r.path), r.lineno, r.message)
@@ -297,7 +325,7 @@ class TestReport(BaseReport):
excinfo, style=item.config.getoption("tbstyle", "auto")
)
for rwhen, key, content in item._report_sections:
- sections.append(("Captured {} {}".format(key, rwhen), content))
+ sections.append((f"Captured {key} {rwhen}", content))
return cls(
item.nodeid,
item.location,
@@ -311,45 +339,74 @@ class TestReport(BaseReport):
)
+@final
class CollectReport(BaseReport):
+ """Collection report object."""
+
when = "collect"
def __init__(
- self, nodeid: str, outcome, longrepr, result: List[Node], sections=(), **extra
+ self,
+ nodeid: str,
+ outcome: "Literal['passed', 'skipped', 'failed']",
+ longrepr,
+ result: Optional[List[Union[Item, Collector]]],
+ sections: Iterable[Tuple[str, str]] = (),
+ **extra,
) -> None:
+ #: Normalized collection nodeid.
self.nodeid = nodeid
+
+ #: Test outcome, always one of "passed", "failed", "skipped".
self.outcome = outcome
+
+ #: None or a failure representation.
self.longrepr = longrepr
+
+ #: The collected items and collection nodes.
self.result = result or []
+
+ #: List of pairs ``(str, str)`` of extra information which needs to
+ #: marshallable.
+ # Used by pytest to add captured text : from ``stdout`` and ``stderr``,
+ # but may be used by other plugins : to add arbitrary information to
+ # reports.
self.sections = list(sections)
+
self.__dict__.update(extra)
@property
def location(self):
return (self.fspath, None, self.fspath)
- def __repr__(self):
+ def __repr__(self) -> str:
return "<CollectReport {!r} lenresult={} outcome={!r}>".format(
self.nodeid, len(self.result), self.outcome
)
class CollectErrorRepr(TerminalRepr):
- def __init__(self, msg):
+ def __init__(self, msg: str) -> None:
self.longrepr = msg
- def toterminal(self, out) -> None:
+ def toterminal(self, out: TerminalWriter) -> None:
out.line(self.longrepr, red=True)
-def pytest_report_to_serializable(report):
+def pytest_report_to_serializable(
+ report: Union[CollectReport, TestReport]
+) -> Optional[Dict[str, Any]]:
if isinstance(report, (TestReport, CollectReport)):
data = report._to_json()
data["$report_type"] = report.__class__.__name__
return data
+ # TODO: Check if this is actually reachable.
+ return None # type: ignore[unreachable]
-def pytest_report_from_serializable(data):
+def pytest_report_from_serializable(
+ data: Dict[str, Any],
+) -> Optional[Union[CollectReport, TestReport]]:
if "$report_type" in data:
if data["$report_type"] == "TestReport":
return TestReport._from_json(data)
@@ -358,45 +415,53 @@ def pytest_report_from_serializable(data):
assert False, "Unknown report_type unserialize data: {}".format(
data["$report_type"]
)
+ return None
-def _report_to_json(report):
- """
- This was originally the serialize_report() function from xdist (ca03269).
+def _report_to_json(report: BaseReport) -> Dict[str, Any]:
+ """Return the contents of this report as a dict of builtin entries,
+ suitable for serialization.
- Returns the contents of this report as a dict of builtin entries, suitable for
- serialization.
+ This was originally the serialize_report() function from xdist (ca03269).
"""
- def serialize_repr_entry(entry):
- entry_data = {"type": type(entry).__name__, "data": attr.asdict(entry)}
- for key, value in entry_data["data"].items():
+ def serialize_repr_entry(
+ entry: Union[ReprEntry, ReprEntryNative]
+ ) -> Dict[str, Any]:
+ data = attr.asdict(entry)
+ for key, value in data.items():
if hasattr(value, "__dict__"):
- entry_data["data"][key] = attr.asdict(value)
+ data[key] = attr.asdict(value)
+ entry_data = {"type": type(entry).__name__, "data": data}
return entry_data
- def serialize_repr_traceback(reprtraceback: ReprTraceback):
+ def serialize_repr_traceback(reprtraceback: ReprTraceback) -> Dict[str, Any]:
result = attr.asdict(reprtraceback)
result["reprentries"] = [
serialize_repr_entry(x) for x in reprtraceback.reprentries
]
return result
- def serialize_repr_crash(reprcrash: Optional[ReprFileLocation]):
+ def serialize_repr_crash(
+ reprcrash: Optional[ReprFileLocation],
+ ) -> Optional[Dict[str, Any]]:
if reprcrash is not None:
return attr.asdict(reprcrash)
else:
return None
- def serialize_longrepr(rep):
- result = {
- "reprcrash": serialize_repr_crash(rep.longrepr.reprcrash),
- "reprtraceback": serialize_repr_traceback(rep.longrepr.reprtraceback),
- "sections": rep.longrepr.sections,
+ def serialize_exception_longrepr(rep: BaseReport) -> Dict[str, Any]:
+ assert rep.longrepr is not None
+ # TODO: Investigate whether the duck typing is really necessary here.
+ longrepr = cast(ExceptionRepr, rep.longrepr)
+ result: Dict[str, Any] = {
+ "reprcrash": serialize_repr_crash(longrepr.reprcrash),
+ "reprtraceback": serialize_repr_traceback(longrepr.reprtraceback),
+ "sections": longrepr.sections,
}
- if isinstance(rep.longrepr, ExceptionChainRepr):
+ if isinstance(longrepr, ExceptionChainRepr):
result["chain"] = []
- for repr_traceback, repr_crash, description in rep.longrepr.chain:
+ for repr_traceback, repr_crash, description in longrepr.chain:
result["chain"].append(
(
serialize_repr_traceback(repr_traceback),
@@ -413,7 +478,7 @@ def _report_to_json(report):
if hasattr(report.longrepr, "reprtraceback") and hasattr(
report.longrepr, "reprcrash"
):
- d["longrepr"] = serialize_longrepr(report)
+ d["longrepr"] = serialize_exception_longrepr(report)
else:
d["longrepr"] = str(report.longrepr)
else:
@@ -426,11 +491,11 @@ def _report_to_json(report):
return d
-def _report_kwargs_from_json(reportdict):
- """
- This was originally the serialize_report() function from xdist (ca03269).
+def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
+ """Return **kwargs that can be used to construct a TestReport or
+ CollectReport instance.
- Returns **kwargs that can be used to construct a TestReport or CollectReport instance.
+ This was originally the serialize_report() function from xdist (ca03269).
"""
def deserialize_repr_entry(entry_data):
@@ -447,13 +512,13 @@ def _report_kwargs_from_json(reportdict):
if data["reprlocals"]:
reprlocals = ReprLocals(data["reprlocals"]["lines"])
- reprentry = ReprEntry(
+ reprentry: Union[ReprEntry, ReprEntryNative] = ReprEntry(
lines=data["lines"],
reprfuncargs=reprfuncargs,
reprlocals=reprlocals,
reprfileloc=reprfileloc,
style=data["style"],
- ) # type: Union[ReprEntry, ReprEntryNative]
+ )
elif entry_type == "ReprEntryNative":
reprentry = ReprEntryNative(data["lines"])
else:
@@ -466,7 +531,7 @@ def _report_kwargs_from_json(reportdict):
]
return ReprTraceback(**repr_traceback_dict)
- def deserialize_repr_crash(repr_crash_dict: Optional[dict]):
+ def deserialize_repr_crash(repr_crash_dict: Optional[Dict[str, Any]]):
if repr_crash_dict is not None:
return ReprFileLocation(**repr_crash_dict)
else:
@@ -494,9 +559,9 @@ def _report_kwargs_from_json(reportdict):
description,
)
)
- exception_info = ExceptionChainRepr(
- chain
- ) # type: Union[ExceptionChainRepr,ReprExceptionInfo]
+ exception_info: Union[
+ ExceptionChainRepr, ReprExceptionInfo
+ ] = ExceptionChainRepr(chain)
else:
exception_info = ReprExceptionInfo(reprtraceback, reprcrash)
diff --git a/contrib/python/pytest/py3/_pytest/resultlog.py b/contrib/python/pytest/py3/_pytest/resultlog.py
deleted file mode 100644
index 3cfa9e0e96..0000000000
--- a/contrib/python/pytest/py3/_pytest/resultlog.py
+++ /dev/null
@@ -1,102 +0,0 @@
-""" log machine-parseable test session result information in a plain
-text file.
-"""
-import os
-
-import py
-
-from _pytest.store import StoreKey
-
-
-resultlog_key = StoreKey["ResultLog"]()
-
-
-def pytest_addoption(parser):
- group = parser.getgroup("terminal reporting", "resultlog plugin options")
- group.addoption(
- "--resultlog",
- "--result-log",
- action="store",
- metavar="path",
- default=None,
- help="DEPRECATED path for machine-readable result log.",
- )
-
-
-def pytest_configure(config):
- resultlog = config.option.resultlog
- # prevent opening resultlog on slave nodes (xdist)
- if resultlog and not hasattr(config, "slaveinput"):
- dirname = os.path.dirname(os.path.abspath(resultlog))
- if not os.path.isdir(dirname):
- os.makedirs(dirname)
- logfile = open(resultlog, "w", 1) # line buffered
- config._store[resultlog_key] = ResultLog(config, logfile)
- config.pluginmanager.register(config._store[resultlog_key])
-
- from _pytest.deprecated import RESULT_LOG
- from _pytest.warnings import _issue_warning_captured
-
- _issue_warning_captured(RESULT_LOG, config.hook, stacklevel=2)
-
-
-def pytest_unconfigure(config):
- resultlog = config._store.get(resultlog_key, None)
- if resultlog:
- resultlog.logfile.close()
- del config._store[resultlog_key]
- config.pluginmanager.unregister(resultlog)
-
-
-class ResultLog:
- def __init__(self, config, logfile):
- self.config = config
- self.logfile = logfile # preferably line buffered
-
- def write_log_entry(self, testpath, lettercode, longrepr):
- print("{} {}".format(lettercode, testpath), file=self.logfile)
- for line in longrepr.splitlines():
- print(" %s" % line, file=self.logfile)
-
- def log_outcome(self, report, lettercode, longrepr):
- testpath = getattr(report, "nodeid", None)
- if testpath is None:
- testpath = report.fspath
- self.write_log_entry(testpath, lettercode, longrepr)
-
- def pytest_runtest_logreport(self, report):
- if report.when != "call" and report.passed:
- return
- res = self.config.hook.pytest_report_teststatus(
- report=report, config=self.config
- )
- code = res[1]
- if code == "x":
- longrepr = str(report.longrepr)
- elif code == "X":
- longrepr = ""
- elif report.passed:
- longrepr = ""
- elif report.skipped:
- longrepr = str(report.longrepr[2])
- else:
- longrepr = str(report.longrepr)
- self.log_outcome(report, code, longrepr)
-
- def pytest_collectreport(self, report):
- if not report.passed:
- if report.failed:
- code = "F"
- longrepr = str(report.longrepr)
- else:
- assert report.skipped
- code = "S"
- longrepr = "%s:%d: %s" % report.longrepr
- self.log_outcome(report, code, longrepr)
-
- def pytest_internalerror(self, excrepr):
- reprcrash = getattr(excrepr, "reprcrash", None)
- path = getattr(reprcrash, "path", None)
- if path is None:
- path = "cwd:%s" % py.path.local()
- self.write_log_entry(path, "!", str(excrepr))
diff --git a/contrib/python/pytest/py3/_pytest/runner.py b/contrib/python/pytest/py3/_pytest/runner.py
index 412ea44a87..794690ddb0 100644
--- a/contrib/python/pytest/py3/_pytest/runner.py
+++ b/contrib/python/pytest/py3/_pytest/runner.py
@@ -1,37 +1,49 @@
-""" basic collect and runtest protocol implementations """
+"""Basic collect and runtest protocol implementations."""
import bdb
import os
import sys
-from time import time
from typing import Callable
+from typing import cast
from typing import Dict
+from typing import Generic
from typing import List
from typing import Optional
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
import attr
+from .reports import BaseReport
from .reports import CollectErrorRepr
from .reports import CollectReport
from .reports import TestReport
+from _pytest import timing
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
-from _pytest.compat import TYPE_CHECKING
+from _pytest._code.code import TerminalRepr
+from _pytest.compat import final
+from _pytest.config.argparsing import Parser
from _pytest.nodes import Collector
+from _pytest.nodes import Item
from _pytest.nodes import Node
from _pytest.outcomes import Exit
from _pytest.outcomes import Skipped
from _pytest.outcomes import TEST_OUTCOME
if TYPE_CHECKING:
- from typing import Type
from typing_extensions import Literal
+ from _pytest.main import Session
+ from _pytest.terminal import TerminalReporter
+
#
-# pytest plugin hooks
+# pytest plugin hooks.
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "reporting", after="general")
group.addoption(
"--durations",
@@ -41,10 +53,19 @@ def pytest_addoption(parser):
metavar="N",
help="show N slowest setup/test durations (N=0 for all).",
)
+ group.addoption(
+ "--durations-min",
+ action="store",
+ type=float,
+ default=0.005,
+ metavar="N",
+ help="Minimal duration in seconds for inclusion in slowest list. Default 0.005",
+ )
-def pytest_terminal_summary(terminalreporter):
+def pytest_terminal_summary(terminalreporter: "TerminalReporter") -> None:
durations = terminalreporter.config.option.durations
+ durations_min = terminalreporter.config.option.durations_min
verbose = terminalreporter.config.getvalue("verbose")
if durations is None:
return
@@ -56,41 +77,46 @@ def pytest_terminal_summary(terminalreporter):
dlist.append(rep)
if not dlist:
return
- dlist.sort(key=lambda x: x.duration)
- dlist.reverse()
+ dlist.sort(key=lambda x: x.duration, reverse=True) # type: ignore[no-any-return]
if not durations:
- tr.write_sep("=", "slowest test durations")
+ tr.write_sep("=", "slowest durations")
else:
- tr.write_sep("=", "slowest %s test durations" % durations)
+ tr.write_sep("=", "slowest %s durations" % durations)
dlist = dlist[:durations]
- for rep in dlist:
- if verbose < 2 and rep.duration < 0.005:
+ for i, rep in enumerate(dlist):
+ if verbose < 2 and rep.duration < durations_min:
tr.write_line("")
- tr.write_line("(0.00 durations hidden. Use -vv to show these durations.)")
+ tr.write_line(
+ "(%s durations < %gs hidden. Use -vv to show these durations.)"
+ % (len(dlist) - i, durations_min)
+ )
break
- tr.write_line("{:02.2f}s {:<8} {}".format(rep.duration, rep.when, rep.nodeid))
+ tr.write_line(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}")
-def pytest_sessionstart(session):
+def pytest_sessionstart(session: "Session") -> None:
session._setupstate = SetupState()
-def pytest_sessionfinish(session):
+def pytest_sessionfinish(session: "Session") -> None:
session._setupstate.teardown_all()
-def pytest_runtest_protocol(item, nextitem):
- item.ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
+def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
+ ihook = item.ihook
+ ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
runtestprotocol(item, nextitem=nextitem)
- item.ihook.pytest_runtest_logfinish(nodeid=item.nodeid, location=item.location)
+ ihook.pytest_runtest_logfinish(nodeid=item.nodeid, location=item.location)
return True
-def runtestprotocol(item, log=True, nextitem=None):
+def runtestprotocol(
+ item: Item, log: bool = True, nextitem: Optional[Item] = None
+) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
- if hasrequest and not item._request:
- item._initrequest()
+ if hasrequest and not item._request: # type: ignore[attr-defined]
+ item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
@@ -99,15 +125,15 @@ def runtestprotocol(item, log=True, nextitem=None):
if not item.config.getoption("setuponly", False):
reports.append(call_and_report(item, "call", log))
reports.append(call_and_report(item, "teardown", log, nextitem=nextitem))
- # after all teardown hooks have been called
- # want funcargs and request info to go away
+ # After all teardown hooks have been called
+ # want funcargs and request info to go away.
if hasrequest:
- item._request = False
- item.funcargs = None
+ item._request = False # type: ignore[attr-defined]
+ item.funcargs = None # type: ignore[attr-defined]
return reports
-def show_test_item(item):
+def show_test_item(item: Item) -> None:
"""Show test function, parameters and the fixtures of the test item."""
tw = item.config.get_terminal_writer()
tw.line()
@@ -116,14 +142,15 @@ def show_test_item(item):
used_fixtures = sorted(getattr(item, "fixturenames", []))
if used_fixtures:
tw.write(" (fixtures used: {})".format(", ".join(used_fixtures)))
+ tw.flush()
-def pytest_runtest_setup(item):
+def pytest_runtest_setup(item: Item) -> None:
_update_current_test_var(item, "setup")
item.session._setupstate.prepare(item)
-def pytest_runtest_call(item):
+def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
@@ -143,21 +170,22 @@ def pytest_runtest_call(item):
raise e
-def pytest_runtest_teardown(item, nextitem):
+def pytest_runtest_teardown(item: Item, nextitem: Optional[Item]) -> None:
_update_current_test_var(item, "teardown")
item.session._setupstate.teardown_exact(item, nextitem)
_update_current_test_var(item, None)
-def _update_current_test_var(item, when):
- """
- Update PYTEST_CURRENT_TEST to reflect the current item and stage.
+def _update_current_test_var(
+ item: Item, when: Optional["Literal['setup', 'call', 'teardown']"]
+) -> None:
+ """Update :envvar:`PYTEST_CURRENT_TEST` to reflect the current item and stage.
- If ``when`` is None, delete PYTEST_CURRENT_TEST from the environment.
+ If ``when`` is None, delete ``PYTEST_CURRENT_TEST`` from the environment.
"""
var_name = "PYTEST_CURRENT_TEST"
if when:
- value = "{} ({})".format(item.nodeid, when)
+ value = f"{item.nodeid} ({when})"
# don't allow null bytes on environment variables (see #2644, #2957)
value = value.replace("\x00", "(null)")
os.environ[var_name] = value
@@ -165,7 +193,7 @@ def _update_current_test_var(item, when):
os.environ.pop(var_name)
-def pytest_report_teststatus(report):
+def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if report.when in ("setup", "teardown"):
if report.failed:
# category, shortletter, verbose-word
@@ -174,6 +202,7 @@ def pytest_report_teststatus(report):
return "skipped", "s", "SKIPPED"
else:
return "", "", ""
+ return None
#
@@ -181,11 +210,11 @@ def pytest_report_teststatus(report):
def call_and_report(
- item, when: "Literal['setup', 'call', 'teardown']", log=True, **kwds
-):
+ item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
+) -> TestReport:
call = call_runtest_hook(item, when, **kwds)
hook = item.ihook
- report = hook.pytest_runtest_makereport(item=item, call=call)
+ report: TestReport = hook.pytest_runtest_makereport(item=item, call=call)
if log:
hook.pytest_runtest_logreport(report=report)
if check_interactive_exception(call, report):
@@ -193,24 +222,33 @@ def call_and_report(
return report
-def check_interactive_exception(call, report):
- return call.excinfo and not (
- hasattr(report, "wasxfail")
- or call.excinfo.errisinstance(Skipped)
- or call.excinfo.errisinstance(bdb.BdbQuit)
- )
+def check_interactive_exception(call: "CallInfo[object]", report: BaseReport) -> bool:
+ """Check whether the call raised an exception that should be reported as
+ interactive."""
+ if call.excinfo is None:
+ # Didn't raise.
+ return False
+ if hasattr(report, "wasxfail"):
+ # Exception was expected.
+ return False
+ if isinstance(call.excinfo.value, (Skipped, bdb.BdbQuit)):
+ # Special control flow exception.
+ return False
+ return True
-def call_runtest_hook(item, when: "Literal['setup', 'call', 'teardown']", **kwds):
+def call_runtest_hook(
+ item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
+) -> "CallInfo[None]":
if when == "setup":
- ihook = item.ihook.pytest_runtest_setup
+ ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
- assert False, "Unhandled runtest hook case: {}".format(when)
- reraise = (Exit,) # type: Tuple[Type[BaseException], ...]
+ assert False, f"Unhandled runtest hook case: {when}"
+ reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
return CallInfo.from_call(
@@ -218,60 +256,99 @@ def call_runtest_hook(item, when: "Literal['setup', 'call', 'teardown']", **kwds
)
+TResult = TypeVar("TResult", covariant=True)
+
+
+@final
@attr.s(repr=False)
-class CallInfo:
- """ Result/Exception info a function invocation. """
+class CallInfo(Generic[TResult]):
+ """Result/Exception info a function invocation.
+
+ :param T result:
+ The return value of the call, if it didn't raise. Can only be
+ accessed if excinfo is None.
+ :param Optional[ExceptionInfo] excinfo:
+ The captured exception of the call, if it raised.
+ :param float start:
+ The system time when the call started, in seconds since the epoch.
+ :param float stop:
+ The system time when the call ended, in seconds since the epoch.
+ :param float duration:
+ The call duration, in seconds.
+ :param str when:
+ The context of invocation: "setup", "call", "teardown", ...
+ """
- _result = attr.ib()
- excinfo = attr.ib(type=Optional[ExceptionInfo])
- start = attr.ib()
- stop = attr.ib()
- when = attr.ib()
+ _result = attr.ib(type="Optional[TResult]")
+ excinfo = attr.ib(type=Optional[ExceptionInfo[BaseException]])
+ start = attr.ib(type=float)
+ stop = attr.ib(type=float)
+ duration = attr.ib(type=float)
+ when = attr.ib(type="Literal['collect', 'setup', 'call', 'teardown']")
@property
- def result(self):
+ def result(self) -> TResult:
if self.excinfo is not None:
- raise AttributeError("{!r} has no valid result".format(self))
- return self._result
+ raise AttributeError(f"{self!r} has no valid result")
+ # The cast is safe because an exception wasn't raised, hence
+ # _result has the expected function return type (which may be
+ # None, that's why a cast and not an assert).
+ return cast(TResult, self._result)
@classmethod
- def from_call(cls, func, when, reraise=None) -> "CallInfo":
- #: context of invocation: one of "setup", "call",
- #: "teardown", "memocollect"
- start = time()
+ def from_call(
+ cls,
+ func: "Callable[[], TResult]",
+ when: "Literal['collect', 'setup', 'call', 'teardown']",
+ reraise: Optional[
+ Union[Type[BaseException], Tuple[Type[BaseException], ...]]
+ ] = None,
+ ) -> "CallInfo[TResult]":
excinfo = None
+ start = timing.time()
+ precise_start = timing.perf_counter()
try:
- result = func()
- except: # noqa
+ result: Optional[TResult] = func()
+ except BaseException:
excinfo = ExceptionInfo.from_current()
- if reraise is not None and excinfo.errisinstance(reraise):
+ if reraise is not None and isinstance(excinfo.value, reraise):
raise
result = None
- stop = time()
- return cls(start=start, stop=stop, when=when, result=result, excinfo=excinfo)
-
- def __repr__(self):
+ # use the perf counter
+ precise_stop = timing.perf_counter()
+ duration = precise_stop - precise_start
+ stop = timing.time()
+ return cls(
+ start=start,
+ stop=stop,
+ duration=duration,
+ when=when,
+ result=result,
+ excinfo=excinfo,
+ )
+
+ def __repr__(self) -> str:
if self.excinfo is None:
- return "<CallInfo when={!r} result: {!r}>".format(self.when, self._result)
- return "<CallInfo when={!r} excinfo={!r}>".format(self.when, self.excinfo)
+ return f"<CallInfo when={self.when!r} result: {self._result!r}>"
+ return f"<CallInfo when={self.when!r} excinfo={self.excinfo!r}>"
-def pytest_runtest_makereport(item, call):
+def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
return TestReport.from_item_and_call(item, call)
def pytest_make_collect_report(collector: Collector) -> CollectReport:
call = CallInfo.from_call(lambda: list(collector.collect()), "collect")
- longrepr = None
+ longrepr: Union[None, Tuple[str, int, str], str, TerminalRepr] = None
if not call.excinfo:
- outcome = "passed"
+ outcome: Literal["passed", "skipped", "failed"] = "passed"
else:
skip_exceptions = [Skipped]
unittest = sys.modules.get("unittest")
if unittest is not None:
# Type ignored because unittest is loaded dynamically.
skip_exceptions.append(unittest.SkipTest) # type: ignore
- if call.excinfo.errisinstance(tuple(skip_exceptions)):
+ if isinstance(call.excinfo.value, tuple(skip_exceptions)):
outcome = "skipped"
r_ = collector._repr_failure_py(call.excinfo, "line")
assert isinstance(r_, ExceptionChainRepr), repr(r_)
@@ -282,24 +359,24 @@ def pytest_make_collect_report(collector: Collector) -> CollectReport:
outcome = "failed"
errorinfo = collector.repr_failure(call.excinfo)
if not hasattr(errorinfo, "toterminal"):
+ assert isinstance(errorinfo, str)
errorinfo = CollectErrorRepr(errorinfo)
longrepr = errorinfo
- rep = CollectReport(
- collector.nodeid, outcome, longrepr, getattr(call, "result", None)
- )
+ result = call.result if not call.excinfo else None
+ rep = CollectReport(collector.nodeid, outcome, longrepr, result)
rep.call = call # type: ignore # see collect_one_node
return rep
class SetupState:
- """ shared state for setting up/tearing down test items or collectors. """
+ """Shared state for setting up/tearing down test items or collectors."""
def __init__(self):
- self.stack = [] # type: List[Node]
- self._finalizers = {} # type: Dict[Node, List[Callable[[], None]]]
+ self.stack: List[Node] = []
+ self._finalizers: Dict[Node, List[Callable[[], object]]] = {}
- def addfinalizer(self, finalizer, colitem):
- """ attach a finalizer to the given colitem. """
+ def addfinalizer(self, finalizer: Callable[[], object], colitem) -> None:
+ """Attach a finalizer to the given colitem."""
assert colitem and not isinstance(colitem, tuple)
assert callable(finalizer)
# assert colitem in self.stack # some unit tests don't setup stack :/
@@ -309,7 +386,7 @@ class SetupState:
colitem = self.stack.pop()
self._teardown_with_finalization(colitem)
- def _callfinalizers(self, colitem):
+ def _callfinalizers(self, colitem) -> None:
finalizers = self._finalizers.pop(colitem, None)
exc = None
while finalizers:
@@ -324,24 +401,24 @@ class SetupState:
if exc:
raise exc
- def _teardown_with_finalization(self, colitem):
+ def _teardown_with_finalization(self, colitem) -> None:
self._callfinalizers(colitem)
colitem.teardown()
for colitem in self._finalizers:
assert colitem in self.stack
- def teardown_all(self):
+ def teardown_all(self) -> None:
while self.stack:
self._pop_and_teardown()
for key in list(self._finalizers):
self._teardown_with_finalization(key)
assert not self._finalizers
- def teardown_exact(self, item, nextitem):
+ def teardown_exact(self, item, nextitem) -> None:
needed_collectors = nextitem and nextitem.listchain() or []
self._teardown_towards(needed_collectors)
- def _teardown_towards(self, needed_collectors):
+ def _teardown_towards(self, needed_collectors) -> None:
exc = None
while self.stack:
if self.stack == needed_collectors[: len(self.stack)]:
@@ -356,30 +433,29 @@ class SetupState:
if exc:
raise exc
- def prepare(self, colitem):
- """ setup objects along the collector chain to the test-method
- and teardown previously setup objects."""
- needed_collectors = colitem.listchain()
- self._teardown_towards(needed_collectors)
+ def prepare(self, colitem) -> None:
+ """Setup objects along the collector chain to the test-method."""
- # check if the last collection node has raised an error
+ # Check if the last collection node has raised an error.
for col in self.stack:
if hasattr(col, "_prepare_exc"):
- exc = col._prepare_exc
+ exc = col._prepare_exc # type: ignore[attr-defined]
raise exc
+
+ needed_collectors = colitem.listchain()
for col in needed_collectors[len(self.stack) :]:
self.stack.append(col)
try:
col.setup()
except TEST_OUTCOME as e:
- col._prepare_exc = e
+ col._prepare_exc = e # type: ignore[attr-defined]
raise e
-def collect_one_node(collector):
+def collect_one_node(collector: Collector) -> CollectReport:
ihook = collector.ihook
ihook.pytest_collectstart(collector=collector)
- rep = ihook.pytest_make_collect_report(collector=collector)
+ rep: CollectReport = ihook.pytest_make_collect_report(collector=collector)
call = rep.__dict__.pop("call", None)
if call and check_interactive_exception(call, rep):
ihook.pytest_exception_interact(node=collector, call=call, report=rep)
diff --git a/contrib/python/pytest/py3/_pytest/setuponly.py b/contrib/python/pytest/py3/_pytest/setuponly.py
index aa5a95ff92..44a1094c0d 100644
--- a/contrib/python/pytest/py3/_pytest/setuponly.py
+++ b/contrib/python/pytest/py3/_pytest/setuponly.py
@@ -1,7 +1,17 @@
+from typing import Generator
+from typing import Optional
+from typing import Union
+
import pytest
+from _pytest._io.saferepr import saferepr
+from _pytest.config import Config
+from _pytest.config import ExitCode
+from _pytest.config.argparsing import Parser
+from _pytest.fixtures import FixtureDef
+from _pytest.fixtures import SubRequest
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--setuponly",
@@ -18,7 +28,9 @@ def pytest_addoption(parser):
@pytest.hookimpl(hookwrapper=True)
-def pytest_fixture_setup(fixturedef, request):
+def pytest_fixture_setup(
+ fixturedef: FixtureDef[object], request: SubRequest
+) -> Generator[None, None, None]:
yield
if request.config.option.setupshow:
if hasattr(request, "param"):
@@ -26,24 +38,25 @@ def pytest_fixture_setup(fixturedef, request):
# display it now and during the teardown (in .finish()).
if fixturedef.ids:
if callable(fixturedef.ids):
- fixturedef.cached_param = fixturedef.ids(request.param)
+ param = fixturedef.ids(request.param)
else:
- fixturedef.cached_param = fixturedef.ids[request.param_index]
+ param = fixturedef.ids[request.param_index]
else:
- fixturedef.cached_param = request.param
+ param = request.param
+ fixturedef.cached_param = param # type: ignore[attr-defined]
_show_fixture_action(fixturedef, "SETUP")
-def pytest_fixture_post_finalizer(fixturedef) -> None:
+def pytest_fixture_post_finalizer(fixturedef: FixtureDef[object]) -> None:
if fixturedef.cached_result is not None:
config = fixturedef._fixturemanager.config
if config.option.setupshow:
_show_fixture_action(fixturedef, "TEARDOWN")
if hasattr(fixturedef, "cached_param"):
- del fixturedef.cached_param
+ del fixturedef.cached_param # type: ignore[attr-defined]
-def _show_fixture_action(fixturedef, msg):
+def _show_fixture_action(fixturedef: FixtureDef[object], msg: str) -> None:
config = fixturedef._fixturemanager.config
capman = config.pluginmanager.getplugin("capturemanager")
if capman:
@@ -66,13 +79,16 @@ def _show_fixture_action(fixturedef, msg):
tw.write(" (fixtures used: {})".format(", ".join(deps)))
if hasattr(fixturedef, "cached_param"):
- tw.write("[{}]".format(fixturedef.cached_param))
+ tw.write("[{}]".format(saferepr(fixturedef.cached_param, maxsize=42))) # type: ignore[attr-defined]
+
+ tw.flush()
if capman:
capman.resume_global_capture()
@pytest.hookimpl(tryfirst=True)
-def pytest_cmdline_main(config):
+def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.setuponly:
config.option.setupshow = True
+ return None
diff --git a/contrib/python/pytest/py3/_pytest/setupplan.py b/contrib/python/pytest/py3/_pytest/setupplan.py
index 6fdd3aed06..9ba81ccaf0 100644
--- a/contrib/python/pytest/py3/_pytest/setupplan.py
+++ b/contrib/python/pytest/py3/_pytest/setupplan.py
@@ -1,7 +1,15 @@
+from typing import Optional
+from typing import Union
+
import pytest
+from _pytest.config import Config
+from _pytest.config import ExitCode
+from _pytest.config.argparsing import Parser
+from _pytest.fixtures import FixtureDef
+from _pytest.fixtures import SubRequest
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--setupplan",
@@ -13,16 +21,20 @@ def pytest_addoption(parser):
@pytest.hookimpl(tryfirst=True)
-def pytest_fixture_setup(fixturedef, request):
+def pytest_fixture_setup(
+ fixturedef: FixtureDef[object], request: SubRequest
+) -> Optional[object]:
# Will return a dummy fixture if the setuponly option is provided.
if request.config.option.setupplan:
my_cache_key = fixturedef.cache_key(request)
fixturedef.cached_result = (None, my_cache_key, None)
return fixturedef.cached_result
+ return None
@pytest.hookimpl(tryfirst=True)
-def pytest_cmdline_main(config):
+def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.setupplan:
config.option.setuponly = True
config.option.setupshow = True
+ return None
diff --git a/contrib/python/pytest/py3/_pytest/skipping.py b/contrib/python/pytest/py3/_pytest/skipping.py
index fe8742c667..9aacfecee7 100644
--- a/contrib/python/pytest/py3/_pytest/skipping.py
+++ b/contrib/python/pytest/py3/_pytest/skipping.py
@@ -1,18 +1,30 @@
-""" support for skip/xfail functions and markers. """
+"""Support for skip/xfail functions and markers."""
+import os
+import platform
+import sys
+import traceback
+from collections.abc import Mapping
+from typing import Generator
+from typing import Optional
+from typing import Tuple
+from typing import Type
+
+import attr
+
+from _pytest.config import Config
from _pytest.config import hookimpl
-from _pytest.mark.evaluate import MarkEvaluator
+from _pytest.config.argparsing import Parser
+from _pytest.mark.structures import Mark
+from _pytest.nodes import Item
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
+from _pytest.reports import BaseReport
+from _pytest.runner import CallInfo
from _pytest.store import StoreKey
-skipped_by_mark_key = StoreKey[bool]()
-evalxfail_key = StoreKey[MarkEvaluator]()
-unexpectedsuccess_key = StoreKey[str]()
-
-
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--runxfail",
@@ -31,7 +43,7 @@ def pytest_addoption(parser):
)
-def pytest_configure(config):
+def pytest_configure(config: Config) -> None:
if config.option.runxfail:
# yay a hack
import pytest
@@ -42,7 +54,7 @@ def pytest_configure(config):
def nop(*args, **kwargs):
pass
- nop.Exception = xfail.Exception
+ nop.Exception = xfail.Exception # type: ignore[attr-defined]
setattr(pytest, "xfail", nop)
config.addinivalue_line(
@@ -53,131 +65,260 @@ def pytest_configure(config):
)
config.addinivalue_line(
"markers",
- "skipif(condition): skip the given test function if eval(condition) "
- "results in a True value. Evaluation happens within the "
- "module global context. Example: skipif('sys.platform == \"win32\"') "
- "skips the test if we are on the win32 platform. see "
- "https://docs.pytest.org/en/latest/skipping.html",
+ "skipif(condition, ..., *, reason=...): "
+ "skip the given test function if any of the conditions evaluate to True. "
+ "Example: skipif(sys.platform == 'win32') skips the test if we are on the win32 platform. "
+ "See https://docs.pytest.org/en/stable/reference.html#pytest-mark-skipif",
)
config.addinivalue_line(
"markers",
- "xfail(condition, reason=None, run=True, raises=None, strict=False): "
- "mark the test function as an expected failure if eval(condition) "
- "has a True value. Optionally specify a reason for better reporting "
+ "xfail(condition, ..., *, reason=..., run=True, raises=None, strict=xfail_strict): "
+ "mark the test function as an expected failure if any of the conditions "
+ "evaluate to True. Optionally specify a reason for better reporting "
"and run=False if you don't even want to execute the test function. "
"If only specific exception(s) are expected, you can list them in "
"raises, and if the test fails in other ways, it will be reported as "
- "a true failure. See https://docs.pytest.org/en/latest/skipping.html",
+ "a true failure. See https://docs.pytest.org/en/stable/reference.html#pytest-mark-xfail",
)
-@hookimpl(tryfirst=True)
-def pytest_runtest_setup(item):
- # Check if skip or skipif are specified as pytest marks
- item._store[skipped_by_mark_key] = False
- eval_skipif = MarkEvaluator(item, "skipif")
- if eval_skipif.istrue():
- item._store[skipped_by_mark_key] = True
- skip(eval_skipif.getexplanation())
-
- for skip_info in item.iter_markers(name="skip"):
- item._store[skipped_by_mark_key] = True
- if "reason" in skip_info.kwargs:
- skip(skip_info.kwargs["reason"])
- elif skip_info.args:
- skip(skip_info.args[0])
+def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool, str]:
+ """Evaluate a single skipif/xfail condition.
+
+ If an old-style string condition is given, it is eval()'d, otherwise the
+ condition is bool()'d. If this fails, an appropriately formatted pytest.fail
+ is raised.
+
+ Returns (result, reason). The reason is only relevant if the result is True.
+ """
+ # String condition.
+ if isinstance(condition, str):
+ globals_ = {
+ "os": os,
+ "sys": sys,
+ "platform": platform,
+ "config": item.config,
+ }
+ for dictionary in reversed(
+ item.ihook.pytest_markeval_namespace(config=item.config)
+ ):
+ if not isinstance(dictionary, Mapping):
+ raise ValueError(
+ "pytest_markeval_namespace() needs to return a dict, got {!r}".format(
+ dictionary
+ )
+ )
+ globals_.update(dictionary)
+ if hasattr(item, "obj"):
+ globals_.update(item.obj.__globals__) # type: ignore[attr-defined]
+ try:
+ filename = f"<{mark.name} condition>"
+ condition_code = compile(condition, filename, "eval")
+ result = eval(condition_code, globals_)
+ except SyntaxError as exc:
+ msglines = [
+ "Error evaluating %r condition" % mark.name,
+ " " + condition,
+ " " + " " * (exc.offset or 0) + "^",
+ "SyntaxError: invalid syntax",
+ ]
+ fail("\n".join(msglines), pytrace=False)
+ except Exception as exc:
+ msglines = [
+ "Error evaluating %r condition" % mark.name,
+ " " + condition,
+ *traceback.format_exception_only(type(exc), exc),
+ ]
+ fail("\n".join(msglines), pytrace=False)
+
+ # Boolean condition.
+ else:
+ try:
+ result = bool(condition)
+ except Exception as exc:
+ msglines = [
+ "Error evaluating %r condition as a boolean" % mark.name,
+ *traceback.format_exception_only(type(exc), exc),
+ ]
+ fail("\n".join(msglines), pytrace=False)
+
+ reason = mark.kwargs.get("reason", None)
+ if reason is None:
+ if isinstance(condition, str):
+ reason = "condition: " + condition
else:
- skip("unconditional skip")
+ # XXX better be checked at collection time
+ msg = (
+ "Error evaluating %r: " % mark.name
+ + "you need to specify reason=STRING when using booleans as conditions."
+ )
+ fail(msg, pytrace=False)
- item._store[evalxfail_key] = MarkEvaluator(item, "xfail")
- check_xfail_no_run(item)
+ return result, reason
-@hookimpl(hookwrapper=True)
-def pytest_pyfunc_call(pyfuncitem):
- check_xfail_no_run(pyfuncitem)
- outcome = yield
- passed = outcome.excinfo is None
- if passed:
- check_strict_xfail(pyfuncitem)
+@attr.s(slots=True, frozen=True)
+class Skip:
+ """The result of evaluate_skip_marks()."""
+
+ reason = attr.ib(type=str)
+
+def evaluate_skip_marks(item: Item) -> Optional[Skip]:
+ """Evaluate skip and skipif marks on item, returning Skip if triggered."""
+ for mark in item.iter_markers(name="skipif"):
+ if "condition" not in mark.kwargs:
+ conditions = mark.args
+ else:
+ conditions = (mark.kwargs["condition"],)
+
+ # Unconditional.
+ if not conditions:
+ reason = mark.kwargs.get("reason", "")
+ return Skip(reason)
-def check_xfail_no_run(item):
- """check xfail(run=False)"""
- if not item.config.option.runxfail:
- evalxfail = item._store[evalxfail_key]
- if evalxfail.istrue():
- if not evalxfail.get("run", True):
- xfail("[NOTRUN] " + evalxfail.getexplanation())
+ # If any of the conditions are true.
+ for condition in conditions:
+ result, reason = evaluate_condition(item, mark, condition)
+ if result:
+ return Skip(reason)
+ for mark in item.iter_markers(name="skip"):
+ if "reason" in mark.kwargs:
+ reason = mark.kwargs["reason"]
+ elif mark.args:
+ reason = mark.args[0]
+ else:
+ reason = "unconditional skip"
+ return Skip(reason)
-def check_strict_xfail(pyfuncitem):
- """check xfail(strict=True) for the given PASSING test"""
- evalxfail = pyfuncitem._store[evalxfail_key]
- if evalxfail.istrue():
- strict_default = pyfuncitem.config.getini("xfail_strict")
- is_strict_xfail = evalxfail.get("strict", strict_default)
- if is_strict_xfail:
- del pyfuncitem._store[evalxfail_key]
- explanation = evalxfail.getexplanation()
- fail("[XPASS(strict)] " + explanation, pytrace=False)
+ return None
+
+
+@attr.s(slots=True, frozen=True)
+class Xfail:
+ """The result of evaluate_xfail_marks()."""
+
+ reason = attr.ib(type=str)
+ run = attr.ib(type=bool)
+ strict = attr.ib(type=bool)
+ raises = attr.ib(type=Optional[Tuple[Type[BaseException], ...]])
+
+
+def evaluate_xfail_marks(item: Item) -> Optional[Xfail]:
+ """Evaluate xfail marks on item, returning Xfail if triggered."""
+ for mark in item.iter_markers(name="xfail"):
+ run = mark.kwargs.get("run", True)
+ strict = mark.kwargs.get("strict", item.config.getini("xfail_strict"))
+ raises = mark.kwargs.get("raises", None)
+ if "condition" not in mark.kwargs:
+ conditions = mark.args
+ else:
+ conditions = (mark.kwargs["condition"],)
+
+ # Unconditional.
+ if not conditions:
+ reason = mark.kwargs.get("reason", "")
+ return Xfail(reason, run, strict, raises)
+
+ # If any of the conditions are true.
+ for condition in conditions:
+ result, reason = evaluate_condition(item, mark, condition)
+ if result:
+ return Xfail(reason, run, strict, raises)
+
+ return None
+
+
+# Whether skipped due to skip or skipif marks.
+skipped_by_mark_key = StoreKey[bool]()
+# Saves the xfail mark evaluation. Can be refreshed during call if None.
+xfailed_key = StoreKey[Optional[Xfail]]()
+unexpectedsuccess_key = StoreKey[str]()
+
+
+@hookimpl(tryfirst=True)
+def pytest_runtest_setup(item: Item) -> None:
+ skipped = evaluate_skip_marks(item)
+ item._store[skipped_by_mark_key] = skipped is not None
+ if skipped:
+ skip(skipped.reason)
+
+ item._store[xfailed_key] = xfailed = evaluate_xfail_marks(item)
+ if xfailed and not item.config.option.runxfail and not xfailed.run:
+ xfail("[NOTRUN] " + xfailed.reason)
@hookimpl(hookwrapper=True)
-def pytest_runtest_makereport(item, call):
+def pytest_runtest_call(item: Item) -> Generator[None, None, None]:
+ xfailed = item._store.get(xfailed_key, None)
+ if xfailed is None:
+ item._store[xfailed_key] = xfailed = evaluate_xfail_marks(item)
+
+ if xfailed and not item.config.option.runxfail and not xfailed.run:
+ xfail("[NOTRUN] " + xfailed.reason)
+
+ yield
+
+ # The test run may have added an xfail mark dynamically.
+ xfailed = item._store.get(xfailed_key, None)
+ if xfailed is None:
+ item._store[xfailed_key] = xfailed = evaluate_xfail_marks(item)
+
+
+@hookimpl(hookwrapper=True)
+def pytest_runtest_makereport(item: Item, call: CallInfo[None]):
outcome = yield
rep = outcome.get_result()
- evalxfail = item._store.get(evalxfail_key, None)
+ xfailed = item._store.get(xfailed_key, None)
# unittest special case, see setting of unexpectedsuccess_key
if unexpectedsuccess_key in item._store and rep.when == "call":
reason = item._store[unexpectedsuccess_key]
if reason:
- rep.longrepr = "Unexpected success: {}".format(reason)
+ rep.longrepr = f"Unexpected success: {reason}"
else:
rep.longrepr = "Unexpected success"
rep.outcome = "failed"
-
elif item.config.option.runxfail:
pass # don't interfere
- elif call.excinfo and call.excinfo.errisinstance(xfail.Exception):
+ elif call.excinfo and isinstance(call.excinfo.value, xfail.Exception):
+ assert call.excinfo.value.msg is not None
rep.wasxfail = "reason: " + call.excinfo.value.msg
rep.outcome = "skipped"
- elif evalxfail and not rep.skipped and evalxfail.wasvalid() and evalxfail.istrue():
+ elif not rep.skipped and xfailed:
if call.excinfo:
- if evalxfail.invalidraise(call.excinfo.value):
+ raises = xfailed.raises
+ if raises is not None and not isinstance(call.excinfo.value, raises):
rep.outcome = "failed"
else:
rep.outcome = "skipped"
- rep.wasxfail = evalxfail.getexplanation()
+ rep.wasxfail = xfailed.reason
elif call.when == "call":
- strict_default = item.config.getini("xfail_strict")
- is_strict_xfail = evalxfail.get("strict", strict_default)
- explanation = evalxfail.getexplanation()
- if is_strict_xfail:
+ if xfailed.strict:
rep.outcome = "failed"
- rep.longrepr = "[XPASS(strict)] {}".format(explanation)
+ rep.longrepr = "[XPASS(strict)] " + xfailed.reason
else:
rep.outcome = "passed"
- rep.wasxfail = explanation
- elif (
+ rep.wasxfail = xfailed.reason
+
+ if (
item._store.get(skipped_by_mark_key, True)
and rep.skipped
and type(rep.longrepr) is tuple
):
- # skipped by mark.skipif; change the location of the failure
+ # Skipped by mark.skipif; change the location of the failure
# to point to the item definition, otherwise it will display
- # the location of where the skip exception was raised within pytest
+ # the location of where the skip exception was raised within pytest.
_, _, reason = rep.longrepr
- filename, line = item.location[:2]
- rep.longrepr = filename, line + 1, reason
-
-
-# called by terminalreporter progress reporting
+ filename, line = item.reportinfo()[:2]
+ assert line is not None
+ rep.longrepr = str(filename), line + 1, reason
-def pytest_report_teststatus(report):
+def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if hasattr(report, "wasxfail"):
if report.skipped:
return "xfailed", "x", "XFAIL"
elif report.passed:
return "xpassed", "X", "XPASS"
+ return None
diff --git a/contrib/python/pytest/py3/_pytest/stepwise.py b/contrib/python/pytest/py3/_pytest/stepwise.py
index 6fa21cd1c6..197577c790 100644
--- a/contrib/python/pytest/py3/_pytest/stepwise.py
+++ b/contrib/python/pytest/py3/_pytest/stepwise.py
@@ -1,79 +1,92 @@
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+
import pytest
+from _pytest import nodes
+from _pytest.config import Config
+from _pytest.config.argparsing import Parser
+from _pytest.main import Session
+from _pytest.reports import TestReport
+
+if TYPE_CHECKING:
+ from _pytest.cacheprovider import Cache
+STEPWISE_CACHE_DIR = "cache/stepwise"
-def pytest_addoption(parser):
+
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--sw",
"--stepwise",
action="store_true",
+ default=False,
dest="stepwise",
help="exit on test failure and continue from last failing test next time",
)
group.addoption(
+ "--sw-skip",
"--stepwise-skip",
action="store_true",
+ default=False,
dest="stepwise_skip",
help="ignore the first failing test but stop on the next failing test",
)
@pytest.hookimpl
-def pytest_configure(config):
- config.pluginmanager.register(StepwisePlugin(config), "stepwiseplugin")
+def pytest_configure(config: Config) -> None:
+ # We should always have a cache as cache provider plugin uses tryfirst=True
+ if config.getoption("stepwise"):
+ config.pluginmanager.register(StepwisePlugin(config), "stepwiseplugin")
+
+
+def pytest_sessionfinish(session: Session) -> None:
+ if not session.config.getoption("stepwise"):
+ assert session.config.cache is not None
+ # Clear the list of failing tests if the plugin is not active.
+ session.config.cache.set(STEPWISE_CACHE_DIR, [])
class StepwisePlugin:
- def __init__(self, config):
+ def __init__(self, config: Config) -> None:
self.config = config
- self.active = config.getvalue("stepwise")
- self.session = None
+ self.session: Optional[Session] = None
self.report_status = ""
+ assert config.cache is not None
+ self.cache: Cache = config.cache
+ self.lastfailed: Optional[str] = self.cache.get(STEPWISE_CACHE_DIR, None)
+ self.skip: bool = config.getoption("stepwise_skip")
- if self.active:
- self.lastfailed = config.cache.get("cache/stepwise", None)
- self.skip = config.getvalue("stepwise_skip")
-
- def pytest_sessionstart(self, session):
+ def pytest_sessionstart(self, session: Session) -> None:
self.session = session
- def pytest_collection_modifyitems(self, session, config, items):
- if not self.active:
- return
+ def pytest_collection_modifyitems(
+ self, config: Config, items: List[nodes.Item]
+ ) -> None:
if not self.lastfailed:
self.report_status = "no previously failed tests, not skipping."
return
- already_passed = []
- found = False
-
- # Make a list of all tests that have been run before the last failing one.
- for item in items:
+ # check all item nodes until we find a match on last failed
+ failed_index = None
+ for index, item in enumerate(items):
if item.nodeid == self.lastfailed:
- found = True
+ failed_index = index
break
- else:
- already_passed.append(item)
# If the previously failed test was not found among the test items,
# do not skip any tests.
- if not found:
+ if failed_index is None:
self.report_status = "previously failed test not found, not skipping."
- already_passed = []
else:
- self.report_status = "skipping {} already passed items.".format(
- len(already_passed)
- )
-
- for item in already_passed:
- items.remove(item)
-
- config.hook.pytest_deselected(items=already_passed)
-
- def pytest_runtest_logreport(self, report):
- if not self.active:
- return
+ self.report_status = f"skipping {failed_index} already passed items."
+ deselected = items[:failed_index]
+ del items[:failed_index]
+ config.hook.pytest_deselected(items=deselected)
+ def pytest_runtest_logreport(self, report: TestReport) -> None:
if report.failed:
if self.skip:
# Remove test from the failed ones (if it exists) and unset the skip option
@@ -85,6 +98,7 @@ class StepwisePlugin:
else:
# Mark test as the last failing and interrupt the test session.
self.lastfailed = report.nodeid
+ assert self.session is not None
self.session.shouldstop = (
"Test failed, continuing from this test next run."
)
@@ -96,13 +110,10 @@ class StepwisePlugin:
if report.nodeid == self.lastfailed:
self.lastfailed = None
- def pytest_report_collectionfinish(self):
- if self.active and self.config.getoption("verbose") >= 0 and self.report_status:
- return "stepwise: %s" % self.report_status
+ def pytest_report_collectionfinish(self) -> Optional[str]:
+ if self.config.getoption("verbose") >= 0 and self.report_status:
+ return f"stepwise: {self.report_status}"
+ return None
- def pytest_sessionfinish(self, session):
- if self.active:
- self.config.cache.set("cache/stepwise", self.lastfailed)
- else:
- # Clear the list of failing tests if the plugin is not active.
- self.config.cache.set("cache/stepwise", [])
+ def pytest_sessionfinish(self) -> None:
+ self.cache.set(STEPWISE_CACHE_DIR, self.lastfailed)
diff --git a/contrib/python/pytest/py3/_pytest/store.py b/contrib/python/pytest/py3/_pytest/store.py
index eed50d103a..e5008cfc5a 100644
--- a/contrib/python/pytest/py3/_pytest/store.py
+++ b/contrib/python/pytest/py3/_pytest/store.py
@@ -27,7 +27,7 @@ class StoreKey(Generic[T]):
class Store:
"""Store is a type-safe heterogenous mutable mapping that
allows keys and value types to be defined separately from
- where it is defined.
+ where it (the Store) is created.
Usually you will be given an object which has a ``Store``:
@@ -77,13 +77,13 @@ class Store:
Good solution: module Internal adds a ``Store`` to the object. Module
External mints StoreKeys for its own keys. Module External stores and
- retrieves its data using its keys.
+ retrieves its data using these keys.
"""
__slots__ = ("_store",)
def __init__(self) -> None:
- self._store = {} # type: Dict[StoreKey[Any], object]
+ self._store: Dict[StoreKey[Any], object] = {}
def __setitem__(self, key: StoreKey[T], value: T) -> None:
"""Set a value for key."""
@@ -92,7 +92,7 @@ class Store:
def __getitem__(self, key: StoreKey[T]) -> T:
"""Get the value for key.
- Raises KeyError if the key wasn't set before.
+ Raises ``KeyError`` if the key wasn't set before.
"""
return cast(T, self._store[key])
@@ -116,10 +116,10 @@ class Store:
def __delitem__(self, key: StoreKey[T]) -> None:
"""Delete the value for key.
- Raises KeyError if the key wasn't set before.
+ Raises ``KeyError`` if the key wasn't set before.
"""
del self._store[key]
def __contains__(self, key: StoreKey[T]) -> bool:
- """Returns whether key was set."""
+ """Return whether key was set."""
return key in self._store
diff --git a/contrib/python/pytest/py3/_pytest/terminal.py b/contrib/python/pytest/py3/_pytest/terminal.py
index 812afd258b..fbfb09aecf 100644
--- a/contrib/python/pytest/py3/_pytest/terminal.py
+++ b/contrib/python/pytest/py3/_pytest/terminal.py
@@ -1,39 +1,61 @@
-""" terminal reporting of the full testing process.
+"""Terminal reporting of the full testing process.
This is a good source for looking at the various reporting hooks.
"""
import argparse
-import collections
import datetime
+import inspect
import platform
import sys
-import time
import warnings
+from collections import Counter
from functools import partial
+from pathlib import Path
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
+from typing import Generator
from typing import List
from typing import Mapping
from typing import Optional
+from typing import Sequence
from typing import Set
+from typing import TextIO
from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
import attr
import pluggy
import py
-from more_itertools import collapse
-import pytest
+import _pytest._version
from _pytest import nodes
-from _pytest._io import TerminalWriter
+from _pytest import timing
+from _pytest._code import ExceptionInfo
+from _pytest._code.code import ExceptionRepr
+from _pytest._io.wcwidth import wcswidth
+from _pytest.compat import final
+from _pytest.config import _PluggyPlugin
from _pytest.config import Config
from _pytest.config import ExitCode
-from _pytest.deprecated import TERMINALWRITER_WRITER
-from _pytest.main import Session
+from _pytest.config import hookimpl
+from _pytest.config.argparsing import Parser
+from _pytest.nodes import Item
+from _pytest.nodes import Node
+from _pytest.pathlib import absolutepath
+from _pytest.pathlib import bestrelpath
+from _pytest.reports import BaseReport
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
+if TYPE_CHECKING:
+ from typing_extensions import Literal
+
+ from _pytest.main import Session
+
+
REPORT_COLLECTING_RESOLUTION = 0.5
KNOWN_TYPES = (
@@ -51,14 +73,20 @@ _REPORTCHARS_DEFAULT = "fE"
class MoreQuietAction(argparse.Action):
- """
- a modified copy of the argparse count action which counts down and updates
- the legacy quiet attribute at the same time
+ """A modified copy of the argparse count action which counts down and updates
+ the legacy quiet attribute at the same time.
- used to unify verbosity handling
+ Used to unify verbosity handling.
"""
- def __init__(self, option_strings, dest, default=None, required=False, help=None):
+ def __init__(
+ self,
+ option_strings: Sequence[str],
+ dest: str,
+ default: object = None,
+ required: bool = False,
+ help: Optional[str] = None,
+ ) -> None:
super().__init__(
option_strings=option_strings,
dest=dest,
@@ -68,14 +96,20 @@ class MoreQuietAction(argparse.Action):
help=help,
)
- def __call__(self, parser, namespace, values, option_string=None):
+ def __call__(
+ self,
+ parser: argparse.ArgumentParser,
+ namespace: argparse.Namespace,
+ values: Union[str, Sequence[object], None],
+ option_string: Optional[str] = None,
+ ) -> None:
new_count = getattr(namespace, self.dest, 0) - 1
setattr(namespace, self.dest, new_count)
# todo Deprecate config.quiet
namespace.quiet = getattr(namespace, "quiet", 0) + 1
-def pytest_addoption(parser):
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "reporting", after="general")
group._addoption(
"-v",
@@ -86,6 +120,20 @@ def pytest_addoption(parser):
help="increase verbosity.",
)
group._addoption(
+ "--no-header",
+ action="store_true",
+ default=False,
+ dest="no_header",
+ help="disable header",
+ )
+ group._addoption(
+ "--no-summary",
+ action="store_true",
+ default=False,
+ dest="no_summary",
+ help="disable summary",
+ )
+ group._addoption(
"-q",
"--quiet",
action=MoreQuietAction,
@@ -162,6 +210,12 @@ def pytest_addoption(parser):
choices=["yes", "no", "auto"],
help="color terminal output (yes/no/auto).",
)
+ group._addoption(
+ "--code-highlight",
+ default="yes",
+ choices=["yes", "no"],
+ help="Whether code should be highlighted (only if --color is also enabled)",
+ )
parser.addini(
"console_output_style",
@@ -183,7 +237,7 @@ def pytest_configure(config: Config) -> None:
def getreportopt(config: Config) -> str:
- reportchars = config.option.reportchars
+ reportchars: str = config.option.reportchars
old_aliases = {"F", "S"}
reportopts = ""
@@ -207,15 +261,15 @@ def getreportopt(config: Config) -> str:
return reportopts
-@pytest.hookimpl(trylast=True) # after _pytest.runner
-def pytest_report_teststatus(report: TestReport) -> Tuple[str, str, str]:
+@hookimpl(trylast=True) # after _pytest.runner
+def pytest_report_teststatus(report: BaseReport) -> Tuple[str, str, str]:
letter = "F"
if report.passed:
letter = "."
elif report.skipped:
letter = "s"
- outcome = report.outcome
+ outcome: str = report.outcome
if report.when in ("collect", "setup", "teardown") and outcome == "failed":
outcome = "error"
letter = "E"
@@ -225,127 +279,131 @@ def pytest_report_teststatus(report: TestReport) -> Tuple[str, str, str]:
@attr.s
class WarningReport:
- """
- Simple structure to hold warnings information captured by ``pytest_warning_captured``.
+ """Simple structure to hold warnings information captured by ``pytest_warning_recorded``.
- :ivar str message: user friendly message about the warning
- :ivar str|None nodeid: node id that generated the warning (see ``get_location``).
+ :ivar str message:
+ User friendly message about the warning.
+ :ivar str|None nodeid:
+ nodeid that generated the warning (see ``get_location``).
:ivar tuple|py.path.local fslocation:
- file system location of the source of the warning (see ``get_location``).
+ File system location of the source of the warning (see ``get_location``).
"""
message = attr.ib(type=str)
nodeid = attr.ib(type=Optional[str], default=None)
- fslocation = attr.ib(default=None)
+ fslocation = attr.ib(
+ type=Optional[Union[Tuple[str, int], py.path.local]], default=None
+ )
count_towards_summary = True
- def get_location(self, config):
- """
- Returns the more user-friendly information about the location
- of a warning, or None.
- """
+ def get_location(self, config: Config) -> Optional[str]:
+ """Return the more user-friendly information about the location of a warning, or None."""
if self.nodeid:
return self.nodeid
if self.fslocation:
if isinstance(self.fslocation, tuple) and len(self.fslocation) >= 2:
filename, linenum = self.fslocation[:2]
- relpath = py.path.local(filename).relto(config.invocation_dir)
- if not relpath:
- relpath = str(filename)
- return "{}:{}".format(relpath, linenum)
+ relpath = bestrelpath(
+ config.invocation_params.dir, absolutepath(filename)
+ )
+ return f"{relpath}:{linenum}"
else:
return str(self.fslocation)
return None
+@final
class TerminalReporter:
- def __init__(self, config: Config, file=None) -> None:
+ def __init__(self, config: Config, file: Optional[TextIO] = None) -> None:
import _pytest.config
self.config = config
self._numcollected = 0
- self._session = None # type: Optional[Session]
- self._showfspath = None
+ self._session: Optional[Session] = None
+ self._showfspath: Optional[bool] = None
- self.stats = {} # type: Dict[str, List[Any]]
- self._main_color = None # type: Optional[str]
- self._known_types = None # type: Optional[List]
+ self.stats: Dict[str, List[Any]] = {}
+ self._main_color: Optional[str] = None
+ self._known_types: Optional[List[str]] = None
self.startdir = config.invocation_dir
+ self.startpath = config.invocation_params.dir
if file is None:
file = sys.stdout
self._tw = _pytest.config.create_terminal_writer(config, file)
self._screen_width = self._tw.fullwidth
- self.currentfspath = None # type: Any
+ self.currentfspath: Union[None, Path, str, int] = None
self.reportchars = getreportopt(config)
self.hasmarkup = self._tw.hasmarkup
self.isatty = file.isatty()
- self._progress_nodeids_reported = set() # type: Set[str]
+ self._progress_nodeids_reported: Set[str] = set()
self._show_progress_info = self._determine_show_progress_info()
- self._collect_report_last_write = None # type: Optional[float]
+ self._collect_report_last_write: Optional[float] = None
+ self._already_displayed_warnings: Optional[int] = None
+ self._keyboardinterrupt_memo: Optional[ExceptionRepr] = None
- @property
- def writer(self) -> TerminalWriter:
- warnings.warn(TERMINALWRITER_WRITER, stacklevel=2)
- return self._tw
-
- @writer.setter
- def writer(self, value: TerminalWriter):
- warnings.warn(TERMINALWRITER_WRITER, stacklevel=2)
- self._tw = value
-
- def _determine_show_progress_info(self):
- """Return True if we should display progress information based on the current config"""
+ def _determine_show_progress_info(self) -> "Literal['progress', 'count', False]":
+ """Return whether we should display progress information based on the current config."""
# do not show progress if we are not capturing output (#3038)
if self.config.getoption("capture", "no") == "no":
return False
# do not show progress if we are showing fixture setup/teardown
if self.config.getoption("setupshow", False):
return False
- cfg = self.config.getini("console_output_style")
- if cfg in ("progress", "count"):
- return cfg
- return False
+ cfg: str = self.config.getini("console_output_style")
+ if cfg == "progress":
+ return "progress"
+ elif cfg == "count":
+ return "count"
+ else:
+ return False
@property
- def verbosity(self):
- return self.config.option.verbose
+ def verbosity(self) -> int:
+ verbosity: int = self.config.option.verbose
+ return verbosity
@property
- def showheader(self):
+ def showheader(self) -> bool:
return self.verbosity >= 0
@property
- def showfspath(self):
+ def no_header(self) -> bool:
+ return bool(self.config.option.no_header)
+
+ @property
+ def no_summary(self) -> bool:
+ return bool(self.config.option.no_summary)
+
+ @property
+ def showfspath(self) -> bool:
if self._showfspath is None:
return self.verbosity >= 0
return self._showfspath
@showfspath.setter
- def showfspath(self, value):
+ def showfspath(self, value: Optional[bool]) -> None:
self._showfspath = value
@property
- def showlongtestinfo(self):
+ def showlongtestinfo(self) -> bool:
return self.verbosity > 0
- def hasopt(self, char):
+ def hasopt(self, char: str) -> bool:
char = {"xfailed": "x", "skipped": "s"}.get(char, char)
return char in self.reportchars
- def write_fspath_result(self, nodeid, res, **markup):
- fspath = self.config.rootdir.join(nodeid.split("::")[0])
- # NOTE: explicitly check for None to work around py bug, and for less
- # overhead in general (https://github.com/pytest-dev/py/pull/207).
+ def write_fspath_result(self, nodeid: str, res, **markup: bool) -> None:
+ fspath = self.config.rootpath / nodeid.split("::")[0]
if self.currentfspath is None or fspath != self.currentfspath:
if self.currentfspath is not None and self._show_progress_info:
self._write_progress_information_filling_space()
self.currentfspath = fspath
- fspath = self.startdir.bestrelpath(fspath)
+ relfspath = bestrelpath(self.startpath, fspath)
self._tw.line()
- self._tw.write(fspath + " ")
- self._tw.write(res, **markup)
+ self._tw.write(relfspath + " ")
+ self._tw.write(res, flush=True, **markup)
- def write_ensure_prefix(self, prefix, extra="", **kwargs):
+ def write_ensure_prefix(self, prefix: str, extra: str = "", **kwargs) -> None:
if self.currentfspath != prefix:
self._tw.line()
self.currentfspath = prefix
@@ -354,25 +412,28 @@ class TerminalReporter:
self._tw.write(extra, **kwargs)
self.currentfspath = -2
- def ensure_newline(self):
+ def ensure_newline(self) -> None:
if self.currentfspath:
self._tw.line()
self.currentfspath = None
- def write(self, content, **markup):
- self._tw.write(content, **markup)
+ def write(self, content: str, *, flush: bool = False, **markup: bool) -> None:
+ self._tw.write(content, flush=flush, **markup)
- def write_line(self, line, **markup):
+ def flush(self) -> None:
+ self._tw.flush()
+
+ def write_line(self, line: Union[str, bytes], **markup: bool) -> None:
if not isinstance(line, str):
line = str(line, errors="replace")
self.ensure_newline()
self._tw.line(line, **markup)
- def rewrite(self, line, **markup):
- """
- Rewinds the terminal cursor to the beginning and writes the given line.
+ def rewrite(self, line: str, **markup: bool) -> None:
+ """Rewinds the terminal cursor to the beginning and writes the given line.
- :kwarg erase: if True, will also add spaces until the full terminal width to ensure
+ :param erase:
+ If True, will also add spaces until the full terminal width to ensure
previous lines are properly erased.
The rest of the keyword arguments are markup instructions.
@@ -386,73 +447,84 @@ class TerminalReporter:
line = str(line)
self._tw.write("\r" + line + fill, **markup)
- def write_sep(self, sep, title=None, **markup):
+ def write_sep(
+ self,
+ sep: str,
+ title: Optional[str] = None,
+ fullwidth: Optional[int] = None,
+ **markup: bool,
+ ) -> None:
self.ensure_newline()
- self._tw.sep(sep, title, **markup)
+ self._tw.sep(sep, title, fullwidth, **markup)
- def section(self, title, sep="=", **kw):
+ def section(self, title: str, sep: str = "=", **kw: bool) -> None:
self._tw.sep(sep, title, **kw)
- def line(self, msg, **kw):
+ def line(self, msg: str, **kw: bool) -> None:
self._tw.line(msg, **kw)
- def _add_stats(self, category: str, items: List) -> None:
+ def _add_stats(self, category: str, items: Sequence[Any]) -> None:
set_main_color = category not in self.stats
- self.stats.setdefault(category, []).extend(items[:])
+ self.stats.setdefault(category, []).extend(items)
if set_main_color:
self._set_main_color()
- def pytest_internalerror(self, excrepr):
+ def pytest_internalerror(self, excrepr: ExceptionRepr) -> bool:
for line in str(excrepr).split("\n"):
self.write_line("INTERNALERROR> " + line)
- return 1
+ return True
- def pytest_warning_captured(self, warning_message, item):
- # from _pytest.nodes import get_fslocation_from_item
+ def pytest_warning_recorded(
+ self, warning_message: warnings.WarningMessage, nodeid: str,
+ ) -> None:
from _pytest.warnings import warning_record_to_str
fslocation = warning_message.filename, warning_message.lineno
message = warning_record_to_str(warning_message)
- nodeid = item.nodeid if item is not None else ""
warning_report = WarningReport(
fslocation=fslocation, message=message, nodeid=nodeid
)
self._add_stats("warnings", [warning_report])
- def pytest_plugin_registered(self, plugin):
+ def pytest_plugin_registered(self, plugin: _PluggyPlugin) -> None:
if self.config.option.traceconfig:
- msg = "PLUGIN registered: {}".format(plugin)
- # XXX this event may happen during setup/teardown time
+ msg = f"PLUGIN registered: {plugin}"
+ # XXX This event may happen during setup/teardown time
# which unfortunately captures our output here
- # which garbles our output if we use self.write_line
+ # which garbles our output if we use self.write_line.
self.write_line(msg)
- def pytest_deselected(self, items):
+ def pytest_deselected(self, items: Sequence[Item]) -> None:
self._add_stats("deselected", items)
- def pytest_runtest_logstart(self, nodeid, location):
- # ensure that the path is printed before the
- # 1st test of a module starts running
+ def pytest_runtest_logstart(
+ self, nodeid: str, location: Tuple[str, Optional[int], str]
+ ) -> None:
+ # Ensure that the path is printed before the
+ # 1st test of a module starts running.
if self.showlongtestinfo:
line = self._locationline(nodeid, *location)
self.write_ensure_prefix(line, "")
+ self.flush()
elif self.showfspath:
- fsid = nodeid.split("::")[0]
- self.write_fspath_result(fsid, "")
+ self.write_fspath_result(nodeid, "")
+ self.flush()
def pytest_runtest_logreport(self, report: TestReport) -> None:
self._tests_ran = True
rep = report
- res = self.config.hook.pytest_report_teststatus(report=rep, config=self.config)
+ res: Tuple[
+ str, str, Union[str, Tuple[str, Mapping[str, bool]]]
+ ] = self.config.hook.pytest_report_teststatus(report=rep, config=self.config)
category, letter, word = res
- if isinstance(word, tuple):
- word, markup = word
- else:
+ if not isinstance(word, tuple):
markup = None
+ else:
+ word, markup = word
self._add_stats(category, [rep])
if not letter and not word:
- # probably passed setup/teardown
+ # Probably passed setup/teardown.
return
running_xdist = hasattr(rep, "node")
if markup is None:
@@ -468,20 +540,27 @@ class TerminalReporter:
else:
markup = {}
if self.verbosity <= 0:
- if not running_xdist and self.showfspath:
- self.write_fspath_result(rep.nodeid, letter, **markup)
- else:
- self._tw.write(letter, **markup)
+ self._tw.write(letter, **markup)
else:
self._progress_nodeids_reported.add(rep.nodeid)
line = self._locationline(rep.nodeid, *rep.location)
if not running_xdist:
self.write_ensure_prefix(line, word, **markup)
+ if rep.skipped or hasattr(report, "wasxfail"):
+ available_width = (
+ (self._tw.fullwidth - self._tw.width_of_current_line)
+ - len(" [100%]")
+ - 1
+ )
+ reason = _get_raw_skip_reason(rep)
+ reason_ = _format_trimmed(" ({})", reason, available_width)
+ if reason and reason_ is not None:
+ self._tw.write(reason_)
if self._show_progress_info:
self._write_progress_information_filling_space()
else:
self.ensure_newline()
- self._tw.write("[%s]" % rep.node.gateway.id) # type: ignore
+ self._tw.write("[%s]" % rep.node.gateway.id)
if self._show_progress_info:
self._tw.write(
self._get_progress_information_message() + " ", cyan=True
@@ -491,12 +570,14 @@ class TerminalReporter:
self._tw.write(word, **markup)
self._tw.write(" " + line)
self.currentfspath = -2
+ self.flush()
@property
- def _is_last_item(self):
+ def _is_last_item(self) -> bool:
+ assert self._session is not None
return len(self._progress_nodeids_reported) == self._session.testscollected
- def pytest_runtest_logfinish(self, nodeid):
+ def pytest_runtest_logfinish(self, nodeid: str) -> None:
assert self._session
if self.verbosity <= 0 and self._show_progress_info:
if self._show_progress_info == "count":
@@ -524,9 +605,9 @@ class TerminalReporter:
if collected:
progress = self._progress_nodeids_reported
counter_format = "{{:{}d}}".format(len(str(collected)))
- format_string = " [{}/{{}}]".format(counter_format)
+ format_string = f" [{counter_format}/{{}}]"
return format_string.format(len(progress), collected)
- return " [ {} / {} ]".format(collected, collected)
+ return f" [ {collected} / {collected} ]"
else:
if collected:
return " [{:3d}%]".format(
@@ -534,47 +615,43 @@ class TerminalReporter:
)
return " [100%]"
- def _write_progress_information_filling_space(self):
+ def _write_progress_information_filling_space(self) -> None:
color, _ = self._get_main_color()
msg = self._get_progress_information_message()
w = self._width_of_current_line
fill = self._tw.fullwidth - w - 1
- self.write(msg.rjust(fill), **{color: True})
+ self.write(msg.rjust(fill), flush=True, **{color: True})
@property
- def _width_of_current_line(self):
- """Return the width of current line, using the superior implementation of py-1.6 when available"""
- try:
- return self._tw.width_of_current_line
- except AttributeError:
- # py < 1.6.0
- return self._tw.chars_on_current_line
+ def _width_of_current_line(self) -> int:
+ """Return the width of the current line."""
+ return self._tw.width_of_current_line
def pytest_collection(self) -> None:
if self.isatty:
if self.config.option.verbose >= 0:
- self.write("collecting ... ", bold=True)
- self._collect_report_last_write = time.time()
+ self.write("collecting ... ", flush=True, bold=True)
+ self._collect_report_last_write = timing.time()
elif self.config.option.verbose >= 1:
- self.write("collecting ... ", bold=True)
+ self.write("collecting ... ", flush=True, bold=True)
def pytest_collectreport(self, report: CollectReport) -> None:
if report.failed:
self._add_stats("error", [report])
elif report.skipped:
self._add_stats("skipped", [report])
- items = [x for x in report.result if isinstance(x, pytest.Item)]
+ items = [x for x in report.result if isinstance(x, Item)]
self._numcollected += len(items)
if self.isatty:
self.report_collect()
- def report_collect(self, final=False):
+ def report_collect(self, final: bool = False) -> None:
if self.config.option.verbose < 0:
return
if not final:
# Only write "collecting" report every 0.5s.
- t = time.time()
+ t = timing.time()
if (
self._collect_report_last_write is not None
and self._collect_report_last_write > t - REPORT_COLLECTING_RESOLUTION
@@ -608,49 +685,55 @@ class TerminalReporter:
else:
self.write_line(line)
- @pytest.hookimpl(trylast=True)
- def pytest_sessionstart(self, session: Session) -> None:
+ @hookimpl(trylast=True)
+ def pytest_sessionstart(self, session: "Session") -> None:
self._session = session
- self._sessionstarttime = time.time()
+ self._sessionstarttime = timing.time()
if not self.showheader:
return
self.write_sep("=", "test session starts", bold=True)
verinfo = platform.python_version()
- msg = "platform {} -- Python {}".format(sys.platform, verinfo)
- pypy_version_info = getattr(sys, "pypy_version_info", None)
- if pypy_version_info:
- verinfo = ".".join(map(str, pypy_version_info[:3]))
- msg += "[pypy-{}-{}]".format(verinfo, pypy_version_info[3])
- msg += ", pytest-{}, py-{}, pluggy-{}".format(
- pytest.__version__, py.__version__, pluggy.__version__
- )
- if (
- self.verbosity > 0
- or self.config.option.debug
- or getattr(self.config.option, "pastebin", None)
- ):
- msg += " -- " + str(sys.executable)
- self.write_line(msg)
- lines = self.config.hook.pytest_report_header(
- config=self.config, startdir=self.startdir
- )
- self._write_report_lines_from_hooks(lines)
+ if not self.no_header:
+ msg = f"platform {sys.platform} -- Python {verinfo}"
+ pypy_version_info = getattr(sys, "pypy_version_info", None)
+ if pypy_version_info:
+ verinfo = ".".join(map(str, pypy_version_info[:3]))
+ msg += "[pypy-{}-{}]".format(verinfo, pypy_version_info[3])
+ msg += ", pytest-{}, py-{}, pluggy-{}".format(
+ _pytest._version.version, py.__version__, pluggy.__version__
+ )
+ if (
+ self.verbosity > 0
+ or self.config.option.debug
+ or getattr(self.config.option, "pastebin", None)
+ ):
+ msg += " -- " + str(sys.executable)
+ self.write_line(msg)
+ lines = self.config.hook.pytest_report_header(
+ config=self.config, startdir=self.startdir
+ )
+ self._write_report_lines_from_hooks(lines)
+
+ def _write_report_lines_from_hooks(
+ self, lines: Sequence[Union[str, Sequence[str]]]
+ ) -> None:
+ for line_or_lines in reversed(lines):
+ if isinstance(line_or_lines, str):
+ self.write_line(line_or_lines)
+ else:
+ for line in line_or_lines:
+ self.write_line(line)
- def _write_report_lines_from_hooks(self, lines):
- lines.reverse()
- for line in collapse(lines):
- self.write_line(line)
+ def pytest_report_header(self, config: Config) -> List[str]:
+ line = "rootdir: %s" % config.rootpath
- def pytest_report_header(self, config):
- line = "rootdir: %s" % config.rootdir
+ if config.inipath:
+ line += ", configfile: " + bestrelpath(config.rootpath, config.inipath)
- if config.inifile:
- line += ", inifile: " + config.rootdir.bestrelpath(config.inifile)
+ testpaths: List[str] = config.getini("testpaths")
+ if config.invocation_params.dir == config.rootpath and config.args == testpaths:
+ line += ", testpaths: {}".format(", ".join(testpaths))
- testpaths = config.getini("testpaths")
- if testpaths and config.args == testpaths:
- rel_paths = [config.rootdir.bestrelpath(x) for x in testpaths]
- line += ", testpaths: {}".format(", ".join(rel_paths))
result = [line]
plugininfo = config.pluginmanager.list_plugin_distinfo()
@@ -658,41 +741,40 @@ class TerminalReporter:
result.append("plugins: %s" % ", ".join(_plugin_nameversions(plugininfo)))
return result
- def pytest_collection_finish(self, session):
+ def pytest_collection_finish(self, session: "Session") -> None:
self.report_collect(True)
- if self.config.getoption("collectonly"):
- self._printcollecteditems(session.items)
-
lines = self.config.hook.pytest_report_collectionfinish(
config=self.config, startdir=self.startdir, items=session.items
)
self._write_report_lines_from_hooks(lines)
if self.config.getoption("collectonly"):
+ if session.items:
+ if self.config.option.verbose > -1:
+ self._tw.line("")
+ self._printcollecteditems(session.items)
+
failed = self.stats.get("failed")
if failed:
self._tw.sep("!", "collection failures")
for rep in failed:
rep.toterminal(self._tw)
- def _printcollecteditems(self, items):
- # to print out items and their parent collectors
+ def _printcollecteditems(self, items: Sequence[Item]) -> None:
+ # To print out items and their parent collectors
# we take care to leave out Instances aka ()
- # because later versions are going to get rid of them anyway
+ # because later versions are going to get rid of them anyway.
if self.config.option.verbose < 0:
if self.config.option.verbose < -1:
- counts = {} # type: Dict[str, int]
- for item in items:
- name = item.nodeid.split("::", 1)[0]
- counts[name] = counts.get(name, 0) + 1
+ counts = Counter(item.nodeid.split("::", 1)[0] for item in items)
for name, count in sorted(counts.items()):
self._tw.line("%s: %d" % (name, count))
else:
for item in items:
self._tw.line(item.nodeid)
return
- stack = []
+ stack: List[Node] = []
indent = ""
for item in items:
needed_collectors = item.listchain()[1:] # strip root node
@@ -705,14 +787,18 @@ class TerminalReporter:
if col.name == "()": # Skip Instances.
continue
indent = (len(stack) - 1) * " "
- self._tw.line("{}{}".format(indent, col))
+ self._tw.line(f"{indent}{col}")
if self.config.option.verbose >= 1:
- if hasattr(col, "_obj") and col._obj.__doc__:
- for line in col._obj.__doc__.strip().splitlines():
- self._tw.line("{}{}".format(indent + " ", line.strip()))
-
- @pytest.hookimpl(hookwrapper=True)
- def pytest_sessionfinish(self, session: Session, exitstatus: ExitCode):
+ obj = getattr(col, "obj", None)
+ doc = inspect.getdoc(obj) if obj else None
+ if doc:
+ for line in doc.splitlines():
+ self._tw.line("{}{}".format(indent + " ", line))
+
+ @hookimpl(hookwrapper=True)
+ def pytest_sessionfinish(
+ self, session: "Session", exitstatus: Union[int, ExitCode]
+ ):
outcome = yield
outcome.get_result()
self._tw.line("")
@@ -723,21 +809,21 @@ class TerminalReporter:
ExitCode.USAGE_ERROR,
ExitCode.NO_TESTS_COLLECTED,
)
- if exitstatus in summary_exit_codes:
+ if exitstatus in summary_exit_codes and not self.no_summary:
self.config.hook.pytest_terminal_summary(
terminalreporter=self, exitstatus=exitstatus, config=self.config
)
if session.shouldfail:
- self.write_sep("!", session.shouldfail, red=True)
+ self.write_sep("!", str(session.shouldfail), red=True)
if exitstatus == ExitCode.INTERRUPTED:
self._report_keyboardinterrupt()
- del self._keyboardinterrupt_memo
+ self._keyboardinterrupt_memo = None
elif session.shouldstop:
- self.write_sep("!", session.shouldstop, red=True)
+ self.write_sep("!", str(session.shouldstop), red=True)
self.summary_stats()
- @pytest.hookimpl(hookwrapper=True)
- def pytest_terminal_summary(self):
+ @hookimpl(hookwrapper=True)
+ def pytest_terminal_summary(self) -> Generator[None, None, None]:
self.summary_errors()
self.summary_failures()
self.summary_warnings()
@@ -747,15 +833,17 @@ class TerminalReporter:
# Display any extra warnings from teardown here (if any).
self.summary_warnings()
- def pytest_keyboard_interrupt(self, excinfo):
+ def pytest_keyboard_interrupt(self, excinfo: ExceptionInfo[BaseException]) -> None:
self._keyboardinterrupt_memo = excinfo.getrepr(funcargs=True)
- def pytest_unconfigure(self):
- if hasattr(self, "_keyboardinterrupt_memo"):
+ def pytest_unconfigure(self) -> None:
+ if self._keyboardinterrupt_memo is not None:
self._report_keyboardinterrupt()
- def _report_keyboardinterrupt(self):
+ def _report_keyboardinterrupt(self) -> None:
excrepr = self._keyboardinterrupt_memo
+ assert excrepr is not None
+ assert excrepr.reprcrash is not None
msg = excrepr.reprcrash.message
self.write_sep("!", msg)
if "KeyboardInterrupt" in msg:
@@ -778,14 +866,14 @@ class TerminalReporter:
line += "[".join(values)
return line
- # collect_fspath comes from testid which has a "/"-normalized path
+ # collect_fspath comes from testid which has a "/"-normalized path.
if fspath:
res = mkrel(nodeid)
if self.verbosity >= 2 and nodeid.split("::")[0] != fspath.replace(
"\\", nodes.SEP
):
- res += " <- " + self.startdir.bestrelpath(fspath)
+ res += " <- " + bestrelpath(self.startpath, fspath)
else:
res = "[location]"
return res + " "
@@ -806,24 +894,22 @@ class TerminalReporter:
return ""
#
- # summaries for sessionfinish
+ # Summaries for sessionfinish.
#
- def getreports(self, name):
+ def getreports(self, name: str):
values = []
for x in self.stats.get(name, []):
if not hasattr(x, "_pdbshown"):
values.append(x)
return values
- def summary_warnings(self):
+ def summary_warnings(self) -> None:
if self.hasopt("w"):
- all_warnings = self.stats.get(
- "warnings"
- ) # type: Optional[List[WarningReport]]
+ all_warnings: Optional[List[WarningReport]] = self.stats.get("warnings")
if not all_warnings:
return
- final = hasattr(self, "_already_displayed_warnings")
+ final = self._already_displayed_warnings is not None
if final:
warning_reports = all_warnings[self._already_displayed_warnings :]
else:
@@ -832,15 +918,13 @@ class TerminalReporter:
if not warning_reports:
return
- reports_grouped_by_message = (
- collections.OrderedDict()
- ) # type: collections.OrderedDict[str, List[WarningReport]]
+ reports_grouped_by_message: Dict[str, List[WarningReport]] = {}
for wr in warning_reports:
reports_grouped_by_message.setdefault(wr.message, []).append(wr)
- def collapsed_location_report(reports: List[WarningReport]):
+ def collapsed_location_report(reports: List[WarningReport]) -> str:
locations = []
- for w in warning_reports:
+ for w in reports:
location = w.get_location(self.config)
if location:
locations.append(location)
@@ -848,20 +932,18 @@ class TerminalReporter:
if len(locations) < 10:
return "\n".join(map(str, locations))
- counts_by_filename = collections.Counter(
+ counts_by_filename = Counter(
str(loc).split("::", 1)[0] for loc in locations
)
return "\n".join(
- "{0}: {1} test{2} with warning{2}".format(
- k, v, "s" if v > 1 else ""
- )
+ "{}: {} warning{}".format(k, v, "s" if v > 1 else "")
for k, v in counts_by_filename.items()
)
title = "warnings summary (final)" if final else "warnings summary"
self.write_sep("=", title, yellow=True, bold=False)
- for message, warning_reports in reports_grouped_by_message.items():
- maybe_location = collapsed_location_report(warning_reports)
+ for message, message_reports in reports_grouped_by_message.items():
+ maybe_location = collapsed_location_report(message_reports)
if maybe_location:
self._tw.line(maybe_location)
lines = message.splitlines()
@@ -871,12 +953,12 @@ class TerminalReporter:
message = message.rstrip()
self._tw.line(message)
self._tw.line()
- self._tw.line("-- Docs: https://docs.pytest.org/en/latest/warnings.html")
+ self._tw.line("-- Docs: https://docs.pytest.org/en/stable/warnings.html")
- def summary_passes(self):
+ def summary_passes(self) -> None:
if self.config.option.tbstyle != "no":
if self.hasopt("P"):
- reports = self.getreports("passed")
+ reports: List[TestReport] = self.getreports("passed")
if not reports:
return
self.write_sep("=", "PASSES")
@@ -888,9 +970,10 @@ class TerminalReporter:
self._handle_teardown_sections(rep.nodeid)
def _get_teardown_reports(self, nodeid: str) -> List[TestReport]:
+ reports = self.getreports("")
return [
report
- for report in self.getreports("")
+ for report in reports
if report.when == "teardown" and report.nodeid == nodeid
]
@@ -911,9 +994,9 @@ class TerminalReporter:
content = content[:-1]
self._tw.line(content)
- def summary_failures(self):
+ def summary_failures(self) -> None:
if self.config.option.tbstyle != "no":
- reports = self.getreports("failed")
+ reports: List[BaseReport] = self.getreports("failed")
if not reports:
return
self.write_sep("=", "FAILURES")
@@ -928,9 +1011,9 @@ class TerminalReporter:
self._outrep_summary(rep)
self._handle_teardown_sections(rep.nodeid)
- def summary_errors(self):
+ def summary_errors(self) -> None:
if self.config.option.tbstyle != "no":
- reports = self.getreports("error")
+ reports: List[BaseReport] = self.getreports("error")
if not reports:
return
self.write_sep("=", "ERRORS")
@@ -939,11 +1022,11 @@ class TerminalReporter:
if rep.when == "collect":
msg = "ERROR collecting " + msg
else:
- msg = "ERROR at {} of {}".format(rep.when, msg)
+ msg = f"ERROR at {rep.when} of {msg}"
self.write_sep("_", msg, red=True, bold=True)
self._outrep_summary(rep)
- def _outrep_summary(self, rep):
+ def _outrep_summary(self, rep: BaseReport) -> None:
rep.toterminal(self._tw)
showcapture = self.config.option.showcapture
if showcapture == "no":
@@ -956,11 +1039,11 @@ class TerminalReporter:
content = content[:-1]
self._tw.line(content)
- def summary_stats(self):
+ def summary_stats(self) -> None:
if self.verbosity < -1:
return
- session_duration = time.time() - self._sessionstarttime
+ session_duration = timing.time() - self._sessionstarttime
(parts, main_color) = self.build_summary_stats_line()
line_parts = []
@@ -1012,7 +1095,7 @@ class TerminalReporter:
for rep in xfailed:
verbose_word = rep._get_verbose_word(self.config)
pos = _get_pos(self.config, rep)
- lines.append("{} {}".format(verbose_word, pos))
+ lines.append(f"{verbose_word} {pos}")
reason = rep.wasxfail
if reason:
lines.append(" " + str(reason))
@@ -1023,11 +1106,11 @@ class TerminalReporter:
verbose_word = rep._get_verbose_word(self.config)
pos = _get_pos(self.config, rep)
reason = rep.wasxfail
- lines.append("{} {} {}".format(verbose_word, pos, reason))
+ lines.append(f"{verbose_word} {pos} {reason}")
def show_skipped(lines: List[str]) -> None:
- skipped = self.stats.get("skipped", [])
- fskips = _folded_skips(skipped) if skipped else []
+ skipped: List[CollectReport] = self.stats.get("skipped", [])
+ fskips = _folded_skips(self.startpath, skipped) if skipped else []
if not fskips:
return
verbose_word = skipped[0]._get_verbose_word(self.config)
@@ -1042,16 +1125,16 @@ class TerminalReporter:
else:
lines.append("%s [%d] %s: %s" % (verbose_word, num, fspath, reason))
- REPORTCHAR_ACTIONS = {
+ REPORTCHAR_ACTIONS: Mapping[str, Callable[[List[str]], None]] = {
"x": show_xfailed,
"X": show_xpassed,
"f": partial(show_simple, "failed"),
"s": show_skipped,
"p": partial(show_simple, "passed"),
"E": partial(show_simple, "error"),
- } # type: Mapping[str, Callable[[List[str]], None]]
+ }
- lines = [] # type: List[str]
+ lines: List[str] = []
for char in self.reportchars:
action = REPORTCHAR_ACTIONS.get(char)
if action: # skipping e.g. "P" (passed with output) here.
@@ -1082,7 +1165,7 @@ class TerminalReporter:
return main_color
def _set_main_color(self) -> None:
- unknown_types = [] # type: List[str]
+ unknown_types: List[str] = []
for found_type in self.stats.keys():
if found_type: # setup/teardown reports have an empty key, ignore them
if found_type not in KNOWN_TYPES and found_type not in unknown_types:
@@ -1091,87 +1174,168 @@ class TerminalReporter:
self._main_color = self._determine_main_color(bool(unknown_types))
def build_summary_stats_line(self) -> Tuple[List[Tuple[str, Dict[str, bool]]], str]:
- main_color, known_types = self._get_main_color()
+ """
+ Build the parts used in the last summary stats line.
+
+ The summary stats line is the line shown at the end, "=== 12 passed, 2 errors in Xs===".
+
+ This function builds a list of the "parts" that make up for the text in that line, in
+ the example above it would be:
+
+ [
+ ("12 passed", {"green": True}),
+ ("2 errors", {"red": True}
+ ]
+ That last dict for each line is a "markup dictionary", used by TerminalWriter to
+ color output.
+
+ The final color of the line is also determined by this function, and is the second
+ element of the returned tuple.
+ """
+ if self.config.getoption("collectonly"):
+ return self._build_collect_only_summary_stats_line()
+ else:
+ return self._build_normal_summary_stats_line()
+
+ def _get_reports_to_display(self, key: str) -> List[Any]:
+ """Get test/collection reports for the given status key, such as `passed` or `error`."""
+ reports = self.stats.get(key, [])
+ return [x for x in reports if getattr(x, "count_towards_summary", True)]
+
+ def _build_normal_summary_stats_line(
+ self,
+ ) -> Tuple[List[Tuple[str, Dict[str, bool]]], str]:
+ main_color, known_types = self._get_main_color()
parts = []
+
for key in known_types:
- reports = self.stats.get(key, None)
+ reports = self._get_reports_to_display(key)
if reports:
- count = sum(
- 1 for rep in reports if getattr(rep, "count_towards_summary", True)
- )
+ count = len(reports)
color = _color_for_type.get(key, _color_for_type_default)
markup = {color: True, "bold": color == main_color}
- parts.append(("%d %s" % _make_plural(count, key), markup))
+ parts.append(("%d %s" % pluralize(count, key), markup))
if not parts:
parts = [("no tests ran", {_color_for_type_default: True})]
return parts, main_color
+ def _build_collect_only_summary_stats_line(
+ self,
+ ) -> Tuple[List[Tuple[str, Dict[str, bool]]], str]:
+ deselected = len(self._get_reports_to_display("deselected"))
+ errors = len(self._get_reports_to_display("error"))
-def _get_pos(config, rep):
+ if self._numcollected == 0:
+ parts = [("no tests collected", {"yellow": True})]
+ main_color = "yellow"
+
+ elif deselected == 0:
+ main_color = "green"
+ collected_output = "%d %s collected" % pluralize(self._numcollected, "test")
+ parts = [(collected_output, {main_color: True})]
+ else:
+ all_tests_were_deselected = self._numcollected == deselected
+ if all_tests_were_deselected:
+ main_color = "yellow"
+ collected_output = f"no tests collected ({deselected} deselected)"
+ else:
+ main_color = "green"
+ selected = self._numcollected - deselected
+ collected_output = f"{selected}/{self._numcollected} tests collected ({deselected} deselected)"
+
+ parts = [(collected_output, {main_color: True})]
+
+ if errors:
+ main_color = _color_for_type["error"]
+ parts += [("%d %s" % pluralize(errors, "error"), {main_color: True})]
+
+ return parts, main_color
+
+
+def _get_pos(config: Config, rep: BaseReport):
nodeid = config.cwd_relative_nodeid(rep.nodeid)
return nodeid
-def _get_line_with_reprcrash_message(config, rep, termwidth):
- """Get summary line for a report, trying to add reprcrash message."""
- from wcwidth import wcswidth
+def _format_trimmed(format: str, msg: str, available_width: int) -> Optional[str]:
+ """Format msg into format, ellipsizing it if doesn't fit in available_width.
+
+ Returns None if even the ellipsis can't fit.
+ """
+ # Only use the first line.
+ i = msg.find("\n")
+ if i != -1:
+ msg = msg[:i]
+
+ ellipsis = "..."
+ format_width = wcswidth(format.format(""))
+ if format_width + len(ellipsis) > available_width:
+ return None
+
+ if format_width + wcswidth(msg) > available_width:
+ available_width -= len(ellipsis)
+ msg = msg[:available_width]
+ while format_width + wcswidth(msg) > available_width:
+ msg = msg[:-1]
+ msg += ellipsis
+
+ return format.format(msg)
+
+def _get_line_with_reprcrash_message(
+ config: Config, rep: BaseReport, termwidth: int
+) -> str:
+ """Get summary line for a report, trying to add reprcrash message."""
verbose_word = rep._get_verbose_word(config)
pos = _get_pos(config, rep)
- line = "{} {}".format(verbose_word, pos)
- len_line = wcswidth(line)
- ellipsis, len_ellipsis = "...", 3
- if len_line > termwidth - len_ellipsis:
- # No space for an additional message.
- return line
+ line = f"{verbose_word} {pos}"
+ line_width = wcswidth(line)
try:
- msg = rep.longrepr.reprcrash.message
+ # Type ignored intentionally -- possible AttributeError expected.
+ msg = rep.longrepr.reprcrash.message # type: ignore[union-attr]
except AttributeError:
pass
else:
- # Only use the first line.
- i = msg.find("\n")
- if i != -1:
- msg = msg[:i]
- len_msg = wcswidth(msg)
-
- sep, len_sep = " - ", 3
- max_len_msg = termwidth - len_line - len_sep
- if max_len_msg >= len_ellipsis:
- if len_msg > max_len_msg:
- max_len_msg -= len_ellipsis
- msg = msg[:max_len_msg]
- while wcswidth(msg) > max_len_msg:
- msg = msg[:-1]
- msg += ellipsis
- line += sep + msg
+ available_width = termwidth - line_width
+ msg = _format_trimmed(" - {}", msg, available_width)
+ if msg is not None:
+ line += msg
+
return line
-def _folded_skips(skipped):
- d = {}
+def _folded_skips(
+ startpath: Path, skipped: Sequence[CollectReport],
+) -> List[Tuple[int, str, Optional[int], str]]:
+ d: Dict[Tuple[str, Optional[int], str], List[CollectReport]] = {}
for event in skipped:
- key = event.longrepr
- assert len(key) == 3, (event, key)
+ assert event.longrepr is not None
+ assert isinstance(event.longrepr, tuple), (event, event.longrepr)
+ assert len(event.longrepr) == 3, (event, event.longrepr)
+ fspath, lineno, reason = event.longrepr
+ # For consistency, report all fspaths in relative form.
+ fspath = bestrelpath(startpath, Path(fspath))
keywords = getattr(event, "keywords", {})
- # folding reports with global pytestmark variable
- # this is workaround, because for now we cannot identify the scope of a skip marker
- # TODO: revisit after marks scope would be fixed
+ # Folding reports with global pytestmark variable.
+ # This is a workaround, because for now we cannot identify the scope of a skip marker
+ # TODO: Revisit after marks scope would be fixed.
if (
event.when == "setup"
and "skip" in keywords
and "pytestmark" not in keywords
):
- key = (key[0], None, key[2])
+ key: Tuple[str, Optional[int], str] = (fspath, None, reason)
+ else:
+ key = (fspath, lineno, reason)
d.setdefault(key, []).append(event)
- values = []
+ values: List[Tuple[int, str, Optional[int], str]] = []
for key, events in d.items():
- values.append((len(events),) + key)
+ values.append((len(events), *key))
return values
@@ -1184,9 +1348,9 @@ _color_for_type = {
_color_for_type_default = "yellow"
-def _make_plural(count, noun):
+def pluralize(count: int, noun: str) -> Tuple[int, str]:
# No need to pluralize words such as `failed` or `passed`.
- if noun not in ["error", "warnings"]:
+ if noun not in ["error", "warnings", "test"]:
return count, noun
# The `warnings` key is plural. To avoid API breakage, we keep it that way but
@@ -1198,24 +1362,44 @@ def _make_plural(count, noun):
def _plugin_nameversions(plugininfo) -> List[str]:
- values = [] # type: List[str]
+ values: List[str] = []
for plugin, dist in plugininfo:
- # gets us name and version!
+ # Gets us name and version!
name = "{dist.project_name}-{dist.version}".format(dist=dist)
- # questionable convenience, but it keeps things short
+ # Questionable convenience, but it keeps things short.
if name.startswith("pytest-"):
name = name[7:]
- # we decided to print python package names
- # they can have more than one plugin
+ # We decided to print python package names they can have more than one plugin.
if name not in values:
values.append(name)
return values
def format_session_duration(seconds: float) -> str:
- """Format the given seconds in a human readable manner to show in the final summary"""
+ """Format the given seconds in a human readable manner to show in the final summary."""
if seconds < 60:
- return "{:.2f}s".format(seconds)
+ return f"{seconds:.2f}s"
else:
dt = datetime.timedelta(seconds=int(seconds))
- return "{:.2f}s ({})".format(seconds, dt)
+ return f"{seconds:.2f}s ({dt})"
+
+
+def _get_raw_skip_reason(report: TestReport) -> str:
+ """Get the reason string of a skip/xfail/xpass test report.
+
+ The string is just the part given by the user.
+ """
+ if hasattr(report, "wasxfail"):
+ reason = cast(str, report.wasxfail)
+ if reason.startswith("reason: "):
+ reason = reason[len("reason: ") :]
+ return reason
+ else:
+ assert report.skipped
+ assert isinstance(report.longrepr, tuple)
+ _, _, reason = report.longrepr
+ if reason.startswith("Skipped: "):
+ reason = reason[len("Skipped: ") :]
+ elif reason == "Skipped":
+ reason = ""
+ return reason
diff --git a/contrib/python/pytest/py3/_pytest/threadexception.py b/contrib/python/pytest/py3/_pytest/threadexception.py
new file mode 100644
index 0000000000..1c1f62fdb7
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/threadexception.py
@@ -0,0 +1,90 @@
+import threading
+import traceback
+import warnings
+from types import TracebackType
+from typing import Any
+from typing import Callable
+from typing import Generator
+from typing import Optional
+from typing import Type
+
+import pytest
+
+
+# Copied from cpython/Lib/test/support/threading_helper.py, with modifications.
+class catch_threading_exception:
+ """Context manager catching threading.Thread exception using
+ threading.excepthook.
+
+ Storing exc_value using a custom hook can create a reference cycle. The
+ reference cycle is broken explicitly when the context manager exits.
+
+ Storing thread using a custom hook can resurrect it if it is set to an
+ object which is being finalized. Exiting the context manager clears the
+ stored object.
+
+ Usage:
+ with threading_helper.catch_threading_exception() as cm:
+ # code spawning a thread which raises an exception
+ ...
+ # check the thread exception: use cm.args
+ ...
+ # cm.args attribute no longer exists at this point
+ # (to break a reference cycle)
+ """
+
+ def __init__(self) -> None:
+ # See https://github.com/python/typeshed/issues/4767 regarding the underscore.
+ self.args: Optional["threading._ExceptHookArgs"] = None
+ self._old_hook: Optional[Callable[["threading._ExceptHookArgs"], Any]] = None
+
+ def _hook(self, args: "threading._ExceptHookArgs") -> None:
+ self.args = args
+
+ def __enter__(self) -> "catch_threading_exception":
+ self._old_hook = threading.excepthook
+ threading.excepthook = self._hook
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ assert self._old_hook is not None
+ threading.excepthook = self._old_hook
+ self._old_hook = None
+ del self.args
+
+
+def thread_exception_runtest_hook() -> Generator[None, None, None]:
+ with catch_threading_exception() as cm:
+ yield
+ if cm.args:
+ if cm.args.thread is not None:
+ thread_name = cm.args.thread.name
+ else:
+ thread_name = "<unknown>"
+ msg = f"Exception in thread {thread_name}\n\n"
+ msg += "".join(
+ traceback.format_exception(
+ cm.args.exc_type, cm.args.exc_value, cm.args.exc_traceback,
+ )
+ )
+ warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))
+
+
+@pytest.hookimpl(hookwrapper=True, trylast=True)
+def pytest_runtest_setup() -> Generator[None, None, None]:
+ yield from thread_exception_runtest_hook()
+
+
+@pytest.hookimpl(hookwrapper=True, tryfirst=True)
+def pytest_runtest_call() -> Generator[None, None, None]:
+ yield from thread_exception_runtest_hook()
+
+
+@pytest.hookimpl(hookwrapper=True, tryfirst=True)
+def pytest_runtest_teardown() -> Generator[None, None, None]:
+ yield from thread_exception_runtest_hook()
diff --git a/contrib/python/pytest/py3/_pytest/timing.py b/contrib/python/pytest/py3/_pytest/timing.py
new file mode 100644
index 0000000000..925163a585
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/timing.py
@@ -0,0 +1,12 @@
+"""Indirection for time functions.
+
+We intentionally grab some "time" functions internally to avoid tests mocking "time" to affect
+pytest runtime information (issue #185).
+
+Fixture "mock_timing" also interacts with this module for pytest's own tests.
+"""
+from time import perf_counter
+from time import sleep
+from time import time
+
+__all__ = ["perf_counter", "sleep", "time"]
diff --git a/contrib/python/pytest/py3/_pytest/tmpdir.py b/contrib/python/pytest/py3/_pytest/tmpdir.py
index 85c5b83810..a6bd383a9c 100644
--- a/contrib/python/pytest/py3/_pytest/tmpdir.py
+++ b/contrib/python/pytest/py3/_pytest/tmpdir.py
@@ -1,90 +1,112 @@
-""" support for providing temporary directories to test functions. """
+"""Support for providing temporary directories to test functions."""
import os
import re
+import sys
import tempfile
+from pathlib import Path
from typing import Optional
import attr
import py
-import pytest
-from .pathlib import ensure_reset_dir
from .pathlib import LOCK_TIMEOUT
from .pathlib import make_numbered_dir
from .pathlib import make_numbered_dir_with_cleanup
-from .pathlib import Path
+from .pathlib import rm_rf
+from _pytest.compat import final
+from _pytest.config import Config
+from _pytest.deprecated import check_ispytest
+from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.monkeypatch import MonkeyPatch
-@attr.s
+@final
+@attr.s(init=False)
class TempPathFactory:
"""Factory for temporary directories under the common base temp directory.
- The base directory can be configured using the ``--basetemp`` option."""
-
- _given_basetemp = attr.ib(
- type=Path,
- # using os.path.abspath() to get absolute path instead of resolve() as it
- # does not work the same in all platforms (see #4427)
- # Path.absolute() exists, but it is not public (see https://bugs.python.org/issue25012)
- # Ignore type because of https://github.com/python/mypy/issues/6172.
- converter=attr.converters.optional(
- lambda p: Path(os.path.abspath(str(p))) # type: ignore
- ),
- )
+ The base directory can be configured using the ``--basetemp`` option.
+ """
+
+ _given_basetemp = attr.ib(type=Optional[Path])
_trace = attr.ib()
- _basetemp = attr.ib(type=Optional[Path], default=None)
+ _basetemp = attr.ib(type=Optional[Path])
+
+ def __init__(
+ self,
+ given_basetemp: Optional[Path],
+ trace,
+ basetemp: Optional[Path] = None,
+ *,
+ _ispytest: bool = False,
+ ) -> None:
+ check_ispytest(_ispytest)
+ if given_basetemp is None:
+ self._given_basetemp = None
+ else:
+ # Use os.path.abspath() to get absolute path instead of resolve() as it
+ # does not work the same in all platforms (see #4427).
+ # Path.absolute() exists, but it is not public (see https://bugs.python.org/issue25012).
+ self._given_basetemp = Path(os.path.abspath(str(given_basetemp)))
+ self._trace = trace
+ self._basetemp = basetemp
@classmethod
- def from_config(cls, config) -> "TempPathFactory":
- """
- :param config: a pytest configuration
+ def from_config(
+ cls, config: Config, *, _ispytest: bool = False,
+ ) -> "TempPathFactory":
+ """Create a factory according to pytest configuration.
+
+ :meta private:
"""
+ check_ispytest(_ispytest)
return cls(
- given_basetemp=config.option.basetemp, trace=config.trace.get("tmpdir")
+ given_basetemp=config.option.basetemp,
+ trace=config.trace.get("tmpdir"),
+ _ispytest=True,
)
- def _ensure_relative_to_basetemp(self, basename: str):
+ def _ensure_relative_to_basetemp(self, basename: str) -> str:
basename = os.path.normpath(basename)
if (self.getbasetemp() / basename).resolve().parent != self.getbasetemp():
- raise ValueError(
- "{} is not a normalized and relative path".format(basename)
- )
+ raise ValueError(f"{basename} is not a normalized and relative path")
return basename
def mktemp(self, basename: str, numbered: bool = True) -> Path:
- """Creates a new temporary directory managed by the factory.
+ """Create a new temporary directory managed by the factory.
:param basename:
Directory base name, must be a relative path.
:param numbered:
- If True, ensure the directory is unique by adding a number
- prefix greater than any existing one: ``basename="foo"`` and ``numbered=True``
+ If ``True``, ensure the directory is unique by adding a numbered
+ suffix greater than any existing one: ``basename="foo-"`` and ``numbered=True``
means that this function will create directories named ``"foo-0"``,
``"foo-1"``, ``"foo-2"`` and so on.
- :return:
+ :returns:
The path to the new directory.
"""
basename = self._ensure_relative_to_basetemp(basename)
if not numbered:
p = self.getbasetemp().joinpath(basename)
- p.mkdir()
+ p.mkdir(mode=0o700)
else:
- p = make_numbered_dir(root=self.getbasetemp(), prefix=basename)
+ p = make_numbered_dir(root=self.getbasetemp(), prefix=basename, mode=0o700)
self._trace("mktemp", p)
return p
def getbasetemp(self) -> Path:
- """ return base temporary directory. """
+ """Return the base temporary directory, creating it if needed."""
if self._basetemp is not None:
return self._basetemp
if self._given_basetemp is not None:
basetemp = self._given_basetemp
- ensure_reset_dir(basetemp)
+ if basetemp.exists():
+ rm_rf(basetemp)
+ basetemp.mkdir(mode=0o700)
basetemp = basetemp.resolve()
else:
from_env = os.environ.get("PYTEST_DEBUG_TEMPROOT")
@@ -92,41 +114,66 @@ class TempPathFactory:
user = get_user() or "unknown"
# use a sub-directory in the temproot to speed-up
# make_numbered_dir() call
- rootdir = temproot.joinpath("pytest-of-{}".format(user))
- rootdir.mkdir(exist_ok=True)
+ rootdir = temproot.joinpath(f"pytest-of-{user}")
+ rootdir.mkdir(mode=0o700, exist_ok=True)
+ # Because we use exist_ok=True with a predictable name, make sure
+ # we are the owners, to prevent any funny business (on unix, where
+ # temproot is usually shared).
+ # Also, to keep things private, fixup any world-readable temp
+ # rootdir's permissions. Historically 0o755 was used, so we can't
+ # just error out on this, at least for a while.
+ if sys.platform != "win32":
+ uid = os.getuid()
+ rootdir_stat = rootdir.stat()
+ # getuid shouldn't fail, but cpython defines such a case.
+ # Let's hope for the best.
+ if uid != -1:
+ if rootdir_stat.st_uid != uid:
+ raise OSError(
+ f"The temporary directory {rootdir} is not owned by the current user. "
+ "Fix this and try again."
+ )
+ if (rootdir_stat.st_mode & 0o077) != 0:
+ os.chmod(rootdir, rootdir_stat.st_mode & ~0o077)
basetemp = make_numbered_dir_with_cleanup(
- prefix="pytest-", root=rootdir, keep=3, lock_timeout=LOCK_TIMEOUT
+ prefix="pytest-",
+ root=rootdir,
+ keep=3,
+ lock_timeout=LOCK_TIMEOUT,
+ mode=0o700,
)
assert basetemp is not None, basetemp
- self._basetemp = t = basetemp
- self._trace("new basetemp", t)
- return t
+ self._basetemp = basetemp
+ self._trace("new basetemp", basetemp)
+ return basetemp
-@attr.s
+@final
+@attr.s(init=False)
class TempdirFactory:
- """
- backward comptibility wrapper that implements
- :class:``py.path.local`` for :class:``TempPathFactory``
- """
+ """Backward comptibility wrapper that implements :class:``py.path.local``
+ for :class:``TempPathFactory``."""
_tmppath_factory = attr.ib(type=TempPathFactory)
+ def __init__(
+ self, tmppath_factory: TempPathFactory, *, _ispytest: bool = False
+ ) -> None:
+ check_ispytest(_ispytest)
+ self._tmppath_factory = tmppath_factory
+
def mktemp(self, basename: str, numbered: bool = True) -> py.path.local:
- """
- Same as :meth:`TempPathFactory.mkdir`, but returns a ``py.path.local`` object.
- """
+ """Same as :meth:`TempPathFactory.mktemp`, but returns a ``py.path.local`` object."""
return py.path.local(self._tmppath_factory.mktemp(basename, numbered).resolve())
- def getbasetemp(self):
- """backward compat wrapper for ``_tmppath_factory.getbasetemp``"""
+ def getbasetemp(self) -> py.path.local:
+ """Backward compat wrapper for ``_tmppath_factory.getbasetemp``."""
return py.path.local(self._tmppath_factory.getbasetemp().resolve())
def get_user() -> Optional[str]:
"""Return the current user name, or None if getuser() does not work
- in the current environment (see #1010).
- """
+ in the current environment (see #1010)."""
import getpass
try:
@@ -135,7 +182,7 @@ def get_user() -> Optional[str]:
return None
-def pytest_configure(config) -> None:
+def pytest_configure(config: Config) -> None:
"""Create a TempdirFactory and attach it to the config object.
This is to comply with existing plugins which expect the handler to be
@@ -143,25 +190,23 @@ def pytest_configure(config) -> None:
to the tmpdir_factory session fixture.
"""
mp = MonkeyPatch()
- tmppath_handler = TempPathFactory.from_config(config)
- t = TempdirFactory(tmppath_handler)
+ tmppath_handler = TempPathFactory.from_config(config, _ispytest=True)
+ t = TempdirFactory(tmppath_handler, _ispytest=True)
config._cleanup.append(mp.undo)
mp.setattr(config, "_tmp_path_factory", tmppath_handler, raising=False)
mp.setattr(config, "_tmpdirhandler", t, raising=False)
-@pytest.fixture(scope="session")
+@fixture(scope="session")
def tmpdir_factory(request: FixtureRequest) -> TempdirFactory:
- """Return a :class:`_pytest.tmpdir.TempdirFactory` instance for the test session.
- """
+ """Return a :class:`_pytest.tmpdir.TempdirFactory` instance for the test session."""
# Set dynamically by pytest_configure() above.
return request.config._tmpdirhandler # type: ignore
-@pytest.fixture(scope="session")
+@fixture(scope="session")
def tmp_path_factory(request: FixtureRequest) -> TempPathFactory:
- """Return a :class:`_pytest.tmpdir.TempPathFactory` instance for the test session.
- """
+ """Return a :class:`_pytest.tmpdir.TempPathFactory` instance for the test session."""
# Set dynamically by pytest_configure() above.
return request.config._tmp_path_factory # type: ignore
@@ -174,30 +219,36 @@ def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path:
return factory.mktemp(name, numbered=True)
-@pytest.fixture
-def tmpdir(tmp_path):
- """Return a temporary directory path object
- which is unique to each test function invocation,
- created as a sub directory of the base temporary
- directory. The returned object is a `py.path.local`_
- path object.
+@fixture
+def tmpdir(tmp_path: Path) -> py.path.local:
+ """Return a temporary directory path object which is unique to each test
+ function invocation, created as a sub directory of the base temporary
+ directory.
+
+ By default, a new base temporary directory is created each test session,
+ and old bases are removed after 3 sessions, to aid in debugging. If
+ ``--basetemp`` is used then it is cleared each session. See :ref:`base
+ temporary directory`.
+
+ The returned object is a `py.path.local`_ path object.
.. _`py.path.local`: https://py.readthedocs.io/en/latest/path.html
"""
return py.path.local(tmp_path)
-@pytest.fixture
+@fixture
def tmp_path(request: FixtureRequest, tmp_path_factory: TempPathFactory) -> Path:
- """Return a temporary directory path object
- which is unique to each test function invocation,
- created as a sub directory of the base temporary
- directory. The returned object is a :class:`pathlib.Path`
- object.
+ """Return a temporary directory path object which is unique to each test
+ function invocation, created as a sub directory of the base temporary
+ directory.
- .. note::
+ By default, a new base temporary directory is created each test session,
+ and old bases are removed after 3 sessions, to aid in debugging. If
+ ``--basetemp`` is used then it is cleared each session. See :ref:`base
+ temporary directory`.
- in python < 3.6 this is a pathlib2.Path
+ The returned object is a :class:`pathlib.Path` object.
"""
return _mk_tmp(request, tmp_path_factory)
diff --git a/contrib/python/pytest/py3/_pytest/unittest.py b/contrib/python/pytest/py3/_pytest/unittest.py
index 36158c62d2..55f15efe4b 100644
--- a/contrib/python/pytest/py3/_pytest/unittest.py
+++ b/contrib/python/pytest/py3/_pytest/unittest.py
@@ -1,40 +1,70 @@
-""" discovery and running of std-library "unittest" style tests. """
+"""Discover and run std-library "unittest" style tests."""
import sys
import traceback
+import types
+from typing import Any
+from typing import Callable
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
import _pytest._code
import pytest
from _pytest.compat import getimfunc
from _pytest.compat import is_async_function
from _pytest.config import hookimpl
+from _pytest.fixtures import FixtureRequest
+from _pytest.nodes import Collector
+from _pytest.nodes import Item
from _pytest.outcomes import exit
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
from _pytest.python import Class
from _pytest.python import Function
+from _pytest.python import PyCollector
from _pytest.runner import CallInfo
from _pytest.skipping import skipped_by_mark_key
from _pytest.skipping import unexpectedsuccess_key
+if TYPE_CHECKING:
+ import unittest
-def pytest_pycollect_makeitem(collector, name, obj):
- # has unittest been imported and is obj a subclass of its TestCase?
+ from _pytest.fixtures import _Scope
+
+ _SysExcInfoType = Union[
+ Tuple[Type[BaseException], BaseException, types.TracebackType],
+ Tuple[None, None, None],
+ ]
+
+
+def pytest_pycollect_makeitem(
+ collector: PyCollector, name: str, obj: object
+) -> Optional["UnitTestCase"]:
+ # Has unittest been imported and is obj a subclass of its TestCase?
try:
- if not issubclass(obj, sys.modules["unittest"].TestCase):
- return
+ ut = sys.modules["unittest"]
+ # Type ignored because `ut` is an opaque module.
+ if not issubclass(obj, ut.TestCase): # type: ignore
+ return None
except Exception:
- return
- # yes, so let's collect it
- return UnitTestCase.from_parent(collector, name=name, obj=obj)
+ return None
+ # Yes, so let's collect it.
+ item: UnitTestCase = UnitTestCase.from_parent(collector, name=name, obj=obj)
+ return item
class UnitTestCase(Class):
- # marker for fixturemanger.getfixtureinfo()
- # to declare that our children do not support funcargs
+ # Marker for fixturemanger.getfixtureinfo()
+ # to declare that our children do not support funcargs.
nofuncargs = True
- def collect(self):
+ def collect(self) -> Iterable[Union[Item, Collector]]:
from unittest import TestLoader
cls = self.obj
@@ -61,82 +91,128 @@ class UnitTestCase(Class):
runtest = getattr(self.obj, "runTest", None)
if runtest is not None:
ut = sys.modules.get("twisted.trial.unittest", None)
- if ut is None or runtest != ut.TestCase.runTest:
- # TODO: callobj consistency
+ # Type ignored because `ut` is an opaque module.
+ if ut is None or runtest != ut.TestCase.runTest: # type: ignore
yield TestCaseFunction.from_parent(self, name="runTest")
- def _inject_setup_teardown_fixtures(self, cls):
+ def _inject_setup_teardown_fixtures(self, cls: type) -> None:
"""Injects a hidden auto-use fixture to invoke setUpClass/setup_method and corresponding
- teardown functions (#517)"""
+ teardown functions (#517)."""
class_fixture = _make_xunit_fixture(
- cls, "setUpClass", "tearDownClass", scope="class", pass_self=False
+ cls,
+ "setUpClass",
+ "tearDownClass",
+ "doClassCleanups",
+ scope="class",
+ pass_self=False,
)
if class_fixture:
- cls.__pytest_class_setup = class_fixture
+ cls.__pytest_class_setup = class_fixture # type: ignore[attr-defined]
method_fixture = _make_xunit_fixture(
- cls, "setup_method", "teardown_method", scope="function", pass_self=True
+ cls,
+ "setup_method",
+ "teardown_method",
+ None,
+ scope="function",
+ pass_self=True,
)
if method_fixture:
- cls.__pytest_method_setup = method_fixture
+ cls.__pytest_method_setup = method_fixture # type: ignore[attr-defined]
-def _make_xunit_fixture(obj, setup_name, teardown_name, scope, pass_self):
+def _make_xunit_fixture(
+ obj: type,
+ setup_name: str,
+ teardown_name: str,
+ cleanup_name: Optional[str],
+ scope: "_Scope",
+ pass_self: bool,
+):
setup = getattr(obj, setup_name, None)
teardown = getattr(obj, teardown_name, None)
if setup is None and teardown is None:
return None
- @pytest.fixture(scope=scope, autouse=True)
- def fixture(self, request):
+ if cleanup_name:
+ cleanup = getattr(obj, cleanup_name, lambda *args: None)
+ else:
+
+ def cleanup(*args):
+ pass
+
+ @pytest.fixture(
+ scope=scope,
+ autouse=True,
+ # Use a unique name to speed up lookup.
+ name=f"unittest_{setup_name}_fixture_{obj.__qualname__}",
+ )
+ def fixture(self, request: FixtureRequest) -> Generator[None, None, None]:
if _is_skipped(self):
reason = self.__unittest_skip_why__
pytest.skip(reason)
if setup is not None:
- if pass_self:
- setup(self, request.function)
- else:
- setup()
+ try:
+ if pass_self:
+ setup(self, request.function)
+ else:
+ setup()
+ # unittest does not call the cleanup function for every BaseException, so we
+ # follow this here.
+ except Exception:
+ if pass_self:
+ cleanup(self)
+ else:
+ cleanup()
+
+ raise
yield
- if teardown is not None:
+ try:
+ if teardown is not None:
+ if pass_self:
+ teardown(self, request.function)
+ else:
+ teardown()
+ finally:
if pass_self:
- teardown(self, request.function)
+ cleanup(self)
else:
- teardown()
+ cleanup()
return fixture
class TestCaseFunction(Function):
nofuncargs = True
- _excinfo = None
- _testcase = None
-
- def setup(self):
- # a bound method to be called during teardown() if set (see 'runtest()')
- self._explicit_tearDown = None
- self._testcase = self.parent.obj(self.name)
+ _excinfo: Optional[List[_pytest._code.ExceptionInfo[BaseException]]] = None
+ _testcase: Optional["unittest.TestCase"] = None
+
+ def setup(self) -> None:
+ # A bound method to be called during teardown() if set (see 'runtest()').
+ self._explicit_tearDown: Optional[Callable[[], None]] = None
+ assert self.parent is not None
+ self._testcase = self.parent.obj(self.name) # type: ignore[attr-defined]
self._obj = getattr(self._testcase, self.name)
if hasattr(self, "_request"):
self._request._fillfixtures()
- def teardown(self):
+ def teardown(self) -> None:
if self._explicit_tearDown is not None:
self._explicit_tearDown()
self._explicit_tearDown = None
self._testcase = None
self._obj = None
- def startTest(self, testcase):
+ def startTest(self, testcase: "unittest.TestCase") -> None:
pass
- def _addexcinfo(self, rawexcinfo):
- # unwrap potential exception info (see twisted trial support below)
+ def _addexcinfo(self, rawexcinfo: "_SysExcInfoType") -> None:
+ # Unwrap potential exception info (see twisted trial support below).
rawexcinfo = getattr(rawexcinfo, "_rawexcinfo", rawexcinfo)
try:
- excinfo = _pytest._code.ExceptionInfo(rawexcinfo)
- # invoke the attributes to trigger storing the traceback
- # trial causes some issue there
+ excinfo = _pytest._code.ExceptionInfo(rawexcinfo) # type: ignore[arg-type]
+ # Invoke the attributes to trigger storing the traceback
+ # trial causes some issue there.
excinfo.value
excinfo.traceback
except TypeError:
@@ -151,7 +227,7 @@ class TestCaseFunction(Function):
fail("".join(values), pytrace=False)
except (fail.Exception, KeyboardInterrupt):
raise
- except: # noqa
+ except BaseException:
fail(
"ERROR: Unknown Incompatible Exception "
"representation:\n%r" % (rawexcinfo,),
@@ -163,7 +239,9 @@ class TestCaseFunction(Function):
excinfo = _pytest._code.ExceptionInfo.from_current()
self.__dict__.setdefault("_excinfo", []).append(excinfo)
- def addError(self, testcase, rawexcinfo):
+ def addError(
+ self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
+ ) -> None:
try:
if isinstance(rawexcinfo[1], exit.Exception):
exit(rawexcinfo[1].msg)
@@ -171,68 +249,82 @@ class TestCaseFunction(Function):
pass
self._addexcinfo(rawexcinfo)
- def addFailure(self, testcase, rawexcinfo):
+ def addFailure(
+ self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
+ ) -> None:
self._addexcinfo(rawexcinfo)
- def addSkip(self, testcase, reason):
+ def addSkip(self, testcase: "unittest.TestCase", reason: str) -> None:
try:
skip(reason)
except skip.Exception:
self._store[skipped_by_mark_key] = True
self._addexcinfo(sys.exc_info())
- def addExpectedFailure(self, testcase, rawexcinfo, reason=""):
+ def addExpectedFailure(
+ self,
+ testcase: "unittest.TestCase",
+ rawexcinfo: "_SysExcInfoType",
+ reason: str = "",
+ ) -> None:
try:
xfail(str(reason))
except xfail.Exception:
self._addexcinfo(sys.exc_info())
- def addUnexpectedSuccess(self, testcase, reason=""):
+ def addUnexpectedSuccess(
+ self, testcase: "unittest.TestCase", reason: str = ""
+ ) -> None:
self._store[unexpectedsuccess_key] = reason
- def addSuccess(self, testcase):
+ def addSuccess(self, testcase: "unittest.TestCase") -> None:
pass
- def stopTest(self, testcase):
+ def stopTest(self, testcase: "unittest.TestCase") -> None:
pass
def _expecting_failure(self, test_method) -> bool:
"""Return True if the given unittest method (or the entire class) is marked
- with @expectedFailure"""
+ with @expectedFailure."""
expecting_failure_method = getattr(
test_method, "__unittest_expecting_failure__", False
)
expecting_failure_class = getattr(self, "__unittest_expecting_failure__", False)
return bool(expecting_failure_class or expecting_failure_method)
- def runtest(self):
+ def runtest(self) -> None:
from _pytest.debugging import maybe_wrap_pytest_function_for_tracing
+ assert self._testcase is not None
+
maybe_wrap_pytest_function_for_tracing(self)
- # let the unittest framework handle async functions
+ # Let the unittest framework handle async functions.
if is_async_function(self.obj):
- self._testcase(self)
+ # Type ignored because self acts as the TestResult, but is not actually one.
+ self._testcase(result=self) # type: ignore[arg-type]
else:
- # when --pdb is given, we want to postpone calling tearDown() otherwise
+ # When --pdb is given, we want to postpone calling tearDown() otherwise
# when entering the pdb prompt, tearDown() would have probably cleaned up
- # instance variables, which makes it difficult to debug
- # arguably we could always postpone tearDown(), but this changes the moment where the
+ # instance variables, which makes it difficult to debug.
+ # Arguably we could always postpone tearDown(), but this changes the moment where the
# TestCase instance interacts with the results object, so better to only do it
- # when absolutely needed
+ # when absolutely needed.
if self.config.getoption("usepdb") and not _is_skipped(self.obj):
self._explicit_tearDown = self._testcase.tearDown
setattr(self._testcase, "tearDown", lambda *args: None)
- # we need to update the actual bound method with self.obj, because
- # wrap_pytest_function_for_tracing replaces self.obj by a wrapper
+ # We need to update the actual bound method with self.obj, because
+ # wrap_pytest_function_for_tracing replaces self.obj by a wrapper.
setattr(self._testcase, self.name, self.obj)
try:
- self._testcase(result=self)
+ self._testcase(result=self) # type: ignore[arg-type]
finally:
delattr(self._testcase, self.name)
- def _prunetraceback(self, excinfo):
+ def _prunetraceback(
+ self, excinfo: _pytest._code.ExceptionInfo[BaseException]
+ ) -> None:
Function._prunetraceback(self, excinfo)
traceback = excinfo.traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest")
@@ -242,7 +334,7 @@ class TestCaseFunction(Function):
@hookimpl(tryfirst=True)
-def pytest_runtest_makereport(item, call):
+def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> None:
if isinstance(item, TestCaseFunction):
if item._excinfo:
call.excinfo = item._excinfo.pop(0)
@@ -252,21 +344,26 @@ def pytest_runtest_makereport(item, call):
pass
unittest = sys.modules.get("unittest")
- if unittest and call.excinfo and call.excinfo.errisinstance(unittest.SkipTest):
- # let's substitute the excinfo with a pytest.skip one
- call2 = CallInfo.from_call(
- lambda: pytest.skip(str(call.excinfo.value)), call.when
+ if (
+ unittest
+ and call.excinfo
+ and isinstance(call.excinfo.value, unittest.SkipTest) # type: ignore[attr-defined]
+ ):
+ excinfo = call.excinfo
+ # Let's substitute the excinfo with a pytest.skip one.
+ call2 = CallInfo[None].from_call(
+ lambda: pytest.skip(str(excinfo.value)), call.when
)
call.excinfo = call2.excinfo
-# twisted trial support
+# Twisted trial support.
@hookimpl(hookwrapper=True)
-def pytest_runtest_protocol(item):
+def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
if isinstance(item, TestCaseFunction) and "twisted.trial.unittest" in sys.modules:
- ut = sys.modules["twisted.python.failure"]
+ ut: Any = sys.modules["twisted.python.failure"]
Failure__init__ = ut.Failure.__init__
check_testcase_implements_trial_reporter()
@@ -293,7 +390,7 @@ def pytest_runtest_protocol(item):
yield
-def check_testcase_implements_trial_reporter(done=[]):
+def check_testcase_implements_trial_reporter(done: List[int] = []) -> None:
if done:
return
from zope.interface import classImplements
@@ -304,5 +401,5 @@ def check_testcase_implements_trial_reporter(done=[]):
def _is_skipped(obj) -> bool:
- """Return True if the given object has been marked with @unittest.skip"""
+ """Return True if the given object has been marked with @unittest.skip."""
return bool(getattr(obj, "__unittest_skip__", False))
diff --git a/contrib/python/pytest/py3/_pytest/unraisableexception.py b/contrib/python/pytest/py3/_pytest/unraisableexception.py
new file mode 100644
index 0000000000..fcb5d8237c
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/unraisableexception.py
@@ -0,0 +1,93 @@
+import sys
+import traceback
+import warnings
+from types import TracebackType
+from typing import Any
+from typing import Callable
+from typing import Generator
+from typing import Optional
+from typing import Type
+
+import pytest
+
+
+# Copied from cpython/Lib/test/support/__init__.py, with modifications.
+class catch_unraisable_exception:
+ """Context manager catching unraisable exception using sys.unraisablehook.
+
+ Storing the exception value (cm.unraisable.exc_value) creates a reference
+ cycle. The reference cycle is broken explicitly when the context manager
+ exits.
+
+ Storing the object (cm.unraisable.object) can resurrect it if it is set to
+ an object which is being finalized. Exiting the context manager clears the
+ stored object.
+
+ Usage:
+ with catch_unraisable_exception() as cm:
+ # code creating an "unraisable exception"
+ ...
+ # check the unraisable exception: use cm.unraisable
+ ...
+ # cm.unraisable attribute no longer exists at this point
+ # (to break a reference cycle)
+ """
+
+ def __init__(self) -> None:
+ self.unraisable: Optional["sys.UnraisableHookArgs"] = None
+ self._old_hook: Optional[Callable[["sys.UnraisableHookArgs"], Any]] = None
+
+ def _hook(self, unraisable: "sys.UnraisableHookArgs") -> None:
+ # Storing unraisable.object can resurrect an object which is being
+ # finalized. Storing unraisable.exc_value creates a reference cycle.
+ self.unraisable = unraisable
+
+ def __enter__(self) -> "catch_unraisable_exception":
+ self._old_hook = sys.unraisablehook
+ sys.unraisablehook = self._hook
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ assert self._old_hook is not None
+ sys.unraisablehook = self._old_hook
+ self._old_hook = None
+ del self.unraisable
+
+
+def unraisable_exception_runtest_hook() -> Generator[None, None, None]:
+ with catch_unraisable_exception() as cm:
+ yield
+ if cm.unraisable:
+ if cm.unraisable.err_msg is not None:
+ err_msg = cm.unraisable.err_msg
+ else:
+ err_msg = "Exception ignored in"
+ msg = f"{err_msg}: {cm.unraisable.object!r}\n\n"
+ msg += "".join(
+ traceback.format_exception(
+ cm.unraisable.exc_type,
+ cm.unraisable.exc_value,
+ cm.unraisable.exc_traceback,
+ )
+ )
+ warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))
+
+
+@pytest.hookimpl(hookwrapper=True, tryfirst=True)
+def pytest_runtest_setup() -> Generator[None, None, None]:
+ yield from unraisable_exception_runtest_hook()
+
+
+@pytest.hookimpl(hookwrapper=True, tryfirst=True)
+def pytest_runtest_call() -> Generator[None, None, None]:
+ yield from unraisable_exception_runtest_hook()
+
+
+@pytest.hookimpl(hookwrapper=True, tryfirst=True)
+def pytest_runtest_teardown() -> Generator[None, None, None]:
+ yield from unraisable_exception_runtest_hook()
diff --git a/contrib/python/pytest/py3/_pytest/warning_types.py b/contrib/python/pytest/py3/_pytest/warning_types.py
index 2e03c578c0..2eadd9fe4d 100644
--- a/contrib/python/pytest/py3/_pytest/warning_types.py
+++ b/contrib/python/pytest/py3/_pytest/warning_types.py
@@ -1,81 +1,60 @@
from typing import Any
from typing import Generic
+from typing import Type
from typing import TypeVar
import attr
-from _pytest.compat import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from typing import Type # noqa: F401 (used in type string)
+from _pytest.compat import final
class PytestWarning(UserWarning):
- """
- Bases: :class:`UserWarning`.
-
- Base class for all warnings emitted by pytest.
- """
+ """Base class for all warnings emitted by pytest."""
__module__ = "pytest"
+@final
class PytestAssertRewriteWarning(PytestWarning):
- """
- Bases: :class:`PytestWarning`.
-
- Warning emitted by the pytest assert rewrite module.
- """
+ """Warning emitted by the pytest assert rewrite module."""
__module__ = "pytest"
+@final
class PytestCacheWarning(PytestWarning):
- """
- Bases: :class:`PytestWarning`.
-
- Warning emitted by the cache plugin in various situations.
- """
+ """Warning emitted by the cache plugin in various situations."""
__module__ = "pytest"
+@final
class PytestConfigWarning(PytestWarning):
- """
- Bases: :class:`PytestWarning`.
-
- Warning emitted for configuration issues.
- """
+ """Warning emitted for configuration issues."""
__module__ = "pytest"
+@final
class PytestCollectionWarning(PytestWarning):
- """
- Bases: :class:`PytestWarning`.
-
- Warning emitted when pytest is not able to collect a file or symbol in a module.
- """
+ """Warning emitted when pytest is not able to collect a file or symbol in a module."""
__module__ = "pytest"
+@final
class PytestDeprecationWarning(PytestWarning, DeprecationWarning):
- """
- Bases: :class:`pytest.PytestWarning`, :class:`DeprecationWarning`.
-
- Warning class for features that will be removed in a future version.
- """
+ """Warning class for features that will be removed in a future version."""
__module__ = "pytest"
+@final
class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
- """
- Bases: :class:`pytest.PytestWarning`, :class:`FutureWarning`.
+ """Warning category used to denote experiments in pytest.
- Warning category used to denote experiments in pytest. Use sparingly as the API might change or even be
- removed completely in future version
+ Use sparingly as the API might change or even be removed completely in a
+ future version.
"""
__module__ = "pytest"
@@ -89,24 +68,45 @@ class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
)
+@final
class PytestUnhandledCoroutineWarning(PytestWarning):
- """
- Bases: :class:`PytestWarning`.
+ """Warning emitted for an unhandled coroutine.
- Warning emitted when pytest encounters a test function which is a coroutine,
- but it was not handled by any async-aware plugin. Coroutine test functions
- are not natively supported.
+ A coroutine was encountered when collecting test functions, but was not
+ handled by any async-aware plugin.
+ Coroutine test functions are not natively supported.
"""
__module__ = "pytest"
+@final
class PytestUnknownMarkWarning(PytestWarning):
+ """Warning emitted on use of unknown markers.
+
+ See :ref:`mark` for details.
"""
- Bases: :class:`PytestWarning`.
- Warning emitted on use of unknown markers.
- See https://docs.pytest.org/en/latest/mark.html for details.
+ __module__ = "pytest"
+
+
+@final
+class PytestUnraisableExceptionWarning(PytestWarning):
+ """An unraisable exception was reported.
+
+ Unraisable exceptions are exceptions raised in :meth:`__del__ <object.__del__>`
+ implementations and similar situations when the exception cannot be raised
+ as normal.
+ """
+
+ __module__ = "pytest"
+
+
+@final
+class PytestUnhandledThreadExceptionWarning(PytestWarning):
+ """An unhandled exception occurred in a :class:`~threading.Thread`.
+
+ Such exceptions don't propagate normally.
"""
__module__ = "pytest"
@@ -115,19 +115,18 @@ class PytestUnknownMarkWarning(PytestWarning):
_W = TypeVar("_W", bound=PytestWarning)
+@final
@attr.s
class UnformattedWarning(Generic[_W]):
- """Used to hold warnings that need to format their message at runtime, as opposed to a direct message.
+ """A warning meant to be formatted during runtime.
- Using this class avoids to keep all the warning types and messages in this module, avoiding misuse.
+ This is used to hold warnings that need to format their message at runtime,
+ as opposed to a direct message.
"""
- category = attr.ib(type="Type[_W]")
+ category = attr.ib(type=Type["_W"])
template = attr.ib(type=str)
def format(self, **kwargs: Any) -> _W:
- """Returns an instance of the warning category, formatted with given kwargs"""
+ """Return an instance of the warning category, formatted with given kwargs."""
return self.category(self.template.format(**kwargs))
-
-
-PYTESTER_COPY_EXAMPLE = PytestExperimentalApiWarning.simple("testdir.copy_example")
diff --git a/contrib/python/pytest/py3/_pytest/warnings.py b/contrib/python/pytest/py3/_pytest/warnings.py
index 2a4d189d57..35eed96df5 100644
--- a/contrib/python/pytest/py3/_pytest/warnings.py
+++ b/contrib/python/pytest/py3/_pytest/warnings.py
@@ -2,106 +2,88 @@ import sys
import warnings
from contextlib import contextmanager
from typing import Generator
+from typing import Optional
+from typing import TYPE_CHECKING
import pytest
+from _pytest.config import apply_warning_filters
+from _pytest.config import Config
+from _pytest.config import parse_warning_filter
from _pytest.main import Session
+from _pytest.nodes import Item
+from _pytest.terminal import TerminalReporter
+if TYPE_CHECKING:
+ from typing_extensions import Literal
-def _setoption(wmod, arg):
- """
- Copy of the warning._setoption function but does not escape arguments.
- """
- parts = arg.split(":")
- if len(parts) > 5:
- raise wmod._OptionError("too many fields (max 5): {!r}".format(arg))
- while len(parts) < 5:
- parts.append("")
- action, message, category, module, lineno = [s.strip() for s in parts]
- action = wmod._getaction(action)
- category = wmod._getcategory(category)
- if lineno:
- try:
- lineno = int(lineno)
- if lineno < 0:
- raise ValueError
- except (ValueError, OverflowError):
- raise wmod._OptionError("invalid lineno {!r}".format(lineno))
- else:
- lineno = 0
- wmod.filterwarnings(action, message, category, module, lineno)
-
-
-def pytest_addoption(parser):
- group = parser.getgroup("pytest-warnings")
- group.addoption(
- "-W",
- "--pythonwarnings",
- action="append",
- help="set which warnings to report, see -W option of python itself.",
- )
- parser.addini(
- "filterwarnings",
- type="linelist",
- help="Each line specifies a pattern for "
- "warnings.filterwarnings. "
- "Processed after -W/--pythonwarnings.",
- )
-
-def pytest_configure(config):
+def pytest_configure(config: Config) -> None:
config.addinivalue_line(
"markers",
"filterwarnings(warning): add a warning filter to the given test. "
- "see https://docs.pytest.org/en/latest/warnings.html#pytest-mark-filterwarnings ",
+ "see https://docs.pytest.org/en/stable/warnings.html#pytest-mark-filterwarnings ",
)
@contextmanager
-def catch_warnings_for_item(config, ihook, when, item):
- """
- Context manager that catches warnings generated in the contained execution block.
+def catch_warnings_for_item(
+ config: Config,
+ ihook,
+ when: "Literal['config', 'collect', 'runtest']",
+ item: Optional[Item],
+) -> Generator[None, None, None]:
+ """Context manager that catches warnings generated in the contained execution block.
``item`` can be None if we are not in the context of an item execution.
- Each warning captured triggers the ``pytest_warning_captured`` hook.
+ Each warning captured triggers the ``pytest_warning_recorded`` hook.
"""
- cmdline_filters = config.getoption("pythonwarnings") or []
- inifilters = config.getini("filterwarnings")
+ config_filters = config.getini("filterwarnings")
+ cmdline_filters = config.known_args_namespace.pythonwarnings or []
with warnings.catch_warnings(record=True) as log:
# mypy can't infer that record=True means log is not None; help it.
assert log is not None
if not sys.warnoptions:
- # if user is not explicitly configuring warning filters, show deprecation warnings by default (#2908)
+ # If user is not explicitly configuring warning filters, show deprecation warnings by default (#2908).
warnings.filterwarnings("always", category=DeprecationWarning)
warnings.filterwarnings("always", category=PendingDeprecationWarning)
- # filters should have this precedence: mark, cmdline options, ini
- # filters should be applied in the inverse order of precedence
- for arg in inifilters:
- _setoption(warnings, arg)
-
- for arg in cmdline_filters:
- warnings._setoption(arg)
+ apply_warning_filters(config_filters, cmdline_filters)
+ # apply filters from "filterwarnings" marks
+ nodeid = "" if item is None else item.nodeid
if item is not None:
for mark in item.iter_markers(name="filterwarnings"):
for arg in mark.args:
- _setoption(warnings, arg)
+ warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
yield
for warning_message in log:
ihook.pytest_warning_captured.call_historic(
- kwargs=dict(warning_message=warning_message, when=when, item=item)
+ kwargs=dict(
+ warning_message=warning_message,
+ when=when,
+ item=item,
+ location=None,
+ )
+ )
+ ihook.pytest_warning_recorded.call_historic(
+ kwargs=dict(
+ warning_message=warning_message,
+ nodeid=nodeid,
+ when=when,
+ location=None,
+ )
)
-def warning_record_to_str(warning_message):
+def warning_record_to_str(warning_message: warnings.WarningMessage) -> str:
"""Convert a warnings.WarningMessage to a string."""
warn_msg = warning_message.message
msg = warnings.formatwarning(
- warn_msg,
+ str(warn_msg),
warning_message.category,
warning_message.filename,
warning_message.lineno,
@@ -111,7 +93,7 @@ def warning_record_to_str(warning_message):
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
-def pytest_runtest_protocol(item):
+def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
with catch_warnings_for_item(
config=item.config, ihook=item.ihook, when="runtest", item=item
):
@@ -128,7 +110,9 @@ def pytest_collection(session: Session) -> Generator[None, None, None]:
@pytest.hookimpl(hookwrapper=True)
-def pytest_terminal_summary(terminalreporter):
+def pytest_terminal_summary(
+ terminalreporter: TerminalReporter,
+) -> Generator[None, None, None]:
config = terminalreporter.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
@@ -137,7 +121,7 @@ def pytest_terminal_summary(terminalreporter):
@pytest.hookimpl(hookwrapper=True)
-def pytest_sessionfinish(session):
+def pytest_sessionfinish(session: Session) -> Generator[None, None, None]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
@@ -145,25 +129,11 @@ def pytest_sessionfinish(session):
yield
-def _issue_warning_captured(warning, hook, stacklevel):
- """
- This function should be used instead of calling ``warnings.warn`` directly when we are in the "configure" stage:
- at this point the actual options might not have been set, so we manually trigger the pytest_warning_captured
- hook so we can display these warnings in the terminal. This is a hack until we can sort out #2891.
-
- :param warning: the warning instance.
- :param hook: the hook caller
- :param stacklevel: stacklevel forwarded to warnings.warn
- """
- with warnings.catch_warnings(record=True) as records:
- warnings.simplefilter("always", type(warning))
- warnings.warn(warning, stacklevel=stacklevel)
- # Mypy can't infer that record=True means records is not None; help it.
- assert records is not None
- frame = sys._getframe(stacklevel - 1)
- location = frame.f_code.co_filename, frame.f_lineno, frame.f_code.co_name
- hook.pytest_warning_captured.call_historic(
- kwargs=dict(
- warning_message=records[0], when="config", item=None, location=location
- )
- )
+@pytest.hookimpl(hookwrapper=True)
+def pytest_load_initial_conftests(
+ early_config: "Config",
+) -> Generator[None, None, None]:
+ with catch_warnings_for_item(
+ config=early_config, ihook=early_config.hook, when="config", item=None
+ ):
+ yield
diff --git a/contrib/python/pytest/py3/patches/03-limit-id.patch b/contrib/python/pytest/py3/patches/03-limit-id.patch
index ba1c199517..f0a57f8599 100644
--- a/contrib/python/pytest/py3/patches/03-limit-id.patch
+++ b/contrib/python/pytest/py3/patches/03-limit-id.patch
@@ -1,6 +1,6 @@
--- contrib/python/pytest/py3/_pytest/python.py (index)
+++ contrib/python/pytest/py3/_pytest/python.py (working tree)
-@@ -1192,6 +1192,33 @@ def _idval(val, argname, idx, idfn, item, config):
+@@ -1339,6 +1339,33 @@ def _idval(val, argname, idx, idfn, item, config):
return str(argname) + str(idx)
diff --git a/contrib/python/pytest/py3/patches/04-support-cyrillic-id.patch b/contrib/python/pytest/py3/patches/04-support-cyrillic-id.patch
index c45fd08282..c601f183f9 100644
--- a/contrib/python/pytest/py3/patches/04-support-cyrillic-id.patch
+++ b/contrib/python/pytest/py3/patches/04-support-cyrillic-id.patch
@@ -1,6 +1,6 @@
--- contrib/python/pytest/py3/_pytest/compat.py (index)
+++ contrib/python/pytest/py3/_pytest/compat.py (working tree)
-@@ -255,7 +255,7 @@ if _PY3:
+@@ -233,7 +233,7 @@ if _PY3:
if isinstance(val, bytes):
ret = _bytes_to_ascii(val)
else:
diff --git a/contrib/python/pytest/py3/patches/05-support-readline.patch b/contrib/python/pytest/py3/patches/05-support-readline.patch
index 11d6b94f0d..0f77979192 100644
--- a/contrib/python/pytest/py3/patches/05-support-readline.patch
+++ b/contrib/python/pytest/py3/patches/05-support-readline.patch
@@ -1,15 +1,15 @@
--- contrib/python/pytest/py3/_pytest/debugging.py (index)
+++ contrib/python/pytest/py3/_pytest/debugging.py (working tree)
@@ -1,6 +1,7 @@ from __future__ import absolute_import
- """ interactive debugging with PDB, the Python Debugger. """
+ """Interactive debugging with PDB, the Python Debugger."""
import argparse
import functools
+import os
import sys
-
- from _pytest import outcomes
-@@ -9,6 +10,42 @@ from _pytest import outcomes
- from _pytest.config.exceptions import UsageError
+ import types
+ from typing import Any
+@@ -29,6 +30,42 @@ from _pytest import outcomes
+ from _pytest.runner import CallInfo
+def import_readline():
@@ -48,10 +48,10 @@
+ sys.path = old_sys_path
+
+
- def _validate_usepdb_cls(value):
+ def _validate_usepdb_cls(value: str) -> Tuple[str, str]:
"""Validate syntax of --pdbcls option."""
try:
-@@ -250,6 +287,7 @@ class pytestPDB(object):
+@@ -277,6 +314,7 @@ class pytestPDB(object):
@classmethod
def set_trace(cls, *args, **kwargs) -> None:
"""Invoke debugging via ``Pdb.set_trace``, dropping any IO capturing."""
@@ -64,6 +64,6 @@
sys.stdout.write(out)
sys.stdout.write(err)
+ tty()
+ assert call.excinfo is not None
_enter_pdb(node, call.excinfo, report)
- def pytest_internalerror(self, excrepr, excinfo):
diff --git a/contrib/python/pytest/py3/patches/06-support-ya-markers.patch b/contrib/python/pytest/py3/patches/06-support-ya-markers.patch
index 373d89c866..0726477fe8 100644
--- a/contrib/python/pytest/py3/patches/06-support-ya-markers.patch
+++ b/contrib/python/pytest/py3/patches/06-support-ya-markers.patch
@@ -1,6 +1,6 @@
--- contrib/python/pytest/py3/_pytest/mark/structures.py (index)
+++ contrib/python/pytest/py3/_pytest/mark/structures.py (working tree)
-@@ -326,7 +326,10 @@ class MarkGenerator(object):
+@@ -490,7 +490,10 @@ class MarkGenerator(object):
# example lines: "skipif(condition): skip the given test if..."
# or "hypothesis: tests which use Hypothesis", so to get the
# marker name we split on both `:` and `(`.
diff --git a/contrib/python/pytest/py3/patches/07-disable-translate-non-printable.patch b/contrib/python/pytest/py3/patches/07-disable-translate-non-printable.patch
index 1d0b80c6d4..dd81a883e8 100644
--- a/contrib/python/pytest/py3/patches/07-disable-translate-non-printable.patch
+++ b/contrib/python/pytest/py3/patches/07-disable-translate-non-printable.patch
@@ -1,6 +1,6 @@
--- contrib/python/pytest/py3/_pytest/compat.py (index)
+++ contrib/python/pytest/py3/_pytest/compat.py (working tree)
-@@ -256,7 +256,7 @@ if _PY3:
+@@ -234,7 +234,7 @@ if _PY3:
ret = _bytes_to_ascii(val)
else:
ret = val
diff --git a/contrib/python/pytest/py3/pytest/__init__.py b/contrib/python/pytest/py3/pytest/__init__.py
index 33bc3d0fbe..70177f9504 100644
--- a/contrib/python/pytest/py3/pytest/__init__.py
+++ b/contrib/python/pytest/py3/pytest/__init__.py
@@ -1,24 +1,29 @@
# PYTHON_ARGCOMPLETE_OK
-"""
-pytest: unit and functional testing with Python.
-"""
+"""pytest: unit and functional testing with Python."""
+from . import collect
from _pytest import __version__
from _pytest.assertion import register_assert_rewrite
-from _pytest.compat import _setup_collect_fakemodule
+from _pytest.cacheprovider import Cache
+from _pytest.capture import CaptureFixture
from _pytest.config import cmdline
+from _pytest.config import console_main
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import hookspec
from _pytest.config import main
from _pytest.config import UsageError
from _pytest.debugging import pytestPDB as __pytestPDB
-from _pytest.fixtures import fillfixtures as _fillfuncargs
+from _pytest.fixtures import _fillfuncargs
from _pytest.fixtures import fixture
+from _pytest.fixtures import FixtureLookupError
+from _pytest.fixtures import FixtureRequest
from _pytest.fixtures import yield_fixture
from _pytest.freeze_support import freeze_includes
+from _pytest.logging import LogCaptureFixture
from _pytest.main import Session
from _pytest.mark import MARK_GEN as mark
from _pytest.mark import param
+from _pytest.monkeypatch import MonkeyPatch
from _pytest.nodes import Collector
from _pytest.nodes import File
from _pytest.nodes import Item
@@ -27,6 +32,8 @@ from _pytest.outcomes import fail
from _pytest.outcomes import importorskip
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
+from _pytest.pytester import Pytester
+from _pytest.pytester import Testdir
from _pytest.python import Class
from _pytest.python import Function
from _pytest.python import Instance
@@ -35,7 +42,10 @@ from _pytest.python import Package
from _pytest.python_api import approx
from _pytest.python_api import raises
from _pytest.recwarn import deprecated_call
+from _pytest.recwarn import WarningsRecorder
from _pytest.recwarn import warns
+from _pytest.tmpdir import TempdirFactory
+from _pytest.tmpdir import TempPathFactory
from _pytest.warning_types import PytestAssertRewriteWarning
from _pytest.warning_types import PytestCacheWarning
from _pytest.warning_types import PytestCollectionWarning
@@ -43,25 +53,32 @@ from _pytest.warning_types import PytestConfigWarning
from _pytest.warning_types import PytestDeprecationWarning
from _pytest.warning_types import PytestExperimentalApiWarning
from _pytest.warning_types import PytestUnhandledCoroutineWarning
+from _pytest.warning_types import PytestUnhandledThreadExceptionWarning
from _pytest.warning_types import PytestUnknownMarkWarning
+from _pytest.warning_types import PytestUnraisableExceptionWarning
from _pytest.warning_types import PytestWarning
-
set_trace = __pytestPDB.set_trace
__all__ = [
"__version__",
"_fillfuncargs",
"approx",
+ "Cache",
+ "CaptureFixture",
"Class",
"cmdline",
+ "collect",
"Collector",
+ "console_main",
"deprecated_call",
"exit",
"ExitCode",
"fail",
"File",
"fixture",
+ "FixtureLookupError",
+ "FixtureRequest",
"freeze_includes",
"Function",
"hookimpl",
@@ -69,9 +86,11 @@ __all__ = [
"importorskip",
"Instance",
"Item",
+ "LogCaptureFixture",
"main",
"mark",
"Module",
+ "MonkeyPatch",
"Package",
"param",
"PytestAssertRewriteWarning",
@@ -80,20 +99,23 @@ __all__ = [
"PytestConfigWarning",
"PytestDeprecationWarning",
"PytestExperimentalApiWarning",
+ "Pytester",
"PytestUnhandledCoroutineWarning",
+ "PytestUnhandledThreadExceptionWarning",
"PytestUnknownMarkWarning",
+ "PytestUnraisableExceptionWarning",
"PytestWarning",
"raises",
"register_assert_rewrite",
"Session",
"set_trace",
"skip",
+ "TempPathFactory",
+ "Testdir",
+ "TempdirFactory",
"UsageError",
+ "WarningsRecorder",
"warns",
"xfail",
"yield_fixture",
]
-
-
-_setup_collect_fakemodule()
-del _setup_collect_fakemodule
diff --git a/contrib/python/pytest/py3/pytest/__main__.py b/contrib/python/pytest/py3/pytest/__main__.py
index 01b2f6ccfe..b170152937 100644
--- a/contrib/python/pytest/py3/pytest/__main__.py
+++ b/contrib/python/pytest/py3/pytest/__main__.py
@@ -1,7 +1,5 @@
-"""
-pytest entry point
-"""
+"""The pytest entry point."""
import pytest
if __name__ == "__main__":
- raise SystemExit(pytest.main())
+ raise SystemExit(pytest.console_main())
diff --git a/contrib/python/pytest/py3/pytest/collect.py b/contrib/python/pytest/py3/pytest/collect.py
new file mode 100644
index 0000000000..2edf4470f4
--- /dev/null
+++ b/contrib/python/pytest/py3/pytest/collect.py
@@ -0,0 +1,39 @@
+import sys
+import warnings
+from types import ModuleType
+from typing import Any
+from typing import List
+
+import pytest
+from _pytest.deprecated import PYTEST_COLLECT_MODULE
+
+COLLECT_FAKEMODULE_ATTRIBUTES = [
+ "Collector",
+ "Module",
+ "Function",
+ "Instance",
+ "Session",
+ "Item",
+ "Class",
+ "File",
+ "_fillfuncargs",
+]
+
+
+class FakeCollectModule(ModuleType):
+ def __init__(self) -> None:
+ super().__init__("pytest.collect")
+ self.__all__ = list(COLLECT_FAKEMODULE_ATTRIBUTES)
+ self.__pytest = pytest
+
+ def __dir__(self) -> List[str]:
+ return dir(super()) + self.__all__
+
+ def __getattr__(self, name: str) -> Any:
+ if name not in self.__all__:
+ raise AttributeError(name)
+ warnings.warn(PYTEST_COLLECT_MODULE.format(name=name), stacklevel=2)
+ return getattr(pytest, name)
+
+
+sys.modules["pytest.collect"] = FakeCollectModule()
diff --git a/contrib/python/more-itertools/py3/tests/__init__.py b/contrib/python/pytest/py3/pytest/py.typed
index e69de29bb2..e69de29bb2 100644
--- a/contrib/python/more-itertools/py3/tests/__init__.py
+++ b/contrib/python/pytest/py3/pytest/py.typed
diff --git a/contrib/python/pytest/py3/ya.make b/contrib/python/pytest/py3/ya.make
index 88de1914cb..1d9a6034e4 100644
--- a/contrib/python/pytest/py3/ya.make
+++ b/contrib/python/pytest/py3/ya.make
@@ -2,17 +2,17 @@ PY3_LIBRARY()
OWNER(dmitko g:python-contrib)
-VERSION(5.4.3)
+VERSION(6.2.5)
LICENSE(MIT)
PEERDIR(
contrib/python/attrs
- contrib/python/more-itertools
+ contrib/python/iniconfig
contrib/python/packaging
contrib/python/pluggy
contrib/python/py
- contrib/python/wcwidth
+ contrib/python/toml
)
IF (OS_WINDOWS)
@@ -37,6 +37,8 @@ PY_SRCS(
_pytest/_code/source.py
_pytest/_io/__init__.py
_pytest/_io/saferepr.py
+ _pytest/_io/terminalwriter.py
+ _pytest/_io/wcwidth.py
_pytest/_version.py
_pytest/assertion/__init__.py
_pytest/assertion/rewrite.py
@@ -61,8 +63,7 @@ PY_SRCS(
_pytest/logging.py
_pytest/main.py
_pytest/mark/__init__.py
- _pytest/mark/evaluate.py
- _pytest/mark/legacy.py
+ _pytest/mark/expression.py
_pytest/mark/structures.py
_pytest/monkeypatch.py
_pytest/nodes.py
@@ -71,11 +72,11 @@ PY_SRCS(
_pytest/pastebin.py
_pytest/pathlib.py
_pytest/pytester.py
+ _pytest/pytester_assertions.py
_pytest/python.py
_pytest/python_api.py
_pytest/recwarn.py
_pytest/reports.py
- _pytest/resultlog.py
_pytest/runner.py
_pytest/setuponly.py
_pytest/setupplan.py
@@ -83,12 +84,16 @@ PY_SRCS(
_pytest/stepwise.py
_pytest/store.py
_pytest/terminal.py
+ _pytest/threadexception.py
+ _pytest/timing.py
_pytest/tmpdir.py
_pytest/unittest.py
+ _pytest/unraisableexception.py
_pytest/warning_types.py
_pytest/warnings.py
pytest/__init__.py
pytest/__main__.py
+ pytest/collect.py
)
RESOURCE_FILES(
@@ -96,6 +101,8 @@ RESOURCE_FILES(
.dist-info/METADATA
.dist-info/entry_points.txt
.dist-info/top_level.txt
+ _pytest/py.typed
+ pytest/py.typed
)
END()
diff --git a/contrib/python/toml/.dist-info/METADATA b/contrib/python/toml/.dist-info/METADATA
new file mode 100644
index 0000000000..6f2635ce4d
--- /dev/null
+++ b/contrib/python/toml/.dist-info/METADATA
@@ -0,0 +1,255 @@
+Metadata-Version: 2.1
+Name: toml
+Version: 0.10.2
+Summary: Python Library for Tom's Obvious, Minimal Language
+Home-page: https://github.com/uiri/toml
+Author: William Pearson
+Author-email: uiri@xqz.ca
+License: MIT
+Platform: UNKNOWN
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: Intended Audience :: Developers
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Operating System :: OS Independent
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 2
+Classifier: Programming Language :: Python :: 2.6
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.3
+Classifier: Programming Language :: Python :: 3.4
+Classifier: Programming Language :: Python :: 3.5
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: Implementation :: CPython
+Classifier: Programming Language :: Python :: Implementation :: PyPy
+Requires-Python: >=2.6, !=3.0.*, !=3.1.*, !=3.2.*
+
+****
+TOML
+****
+
+.. image:: https://img.shields.io/pypi/v/toml
+ :target: https://pypi.org/project/toml/
+
+.. image:: https://travis-ci.org/uiri/toml.svg?branch=master
+ :target: https://travis-ci.org/uiri/toml
+
+.. image:: https://img.shields.io/pypi/pyversions/toml.svg
+ :target: https://pypi.org/project/toml/
+
+
+A Python library for parsing and creating `TOML <https://en.wikipedia.org/wiki/TOML>`_.
+
+The module passes `the TOML test suite <https://github.com/BurntSushi/toml-test>`_.
+
+See also:
+
+* `The TOML Standard <https://github.com/toml-lang/toml>`_
+* `The currently supported TOML specification <https://github.com/toml-lang/toml/blob/v0.5.0/README.md>`_
+
+Installation
+============
+
+To install the latest release on `PyPI <https://pypi.org/project/toml/>`_,
+simply run:
+
+::
+
+ pip install toml
+
+Or to install the latest development version, run:
+
+::
+
+ git clone https://github.com/uiri/toml.git
+ cd toml
+ python setup.py install
+
+Quick Tutorial
+==============
+
+*toml.loads* takes in a string containing standard TOML-formatted data and
+returns a dictionary containing the parsed data.
+
+.. code:: pycon
+
+ >>> import toml
+ >>> toml_string = """
+ ... # This is a TOML document.
+ ...
+ ... title = "TOML Example"
+ ...
+ ... [owner]
+ ... name = "Tom Preston-Werner"
+ ... dob = 1979-05-27T07:32:00-08:00 # First class dates
+ ...
+ ... [database]
+ ... server = "192.168.1.1"
+ ... ports = [ 8001, 8001, 8002 ]
+ ... connection_max = 5000
+ ... enabled = true
+ ...
+ ... [servers]
+ ...
+ ... # Indentation (tabs and/or spaces) is allowed but not required
+ ... [servers.alpha]
+ ... ip = "10.0.0.1"
+ ... dc = "eqdc10"
+ ...
+ ... [servers.beta]
+ ... ip = "10.0.0.2"
+ ... dc = "eqdc10"
+ ...
+ ... [clients]
+ ... data = [ ["gamma", "delta"], [1, 2] ]
+ ...
+ ... # Line breaks are OK when inside arrays
+ ... hosts = [
+ ... "alpha",
+ ... "omega"
+ ... ]
+ ... """
+ >>> parsed_toml = toml.loads(toml_string)
+
+
+*toml.dumps* takes a dictionary and returns a string containing the
+corresponding TOML-formatted data.
+
+.. code:: pycon
+
+ >>> new_toml_string = toml.dumps(parsed_toml)
+ >>> print(new_toml_string)
+ title = "TOML Example"
+ [owner]
+ name = "Tom Preston-Werner"
+ dob = 1979-05-27T07:32:00Z
+ [database]
+ server = "192.168.1.1"
+ ports = [ 8001, 8001, 8002,]
+ connection_max = 5000
+ enabled = true
+ [clients]
+ data = [ [ "gamma", "delta",], [ 1, 2,],]
+ hosts = [ "alpha", "omega",]
+ [servers.alpha]
+ ip = "10.0.0.1"
+ dc = "eqdc10"
+ [servers.beta]
+ ip = "10.0.0.2"
+ dc = "eqdc10"
+
+*toml.dump* takes a dictionary and a file descriptor and returns a string containing the
+corresponding TOML-formatted data.
+
+.. code:: pycon
+
+ >>> with open('new_toml_file.toml', 'w') as f:
+ ... new_toml_string = toml.dump(parsed_toml, f)
+ >>> print(new_toml_string)
+ title = "TOML Example"
+ [owner]
+ name = "Tom Preston-Werner"
+ dob = 1979-05-27T07:32:00Z
+ [database]
+ server = "192.168.1.1"
+ ports = [ 8001, 8001, 8002,]
+ connection_max = 5000
+ enabled = true
+ [clients]
+ data = [ [ "gamma", "delta",], [ 1, 2,],]
+ hosts = [ "alpha", "omega",]
+ [servers.alpha]
+ ip = "10.0.0.1"
+ dc = "eqdc10"
+ [servers.beta]
+ ip = "10.0.0.2"
+ dc = "eqdc10"
+
+For more functions, view the API Reference below.
+
+Note
+----
+
+For Numpy users, by default the data types ``np.floatX`` will not be translated to floats by toml, but will instead be encoded as strings. To get around this, specify the ``TomlNumpyEncoder`` when saving your data.
+
+.. code:: pycon
+
+ >>> import toml
+ >>> import numpy as np
+ >>> a = np.arange(0, 10, dtype=np.double)
+ >>> output = {'a': a}
+ >>> toml.dumps(output)
+ 'a = [ "0.0", "1.0", "2.0", "3.0", "4.0", "5.0", "6.0", "7.0", "8.0", "9.0",]\n'
+ >>> toml.dumps(output, encoder=toml.TomlNumpyEncoder())
+ 'a = [ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,]\n'
+
+API Reference
+=============
+
+``toml.load(f, _dict=dict)``
+ Parse a file or a list of files as TOML and return a dictionary.
+
+ :Args:
+ * ``f``: A path to a file, list of filepaths (to be read into single
+ object) or a file descriptor
+ * ``_dict``: The class of the dictionary object to be returned
+
+ :Returns:
+ A dictionary (or object ``_dict``) containing parsed TOML data
+
+ :Raises:
+ * ``TypeError``: When ``f`` is an invalid type or is a list containing
+ invalid types
+ * ``TomlDecodeError``: When an error occurs while decoding the file(s)
+
+``toml.loads(s, _dict=dict)``
+ Parse a TOML-formatted string to a dictionary.
+
+ :Args:
+ * ``s``: The TOML-formatted string to be parsed
+ * ``_dict``: Specifies the class of the returned toml dictionary
+
+ :Returns:
+ A dictionary (or object ``_dict``) containing parsed TOML data
+
+ :Raises:
+ * ``TypeError``: When a non-string object is passed
+ * ``TomlDecodeError``: When an error occurs while decoding the
+ TOML-formatted string
+
+``toml.dump(o, f, encoder=None)``
+ Write a dictionary to a file containing TOML-formatted data
+
+ :Args:
+ * ``o``: An object to be converted into TOML
+ * ``f``: A File descriptor where the TOML-formatted output should be stored
+ * ``encoder``: An instance of ``TomlEncoder`` (or subclass) for encoding the object. If ``None``, will default to ``TomlEncoder``
+
+ :Returns:
+ A string containing the TOML-formatted data corresponding to object ``o``
+
+ :Raises:
+ * ``TypeError``: When anything other than file descriptor is passed
+
+``toml.dumps(o, encoder=None)``
+ Create a TOML-formatted string from an input object
+
+ :Args:
+ * ``o``: An object to be converted into TOML
+ * ``encoder``: An instance of ``TomlEncoder`` (or subclass) for encoding the object. If ``None``, will default to ``TomlEncoder``
+
+ :Returns:
+ A string containing the TOML-formatted data corresponding to object ``o``
+
+
+
+Licensing
+=========
+
+This project is released under the terms of the MIT Open Source License. View
+*LICENSE.txt* for more information.
+
+
diff --git a/contrib/python/toml/.dist-info/top_level.txt b/contrib/python/toml/.dist-info/top_level.txt
new file mode 100644
index 0000000000..bd79a658fe
--- /dev/null
+++ b/contrib/python/toml/.dist-info/top_level.txt
@@ -0,0 +1 @@
+toml
diff --git a/contrib/python/toml/LICENSE b/contrib/python/toml/LICENSE
new file mode 100644
index 0000000000..5010e3075e
--- /dev/null
+++ b/contrib/python/toml/LICENSE
@@ -0,0 +1,27 @@
+The MIT License
+
+Copyright 2013-2019 William Pearson
+Copyright 2015-2016 Julien Enselme
+Copyright 2016 Google Inc.
+Copyright 2017 Samuel Vasko
+Copyright 2017 Nate Prewitt
+Copyright 2017 Jack Evans
+Copyright 2019 Filippo Broggini
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE. \ No newline at end of file
diff --git a/contrib/python/toml/README.rst b/contrib/python/toml/README.rst
new file mode 100644
index 0000000000..b65ae72ef4
--- /dev/null
+++ b/contrib/python/toml/README.rst
@@ -0,0 +1,224 @@
+****
+TOML
+****
+
+.. image:: https://img.shields.io/pypi/v/toml
+ :target: https://pypi.org/project/toml/
+
+.. image:: https://travis-ci.org/uiri/toml.svg?branch=master
+ :target: https://travis-ci.org/uiri/toml
+
+.. image:: https://img.shields.io/pypi/pyversions/toml.svg
+ :target: https://pypi.org/project/toml/
+
+
+A Python library for parsing and creating `TOML <https://en.wikipedia.org/wiki/TOML>`_.
+
+The module passes `the TOML test suite <https://github.com/BurntSushi/toml-test>`_.
+
+See also:
+
+* `The TOML Standard <https://github.com/toml-lang/toml>`_
+* `The currently supported TOML specification <https://github.com/toml-lang/toml/blob/v0.5.0/README.md>`_
+
+Installation
+============
+
+To install the latest release on `PyPI <https://pypi.org/project/toml/>`_,
+simply run:
+
+::
+
+ pip install toml
+
+Or to install the latest development version, run:
+
+::
+
+ git clone https://github.com/uiri/toml.git
+ cd toml
+ python setup.py install
+
+Quick Tutorial
+==============
+
+*toml.loads* takes in a string containing standard TOML-formatted data and
+returns a dictionary containing the parsed data.
+
+.. code:: pycon
+
+ >>> import toml
+ >>> toml_string = """
+ ... # This is a TOML document.
+ ...
+ ... title = "TOML Example"
+ ...
+ ... [owner]
+ ... name = "Tom Preston-Werner"
+ ... dob = 1979-05-27T07:32:00-08:00 # First class dates
+ ...
+ ... [database]
+ ... server = "192.168.1.1"
+ ... ports = [ 8001, 8001, 8002 ]
+ ... connection_max = 5000
+ ... enabled = true
+ ...
+ ... [servers]
+ ...
+ ... # Indentation (tabs and/or spaces) is allowed but not required
+ ... [servers.alpha]
+ ... ip = "10.0.0.1"
+ ... dc = "eqdc10"
+ ...
+ ... [servers.beta]
+ ... ip = "10.0.0.2"
+ ... dc = "eqdc10"
+ ...
+ ... [clients]
+ ... data = [ ["gamma", "delta"], [1, 2] ]
+ ...
+ ... # Line breaks are OK when inside arrays
+ ... hosts = [
+ ... "alpha",
+ ... "omega"
+ ... ]
+ ... """
+ >>> parsed_toml = toml.loads(toml_string)
+
+
+*toml.dumps* takes a dictionary and returns a string containing the
+corresponding TOML-formatted data.
+
+.. code:: pycon
+
+ >>> new_toml_string = toml.dumps(parsed_toml)
+ >>> print(new_toml_string)
+ title = "TOML Example"
+ [owner]
+ name = "Tom Preston-Werner"
+ dob = 1979-05-27T07:32:00Z
+ [database]
+ server = "192.168.1.1"
+ ports = [ 8001, 8001, 8002,]
+ connection_max = 5000
+ enabled = true
+ [clients]
+ data = [ [ "gamma", "delta",], [ 1, 2,],]
+ hosts = [ "alpha", "omega",]
+ [servers.alpha]
+ ip = "10.0.0.1"
+ dc = "eqdc10"
+ [servers.beta]
+ ip = "10.0.0.2"
+ dc = "eqdc10"
+
+*toml.dump* takes a dictionary and a file descriptor and returns a string containing the
+corresponding TOML-formatted data.
+
+.. code:: pycon
+
+ >>> with open('new_toml_file.toml', 'w') as f:
+ ... new_toml_string = toml.dump(parsed_toml, f)
+ >>> print(new_toml_string)
+ title = "TOML Example"
+ [owner]
+ name = "Tom Preston-Werner"
+ dob = 1979-05-27T07:32:00Z
+ [database]
+ server = "192.168.1.1"
+ ports = [ 8001, 8001, 8002,]
+ connection_max = 5000
+ enabled = true
+ [clients]
+ data = [ [ "gamma", "delta",], [ 1, 2,],]
+ hosts = [ "alpha", "omega",]
+ [servers.alpha]
+ ip = "10.0.0.1"
+ dc = "eqdc10"
+ [servers.beta]
+ ip = "10.0.0.2"
+ dc = "eqdc10"
+
+For more functions, view the API Reference below.
+
+Note
+----
+
+For Numpy users, by default the data types ``np.floatX`` will not be translated to floats by toml, but will instead be encoded as strings. To get around this, specify the ``TomlNumpyEncoder`` when saving your data.
+
+.. code:: pycon
+
+ >>> import toml
+ >>> import numpy as np
+ >>> a = np.arange(0, 10, dtype=np.double)
+ >>> output = {'a': a}
+ >>> toml.dumps(output)
+ 'a = [ "0.0", "1.0", "2.0", "3.0", "4.0", "5.0", "6.0", "7.0", "8.0", "9.0",]\n'
+ >>> toml.dumps(output, encoder=toml.TomlNumpyEncoder())
+ 'a = [ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,]\n'
+
+API Reference
+=============
+
+``toml.load(f, _dict=dict)``
+ Parse a file or a list of files as TOML and return a dictionary.
+
+ :Args:
+ * ``f``: A path to a file, list of filepaths (to be read into single
+ object) or a file descriptor
+ * ``_dict``: The class of the dictionary object to be returned
+
+ :Returns:
+ A dictionary (or object ``_dict``) containing parsed TOML data
+
+ :Raises:
+ * ``TypeError``: When ``f`` is an invalid type or is a list containing
+ invalid types
+ * ``TomlDecodeError``: When an error occurs while decoding the file(s)
+
+``toml.loads(s, _dict=dict)``
+ Parse a TOML-formatted string to a dictionary.
+
+ :Args:
+ * ``s``: The TOML-formatted string to be parsed
+ * ``_dict``: Specifies the class of the returned toml dictionary
+
+ :Returns:
+ A dictionary (or object ``_dict``) containing parsed TOML data
+
+ :Raises:
+ * ``TypeError``: When a non-string object is passed
+ * ``TomlDecodeError``: When an error occurs while decoding the
+ TOML-formatted string
+
+``toml.dump(o, f, encoder=None)``
+ Write a dictionary to a file containing TOML-formatted data
+
+ :Args:
+ * ``o``: An object to be converted into TOML
+ * ``f``: A File descriptor where the TOML-formatted output should be stored
+ * ``encoder``: An instance of ``TomlEncoder`` (or subclass) for encoding the object. If ``None``, will default to ``TomlEncoder``
+
+ :Returns:
+ A string containing the TOML-formatted data corresponding to object ``o``
+
+ :Raises:
+ * ``TypeError``: When anything other than file descriptor is passed
+
+``toml.dumps(o, encoder=None)``
+ Create a TOML-formatted string from an input object
+
+ :Args:
+ * ``o``: An object to be converted into TOML
+ * ``encoder``: An instance of ``TomlEncoder`` (or subclass) for encoding the object. If ``None``, will default to ``TomlEncoder``
+
+ :Returns:
+ A string containing the TOML-formatted data corresponding to object ``o``
+
+
+
+Licensing
+=========
+
+This project is released under the terms of the MIT Open Source License. View
+*LICENSE.txt* for more information.
diff --git a/contrib/python/toml/toml/__init__.py b/contrib/python/toml/toml/__init__.py
new file mode 100644
index 0000000000..7719ac23a7
--- /dev/null
+++ b/contrib/python/toml/toml/__init__.py
@@ -0,0 +1,25 @@
+"""Python module which parses and emits TOML.
+
+Released under the MIT license.
+"""
+
+from toml import encoder
+from toml import decoder
+
+__version__ = "0.10.2"
+_spec_ = "0.5.0"
+
+load = decoder.load
+loads = decoder.loads
+TomlDecoder = decoder.TomlDecoder
+TomlDecodeError = decoder.TomlDecodeError
+TomlPreserveCommentDecoder = decoder.TomlPreserveCommentDecoder
+
+dump = encoder.dump
+dumps = encoder.dumps
+TomlEncoder = encoder.TomlEncoder
+TomlArraySeparatorEncoder = encoder.TomlArraySeparatorEncoder
+TomlPreserveInlineDictEncoder = encoder.TomlPreserveInlineDictEncoder
+TomlNumpyEncoder = encoder.TomlNumpyEncoder
+TomlPreserveCommentEncoder = encoder.TomlPreserveCommentEncoder
+TomlPathlibEncoder = encoder.TomlPathlibEncoder
diff --git a/contrib/python/toml/toml/__init__.pyi b/contrib/python/toml/toml/__init__.pyi
new file mode 100644
index 0000000000..94c20f449c
--- /dev/null
+++ b/contrib/python/toml/toml/__init__.pyi
@@ -0,0 +1,15 @@
+from toml import decoder as decoder, encoder as encoder
+
+load = decoder.load
+loads = decoder.loads
+TomlDecoder = decoder.TomlDecoder
+TomlDecodeError = decoder.TomlDecodeError
+TomlPreserveCommentDecoder = decoder.TomlPreserveCommentDecoder
+dump = encoder.dump
+dumps = encoder.dumps
+TomlEncoder = encoder.TomlEncoder
+TomlArraySeparatorEncoder = encoder.TomlArraySeparatorEncoder
+TomlPreserveInlineDictEncoder = encoder.TomlPreserveInlineDictEncoder
+TomlNumpyEncoder = encoder.TomlNumpyEncoder
+TomlPreserveCommentEncoder = encoder.TomlPreserveCommentEncoder
+TomlPathlibEncoder = encoder.TomlPathlibEncoder
diff --git a/contrib/python/toml/toml/decoder.py b/contrib/python/toml/toml/decoder.py
new file mode 100644
index 0000000000..bf400e9761
--- /dev/null
+++ b/contrib/python/toml/toml/decoder.py
@@ -0,0 +1,1057 @@
+import datetime
+import io
+from os import linesep
+import re
+import sys
+
+from toml.tz import TomlTz
+
+if sys.version_info < (3,):
+ _range = xrange # noqa: F821
+else:
+ unicode = str
+ _range = range
+ basestring = str
+ unichr = chr
+
+
+def _detect_pathlib_path(p):
+ if (3, 4) <= sys.version_info:
+ import pathlib
+ if isinstance(p, pathlib.PurePath):
+ return True
+ return False
+
+
+def _ispath(p):
+ if isinstance(p, (bytes, basestring)):
+ return True
+ return _detect_pathlib_path(p)
+
+
+def _getpath(p):
+ if (3, 6) <= sys.version_info:
+ import os
+ return os.fspath(p)
+ if _detect_pathlib_path(p):
+ return str(p)
+ return p
+
+
+try:
+ FNFError = FileNotFoundError
+except NameError:
+ FNFError = IOError
+
+
+TIME_RE = re.compile(r"([0-9]{2}):([0-9]{2}):([0-9]{2})(\.([0-9]{3,6}))?")
+
+
+class TomlDecodeError(ValueError):
+ """Base toml Exception / Error."""
+
+ def __init__(self, msg, doc, pos):
+ lineno = doc.count('\n', 0, pos) + 1
+ colno = pos - doc.rfind('\n', 0, pos)
+ emsg = '{} (line {} column {} char {})'.format(msg, lineno, colno, pos)
+ ValueError.__init__(self, emsg)
+ self.msg = msg
+ self.doc = doc
+ self.pos = pos
+ self.lineno = lineno
+ self.colno = colno
+
+
+# Matches a TOML number, which allows underscores for readability
+_number_with_underscores = re.compile('([0-9])(_([0-9]))*')
+
+
+class CommentValue(object):
+ def __init__(self, val, comment, beginline, _dict):
+ self.val = val
+ separator = "\n" if beginline else " "
+ self.comment = separator + comment
+ self._dict = _dict
+
+ def __getitem__(self, key):
+ return self.val[key]
+
+ def __setitem__(self, key, value):
+ self.val[key] = value
+
+ def dump(self, dump_value_func):
+ retstr = dump_value_func(self.val)
+ if isinstance(self.val, self._dict):
+ return self.comment + "\n" + unicode(retstr)
+ else:
+ return unicode(retstr) + self.comment
+
+
+def _strictly_valid_num(n):
+ n = n.strip()
+ if not n:
+ return False
+ if n[0] == '_':
+ return False
+ if n[-1] == '_':
+ return False
+ if "_." in n or "._" in n:
+ return False
+ if len(n) == 1:
+ return True
+ if n[0] == '0' and n[1] not in ['.', 'o', 'b', 'x']:
+ return False
+ if n[0] == '+' or n[0] == '-':
+ n = n[1:]
+ if len(n) > 1 and n[0] == '0' and n[1] != '.':
+ return False
+ if '__' in n:
+ return False
+ return True
+
+
+def load(f, _dict=dict, decoder=None):
+ """Parses named file or files as toml and returns a dictionary
+
+ Args:
+ f: Path to the file to open, array of files to read into single dict
+ or a file descriptor
+ _dict: (optional) Specifies the class of the returned toml dictionary
+ decoder: The decoder to use
+
+ Returns:
+ Parsed toml file represented as a dictionary
+
+ Raises:
+ TypeError -- When f is invalid type
+ TomlDecodeError: Error while decoding toml
+ IOError / FileNotFoundError -- When an array with no valid (existing)
+ (Python 2 / Python 3) file paths is passed
+ """
+
+ if _ispath(f):
+ with io.open(_getpath(f), encoding='utf-8') as ffile:
+ return loads(ffile.read(), _dict, decoder)
+ elif isinstance(f, list):
+ from os import path as op
+ from warnings import warn
+ if not [path for path in f if op.exists(path)]:
+ error_msg = "Load expects a list to contain filenames only."
+ error_msg += linesep
+ error_msg += ("The list needs to contain the path of at least one "
+ "existing file.")
+ raise FNFError(error_msg)
+ if decoder is None:
+ decoder = TomlDecoder(_dict)
+ d = decoder.get_empty_table()
+ for l in f: # noqa: E741
+ if op.exists(l):
+ d.update(load(l, _dict, decoder))
+ else:
+ warn("Non-existent filename in list with at least one valid "
+ "filename")
+ return d
+ else:
+ try:
+ return loads(f.read(), _dict, decoder)
+ except AttributeError:
+ raise TypeError("You can only load a file descriptor, filename or "
+ "list")
+
+
+_groupname_re = re.compile(r'^[A-Za-z0-9_-]+$')
+
+
+def loads(s, _dict=dict, decoder=None):
+ """Parses string as toml
+
+ Args:
+ s: String to be parsed
+ _dict: (optional) Specifies the class of the returned toml dictionary
+
+ Returns:
+ Parsed toml file represented as a dictionary
+
+ Raises:
+ TypeError: When a non-string is passed
+ TomlDecodeError: Error while decoding toml
+ """
+
+ implicitgroups = []
+ if decoder is None:
+ decoder = TomlDecoder(_dict)
+ retval = decoder.get_empty_table()
+ currentlevel = retval
+ if not isinstance(s, basestring):
+ raise TypeError("Expecting something like a string")
+
+ if not isinstance(s, unicode):
+ s = s.decode('utf8')
+
+ original = s
+ sl = list(s)
+ openarr = 0
+ openstring = False
+ openstrchar = ""
+ multilinestr = False
+ arrayoftables = False
+ beginline = True
+ keygroup = False
+ dottedkey = False
+ keyname = 0
+ key = ''
+ prev_key = ''
+ line_no = 1
+
+ for i, item in enumerate(sl):
+ if item == '\r' and sl[i + 1] == '\n':
+ sl[i] = ' '
+ continue
+ if keyname:
+ key += item
+ if item == '\n':
+ raise TomlDecodeError("Key name found without value."
+ " Reached end of line.", original, i)
+ if openstring:
+ if item == openstrchar:
+ oddbackslash = False
+ k = 1
+ while i >= k and sl[i - k] == '\\':
+ oddbackslash = not oddbackslash
+ k += 1
+ if not oddbackslash:
+ keyname = 2
+ openstring = False
+ openstrchar = ""
+ continue
+ elif keyname == 1:
+ if item.isspace():
+ keyname = 2
+ continue
+ elif item == '.':
+ dottedkey = True
+ continue
+ elif item.isalnum() or item == '_' or item == '-':
+ continue
+ elif (dottedkey and sl[i - 1] == '.' and
+ (item == '"' or item == "'")):
+ openstring = True
+ openstrchar = item
+ continue
+ elif keyname == 2:
+ if item.isspace():
+ if dottedkey:
+ nextitem = sl[i + 1]
+ if not nextitem.isspace() and nextitem != '.':
+ keyname = 1
+ continue
+ if item == '.':
+ dottedkey = True
+ nextitem = sl[i + 1]
+ if not nextitem.isspace() and nextitem != '.':
+ keyname = 1
+ continue
+ if item == '=':
+ keyname = 0
+ prev_key = key[:-1].rstrip()
+ key = ''
+ dottedkey = False
+ else:
+ raise TomlDecodeError("Found invalid character in key name: '" +
+ item + "'. Try quoting the key name.",
+ original, i)
+ if item == "'" and openstrchar != '"':
+ k = 1
+ try:
+ while sl[i - k] == "'":
+ k += 1
+ if k == 3:
+ break
+ except IndexError:
+ pass
+ if k == 3:
+ multilinestr = not multilinestr
+ openstring = multilinestr
+ else:
+ openstring = not openstring
+ if openstring:
+ openstrchar = "'"
+ else:
+ openstrchar = ""
+ if item == '"' and openstrchar != "'":
+ oddbackslash = False
+ k = 1
+ tripquote = False
+ try:
+ while sl[i - k] == '"':
+ k += 1
+ if k == 3:
+ tripquote = True
+ break
+ if k == 1 or (k == 3 and tripquote):
+ while sl[i - k] == '\\':
+ oddbackslash = not oddbackslash
+ k += 1
+ except IndexError:
+ pass
+ if not oddbackslash:
+ if tripquote:
+ multilinestr = not multilinestr
+ openstring = multilinestr
+ else:
+ openstring = not openstring
+ if openstring:
+ openstrchar = '"'
+ else:
+ openstrchar = ""
+ if item == '#' and (not openstring and not keygroup and
+ not arrayoftables):
+ j = i
+ comment = ""
+ try:
+ while sl[j] != '\n':
+ comment += s[j]
+ sl[j] = ' '
+ j += 1
+ except IndexError:
+ break
+ if not openarr:
+ decoder.preserve_comment(line_no, prev_key, comment, beginline)
+ if item == '[' and (not openstring and not keygroup and
+ not arrayoftables):
+ if beginline:
+ if len(sl) > i + 1 and sl[i + 1] == '[':
+ arrayoftables = True
+ else:
+ keygroup = True
+ else:
+ openarr += 1
+ if item == ']' and not openstring:
+ if keygroup:
+ keygroup = False
+ elif arrayoftables:
+ if sl[i - 1] == ']':
+ arrayoftables = False
+ else:
+ openarr -= 1
+ if item == '\n':
+ if openstring or multilinestr:
+ if not multilinestr:
+ raise TomlDecodeError("Unbalanced quotes", original, i)
+ if ((sl[i - 1] == "'" or sl[i - 1] == '"') and (
+ sl[i - 2] == sl[i - 1])):
+ sl[i] = sl[i - 1]
+ if sl[i - 3] == sl[i - 1]:
+ sl[i - 3] = ' '
+ elif openarr:
+ sl[i] = ' '
+ else:
+ beginline = True
+ line_no += 1
+ elif beginline and sl[i] != ' ' and sl[i] != '\t':
+ beginline = False
+ if not keygroup and not arrayoftables:
+ if sl[i] == '=':
+ raise TomlDecodeError("Found empty keyname. ", original, i)
+ keyname = 1
+ key += item
+ if keyname:
+ raise TomlDecodeError("Key name found without value."
+ " Reached end of file.", original, len(s))
+ if openstring: # reached EOF and have an unterminated string
+ raise TomlDecodeError("Unterminated string found."
+ " Reached end of file.", original, len(s))
+ s = ''.join(sl)
+ s = s.split('\n')
+ multikey = None
+ multilinestr = ""
+ multibackslash = False
+ pos = 0
+ for idx, line in enumerate(s):
+ if idx > 0:
+ pos += len(s[idx - 1]) + 1
+
+ decoder.embed_comments(idx, currentlevel)
+
+ if not multilinestr or multibackslash or '\n' not in multilinestr:
+ line = line.strip()
+ if line == "" and (not multikey or multibackslash):
+ continue
+ if multikey:
+ if multibackslash:
+ multilinestr += line
+ else:
+ multilinestr += line
+ multibackslash = False
+ closed = False
+ if multilinestr[0] == '[':
+ closed = line[-1] == ']'
+ elif len(line) > 2:
+ closed = (line[-1] == multilinestr[0] and
+ line[-2] == multilinestr[0] and
+ line[-3] == multilinestr[0])
+ if closed:
+ try:
+ value, vtype = decoder.load_value(multilinestr)
+ except ValueError as err:
+ raise TomlDecodeError(str(err), original, pos)
+ currentlevel[multikey] = value
+ multikey = None
+ multilinestr = ""
+ else:
+ k = len(multilinestr) - 1
+ while k > -1 and multilinestr[k] == '\\':
+ multibackslash = not multibackslash
+ k -= 1
+ if multibackslash:
+ multilinestr = multilinestr[:-1]
+ else:
+ multilinestr += "\n"
+ continue
+ if line[0] == '[':
+ arrayoftables = False
+ if len(line) == 1:
+ raise TomlDecodeError("Opening key group bracket on line by "
+ "itself.", original, pos)
+ if line[1] == '[':
+ arrayoftables = True
+ line = line[2:]
+ splitstr = ']]'
+ else:
+ line = line[1:]
+ splitstr = ']'
+ i = 1
+ quotesplits = decoder._get_split_on_quotes(line)
+ quoted = False
+ for quotesplit in quotesplits:
+ if not quoted and splitstr in quotesplit:
+ break
+ i += quotesplit.count(splitstr)
+ quoted = not quoted
+ line = line.split(splitstr, i)
+ if len(line) < i + 1 or line[-1].strip() != "":
+ raise TomlDecodeError("Key group not on a line by itself.",
+ original, pos)
+ groups = splitstr.join(line[:-1]).split('.')
+ i = 0
+ while i < len(groups):
+ groups[i] = groups[i].strip()
+ if len(groups[i]) > 0 and (groups[i][0] == '"' or
+ groups[i][0] == "'"):
+ groupstr = groups[i]
+ j = i + 1
+ while ((not groupstr[0] == groupstr[-1]) or
+ len(groupstr) == 1):
+ j += 1
+ if j > len(groups) + 2:
+ raise TomlDecodeError("Invalid group name '" +
+ groupstr + "' Something " +
+ "went wrong.", original, pos)
+ groupstr = '.'.join(groups[i:j]).strip()
+ groups[i] = groupstr[1:-1]
+ groups[i + 1:j] = []
+ else:
+ if not _groupname_re.match(groups[i]):
+ raise TomlDecodeError("Invalid group name '" +
+ groups[i] + "'. Try quoting it.",
+ original, pos)
+ i += 1
+ currentlevel = retval
+ for i in _range(len(groups)):
+ group = groups[i]
+ if group == "":
+ raise TomlDecodeError("Can't have a keygroup with an empty "
+ "name", original, pos)
+ try:
+ currentlevel[group]
+ if i == len(groups) - 1:
+ if group in implicitgroups:
+ implicitgroups.remove(group)
+ if arrayoftables:
+ raise TomlDecodeError("An implicitly defined "
+ "table can't be an array",
+ original, pos)
+ elif arrayoftables:
+ currentlevel[group].append(decoder.get_empty_table()
+ )
+ else:
+ raise TomlDecodeError("What? " + group +
+ " already exists?" +
+ str(currentlevel),
+ original, pos)
+ except TypeError:
+ currentlevel = currentlevel[-1]
+ if group not in currentlevel:
+ currentlevel[group] = decoder.get_empty_table()
+ if i == len(groups) - 1 and arrayoftables:
+ currentlevel[group] = [decoder.get_empty_table()]
+ except KeyError:
+ if i != len(groups) - 1:
+ implicitgroups.append(group)
+ currentlevel[group] = decoder.get_empty_table()
+ if i == len(groups) - 1 and arrayoftables:
+ currentlevel[group] = [decoder.get_empty_table()]
+ currentlevel = currentlevel[group]
+ if arrayoftables:
+ try:
+ currentlevel = currentlevel[-1]
+ except KeyError:
+ pass
+ elif line[0] == "{":
+ if line[-1] != "}":
+ raise TomlDecodeError("Line breaks are not allowed in inline"
+ "objects", original, pos)
+ try:
+ decoder.load_inline_object(line, currentlevel, multikey,
+ multibackslash)
+ except ValueError as err:
+ raise TomlDecodeError(str(err), original, pos)
+ elif "=" in line:
+ try:
+ ret = decoder.load_line(line, currentlevel, multikey,
+ multibackslash)
+ except ValueError as err:
+ raise TomlDecodeError(str(err), original, pos)
+ if ret is not None:
+ multikey, multilinestr, multibackslash = ret
+ return retval
+
+
+def _load_date(val):
+ microsecond = 0
+ tz = None
+ try:
+ if len(val) > 19:
+ if val[19] == '.':
+ if val[-1].upper() == 'Z':
+ subsecondval = val[20:-1]
+ tzval = "Z"
+ else:
+ subsecondvalandtz = val[20:]
+ if '+' in subsecondvalandtz:
+ splitpoint = subsecondvalandtz.index('+')
+ subsecondval = subsecondvalandtz[:splitpoint]
+ tzval = subsecondvalandtz[splitpoint:]
+ elif '-' in subsecondvalandtz:
+ splitpoint = subsecondvalandtz.index('-')
+ subsecondval = subsecondvalandtz[:splitpoint]
+ tzval = subsecondvalandtz[splitpoint:]
+ else:
+ tzval = None
+ subsecondval = subsecondvalandtz
+ if tzval is not None:
+ tz = TomlTz(tzval)
+ microsecond = int(int(subsecondval) *
+ (10 ** (6 - len(subsecondval))))
+ else:
+ tz = TomlTz(val[19:])
+ except ValueError:
+ tz = None
+ if "-" not in val[1:]:
+ return None
+ try:
+ if len(val) == 10:
+ d = datetime.date(
+ int(val[:4]), int(val[5:7]),
+ int(val[8:10]))
+ else:
+ d = datetime.datetime(
+ int(val[:4]), int(val[5:7]),
+ int(val[8:10]), int(val[11:13]),
+ int(val[14:16]), int(val[17:19]), microsecond, tz)
+ except ValueError:
+ return None
+ return d
+
+
+def _load_unicode_escapes(v, hexbytes, prefix):
+ skip = False
+ i = len(v) - 1
+ while i > -1 and v[i] == '\\':
+ skip = not skip
+ i -= 1
+ for hx in hexbytes:
+ if skip:
+ skip = False
+ i = len(hx) - 1
+ while i > -1 and hx[i] == '\\':
+ skip = not skip
+ i -= 1
+ v += prefix
+ v += hx
+ continue
+ hxb = ""
+ i = 0
+ hxblen = 4
+ if prefix == "\\U":
+ hxblen = 8
+ hxb = ''.join(hx[i:i + hxblen]).lower()
+ if hxb.strip('0123456789abcdef'):
+ raise ValueError("Invalid escape sequence: " + hxb)
+ if hxb[0] == "d" and hxb[1].strip('01234567'):
+ raise ValueError("Invalid escape sequence: " + hxb +
+ ". Only scalar unicode points are allowed.")
+ v += unichr(int(hxb, 16))
+ v += unicode(hx[len(hxb):])
+ return v
+
+
+# Unescape TOML string values.
+
+# content after the \
+_escapes = ['0', 'b', 'f', 'n', 'r', 't', '"']
+# What it should be replaced by
+_escapedchars = ['\0', '\b', '\f', '\n', '\r', '\t', '\"']
+# Used for substitution
+_escape_to_escapedchars = dict(zip(_escapes, _escapedchars))
+
+
+def _unescape(v):
+ """Unescape characters in a TOML string."""
+ i = 0
+ backslash = False
+ while i < len(v):
+ if backslash:
+ backslash = False
+ if v[i] in _escapes:
+ v = v[:i - 1] + _escape_to_escapedchars[v[i]] + v[i + 1:]
+ elif v[i] == '\\':
+ v = v[:i - 1] + v[i:]
+ elif v[i] == 'u' or v[i] == 'U':
+ i += 1
+ else:
+ raise ValueError("Reserved escape sequence used")
+ continue
+ elif v[i] == '\\':
+ backslash = True
+ i += 1
+ return v
+
+
+class InlineTableDict(object):
+ """Sentinel subclass of dict for inline tables."""
+
+
+class TomlDecoder(object):
+
+ def __init__(self, _dict=dict):
+ self._dict = _dict
+
+ def get_empty_table(self):
+ return self._dict()
+
+ def get_empty_inline_table(self):
+ class DynamicInlineTableDict(self._dict, InlineTableDict):
+ """Concrete sentinel subclass for inline tables.
+ It is a subclass of _dict which is passed in dynamically at load
+ time
+
+ It is also a subclass of InlineTableDict
+ """
+
+ return DynamicInlineTableDict()
+
+ def load_inline_object(self, line, currentlevel, multikey=False,
+ multibackslash=False):
+ candidate_groups = line[1:-1].split(",")
+ groups = []
+ if len(candidate_groups) == 1 and not candidate_groups[0].strip():
+ candidate_groups.pop()
+ while len(candidate_groups) > 0:
+ candidate_group = candidate_groups.pop(0)
+ try:
+ _, value = candidate_group.split('=', 1)
+ except ValueError:
+ raise ValueError("Invalid inline table encountered")
+ value = value.strip()
+ if ((value[0] == value[-1] and value[0] in ('"', "'")) or (
+ value[0] in '-0123456789' or
+ value in ('true', 'false') or
+ (value[0] == "[" and value[-1] == "]") or
+ (value[0] == '{' and value[-1] == '}'))):
+ groups.append(candidate_group)
+ elif len(candidate_groups) > 0:
+ candidate_groups[0] = (candidate_group + "," +
+ candidate_groups[0])
+ else:
+ raise ValueError("Invalid inline table value encountered")
+ for group in groups:
+ status = self.load_line(group, currentlevel, multikey,
+ multibackslash)
+ if status is not None:
+ break
+
+ def _get_split_on_quotes(self, line):
+ doublequotesplits = line.split('"')
+ quoted = False
+ quotesplits = []
+ if len(doublequotesplits) > 1 and "'" in doublequotesplits[0]:
+ singlequotesplits = doublequotesplits[0].split("'")
+ doublequotesplits = doublequotesplits[1:]
+ while len(singlequotesplits) % 2 == 0 and len(doublequotesplits):
+ singlequotesplits[-1] += '"' + doublequotesplits[0]
+ doublequotesplits = doublequotesplits[1:]
+ if "'" in singlequotesplits[-1]:
+ singlequotesplits = (singlequotesplits[:-1] +
+ singlequotesplits[-1].split("'"))
+ quotesplits += singlequotesplits
+ for doublequotesplit in doublequotesplits:
+ if quoted:
+ quotesplits.append(doublequotesplit)
+ else:
+ quotesplits += doublequotesplit.split("'")
+ quoted = not quoted
+ return quotesplits
+
+ def load_line(self, line, currentlevel, multikey, multibackslash):
+ i = 1
+ quotesplits = self._get_split_on_quotes(line)
+ quoted = False
+ for quotesplit in quotesplits:
+ if not quoted and '=' in quotesplit:
+ break
+ i += quotesplit.count('=')
+ quoted = not quoted
+ pair = line.split('=', i)
+ strictly_valid = _strictly_valid_num(pair[-1])
+ if _number_with_underscores.match(pair[-1]):
+ pair[-1] = pair[-1].replace('_', '')
+ while len(pair[-1]) and (pair[-1][0] != ' ' and pair[-1][0] != '\t' and
+ pair[-1][0] != "'" and pair[-1][0] != '"' and
+ pair[-1][0] != '[' and pair[-1][0] != '{' and
+ pair[-1].strip() != 'true' and
+ pair[-1].strip() != 'false'):
+ try:
+ float(pair[-1])
+ break
+ except ValueError:
+ pass
+ if _load_date(pair[-1]) is not None:
+ break
+ if TIME_RE.match(pair[-1]):
+ break
+ i += 1
+ prev_val = pair[-1]
+ pair = line.split('=', i)
+ if prev_val == pair[-1]:
+ raise ValueError("Invalid date or number")
+ if strictly_valid:
+ strictly_valid = _strictly_valid_num(pair[-1])
+ pair = ['='.join(pair[:-1]).strip(), pair[-1].strip()]
+ if '.' in pair[0]:
+ if '"' in pair[0] or "'" in pair[0]:
+ quotesplits = self._get_split_on_quotes(pair[0])
+ quoted = False
+ levels = []
+ for quotesplit in quotesplits:
+ if quoted:
+ levels.append(quotesplit)
+ else:
+ levels += [level.strip() for level in
+ quotesplit.split('.')]
+ quoted = not quoted
+ else:
+ levels = pair[0].split('.')
+ while levels[-1] == "":
+ levels = levels[:-1]
+ for level in levels[:-1]:
+ if level == "":
+ continue
+ if level not in currentlevel:
+ currentlevel[level] = self.get_empty_table()
+ currentlevel = currentlevel[level]
+ pair[0] = levels[-1].strip()
+ elif (pair[0][0] == '"' or pair[0][0] == "'") and \
+ (pair[0][-1] == pair[0][0]):
+ pair[0] = _unescape(pair[0][1:-1])
+ k, koffset = self._load_line_multiline_str(pair[1])
+ if k > -1:
+ while k > -1 and pair[1][k + koffset] == '\\':
+ multibackslash = not multibackslash
+ k -= 1
+ if multibackslash:
+ multilinestr = pair[1][:-1]
+ else:
+ multilinestr = pair[1] + "\n"
+ multikey = pair[0]
+ else:
+ value, vtype = self.load_value(pair[1], strictly_valid)
+ try:
+ currentlevel[pair[0]]
+ raise ValueError("Duplicate keys!")
+ except TypeError:
+ raise ValueError("Duplicate keys!")
+ except KeyError:
+ if multikey:
+ return multikey, multilinestr, multibackslash
+ else:
+ currentlevel[pair[0]] = value
+
+ def _load_line_multiline_str(self, p):
+ poffset = 0
+ if len(p) < 3:
+ return -1, poffset
+ if p[0] == '[' and (p.strip()[-1] != ']' and
+ self._load_array_isstrarray(p)):
+ newp = p[1:].strip().split(',')
+ while len(newp) > 1 and newp[-1][0] != '"' and newp[-1][0] != "'":
+ newp = newp[:-2] + [newp[-2] + ',' + newp[-1]]
+ newp = newp[-1]
+ poffset = len(p) - len(newp)
+ p = newp
+ if p[0] != '"' and p[0] != "'":
+ return -1, poffset
+ if p[1] != p[0] or p[2] != p[0]:
+ return -1, poffset
+ if len(p) > 5 and p[-1] == p[0] and p[-2] == p[0] and p[-3] == p[0]:
+ return -1, poffset
+ return len(p) - 1, poffset
+
+ def load_value(self, v, strictly_valid=True):
+ if not v:
+ raise ValueError("Empty value is invalid")
+ if v == 'true':
+ return (True, "bool")
+ elif v.lower() == 'true':
+ raise ValueError("Only all lowercase booleans allowed")
+ elif v == 'false':
+ return (False, "bool")
+ elif v.lower() == 'false':
+ raise ValueError("Only all lowercase booleans allowed")
+ elif v[0] == '"' or v[0] == "'":
+ quotechar = v[0]
+ testv = v[1:].split(quotechar)
+ triplequote = False
+ triplequotecount = 0
+ if len(testv) > 1 and testv[0] == '' and testv[1] == '':
+ testv = testv[2:]
+ triplequote = True
+ closed = False
+ for tv in testv:
+ if tv == '':
+ if triplequote:
+ triplequotecount += 1
+ else:
+ closed = True
+ else:
+ oddbackslash = False
+ try:
+ i = -1
+ j = tv[i]
+ while j == '\\':
+ oddbackslash = not oddbackslash
+ i -= 1
+ j = tv[i]
+ except IndexError:
+ pass
+ if not oddbackslash:
+ if closed:
+ raise ValueError("Found tokens after a closed " +
+ "string. Invalid TOML.")
+ else:
+ if not triplequote or triplequotecount > 1:
+ closed = True
+ else:
+ triplequotecount = 0
+ if quotechar == '"':
+ escapeseqs = v.split('\\')[1:]
+ backslash = False
+ for i in escapeseqs:
+ if i == '':
+ backslash = not backslash
+ else:
+ if i[0] not in _escapes and (i[0] != 'u' and
+ i[0] != 'U' and
+ not backslash):
+ raise ValueError("Reserved escape sequence used")
+ if backslash:
+ backslash = False
+ for prefix in ["\\u", "\\U"]:
+ if prefix in v:
+ hexbytes = v.split(prefix)
+ v = _load_unicode_escapes(hexbytes[0], hexbytes[1:],
+ prefix)
+ v = _unescape(v)
+ if len(v) > 1 and v[1] == quotechar and (len(v) < 3 or
+ v[1] == v[2]):
+ v = v[2:-2]
+ return (v[1:-1], "str")
+ elif v[0] == '[':
+ return (self.load_array(v), "array")
+ elif v[0] == '{':
+ inline_object = self.get_empty_inline_table()
+ self.load_inline_object(v, inline_object)
+ return (inline_object, "inline_object")
+ elif TIME_RE.match(v):
+ h, m, s, _, ms = TIME_RE.match(v).groups()
+ time = datetime.time(int(h), int(m), int(s), int(ms) if ms else 0)
+ return (time, "time")
+ else:
+ parsed_date = _load_date(v)
+ if parsed_date is not None:
+ return (parsed_date, "date")
+ if not strictly_valid:
+ raise ValueError("Weirdness with leading zeroes or "
+ "underscores in your number.")
+ itype = "int"
+ neg = False
+ if v[0] == '-':
+ neg = True
+ v = v[1:]
+ elif v[0] == '+':
+ v = v[1:]
+ v = v.replace('_', '')
+ lowerv = v.lower()
+ if '.' in v or ('x' not in v and ('e' in v or 'E' in v)):
+ if '.' in v and v.split('.', 1)[1] == '':
+ raise ValueError("This float is missing digits after "
+ "the point")
+ if v[0] not in '0123456789':
+ raise ValueError("This float doesn't have a leading "
+ "digit")
+ v = float(v)
+ itype = "float"
+ elif len(lowerv) == 3 and (lowerv == 'inf' or lowerv == 'nan'):
+ v = float(v)
+ itype = "float"
+ if itype == "int":
+ v = int(v, 0)
+ if neg:
+ return (0 - v, itype)
+ return (v, itype)
+
+ def bounded_string(self, s):
+ if len(s) == 0:
+ return True
+ if s[-1] != s[0]:
+ return False
+ i = -2
+ backslash = False
+ while len(s) + i > 0:
+ if s[i] == "\\":
+ backslash = not backslash
+ i -= 1
+ else:
+ break
+ return not backslash
+
+ def _load_array_isstrarray(self, a):
+ a = a[1:-1].strip()
+ if a != '' and (a[0] == '"' or a[0] == "'"):
+ return True
+ return False
+
+ def load_array(self, a):
+ atype = None
+ retval = []
+ a = a.strip()
+ if '[' not in a[1:-1] or "" != a[1:-1].split('[')[0].strip():
+ strarray = self._load_array_isstrarray(a)
+ if not a[1:-1].strip().startswith('{'):
+ a = a[1:-1].split(',')
+ else:
+ # a is an inline object, we must find the matching parenthesis
+ # to define groups
+ new_a = []
+ start_group_index = 1
+ end_group_index = 2
+ open_bracket_count = 1 if a[start_group_index] == '{' else 0
+ in_str = False
+ while end_group_index < len(a[1:]):
+ if a[end_group_index] == '"' or a[end_group_index] == "'":
+ if in_str:
+ backslash_index = end_group_index - 1
+ while (backslash_index > -1 and
+ a[backslash_index] == '\\'):
+ in_str = not in_str
+ backslash_index -= 1
+ in_str = not in_str
+ if not in_str and a[end_group_index] == '{':
+ open_bracket_count += 1
+ if in_str or a[end_group_index] != '}':
+ end_group_index += 1
+ continue
+ elif a[end_group_index] == '}' and open_bracket_count > 1:
+ open_bracket_count -= 1
+ end_group_index += 1
+ continue
+
+ # Increase end_group_index by 1 to get the closing bracket
+ end_group_index += 1
+
+ new_a.append(a[start_group_index:end_group_index])
+
+ # The next start index is at least after the closing
+ # bracket, a closing bracket can be followed by a comma
+ # since we are in an array.
+ start_group_index = end_group_index + 1
+ while (start_group_index < len(a[1:]) and
+ a[start_group_index] != '{'):
+ start_group_index += 1
+ end_group_index = start_group_index + 1
+ a = new_a
+ b = 0
+ if strarray:
+ while b < len(a) - 1:
+ ab = a[b].strip()
+ while (not self.bounded_string(ab) or
+ (len(ab) > 2 and
+ ab[0] == ab[1] == ab[2] and
+ ab[-2] != ab[0] and
+ ab[-3] != ab[0])):
+ a[b] = a[b] + ',' + a[b + 1]
+ ab = a[b].strip()
+ if b < len(a) - 2:
+ a = a[:b + 1] + a[b + 2:]
+ else:
+ a = a[:b + 1]
+ b += 1
+ else:
+ al = list(a[1:-1])
+ a = []
+ openarr = 0
+ j = 0
+ for i in _range(len(al)):
+ if al[i] == '[':
+ openarr += 1
+ elif al[i] == ']':
+ openarr -= 1
+ elif al[i] == ',' and not openarr:
+ a.append(''.join(al[j:i]))
+ j = i + 1
+ a.append(''.join(al[j:]))
+ for i in _range(len(a)):
+ a[i] = a[i].strip()
+ if a[i] != '':
+ nval, ntype = self.load_value(a[i])
+ if atype:
+ if ntype != atype:
+ raise ValueError("Not a homogeneous array")
+ else:
+ atype = ntype
+ retval.append(nval)
+ return retval
+
+ def preserve_comment(self, line_no, key, comment, beginline):
+ pass
+
+ def embed_comments(self, idx, currentlevel):
+ pass
+
+
+class TomlPreserveCommentDecoder(TomlDecoder):
+
+ def __init__(self, _dict=dict):
+ self.saved_comments = {}
+ super(TomlPreserveCommentDecoder, self).__init__(_dict)
+
+ def preserve_comment(self, line_no, key, comment, beginline):
+ self.saved_comments[line_no] = (key, comment, beginline)
+
+ def embed_comments(self, idx, currentlevel):
+ if idx not in self.saved_comments:
+ return
+
+ key, comment, beginline = self.saved_comments[idx]
+ currentlevel[key] = CommentValue(currentlevel[key], comment, beginline,
+ self._dict)
diff --git a/contrib/python/toml/toml/decoder.pyi b/contrib/python/toml/toml/decoder.pyi
new file mode 100644
index 0000000000..967d3dd15a
--- /dev/null
+++ b/contrib/python/toml/toml/decoder.pyi
@@ -0,0 +1,52 @@
+from toml.tz import TomlTz as TomlTz
+from typing import Any, Optional
+
+unicode = str
+basestring = str
+unichr = chr
+FNFError = FileNotFoundError
+FNFError = IOError
+TIME_RE: Any
+
+class TomlDecodeError(ValueError):
+ msg: Any = ...
+ doc: Any = ...
+ pos: Any = ...
+ lineno: Any = ...
+ colno: Any = ...
+ def __init__(self, msg: Any, doc: Any, pos: Any) -> None: ...
+
+class CommentValue:
+ val: Any = ...
+ comment: Any = ...
+ def __init__(self, val: Any, comment: Any, beginline: Any, _dict: Any) -> None: ...
+ def __getitem__(self, key: Any): ...
+ def __setitem__(self, key: Any, value: Any) -> None: ...
+ def dump(self, dump_value_func: Any): ...
+
+def load(f: Union[str, list, IO[str]],
+ _dict: Type[MutableMapping[str, Any]] = ...,
+ decoder: TomlDecoder = ...) \
+ -> MutableMapping[str, Any]: ...
+def loads(s: str, _dict: Type[MutableMapping[str, Any]] = ..., decoder: TomlDecoder = ...) \
+ -> MutableMapping[str, Any]: ...
+
+class InlineTableDict: ...
+
+class TomlDecoder:
+ def __init__(self, _dict: Any = ...) -> None: ...
+ def get_empty_table(self): ...
+ def get_empty_inline_table(self): ...
+ def load_inline_object(self, line: Any, currentlevel: Any, multikey: bool = ..., multibackslash: bool = ...) -> None: ...
+ def load_line(self, line: Any, currentlevel: Any, multikey: Any, multibackslash: Any): ...
+ def load_value(self, v: Any, strictly_valid: bool = ...): ...
+ def bounded_string(self, s: Any): ...
+ def load_array(self, a: Any): ...
+ def preserve_comment(self, line_no: Any, key: Any, comment: Any, beginline: Any) -> None: ...
+ def embed_comments(self, idx: Any, currentlevel: Any) -> None: ...
+
+class TomlPreserveCommentDecoder(TomlDecoder):
+ saved_comments: Any = ...
+ def __init__(self, _dict: Any = ...) -> None: ...
+ def preserve_comment(self, line_no: Any, key: Any, comment: Any, beginline: Any) -> None: ...
+ def embed_comments(self, idx: Any, currentlevel: Any) -> None: ...
diff --git a/contrib/python/toml/toml/encoder.py b/contrib/python/toml/toml/encoder.py
new file mode 100644
index 0000000000..bf17a72b62
--- /dev/null
+++ b/contrib/python/toml/toml/encoder.py
@@ -0,0 +1,304 @@
+import datetime
+import re
+import sys
+from decimal import Decimal
+
+from toml.decoder import InlineTableDict
+
+if sys.version_info >= (3,):
+ unicode = str
+
+
+def dump(o, f, encoder=None):
+ """Writes out dict as toml to a file
+
+ Args:
+ o: Object to dump into toml
+ f: File descriptor where the toml should be stored
+ encoder: The ``TomlEncoder`` to use for constructing the output string
+
+ Returns:
+ String containing the toml corresponding to dictionary
+
+ Raises:
+ TypeError: When anything other than file descriptor is passed
+ """
+
+ if not f.write:
+ raise TypeError("You can only dump an object to a file descriptor")
+ d = dumps(o, encoder=encoder)
+ f.write(d)
+ return d
+
+
+def dumps(o, encoder=None):
+ """Stringifies input dict as toml
+
+ Args:
+ o: Object to dump into toml
+ encoder: The ``TomlEncoder`` to use for constructing the output string
+
+ Returns:
+ String containing the toml corresponding to dict
+
+ Examples:
+ ```python
+ >>> import toml
+ >>> output = {
+ ... 'a': "I'm a string",
+ ... 'b': ["I'm", "a", "list"],
+ ... 'c': 2400
+ ... }
+ >>> toml.dumps(output)
+ 'a = "I\'m a string"\nb = [ "I\'m", "a", "list",]\nc = 2400\n'
+ ```
+ """
+
+ retval = ""
+ if encoder is None:
+ encoder = TomlEncoder(o.__class__)
+ addtoretval, sections = encoder.dump_sections(o, "")
+ retval += addtoretval
+ outer_objs = [id(o)]
+ while sections:
+ section_ids = [id(section) for section in sections.values()]
+ for outer_obj in outer_objs:
+ if outer_obj in section_ids:
+ raise ValueError("Circular reference detected")
+ outer_objs += section_ids
+ newsections = encoder.get_empty_table()
+ for section in sections:
+ addtoretval, addtosections = encoder.dump_sections(
+ sections[section], section)
+
+ if addtoretval or (not addtoretval and not addtosections):
+ if retval and retval[-2:] != "\n\n":
+ retval += "\n"
+ retval += "[" + section + "]\n"
+ if addtoretval:
+ retval += addtoretval
+ for s in addtosections:
+ newsections[section + "." + s] = addtosections[s]
+ sections = newsections
+ return retval
+
+
+def _dump_str(v):
+ if sys.version_info < (3,) and hasattr(v, 'decode') and isinstance(v, str):
+ v = v.decode('utf-8')
+ v = "%r" % v
+ if v[0] == 'u':
+ v = v[1:]
+ singlequote = v.startswith("'")
+ if singlequote or v.startswith('"'):
+ v = v[1:-1]
+ if singlequote:
+ v = v.replace("\\'", "'")
+ v = v.replace('"', '\\"')
+ v = v.split("\\x")
+ while len(v) > 1:
+ i = -1
+ if not v[0]:
+ v = v[1:]
+ v[0] = v[0].replace("\\\\", "\\")
+ # No, I don't know why != works and == breaks
+ joinx = v[0][i] != "\\"
+ while v[0][:i] and v[0][i] == "\\":
+ joinx = not joinx
+ i -= 1
+ if joinx:
+ joiner = "x"
+ else:
+ joiner = "u00"
+ v = [v[0] + joiner + v[1]] + v[2:]
+ return unicode('"' + v[0] + '"')
+
+
+def _dump_float(v):
+ return "{}".format(v).replace("e+0", "e+").replace("e-0", "e-")
+
+
+def _dump_time(v):
+ utcoffset = v.utcoffset()
+ if utcoffset is None:
+ return v.isoformat()
+ # The TOML norm specifies that it's local time thus we drop the offset
+ return v.isoformat()[:-6]
+
+
+class TomlEncoder(object):
+
+ def __init__(self, _dict=dict, preserve=False):
+ self._dict = _dict
+ self.preserve = preserve
+ self.dump_funcs = {
+ str: _dump_str,
+ unicode: _dump_str,
+ list: self.dump_list,
+ bool: lambda v: unicode(v).lower(),
+ int: lambda v: v,
+ float: _dump_float,
+ Decimal: _dump_float,
+ datetime.datetime: lambda v: v.isoformat().replace('+00:00', 'Z'),
+ datetime.time: _dump_time,
+ datetime.date: lambda v: v.isoformat()
+ }
+
+ def get_empty_table(self):
+ return self._dict()
+
+ def dump_list(self, v):
+ retval = "["
+ for u in v:
+ retval += " " + unicode(self.dump_value(u)) + ","
+ retval += "]"
+ return retval
+
+ def dump_inline_table(self, section):
+ """Preserve inline table in its compact syntax instead of expanding
+ into subsection.
+
+ https://github.com/toml-lang/toml#user-content-inline-table
+ """
+ retval = ""
+ if isinstance(section, dict):
+ val_list = []
+ for k, v in section.items():
+ val = self.dump_inline_table(v)
+ val_list.append(k + " = " + val)
+ retval += "{ " + ", ".join(val_list) + " }\n"
+ return retval
+ else:
+ return unicode(self.dump_value(section))
+
+ def dump_value(self, v):
+ # Lookup function corresponding to v's type
+ dump_fn = self.dump_funcs.get(type(v))
+ if dump_fn is None and hasattr(v, '__iter__'):
+ dump_fn = self.dump_funcs[list]
+ # Evaluate function (if it exists) else return v
+ return dump_fn(v) if dump_fn is not None else self.dump_funcs[str](v)
+
+ def dump_sections(self, o, sup):
+ retstr = ""
+ if sup != "" and sup[-1] != ".":
+ sup += '.'
+ retdict = self._dict()
+ arraystr = ""
+ for section in o:
+ section = unicode(section)
+ qsection = section
+ if not re.match(r'^[A-Za-z0-9_-]+$', section):
+ qsection = _dump_str(section)
+ if not isinstance(o[section], dict):
+ arrayoftables = False
+ if isinstance(o[section], list):
+ for a in o[section]:
+ if isinstance(a, dict):
+ arrayoftables = True
+ if arrayoftables:
+ for a in o[section]:
+ arraytabstr = "\n"
+ arraystr += "[[" + sup + qsection + "]]\n"
+ s, d = self.dump_sections(a, sup + qsection)
+ if s:
+ if s[0] == "[":
+ arraytabstr += s
+ else:
+ arraystr += s
+ while d:
+ newd = self._dict()
+ for dsec in d:
+ s1, d1 = self.dump_sections(d[dsec], sup +
+ qsection + "." +
+ dsec)
+ if s1:
+ arraytabstr += ("[" + sup + qsection +
+ "." + dsec + "]\n")
+ arraytabstr += s1
+ for s1 in d1:
+ newd[dsec + "." + s1] = d1[s1]
+ d = newd
+ arraystr += arraytabstr
+ else:
+ if o[section] is not None:
+ retstr += (qsection + " = " +
+ unicode(self.dump_value(o[section])) + '\n')
+ elif self.preserve and isinstance(o[section], InlineTableDict):
+ retstr += (qsection + " = " +
+ self.dump_inline_table(o[section]))
+ else:
+ retdict[qsection] = o[section]
+ retstr += arraystr
+ return (retstr, retdict)
+
+
+class TomlPreserveInlineDictEncoder(TomlEncoder):
+
+ def __init__(self, _dict=dict):
+ super(TomlPreserveInlineDictEncoder, self).__init__(_dict, True)
+
+
+class TomlArraySeparatorEncoder(TomlEncoder):
+
+ def __init__(self, _dict=dict, preserve=False, separator=","):
+ super(TomlArraySeparatorEncoder, self).__init__(_dict, preserve)
+ if separator.strip() == "":
+ separator = "," + separator
+ elif separator.strip(' \t\n\r,'):
+ raise ValueError("Invalid separator for arrays")
+ self.separator = separator
+
+ def dump_list(self, v):
+ t = []
+ retval = "["
+ for u in v:
+ t.append(self.dump_value(u))
+ while t != []:
+ s = []
+ for u in t:
+ if isinstance(u, list):
+ for r in u:
+ s.append(r)
+ else:
+ retval += " " + unicode(u) + self.separator
+ t = s
+ retval += "]"
+ return retval
+
+
+class TomlNumpyEncoder(TomlEncoder):
+
+ def __init__(self, _dict=dict, preserve=False):
+ import numpy as np
+ super(TomlNumpyEncoder, self).__init__(_dict, preserve)
+ self.dump_funcs[np.float16] = _dump_float
+ self.dump_funcs[np.float32] = _dump_float
+ self.dump_funcs[np.float64] = _dump_float
+ self.dump_funcs[np.int16] = self._dump_int
+ self.dump_funcs[np.int32] = self._dump_int
+ self.dump_funcs[np.int64] = self._dump_int
+
+ def _dump_int(self, v):
+ return "{}".format(int(v))
+
+
+class TomlPreserveCommentEncoder(TomlEncoder):
+
+ def __init__(self, _dict=dict, preserve=False):
+ from toml.decoder import CommentValue
+ super(TomlPreserveCommentEncoder, self).__init__(_dict, preserve)
+ self.dump_funcs[CommentValue] = lambda v: v.dump(self.dump_value)
+
+
+class TomlPathlibEncoder(TomlEncoder):
+
+ def _dump_pathlib_path(self, v):
+ return _dump_str(str(v))
+
+ def dump_value(self, v):
+ if (3, 4) <= sys.version_info:
+ import pathlib
+ if isinstance(v, pathlib.PurePath):
+ v = str(v)
+ return super(TomlPathlibEncoder, self).dump_value(v)
diff --git a/contrib/python/toml/toml/encoder.pyi b/contrib/python/toml/toml/encoder.pyi
new file mode 100644
index 0000000000..194a3583ef
--- /dev/null
+++ b/contrib/python/toml/toml/encoder.pyi
@@ -0,0 +1,34 @@
+from toml.decoder import InlineTableDict as InlineTableDict
+from typing import Any, Optional
+
+unicode = str
+
+def dump(o: Mapping[str, Any], f: IO[str], encoder: TomlEncoder = ...) -> str: ...
+def dumps(o: Mapping[str, Any], encoder: TomlEncoder = ...) -> str: ...
+
+class TomlEncoder:
+ preserve: Any = ...
+ dump_funcs: Any = ...
+ def __init__(self, _dict: Any = ..., preserve: bool = ...): ...
+ def get_empty_table(self): ...
+ def dump_list(self, v: Any): ...
+ def dump_inline_table(self, section: Any): ...
+ def dump_value(self, v: Any): ...
+ def dump_sections(self, o: Any, sup: Any): ...
+
+class TomlPreserveInlineDictEncoder(TomlEncoder):
+ def __init__(self, _dict: Any = ...) -> None: ...
+
+class TomlArraySeparatorEncoder(TomlEncoder):
+ separator: Any = ...
+ def __init__(self, _dict: Any = ..., preserve: bool = ..., separator: str = ...) -> None: ...
+ def dump_list(self, v: Any): ...
+
+class TomlNumpyEncoder(TomlEncoder):
+ def __init__(self, _dict: Any = ..., preserve: bool = ...) -> None: ...
+
+class TomlPreserveCommentEncoder(TomlEncoder):
+ def __init__(self, _dict: Any = ..., preserve: bool = ...): ...
+
+class TomlPathlibEncoder(TomlEncoder):
+ def dump_value(self, v: Any): ...
diff --git a/contrib/python/toml/toml/ordered.py b/contrib/python/toml/toml/ordered.py
new file mode 100644
index 0000000000..9c20c41a1b
--- /dev/null
+++ b/contrib/python/toml/toml/ordered.py
@@ -0,0 +1,15 @@
+from collections import OrderedDict
+from toml import TomlEncoder
+from toml import TomlDecoder
+
+
+class TomlOrderedDecoder(TomlDecoder):
+
+ def __init__(self):
+ super(self.__class__, self).__init__(_dict=OrderedDict)
+
+
+class TomlOrderedEncoder(TomlEncoder):
+
+ def __init__(self):
+ super(self.__class__, self).__init__(_dict=OrderedDict)
diff --git a/contrib/python/toml/toml/ordered.pyi b/contrib/python/toml/toml/ordered.pyi
new file mode 100644
index 0000000000..0f4292dc94
--- /dev/null
+++ b/contrib/python/toml/toml/ordered.pyi
@@ -0,0 +1,7 @@
+from toml import TomlDecoder as TomlDecoder, TomlEncoder as TomlEncoder
+
+class TomlOrderedDecoder(TomlDecoder):
+ def __init__(self) -> None: ...
+
+class TomlOrderedEncoder(TomlEncoder):
+ def __init__(self) -> None: ...
diff --git a/contrib/python/toml/toml/tz.py b/contrib/python/toml/toml/tz.py
new file mode 100644
index 0000000000..bf20593a26
--- /dev/null
+++ b/contrib/python/toml/toml/tz.py
@@ -0,0 +1,24 @@
+from datetime import tzinfo, timedelta
+
+
+class TomlTz(tzinfo):
+ def __init__(self, toml_offset):
+ if toml_offset == "Z":
+ self._raw_offset = "+00:00"
+ else:
+ self._raw_offset = toml_offset
+ self._sign = -1 if self._raw_offset[0] == '-' else 1
+ self._hours = int(self._raw_offset[1:3])
+ self._minutes = int(self._raw_offset[4:6])
+
+ def __deepcopy__(self, memo):
+ return self.__class__(self._raw_offset)
+
+ def tzname(self, dt):
+ return "UTC" + self._raw_offset
+
+ def utcoffset(self, dt):
+ return self._sign * timedelta(hours=self._hours, minutes=self._minutes)
+
+ def dst(self, dt):
+ return timedelta(0)
diff --git a/contrib/python/toml/toml/tz.pyi b/contrib/python/toml/toml/tz.pyi
new file mode 100644
index 0000000000..fe37aead6e
--- /dev/null
+++ b/contrib/python/toml/toml/tz.pyi
@@ -0,0 +1,9 @@
+from datetime import tzinfo
+from typing import Any
+
+class TomlTz(tzinfo):
+ def __init__(self, toml_offset: Any) -> None: ...
+ def __deepcopy__(self, memo: Any): ...
+ def tzname(self, dt: Any): ...
+ def utcoffset(self, dt: Any): ...
+ def dst(self, dt: Any): ...
diff --git a/contrib/python/toml/ya.make b/contrib/python/toml/ya.make
new file mode 100644
index 0000000000..104e501e8e
--- /dev/null
+++ b/contrib/python/toml/ya.make
@@ -0,0 +1,31 @@
+OWNER(g:python-contrib)
+
+PY23_LIBRARY()
+
+LICENSE(MIT)
+
+VERSION(0.10.2)
+
+PY_SRCS(
+ TOP_LEVEL
+ toml/__init__.py
+ toml/decoder.py
+ toml/encoder.py
+ toml/ordered.py
+ toml/tz.py
+)
+
+RESOURCE_FILES(
+ PREFIX contrib/python/toml/
+ .dist-info/METADATA
+ .dist-info/top_level.txt
+ toml/__init__.pyi
+ toml/decoder.pyi
+ toml/encoder.pyi
+ toml/ordered.pyi
+ toml/tz.pyi
+)
+
+NO_LINT()
+
+END()
diff --git a/contrib/python/ya.make b/contrib/python/ya.make
index be27f59479..0577a9fb1f 100644
--- a/contrib/python/ya.make
+++ b/contrib/python/ya.make
@@ -944,7 +944,7 @@ RECURSE(
requests-file
requests-mock
requests-oauthlib
- requests_toolbelt
+ requests-toolbelt
requests-unixsocket
responses
respx