diff options
| author | robot-contrib <[email protected]> | 2025-11-11 13:17:49 +0300 |
|---|---|---|
| committer | robot-contrib <[email protected]> | 2025-11-11 13:40:18 +0300 |
| commit | 452ab533ea3ab6d559d4df083f5d43466e764dd9 (patch) | |
| tree | cdfdda3d7bd4c51aa1da18df54c9375fed7c2751 /contrib/python | |
| parent | 679fed2244780e5d17597bbe39ef7ae7dcca9073 (diff) | |
Update contrib/python/Jinja2/py3 to 3.1.6
commit_hash:244282cfaf643bf8c4718c0d6cb16149761365a3
Diffstat (limited to 'contrib/python')
144 files changed, 28065 insertions, 311 deletions
diff --git a/contrib/python/Jinja2/py3/.dist-info/METADATA b/contrib/python/Jinja2/py3/.dist-info/METADATA index 265cc32e135..ffef2ff3bfa 100644 --- a/contrib/python/Jinja2/py3/.dist-info/METADATA +++ b/contrib/python/Jinja2/py3/.dist-info/METADATA @@ -1,6 +1,6 @@ -Metadata-Version: 2.1 +Metadata-Version: 2.4 Name: Jinja2 -Version: 3.1.4 +Version: 3.1.6 Summary: A very fast and expressive template engine. Maintainer-email: Pallets <[email protected]> Requires-Python: >=3.7 @@ -14,6 +14,7 @@ Classifier: Programming Language :: Python Classifier: Topic :: Internet :: WWW/HTTP :: Dynamic Content Classifier: Topic :: Text Processing :: Markup :: HTML Classifier: Typing :: Typed +License-File: LICENSE.txt Requires-Dist: MarkupSafe>=2.0 Requires-Dist: Babel>=2.7 ; extra == "i18n" Project-URL: Changes, https://jinja.palletsprojects.com/changes/ @@ -52,18 +53,17 @@ restricting functionality too much. ## In A Nutshell -.. code-block:: jinja - - {% extends "base.html" %} - {% block title %}Members{% endblock %} - {% block content %} - <ul> - {% for user in users %} - <li><a href="{{ user.url }}">{{ user.username }}</a></li> - {% endfor %} - </ul> - {% endblock %} - +```jinja +{% extends "base.html" %} +{% block title %}Members{% endblock %} +{% block content %} + <ul> + {% for user in users %} + <li><a href="{{ user.url }}">{{ user.username }}</a></li> + {% endfor %} + </ul> +{% endblock %} +``` ## Donate @@ -74,3 +74,11 @@ donate today][]. [please donate today]: https://palletsprojects.com/donate +## Contributing + +See our [detailed contributing documentation][contrib] for many ways to +contribute, including reporting issues, requesting features, asking or answering +questions, and making PRs. + +[contrib]: https://palletsprojects.com/contributing/ + diff --git a/contrib/python/Jinja2/py3/README.md b/contrib/python/Jinja2/py3/README.md index 330970b5948..d1a6870d08a 100644 --- a/contrib/python/Jinja2/py3/README.md +++ b/contrib/python/Jinja2/py3/README.md @@ -27,18 +27,17 @@ restricting functionality too much. ## In A Nutshell -.. code-block:: jinja - - {% extends "base.html" %} - {% block title %}Members{% endblock %} - {% block content %} - <ul> - {% for user in users %} - <li><a href="{{ user.url }}">{{ user.username }}</a></li> - {% endfor %} - </ul> - {% endblock %} - +```jinja +{% extends "base.html" %} +{% block title %}Members{% endblock %} +{% block content %} + <ul> + {% for user in users %} + <li><a href="{{ user.url }}">{{ user.username }}</a></li> + {% endfor %} + </ul> +{% endblock %} +``` ## Donate @@ -48,3 +47,11 @@ allow the maintainers to devote more time to the projects, [please donate today][]. [please donate today]: https://palletsprojects.com/donate + +## Contributing + +See our [detailed contributing documentation][contrib] for many ways to +contribute, including reporting issues, requesting features, asking or answering +questions, and making PRs. + +[contrib]: https://palletsprojects.com/contributing/ diff --git a/contrib/python/Jinja2/py3/jinja2/__init__.py b/contrib/python/Jinja2/py3/jinja2/__init__.py index 720343c0c6c..7972659ffab 100644 --- a/contrib/python/Jinja2/py3/jinja2/__init__.py +++ b/contrib/python/Jinja2/py3/jinja2/__init__.py @@ -36,4 +36,4 @@ from .utils import pass_environment as pass_environment from .utils import pass_eval_context as pass_eval_context from .utils import select_autoescape as select_autoescape -__version__ = "3.1.4" +__version__ = "3.1.6" diff --git a/contrib/python/Jinja2/py3/jinja2/async_utils.py b/contrib/python/Jinja2/py3/jinja2/async_utils.py index e65219e497b..f0c140205c5 100644 --- a/contrib/python/Jinja2/py3/jinja2/async_utils.py +++ b/contrib/python/Jinja2/py3/jinja2/async_utils.py @@ -6,6 +6,9 @@ from functools import wraps from .utils import _PassArg from .utils import pass_eval_context +if t.TYPE_CHECKING: + import typing_extensions as te + V = t.TypeVar("V") @@ -64,18 +67,30 @@ async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V": if inspect.isawaitable(value): return await t.cast("t.Awaitable[V]", value) - return t.cast("V", value) + return value + + +class _IteratorToAsyncIterator(t.Generic[V]): + def __init__(self, iterator: "t.Iterator[V]"): + self._iterator = iterator + + def __aiter__(self) -> "te.Self": + return self + + async def __anext__(self) -> V: + try: + return next(self._iterator) + except StopIteration as e: + raise StopAsyncIteration(e.value) from e -async def auto_aiter( +def auto_aiter( iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", ) -> "t.AsyncIterator[V]": if hasattr(iterable, "__aiter__"): - async for item in t.cast("t.AsyncIterable[V]", iterable): - yield item + return iterable.__aiter__() else: - for item in iterable: - yield item + return _IteratorToAsyncIterator(iter(iterable)) async def auto_to_list( diff --git a/contrib/python/Jinja2/py3/jinja2/compiler.py b/contrib/python/Jinja2/py3/jinja2/compiler.py index 274071750f0..a4ff6a1b11a 100644 --- a/contrib/python/Jinja2/py3/jinja2/compiler.py +++ b/contrib/python/Jinja2/py3/jinja2/compiler.py @@ -55,7 +55,7 @@ def optimizeconst(f: F) -> F: return f(self, node, frame, **kwargs) - return update_wrapper(t.cast(F, new_func), f) + return update_wrapper(new_func, f) # type: ignore[return-value] def _make_binop(op: str) -> t.Callable[["CodeGenerator", nodes.BinExpr, "Frame"], None]: @@ -216,7 +216,7 @@ class Frame: # or compile time. self.soft_frame = False - def copy(self) -> "Frame": + def copy(self) -> "te.Self": """Create a copy of the current one.""" rv = object.__new__(self.__class__) rv.__dict__.update(self.__dict__) @@ -229,7 +229,7 @@ class Frame: return Frame(self.eval_ctx, level=self.symbols.level + 1) return Frame(self.eval_ctx, self) - def soft(self) -> "Frame": + def soft(self) -> "te.Self": """Return a soft frame. A soft frame may not be modified as standalone thing as it shares the resources with the frame it was created of, but it's not a rootlevel frame any longer. @@ -811,7 +811,7 @@ class CodeGenerator(NodeVisitor): self.writeline("_block_vars.update({") else: self.writeline("context.vars.update({") - for idx, name in enumerate(vars): + for idx, name in enumerate(sorted(vars)): if idx: self.write(", ") ref = frame.symbols.ref(name) @@ -821,7 +821,7 @@ class CodeGenerator(NodeVisitor): if len(public_names) == 1: self.writeline(f"context.exported_vars.add({public_names[0]!r})") else: - names_str = ", ".join(map(repr, public_names)) + names_str = ", ".join(map(repr, sorted(public_names))) self.writeline(f"context.exported_vars.update(({names_str}))") # -- Statement Visitors @@ -902,12 +902,15 @@ class CodeGenerator(NodeVisitor): if not self.environment.is_async: self.writeline("yield from parent_template.root_render_func(context)") else: - self.writeline( - "async for event in parent_template.root_render_func(context):" - ) + self.writeline("agen = parent_template.root_render_func(context)") + self.writeline("try:") + self.indent() + self.writeline("async for event in agen:") self.indent() self.writeline("yield event") self.outdent() + self.outdent() + self.writeline("finally: await agen.aclose()") self.outdent(1 + (not self.has_known_extends)) # at this point we now have the blocks collected and can visit them too. @@ -977,14 +980,20 @@ class CodeGenerator(NodeVisitor): f"yield from context.blocks[{node.name!r}][0]({context})", node ) else: + self.writeline(f"gen = context.blocks[{node.name!r}][0]({context})") + self.writeline("try:") + self.indent() self.writeline( - f"{self.choose_async()}for event in" - f" context.blocks[{node.name!r}][0]({context}):", + f"{self.choose_async()}for event in gen:", node, ) self.indent() self.simple_write("event", frame) self.outdent() + self.outdent() + self.writeline( + f"finally: {self.choose_async('await gen.aclose()', 'gen.close()')}" + ) self.outdent(level) @@ -1057,26 +1066,33 @@ class CodeGenerator(NodeVisitor): self.writeline("else:") self.indent() - skip_event_yield = False + def loop_body() -> None: + self.indent() + self.simple_write("event", frame) + self.outdent() + if node.with_context: self.writeline( - f"{self.choose_async()}for event in template.root_render_func(" + f"gen = template.root_render_func(" "template.new_context(context.get_all(), True," - f" {self.dump_local_context(frame)})):" + f" {self.dump_local_context(frame)}))" + ) + self.writeline("try:") + self.indent() + self.writeline(f"{self.choose_async()}for event in gen:") + loop_body() + self.outdent() + self.writeline( + f"finally: {self.choose_async('await gen.aclose()', 'gen.close()')}" ) elif self.environment.is_async: self.writeline( "for event in (await template._get_default_module_async())" "._body_stream:" ) + loop_body() else: self.writeline("yield from template._get_default_module()._body_stream") - skip_event_yield = True - - if not skip_event_yield: - self.indent() - self.simple_write("event", frame) - self.outdent() if node.ignore_missing: self.outdent() @@ -1125,9 +1141,14 @@ class CodeGenerator(NodeVisitor): ) self.writeline(f"if {frame.symbols.ref(alias)} is missing:") self.indent() + # The position will contain the template name, and will be formatted + # into a string that will be compiled into an f-string. Curly braces + # in the name must be replaced with escapes so that they will not be + # executed as part of the f-string. + position = self.position(node).replace("{", "{{").replace("}", "}}") message = ( "the template {included_template.__name__!r}" - f" (imported on {self.position(node)})" + f" (imported on {position})" f" does not export the requested name {name!r}" ) self.writeline( @@ -1560,6 +1581,29 @@ class CodeGenerator(NodeVisitor): def visit_Assign(self, node: nodes.Assign, frame: Frame) -> None: self.push_assign_tracking() + + # ``a.b`` is allowed for assignment, and is parsed as an NSRef. However, + # it is only valid if it references a Namespace object. Emit a check for + # that for each ref here, before assignment code is emitted. This can't + # be done in visit_NSRef as the ref could be in the middle of a tuple. + seen_refs: t.Set[str] = set() + + for nsref in node.find_all(nodes.NSRef): + if nsref.name in seen_refs: + # Only emit the check for each reference once, in case the same + # ref is used multiple times in a tuple, `ns.a, ns.b = c, d`. + continue + + seen_refs.add(nsref.name) + ref = frame.symbols.ref(nsref.name) + self.writeline(f"if not isinstance({ref}, Namespace):") + self.indent() + self.writeline( + "raise TemplateRuntimeError" + '("cannot assign attribute on non-namespace object")' + ) + self.outdent() + self.newline(node) self.visit(node.target, frame) self.write(" = ") @@ -1616,17 +1660,11 @@ class CodeGenerator(NodeVisitor): self.write(ref) def visit_NSRef(self, node: nodes.NSRef, frame: Frame) -> None: - # NSRefs can only be used to store values; since they use the normal - # `foo.bar` notation they will be parsed as a normal attribute access - # when used anywhere but in a `set` context + # NSRef is a dotted assignment target a.b=c, but uses a[b]=c internally. + # visit_Assign emits code to validate that each ref is to a Namespace + # object only. That can't be emitted here as the ref could be in the + # middle of a tuple assignment. ref = frame.symbols.ref(node.name) - self.writeline(f"if not isinstance({ref}, Namespace):") - self.indent() - self.writeline( - "raise TemplateRuntimeError" - '("cannot assign attribute on non-namespace object")' - ) - self.outdent() self.writeline(f"{ref}[{node.attr!r}]") def visit_Const(self, node: nodes.Const, frame: Frame) -> None: diff --git a/contrib/python/Jinja2/py3/jinja2/debug.py b/contrib/python/Jinja2/py3/jinja2/debug.py index 7ed7e9297e0..eeeeee78b62 100644 --- a/contrib/python/Jinja2/py3/jinja2/debug.py +++ b/contrib/python/Jinja2/py3/jinja2/debug.py @@ -152,7 +152,7 @@ def get_template_locals(real_locals: t.Mapping[str, t.Any]) -> t.Dict[str, t.Any available at that point in the template. """ # Start with the current template context. - ctx: "t.Optional[Context]" = real_locals.get("context") + ctx: t.Optional[Context] = real_locals.get("context") if ctx is not None: data: t.Dict[str, t.Any] = ctx.get_all().copy() diff --git a/contrib/python/Jinja2/py3/jinja2/environment.py b/contrib/python/Jinja2/py3/jinja2/environment.py index 1d3be0bed08..0fc6e5be87a 100644 --- a/contrib/python/Jinja2/py3/jinja2/environment.py +++ b/contrib/python/Jinja2/py3/jinja2/environment.py @@ -123,7 +123,7 @@ def load_extensions( return result -def _environment_config_check(environment: "Environment") -> "Environment": +def _environment_config_check(environment: _env_bound) -> _env_bound: """Perform a sanity check on the environment.""" assert issubclass( environment.undefined, Undefined @@ -406,8 +406,8 @@ class Environment: cache_size: int = missing, auto_reload: bool = missing, bytecode_cache: t.Optional["BytecodeCache"] = missing, - enable_async: bool = False, - ) -> "Environment": + enable_async: bool = missing, + ) -> "te.Self": """Create a new overlay environment that shares all the data with the current environment except for cache and the overridden attributes. Extensions cannot be removed for an overlayed environment. An overlayed @@ -419,8 +419,11 @@ class Environment: copied over so modifications on the original environment may not shine through. + .. versionchanged:: 3.1.5 + ``enable_async`` is applied correctly. + .. versionchanged:: 3.1.2 - Added the ``newline_sequence``,, ``keep_trailing_newline``, + Added the ``newline_sequence``, ``keep_trailing_newline``, and ``enable_async`` parameters to match ``__init__``. """ args = dict(locals()) @@ -706,7 +709,7 @@ class Environment: return compile(source, filename, "exec") @typing.overload - def compile( # type: ignore + def compile( self, source: t.Union[str, nodes.Template], name: t.Optional[str] = None, @@ -1248,7 +1251,7 @@ class Template: namespace: t.MutableMapping[str, t.Any], globals: t.MutableMapping[str, t.Any], ) -> "Template": - t: "Template" = object.__new__(cls) + t: Template = object.__new__(cls) t.environment = environment t.globals = globals t.name = namespace["name"] @@ -1282,19 +1285,7 @@ class Template: if self.environment.is_async: import asyncio - close = False - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - close = True - - try: - return loop.run_until_complete(self.render_async(*args, **kwargs)) - finally: - if close: - loop.close() + return asyncio.run(self.render_async(*args, **kwargs)) ctx = self.new_context(dict(*args, **kwargs)) @@ -1358,7 +1349,7 @@ class Template: async def generate_async( self, *args: t.Any, **kwargs: t.Any - ) -> t.AsyncIterator[str]: + ) -> t.AsyncGenerator[str, object]: """An async version of :meth:`generate`. Works very similarly but returns an async iterator instead. """ @@ -1370,8 +1361,14 @@ class Template: ctx = self.new_context(dict(*args, **kwargs)) try: - async for event in self.root_render_func(ctx): # type: ignore - yield event + agen = self.root_render_func(ctx) + try: + async for event in agen: # type: ignore + yield event + finally: + # we can't use async with aclosing(...) because that's only + # in 3.10+ + await agen.aclose() # type: ignore except Exception: yield self.environment.handle_exception() diff --git a/contrib/python/Jinja2/py3/jinja2/ext.py b/contrib/python/Jinja2/py3/jinja2/ext.py index 8d0810cd480..c7af8d45f06 100644 --- a/contrib/python/Jinja2/py3/jinja2/ext.py +++ b/contrib/python/Jinja2/py3/jinja2/ext.py @@ -89,7 +89,7 @@ class Extension: def __init__(self, environment: Environment) -> None: self.environment = environment - def bind(self, environment: Environment) -> "Extension": + def bind(self, environment: Environment) -> "te.Self": """Create a copy of this extension bound to another environment.""" rv = object.__new__(self.__class__) rv.__dict__.update(self.__dict__) diff --git a/contrib/python/Jinja2/py3/jinja2/filters.py b/contrib/python/Jinja2/py3/jinja2/filters.py index acd11976e4f..2bcba4fbd3c 100644 --- a/contrib/python/Jinja2/py3/jinja2/filters.py +++ b/contrib/python/Jinja2/py3/jinja2/filters.py @@ -6,6 +6,7 @@ import re import typing import typing as t from collections import abc +from inspect import getattr_static from itertools import chain from itertools import groupby @@ -438,7 +439,7 @@ def do_sort( @pass_environment -def do_unique( +def sync_do_unique( environment: "Environment", value: "t.Iterable[V]", case_sensitive: bool = False, @@ -470,6 +471,18 @@ def do_unique( yield item +@async_variant(sync_do_unique) # type: ignore +async def do_unique( + environment: "Environment", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + case_sensitive: bool = False, + attribute: t.Optional[t.Union[str, int]] = None, +) -> "t.Iterator[V]": + return sync_do_unique( + environment, await auto_to_list(value), case_sensitive, attribute + ) + + def _min_or_max( environment: "Environment", value: "t.Iterable[V]", @@ -987,7 +1000,7 @@ def do_int(value: t.Any, default: int = 0, base: int = 10) -> int: # this quirk is necessary so that "42.23"|int gives 42. try: return int(float(value)) - except (TypeError, ValueError): + except (TypeError, ValueError, OverflowError): return default @@ -1116,7 +1129,7 @@ def do_batch( {%- endfor %} </table> """ - tmp: "t.List[V]" = [] + tmp: t.List[V] = [] for item in value: if len(tmp) == linecount: @@ -1399,31 +1412,25 @@ def do_reverse(value: t.Union[str, t.Iterable[V]]) -> t.Union[str, t.Iterable[V] def do_attr( environment: "Environment", obj: t.Any, name: str ) -> t.Union[Undefined, t.Any]: - """Get an attribute of an object. ``foo|attr("bar")`` works like - ``foo.bar`` just that always an attribute is returned and items are not - looked up. + """Get an attribute of an object. ``foo|attr("bar")`` works like + ``foo.bar``, but returns undefined instead of falling back to ``foo["bar"]`` + if the attribute doesn't exist. See :ref:`Notes on subscriptions <notes-on-subscriptions>` for more details. """ + # Environment.getattr will fall back to obj[name] if obj.name doesn't exist. + # But we want to call env.getattr to get behavior such as sandboxing. + # Determine if the attr exists first, so we know the fallback won't trigger. try: - name = str(name) - except UnicodeError: - pass - else: - try: - value = getattr(obj, name) - except AttributeError: - pass - else: - if environment.sandboxed: - environment = t.cast("SandboxedEnvironment", environment) - - if not environment.is_safe_attribute(obj, name, value): - return environment.unsafe_undefined(obj, name) - - return value + # This avoids executing properties/descriptors, but misses __getattr__ + # and __getattribute__ dynamic attrs. + getattr_static(obj, name) + except AttributeError: + # This finds dynamic attrs, and we know it's not a descriptor at this point. + if not hasattr(obj, name): + return environment.undefined(obj=obj, name=name) - return environment.undefined(obj=obj, name=name) + return environment.getattr(obj, name) @typing.overload @@ -1629,8 +1636,8 @@ def sync_do_selectattr( .. code-block:: python - (u for user in users if user.is_active) - (u for user in users if test_none(user.email)) + (user for user in users if user.is_active) + (user for user in users if test_none(user.email)) .. versionadded:: 2.7 """ @@ -1667,8 +1674,8 @@ def sync_do_rejectattr( .. code-block:: python - (u for user in users if not user.is_active) - (u for user in users if not test_none(user.email)) + (user for user in users if not user.is_active) + (user for user in users if not test_none(user.email)) .. versionadded:: 2.7 """ @@ -1768,7 +1775,7 @@ def prepare_select_or_reject( args = args[1 + off :] def func(item: t.Any) -> t.Any: - return context.environment.call_test(name, item, args, kwargs) + return context.environment.call_test(name, item, args, kwargs, context) except LookupError: func = bool # type: ignore diff --git a/contrib/python/Jinja2/py3/jinja2/idtracking.py b/contrib/python/Jinja2/py3/jinja2/idtracking.py index 995ebaa0c81..e6dd8cd1110 100644 --- a/contrib/python/Jinja2/py3/jinja2/idtracking.py +++ b/contrib/python/Jinja2/py3/jinja2/idtracking.py @@ -3,6 +3,9 @@ import typing as t from . import nodes from .visitor import NodeVisitor +if t.TYPE_CHECKING: + import typing_extensions as te + VAR_LOAD_PARAMETER = "param" VAR_LOAD_RESOLVE = "resolve" VAR_LOAD_ALIAS = "alias" @@ -83,7 +86,7 @@ class Symbols: ) return rv - def copy(self) -> "Symbols": + def copy(self) -> "te.Self": rv = object.__new__(self.__class__) rv.__dict__.update(self.__dict__) rv.refs = self.refs.copy() @@ -118,23 +121,20 @@ class Symbols: self._define_ref(name, load=(VAR_LOAD_RESOLVE, name)) def branch_update(self, branch_symbols: t.Sequence["Symbols"]) -> None: - stores: t.Dict[str, int] = {} + stores: t.Set[str] = set() + for branch in branch_symbols: - for target in branch.stores: - if target in self.stores: - continue - stores[target] = stores.get(target, 0) + 1 + stores.update(branch.stores) + + stores.difference_update(self.stores) for sym in branch_symbols: self.refs.update(sym.refs) self.loads.update(sym.loads) self.stores.update(sym.stores) - for name, branch_count in stores.items(): - if branch_count == len(branch_symbols): - continue - - target = self.find_ref(name) # type: ignore + for name in stores: + target = self.find_ref(name) assert target is not None, "should not happen" if self.parent is not None: @@ -146,7 +146,7 @@ class Symbols: def dump_stores(self) -> t.Dict[str, str]: rv: t.Dict[str, str] = {} - node: t.Optional["Symbols"] = self + node: t.Optional[Symbols] = self while node is not None: for name in sorted(node.stores): @@ -159,7 +159,7 @@ class Symbols: def dump_param_targets(self) -> t.Set[str]: rv = set() - node: t.Optional["Symbols"] = self + node: t.Optional[Symbols] = self while node is not None: for target, (instr, _) in self.loads.items(): diff --git a/contrib/python/Jinja2/py3/jinja2/lexer.py b/contrib/python/Jinja2/py3/jinja2/lexer.py index 62b0471a3a6..9b1c969791f 100644 --- a/contrib/python/Jinja2/py3/jinja2/lexer.py +++ b/contrib/python/Jinja2/py3/jinja2/lexer.py @@ -262,7 +262,7 @@ class Failure: self.message = message self.error_class = cls - def __call__(self, lineno: int, filename: str) -> "te.NoReturn": + def __call__(self, lineno: int, filename: t.Optional[str]) -> "te.NoReturn": raise self.error_class(self.message, lineno, filename) @@ -329,7 +329,7 @@ class TokenStream: filename: t.Optional[str], ): self._iter = iter(generator) - self._pushed: "te.Deque[Token]" = deque() + self._pushed: te.Deque[Token] = deque() self.name = name self.filename = filename self.closed = False @@ -757,7 +757,7 @@ class Lexer: for idx, token in enumerate(tokens): # failure group - if token.__class__ is Failure: + if isinstance(token, Failure): raise token(lineno, filename) # bygroup is a bit more complex, in that case we # yield for the current token the first named @@ -778,7 +778,7 @@ class Lexer: data = groups[idx] if data or token not in ignore_if_empty: - yield lineno, token, data + yield lineno, token, data # type: ignore[misc] lineno += data.count("\n") + newlines_stripped newlines_stripped = 0 diff --git a/contrib/python/Jinja2/py3/jinja2/loaders.py b/contrib/python/Jinja2/py3/jinja2/loaders.py index 24951c35d2b..809b1dd7df5 100644 --- a/contrib/python/Jinja2/py3/jinja2/loaders.py +++ b/contrib/python/Jinja2/py3/jinja2/loaders.py @@ -7,6 +7,7 @@ import os import posixpath import sys import pkgutil +import importlib.resources import typing as t import weakref import zipimport @@ -207,7 +208,12 @@ class FileSystemLoader(BaseLoader): if os.path.isfile(filename): break else: - raise TemplateNotFound(template) + plural = "path" if len(self.searchpath) == 1 else "paths" + paths_str = ", ".join(repr(p) for p in self.searchpath) + raise TemplateNotFound( + template, + f"{template!r} not found in search {plural}: {paths_str}", + ) with open(filename, encoding=self.encoding) as f: contents = f.read() @@ -241,6 +247,30 @@ class FileSystemLoader(BaseLoader): return sorted(found) +if sys.version_info >= (3, 13): + + def _get_zipimporter_files(z: t.Any) -> t.Dict[str, object]: + try: + get_files = z._get_files + except AttributeError as e: + raise TypeError( + "This zip import does not have the required" + " metadata to list templates." + ) from e + return get_files() +else: + + def _get_zipimporter_files(z: t.Any) -> t.Dict[str, object]: + try: + files = z._files + except AttributeError as e: + raise TypeError( + "This zip import does not have the required" + " metadata to list templates." + ) from e + return files # type: ignore[no-any-return] + + class PackageLoader(BaseLoader): """Load templates from a directory in a Python package. @@ -280,6 +310,7 @@ class PackageLoader(BaseLoader): package_path: "str" = "templates", encoding: str = "utf-8", skip_unknown_package: bool = False, + check_templates: bool = False, ) -> None: package_path = os.path.normpath(package_path).rstrip(os.path.sep) @@ -310,7 +341,7 @@ class PackageLoader(BaseLoader): self._loader = loader self._archive = None self._package = package - template_root = None + self._check_templates = check_templates if isinstance(loader, zipimport.zipimporter): self._archive = loader.archive @@ -318,6 +349,15 @@ class PackageLoader(BaseLoader): template_root = os.path.join(pkgdir, package_path).rstrip(os.path.sep) elif hasattr(loader, "arcadia_source_finder"): template_root = os.path.dirname(package.__file__).rstrip(os.path.sep) + try: + if package_path: + importlib.resources.files(package_name).joinpath(package_path) + except FileNotFoundError: + if self._check_templates: + raise ValueError( + f"PackageLoader could not find a {package_path!r} directory" + f" in the {package_name!r} package." + ) else: roots: t.List[str] = [] @@ -329,18 +369,23 @@ class PackageLoader(BaseLoader): elif spec.origin is not None: roots.append(os.path.dirname(spec.origin)) + if not roots: + raise ValueError( + f"The {package_name!r} package was not installed in a" + " way that PackageLoader understands." + ) + for root in roots: root = os.path.join(root, package_path) if os.path.isdir(root): template_root = root break - - if template_root is None: - raise ValueError( - f"The {package_name!r} package was not installed in a" - " way that PackageLoader understands." - ) + else: + raise ValueError( + f"PackageLoader could not find a {package_path!r} directory" + f" in the {package_name!r} package." + ) self._template_root = template_root @@ -416,11 +461,7 @@ class PackageLoader(BaseLoader): for name in filenames ) else: - if not hasattr(self._loader, "_files"): - raise TypeError( - "This zip import does not have the required" - " metadata to list templates." - ) + files = _get_zipimporter_files(self._loader) # Package is a zip file. prefix = ( @@ -429,7 +470,7 @@ class PackageLoader(BaseLoader): ) offset = len(prefix) - for name in self._loader._files.keys(): + for name in files: # Find names under the templates directory that aren't directories. if name.startswith(prefix) and name[-1] != os.path.sep: results.append(name[offset:].replace(os.path.sep, "/")) @@ -444,7 +485,7 @@ class DictLoader(BaseLoader): >>> loader = DictLoader({'index.html': 'source here'}) - Because auto reloading is rarely useful this is disabled per default. + Because auto reloading is rarely useful this is disabled by default. """ def __init__(self, mapping: t.Mapping[str, str]) -> None: @@ -627,10 +668,7 @@ class ModuleLoader(BaseLoader): Example usage: - >>> loader = ChoiceLoader([ - ... ModuleLoader('/path/to/compiled/templates'), - ... FileSystemLoader('/path/to/templates') - ... ]) + >>> loader = ModuleLoader('/path/to/compiled/templates') Templates can be precompiled with :meth:`Environment.compile_templates`. """ diff --git a/contrib/python/Jinja2/py3/jinja2/parser.py b/contrib/python/Jinja2/py3/jinja2/parser.py index 0ec997fb499..f4117754aaf 100644 --- a/contrib/python/Jinja2/py3/jinja2/parser.py +++ b/contrib/python/Jinja2/py3/jinja2/parser.py @@ -64,7 +64,7 @@ class Parser: self.filename = filename self.closed = False self.extensions: t.Dict[ - str, t.Callable[["Parser"], t.Union[nodes.Node, t.List[nodes.Node]]] + str, t.Callable[[Parser], t.Union[nodes.Node, t.List[nodes.Node]]] ] = {} for extension in environment.iter_extensions(): for tag in extension.tags: @@ -487,21 +487,18 @@ class Parser: """ target: nodes.Expr - if with_namespace and self.stream.look().type == "dot": - token = self.stream.expect("name") - next(self.stream) # dot - attr = self.stream.expect("name") - target = nodes.NSRef(token.value, attr.value, lineno=token.lineno) - elif name_only: + if name_only: token = self.stream.expect("name") target = nodes.Name(token.value, "store", lineno=token.lineno) else: if with_tuple: target = self.parse_tuple( - simplified=True, extra_end_rules=extra_end_rules + simplified=True, + extra_end_rules=extra_end_rules, + with_namespace=with_namespace, ) else: - target = self.parse_primary() + target = self.parse_primary(with_namespace=with_namespace) target.set_ctx("store") @@ -643,17 +640,25 @@ class Parser: node = self.parse_filter_expr(node) return node - def parse_primary(self) -> nodes.Expr: + def parse_primary(self, with_namespace: bool = False) -> nodes.Expr: + """Parse a name or literal value. If ``with_namespace`` is enabled, also + parse namespace attr refs, for use in assignments.""" token = self.stream.current node: nodes.Expr if token.type == "name": + next(self.stream) if token.value in ("true", "false", "True", "False"): node = nodes.Const(token.value in ("true", "True"), lineno=token.lineno) elif token.value in ("none", "None"): node = nodes.Const(None, lineno=token.lineno) + elif with_namespace and self.stream.current.type == "dot": + # If namespace attributes are allowed at this point, and the next + # token is a dot, produce a namespace reference. + next(self.stream) + attr = self.stream.expect("name") + node = nodes.NSRef(token.value, attr.value, lineno=token.lineno) else: node = nodes.Name(token.value, "load", lineno=token.lineno) - next(self.stream) elif token.type == "string": next(self.stream) buf = [token.value] @@ -683,6 +688,7 @@ class Parser: with_condexpr: bool = True, extra_end_rules: t.Optional[t.Tuple[str, ...]] = None, explicit_parentheses: bool = False, + with_namespace: bool = False, ) -> t.Union[nodes.Tuple, nodes.Expr]: """Works like `parse_expression` but if multiple expressions are delimited by a comma a :class:`~jinja2.nodes.Tuple` node is created. @@ -690,8 +696,9 @@ class Parser: if no commas where found. The default parsing mode is a full tuple. If `simplified` is `True` - only names and literals are parsed. The `no_condexpr` parameter is - forwarded to :meth:`parse_expression`. + only names and literals are parsed; ``with_namespace`` allows namespace + attr refs as well. The `no_condexpr` parameter is forwarded to + :meth:`parse_expression`. Because tuples do not require delimiters and may end in a bogus comma an extra hint is needed that marks the end of a tuple. For example @@ -704,13 +711,14 @@ class Parser: """ lineno = self.stream.current.lineno if simplified: - parse = self.parse_primary - elif with_condexpr: - parse = self.parse_expression + + def parse() -> nodes.Expr: + return self.parse_primary(with_namespace=with_namespace) + else: def parse() -> nodes.Expr: - return self.parse_expression(with_condexpr=False) + return self.parse_expression(with_condexpr=with_condexpr) args: t.List[nodes.Expr] = [] is_tuple = False diff --git a/contrib/python/Jinja2/py3/jinja2/runtime.py b/contrib/python/Jinja2/py3/jinja2/runtime.py index 4325c8deb22..09119e2ae55 100644 --- a/contrib/python/Jinja2/py3/jinja2/runtime.py +++ b/contrib/python/Jinja2/py3/jinja2/runtime.py @@ -172,7 +172,7 @@ class Context: ): self.parent = parent self.vars: t.Dict[str, t.Any] = {} - self.environment: "Environment" = environment + self.environment: Environment = environment self.eval_ctx = EvalContext(self.environment, name) self.exported_vars: t.Set[str] = set() self.name = name @@ -367,7 +367,7 @@ class BlockReference: @internalcode async def _async_call(self) -> str: - rv = concat( + rv = self._context.environment.concat( # type: ignore [x async for x in self._stack[self._depth](self._context)] # type: ignore ) @@ -381,7 +381,9 @@ class BlockReference: if self._context.environment.is_async: return self._async_call() # type: ignore - rv = concat(self._stack[self._depth](self._context)) + rv = self._context.environment.concat( # type: ignore + self._stack[self._depth](self._context) + ) if self._context.eval_ctx.autoescape: return Markup(rv) @@ -792,8 +794,8 @@ class Macro: class Undefined: - """The default undefined type. This undefined type can be printed and - iterated over, but every other access will raise an :exc:`UndefinedError`: + """The default undefined type. This can be printed, iterated, and treated as + a boolean. Any other operation will raise an :exc:`UndefinedError`. >>> foo = Undefined(name='foo') >>> str(foo) @@ -858,7 +860,11 @@ class Undefined: @internalcode def __getattr__(self, name: str) -> t.Any: - if name[:2] == "__": + # Raise AttributeError on requests for names that appear to be unimplemented + # dunder methods to keep Python's internal protocol probing behaviors working + # properly in cases where another exception type could cause unexpected or + # difficult-to-diagnose failures. + if name[:2] == "__" and name[-2:] == "__": raise AttributeError(name) return self._fail_with_undefined_error() @@ -982,10 +988,20 @@ class ChainableUndefined(Undefined): def __html__(self) -> str: return str(self) - def __getattr__(self, _: str) -> "ChainableUndefined": + def __getattr__(self, name: str) -> "ChainableUndefined": + # Raise AttributeError on requests for names that appear to be unimplemented + # dunder methods to avoid confusing Python with truthy non-method objects that + # do not implement the protocol being probed for. e.g., copy.copy(Undefined()) + # fails spectacularly if getattr(Undefined(), '__setstate__') returns an + # Undefined object instead of raising AttributeError to signal that it does not + # support that style of object initialization. + if name[:2] == "__" and name[-2:] == "__": + raise AttributeError(name) + return self - __getitem__ = __getattr__ # type: ignore + def __getitem__(self, _name: str) -> "ChainableUndefined": # type: ignore[override] + return self class DebugUndefined(Undefined): @@ -1044,13 +1060,3 @@ class StrictUndefined(Undefined): __iter__ = __str__ = __len__ = Undefined._fail_with_undefined_error __eq__ = __ne__ = __bool__ = __hash__ = Undefined._fail_with_undefined_error __contains__ = Undefined._fail_with_undefined_error - - -# Remove slots attributes, after the metaclass is applied they are -# unneeded and contain wrong data for subclasses. -del ( - Undefined.__slots__, - ChainableUndefined.__slots__, - DebugUndefined.__slots__, - StrictUndefined.__slots__, -) diff --git a/contrib/python/Jinja2/py3/jinja2/sandbox.py b/contrib/python/Jinja2/py3/jinja2/sandbox.py index 0b4fc12d347..9c9dae22f1c 100644 --- a/contrib/python/Jinja2/py3/jinja2/sandbox.py +++ b/contrib/python/Jinja2/py3/jinja2/sandbox.py @@ -5,11 +5,12 @@ Useful when the template itself comes from an untrusted source. import operator import types import typing as t +from _string import formatter_field_name_split # type: ignore from collections import abc from collections import deque +from functools import update_wrapper from string import Formatter -from _string import formatter_field_name_split # type: ignore from markupsafe import EscapeFormatter from markupsafe import Markup @@ -60,7 +61,9 @@ _mutable_spec: t.Tuple[t.Tuple[t.Type[t.Any], t.FrozenSet[str]], ...] = ( ), ( abc.MutableSequence, - frozenset(["append", "reverse", "insert", "sort", "extend", "remove"]), + frozenset( + ["append", "clear", "pop", "reverse", "insert", "sort", "extend", "remove"] + ), ), ( deque, @@ -81,20 +84,6 @@ _mutable_spec: t.Tuple[t.Tuple[t.Type[t.Any], t.FrozenSet[str]], ...] = ( ) -def inspect_format_method(callable: t.Callable[..., t.Any]) -> t.Optional[str]: - if not isinstance( - callable, (types.MethodType, types.BuiltinMethodType) - ) or callable.__name__ not in ("format", "format_map"): - return None - - obj = callable.__self__ - - if isinstance(obj, str): - return obj - - return None - - def safe_range(*args: int) -> range: """A range that can't generate ranges with a length of more than MAX_RANGE items. @@ -314,6 +303,9 @@ class SandboxedEnvironment(Environment): except AttributeError: pass else: + fmt = self.wrap_str_format(value) + if fmt is not None: + return fmt if self.is_safe_attribute(obj, argument, value): return value return self.unsafe_undefined(obj, argument) @@ -331,6 +323,9 @@ class SandboxedEnvironment(Environment): except (TypeError, LookupError): pass else: + fmt = self.wrap_str_format(value) + if fmt is not None: + return fmt if self.is_safe_attribute(obj, attribute, value): return value return self.unsafe_undefined(obj, attribute) @@ -346,34 +341,49 @@ class SandboxedEnvironment(Environment): exc=SecurityError, ) - def format_string( - self, - s: str, - args: t.Tuple[t.Any, ...], - kwargs: t.Dict[str, t.Any], - format_func: t.Optional[t.Callable[..., t.Any]] = None, - ) -> str: - """If a format call is detected, then this is routed through this - method so that our safety sandbox can be used for it. + def wrap_str_format(self, value: t.Any) -> t.Optional[t.Callable[..., str]]: + """If the given value is a ``str.format`` or ``str.format_map`` method, + return a new function than handles sandboxing. This is done at access + rather than in :meth:`call`, so that calls made without ``call`` are + also sandboxed. """ + if not isinstance( + value, (types.MethodType, types.BuiltinMethodType) + ) or value.__name__ not in ("format", "format_map"): + return None + + f_self: t.Any = value.__self__ + + if not isinstance(f_self, str): + return None + + str_type: t.Type[str] = type(f_self) + is_format_map = value.__name__ == "format_map" formatter: SandboxedFormatter - if isinstance(s, Markup): - formatter = SandboxedEscapeFormatter(self, escape=s.escape) + + if isinstance(f_self, Markup): + formatter = SandboxedEscapeFormatter(self, escape=f_self.escape) else: formatter = SandboxedFormatter(self) - if format_func is not None and format_func.__name__ == "format_map": - if len(args) != 1 or kwargs: - raise TypeError( - "format_map() takes exactly one argument" - f" {len(args) + (kwargs is not None)} given" - ) + vformat = formatter.vformat + + def wrapper(*args: t.Any, **kwargs: t.Any) -> str: + if is_format_map: + if kwargs: + raise TypeError("format_map() takes no keyword arguments") + + if len(args) != 1: + raise TypeError( + f"format_map() takes exactly one argument ({len(args)} given)" + ) + + kwargs = args[0] + args = () - kwargs = args[0] - args = () + return str_type(vformat(f_self, args, kwargs)) - rv = formatter.vformat(s, args, kwargs) - return type(s)(rv) + return update_wrapper(wrapper, value) def call( __self, # noqa: B902 @@ -383,9 +393,6 @@ class SandboxedEnvironment(Environment): **kwargs: t.Any, ) -> t.Any: """Call an object from sandboxed code.""" - fmt = inspect_format_method(__obj) - if fmt is not None: - return __self.format_string(fmt, args, kwargs, __obj) # the double prefixes are to avoid double keyword argument # errors when proxying the call. diff --git a/contrib/python/Jinja2/py3/jinja2/utils.py b/contrib/python/Jinja2/py3/jinja2/utils.py index 7fb76935aa3..7c922629a92 100644 --- a/contrib/python/Jinja2/py3/jinja2/utils.py +++ b/contrib/python/Jinja2/py3/jinja2/utils.py @@ -18,8 +18,17 @@ if t.TYPE_CHECKING: F = t.TypeVar("F", bound=t.Callable[..., t.Any]) -# special singleton representing missing values for the runtime -missing: t.Any = type("MissingType", (), {"__repr__": lambda x: "missing"})() + +class _MissingType: + def __repr__(self) -> str: + return "missing" + + def __reduce__(self) -> str: + return "missing" + + +missing: t.Any = _MissingType() +"""Special singleton representing missing values for the runtime.""" internal_code: t.MutableSet[CodeType] = set() @@ -324,6 +333,8 @@ def urlize( elif ( "@" in middle and not middle.startswith("www.") + # ignore values like `@a@b` + and not middle.startswith("@") and ":" not in middle and _email_re.match(middle) ): @@ -428,7 +439,7 @@ class LRUCache: def __init__(self, capacity: int) -> None: self.capacity = capacity self._mapping: t.Dict[t.Any, t.Any] = {} - self._queue: "te.Deque[t.Any]" = deque() + self._queue: te.Deque[t.Any] = deque() self._postinit() def _postinit(self) -> None: @@ -453,7 +464,7 @@ class LRUCache: def __getnewargs__(self) -> t.Tuple[t.Any, ...]: return (self.capacity,) - def copy(self) -> "LRUCache": + def copy(self) -> "te.Self": """Return a shallow copy of the instance.""" rv = self.__class__(self.capacity) rv._mapping.update(self._mapping) diff --git a/contrib/python/Jinja2/py3/patches/03-fix-PackageLoader.patch b/contrib/python/Jinja2/py3/patches/03-fix-PackageLoader.patch index d4c98d43f53..7713e499611 100644 --- a/contrib/python/Jinja2/py3/patches/03-fix-PackageLoader.patch +++ b/contrib/python/Jinja2/py3/patches/03-fix-PackageLoader.patch @@ -1,10 +1,11 @@ --- contrib/python/Jinja2/py3/jinja2/loaders.py (index) +++ contrib/python/Jinja2/py3/jinja2/loaders.py (working tree) -@@ -5,6 +5,7 @@ sources. +@@ -5,6 +5,8 @@ sources. import os import posixpath import sys +import pkgutil ++import importlib.resources import typing as t import weakref import zipimport @@ -17,7 +18,7 @@ def split_template_path(template: str) -> t.List[str]: """Split a path into segments and perform a sanity check. If it detects -@@ -288,19 +291,22 @@ class PackageLoader(BaseLoader): +@@ -288,18 +291,31 @@ class PackageLoader(BaseLoader): # Make sure the package exists. This also makes namespace # packages work, otherwise get_loader returns None. @@ -30,7 +31,7 @@ self._loader = loader self._archive = None + self._package = package - template_root = None ++ self._check_templates = False if isinstance(loader, zipimport.zipimporter): self._archive = loader.archive @@ -38,6 +39,15 @@ template_root = os.path.join(pkgdir, package_path).rstrip(os.path.sep) + elif hasattr(loader, "arcadia_source_finder"): + template_root = os.path.dirname(package.__file__).rstrip(os.path.sep) ++ try: ++ if package_path: ++ importlib.resources.files(package_name).joinpath(package_path) ++ except FileNotFoundError: ++ if self._check_templates: ++ raise ValueError( ++ f"PackageLoader could not find a {package_path!r} directory" ++ f" in the {package_name!r} package." ++ ) else: roots: t.List[str] = [] diff --git a/contrib/python/Jinja2/py3/patches/04-fix-PackageLoader-2.patch b/contrib/python/Jinja2/py3/patches/04-fix-PackageLoader-2.patch new file mode 100644 index 00000000000..d130ba0fbda --- /dev/null +++ b/contrib/python/Jinja2/py3/patches/04-fix-PackageLoader-2.patch @@ -0,0 +1,27 @@ +--- contrib/python/Jinja2/py3/jinja2/loaders.py (index) ++++ contrib/python/Jinja2/py3/jinja2/loaders.py (working tree) +@@ -310,6 +310,7 @@ class PackageLoader(BaseLoader): + package_path: "str" = "templates", + encoding: str = "utf-8", + skip_unknown_package: bool = False, ++ check_templates: bool = False, + ) -> None: + package_path = os.path.normpath(package_path).rstrip(os.path.sep) + +@@ -340,7 +341,7 @@ class PackageLoader(BaseLoader): + self._loader = loader + self._archive = None + self._package = package +- self._check_templates = False ++ self._check_templates = check_templates + + if isinstance(loader, zipimport.zipimporter): + self._archive = loader.archive +--- contrib/python/Jinja2/py3/tests/test_loader.py (index) ++++ contrib/python/Jinja2/py3/tests/test_loader.py (working tree) +@@ -436,4 +436,4 @@ def test_pep_451_import_hook(): + + def test_package_loader_no_dir() -> None: + with pytest.raises(ValueError, match="could not find a 'templates' directory"): +- PackageLoader("jinja2") ++ PackageLoader("jinja2", check_templates=True) diff --git a/contrib/python/Jinja2/py3/tests/test_api.py b/contrib/python/Jinja2/py3/tests/test_api.py index ff3fcb138bb..4472b85ac00 100644 --- a/contrib/python/Jinja2/py3/tests/test_api.py +++ b/contrib/python/Jinja2/py3/tests/test_api.py @@ -323,8 +323,6 @@ class TestUndefined: assert und1 == und2 assert und1 != 42 assert hash(und1) == hash(und2) == hash(Undefined()) - with pytest.raises(AttributeError): - getattr(Undefined, "__slots__") # noqa: B009 def test_chainable_undefined(self): env = Environment(undefined=ChainableUndefined) @@ -335,8 +333,6 @@ class TestUndefined: assert env.from_string("{{ foo.missing }}").render(foo=42) == "" assert env.from_string("{{ not missing }}").render() == "True" pytest.raises(UndefinedError, env.from_string("{{ missing - 1}}").render) - with pytest.raises(AttributeError): - getattr(ChainableUndefined, "__slots__") # noqa: B009 # The following tests ensure subclass functionality works as expected assert env.from_string('{{ missing.bar["baz"] }}').render() == "" @@ -368,8 +364,6 @@ class TestUndefined: str(DebugUndefined(hint=undefined_hint)) == f"{{{{ undefined value printed: {undefined_hint} }}}}" ) - with pytest.raises(AttributeError): - getattr(DebugUndefined, "__slots__") # noqa: B009 def test_strict_undefined(self): env = Environment(undefined=StrictUndefined) @@ -386,8 +380,6 @@ class TestUndefined: env.from_string('{{ missing|default("default", true) }}').render() == "default" ) - with pytest.raises(AttributeError): - getattr(StrictUndefined, "__slots__") # noqa: B009 assert env.from_string('{{ "foo" if false }}').render() == "" def test_indexing_gives_undefined(self): @@ -433,3 +425,11 @@ class TestLowLevel: env = CustomEnvironment() tmpl = env.from_string("{{ foo }}") assert tmpl.render() == "resolve-foo" + + +def test_overlay_enable_async(env): + assert not env.is_async + assert not env.overlay().is_async + env_async = env.overlay(enable_async=True) + assert env_async.is_async + assert not env_async.overlay(enable_async=False).is_async diff --git a/contrib/python/Jinja2/py3/tests/test_async.py b/contrib/python/Jinja2/py3/tests/test_async.py index c9ba70c3e66..4edced9dd9f 100644 --- a/contrib/python/Jinja2/py3/tests/test_async.py +++ b/contrib/python/Jinja2/py3/tests/test_async.py @@ -1,6 +1,7 @@ import asyncio import pytest +import trio from jinja2 import ChainableUndefined from jinja2 import DictLoader @@ -13,7 +14,16 @@ from jinja2.exceptions import UndefinedError from jinja2.nativetypes import NativeEnvironment -def test_basic_async(): +def _asyncio_run(async_fn, *args): + return asyncio.run(async_fn(*args)) + + [email protected](params=[_asyncio_run, trio.run], ids=["asyncio", "trio"]) +def run_async_fn(request): + return request.param + + +def test_basic_async(run_async_fn): t = Template( "{% for item in [1, 2, 3] %}[{{ item }}]{% endfor %}", enable_async=True ) @@ -21,11 +31,11 @@ def test_basic_async(): async def func(): return await t.render_async() - rv = asyncio.run(func()) + rv = run_async_fn(func) assert rv == "[1][2][3]" -def test_await_on_calls(): +def test_await_on_calls(run_async_fn): t = Template("{{ async_func() + normal_func() }}", enable_async=True) async def async_func(): @@ -37,7 +47,7 @@ def test_await_on_calls(): async def func(): return await t.render_async(async_func=async_func, normal_func=normal_func) - rv = asyncio.run(func()) + rv = run_async_fn(func) assert rv == "65" @@ -54,7 +64,7 @@ def test_await_on_calls_normal_render(): assert rv == "65" -def test_await_and_macros(): +def test_await_and_macros(run_async_fn): t = Template( "{% macro foo(x) %}[{{ x }}][{{ async_func() }}]{% endmacro %}{{ foo(42) }}", enable_async=True, @@ -66,11 +76,11 @@ def test_await_and_macros(): async def func(): return await t.render_async(async_func=async_func) - rv = asyncio.run(func()) + rv = run_async_fn(func) assert rv == "[42][42]" -def test_async_blocks(): +def test_async_blocks(run_async_fn): t = Template( "{% block foo %}<Test>{% endblock %}{{ self.foo() }}", enable_async=True, @@ -80,7 +90,7 @@ def test_async_blocks(): async def func(): return await t.render_async() - rv = asyncio.run(func()) + rv = run_async_fn(func) assert rv == "<Test><Test>" @@ -156,8 +166,8 @@ class TestAsyncImports: test_env_async.from_string('{% from "foo" import bar, with, context %}') test_env_async.from_string('{% from "foo" import bar, with with context %}') - def test_exports(self, test_env_async): - coro = test_env_async.from_string( + def test_exports(self, test_env_async, run_async_fn): + coro_fn = test_env_async.from_string( """ {% macro toplevel() %}...{% endmacro %} {% macro __private() %}...{% endmacro %} @@ -166,9 +176,9 @@ class TestAsyncImports: {% macro notthere() %}{% endmacro %} {% endfor %} """ - )._get_default_module_async() - m = asyncio.run(coro) - assert asyncio.run(m.toplevel()) == "..." + )._get_default_module_async + m = run_async_fn(coro_fn) + assert run_async_fn(m.toplevel) == "..." assert not hasattr(m, "__missing") assert m.variable == 42 assert not hasattr(m, "notthere") @@ -457,17 +467,19 @@ class TestAsyncForLoop: ) assert tmpl.render(items=reversed([3, 2, 1])) == "1,2,3" - def test_loop_errors(self, test_env_async): + def test_loop_errors(self, test_env_async, run_async_fn): tmpl = test_env_async.from_string( """{% for item in [1] if loop.index == 0 %}...{% endfor %}""" ) - pytest.raises(UndefinedError, tmpl.render) + with pytest.raises(UndefinedError): + run_async_fn(tmpl.render_async) + tmpl = test_env_async.from_string( """{% for item in [] %}...{% else %}{{ loop }}{% endfor %}""" ) - assert tmpl.render() == "" + assert run_async_fn(tmpl.render_async) == "" def test_loop_filter(self, test_env_async): tmpl = test_env_async.from_string( @@ -597,7 +609,7 @@ class TestAsyncForLoop: assert t.render(a=dict(b=[1, 2, 3])) == "1" -def test_namespace_awaitable(test_env_async): +def test_namespace_awaitable(test_env_async, run_async_fn): async def _test(): t = test_env_async.from_string( '{% set ns = namespace(foo="Bar") %}{{ ns.foo }}' @@ -605,10 +617,10 @@ def test_namespace_awaitable(test_env_async): actual = await t.render_async() assert actual == "Bar" - asyncio.run(_test()) + run_async_fn(_test) -def test_chainable_undefined_aiter(): +def test_chainable_undefined_aiter(run_async_fn): async def _test(): t = Template( "{% for x in a['b']['c'] %}{{ x }}{% endfor %}", @@ -618,7 +630,7 @@ def test_chainable_undefined_aiter(): rv = await t.render_async(a={}) assert rv == "" - asyncio.run(_test()) + run_async_fn(_test) @pytest.fixture @@ -626,22 +638,22 @@ def async_native_env(): return NativeEnvironment(enable_async=True) -def test_native_async(async_native_env): +def test_native_async(async_native_env, run_async_fn): async def _test(): t = async_native_env.from_string("{{ x }}") rv = await t.render_async(x=23) assert rv == 23 - asyncio.run(_test()) + run_async_fn(_test) -def test_native_list_async(async_native_env): +def test_native_list_async(async_native_env, run_async_fn): async def _test(): t = async_native_env.from_string("{{ x }}") rv = await t.render_async(x=list(range(3))) assert rv == [0, 1, 2] - asyncio.run(_test()) + run_async_fn(_test) def test_getitem_after_filter(): @@ -658,3 +670,65 @@ def test_getitem_after_call(): t = env.from_string("{{ add_each(a, 2)[1:] }}") out = t.render(a=range(3)) assert out == "[3, 4]" + + +def test_basic_generate_async(run_async_fn): + t = Template( + "{% for item in [1, 2, 3] %}[{{ item }}]{% endfor %}", enable_async=True + ) + + async def func(): + agen = t.generate_async() + try: + return await agen.__anext__() + finally: + await agen.aclose() + + rv = run_async_fn(func) + assert rv == "[" + + +def test_include_generate_async(run_async_fn, test_env_async): + t = test_env_async.from_string('{% include "header" %}') + + async def func(): + agen = t.generate_async() + try: + return await agen.__anext__() + finally: + await agen.aclose() + + rv = run_async_fn(func) + assert rv == "[" + + +def test_blocks_generate_async(run_async_fn): + t = Template( + "{% block foo %}<Test>{% endblock %}{{ self.foo() }}", + enable_async=True, + autoescape=True, + ) + + async def func(): + agen = t.generate_async() + try: + return await agen.__anext__() + finally: + await agen.aclose() + + rv = run_async_fn(func) + assert rv == "<Test>" + + +def test_async_extend(run_async_fn, test_env_async): + t = test_env_async.from_string('{% extends "header" %}') + + async def func(): + agen = t.generate_async() + try: + return await agen.__anext__() + finally: + await agen.aclose() + + rv = run_async_fn(func) + assert rv == "[" diff --git a/contrib/python/Jinja2/py3/tests/test_async_filters.py b/contrib/python/Jinja2/py3/tests/test_async_filters.py index f5b2627ad87..e9892f1edcd 100644 --- a/contrib/python/Jinja2/py3/tests/test_async_filters.py +++ b/contrib/python/Jinja2/py3/tests/test_async_filters.py @@ -1,6 +1,9 @@ +import asyncio +import contextlib from collections import namedtuple import pytest +import trio from markupsafe import Markup from jinja2 import Environment @@ -26,10 +29,39 @@ def env_async(): return Environment(enable_async=True) +def _asyncio_run(async_fn, *args): + return asyncio.run(async_fn(*args)) + + [email protected](params=[_asyncio_run, trio.run], ids=["asyncio", "trio"]) +def run_async_fn(request): + return request.param + + +async def closing_factory(): + async with contextlib.AsyncExitStack() as stack: + + def closing(maybe_agen): + try: + aclose = maybe_agen.aclose + except AttributeError: + pass + else: + stack.push_async_callback(aclose) + return maybe_agen + + yield closing + + @mark_dualiter("foo", lambda: range(10)) -def test_first(env_async, foo): - tmpl = env_async.from_string("{{ foo()|first }}") - out = tmpl.render(foo=foo) +def test_first(env_async, foo, run_async_fn): + async def test(): + async with closing_factory() as closing: + tmpl = env_async.from_string("{{ closing(foo())|first }}") + return await tmpl.render_async(foo=foo, closing=closing) + + out = run_async_fn(test) assert out == "0" @@ -245,18 +277,30 @@ def test_slice(env_async, items): ) -def test_custom_async_filter(env_async): +def test_unique_with_async_gen(env_async): + items = ["a", "b", "c", "c", "a", "d", "z"] + tmpl = env_async.from_string("{{ items|reject('==', 'z')|unique|list }}") + out = tmpl.render(items=items) + assert out == "['a', 'b', 'c', 'd']" + + +def test_custom_async_filter(env_async, run_async_fn): async def customfilter(val): return str(val) - env_async.filters["customfilter"] = customfilter - tmpl = env_async.from_string("{{ 'static'|customfilter }} {{ arg|customfilter }}") - out = tmpl.render(arg="dynamic") + async def test(): + env_async.filters["customfilter"] = customfilter + tmpl = env_async.from_string( + "{{ 'static'|customfilter }} {{ arg|customfilter }}" + ) + return await tmpl.render_async(arg="dynamic") + + out = run_async_fn(test) assert out == "static dynamic" @mark_dualiter("items", lambda: range(10)) -def test_custom_async_iteratable_filter(env_async, items): +def test_custom_async_iteratable_filter(env_async, items, run_async_fn): async def customfilter(iterable): items = [] async for item in auto_aiter(iterable): @@ -265,9 +309,13 @@ def test_custom_async_iteratable_filter(env_async, items): break return ",".join(items) - env_async.filters["customfilter"] = customfilter - tmpl = env_async.from_string( - "{{ items()|customfilter }} .. {{ [3, 4, 5, 6]|customfilter }}" - ) - out = tmpl.render(items=items) + async def test(): + async with closing_factory() as closing: + env_async.filters["customfilter"] = customfilter + tmpl = env_async.from_string( + "{{ closing(items())|customfilter }} .. {{ [3, 4, 5, 6]|customfilter }}" + ) + return await tmpl.render_async(items=items, closing=closing) + + out = run_async_fn(test) assert out == "0,1,2 .. 3,4,5" diff --git a/contrib/python/Jinja2/py3/tests/test_compile.py b/contrib/python/Jinja2/py3/tests/test_compile.py index 42a773f21cc..e1a5391ea25 100644 --- a/contrib/python/Jinja2/py3/tests/test_compile.py +++ b/contrib/python/Jinja2/py3/tests/test_compile.py @@ -1,6 +1,9 @@ import os import re +import pytest + +from jinja2 import UndefinedError from jinja2.environment import Environment from jinja2.loaders import DictLoader @@ -26,3 +29,80 @@ def test_import_as_with_context_deterministic(tmp_path): expect = [f"'bar{i}': " for i in range(10)] found = re.findall(r"'bar\d': ", content)[:10] assert found == expect + + +def test_top_level_set_vars_unpacking_deterministic(tmp_path): + src = "\n".join(f"{{% set a{i}, b{i}, c{i} = tuple_var{i} %}}" for i in range(10)) + env = Environment(loader=DictLoader({"foo": src})) + env.compile_templates(tmp_path, zip=None) + name = os.listdir(tmp_path)[0] + content = (tmp_path / name).read_text("utf8") + expect = [ + f"context.vars.update({{'a{i}': l_0_a{i}, 'b{i}': l_0_b{i}, 'c{i}': l_0_c{i}}})" + for i in range(10) + ] + found = re.findall( + r"context\.vars\.update\(\{'a\d': l_0_a\d, 'b\d': l_0_b\d, 'c\d': l_0_c\d\}\)", + content, + )[:10] + assert found == expect + expect = [ + f"context.exported_vars.update(('a{i}', 'b{i}', 'c{i}'))" for i in range(10) + ] + found = re.findall( + r"context\.exported_vars\.update\(\('a\d', 'b\d', 'c\d'\)\)", + content, + )[:10] + assert found == expect + + +def test_loop_set_vars_unpacking_deterministic(tmp_path): + src = "\n".join(f" {{% set a{i}, b{i}, c{i} = tuple_var{i} %}}" for i in range(10)) + src = f"{{% for i in seq %}}\n{src}\n{{% endfor %}}" + env = Environment(loader=DictLoader({"foo": src})) + env.compile_templates(tmp_path, zip=None) + name = os.listdir(tmp_path)[0] + content = (tmp_path / name).read_text("utf8") + expect = [ + f"_loop_vars.update({{'a{i}': l_1_a{i}, 'b{i}': l_1_b{i}, 'c{i}': l_1_c{i}}})" + for i in range(10) + ] + found = re.findall( + r"_loop_vars\.update\(\{'a\d': l_1_a\d, 'b\d': l_1_b\d, 'c\d': l_1_c\d\}\)", + content, + )[:10] + assert found == expect + + +def test_block_set_vars_unpacking_deterministic(tmp_path): + src = "\n".join(f" {{% set a{i}, b{i}, c{i} = tuple_var{i} %}}" for i in range(10)) + src = f"{{% block test %}}\n{src}\n{{% endblock test %}}" + env = Environment(loader=DictLoader({"foo": src})) + env.compile_templates(tmp_path, zip=None) + name = os.listdir(tmp_path)[0] + content = (tmp_path / name).read_text("utf8") + expect = [ + f"_block_vars.update({{'a{i}': l_0_a{i}, 'b{i}': l_0_b{i}, 'c{i}': l_0_c{i}}})" + for i in range(10) + ] + found = re.findall( + r"_block_vars\.update\(\{'a\d': l_0_a\d, 'b\d': l_0_b\d, 'c\d': l_0_c\d\}\)", + content, + )[:10] + assert found == expect + + +def test_undefined_import_curly_name(): + env = Environment( + loader=DictLoader( + { + "{bad}": "{% from 'macro' import m %}{{ m() }}", + "macro": "", + } + ) + ) + + # Must not raise `NameError: 'bad' is not defined`, as that would indicate + # that `{bad}` is being interpreted as an f-string. It must be escaped. + with pytest.raises(UndefinedError): + env.get_template("{bad}").render() diff --git a/contrib/python/Jinja2/py3/tests/test_core_tags.py b/contrib/python/Jinja2/py3/tests/test_core_tags.py index 4bb95e0240a..2d847a2c9af 100644 --- a/contrib/python/Jinja2/py3/tests/test_core_tags.py +++ b/contrib/python/Jinja2/py3/tests/test_core_tags.py @@ -538,6 +538,14 @@ class TestSet: ) assert tmpl.render() == "13|37" + def test_namespace_set_tuple(self, env_trim): + tmpl = env_trim.from_string( + "{% set ns = namespace(a=12, b=36) %}" + "{% set ns.a, ns.b = ns.a + 1, ns.b + 1 %}" + "{{ ns.a }}|{{ ns.b }}" + ) + assert tmpl.render() == "13|37" + def test_block_escaping_filtered(self): env = Environment(autoescape=True) tmpl = env.from_string( diff --git a/contrib/python/Jinja2/py3/tests/test_filters.py b/contrib/python/Jinja2/py3/tests/test_filters.py index d8e9114d0f6..2cb53ac9d07 100644 --- a/contrib/python/Jinja2/py3/tests/test_filters.py +++ b/contrib/python/Jinja2/py3/tests/test_filters.py @@ -196,6 +196,7 @@ class TestFilter: ("abc", "0"), ("32.32", "32"), ("12345678901234567890", "12345678901234567890"), + ("1e10000", "0"), ), ) def test_int(self, env, value, expect): diff --git a/contrib/python/Jinja2/py3/tests/test_loader.py b/contrib/python/Jinja2/py3/tests/test_loader.py index d3b4ddf1ba7..0cdcf94bb07 100644 --- a/contrib/python/Jinja2/py3/tests/test_loader.py +++ b/contrib/python/Jinja2/py3/tests/test_loader.py @@ -2,7 +2,6 @@ import importlib.abc import importlib.machinery import importlib.util import os -import platform import shutil import sys import tempfile @@ -183,6 +182,24 @@ class TestFileSystemLoader: t = e.get_template("foo/test.html") assert t.filename == str(self.searchpath / "foo" / "test.html") + def test_error_includes_paths(self, env, filesystem_loader): + env.loader = filesystem_loader + + with pytest.raises(TemplateNotFound) as info: + env.get_template("missing") + + e_str = str(info.value) + assert e_str.startswith("'missing' not found in search path: ") + + filesystem_loader.searchpath.append("other") + + with pytest.raises(TemplateNotFound) as info: + env.get_template("missing") + + e_str = str(info.value) + assert e_str.startswith("'missing' not found in search paths: ") + assert ", 'other'" in e_str + class TestModuleLoader: archive = None @@ -367,8 +384,8 @@ def test_package_zip_source(package_zip_loader, template, expect): @pytest.mark.xfail( - platform.python_implementation() == "PyPy", - reason="PyPy's zipimporter doesn't have a '_files' attribute.", + sys.implementation.name == "pypy", + reason="zipimporter doesn't have a '_files' attribute", raises=TypeError, ) def test_package_zip_list(package_zip_loader): @@ -415,3 +432,8 @@ def test_pep_451_import_hook(): assert "test.html" in package_loader.list_templates() finally: sys.meta_path[:] = before + + +def test_package_loader_no_dir() -> None: + with pytest.raises(ValueError, match="could not find a 'templates' directory"): + PackageLoader("jinja2", check_templates=True) diff --git a/contrib/python/Jinja2/py3/tests/test_nativetypes.py b/contrib/python/Jinja2/py3/tests/test_nativetypes.py index 8c85252518d..13690818079 100644 --- a/contrib/python/Jinja2/py3/tests/test_nativetypes.py +++ b/contrib/python/Jinja2/py3/tests/test_nativetypes.py @@ -160,3 +160,13 @@ def test_macro(env): result = t.render() assert result == 2 assert isinstance(result, int) + + +def test_block(env): + t = env.from_string( + "{% block b %}{% for i in range(1) %}{{ loop.index }}{% endfor %}" + "{% endblock %}{{ self.b() }}" + ) + result = t.render() + assert result == 11 + assert isinstance(result, int) diff --git a/contrib/python/Jinja2/py3/tests/test_regression.py b/contrib/python/Jinja2/py3/tests/test_regression.py index 7bd4d15649a..93d72c5e6f8 100644 --- a/contrib/python/Jinja2/py3/tests/test_regression.py +++ b/contrib/python/Jinja2/py3/tests/test_regression.py @@ -737,6 +737,28 @@ End""" ) assert tmpl.render() == "hellohellohello" + def test_pass_context_with_select(self, env): + @pass_context + def is_foo(ctx, s): + assert ctx is not None + return s == "foo" + + env.tests["foo"] = is_foo + tmpl = env.from_string( + "{% for x in ['one', 'foo'] | select('foo') %}{{ x }}{% endfor %}" + ) + assert tmpl.render() == "foo" + + +def test_load_parameter_when_set_in_all_if_branches(env): + tmpl = env.from_string( + "{% if True %}{{ a.b }}{% set a = 1 %}" + "{% elif False %}{% set a = 2 %}" + "{% else %}{% set a = 3 %}{% endif %}" + "{{ a }}" + ) + assert tmpl.render(a={"b": 0}) == "01" + @pytest.mark.parametrize("unicode_char", ["\N{FORM FEED}", "\x85"]) def test_unicode_whitespace(env, unicode_char): diff --git a/contrib/python/Jinja2/py3/tests/test_runtime.py b/contrib/python/Jinja2/py3/tests/test_runtime.py index 1978c64104a..3cd3be15fb6 100644 --- a/contrib/python/Jinja2/py3/tests/test_runtime.py +++ b/contrib/python/Jinja2/py3/tests/test_runtime.py @@ -1,6 +1,15 @@ +import copy import itertools +import pickle +import pytest + +from jinja2 import ChainableUndefined +from jinja2 import DebugUndefined +from jinja2 import StrictUndefined from jinja2 import Template +from jinja2 import TemplateRuntimeError +from jinja2 import Undefined from jinja2.runtime import LoopContext TEST_IDX_TEMPLATE_STR_1 = ( @@ -73,3 +82,44 @@ def test_mock_not_pass_arg_marker(): out = t.render(calc=Calc()) # Would be "1" if context argument was passed. assert out == "0" + + +_undefined_types = (Undefined, ChainableUndefined, DebugUndefined, StrictUndefined) + + [email protected]("undefined_type", _undefined_types) +def test_undefined_copy(undefined_type): + undef = undefined_type("a hint", ["foo"], "a name", TemplateRuntimeError) + copied = copy.copy(undef) + + assert copied is not undef + assert copied._undefined_hint is undef._undefined_hint + assert copied._undefined_obj is undef._undefined_obj + assert copied._undefined_name is undef._undefined_name + assert copied._undefined_exception is undef._undefined_exception + + [email protected]("undefined_type", _undefined_types) +def test_undefined_deepcopy(undefined_type): + undef = undefined_type("a hint", ["foo"], "a name", TemplateRuntimeError) + copied = copy.deepcopy(undef) + + assert copied._undefined_hint is undef._undefined_hint + assert copied._undefined_obj is not undef._undefined_obj + assert copied._undefined_obj == undef._undefined_obj + assert copied._undefined_name is undef._undefined_name + assert copied._undefined_exception is undef._undefined_exception + + [email protected]("undefined_type", _undefined_types) +def test_undefined_pickle(undefined_type): + undef = undefined_type("a hint", ["foo"], "a name", TemplateRuntimeError) + copied = pickle.loads(pickle.dumps(undef)) + + assert copied._undefined_hint is not undef._undefined_hint + assert copied._undefined_hint == undef._undefined_hint + assert copied._undefined_obj is not undef._undefined_obj + assert copied._undefined_obj == undef._undefined_obj + assert copied._undefined_name is not undef._undefined_name + assert copied._undefined_name == undef._undefined_name + assert copied._undefined_exception is undef._undefined_exception diff --git a/contrib/python/Jinja2/py3/tests/test_security.py b/contrib/python/Jinja2/py3/tests/test_security.py index 0e8dc5c0385..3a13781926e 100644 --- a/contrib/python/Jinja2/py3/tests/test_security.py +++ b/contrib/python/Jinja2/py3/tests/test_security.py @@ -58,6 +58,8 @@ class TestSandbox: def test_immutable_environment(self, env): env = ImmutableSandboxedEnvironment() pytest.raises(SecurityError, env.from_string("{{ [].append(23) }}").render) + pytest.raises(SecurityError, env.from_string("{{ [].clear() }}").render) + pytest.raises(SecurityError, env.from_string("{{ [1].pop() }}").render) pytest.raises(SecurityError, env.from_string("{{ {1:2}.clear() }}").render) def test_restricted(self, env): @@ -171,3 +173,30 @@ class TestStringFormatMap: '{{ ("a{x.foo}b{y}"|safe).format_map({"x":{"foo": 42}, "y":"<foo>"}) }}' ) assert t.render() == "a42b<foo>" + + def test_indirect_call(self): + def run(value, arg): + return value.run(arg) + + env = SandboxedEnvironment() + env.filters["run"] = run + t = env.from_string( + """{% set + ns = namespace(run="{0.__call__.__builtins__[__import__]}".format) + %} + {{ ns | run(not_here) }} + """ + ) + + with pytest.raises(SecurityError): + t.render() + + def test_attr_filter(self) -> None: + env = SandboxedEnvironment() + t = env.from_string( + """{{ "{0.__call__.__builtins__[__import__]}" + | attr("format")(not_here) }}""" + ) + + with pytest.raises(SecurityError): + t.render() diff --git a/contrib/python/Jinja2/py3/tests/test_utils.py b/contrib/python/Jinja2/py3/tests/test_utils.py index feaf8dc1d41..6f42eab5e24 100644 --- a/contrib/python/Jinja2/py3/tests/test_utils.py +++ b/contrib/python/Jinja2/py3/tests/test_utils.py @@ -1,3 +1,4 @@ +import copy import pickle import random from collections import deque @@ -141,6 +142,14 @@ class TestEscapeUrlizeTarget: "http://example.org</a>" ) + def test_urlize_mail_mastodon(self): + fr = "[email protected]\n@[email protected]\n" + to = ( + '<a href="mailto:[email protected]">' + "[email protected]</a>\n@[email protected]\n" + ) + assert urlize(fr) == to + class TestLoremIpsum: def test_lorem_ipsum_markup(self): @@ -183,3 +192,14 @@ def test_consume(): consume(x) with pytest.raises(StopIteration): next(x) + + [email protected]("protocol", range(pickle.HIGHEST_PROTOCOL + 1)) +def test_pickle_missing(protocol: int) -> None: + """Test that missing can be pickled while remaining a singleton.""" + assert pickle.loads(pickle.dumps(missing, protocol)) is missing + + +def test_copy_missing() -> None: + """Test that missing can be copied while remaining a singleton.""" + assert copy.copy(missing) is missing diff --git a/contrib/python/Jinja2/py3/tests/ya.make b/contrib/python/Jinja2/py3/tests/ya.make index c9f59c881f4..7ae1b5e1b9c 100644 --- a/contrib/python/Jinja2/py3/tests/ya.make +++ b/contrib/python/Jinja2/py3/tests/ya.make @@ -2,6 +2,11 @@ PY3TEST() PEERDIR( contrib/python/Jinja2 + contrib/python/trio +) + +DATA( + arcadia/contrib/python/Jinja2/py3/tests ) PY_SRCS( @@ -9,10 +14,6 @@ PY_SRCS( res/__init__.py ) -DATA( - arcadia/contrib/python/Jinja2/py3/tests/res -) - RESOURCE_FILES( PREFIX contrib/python/Jinja2/py3/tests/ res/templates/broken.html @@ -23,31 +24,7 @@ RESOURCE_FILES( res/templates2/foo ) -TEST_SRCS( - conftest.py - test_api.py - test_async.py - test_async_filters.py - test_bytecode_cache.py - test_compile.py - test_core_tags.py - test_debug.py - test_ext.py - test_filters.py - test_idtracking.py - test_imports.py - test_inheritance.py - test_lexnparse.py - test_loader.py - test_nativetypes.py - test_nodes.py - test_pickle.py - test_regression.py - test_runtime.py - test_security.py - test_tests.py - test_utils.py -) +ALL_PYTEST_SRCS() NO_LINT() diff --git a/contrib/python/Jinja2/py3/ya.make b/contrib/python/Jinja2/py3/ya.make index f51dfbfcce8..67a1f90b4f2 100644 --- a/contrib/python/Jinja2/py3/ya.make +++ b/contrib/python/Jinja2/py3/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(3.1.4) +VERSION(3.1.6) LICENSE(BSD-3-Clause) diff --git a/contrib/python/outcome/.dist-info/METADATA b/contrib/python/outcome/.dist-info/METADATA new file mode 100644 index 00000000000..2a8636fddcf --- /dev/null +++ b/contrib/python/outcome/.dist-info/METADATA @@ -0,0 +1,63 @@ +Metadata-Version: 2.1 +Name: outcome +Version: 1.3.0.post0 +Summary: Capture the outcome of Python function calls. +Home-page: https://github.com/python-trio/outcome +Author: Frazer McLean +Author-email: [email protected] +License: MIT OR Apache-2.0 +Project-URL: Documentation, https://outcome.readthedocs.io/en/latest/ +Project-URL: Chat, https://gitter.im/python-trio/general +Project-URL: Changelog, https://outcome.readthedocs.io/en/latest/history.html +Keywords: result +Classifier: Development Status :: 5 - Production/Stable +Classifier: Framework :: Trio +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: POSIX :: Linux +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: Microsoft :: Windows +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Typing :: Typed +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE +License-File: LICENSE.APACHE2 +License-File: LICENSE.MIT +Requires-Dist: attrs >=19.2.0 + +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://outcome.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +.. image:: https://travis-ci.org/python-trio/trio.svg?branch=master + :target: https://travis-ci.org/python-trio/outcome + :alt: Automated test status (Linux and MacOS) + +.. image:: https://ci.appveyor.com/api/projects/status/c54uu4rxlgs2usmj/branch/master?svg=true + :target: https://ci.appveyor.com/project/RazerM/outcome/history + :alt: Automated test status (Windows) + +.. image:: https://codecov.io/gh/python-trio/trio/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-trio/outcome + :alt: Test coverage + +outcome +======= + +Welcome to `outcome <https://github.com/python-trio/outcome>`__! + +Capture the outcome of Python function calls. Extracted from the +`Trio <https://github.com/python-trio/trio>`__ project. + +License: Your choice of MIT or Apache License 2.0 diff --git a/contrib/python/outcome/.dist-info/top_level.txt b/contrib/python/outcome/.dist-info/top_level.txt new file mode 100644 index 00000000000..e6d2d58a1eb --- /dev/null +++ b/contrib/python/outcome/.dist-info/top_level.txt @@ -0,0 +1 @@ +outcome diff --git a/contrib/python/outcome/LICENSE b/contrib/python/outcome/LICENSE new file mode 100644 index 00000000000..51f34429178 --- /dev/null +++ b/contrib/python/outcome/LICENSE @@ -0,0 +1,3 @@ +This software is made available under the terms of *either* of the +licenses found in LICENSE.APACHE2 or LICENSE.MIT. Contributions to are +made under the terms of *both* these licenses. diff --git a/contrib/python/outcome/LICENSE.APACHE2 b/contrib/python/outcome/LICENSE.APACHE2 new file mode 100644 index 00000000000..d6456956733 --- /dev/null +++ b/contrib/python/outcome/LICENSE.APACHE2 @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrib/python/outcome/LICENSE.MIT b/contrib/python/outcome/LICENSE.MIT new file mode 100644 index 00000000000..b8bb9718592 --- /dev/null +++ b/contrib/python/outcome/LICENSE.MIT @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/contrib/python/outcome/README.rst b/contrib/python/outcome/README.rst new file mode 100644 index 00000000000..d3a27dc41aa --- /dev/null +++ b/contrib/python/outcome/README.rst @@ -0,0 +1,29 @@ +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://outcome.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +.. image:: https://travis-ci.org/python-trio/trio.svg?branch=master + :target: https://travis-ci.org/python-trio/outcome + :alt: Automated test status (Linux and MacOS) + +.. image:: https://ci.appveyor.com/api/projects/status/c54uu4rxlgs2usmj/branch/master?svg=true + :target: https://ci.appveyor.com/project/RazerM/outcome/history + :alt: Automated test status (Windows) + +.. image:: https://codecov.io/gh/python-trio/trio/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-trio/outcome + :alt: Test coverage + +outcome +======= + +Welcome to `outcome <https://github.com/python-trio/outcome>`__! + +Capture the outcome of Python function calls. Extracted from the +`Trio <https://github.com/python-trio/trio>`__ project. + +License: Your choice of MIT or Apache License 2.0 diff --git a/contrib/python/outcome/outcome/__init__.py b/contrib/python/outcome/outcome/__init__.py new file mode 100644 index 00000000000..9e6b453710e --- /dev/null +++ b/contrib/python/outcome/outcome/__init__.py @@ -0,0 +1,20 @@ +"""Top-level package for outcome.""" + +from ._impl import ( + Error as Error, + Maybe as Maybe, + Outcome as Outcome, + Value as Value, + acapture as acapture, + capture as capture, +) +from ._util import AlreadyUsedError as AlreadyUsedError, fixup_module_metadata +from ._version import __version__ as __version__ + +__all__ = ( + 'Error', 'Outcome', 'Value', 'Maybe', 'acapture', 'capture', + 'AlreadyUsedError' +) + +fixup_module_metadata(__name__, globals()) +del fixup_module_metadata diff --git a/contrib/python/outcome/outcome/_impl.py b/contrib/python/outcome/outcome/_impl.py new file mode 100644 index 00000000000..004b72dad21 --- /dev/null +++ b/contrib/python/outcome/outcome/_impl.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import abc +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + Awaitable, + Callable, + Generator, + Generic, + NoReturn, + TypeVar, + Union, + overload, +) + +import attr + +from ._util import AlreadyUsedError, remove_tb_frames + +if TYPE_CHECKING: + from typing_extensions import ParamSpec, final + ArgsT = ParamSpec("ArgsT") +else: + + def final(func): + return func + + +__all__ = ['Error', 'Outcome', 'Maybe', 'Value', 'acapture', 'capture'] + +ValueT = TypeVar("ValueT", covariant=True) +ResultT = TypeVar("ResultT") + + +@overload +def capture( + # NoReturn = raises exception, so we should get an error. + sync_fn: Callable[ArgsT, NoReturn], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, +) -> Error: + ... + + +@overload +def capture( + sync_fn: Callable[ArgsT, ResultT], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, +) -> Value[ResultT] | Error: + ... + + +def capture( + sync_fn: Callable[ArgsT, ResultT], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, +) -> Value[ResultT] | Error: + """Run ``sync_fn(*args, **kwargs)`` and capture the result. + + Returns: + Either a :class:`Value` or :class:`Error` as appropriate. + + """ + try: + return Value(sync_fn(*args, **kwargs)) + except BaseException as exc: + exc = remove_tb_frames(exc, 1) + return Error(exc) + + +@overload +async def acapture( + async_fn: Callable[ArgsT, Awaitable[NoReturn]], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, +) -> Error: + ... + + +@overload +async def acapture( + async_fn: Callable[ArgsT, Awaitable[ResultT]], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, +) -> Value[ResultT] | Error: + ... + + +async def acapture( + async_fn: Callable[ArgsT, Awaitable[ResultT]], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, +) -> Value[ResultT] | Error: + """Run ``await async_fn(*args, **kwargs)`` and capture the result. + + Returns: + Either a :class:`Value` or :class:`Error` as appropriate. + + """ + try: + return Value(await async_fn(*args, **kwargs)) + except BaseException as exc: + exc = remove_tb_frames(exc, 1) + return Error(exc) + + [email protected](repr=False, init=False, slots=True) +class Outcome(abc.ABC, Generic[ValueT]): + """An abstract class representing the result of a Python computation. + + This class has two concrete subclasses: :class:`Value` representing a + value, and :class:`Error` representing an exception. + + In addition to the methods described below, comparison operators on + :class:`Value` and :class:`Error` objects (``==``, ``<``, etc.) check that + the other object is also a :class:`Value` or :class:`Error` object + respectively, and then compare the contained objects. + + :class:`Outcome` objects are hashable if the contained objects are + hashable. + + """ + _unwrapped: bool = attr.ib(default=False, eq=False, init=False) + + def _set_unwrapped(self) -> None: + if self._unwrapped: + raise AlreadyUsedError + object.__setattr__(self, '_unwrapped', True) + + @abc.abstractmethod + def unwrap(self) -> ValueT: + """Return or raise the contained value or exception. + + These two lines of code are equivalent:: + + x = fn(*args) + x = outcome.capture(fn, *args).unwrap() + + """ + + @abc.abstractmethod + def send(self, gen: Generator[ResultT, ValueT, object]) -> ResultT: + """Send or throw the contained value or exception into the given + generator object. + + Args: + gen: A generator object supporting ``.send()`` and ``.throw()`` + methods. + + """ + + @abc.abstractmethod + async def asend(self, agen: AsyncGenerator[ResultT, ValueT]) -> ResultT: + """Send or throw the contained value or exception into the given async + generator object. + + Args: + agen: An async generator object supporting ``.asend()`` and + ``.athrow()`` methods. + + """ + + +@final [email protected](frozen=True, repr=False, slots=True) +class Value(Outcome[ValueT], Generic[ValueT]): + """Concrete :class:`Outcome` subclass representing a regular value. + + """ + + value: ValueT = attr.ib() + """The contained value.""" + + def __repr__(self) -> str: + return f'Value({self.value!r})' + + def unwrap(self) -> ValueT: + self._set_unwrapped() + return self.value + + def send(self, gen: Generator[ResultT, ValueT, object]) -> ResultT: + self._set_unwrapped() + return gen.send(self.value) + + async def asend(self, agen: AsyncGenerator[ResultT, ValueT]) -> ResultT: + self._set_unwrapped() + return await agen.asend(self.value) + + +@final [email protected](frozen=True, repr=False, slots=True) +class Error(Outcome[NoReturn]): + """Concrete :class:`Outcome` subclass representing a raised exception. + + """ + + error: BaseException = attr.ib( + validator=attr.validators.instance_of(BaseException) + ) + """The contained exception object.""" + + def __repr__(self) -> str: + return f'Error({self.error!r})' + + def unwrap(self) -> NoReturn: + self._set_unwrapped() + # Tracebacks show the 'raise' line below out of context, so let's give + # this variable a name that makes sense out of context. + captured_error = self.error + try: + raise captured_error + finally: + # We want to avoid creating a reference cycle here. Python does + # collect cycles just fine, so it wouldn't be the end of the world + # if we did create a cycle, but the cyclic garbage collector adds + # latency to Python programs, and the more cycles you create, the + # more often it runs, so it's nicer to avoid creating them in the + # first place. For more details see: + # + # https://github.com/python-trio/trio/issues/1770 + # + # In particuar, by deleting this local variables from the 'unwrap' + # methods frame, we avoid the 'captured_error' object's + # __traceback__ from indirectly referencing 'captured_error'. + del captured_error, self + + def send(self, gen: Generator[ResultT, NoReturn, object]) -> ResultT: + self._set_unwrapped() + return gen.throw(self.error) + + async def asend(self, agen: AsyncGenerator[ResultT, NoReturn]) -> ResultT: + self._set_unwrapped() + return await agen.athrow(self.error) + + +# A convenience alias to a union of both results, allowing exhaustiveness checking. +Maybe = Union[Value[ValueT], Error] diff --git a/contrib/python/outcome/outcome/_util.py b/contrib/python/outcome/outcome/_util.py new file mode 100644 index 00000000000..98ad50f98ce --- /dev/null +++ b/contrib/python/outcome/outcome/_util.py @@ -0,0 +1,33 @@ +from typing import Any, Dict + + +class AlreadyUsedError(RuntimeError): + """An Outcome can only be unwrapped once.""" + pass + + +def fixup_module_metadata( + module_name: str, + namespace: Dict[str, object], +) -> None: + def fix_one(obj: object) -> None: + mod = getattr(obj, "__module__", None) + if mod is not None and mod.startswith("outcome."): + obj.__module__ = module_name + if isinstance(obj, type): + for attr_value in obj.__dict__.values(): + fix_one(attr_value) + + all_list = namespace["__all__"] + assert isinstance(all_list, (tuple, list)), repr(all_list) + for objname in all_list: + obj = namespace[objname] + fix_one(obj) + + +def remove_tb_frames(exc: BaseException, n: int) -> BaseException: + tb = exc.__traceback__ + for _ in range(n): + assert tb is not None + tb = tb.tb_next + return exc.with_traceback(tb) diff --git a/contrib/python/outcome/outcome/_version.py b/contrib/python/outcome/outcome/_version.py new file mode 100644 index 00000000000..5bd0d9c6bed --- /dev/null +++ b/contrib/python/outcome/outcome/_version.py @@ -0,0 +1,7 @@ +# This file is imported from __init__.py and exec'd from setup.py +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Final + +__version__: 'Final[str]' = "1.3.0.post0" diff --git a/contrib/python/outcome/outcome/py.typed b/contrib/python/outcome/outcome/py.typed new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/contrib/python/outcome/outcome/py.typed diff --git a/contrib/python/outcome/ya.make b/contrib/python/outcome/ya.make new file mode 100644 index 00000000000..1677b7b74dd --- /dev/null +++ b/contrib/python/outcome/ya.make @@ -0,0 +1,30 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(1.3.0.post0) + +LICENSE(Apache-2.0 AND BSD-3-Clause AND MIT AND "MIT OR Apache-2.0") + +PEERDIR( + contrib/python/attrs +) + +NO_LINT() + +PY_SRCS( + TOP_LEVEL + outcome/__init__.py + outcome/_impl.py + outcome/_util.py + outcome/_version.py +) + +RESOURCE_FILES( + PREFIX contrib/python/outcome/ + .dist-info/METADATA + .dist-info/top_level.txt + outcome/py.typed +) + +END() diff --git a/contrib/python/sniffio/.dist-info/METADATA b/contrib/python/sniffio/.dist-info/METADATA new file mode 100644 index 00000000000..88968aed169 --- /dev/null +++ b/contrib/python/sniffio/.dist-info/METADATA @@ -0,0 +1,104 @@ +Metadata-Version: 2.1 +Name: sniffio +Version: 1.3.1 +Summary: Sniff out which async library your code is running under +Author-email: "Nathaniel J. Smith" <[email protected]> +License: MIT OR Apache-2.0 +Project-URL: Homepage, https://github.com/python-trio/sniffio +Project-URL: Documentation, https://sniffio.readthedocs.io/ +Project-URL: Changelog, https://sniffio.readthedocs.io/en/latest/history.html +Keywords: async,trio,asyncio +Classifier: License :: OSI Approved :: MIT License +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Framework :: Trio +Classifier: Framework :: AsyncIO +Classifier: Operating System :: POSIX :: Linux +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: Microsoft :: Windows +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Intended Audience :: Developers +Classifier: Development Status :: 5 - Production/Stable +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE +License-File: LICENSE.APACHE2 +License-File: LICENSE.MIT + +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://sniffio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +.. image:: https://img.shields.io/pypi/v/sniffio.svg + :target: https://pypi.org/project/sniffio + :alt: Latest PyPi version + +.. image:: https://img.shields.io/conda/vn/conda-forge/sniffio.svg + :target: https://anaconda.org/conda-forge/sniffio + :alt: Latest conda-forge version + +.. image:: https://travis-ci.org/python-trio/sniffio.svg?branch=master + :target: https://travis-ci.org/python-trio/sniffio + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-trio/sniffio/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-trio/sniffio + :alt: Test coverage + +================================================================= +sniffio: Sniff out which async library your code is running under +================================================================= + +You're writing a library. You've decided to be ambitious, and support +multiple async I/O packages, like `Trio +<https://trio.readthedocs.io>`__, and `asyncio +<https://docs.python.org/3/library/asyncio.html>`__, and ... You've +written a bunch of clever code to handle all the differences. But... +how do you know *which* piece of clever code to run? + +This is a tiny package whose only purpose is to let you detect which +async library your code is running under. + +* Documentation: https://sniffio.readthedocs.io + +* Bug tracker and source code: https://github.com/python-trio/sniffio + +* License: MIT or Apache License 2.0, your choice + +* Contributor guide: https://trio.readthedocs.io/en/latest/contributing.html + +* Code of conduct: Contributors are requested to follow our `code of + conduct + <https://trio.readthedocs.io/en/latest/code-of-conduct.html>`_ + in all project spaces. + +This library is maintained by the Trio project, as a service to the +async Python community as a whole. + + +Quickstart +---------- + +.. code-block:: python3 + + from sniffio import current_async_library + import trio + import asyncio + + async def print_library(): + library = current_async_library() + print("This is:", library) + + # Prints "This is trio" + trio.run(print_library) + + # Prints "This is asyncio" + asyncio.run(print_library()) + +For more details, including how to add support to new async libraries, +`please peruse our fine manual <https://sniffio.readthedocs.io>`__. diff --git a/contrib/python/sniffio/.dist-info/top_level.txt b/contrib/python/sniffio/.dist-info/top_level.txt new file mode 100644 index 00000000000..01c650244d0 --- /dev/null +++ b/contrib/python/sniffio/.dist-info/top_level.txt @@ -0,0 +1 @@ +sniffio diff --git a/contrib/python/sniffio/.yandex_meta/yamaker.yaml b/contrib/python/sniffio/.yandex_meta/yamaker.yaml new file mode 100644 index 00000000000..1a9504527f4 --- /dev/null +++ b/contrib/python/sniffio/.yandex_meta/yamaker.yaml @@ -0,0 +1,2 @@ +mark_as_tests: + - sniffio/_tests/* diff --git a/contrib/python/sniffio/LICENSE b/contrib/python/sniffio/LICENSE new file mode 100644 index 00000000000..51f34429178 --- /dev/null +++ b/contrib/python/sniffio/LICENSE @@ -0,0 +1,3 @@ +This software is made available under the terms of *either* of the +licenses found in LICENSE.APACHE2 or LICENSE.MIT. Contributions to are +made under the terms of *both* these licenses. diff --git a/contrib/python/sniffio/LICENSE.APACHE2 b/contrib/python/sniffio/LICENSE.APACHE2 new file mode 100644 index 00000000000..d6456956733 --- /dev/null +++ b/contrib/python/sniffio/LICENSE.APACHE2 @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrib/python/sniffio/LICENSE.MIT b/contrib/python/sniffio/LICENSE.MIT new file mode 100644 index 00000000000..b8bb9718592 --- /dev/null +++ b/contrib/python/sniffio/LICENSE.MIT @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/contrib/python/sniffio/README.rst b/contrib/python/sniffio/README.rst new file mode 100644 index 00000000000..2a62cea5a0a --- /dev/null +++ b/contrib/python/sniffio/README.rst @@ -0,0 +1,76 @@ +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://sniffio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +.. image:: https://img.shields.io/pypi/v/sniffio.svg + :target: https://pypi.org/project/sniffio + :alt: Latest PyPi version + +.. image:: https://img.shields.io/conda/vn/conda-forge/sniffio.svg + :target: https://anaconda.org/conda-forge/sniffio + :alt: Latest conda-forge version + +.. image:: https://travis-ci.org/python-trio/sniffio.svg?branch=master + :target: https://travis-ci.org/python-trio/sniffio + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-trio/sniffio/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-trio/sniffio + :alt: Test coverage + +================================================================= +sniffio: Sniff out which async library your code is running under +================================================================= + +You're writing a library. You've decided to be ambitious, and support +multiple async I/O packages, like `Trio +<https://trio.readthedocs.io>`__, and `asyncio +<https://docs.python.org/3/library/asyncio.html>`__, and ... You've +written a bunch of clever code to handle all the differences. But... +how do you know *which* piece of clever code to run? + +This is a tiny package whose only purpose is to let you detect which +async library your code is running under. + +* Documentation: https://sniffio.readthedocs.io + +* Bug tracker and source code: https://github.com/python-trio/sniffio + +* License: MIT or Apache License 2.0, your choice + +* Contributor guide: https://trio.readthedocs.io/en/latest/contributing.html + +* Code of conduct: Contributors are requested to follow our `code of + conduct + <https://trio.readthedocs.io/en/latest/code-of-conduct.html>`_ + in all project spaces. + +This library is maintained by the Trio project, as a service to the +async Python community as a whole. + + +Quickstart +---------- + +.. code-block:: python3 + + from sniffio import current_async_library + import trio + import asyncio + + async def print_library(): + library = current_async_library() + print("This is:", library) + + # Prints "This is trio" + trio.run(print_library) + + # Prints "This is asyncio" + asyncio.run(print_library()) + +For more details, including how to add support to new async libraries, +`please peruse our fine manual <https://sniffio.readthedocs.io>`__. diff --git a/contrib/python/sniffio/sniffio/__init__.py b/contrib/python/sniffio/sniffio/__init__.py new file mode 100644 index 00000000000..63f2f19e409 --- /dev/null +++ b/contrib/python/sniffio/sniffio/__init__.py @@ -0,0 +1,17 @@ +"""Top-level package for sniffio.""" + +__all__ = [ + "current_async_library", + "AsyncLibraryNotFoundError", + "current_async_library_cvar", + "thread_local", +] + +from ._version import __version__ + +from ._impl import ( + current_async_library, + AsyncLibraryNotFoundError, + current_async_library_cvar, + thread_local, +) diff --git a/contrib/python/sniffio/sniffio/_impl.py b/contrib/python/sniffio/sniffio/_impl.py new file mode 100644 index 00000000000..c1a7bbf218b --- /dev/null +++ b/contrib/python/sniffio/sniffio/_impl.py @@ -0,0 +1,95 @@ +from contextvars import ContextVar +from typing import Optional +import sys +import threading + +current_async_library_cvar = ContextVar( + "current_async_library_cvar", default=None +) # type: ContextVar[Optional[str]] + + +class _ThreadLocal(threading.local): + # Since threading.local provides no explicit mechanism is for setting + # a default for a value, a custom class with a class attribute is used + # instead. + name = None # type: Optional[str] + + +thread_local = _ThreadLocal() + + +class AsyncLibraryNotFoundError(RuntimeError): + pass + + +def current_async_library() -> str: + """Detect which async library is currently running. + + The following libraries are currently supported: + + ================ =========== ============================ + Library Requires Magic string + ================ =========== ============================ + **Trio** Trio v0.6+ ``"trio"`` + **Curio** - ``"curio"`` + **asyncio** ``"asyncio"`` + **Trio-asyncio** v0.8.2+ ``"trio"`` or ``"asyncio"``, + depending on current mode + ================ =========== ============================ + + Returns: + A string like ``"trio"``. + + Raises: + AsyncLibraryNotFoundError: if called from synchronous context, + or if the current async library was not recognized. + + Examples: + + .. code-block:: python3 + + from sniffio import current_async_library + + async def generic_sleep(seconds): + library = current_async_library() + if library == "trio": + import trio + await trio.sleep(seconds) + elif library == "asyncio": + import asyncio + await asyncio.sleep(seconds) + # ... and so on ... + else: + raise RuntimeError(f"Unsupported library {library!r}") + + """ + value = thread_local.name + if value is not None: + return value + + value = current_async_library_cvar.get() + if value is not None: + return value + + # Need to sniff for asyncio + if "asyncio" in sys.modules: + import asyncio + try: + current_task = asyncio.current_task # type: ignore[attr-defined] + except AttributeError: + current_task = asyncio.Task.current_task # type: ignore[attr-defined] + try: + if current_task() is not None: + return "asyncio" + except RuntimeError: + pass + + # Sniff for curio (for now) + if 'curio' in sys.modules: + from curio.meta import curio_running + if curio_running(): + return 'curio' + + raise AsyncLibraryNotFoundError( + "unknown async library, or not in async context" + ) diff --git a/contrib/python/sniffio/sniffio/_version.py b/contrib/python/sniffio/sniffio/_version.py new file mode 100644 index 00000000000..0495d10545c --- /dev/null +++ b/contrib/python/sniffio/sniffio/_version.py @@ -0,0 +1,3 @@ +# This file is imported from __init__.py and exec'd from setup.py + +__version__ = "1.3.1" diff --git a/contrib/python/sniffio/sniffio/py.typed b/contrib/python/sniffio/sniffio/py.typed new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/contrib/python/sniffio/sniffio/py.typed diff --git a/contrib/python/sniffio/ya.make b/contrib/python/sniffio/ya.make new file mode 100644 index 00000000000..165b99c587e --- /dev/null +++ b/contrib/python/sniffio/ya.make @@ -0,0 +1,25 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(1.3.1) + +LICENSE(Apache-2.0 AND MIT) + +NO_LINT() + +PY_SRCS( + TOP_LEVEL + sniffio/__init__.py + sniffio/_impl.py + sniffio/_version.py +) + +RESOURCE_FILES( + PREFIX contrib/python/sniffio/ + .dist-info/METADATA + .dist-info/top_level.txt + sniffio/py.typed +) + +END() diff --git a/contrib/python/trio/.dist-info/METADATA b/contrib/python/trio/.dist-info/METADATA new file mode 100644 index 00000000000..b75c591ed29 --- /dev/null +++ b/contrib/python/trio/.dist-info/METADATA @@ -0,0 +1,186 @@ +Metadata-Version: 2.4 +Name: trio +Version: 0.31.0 +Summary: A friendly Python library for async concurrency and I/O +Author-email: "Nathaniel J. Smith" <[email protected]> +License-Expression: MIT OR Apache-2.0 +Project-URL: Homepage, https://github.com/python-trio/trio +Project-URL: Documentation, https://trio.readthedocs.io/ +Project-URL: Changelog, https://trio.readthedocs.io/en/latest/history.html +Keywords: async,io,networking,trio +Classifier: Development Status :: 4 - Beta +Classifier: Framework :: Trio +Classifier: Intended Audience :: Developers +Classifier: Operating System :: POSIX :: Linux +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: POSIX :: BSD +Classifier: Operating System :: Microsoft :: Windows +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Topic :: System :: Networking +Classifier: Typing :: Typed +Requires-Python: >=3.9 +Description-Content-Type: text/x-rst +License-File: LICENSE +License-File: LICENSE.APACHE2 +License-File: LICENSE.MIT +Requires-Dist: attrs>=23.2.0 +Requires-Dist: sortedcontainers +Requires-Dist: idna +Requires-Dist: outcome +Requires-Dist: sniffio>=1.3.0 +Requires-Dist: cffi>=1.14; os_name == "nt" and implementation_name != "pypy" +Requires-Dist: exceptiongroup; python_version < "3.11" +Dynamic: license-file + +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/forum-join%20now-blue.svg + :target: https://trio.discourse.group + :alt: Join forum + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://trio.readthedocs.io + :alt: Documentation + +.. image:: https://img.shields.io/pypi/v/trio.svg + :target: https://pypi.org/project/trio + :alt: Latest PyPi version + +.. image:: https://img.shields.io/conda/vn/conda-forge/trio.svg + :target: https://anaconda.org/conda-forge/trio + :alt: Latest conda-forge version + +.. image:: https://codecov.io/gh/python-trio/trio/branch/main/graph/badge.svg + :target: https://codecov.io/gh/python-trio/trio + :alt: Test coverage + +Trio – a friendly Python library for async concurrency and I/O +============================================================== + +.. image:: https://raw.githubusercontent.com/python-trio/trio/9b0bec646a31e0d0f67b8b6ecc6939726faf3e17/logo/logo-with-background.svg + :width: 200px + :align: right + +The Trio project aims to produce a production-quality, +`permissively licensed +<https://github.com/python-trio/trio/blob/main/LICENSE>`__, +async/await-native I/O library for Python. Like all async libraries, +its main purpose is to help you write programs that do **multiple +things at the same time** with **parallelized I/O**. A web spider that +wants to fetch lots of pages in parallel, a web server that needs to +juggle lots of downloads and websocket connections simultaneously, a +process supervisor monitoring multiple subprocesses... that sort of +thing. Compared to other libraries, Trio attempts to distinguish +itself with an obsessive focus on **usability** and +**correctness**. Concurrency is complicated; we try to make it *easy* +to get things *right*. + +Trio was built from the ground up to take advantage of the `latest +Python features <https://www.python.org/dev/peps/pep-0492/>`__, and +draws inspiration from `many sources +<https://github.com/python-trio/trio/wiki/Reading-list>`__, in +particular Dave Beazley's `Curio <https://curio.readthedocs.io/>`__. +The resulting design is radically simpler than older competitors like +`asyncio <https://docs.python.org/3/library/asyncio.html>`__ and +`Twisted <https://twistedmatrix.com/>`__, yet just as capable. Trio is +the Python I/O library I always wanted; I find it makes building +I/O-oriented programs easier, less error-prone, and just plain more +fun. `Perhaps you'll find the same +<https://github.com/python-trio/trio/wiki/Testimonials>`__. + +Trio is a mature and well-tested project: the overall design is solid, +and the existing features are fully documented and widely used in +production. While we occasionally make minor interface adjustments, +breaking changes are rare. We encourage you to use Trio with confidence, +but if you rely on long-term API stability, consider `subscribing to +issue #1 <https://github.com/python-trio/trio/issues/1>`__ for advance +notice of any compatibility updates. + + +Where to next? +-------------- + +**I want to try it out!** Awesome! We have a `friendly tutorial +<https://trio.readthedocs.io/en/stable/tutorial.html>`__ to get you +started; no prior experience with async coding is required. + +**Ugh, I don't want to read all that – show me some code!** If you're +impatient, then here's a `simple concurrency example +<https://trio.readthedocs.io/en/stable/tutorial.html#tutorial-example-tasks-intro>`__, +an `echo client +<https://trio.readthedocs.io/en/stable/tutorial.html#tutorial-echo-client-example>`__, +and an `echo server +<https://trio.readthedocs.io/en/stable/tutorial.html#tutorial-echo-server-example>`__. + +**How does Trio make programs easier to read and reason about than +competing approaches?** Trio is based on a new way of thinking that we +call "structured concurrency". The best theoretical introduction is +the article `Notes on structured concurrency, or: Go statement +considered harmful +<https://vorpus.org/blog/notes-on-structured-concurrency-or-go-statement-considered-harmful/>`__. +Or, `check out this talk at PyCon 2018 +<https://www.youtube.com/watch?v=oLkfnc_UMcE>`__ to see a +demonstration of implementing the "Happy Eyeballs" algorithm in an +older library versus Trio. + +**Cool, but will it work on my system?** Probably! As long as you have +some kind of Python 3.9-or-better (CPython or `currently maintained versions of +PyPy3 <https://doc.pypy.org/en/latest/faq.html#which-python-versions-does-pypy-implement>`__ +are both fine), and are using Linux, macOS, Windows, or FreeBSD, then Trio +will work. Other environments might work too, but those +are the ones we test on. And all of our dependencies are pure Python, +except for CFFI on Windows, which has wheels available, so +installation should be easy (no C compiler needed). + +**I tried it, but it's not working.** Sorry to hear that! You can try +asking for help in our `chat room +<https://gitter.im/python-trio/general>`__ or `forum +<https://trio.discourse.group>`__, `filing a bug +<https://github.com/python-trio/trio/issues/new>`__, or `posting a +question on StackOverflow +<https://stackoverflow.com/questions/ask?tags=python+python-trio>`__, +and we'll do our best to help you out. + +**Trio is awesome, and I want to help make it more awesome!** You're +the best! There's tons of work to do – filling in missing +functionality, building up an ecosystem of Trio-using libraries, +usability testing (e.g., maybe try teaching yourself or a friend to +use Trio and make a list of every error message you hit and place +where you got confused?), improving the docs, ... check out our `guide +for contributors +<https://trio.readthedocs.io/en/stable/contributing.html>`__! + +**I don't have any immediate plans to use it, but I love geeking out +about I/O library design!** That's a little weird? But let's be +honest, you'll fit in great around here. We have a `whole sub-forum +for discussing structured concurrency +<https://trio.discourse.group/c/structured-concurrency>`__ (developers +of other systems welcome!). Or check out our `discussion of design +choices +<https://trio.readthedocs.io/en/stable/design.html#user-level-api-principles>`__, +`reading list +<https://github.com/python-trio/trio/wiki/Reading-list>`__, and +`issues tagged design-discussion +<https://github.com/python-trio/trio/labels/design%20discussion>`__. + +**I want to make sure my company's lawyers won't get angry at me!** No +worries, Trio is permissively licensed under your choice of MIT or +Apache 2. See `LICENSE +<https://github.com/python-trio/trio/blob/main/LICENSE>`__ for details. + + +Code of conduct +--------------- + +Contributors are requested to follow our `code of conduct +<https://trio.readthedocs.io/en/stable/code-of-conduct.html>`__ in all +project spaces. diff --git a/contrib/python/trio/.dist-info/entry_points.txt b/contrib/python/trio/.dist-info/entry_points.txt new file mode 100644 index 00000000000..6563bece5ed --- /dev/null +++ b/contrib/python/trio/.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[hypothesis] +trio = trio._core._run:_hypothesis_plugin_setup diff --git a/contrib/python/trio/.dist-info/top_level.txt b/contrib/python/trio/.dist-info/top_level.txt new file mode 100644 index 00000000000..ae0d704f07c --- /dev/null +++ b/contrib/python/trio/.dist-info/top_level.txt @@ -0,0 +1 @@ +trio diff --git a/contrib/python/trio/.yandex_meta/yamaker.yaml b/contrib/python/trio/.yandex_meta/yamaker.yaml new file mode 100644 index 00000000000..2c838790a4e --- /dev/null +++ b/contrib/python/trio/.yandex_meta/yamaker.yaml @@ -0,0 +1,3 @@ +mark_as_tests: + - trio/_tests/* + - trio/**/_tests/* diff --git a/contrib/python/trio/LICENSE b/contrib/python/trio/LICENSE new file mode 100644 index 00000000000..b79c96408a4 --- /dev/null +++ b/contrib/python/trio/LICENSE @@ -0,0 +1,3 @@ +This software is made available under the terms of *either* of the +licenses found in LICENSE.APACHE2 or LICENSE.MIT. Contributions to +Trio are made under the terms of *both* these licenses. diff --git a/contrib/python/trio/LICENSE.APACHE2 b/contrib/python/trio/LICENSE.APACHE2 new file mode 100644 index 00000000000..d6456956733 --- /dev/null +++ b/contrib/python/trio/LICENSE.APACHE2 @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrib/python/trio/LICENSE.MIT b/contrib/python/trio/LICENSE.MIT new file mode 100644 index 00000000000..c26b9f32ae5 --- /dev/null +++ b/contrib/python/trio/LICENSE.MIT @@ -0,0 +1,22 @@ +Copyright Contributors to the Trio project. + +The MIT License (MIT) + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/contrib/python/trio/README.rst b/contrib/python/trio/README.rst new file mode 100644 index 00000000000..f62c2a37af1 --- /dev/null +++ b/contrib/python/trio/README.rst @@ -0,0 +1,145 @@ +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/forum-join%20now-blue.svg + :target: https://trio.discourse.group + :alt: Join forum + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://trio.readthedocs.io + :alt: Documentation + +.. image:: https://img.shields.io/pypi/v/trio.svg + :target: https://pypi.org/project/trio + :alt: Latest PyPi version + +.. image:: https://img.shields.io/conda/vn/conda-forge/trio.svg + :target: https://anaconda.org/conda-forge/trio + :alt: Latest conda-forge version + +.. image:: https://codecov.io/gh/python-trio/trio/branch/main/graph/badge.svg + :target: https://codecov.io/gh/python-trio/trio + :alt: Test coverage + +Trio – a friendly Python library for async concurrency and I/O +============================================================== + +.. image:: https://raw.githubusercontent.com/python-trio/trio/9b0bec646a31e0d0f67b8b6ecc6939726faf3e17/logo/logo-with-background.svg + :width: 200px + :align: right + +The Trio project aims to produce a production-quality, +`permissively licensed +<https://github.com/python-trio/trio/blob/main/LICENSE>`__, +async/await-native I/O library for Python. Like all async libraries, +its main purpose is to help you write programs that do **multiple +things at the same time** with **parallelized I/O**. A web spider that +wants to fetch lots of pages in parallel, a web server that needs to +juggle lots of downloads and websocket connections simultaneously, a +process supervisor monitoring multiple subprocesses... that sort of +thing. Compared to other libraries, Trio attempts to distinguish +itself with an obsessive focus on **usability** and +**correctness**. Concurrency is complicated; we try to make it *easy* +to get things *right*. + +Trio was built from the ground up to take advantage of the `latest +Python features <https://www.python.org/dev/peps/pep-0492/>`__, and +draws inspiration from `many sources +<https://github.com/python-trio/trio/wiki/Reading-list>`__, in +particular Dave Beazley's `Curio <https://curio.readthedocs.io/>`__. +The resulting design is radically simpler than older competitors like +`asyncio <https://docs.python.org/3/library/asyncio.html>`__ and +`Twisted <https://twistedmatrix.com/>`__, yet just as capable. Trio is +the Python I/O library I always wanted; I find it makes building +I/O-oriented programs easier, less error-prone, and just plain more +fun. `Perhaps you'll find the same +<https://github.com/python-trio/trio/wiki/Testimonials>`__. + +Trio is a mature and well-tested project: the overall design is solid, +and the existing features are fully documented and widely used in +production. While we occasionally make minor interface adjustments, +breaking changes are rare. We encourage you to use Trio with confidence, +but if you rely on long-term API stability, consider `subscribing to +issue #1 <https://github.com/python-trio/trio/issues/1>`__ for advance +notice of any compatibility updates. + + +Where to next? +-------------- + +**I want to try it out!** Awesome! We have a `friendly tutorial +<https://trio.readthedocs.io/en/stable/tutorial.html>`__ to get you +started; no prior experience with async coding is required. + +**Ugh, I don't want to read all that – show me some code!** If you're +impatient, then here's a `simple concurrency example +<https://trio.readthedocs.io/en/stable/tutorial.html#tutorial-example-tasks-intro>`__, +an `echo client +<https://trio.readthedocs.io/en/stable/tutorial.html#tutorial-echo-client-example>`__, +and an `echo server +<https://trio.readthedocs.io/en/stable/tutorial.html#tutorial-echo-server-example>`__. + +**How does Trio make programs easier to read and reason about than +competing approaches?** Trio is based on a new way of thinking that we +call "structured concurrency". The best theoretical introduction is +the article `Notes on structured concurrency, or: Go statement +considered harmful +<https://vorpus.org/blog/notes-on-structured-concurrency-or-go-statement-considered-harmful/>`__. +Or, `check out this talk at PyCon 2018 +<https://www.youtube.com/watch?v=oLkfnc_UMcE>`__ to see a +demonstration of implementing the "Happy Eyeballs" algorithm in an +older library versus Trio. + +**Cool, but will it work on my system?** Probably! As long as you have +some kind of Python 3.9-or-better (CPython or `currently maintained versions of +PyPy3 <https://doc.pypy.org/en/latest/faq.html#which-python-versions-does-pypy-implement>`__ +are both fine), and are using Linux, macOS, Windows, or FreeBSD, then Trio +will work. Other environments might work too, but those +are the ones we test on. And all of our dependencies are pure Python, +except for CFFI on Windows, which has wheels available, so +installation should be easy (no C compiler needed). + +**I tried it, but it's not working.** Sorry to hear that! You can try +asking for help in our `chat room +<https://gitter.im/python-trio/general>`__ or `forum +<https://trio.discourse.group>`__, `filing a bug +<https://github.com/python-trio/trio/issues/new>`__, or `posting a +question on StackOverflow +<https://stackoverflow.com/questions/ask?tags=python+python-trio>`__, +and we'll do our best to help you out. + +**Trio is awesome, and I want to help make it more awesome!** You're +the best! There's tons of work to do – filling in missing +functionality, building up an ecosystem of Trio-using libraries, +usability testing (e.g., maybe try teaching yourself or a friend to +use Trio and make a list of every error message you hit and place +where you got confused?), improving the docs, ... check out our `guide +for contributors +<https://trio.readthedocs.io/en/stable/contributing.html>`__! + +**I don't have any immediate plans to use it, but I love geeking out +about I/O library design!** That's a little weird? But let's be +honest, you'll fit in great around here. We have a `whole sub-forum +for discussing structured concurrency +<https://trio.discourse.group/c/structured-concurrency>`__ (developers +of other systems welcome!). Or check out our `discussion of design +choices +<https://trio.readthedocs.io/en/stable/design.html#user-level-api-principles>`__, +`reading list +<https://github.com/python-trio/trio/wiki/Reading-list>`__, and +`issues tagged design-discussion +<https://github.com/python-trio/trio/labels/design%20discussion>`__. + +**I want to make sure my company's lawyers won't get angry at me!** No +worries, Trio is permissively licensed under your choice of MIT or +Apache 2. See `LICENSE +<https://github.com/python-trio/trio/blob/main/LICENSE>`__ for details. + + +Code of conduct +--------------- + +Contributors are requested to follow our `code of conduct +<https://trio.readthedocs.io/en/stable/code-of-conduct.html>`__ in all +project spaces. diff --git a/contrib/python/trio/trio/__init__.py b/contrib/python/trio/trio/__init__.py new file mode 100644 index 00000000000..b937ac5b937 --- /dev/null +++ b/contrib/python/trio/trio/__init__.py @@ -0,0 +1,133 @@ +"""Trio - A friendly Python library for async concurrency and I/O""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +# General layout: +# +# trio/_core/... is the self-contained core library. It does various +# shenanigans to export a consistent "core API", but parts of the core API are +# too low-level to be recommended for regular use. +# +# trio/*.py define a set of more usable tools on top of this. They import from +# trio._core and from each other. +# +# This file pulls together the friendly public API, by re-exporting the more +# innocuous bits of the _core API + the higher-level tools from trio/*.py. +# +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) +# +# must be imported early to avoid circular import +from ._core import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED # isort: split + +# Submodules imported by default +from . import abc, from_thread, lowlevel, socket, to_thread +from ._channel import ( + MemoryChannelStatistics as MemoryChannelStatistics, + MemoryReceiveChannel as MemoryReceiveChannel, + MemorySendChannel as MemorySendChannel, + as_safe_channel as as_safe_channel, + open_memory_channel as open_memory_channel, +) +from ._core import ( + BrokenResourceError as BrokenResourceError, + BusyResourceError as BusyResourceError, + Cancelled as Cancelled, + CancelScope as CancelScope, + ClosedResourceError as ClosedResourceError, + EndOfChannel as EndOfChannel, + Nursery as Nursery, + RunFinishedError as RunFinishedError, + TaskStatus as TaskStatus, + TrioInternalError as TrioInternalError, + WouldBlock as WouldBlock, + current_effective_deadline as current_effective_deadline, + current_time as current_time, + open_nursery as open_nursery, + run as run, +) +from ._deprecate import TrioDeprecationWarning as TrioDeprecationWarning +from ._dtls import ( + DTLSChannel as DTLSChannel, + DTLSChannelStatistics as DTLSChannelStatistics, + DTLSEndpoint as DTLSEndpoint, +) +from ._file_io import open_file as open_file, wrap_file as wrap_file +from ._highlevel_generic import ( + StapledStream as StapledStream, + aclose_forcefully as aclose_forcefully, +) +from ._highlevel_open_tcp_listeners import ( + open_tcp_listeners as open_tcp_listeners, + serve_tcp as serve_tcp, +) +from ._highlevel_open_tcp_stream import open_tcp_stream as open_tcp_stream +from ._highlevel_open_unix_stream import open_unix_socket as open_unix_socket +from ._highlevel_serve_listeners import serve_listeners as serve_listeners +from ._highlevel_socket import ( + SocketListener as SocketListener, + SocketStream as SocketStream, +) +from ._highlevel_ssl_helpers import ( + open_ssl_over_tcp_listeners as open_ssl_over_tcp_listeners, + open_ssl_over_tcp_stream as open_ssl_over_tcp_stream, + serve_ssl_over_tcp as serve_ssl_over_tcp, +) +from ._path import Path as Path, PosixPath as PosixPath, WindowsPath as WindowsPath +from ._signals import open_signal_receiver as open_signal_receiver +from ._ssl import ( + NeedHandshakeError as NeedHandshakeError, + SSLListener as SSLListener, + SSLStream as SSLStream, +) +from ._subprocess import Process as Process, run_process as run_process +from ._sync import ( + CapacityLimiter as CapacityLimiter, + CapacityLimiterStatistics as CapacityLimiterStatistics, + Condition as Condition, + ConditionStatistics as ConditionStatistics, + Event as Event, + EventStatistics as EventStatistics, + Lock as Lock, + LockStatistics as LockStatistics, + Semaphore as Semaphore, + StrictFIFOLock as StrictFIFOLock, +) +from ._timeouts import ( + TooSlowError as TooSlowError, + fail_after as fail_after, + fail_at as fail_at, + move_on_after as move_on_after, + move_on_at as move_on_at, + sleep as sleep, + sleep_forever as sleep_forever, + sleep_until as sleep_until, +) +from ._version import __version__ as __version__ + +# Not imported by default, but mentioned here so static analysis tools like +# pylint will know that it exists. +if TYPE_CHECKING: + from . import testing + +from . import _deprecate as _deprecate + +_deprecate.deprecate_attributes(__name__, {}) + +# Having the public path in .__module__ attributes is important for: +# - exception names in printed tracebacks +# - sphinx :show-inheritance: +# - deprecation warnings +# - pickle +# - probably other stuff +from ._util import fixup_module_metadata + +fixup_module_metadata(__name__, globals()) +fixup_module_metadata(lowlevel.__name__, lowlevel.__dict__) +fixup_module_metadata(socket.__name__, socket.__dict__) +fixup_module_metadata(abc.__name__, abc.__dict__) +fixup_module_metadata(from_thread.__name__, from_thread.__dict__) +fixup_module_metadata(to_thread.__name__, to_thread.__dict__) +del fixup_module_metadata +del TYPE_CHECKING diff --git a/contrib/python/trio/trio/__main__.py b/contrib/python/trio/trio/__main__.py new file mode 100644 index 00000000000..3b7c898ad50 --- /dev/null +++ b/contrib/python/trio/trio/__main__.py @@ -0,0 +1,3 @@ +from trio._repl import main + +main(locals()) diff --git a/contrib/python/trio/trio/_abc.py b/contrib/python/trio/trio/_abc.py new file mode 100644 index 00000000000..abb68243810 --- /dev/null +++ b/contrib/python/trio/trio/_abc.py @@ -0,0 +1,714 @@ +from __future__ import annotations + +import socket +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, TypeVar + +import trio + +if TYPE_CHECKING: + from types import TracebackType + + from typing_extensions import Self + + # both of these introduce circular imports if outside a TYPE_CHECKING guard + from ._socket import SocketType + from .lowlevel import Task + + +class Clock(ABC): + """The interface for custom run loop clocks.""" + + __slots__ = () + + @abstractmethod + def start_clock(self) -> None: + """Do any setup this clock might need. + + Called at the beginning of the run. + + """ + + @abstractmethod + def current_time(self) -> float: + """Return the current time, according to this clock. + + This is used to implement functions like :func:`trio.current_time` and + :func:`trio.move_on_after`. + + Returns: + float: The current time. + + """ + + @abstractmethod + def deadline_to_sleep_time(self, deadline: float) -> float: + """Compute the real time until the given deadline. + + This is called before we enter a system-specific wait function like + :func:`select.select`, to get the timeout to pass. + + For a clock using wall-time, this should be something like:: + + return deadline - self.current_time() + + but of course it may be different if you're implementing some kind of + virtual clock. + + Args: + deadline (float): The absolute time of the next deadline, + according to this clock. + + Returns: + float: The number of real seconds to sleep until the given + deadline. May be :data:`math.inf`. + + """ + + +class Instrument(ABC): # noqa: B024 # conceptually is ABC + """The interface for run loop instrumentation. + + Instruments don't have to inherit from this abstract base class, and all + of these methods are optional. This class serves mostly as documentation. + + """ + + __slots__ = () + + def before_run(self) -> None: + """Called at the beginning of :func:`trio.run`.""" + return + + def after_run(self) -> None: + """Called just before :func:`trio.run` returns.""" + return + + def task_spawned(self, task: Task) -> None: + """Called when the given task is created. + + Args: + task (trio.lowlevel.Task): The new task. + + """ + return + + def task_scheduled(self, task: Task) -> None: + """Called when the given task becomes runnable. + + It may still be some time before it actually runs, if there are other + runnable tasks ahead of it. + + Args: + task (trio.lowlevel.Task): The task that became runnable. + + """ + return + + def before_task_step(self, task: Task) -> None: + """Called immediately before we resume running the given task. + + Args: + task (trio.lowlevel.Task): The task that is about to run. + + """ + return + + def after_task_step(self, task: Task) -> None: + """Called when we return to the main run loop after a task has yielded. + + Args: + task (trio.lowlevel.Task): The task that just ran. + + """ + return + + def task_exited(self, task: Task) -> None: + """Called when the given task exits. + + Args: + task (trio.lowlevel.Task): The finished task. + + """ + return + + def before_io_wait(self, timeout: float) -> None: + """Called before blocking to wait for I/O readiness. + + Args: + timeout (float): The number of seconds we are willing to wait. + + """ + return + + def after_io_wait(self, timeout: float) -> None: + """Called after handling pending I/O. + + Args: + timeout (float): The number of seconds we were willing to + wait. This much time may or may not have elapsed, depending on + whether any I/O was ready. + + """ + return + + +class HostnameResolver(ABC): + """If you have a custom hostname resolver, then implementing + :class:`HostnameResolver` allows you to register this to be used by Trio. + + See :func:`trio.socket.set_custom_hostname_resolver`. + + """ + + __slots__ = () + + @abstractmethod + async def getaddrinfo( + self, + host: bytes | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes], + ] + ]: + """A custom implementation of :func:`~trio.socket.getaddrinfo`. + + Called by :func:`trio.socket.getaddrinfo`. + + If ``host`` is given as a numeric IP address, then + :func:`~trio.socket.getaddrinfo` may handle the request itself rather + than calling this method. + + Any required IDNA encoding is handled before calling this function; + your implementation can assume that it will never see U-labels like + ``"café.com"``, and only needs to handle A-labels like + ``b"xn--caf-dma.com"``.""" # spellchecker:disable-line + + @abstractmethod + async def getnameinfo( + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, + ) -> tuple[str, str]: + """A custom implementation of :func:`~trio.socket.getnameinfo`. + + Called by :func:`trio.socket.getnameinfo`. + + """ + + +class SocketFactory(ABC): + """If you write a custom class implementing the Trio socket interface, + then you can use a :class:`SocketFactory` to get Trio to use it. + + See :func:`trio.socket.set_custom_socket_factory`. + + """ + + __slots__ = () + + @abstractmethod + def socket( + self, + family: socket.AddressFamily | int = socket.AF_INET, + type: socket.SocketKind | int = socket.SOCK_STREAM, + proto: int = 0, + ) -> SocketType: + """Create and return a socket object. + + Your socket object must inherit from :class:`trio.socket.SocketType`, + which is an empty class whose only purpose is to "mark" which classes + should be considered valid Trio sockets. + + Called by :func:`trio.socket.socket`. + + Note that unlike :func:`trio.socket.socket`, this does not take a + ``fileno=`` argument. If a ``fileno=`` is specified, then + :func:`trio.socket.socket` returns a regular Trio socket object + instead of calling this method. + + """ + + +class AsyncResource(ABC): + """A standard interface for resources that needs to be cleaned up, and + where that cleanup may require blocking operations. + + This class distinguishes between "graceful" closes, which may perform I/O + and thus block, and a "forceful" close, which cannot. For example, cleanly + shutting down a TLS-encrypted connection requires sending a "goodbye" + message; but if a peer has become non-responsive, then sending this + message might block forever, so we may want to just drop the connection + instead. Therefore the :meth:`aclose` method is unusual in that it + should always close the connection (or at least make its best attempt) + *even if it fails*; failure indicates a failure to achieve grace, not a + failure to close the connection. + + Objects that implement this interface can be used as async context + managers, i.e., you can write:: + + async with create_resource() as some_async_resource: + ... + + Entering the context manager is synchronous (not a checkpoint); exiting it + calls :meth:`aclose`. The default implementations of + ``__aenter__`` and ``__aexit__`` should be adequate for all subclasses. + + """ + + __slots__ = () + + @abstractmethod + async def aclose(self) -> None: + """Close this resource, possibly blocking. + + IMPORTANT: This method may block in order to perform a "graceful" + shutdown. But, if this fails, then it still *must* close any + underlying resources before returning. An error from this method + indicates a failure to achieve grace, *not* a failure to close the + connection. + + For example, suppose we call :meth:`aclose` on a TLS-encrypted + connection. This requires sending a "goodbye" message; but if the peer + has become non-responsive, then our attempt to send this message might + block forever, and eventually time out and be cancelled. In this case + the :meth:`aclose` method on :class:`~trio.SSLStream` will + immediately close the underlying transport stream using + :func:`trio.aclose_forcefully` before raising :exc:`~trio.Cancelled`. + + If the resource is already closed, then this method should silently + succeed. + + Once this method completes, any other pending or future operations on + this resource should generally raise :exc:`~trio.ClosedResourceError`, + unless there's a good reason to do otherwise. + + See also: :func:`trio.aclose_forcefully`. + + """ + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.aclose() + + +class SendStream(AsyncResource): + """A standard interface for sending data on a byte stream. + + The underlying stream may be unidirectional, or bidirectional. If it's + bidirectional, then you probably want to also implement + :class:`ReceiveStream`, which makes your object a :class:`Stream`. + + :class:`SendStream` objects also implement the :class:`AsyncResource` + interface, so they can be closed by calling :meth:`~AsyncResource.aclose` + or using an ``async with`` block. + + If you want to send Python objects rather than raw bytes, see + :class:`SendChannel`. + + """ + + __slots__ = () + + @abstractmethod + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + """Sends the given data through the stream, blocking if necessary. + + Args: + data (bytes, bytearray, or memoryview): The data to send. + + Raises: + trio.BusyResourceError: if another task is already executing a + :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or + :meth:`HalfCloseableStream.send_eof` on this stream. + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`send_all` is running. + + Most low-level operations in Trio provide a guarantee: if they raise + :exc:`trio.Cancelled`, this means that they had no effect, so the + system remains in a known state. This is **not true** for + :meth:`send_all`. If this operation raises :exc:`trio.Cancelled` (or + any other exception for that matter), then it may have sent some, all, + or none of the requested data, and there is no way to know which. + + """ + + @abstractmethod + async def wait_send_all_might_not_block(self) -> None: + """Block until it's possible that :meth:`send_all` might not block. + + This method may return early: it's possible that after it returns, + :meth:`send_all` will still block. (In the worst case, if no better + implementation is available, then it might always return immediately + without blocking. It's nice to do better than that when possible, + though.) + + This method **must not** return *late*: if it's possible for + :meth:`send_all` to complete without blocking, then it must + return. When implementing it, err on the side of returning early. + + Raises: + trio.BusyResourceError: if another task is already executing a + :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or + :meth:`HalfCloseableStream.send_eof` on this stream. + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`wait_send_all_might_not_block` is running. + + Note: + + This method is intended to aid in implementing protocols that want + to delay choosing which data to send until the last moment. E.g., + suppose you're working on an implementation of a remote display server + like `VNC + <https://en.wikipedia.org/wiki/Virtual_Network_Computing>`__, and + the network connection is currently backed up so that if you call + :meth:`send_all` now then it will sit for 0.5 seconds before actually + sending anything. In this case it doesn't make sense to take a + screenshot, then wait 0.5 seconds, and then send it, because the + screen will keep changing while you wait; it's better to wait 0.5 + seconds, then take the screenshot, and then send it, because this + way the data you deliver will be more + up-to-date. Using :meth:`wait_send_all_might_not_block` makes it + possible to implement the better strategy. + + If you use this method, you might also want to read up on + ``TCP_NOTSENT_LOWAT``. + + Further reading: + + * `Prioritization Only Works When There's Pending Data to Prioritize + <https://insouciant.org/tech/prioritization-only-works-when-theres-pending-data-to-prioritize/>`__ + + * WWDC 2015: Your App and Next Generation Networks: `slides + <http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1>`__, + `video and transcript + <https://developer.apple.com/videos/play/wwdc2015/719/>`__ + + """ + + +class ReceiveStream(AsyncResource): + """A standard interface for receiving data on a byte stream. + + The underlying stream may be unidirectional, or bidirectional. If it's + bidirectional, then you probably want to also implement + :class:`SendStream`, which makes your object a :class:`Stream`. + + :class:`ReceiveStream` objects also implement the :class:`AsyncResource` + interface, so they can be closed by calling :meth:`~AsyncResource.aclose` + or using an ``async with`` block. + + If you want to receive Python objects rather than raw bytes, see + :class:`ReceiveChannel`. + + `ReceiveStream` objects can be used in ``async for`` loops. Each iteration + will produce an arbitrary sized chunk of bytes, like calling + `receive_some` with no arguments. Every chunk will contain at least one + byte, and the loop automatically exits when reaching end-of-file. + + """ + + __slots__ = () + + @abstractmethod + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: + """Wait until there is data available on this stream, and then return + some of it. + + A return value of ``b""`` (an empty bytestring) indicates that the + stream has reached end-of-file. Implementations should be careful that + they return ``b""`` if, and only if, the stream has reached + end-of-file! + + Args: + max_bytes (int): The maximum number of bytes to return. Must be + greater than zero. Optional; if omitted, then the stream object + is free to pick a reasonable default. + + Returns: + bytes or bytearray: The data received. + + Raises: + trio.BusyResourceError: if two tasks attempt to call + :meth:`receive_some` on the same stream at the same time. + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`receive_some` is running. + + """ + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> bytes | bytearray: + data = await self.receive_some() + if not data: + raise StopAsyncIteration + return data + + +class Stream(SendStream, ReceiveStream): + """A standard interface for interacting with bidirectional byte streams. + + A :class:`Stream` is an object that implements both the + :class:`SendStream` and :class:`ReceiveStream` interfaces. + + If implementing this interface, you should consider whether you can go one + step further and implement :class:`HalfCloseableStream`. + + """ + + __slots__ = () + + +class HalfCloseableStream(Stream): + """This interface extends :class:`Stream` to also allow closing the send + part of the stream without closing the receive part. + + """ + + __slots__ = () + + @abstractmethod + async def send_eof(self) -> None: + """Send an end-of-file indication on this stream, if possible. + + The difference between :meth:`send_eof` and + :meth:`~AsyncResource.aclose` is that :meth:`send_eof` is a + *unidirectional* end-of-file indication. After you call this method, + you shouldn't try sending any more data on this stream, and your + remote peer should receive an end-of-file indication (eventually, + after receiving all the data you sent before that). But, they may + continue to send data to you, and you can continue to receive it by + calling :meth:`~ReceiveStream.receive_some`. You can think of it as + calling :meth:`~AsyncResource.aclose` on just the + :class:`SendStream` "half" of the stream object (and in fact that's + literally how :class:`trio.StapledStream` implements it). + + Examples: + + * On a socket, this corresponds to ``shutdown(..., SHUT_WR)`` (`man + page <https://linux.die.net/man/2/shutdown>`__). + + * The SSH protocol provides the ability to multiplex bidirectional + "channels" on top of a single encrypted connection. A Trio + implementation of SSH could expose these channels as + :class:`HalfCloseableStream` objects, and calling :meth:`send_eof` + would send an ``SSH_MSG_CHANNEL_EOF`` request (see `RFC 4254 §5.3 + <https://tools.ietf.org/html/rfc4254#section-5.3>`__). + + * On an SSL/TLS-encrypted connection, the protocol doesn't provide any + way to do a unidirectional shutdown without closing the connection + entirely, so :class:`~trio.SSLStream` implements + :class:`Stream`, not :class:`HalfCloseableStream`. + + If an EOF has already been sent, then this method should silently + succeed. + + Raises: + trio.BusyResourceError: if another task is already executing a + :meth:`~SendStream.send_all`, + :meth:`~SendStream.wait_send_all_might_not_block`, or + :meth:`send_eof` on this stream. + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`send_eof` is running. + + """ + + +# A regular invariant generic type +T = TypeVar("T") + +# The type of object produced by a ReceiveChannel (covariant because +# ReceiveChannel[Derived] can be passed to someone expecting +# ReceiveChannel[Base]) +ReceiveType = TypeVar("ReceiveType", covariant=True) + +# The type of object accepted by a SendChannel (contravariant because +# SendChannel[Base] can be passed to someone expecting +# SendChannel[Derived]) +SendType = TypeVar("SendType", contravariant=True) + +# The type of object produced by a Listener (covariant plus must be +# an AsyncResource) +T_resource = TypeVar("T_resource", bound=AsyncResource, covariant=True) + + +class Listener(AsyncResource, Generic[T_resource]): + """A standard interface for listening for incoming connections. + + :class:`Listener` objects also implement the :class:`AsyncResource` + interface, so they can be closed by calling :meth:`~AsyncResource.aclose` + or using an ``async with`` block. + + """ + + __slots__ = () + + @abstractmethod + async def accept(self) -> T_resource: + """Wait until an incoming connection arrives, and then return it. + + Returns: + AsyncResource: An object representing the incoming connection. In + practice this is generally some kind of :class:`Stream`, + but in principle you could also define a :class:`Listener` that + returned, say, channel objects. + + Raises: + trio.BusyResourceError: if two tasks attempt to call + :meth:`accept` on the same listener at the same time. + trio.ClosedResourceError: if you previously closed this listener + object, or if another task closes this listener object while + :meth:`accept` is running. + + Listeners don't generally raise :exc:`~trio.BrokenResourceError`, + because for listeners there is no general condition of "the + network/remote peer broke the connection" that can be handled in a + generic way, like there is for streams. Other errors *can* occur and + be raised from :meth:`accept` – for example, if you run out of file + descriptors then you might get an :class:`OSError` with its errno set + to ``EMFILE``. + + """ + + +class SendChannel(AsyncResource, Generic[SendType]): + """A standard interface for sending Python objects to some receiver. + + `SendChannel` objects also implement the `AsyncResource` interface, so + they can be closed by calling `~AsyncResource.aclose` or using an ``async + with`` block. + + If you want to send raw bytes rather than Python objects, see + `SendStream`. + + """ + + __slots__ = () + + @abstractmethod + async def send(self, value: SendType) -> None: + """Attempt to send an object through the channel, blocking if necessary. + + Args: + value (object): The object to send. + + Raises: + trio.BrokenResourceError: if something has gone wrong, and the + channel is broken. For example, you may get this if the receiver + has already been closed. + trio.ClosedResourceError: if you previously closed this + :class:`SendChannel` object, or if another task closes it while + :meth:`send` is running. + trio.BusyResourceError: some channels allow multiple tasks to call + `send` at the same time, but others don't. If you try to call + `send` simultaneously from multiple tasks on a channel that + doesn't support it, then you can get `~trio.BusyResourceError`. + + """ + + +class ReceiveChannel(AsyncResource, Generic[ReceiveType]): + """A standard interface for receiving Python objects from some sender. + + You can iterate over a :class:`ReceiveChannel` using an ``async for`` + loop:: + + async for value in receive_channel: + ... + + This is equivalent to calling :meth:`receive` repeatedly. The loop exits + without error when `receive` raises `~trio.EndOfChannel`. + + `ReceiveChannel` objects also implement the `AsyncResource` interface, so + they can be closed by calling `~AsyncResource.aclose` or using an ``async + with`` block. + + If you want to receive raw bytes rather than Python objects, see + `ReceiveStream`. + + """ + + __slots__ = () + + @abstractmethod + async def receive(self) -> ReceiveType: + """Attempt to receive an incoming object, blocking if necessary. + + Returns: + object: Whatever object was received. + + Raises: + trio.EndOfChannel: if the sender has been closed cleanly, and no + more objects are coming. This is not an error condition. + trio.ClosedResourceError: if you previously closed this + :class:`ReceiveChannel` object. + trio.BrokenResourceError: if something has gone wrong, and the + channel is broken. + trio.BusyResourceError: some channels allow multiple tasks to call + `receive` at the same time, but others don't. If you try to call + `receive` simultaneously from multiple tasks on a channel that + doesn't support it, then you can get `~trio.BusyResourceError`. + + """ + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> ReceiveType: + try: + return await self.receive() + except trio.EndOfChannel: + raise StopAsyncIteration from None + + +# these are necessary for Sphinx's :show-inheritance: with type args. +# (this should be removed if possible) +# see: https://github.com/python/cpython/issues/123250 +SendChannel.__module__ = SendChannel.__module__.replace("_abc", "abc") +ReceiveChannel.__module__ = ReceiveChannel.__module__.replace("_abc", "abc") +Listener.__module__ = Listener.__module__.replace("_abc", "abc") + + +class Channel(SendChannel[T], ReceiveChannel[T]): + """A standard interface for interacting with bidirectional channels. + + A `Channel` is an object that implements both the `SendChannel` and + `ReceiveChannel` interfaces, so you can both send and receive objects. + + """ + + __slots__ = () + + +# see above +Channel.__module__ = Channel.__module__.replace("_abc", "abc") diff --git a/contrib/python/trio/trio/_channel.py b/contrib/python/trio/trio/_channel.py new file mode 100644 index 00000000000..2afca9d7cdb --- /dev/null +++ b/contrib/python/trio/trio/_channel.py @@ -0,0 +1,615 @@ +from __future__ import annotations + +import sys +from collections import OrderedDict, deque +from collections.abc import AsyncGenerator, Callable # noqa: TC003 # Needed for Sphinx +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from functools import wraps +from math import inf +from typing import ( + TYPE_CHECKING, + Generic, +) + +import attrs +from outcome import Error, Value + +import trio + +from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T +from ._core import Abort, RaiseCancelT, Task, enable_ki_protection +from ._util import ( + MultipleExceptionError, + NoPublicConstructor, + final, + generic_function, + raise_single_exception_from_group, +) + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + +if TYPE_CHECKING: + from types import TracebackType + + from typing_extensions import ParamSpec, Self + + P = ParamSpec("P") +elif "sphinx.ext.autodoc" in sys.modules: + # P needs to exist for Sphinx to parse the type hints successfully. + try: + from typing_extensions import ParamSpec + except ImportError: + P = ... # This is valid in Callable, though not correct + else: + P = ParamSpec("P") + + +def _open_memory_channel( + max_buffer_size: int | float, # noqa: PYI041 +) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: + """Open a channel for passing objects between tasks within a process. + + Memory channels are lightweight, cheap to allocate, and entirely + in-memory. They don't involve any operating-system resources, or any kind + of serialization. They just pass Python objects directly between tasks + (with a possible stop in an internal buffer along the way). + + Channel objects can be closed by calling `~trio.abc.AsyncResource.aclose` + or using ``async with``. They are *not* automatically closed when garbage + collected. Closing memory channels isn't mandatory, but it is generally a + good idea, because it helps avoid situations where tasks get stuck waiting + on a channel when there's no-one on the other side. See + :ref:`channel-shutdown` for details. + + Memory channel operations are all atomic with respect to + cancellation, either `~trio.abc.ReceiveChannel.receive` will + successfully return an object, or it will raise :exc:`Cancelled` + while leaving the channel unchanged. + + Args: + max_buffer_size (int or math.inf): The maximum number of items that can + be buffered in the channel before :meth:`~trio.abc.SendChannel.send` + blocks. Choosing a sensible value here is important to ensure that + backpressure is communicated promptly and avoid unnecessary latency; + see :ref:`channel-buffering` for more details. If in doubt, use 0. + + Returns: + A pair ``(send_channel, receive_channel)``. If you have + trouble remembering which order these go in, remember: data + flows from left → right. + + In addition to the standard channel methods, all memory channel objects + provide a ``statistics()`` method, which returns an object with the + following fields: + + * ``current_buffer_used``: The number of items currently stored in the + channel buffer. + * ``max_buffer_size``: The maximum number of items allowed in the buffer, + as passed to :func:`open_memory_channel`. + * ``open_send_channels``: The number of open + :class:`MemorySendChannel` endpoints pointing to this channel. + Initially 1, but can be increased by + :meth:`MemorySendChannel.clone`. + * ``open_receive_channels``: Likewise, but for open + :class:`MemoryReceiveChannel` endpoints. + * ``tasks_waiting_send``: The number of tasks blocked in ``send`` on this + channel (summing over all clones). + * ``tasks_waiting_receive``: The number of tasks blocked in ``receive`` on + this channel (summing over all clones). + + """ + if max_buffer_size != inf and not isinstance(max_buffer_size, int): + raise TypeError("max_buffer_size must be an integer or math.inf") + if max_buffer_size < 0: + raise ValueError("max_buffer_size must be >= 0") + state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size) + return ( + MemorySendChannel[T]._create(state), + MemoryReceiveChannel[T]._create(state), + ) + + +# This workaround requires python3.9+, once older python versions are not supported +# or there's a better way of achieving type-checking on a generic factory function, +# it could replace the normal function header +if TYPE_CHECKING: + # written as a class so you can say open_memory_channel[int](5) + class open_memory_channel(tuple["MemorySendChannel[T]", "MemoryReceiveChannel[T]"]): + def __new__( # type: ignore[misc] # "must return a subtype" + cls, + max_buffer_size: int | float, # noqa: PYI041 + ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: + return _open_memory_channel(max_buffer_size) + + def __init__(self, max_buffer_size: int | float) -> None: # noqa: PYI041 + ... + +else: + # apply the generic_function decorator to make open_memory_channel indexable + # so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime + open_memory_channel = generic_function(_open_memory_channel) + + +class MemoryChannelStatistics: + current_buffer_used: int + max_buffer_size: int | float + open_send_channels: int + open_receive_channels: int + tasks_waiting_send: int + tasks_waiting_receive: int + + +class MemoryChannelState(Generic[T]): + max_buffer_size: int | float + data: deque[T] = attrs.Factory(deque) + # Counts of open endpoints using this state + open_send_channels: int = 0 + open_receive_channels: int = 0 + # {task: value} + send_tasks: OrderedDict[Task, T] = attrs.Factory(OrderedDict) + # {task: None} + receive_tasks: OrderedDict[Task, None] = attrs.Factory(OrderedDict) + + def statistics(self) -> MemoryChannelStatistics: + return MemoryChannelStatistics( + current_buffer_used=len(self.data), + max_buffer_size=self.max_buffer_size, + open_send_channels=self.open_send_channels, + open_receive_channels=self.open_receive_channels, + tasks_waiting_send=len(self.send_tasks), + tasks_waiting_receive=len(self.receive_tasks), + ) + + +@final [email protected](eq=False, repr=False, slots=False) +class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor): + _state: MemoryChannelState[SendType] + _closed: bool = False + # This is just the tasks waiting on *this* object. As compared to + # self._state.send_tasks, which includes tasks from this object and + # all clones. + _tasks: set[Task] = attrs.Factory(set) + + def __attrs_post_init__(self) -> None: + self._state.open_send_channels += 1 + + def __repr__(self) -> str: + return f"<send channel at {id(self):#x}, using buffer at {id(self._state):#x}>" + + def statistics(self) -> MemoryChannelStatistics: + """Returns a `MemoryChannelStatistics` for the memory channel this is + associated with.""" + # XX should we also report statistics specific to this object? + return self._state.statistics() + + @enable_ki_protection + def send_nowait(self, value: SendType) -> None: + """Like `~trio.abc.SendChannel.send`, but if the channel's buffer is + full, raises `WouldBlock` instead of blocking. + + """ + if self._closed: + raise trio.ClosedResourceError + if self._state.open_receive_channels == 0: + raise trio.BrokenResourceError + if self._state.receive_tasks: + assert not self._state.data + task, _ = self._state.receive_tasks.popitem(last=False) + task.custom_sleep_data._tasks.remove(task) + trio.lowlevel.reschedule(task, Value(value)) + elif len(self._state.data) < self._state.max_buffer_size: + self._state.data.append(value) + else: + raise trio.WouldBlock + + @enable_ki_protection + async def send(self, value: SendType) -> None: + """See `SendChannel.send <trio.abc.SendChannel.send>`. + + Memory channels allow multiple tasks to call `send` at the same time. + + """ + await trio.lowlevel.checkpoint_if_cancelled() + try: + self.send_nowait(value) + except trio.WouldBlock: + pass + else: + await trio.lowlevel.cancel_shielded_checkpoint() + return + + task = trio.lowlevel.current_task() + self._tasks.add(task) + self._state.send_tasks[task] = value + task.custom_sleep_data = self + + def abort_fn(_: RaiseCancelT) -> Abort: + self._tasks.remove(task) + del self._state.send_tasks[task] + return trio.lowlevel.Abort.SUCCEEDED + + await trio.lowlevel.wait_task_rescheduled(abort_fn) + + # Return type must be stringified or use a TypeVar + @enable_ki_protection + def clone(self) -> MemorySendChannel[SendType]: + """Clone this send channel object. + + This returns a new `MemorySendChannel` object, which acts as a + duplicate of the original: sending on the new object does exactly the + same thing as sending on the old object. (If you're familiar with + `os.dup`, then this is a similar idea.) + + However, closing one of the objects does not close the other, and + receivers don't get `EndOfChannel` until *all* clones have been + closed. + + This is useful for communication patterns that involve multiple + producers all sending objects to the same destination. If you give + each producer its own clone of the `MemorySendChannel`, and then make + sure to close each `MemorySendChannel` when it's finished, receivers + will automatically get notified when all producers are finished. See + :ref:`channel-mpmc` for examples. + + Raises: + trio.ClosedResourceError: if you already closed this + `MemorySendChannel` object. + + """ + if self._closed: + raise trio.ClosedResourceError + return MemorySendChannel._create(self._state) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + @enable_ki_protection + def close(self) -> None: + """Close this send channel object synchronously. + + All channel objects have an asynchronous `~.AsyncResource.aclose` method. + Memory channels can also be closed synchronously. This has the same + effect on the channel and other tasks using it, but `close` is not a + trio checkpoint. This simplifies cleaning up in cancelled tasks. + + Using ``with send_channel:`` will close the channel object on leaving + the with block. + + """ + if self._closed: + return + self._closed = True + for task in self._tasks: + trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError())) + del self._state.send_tasks[task] + self._tasks.clear() + self._state.open_send_channels -= 1 + if self._state.open_send_channels == 0: + assert not self._state.send_tasks + for task in self._state.receive_tasks: + task.custom_sleep_data._tasks.remove(task) + trio.lowlevel.reschedule(task, Error(trio.EndOfChannel())) + self._state.receive_tasks.clear() + + @enable_ki_protection + async def aclose(self) -> None: + """Close this send channel object asynchronously. + + See `MemorySendChannel.close`.""" + self.close() + await trio.lowlevel.checkpoint() + + +@final [email protected](eq=False, repr=False, slots=False) +class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor): + _state: MemoryChannelState[ReceiveType] + _closed: bool = False + _tasks: set[trio._core._run.Task] = attrs.Factory(set) + + def __attrs_post_init__(self) -> None: + self._state.open_receive_channels += 1 + + def statistics(self) -> MemoryChannelStatistics: + """Returns a `MemoryChannelStatistics` for the memory channel this is + associated with.""" + return self._state.statistics() + + def __repr__(self) -> str: + return ( + f"<receive channel at {id(self):#x}, using buffer at {id(self._state):#x}>" + ) + + @enable_ki_protection + def receive_nowait(self) -> ReceiveType: + """Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing + ready to receive, raises `WouldBlock` instead of blocking. + + """ + if self._closed: + raise trio.ClosedResourceError + if self._state.send_tasks: + task, value = self._state.send_tasks.popitem(last=False) + task.custom_sleep_data._tasks.remove(task) + trio.lowlevel.reschedule(task) + self._state.data.append(value) + # Fall through + if self._state.data: + return self._state.data.popleft() + if not self._state.open_send_channels: + raise trio.EndOfChannel + raise trio.WouldBlock + + @enable_ki_protection + async def receive(self) -> ReceiveType: + """See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`. + + Memory channels allow multiple tasks to call `receive` at the same + time. The first task will get the first item sent, the second task + will get the second item sent, and so on. + + """ + await trio.lowlevel.checkpoint_if_cancelled() + try: + value = self.receive_nowait() + except trio.WouldBlock: + pass + else: + await trio.lowlevel.cancel_shielded_checkpoint() + return value + + task = trio.lowlevel.current_task() + self._tasks.add(task) + self._state.receive_tasks[task] = None + task.custom_sleep_data = self + + def abort_fn(_: RaiseCancelT) -> Abort: + self._tasks.remove(task) + del self._state.receive_tasks[task] + return trio.lowlevel.Abort.SUCCEEDED + + # Not strictly guaranteed to return ReceiveType, but will do so unless + # you intentionally reschedule with a bad value. + return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[no-any-return] + + @enable_ki_protection + def clone(self) -> MemoryReceiveChannel[ReceiveType]: + """Clone this receive channel object. + + This returns a new `MemoryReceiveChannel` object, which acts as a + duplicate of the original: receiving on the new object does exactly + the same thing as receiving on the old object. + + However, closing one of the objects does not close the other, and the + underlying channel is not closed until all clones are closed. (If + you're familiar with `os.dup`, then this is a similar idea.) + + This is useful for communication patterns that involve multiple + consumers all receiving objects from the same underlying channel. See + :ref:`channel-mpmc` for examples. + + .. warning:: The clones all share the same underlying channel. + Whenever a clone :meth:`receive`\\s a value, it is removed from the + channel and the other clones do *not* receive that value. If you + want to send multiple copies of the same stream of values to + multiple destinations, like :func:`itertools.tee`, then you need to + find some other solution; this method does *not* do that. + + Raises: + trio.ClosedResourceError: if you already closed this + `MemoryReceiveChannel` object. + + """ + if self._closed: + raise trio.ClosedResourceError + return MemoryReceiveChannel._create(self._state) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + @enable_ki_protection + def close(self) -> None: + """Close this receive channel object synchronously. + + All channel objects have an asynchronous `~.AsyncResource.aclose` method. + Memory channels can also be closed synchronously. This has the same + effect on the channel and other tasks using it, but `close` is not a + trio checkpoint. This simplifies cleaning up in cancelled tasks. + + Using ``with receive_channel:`` will close the channel object on + leaving the with block. + + """ + if self._closed: + return + self._closed = True + for task in self._tasks: + trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError())) + del self._state.receive_tasks[task] + self._tasks.clear() + self._state.open_receive_channels -= 1 + if self._state.open_receive_channels == 0: + assert not self._state.receive_tasks + for task in self._state.send_tasks: + task.custom_sleep_data._tasks.remove(task) + trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError())) + self._state.send_tasks.clear() + self._state.data.clear() + + @enable_ki_protection + async def aclose(self) -> None: + """Close this receive channel object asynchronously. + + See `MemoryReceiveChannel.close`.""" + self.close() + await trio.lowlevel.checkpoint() + + +class RecvChanWrapper(ReceiveChannel[T]): + def __init__( + self, recv_chan: MemoryReceiveChannel[T], send_semaphore: trio.Semaphore + ) -> None: + self._recv_chan = recv_chan + self._send_semaphore = send_semaphore + + async def receive(self) -> T: + self._send_semaphore.release() + return await self._recv_chan.receive() + + async def aclose(self) -> None: + await self._recv_chan.aclose() + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self._recv_chan.close() + + +def as_safe_channel( + fn: Callable[P, AsyncGenerator[T, None]], +) -> Callable[P, AbstractAsyncContextManager[ReceiveChannel[T]]]: + """Decorate an async generator function to make it cancellation-safe. + + The ``yield`` keyword offers a very convenient way to write iterators... + which makes it really unfortunate that async generators are so difficult + to call correctly. Yielding from the inside of a cancel scope or a nursery + to the outside `violates structured concurrency <https://xkcd.com/292/>`_ + with consequences explained in :pep:`789`. Even then, resource cleanup + errors remain common (:pep:`533`) unless you wrap every call in + :func:`~contextlib.aclosing`. + + This decorator gives you the best of both worlds: with careful exception + handling and a background task we preserve structured concurrency by + offering only the safe interface, and you can still write your iterables + with the convenience of ``yield``. For example:: + + @as_safe_channel + async def my_async_iterable(arg, *, kwarg=True): + while ...: + item = await ... + yield item + + async with my_async_iterable(...) as recv_chan: + async for item in recv_chan: + ... + + While the combined async-with-async-for can be inconvenient at first, + the context manager is indispensable for both correctness and for prompt + cleanup of resources. + """ + # Perhaps a future PEP will adopt `async with for` syntax, like + # https://coconut.readthedocs.io/en/master/DOCS.html#async-with-for + + @asynccontextmanager + @wraps(fn) + async def context_manager( + *args: P.args, **kwargs: P.kwargs + ) -> AsyncGenerator[trio._channel.RecvChanWrapper[T], None]: + send_chan, recv_chan = trio.open_memory_channel[T](0) + try: + async with trio.open_nursery(strict_exception_groups=True) as nursery: + agen = fn(*args, **kwargs) + send_semaphore = trio.Semaphore(0) + # `nursery.start` to make sure that we will clean up send_chan & agen + # If this errors we don't close `recv_chan`, but the caller + # never gets access to it, so that's not a problem. + await nursery.start( + _move_elems_to_channel, agen, send_chan, send_semaphore + ) + # `async with recv_chan` could eat exceptions, so use sync cm + with RecvChanWrapper(recv_chan, send_semaphore) as wrapped_recv_chan: + yield wrapped_recv_chan + # User has exited context manager, cancel to immediately close the + # abandoned generator if it's still alive. + nursery.cancel_scope.cancel( + "exited trio.as_safe_channel context manager" + ) + except BaseExceptionGroup as eg: + try: + raise_single_exception_from_group(eg) + except MultipleExceptionError: + # In case user has except* we make it possible for them to handle the + # exceptions. + raise BaseExceptionGroup( + "Encountered exception during cleanup of generator object, as well as exception in the contextmanager body - unable to unwrap.", + [eg], + ) from None + + async def _move_elems_to_channel( + agen: AsyncGenerator[T, None], + send_chan: trio.MemorySendChannel[T], + send_semaphore: trio.Semaphore, + task_status: trio.TaskStatus, + ) -> None: + # `async with send_chan` will eat exceptions, + # see https://github.com/python-trio/trio/issues/1559 + with send_chan: + # replace try-finally with contextlib.aclosing once python39 is + # dropped: + try: + task_status.started() + while True: + # wait for receiver to call next on the aiter + await send_semaphore.acquire() + try: + value = await agen.__anext__() + except StopAsyncIteration: + return + # Send the value to the channel + await send_chan.send(value) + finally: + # work around `.aclose()` not suppressing GeneratorExit in an + # ExceptionGroup: + # TODO: make an issue on CPython about this + try: + await agen.aclose() + except BaseExceptionGroup as exceptions: + removed, narrowed_exceptions = exceptions.split(GeneratorExit) + + # TODO: extract a helper to flatten exception groups + removed_exceptions: list[BaseException | None] = [removed] + genexits_seen = 0 + for e in removed_exceptions: + if isinstance(e, BaseExceptionGroup): + removed_exceptions.extend(e.exceptions) # noqa: B909 + else: + genexits_seen += 1 + + if genexits_seen > 1: + exc = AssertionError("More than one GeneratorExit found.") + if narrowed_exceptions is None: + narrowed_exceptions = exceptions.derive([exc]) + else: + narrowed_exceptions = narrowed_exceptions.derive( + [*narrowed_exceptions.exceptions, exc] + ) + if narrowed_exceptions is not None: + raise narrowed_exceptions from None + + return context_manager diff --git a/contrib/python/trio/trio/_core/__init__.py b/contrib/python/trio/trio/_core/__init__.py new file mode 100644 index 00000000000..f9d8068f0cc --- /dev/null +++ b/contrib/python/trio/trio/_core/__init__.py @@ -0,0 +1,94 @@ +""" +This namespace represents the core functionality that has to be built-in +and deal with private internal data structures. Things in this namespace +are publicly available in either trio, trio.lowlevel, or trio.testing. +""" + +import sys +import typing as _t + +from ._entry_queue import TrioToken +from ._exceptions import ( + BrokenResourceError, + BusyResourceError, + Cancelled, + ClosedResourceError, + EndOfChannel, + RunFinishedError, + TrioInternalError, + WouldBlock, +) +from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection +from ._local import RunVar, RunVarToken +from ._mock_clock import MockClock +from ._parking_lot import ( + ParkingLot, + ParkingLotStatistics, + add_parking_lot_breaker, + remove_parking_lot_breaker, +) + +# Imports that always exist +from ._run import ( + TASK_STATUS_IGNORED, + CancelScope, + Nursery, + RunStatistics, + Task, + TaskStatus, + add_instrument, + checkpoint, + checkpoint_if_cancelled, + current_clock, + current_effective_deadline, + current_root_task, + current_statistics, + current_task, + current_time, + current_trio_token, + in_trio_run, + in_trio_task, + notify_closing, + open_nursery, + remove_instrument, + reschedule, + run, + spawn_system_task, + start_guest_run, + wait_all_tasks_blocked, + wait_readable, + wait_writable, +) +from ._thread_cache import start_thread_soon + +# Has to come after _run to resolve a circular import +from ._traps import ( + Abort, + RaiseCancelT, + cancel_shielded_checkpoint, + permanently_detach_coroutine_object, + reattach_detached_coroutine_object, + temporarily_detach_coroutine_object, + wait_task_rescheduled, +) +from ._unbounded_queue import UnboundedQueue, UnboundedQueueStatistics + +# Windows imports +if sys.platform == "win32" or ( + not _t.TYPE_CHECKING and "sphinx.ext.autodoc" in sys.modules +): + from ._run import ( + current_iocp, + monitor_completion_key, + readinto_overlapped, + register_with_iocp, + wait_overlapped, + write_overlapped, + ) +# Kqueue imports +if (sys.platform != "linux" and sys.platform != "win32") or ( + not _t.TYPE_CHECKING and "sphinx.ext.autodoc" in sys.modules +): + from ._run import current_kqueue, monitor_kevent, wait_kevent + +del sys # It would be better to import sys as _sys, but mypy does not understand it diff --git a/contrib/python/trio/trio/_core/_asyncgens.py b/contrib/python/trio/trio/_core/_asyncgens.py new file mode 100644 index 00000000000..fea41e0e4de --- /dev/null +++ b/contrib/python/trio/trio/_core/_asyncgens.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import logging +import sys +import warnings +import weakref +from typing import TYPE_CHECKING, NoReturn, TypeVar + +import attrs + +from .. import _core +from .._util import name_asyncgen +from . import _run + +# Used to log exceptions in async generator finalizers +ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors") + +if TYPE_CHECKING: + from collections.abc import Callable + from types import AsyncGeneratorType + + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") + + _WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]] + _ASYNC_GEN_SET = set[AsyncGeneratorType[object, NoReturn]] +else: + _WEAK_ASYNC_GEN_SET = weakref.WeakSet + _ASYNC_GEN_SET = set + +_R = TypeVar("_R") + + +@_core.disable_ki_protection +def _call_without_ki_protection( + f: Callable[_P, _R], + /, + *args: _P.args, + **kwargs: _P.kwargs, +) -> _R: + return f(*args, **kwargs) + + [email protected](eq=False) +class AsyncGenerators: + # Async generators are added to this set when first iterated. Any + # left after the main task exits will be closed before trio.run() + # returns. During most of the run, this is a WeakSet so GC works. + # During shutdown, when we're finalizing all the remaining + # asyncgens after the system nursery has been closed, it's a + # regular set so we don't have to deal with GC firing at + # unexpected times. + alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET) + # The ids of foreign async generators are added to this set when first + # iterated. Usually it is not safe to refer to ids like this, but because + # we're using a finalizer we can ensure ids in this set do not outlive + # their async generator. + foreign: set[int] = attrs.Factory(set) + + # This collects async generators that get garbage collected during + # the one-tick window between the system nursery closing and the + # init task starting end-of-run asyncgen finalization. + trailing_needs_finalize: _ASYNC_GEN_SET = attrs.Factory(_ASYNC_GEN_SET) + + prev_hooks: sys._asyncgen_hooks = attrs.field(init=False) + + def install_hooks(self, runner: _run.Runner) -> None: + def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None: + if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"): + self.alive.add(agen) + else: + # An async generator first iterated outside of a Trio + # task doesn't belong to Trio. Probably we're in guest + # mode and the async generator belongs to our host. + # A strong set of ids is one of the only good places to + # remember this fact, at least until + # https://github.com/python/cpython/issues/85093 is implemented. + self.foreign.add(id(agen)) + if self.prev_hooks.firstiter is not None: + self.prev_hooks.firstiter(agen) + + def finalize_in_trio_context( + agen: AsyncGeneratorType[object, NoReturn], + agen_name: str, + ) -> None: + try: + runner.spawn_system_task( + self._finalize_one, + agen, + agen_name, + name=f"close asyncgen {agen_name} (abandoned)", + ) + except RuntimeError: + # There is a one-tick window where the system nursery + # is closed but the init task hasn't yet made + # self.asyncgens a strong set to disable GC. We seem to + # have hit it. + self.trailing_needs_finalize.add(agen) + + @_core.enable_ki_protection + def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: + try: + self.foreign.remove(id(agen)) + except KeyError: + is_ours = True + else: + is_ours = False + + agen_name = name_asyncgen(agen) + if is_ours: + runner.entry_queue.run_sync_soon( + finalize_in_trio_context, + agen, + agen_name, + ) + + # Do this last, because it might raise an exception + # depending on the user's warnings filter. (That + # exception will be printed to the terminal and + # ignored, since we're running in GC context.) + warnings.warn( + f"Async generator {agen_name!r} was garbage collected before it " + "had been exhausted. Surround its use in 'async with " + "aclosing(...):' to ensure that it gets cleaned up as soon as " + "you're done using it.", + ResourceWarning, + stacklevel=2, + source=agen, + ) + else: + # Not ours -> forward to the host loop's async generator finalizer + finalizer = self.prev_hooks.finalizer + if finalizer is not None: + _call_without_ki_protection(finalizer, agen) + else: + # Host has no finalizer. Reimplement the default + # Python behavior with no hooks installed: throw in + # GeneratorExit, step once, raise RuntimeError if + # it doesn't exit. + closer = agen.aclose() + try: + # If the next thing is a yield, this will raise RuntimeError + # which we allow to propagate + _call_without_ki_protection(closer.send, None) + except StopIteration: + pass + else: + # If the next thing is an await, we get here. Give a nicer + # error than the default "async generator ignored GeneratorExit" + raise RuntimeError( + f"Non-Trio async generator {agen_name!r} awaited something " + "during finalization; install a finalization hook to " + "support this, or wrap it in 'async with aclosing(...):'", + ) + + self.prev_hooks = sys.get_asyncgen_hooks() + sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer) # type: ignore[arg-type] # Finalizer doesn't use AsyncGeneratorType + + async def finalize_remaining(self, runner: _run.Runner) -> None: + # This is called from init after shutting down the system nursery. + # The only tasks running at this point are init and + # the run_sync_soon task, and since the system nursery is closed, + # there's no way for user code to spawn more. + assert _core.current_task() is runner.init_task + assert len(runner.tasks) == 2 + + # To make async generator finalization easier to reason + # about, we'll shut down asyncgen garbage collection by turning + # the alive WeakSet into a regular set. + self.alive = set(self.alive) + + # Process all pending run_sync_soon callbacks, in case one of + # them was an asyncgen finalizer that snuck in under the wire. + runner.entry_queue.run_sync_soon(runner.reschedule, runner.init_task) + await _core.wait_task_rescheduled( + lambda _: _core.Abort.FAILED, # pragma: no cover + ) + self.alive.update(self.trailing_needs_finalize) + self.trailing_needs_finalize.clear() + + # None of the still-living tasks use async generators, so + # every async generator must be suspended at a yield point -- + # there's no one to be doing the iteration. That's good, + # because aclose() only works on an asyncgen that's suspended + # at a yield point. (If it's suspended at an event loop trap, + # because someone is in the middle of iterating it, then you + # get a RuntimeError on 3.8+, and a nasty surprise on earlier + # versions due to https://bugs.python.org/issue32526.) + # + # However, once we start aclose() of one async generator, it + # might start fetching the next value from another, thus + # preventing us from closing that other (at least until + # aclose() of the first one is complete). This constraint + # effectively requires us to finalize the remaining asyncgens + # in arbitrary order, rather than doing all of them at the + # same time. On 3.8+ we could defer any generator with + # ag_running=True to a later batch, but that only catches + # the case where our aclose() starts after the user's + # asend()/etc. If our aclose() starts first, then the + # user's asend()/etc will raise RuntimeError, since they're + # probably not checking ag_running. + # + # It might be possible to allow some parallelized cleanup if + # we can determine that a certain set of asyncgens have no + # interdependencies, using gc.get_referents() and such. + # But just doing one at a time will typically work well enough + # (since each aclose() executes in a cancelled scope) and + # is much easier to reason about. + + # It's possible that that cleanup code will itself create + # more async generators, so we iterate repeatedly until + # all are gone. + while self.alive: + batch = self.alive + self.alive = _ASYNC_GEN_SET() + for agen in batch: + await self._finalize_one(agen, name_asyncgen(agen)) + + def close(self) -> None: + sys.set_asyncgen_hooks(*self.prev_hooks) + + async def _finalize_one( + self, + agen: AsyncGeneratorType[object, NoReturn], + name: object, + ) -> None: + try: + # This shield ensures that finalize_asyncgen never exits + # with an exception, not even a Cancelled. The inside + # is cancelled so there's no deadlock risk. + with _core.CancelScope(shield=True) as cancel_scope: + cancel_scope.cancel( + reason="disallow async work when closing async generators during trio shutdown" + ) + await agen.aclose() + except BaseException: + ASYNCGEN_LOGGER.exception( + "Exception ignored during finalization of async generator %r -- " + "surround your use of the generator in 'async with aclosing(...):' " + "to raise exceptions like this in the context where they're generated", + name, + ) diff --git a/contrib/python/trio/trio/_core/_concat_tb.py b/contrib/python/trio/trio/_core/_concat_tb.py new file mode 100644 index 00000000000..9d0291ccf80 --- /dev/null +++ b/contrib/python/trio/trio/_core/_concat_tb.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from types import TracebackType + + +# this is used for collapsing single-exception ExceptionGroups when using +# `strict_exception_groups=False`. Once that is retired this function can +# be removed as well. +def concat_tb( + head: TracebackType | None, + tail: TracebackType | None, +) -> TracebackType | None: + # We have to use an iterative algorithm here, because in the worst case + # this might be a RecursionError stack that is by definition too deep to + # process by recursion! + head_tbs = [] + pointer = head + while pointer is not None: + head_tbs.append(pointer) + pointer = pointer.tb_next + current_head = tail + for head_tb in reversed(head_tbs): + current_head = TracebackType( + current_head, head_tb.tb_frame, head_tb.tb_lasti, head_tb.tb_lineno + ) + return current_head diff --git a/contrib/python/trio/trio/_core/_entry_queue.py b/contrib/python/trio/trio/_core/_entry_queue.py new file mode 100644 index 00000000000..988b45ca00d --- /dev/null +++ b/contrib/python/trio/trio/_core/_entry_queue.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import threading +from collections import deque +from collections.abc import Callable +from typing import TYPE_CHECKING, NoReturn + +import attrs + +from .. import _core +from .._util import NoPublicConstructor, final +from ._wakeup_socketpair import WakeupSocketpair + +if TYPE_CHECKING: + from typing_extensions import TypeVarTuple, Unpack + + PosArgsT = TypeVarTuple("PosArgsT") + +Function = Callable[..., object] # type: ignore[explicit-any] +Job = tuple[Function, tuple[object, ...]] + + +class EntryQueue: + # This used to use a queue.Queue. but that was broken, because Queues are + # implemented in Python, and not reentrant -- so it was thread-safe, but + # not signal-safe. deque is implemented in C, so each operation is atomic + # WRT threads (and this is guaranteed in the docs), AND each operation is + # atomic WRT signal delivery (signal handlers can run on either side, but + # not *during* a deque operation). dict makes similar guarantees - and + # it's even ordered! + queue: deque[Job] = attrs.Factory(deque) + idempotent_queue: dict[Job, None] = attrs.Factory(dict) + + wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair) + done: bool = False + # Must be a reentrant lock, because it's acquired from signal handlers. + # RLock is signal-safe as of cpython 3.2. NB that this does mean that the + # lock is effectively *disabled* when we enter from signal context. The + # way we use the lock this is OK though, because when + # run_sync_soon is called from a signal it's atomic WRT the + # main thread -- it just might happen at some inconvenient place. But if + # you look at the one place where the main thread holds the lock, it's + # just to make 1 assignment, so that's atomic WRT a signal anyway. + lock: threading.RLock = attrs.Factory(threading.RLock) + + async def task(self) -> None: + assert _core.currently_ki_protected() + # RLock has two implementations: a signal-safe version in _thread, and + # and signal-UNsafe version in threading. We need the signal safe + # version. Python 3.2 and later should always use this anyway, but, + # since the symptoms if this goes wrong are just "weird rare + # deadlocks", then let's make a little check. + # See: + # https://bugs.python.org/issue13697#msg237140 + assert self.lock.__class__.__module__ == "_thread" + + def run_cb(job: Job) -> None: + # We run this with KI protection enabled; it's the callback's + # job to disable it if it wants it disabled. Exceptions are + # treated like system task exceptions (i.e., converted into + # TrioInternalError and cause everything to shut down). + sync_fn, args = job + try: + sync_fn(*args) + except BaseException as exc: + + async def kill_everything( # noqa: RUF029 # await not used + exc: BaseException, + ) -> NoReturn: + raise exc + + try: + _core.spawn_system_task(kill_everything, exc) + except RuntimeError: + # We're quite late in the shutdown process and the + # system nursery is already closed. + # TODO(2020-06): this is a gross hack and should + # be fixed soon when we address #1607. + parent_nursery = _core.current_task().parent_nursery + if parent_nursery is None: + raise AssertionError( + "Internal error: `parent_nursery` should never be `None`", + ) from exc # pragma: no cover + parent_nursery.start_soon(kill_everything, exc) + + # This has to be carefully written to be safe in the face of new items + # being queued while we iterate, and to do a bounded amount of work on + # each pass: + def run_all_bounded() -> None: + for _ in range(len(self.queue)): + run_cb(self.queue.popleft()) + for job in list(self.idempotent_queue): + del self.idempotent_queue[job] + run_cb(job) + + try: + while True: + run_all_bounded() + if not self.queue and not self.idempotent_queue: + await self.wakeup.wait_woken() + else: + await _core.checkpoint() + except _core.Cancelled: + # Keep the work done with this lock held as minimal as possible, + # because it doesn't protect us against concurrent signal delivery + # (see the comment above). Notice that this code would still be + # correct if written like: + # self.done = True + # with self.lock: + # pass + # because all we want is to force run_sync_soon + # to either be completely before or completely after the write to + # done. That's why we don't need the lock to protect + # against signal handlers. + with self.lock: + self.done = True + # No more jobs will be submitted, so just clear out any residual + # ones: + run_all_bounded() + assert not self.queue + assert not self.idempotent_queue + + def close(self) -> None: + self.wakeup.close() + + def size(self) -> int: + return len(self.queue) + len(self.idempotent_queue) + + def run_sync_soon( + self, + sync_fn: Callable[[Unpack[PosArgsT]], object], + *args: Unpack[PosArgsT], + idempotent: bool = False, + ) -> None: + with self.lock: + if self.done: + raise _core.RunFinishedError("run() has exited") + # We have to hold the lock all the way through here, because + # otherwise the main thread might exit *while* we're doing these + # calls, and then our queue item might not be processed, or the + # wakeup call might trigger an OSError b/c the IO manager has + # already been shut down. + if idempotent: + self.idempotent_queue[sync_fn, args] = None + else: + self.queue.append((sync_fn, args)) + self.wakeup.wakeup_thread_and_signal_safe() + + +@final [email protected](eq=False) +class TrioToken(metaclass=NoPublicConstructor): + """An opaque object representing a single call to :func:`trio.run`. + + It has no public constructor; instead, see :func:`current_trio_token`. + + This object has two uses: + + 1. It lets you re-enter the Trio run loop from external threads or signal + handlers. This is the low-level primitive that :func:`trio.to_thread` + and `trio.from_thread` use to communicate with worker threads, that + `trio.open_signal_receiver` uses to receive notifications about + signals, and so forth. + + 2. Each call to :func:`trio.run` has exactly one associated + :class:`TrioToken` object, so you can use it to identify a particular + call. + + """ + + _reentry_queue: EntryQueue + + def run_sync_soon( + self, + sync_fn: Callable[[Unpack[PosArgsT]], object], + *args: Unpack[PosArgsT], + idempotent: bool = False, + ) -> None: + """Schedule a call to ``sync_fn(*args)`` to occur in the context of a + Trio task. + + This is safe to call from the main thread, from other threads, and + from signal handlers. This is the fundamental primitive used to + re-enter the Trio run loop from outside of it. + + The call will happen "soon", but there's no guarantee about exactly + when, and no mechanism provided for finding out when it's happened. + If you need this, you'll have to build your own. + + The call is effectively run as part of a system task (see + :func:`~trio.lowlevel.spawn_system_task`). In particular this means + that: + + * :exc:`KeyboardInterrupt` protection is *enabled* by default; if + you want ``sync_fn`` to be interruptible by control-C, then you + need to use :func:`~trio.lowlevel.disable_ki_protection` + explicitly. + + * If ``sync_fn`` raises an exception, then it's converted into a + :exc:`~trio.TrioInternalError` and *all* tasks are cancelled. You + should be careful that ``sync_fn`` doesn't crash. + + All calls with ``idempotent=False`` are processed in strict + first-in first-out order. + + If ``idempotent=True``, then ``sync_fn`` and ``args`` must be + hashable, and Trio will make a best-effort attempt to discard any + call submission which is equal to an already-pending call. Trio + will process these in first-in first-out order. + + Any ordering guarantees apply separately to ``idempotent=False`` + and ``idempotent=True`` calls; there's no rule for how calls in the + different categories are ordered with respect to each other. + + :raises trio.RunFinishedError: + if the associated call to :func:`trio.run` + has already exited. (Any call that *doesn't* raise this error + is guaranteed to be fully processed before :func:`trio.run` + exits.) + + """ + self._reentry_queue.run_sync_soon(sync_fn, *args, idempotent=idempotent) diff --git a/contrib/python/trio/trio/_core/_exceptions.py b/contrib/python/trio/trio/_core/_exceptions.py new file mode 100644 index 00000000000..f70d5e0e95b --- /dev/null +++ b/contrib/python/trio/trio/_core/_exceptions.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Literal + +import attrs + +from trio._util import NoPublicConstructor, final + +if TYPE_CHECKING: + from collections.abc import Callable + + from typing_extensions import Self, TypeAlias + +CancelReasonLiteral: TypeAlias = Literal[ + "KeyboardInterrupt", + "deadline", + "explicit", + "nursery", + "shutdown", + "unknown", +] + + +class TrioInternalError(Exception): + """Raised by :func:`run` if we encounter a bug in Trio, or (possibly) a + misuse of one of the low-level :mod:`trio.lowlevel` APIs. + + This should never happen! If you get this error, please file a bug. + + Unfortunately, if you get this error it also means that all bets are off – + Trio doesn't know what is going on and its normal invariants may be void. + (For example, we might have "lost track" of a task. Or lost track of all + tasks.) Again, though, this shouldn't happen. + + """ + + +class RunFinishedError(RuntimeError): + """Raised by `trio.from_thread.run` and similar functions if the + corresponding call to :func:`trio.run` has already finished. + + """ + + +class WouldBlock(Exception): + """Raised by ``X_nowait`` functions if ``X`` would block.""" + + +@final [email protected](eq=False, kw_only=True) +class Cancelled(BaseException, metaclass=NoPublicConstructor): + """Raised by blocking calls if the surrounding scope has been cancelled. + + You should let this exception propagate, to be caught by the relevant + cancel scope. To remind you of this, it inherits from :exc:`BaseException` + instead of :exc:`Exception`, just like :exc:`KeyboardInterrupt` and + :exc:`SystemExit` do. This means that if you write something like:: + + try: + ... + except Exception: + ... + + then this *won't* catch a :exc:`Cancelled` exception. + + You cannot raise :exc:`Cancelled` yourself. Attempting to do so + will produce a :exc:`TypeError`. Use :meth:`cancel_scope.cancel() + <trio.CancelScope.cancel>` instead. + + .. note:: + + In the US it's also common to see this word spelled "canceled", with + only one "l". This is a `recent + <https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=5&smoothing=3&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__ + and `US-specific + <https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=18&smoothing=3&share=&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__ + innovation, and even in the US both forms are still commonly used. So + for consistency with the rest of the world and with "cancellation" + (which always has two "l"s), Trio uses the two "l" spelling + everywhere. + + """ + + source: CancelReasonLiteral = "unknown" + # repr(Task), so as to avoid gc troubles from holding a reference + source_task: str | None = None + reason: str | None = None + + def __str__(self) -> str: + return ( + f"cancelled due to {self.source}" + + ("" if self.reason is None else f" with reason {self.reason!r}") + + ("" if self.source_task is None else f" from task {self.source_task}") + ) + + def __reduce__(self) -> tuple[Callable[[], Cancelled], tuple[()]]: + # The `__reduce__` tuple does not support directly passing kwargs, and the + # kwargs are required so we can't use the third item for adding to __dict__, + # so we use partial. + return ( + partial( + Cancelled._create, + source=self.source, + source_task=self.source_task, + reason=self.reason, + ), + (), + ) + + if TYPE_CHECKING: + # for type checking on internal code + @classmethod + def _create( + cls, + *, + source: CancelReasonLiteral = "unknown", + source_task: str | None = None, + reason: str | None = None, + ) -> Self: ... + + +class BusyResourceError(Exception): + """Raised when a task attempts to use a resource that some other task is + already using, and this would lead to bugs and nonsense. + + For example, if two tasks try to send data through the same socket at the + same time, Trio will raise :class:`BusyResourceError` instead of letting + the data get scrambled. + + """ + + +class ClosedResourceError(Exception): + """Raised when attempting to use a resource after it has been closed. + + Note that "closed" here means that *your* code closed the resource, + generally by calling a method with a name like ``close`` or ``aclose``, or + by exiting a context manager. If a problem arises elsewhere – for example, + because of a network failure, or because a remote peer closed their end of + a connection – then that should be indicated by a different exception + class, like :exc:`BrokenResourceError` or an :exc:`OSError` subclass. + + """ + + +class BrokenResourceError(Exception): + """Raised when an attempt to use a resource fails due to external + circumstances. + + For example, you might get this if you try to send data on a stream where + the remote side has already closed the connection. + + You *don't* get this error if *you* closed the resource – in that case you + get :class:`ClosedResourceError`. + + This exception's ``__cause__`` attribute will often contain more + information about the underlying error. + + """ + + +class EndOfChannel(Exception): + """Raised when trying to receive from a :class:`trio.abc.ReceiveChannel` + that has no more data to receive. + + This is analogous to an "end-of-file" condition, but for channels. + + """ diff --git a/contrib/python/trio/trio/_core/_generated_instrumentation.py b/contrib/python/trio/trio/_core/_generated_instrumentation.py new file mode 100644 index 00000000000..d03ef9db7de --- /dev/null +++ b/contrib/python/trio/trio/_core/_generated_instrumentation.py @@ -0,0 +1,50 @@ +# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ._ki import enable_ki_protection +from ._run import GLOBAL_RUN_CONTEXT + +if TYPE_CHECKING: + from ._instrumentation import Instrument + +__all__ = ["add_instrument", "remove_instrument"] + + +@enable_ki_protection +def add_instrument(instrument: Instrument) -> None: + """Start instrumenting the current run loop with the given instrument. + + Args: + instrument (trio.abc.Instrument): The instrument to activate. + + If ``instrument`` is already active, does nothing. + + """ + try: + return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def remove_instrument(instrument: Instrument) -> None: + """Stop instrumenting the current run loop with the given instrument. + + Args: + instrument (trio.abc.Instrument): The instrument to de-activate. + + Raises: + KeyError: if the instrument is not currently active. This could + occur either because you never added it, or because you added it + and then it raised an unhandled exception and was automatically + deactivated. + + """ + try: + return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument) + except AttributeError: + raise RuntimeError("must be called from async context") from None diff --git a/contrib/python/trio/trio/_core/_generated_io_epoll.py b/contrib/python/trio/trio/_core/_generated_io_epoll.py new file mode 100644 index 00000000000..41cbb406502 --- /dev/null +++ b/contrib/python/trio/trio/_core/_generated_io_epoll.py @@ -0,0 +1,98 @@ +# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + +from ._ki import enable_ki_protection +from ._run import GLOBAL_RUN_CONTEXT + +if TYPE_CHECKING: + from .._file_io import _HasFileNo + +assert not TYPE_CHECKING or sys.platform == "linux" + + +__all__ = ["notify_closing", "wait_readable", "wait_writable"] + + +@enable_ki_protection +async def wait_readable(fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``fd`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``fd`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +async def wait_writable(fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``fd``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def notify_closing(fd: int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None diff --git a/contrib/python/trio/trio/_core/_generated_io_kqueue.py b/contrib/python/trio/trio/_core/_generated_io_kqueue.py new file mode 100644 index 00000000000..556d29e1f26 --- /dev/null +++ b/contrib/python/trio/trio/_core/_generated_io_kqueue.py @@ -0,0 +1,153 @@ +# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + +from ._ki import enable_ki_protection +from ._run import GLOBAL_RUN_CONTEXT + +if TYPE_CHECKING: + import select + from collections.abc import Callable + from contextlib import AbstractContextManager + + from .. import _core + from .._file_io import _HasFileNo + from ._traps import Abort, RaiseCancelT + +assert not TYPE_CHECKING or sys.platform == "darwin" + + +__all__ = [ + "current_kqueue", + "monitor_kevent", + "notify_closing", + "wait_kevent", + "wait_readable", + "wait_writable", +] + + +@enable_ki_protection +def current_kqueue() -> select.kqueue: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__. + """ + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def monitor_kevent( + ident: int, filter: int +) -> AbstractContextManager[_core.UnboundedQueue[select.kevent]]: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__. + """ + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +async def wait_kevent( + ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort] +) -> Abort: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent( + ident, filter, abort_func + ) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +async def wait_readable(fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``fd`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``fd`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +async def wait_writable(fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``fd``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def notify_closing(fd: int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None diff --git a/contrib/python/trio/trio/_core/_generated_io_windows.py b/contrib/python/trio/trio/_core/_generated_io_windows.py new file mode 100644 index 00000000000..211f81215c5 --- /dev/null +++ b/contrib/python/trio/trio/_core/_generated_io_windows.py @@ -0,0 +1,204 @@ +# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + +from ._ki import enable_ki_protection +from ._run import GLOBAL_RUN_CONTEXT + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + from typing_extensions import Buffer + + from .._file_io import _HasFileNo + from ._unbounded_queue import UnboundedQueue + from ._windows_cffi import CData, Handle + +assert not TYPE_CHECKING or sys.platform == "win32" + + +__all__ = [ + "current_iocp", + "monitor_completion_key", + "notify_closing", + "readinto_overlapped", + "register_with_iocp", + "wait_overlapped", + "wait_readable", + "wait_writable", + "write_overlapped", +] + + +@enable_ki_protection +async def wait_readable(sock: _HasFileNo | int) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``sock`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``sock`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +async def wait_writable(sock: _HasFileNo | int) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``sock``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def notify_closing(handle: Handle | int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def register_with_iocp(handle: int | CData) -> None: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> object: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped( + handle_, lpOverlapped + ) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +async def write_overlapped( + handle: int | CData, data: Buffer, file_offset: int = 0 +) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped( + handle, data, file_offset + ) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +async def readinto_overlapped( + handle: int | CData, buffer: Buffer, file_offset: int = 0 +) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped( + handle, buffer, file_offset + ) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def current_iocp() -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def monitor_completion_key() -> ( + AbstractContextManager[tuple[int, UnboundedQueue[object]]] +): + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() + except AttributeError: + raise RuntimeError("must be called from async context") from None diff --git a/contrib/python/trio/trio/_core/_generated_run.py b/contrib/python/trio/trio/_core/_generated_run.py new file mode 100644 index 00000000000..db1454e6c76 --- /dev/null +++ b/contrib/python/trio/trio/_core/_generated_run.py @@ -0,0 +1,269 @@ +# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ._ki import enable_ki_protection +from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT, RunStatistics, Task + +if TYPE_CHECKING: + import contextvars + from collections.abc import Awaitable, Callable + + from outcome import Outcome + from typing_extensions import Unpack + + from .._abc import Clock + from ._entry_queue import TrioToken + from ._run import PosArgT + + +__all__ = [ + "current_clock", + "current_root_task", + "current_statistics", + "current_time", + "current_trio_token", + "reschedule", + "spawn_system_task", + "wait_all_tasks_blocked", +] + + +@enable_ki_protection +def current_statistics() -> RunStatistics: + """Returns ``RunStatistics``, which contains run-loop-level debugging information. + + Currently, the following fields are defined: + + * ``tasks_living`` (int): The number of tasks that have been spawned + and not yet exited. + * ``tasks_runnable`` (int): The number of tasks that are currently + queued on the run queue (as opposed to blocked waiting for something + to happen). + * ``seconds_to_next_deadline`` (float): The time until the next + pending cancel scope deadline. May be negative if the deadline has + expired but we haven't yet processed cancellations. May be + :data:`~math.inf` if there are no pending deadlines. + * ``run_sync_soon_queue_size`` (int): The number of + unprocessed callbacks queued via + :meth:`trio.lowlevel.TrioToken.run_sync_soon`. + * ``io_statistics`` (object): Some statistics from Trio's I/O + backend. This always has an attribute ``backend`` which is a string + naming which operating-system-specific I/O backend is in use; the + other attributes vary between backends. + + """ + try: + return GLOBAL_RUN_CONTEXT.runner.current_statistics() + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def current_time() -> float: + """Returns the current time according to Trio's internal clock. + + Returns: + float: The current time. + + Raises: + RuntimeError: if not inside a call to :func:`trio.run`. + + """ + try: + return GLOBAL_RUN_CONTEXT.runner.current_time() + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def current_clock() -> Clock: + """Returns the current :class:`~trio.abc.Clock`.""" + try: + return GLOBAL_RUN_CONTEXT.runner.current_clock() + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def current_root_task() -> Task | None: + """Returns the current root :class:`Task`. + + This is the task that is the ultimate parent of all other tasks. + + """ + try: + return GLOBAL_RUN_CONTEXT.runner.current_root_task() + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def reschedule(task: Task, next_send: Outcome[object] = _NO_SEND) -> None: + """Reschedule the given task with the given + :class:`outcome.Outcome`. + + See :func:`wait_task_rescheduled` for the gory details. + + There must be exactly one call to :func:`reschedule` for every call to + :func:`wait_task_rescheduled`. (And when counting, keep in mind that + returning :data:`Abort.SUCCEEDED` from an abort callback is equivalent + to calling :func:`reschedule` once.) + + Args: + task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked + in a call to :func:`wait_task_rescheduled`. + next_send (outcome.Outcome): the value (or error) to return (or + raise) from :func:`wait_task_rescheduled`. + + """ + try: + return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def spawn_system_task( + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + *args: Unpack[PosArgT], + name: object = None, + context: contextvars.Context | None = None, +) -> Task: + """Spawn a "system" task. + + System tasks have a few differences from regular tasks: + + * They don't need an explicit nursery; instead they go into the + internal "system nursery". + + * If a system task raises an exception, then it's converted into a + :exc:`~trio.TrioInternalError` and *all* tasks are cancelled. If you + write a system task, you should be careful to make sure it doesn't + crash. + + * System tasks are automatically cancelled when the main task exits. + + * By default, system tasks have :exc:`KeyboardInterrupt` protection + *enabled*. If you want your task to be interruptible by control-C, + then you need to use :func:`disable_ki_protection` explicitly (and + come up with some plan for what to do with a + :exc:`KeyboardInterrupt`, given that system tasks aren't allowed to + raise exceptions). + + * System tasks do not inherit context variables from their creator. + + Towards the end of a call to :meth:`trio.run`, after the main + task and all system tasks have exited, the system nursery + becomes closed. At this point, new calls to + :func:`spawn_system_task` will raise ``RuntimeError("Nursery + is closed to new arrivals")`` instead of creating a system + task. It's possible to encounter this state either in + a ``finally`` block in an async generator, or in a callback + passed to :meth:`TrioToken.run_sync_soon` at the right moment. + + Args: + async_fn: An async callable. + args: Positional arguments for ``async_fn``. If you want to pass + keyword arguments, use :func:`functools.partial`. + name: The name for this task. Only used for debugging/introspection + (e.g. ``repr(task_obj)``). If this isn't a string, + :func:`spawn_system_task` will try to make it one. A common use + case is if you're wrapping a function before spawning a new + task, you might pass the original function as the ``name=`` to + make debugging easier. + context: An optional ``contextvars.Context`` object with context variables + to use for this task. You would normally get a copy of the current + context with ``context = contextvars.copy_context()`` and then you would + pass that ``context`` object here. + + Returns: + Task: the newly spawned task + + """ + try: + return GLOBAL_RUN_CONTEXT.runner.spawn_system_task( + async_fn, *args, name=name, context=context + ) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +def current_trio_token() -> TrioToken: + """Retrieve the :class:`TrioToken` for the current call to + :func:`trio.run`. + + """ + try: + return GLOBAL_RUN_CONTEXT.runner.current_trio_token() + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +@enable_ki_protection +async def wait_all_tasks_blocked(cushion: float = 0.0) -> None: + """Block until there are no runnable tasks. + + This is useful in testing code when you want to give other tasks a + chance to "settle down". The calling task is blocked, and doesn't wake + up until all other tasks are also blocked for at least ``cushion`` + seconds. (Setting a non-zero ``cushion`` is intended to handle cases + like two tasks talking to each other over a local socket, where we + want to ignore the potential brief moment between a send and receive + when all tasks are blocked.) + + Note that ``cushion`` is measured in *real* time, not the Trio clock + time. + + If there are multiple tasks blocked in :func:`wait_all_tasks_blocked`, + then the one with the shortest ``cushion`` is the one woken (and + this task becoming unblocked resets the timers for the remaining + tasks). If there are multiple tasks that have exactly the same + ``cushion``, then all are woken. + + You should also consider :class:`trio.testing.Sequencer`, which + provides a more explicit way to control execution ordering within a + test, and will often produce more readable tests. + + Example: + Here's an example of one way to test that Trio's locks are fair: we + take the lock in the parent, start a child, wait for the child to be + blocked waiting for the lock (!), and then check that we can't + release and immediately re-acquire the lock:: + + async def lock_taker(lock): + await lock.acquire() + lock.release() + + async def test_lock_fairness(): + lock = trio.Lock() + await lock.acquire() + async with trio.open_nursery() as nursery: + nursery.start_soon(lock_taker, lock) + # child hasn't run yet, we have the lock + assert lock.locked() + assert lock._owner is trio.lowlevel.current_task() + await trio.testing.wait_all_tasks_blocked() + # now the child has run and is blocked on lock.acquire(), we + # still have the lock + assert lock.locked() + assert lock._owner is trio.lowlevel.current_task() + lock.release() + try: + # The child has a prior claim, so we can't have it + lock.acquire_nowait() + except trio.WouldBlock: + assert lock._owner is not trio.lowlevel.current_task() + print("PASS") + else: + print("FAIL") + + """ + try: + return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion) + except AttributeError: + raise RuntimeError("must be called from async context") from None diff --git a/contrib/python/trio/trio/_core/_generated_windows_ffi.py b/contrib/python/trio/trio/_core/_generated_windows_ffi.py new file mode 100644 index 00000000000..0178993e375 --- /dev/null +++ b/contrib/python/trio/trio/_core/_generated_windows_ffi.py @@ -0,0 +1,10 @@ +# auto-generated file +import _cffi_backend + +ffi = _cffi_backend.FFI('trio._core._generated_windows_ffi', + _version = 0x2601, + _types = b'\x00\x00\x39\x0D\x00\x00\x1A\x01\x00\x00\x0A\x01\x00\x00\x72\x03\x00\x00\x0A\x01\x00\x00\x03\x11\x00\x00\x0A\x01\x00\x00\x02\x03\x00\x00\x6D\x03\x00\x00\x03\x11\x00\x00\x00\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x00\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x04\x01\x00\x00\x00\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x0A\x01\x00\x00\x03\x11\x00\x00\x0A\x01\x00\x00\x03\x11\x00\x00\x0A\x01\x00\x00\x07\x11\x00\x00\x08\x11\x00\x00\x00\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x03\x11\x00\x00\x0A\x01\x00\x00\x07\x11\x00\x00\x08\x11\x00\x00\x00\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x72\x03\x00\x00\x0A\x01\x00\x00\x07\x11\x00\x00\x08\x11\x00\x00\x00\x0F\x00\x00\x39\x0D\x00\x00\x00\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x02\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x08\x11\x00\x00\x02\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x6E\x03\x00\x00\x0A\x01\x00\x00\x07\x11\x00\x00\x0A\x01\x00\x00\x07\x01\x00\x00\x02\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x07\x01\x00\x00\x02\x0F\x00\x00\x39\x0D\x00\x00\x03\x11\x00\x00\x0A\x01\x00\x00\x1A\x01\x00\x00\x08\x11\x00\x00\x02\x0F\x00\x00\x02\x0D\x00\x00\x08\x01\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x0A\x01\x00\x00\x03\x03\x00\x00\x07\x01\x00\x00\x0A\x01\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x03\x11\x00\x00\x0A\x01\x00\x00\x00\x0F\x00\x00\x03\x0D\x00\x00\x03\x11\x00\x00\x07\x01\x00\x00\x07\x01\x00\x00\x03\x11\x00\x00\x00\x0F\x00\x00\x03\x0D\x00\x00\x73\x03\x00\x00\x0A\x01\x00\x00\x0A\x01\x00\x00\x03\x11\x00\x00\x0A\x01\x00\x00\x0A\x01\x00\x00\x03\x11\x00\x00\x00\x0F\x00\x00\x03\x0D\x00\x00\x03\x11\x00\x00\x03\x11\x00\x00\x1A\x01\x00\x00\x0A\x01\x00\x00\x02\x0F\x00\x00\x68\x03\x00\x00\x02\x09\x00\x00\x68\x05\x00\x00\x00\x01\x00\x00\x6C\x03\x00\x00\x03\x09\x00\x00\x04\x09\x00\x00\x05\x09\x00\x00\x17\x01\x00\x00\x01\x09\x00\x00\x00\x09\x00\x00\x00\x01\x00\x00\x10\x01', + _globals = (b'\x00\x00\x2F\x23CancelIoEx',0,b'\x00\x00\x2C\x23CloseHandle',0,b'\x00\x00\x52\x23CreateEventA',0,b'\x00\x00\x58\x23CreateFileW',0,b'\x00\x00\x61\x23CreateIoCompletionPort',0,b'\x00\x00\x12\x23DeviceIoControl',0,b'\x00\x00\x33\x23GetQueuedCompletionStatusEx',0,b'\x00\x00\x3F\x23PostQueuedCompletionStatus',0,b'\x00\x00\x1C\x23ReadFile',0,b'\x00\x00\x0B\x23ResetEvent',0,b'\x00\x00\x45\x23RtlNtStatusToDosError',0,b'\x00\x00\x3B\x23SetConsoleCtrlHandler',0,b'\x00\x00\x0B\x23SetEvent',0,b'\x00\x00\x0E\x23SetFileCompletionNotificationModes',0,b'\x00\x00\x2A\x23WSAGetLastError',0,b'\x00\x00\x00\x23WSAIoctl',0,b'\x00\x00\x48\x23WaitForMultipleObjects',0,b'\x00\x00\x4E\x23WaitForSingleObject',0,b'\x00\x00\x23\x23WriteFile',0), + _struct_unions = ((b'\x00\x00\x00\x71\x00\x00\x00\x03$1',b'\x00\x00\x70\x11DUMMYSTRUCTNAME',b'\x00\x00\x03\x11Pointer'),(b'\x00\x00\x00\x70\x00\x00\x00\x02$2',b'\x00\x00\x02\x11Offset',b'\x00\x00\x02\x11OffsetHigh'),(b'\x00\x00\x00\x68\x00\x00\x00\x02_AFD_POLL_HANDLE_INFO',b'\x00\x00\x03\x11Handle',b'\x00\x00\x02\x11Events',b'\x00\x00\x46\x11Status'),(b'\x00\x00\x00\x6C\x00\x00\x00\x02_AFD_POLL_INFO',b'\x00\x00\x6F\x11Timeout',b'\x00\x00\x02\x11NumberOfHandles',b'\x00\x00\x02\x11Exclusive',b'\x00\x00\x69\x11Handles'),(b'\x00\x00\x00\x6D\x00\x00\x00\x02_OVERLAPPED',b'\x00\x00\x01\x11Internal',b'\x00\x00\x01\x11InternalHigh',b'\x00\x00\x71\x11DUMMYUNIONNAME',b'\x00\x00\x03\x11hEvent'),(b'\x00\x00\x00\x6E\x00\x00\x00\x02_OVERLAPPED_ENTRY',b'\x00\x00\x01\x11lpCompletionKey',b'\x00\x00\x08\x11lpOverlapped',b'\x00\x00\x01\x11Internal',b'\x00\x00\x02\x11dwNumberOfBytesTransferred')), + _typenames = (b'\x00\x00\x00\x68AFD_POLL_HANDLE_INFO',b'\x00\x00\x00\x6CAFD_POLL_INFO',b'\x00\x00\x00\x39BOOL',b'\x00\x00\x00\x10BOOLEAN',b'\x00\x00\x00\x10BYTE',b'\x00\x00\x00\x02DWORD',b'\x00\x00\x00\x03HANDLE',b'\x00\x00\x00\x6FLARGE_INTEGER',b'\x00\x00\x00\x03LPCSTR',b'\x00\x00\x00\x25LPCVOID',b'\x00\x00\x00\x59LPCWSTR',b'\x00\x00\x00\x07LPDWORD',b'\x00\x00\x00\x08LPOVERLAPPED',b'\x00\x00\x00\x35LPOVERLAPPED_ENTRY',b'\x00\x00\x00\x03LPSECURITY_ATTRIBUTES',b'\x00\x00\x00\x03LPVOID',b'\x00\x00\x00\x08LPWSAOVERLAPPED',b'\x00\x00\x00\x46NTSTATUS',b'\x00\x00\x00\x6DOVERLAPPED',b'\x00\x00\x00\x6EOVERLAPPED_ENTRY',b'\x00\x00\x00\x67PAFD_POLL_HANDLE_INFO',b'\x00\x00\x00\x6BPAFD_POLL_INFO',b'\x00\x00\x00\x07PULONG',b'\x00\x00\x00\x03PVOID',b'\x00\x00\x00\x01SOCKET',b'\x00\x00\x00\x10UCHAR',b'\x00\x00\x00\x01UINT_PTR',b'\x00\x00\x00\x02ULONG',b'\x00\x00\x00\x01ULONG_PTR',b'\x00\x00\x00\x6DWSAOVERLAPPED',b'\x00\x00\x00\x02u_long'), +) diff --git a/contrib/python/trio/trio/_core/_instrumentation.py b/contrib/python/trio/trio/_core/_instrumentation.py new file mode 100644 index 00000000000..f2a106e29b4 --- /dev/null +++ b/contrib/python/trio/trio/_core/_instrumentation.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import logging +import types +from collections import UserDict +from typing import TYPE_CHECKING, TypeVar + +from .._abc import Instrument + +# Used to log exceptions in instruments +INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument") + +if TYPE_CHECKING: + from collections.abc import Sequence + + T = TypeVar("T") + + +# Decorator to mark methods public. This does nothing by itself, but +# trio/_tools/gen_exports.py looks for it. +def _public(fn: T) -> T: + return fn + + +class Instruments(UserDict[str, dict[Instrument, None]]): + """A collection of `trio.abc.Instrument` organized by hook. + + Instrumentation calls are rather expensive, and we don't want a + rarely-used instrument (like before_run()) to slow down hot + operations (like before_task_step()). Thus, we cache the set of + instruments to be called for each hook, and skip the instrumentation + call if there's nothing currently installed for that hook. + """ + + __slots__ = () + + def __init__(self, incoming: Sequence[Instrument]) -> None: + super().__init__({"_all": {}}) + for instrument in incoming: + self.add_instrument(instrument) + + @_public + def add_instrument(self, instrument: Instrument) -> None: + """Start instrumenting the current run loop with the given instrument. + + Args: + instrument (trio.abc.Instrument): The instrument to activate. + + If ``instrument`` is already active, does nothing. + + """ + if instrument in self.data["_all"]: + return + self.data["_all"][instrument] = None + try: + for name in dir(instrument): + if name.startswith("_"): + continue + try: + prototype = getattr(Instrument, name) + except AttributeError: + continue + impl = getattr(instrument, name) + if isinstance(impl, types.MethodType) and impl.__func__ is prototype: + # Inherited unchanged from _abc.Instrument + continue + self.data.setdefault(name, {})[instrument] = None + except: + self.remove_instrument(instrument) + raise + + @_public + def remove_instrument(self, instrument: Instrument) -> None: + """Stop instrumenting the current run loop with the given instrument. + + Args: + instrument (trio.abc.Instrument): The instrument to de-activate. + + Raises: + KeyError: if the instrument is not currently active. This could + occur either because you never added it, or because you added it + and then it raised an unhandled exception and was automatically + deactivated. + + """ + # If instrument isn't present, the KeyError propagates out + self.data["_all"].pop(instrument) + for hookname, instruments in list(self.data.items()): + if instrument in instruments: + del instruments[instrument] + if not instruments: + del self.data[hookname] + + def call( + self, + hookname: str, + *args: object, + ) -> None: + """Call hookname(*args) on each applicable instrument. + + You must first check whether there are any instruments installed for + that hook, e.g.:: + + if "before_task_step" in instruments: + instruments.call("before_task_step", task) + """ + for instrument in list(self.data[hookname]): + try: + getattr(instrument, hookname)(*args) + except BaseException: + self.remove_instrument(instrument) + INSTRUMENT_LOGGER.exception( + "Exception raised when calling %r on instrument %r. " + "Instrument has been disabled.", + hookname, + instrument, + ) diff --git a/contrib/python/trio/trio/_core/_io_common.py b/contrib/python/trio/trio/_core/_io_common.py new file mode 100644 index 00000000000..14cd9d33e63 --- /dev/null +++ b/contrib/python/trio/trio/_core/_io_common.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +import outcome + +from .. import _core + +if TYPE_CHECKING: + from ._io_epoll import EpollWaiters + from ._io_windows import AFDWaiters + + +# Utility function shared between _io_epoll and _io_windows +def wake_all(waiters: EpollWaiters | AFDWaiters, exc: BaseException) -> None: + try: + current_task = _core.current_task() + except RuntimeError: + current_task = None + raise_at_end = False + for attr_name in ["read_task", "write_task"]: + task = getattr(waiters, attr_name) + if task is not None: + if task is current_task: + raise_at_end = True + else: + _core.reschedule(task, outcome.Error(copy.copy(exc))) + setattr(waiters, attr_name, None) + if raise_at_end: + raise exc diff --git a/contrib/python/trio/trio/_core/_io_epoll.py b/contrib/python/trio/trio/_core/_io_epoll.py new file mode 100644 index 00000000000..5e05f0813fe --- /dev/null +++ b/contrib/python/trio/trio/_core/_io_epoll.py @@ -0,0 +1,387 @@ +from __future__ import annotations + +import contextlib +import select +import sys +from collections import defaultdict +from typing import TYPE_CHECKING, Literal + +import attrs + +from .. import _core +from ._io_common import wake_all +from ._run import Task, _public +from ._wakeup_socketpair import WakeupSocketpair + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from .._core import Abort, RaiseCancelT + from .._file_io import _HasFileNo + + [email protected](eq=False) +class EpollWaiters: + read_task: Task | None = None + write_task: Task | None = None + current_flags: int = 0 + + +assert not TYPE_CHECKING or sys.platform == "linux" + + +EventResult: TypeAlias = "list[tuple[int, int]]" + + [email protected](eq=False) +class _EpollStatistics: + tasks_waiting_read: int + tasks_waiting_write: int + backend: Literal["epoll"] = attrs.field(init=False, default="epoll") + + +# Some facts about epoll +# ---------------------- +# +# Internally, an epoll object is sort of like a WeakKeyDictionary where the +# keys are tuples of (fd number, file object). When you call epoll_ctl, you +# pass in an fd; that gets converted to an (fd number, file object) tuple by +# looking up the fd in the process's fd table at the time of the call. When an +# event happens on the file object, epoll_wait drops the file object part, and +# just returns the fd number in its event. So from the outside it looks like +# it's keeping a table of fds, but really it's a bit more complicated. This +# has some subtle consequences. +# +# In general, file objects inside the kernel are reference counted. Each entry +# in a process's fd table holds a strong reference to the corresponding file +# object, and most operations that use file objects take a temporary strong +# reference while they're working. So when you call close() on an fd, that +# might or might not cause the file object to be deallocated -- it depends on +# whether there are any other references to that file object. Some common ways +# this can happen: +# +# - after calling dup(), you have two fds in the same process referring to the +# same file object. Even if you close one fd (= remove that entry from the +# fd table), the file object will be kept alive by the other fd. +# - when calling fork(), the child inherits a copy of the parent's fd table, +# so all the file objects get another reference. (But if the fork() is +# followed by exec(), then all of the child's fds that have the CLOEXEC flag +# set will be closed at that point.) +# - most syscalls that work on fds take a strong reference to the underlying +# file object while they're using it. So there's one thread blocked in +# read(fd), and then another thread calls close() on the last fd referring +# to that object, the underlying file won't actually be closed until +# after read() returns. +# +# However, epoll does *not* take a reference to any of the file objects in its +# interest set (that's what makes it similar to a WeakKeyDictionary). File +# objects inside an epoll interest set will be deallocated if all *other* +# references to them are closed. And when that happens, the epoll object will +# automatically deregister that file object and stop reporting events on it. +# So that's quite handy. +# +# But, what happens if we do this? +# +# fd1 = open(...) +# epoll_ctl(EPOLL_CTL_ADD, fd1, ...) +# fd2 = dup(fd1) +# close(fd1) +# +# In this case, the dup() keeps the underlying file object alive, so it +# remains registered in the epoll object's interest set, as the tuple (fd1, +# file object). But, fd1 no longer refers to this file object! You might think +# there was some magic to handle this, but unfortunately no; the consequences +# are totally predictable from what I said above: +# +# If any events occur on the file object, then epoll will report them as +# happening on fd1, even though that doesn't make sense. +# +# Perhaps we would like to deregister fd1 to stop getting nonsensical events. +# But how? When we call epoll_ctl, we have to pass an fd number, which will +# get expanded to an (fd number, file object) tuple. We can't pass fd1, +# because when epoll_ctl tries to look it up, it won't find our file object. +# And we can't pass fd2, because that will get expanded to (fd2, file object), +# which is a different lookup key. In fact, it's *impossible* to de-register +# this fd! +# +# We could even have fd1 get assigned to another file object, and then we can +# have multiple keys registered simultaneously using the same fd number, like: +# (fd1, file object 1), (fd1, file object 2). And if events happen on either +# file object, then epoll will happily report that something happened to +# "fd1". +# +# Now here's what makes this especially nasty: suppose the old file object +# becomes, say, readable. That means that every time we call epoll_wait, it +# will return immediately to tell us that "fd1" is readable. Normally, we +# would handle this by de-registering fd1, waking up the corresponding call to +# wait_readable, then the user will call read() or recv() or something, and +# we're fine. But if this happens on a stale fd where we can't remove the +# registration, then we might get stuck in a state where epoll_wait *always* +# returns immediately, so our event loop becomes unable to sleep, and now our +# program is burning 100% of the CPU doing nothing, with no way out. +# +# +# What does this mean for Trio? +# ----------------------------- +# +# Since we don't control the user's code, we have no way to guarantee that we +# don't get stuck with stale fd's in our epoll interest set. For example, a +# user could call wait_readable(fd) in one task, and then while that's +# running, they might close(fd) from another task. In this situation, they're +# *supposed* to call notify_closing(fd) to let us know what's happening, so we +# can interrupt the wait_readable() call and avoid getting into this mess. And +# that's the only thing that can possibly work correctly in all cases. But +# sometimes user code has bugs. So if this does happen, we'd like to degrade +# gracefully, and survive without corrupting Trio's internal state or +# otherwise causing the whole program to explode messily. +# +# Our solution: we always use EPOLLONESHOT. This way, we might get *one* +# spurious event on a stale fd, but then epoll will automatically silence it +# until we explicitly say that we want more events... and if we have a stale +# fd, then we actually can't re-enable it! So we can't get stuck in an +# infinite busy-loop. If there's a stale fd hanging around, then it might +# cause a spurious `BusyResourceError`, or cause one wait_* call to return +# before it should have... but in general, the wait_* functions are allowed to +# have some spurious wakeups; the user code will just attempt the operation, +# get EWOULDBLOCK, and call wait_* again. And the program as a whole will +# survive, any exceptions will propagate, etc. +# +# As a bonus, EPOLLONESHOT also saves us having to explicitly deregister fds +# on the normal wakeup path, so it's a bit more efficient in general. +# +# However, EPOLLONESHOT has a few trade-offs to consider: +# +# First, you can't combine EPOLLONESHOT with EPOLLEXCLUSIVE. This is a bit sad +# in one somewhat rare case: if you have a multi-process server where a group +# of processes all share the same listening socket, then EPOLLEXCLUSIVE can be +# used to avoid "thundering herd" problems when a new connection comes in. But +# this isn't too bad. It's not clear if EPOLLEXCLUSIVE even works for us +# anyway: +# +# https://stackoverflow.com/questions/41582560/how-does-epolls-epollexclusive-mode-interact-with-level-triggering +# +# And it's not clear that EPOLLEXCLUSIVE is a great approach either: +# +# https://blog.cloudflare.com/the-sad-state-of-linux-socket-balancing/ +# +# And if we do need to support this, we could always add support through some +# more-specialized API in the future. So this isn't a blocker to using +# EPOLLONESHOT. +# +# Second, EPOLLONESHOT does not actually *deregister* the fd after delivering +# an event (EPOLL_CTL_DEL). Instead, it keeps the fd registered, but +# effectively does an EPOLL_CTL_MOD to set the fd's interest flags to +# all-zeros. So we could still end up with an fd hanging around in the +# interest set for a long time, even if we're not using it. +# +# Fortunately, this isn't a problem, because it's only a weak reference – if +# we have a stale fd that's been silenced by EPOLLONESHOT, then it wastes a +# tiny bit of kernel memory remembering this fd that can never be revived, but +# when the underlying file object is eventually closed, that memory will be +# reclaimed. So that's OK. +# +# The other issue is that when someone calls wait_*, using EPOLLONESHOT means +# that if we have ever waited for this fd before, we have to use EPOLL_CTL_MOD +# to re-enable it; but if it's a new fd, we have to use EPOLL_CTL_ADD. How do +# we know which one to use? There's no reasonable way to track which fds are +# currently registered -- remember, we're assuming the user might have gone +# and rearranged their fds without telling us! +# +# Fortunately, this also has a simple solution: if we wait on a socket or +# other fd once, then we'll probably wait on it lots of times. And the epoll +# object itself knows which fds it already has registered. So when an fd comes +# in, we optimistically assume that it's been waited on before, and try doing +# EPOLL_CTL_MOD. And if that fails with an ENOENT error, then we try again +# with EPOLL_CTL_ADD. +# +# So that's why this code is the way it is. And now you know more than you +# wanted to about how epoll works. + + [email protected](eq=False) +class EpollIOManager: + # Using lambda here because otherwise crash on import with gevent monkey patching + # See https://github.com/python-trio/trio/issues/2848 + _epoll: select.epoll = attrs.Factory(lambda: select.epoll()) + # {fd: EpollWaiters} + _registered: defaultdict[int, EpollWaiters] = attrs.Factory( + lambda: defaultdict(EpollWaiters), + ) + _force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair) + _force_wakeup_fd: int | None = None + + def __attrs_post_init__(self) -> None: + self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) + self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() + + def statistics(self) -> _EpollStatistics: + tasks_waiting_read = 0 + tasks_waiting_write = 0 + for waiter in self._registered.values(): + if waiter.read_task is not None: + tasks_waiting_read += 1 + if waiter.write_task is not None: + tasks_waiting_write += 1 + return _EpollStatistics( + tasks_waiting_read=tasks_waiting_read, + tasks_waiting_write=tasks_waiting_write, + ) + + def close(self) -> None: + self._epoll.close() + self._force_wakeup.close() + + def force_wakeup(self) -> None: + self._force_wakeup.wakeup_thread_and_signal_safe() + + # Return value must be False-y IFF the timeout expired, NOT if any I/O + # happened or force_wakeup was called. Otherwise it can be anything; gets + # passed straight through to process_events. + def get_events(self, timeout: float) -> EventResult: + # max_events must be > 0 or epoll gets cranky + # accessing self._registered from a thread looks dangerous, but it's + # OK because it doesn't matter if our value is a little bit off. + max_events = max(1, len(self._registered)) + return self._epoll.poll(timeout, max_events) + + def process_events(self, events: EventResult) -> None: + for fd, flags in events: + if fd == self._force_wakeup_fd: + self._force_wakeup.drain() + continue + waiters = self._registered[fd] + # EPOLLONESHOT always clears the flags when an event is delivered + waiters.current_flags = 0 + # Clever hack stolen from selectors.EpollSelector: an event + # with EPOLLHUP or EPOLLERR flags wakes both readers and + # writers. + if flags & ~select.EPOLLIN and waiters.write_task is not None: + _core.reschedule(waiters.write_task) + waiters.write_task = None + if flags & ~select.EPOLLOUT and waiters.read_task is not None: + _core.reschedule(waiters.read_task) + waiters.read_task = None + self._update_registrations(fd) + + def _update_registrations(self, fd: int) -> None: + waiters = self._registered[fd] + wanted_flags = 0 + if waiters.read_task is not None: + wanted_flags |= select.EPOLLIN + if waiters.write_task is not None: + wanted_flags |= select.EPOLLOUT + if wanted_flags != waiters.current_flags: + try: + try: + # First try EPOLL_CTL_MOD + self._epoll.modify(fd, wanted_flags | select.EPOLLONESHOT) + except OSError: + # If that fails, it might be a new fd; try EPOLL_CTL_ADD + self._epoll.register(fd, wanted_flags | select.EPOLLONESHOT) + waiters.current_flags = wanted_flags + except OSError as exc: + # If everything fails, probably it's a bad fd, e.g. because + # the fd was closed behind our back. In this case we don't + # want to try to unregister the fd, because that will probably + # fail too. Just clear our state and wake everyone up. + del self._registered[fd] + # This could raise (in case we're calling this inside one of + # the to-be-woken tasks), so we have to do it last. + wake_all(waiters, exc) + return + if not wanted_flags: + del self._registered[fd] + + async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None: + if not isinstance(fd, int): + fd = fd.fileno() + waiters = self._registered[fd] + if getattr(waiters, attr_name) is not None: + raise _core.BusyResourceError( + "another task is already reading / writing this fd", + ) + setattr(waiters, attr_name, _core.current_task()) + self._update_registrations(fd) + + def abort(_: RaiseCancelT) -> Abort: + setattr(waiters, attr_name, None) + self._update_registrations(fd) + return _core.Abort.SUCCEEDED + + await _core.wait_task_rescheduled(abort) + + @_public + async def wait_readable(self, fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``fd`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``fd`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + await self._epoll_wait(fd, "read_task") + + @_public + async def wait_writable(self, fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``fd``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + await self._epoll_wait(fd, "write_task") + + @_public + def notify_closing(self, fd: int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ + if not isinstance(fd, int): + fd = fd.fileno() + wake_all( + self._registered[fd], + _core.ClosedResourceError("another task closed this fd"), + ) + del self._registered[fd] + with contextlib.suppress(OSError, ValueError): + self._epoll.unregister(fd) diff --git a/contrib/python/trio/trio/_core/_io_kqueue.py b/contrib/python/trio/trio/_core/_io_kqueue.py new file mode 100644 index 00000000000..9718c4df809 --- /dev/null +++ b/contrib/python/trio/trio/_core/_io_kqueue.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import errno +import select +import sys +from contextlib import contextmanager +from typing import TYPE_CHECKING, Literal + +import attrs +import outcome + +from .. import _core +from ._run import _public +from ._wakeup_socketpair import WakeupSocketpair + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from typing_extensions import TypeAlias + + from .._core import Abort, RaiseCancelT, Task, UnboundedQueue + from .._file_io import _HasFileNo + +assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") + +EventResult: TypeAlias = "list[select.kevent]" + + [email protected](eq=False) +class _KqueueStatistics: + tasks_waiting: int + monitors: int + backend: Literal["kqueue"] = attrs.field(init=False, default="kqueue") + + [email protected](eq=False) +class KqueueIOManager: + _kqueue: select.kqueue = attrs.Factory(select.kqueue) + # {(ident, filter): Task or UnboundedQueue} + _registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = ( + attrs.Factory(dict) + ) + _force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair) + _force_wakeup_fd: int | None = None + + def __attrs_post_init__(self) -> None: + force_wakeup_event = select.kevent( + self._force_wakeup.wakeup_sock, + select.KQ_FILTER_READ, + select.KQ_EV_ADD, + ) + self._kqueue.control([force_wakeup_event], 0) + self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() + + def statistics(self) -> _KqueueStatistics: + tasks_waiting = 0 + monitors = 0 + for receiver in self._registered.values(): + if type(receiver) is _core.Task: + tasks_waiting += 1 + else: + monitors += 1 + return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors) + + def close(self) -> None: + self._kqueue.close() + self._force_wakeup.close() + + def force_wakeup(self) -> None: + self._force_wakeup.wakeup_thread_and_signal_safe() + + def get_events(self, timeout: float) -> EventResult: + # max_events must be > 0 or kqueue gets cranky + # and we generally want this to be strictly larger than the actual + # number of events we get, so that we can tell that we've gotten + # all the events in just 1 call. + max_events = len(self._registered) + 1 + events = [] + while True: + batch = self._kqueue.control([], max_events, timeout) + events += batch + if len(batch) < max_events: + break + else: # TODO: test this line + timeout = 0 + # and loop back to the start + return events + + def process_events(self, events: EventResult) -> None: + for event in events: + key = (event.ident, event.filter) + if event.ident == self._force_wakeup_fd: + self._force_wakeup.drain() + continue + receiver = self._registered[key] + if event.flags & select.KQ_EV_ONESHOT: # TODO: test this branch + del self._registered[key] + if isinstance(receiver, _core.Task): + _core.reschedule(receiver, outcome.Value(event)) + else: + receiver.put_nowait(event) # TODO: test this line + + # kevent registration is complicated -- e.g. aio submission can + # implicitly perform a EV_ADD, and EVFILT_PROC with NOTE_TRACK will + # automatically register filters for child processes. So our lowlevel + # API is *very* low-level: we expose the kqueue itself for adding + # events or sticking into AIO submission structs, and split waiting + # off into separate methods. It's your responsibility to make sure + # that handle_io never receives an event without a corresponding + # registration! This may be challenging if you want to be careful + # about e.g. KeyboardInterrupt. Possibly this API could be improved to + # be more ergonomic... + + @_public + def current_kqueue(self) -> select.kqueue: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__. + """ + return self._kqueue + + @contextmanager + @_public + def monitor_kevent( + self, + ident: int, + filter: int, + ) -> Iterator[_core.UnboundedQueue[select.kevent]]: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__. + """ + key = (ident, filter) + if key in self._registered: + raise _core.BusyResourceError( + "attempt to register multiple listeners for same ident/filter pair", + ) + q = _core.UnboundedQueue[select.kevent]() + self._registered[key] = q + try: + yield q + finally: + del self._registered[key] + + @_public + async def wait_kevent( + self, + ident: int, + filter: int, + abort_func: Callable[[RaiseCancelT], Abort], + ) -> Abort: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__. + """ + key = (ident, filter) + if key in self._registered: + raise _core.BusyResourceError( + "attempt to register multiple listeners for same ident/filter pair", + ) + self._registered[key] = _core.current_task() + + def abort(raise_cancel: RaiseCancelT) -> Abort: + r = abort_func(raise_cancel) + if r is _core.Abort.SUCCEEDED: # TODO: test this branch + del self._registered[key] + return r + + # wait_task_rescheduled does not have its return type typed + return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return] + + async def _wait_common( + self, + fd: int | _HasFileNo, + filter: int, + ) -> None: + if not isinstance(fd, int): + fd = fd.fileno() + flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT + event = select.kevent(fd, filter, flags) + self._kqueue.control([event], 0) + + def abort(_: RaiseCancelT) -> Abort: + event = select.kevent(fd, filter, select.KQ_EV_DELETE) + try: + self._kqueue.control([event], 0) + except OSError as exc: + # kqueue tracks individual fds (*not* the underlying file + # object, see _io_epoll.py for a long discussion of why this + # distinction matters), and automatically deregisters an event + # if the fd is closed. So if kqueue.control says that it + # doesn't know about this event, then probably it's because + # the fd was closed behind our backs. (Too bad we can't ask it + # to wake us up when this happens, versus discovering it after + # the fact... oh well, you can't have everything.) + # + # FreeBSD reports this using EBADF. macOS uses ENOENT. + if exc.errno in (errno.EBADF, errno.ENOENT): # pragma: no branch + pass + else: # pragma: no cover + # As far as we know, this branch can't happen. + raise + return _core.Abort.SUCCEEDED + + await self.wait_kevent(fd, filter, abort) + + @_public + async def wait_readable(self, fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``fd`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``fd`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + await self._wait_common(fd, select.KQ_FILTER_READ) + + @_public + async def wait_writable(self, fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``fd``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + await self._wait_common(fd, select.KQ_FILTER_WRITE) + + @_public + def notify_closing(self, fd: int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ + if not isinstance(fd, int): + fd = fd.fileno() + + for filter_ in [select.KQ_FILTER_READ, select.KQ_FILTER_WRITE]: + key = (fd, filter_) + receiver = self._registered.get(key) + + if receiver is None: + continue + + if type(receiver) is _core.Task: + event = select.kevent(fd, filter_, select.KQ_EV_DELETE) + self._kqueue.control([event], 0) + exc = _core.ClosedResourceError("another task closed this fd") + _core.reschedule(receiver, outcome.Error(exc)) + del self._registered[key] + else: + # XX this is an interesting example of a case where being able + # to close a queue would be useful... + raise NotImplementedError( + "can't close an fd that monitor_kevent is using", + ) diff --git a/contrib/python/trio/trio/_core/_io_windows.py b/contrib/python/trio/trio/_core/_io_windows.py new file mode 100644 index 00000000000..9a9d6b9cc41 --- /dev/null +++ b/contrib/python/trio/trio/_core/_io_windows.py @@ -0,0 +1,1042 @@ +from __future__ import annotations + +import enum +import itertools +import socket +import sys +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Literal, + Protocol, + TypeVar, + cast, +) + +import attrs +from outcome import Value + +from .. import _core +from ._io_common import wake_all +from ._run import _public +from ._windows_cffi import ( + INVALID_HANDLE_VALUE, + AFDPollFlags, + CData, + CompletionModes, + CType, + ErrorCodes, + FileFlags, + Handle, + IoControlCodes, + WSAIoctls, + _handle, + _Overlapped, + ffi, + kernel32, + ntdll, + raise_winerror, + ws2_32, +) + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from typing_extensions import Buffer, TypeAlias + + from .._file_io import _HasFileNo + from ._traps import Abort, RaiseCancelT + from ._unbounded_queue import UnboundedQueue + +EventResult: TypeAlias = int +T = TypeVar("T") + +# There's a lot to be said about the overall design of a Windows event +# loop. See +# +# https://github.com/python-trio/trio/issues/52 +# +# for discussion. This now just has some lower-level notes: +# +# How IOCP fits together: +# +# The general model is that you call some function like ReadFile or WriteFile +# to tell the kernel that you want it to perform some operation, and the +# kernel goes off and does that in the background, then at some point later it +# sends you a notification that the operation is complete. There are some more +# exotic APIs that don't quite fit this pattern, but most APIs do. +# +# Each background operation is tracked using an OVERLAPPED struct, that +# uniquely identifies that particular operation. +# +# An "IOCP" (or "I/O completion port") is an object that lets the kernel send +# us these notifications -- basically it's just a kernel->userspace queue. +# +# Each IOCP notification is represented by an OVERLAPPED_ENTRY struct, which +# contains 3 fields: +# - The "completion key". This is an opaque integer that we pick, and use +# however is convenient. +# - pointer to the OVERLAPPED struct for the completed operation. +# - dwNumberOfBytesTransferred (an integer). +# +# And in addition, for regular I/O, the OVERLAPPED structure gets filled in +# with: +# - result code (named "Internal") +# - number of bytes transferred (named "InternalHigh"); usually redundant +# with dwNumberOfBytesTransferred. +# +# There are also some other entries in OVERLAPPED which only matter on input: +# - Offset and OffsetHigh which are inputs to {Read,Write}File and +# otherwise always zero +# - hEvent which is for if you aren't using IOCP; we always set it to zero. +# +# That describes the usual pattern for operations and the usual meaning of +# these struct fields, but really these are just some arbitrary chunks of +# bytes that get passed back and forth, so some operations like to overload +# them to mean something else. +# +# You can also directly queue an OVERLAPPED_ENTRY object to an IOCP by calling +# PostQueuedCompletionStatus. When you use this you get to set all the +# OVERLAPPED_ENTRY fields to arbitrary values. +# +# You can request to cancel any operation if you know which handle it was +# issued on + the OVERLAPPED struct that identifies it (via CancelIoEx). This +# request might fail because the operation has already completed, or it might +# be queued to happen in the background, so you only find out whether it +# succeeded or failed later, when we get back the notification for the +# operation being complete. +# +# There are three types of operations that we support: +# +# == Regular I/O operations on handles (e.g. files or named pipes) == +# +# Implemented by: register_with_iocp, wait_overlapped +# +# To use these, you have to register the handle with your IOCP first. Once +# it's registered, any operations on that handle will automatically send +# completion events to that IOCP, with a completion key that you specify *when +# the handle is registered* (so you can't use different completion keys for +# different operations). +# +# We give these two dedicated completion keys: CKeys.WAIT_OVERLAPPED for +# regular operations, and CKeys.LATE_CANCEL that's used to make +# wait_overlapped cancellable even if the user forgot to call +# register_with_iocp. The problem here is that after we request the cancel, +# wait_overlapped keeps blocking until it sees the completion notification... +# but if the user forgot to register_with_iocp, then the completion will never +# come, so the cancellation will never resolve. To avoid this, whenever we try +# to cancel an I/O operation and the cancellation fails, we use +# PostQueuedCompletionStatus to send a CKeys.LATE_CANCEL notification. If this +# arrives before the real completion, we assume the user forgot to call +# register_with_iocp on their handle, and raise an error accordingly. +# +# == Socket state notifications == +# +# Implemented by: wait_readable, wait_writable +# +# The public APIs that windows provides for this are all really awkward and +# don't integrate with IOCP. So we drop down to a lower level, and talk +# directly to the socket device driver in the kernel, which is called "AFD". +# Unfortunately, this is a totally undocumented internal API. Fortunately +# libuv also does this, so we can be pretty confident that MS won't break it +# on us, and there is a *little* bit of information out there if you go +# digging. +# +# Basically: we open a magic file that refers to the AFD driver, register the +# magic file with our IOCP, and then we can issue regular overlapped I/O +# operations on that handle. Specifically, the operation we use is called +# IOCTL_AFD_POLL, which lets us pass in a buffer describing which events we're +# interested in on a given socket (readable, writable, etc.). Later, when the +# operation completes, the kernel rewrites the buffer we passed in to record +# which events happened, and uses IOCP as normal to notify us that this +# operation has completed. +# +# Unfortunately, the Windows kernel seems to have bugs if you try to issue +# multiple simultaneous IOCTL_AFD_POLL operations on the same socket (see +# https://github.com/python-trio/trio/wiki/notes-to-self#afd-labpy). +# So if a user calls wait_readable and +# wait_writable at the same time, we have to combine those into a single +# IOCTL_AFD_POLL. This means we can't just use the wait_overlapped machinery. +# Instead we have some dedicated code to handle these operations, and a +# dedicated completion key CKeys.AFD_POLL. +# +# Sources of information: +# - https://github.com/python-trio/trio/issues/52 +# - Wepoll: https://github.com/piscisaureus/wepoll/ +# - libuv: https://github.com/libuv/libuv/ +# - ReactOS: https://github.com/reactos/reactos/ +# - Ancient leaked copies of the Windows NT and Winsock source code: +# https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655 +# https://github.com/metoo10987/WinNT4/blob/f5c14e6b42c8f45c20fe88d14c61f9d6e0386b8e/private/ntos/afd/poll.c#L68-L707 +# - The WSAEventSelect docs (this exposes a finer-grained set of events than +# select(), so if you squint you can treat it as a source of information on +# the fine-grained AFD poll types) +# +# +# == Everything else == +# +# There are also some weirder APIs for interacting with IOCP. For example, the +# "Job" API lets you specify an IOCP handle and "completion key", and then in +# the future whenever certain events happen it sends uses IOCP to send a +# notification. These notifications don't correspond to any particular +# operation; they're just spontaneous messages you get. The +# "dwNumberOfBytesTransferred" field gets repurposed to carry an identifier +# for the message type (e.g. JOB_OBJECT_MSG_EXIT_PROCESS), and the +# "lpOverlapped" field gets repurposed to carry some arbitrary data that +# depends on the message type (e.g. the pid of the process that exited). +# +# To handle these, we have monitor_completion_key, where we hand out an +# unassigned completion key, let users set it up however they want, and then +# get any events that arrive on that key. +# +# (Note: monitor_completion_key is not documented or fully baked; expect it to +# change in the future.) + + +# Our completion keys +class CKeys(enum.IntEnum): + AFD_POLL = 0 + WAIT_OVERLAPPED = 1 + LATE_CANCEL = 2 + FORCE_WAKEUP = 3 + USER_DEFINED = 4 # and above + + +# AFD_POLL has a finer-grained set of events than other APIs. We collapse them +# down into Unix-style "readable" and "writable". +# +# Note: AFD_POLL_LOCAL_CLOSE isn't a reliable substitute for notify_closing(), +# because even if the user closes the socket *handle*, the socket *object* +# could still remain open, e.g. if the socket was dup'ed (possibly into +# another process). Explicitly calling notify_closing() guarantees that +# everyone waiting on the *handle* wakes up, which is what you'd expect. +# +# However, we can't avoid getting LOCAL_CLOSE notifications -- the kernel +# delivers them whether we ask for them or not -- so better to include them +# here for documentation, and so that when we check (delivered & requested) we +# get a match. + +READABLE_FLAGS = ( + AFDPollFlags.AFD_POLL_RECEIVE + | AFDPollFlags.AFD_POLL_ACCEPT + | AFDPollFlags.AFD_POLL_DISCONNECT # other side sent an EOF + | AFDPollFlags.AFD_POLL_ABORT + | AFDPollFlags.AFD_POLL_LOCAL_CLOSE +) + +WRITABLE_FLAGS = ( + AFDPollFlags.AFD_POLL_SEND + | AFDPollFlags.AFD_POLL_CONNECT_FAIL + | AFDPollFlags.AFD_POLL_ABORT + | AFDPollFlags.AFD_POLL_LOCAL_CLOSE +) + + +# Annoyingly, while the API makes it *seem* like you can happily issue as many +# independent AFD_POLL operations as you want without them interfering with +# each other, in fact if you issue two AFD_POLL operations for the same socket +# at the same time with notification going to the same IOCP port, then Windows +# gets super confused. For example, if we issue one operation from +# wait_readable, and another independent operation from wait_writable, then +# Windows may complete the wait_writable operation when the socket becomes +# readable. +# +# To avoid this, we have to coalesce all the operations on a single socket +# into one, and when the set of waiters changes we have to throw away the old +# operation and start a new one. [email protected](eq=False) +class AFDWaiters: + read_task: _core.Task | None = None + write_task: _core.Task | None = None + current_op: AFDPollOp | None = None + + +# Just used for internal type checking. +class _AFDHandle(Protocol): + Handle: Handle + Status: int + Events: int + + +# Just used for internal type checking. +class _AFDPollInfo(Protocol): + Timeout: int + NumberOfHandles: int + Exclusive: int + Handles: list[_AFDHandle] + + +# We also need to bundle up all the info for a single op into a standalone +# object, because we need to keep all these objects alive until the operation +# finishes, even if we're throwing it away. [email protected](eq=False) +class AFDPollOp: + lpOverlapped: CData + poll_info: _AFDPollInfo + waiters: AFDWaiters + afd_group: AFDGroup + + +# The Windows kernel has a weird issue when using AFD handles. If you have N +# instances of wait_readable/wait_writable registered with a single AFD handle, +# then cancelling any one of them takes something like O(N**2) time. So if we +# used just a single AFD handle, then cancellation would quickly become very +# expensive, e.g. a program with N active sockets would take something like +# O(N**3) time to unwind after control-C. The solution is to spread our sockets +# out over multiple AFD handles, so that N doesn't grow too large for any +# individual handle. +MAX_AFD_GROUP_SIZE = 500 # at 1000, the cubic scaling is just starting to bite + + [email protected](eq=False) +class AFDGroup: + size: int + handle: Handle + + +assert not TYPE_CHECKING or sys.platform == "win32" + + [email protected](eq=False) +class _WindowsStatistics: + tasks_waiting_read: int + tasks_waiting_write: int + tasks_waiting_overlapped: int + completion_key_monitors: int + backend: Literal["windows"] = attrs.field(init=False, default="windows") + + +# Maximum number of events to dequeue from the completion port on each pass +# through the run loop. Somewhat arbitrary. Should be large enough to collect +# a good set of tasks on each loop, but not so large to waste tons of memory. +# (Each WindowsIOManager holds a buffer whose size is ~32x this number.) +MAX_EVENTS = 1000 + + +def _check(success: T) -> T: + if not success: + raise_winerror() + return success + + +def _get_underlying_socket( + sock: _HasFileNo | int | Handle, + *, + which: WSAIoctls = WSAIoctls.SIO_BASE_HANDLE, +) -> Handle: + if hasattr(sock, "fileno"): + sock = sock.fileno() + base_ptr = ffi.new("HANDLE *") + out_size = ffi.new("DWORD *") + failed = ws2_32.WSAIoctl( + ffi.cast("SOCKET", sock), + which, + ffi.NULL, + 0, + base_ptr, + ffi.sizeof("HANDLE"), + out_size, + ffi.NULL, + ffi.NULL, + ) + if failed: + code = ws2_32.WSAGetLastError() + raise_winerror(code) + return Handle(base_ptr[0]) + + +def _get_base_socket(sock: _HasFileNo | int | Handle) -> Handle: + # There is a development kit for LSPs called Komodia Redirector. + # It does some unusual (some might say evil) things like intercepting + # SIO_BASE_HANDLE (fails) and SIO_BSP_HANDLE_SELECT (returns the same + # socket) in a misguided attempt to prevent bypassing it. It's been used + # in malware including the infamous Lenovo Superfish incident from 2015, + # but unfortunately is also used in some legitimate products such as + # parental control tools and Astrill VPN. Komodia happens to not + # block SIO_BSP_HANDLE_POLL, so we'll try SIO_BASE_HANDLE and fall back + # to SIO_BSP_HANDLE_POLL if it doesn't work. + # References: + # - https://github.com/piscisaureus/wepoll/blob/0598a791bf9cbbf480793d778930fc635b044980/wepoll.c#L2223 + # - https://github.com/tokio-rs/mio/issues/1314 + + while True: + try: + # If this is not a Komodia-intercepted socket, we can just use + # SIO_BASE_HANDLE. + return _get_underlying_socket(sock) + except OSError as ex: + if ex.winerror == ErrorCodes.ERROR_NOT_SOCKET: + # SIO_BASE_HANDLE might fail even without LSP intervention, + # if we get something that's not a socket. + raise + if hasattr(sock, "fileno"): + sock = sock.fileno() + sock = _handle(sock) + next_sock = _get_underlying_socket( + sock, + which=WSAIoctls.SIO_BSP_HANDLE_POLL, + ) + if next_sock == sock: + # If BSP_HANDLE_POLL returns the same socket we already had, + # then there's no layering going on and we need to fail + # to prevent an infinite loop. + raise RuntimeError( + "Unexpected network configuration detected: " + "SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't " + "return a different socket. Please file a bug at " + "https://github.com/python-trio/trio/issues/new, " + "and include the output of running: " + "netsh winsock show catalog", + ) from ex + # Otherwise we've gotten at least one layer deeper, so + # loop back around to keep digging. + sock = next_sock + + +def _afd_helper_handle() -> Handle: + # The "AFD" driver is exposed at the NT path "\Device\Afd". We're using + # the Win32 CreateFile, though, so we have to pass a Win32 path. \\.\ is + # how Win32 refers to the NT \GLOBAL??\ directory, and GLOBALROOT is a + # symlink inside that directory that points to the root of the NT path + # system. So by sticking that in front of the NT path, we get a Win32 + # path. Alternatively, we could use NtCreateFile directly, since it takes + # an NT path. But we already wrap CreateFileW so this was easier. + # References: + # https://blogs.msdn.microsoft.com/jeremykuhne/2016/05/02/dos-to-nt-a-paths-journey/ + # https://stackoverflow.com/a/21704022 + # + # I'm actually not sure what the \Trio part at the end of the path does. + # Wepoll uses \Device\Afd\Wepoll, so I just copied them. (I'm guessing it + # might be visible in some debug tools, and is otherwise arbitrary?) + rawname = r"\\.\GLOBALROOT\Device\Afd\Trio".encode("utf-16le") + b"\0\0" + rawname_buf = ffi.from_buffer(rawname) + + handle = kernel32.CreateFileW( + ffi.cast("LPCWSTR", rawname_buf), + FileFlags.SYNCHRONIZE, + FileFlags.FILE_SHARE_READ | FileFlags.FILE_SHARE_WRITE, + ffi.NULL, # no security attributes + FileFlags.OPEN_EXISTING, + FileFlags.FILE_FLAG_OVERLAPPED, + ffi.NULL, # no template file + ) + if handle == INVALID_HANDLE_VALUE: # pragma: no cover + raise_winerror() + return handle + + [email protected](slots=False) +class CompletionKeyEventInfo: + lpOverlapped: CData | int + dwNumberOfBytesTransferred: int + + +class WindowsIOManager: + def __init__(self) -> None: + # If this method raises an exception, then __del__ could run on a + # half-initialized object. So we initialize everything that __del__ + # touches to safe values up front, before we do anything that can + # fail. + self._iocp = None + self._all_afd_handles: list[Handle] = [] + + self._iocp = _check( + kernel32.CreateIoCompletionPort(INVALID_HANDLE_VALUE, ffi.NULL, 0, 0), + ) + self._events = ffi.new("OVERLAPPED_ENTRY[]", MAX_EVENTS) + + self._vacant_afd_groups: set[AFDGroup] = set() + # {lpOverlapped: AFDPollOp} + self._afd_ops: dict[CData, AFDPollOp] = {} + # {socket handle: AFDWaiters} + self._afd_waiters: dict[Handle, AFDWaiters] = {} + + # {lpOverlapped: task} + self._overlapped_waiters: dict[CData, _core.Task] = {} + self._posted_too_late_to_cancel: set[CData] = set() + + self._completion_key_queues: dict[int, UnboundedQueue[object]] = {} + self._completion_key_counter = itertools.count(CKeys.USER_DEFINED) + + with socket.socket() as s: + # We assume we're not working with any LSP that changes + # how select() is supposed to work. Validate this by + # ensuring that the result of SIO_BSP_HANDLE_SELECT (the + # LSP-hookable mechanism for "what should I use for + # select()?") matches that of SIO_BASE_HANDLE ("what is + # the real non-hooked underlying socket here?"). + # + # This doesn't work for Komodia-based LSPs; see the comments + # in _get_base_socket() for details. But we have special + # logic for those, so we just skip this check if + # SIO_BASE_HANDLE fails. + + # LSPs can in theory override this, but we believe that it never + # actually happens in the wild (except Komodia) + select_handle = _get_underlying_socket( + s, + which=WSAIoctls.SIO_BSP_HANDLE_SELECT, + ) + try: + # LSPs shouldn't override this... + base_handle = _get_underlying_socket(s, which=WSAIoctls.SIO_BASE_HANDLE) + except OSError: + # But Komodia-based LSPs do anyway, in a way that causes + # a failure with WSAEFAULT. We have special handling for + # them in _get_base_socket(). Make sure it works. + _get_base_socket(s) + else: + if base_handle != select_handle: + raise RuntimeError( + "Unexpected network configuration detected: " + "SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ. " + "Please file a bug at " + "https://github.com/python-trio/trio/issues/new, " + "and include the output of running: " + "netsh winsock show catalog", + ) + + def close(self) -> None: + try: + if self._iocp is not None: + iocp = self._iocp + self._iocp = None + _check(kernel32.CloseHandle(iocp)) + finally: + while self._all_afd_handles: + afd_handle = self._all_afd_handles.pop() + _check(kernel32.CloseHandle(afd_handle)) + + def __del__(self) -> None: + self.close() + + def statistics(self) -> _WindowsStatistics: + tasks_waiting_read = 0 + tasks_waiting_write = 0 + for waiter in self._afd_waiters.values(): + if waiter.read_task is not None: + tasks_waiting_read += 1 + if waiter.write_task is not None: + tasks_waiting_write += 1 + return _WindowsStatistics( + tasks_waiting_read=tasks_waiting_read, + tasks_waiting_write=tasks_waiting_write, + tasks_waiting_overlapped=len(self._overlapped_waiters), + completion_key_monitors=len(self._completion_key_queues), + ) + + def force_wakeup(self) -> None: + assert self._iocp is not None + _check( + kernel32.PostQueuedCompletionStatus( + self._iocp, + 0, + CKeys.FORCE_WAKEUP, + ffi.NULL, + ), + ) + + def get_events(self, timeout: float) -> EventResult: + received = ffi.new("PULONG") + milliseconds = round(1000 * timeout) + if timeout > 0 and milliseconds == 0: + milliseconds = 1 + try: + assert self._iocp is not None + _check( + kernel32.GetQueuedCompletionStatusEx( + self._iocp, + self._events, + MAX_EVENTS, + received, + milliseconds, + 0, + ), + ) + except OSError as exc: + if exc.winerror != ErrorCodes.WAIT_TIMEOUT: # pragma: no cover + raise + return 0 + result = received[0] + assert isinstance(result, int) + return result + + def process_events(self, received: EventResult) -> None: + for i in range(received): + entry = self._events[i] + if entry.lpCompletionKey == CKeys.AFD_POLL: + lpo = entry.lpOverlapped + op = self._afd_ops.pop(lpo) + waiters = op.waiters + if waiters.current_op is not op: + # Stale op, nothing to do + pass + else: + waiters.current_op = None + # I don't think this can happen, so if it does let's crash + # and get a debug trace. + if lpo.Internal != 0: # pragma: no cover + code = ntdll.RtlNtStatusToDosError(lpo.Internal) + raise_winerror(code) + flags = op.poll_info.Handles[0].Events + if waiters.read_task and flags & READABLE_FLAGS: + _core.reschedule(waiters.read_task) + waiters.read_task = None + if waiters.write_task and flags & WRITABLE_FLAGS: + _core.reschedule(waiters.write_task) + waiters.write_task = None + self._refresh_afd(op.poll_info.Handles[0].Handle) + elif entry.lpCompletionKey == CKeys.WAIT_OVERLAPPED: + # Regular I/O event, dispatch on lpOverlapped + waiter = self._overlapped_waiters.pop(entry.lpOverlapped) + overlapped = entry.lpOverlapped + transferred = entry.dwNumberOfBytesTransferred + info = CompletionKeyEventInfo( + lpOverlapped=overlapped, + dwNumberOfBytesTransferred=transferred, + ) + _core.reschedule(waiter, Value(info)) + elif entry.lpCompletionKey == CKeys.LATE_CANCEL: + # Post made by a regular I/O event's abort_fn + # after it failed to cancel the I/O. If we still + # have a waiter with this lpOverlapped, we didn't + # get the regular I/O completion and almost + # certainly the user forgot to call + # register_with_iocp. + self._posted_too_late_to_cancel.remove(entry.lpOverlapped) + try: + waiter = self._overlapped_waiters.pop(entry.lpOverlapped) + except KeyError: + # Looks like the actual completion got here before this + # fallback post did -- we're in the "expected" case of + # too-late-to-cancel, where the user did nothing wrong. + # Nothing more to do. + pass + else: + exc = _core.TrioInternalError( + f"Failed to cancel overlapped I/O in {waiter.name} and didn't " + "receive the completion either. Did you forget to " + "call register_with_iocp()?", + ) + # Raising this out of handle_io ensures that + # the user will see our message even if some + # other task is in an uncancellable wait due + # to the same underlying forgot-to-register + # issue (if their CancelIoEx succeeds, we + # have no way of noticing that their completion + # won't arrive). Unfortunately it loses the + # task traceback. If you're debugging this + # error and can't tell where it's coming from, + # try changing this line to + # _core.reschedule(waiter, outcome.Error(exc)) + raise exc + elif entry.lpCompletionKey == CKeys.FORCE_WAKEUP: + pass + else: + # dispatch on lpCompletionKey + queue = self._completion_key_queues[entry.lpCompletionKey] + overlapped = int(ffi.cast("uintptr_t", entry.lpOverlapped)) + transferred = entry.dwNumberOfBytesTransferred + info = CompletionKeyEventInfo( + lpOverlapped=overlapped, + dwNumberOfBytesTransferred=transferred, + ) + queue.put_nowait(info) + + def _register_with_iocp(self, handle_: int | CData, completion_key: int) -> None: + handle = _handle(handle_) + assert self._iocp is not None + _check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0)) + # Supposedly this makes things slightly faster, by disabling the + # ability to do WaitForSingleObject(handle). We would never want to do + # that anyway, so might as well get the extra speed (if any). + # Ref: http://www.lenholgate.com/blog/2009/09/interesting-blog-posts-on-high-performance-servers.html + _check( + kernel32.SetFileCompletionNotificationModes( + handle, + CompletionModes.FILE_SKIP_SET_EVENT_ON_HANDLE, + ), + ) + + ################################################################ + # AFD stuff + ################################################################ + + def _refresh_afd(self, base_handle: Handle) -> None: + waiters = self._afd_waiters[base_handle] + if waiters.current_op is not None: + afd_group = waiters.current_op.afd_group + try: + _check( + kernel32.CancelIoEx( + afd_group.handle, + waiters.current_op.lpOverlapped, + ), + ) + except OSError as exc: + if exc.winerror != ErrorCodes.ERROR_NOT_FOUND: + # I don't think this is possible, so if it happens let's + # crash noisily. + raise # pragma: no cover + waiters.current_op = None + afd_group.size -= 1 + self._vacant_afd_groups.add(afd_group) + + flags = 0 + if waiters.read_task is not None: + flags |= READABLE_FLAGS + if waiters.write_task is not None: + flags |= WRITABLE_FLAGS + + if not flags: + del self._afd_waiters[base_handle] + else: + try: + afd_group = self._vacant_afd_groups.pop() + except KeyError: + afd_group = AFDGroup(0, _afd_helper_handle()) + self._register_with_iocp(afd_group.handle, CKeys.AFD_POLL) + self._all_afd_handles.append(afd_group.handle) + self._vacant_afd_groups.add(afd_group) + + lpOverlapped = ffi.new("LPOVERLAPPED") + + poll_info = cast("_AFDPollInfo", ffi.new("AFD_POLL_INFO *")) + poll_info.Timeout = 2**63 - 1 # INT64_MAX + poll_info.NumberOfHandles = 1 + poll_info.Exclusive = 0 + poll_info.Handles[0].Handle = base_handle + poll_info.Handles[0].Status = 0 + poll_info.Handles[0].Events = flags + + try: + _check( + kernel32.DeviceIoControl( + afd_group.handle, + IoControlCodes.IOCTL_AFD_POLL, + cast("CType", poll_info), # type: ignore[arg-type] + ffi.sizeof("AFD_POLL_INFO"), + cast("CType", poll_info), # type: ignore[arg-type] + ffi.sizeof("AFD_POLL_INFO"), + ffi.NULL, + lpOverlapped, + ), + ) + except OSError as exc: + if exc.winerror != ErrorCodes.ERROR_IO_PENDING: + # This could happen if the socket handle got closed behind + # our back while a wait_* call was pending, and we tried + # to re-issue the call. Clear our state and wake up any + # pending calls. + del self._afd_waiters[base_handle] + # Do this last, because it could raise. + wake_all(waiters, exc) + return + op = AFDPollOp(lpOverlapped, poll_info, waiters, afd_group) + waiters.current_op = op + self._afd_ops[lpOverlapped] = op + afd_group.size += 1 + if afd_group.size >= MAX_AFD_GROUP_SIZE: + self._vacant_afd_groups.remove(afd_group) + + async def _afd_poll(self, sock: _HasFileNo | int, mode: str) -> None: + base_handle = _get_base_socket(sock) + waiters = self._afd_waiters.get(base_handle) + if waiters is None: + waiters = AFDWaiters() + self._afd_waiters[base_handle] = waiters + if getattr(waiters, mode) is not None: + raise _core.BusyResourceError + setattr(waiters, mode, _core.current_task()) + # Could potentially raise if the handle is somehow invalid; that's OK, + # we let it escape. + self._refresh_afd(base_handle) + + def abort_fn(_: RaiseCancelT) -> Abort: + setattr(waiters, mode, None) + self._refresh_afd(base_handle) + return _core.Abort.SUCCEEDED + + await _core.wait_task_rescheduled(abort_fn) + + @_public + async def wait_readable(self, sock: _HasFileNo | int) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``sock`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``sock`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + await self._afd_poll(sock, "read_task") + + @_public + async def wait_writable(self, sock: _HasFileNo | int) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``sock``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + await self._afd_poll(sock, "write_task") + + @_public + def notify_closing(self, handle: Handle | int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ + handle = _get_base_socket(handle) + waiters = self._afd_waiters.get(handle) + if waiters is not None: + wake_all(waiters, _core.ClosedResourceError()) + self._refresh_afd(handle) + + ################################################################ + # Regular overlapped operations + ################################################################ + + @_public + def register_with_iocp(self, handle: int | CData) -> None: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED) + + @_public + async def wait_overlapped( + self, + handle_: int | CData, + lpOverlapped: CData | int, + ) -> object: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + handle = _handle(handle_) + if isinstance(lpOverlapped, int): # TODO: test this line + lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped) + if lpOverlapped in self._overlapped_waiters: # TODO: test this line + raise _core.BusyResourceError( + "another task is already waiting on that lpOverlapped", + ) + task = _core.current_task() + self._overlapped_waiters[lpOverlapped] = task + raise_cancel = None + + def abort(raise_cancel_: RaiseCancelT) -> Abort: + nonlocal raise_cancel + raise_cancel = raise_cancel_ + try: + _check(kernel32.CancelIoEx(handle, lpOverlapped)) + except OSError as exc: + if exc.winerror == ErrorCodes.ERROR_NOT_FOUND: + assert self._iocp is not None + # Too late to cancel. If this happens because the + # operation is already completed, we don't need to do + # anything; we'll get a notification of that completion + # soon. But another possibility is that the operation was + # performed on a handle that wasn't registered with our + # IOCP (ie, the user forgot to call register_with_iocp), + # in which case we're just never going to see the + # completion. To avoid an uncancellable infinite sleep in + # the latter case, we'll PostQueuedCompletionStatus here, + # and if our post arrives before the original completion + # does, we'll assume the handle wasn't registered. + _check( + kernel32.PostQueuedCompletionStatus( + self._iocp, + 0, + CKeys.LATE_CANCEL, + lpOverlapped, + ), + ) + # Keep the lpOverlapped referenced so its address + # doesn't get reused until our posted completion + # status has been processed. Otherwise, we can + # get confused about which completion goes with + # which I/O. + self._posted_too_late_to_cancel.add(lpOverlapped) + else: # pragma: no cover + raise _core.TrioInternalError( + "CancelIoEx failed with unexpected error", + ) from exc + return _core.Abort.FAILED + + # TODO: what type does this return? + info = await _core.wait_task_rescheduled(abort) + lpOverlappedTyped = cast("_Overlapped", lpOverlapped) + if lpOverlappedTyped.Internal != 0: + # the lpOverlapped reports the error as an NT status code, + # which we must convert back to a Win32 error code before + # it will produce the right sorts of exceptions + code = ntdll.RtlNtStatusToDosError(lpOverlappedTyped.Internal) + if code == ErrorCodes.ERROR_OPERATION_ABORTED: + if raise_cancel is not None: + raise_cancel() + else: + # We didn't request this cancellation, so assume + # it happened due to the underlying handle being + # closed before the operation could complete. + raise _core.ClosedResourceError("another task closed this resource") + else: + raise_winerror(code) + return info + + async def _perform_overlapped( + self, + handle: int | CData, + submit_fn: Callable[[_Overlapped], None], + ) -> _Overlapped: + # submit_fn(lpOverlapped) submits some I/O + # it may raise an OSError with ERROR_IO_PENDING + # the handle must already be registered using + # register_with_iocp(handle) + # This always does a schedule point, but it's possible that the + # operation will not be cancellable, depending on how Windows is + # feeling today. So we need to check for cancellation manually. + await _core.checkpoint_if_cancelled() + lpOverlapped = cast("_Overlapped", ffi.new("LPOVERLAPPED")) + try: + submit_fn(lpOverlapped) + except OSError as exc: + if exc.winerror != ErrorCodes.ERROR_IO_PENDING: + raise + await self.wait_overlapped(handle, cast("CData", lpOverlapped)) + return lpOverlapped + + @_public + async def write_overlapped( + self, + handle: int | CData, + data: Buffer, + file_offset: int = 0, + ) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + with ffi.from_buffer(data) as cbuf: + + def submit_write(lpOverlapped: _Overlapped) -> None: + # yes, these are the real documented names + offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME + offset_fields.Offset = file_offset & 0xFFFFFFFF + offset_fields.OffsetHigh = file_offset >> 32 + _check( + kernel32.WriteFile( + _handle(handle), + ffi.cast("LPCVOID", cbuf), + len(cbuf), + ffi.NULL, + lpOverlapped, + ), + ) + + lpOverlapped = await self._perform_overlapped(handle, submit_write) + # this is "number of bytes transferred" + return lpOverlapped.InternalHigh + + @_public + async def readinto_overlapped( + self, + handle: int | CData, + buffer: Buffer, + file_offset: int = 0, + ) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + with ffi.from_buffer(buffer, require_writable=True) as cbuf: + + def submit_read(lpOverlapped: _Overlapped) -> None: + offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME + offset_fields.Offset = file_offset & 0xFFFFFFFF + offset_fields.OffsetHigh = file_offset >> 32 + _check( + kernel32.ReadFile( + _handle(handle), + ffi.cast("LPVOID", cbuf), + len(cbuf), + ffi.NULL, + lpOverlapped, + ), + ) + + lpOverlapped = await self._perform_overlapped(handle, submit_read) + return lpOverlapped.InternalHigh + + ################################################################ + # Raw IOCP operations + ################################################################ + + @_public + def current_iocp(self) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + assert self._iocp is not None + return int(ffi.cast("uintptr_t", self._iocp)) + + @contextmanager + @_public + def monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + <https://github.com/python-trio/trio/issues/26>`__ and `#52 + <https://github.com/python-trio/trio/issues/52>`__. + """ + key = next(self._completion_key_counter) + queue = _core.UnboundedQueue[object]() + self._completion_key_queues[key] = queue + try: + yield (key, queue) + finally: + del self._completion_key_queues[key] diff --git a/contrib/python/trio/trio/_core/_ki.py b/contrib/python/trio/trio/_core/_ki.py new file mode 100644 index 00000000000..9fa849229a3 --- /dev/null +++ b/contrib/python/trio/trio/_core/_ki.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import signal +import sys +import types +import weakref +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar + +import attrs + +from .._util import is_main_thread +from ._run_context import GLOBAL_RUN_CONTEXT + +if TYPE_CHECKING: + import types + from collections.abc import Callable + + from typing_extensions import Self, TypeGuard +# In ordinary single-threaded Python code, when you hit control-C, it raises +# an exception and automatically does all the regular unwinding stuff. +# +# In Trio code, we would like hitting control-C to raise an exception and +# automatically do all the regular unwinding stuff. In particular, we would +# like to maintain our invariant that all tasks always run to completion (one +# way or another), by unwinding all of them. +# +# But it's basically impossible to write the core task running code in such a +# way that it can maintain this invariant in the face of KeyboardInterrupt +# exceptions arising at arbitrary bytecode positions. Similarly, if a +# KeyboardInterrupt happened at the wrong moment inside pretty much any of our +# inter-task synchronization or I/O primitives, then the system state could +# get corrupted and prevent our being able to clean up properly. +# +# So, we need a way to defer KeyboardInterrupt processing from these critical +# sections. +# +# Things that don't work: +# +# - Listen for SIGINT and process it in a system task: works fine for +# well-behaved programs that regularly pass through the event loop, but if +# user-code goes into an infinite loop then it can't be interrupted. Which +# is unfortunate, since dealing with infinite loops is what +# KeyboardInterrupt is for! +# +# - Use pthread_sigmask to disable signal delivery during critical section: +# (a) windows has no pthread_sigmask, (b) python threads start with all +# signals unblocked, so if there are any threads around they'll receive the +# signal and then tell the main thread to run the handler, even if the main +# thread has that signal blocked. +# +# - Install a signal handler which checks a global variable to decide whether +# to raise the exception immediately (if we're in a non-critical section), +# or to schedule it on the event loop (if we're in a critical section). The +# problem here is that it's impossible to transition safely out of user code: +# +# with keyboard_interrupt_enabled: +# msg = coro.send(value) +# +# If this raises a KeyboardInterrupt, it might be because the coroutine got +# interrupted and has unwound... or it might be the KeyboardInterrupt +# arrived just *after* 'send' returned, so the coroutine is still running, +# but we just lost the message it sent. (And worse, in our actual task +# runner, the send is hidden inside a utility function etc.) +# +# Solution: +# +# Mark *stack frames* as being interrupt-safe or interrupt-unsafe, and from +# the signal handler check which kind of frame we're currently in when +# deciding whether to raise or schedule the exception. +# +# There are still some cases where this can fail, like if someone hits +# control-C while the process is in the event loop, and then it immediately +# enters an infinite loop in user code. In this case the user has to hit +# control-C a second time. And of course if the user code is written so that +# it doesn't actually exit after a task crashes and everything gets cancelled, +# then there's not much to be done. (Hitting control-C repeatedly might help, +# but in general the solution is to kill the process some other way, just like +# for any Python program that's written to catch and ignore +# KeyboardInterrupt.) + +_T = TypeVar("_T") + + +class _IdRef(weakref.ref[_T]): + __slots__ = ("_hash",) + _hash: int + + def __new__( + cls, + ob: _T, + callback: Callable[[Self], object] | None = None, + /, + ) -> Self: + self: Self = weakref.ref.__new__(cls, ob, callback) + self._hash = object.__hash__(ob) + return self + + def __eq__(self, other: object) -> bool: + if self is other: + return True + + if not isinstance(other, _IdRef): + return NotImplemented + + my_obj = None + try: + my_obj = self() + return my_obj is not None and my_obj is other() + finally: + del my_obj + + # we're overriding a builtin so we do need this + def __ne__(self, other: object) -> bool: + return not self == other + + def __hash__(self) -> int: + return self._hash + + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +# see also: https://github.com/python/cpython/issues/88306 +class WeakKeyIdentityDictionary(Generic[_KT, _VT]): + def __init__(self) -> None: + self._data: dict[_IdRef[_KT], _VT] = {} + + def remove( + k: _IdRef[_KT], + selfref: weakref.ref[ + WeakKeyIdentityDictionary[_KT, _VT] + ] = weakref.ref( # noqa: B008 # function-call-in-default-argument + self, + ), + ) -> None: + self = selfref() + if self is not None: + try: # noqa: SIM105 # suppressible-exception + del self._data[k] + except KeyError: + pass + + self._remove = remove + + def __getitem__(self, k: _KT) -> _VT: + return self._data[_IdRef(k)] + + def __setitem__(self, k: _KT, v: _VT) -> None: + self._data[_IdRef(k, self._remove)] = v + + +_CODE_KI_PROTECTION_STATUS_WMAP: WeakKeyIdentityDictionary[ + types.CodeType, + bool, +] = WeakKeyIdentityDictionary() + + +# This is to support the async_generator package necessary for aclosing on <3.10 +# functions decorated @async_generator are given this magic property that's a +# reference to the object itself +# see python-trio/async_generator/async_generator/_impl.py +def legacy_isasyncgenfunction( + obj: object, +) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]: + return getattr(obj, "_async_gen_function", None) == id(obj) + + +# NB: according to the signal.signal docs, 'frame' can be None on entry to +# this function: +def ki_protection_enabled(frame: types.FrameType | None) -> bool: + try: + task = GLOBAL_RUN_CONTEXT.task + except AttributeError: + task_ki_protected = False + task_frame = None + else: + task_ki_protected = task._ki_protected + task_frame = task.coro.cr_frame + + while frame is not None: + try: + v = _CODE_KI_PROTECTION_STATUS_WMAP[frame.f_code] + except KeyError: + pass + else: + return bool(v) + if frame.f_code.co_name == "__del__": + return True + if frame is task_frame: + return task_ki_protected + frame = frame.f_back + return True + + +def currently_ki_protected() -> bool: + r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection + enabled. + + It's surprisingly easy to think that one's :exc:`KeyboardInterrupt` + protection is enabled when it isn't, or vice-versa. This function tells + you what Trio thinks of the matter, which makes it useful for ``assert``\s + and unit tests. + + Returns: + bool: True if protection is enabled, and False otherwise. + + """ + return ki_protection_enabled(sys._getframe()) + + +class _SupportsCode(Protocol): + __code__: types.CodeType + + +_T_supports_code = TypeVar("_T_supports_code", bound=_SupportsCode) + + +def enable_ki_protection(f: _T_supports_code, /) -> _T_supports_code: + """Decorator to enable KI protection.""" + orig = f + + if legacy_isasyncgenfunction(f): + f = f.__wrapped__ # type: ignore + + _CODE_KI_PROTECTION_STATUS_WMAP[f.__code__] = True + return orig + + +def disable_ki_protection(f: _T_supports_code, /) -> _T_supports_code: + """Decorator to disable KI protection.""" + orig = f + + if legacy_isasyncgenfunction(f): + f = f.__wrapped__ # type: ignore + + _CODE_KI_PROTECTION_STATUS_WMAP[f.__code__] = False + return orig + + [email protected](slots=False) +class KIManager: + handler: Callable[[int, types.FrameType | None], None] | None = None + + def install( + self, + deliver_cb: Callable[[], object], + restrict_keyboard_interrupt_to_checkpoints: bool, + ) -> None: + assert self.handler is None + if ( + not is_main_thread() + or signal.getsignal(signal.SIGINT) != signal.default_int_handler + ): + return + + def handler(signum: int, frame: types.FrameType | None) -> None: + assert signum == signal.SIGINT + protection_enabled = ki_protection_enabled(frame) + if protection_enabled or restrict_keyboard_interrupt_to_checkpoints: + deliver_cb() + else: + raise KeyboardInterrupt + + self.handler = handler + signal.signal(signal.SIGINT, handler) + + def close(self) -> None: + if self.handler is not None: + if signal.getsignal(signal.SIGINT) is self.handler: + signal.signal(signal.SIGINT, signal.default_int_handler) + self.handler = None diff --git a/contrib/python/trio/trio/_core/_local.py b/contrib/python/trio/trio/_core/_local.py new file mode 100644 index 00000000000..fff1234f59e --- /dev/null +++ b/contrib/python/trio/trio/_core/_local.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import Generic, TypeVar, cast + +# Runvar implementations +import attrs + +from .._util import NoPublicConstructor, final +from . import _run + +T = TypeVar("T") + + +@final +class _NoValue: ... + + +@final [email protected](eq=False) +class RunVarToken(Generic[T], metaclass=NoPublicConstructor): + _var: RunVar[T] + previous_value: T | type[_NoValue] = _NoValue + redeemed: bool = attrs.field(default=False, init=False) + + @classmethod + def _empty(cls, var: RunVar[T]) -> RunVarToken[T]: + return cls._create(var) + + +@final [email protected](eq=False, repr=False) +class RunVar(Generic[T]): + """The run-local variant of a context variable. + + :class:`RunVar` objects are similar to context variable objects, + except that they are shared across a single call to :func:`trio.run` + rather than a single task. + + """ + + _name: str = attrs.field(alias="name") + _default: T | type[_NoValue] = attrs.field(default=_NoValue, alias="default") + + def get(self, default: T | type[_NoValue] = _NoValue) -> T: + """Gets the value of this :class:`RunVar` for the current run call.""" + try: + return cast("T", _run.GLOBAL_RUN_CONTEXT.runner._locals[self]) + except AttributeError: + raise RuntimeError("Cannot be used outside of a run context") from None + except KeyError: + # contextvars consistency + # `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released + if default is not _NoValue: + return default # type: ignore[return-value] + + if self._default is not _NoValue: + return self._default # type: ignore[return-value] + + raise LookupError(self) from None + + def set(self, value: T) -> RunVarToken[T]: + """Sets the value of this :class:`RunVar` for this current run + call. + + """ + try: + old_value = self.get() + except LookupError: + token = RunVarToken._empty(self) + else: + token = RunVarToken[T]._create(self, old_value) + + # This can't fail, because if we weren't in Trio context then the + # get() above would have failed. + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value + return token + + def reset(self, token: RunVarToken[T]) -> None: + """Resets the value of this :class:`RunVar` to what it was + previously specified by the token. + + """ + if token is None: + raise TypeError("token must not be none") + + if token.redeemed: + raise ValueError("token has already been used") + + if token._var is not self: + raise ValueError("token is not for us") + + previous = token.previous_value + try: + if previous is _NoValue: + _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) + else: + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous + except AttributeError: + raise RuntimeError("Cannot be used outside of a run context") from None + + token.redeemed = True + + def __repr__(self) -> str: + return f"<RunVar name={self._name!r}>" diff --git a/contrib/python/trio/trio/_core/_mock_clock.py b/contrib/python/trio/trio/_core/_mock_clock.py new file mode 100644 index 00000000000..e437464b65e --- /dev/null +++ b/contrib/python/trio/trio/_core/_mock_clock.py @@ -0,0 +1,165 @@ +import time +from math import inf + +from .. import _core +from .._abc import Clock +from .._util import final +from ._run import GLOBAL_RUN_CONTEXT + +################################################################ +# The glorious MockClock +################################################################ + + +# Prior art: +# https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html +# https://github.com/ztellman/manifold/issues/57 +@final +class MockClock(Clock): + """A user-controllable clock suitable for writing tests. + + Args: + rate (float): the initial :attr:`rate`. + autojump_threshold (float): the initial :attr:`autojump_threshold`. + + .. attribute:: rate + + How many seconds of clock time pass per second of real time. Default is + 0.0, i.e. the clock only advances through manuals calls to :meth:`jump` + or when the :attr:`autojump_threshold` is triggered. You can assign to + this attribute to change it. + + .. attribute:: autojump_threshold + + The clock keeps an eye on the run loop, and if at any point it detects + that all tasks have been blocked for this many real seconds (i.e., + according to the actual clock, not this clock), then the clock + automatically jumps ahead to the run loop's next scheduled + timeout. Default is :data:`math.inf`, i.e., to never autojump. You can + assign to this attribute to change it. + + Basically the idea is that if you have code or tests that use sleeps + and timeouts, you can use this to make it run much faster, totally + automatically. (At least, as long as those sleeps/timeouts are + happening inside Trio; if your test involves talking to external + service and waiting for it to timeout then obviously we can't help you + there.) + + You should set this to the smallest value that lets you reliably avoid + "false alarms" where some I/O is in flight (e.g. between two halves of + a socketpair) but the threshold gets triggered and time gets advanced + anyway. This will depend on the details of your tests and test + environment. If you aren't doing any I/O (like in our sleeping example + above) then just set it to zero, and the clock will jump whenever all + tasks are blocked. + + .. note:: If you use ``autojump_threshold`` and + `wait_all_tasks_blocked` at the same time, then you might wonder how + they interact, since they both cause things to happen after the run + loop goes idle for some time. The answer is: + `wait_all_tasks_blocked` takes priority. If there's a task blocked + in `wait_all_tasks_blocked`, then the autojump feature treats that + as active task and does *not* jump the clock. + + """ + + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf) -> None: + # when the real clock said 'real_base', the virtual time was + # 'virtual_base', and since then it's advanced at 'rate' virtual + # seconds per real second. + self._real_base = 0.0 + self._virtual_base = 0.0 + self._rate = 0.0 + + # kept as an attribute so that our tests can monkeypatch it + self._real_clock = time.perf_counter + + # use the property update logic to set initial values + self.rate = rate + self.autojump_threshold = autojump_threshold + + def __repr__(self) -> str: + return f"<MockClock, time={self.current_time():.7f}, rate={self._rate} @ {id(self):#x}>" + + @property + def rate(self) -> float: + return self._rate + + @rate.setter + def rate(self, new_rate: float) -> None: + if new_rate < 0: + raise ValueError("rate must be >= 0") + else: + real = self._real_clock() + virtual = self._real_to_virtual(real) + self._virtual_base = virtual + self._real_base = real + self._rate = float(new_rate) + + @property + def autojump_threshold(self) -> float: + return self._autojump_threshold + + @autojump_threshold.setter + def autojump_threshold(self, new_autojump_threshold: float) -> None: + self._autojump_threshold = float(new_autojump_threshold) + self._try_resync_autojump_threshold() + + # runner.clock_autojump_threshold is an internal API that isn't easily + # usable by custom third-party Clock objects. If you need access to this + # functionality, let us know, and we'll figure out how to make a public + # API. Discussion: + # + # https://github.com/python-trio/trio/issues/1587 + def _try_resync_autojump_threshold(self) -> None: + try: + runner = GLOBAL_RUN_CONTEXT.runner + if runner.is_guest: + runner.force_guest_tick_asap() + except AttributeError: + pass + else: + if runner.clock is self: + runner.clock_autojump_threshold = self._autojump_threshold + + # Invoked by the run loop when runner.clock_autojump_threshold is + # exceeded. + def _autojump(self) -> None: + statistics = _core.current_statistics() + jump = statistics.seconds_to_next_deadline + if 0 < jump < inf: + self.jump(jump) + + def _real_to_virtual(self, real: float) -> float: + real_offset = real - self._real_base + virtual_offset = self._rate * real_offset + return self._virtual_base + virtual_offset + + def start_clock(self) -> None: + self._try_resync_autojump_threshold() + + def current_time(self) -> float: + return self._real_to_virtual(self._real_clock()) + + def deadline_to_sleep_time(self, deadline: float) -> float: + virtual_timeout = deadline - self.current_time() + if virtual_timeout <= 0: + return 0 + elif self._rate > 0: + return virtual_timeout / self._rate + else: + return 999999999 + + def jump(self, seconds: float) -> None: + """Manually advance the clock by the given number of seconds. + + Args: + seconds (float): the number of seconds to jump the clock forward. + + Raises: + ValueError: if you try to pass a negative value for ``seconds``. + + """ + if seconds < 0: + raise ValueError("time can't go backwards") + self._virtual_base += seconds diff --git a/contrib/python/trio/trio/_core/_parking_lot.py b/contrib/python/trio/trio/_core/_parking_lot.py new file mode 100644 index 00000000000..ddf62761176 --- /dev/null +++ b/contrib/python/trio/trio/_core/_parking_lot.py @@ -0,0 +1,317 @@ +# ParkingLot provides an abstraction for a fair waitqueue with cancellation +# and requeuing support. Inspiration: +# +# https://webkit.org/blog/6161/locking-in-webkit/ +# https://amanieu.github.io/parking_lot/ +# +# which were in turn heavily influenced by +# +# http://gee.cs.oswego.edu/dl/papers/aqs.pdf +# +# Compared to these, our use of cooperative scheduling allows some +# simplifications (no need for internal locking). On the other hand, the need +# to support Trio's strong cancellation semantics adds some complications +# (tasks need to know where they're queued so they can cancel). Also, in the +# above work, the ParkingLot is a global structure that holds a collection of +# waitqueues keyed by lock address, and which are opportunistically allocated +# and destroyed as contention arises; this allows the worst-case memory usage +# for all waitqueues to be O(#tasks). Here we allocate a separate wait queue +# for each synchronization object, so we're O(#objects + #tasks). This isn't +# *so* bad since compared to our synchronization objects are heavier than +# theirs and our tasks are lighter, so for us #objects is smaller and #tasks +# is larger. +# +# This is in the core because for two reasons. First, it's used by +# UnboundedQueue, and UnboundedQueue is used for a number of things in the +# core. And second, it's responsible for providing fairness to all of our +# high-level synchronization primitives (locks, queues, etc.). For now with +# our FIFO scheduler this is relatively trivial (it's just a FIFO waitqueue), +# but in the future we ever start support task priorities or fair scheduling +# +# https://github.com/python-trio/trio/issues/32 +# +# then all we'll have to do is update this. (Well, full-fledged task +# priorities might also require priority inheritance, which would require more +# work.) +# +# For discussion of data structures to use here, see: +# +# https://github.com/dabeaz/curio/issues/136 +# +# (and also the articles above). Currently we use a SortedDict ordered by a +# global monotonic counter that ensures FIFO ordering. The main advantage of +# this is that it's easy to implement :-). An intrusive doubly-linked list +# would also be a natural approach, so long as we only handle FIFO ordering. +# +# XX: should we switch to the shared global ParkingLot approach? +# +# XX: we should probably add support for "parking tokens" to allow for +# task-fair RWlock (basically: when parking a task needs to be able to mark +# itself as a reader or a writer, and then a task-fair wakeup policy is, wake +# the next task, and if it's a reader than keep waking tasks so long as they +# are readers). Without this I think you can implement write-biased or +# read-biased RWlocks (by using two parking lots and drawing from whichever is +# preferred), but not task-fair -- and task-fair plays much more nicely with +# WFQ. (Consider what happens in the two-lot implementation if you're +# write-biased but all the pending writers are blocked at the scheduler level +# by the WFQ logic...) +# ...alternatively, "phase-fair" RWlocks are pretty interesting: +# http://www.cs.unc.edu/~anderson/papers/ecrts09b.pdf +# Useful summary: +# https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/locks/ReadWriteLock.html +# +# XX: if we do add WFQ, then we might have to drop the current feature where +# unpark returns the tasks that were unparked. Rationale: suppose that at the +# time we call unpark, the next task is deprioritized... and then, before it +# becomes runnable, a new task parks which *is* runnable. Ideally we should +# immediately wake the new task, and leave the old task on the queue for +# later. But this means we can't commit to which task we are unparking when +# unpark is called. +# +# See: https://github.com/python-trio/trio/issues/53 +from __future__ import annotations + +import inspect +import math +from collections import OrderedDict +from typing import TYPE_CHECKING + +import attrs +import outcome + +from .. import _core +from .._util import final + +if TYPE_CHECKING: + from collections.abc import Iterator + + from ._run import Task + + +GLOBAL_PARKING_LOT_BREAKER: dict[Task, list[ParkingLot]] = {} + + +def add_parking_lot_breaker(task: Task, lot: ParkingLot) -> None: + """Register a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`. + + raises: + trio.BrokenResourceError: if the task has already exited. + """ + if inspect.getcoroutinestate(task.coro) == inspect.CORO_CLOSED: + raise _core._exceptions.BrokenResourceError( + "Attempted to add already exited task as lot breaker.", + ) + if task not in GLOBAL_PARKING_LOT_BREAKER: + GLOBAL_PARKING_LOT_BREAKER[task] = [lot] + else: + GLOBAL_PARKING_LOT_BREAKER[task].append(lot) + + +def remove_parking_lot_breaker(task: Task, lot: ParkingLot) -> None: + """Deregister a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`""" + try: + GLOBAL_PARKING_LOT_BREAKER[task].remove(lot) + except (KeyError, ValueError): + raise RuntimeError( + "Attempted to remove task as breaker for a lot it is not registered for", + ) from None + if not GLOBAL_PARKING_LOT_BREAKER[task]: + del GLOBAL_PARKING_LOT_BREAKER[task] + + +class ParkingLotStatistics: + """An object containing debugging information for a ParkingLot. + + Currently, the following fields are defined: + + * ``tasks_waiting`` (int): The number of tasks blocked on this lot's + :meth:`trio.lowlevel.ParkingLot.park` method. + + """ + + tasks_waiting: int + + +@final [email protected](eq=False) +class ParkingLot: + """A fair wait queue with cancellation and requeuing. + + This class encapsulates the tricky parts of implementing a wait + queue. It's useful for implementing higher-level synchronization + primitives like queues and locks. + + In addition to the methods below, you can use ``len(parking_lot)`` to get + the number of parked tasks, and ``if parking_lot: ...`` to check whether + there are any parked tasks. + + """ + + # {task: None}, we just want a deque where we can quickly delete random + # items + _parked: OrderedDict[Task, None] = attrs.field(factory=OrderedDict, init=False) + broken_by: list[Task] = attrs.field(factory=list, init=False) + + def __len__(self) -> int: + """Returns the number of parked tasks.""" + return len(self._parked) + + def __bool__(self) -> bool: + """True if there are parked tasks, False otherwise.""" + return bool(self._parked) + + # XX this currently returns None + # if we ever add the ability to repark while one's resuming place in + # line (for false wakeups), then we could have it return a ticket that + # abstracts the "place in line" concept. + @_core.enable_ki_protection + async def park(self) -> None: + """Park the current task until woken by a call to :meth:`unpark` or + :meth:`unpark_all`. + + Raises: + BrokenResourceError: if attempting to park in a broken lot, or the lot + breaks before we get to unpark. + + """ + if self.broken_by: + raise _core.BrokenResourceError( + f"Attempted to park in parking lot broken by {self.broken_by}", + ) + task = _core.current_task() + self._parked[task] = None + task.custom_sleep_data = self + + def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: + del task.custom_sleep_data._parked[task] + return _core.Abort.SUCCEEDED + + await _core.wait_task_rescheduled(abort_fn) + + def _pop_several(self, count: int | float) -> Iterator[Task]: # noqa: PYI041 + if isinstance(count, float): + if math.isinf(count): + count = len(self._parked) + else: + raise ValueError("Cannot pop a non-integer number of tasks.") + else: + count = min(count, len(self._parked)) + for _ in range(count): + task, _ = self._parked.popitem(last=False) + yield task + + @_core.enable_ki_protection + def unpark(self, *, count: int | float = 1) -> list[Task]: # noqa: PYI041 + """Unpark one or more tasks. + + This wakes up ``count`` tasks that are blocked in :meth:`park`. If + there are fewer than ``count`` tasks parked, then wakes as many tasks + are available and then returns successfully. + + Args: + count (int | math.inf): the number of tasks to unpark. + + """ + tasks = list(self._pop_several(count)) + for task in tasks: + _core.reschedule(task) + return tasks + + def unpark_all(self) -> list[Task]: + """Unpark all parked tasks.""" + return self.unpark(count=len(self)) + + @_core.enable_ki_protection + def repark( + self, + new_lot: ParkingLot, + *, + count: int | float = 1, # noqa: PYI041 + ) -> None: + """Move parked tasks from one :class:`ParkingLot` object to another. + + This dequeues ``count`` tasks from one lot, and requeues them on + another, preserving order. For example:: + + async def parker(lot): + print("sleeping") + await lot.park() + print("woken") + + async def main(): + lot1 = trio.lowlevel.ParkingLot() + lot2 = trio.lowlevel.ParkingLot() + async with trio.open_nursery() as nursery: + nursery.start_soon(parker, lot1) + await trio.testing.wait_all_tasks_blocked() + assert len(lot1) == 1 + assert len(lot2) == 0 + lot1.repark(lot2) + assert len(lot1) == 0 + assert len(lot2) == 1 + # This wakes up the task that was originally parked in lot1 + lot2.unpark() + + If there are fewer than ``count`` tasks parked, then reparks as many + tasks as are available and then returns successfully. + + Args: + new_lot (ParkingLot): the parking lot to move tasks to. + count (int|math.inf): the number of tasks to move. + + """ + if not isinstance(new_lot, ParkingLot): + raise TypeError("new_lot must be a ParkingLot") + for task in self._pop_several(count): + new_lot._parked[task] = None + task.custom_sleep_data = new_lot + + def repark_all(self, new_lot: ParkingLot) -> None: + """Move all parked tasks from one :class:`ParkingLot` object to + another. + + See :meth:`repark` for details. + + """ + return self.repark(new_lot, count=len(self)) + + def break_lot(self, task: Task | None = None) -> None: + """Break this lot, with ``task`` noted as the task that broke it. + + This causes all parked tasks to raise an error, and any + future tasks attempting to park to error. Unpark & repark become no-ops as the + parking lot is empty. + + The error raised contains a reference to the task sent as a parameter. The task + is also saved in the parking lot in the ``broken_by`` attribute. + """ + if task is None: + task = _core.current_task() + + # if lot is already broken, just mark this as another breaker and return + if self.broken_by: + self.broken_by.append(task) + return + + self.broken_by.append(task) + + for parked_task in self._parked: + _core.reschedule( + parked_task, + outcome.Error( + _core.BrokenResourceError(f"Parking lot broken by {task}"), + ), + ) + self._parked.clear() + + def statistics(self) -> ParkingLotStatistics: + """Return an object containing debugging information. + + Currently the following fields are defined: + + * ``tasks_waiting``: The number of tasks blocked on this lot's + :meth:`park` method. + + """ + return ParkingLotStatistics(tasks_waiting=len(self._parked)) diff --git a/contrib/python/trio/trio/_core/_run.py b/contrib/python/trio/trio/_core/_run.py new file mode 100644 index 00000000000..5303dfe75d4 --- /dev/null +++ b/contrib/python/trio/trio/_core/_run.py @@ -0,0 +1,3139 @@ +from __future__ import annotations + +import enum +import functools +import gc +import itertools +import random +import select +import sys +import warnings +from collections import deque +from contextlib import AbstractAsyncContextManager, contextmanager, suppress +from contextvars import copy_context +from heapq import heapify, heappop, heappush +from math import inf, isnan +from time import perf_counter +from typing import ( + TYPE_CHECKING, + Any, + Final, + NoReturn, + Protocol, + cast, + overload, +) + +import attrs +from outcome import Error, Outcome, Value, capture +from sniffio import thread_local as sniffio_library +from sortedcontainers import SortedDict + +from .. import _core +from .._abc import Clock, Instrument +from .._deprecate import warn_deprecated +from .._util import NoPublicConstructor, coroutine_or_error, final +from ._asyncgens import AsyncGenerators +from ._concat_tb import concat_tb +from ._entry_queue import EntryQueue, TrioToken +from ._exceptions import ( + Cancelled, + CancelReasonLiteral, + RunFinishedError, + TrioInternalError, +) +from ._instrumentation import Instruments +from ._ki import KIManager, enable_ki_protection +from ._parking_lot import GLOBAL_PARKING_LOT_BREAKER +from ._run_context import GLOBAL_RUN_CONTEXT as GLOBAL_RUN_CONTEXT +from ._thread_cache import start_thread_soon +from ._traps import ( + Abort, + CancelShieldedCheckpoint, + PermanentlyDetachCoroutineObject, + WaitTaskRescheduled, + cancel_shielded_checkpoint, + wait_task_rescheduled, +) + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + + +if TYPE_CHECKING: + import contextvars + import types + from collections.abc import ( + Awaitable, + Callable, + Generator, + Iterator, + Sequence, + ) + from types import TracebackType + + # for some strange reason Sphinx works with outcome.Outcome, but not Outcome, in + # start_guest_run. Same with types.FrameType in iter_await_frames + import outcome + from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack + + PosArgT = TypeVarTuple("PosArgT") + StatusT = TypeVar("StatusT", default=None) + StatusT_contra = TypeVar("StatusT_contra", contravariant=True, default=None) + BaseExcT = TypeVar("BaseExcT", bound=BaseException) +else: + from typing import TypeVar + + StatusT = TypeVar("StatusT") + StatusT_contra = TypeVar("StatusT_contra", contravariant=True) + +RetT = TypeVar("RetT") + + +DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000 + +# Passed as a sentinel +_NO_SEND: Final[Outcome[object]] = cast("Outcome[object]", object()) + +# Used to track if an exceptiongroup can be collapsed +NONSTRICT_EXCEPTIONGROUP_NOTE = 'This is a "loose" ExceptionGroup, and may be collapsed by Trio if it only contains one exception - typically after `Cancelled` has been stripped from it. Note this has consequences for exception handling, and strict_exception_groups=True is recommended.' + + +@final +class _NoStatus(metaclass=NoPublicConstructor): + """Sentinel for unset TaskStatus._value.""" + + +# Decorator to mark methods public. This does nothing by itself, but +# trio/_tools/gen_exports.py looks for it. +def _public(fn: RetT) -> RetT: + return fn + + +# When running under Hypothesis, we want examples to be reproducible and +# shrinkable. We therefore register `_hypothesis_plugin_setup()` as a +# plugin, so that importing *Hypothesis* will make Trio's task +# scheduling loop deterministic. We have a test for that, of course. +# Before Hypothesis supported entry-point plugins this integration was +# handled by pytest-trio, but we want it to work in e.g. unittest too. +_ALLOW_DETERMINISTIC_SCHEDULING: Final = False +_r = random.Random() + + +# no cover because we don't check the hypothesis plugin works with hypothesis +def _hypothesis_plugin_setup() -> None: # pragma: no cover + from hypothesis import register_random + + global _ALLOW_DETERMINISTIC_SCHEDULING + _ALLOW_DETERMINISTIC_SCHEDULING = True # type: ignore + register_random(_r) + + # monkeypatch repr_callable to make repr's way better + # requires importing hypothesis (in the test file or in conftest.py) + try: + from hypothesis.internal.reflection import get_pretty_function_description + + import trio.testing._raises_group + + def repr_callable(fun: Callable[[BaseExcT], bool]) -> str: + # add quotes around the signature + return repr(get_pretty_function_description(fun)) + + trio.testing._raises_group.repr_callable = repr_callable + except ImportError: + pass + + +def _count_context_run_tb_frames() -> int: + """Count implementation dependent traceback frames from Context.run() + + On CPython, Context.run() is implemented in C and doesn't show up in + tracebacks. On PyPy, it is implemented in Python and adds 1 frame to + tracebacks. + + Returns: + int: Traceback frame count + + """ + + def function_with_unique_name_xyzzy() -> NoReturn: + try: + 1 / 0 # noqa: B018 # We need a ZeroDivisionError to fire + except ZeroDivisionError: + raise + else: # pragma: no cover + raise TrioInternalError( + "A ZeroDivisionError should have been raised, but it wasn't.", + ) + + ctx = copy_context() + try: + ctx.run(function_with_unique_name_xyzzy) + except ZeroDivisionError as exc: + tb = exc.__traceback__ + # Skip the frame where we caught it + tb = tb.tb_next # type: ignore[union-attr] + count = 0 + while tb.tb_frame.f_code.co_name != "function_with_unique_name_xyzzy": # type: ignore[union-attr] + tb = tb.tb_next # type: ignore[union-attr] + count += 1 + return count + else: # pragma: no cover + raise TrioInternalError( + f"The purpose of {function_with_unique_name_xyzzy.__name__} is " + "to raise a ZeroDivisionError, but it didn't.", + ) + + +CONTEXT_RUN_TB_FRAMES: Final = _count_context_run_tb_frames() + + +class SystemClock(Clock): + # Add a large random offset to our clock to ensure that if people + # accidentally call time.perf_counter() directly or start comparing clocks + # between different runs, then they'll notice the bug quickly: + offset: float = attrs.Factory(lambda: _r.uniform(10000, 200000)) + + def start_clock(self) -> None: + pass + + # In cPython 3, on every platform except Windows, perf_counter is + # exactly the same as time.monotonic; and on Windows, it uses + # QueryPerformanceCounter instead of GetTickCount64. + def current_time(self) -> float: + return self.offset + perf_counter() + + def deadline_to_sleep_time(self, deadline: float) -> float: + return deadline - self.current_time() + + +class IdlePrimedTypes(enum.Enum): + WAITING_FOR_IDLE = 1 + AUTOJUMP_CLOCK = 2 + + +################################################################ +# CancelScope and friends +################################################################ + + +def collapse_exception_group( + excgroup: BaseExceptionGroup[BaseException], +) -> BaseException: + """Recursively collapse any single-exception groups into that single contained + exception. + + """ + exceptions = list(excgroup.exceptions) + modified = False + for i, exc in enumerate(exceptions): + if isinstance(exc, BaseExceptionGroup): + new_exc = collapse_exception_group(exc) + if new_exc is not exc: + modified = True + exceptions[i] = new_exc + + if ( + len(exceptions) == 1 + and isinstance(excgroup, BaseExceptionGroup) + and NONSTRICT_EXCEPTIONGROUP_NOTE in getattr(excgroup, "__notes__", ()) + ): + exceptions[0].__traceback__ = concat_tb( + excgroup.__traceback__, + exceptions[0].__traceback__, + ) + return exceptions[0] + elif modified: + return excgroup.derive(exceptions) + else: + return excgroup + + [email protected](eq=False) +class Deadlines: + """A container of deadlined cancel scopes. + + Only contains scopes with non-infinite deadlines that are currently + attached to at least one task. + + """ + + # Heap of (deadline, id(CancelScope), CancelScope) + _heap: list[tuple[float, int, CancelScope]] = attrs.Factory(list) + # Count of active deadlines (those that haven't been changed) + _active: int = 0 + + def add(self, deadline: float, cancel_scope: CancelScope) -> None: + heappush(self._heap, (deadline, id(cancel_scope), cancel_scope)) + self._active += 1 + + def remove(self, deadline: float, cancel_scope: CancelScope) -> None: + self._active -= 1 + + def next_deadline(self) -> float: + while self._heap: + deadline, _, cancel_scope = self._heap[0] + if deadline == cancel_scope._registered_deadline: + return deadline + else: + # This entry is stale; discard it and try again + heappop(self._heap) + return inf + + def _prune(self) -> None: + # In principle, it's possible for a cancel scope to toggle back and + # forth repeatedly between the same two deadlines, and end up with + # lots of stale entries that *look* like they're still active, because + # their deadline is correct, but in fact are redundant. So when + # pruning we have to eliminate entries with the wrong deadline, *and* + # eliminate duplicates. + seen = set() + pruned_heap = [] + for deadline, tiebreaker, cancel_scope in self._heap: + if deadline == cancel_scope._registered_deadline: + if cancel_scope in seen: + continue + seen.add(cancel_scope) + pruned_heap.append((deadline, tiebreaker, cancel_scope)) + # See test_cancel_scope_deadline_duplicates for a test that exercises + # this assert: + assert len(pruned_heap) == self._active + heapify(pruned_heap) + self._heap = pruned_heap + + def expire(self, now: float) -> bool: + did_something = False + while self._heap and self._heap[0][0] <= now: + deadline, _, cancel_scope = heappop(self._heap) + if deadline == cancel_scope._registered_deadline: + did_something = True + # This implicitly calls self.remove(), so we don't need to + # decrement _active here + cancel_scope._cancel(CancelReason(source="deadline")) + # If we've accumulated too many stale entries, then prune the heap to + # keep it under control. (We only do this occasionally in a batch, to + # keep the amortized cost down) + if len(self._heap) > self._active * 2 + DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: + self._prune() + return did_something + + +class CancelReason: + """Attached to a :class:`CancelScope` upon cancellation with details of the source of the + cancellation, which is then used to construct the string in a :exc:`Cancelled`. + Users can pass a ``reason`` str to :meth:`CancelScope.cancel` to set it. + + Not publicly exported or documented. + """ + + source: CancelReasonLiteral + source_task: str | None = None + reason: str | None = None + + [email protected](eq=False) +class CancelStatus: + """Tracks the cancellation status for a contiguous extent + of code that will become cancelled, or not, as a unit. + + Each task has at all times a single "active" CancelStatus whose + cancellation state determines whether checkpoints executed in that + task raise Cancelled. Each 'with CancelScope(...)' context is + associated with a particular CancelStatus. When a task enters + such a context, a CancelStatus is created which becomes the active + CancelStatus for that task; when the 'with' block is exited, the + active CancelStatus for that task goes back to whatever it was + before. + + CancelStatus objects are arranged in a tree whose structure + mirrors the lexical nesting of the cancel scope contexts. When a + CancelStatus becomes cancelled, it notifies all of its direct + children, who become cancelled in turn (and continue propagating + the cancellation down the tree) unless they are shielded. (There + will be at most one such child except in the case of a + CancelStatus that immediately encloses a nursery.) At the leaves + of this tree are the tasks themselves, which get woken up to deliver + an abort when their direct parent CancelStatus becomes cancelled. + + You can think of CancelStatus as being responsible for the + "plumbing" of cancellations as oppposed to CancelScope which is + responsible for the origination of them. + + """ + + # Our associated cancel scope. Can be any object with attributes + # `deadline`, `shield`, and `cancel_called`, but in current usage + # is always a CancelScope object. Must not be None. + _scope: CancelScope = attrs.field(alias="scope") + + # True iff the tasks in self._tasks should receive cancellations + # when they checkpoint. Always True when scope.cancel_called is True; + # may also be True due to a cancellation propagated from our + # parent. Unlike scope.cancel_called, this does not necessarily stay + # true once it becomes true. For example, we might become + # effectively cancelled due to the cancel scope two levels out + # becoming cancelled, but then the cancel scope one level out + # becomes shielded so we're not effectively cancelled anymore. + effectively_cancelled: bool = False + + # The CancelStatus whose cancellations can propagate to us; we + # become effectively cancelled when they do, unless scope.shield + # is True. May be None (for the outermost CancelStatus in a call + # to trio.run(), briefly during TaskStatus.started(), or during + # recovery from misnesting of cancel scopes). + _parent: CancelStatus | None = attrs.field(default=None, repr=False, alias="parent") + + # All of the CancelStatuses that have this CancelStatus as their parent. + _children: set[CancelStatus] = attrs.field(factory=set, init=False, repr=False) + + # Tasks whose cancellation state is currently tied directly to + # the cancellation state of this CancelStatus object. Don't modify + # this directly; instead, use Task._activate_cancel_status(). + # Invariant: all(task._cancel_status is self for task in self._tasks) + _tasks: set[Task] = attrs.field(factory=set, init=False, repr=False) + + # Set to True on still-active cancel statuses that are children + # of a cancel status that's been closed. This is used to permit + # recovery from misnested cancel scopes (well, at least enough + # recovery to show a useful traceback). + abandoned_by_misnesting: bool = attrs.field(default=False, init=False, repr=False) + + def __attrs_post_init__(self) -> None: + if self._parent is not None: + self._parent._children.add(self) + self.recalculate() + + # parent/children/tasks accessors are used by TaskStatus.started() + + @property + def parent(self) -> CancelStatus | None: + return self._parent + + @parent.setter + def parent(self, parent: CancelStatus | None) -> None: + if self._parent is not None: + self._parent._children.remove(self) + self._parent = parent + if self._parent is not None: + self._parent._children.add(self) + self.recalculate() + + @property + def children(self) -> frozenset[CancelStatus]: + return frozenset(self._children) + + @property + def tasks(self) -> frozenset[Task]: + return frozenset(self._tasks) + + def encloses(self, other: CancelStatus | None) -> bool: + """Returns true if this cancel status is a direct or indirect + parent of cancel status *other*, or if *other* is *self*. + """ + while other is not None: + if other is self: + return True + other = other.parent + return False + + def close(self) -> None: + self.parent = None # now we're not a child of self.parent anymore + if self._tasks or self._children: + # Cancel scopes weren't exited in opposite order of being + # entered. CancelScope._close() deals with raising an error + # if appropriate; our job is to leave things in a reasonable + # state for unwinding our dangling children. We choose to leave + # this part of the CancelStatus tree unlinked from everyone + # else, cancelled, and marked so that exiting a CancelScope + # within the abandoned subtree doesn't affect the active + # CancelStatus. Note that it's possible for us to get here + # without CancelScope._close() raising an error, if a + # nursery's cancel scope is closed within the nursery's + # nested child and no other cancel scopes are involved, + # but in that case task_exited() will deal with raising + # the error. + self._mark_abandoned() + + # Since our CancelScope is about to forget about us, and we + # have no parent anymore, there's nothing left to call + # recalculate(). So, we can stay cancelled by setting + # effectively_cancelled and updating our children. + self.effectively_cancelled = True + for task in self._tasks: + task._attempt_delivery_of_any_pending_cancel() + for child in self._children: + child.recalculate() + + @property + def parent_cancellation_is_visible_to_us(self) -> bool: + return ( + self._parent is not None + and not self._scope.shield + and self._parent.effectively_cancelled + ) + + def recalculate(self) -> None: + # This does a depth-first traversal over this and descendent cancel + # statuses, to ensure their state is up-to-date. It's basically a + # recursive algorithm, but we use an explicit stack to avoid any + # issues with stack overflow. + todo = [self] + while todo: + current = todo.pop() + new_state = ( + current._scope.cancel_called + or current.parent_cancellation_is_visible_to_us + ) + if new_state != current.effectively_cancelled: + if ( + current._scope._cancel_reason is None + and current.parent_cancellation_is_visible_to_us + ): + assert current._parent is not None + current._scope._cancel_reason = ( + current._parent._scope._cancel_reason + ) + current.effectively_cancelled = new_state + if new_state: + for task in current._tasks: + task._attempt_delivery_of_any_pending_cancel() + todo.extend(current._children) + + def _mark_abandoned(self) -> None: + self.abandoned_by_misnesting = True + for child in self._children: + child._mark_abandoned() + + def effective_deadline(self) -> float: + if self.effectively_cancelled: + return -inf + if self._parent is None or self._scope.shield: + return self._scope.deadline + return min(self._scope.deadline, self._parent.effective_deadline()) + + +MISNESTING_ADVICE = """ +This is probably a bug in your code, that has caused Trio's internal state to +become corrupted. We'll do our best to recover, but from now on there are +no guarantees. + +Typically this is caused by one of the following: + - yielding within a generator or async generator that's opened a cancel + scope or nursery (unless the generator is a @contextmanager or + @asynccontextmanager); see https://github.com/python-trio/trio/issues/638 + - manually calling __enter__ or __exit__ on a trio.CancelScope, or + __aenter__ or __aexit__ on the object returned by trio.open_nursery(); + doing so correctly is difficult and you should use @[async]contextmanager + instead, or maybe [Async]ExitStack + - using [Async]ExitStack to interleave the entries/exits of cancel scopes + and/or nurseries in a way that couldn't be achieved by some nesting of + 'with' and 'async with' blocks + - using the low-level coroutine object protocol to execute some parts of + an async function in a different cancel scope/nursery context than + other parts +If you don't believe you're doing any of these things, please file a bug: +https://github.com/python-trio/trio/issues/new +""" + + +@final [email protected](eq=False, repr=False) +class CancelScope: + """A *cancellation scope*: the link between a unit of cancellable + work and Trio's cancellation system. + + A :class:`CancelScope` becomes associated with some cancellable work + when it is used as a context manager surrounding that work:: + + cancel_scope = trio.CancelScope() + ... + with cancel_scope: + await long_running_operation() + + Inside the ``with`` block, a cancellation of ``cancel_scope`` (via + a call to its :meth:`cancel` method or via the expiry of its + :attr:`deadline`) will immediately interrupt the + ``long_running_operation()`` by raising :exc:`Cancelled` at its + next :ref:`checkpoint <checkpoints>`. + + The context manager ``__enter__`` returns the :class:`CancelScope` + object itself, so you can also write ``with trio.CancelScope() as + cancel_scope:``. + + If a cancel scope becomes cancelled before entering its ``with`` block, + the :exc:`Cancelled` exception will be raised at the first + checkpoint inside the ``with`` block. This allows a + :class:`CancelScope` to be created in one :ref:`task <tasks>` and + passed to another, so that the first task can later cancel some work + inside the second. + + Cancel scopes are not reusable or reentrant; that is, each cancel + scope can be used for at most one ``with`` block. (You'll get a + :exc:`RuntimeError` if you violate this rule.) + + The :class:`CancelScope` constructor takes initial values for the + cancel scope's :attr:`deadline` and :attr:`shield` attributes; these + may be freely modified after construction, whether or not the scope + has been entered yet, and changes take immediate effect. + """ + + _cancel_status: CancelStatus | None = attrs.field(default=None, init=False) + _has_been_entered: bool = attrs.field(default=False, init=False) + _registered_deadline: float = attrs.field(default=inf, init=False) + _cancel_called: bool = attrs.field(default=False, init=False) + cancelled_caught: bool = attrs.field(default=False, init=False) + + _cancel_reason: CancelReason | None = attrs.field(default=None, init=False) + + # Constructor arguments: + _relative_deadline: float = attrs.field( + default=inf, + kw_only=True, + alias="relative_deadline", + ) + _deadline: float = attrs.field(default=inf, kw_only=True, alias="deadline") + _shield: bool = attrs.field(default=False, kw_only=True, alias="shield") + + def __attrs_post_init__(self) -> None: + if isnan(self._deadline): + raise ValueError("deadline must not be NaN") + if isnan(self._relative_deadline): + raise ValueError("relative deadline must not be NaN") + if self._relative_deadline < 0: + raise ValueError("timeout must be non-negative") + if self._relative_deadline != inf and self._deadline != inf: + raise ValueError( + "Cannot specify both a deadline and a relative deadline", + ) + + @enable_ki_protection + def __enter__(self) -> Self: + task = _core.current_task() + if self._has_been_entered: + raise RuntimeError( + "Each CancelScope may only be used for a single 'with' block", + ) + self._has_been_entered = True + + if self._relative_deadline != inf: + assert self._deadline == inf + self._deadline = current_time() + self._relative_deadline + self._relative_deadline = inf + + if current_time() >= self._deadline: + self._cancel(CancelReason(source="deadline")) + with self._might_change_registered_deadline(): + self._cancel_status = CancelStatus(scope=self, parent=task._cancel_status) + task._activate_cancel_status(self._cancel_status) + return self + + def _close(self, exc: BaseException | None) -> BaseException | None: + if self._cancel_status is None: + new_exc = RuntimeError( + f"Cancel scope stack corrupted: attempted to exit {self!r} " + "which had already been exited", + ) + new_exc.__context__ = exc + return new_exc + scope_task = current_task() + if scope_task._cancel_status is not self._cancel_status: + # Cancel scope misnesting: this cancel scope isn't the most + # recently opened by this task (that's still open). That is, + # our assumptions about context managers forming a stack + # have been violated. Try and make the best of it. + if self._cancel_status.abandoned_by_misnesting: + # We are an inner cancel scope that was still active when + # some outer scope was closed. The closure of that outer + # scope threw an error, so we don't need to throw another + # one; it would just confuse the traceback. + pass + elif not self._cancel_status.encloses(scope_task._cancel_status): + # This task isn't even indirectly contained within the + # cancel scope it's trying to close. Raise an error + # without changing any state. + new_exc = RuntimeError( + f"Cancel scope stack corrupted: attempted to exit {self!r} " + f"from unrelated {scope_task!r}\n{MISNESTING_ADVICE}", + ) + new_exc.__context__ = exc + return new_exc + else: + # Otherwise, there's some inner cancel scope(s) that + # we're abandoning by closing this outer one. + # CancelStatus.close() will take care of the plumbing; + # we just need to make sure we don't let the error + # pass silently. + new_exc = RuntimeError( + f"Cancel scope stack corrupted: attempted to exit {self!r} " + f"in {scope_task!r} that's still within its child {scope_task._cancel_status._scope!r}\n{MISNESTING_ADVICE}", + ) + new_exc.__context__ = exc + exc = new_exc + scope_task._activate_cancel_status(self._cancel_status.parent) + else: + scope_task._activate_cancel_status(self._cancel_status.parent) + if ( + exc is not None + and self._cancel_status.effectively_cancelled + and not self._cancel_status.parent_cancellation_is_visible_to_us + ) or ( + scope_task._cancel_status is not self._cancel_status + and self._cancel_status.abandoned_by_misnesting + ): + if isinstance(exc, Cancelled): + self.cancelled_caught = True + exc = None + elif isinstance(exc, BaseExceptionGroup): + matched, exc = exc.split(Cancelled) + if matched: + self.cancelled_caught = True + + if exc: + exc = collapse_exception_group(exc) + + self._cancel_status.close() + with self._might_change_registered_deadline(): + self._cancel_status = None + return exc + + @enable_ki_protection + def __exit__( + self, + etype: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: + # NB: NurseryManager calls _close() directly rather than __exit__(), + # so __exit__() must be just _close() plus this logic for adapting + # the exception-filtering result to the context manager API. + + # Tracebacks show the 'raise' line below out of context, so let's give + # this variable a name that makes sense out of context. + remaining_error_after_cancel_scope = self._close(exc) + if remaining_error_after_cancel_scope is None: + return True + elif remaining_error_after_cancel_scope is exc: + return False + else: + # Copied verbatim from the old MultiErrorCatcher. Python doesn't + # allow us to encapsulate this __context__ fixup. + old_context = remaining_error_after_cancel_scope.__context__ + try: + raise remaining_error_after_cancel_scope + finally: + _, value, _ = sys.exc_info() + assert value is remaining_error_after_cancel_scope + value.__context__ = old_context + # delete references from locals to avoid creating cycles + # see test_cancel_scope_exit_doesnt_create_cyclic_garbage + # Note: still relevant + del remaining_error_after_cancel_scope, value, _, exc + + def __repr__(self) -> str: + if self._cancel_status is not None: + binding = "active" + elif self._has_been_entered: + binding = "exited" + else: + binding = "unbound" + + if self._cancel_called: + state = ", cancelled" + elif self._deadline == inf: + state = "" + else: + try: + now = current_time() + except RuntimeError: # must be called from async context + state = "" + else: + state = ", deadline is {:.2f} seconds {}".format( + abs(self._deadline - now), + "from now" if self._deadline >= now else "ago", + ) + + return f"<trio.CancelScope at {id(self):#x}, {binding}{state}>" + + @contextmanager + @enable_ki_protection + def _might_change_registered_deadline(self) -> Iterator[None]: + try: + yield + finally: + old = self._registered_deadline + if self._cancel_status is None or self._cancel_called: + new = inf + else: + new = self._deadline + if old != new: + self._registered_deadline = new + runner = GLOBAL_RUN_CONTEXT.runner + if runner.is_guest: + old_next_deadline = runner.deadlines.next_deadline() + if old != inf: + runner.deadlines.remove(old, self) + if new != inf: + runner.deadlines.add(new, self) + if runner.is_guest: + new_next_deadline = runner.deadlines.next_deadline() + if old_next_deadline != new_next_deadline: + runner.force_guest_tick_asap() + + @property + def deadline(self) -> float: + """Read-write, :class:`float`. An absolute time on the current + run's clock at which this scope will automatically become + cancelled. You can adjust the deadline by modifying this + attribute, e.g.:: + + # I need a little more time! + cancel_scope.deadline += 30 + + Note that for efficiency, the core run loop only checks for + expired deadlines every once in a while. This means that in + certain cases there may be a short delay between when the clock + says the deadline should have expired, and when checkpoints + start raising :exc:`~trio.Cancelled`. This is a very obscure + corner case that you're unlikely to notice, but we document it + for completeness. (If this *does* cause problems for you, of + course, then `we want to know! + <https://github.com/python-trio/trio/issues>`__) + + Defaults to :data:`math.inf`, which means "no deadline", though + this can be overridden by the ``deadline=`` argument to + the :class:`~trio.CancelScope` constructor. + """ + if self._relative_deadline != inf: + assert self._deadline == inf + warnings.warn( + DeprecationWarning( + "unentered relative cancel scope does not have an absolute deadline. Use `.relative_deadline`", + ), + stacklevel=2, + ) + return current_time() + self._relative_deadline + return self._deadline + + @deadline.setter + def deadline(self, new_deadline: float) -> None: + if isnan(new_deadline): + raise ValueError("deadline must not be NaN") + if self._relative_deadline != inf: + assert self._deadline == inf + warnings.warn( + DeprecationWarning( + "unentered relative cancel scope does not have an absolute deadline. Transforming into an absolute cancel scope. First set `.relative_deadline = math.inf` if you do want an absolute cancel scope.", + ), + stacklevel=2, + ) + self._relative_deadline = inf + with self._might_change_registered_deadline(): + self._deadline = float(new_deadline) + + @property + def relative_deadline(self) -> float: + """Read-write, :class:`float`. The number of seconds remaining until this + scope's deadline, relative to the current time. + + Defaults to :data:`math.inf` ("no deadline"). Must be non-negative. + + When modified + Before entering: sets the deadline relative to when the scope enters. + After entering: sets a new deadline relative to the current time. + + Raises: + RuntimeError: if trying to read or modify an unentered scope with an absolute deadline, i.e. when :attr:`is_relative` is ``False``. + """ + if self._has_been_entered: + return self._deadline - current_time() + elif self._deadline != inf: + assert self._relative_deadline == inf + raise RuntimeError( + "unentered non-relative cancel scope does not have a relative deadline", + ) + return self._relative_deadline + + @relative_deadline.setter + def relative_deadline(self, new_relative_deadline: float) -> None: + if isnan(new_relative_deadline): + raise ValueError("relative deadline must not be NaN") + if new_relative_deadline < 0: + raise ValueError("relative deadline must be non-negative") + if self._has_been_entered: + with self._might_change_registered_deadline(): + self._deadline = current_time() + float(new_relative_deadline) + elif self._deadline != inf: + assert self._relative_deadline == inf + raise RuntimeError( + "unentered non-relative cancel scope does not have a relative deadline", + ) + else: + self._relative_deadline = new_relative_deadline + + @property + def is_relative(self) -> bool | None: + """Returns None after entering. Returns False if both deadline and + relative_deadline are inf.""" + assert not (self._deadline != inf and self._relative_deadline != inf) + if self._has_been_entered: + return None + return self._relative_deadline != inf + + @property + def shield(self) -> bool: + """Read-write, :class:`bool`, default :data:`False`. So long as + this is set to :data:`True`, then the code inside this scope + will not receive :exc:`~trio.Cancelled` exceptions from scopes + that are outside this scope. They can still receive + :exc:`~trio.Cancelled` exceptions from (1) this scope, or (2) + scopes inside this scope. You can modify this attribute:: + + with trio.CancelScope() as cancel_scope: + cancel_scope.shield = True + # This cannot be interrupted by any means short of + # killing the process: + await sleep(10) + + cancel_scope.shield = False + # Now this can be cancelled normally: + await sleep(10) + + Defaults to :data:`False`, though this can be overridden by the + ``shield=`` argument to the :class:`~trio.CancelScope` constructor. + """ + return self._shield + + @shield.setter + @enable_ki_protection + def shield(self, new_value: bool) -> None: + if not isinstance(new_value, bool): + raise TypeError("shield must be a bool") + self._shield = new_value + if self._cancel_status is not None: + self._cancel_status.recalculate() + + @enable_ki_protection + def _cancel(self, cancel_reason: CancelReason | None) -> None: + """Internal sources of cancellation should use this instead of :meth:`cancel` + in order to set a more detailed :class:`CancelReason` + Helper or high-level functions can use `cancel`. + """ + if self._cancel_called: + return + + if self._cancel_reason is None: + self._cancel_reason = cancel_reason + + with self._might_change_registered_deadline(): + self._cancel_called = True + + if self._cancel_status is not None: + self._cancel_status.recalculate() + + @enable_ki_protection + def cancel(self, reason: str | None = None) -> None: + """Cancels this scope immediately. + + The optional ``reason`` argument accepts a string, which will be attached to + any resulting :exc:`Cancelled` exception to help you understand where that + cancellation is coming from and why it happened. + + This method is idempotent, i.e., if the scope was already + cancelled then this method silently does nothing. + """ + try: + current_task = repr(_core.current_task()) + except RuntimeError: + current_task = None + self._cancel( + CancelReason(reason=reason, source="explicit", source_task=current_task) + ) + + @property + def cancel_called(self) -> bool: + """Readonly :class:`bool`. Records whether cancellation has been + requested for this scope, either by an explicit call to + :meth:`cancel` or by the deadline expiring. + + This attribute being True does *not* necessarily mean that the + code within the scope has been, or will be, affected by the + cancellation. For example, if :meth:`cancel` was called after + the last checkpoint in the ``with`` block, when it's too late to + deliver a :exc:`~trio.Cancelled` exception, then this attribute + will still be True. + + This attribute is mostly useful for debugging and introspection. + If you want to know whether or not a chunk of code was actually + cancelled, then :attr:`cancelled_caught` is usually more + appropriate. + """ + if ( # noqa: SIM102 # collapsible-if but this way is nicer + self._cancel_status is not None or not self._has_been_entered + ): + # Scope is active or not yet entered: make sure cancel_called + # is true if the deadline has passed. This shouldn't + # be able to actually change behavior, since we check for + # deadline expiry on scope entry and at every checkpoint, + # but it makes the value returned by cancel_called more + # closely match expectations. + if not self._cancel_called and current_time() >= self._deadline: + self._cancel(CancelReason(source="deadline")) + return self._cancel_called + + +################################################################ +# Nursery and friends +################################################################ + + +class TaskStatus(Protocol[StatusT_contra]): + """The interface provided by :meth:`Nursery.start()` to the spawned task. + + This is provided via the ``task_status`` keyword-only parameter. + """ + + @overload + def started(self: TaskStatus[None]) -> None: ... + + @overload + def started(self, value: StatusT_contra) -> None: ... + + def started(self, value: StatusT_contra | None = None) -> None: + """Tasks call this method to indicate that they have initialized. + + See `nursery.start() <trio.Nursery.start>` for more information. + """ + + +# This code needs to be read alongside the code from Nursery.start to make +# sense. [email protected](eq=False, repr=False, slots=False) +class _TaskStatus(TaskStatus[StatusT]): + _old_nursery: Nursery + _new_nursery: Nursery + # NoStatus is a sentinel. + _value: StatusT | type[_NoStatus] = _NoStatus + + def __repr__(self) -> str: + return f"<Task status object at {id(self):#x}>" + + @overload + def started(self: _TaskStatus[None]) -> None: ... + + @overload + def started(self: _TaskStatus[StatusT], value: StatusT) -> None: ... + + def started(self, value: StatusT | None = None) -> None: + if self._value is not _NoStatus: + raise RuntimeError("called 'started' twice on the same task status") + self._value = cast("StatusT", value) # If None, StatusT == None + + # If the old nursery is cancelled, then quietly quit now; the child + # will eventually exit on its own, and we don't want to risk moving + # children that might have propagating Cancelled exceptions into + # a place with no cancelled cancel scopes to catch them. + assert self._old_nursery._cancel_status is not None + if self._old_nursery._cancel_status.effectively_cancelled: + return + + # Can't be closed, b/c we checked in start() and then _pending_starts + # should keep it open. + assert not self._new_nursery._closed + + # Move tasks from the old nursery to the new + tasks = self._old_nursery._children + self._old_nursery._children = set() + for task in tasks: + task._parent_nursery = self._new_nursery + task._eventual_parent_nursery = None + self._new_nursery._children.add(task) + + # Move all children of the old nursery's cancel status object + # to be underneath the new nursery instead. This includes both + # tasks and child cancel status objects. + # NB: If the new nursery is cancelled, reparenting a cancel + # status to be underneath it can invoke an abort_fn, which might + # do something evil like cancel the old nursery. We thus break + # everything off from the old nursery before we start attaching + # anything to the new. + cancel_status_children = self._old_nursery._cancel_status.children + cancel_status_tasks = set(self._old_nursery._cancel_status.tasks) + cancel_status_tasks.discard(self._old_nursery._parent_task) + for cancel_status in cancel_status_children: + cancel_status.parent = None + for task in cancel_status_tasks: + task._activate_cancel_status(None) + for cancel_status in cancel_status_children: + cancel_status.parent = self._new_nursery._cancel_status + for task in cancel_status_tasks: + task._activate_cancel_status(self._new_nursery._cancel_status) + + # That should have removed all the children from the old nursery + assert not self._old_nursery._children + + # And finally, poke the old nursery so it notices that all its + # children have disappeared and can exit. + self._old_nursery._check_nursery_closed() + + [email protected](slots=False) +class NurseryManager: + """Nursery context manager. + + Note we explicitly avoid @asynccontextmanager and @async_generator + since they add a lot of extraneous stack frames to exceptions, as + well as cause problematic behavior with handling of StopIteration + and StopAsyncIteration. + + """ + + strict_exception_groups: bool = True + + @enable_ki_protection + async def __aenter__(self) -> Nursery: + self._scope = CancelScope() + self._scope.__enter__() + self._nursery = Nursery._create( + current_task(), + self._scope, + self.strict_exception_groups, + ) + return self._nursery + + @enable_ki_protection + async def __aexit__( + self, + etype: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: + new_exc = await self._nursery._nested_child_finished(exc) + # Tracebacks show the 'raise' line below out of context, so let's give + # this variable a name that makes sense out of context. + combined_error_from_nursery = self._scope._close(new_exc) + if combined_error_from_nursery is None: + return True + elif combined_error_from_nursery is exc: + return False + else: + # Copied verbatim from the old MultiErrorCatcher. Python doesn't + # allow us to encapsulate this __context__ fixup. + old_context = combined_error_from_nursery.__context__ + try: + raise combined_error_from_nursery + finally: + _, value, _ = sys.exc_info() + assert value is combined_error_from_nursery + value.__context__ = old_context + # delete references from locals to avoid creating cycles + # see test_cancel_scope_exit_doesnt_create_cyclic_garbage + del _, combined_error_from_nursery, value, new_exc + + # make sure these raise errors in static analysis if called + if not TYPE_CHECKING: + + def __enter__(self) -> NoReturn: + raise RuntimeError( + "use 'async with open_nursery(...)', not 'with open_nursery(...)'", + ) + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> NoReturn: # pragma: no cover + raise AssertionError("Never called, but should be defined") + + +def open_nursery( + strict_exception_groups: bool | None = None, +) -> AbstractAsyncContextManager[Nursery]: + """Returns an async context manager which must be used to create a + new `Nursery`. + + It does not block on entry; on exit it blocks until all child tasks + have exited. If no child tasks are running on exit, it will insert a + schedule point (but no cancellation point) - equivalent to + :func:`trio.lowlevel.cancel_shielded_checkpoint`. This means a nursery + is never the source of a cancellation exception, it only propagates it + from sub-tasks. + + Args: + strict_exception_groups (bool): Unless set to False, even a single raised exception + will be wrapped in an exception group. If not specified, uses the value passed + to :func:`run`, which defaults to true. Setting it to False will be deprecated + and ultimately removed in a future version of Trio. + + """ + # only warn if explicitly set to falsy, not if we get it from the global context. + if strict_exception_groups is not None and not strict_exception_groups: + warn_deprecated( + "open_nursery(strict_exception_groups=False)", + version="0.25.0", + issue=2929, + instead=( + "the default value of True and rewrite exception handlers to handle ExceptionGroups. " + "See https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors" + ), + use_triodeprecationwarning=True, + ) + + if strict_exception_groups is None: + strict_exception_groups = GLOBAL_RUN_CONTEXT.runner.strict_exception_groups + + return NurseryManager(strict_exception_groups=strict_exception_groups) + + +@final +class Nursery(metaclass=NoPublicConstructor): + """A context which may be used to spawn (or cancel) child tasks. + + Not constructed directly, use `open_nursery` instead. + + The nursery will remain open until all child tasks have completed, + or until it is cancelled, at which point it will cancel all its + remaining child tasks and close. + + Nurseries ensure the absence of orphaned Tasks, since all running + tasks will belong to an open Nursery. + + Attributes: + cancel_scope: + Creating a nursery also implicitly creates a cancellation scope, + which is exposed as the :attr:`cancel_scope` attribute. This is + used internally to implement the logic where if an error occurs + then ``__aexit__`` cancels all children, but you can use it for + other things, e.g. if you want to explicitly cancel all children + in response to some external event. + """ + + def __init__( + self, + parent_task: Task, + cancel_scope: CancelScope, + strict_exception_groups: bool, + ) -> None: + self._parent_task = parent_task + self._strict_exception_groups = strict_exception_groups + parent_task._child_nurseries.append(self) + # the cancel status that children inherit - we take a snapshot, so it + # won't be affected by any changes in the parent. + self._cancel_status = parent_task._cancel_status + # the cancel scope that directly surrounds us; used for cancelling all + # children. + self.cancel_scope = cancel_scope + assert self.cancel_scope._cancel_status is self._cancel_status + self._children: set[Task] = set() + self._pending_excs: list[BaseException] = [] + # The "nested child" is how this code refers to the contents of the + # nursery's 'async with' block, which acts like a child Task in all + # the ways we can make it. + self._nested_child_running = True + self._parent_waiting_in_aexit = False + self._pending_starts = 0 + self._closed = False + + @property + def child_tasks(self) -> frozenset[Task]: + """(`frozenset`): Contains all the child :class:`~trio.lowlevel.Task` + objects which are still running.""" + return frozenset(self._children) + + @property + def parent_task(self) -> Task: + "(`~trio.lowlevel.Task`): The Task that opened this nursery." + return self._parent_task + + def _add_exc(self, exc: BaseException, reason: CancelReason | None) -> None: + self._pending_excs.append(exc) + self.cancel_scope._cancel(reason) + + def _check_nursery_closed(self) -> None: + if not any([self._nested_child_running, self._children, self._pending_starts]): + self._closed = True + if self._parent_waiting_in_aexit: + self._parent_waiting_in_aexit = False + GLOBAL_RUN_CONTEXT.runner.reschedule(self._parent_task) + + def _child_finished( + self, + task: Task, + outcome: Outcome[object], + ) -> None: + self._children.remove(task) + if self._closed and not hasattr(self, "_pending_excs"): + # We're abandoned by misnested nurseries, the result of the task is lost. + return + if isinstance(outcome, Error): + self._add_exc( + outcome.error, + CancelReason( + source="nursery", + source_task=repr(task), + reason=f"child task raised exception {outcome.error!r}", + ), + ) + self._check_nursery_closed() + + async def _nested_child_finished( + self, + nested_child_exc: BaseException | None, + ) -> BaseException | None: + # Returns ExceptionGroup instance (or any exception if the nursery is in loose mode + # and there is just one contained exception) if there are pending exceptions + if nested_child_exc is not None: + self._add_exc( + nested_child_exc, + reason=CancelReason( + source="nursery", + source_task=repr(self._parent_task), + reason=f"Code block inside nursery contextmanager raised exception {nested_child_exc!r}", + ), + ) + self._nested_child_running = False + self._check_nursery_closed() + + if not self._closed: + # If we have a KeyboardInterrupt injected, we want to save it in + # the nursery's final exceptions list. But if it's just a + # Cancelled, then we don't -- see gh-1457. + def aborted(raise_cancel: _core.RaiseCancelT) -> Abort: + exn = capture(raise_cancel).error + if not isinstance(exn, Cancelled): + self._add_exc( + exn, + CancelReason( + source="KeyboardInterrupt", + source_task=repr(self._parent_task), + ), + ) + # see test_cancel_scope_exit_doesnt_create_cyclic_garbage + del exn # prevent cyclic garbage creation + return Abort.FAILED + + self._parent_waiting_in_aexit = True + await wait_task_rescheduled(aborted) + else: + # Nothing to wait for, so execute a schedule point, but don't + # allow us to be cancelled, just like the other branch. We + # still need to catch and store non-Cancelled exceptions. + try: + await cancel_shielded_checkpoint() + except BaseException as exc: + # there's no children to cancel, so don't need to supply cancel reason + self._add_exc(exc, reason=None) + + popped = self._parent_task._child_nurseries.pop() + assert popped is self, "Nursery misnesting detected!" + if self._pending_excs: + try: + if not self._strict_exception_groups and len(self._pending_excs) == 1: + return self._pending_excs[0] + exception = BaseExceptionGroup( + "Exceptions from Trio nursery", + self._pending_excs, + ) + if not self._strict_exception_groups: + exception.add_note(NONSTRICT_EXCEPTIONGROUP_NOTE) + return exception + finally: + # avoid a garbage cycle + # (see test_locals_destroyed_promptly_on_cancel) + del self._pending_excs + return None + + def start_soon( + self, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + *args: Unpack[PosArgT], + name: object = None, + ) -> None: + """Creates a child task, scheduling ``await async_fn(*args)``. + + If you want to run a function and immediately wait for its result, + then you don't need a nursery; just use ``await async_fn(*args)``. + If you want to wait for the task to initialize itself before + continuing, see :meth:`start`, the other fundamental method for + creating concurrent tasks in Trio. + + Note that this is *not* an async function and you don't use await + when calling it. It sets up the new task, but then returns + immediately, *before* the new task has a chance to do anything. + New tasks may start running in any order, and at any checkpoint the + scheduler chooses - at latest when the nursery is waiting to exit. + + It's possible to pass a nursery object into another task, which + allows that task to start new child tasks in the first task's + nursery. + + The child task inherits its parent nursery's cancel scopes. + + Args: + async_fn: An async callable. + args: Positional arguments for ``async_fn``. If you want + to pass keyword arguments, use + :func:`functools.partial`. + name: The name for this task. Only used for + debugging/introspection + (e.g. ``repr(task_obj)``). If this isn't a string, + :meth:`start_soon` will try to make it one. A + common use case is if you're wrapping a function + before spawning a new task, you might pass the + original function as the ``name=`` to make + debugging easier. + + Raises: + RuntimeError: If this nursery is no longer open + (i.e. its ``async with`` block has + exited). + """ + GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name) + + # Typing changes blocked by https://github.com/python/mypy/pull/17512 + async def start( # type: ignore[explicit-any] + self, + async_fn: Callable[..., Awaitable[object]], + *args: object, + name: object = None, + ) -> Any: + r"""Creates and initializes a child task. + + Like :meth:`start_soon`, but blocks until the new task has + finished initializing itself, and optionally returns some + information from it. + + The ``async_fn`` must accept a ``task_status`` keyword argument, + and it must make sure that it (or someone) eventually calls + :meth:`task_status.started() <TaskStatus.started>`. + + The conventional way to define ``async_fn`` is like:: + + async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): + ... # Caller is blocked waiting for this code to run + task_status.started() + ... # This async code can be interleaved with the caller + + :attr:`trio.TASK_STATUS_IGNORED` is a special global object with + a do-nothing ``started`` method. This way your function supports + being called either like ``await nursery.start(async_fn, arg1, + arg2)`` or directly like ``await async_fn(arg1, arg2)``, and + either way it can call :meth:`task_status.started() <TaskStatus.started>` + without worrying about which mode it's in. Defining your function like + this will make it obvious to readers that it supports being used + in both modes. + + Before the child calls :meth:`task_status.started() <TaskStatus.started>`, + it's effectively run underneath the call to :meth:`start`: if it + raises an exception then that exception is reported by + :meth:`start`, and does *not* propagate out of the nursery. If + :meth:`start` is cancelled, then the child task is also + cancelled. + + When the child calls :meth:`task_status.started() <TaskStatus.started>`, + it's moved out from underneath :meth:`start` and into the given nursery. + + If the child task passes a value to :meth:`task_status.started(value) <TaskStatus.started>`, + then :meth:`start` returns this value. Otherwise, it returns ``None``. + """ + if self._closed: + raise RuntimeError("Nursery is closed to new arrivals") + try: + self._pending_starts += 1 + # wrap internal nursery in try-except to unroll any exceptiongroups + # to avoid wrapping pre-started() exceptions in an extra ExceptionGroup. + # See #2611. + try: + # set strict_exception_groups = True to make sure we always unwrap + # *this* nursery's exceptiongroup + async with open_nursery(strict_exception_groups=True) as old_nursery: + task_status: _TaskStatus[object | None] = _TaskStatus( + old_nursery, + self, + ) + thunk = functools.partial(async_fn, task_status=task_status) + task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( + thunk, + args, + old_nursery, + name, + ) + task._eventual_parent_nursery = self + # Wait for either TaskStatus.started or an exception to + # cancel this nursery: + except BaseExceptionGroup as exc: + if len(exc.exceptions) == 1: + raise exc.exceptions[0] from None + raise TrioInternalError( + "Internal nursery should not have multiple tasks. This can be " + 'caused by the user managing to access the "old" nursery in ' + "`task_status` and spawning tasks in it.", + ) from exc + + # If we get here, then the child either got reparented or exited + # normally. The complicated logic is all in TaskStatus.started(). + # (Any exceptions propagate directly out of the above.) + if task_status._value is _NoStatus: + raise RuntimeError("child exited without calling task_status.started()") + return task_status._value + finally: + self._pending_starts -= 1 + self._check_nursery_closed() + + def __del__(self) -> None: + assert not self._children + + +################################################################ +# Task and friends +################################################################ + + +@final [email protected](eq=False, repr=False) +class Task(metaclass=NoPublicConstructor): # type: ignore[explicit-any] + _parent_nursery: Nursery | None + coro: types.CoroutineType[Any, Outcome[object], Any] # type: ignore[explicit-any] + _runner: Runner + name: str + context: contextvars.Context + _counter: int = attrs.field(init=False, factory=itertools.count().__next__) + _ki_protected: bool + + # Invariant: + # - for unscheduled tasks, _next_send_fn and _next_send are both None + # - for scheduled tasks, _next_send_fn(_next_send) resumes the task; + # usually _next_send_fn is self.coro.send and _next_send is an + # Outcome. When recovering from a foreign await, _next_send_fn is + # self.coro.throw and _next_send is an exception. _next_send_fn + # will effectively be at the top of every task's call stack, so + # it should be written in C if you don't want to pollute Trio + # tracebacks with extraneous frames. + # - for scheduled tasks, custom_sleep_data is None + # Tasks start out unscheduled. + _next_send_fn: Callable[[Any], object] | None = None # type: ignore[explicit-any] + _next_send: Outcome[Any] | BaseException | None = None # type: ignore[explicit-any] + _abort_func: Callable[[_core.RaiseCancelT], Abort] | None = None + custom_sleep_data: Any = None # type: ignore[explicit-any] + + # For introspection and nursery.start() + _child_nurseries: list[Nursery] = attrs.Factory(list) + _eventual_parent_nursery: Nursery | None = None + + # these are counts of how many cancel/schedule points this task has + # executed, for assert{_no,}_checkpoints + # XX maybe these should be exposed as part of a statistics() method? + _cancel_points: int = 0 + _schedule_points: int = 0 + + def __repr__(self) -> str: + return f"<Task {self.name!r} at {id(self):#x}>" + + @property + def parent_nursery(self) -> Nursery | None: + """The nursery this task is inside (or None if this is the "init" + task). + + Example use case: drawing a visualization of the task tree in a + debugger. + + """ + return self._parent_nursery + + @property + def eventual_parent_nursery(self) -> Nursery | None: + """The nursery this task will be inside after it calls + ``task_status.started()``. + + If this task has already called ``started()``, or if it was not + spawned using `nursery.start() <trio.Nursery.start>`, then + its `eventual_parent_nursery` is ``None``. + + """ + return self._eventual_parent_nursery + + @property + def child_nurseries(self) -> list[Nursery]: + """The nurseries this task contains. + + This is a list, with outer nurseries before inner nurseries. + + """ + return list(self._child_nurseries) + + def iter_await_frames(self) -> Iterator[tuple[types.FrameType, int]]: + """Iterates recursively over the coroutine-like objects this + task is waiting on, yielding the frame and line number at each + frame. + + This is similar to `traceback.walk_stack` in a synchronous + context. Note that `traceback.walk_stack` returns frames from + the bottom of the call stack to the top, while this function + starts from `Task.coro <trio.lowlevel.Task.coro>` and works it + way down. + + Example usage: extracting a stack trace:: + + import traceback + + def print_stack_for_task(task): + ss = traceback.StackSummary.extract(task.iter_await_frames()) + print("".join(ss.format())) + + """ + # Ignore static typing as we're doing lots of dynamic introspection + coro: Any = self.coro # type: ignore[explicit-any] + while coro is not None: + if hasattr(coro, "cr_frame"): + # A real coroutine + yield coro.cr_frame, coro.cr_frame.f_lineno + coro = coro.cr_await + elif hasattr(coro, "gi_frame"): + # A generator decorated with @types.coroutine + yield coro.gi_frame, coro.gi_frame.f_lineno + coro = coro.gi_yieldfrom + elif coro.__class__.__name__ in [ + "async_generator_athrow", + "async_generator_asend", + ]: + # cannot extract the generator directly, see https://github.com/python/cpython/issues/76991 + # we can however use the gc to look through the object + for referent in gc.get_referents(coro): + if hasattr(referent, "ag_frame"): # pragma: no branch + yield referent.ag_frame, referent.ag_frame.f_lineno + coro = referent.ag_await + break + else: # pragma: no cover + # either cpython changed or we are running on an alternative python implementation + return + else: # pragma: no cover + return + + ################ + # Cancellation + ################ + + # The CancelStatus object that is currently active for this task. + # Don't change this directly; instead, use _activate_cancel_status(). + # This can be None, but only in the init task. + _cancel_status: CancelStatus = attrs.field(default=None, repr=False) + + def _activate_cancel_status(self, cancel_status: CancelStatus | None) -> None: + if self._cancel_status is not None: + self._cancel_status._tasks.remove(self) + self._cancel_status = cancel_status # type: ignore[assignment] + if self._cancel_status is not None: + self._cancel_status._tasks.add(self) + if self._cancel_status.effectively_cancelled: + self._attempt_delivery_of_any_pending_cancel() + + def _attempt_abort(self, raise_cancel: _core.RaiseCancelT) -> None: + # Either the abort succeeds, in which case we will reschedule the + # task, or else it fails, in which case it will worry about + # rescheduling itself (hopefully eventually calling reraise to raise + # the given exception, but not necessarily). + + # This is only called by the functions immediately below, which both check + # `self.abort_func is not None`. + assert self._abort_func is not None, "FATAL INTERNAL ERROR" + + success = self._abort_func(raise_cancel) + if type(success) is not Abort: + raise TrioInternalError("abort function must return Abort enum") + # We only attempt to abort once per blocking call, regardless of + # whether we succeeded or failed. + self._abort_func = None + if success is Abort.SUCCEEDED: + self._runner.reschedule(self, capture(raise_cancel)) + + def _attempt_delivery_of_any_pending_cancel(self) -> None: + if self._abort_func is None: + return + if not self._cancel_status.effectively_cancelled: + return + + reason = self._cancel_status._scope._cancel_reason + + def raise_cancel() -> NoReturn: + if reason is None: + raise Cancelled._create(source="unknown", reason="misnesting") + else: + raise Cancelled._create( + source=reason.source, + reason=reason.reason, + source_task=reason.source_task, + ) + + self._attempt_abort(raise_cancel) + + def _attempt_delivery_of_pending_ki(self) -> None: + assert self._runner.ki_pending + if self._abort_func is None: + return + + def raise_cancel() -> NoReturn: + self._runner.ki_pending = False + raise KeyboardInterrupt + + self._attempt_abort(raise_cancel) + + +################################################################ +# The central Runner object +################################################################ + + +class RunStatistics: + """An object containing run-loop-level debugging information. + + Currently, the following fields are defined: + + * ``tasks_living`` (int): The number of tasks that have been spawned + and not yet exited. + * ``tasks_runnable`` (int): The number of tasks that are currently + queued on the run queue (as opposed to blocked waiting for something + to happen). + * ``seconds_to_next_deadline`` (float): The time until the next + pending cancel scope deadline. May be negative if the deadline has + expired but we haven't yet processed cancellations. May be + :data:`~math.inf` if there are no pending deadlines. + * ``run_sync_soon_queue_size`` (int): The number of + unprocessed callbacks queued via + :meth:`trio.lowlevel.TrioToken.run_sync_soon`. + * ``io_statistics`` (object): Some statistics from Trio's I/O + backend. This always has an attribute ``backend`` which is a string + naming which operating-system-specific I/O backend is in use; the + other attributes vary between backends. + """ + + tasks_living: int + tasks_runnable: int + seconds_to_next_deadline: float + io_statistics: IOStatistics + run_sync_soon_queue_size: int + + +# This holds all the state that gets trampolined back and forth between +# callbacks when we're running in guest mode. +# +# It has to be a separate object from Runner, and Runner *cannot* hold +# references to it (directly or indirectly)! +# +# The idea is that we want a chance to detect if our host loop quits and stops +# driving us forward. We detect that by unrolled_run_gen being garbage +# collected, and hitting its 'except GeneratorExit:' block. So this only +# happens if unrolled_run_gen is GCed. +# +# The Runner state is referenced from the global GLOBAL_RUN_CONTEXT. The only +# way it gets *un*referenced is by unrolled_run_gen completing, e.g. by being +# GCed. But if Runner has a direct or indirect reference to it, and the host +# loop has abandoned it, then this will never happen! +# +# So this object can reference Runner, but Runner can't reference it. The only +# references to it are the "in flight" callback chain on the host loop / +# worker thread. + + [email protected](eq=False) +class GuestState: # type: ignore[explicit-any] + runner: Runner + run_sync_soon_threadsafe: Callable[[Callable[[], object]], object] + run_sync_soon_not_threadsafe: Callable[[Callable[[], object]], object] + done_callback: Callable[[Outcome[Any]], object] # type: ignore[explicit-any] + unrolled_run_gen: Generator[float, EventResult, None] + unrolled_run_next_send: Outcome[Any] = attrs.Factory(lambda: Value(None)) # type: ignore[explicit-any] + + def guest_tick(self) -> None: + prev_library, sniffio_library.name = sniffio_library.name, "trio" + try: + timeout = self.unrolled_run_next_send.send(self.unrolled_run_gen) + except StopIteration: + assert self.runner.main_task_outcome is not None + self.done_callback(self.runner.main_task_outcome) + return + except TrioInternalError as exc: + self.done_callback(Error(exc)) + return + finally: + sniffio_library.name = prev_library + + # Optimization: try to skip going into the thread if we can avoid it + events_outcome: Value[EventResult] | Error = capture( + self.runner.io_manager.get_events, + 0, + ) + if timeout <= 0 or isinstance(events_outcome, Error) or events_outcome.value: + # No need to go into the thread + self.unrolled_run_next_send = events_outcome + self.runner.guest_tick_scheduled = True + self.run_sync_soon_not_threadsafe(self.guest_tick) + else: + # Need to go into the thread and call get_events() there + self.runner.guest_tick_scheduled = False + + def get_events() -> EventResult: + return self.runner.io_manager.get_events(timeout) + + def deliver(events_outcome: Outcome[EventResult]) -> None: + def in_main_thread() -> None: + self.unrolled_run_next_send = events_outcome + self.runner.guest_tick_scheduled = True + self.guest_tick() + + self.run_sync_soon_threadsafe(in_main_thread) + + start_thread_soon(get_events, deliver) + + [email protected](eq=False) +class Runner: # type: ignore[explicit-any] + clock: Clock + instruments: Instruments + io_manager: TheIOManager + ki_manager: KIManager + strict_exception_groups: bool + + # Run-local values, see _local.py + _locals: dict[_core.RunVar[Any], object] = attrs.Factory(dict) # type: ignore[explicit-any] + + runq: deque[Task] = attrs.Factory(deque) + tasks: set[Task] = attrs.Factory(set) + + deadlines: Deadlines = attrs.Factory(Deadlines) + + init_task: Task | None = None + system_nursery: Nursery | None = None + system_context: contextvars.Context = attrs.field(kw_only=True) + main_task: Task | None = None + main_task_outcome: Outcome[object] | None = None + + entry_queue: EntryQueue = attrs.Factory(EntryQueue) + trio_token: TrioToken | None = None + asyncgens: AsyncGenerators = attrs.Factory(AsyncGenerators) + + # If everything goes idle for this long, we call clock._autojump() + clock_autojump_threshold: float = inf + + # Guest mode stuff + is_guest: bool = False + guest_tick_scheduled: bool = False + + def force_guest_tick_asap(self) -> None: + if self.guest_tick_scheduled: + return + self.guest_tick_scheduled = True + self.io_manager.force_wakeup() + + def close(self) -> None: + self.io_manager.close() + self.entry_queue.close() + self.asyncgens.close() + if "after_run" in self.instruments: + self.instruments.call("after_run") + # This is where KI protection gets disabled, so we do it last + self.ki_manager.close() + + @_public + def current_statistics(self) -> RunStatistics: + """Returns ``RunStatistics``, which contains run-loop-level debugging information. + + Currently, the following fields are defined: + + * ``tasks_living`` (int): The number of tasks that have been spawned + and not yet exited. + * ``tasks_runnable`` (int): The number of tasks that are currently + queued on the run queue (as opposed to blocked waiting for something + to happen). + * ``seconds_to_next_deadline`` (float): The time until the next + pending cancel scope deadline. May be negative if the deadline has + expired but we haven't yet processed cancellations. May be + :data:`~math.inf` if there are no pending deadlines. + * ``run_sync_soon_queue_size`` (int): The number of + unprocessed callbacks queued via + :meth:`trio.lowlevel.TrioToken.run_sync_soon`. + * ``io_statistics`` (object): Some statistics from Trio's I/O + backend. This always has an attribute ``backend`` which is a string + naming which operating-system-specific I/O backend is in use; the + other attributes vary between backends. + + """ + seconds_to_next_deadline = self.deadlines.next_deadline() - self.current_time() + return RunStatistics( + tasks_living=len(self.tasks), + tasks_runnable=len(self.runq), + seconds_to_next_deadline=seconds_to_next_deadline, + io_statistics=self.io_manager.statistics(), + run_sync_soon_queue_size=self.entry_queue.size(), + ) + + @_public + def current_time(self) -> float: + """Returns the current time according to Trio's internal clock. + + Returns: + float: The current time. + + Raises: + RuntimeError: if not inside a call to :func:`trio.run`. + + """ + return self.clock.current_time() + + @_public + def current_clock(self) -> Clock: + """Returns the current :class:`~trio.abc.Clock`.""" + return self.clock + + @_public + def current_root_task(self) -> Task | None: + """Returns the current root :class:`Task`. + + This is the task that is the ultimate parent of all other tasks. + + """ + return self.init_task + + ################ + # Core task handling primitives + ################ + + @_public + def reschedule(self, task: Task, next_send: Outcome[object] = _NO_SEND) -> None: + """Reschedule the given task with the given + :class:`outcome.Outcome`. + + See :func:`wait_task_rescheduled` for the gory details. + + There must be exactly one call to :func:`reschedule` for every call to + :func:`wait_task_rescheduled`. (And when counting, keep in mind that + returning :data:`Abort.SUCCEEDED` from an abort callback is equivalent + to calling :func:`reschedule` once.) + + Args: + task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked + in a call to :func:`wait_task_rescheduled`. + next_send (outcome.Outcome): the value (or error) to return (or + raise) from :func:`wait_task_rescheduled`. + + """ + if next_send is _NO_SEND: + next_send = Value(None) + + assert task._runner is self + assert task._next_send_fn is None + task._next_send_fn = task.coro.send + task._next_send = next_send + task._abort_func = None + task.custom_sleep_data = None + if not self.runq and self.is_guest: + self.force_guest_tick_asap() + self.runq.append(task) + if "task_scheduled" in self.instruments: + self.instruments.call("task_scheduled", task) + + def spawn_impl( + self, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + args: tuple[Unpack[PosArgT]], + nursery: Nursery | None, + name: object, + *, + system_task: bool = False, + context: contextvars.Context | None = None, + ) -> Task: + ###### + # Make sure the nursery is in working order + ###### + + # This sorta feels like it should be a method on nursery, except it + # has to handle nursery=None for init. And it touches the internals of + # all kinds of objects. + if nursery is not None and nursery._closed: + raise RuntimeError("Nursery is closed to new arrivals") + if nursery is None: + assert self.init_task is None + + ###### + # Propagate contextvars + ###### + if context is None: + context = self.system_context.copy() if system_task else copy_context() + + ###### + # Call the function and get the coroutine object, while giving helpful + # errors for common mistakes. + ###### + # TypeVarTuple passed into ParamSpec function confuses Mypy. + coro = context.run(coroutine_or_error, async_fn, *args) # type: ignore[arg-type] + + if name is None: + name = async_fn + if isinstance(name, functools.partial): + name = name.func + if not isinstance(name, str): + try: + name = f"{name.__module__}.{name.__qualname__}" # type: ignore[attr-defined] + except AttributeError: + name = repr(name) + + # very old Cython versions (<0.29.24) has the attribute, but with a value of None + if getattr(coro, "cr_frame", None) is None: + # This async function is implemented in C or Cython + async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: + return await orig_coro + + coro = python_wrapper(coro) + assert coro.cr_frame is not None, "Coroutine frame should exist" # type: ignore[attr-defined] + + ###### + # Set up the Task object + ###### + task = Task._create( + coro=coro, + parent_nursery=nursery, + runner=self, + name=name, + context=context, + ki_protected=system_task, + ) + + self.tasks.add(task) + if nursery is not None: + nursery._children.add(task) + task._activate_cancel_status(nursery._cancel_status) + + if "task_spawned" in self.instruments: + self.instruments.call("task_spawned", task) + # Special case: normally next_send should be an Outcome, but for the + # very first send we have to send a literal unboxed None. + self.reschedule(task, None) # type: ignore[arg-type] + return task + + def task_exited(self, task: Task, outcome: Outcome[object]) -> None: + if task._child_nurseries: + for nursery in task._child_nurseries: + nursery.cancel_scope._cancel( + CancelReason( + source="nursery", + reason="Parent Task exited prematurely, abandoning this nursery without exiting it properly.", + source_task=repr(task), + ) + ) + nursery._closed = True + + # break parking lots associated with the exiting task + if task in GLOBAL_PARKING_LOT_BREAKER: + for lot in GLOBAL_PARKING_LOT_BREAKER[task]: + lot.break_lot(task) + del GLOBAL_PARKING_LOT_BREAKER[task] + + if ( + task._cancel_status is not None + and task._cancel_status.abandoned_by_misnesting + and task._cancel_status.parent is None + ) or task._child_nurseries: + reason = "Nursery" if task._child_nurseries else "Cancel scope" + # The cancel scope surrounding this task's nursery was closed + # before the task exited. Force the task to exit with an error, + # since the error might not have been caught elsewhere. See the + # comments in CancelStatus.close(). + try: + # Raise this, rather than just constructing it, to get a + # traceback frame included + raise RuntimeError( + f"{reason} stack corrupted: {reason} surrounding " + f"{task!r} was closed before the task exited\n{MISNESTING_ADVICE}", + ) + except RuntimeError as new_exc: + if isinstance(outcome, Error): + new_exc.__context__ = outcome.error + outcome = Error(new_exc) + + task._activate_cancel_status(None) + self.tasks.remove(task) + if task is self.init_task: + # If the init task crashed, then something is very wrong and we + # let the error propagate. (It'll eventually be wrapped in a + # TrioInternalError.) + outcome.unwrap() + # the init task should be the last task to exit. If not, then + # something is very wrong. + if self.tasks: # pragma: no cover + raise TrioInternalError + else: + if task is self.main_task: + self.main_task_outcome = outcome + outcome = Value(None) + assert task._parent_nursery is not None, task + task._parent_nursery._child_finished(task, outcome) + + if "task_exited" in self.instruments: + self.instruments.call("task_exited", task) + + ################ + # System tasks and init + ################ + + @_public + def spawn_system_task( + self, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + *args: Unpack[PosArgT], + name: object = None, + context: contextvars.Context | None = None, + ) -> Task: + """Spawn a "system" task. + + System tasks have a few differences from regular tasks: + + * They don't need an explicit nursery; instead they go into the + internal "system nursery". + + * If a system task raises an exception, then it's converted into a + :exc:`~trio.TrioInternalError` and *all* tasks are cancelled. If you + write a system task, you should be careful to make sure it doesn't + crash. + + * System tasks are automatically cancelled when the main task exits. + + * By default, system tasks have :exc:`KeyboardInterrupt` protection + *enabled*. If you want your task to be interruptible by control-C, + then you need to use :func:`disable_ki_protection` explicitly (and + come up with some plan for what to do with a + :exc:`KeyboardInterrupt`, given that system tasks aren't allowed to + raise exceptions). + + * System tasks do not inherit context variables from their creator. + + Towards the end of a call to :meth:`trio.run`, after the main + task and all system tasks have exited, the system nursery + becomes closed. At this point, new calls to + :func:`spawn_system_task` will raise ``RuntimeError("Nursery + is closed to new arrivals")`` instead of creating a system + task. It's possible to encounter this state either in + a ``finally`` block in an async generator, or in a callback + passed to :meth:`TrioToken.run_sync_soon` at the right moment. + + Args: + async_fn: An async callable. + args: Positional arguments for ``async_fn``. If you want to pass + keyword arguments, use :func:`functools.partial`. + name: The name for this task. Only used for debugging/introspection + (e.g. ``repr(task_obj)``). If this isn't a string, + :func:`spawn_system_task` will try to make it one. A common use + case is if you're wrapping a function before spawning a new + task, you might pass the original function as the ``name=`` to + make debugging easier. + context: An optional ``contextvars.Context`` object with context variables + to use for this task. You would normally get a copy of the current + context with ``context = contextvars.copy_context()`` and then you would + pass that ``context`` object here. + + Returns: + Task: the newly spawned task + + """ + return self.spawn_impl( + async_fn, + args, + self.system_nursery, + name, + system_task=True, + context=context, + ) + + async def init( + self, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + args: tuple[Unpack[PosArgT]], + ) -> None: + # run_sync_soon task runs here: + async with open_nursery() as run_sync_soon_nursery: + # All other system tasks run here: + async with open_nursery() as self.system_nursery: + # Only the main task runs here: + async with open_nursery() as main_task_nursery: + try: + self.main_task = self.spawn_impl( + async_fn, + args, + main_task_nursery, + None, + ) + except BaseException as exc: + self.main_task_outcome = Error(exc) + return + self.spawn_impl( + self.entry_queue.task, + (), + run_sync_soon_nursery, + "<TrioToken.run_sync_soon task>", + system_task=True, + ) + + # Main task is done; start shutting down system tasks + self.system_nursery.cancel_scope._cancel( + CancelReason( + source="shutdown", + reason="main task done, shutting down system tasks", + source_task=repr(self.init_task), + ) + ) + + # System nursery is closed; finalize remaining async generators + await self.asyncgens.finalize_remaining(self) + + # There are no more asyncgens, which means no more user-provided + # code except possibly run_sync_soon callbacks. It's finally safe + # to stop the run_sync_soon task and exit run(). + run_sync_soon_nursery.cancel_scope._cancel( + CancelReason( + source="shutdown", + reason="main task done, shutting down run_sync_soon callbacks", + source_task=repr(self.init_task), + ) + ) + + ################ + # Outside context problems + ################ + + @_public + def current_trio_token(self) -> TrioToken: + """Retrieve the :class:`TrioToken` for the current call to + :func:`trio.run`. + + """ + if self.trio_token is None: + self.trio_token = TrioToken._create(self.entry_queue) + return self.trio_token + + ################ + # KI handling + ################ + + ki_pending: bool = False + + # deliver_ki is broke. Maybe move all the actual logic and state into + # RunToken, and we'll only have one instance per runner? But then we can't + # have a public constructor. Eh, but current_run_token() returning a + # unique object per run feels pretty nice. Maybe let's just go for it. And + # keep the class public so people can isinstance() it if they want. + + # This gets called from signal context + def deliver_ki(self) -> None: + self.ki_pending = True + with suppress(RunFinishedError): + self.entry_queue.run_sync_soon(self._deliver_ki_cb) + + def _deliver_ki_cb(self) -> None: + if not self.ki_pending: + return + # Can't happen because main_task and run_sync_soon_task are created at + # the same time -- so even if KI arrives before main_task is created, + # we won't get here until afterwards. + assert self.main_task is not None + if self.main_task_outcome is not None: + # We're already in the process of exiting -- leave ki_pending set + # and we'll check it again on our way out of run(). + return + self.main_task._attempt_delivery_of_pending_ki() + + ################ + # Quiescing + ################ + + # sortedcontainers doesn't have types, and is reportedly very hard to type: + # https://github.com/grantjenks/python-sortedcontainers/issues/68 + waiting_for_idle: Any = attrs.Factory(SortedDict) # type: ignore[explicit-any] + + @_public + async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: + """Block until there are no runnable tasks. + + This is useful in testing code when you want to give other tasks a + chance to "settle down". The calling task is blocked, and doesn't wake + up until all other tasks are also blocked for at least ``cushion`` + seconds. (Setting a non-zero ``cushion`` is intended to handle cases + like two tasks talking to each other over a local socket, where we + want to ignore the potential brief moment between a send and receive + when all tasks are blocked.) + + Note that ``cushion`` is measured in *real* time, not the Trio clock + time. + + If there are multiple tasks blocked in :func:`wait_all_tasks_blocked`, + then the one with the shortest ``cushion`` is the one woken (and + this task becoming unblocked resets the timers for the remaining + tasks). If there are multiple tasks that have exactly the same + ``cushion``, then all are woken. + + You should also consider :class:`trio.testing.Sequencer`, which + provides a more explicit way to control execution ordering within a + test, and will often produce more readable tests. + + Example: + Here's an example of one way to test that Trio's locks are fair: we + take the lock in the parent, start a child, wait for the child to be + blocked waiting for the lock (!), and then check that we can't + release and immediately re-acquire the lock:: + + async def lock_taker(lock): + await lock.acquire() + lock.release() + + async def test_lock_fairness(): + lock = trio.Lock() + await lock.acquire() + async with trio.open_nursery() as nursery: + nursery.start_soon(lock_taker, lock) + # child hasn't run yet, we have the lock + assert lock.locked() + assert lock._owner is trio.lowlevel.current_task() + await trio.testing.wait_all_tasks_blocked() + # now the child has run and is blocked on lock.acquire(), we + # still have the lock + assert lock.locked() + assert lock._owner is trio.lowlevel.current_task() + lock.release() + try: + # The child has a prior claim, so we can't have it + lock.acquire_nowait() + except trio.WouldBlock: + assert lock._owner is not trio.lowlevel.current_task() + print("PASS") + else: + print("FAIL") + + """ + task = current_task() + key = (cushion, id(task)) + self.waiting_for_idle[key] = task + + def abort(_: _core.RaiseCancelT) -> Abort: + del self.waiting_for_idle[key] + return Abort.SUCCEEDED + + await wait_task_rescheduled(abort) + + +################################################################ +# run +################################################################ +# +# Trio's core task scheduler and coroutine runner is in 'unrolled_run'. It's +# called that because it has an unusual feature: it's actually a generator. +# Whenever it needs to fetch IO events from the OS, it yields, and waits for +# its caller to send the IO events back in. So the loop is "unrolled" into a +# sequence of generator send() calls. +# +# The reason for this unusual design is to support two different modes of +# operation, where the IO is handled differently. +# +# In normal mode using trio.run, the scheduler and IO run in the same thread: +# +# Main thread: +# +# +---------------------------+ +# | Run tasks | +# | (unrolled_run) | +# +---------------------------+ +# | Block waiting for I/O | +# | (io_manager.get_events) | +# +---------------------------+ +# | Run tasks | +# | (unrolled_run) | +# +---------------------------+ +# | Block waiting for I/O | +# | (io_manager.get_events) | +# +---------------------------+ +# : +# +# +# In guest mode using trio.lowlevel.start_guest_run, the scheduler runs on the +# main thread as a host loop callback, but blocking for IO gets pushed into a +# worker thread: +# +# Main thread executing host loop: Trio I/O thread: +# +# +---------------------------+ +# | Run Trio tasks | +# | (unrolled_run) | +# +---------------------------+ --------------+ +# v +# +---------------------------+ +----------------------------+ +# | Host loop does whatever | | Block waiting for Trio I/O | +# | it wants | | (io_manager.get_events) | +# +---------------------------+ +----------------------------+ +# | +# +---------------------------+ <-------------+ +# | Run Trio tasks | +# | (unrolled_run) | +# +---------------------------+ --------------+ +# v +# +---------------------------+ +----------------------------+ +# | Host loop does whatever | | Block waiting for Trio I/O | +# | it wants | | (io_manager.get_events) | +# +---------------------------+ +----------------------------+ +# : : +# +# Most of Trio's internals don't need to care about this difference. The main +# complication it creates is that in guest mode, we might need to wake up not +# just due to OS-reported IO events, but also because of code running on the +# host loop calling reschedule() or changing task deadlines. Search for +# 'is_guest' to see the special cases we need to handle this. + + +def setup_runner( + clock: Clock | None, + instruments: Sequence[Instrument], + restrict_keyboard_interrupt_to_checkpoints: bool, + strict_exception_groups: bool, +) -> Runner: + """Create a Runner object and install it as the GLOBAL_RUN_CONTEXT.""" + # It wouldn't be *hard* to support nested calls to run(), but I can't + # think of a single good reason for it, so let's be conservative for + # now: + if in_trio_run(): + raise RuntimeError("Attempted to call run() from inside a run()") + + if clock is None: + clock = SystemClock() + instrument_group = Instruments(instruments) + io_manager = TheIOManager() + system_context = copy_context() + ki_manager = KIManager() + + runner = Runner( + clock=clock, + instruments=instrument_group, + io_manager=io_manager, + system_context=system_context, + ki_manager=ki_manager, + strict_exception_groups=strict_exception_groups, + ) + runner.asyncgens.install_hooks(runner) + + # This is where KI protection gets enabled, so we want to do it early - in + # particular before we start modifying global state like GLOBAL_RUN_CONTEXT + ki_manager.install(runner.deliver_ki, restrict_keyboard_interrupt_to_checkpoints) + + GLOBAL_RUN_CONTEXT.runner = runner + return runner + + +def run( + async_fn: Callable[[Unpack[PosArgT]], Awaitable[RetT]], + *args: Unpack[PosArgT], + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = True, +) -> RetT: + """Run a Trio-flavored async function, and return the result. + + Calling:: + + run(async_fn, *args) + + is the equivalent of:: + + await async_fn(*args) + + except that :func:`run` can (and must) be called from a synchronous + context. + + This is Trio's main entry point. Almost every other function in Trio + requires that you be inside a call to :func:`run`. + + Args: + async_fn: An async function. + + args: Positional arguments to be passed to *async_fn*. If you need to + pass keyword arguments, then use :func:`functools.partial`. + + clock: ``None`` to use the default system-specific monotonic clock; + otherwise, an object implementing the :class:`trio.abc.Clock` + interface, like (for example) a :class:`trio.testing.MockClock` + instance. + + instruments (list of :class:`trio.abc.Instrument` objects): Any + instrumentation you want to apply to this run. This can also be + modified during the run; see :ref:`instrumentation`. + + restrict_keyboard_interrupt_to_checkpoints (bool): What happens if the + user hits control-C while :func:`run` is running? If this argument + is False (the default), then you get the standard Python behavior: a + :exc:`KeyboardInterrupt` exception will immediately interrupt + whatever task is running (or if no task is running, then Trio will + wake up a task to be interrupted). Alternatively, if you set this + argument to True, then :exc:`KeyboardInterrupt` delivery will be + delayed: it will be *only* be raised at :ref:`checkpoints + <checkpoints>`, like a :exc:`Cancelled` exception. + + The default behavior is nice because it means that even if you + accidentally write an infinite loop that never executes any + checkpoints, then you can still break out of it using control-C. + The alternative behavior is nice if you're paranoid about a + :exc:`KeyboardInterrupt` at just the wrong place leaving your + program in an inconsistent state, because it means that you only + have to worry about :exc:`KeyboardInterrupt` at the exact same + places where you already have to worry about :exc:`Cancelled`. + + This setting has no effect if your program has registered a custom + SIGINT handler, or if :func:`run` is called from anywhere but the + main thread (this is a Python limitation), or if you use + :func:`open_signal_receiver` to catch SIGINT. + + strict_exception_groups (bool): Unless set to False, nurseries will always wrap + even a single raised exception in an exception group. This can be overridden + on the level of individual nurseries. Setting it to False will be deprecated + and ultimately removed in a future version of Trio. + + Returns: + Whatever ``async_fn`` returns. + + Raises: + TrioInternalError: if an unexpected error is encountered inside Trio's + internal machinery. This is a bug and you should `let us know + <https://github.com/python-trio/trio/issues>`__. + + Anything else: if ``async_fn`` raises an exception, then :func:`run` + propagates it. + + """ + if strict_exception_groups is not None and not strict_exception_groups: + warn_deprecated( + "trio.run(..., strict_exception_groups=False)", + version="0.25.0", + issue=2929, + instead=( + "the default value of True and rewrite exception handlers to handle ExceptionGroups. " + "See https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors" + ), + use_triodeprecationwarning=True, + ) + + __tracebackhide__ = True + + runner = setup_runner( + clock, + instruments, + restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups, + ) + + prev_library, sniffio_library.name = sniffio_library.name, "trio" + try: + gen = unrolled_run(runner, async_fn, args) + # Need to send None in the first time. + next_send: EventResult = None # type: ignore[assignment] + while True: + try: + timeout = gen.send(next_send) + except StopIteration: + break + next_send = runner.io_manager.get_events(timeout) + finally: + sniffio_library.name = prev_library + # Inlined copy of runner.main_task_outcome.unwrap() to avoid + # cluttering every single Trio traceback with an extra frame. + if isinstance(runner.main_task_outcome, Value): + return cast("RetT", runner.main_task_outcome.value) + elif isinstance(runner.main_task_outcome, Error): + raise runner.main_task_outcome.error + else: # pragma: no cover + raise AssertionError(runner.main_task_outcome) + + +def start_guest_run( # type: ignore[explicit-any] + async_fn: Callable[..., Awaitable[RetT]], + *args: object, + run_sync_soon_threadsafe: Callable[[Callable[[], object]], object], + done_callback: Callable[[outcome.Outcome[RetT]], object], + run_sync_soon_not_threadsafe: ( + Callable[[Callable[[], object]], object] | None + ) = None, + host_uses_signal_set_wakeup_fd: bool = False, + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = True, +) -> None: + """Start a "guest" run of Trio on top of some other "host" event loop. + + Each host loop can only have one guest run at a time. + + You should always let the Trio run finish before stopping the host loop; + if not, it may leave Trio's internal data structures in an inconsistent + state. You might be able to get away with it if you immediately exit the + program, but it's safest not to go there in the first place. + + Generally, the best way to do this is wrap this in a function that starts + the host loop and then immediately starts the guest run, and then shuts + down the host when the guest run completes. + + Once :func:`start_guest_run` returns successfully, the guest run + has been set up enough that you can invoke sync-colored Trio + functions such as :func:`~trio.current_time`, :func:`spawn_system_task`, + and :func:`current_trio_token`. If a `~trio.TrioInternalError` occurs + during this early setup of the guest run, it will be raised out of + :func:`start_guest_run`. All other errors, including all errors + raised by the *async_fn*, will be delivered to your + *done_callback* at some point after :func:`start_guest_run` returns + successfully. + + Args: + + run_sync_soon_threadsafe: An arbitrary callable, which will be passed a + function as its sole argument:: + + def my_run_sync_soon_threadsafe(fn): + ... + + This callable should schedule ``fn()`` to be run by the host on its + next pass through its loop. **Must support being called from + arbitrary threads.** + + done_callback: An arbitrary callable:: + + def my_done_callback(run_outcome): + ... + + When the Trio run has finished, Trio will invoke this callback to let + you know. The argument is an `outcome.Outcome`, reporting what would + have been returned or raised by `trio.run`. This function can do + anything you want, but commonly you'll want it to shut down the + host loop, unwrap the outcome, etc. + + run_sync_soon_not_threadsafe: Like ``run_sync_soon_threadsafe``, but + will only be called from inside the host loop's main thread. + Optional, but if your host loop allows you to implement this more + efficiently than ``run_sync_soon_threadsafe`` then passing it will + make things a bit faster. + + host_uses_signal_set_wakeup_fd (bool): Pass `True` if your host loop + uses `signal.set_wakeup_fd`, and `False` otherwise. For more details, + see :ref:`guest-run-implementation`. + + For the meaning of other arguments, see `trio.run`. + + """ + if strict_exception_groups is not None and not strict_exception_groups: + warn_deprecated( + "trio.start_guest_run(..., strict_exception_groups=False)", + version="0.25.0", + issue=2929, + instead=( + "the default value of True and rewrite exception handlers to handle ExceptionGroups. " + "See https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors" + ), + use_triodeprecationwarning=True, + ) + + runner = setup_runner( + clock, + instruments, + restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups, + ) + runner.is_guest = True + runner.guest_tick_scheduled = True + + if run_sync_soon_not_threadsafe is None: + run_sync_soon_not_threadsafe = run_sync_soon_threadsafe + + guest_state = GuestState( + runner=runner, + run_sync_soon_threadsafe=run_sync_soon_threadsafe, + run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe, + done_callback=done_callback, + unrolled_run_gen=unrolled_run( + runner, + async_fn, + args, + host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, + ), + ) + + # Run a few ticks of the guest run synchronously, so that by the + # time we return, the system nursery exists and callers can use + # spawn_system_task. We don't actually run any user code during + # this time, so it shouldn't be possible to get an exception here, + # except for a TrioInternalError. + next_send = cast( + "EventResult", + None, + ) # First iteration must be `None`, every iteration after that is EventResult + for _tick in range(5): # expected need is 2 iterations + leave some wiggle room + if runner.system_nursery is not None: + # We're initialized enough to switch to async guest ticks + break + try: + timeout = guest_state.unrolled_run_gen.send(next_send) + except StopIteration: # pragma: no cover + raise TrioInternalError( + "Guest runner exited before system nursery was initialized", + ) from None + if timeout != 0: # pragma: no cover + guest_state.unrolled_run_gen.throw( + TrioInternalError( + "Guest runner blocked before system nursery was initialized", + ), + ) + # next_send should be the return value of + # IOManager.get_events() if no I/O was waiting, which is + # platform-dependent. We don't actually check for I/O during + # this init phase because no one should be expecting any yet. + if sys.platform == "win32": + next_send = 0 + else: + next_send = [] + else: # pragma: no cover + guest_state.unrolled_run_gen.throw( + TrioInternalError( + "Guest runner yielded too many times before " + "system nursery was initialized", + ), + ) + + guest_state.unrolled_run_next_send = Value(next_send) + run_sync_soon_not_threadsafe(guest_state.guest_tick) + + +# 24 hours is arbitrary, but it avoids issues like people setting timeouts of +# 10**20 and then getting integer overflows in the underlying system calls. +_MAX_TIMEOUT: Final = 24 * 60 * 60 + + +# Weird quirk: this is written as a generator in order to support "guest +# mode", where our core event loop gets unrolled into a series of callbacks on +# the host loop. If you're doing a regular trio.run then this gets run +# straight through. +@enable_ki_protection +def unrolled_run( + runner: Runner, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + args: tuple[Unpack[PosArgT]], + host_uses_signal_set_wakeup_fd: bool = False, +) -> Generator[float, EventResult, None]: + __tracebackhide__ = True + + try: + if not host_uses_signal_set_wakeup_fd: + runner.entry_queue.wakeup.wakeup_on_signals() + + if "before_run" in runner.instruments: + runner.instruments.call("before_run") + runner.clock.start_clock() + runner.init_task = runner.spawn_impl( + runner.init, + (async_fn, args), + None, + "<init>", + system_task=True, + ) + + # You know how people talk about "event loops"? This 'while' loop right + # here is our event loop: + while runner.tasks: + if runner.runq: + timeout: float = 0 + else: + deadline = runner.deadlines.next_deadline() + timeout = runner.clock.deadline_to_sleep_time(deadline) + timeout = min(max(0, timeout), _MAX_TIMEOUT) + + idle_primed = None + if runner.waiting_for_idle: + cushion, _ = runner.waiting_for_idle.keys()[0] + if cushion < timeout: + timeout = cushion + idle_primed = IdlePrimedTypes.WAITING_FOR_IDLE + # We use 'elif' here because if there are tasks in + # wait_all_tasks_blocked, then those tasks will wake up without + # jumping the clock, so we don't need to autojump. + elif runner.clock_autojump_threshold < timeout: + timeout = runner.clock_autojump_threshold + idle_primed = IdlePrimedTypes.AUTOJUMP_CLOCK + + if "before_io_wait" in runner.instruments: + runner.instruments.call("before_io_wait", timeout) + + # Driver will call io_manager.get_events(timeout) and pass it back + # in through the yield + events = yield timeout + runner.io_manager.process_events(events) + + if "after_io_wait" in runner.instruments: + runner.instruments.call("after_io_wait", timeout) + + # Process cancellations due to deadline expiry + now = runner.clock.current_time() + if runner.deadlines.expire(now): + idle_primed = None + + # idle_primed != None means: if the IO wait hit the timeout, and + # still nothing is happening, then we should start waking up + # wait_all_tasks_blocked tasks or autojump the clock. But there + # are some subtleties in defining "nothing is happening". + # + # 'not runner.runq' means that no tasks are currently runnable. + # 'not events' means that the last IO wait call hit its full + # timeout. These are very similar, and if idle_primed != None and + # we're running in regular mode then they always go together. But, + # in *guest* mode, they can happen independently, even when + # idle_primed=True: + # + # - runner.runq=empty and events=True: the host loop adjusted a + # deadline and that forced an IO wakeup before the timeout expired, + # even though no actual tasks were scheduled. + # + # - runner.runq=nonempty and events=False: the IO wait hit its + # timeout, but then some code in the host thread rescheduled a task + # before we got here. + # + # So we need to check both. + if idle_primed is not None and not runner.runq and not events: + if idle_primed is IdlePrimedTypes.WAITING_FOR_IDLE: + while runner.waiting_for_idle: + key, task = runner.waiting_for_idle.peekitem(0) + if key[0] == cushion: + del runner.waiting_for_idle[key] + runner.reschedule(task) + else: + break + else: + assert idle_primed is IdlePrimedTypes.AUTOJUMP_CLOCK + assert isinstance(runner.clock, _core.MockClock) + runner.clock._autojump() + + # Process all runnable tasks, but only the ones that are already + # runnable now. Anything that becomes runnable during this cycle + # needs to wait until the next pass. This avoids various + # starvation issues by ensuring that there's never an unbounded + # delay between successive checks for I/O. + # + # Also, we randomize the order of each batch to avoid assumptions + # about scheduling order sneaking in. In the long run, I suspect + # we'll either (a) use strict FIFO ordering and document that for + # predictability/determinism, or (b) implement a more + # sophisticated scheduler (e.g. some variant of fair queueing), + # for better behavior under load. For now, this is the worst of + # both worlds - but it keeps our options open. (If we do decide to + # go all in on deterministic scheduling, then there are other + # things that will probably need to change too, like the deadlines + # tie-breaker and the non-deterministic ordering of + # task._notify_queues.) + batch = list(runner.runq) + runner.runq.clear() + if _ALLOW_DETERMINISTIC_SCHEDULING: + # We're running under Hypothesis, and pytest-trio has patched + # this in to make the scheduler deterministic and avoid flaky + # tests. It's not worth the (small) performance cost in normal + # operation, since we'll shuffle the list and _r is only + # seeded for tests. + batch.sort(key=lambda t: t._counter) + _r.shuffle(batch) + else: + # 50% chance of reversing the batch, this way each task + # can appear before/after any other task. + if _r.random() < 0.5: + batch.reverse() + while batch: + task = batch.pop() + GLOBAL_RUN_CONTEXT.task = task + + if "before_task_step" in runner.instruments: + runner.instruments.call("before_task_step", task) + + next_send_fn = task._next_send_fn + next_send = task._next_send + task._next_send_fn = task._next_send = None + final_outcome: Outcome[object] | None = None + + assert next_send_fn is not None + + try: + # We used to unwrap the Outcome object here and send/throw + # its contents in directly, but it turns out that .throw() + # is buggy on CPython (all versions at time of writing): + # https://bugs.python.org/issue29587 + # https://bugs.python.org/issue29590 + # https://bugs.python.org/issue40694 + # https://github.com/python/cpython/issues/108668 + # So now we send in the Outcome object and unwrap it on the + # other side. + msg = task.context.run(next_send_fn, next_send) + except StopIteration as stop_iteration: + final_outcome = Value(stop_iteration.value) + except BaseException as task_exc: + # Store for later, removing uninteresting top frames: 1 + # frame we always remove, because it's this function + # catching it, and then in addition we remove however many + # more Context.run adds. + tb = task_exc.__traceback__ + for _ in range(1 + CONTEXT_RUN_TB_FRAMES): + if tb is not None: # pragma: no branch + tb = tb.tb_next + final_outcome = Error(task_exc.with_traceback(tb)) + # Remove local refs so that e.g. cancelled coroutine locals + # are not kept alive by this frame until another exception + # comes along. + del tb + + if final_outcome is not None: + # We can't call this directly inside the except: blocks + # above, because then the exceptions end up attaching + # themselves to other exceptions as __context__ in + # unwanted ways. + runner.task_exited(task, final_outcome) + # final_outcome may contain a traceback ref. It's not as + # crucial compared to the above, but this will allow more + # prompt release of resources in coroutine locals. + final_outcome = None + else: + task._schedule_points += 1 + if msg is CancelShieldedCheckpoint: + runner.reschedule(task) + elif type(msg) is WaitTaskRescheduled: + task._cancel_points += 1 + task._abort_func = msg.abort_func + # KI is "outside" all cancel scopes, so check for it + # before checking for regular cancellation: + if runner.ki_pending and task is runner.main_task: + task._attempt_delivery_of_pending_ki() + task._attempt_delivery_of_any_pending_cancel() + elif type(msg) is PermanentlyDetachCoroutineObject: + # Pretend the task just exited with the given outcome + runner.task_exited(task, msg.final_outcome) + else: + exc = TypeError( + f"trio.run received unrecognized yield message {msg!r}. " + "Are you trying to use a library written for some " + "other framework like asyncio? That won't work " + "without some kind of compatibility shim.", + ) + # The foreign library probably doesn't adhere to our + # protocol of unwrapping whatever outcome gets sent in. + # Instead, we'll arrange to throw `exc` in directly, + # which works for at least asyncio and curio. + runner.reschedule(task, exc) # type: ignore[arg-type] + task._next_send_fn = task.coro.throw + # prevent long-lived reference + # TODO: develop test for this deletion + del msg + + if "after_task_step" in runner.instruments: + runner.instruments.call("after_task_step", task) + del GLOBAL_RUN_CONTEXT.task + # prevent long-lived references + # TODO: develop test for this deletion + del task, next_send, next_send_fn + + except GeneratorExit: + # The run-loop generator has been garbage collected without finishing + warnings.warn( + RuntimeWarning( + "Trio guest run got abandoned without properly finishing... " + "weird stuff might happen", + ), + stacklevel=1, + ) + except TrioInternalError: + raise + except BaseException as exc: + raise TrioInternalError("internal error in Trio - please file a bug!") from exc + finally: + runner.close() + GLOBAL_RUN_CONTEXT.__dict__.clear() + + # Have to do this after runner.close() has disabled KI protection, + # because otherwise there's a race where ki_pending could get set + # after we check it. + if runner.ki_pending: + ki = KeyboardInterrupt() + if isinstance(runner.main_task_outcome, Error): + ki.__context__ = runner.main_task_outcome.error + runner.main_task_outcome = Error(ki) + + +################################################################ +# Other public API functions +################################################################ + + +class _TaskStatusIgnored(TaskStatus[object]): + def __repr__(self) -> str: + return "TASK_STATUS_IGNORED" + + def started(self, value: object = None) -> None: + pass + + +TASK_STATUS_IGNORED: Final[TaskStatus[object]] = _TaskStatusIgnored() + + +def current_task() -> Task: + """Return the :class:`Task` object representing the current task. + + Returns: + Task: the :class:`Task` that called :func:`current_task`. + + """ + + try: + return GLOBAL_RUN_CONTEXT.task + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +def current_effective_deadline() -> float: + """Returns the current effective deadline for the current task. + + This function examines all the cancellation scopes that are currently in + effect (taking into account shielding), and returns the deadline that will + expire first. + + One example of where this might be is useful is if your code is trying to + decide whether to begin an expensive operation like an RPC call, but wants + to skip it if it knows that it can't possibly complete in the available + time. Another example would be if you're using a protocol like gRPC that + `propagates timeout information to the remote peer + <http://www.grpc.io/docs/guides/concepts.html#deadlines>`__; this function + gives a way to fetch that information so you can send it along. + + If this is called in a context where a cancellation is currently active + (i.e., a blocking call will immediately raise :exc:`Cancelled`), then + returned deadline is ``-inf``. If it is called in a context where no + scopes have a deadline set, it returns ``inf``. + + Returns: + float: the effective deadline, as an absolute time. + + """ + return current_task()._cancel_status.effective_deadline() + + +async def checkpoint() -> None: + """A pure :ref:`checkpoint <checkpoints>`. + + This checks for cancellation and allows other tasks to be scheduled, + without otherwise blocking. + + Note that the scheduler has the option of ignoring this and continuing to + run the current task if it decides this is appropriate (e.g. for increased + efficiency). + + Equivalent to ``await trio.sleep(0)`` (which is implemented by calling + :func:`checkpoint`.) + + """ + # The scheduler is what checks timeouts and converts them into + # cancellations. So by doing the schedule point first, we ensure that the + # cancel point has the most up-to-date info. + await cancel_shielded_checkpoint() + task = current_task() + task._cancel_points += 1 + if task._cancel_status.effectively_cancelled or ( + task is task._runner.main_task and task._runner.ki_pending + ): + cs = CancelScope(deadline=-inf) + if ( + task._cancel_status._scope._cancel_reason is None + and task is task._runner.main_task + and task._runner.ki_pending + ): + task._cancel_status._scope._cancel_reason = CancelReason( + source="KeyboardInterrupt" + ) + assert task._cancel_status._scope._cancel_reason is not None + cs._cancel_reason = task._cancel_status._scope._cancel_reason + with cs: + await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) + + +async def checkpoint_if_cancelled() -> None: + """Issue a :ref:`checkpoint <checkpoints>` if the calling context has been + cancelled. + + Equivalent to (but potentially more efficient than):: + + if trio.current_effective_deadline() == -inf: + await trio.lowlevel.checkpoint() + + This is either a no-op, or else it allow other tasks to be scheduled and + then raises :exc:`trio.Cancelled`. + + Typically used together with :func:`cancel_shielded_checkpoint`. + + """ + task = current_task() + if task._cancel_status.effectively_cancelled or ( + task is task._runner.main_task and task._runner.ki_pending + ): + await _core.checkpoint() + raise AssertionError("this should never happen") # pragma: no cover + task._cancel_points += 1 + + +def in_trio_run() -> bool: + """Check whether we are in a Trio run. + This returns `True` if and only if :func:`~trio.current_time` will succeed. + + See also the discussion of differing ways of :ref:`detecting Trio <trio_contexts>`. + """ + return hasattr(GLOBAL_RUN_CONTEXT, "runner") + + +def in_trio_task() -> bool: + """Check whether we are in a Trio task. + This returns `True` if and only if :func:`~trio.lowlevel.current_task` will succeed. + + See also the discussion of differing ways of :ref:`detecting Trio <trio_contexts>`. + """ + return hasattr(GLOBAL_RUN_CONTEXT, "task") + + +# export everything for the documentation +if "sphinx.ext.autodoc" in sys.modules: + from ._generated_io_epoll import * + from ._generated_io_kqueue import * + from ._generated_io_windows import * + +if sys.platform == "win32": + from ._generated_io_windows import * + from ._io_windows import ( + EventResult as EventResult, + WindowsIOManager as TheIOManager, + _WindowsStatistics as IOStatistics, + ) +elif sys.platform == "linux" or (not TYPE_CHECKING and hasattr(select, "epoll")): + from ._generated_io_epoll import * + from ._io_epoll import ( + EpollIOManager as TheIOManager, + EventResult as EventResult, + _EpollStatistics as IOStatistics, + ) +elif TYPE_CHECKING or hasattr(select, "kqueue"): + from ._generated_io_kqueue import * + from ._io_kqueue import ( + EventResult as EventResult, + KqueueIOManager as TheIOManager, + _KqueueStatistics as IOStatistics, + ) +else: # pragma: no cover + _patchers = sorted({"eventlet", "gevent"}.intersection(sys.modules)) + if _patchers: + raise NotImplementedError( + "unsupported platform or primitives Trio depends on are monkey-patched out by " + + ", ".join(_patchers), + ) + + raise NotImplementedError("unsupported platform") + +from ._generated_instrumentation import * +from ._generated_run import * diff --git a/contrib/python/trio/trio/_core/_run_context.py b/contrib/python/trio/trio/_core/_run_context.py new file mode 100644 index 00000000000..085bff9a345 --- /dev/null +++ b/contrib/python/trio/trio/_core/_run_context.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Final + +if TYPE_CHECKING: + from ._run import Runner, Task + + +class RunContext(threading.local): + runner: Runner + task: Task + + +GLOBAL_RUN_CONTEXT: Final = RunContext() diff --git a/contrib/python/trio/trio/_core/_thread_cache.py b/contrib/python/trio/trio/_core/_thread_cache.py new file mode 100644 index 00000000000..44820e7711f --- /dev/null +++ b/contrib/python/trio/trio/_core/_thread_cache.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import ctypes +import ctypes.util +import os +import sys +import traceback +from functools import partial +from itertools import count +from threading import Lock, Thread +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +import outcome + +if TYPE_CHECKING: + from collections.abc import Callable + +RetT = TypeVar("RetT") + + +def _to_os_thread_name(name: str) -> bytes: + # ctypes handles the trailing \00 + return name.encode("ascii", errors="replace")[:15] + + +# used to construct the method used to set os thread name, or None, depending on platform. +# called once on import +def get_os_thread_name_func() -> Callable[[int | None, str], None] | None: + def namefunc( + setname: Callable[[int, bytes], int], + ident: int | None, + name: str, + ) -> None: + # Thread.ident is None "if it has not been started". Unclear if that can happen + # with current usage. + if ident is not None: # pragma: no cover + setname(ident, _to_os_thread_name(name)) + + # namefunc on Mac also takes an ident, even if pthread_setname_np doesn't/can't use it + # so the caller don't need to care about platform. + def darwin_namefunc( + setname: Callable[[bytes], int], + ident: int | None, + name: str, + ) -> None: + # I don't know if Mac can rename threads that hasn't been started, but default + # to no to be on the safe side. + if ident is not None: # pragma: no cover + setname(_to_os_thread_name(name)) + + # find the pthread library + # this will fail on windows and musl + libpthread_path = ctypes.util.find_library("pthread") + if not libpthread_path: + # musl includes pthread functions directly in libc.so + # (but note that find_library("c") does not work on musl, + # see: https://github.com/python/cpython/issues/65821) + # so try that library instead + # if it doesn't exist, CDLL() will fail below + libpthread_path = "libc.so" + + # Sometimes windows can find the path, but gives a permission error when + # accessing it. Catching a wider exception in case of more esoteric errors. + # https://github.com/python-trio/trio/issues/2688 + try: + libpthread = ctypes.CDLL(libpthread_path) + except Exception: # pragma: no cover + return None + + # get the setname method from it + # afaik this should never fail + pthread_setname_np = getattr(libpthread, "pthread_setname_np", None) + if pthread_setname_np is None: # pragma: no cover + return None + + # specify function prototype + pthread_setname_np.restype = ctypes.c_int + + # on mac OSX pthread_setname_np does not take a thread id, + # it only lets threads name themselves, which is not a problem for us. + # Just need to make sure to call it correctly + if sys.platform == "darwin": + pthread_setname_np.argtypes = [ctypes.c_char_p] + return partial(darwin_namefunc, pthread_setname_np) + + # otherwise assume linux parameter conventions. Should also work on *BSD + pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + return partial(namefunc, pthread_setname_np) + + +# construct os thread name method +set_os_thread_name = get_os_thread_name_func() + +# The "thread cache" is a simple unbounded thread pool, i.e., it automatically +# spawns as many threads as needed to handle all the requests its given. Its +# only purpose is to cache worker threads so that they don't have to be +# started from scratch every time we want to delegate some work to a thread. +# It's expected that some higher-level code will track how many threads are in +# use to avoid overwhelming the system (e.g. the limiter= argument to +# trio.to_thread.run_sync). +# +# To maximize sharing, there's only one thread cache per process, even if you +# have multiple calls to trio.run. +# +# Guarantees: +# +# It's safe to call start_thread_soon simultaneously from +# multiple threads. +# +# Idle threads are chosen in LIFO order, i.e. we *don't* spread work evenly +# over all threads. Instead we try to let some threads do most of the work +# while others sit idle as much as possible. Compared to FIFO, this has better +# memory cache behavior, and it makes it easier to detect when we have too +# many threads, so idle ones can exit. +# +# This code assumes that 'dict' has the following properties: +# +# - __setitem__, __delitem__, and popitem are all thread-safe and atomic with +# respect to each other. This is guaranteed by the GIL. +# +# - popitem returns the most-recently-added item (i.e., __setitem__ + popitem +# give you a LIFO queue). This relies on dicts being insertion-ordered, like +# they are in py36+. + +# How long a thread will idle waiting for new work before gives up and exits. +# This value is pretty arbitrary; I don't think it matters too much. +IDLE_TIMEOUT = 10 # seconds + +name_counter = count() + + +class WorkerThread(Generic[RetT]): + __slots__ = ("_default_name", "_job", "_thread", "_thread_cache", "_worker_lock") + + def __init__(self, thread_cache: ThreadCache) -> None: + self._job: ( + tuple[ + Callable[[], RetT], + Callable[[outcome.Outcome[RetT]], object], + str | None, + ] + | None + ) = None + self._thread_cache = thread_cache + # This Lock is used in an unconventional way. + # + # "Unlocked" means we have a pending job that's been assigned to us; + # "locked" means that we don't. + # + # Initially we have no job, so it starts out in locked state. + self._worker_lock = Lock() + self._worker_lock.acquire() + self._default_name = f"Trio thread {next(name_counter)}" + + self._thread = Thread(target=self._work, name=self._default_name, daemon=True) + + if set_os_thread_name: + set_os_thread_name(self._thread.ident, self._default_name) + self._thread.start() + + def _handle_job(self) -> None: + # Handle job in a separate method to ensure user-created + # objects are cleaned up in a consistent manner. + assert self._job is not None + fn, deliver, name = self._job + self._job = None + + # set name + if name is not None: + self._thread.name = name + if set_os_thread_name: + set_os_thread_name(self._thread.ident, name) + result = outcome.capture(fn) + + # reset name if it was changed + if name is not None: + self._thread.name = self._default_name + if set_os_thread_name: + set_os_thread_name(self._thread.ident, self._default_name) + + # Tell the cache that we're available to be assigned a new + # job. We do this *before* calling 'deliver', so that if + # 'deliver' triggers a new job, it can be assigned to us + # instead of spawning a new thread. + self._thread_cache._idle_workers[self] = None + try: + deliver(result) + except BaseException as e: + print("Exception while delivering result of thread", file=sys.stderr) + traceback.print_exception(type(e), e, e.__traceback__) + + def _work(self) -> None: + while True: + if self._worker_lock.acquire(timeout=IDLE_TIMEOUT): + # We got a job + self._handle_job() + else: + # Timeout acquiring lock, so we can probably exit. But, + # there's a race condition: we might be assigned a job *just* + # as we're about to exit. So we have to check. + try: + del self._thread_cache._idle_workers[self] + except KeyError: + # Someone else removed us from the idle worker queue, so + # they must be in the process of assigning us a job - loop + # around and wait for it. + continue + else: + # We successfully removed ourselves from the idle + # worker queue, so no more jobs are incoming; it's safe to + # exit. + return + + +class ThreadCache: + __slots__ = ("_idle_workers",) + + def __init__(self) -> None: + self._idle_workers: dict[WorkerThread[Any], None] = {} # type: ignore[explicit-any] + + def start_thread_soon( + self, + fn: Callable[[], RetT], + deliver: Callable[[outcome.Outcome[RetT]], object], + name: str | None = None, + ) -> None: + worker: WorkerThread[RetT] + try: + worker, _ = self._idle_workers.popitem() + except KeyError: + worker = WorkerThread(self) + worker._job = (fn, deliver, name) + worker._worker_lock.release() + + +THREAD_CACHE = ThreadCache() + + +def start_thread_soon( + fn: Callable[[], RetT], + deliver: Callable[[outcome.Outcome[RetT]], object], + name: str | None = None, +) -> None: + """Runs ``deliver(outcome.capture(fn))`` in a worker thread. + + Generally ``fn`` does some blocking work, and ``deliver`` delivers the + result back to whoever is interested. + + This is a low-level, no-frills interface, very similar to using + `threading.Thread` to spawn a thread directly. The main difference is + that this function tries to reuse threads when possible, so it can be + a bit faster than `threading.Thread`. + + Worker threads have the `~threading.Thread.daemon` flag set, which means + that if your main thread exits, worker threads will automatically be + killed. If you want to make sure that your ``fn`` runs to completion, then + you should make sure that the main thread remains alive until ``deliver`` + is called. + + It is safe to call this function simultaneously from multiple threads. + + Args: + + fn (sync function): Performs arbitrary blocking work. + + deliver (sync function): Takes the `outcome.Outcome` of ``fn``, and + delivers it. *Must not block.* + + Because worker threads are cached and reused for multiple calls, neither + function should mutate thread-level state, like `threading.local` objects + – or if they do, they should be careful to revert their changes before + returning. + + Note: + + The split between ``fn`` and ``deliver`` serves two purposes. First, + it's convenient, since most callers need something like this anyway. + + Second, it avoids a small race condition that could cause too many + threads to be spawned. Consider a program that wants to run several + jobs sequentially on a thread, so the main thread submits a job, waits + for it to finish, submits another job, etc. In theory, this program + should only need one worker thread. But what could happen is: + + 1. Worker thread: First job finishes, and calls ``deliver``. + + 2. Main thread: receives notification that the job finished, and calls + ``start_thread_soon``. + + 3. Main thread: sees that no worker threads are marked idle, so spawns + a second worker thread. + + 4. Original worker thread: marks itself as idle. + + To avoid this, threads mark themselves as idle *before* calling + ``deliver``. + + Is this potential extra thread a major problem? Maybe not, but it's + easy enough to avoid, and we figure that if the user is trying to + limit how many threads they're using then it's polite to respect that. + + """ + THREAD_CACHE.start_thread_soon(fn, deliver, name) + + +def clear_worker_threads() -> None: + # This is OK because the child process does not actually have any + # worker threads. Additionally, while WorkerThread keeps a strong + # reference and so would get affected, the only place those are + # stored is here. + THREAD_CACHE._idle_workers.clear() + + +if hasattr(os, "register_at_fork"): + os.register_at_fork(after_in_child=clear_worker_threads) diff --git a/contrib/python/trio/trio/_core/_traps.py b/contrib/python/trio/trio/_core/_traps.py new file mode 100644 index 00000000000..60f72d1295a --- /dev/null +++ b/contrib/python/trio/trio/_core/_traps.py @@ -0,0 +1,310 @@ +"""These are the only functions that ever yield back to the task runner.""" + +from __future__ import annotations + +import enum +import types + +# Jedi gets mad in test_static_tool_sees_class_members if we use collections Callable +from typing import TYPE_CHECKING, Any, Callable, NoReturn, Union, cast + +import attrs +import outcome + +from . import _run + +if TYPE_CHECKING: + from collections.abc import Awaitable, Generator + + from typing_extensions import TypeAlias + + from ._run import Task + +RaiseCancelT: TypeAlias = Callable[[], NoReturn] + + +# This class object is used as a singleton. +# Not exported in the trio._core namespace, but imported directly by _run. +class CancelShieldedCheckpoint: + __slots__ = () + + +# Not exported in the trio._core namespace, but imported directly by _run. [email protected](slots=False) +class WaitTaskRescheduled: + abort_func: Callable[[RaiseCancelT], Abort] + + +# Not exported in the trio._core namespace, but imported directly by _run. [email protected](slots=False) +class PermanentlyDetachCoroutineObject: + final_outcome: outcome.Outcome[object] + + +MessageType: TypeAlias = Union[ + type[CancelShieldedCheckpoint], + WaitTaskRescheduled, + PermanentlyDetachCoroutineObject, + object, +] + + +# Helper for the bottommost 'yield'. You can't use 'yield' inside an async +# function, but you can inside a generator, and if you decorate your generator +# with @types.coroutine, then it's even awaitable. However, it's still not a +# real async function: in particular, it isn't recognized by +# inspect.iscoroutinefunction, and it doesn't trigger the unawaited coroutine +# tracking machinery. Since our traps are public APIs, we make them real async +# functions, and then this helper takes care of the actual yield: +def _real_async_yield( + obj: MessageType, +) -> Generator[MessageType, None, None]: + return (yield obj) + + +# Real yield value is from trio's main loop, but type checkers can't +# understand that, so we cast it to make type checkers understand. +_async_yield = cast( + "Callable[[MessageType], Awaitable[outcome.Outcome[object]]]", + _real_async_yield, +) + + +async def cancel_shielded_checkpoint() -> None: + """Introduce a schedule point, but not a cancel point. + + This is *not* a :ref:`checkpoint <checkpoints>`, but it is half of a + checkpoint, and when combined with :func:`checkpoint_if_cancelled` it can + make a full checkpoint. + + Equivalent to (but potentially more efficient than):: + + with trio.CancelScope(shield=True): + await trio.lowlevel.checkpoint() + + """ + (await _async_yield(CancelShieldedCheckpoint)).unwrap() + + +# Return values for abort functions +class Abort(enum.Enum): + """:class:`enum.Enum` used as the return value from abort functions. + + See :func:`wait_task_rescheduled` for details. + + .. data:: SUCCEEDED + FAILED + + """ + + SUCCEEDED = 1 + FAILED = 2 + + +# Should always return the type a Task "expects", unless you willfully reschedule it +# with a bad value. +async def wait_task_rescheduled( # type: ignore[explicit-any] + abort_func: Callable[[RaiseCancelT], Abort], +) -> Any: + """Put the current task to sleep, with cancellation support. + + This is the lowest-level API for blocking in Trio. Every time a + :class:`~trio.lowlevel.Task` blocks, it does so by calling this function + (usually indirectly via some higher-level API). + + This is a tricky interface with no guard rails. If you can use + :class:`ParkingLot` or the built-in I/O wait functions instead, then you + should. + + Generally the way it works is that before calling this function, you make + arrangements for "someone" to call :func:`reschedule` on the current task + at some later point. + + Then you call :func:`wait_task_rescheduled`, passing in ``abort_func``, an + "abort callback". + + (Terminology: in Trio, "aborting" is the process of attempting to + interrupt a blocked task to deliver a cancellation.) + + There are two possibilities for what happens next: + + 1. "Someone" calls :func:`reschedule` on the current task, and + :func:`wait_task_rescheduled` returns or raises whatever value or error + was passed to :func:`reschedule`. + + 2. The call's context transitions to a cancelled state (e.g. due to a + timeout expiring). When this happens, the ``abort_func`` is called. Its + interface looks like:: + + def abort_func(raise_cancel): + ... + return trio.lowlevel.Abort.SUCCEEDED # or FAILED + + It should attempt to clean up any state associated with this call, and + in particular, arrange that :func:`reschedule` will *not* be called + later. If (and only if!) it is successful, then it should return + :data:`Abort.SUCCEEDED`, in which case the task will automatically be + rescheduled with an appropriate :exc:`~trio.Cancelled` error. + + Otherwise, it should return :data:`Abort.FAILED`. This means that the + task can't be cancelled at this time, and still has to make sure that + "someone" eventually calls :func:`reschedule`. + + At that point there are again two possibilities. You can simply ignore + the cancellation altogether: wait for the operation to complete and + then reschedule and continue as normal. (For example, this is what + :func:`trio.to_thread.run_sync` does if cancellation is disabled.) + The other possibility is that the ``abort_func`` does succeed in + cancelling the operation, but for some reason isn't able to report that + right away. (Example: on Windows, it's possible to request that an + async ("overlapped") I/O operation be cancelled, but this request is + *also* asynchronous – you don't find out until later whether the + operation was actually cancelled or not.) To report a delayed + cancellation, then you should reschedule the task yourself, and call + the ``raise_cancel`` callback passed to ``abort_func`` to raise a + :exc:`~trio.Cancelled` (or possibly :exc:`KeyboardInterrupt`) exception + into this task. Either of the approaches sketched below can work:: + + # Option 1: + # Catch the exception from raise_cancel and inject it into the task. + # (This is what Trio does automatically for you if you return + # Abort.SUCCEEDED.) + trio.lowlevel.reschedule(task, outcome.capture(raise_cancel)) + + # Option 2: + # wait to be woken by "someone", and then decide whether to raise + # the error from inside the task. + outer_raise_cancel = None + def abort(inner_raise_cancel): + nonlocal outer_raise_cancel + outer_raise_cancel = inner_raise_cancel + TRY_TO_CANCEL_OPERATION() + return trio.lowlevel.Abort.FAILED + await wait_task_rescheduled(abort) + if OPERATION_WAS_SUCCESSFULLY_CANCELLED: + # raises the error + outer_raise_cancel() + + In any case it's guaranteed that we only call the ``abort_func`` at most + once per call to :func:`wait_task_rescheduled`. + + Sometimes, it's useful to be able to share some mutable sleep-related data + between the sleeping task, the abort function, and the waking task. You + can use the sleeping task's :data:`~Task.custom_sleep_data` attribute to + store this data, and Trio won't touch it, except to make sure that it gets + cleared when the task is rescheduled. + + .. warning:: + + If your ``abort_func`` raises an error, or returns any value other than + :data:`Abort.SUCCEEDED` or :data:`Abort.FAILED`, then Trio will crash + violently. Be careful! Similarly, it is entirely possible to deadlock a + Trio program by failing to reschedule a blocked task, or cause havoc by + calling :func:`reschedule` too many times. Remember what we said up + above about how you should use a higher-level API if at all possible? + + """ + return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap() + + +async def permanently_detach_coroutine_object( + final_outcome: outcome.Outcome[object], +) -> object: + """Permanently detach the current task from the Trio scheduler. + + Normally, a Trio task doesn't exit until its coroutine object exits. When + you call this function, Trio acts like the coroutine object just exited + and the task terminates with the given outcome. This is useful if you want + to permanently switch the coroutine object over to a different coroutine + runner. + + When the calling coroutine enters this function it's running under Trio, + and when the function returns it's running under the foreign coroutine + runner. + + You should make sure that the coroutine object has released any + Trio-specific resources it has acquired (e.g. nurseries). + + Args: + final_outcome (outcome.Outcome): Trio acts as if the current task exited + with the given return value or exception. + + Returns or raises whatever value or exception the new coroutine runner + uses to resume the coroutine. + + """ + if _run.current_task().child_nurseries: + raise RuntimeError( + "can't permanently detach a coroutine object with open nurseries", + ) + return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome)) + + +async def temporarily_detach_coroutine_object( + abort_func: Callable[[RaiseCancelT], Abort], +) -> object: + """Temporarily detach the current coroutine object from the Trio + scheduler. + + When the calling coroutine enters this function it's running under Trio, + and when the function returns it's running under the foreign coroutine + runner. + + The Trio :class:`Task` will continue to exist, but will be suspended until + you use :func:`reattach_detached_coroutine_object` to resume it. In the + mean time, you can use another coroutine runner to schedule the coroutine + object. In fact, you have to – the function doesn't return until the + coroutine is advanced from outside. + + Note that you'll need to save the current :class:`Task` object to later + resume; you can retrieve it with :func:`current_task`. You can also use + this :class:`Task` object to retrieve the coroutine object – see + :data:`Task.coro`. + + Args: + abort_func: Same as for :func:`wait_task_rescheduled`, except that it + must return :data:`Abort.FAILED`. (If it returned + :data:`Abort.SUCCEEDED`, then Trio would attempt to reschedule the + detached task directly without going through + :func:`reattach_detached_coroutine_object`, which would be bad.) + Your ``abort_func`` should still arrange for whatever the coroutine + object is doing to be cancelled, and then reattach to Trio and call + the ``raise_cancel`` callback, if possible. + + Returns or raises whatever value or exception the new coroutine runner + uses to resume the coroutine. + + """ + return await _async_yield(WaitTaskRescheduled(abort_func)) + + +async def reattach_detached_coroutine_object(task: Task, yield_value: object) -> None: + """Reattach a coroutine object that was detached using + :func:`temporarily_detach_coroutine_object`. + + When the calling coroutine enters this function it's running under the + foreign coroutine runner, and when the function returns it's running under + Trio. + + This must be called from inside the coroutine being resumed, and yields + whatever value you pass in. (Presumably you'll pass a value that will + cause the current coroutine runner to stop scheduling this task.) Then the + coroutine is resumed by the Trio scheduler at the next opportunity. + + Args: + task (Task): The Trio task object that the current coroutine was + detached from. + yield_value (object): The object to yield to the current coroutine + runner. + + """ + # This is a kind of crude check – in particular, it can fail if the + # passed-in task is where the coroutine *runner* is running. But this is + # an experts-only interface, and there's no easy way to do a more accurate + # check, so I guess that's OK. + if not task.coro.cr_running: + raise RuntimeError("given task does not match calling coroutine") + _run.reschedule(task, outcome.Value("reattaching")) + value = await _async_yield(yield_value) + assert value == outcome.Value("reattaching") diff --git a/contrib/python/trio/trio/_core/_unbounded_queue.py b/contrib/python/trio/trio/_core/_unbounded_queue.py new file mode 100644 index 00000000000..b9e7974841c --- /dev/null +++ b/contrib/python/trio/trio/_core/_unbounded_queue.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +import attrs + +from .. import _core +from .._deprecate import deprecated +from .._util import final + +T = TypeVar("T") + +if TYPE_CHECKING: + from typing_extensions import Self + + +class UnboundedQueueStatistics: + """An object containing debugging information. + + Currently, the following fields are defined: + + * ``qsize``: The number of items currently in the queue. + * ``tasks_waiting``: The number of tasks blocked on this queue's + :meth:`get_batch` method. + + """ + + qsize: int + tasks_waiting: int + + +@final +class UnboundedQueue(Generic[T]): + """An unbounded queue suitable for certain unusual forms of inter-task + communication. + + This class is designed for use as a queue in cases where the producer for + some reason cannot be subjected to back-pressure, i.e., :meth:`put_nowait` + has to always succeed. In order to prevent the queue backlog from actually + growing without bound, the consumer API is modified to dequeue items in + "batches". If a consumer task processes each batch without yielding, then + this helps achieve (but does not guarantee) an effective bound on the + queue's memory use, at the cost of potentially increasing system latencies + in general. You should generally prefer to use a memory channel + instead if you can. + + Currently each batch completely empties the queue, but `this may change in + the future <https://github.com/python-trio/trio/issues/51>`__. + + A :class:`UnboundedQueue` object can be used as an asynchronous iterator, + where each iteration returns a new batch of items. I.e., these two loops + are equivalent:: + + async for batch in queue: + ... + + while True: + obj = await queue.get_batch() + ... + + """ + + @deprecated( + "0.9.0", + issue=497, + thing="trio.lowlevel.UnboundedQueue", + instead="trio.open_memory_channel(math.inf)", + use_triodeprecationwarning=True, + ) + def __init__(self) -> None: + self._lot = _core.ParkingLot() + self._data: list[T] = [] + # used to allow handoff from put to the first task in the lot + self._can_get = False + + def __repr__(self) -> str: + return f"<UnboundedQueue holding {len(self._data)} items>" + + def qsize(self) -> int: + """Returns the number of items currently in the queue.""" + return len(self._data) + + def empty(self) -> bool: + """Returns True if the queue is empty, False otherwise. + + There is some subtlety to interpreting this method's return value: see + `issue #63 <https://github.com/python-trio/trio/issues/63>`__. + + """ + return not self._data + + @_core.enable_ki_protection + def put_nowait(self, obj: T) -> None: + """Put an object into the queue, without blocking. + + This always succeeds, because the queue is unbounded. We don't provide + a blocking ``put`` method, because it would never need to block. + + Args: + obj (object): The object to enqueue. + + """ + if not self._data: + assert not self._can_get + if self._lot: + self._lot.unpark(count=1) + else: + self._can_get = True + self._data.append(obj) + + def _get_batch_protected(self) -> list[T]: + data = self._data.copy() + self._data.clear() + self._can_get = False + return data + + def get_batch_nowait(self) -> list[T]: + """Attempt to get the next batch from the queue, without blocking. + + Returns: + list: A list of dequeued items, in order. On a successful call this + list is always non-empty; if it would be empty we raise + :exc:`~trio.WouldBlock` instead. + + Raises: + ~trio.WouldBlock: if the queue is empty. + + """ + if not self._can_get: + raise _core.WouldBlock + return self._get_batch_protected() + + async def get_batch(self) -> list[T]: + """Get the next batch from the queue, blocking as necessary. + + Returns: + list: A list of dequeued items, in order. This list is always + non-empty. + + """ + await _core.checkpoint_if_cancelled() + if not self._can_get: + await self._lot.park() + return self._get_batch_protected() + else: + try: + return self._get_batch_protected() + finally: + await _core.cancel_shielded_checkpoint() + + def statistics(self) -> UnboundedQueueStatistics: + """Return an :class:`UnboundedQueueStatistics` object containing debugging information.""" + return UnboundedQueueStatistics( + qsize=len(self._data), + tasks_waiting=self._lot.statistics().tasks_waiting, + ) + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> list[T]: + return await self.get_batch() diff --git a/contrib/python/trio/trio/_core/_wakeup_socketpair.py b/contrib/python/trio/trio/_core/_wakeup_socketpair.py new file mode 100644 index 00000000000..ea4567017f0 --- /dev/null +++ b/contrib/python/trio/trio/_core/_wakeup_socketpair.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import contextlib +import signal +import socket +import warnings + +from .. import _core +from .._util import is_main_thread + + +class WakeupSocketpair: + def __init__(self) -> None: + # explicitly typed to please `pyright --verifytypes` without `--ignoreexternal` + self.wakeup_sock: socket.socket + self.write_sock: socket.socket + + self.wakeup_sock, self.write_sock = socket.socketpair() + self.wakeup_sock.setblocking(False) + self.write_sock.setblocking(False) + # This somewhat reduces the amount of memory wasted queueing up data + # for wakeups. With these settings, maximum number of 1-byte sends + # before getting BlockingIOError: + # Linux 4.8: 6 + # macOS (darwin 15.5): 1 + # Windows 10: 525347 + # Windows you're weird. (And on Windows setting SNDBUF to 0 makes send + # blocking, even on non-blocking sockets, so don't do that.) + self.wakeup_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1) + self.write_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1) + # On Windows this is a TCP socket so this might matter. On other + # platforms this fails b/c AF_UNIX sockets aren't actually TCP. + with contextlib.suppress(OSError): + self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.old_wakeup_fd: int | None = None + + def wakeup_thread_and_signal_safe(self) -> None: + with contextlib.suppress(BlockingIOError): + self.write_sock.send(b"\x00") + + async def wait_woken(self) -> None: + await _core.wait_readable(self.wakeup_sock) + self.drain() + + def drain(self) -> None: + try: + while True: + self.wakeup_sock.recv(2**16) + except BlockingIOError: + pass + + def wakeup_on_signals(self) -> None: + assert self.old_wakeup_fd is None + if not is_main_thread(): + return + fd = self.write_sock.fileno() + self.old_wakeup_fd = signal.set_wakeup_fd(fd, warn_on_full_buffer=False) + if self.old_wakeup_fd != -1: + warnings.warn( + RuntimeWarning( + "It looks like Trio's signal handling code might have " + "collided with another library you're using. If you're " + "running Trio in guest mode, then this might mean you " + "should set host_uses_signal_set_wakeup_fd=True. " + "Otherwise, file a bug on Trio and we'll help you figure " + "out what's going on.", + ), + stacklevel=1, + ) + + def close(self) -> None: + self.wakeup_sock.close() + self.write_sock.close() + if self.old_wakeup_fd is not None: + signal.set_wakeup_fd(self.old_wakeup_fd) diff --git a/contrib/python/trio/trio/_core/_windows_cffi.py b/contrib/python/trio/trio/_core/_windows_cffi.py new file mode 100644 index 00000000000..0e3c0b10b37 --- /dev/null +++ b/contrib/python/trio/trio/_core/_windows_cffi.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +import enum +from typing import TYPE_CHECKING, NewType, NoReturn, Protocol, cast + +if TYPE_CHECKING: + import cffi + from typing_extensions import TypeAlias + + CData: TypeAlias = cffi.api.FFI.CData + CType: TypeAlias = cffi.api.FFI.CType + +from ._generated_windows_ffi import ffi + +################################################################ +# Functions and types +################################################################ + +if not TYPE_CHECKING: + CData: TypeAlias = ffi.CData + CType: TypeAlias = ffi.CType + +AlwaysNull: TypeAlias = CData # We currently always pass ffi.NULL here. +Handle = NewType("Handle", CData) +HandleArray = NewType("HandleArray", CData) + + +class _Kernel32(Protocol): + """Statically typed version of the kernel32.dll functions we use.""" + + def CreateIoCompletionPort( + self, + FileHandle: Handle, + ExistingCompletionPort: CData | AlwaysNull, + CompletionKey: int, + NumberOfConcurrentThreads: int, + /, + ) -> Handle: ... + + def CreateEventA( + self, + lpEventAttributes: AlwaysNull, + bManualReset: bool, + bInitialState: bool, + lpName: AlwaysNull, + /, + ) -> Handle: ... + + def SetFileCompletionNotificationModes( + self, + handle: Handle, + flags: CompletionModes, + /, + ) -> int: ... + + def PostQueuedCompletionStatus( + self, + CompletionPort: Handle, + dwNumberOfBytesTransferred: int, + dwCompletionKey: int, + lpOverlapped: CData | AlwaysNull, + /, + ) -> bool: ... + + def CancelIoEx( + self, + hFile: Handle, + lpOverlapped: CData | AlwaysNull, + /, + ) -> bool: ... + + def WriteFile( + self, + hFile: Handle, + # not sure about this type + lpBuffer: CData, + nNumberOfBytesToWrite: int, + lpNumberOfBytesWritten: AlwaysNull, + lpOverlapped: _Overlapped, + /, + ) -> bool: ... + + def ReadFile( + self, + hFile: Handle, + # not sure about this type + lpBuffer: CData, + nNumberOfBytesToRead: int, + lpNumberOfBytesRead: AlwaysNull, + lpOverlapped: _Overlapped, + /, + ) -> bool: ... + + def GetQueuedCompletionStatusEx( + self, + CompletionPort: Handle, + lpCompletionPortEntries: CData, + ulCount: int, + ulNumEntriesRemoved: CData, + dwMilliseconds: int, + fAlertable: bool | int, + /, + ) -> CData: ... + + def CreateFileW( + self, + lpFileName: CData, + dwDesiredAccess: FileFlags, + dwShareMode: FileFlags, + lpSecurityAttributes: AlwaysNull, + dwCreationDisposition: FileFlags, + dwFlagsAndAttributes: FileFlags, + hTemplateFile: AlwaysNull, + /, + ) -> Handle: ... + + def WaitForSingleObject(self, hHandle: Handle, dwMilliseconds: int, /) -> CData: ... + + def WaitForMultipleObjects( + self, + nCount: int, + lpHandles: HandleArray, + bWaitAll: bool, + dwMilliseconds: int, + /, + ) -> ErrorCodes: ... + + def SetEvent(self, handle: Handle, /) -> None: ... + + def CloseHandle(self, handle: Handle, /) -> bool: ... + + def DeviceIoControl( + self, + hDevice: Handle, + dwIoControlCode: int, + # this is wrong (it's not always null) + lpInBuffer: AlwaysNull, + nInBufferSize: int, + # this is also wrong + lpOutBuffer: AlwaysNull, + nOutBufferSize: int, + lpBytesReturned: AlwaysNull, + lpOverlapped: CData, + /, + ) -> bool: ... + + +class _Nt(Protocol): + """Statically typed version of the dtdll.dll functions we use.""" + + def RtlNtStatusToDosError(self, status: int, /) -> ErrorCodes: ... + + +class _Ws2(Protocol): + """Statically typed version of the ws2_32.dll functions we use.""" + + def WSAGetLastError(self) -> int: ... + + def WSAIoctl( + self, + socket: CData, + dwIoControlCode: WSAIoctls, + lpvInBuffer: AlwaysNull, + cbInBuffer: int, + lpvOutBuffer: CData, + cbOutBuffer: int, + lpcbBytesReturned: CData, # int* + lpOverlapped: AlwaysNull, + # actually LPWSAOVERLAPPED_COMPLETION_ROUTINE + lpCompletionRoutine: AlwaysNull, + /, + ) -> int: ... + + +class _DummyStruct(Protocol): + Offset: int + OffsetHigh: int + + +class _DummyUnion(Protocol): + DUMMYSTRUCTNAME: _DummyStruct + Pointer: object + + +class _Overlapped(Protocol): + Internal: int + InternalHigh: int + DUMMYUNIONNAME: _DummyUnion + hEvent: Handle + + +kernel32 = cast("_Kernel32", ffi.dlopen("kernel32.dll")) +ntdll = cast("_Nt", ffi.dlopen("ntdll.dll")) +ws2_32 = cast("_Ws2", ffi.dlopen("ws2_32.dll")) + +################################################################ +# Magic numbers +################################################################ + +# Here's a great resource for looking these up: +# https://www.magnumdb.com +# (Tip: check the box to see "Hex value") + +INVALID_HANDLE_VALUE = Handle(ffi.cast("HANDLE", -1)) + + +class ErrorCodes(enum.IntEnum): + STATUS_TIMEOUT = 0x102 + WAIT_TIMEOUT = 0x102 + WAIT_ABANDONED = 0x80 + WAIT_OBJECT_0 = 0x00 # object is signaled + WAIT_FAILED = 0xFFFFFFFF + ERROR_IO_PENDING = 997 + ERROR_OPERATION_ABORTED = 995 + ERROR_ABANDONED_WAIT_0 = 735 + ERROR_INVALID_HANDLE = 6 + ERROR_INVALID_PARAMETER = 87 + ERROR_NOT_FOUND = 1168 + ERROR_NOT_SOCKET = 10038 + + +class FileFlags(enum.IntFlag): + GENERIC_READ = 0x80000000 + SYNCHRONIZE = 0x00100000 + FILE_FLAG_OVERLAPPED = 0x40000000 + FILE_SHARE_READ = 1 + FILE_SHARE_WRITE = 2 + FILE_SHARE_DELETE = 4 + CREATE_NEW = 1 + CREATE_ALWAYS = 2 + OPEN_EXISTING = 3 + OPEN_ALWAYS = 4 + TRUNCATE_EXISTING = 5 + + +class AFDPollFlags(enum.IntFlag): + # These are drawn from a combination of: + # https://github.com/piscisaureus/wepoll/blob/master/src/afd.h + # https://github.com/reactos/reactos/blob/master/sdk/include/reactos/drivers/afd/shared.h + AFD_POLL_RECEIVE = 0x0001 + AFD_POLL_RECEIVE_EXPEDITED = 0x0002 # OOB/urgent data + AFD_POLL_SEND = 0x0004 + AFD_POLL_DISCONNECT = 0x0008 # received EOF (FIN) + AFD_POLL_ABORT = 0x0010 # received RST + AFD_POLL_LOCAL_CLOSE = 0x0020 # local socket object closed + AFD_POLL_CONNECT = 0x0040 # socket is successfully connected + AFD_POLL_ACCEPT = 0x0080 # you can call accept on this socket + AFD_POLL_CONNECT_FAIL = 0x0100 # connect() terminated unsuccessfully + # See WSAEventSelect docs for more details on these four: + AFD_POLL_QOS = 0x0200 + AFD_POLL_GROUP_QOS = 0x0400 + AFD_POLL_ROUTING_INTERFACE_CHANGE = 0x0800 + AFD_POLL_EVENT_ADDRESS_LIST_CHANGE = 0x1000 + + +class WSAIoctls(enum.IntEnum): + SIO_BASE_HANDLE = 0x48000022 + SIO_BSP_HANDLE_SELECT = 0x4800001C + SIO_BSP_HANDLE_POLL = 0x4800001D + + +class CompletionModes(enum.IntFlag): + FILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 0x1 + FILE_SKIP_SET_EVENT_ON_HANDLE = 0x2 + + +class IoControlCodes(enum.IntEnum): + IOCTL_AFD_POLL = 0x00012024 + + +################################################################ +# Generic helpers +################################################################ + + +def _handle(obj: int | CData) -> Handle: + # For now, represent handles as either cffi HANDLEs or as ints. If you + # try to pass in a file descriptor instead, it's not going to work + # out. (For that msvcrt.get_osfhandle does the trick, but I don't know if + # we'll actually need that for anything...) For sockets this doesn't + # matter, Python never allocates an fd. So let's wait until we actually + # encounter the problem before worrying about it. + if isinstance(obj, int): + return Handle(ffi.cast("HANDLE", obj)) + return Handle(obj) + + +def handle_array(count: int) -> HandleArray: + """Make an array of handles.""" + return HandleArray(ffi.new(f"HANDLE[{count}]")) + + +def raise_winerror( + winerror: int | None = None, + *, + filename: str | None = None, + filename2: str | None = None, +) -> NoReturn: + # assert sys.platform == "win32" # TODO: make this work in MyPy + # ... in the meanwhile, ffi.getwinerror() is undefined on non-Windows, necessitating the type + # ignores. + + if winerror is None: + err = ffi.getwinerror() # type: ignore[attr-defined,unused-ignore] + if err is None: + raise RuntimeError("No error set?") + winerror, msg = err + else: + err = ffi.getwinerror(winerror) # type: ignore[attr-defined,unused-ignore] + if err is None: + raise RuntimeError("No error set?") + _, msg = err + # https://docs.python.org/3/library/exceptions.html#OSError + raise OSError(0, msg, filename, winerror, filename2) diff --git a/contrib/python/trio/trio/_deprecate.py b/contrib/python/trio/trio/_deprecate.py new file mode 100644 index 00000000000..5c827b4fda1 --- /dev/null +++ b/contrib/python/trio/trio/_deprecate.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import sys +import warnings +from functools import wraps +from typing import TYPE_CHECKING, ClassVar, TypeVar + +import attrs + +if TYPE_CHECKING: + from collections.abc import Callable + + from typing_extensions import ParamSpec + + ArgsT = ParamSpec("ArgsT") + +RetT = TypeVar("RetT") + + +# We want our warnings to be visible by default (at least for now), but we +# also want it to be possible to override that using the -W switch. AFAICT +# this means we cannot inherit from DeprecationWarning, because the only way +# to make it visible by default then would be to add our own filter at import +# time, but that would override -W switches... +class TrioDeprecationWarning(FutureWarning): + """Warning emitted if you use deprecated Trio functionality. + + While a relatively mature project, Trio remains committed to refining its + design and improving usability. As part of this, we occasionally deprecate + or remove functionality that proves suboptimal. If you use Trio, we + recommend `subscribing to issue #1 + <https://github.com/python-trio/trio/issues/1>`__ to get information about + upcoming deprecations and other backwards compatibility breaking changes. + + Despite the name, this class currently inherits from + :class:`FutureWarning`, not :class:`DeprecationWarning`, because until a + 1.0 release, we want these warnings to be visible by default. You can hide + them by installing a filter or with the ``-W`` switch: see the + :mod:`warnings` documentation for details. + """ + + +def _url_for_issue(issue: int) -> str: + return f"https://github.com/python-trio/trio/issues/{issue}" + + +def _stringify(thing: object) -> str: + if hasattr(thing, "__module__") and hasattr(thing, "__qualname__"): + return f"{thing.__module__}.{thing.__qualname__}" + return str(thing) + + +def warn_deprecated( + thing: object, + version: str, + *, + issue: int | None, + instead: object, + stacklevel: int = 2, + use_triodeprecationwarning: bool = False, +) -> None: + stacklevel += 1 + msg = f"{_stringify(thing)} is deprecated since Trio {version}" + if instead is None: + msg += " with no replacement" + else: + msg += f"; use {_stringify(instead)} instead" + if issue is not None: + msg += f" ({_url_for_issue(issue)})" + if use_triodeprecationwarning: + warning_class: type[Warning] = TrioDeprecationWarning + else: + warning_class = DeprecationWarning + warnings.warn(warning_class(msg), stacklevel=stacklevel) + + +# @deprecated("0.2.0", issue=..., instead=...) +# def ... +def deprecated( + version: str, + *, + thing: object = None, + issue: int | None, + instead: object, + use_triodeprecationwarning: bool = False, +) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]: + def do_wrap(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: + nonlocal thing + + @wraps(fn) + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: + warn_deprecated( + thing, + version, + instead=instead, + issue=issue, + use_triodeprecationwarning=use_triodeprecationwarning, + ) + return fn(*args, **kwargs) + + # If our __module__ or __qualname__ get modified, we want to pick up + # on that, so we read them off the wrapper object instead of the (now + # hidden) fn object + if thing is None: + thing = wrapper + + if wrapper.__doc__ is not None: + doc = wrapper.__doc__ + doc = doc.rstrip() + doc += "\n\n" + doc += f".. deprecated:: {version}\n" + if instead is not None: + doc += f" Use {_stringify(instead)} instead.\n" + if issue is not None: + doc += f" For details, see `issue #{issue} <{_url_for_issue(issue)}>`__.\n" + doc += "\n" + wrapper.__doc__ = doc + + return wrapper + + return do_wrap + + +def deprecated_alias( + old_qualname: str, + new_fn: Callable[ArgsT, RetT], + version: str, + *, + issue: int | None, +) -> Callable[ArgsT, RetT]: + @deprecated(version, issue=issue, instead=new_fn) + @wraps(new_fn, assigned=("__module__", "__annotations__")) + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: + """Deprecated alias.""" + return new_fn(*args, **kwargs) + + wrapper.__qualname__ = old_qualname + wrapper.__name__ = old_qualname.rpartition(".")[-1] + return wrapper + + [email protected](slots=False) +class DeprecatedAttribute: + _not_set: ClassVar[object] = object() + + value: object + version: str + issue: int | None + instead: object = _not_set + + +def deprecate_attributes( + module_name: str, deprecated_attributes: dict[str, DeprecatedAttribute] +) -> None: + def __getattr__(name: str) -> object: + if name in deprecated_attributes: + info = deprecated_attributes[name] + instead = info.instead + if instead is DeprecatedAttribute._not_set: + instead = info.value + thing = f"{module_name}.{name}" + warn_deprecated(thing, info.version, issue=info.issue, instead=instead) + return info.value + + msg = "module '{}' has no attribute '{}'" + raise AttributeError(msg.format(module_name, name)) + + sys.modules[module_name].__getattr__ = __getattr__ # type: ignore[method-assign] diff --git a/contrib/python/trio/trio/_dtls.py b/contrib/python/trio/trio/_dtls.py new file mode 100644 index 00000000000..a7dff634d90 --- /dev/null +++ b/contrib/python/trio/trio/_dtls.py @@ -0,0 +1,1387 @@ +# Implementation of DTLS 1.2, using pyopenssl +# https://datatracker.ietf.org/doc/html/rfc6347 +# +# OpenSSL's APIs for DTLS are extremely awkward and limited, which forces us to jump +# through a *lot* of hoops and implement important chunks of the protocol ourselves. +# Hopefully they fix this before implementing DTLS 1.3, because it's a very different +# protocol, and it's probably impossible to pull tricks like we do here. + +from __future__ import annotations + +import contextlib +import enum +import errno +import hmac +import os +import struct +import warnings +import weakref +from itertools import count +from typing import ( + TYPE_CHECKING, + Generic, + TypeVar, + Union, +) +from weakref import ReferenceType, WeakValueDictionary + +import attrs + +import trio + +from ._util import NoPublicConstructor, final + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Iterable, Iterator + from types import TracebackType + + # See DTLSEndpoint.__init__ for why this is imported here + from OpenSSL import SSL # noqa: TC004 + from typing_extensions import Self, TypeAlias, TypeVarTuple, Unpack + + from trio._socket import AddressFormat + from trio.socket import SocketType + + PosArgsT = TypeVarTuple("PosArgsT") + +MAX_UDP_PACKET_SIZE = 65527 + + +def packet_header_overhead(sock: SocketType) -> int: + if sock.family == trio.socket.AF_INET: + return 28 + else: + return 48 + + +def worst_case_mtu(sock: SocketType) -> int: + if sock.family == trio.socket.AF_INET: + return 576 - packet_header_overhead(sock) + else: + return 1280 - packet_header_overhead(sock) # TODO: test this line + + +def best_guess_mtu(sock: SocketType) -> int: + return 1500 - packet_header_overhead(sock) + + +# There are a bunch of different RFCs that define these codes, so for a +# comprehensive collection look here: +# https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml +class ContentType(enum.IntEnum): + change_cipher_spec = 20 + alert = 21 + handshake = 22 + application_data = 23 + heartbeat = 24 + + +class HandshakeType(enum.IntEnum): + hello_request = 0 + client_hello = 1 + server_hello = 2 + hello_verify_request = 3 + new_session_ticket = 4 + end_of_early_data = 4 + encrypted_extensions = 8 + certificate = 11 + server_key_exchange = 12 + certificate_request = 13 + server_hello_done = 14 + certificate_verify = 15 + client_key_exchange = 16 + finished = 20 + certificate_url = 21 + certificate_status = 22 + supplemental_data = 23 + key_update = 24 + compressed_certificate = 25 + ekt_key = 26 + message_hash = 254 + + +class ProtocolVersion: + DTLS10 = bytes([254, 255]) + DTLS12 = bytes([254, 253]) + + +EPOCH_MASK = 0xFFFF << (6 * 8) + + +# Conventions: +# - All functions that handle network data end in _untrusted. +# - All functions end in _untrusted MUST make sure that bad data from the +# network cannot *only* cause BadPacket to be raised. No IndexError or +# struct.error or whatever. +class BadPacket(Exception): + pass + + +# This checks that the DTLS 'epoch' field is 0, which is true iff we're in the +# initial handshake. It doesn't check the ContentType, because not all +# handshake messages have ContentType==handshake -- for example, +# ChangeCipherSpec is used during the handshake but has its own ContentType. +# +# Cannot fail. +def part_of_handshake_untrusted(packet: bytes) -> bool: + # If the packet is too short, then slicing will successfully return a + # short string, which will necessarily fail to match. + return packet[3:5] == b"\x00\x00" + + +# Cannot fail +def is_client_hello_untrusted(packet: bytes) -> bool: + try: + return ( + packet[0] == ContentType.handshake + and packet[13] == HandshakeType.client_hello + ) + except IndexError: + # Invalid DTLS record + return False + + +# DTLS records are: +# - 1 byte content type +# - 2 bytes version +# - 8 bytes epoch+seqno +# Technically this is 2 bytes epoch then 6 bytes seqno, but we treat it as +# a single 8-byte integer, where epoch changes are represented as jumping +# forward by 2**(6*8). +# - 2 bytes payload length (unsigned big-endian) +# - payload +RECORD_HEADER = struct.Struct("!B2sQH") + + +def to_hex(data: bytes) -> str: # pragma: no cover + return data.hex() + + +class Record: + content_type: int + version: bytes = attrs.field(repr=to_hex) + epoch_seqno: int + payload: bytes = attrs.field(repr=to_hex) + + +def records_untrusted(packet: bytes) -> Iterator[Record]: + i = 0 + while i < len(packet): + try: + ct, version, epoch_seqno, payload_len = RECORD_HEADER.unpack_from(packet, i) + # Marked as no-cover because at time of writing, this code is unreachable + # (records_untrusted only gets called on packets that are either trusted or that + # have passed is_client_hello_untrusted, which filters out short packets) + except struct.error as exc: # pragma: no cover + raise BadPacket("invalid record header") from exc + i += RECORD_HEADER.size + payload = packet[i : i + payload_len] + if len(payload) != payload_len: + raise BadPacket("short record") + i += payload_len + yield Record(ct, version, epoch_seqno, payload) + + +def encode_record(record: Record) -> bytes: + header = RECORD_HEADER.pack( + record.content_type, + record.version, + record.epoch_seqno, + len(record.payload), + ) + return header + record.payload + + +# Handshake messages are: +# - 1 byte message type +# - 3 bytes total message length +# - 2 bytes message sequence number +# - 3 bytes fragment offset +# - 3 bytes fragment length +HANDSHAKE_MESSAGE_HEADER = struct.Struct("!B3sH3s3s") + + +class HandshakeFragment: + msg_type: int + msg_len: int + msg_seq: int + frag_offset: int + frag_len: int + frag: bytes = attrs.field(repr=to_hex) + + +def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment: + # Raises BadPacket if decoding fails + try: + ( + msg_type, + msg_len_bytes, + msg_seq, + frag_offset_bytes, + frag_len_bytes, + ) = HANDSHAKE_MESSAGE_HEADER.unpack_from(payload) + except struct.error as exc: # TODO: test this line + raise BadPacket("bad handshake message header") from exc + # 'struct' doesn't have built-in support for 24-bit integers, so we + # have to do it by hand. These can't fail. + msg_len = int.from_bytes(msg_len_bytes, "big") + frag_offset = int.from_bytes(frag_offset_bytes, "big") + frag_len = int.from_bytes(frag_len_bytes, "big") + frag = payload[HANDSHAKE_MESSAGE_HEADER.size :] + if len(frag) != frag_len: + raise BadPacket("handshake fragment length doesn't match record length") + return HandshakeFragment( + msg_type, + msg_len, + msg_seq, + frag_offset, + frag_len, + frag, + ) + + +def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes: + hs_header = HANDSHAKE_MESSAGE_HEADER.pack( + hsf.msg_type, + hsf.msg_len.to_bytes(3, "big"), + hsf.msg_seq, + hsf.frag_offset.to_bytes(3, "big"), + hsf.frag_len.to_bytes(3, "big"), + ) + return hs_header + hsf.frag + + +def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]: + # Raises BadPacket if parsing fails + # Returns (record epoch_seqno, cookie from the packet, data that should be + # hashed into cookie) + try: + # ClientHello has to be the first record in the packet + record = next(records_untrusted(packet)) + # no-cover because at time of writing, this is unreachable: + # decode_client_hello_untrusted is only called on packets that have passed + # is_client_hello_untrusted, which confirms the content type. + if record.content_type != ContentType.handshake: # pragma: no cover + raise BadPacket("not a handshake record") + fragment = decode_handshake_fragment_untrusted(record.payload) + if fragment.msg_type != HandshakeType.client_hello: + raise BadPacket("not a ClientHello") + # ClientHello can't be fragmented, because reassembly requires holding + # per-connection state, and we refuse to allocate per-connection state + # until after we get a valid ClientHello. + if fragment.frag_offset != 0: + raise BadPacket("fragmented ClientHello") + if fragment.frag_len != fragment.msg_len: + raise BadPacket("fragmented ClientHello") + + # As per RFC 6347: + # + # When responding to a HelloVerifyRequest, the client MUST use the + # same parameter values (version, random, session_id, cipher_suites, + # compression_method) as it did in the original ClientHello. The + # server SHOULD use those values to generate its cookie and verify that + # they are correct upon cookie receipt. + # + # However, the record-layer framing can and will change (e.g. the + # second ClientHello will have a new record-layer sequence number). So + # we need to pull out the handshake message alone, discarding the + # record-layer stuff, and then we're going to hash all of it *except* + # the cookie. + + body = fragment.frag + # ClientHello is: + # + # - 2 bytes client_version + # - 32 bytes random + # - 1 byte session_id length + # - session_id + # - 1 byte cookie length + # - cookie + # - everything else + # + # So to find the cookie, so we need to figure out how long the + # session_id is and skip past it. + session_id_len = body[2 + 32] + cookie_len_offset = 2 + 32 + 1 + session_id_len + cookie_len = body[cookie_len_offset] + + cookie_start = cookie_len_offset + 1 + cookie_end = cookie_start + cookie_len + + before_cookie = body[:cookie_len_offset] + cookie = body[cookie_start:cookie_end] + after_cookie = body[cookie_end:] + + if len(cookie) != cookie_len: + raise BadPacket("short cookie") + return (record.epoch_seqno, cookie, before_cookie + after_cookie) + + except (struct.error, IndexError) as exc: + raise BadPacket("bad ClientHello") from exc + + +class HandshakeMessage: + record_version: bytes = attrs.field(repr=to_hex) + msg_type: HandshakeType + msg_seq: int + body: bytearray = attrs.field(repr=to_hex) + + +# ChangeCipherSpec is part of the handshake, but it's not a "handshake +# message" and can't be fragmented the same way. Sigh. +class PseudoHandshakeMessage: + record_version: bytes = attrs.field(repr=to_hex) + content_type: int + payload: bytes = attrs.field(repr=to_hex) + + +# The final record in a handshake is Finished, which is encrypted, can't be fragmented +# (at least by us), and keeps its record number (because it's in a new epoch). So we +# just pass it through unchanged. (Fortunately, the payload is only a single hash value, +# so the largest it will ever be is 64 bytes for a 512-bit hash. Which is small enough +# that it never requires fragmenting to fit into a UDP packet. +class OpaqueHandshakeMessage: + record: Record + + +_AnyHandshakeMessage: TypeAlias = Union[ + HandshakeMessage, + PseudoHandshakeMessage, + OpaqueHandshakeMessage, +] + + +# This takes a raw outgoing handshake volley that openssl generated, and +# reconstructs the handshake messages inside it, so that we can repack them +# into records while retransmitting. So the data ought to be well-behaved -- +# it's not coming from the network. +def decode_volley_trusted( + volley: bytes, +) -> list[_AnyHandshakeMessage]: + messages: list[_AnyHandshakeMessage] = [] + messages_by_seq = {} + for record in records_untrusted(volley): + # ChangeCipherSpec isn't a handshake message, so it can't be fragmented. + # Handshake messages with epoch > 0 are encrypted, so we can't fragment them + # either. Fortunately, ChangeCipherSpec has a 1 byte payload, and the only + # encrypted handshake message is Finished, whose payload is a single hash value + # -- so 32 bytes for SHA-256, 64 for SHA-512, etc. Neither is going to be so + # large that it has to be fragmented to fit into a single packet. + if record.epoch_seqno & EPOCH_MASK: + messages.append(OpaqueHandshakeMessage(record)) + elif record.content_type in (ContentType.change_cipher_spec, ContentType.alert): + messages.append( + PseudoHandshakeMessage( + record.version, + record.content_type, + record.payload, + ), + ) + else: + assert record.content_type == ContentType.handshake + fragment = decode_handshake_fragment_untrusted(record.payload) + msg_type = HandshakeType(fragment.msg_type) + if fragment.msg_seq not in messages_by_seq: + msg = HandshakeMessage( + record.version, + msg_type, + fragment.msg_seq, + bytearray(fragment.msg_len), + ) + messages.append(msg) + messages_by_seq[fragment.msg_seq] = msg + else: + msg = messages_by_seq[fragment.msg_seq] + assert msg.msg_type == fragment.msg_type + assert msg.msg_seq == fragment.msg_seq + assert len(msg.body) == fragment.msg_len + + msg.body[ + fragment.frag_offset : fragment.frag_offset + fragment.frag_len + ] = fragment.frag + + return messages + + +class RecordEncoder: + def __init__(self) -> None: + self._record_seq = count() + + def set_first_record_number(self, n: int) -> None: + self._record_seq = count(n) + + def encode_volley( + self, + messages: Iterable[_AnyHandshakeMessage], + mtu: int, + ) -> list[bytearray]: + packets = [] + packet = bytearray() + for message in messages: + if isinstance(message, OpaqueHandshakeMessage): + encoded = encode_record(message.record) + if mtu - len(packet) - len(encoded) <= 0: # TODO: test this line + packets.append(packet) + packet = bytearray() + packet += encoded + assert len(packet) <= mtu + elif isinstance(message, PseudoHandshakeMessage): + space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload) + if space <= 0: # TODO: test this line + packets.append(packet) + packet = bytearray() + packet += RECORD_HEADER.pack( + message.content_type, + message.record_version, + next(self._record_seq), + len(message.payload), + ) + packet += message.payload + assert len(packet) <= mtu + else: + msg_len_bytes = len(message.body).to_bytes(3, "big") + frag_offset = 0 + frags_encoded = 0 + # If message.body is empty, then we still want to encode it in one + # fragment, not zero. + while frag_offset < len(message.body) or not frags_encoded: + space = ( + mtu + - len(packet) + - RECORD_HEADER.size + - HANDSHAKE_MESSAGE_HEADER.size + ) + if space <= 0: + packets.append(packet) + packet = bytearray() + continue + frag = message.body[frag_offset : frag_offset + space] + frag_offset_bytes = frag_offset.to_bytes(3, "big") + frag_len_bytes = len(frag).to_bytes(3, "big") + frag_offset += len(frag) + + packet += RECORD_HEADER.pack( + ContentType.handshake, + message.record_version, + next(self._record_seq), + HANDSHAKE_MESSAGE_HEADER.size + len(frag), + ) + + packet += HANDSHAKE_MESSAGE_HEADER.pack( + message.msg_type, + msg_len_bytes, + message.msg_seq, + frag_offset_bytes, + frag_len_bytes, + ) + + packet += frag + + frags_encoded += 1 + assert len(packet) <= mtu + + if packet: + packets.append(packet) + + return packets + + +# This bit requires implementing a bona fide cryptographic protocol, so even though it's +# a simple one let's take a moment to discuss the design. +# +# Our goal is to force new incoming handshakes that claim to be coming from a +# given ip:port to prove that they can also receive packets sent to that +# ip:port. (There's nothing in UDP to stop someone from forging the return +# address, and it's often used for stuff like DoS reflection attacks, where +# an attacker tries to trick us into sending data at some innocent victim.) +# For more details, see: +# +# https://datatracker.ietf.org/doc/html/rfc6347#section-4.2.1 +# +# To do this, when we receive an initial ClientHello, we calculate a magic +# cookie, and send it back as a HelloVerifyRequest. Then the client sends us a +# second ClientHello, this time with the magic cookie included, and after we +# check that this cookie is valid we go ahead and start the handshake proper. +# +# So the magic cookie needs the following properties: +# - No-one can forge it without knowing our secret key +# - It ensures that the ip, port, and ClientHello contents from the response +# match those in the challenge +# - It expires after a short-ish period (so that if an attacker manages to steal one, it +# won't be useful for long) +# - It doesn't require storing any peer-specific state on our side +# +# To do that, we take the ip/port/ClientHello data and compute an HMAC of them, using a +# secret key we generate on startup. We also include: +# +# - The current time (using Trio's clock), rounded to the nearest 30 seconds +# - A random salt +# +# Then the cookie is the salt and the HMAC digest concatenated together. +# +# When verifying a cookie, we use the salt + new ip/port/ClientHello data to recompute +# the HMAC digest, for both the current time and the current time minus 30 seconds, and +# if either of them match, we consider the cookie good. +# +# Including the rounded-off time like this means that each cookie is good for at least +# 30 seconds, and possibly as much as 60 seconds. +# +# The salt is probably not necessary -- I'm pretty sure that all it does is make it hard +# for an attacker to figure out when our clock ticks over a 30 second boundary. Which is +# probably pretty harmless? But it's easier to add the salt than to convince myself that +# it's *completely* harmless, so, salt it is. + +COOKIE_REFRESH_INTERVAL = 30 # seconds +KEY_BYTES = 32 +COOKIE_HASH = "sha256" +SALT_BYTES = 8 +# 32 bytes was the maximum cookie length in DTLS 1.0. DTLS 1.2 raised it to 255. I doubt +# there are any DTLS 1.0 implementations still in the wild, but really 32 bytes is +# plenty, and it also gets rid of a confusing warning in Wireshark output. +# +# We truncate the cookie to 32 bytes, of which 8 bytes is salt, so that leaves 24 bytes +# of truncated HMAC = 192 bit security, which is still massive overkill. (TCP uses 32 +# *bits* for this.) HMAC truncation is explicitly noted as safe in RFC 2104: +# https://datatracker.ietf.org/doc/html/rfc2104#section-5 +COOKIE_LENGTH = 32 + + +def _current_cookie_tick() -> int: + return int(trio.current_time() / COOKIE_REFRESH_INTERVAL) + + +# Simple deterministic and invertible serializer -- i.e., a useful tool for converting +# structured data into something we can cryptographically sign. +def _signable(*fields: bytes) -> bytes: + out: list[bytes] = [] + for field in fields: + out.extend((struct.pack("!Q", len(field)), field)) + return b"".join(out) + + +def _make_cookie( + key: bytes, + salt: bytes, + tick: int, + address: AddressFormat, + client_hello_bits: bytes, +) -> bytes: + assert len(salt) == SALT_BYTES + assert len(key) == KEY_BYTES + + signable_data = _signable( + salt, + struct.pack("!Q", tick), + # address is a mix of strings and ints, and variable length, so pack + # it into a single nested field + _signable(*(str(part).encode() for part in address)), + client_hello_bits, + ) + + return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] + + +def valid_cookie( + key: bytes, + cookie: bytes, + address: AddressFormat, + client_hello_bits: bytes, +) -> bool: + if len(cookie) > SALT_BYTES: + salt = cookie[:SALT_BYTES] + + tick = _current_cookie_tick() + + cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits) + old_cookie = _make_cookie( + key, + salt, + max(tick - 1, 0), + address, + client_hello_bits, + ) + + # I doubt using a short-circuiting 'or' here would leak any meaningful + # information, but why risk it when '|' is just as easy. + return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest( + cookie, + old_cookie, + ) + else: + return False + + +def challenge_for( + key: bytes, + address: AddressFormat, + epoch_seqno: int, + client_hello_bits: bytes, +) -> bytes: + salt = os.urandom(SALT_BYTES) + tick = _current_cookie_tick() + cookie = _make_cookie(key, salt, tick, address, client_hello_bits) + + # HelloVerifyRequest body is: + # - 2 bytes version + # - length-prefixed cookie + # + # The DTLS 1.2 spec says that for this message specifically we should use + # the DTLS 1.0 version. + # + # (It also says the opposite of that, but that part is a mistake: + # https://www.rfc-editor.org/errata/eid4103 + # ). + # + # And I guess we use this for both the message-level and record-level + # ProtocolVersions, since we haven't negotiated anything else yet? + body = ProtocolVersion.DTLS10 + bytes([len(cookie)]) + cookie + + # RFC says have to copy the client's record number + # Errata says it should be handshake message number + # Openssl copies back record sequence number, and always sets message seq + # number 0. So I guess we'll follow openssl. + hs = HandshakeFragment( + msg_type=HandshakeType.hello_verify_request, + msg_len=len(body), + msg_seq=0, + frag_offset=0, + frag_len=len(body), + frag=body, + ) + payload = encode_handshake_fragment(hs) + + packet = encode_record( + Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload), + ) + return packet + + +_T = TypeVar("_T") + + +class _Queue(Generic[_T]): + def __init__(self, incoming_packets_buffer: int | float) -> None: # noqa: PYI041 + self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer) + + +def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: + chunks = [] + while True: + try: + chunk = read_fn(2**14) # max TLS record size + except SSL.WantReadError: + break + chunks.append(chunk) + return b"".join(chunks) + + +async def handle_client_hello_untrusted( + endpoint: DTLSEndpoint, + address: AddressFormat, + packet: bytes, +) -> None: + # it's trivial to write a simple function that directly calls this to + # get code coverage, but it should maybe: + # 1. be removed + # 2. be asserted + # 3. Write a complicated test case where this happens "organically" + if endpoint._listening_context is None: # pragma: no cover + return + + try: + epoch_seqno, cookie, bits = decode_client_hello_untrusted(packet) + except BadPacket: + return + + if endpoint._listening_key is None: + endpoint._listening_key = os.urandom(KEY_BYTES) + + if not valid_cookie(endpoint._listening_key, cookie, address, bits): + challenge_packet = challenge_for( + endpoint._listening_key, + address, + epoch_seqno, + bits, + ) + try: + async with endpoint._send_lock: + await endpoint.socket.sendto(challenge_packet, address) + except (OSError, trio.ClosedResourceError): + pass + else: + # We got a real, valid ClientHello! + stream = DTLSChannel._create(endpoint, address, endpoint._listening_context) + # Our HelloRetryRequest had some sequence number. We need our future sequence + # numbers to be larger than it, so our peer knows that our future records aren't + # stale/duplicates. But, we don't know what this sequence number was. What we do + # know is: + # - the HelloRetryRequest seqno was copied it from the initial ClientHello + # - the new ClientHello has a higher seqno than the initial ClientHello + # So, if we copy the new ClientHello's seqno into our first real handshake + # record and increment from there, that should work. + stream._record_encoder.set_first_record_number(epoch_seqno) + # Process the ClientHello + try: + stream._ssl.bio_write(packet) + stream._ssl.DTLSv1_listen() + except SSL.Error: # pragma: no cover + # ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello + # after all. + return + + # Check if we have an existing association + old_stream = endpoint._streams.get(address) + if old_stream is not None: + if old_stream._client_hello == (cookie, bits): + # ...This was just a duplicate of the last ClientHello, so never mind. + return + else: + # Ok, this *really is* a new handshake; the old stream should go away. + old_stream._set_replaced() + stream._client_hello = (cookie, bits) + endpoint._streams[address] = stream + endpoint._incoming_connections_q.s.send_nowait(stream) + + +async def dtls_receive_loop( + endpoint_ref: ReferenceType[DTLSEndpoint], + sock: SocketType, +) -> None: + try: + while True: + try: + packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE) + except OSError as exc: + if exc.errno == errno.ECONNRESET: + # Windows only: "On a UDP-datagram socket [ECONNRESET] + # indicates a previous send operation resulted in an ICMP Port + # Unreachable message" -- https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom + # + # This is totally useless -- there's nothing we can do with this + # information. So we just ignore it and retry the recv. + continue + else: + raise + endpoint = endpoint_ref() + try: + if endpoint is None: + return + if is_client_hello_untrusted(packet): + await handle_client_hello_untrusted(endpoint, address, packet) + elif address in endpoint._streams: + stream = endpoint._streams[address] + if stream._did_handshake and part_of_handshake_untrusted(packet): + # The peer just sent us more handshake messages, that aren't a + # ClientHello, and we thought the handshake was done. Some of + # the packets that we sent to finish the handshake must have + # gotten lost. So re-send them. We do this directly here instead + # of just putting it into the queue and letting the receiver do + # it, because there's no guarantee that anyone is reading from + # the queue, because we think the handshake is done! + await stream._resend_final_volley() + else: + try: + stream._q.s.send_nowait(packet) + except trio.WouldBlock: + stream._packets_dropped_in_trio += 1 + else: + # Drop packet + pass + finally: + del endpoint + except trio.ClosedResourceError: + # socket was closed + return + except OSError as exc: + if exc.errno in (errno.EBADF, errno.ENOTSOCK): + # socket was closed + return + else: # pragma: no cover + # ??? shouldn't happen + raise + + +class DTLSChannelStatistics: + """Currently this has only one attribute: + + - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of + incoming packets from this peer that Trio successfully received from the + network, but then got dropped because the internal channel buffer was full. If + this is non-zero, then you might want to call ``receive`` more often, or use a + larger ``incoming_packets_buffer``, or just not worry about it because your + UDP-based protocol should be able to handle the occasional lost packet, right? + + """ + + incoming_packets_dropped_in_trio: int + + +@final +class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): + """A DTLS connection. + + This class has no public constructor – you get instances by calling + `DTLSEndpoint.serve` or `~DTLSEndpoint.connect`. + + .. attribute:: endpoint + + The `DTLSEndpoint` that this connection is using. + + .. attribute:: peer_address + + The IP/port of the remote peer that this connection is associated with. + + """ + + def __init__( + self, + endpoint: DTLSEndpoint, + peer_address: AddressFormat, + ctx: SSL.Context, + ) -> None: + self.endpoint = endpoint + self.peer_address = peer_address + self._packets_dropped_in_trio = 0 + self._client_hello = None + self._did_handshake = False + self._ssl = SSL.Connection(ctx) + self._handshake_mtu = 0 + # This calls self._ssl.set_ciphertext_mtu, which is important, because if you + # don't call it then openssl doesn't work. + self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket)) + self._replaced = False + self._closed = False + self._q = _Queue[bytes](endpoint.incoming_packets_buffer) + self._handshake_lock = trio.Lock() + self._record_encoder: RecordEncoder = RecordEncoder() + + self._final_volley: list[_AnyHandshakeMessage] = [] + + def _set_replaced(self) -> None: + self._replaced = True + # Any packets we already received could maybe possibly still be processed, but + # there are no more coming. So we close this on the sender side. + self._q.s.close() + + def _check_replaced(self) -> None: + if self._replaced: + raise trio.BrokenResourceError( + "peer tore down this connection to start a new one", + ) + + # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU + # estimate + + # XX should we send close-notify when closing? It seems particularly pointless for + # DTLS where packets are all independent and can be lost anyway. We do at least need + # to handle receiving it properly though, which might be easier if we send it... + + def close(self) -> None: + """Close this connection. + + `DTLSChannel`\\s don't actually own any OS-level resources – the + socket is owned by the `DTLSEndpoint`, not the individual connections. So + you don't really *have* to call this. But it will interrupt any other tasks + calling `receive` with a `ClosedResourceError`, and cause future attempts to use + this connection to fail. + + You can also use this object as a synchronous or asynchronous context manager. + + """ + if self._closed: + return + self._closed = True + if self.endpoint._streams.get(self.peer_address) is self: + del self.endpoint._streams[self.peer_address] + # Will wake any tasks waiting on self._q.get with a + # ClosedResourceError + self._q.r.close() + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self.close() + + async def aclose(self) -> None: + """Close this connection, but asynchronously. + + This is included to satisfy the `trio.abc.Channel` contract. It's + identical to `close`, but async. + + """ + self.close() + await trio.lowlevel.checkpoint() + + async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None: + packets = self._record_encoder.encode_volley( + volley_messages, + self._handshake_mtu, + ) + for packet in packets: + async with self.endpoint._send_lock: + await self.endpoint.socket.sendto(packet, self.peer_address) + + async def _resend_final_volley(self) -> None: + await self._send_volley(self._final_volley) + + async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None: + """Perform the handshake. + + Calling this is optional – if you don't, then it will be automatically called + the first time you call `send` or `receive`. But calling it explicitly can be + useful in case you want to control the retransmit timeout, use a cancel scope to + place an overall timeout on the handshake, or catch errors from the handshake + specifically. + + It's safe to call this multiple times, or call it simultaneously from multiple + tasks – the first call will perform the handshake, and the rest will be no-ops. + + Args: + + initial_retransmit_timeout (float): Since UDP is an unreliable protocol, it's + possible that some of the packets we send during the handshake will get + lost. To handle this, DTLS uses a timer to automatically retransmit + handshake packets that don't receive a response. This lets you set the + timeout we use to detect packet loss. Ideally, it should be set to ~1.5 + times the round-trip time to your peer, but 1 second is a reasonable + default. There's `some useful guidance here + <https://tlswg.org/dtls13-spec/draft-ietf-tls-dtls13.html#name-timer-values>`__. + + This is the *initial* timeout, because if packets keep being lost then Trio + will automatically back off to longer values, to avoid overloading the + network. + + """ + async with self._handshake_lock: + if self._did_handshake: + return + + timeout = initial_retransmit_timeout + volley_messages: list[_AnyHandshakeMessage] = [] + volley_failed_sends = 0 + + def read_volley() -> list[_AnyHandshakeMessage]: + volley_bytes = _read_loop(self._ssl.bio_read) + new_volley_messages = decode_volley_trusted(volley_bytes) + if ( + new_volley_messages + and volley_messages + and isinstance(new_volley_messages[0], HandshakeMessage) + and isinstance(volley_messages[0], HandshakeMessage) + and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq + ): + # openssl decided to retransmit; discard because we handle + # retransmits ourselves + return [] + else: + return new_volley_messages + + # If we're a client, we send the initial volley. If we're a server, then + # the initial ClientHello has already been inserted into self._ssl's + # read BIO. So either way, we start by generating a new volley. + with contextlib.suppress(SSL.WantReadError): + self._ssl.do_handshake() + volley_messages = read_volley() + # If we don't have messages to send in our initial volley, then something + # has gone very wrong. (I'm not sure this can actually happen without an + # error from OpenSSL, but we check just in case.) + if not volley_messages: # pragma: no cover + raise SSL.Error("something wrong with peer's ClientHello") + + while True: + # -- at this point, we need to either send or re-send a volley -- + assert volley_messages + self._check_replaced() + await self._send_volley(volley_messages) + # -- then this is where we wait for a reply -- + self.endpoint._ensure_receive_loop() + with trio.move_on_after(timeout) as cscope: + async for packet in self._q.r: + self._ssl.bio_write(packet) + try: + self._ssl.do_handshake() + # We ignore generic SSL.Error here, because you can get those + # from random invalid packets + except (SSL.WantReadError, SSL.Error): + pass + else: + # No exception -> the handshake is done, and we can + # switch into data transfer mode. + self._did_handshake = True + # Might be empty, but that's ok -- we'll just send no + # packets. + self._final_volley = read_volley() + await self._send_volley(self._final_volley) + return + maybe_volley = read_volley() + if maybe_volley: + if ( + isinstance(maybe_volley[0], PseudoHandshakeMessage) + and maybe_volley[0].content_type == ContentType.alert + ): # TODO: test this line + # we're sending an alert (e.g. due to a corrupted + # packet). We want to send it once, but don't save it to + # retransmit -- keep the last volley as the current + # volley. + await self._send_volley(maybe_volley) + else: + # We managed to get all of the peer's volley and + # generate a new one ourselves! break out of the 'for' + # loop and restart the timer. + volley_messages = maybe_volley + # "Implementations SHOULD retain the current timer value + # until a transmission without loss occurs, at which + # time the value may be reset to the initial value." + if volley_failed_sends == 0: + timeout = initial_retransmit_timeout + volley_failed_sends = 0 + break + else: + assert self._replaced + self._check_replaced() + if cscope.cancelled_caught: + # Timeout expired. Double timeout for backoff, with a limit of 60 + # seconds (this matches what openssl does, and also the + # recommendation in draft-ietf-tls-dtls13). + timeout = min(2 * timeout, 60.0) + volley_failed_sends += 1 + if volley_failed_sends == 2: + # We tried sending this twice and they both failed. Maybe our + # PMTU estimate is wrong? Let's try dropping it to the minimum + # and hope that helps. + self._handshake_mtu = min( + self._handshake_mtu, + worst_case_mtu(self.endpoint.socket), + ) + + async def send(self, data: bytes) -> None: + """Send a packet of data, securely.""" + + if self._closed: + raise trio.ClosedResourceError + if not data: + raise ValueError("openssl doesn't support sending empty DTLS packets") + if not self._did_handshake: + await self.do_handshake() + self._check_replaced() + self._ssl.write(data) + async with self.endpoint._send_lock: + await self.endpoint.socket.sendto( + _read_loop(self._ssl.bio_read), + self.peer_address, + ) + + async def receive(self) -> bytes: + """Fetch the next packet of data from this connection's peer, waiting if + necessary. + + This is safe to call from multiple tasks simultaneously, in case you have some + reason to do that. And more importantly, it's cancellation-safe, meaning that + cancelling a call to `receive` will never cause a packet to be lost or corrupt + the underlying connection. + + """ + if not self._did_handshake: + await self.do_handshake() + # If the packet isn't really valid, then openssl can decode it to the empty + # string (e.g. b/c it's a late-arriving handshake packet, or a duplicate copy of + # a data packet). Skip over these instead of returning them. + while True: + try: + packet = await self._q.r.receive() + except trio.EndOfChannel: + assert self._replaced + self._check_replaced() + self._ssl.bio_write(packet) + cleartext = _read_loop(self._ssl.read) + if cleartext: + return cleartext + + def set_ciphertext_mtu(self, new_mtu: int) -> None: + """Tells Trio the `largest amount of data that can be sent in a single packet to + this peer <https://en.wikipedia.org/wiki/Maximum_transmission_unit>`__. + + Trio doesn't actually enforce this limit – if you pass a huge packet to `send`, + then we'll dutifully encrypt it and attempt to send it. But calling this method + does have two useful effects: + + - If called before the handshake is performed, then Trio will automatically + fragment handshake messages to fit within the given MTU. It also might + fragment them even smaller, if it detects signs of packet loss, so setting + this should never be necessary to make a successful connection. But, the + packet loss detection only happens after multiple timeouts have expired, so if + you have reason to believe that a smaller MTU is required, then you can set + this to skip those timeouts and establish the connection more quickly. + + - It changes the value returned from `get_cleartext_mtu`. So if you have some + kind of estimate of the network-level MTU, then you can use this to figure out + how much overhead DTLS will need for hashes/padding/etc., and how much space + you have left for your application data. + + The MTU here is measuring the largest UDP *payload* you think can be sent, the + amount of encrypted data that can be handed to the operating system in a single + call to `send`. It should *not* include IP/UDP headers. Note that OS estimates + of the MTU often are link-layer MTUs, so you have to subtract off 28 bytes on + IPv4 and 48 bytes on IPv6 to get the ciphertext MTU. + + By default, Trio assumes an MTU of 1472 bytes on IPv4, and 1452 bytes on IPv6, + which correspond to the common Ethernet MTU of 1500 bytes after accounting for + IP/UDP overhead. + + """ + self._handshake_mtu = new_mtu + self._ssl.set_ciphertext_mtu(new_mtu) + + def get_cleartext_mtu(self) -> int: + """Returns the largest number of bytes that you can pass in a single call to + `send` while still fitting within the network-level MTU. + + See `set_ciphertext_mtu` for more details. + + """ + if not self._did_handshake: + raise trio.NeedHandshakeError + return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return] + + def statistics(self) -> DTLSChannelStatistics: + """Returns a `DTLSChannelStatistics` object with statistics about this connection.""" + return DTLSChannelStatistics(self._packets_dropped_in_trio) + + +@final +class DTLSEndpoint: + """A DTLS endpoint. + + A single UDP socket can handle arbitrarily many DTLS connections simultaneously, + acting as a client or server as needed. A `DTLSEndpoint` object holds a UDP socket + and manages these connections, which are represented as `DTLSChannel` objects. + + Args: + socket: (trio.socket.SocketType): A ``SOCK_DGRAM`` socket. If you want to accept + incoming connections in server mode, then you should probably bind the socket to + some known port. + incoming_packets_buffer (int): Each `DTLSChannel` using this socket has its own + buffer that holds incoming packets until you call `~DTLSChannel.receive` to read + them. This lets you adjust the size of this buffer. `~DTLSChannel.statistics` + lets you check if the buffer has overflowed. + + .. attribute:: socket + incoming_packets_buffer + + Both constructor arguments are also exposed as attributes, in case you need to + access them later. + + """ + + def __init__( + self, + socket: SocketType, + *, + incoming_packets_buffer: int = 10, + ) -> None: + # We do this lazily on first construction, so only people who actually use DTLS + # have to install PyOpenSSL. + global SSL + from OpenSSL import SSL + + # for __del__, in case the next line raises + self._initialized: bool = False + if socket.type != trio.socket.SOCK_DGRAM: + raise ValueError("DTLS requires a SOCK_DGRAM socket") + self._initialized = True + self.socket: SocketType = socket + + self.incoming_packets_buffer = incoming_packets_buffer + self._token = trio.lowlevel.current_trio_token() + # We don't need to track handshaking vs non-handshake connections + # separately. We only keep one connection per remote address; as soon + # as a peer provides a valid cookie, we can immediately tear down the + # old connection. + # {remote address: DTLSChannel} + self._streams: WeakValueDictionary[AddressFormat, DTLSChannel] = ( + WeakValueDictionary() + ) + self._listening_context: SSL.Context | None = None + self._listening_key: bytes | None = None + self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) + self._send_lock = trio.Lock() + self._closed = False + self._receive_loop_spawned = False + + def _ensure_receive_loop(self) -> None: + # We have to spawn this lazily, because on Windows it will immediately error out + # if the socket isn't already bound -- which for clients might not happen until + # after we send our first packet. + if not self._receive_loop_spawned: + trio.lowlevel.spawn_system_task( + dtls_receive_loop, + weakref.ref(self), + self.socket, + ) + self._receive_loop_spawned = True + + def __del__(self) -> None: + # Do nothing if this object was never fully constructed + if not self._initialized: + return + # Close the socket in Trio context (if our Trio context still exists), so that + # the background task gets notified about the closure and can exit. + if not self._closed: + with contextlib.suppress(RuntimeError): + self._token.run_sync_soon(self.close) + # Do this last, because it might raise an exception + warnings.warn( + f"unclosed DTLS endpoint {self!r}", + ResourceWarning, + source=self, + stacklevel=1, + ) + + def close(self) -> None: + """Close this socket, and all associated DTLS connections. + + This object can also be used as a context manager. + + """ + self._closed = True + self.socket.close() + for stream in list(self._streams.values()): + stream.close() + self._incoming_connections_q.s.close() + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self.close() + + def _check_closed(self) -> None: + if self._closed: + raise trio.ClosedResourceError + + async def serve( + self, + ssl_context: SSL.Context, + async_fn: Callable[[DTLSChannel, Unpack[PosArgsT]], Awaitable[object]], + *args: Unpack[PosArgsT], + task_status: trio.TaskStatus[None] = trio.TASK_STATUS_IGNORED, + ) -> None: + """Listen for incoming connections, and spawn a handler for each using an + internal nursery. + + Similar to `~trio.serve_tcp`, this function never returns until cancelled, or + the `DTLSEndpoint` is closed and all handlers have exited. + + Usage commonly looks like:: + + async def handler(dtls_channel): + ... + + async with trio.open_nursery() as nursery: + await nursery.start(dtls_endpoint.serve, ssl_context, handler) + # ... do other things here ... + + The ``dtls_channel`` passed into the handler function has already performed the + "cookie exchange" part of the DTLS handshake, so the peer address is + trustworthy. But the actual cryptographic handshake doesn't happen until you + start using it, giving you a chance for any last minute configuration, and the + option to catch and handle handshake errors. + + Args: + ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for + incoming connections. + async_fn: The handler function that will be invoked for each incoming + connection. + *args: Additional arguments to pass to the handler function. + + """ + self._check_closed() + if self._listening_context is not None: + raise trio.BusyResourceError("another task is already listening") + try: + self.socket.getsockname() + except OSError: # TODO: test this line + raise RuntimeError( + "DTLS socket must be bound before it can serve", + ) from None + self._ensure_receive_loop() + # We do cookie verification ourselves, so tell OpenSSL not to worry about it. + # (See also _inject_client_hello_untrusted.) + ssl_context.set_cookie_verify_callback(lambda *_: True) + set_ssl_context_options(ssl_context) + try: + self._listening_context = ssl_context + task_status.started() + + async def handler_wrapper(stream: DTLSChannel) -> None: + with stream: + await async_fn(stream, *args) + + async with trio.open_nursery() as nursery: + async for stream in self._incoming_connections_q.r: # pragma: no branch + nursery.start_soon(handler_wrapper, stream) + finally: + self._listening_context = None + + def connect( + self, + address: tuple[str, int], + ssl_context: SSL.Context, + ) -> DTLSChannel: + """Initiate an outgoing DTLS connection. + + Notice that this is a synchronous method. That's because it doesn't actually + initiate any I/O – it just sets up a `DTLSChannel` object. The actual handshake + doesn't occur until you start using the `DTLSChannel`. This gives you a chance + to do further configuration first, like setting MTU etc. + + Args: + address: The address to connect to. Usually a (host, port) tuple, like + ``("127.0.0.1", 12345)``. + ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for + this connection. + + Returns: + DTLSChannel + + """ + # it would be nice if we could detect when 'address' is our own endpoint (a + # loopback connection), because that can't work + # but I don't see how to do it reliably + self._check_closed() + set_ssl_context_options(ssl_context) + channel = DTLSChannel._create(self, address, ssl_context) + channel._ssl.set_connect_state() + old_channel = self._streams.get(address) + if old_channel is not None: + old_channel._set_replaced() + self._streams[address] = channel + return channel + + +def set_ssl_context_options(ctx: SSL.Context) -> None: + # These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to + # stop openssl from trying to query the memory BIO's MTU and then breaking, and + # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to + # support and isn't useful anyway -- especially for DTLS where it's equivalent + # to just performing a new handshake. + ctx.set_options( + SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION, # type: ignore[attr-defined] + ) diff --git a/contrib/python/trio/trio/_file_io.py b/contrib/python/trio/trio/_file_io.py new file mode 100644 index 00000000000..3df9b3e4435 --- /dev/null +++ b/contrib/python/trio/trio/_file_io.py @@ -0,0 +1,513 @@ +from __future__ import annotations + +import io +from collections.abc import Callable, Iterable +from functools import partial +from typing import ( + IO, + TYPE_CHECKING, + Any, + AnyStr, + BinaryIO, + Generic, + TypeVar, + Union, + overload, +) + +import trio + +from ._util import async_wraps +from .abc import AsyncResource + +if TYPE_CHECKING: + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + StrOrBytesPath, + ) + from typing_extensions import Literal + + from ._sync import CapacityLimiter + +# This list is also in the docs, make sure to keep them in sync +_FILE_SYNC_ATTRS: set[str] = { + "closed", + "encoding", + "errors", + "fileno", + "isatty", + "newlines", + "readable", + "seekable", + "writable", + # not defined in *IOBase: + "buffer", + "raw", + "line_buffering", + "closefd", + "name", + "mode", + "getvalue", + "getbuffer", +} + +# This list is also in the docs, make sure to keep them in sync +_FILE_ASYNC_METHODS: set[str] = { + "flush", + "read", + "read1", + "readall", + "readinto", + "readline", + "readlines", + "seek", + "tell", + "truncate", + "write", + "writelines", + # not defined in *IOBase: + "readinto1", + "peek", +} + + +FileT = TypeVar("FileT") +FileT_co = TypeVar("FileT_co", covariant=True) +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) +AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True) +AnyStr_contra = TypeVar("AnyStr_contra", str, bytes, contravariant=True) + +# This is a little complicated. IO objects have a lot of methods, and which are available on +# different types varies wildly. We want to match the interface of whatever file we're wrapping. +# This pile of protocols each has one sync method/property, meaning they're going to be compatible +# with a file class that supports that method/property. The ones parameterized with AnyStr take +# either str or bytes depending. + +# The wrapper is then a generic class, where the typevar is set to the type of the sync file we're +# wrapping. For generics, adding a type to self has a special meaning - properties/methods can be +# conditional - it's only valid to call them if the object you're accessing them on is compatible +# with that type hint. By using the protocols, the type checker will be checking to see if the +# wrapped type has that method, and only allow the methods that do to be called. We can then alter +# the signature however it needs to match runtime behaviour. +# More info: https://mypy.readthedocs.io/en/stable/more_types.html#advanced-uses-of-self-types +if TYPE_CHECKING: + from typing_extensions import Buffer, Protocol + + # fmt: off + + class _HasClosed(Protocol): + @property + def closed(self) -> bool: ... + + class _HasEncoding(Protocol): + @property + def encoding(self) -> str: ... + + class _HasErrors(Protocol): + @property + def errors(self) -> str | None: ... + + class _HasFileNo(Protocol): + def fileno(self) -> int: ... + + class _HasIsATTY(Protocol): + def isatty(self) -> bool: ... + + class _HasNewlines(Protocol[T_co]): + # Type varies here - documented to be None, tuple of strings, strings. Typeshed uses Any. + @property + def newlines(self) -> T_co: ... + + class _HasReadable(Protocol): + def readable(self) -> bool: ... + + class _HasSeekable(Protocol): + def seekable(self) -> bool: ... + + class _HasWritable(Protocol): + def writable(self) -> bool: ... + + class _HasBuffer(Protocol): + @property + def buffer(self) -> BinaryIO: ... + + class _HasRaw(Protocol): + @property + def raw(self) -> io.RawIOBase: ... + + class _HasLineBuffering(Protocol): + @property + def line_buffering(self) -> bool: ... + + class _HasCloseFD(Protocol): + @property + def closefd(self) -> bool: ... + + class _HasName(Protocol): + @property + def name(self) -> str: ... + + class _HasMode(Protocol): + @property + def mode(self) -> str: ... + + class _CanGetValue(Protocol[AnyStr_co]): + def getvalue(self) -> AnyStr_co: ... + + class _CanGetBuffer(Protocol): + def getbuffer(self) -> memoryview: ... + + class _CanFlush(Protocol): + def flush(self) -> None: ... + + class _CanRead(Protocol[AnyStr_co]): + def read(self, size: int | None = ..., /) -> AnyStr_co: ... + + class _CanRead1(Protocol): + def read1(self, size: int | None = ..., /) -> bytes: ... + + class _CanReadAll(Protocol[AnyStr_co]): + def readall(self) -> AnyStr_co: ... + + class _CanReadInto(Protocol): + def readinto(self, buf: Buffer, /) -> int | None: ... + + class _CanReadInto1(Protocol): + def readinto1(self, buffer: Buffer, /) -> int: ... + + class _CanReadLine(Protocol[AnyStr_co]): + def readline(self, size: int = ..., /) -> AnyStr_co: ... + + class _CanReadLines(Protocol[AnyStr]): + def readlines(self, hint: int = ..., /) -> list[AnyStr]: ... + + class _CanSeek(Protocol): + def seek(self, target: int, whence: int = 0, /) -> int: ... + + class _CanTell(Protocol): + def tell(self) -> int: ... + + class _CanTruncate(Protocol): + def truncate(self, size: int | None = ..., /) -> int: ... + + class _CanWrite(Protocol[T_contra]): + def write(self, data: T_contra, /) -> int: ... + + class _CanWriteLines(Protocol[T_contra]): + # The lines parameter varies for bytes/str, so use a typevar to make the async match. + def writelines(self, lines: Iterable[T_contra], /) -> None: ... + + class _CanPeek(Protocol[AnyStr_co]): + def peek(self, size: int = 0, /) -> AnyStr_co: ... + + class _CanDetach(Protocol[T_co]): + # The T typevar will be the unbuffered/binary file this file wraps. + def detach(self) -> T_co: ... + + class _CanClose(Protocol): + def close(self) -> None: ... + + +# FileT needs to be covariant for the protocol trick to work - the real IO types are effectively a +# subtype of the protocols. +class AsyncIOWrapper(AsyncResource, Generic[FileT_co]): + """A generic :class:`~io.IOBase` wrapper that implements the :term:`asynchronous + file object` interface. Wrapped methods that could block are executed in + :meth:`trio.to_thread.run_sync`. + + All properties and methods defined in :mod:`~io` are exposed by this + wrapper, if they exist in the wrapped file object. + """ + + def __init__(self, file: FileT_co) -> None: + self._wrapped = file + + @property + def wrapped(self) -> FileT_co: + """object: A reference to the wrapped file object""" + + return self._wrapped + + if not TYPE_CHECKING: + + def __getattr__(self, name: str) -> object: + if name in _FILE_SYNC_ATTRS: + return getattr(self._wrapped, name) + if name in _FILE_ASYNC_METHODS: + meth = getattr(self._wrapped, name) + + @async_wraps(self.__class__, self._wrapped.__class__, name) + async def wrapper( + *args: Callable[..., T], + **kwargs: object | str | bool | CapacityLimiter | None, + ) -> T: + func = partial(meth, *args, **kwargs) + return await trio.to_thread.run_sync(func) + + # cache the generated method + setattr(self, name, wrapper) + return wrapper + + raise AttributeError(name) + + def __dir__(self) -> Iterable[str]: + attrs = set(super().__dir__()) + attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a)) + attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a)) + return attrs + + def __aiter__(self) -> AsyncIOWrapper[FileT_co]: + return self + + async def __anext__(self: AsyncIOWrapper[_CanReadLine[AnyStr]]) -> AnyStr: + line = await self.readline() + if line: + return line + else: + raise StopAsyncIteration + + async def detach(self: AsyncIOWrapper[_CanDetach[T]]) -> AsyncIOWrapper[T]: + """Like :meth:`io.BufferedIOBase.detach`, but async. + + This also re-wraps the result in a new :term:`asynchronous file object` + wrapper. + + """ + + raw = await trio.to_thread.run_sync(self._wrapped.detach) + return wrap_file(raw) + + async def aclose(self: AsyncIOWrapper[_CanClose]) -> None: + """Like :meth:`io.IOBase.close`, but async. + + This is also shielded from cancellation; if a cancellation scope is + cancelled, the wrapped file object will still be safely closed. + + """ + + # ensure the underling file is closed during cancellation + with trio.CancelScope(shield=True): + await trio.to_thread.run_sync(self._wrapped.close) + + await trio.lowlevel.checkpoint_if_cancelled() + + if TYPE_CHECKING: + # fmt: off + # Based on typing.IO and io stubs. + @property + def closed(self: AsyncIOWrapper[_HasClosed]) -> bool: ... + @property + def encoding(self: AsyncIOWrapper[_HasEncoding]) -> str: ... + @property + def errors(self: AsyncIOWrapper[_HasErrors]) -> str | None: ... + @property + def newlines(self: AsyncIOWrapper[_HasNewlines[T]]) -> T: ... + @property + def buffer(self: AsyncIOWrapper[_HasBuffer]) -> BinaryIO: ... + @property + def raw(self: AsyncIOWrapper[_HasRaw]) -> io.RawIOBase: ... + @property + def line_buffering(self: AsyncIOWrapper[_HasLineBuffering]) -> int: ... + @property + def closefd(self: AsyncIOWrapper[_HasCloseFD]) -> bool: ... + @property + def name(self: AsyncIOWrapper[_HasName]) -> str: ... + @property + def mode(self: AsyncIOWrapper[_HasMode]) -> str: ... + + def fileno(self: AsyncIOWrapper[_HasFileNo]) -> int: ... + def isatty(self: AsyncIOWrapper[_HasIsATTY]) -> bool: ... + def readable(self: AsyncIOWrapper[_HasReadable]) -> bool: ... + def seekable(self: AsyncIOWrapper[_HasSeekable]) -> bool: ... + def writable(self: AsyncIOWrapper[_HasWritable]) -> bool: ... + def getvalue(self: AsyncIOWrapper[_CanGetValue[AnyStr]]) -> AnyStr: ... + def getbuffer(self: AsyncIOWrapper[_CanGetBuffer]) -> memoryview: ... + async def flush(self: AsyncIOWrapper[_CanFlush]) -> None: ... + async def read(self: AsyncIOWrapper[_CanRead[AnyStr]], size: int | None = -1, /) -> AnyStr: ... + async def read1(self: AsyncIOWrapper[_CanRead1], size: int | None = -1, /) -> bytes: ... + async def readall(self: AsyncIOWrapper[_CanReadAll[AnyStr]]) -> AnyStr: ... + async def readinto(self: AsyncIOWrapper[_CanReadInto], buf: Buffer, /) -> int | None: ... + async def readline(self: AsyncIOWrapper[_CanReadLine[AnyStr]], size: int = -1, /) -> AnyStr: ... + async def readlines(self: AsyncIOWrapper[_CanReadLines[AnyStr]]) -> list[AnyStr]: ... + async def seek(self: AsyncIOWrapper[_CanSeek], target: int, whence: int = 0, /) -> int: ... + async def tell(self: AsyncIOWrapper[_CanTell]) -> int: ... + async def truncate(self: AsyncIOWrapper[_CanTruncate], size: int | None = None, /) -> int: ... + async def write(self: AsyncIOWrapper[_CanWrite[T]], data: T, /) -> int: ... + async def writelines(self: AsyncIOWrapper[_CanWriteLines[T]], lines: Iterable[T], /) -> None: ... + async def readinto1(self: AsyncIOWrapper[_CanReadInto1], buffer: Buffer, /) -> int: ... + async def peek(self: AsyncIOWrapper[_CanPeek[AnyStr]], size: int = 0, /) -> AnyStr: ... + + +# Type hints are copied from builtin open. +_OpenFile = Union["StrOrBytesPath", int] +_Opener = Callable[[str, int], int] + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenTextMode = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.TextIOWrapper]: ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.FileIO]: ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedRandom]: ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedWriter]: ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedReader]: ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryMode, + buffering: int, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[BinaryIO]: ... + + +@overload +async def open_file( # type: ignore[explicit-any] # Any usage matches builtins.open(). + file: _OpenFile, + mode: str, + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[IO[Any]]: ... + + +async def open_file( + file: _OpenFile, + mode: str = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[object]: + """Asynchronous version of :func:`open`. + + Returns: + An :term:`asynchronous file object` + + Example:: + + async with await trio.open_file(filename) as f: + async for line in f: + pass + + assert f.closed + + See also: + :func:`trio.Path.open` + + """ + file_ = wrap_file( + await trio.to_thread.run_sync( + io.open, + file, + mode, + buffering, + encoding, + errors, + newline, + closefd, + opener, + ), + ) + return file_ + + +def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]: + """This wraps any file object in a wrapper that provides an asynchronous + file object interface. + + Args: + file: a :term:`file object` + + Returns: + An :term:`asynchronous file object` that wraps ``file`` + + Example:: + + async_file = trio.wrap_file(StringIO('asdf')) + + assert await async_file.read() == 'asdf' + + """ + + def has(attr: str) -> bool: + return hasattr(file, attr) and callable(getattr(file, attr)) + + if not (has("close") and (has("read") or has("write"))): + raise TypeError( + f"{file} does not implement required duck-file methods: " + "close and (read or write)", + ) + + return AsyncIOWrapper(file) diff --git a/contrib/python/trio/trio/_highlevel_generic.py b/contrib/python/trio/trio/_highlevel_generic.py new file mode 100644 index 00000000000..9bd8822c9e0 --- /dev/null +++ b/contrib/python/trio/trio/_highlevel_generic.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +import attrs + +import trio +from trio._util import final + +from .abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream + +if TYPE_CHECKING: + from typing_extensions import TypeGuard + + +SendStreamT = TypeVar("SendStreamT", bound=SendStream) +ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream) + + +async def aclose_forcefully(resource: AsyncResource) -> None: + """Close an async resource or async generator immediately, without + blocking to do any graceful cleanup. + + :class:`~trio.abc.AsyncResource` objects guarantee that if their + :meth:`~trio.abc.AsyncResource.aclose` method is cancelled, then they will + still close the resource (albeit in a potentially ungraceful + fashion). :func:`aclose_forcefully` is a convenience function that + exploits this behavior to let you force a resource to be closed without + blocking: it works by calling ``await resource.aclose()`` and then + cancelling it immediately. + + Most users won't need this, but it may be useful on cleanup paths where + you can't afford to block, or if you want to close a resource and don't + care about handling it gracefully. For example, if + :class:`~trio.SSLStream` encounters an error and cannot perform its + own graceful close, then there's no point in waiting to gracefully shut + down the underlying transport either, so it calls ``await + aclose_forcefully(self.transport_stream)``. + + Note that this function is async, and that it acts as a checkpoint, but + unlike most async functions it cannot block indefinitely (at least, + assuming the underlying resource object is correctly implemented). + + """ + with trio.CancelScope() as cs: + cs.cancel(reason="cancelled during aclose_forcefully") + await resource.aclose() + + +def _is_halfclosable(stream: SendStream) -> TypeGuard[HalfCloseableStream]: + """Check if the stream has a send_eof() method.""" + return hasattr(stream, "send_eof") + + +@final [email protected](eq=False, slots=False) +class StapledStream( + HalfCloseableStream, + Generic[SendStreamT, ReceiveStreamT], +): + """This class `staples <https://en.wikipedia.org/wiki/Staple_(fastener)>`__ + together two unidirectional streams to make single bidirectional stream. + + Args: + send_stream (~trio.abc.SendStream): The stream to use for sending. + receive_stream (~trio.abc.ReceiveStream): The stream to use for + receiving. + + Example: + + A silly way to make a stream that echoes back whatever you write to + it:: + + left, right = trio.testing.memory_stream_pair() + echo_stream = StapledStream(SocketStream(left), SocketStream(right)) + await echo_stream.send_all(b"x") + assert await echo_stream.receive_some() == b"x" + + :class:`StapledStream` objects implement the methods in the + :class:`~trio.abc.HalfCloseableStream` interface. They also have two + additional public attributes: + + .. attribute:: send_stream + + The underlying :class:`~trio.abc.SendStream`. :meth:`send_all` and + :meth:`wait_send_all_might_not_block` are delegated to this object. + + .. attribute:: receive_stream + + The underlying :class:`~trio.abc.ReceiveStream`. :meth:`receive_some` + is delegated to this object. + + """ + + send_stream: SendStreamT + receive_stream: ReceiveStreamT + + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + """Calls ``self.send_stream.send_all``.""" + return await self.send_stream.send_all(data) + + async def wait_send_all_might_not_block(self) -> None: + """Calls ``self.send_stream.wait_send_all_might_not_block``.""" + return await self.send_stream.wait_send_all_might_not_block() + + async def send_eof(self) -> None: + """Shuts down the send side of the stream. + + If :meth:`self.send_stream.send_eof() <trio.abc.HalfCloseableStream.send_eof>` exists, + then this calls it. Otherwise, this calls + :meth:`self.send_stream.aclose() <trio.abc.AsyncResource.aclose>`. + """ + stream = self.send_stream + if _is_halfclosable(stream): + return await stream.send_eof() + else: + return await stream.aclose() + + # we intentionally accept more types from the caller than we support returning + async def receive_some(self, max_bytes: int | None = None) -> bytes: + """Calls ``self.receive_stream.receive_some``.""" + return await self.receive_stream.receive_some(max_bytes) + + async def aclose(self) -> None: + """Calls ``aclose`` on both underlying streams.""" + try: + await self.send_stream.aclose() + finally: + await self.receive_stream.aclose() diff --git a/contrib/python/trio/trio/_highlevel_open_tcp_listeners.py b/contrib/python/trio/trio/_highlevel_open_tcp_listeners.py new file mode 100644 index 00000000000..023b2b240f3 --- /dev/null +++ b/contrib/python/trio/trio/_highlevel_open_tcp_listeners.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import errno +import sys +from typing import TYPE_CHECKING + +import trio +from trio import TaskStatus + +from . import socket as tsocket + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + + +# Default backlog size: +# +# Having the backlog too low can cause practical problems (a perfectly healthy +# service that starts failing to accept connections if they arrive in a +# burst). +# +# Having it too high doesn't really cause any problems. Like any buffer, you +# want backlog queue to be zero usually, and it won't save you if you're +# getting connection attempts faster than you can call accept() on an ongoing +# basis. But unlike other buffers, this one doesn't really provide any +# backpressure. If a connection gets stuck waiting in the backlog queue, then +# from the peer's point of view the connection succeeded but then their +# send/recv will stall until we get to it, possibly for a long time. OTOH if +# there isn't room in the backlog queue, then their connect stalls, possibly +# for a long time, which is pretty much the same thing. +# +# A large backlog can also use a bit more kernel memory, but this seems fairly +# negligible these days. +# +# So this suggests we should make the backlog as large as possible. This also +# matches what Golang does. However, they do it in a weird way, where they +# have a bunch of code to sniff out the configured upper limit for backlog on +# different operating systems. But on every system, passing in a too-large +# backlog just causes it to be silently truncated to the configured maximum, +# so this is unnecessary -- we can just pass in "infinity" and get the maximum +# that way. (Verified on Windows, Linux, macOS using +# https://github.com/python-trio/trio/wiki/notes-to-self#measure-listen-backlogpy +def _compute_backlog(backlog: int | None) -> int: + # Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are + # missing overflow protection, so we apply our own overflow protection. + # https://github.com/golang/go/issues/5030 + if not isinstance(backlog, int) and backlog is not None: + raise TypeError(f"backlog must be an int or None, not {backlog!r}") + if backlog is None: + return 0xFFFF + return min(backlog, 0xFFFF) + + +async def open_tcp_listeners( + port: int, + *, + host: str | bytes | None = None, + backlog: int | None = None, +) -> list[trio.SocketListener]: + """Create :class:`SocketListener` objects to listen for TCP connections. + + Args: + + port (int): The port to listen on. + + If you use 0 as your port, then the kernel will automatically pick + an arbitrary open port. But be careful: if you use this feature when + binding to multiple IP addresses, then each IP address will get its + own random port, and the returned listeners will probably be + listening on different ports. In particular, this will happen if you + use ``host=None`` – which is the default – because in this case + :func:`open_tcp_listeners` will bind to both the IPv4 wildcard + address (``0.0.0.0``) and also the IPv6 wildcard address (``::``). + + host (str, bytes, or None): The local interface to bind to. This is + passed to :func:`~socket.getaddrinfo` with the ``AI_PASSIVE`` flag + set. + + If you want to bind to the wildcard address on both IPv4 and IPv6, + in order to accept connections on all available interfaces, then + pass ``None``. This is the default. + + If you have a specific interface you want to bind to, pass its IP + address or hostname here. If a hostname resolves to multiple IP + addresses, this function will open one listener on each of them. + + If you want to use only IPv4, or only IPv6, but want to accept on + all interfaces, pass the family-specific wildcard address: + ``"0.0.0.0"`` for IPv4-only and ``"::"`` for IPv6-only. + + backlog (int or None): The listen backlog to use. If you leave this as + ``None`` then Trio will pick a good default. (Currently: whatever + your system has configured as the maximum backlog.) + + Returns: + list of :class:`SocketListener` + + Raises: + :class:`TypeError` if invalid arguments. + + """ + # getaddrinfo sometimes allows port=None, sometimes not (depending on + # whether host=None). And on some systems it treats "" as 0, others it + # doesn't: + # http://klickverbot.at/blog/2012/01/getaddrinfo-edge-case-behavior-on-windows-linux-and-osx/ + if not isinstance(port, int): + raise TypeError(f"port must be an int not {port!r}") + + computed_backlog = _compute_backlog(backlog) + + addresses = await tsocket.getaddrinfo( + host, + port, + type=tsocket.SOCK_STREAM, + flags=tsocket.AI_PASSIVE, + ) + + listeners = [] + unsupported_address_families = [] + try: + for family, type_, proto, _, sockaddr in addresses: + try: + sock = tsocket.socket(family, type_, proto) + except OSError as ex: + if ex.errno == errno.EAFNOSUPPORT: + # If a system only supports IPv4, or only IPv6, it + # is still likely that getaddrinfo will return + # both an IPv4 and an IPv6 address. As long as at + # least one of the returned addresses can be + # turned into a socket, we won't complain about a + # failure to create the other. + unsupported_address_families.append(ex) + continue + else: + raise + try: + # See https://github.com/python-trio/trio/issues/39 + if sys.platform != "win32": + sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_REUSEADDR, 1) + + if family == tsocket.AF_INET6: + sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, 1) + + await sock.bind(sockaddr) + sock.listen(computed_backlog) + + listeners.append(trio.SocketListener(sock)) + except: + sock.close() + raise + except: + for listener in listeners: + listener.socket.close() + raise + + if unsupported_address_families and not listeners: + msg = ( + "This system doesn't support any of the kinds of " + "socket that that address could use" + ) + raise OSError(errno.EAFNOSUPPORT, msg) from ExceptionGroup( + msg, + unsupported_address_families, + ) + + return listeners + + +async def serve_tcp( + handler: Callable[[trio.SocketStream], Awaitable[object]], + port: int, + *, + host: str | bytes | None = None, + backlog: int | None = None, + handler_nursery: trio.Nursery | None = None, + task_status: TaskStatus[list[trio.SocketListener]] = trio.TASK_STATUS_IGNORED, +) -> None: + """Listen for incoming TCP connections, and for each one start a task + running ``handler(stream)``. + + This is a thin convenience wrapper around :func:`open_tcp_listeners` and + :func:`serve_listeners` – see them for full details. + + .. warning:: + + If ``handler`` raises an exception, then this function doesn't do + anything special to catch it – so by default the exception will + propagate out and crash your server. If you don't want this, then catch + exceptions inside your ``handler``, or use a ``handler_nursery`` object + that responds to exceptions in some other way. + + When used with ``nursery.start`` you get back the newly opened listeners. + So, for example, if you want to start a server in your test suite and then + connect to it to check that it's working properly, you can use something + like:: + + from trio import SocketListener, SocketStream + from trio.testing import open_stream_to_socket_listener + + async with trio.open_nursery() as nursery: + listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) + client_stream: SocketStream = await open_stream_to_socket_listener(listeners[0]) + + # Then send and receive data on 'client_stream', for example: + await client_stream.send_all(b"GET / HTTP/1.0\\r\\n\\r\\n") + + This avoids several common pitfalls: + + 1. It lets the kernel pick a random open port, so your test suite doesn't + depend on any particular port being open. + + 2. It waits for the server to be accepting connections on that port before + ``start`` returns, so there's no race condition where the incoming + connection arrives before the server is ready. + + 3. It uses the Listener object to find out which port was picked, so it + can connect to the right place. + + Args: + handler: The handler to start for each incoming connection. Passed to + :func:`serve_listeners`. + + port: The port to listen on. Use 0 to let the kernel pick an open port. + Passed to :func:`open_tcp_listeners`. + + host (str, bytes, or None): The host interface to listen on; use + ``None`` to bind to the wildcard address. Passed to + :func:`open_tcp_listeners`. + + backlog: The listen backlog, or None to have a good default picked. + Passed to :func:`open_tcp_listeners`. + + handler_nursery: The nursery to start handlers in, or None to use an + internal nursery. Passed to :func:`serve_listeners`. + + task_status: This function can be used with ``nursery.start``. + + Returns: + This function only returns when cancelled. + + """ + listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog) + await trio.serve_listeners( + handler, + listeners, + handler_nursery=handler_nursery, + task_status=task_status, + ) diff --git a/contrib/python/trio/trio/_highlevel_open_tcp_stream.py b/contrib/python/trio/trio/_highlevel_open_tcp_stream.py new file mode 100644 index 00000000000..1787f4a97e9 --- /dev/null +++ b/contrib/python/trio/trio/_highlevel_open_tcp_stream.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +import sys +from contextlib import contextmanager, suppress +from typing import TYPE_CHECKING, Any + +import trio +from trio.socket import SOCK_STREAM, SocketType, getaddrinfo, socket + +if TYPE_CHECKING: + from collections.abc import Generator, MutableSequence + from socket import AddressFamily, SocketKind + + from trio._socket import AddressFormat + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup, ExceptionGroup + + +# Implementation of RFC 6555 "Happy eyeballs" +# https://tools.ietf.org/html/rfc6555 +# +# Basically, the problem here is that if we want to connect to some host, and +# DNS returns multiple IP addresses, then we don't know which of them will +# actually work -- it can happen that some of them are reachable, and some of +# them are not. One particularly common situation where this happens is on a +# host that thinks it has ipv6 connectivity, but really doesn't. But in +# principle this could happen for any kind of multi-home situation (e.g. the +# route to one mirror is down but another is up). +# +# The naive algorithm (e.g. the stdlib's socket.create_connection) would be to +# pick one of the IP addresses and try to connect; if that fails, try the +# next; etc. The problem with this is that TCP is stubborn, and if the first +# address is a blackhole then it might take a very long time (tens of seconds) +# before that connection attempt fails. +# +# That's where RFC 6555 comes in. It tells us that what we do is: +# - get the list of IPs from getaddrinfo, trusting the order it gives us (with +# one exception noted in section 5.4) +# - start a connection attempt to the first IP +# - when this fails OR if it's still going after DELAY seconds, then start a +# connection attempt to the second IP +# - when this fails OR if it's still going after another DELAY seconds, then +# start a connection attempt to the third IP +# - ... repeat until we run out of IPs. +# +# Our implementation is similarly straightforward: we spawn a chain of tasks, +# where each one (a) waits until the previous connection has failed or DELAY +# seconds have passed, (b) spawns the next task, (c) attempts to connect. As +# soon as any task crashes or succeeds, we cancel all the tasks and return. +# +# Note: this currently doesn't attempt to cache any results, so if you make +# multiple connections to the same host it'll re-run the happy-eyeballs +# algorithm each time. RFC 6555 is pretty confusing about whether this is +# allowed. Section 4 describes an algorithm that attempts ipv4 and ipv6 +# simultaneously, and then says "The client MUST cache information regarding +# the outcome of each connection attempt, and it uses that information to +# avoid thrashing the network with subsequent attempts." Then section 4.2 says +# "implementations MUST prefer the first IP address family returned by the +# host's address preference policy, unless implementing a stateful +# algorithm". Here "stateful" means "one that caches information about +# previous attempts". So my reading of this is that IF you're starting ipv4 +# and ipv6 at the same time then you MUST cache the result for ~ten minutes, +# but IF you're "preferring" one protocol by trying it first (like we are), +# then you don't need to cache. +# +# Caching is quite tricky: to get it right you need to do things like detect +# when the network interfaces are reconfigured, and if you get it wrong then +# connection attempts basically just don't work. So we don't even try. + +# "Firefox and Chrome use 300 ms" +# https://tools.ietf.org/html/rfc6555#section-6 +# Though +# https://www.researchgate.net/profile/Vaibhav_Bajpai3/publication/304568993_Measuring_the_Effects_of_Happy_Eyeballs/links/5773848e08ae6f328f6c284c/Measuring-the-Effects-of-Happy-Eyeballs.pdf +# claims that Firefox actually uses 0 ms, unless an about:config option is +# toggled and then it uses 250 ms. +DEFAULT_DELAY = 0.250 + +# How should we call getaddrinfo? In particular, should we use AI_ADDRCONFIG? +# +# The idea of AI_ADDRCONFIG is that it only returns addresses that might +# work. E.g., if getaddrinfo knows that you don't have any IPv6 connectivity, +# then it doesn't return any IPv6 addresses. And this is kinda nice, because +# it means maybe you can skip sending AAAA requests entirely. But in practice, +# it doesn't really work right. +# +# - on Linux/glibc, empirically, the default is to return all addresses, and +# with AI_ADDRCONFIG then it only returns IPv6 addresses if there is at least +# one non-loopback IPv6 address configured... but this can be a link-local +# address, so in practice I guess this is basically always configured if IPv6 +# is enabled at all. OTOH if you pass in "::1" as the target address with +# AI_ADDRCONFIG and there's no *external* IPv6 address configured, you get an +# error. So AI_ADDRCONFIG mostly doesn't do anything, even when you would want +# it to, and when it does do something it might break things that would have +# worked. +# +# - on Windows 10, empirically, if no IPv6 address is configured then by +# default they are also suppressed from getaddrinfo (flags=0 and +# flags=AI_ADDRCONFIG seem to do the same thing). If you pass AI_ALL, then you +# get the full list. +# ...except for localhost! getaddrinfo("localhost", "80") gives me ::1, even +# though there's no ipv6 and other queries only return ipv4. +# If you pass in and IPv6 IP address as the target address, then that's always +# returned OK, even with AI_ADDRCONFIG set and no IPv6 configured. +# +# But I guess other versions of windows messed this up, judging from these bug +# reports: +# https://bugs.chromium.org/p/chromium/issues/detail?id=5234 +# https://bugs.chromium.org/p/chromium/issues/detail?id=32522#c50 +# +# So basically the options are either to use AI_ADDRCONFIG and then add some +# complicated special cases to work around its brokenness, or else don't use +# AI_ADDRCONFIG and accept that sometimes on legacy/misconfigured networks +# we'll waste 300 ms trying to connect to a blackholed destination. +# +# Twisted and Tornado always uses default flags. I think we'll do the same. + + +@contextmanager +def close_all() -> Generator[set[SocketType], None, None]: + sockets_to_close: set[SocketType] = set() + try: + yield sockets_to_close + finally: + errs = [] + for sock in sockets_to_close: + try: + sock.close() + except BaseException as exc: + errs.append(exc) + if len(errs) == 1: + raise errs[0] + elif errs: + raise BaseExceptionGroup("", errs) + + +def reorder_for_rfc_6555_section_5_4( # type: ignore[explicit-any] + targets: MutableSequence[tuple[AddressFamily, SocketKind, int, str, Any]], +) -> None: + # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address + # families (e.g. IPv4 and IPv6), then you should make sure that your first + # and second attempts use different families: + # + # https://tools.ietf.org/html/rfc6555#section-5.4 + # + # This function post-processes the results from getaddrinfo, in-place, to + # satisfy this requirement. + for i in range(1, len(targets)): + if targets[i][0] != targets[0][0]: + # Found the first entry with a different address family; move it + # so that it becomes the second item on the list. + if i != 1: + targets.insert(1, targets.pop(i)) + break + + +def format_host_port(host: str | bytes, port: int | str) -> str: + host = host.decode("ascii") if isinstance(host, bytes) else host + if ":" in host: + return f"[{host}]:{port}" + else: + return f"{host}:{port}" + + +# Twisted's HostnameEndpoint has a good set of configurables: +# https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.HostnameEndpoint.html +# +# - per-connection timeout +# this doesn't seem useful -- we let you set a timeout on the whole thing +# using Trio's normal mechanisms, and that seems like enough +# - delay between attempts +# - bind address (but not port!) +# they *don't* support multiple address bindings, like giving the ipv4 and +# ipv6 addresses of the host. +# I think maybe our semantics should be: we accept a list of bind addresses, +# and we bind to the first one that is compatible with the +# connection attempt we want to make, and if none are compatible then we +# don't try to connect to that target. +# +# XX TODO: implement bind address support +# +# Actually, the best option is probably to be explicit: {AF_INET: "...", +# AF_INET6: "..."} +# this might be simpler after +async def open_tcp_stream( + host: str | bytes, + port: int, + *, + happy_eyeballs_delay: float | None = DEFAULT_DELAY, + local_address: str | None = None, +) -> trio.SocketStream: + """Connect to the given host and port over TCP. + + If the given ``host`` has multiple IP addresses associated with it, then + we have a problem: which one do we use? + + One approach would be to attempt to connect to the first one, and then if + that fails, attempt to connect to the second one ... until we've tried all + of them. But the problem with this is that if the first IP address is + unreachable (for example, because it's an IPv6 address and our network + discards IPv6 packets), then we might end up waiting tens of seconds for + the first connection attempt to timeout before we try the second address. + + Another approach would be to attempt to connect to all of the addresses at + the same time, in parallel, and then use whichever connection succeeds + first, abandoning the others. This would be fast, but create a lot of + unnecessary load on the network and the remote server. + + This function strikes a balance between these two extremes: it works its + way through the available addresses one at a time, like the first + approach; but, if ``happy_eyeballs_delay`` seconds have passed and it's + still waiting for an attempt to succeed or fail, then it gets impatient + and starts the next connection attempt in parallel. As soon as any one + connection attempt succeeds, all the other attempts are cancelled. This + avoids unnecessary load because most connections will succeed after just + one or two attempts, but if one of the addresses is unreachable then it + doesn't slow us down too much. + + This is known as a "happy eyeballs" algorithm, and our particular variant + is modelled after how Chrome connects to webservers; see `RFC 6555 + <https://tools.ietf.org/html/rfc6555>`__ for more details. + + Args: + host (str or bytes): The host to connect to. Can be an IPv4 address, + IPv6 address, or a hostname. + + port (int): The port to connect to. + + happy_eyeballs_delay (float or None): How many seconds to wait for each + connection attempt to succeed or fail before getting impatient and + starting another one in parallel. Set to `None` if you want + to limit to only one connection attempt at a time (like + :func:`socket.create_connection`). Default: 0.25 (250 ms). + + local_address (None or str): The local IP address or hostname to use as + the source for outgoing connections. If ``None``, we let the OS pick + the source IP. + + This is useful in some exotic networking configurations where your + host has multiple IP addresses, and you want to force the use of a + specific one. + + Note that if you pass an IPv4 ``local_address``, then you won't be + able to connect to IPv6 hosts, and vice-versa. If you want to take + advantage of this to force the use of IPv4 or IPv6 without + specifying an exact source address, you can use the IPv4 wildcard + address ``local_address="0.0.0.0"``, or the IPv6 wildcard address + ``local_address="::"``. + + Returns: + SocketStream: a :class:`~trio.abc.Stream` connected to the given server. + + Raises: + OSError: if the connection fails. + + See also: + open_ssl_over_tcp_stream + + """ + + # To keep our public API surface smaller, rule out some cases that + # getaddrinfo will accept in some circumstances, but that act weird or + # have non-portable behavior or are just plain not useful. + if not isinstance(host, (str, bytes)): + raise ValueError(f"host must be str or bytes, not {host!r}") + if not isinstance(port, int): + raise TypeError(f"port must be int, not {port!r}") + + if happy_eyeballs_delay is None: + happy_eyeballs_delay = DEFAULT_DELAY + + targets = await getaddrinfo(host, port, type=SOCK_STREAM) + + # I don't think this can actually happen -- if there are no results, + # getaddrinfo should have raised OSError instead of returning an empty + # list. But let's be paranoid and handle it anyway: + if not targets: + msg = f"no results found for hostname lookup: {format_host_port(host, port)}" + raise OSError(msg) + + reorder_for_rfc_6555_section_5_4(targets) + + # This list records all the connection failures that we ignored. + oserrors: list[OSError] = [] + + # Keeps track of the socket that we're going to complete with, + # need to make sure this isn't automatically closed + winning_socket: SocketType | None = None + + # Try connecting to the specified address. Possible outcomes: + # - success: record connected socket in winning_socket and cancel + # concurrent attempts + # - failure: record exception in oserrors, set attempt_failed allowing + # the next connection attempt to start early + # code needs to ensure sockets can be closed appropriately in the + # face of crash or cancellation + async def attempt_connect( + socket_args: tuple[AddressFamily, SocketKind, int], + sockaddr: AddressFormat, + attempt_failed: trio.Event, + ) -> None: + nonlocal winning_socket + + try: + sock = socket(*socket_args) + open_sockets.add(sock) + + if local_address is not None: + # TCP connections are identified by a 4-tuple: + # + # (local IP, local port, remote IP, remote port) + # + # So if a single local IP wants to make multiple connections + # to the same (remote IP, remote port) pair, then those + # connections have to use different local ports, or else TCP + # won't be able to tell them apart. OTOH, if you have multiple + # connections to different remote IP/ports, then those + # connections can share a local port. + # + # Normally, when you call bind(), the kernel will immediately + # assign a specific local port to your socket. At this point + # the kernel doesn't know which (remote IP, remote port) + # you're going to use, so it has to pick a local port that + # *no* other connection is using. That's the only way to + # guarantee that this local port will be usable later when we + # call connect(). (Alternatively, you can set SO_REUSEADDR to + # allow multiple nascent connections to share the same port, + # but then connect() might fail with EADDRNOTAVAIL if we get + # unlucky and our TCP 4-tuple ends up colliding with another + # unrelated connection.) + # + # So calling bind() before connect() works, but it disables + # sharing of local ports. This is inefficient: it makes you + # more likely to run out of local ports. + # + # But on some versions of Linux, we can re-enable sharing of + # local ports by setting a special flag. This flag tells + # bind() to only bind the IP, and not the port. That way, + # connect() is allowed to pick the the port, and it can do a + # better job of it because it knows the remote IP/port. + with suppress(OSError, AttributeError): + sock.setsockopt( + trio.socket.IPPROTO_IP, + trio.socket.IP_BIND_ADDRESS_NO_PORT, + 1, + ) + try: + await sock.bind((local_address, 0)) + except OSError: + raise OSError( + f"local_address={local_address!r} is incompatible " + f"with remote address {sockaddr!r}", + ) from None + + await sock.connect(sockaddr) + + # Success! Save the winning socket and cancel all outstanding + # connection attempts. + winning_socket = sock + nursery.cancel_scope.cancel(reason="successfully found a socket") + except OSError as exc: + # This connection attempt failed, but the next one might + # succeed. Save the error for later so we can report it if + # everything fails, and tell the next attempt that it should go + # ahead (if it hasn't already). + oserrors.append(exc) + attempt_failed.set() + + with close_all() as open_sockets: + # nursery spawns a task for each connection attempt, will be + # cancelled by the task that gets a successful connection + async with trio.open_nursery() as nursery: + for address_family, socket_type, proto, _, addr in targets: + # create an event to indicate connection failure, + # allowing the next target to be tried early + attempt_failed = trio.Event() + + nursery.start_soon( + attempt_connect, + (address_family, socket_type, proto), + addr, + attempt_failed, + ) + + # give this attempt at most this time before moving on + with trio.move_on_after(happy_eyeballs_delay): + await attempt_failed.wait() + + # nothing succeeded + if winning_socket is None: + assert len(oserrors) == len(targets) + msg = f"all attempts to connect to {format_host_port(host, port)} failed" + raise OSError(msg) from ExceptionGroup(msg, oserrors) + else: + stream = trio.SocketStream(winning_socket) + open_sockets.remove(winning_socket) + return stream diff --git a/contrib/python/trio/trio/_highlevel_open_unix_stream.py b/contrib/python/trio/trio/_highlevel_open_unix_stream.py new file mode 100644 index 00000000000..d419574369c --- /dev/null +++ b/contrib/python/trio/trio/_highlevel_open_unix_stream.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import os +from contextlib import contextmanager +from typing import TYPE_CHECKING, Protocol, TypeVar + +import trio +from trio.socket import SOCK_STREAM, socket + +if TYPE_CHECKING: + from collections.abc import Generator + + +class Closable(Protocol): + def close(self) -> None: ... + + +CloseT = TypeVar("CloseT", bound=Closable) + + +try: + from trio.socket import AF_UNIX + + has_unix = True +except ImportError: + has_unix = False + + +@contextmanager +def close_on_error(obj: CloseT) -> Generator[CloseT, None, None]: + try: + yield obj + except: + obj.close() + raise + + +async def open_unix_socket( + filename: str | bytes | os.PathLike[str] | os.PathLike[bytes], +) -> trio.SocketStream: + """Opens a connection to the specified + `Unix domain socket <https://en.wikipedia.org/wiki/Unix_domain_socket>`__. + + You must have read/write permission on the specified file to connect. + + Args: + filename (str or bytes): The filename to open the connection to. + + Returns: + SocketStream: a :class:`~trio.abc.Stream` connected to the given file. + + Raises: + OSError: If the socket file could not be connected to. + RuntimeError: If AF_UNIX sockets are not supported. + """ + if not has_unix: + raise RuntimeError("Unix sockets are not supported on this platform") + + # much more simplified logic vs tcp sockets - one socket type and only one + # possible location to connect to + sock = socket(AF_UNIX, SOCK_STREAM) + with close_on_error(sock): + await sock.connect(os.fspath(filename)) + + return trio.SocketStream(sock) diff --git a/contrib/python/trio/trio/_highlevel_serve_listeners.py b/contrib/python/trio/trio/_highlevel_serve_listeners.py new file mode 100644 index 00000000000..008caaabea5 --- /dev/null +++ b/contrib/python/trio/trio/_highlevel_serve_listeners.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import errno +import logging +import os +from collections.abc import Awaitable, Callable +from typing import Any, NoReturn, TypeVar + +import trio + +# Errors that accept(2) can return, and which indicate that the system is +# overloaded +ACCEPT_CAPACITY_ERRNOS = { + errno.EMFILE, + errno.ENFILE, + errno.ENOMEM, + errno.ENOBUFS, +} + +# How long to sleep when we get one of those errors +SLEEP_TIME = 0.100 + +# The logger we use to complain when this happens +LOGGER = logging.getLogger("trio.serve_listeners") + + +StreamT = TypeVar("StreamT", bound=trio.abc.AsyncResource) +ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any]) # type: ignore[explicit-any] +Handler = Callable[[StreamT], Awaitable[object]] + + +async def _run_handler(stream: StreamT, handler: Handler[StreamT]) -> None: + try: + await handler(stream) + finally: + await trio.aclose_forcefully(stream) + + +async def _serve_one_listener( + listener: trio.abc.Listener[StreamT], + handler_nursery: trio.Nursery, + handler: Handler[StreamT], +) -> NoReturn: + async with listener: + while True: + try: + stream = await listener.accept() + except OSError as exc: + if exc.errno in ACCEPT_CAPACITY_ERRNOS: + LOGGER.error( + "accept returned %s (%s); retrying in %s seconds", + errno.errorcode[exc.errno], + os.strerror(exc.errno), + SLEEP_TIME, + exc_info=True, + ) + await trio.sleep(SLEEP_TIME) + else: + raise + else: + handler_nursery.start_soon(_run_handler, stream, handler) + + +# This cannot be typed correctly, we need generic typevar bounds / HKT to indicate the +# relationship between StreamT & ListenerT. +# https://github.com/python/typing/issues/1226 +# https://github.com/python/typing/issues/548 + + +async def serve_listeners( # type: ignore[explicit-any] + handler: Handler[StreamT], + listeners: list[ListenerT], + *, + handler_nursery: trio.Nursery | None = None, + task_status: trio.TaskStatus[list[ListenerT]] = trio.TASK_STATUS_IGNORED, +) -> NoReturn: + r"""Listen for incoming connections on ``listeners``, and for each one + start a task running ``handler(stream)``. + + .. warning:: + + If ``handler`` raises an exception, then this function doesn't do + anything special to catch it – so by default the exception will + propagate out and crash your server. If you don't want this, then catch + exceptions inside your ``handler``, or use a ``handler_nursery`` object + that responds to exceptions in some other way. + + Args: + + handler: An async callable, that will be invoked like + ``handler_nursery.start_soon(handler, stream)`` for each incoming + connection. + + listeners: A list of :class:`~trio.abc.Listener` objects. + :func:`serve_listeners` takes responsibility for closing them. + + handler_nursery: The nursery used to start handlers, or any object with + a ``start_soon`` method. If ``None`` (the default), then + :func:`serve_listeners` will create a new nursery internally and use + that. + + task_status: This function can be used with ``nursery.start``, which + will return ``listeners``. + + Returns: + + This function never returns unless cancelled. + + Resource handling: + + If ``handler`` neglects to close the ``stream``, then it will be closed + using :func:`trio.aclose_forcefully`. + + Error handling: + + Most errors coming from :meth:`~trio.abc.Listener.accept` are allowed to + propagate out (crashing the server in the process). However, some errors – + those which indicate that the server is temporarily overloaded – are + handled specially. These are :class:`OSError`\s with one of the following + errnos: + + * ``EMFILE``: process is out of file descriptors + * ``ENFILE``: system is out of file descriptors + * ``ENOBUFS``, ``ENOMEM``: the kernel hit some sort of memory limitation + when trying to create a socket object + + When :func:`serve_listeners` gets one of these errors, then it: + + * Logs the error to the standard library logger ``trio.serve_listeners`` + (level = ERROR, with exception information included). By default this + causes it to be printed to stderr. + * Waits 100 ms before calling ``accept`` again, in hopes that the + system will recover. + + """ + async with trio.open_nursery() as nursery: + if handler_nursery is None: + handler_nursery = nursery + for listener in listeners: + nursery.start_soon(_serve_one_listener, listener, handler_nursery, handler) + # The listeners are already queueing connections when we're called, + # but we wait until the end to call started() just in case we get an + # error or whatever. + task_status.started(listeners) + + raise AssertionError( + "_serve_one_listener should never complete", + ) # pragma: no cover diff --git a/contrib/python/trio/trio/_highlevel_socket.py b/contrib/python/trio/trio/_highlevel_socket.py new file mode 100644 index 00000000000..142ab11e073 --- /dev/null +++ b/contrib/python/trio/trio/_highlevel_socket.py @@ -0,0 +1,423 @@ +# "High-level" networking interface +from __future__ import annotations + +import errno +from contextlib import contextmanager, suppress +from typing import TYPE_CHECKING, overload + +import trio + +from . import socket as tsocket +from ._util import ConflictDetector, final +from .abc import HalfCloseableStream, Listener + +if TYPE_CHECKING: + from collections.abc import Generator + + from ._socket import SocketType + +import sys + +if sys.version_info >= (3, 12): + # NOTE: this isn't in the `TYPE_CHECKING` since for some reason + # sphinx doesn't autoreload this module for SocketStream + # (hypothesis: it's our module renaming magic) + from collections.abc import Buffer +elif TYPE_CHECKING: + from typing_extensions import Buffer + +# XX TODO: this number was picked arbitrarily. We should do experiments to +# tune it. (Or make it dynamic -- one idea is to start small and increase it +# if we observe single reads filling up the whole buffer, at least within some +# limits.) +DEFAULT_RECEIVE_SIZE = 65536 + +_closed_stream_errnos = { + # Unix + errno.EBADF, + # Windows + errno.ENOTSOCK, +} + + +@contextmanager +def _translate_socket_errors_to_stream_errors() -> Generator[None, None, None]: + try: + yield + except OSError as exc: + if exc.errno in _closed_stream_errnos: + raise trio.ClosedResourceError("this socket was already closed") from None + else: + raise trio.BrokenResourceError(f"socket connection broken: {exc}") from exc + + +@final +class SocketStream(HalfCloseableStream): + """An implementation of the :class:`trio.abc.HalfCloseableStream` + interface based on a raw network socket. + + Args: + socket: The Trio socket object to wrap. Must have type ``SOCK_STREAM``, + and be connected. + + By default for TCP sockets, :class:`SocketStream` enables ``TCP_NODELAY``, + and (on platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with + a reasonable buffer size (currently 16 KiB) – see `issue #72 + <https://github.com/python-trio/trio/issues/72>`__ for discussion. You can + of course override these defaults by calling :meth:`setsockopt`. + + Once a :class:`SocketStream` object is constructed, it implements the full + :class:`trio.abc.HalfCloseableStream` interface. In addition, it provides + a few extra features: + + .. attribute:: socket + + The Trio socket object that this stream wraps. + + """ + + def __init__(self, socket: SocketType) -> None: + if not isinstance(socket, tsocket.SocketType): + raise TypeError("SocketStream requires a Trio socket object") + if socket.type != tsocket.SOCK_STREAM: + raise ValueError("SocketStream requires a SOCK_STREAM socket") + + self.socket = socket + self._send_conflict_detector = ConflictDetector( + "another task is currently sending data on this SocketStream", + ) + + # Socket defaults: + + # Not supported on e.g. unix domain sockets + with suppress(OSError): + self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, True) + + if hasattr(tsocket, "TCP_NOTSENT_LOWAT"): + # 16 KiB is pretty arbitrary and could probably do with some + # tuning. (Apple is also setting this by default in CFNetwork + # apparently -- I'm curious what value they're using, though I + # couldn't find it online trivially. CFNetwork-129.20 source + # has no mentions of TCP_NOTSENT_LOWAT. This presentation says + # "typically 8 kilobytes": + # http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1 + # ). The theory is that you want it to be bandwidth * + # rescheduling interval. + with suppress(OSError): + self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2**14) + + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + if self.socket.did_shutdown_SHUT_WR: + raise trio.ClosedResourceError("can't send data after sending EOF") + with self._send_conflict_detector: + with _translate_socket_errors_to_stream_errors(): + with memoryview(data) as data: + if not data: + if self.socket.fileno() == -1: + raise trio.ClosedResourceError("socket was already closed") + await trio.lowlevel.checkpoint() + return + total_sent = 0 + while total_sent < len(data): + with data[total_sent:] as remaining: + sent = await self.socket.send(remaining) + total_sent += sent + + async def wait_send_all_might_not_block(self) -> None: + with self._send_conflict_detector: + if self.socket.fileno() == -1: + raise trio.ClosedResourceError + with _translate_socket_errors_to_stream_errors(): + await self.socket.wait_writable() + + async def send_eof(self) -> None: + with self._send_conflict_detector: + await trio.lowlevel.checkpoint() + # On macOS, calling shutdown a second time raises ENOTCONN, but + # send_eof needs to be idempotent. + if self.socket.did_shutdown_SHUT_WR: + return + with _translate_socket_errors_to_stream_errors(): + self.socket.shutdown(tsocket.SHUT_WR) + + async def receive_some(self, max_bytes: int | None = None) -> bytes: + if max_bytes is None: + max_bytes = DEFAULT_RECEIVE_SIZE + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") + with _translate_socket_errors_to_stream_errors(): + return await self.socket.recv(max_bytes) + + async def aclose(self) -> None: + self.socket.close() + await trio.lowlevel.checkpoint() + + # __aenter__, __aexit__ inherited from HalfCloseableStream are OK + + @overload + def setsockopt(self, level: int, option: int, value: int | Buffer) -> None: ... + + @overload + def setsockopt(self, level: int, option: int, value: None, length: int) -> None: ... + + # TODO: rename `length` to `optlen` + def setsockopt( + self, + level: int, + option: int, + value: int | Buffer | None, + length: int | None = None, + ) -> None: + """Set an option on the underlying socket. + + See :meth:`socket.socket.setsockopt` for details. + + """ + if length is None: + if value is None: + raise TypeError( + "invalid value for argument 'value', must not be None when specifying length", + ) + return self.socket.setsockopt(level, option, value) + if value is not None: + raise TypeError( + f"invalid value for argument 'value': {value!r}, must be None when specifying optlen", + ) + return self.socket.setsockopt(level, option, value, length) + + @overload + def getsockopt(self, level: int, option: int) -> int: ... + + @overload + def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: ... + + def getsockopt(self, level: int, option: int, buffersize: int = 0) -> int | bytes: + """Check the current value of an option on the underlying socket. + + See :meth:`socket.socket.getsockopt` for details. + + """ + # This is to work around + # https://bitbucket.org/pypy/pypy/issues/2561 + # We should be able to drop it when the next PyPy3 beta is released. + if buffersize == 0: + return self.socket.getsockopt(level, option) + else: + return self.socket.getsockopt(level, option, buffersize) + + +################################################################ +# SocketListener +################################################################ + +# Accept error handling +# ===================== +# +# Literature review +# ----------------- +# +# Here's a list of all the possible errors that accept() can return, according +# to the POSIX spec or the Linux, FreeBSD, macOS, and Windows docs: +# +# Can't happen with a Trio socket: +# - EAGAIN/(WSA)EWOULDBLOCK +# - EINTR +# - WSANOTINITIALISED +# - WSAEINPROGRESS: a blocking call is already in progress +# - WSAEINTR: someone called WSACancelBlockingCall, but we don't make blocking +# calls in the first place +# +# Something is wrong with our call: +# - EBADF: not a file descriptor +# - (WSA)EINVAL: socket isn't listening, or (Linux, BSD) bad flags +# - (WSA)ENOTSOCK: not a socket +# - (WSA)EOPNOTSUPP: this kind of socket doesn't support accept +# - (Linux, FreeBSD, Windows) EFAULT: the sockaddr pointer points to readonly +# memory +# +# Something is wrong with the environment: +# - (WSA)EMFILE: this process hit its fd limit +# - ENFILE: the system hit its fd limit +# - (WSA)ENOBUFS, ENOMEM: unspecified memory problems +# +# Something is wrong with the connection we were going to accept. There's a +# ton of variability between systems here: +# - ECONNABORTED: documented everywhere, but apparently only the BSDs do this +# (signals a connection was closed/reset before being accepted) +# - EPROTO: unspecified protocol error +# - (Linux) EPERM: firewall rule prevented connection +# - (Linux) ENETDOWN, EPROTO, ENOPROTOOPT, EHOSTDOWN, ENONET, EHOSTUNREACH, +# EOPNOTSUPP, ENETUNREACH, ENOSR, ESOCKTNOSUPPORT, EPROTONOSUPPORT, +# ETIMEDOUT, ... or any other error that the socket could give, because +# apparently if an error happens on a connection before it's accept()ed, +# Linux will report that error from accept(). +# - (Windows) WSAECONNRESET, WSAENETDOWN +# +# +# Code review +# ----------- +# +# What do other libraries do? +# +# Twisted on Unix or when using nonblocking I/O on Windows: +# - ignores EPERM, with comment about Linux firewalls +# - logs and ignores EMFILE, ENOBUFS, ENFILE, ENOMEM, ECONNABORTED +# Comment notes that ECONNABORTED is a BSDism and that Linux returns the +# socket before having it fail, and macOS just silently discards it. +# - other errors are raised, which is logged + kills the socket +# ref: src/twisted/internet/tcp.py, Port.doRead +# +# Twisted using IOCP on Windows: +# - logs and ignores all errors +# ref: src/twisted/internet/iocpreactor/tcp.py, Port.handleAccept +# +# Tornado: +# - ignore ECONNABORTED (comments notes that it was observed on FreeBSD) +# - everything else raised, but all this does (by default) is cause it to be +# logged and then ignored +# (ref: tornado/netutil.py, tornado/ioloop.py) +# +# libuv on Unix: +# - ignores ECONNABORTED +# - does a "trick" for EMFILE or ENFILE +# - all other errors passed to the connection_cb to be handled +# (ref: src/unix/stream.c:uv__server_io, uv__emfile_trick) +# +# libuv on Windows: +# src/win/tcp.c:uv_tcp_queue_accept +# this calls AcceptEx, and then arranges to call: +# src/win/tcp.c:uv_process_tcp_accept_req +# this gets the result from AcceptEx. If the original AcceptEx call failed, +# then "we stop accepting connections and report this error to the +# connection callback". I think this is for things like ENOTSOCK. If +# AcceptEx successfully queues an overlapped operation, and then that +# reports an error, it's just discarded. +# +# asyncio, selector mode: +# - ignores EWOULDBLOCK, EINTR, ECONNABORTED +# - on EMFILE, ENFILE, ENOBUFS, ENOMEM, logs an error and then disables the +# listening loop for 1 second +# - everything else raises, but then the event loop just logs and ignores it +# (selector_events.py: BaseSelectorEventLoop._accept_connection) +# +# +# What should we do? +# ------------------ +# +# When accept() returns an error, we can either ignore it or raise it. +# +# We have a long list of errors that should be ignored, and a long list of +# errors that should be raised. The big question is what to do with an error +# that isn't on either list. On Linux apparently you can get nearly arbitrary +# errors from accept() and they should be ignored, because it just indicates a +# socket that crashed before it began, and there isn't really anything to be +# done about this, plus on other platforms you may not get any indication at +# all, so programs have to tolerate not getting any indication too. OTOH if we +# get an unexpected error then it could indicate something arbitrarily bad -- +# after all, it's unexpected. +# +# Given that we know that other libraries seem to be getting along fine with a +# fairly minimal list of errors to ignore, I think we'll be OK if we write +# down that list and then raise on everything else. +# +# The other question is what to do about the capacity problem errors: EMFILE, +# ENFILE, ENOBUFS, ENOMEM. Just flat out ignoring these is clearly not optimal +# -- at the very least you want to log them, and probably you want to take +# some remedial action. And if we ignore them then it prevents higher levels +# from doing anything clever with them. So we raise them. + +_ignorable_accept_errno_names = [ + # Linux can do this when the a connection is denied by the firewall + "EPERM", + # BSDs with an early close/reset + "ECONNABORTED", + # All the other miscellany noted above -- may not happen in practice, but + # whatever. + "EPROTO", + "ENETDOWN", + "ENOPROTOOPT", + "EHOSTDOWN", + "ENONET", + "EHOSTUNREACH", + "EOPNOTSUPP", + "ENETUNREACH", + "ENOSR", + "ESOCKTNOSUPPORT", + "EPROTONOSUPPORT", + "ETIMEDOUT", + "ECONNRESET", +] + +# Not all errnos are defined on all platforms +_ignorable_accept_errnos: set[int] = set() +for name in _ignorable_accept_errno_names: + with suppress(AttributeError): + _ignorable_accept_errnos.add(getattr(errno, name)) + + +@final +class SocketListener(Listener[SocketStream]): + """A :class:`~trio.abc.Listener` that uses a listening socket to accept + incoming connections as :class:`SocketStream` objects. + + Args: + socket: The Trio socket object to wrap. Must have type ``SOCK_STREAM``, + and be listening. + + Note that the :class:`SocketListener` "takes ownership" of the given + socket; closing the :class:`SocketListener` will also close the socket. + + .. attribute:: socket + + The Trio socket object that this stream wraps. + + """ + + def __init__(self, socket: SocketType) -> None: + if not isinstance(socket, tsocket.SocketType): + raise TypeError("SocketListener requires a Trio socket object") + if socket.type != tsocket.SOCK_STREAM: + raise ValueError("SocketListener requires a SOCK_STREAM socket") + try: + listening = socket.getsockopt(tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN) + except OSError: + # SO_ACCEPTCONN fails on macOS; we just have to trust the user. + pass + else: + if not listening: + raise ValueError("SocketListener requires a listening socket") + + self.socket = socket + + async def accept(self) -> SocketStream: + """Accept an incoming connection. + + Returns: + :class:`SocketStream` + + Raises: + OSError: if the underlying call to ``accept`` raises an unexpected + error. + ClosedResourceError: if you already closed the socket. + + This method handles routine errors like ``ECONNABORTED``, but passes + other errors on to its caller. In particular, it does *not* make any + special effort to handle resource exhaustion errors like ``EMFILE``, + ``ENFILE``, ``ENOBUFS``, ``ENOMEM``. + + """ + while True: + try: + sock, _ = await self.socket.accept() + except OSError as exc: + if exc.errno in _closed_stream_errnos: + raise trio.ClosedResourceError from None + if exc.errno not in _ignorable_accept_errnos: + raise + else: + return SocketStream(sock) + + async def aclose(self) -> None: + """Close this listener and its underlying socket.""" + self.socket.close() + await trio.lowlevel.checkpoint() diff --git a/contrib/python/trio/trio/_highlevel_ssl_helpers.py b/contrib/python/trio/trio/_highlevel_ssl_helpers.py new file mode 100644 index 00000000000..1239491a43b --- /dev/null +++ b/contrib/python/trio/trio/_highlevel_ssl_helpers.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import ssl +from typing import TYPE_CHECKING, NoReturn, TypeVar + +import trio + +from ._highlevel_open_tcp_stream import DEFAULT_DELAY + +T = TypeVar("T") + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from ._highlevel_socket import SocketStream + + +# It might have been nice to take a ssl_protocols= argument here to set up +# NPN/ALPN, but to do this we have to mutate the context object, which is OK +# if it's one we created, but not OK if it's one that was passed in... and +# the one major protocol using NPN/ALPN is HTTP/2, which mandates that you use +# a specially configured SSLContext anyway! I also thought maybe we could copy +# the given SSLContext and then mutate the copy, but it's no good as SSLContext +# objects can't be copied: https://bugs.python.org/issue33023. +# So... let's punt on that for now. Hopefully we'll be getting a new Python +# TLS API soon and can revisit this then. +async def open_ssl_over_tcp_stream( + host: str | bytes, + port: int, + *, + https_compatible: bool = False, + ssl_context: ssl.SSLContext | None = None, + happy_eyeballs_delay: float | None = DEFAULT_DELAY, +) -> trio.SSLStream[SocketStream]: + """Make a TLS-encrypted Connection to the given host and port over TCP. + + This is a convenience wrapper that calls :func:`open_tcp_stream` and + wraps the result in an :class:`~trio.SSLStream`. + + This function does not perform the TLS handshake; you can do it + manually by calling :meth:`~trio.SSLStream.do_handshake`, or else + it will be performed automatically the first time you send or receive + data. + + Args: + host (bytes or str): The host to connect to. We require the server + to have a TLS certificate valid for this hostname. + port (int): The port to connect to. + https_compatible (bool): Set this to True if you're connecting to a web + server. See :class:`~trio.SSLStream` for details. Default: + False. + ssl_context (:class:`~ssl.SSLContext` or None): The SSL context to + use. If None (the default), :func:`ssl.create_default_context` + will be called to create a context. + happy_eyeballs_delay (float): See :func:`open_tcp_stream`. + + Returns: + trio.SSLStream: the encrypted connection to the server. + + """ + tcp_stream = await trio.open_tcp_stream( + host, + port, + happy_eyeballs_delay=happy_eyeballs_delay, + ) + if ssl_context is None: + ssl_context = ssl.create_default_context() + + if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): + ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF + + return trio.SSLStream( + tcp_stream, + ssl_context, + server_hostname=host, + https_compatible=https_compatible, + ) + + +async def open_ssl_over_tcp_listeners( + port: int, + ssl_context: ssl.SSLContext, + *, + host: str | bytes | None = None, + https_compatible: bool = False, + backlog: int | None = None, +) -> list[trio.SSLListener[SocketStream]]: + """Start listening for SSL/TLS-encrypted TCP connections to the given port. + + Args: + port (int): The port to listen on. See :func:`open_tcp_listeners`. + ssl_context (~ssl.SSLContext): The SSL context to use for all incoming + connections. + host (str, bytes, or None): The address to bind to; use ``None`` to bind + to the wildcard address. See :func:`open_tcp_listeners`. + https_compatible (bool): See :class:`~trio.SSLStream` for details. + backlog (int or None): See :func:`open_tcp_listeners` for details. + + """ + tcp_listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog) + ssl_listeners = [ + trio.SSLListener(tcp_listener, ssl_context, https_compatible=https_compatible) + for tcp_listener in tcp_listeners + ] + return ssl_listeners + + +async def serve_ssl_over_tcp( + handler: Callable[[trio.SSLStream[SocketStream]], Awaitable[object]], + port: int, + ssl_context: ssl.SSLContext, + *, + host: str | bytes | None = None, + https_compatible: bool = False, + backlog: int | None = None, + handler_nursery: trio.Nursery | None = None, + task_status: trio.TaskStatus[ + list[trio.SSLListener[SocketStream]] + ] = trio.TASK_STATUS_IGNORED, +) -> NoReturn: + """Listen for incoming TCP connections, and for each one start a task + running ``handler(stream)``. + + This is a thin convenience wrapper around + :func:`open_ssl_over_tcp_listeners` and :func:`serve_listeners` – see them + for full details. + + .. warning:: + + If ``handler`` raises an exception, then this function doesn't do + anything special to catch it – so by default the exception will + propagate out and crash your server. If you don't want this, then catch + exceptions inside your ``handler``, or use a ``handler_nursery`` object + that responds to exceptions in some other way. + + When used with ``nursery.start`` you get back the newly opened listeners. + See the documentation for :func:`serve_tcp` for an example where this is + useful. + + Args: + handler: The handler to start for each incoming connection. Passed to + :func:`serve_listeners`. + + port (int): The port to listen on. Use 0 to let the kernel pick + an open port. Ultimately passed to :func:`open_tcp_listeners`. + + ssl_context (~ssl.SSLContext): The SSL context to use for all incoming + connections. Passed to :func:`open_ssl_over_tcp_listeners`. + + host (str, bytes, or None): The address to bind to; use ``None`` to bind + to the wildcard address. Ultimately passed to + :func:`open_tcp_listeners`. + + https_compatible (bool): Set this to True if you want to use + "HTTPS-style" TLS. See :class:`~trio.SSLStream` for details. + + backlog (int or None): See :class:`~trio.SSLStream` for details. + + handler_nursery: The nursery to start handlers in, or None to use an + internal nursery. Passed to :func:`serve_listeners`. + + task_status: This function can be used with ``nursery.start``. + + Returns: + This function only returns when cancelled. + + """ + listeners = await trio.open_ssl_over_tcp_listeners( + port, + ssl_context, + host=host, + https_compatible=https_compatible, + backlog=backlog, + ) + await trio.serve_listeners( + handler, + listeners, + handler_nursery=handler_nursery, + task_status=task_status, + ) diff --git a/contrib/python/trio/trio/_path.py b/contrib/python/trio/trio/_path.py new file mode 100644 index 00000000000..af6fbe00597 --- /dev/null +++ b/contrib/python/trio/trio/_path.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import os +import pathlib +import sys +from functools import partial, update_wrapper +from inspect import cleandoc +from typing import IO, TYPE_CHECKING, Any, BinaryIO, ClassVar, TypeVar, overload + +from trio._file_io import AsyncIOWrapper, wrap_file +from trio._util import final +from trio.to_thread import run_sync + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Iterable + from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWrapper + + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + ) + from typing_extensions import Concatenate, Literal, ParamSpec, Self + + P = ParamSpec("P") + + PathT = TypeVar("PathT", bound="Path") + T = TypeVar("T") + + +def _wraps_async( # type: ignore[explicit-any] + wrapped: Callable[..., object], +) -> Callable[[Callable[P, T]], Callable[P, Awaitable[T]]]: + def decorator(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + return await run_sync(partial(fn, *args, **kwargs)) + + update_wrapper(wrapper, wrapped) + if wrapped.__doc__: + module = wrapped.__module__ + # these are exported specially from CPython's intersphinx inventory + module = module.replace("pathlib._local", "pathlib") + module = module.replace("pathlib._abc", "pathlib") + + name = wrapped.__qualname__ + name = name.replace( + "PathBase", "Path" + ) # I'm not sure why this is necessary + + wrapper.__doc__ = ( + f"Like :meth:`~{module}.{name}`, but async.\n" + f"\n" + f"{cleandoc(wrapped.__doc__)}\n" + ) + return wrapper + + return decorator + + +def _wrap_method( + fn: Callable[Concatenate[pathlib.Path, P], T], +) -> Callable[Concatenate[Path, P], Awaitable[T]]: + @_wraps_async(fn) + def wrapper(self: Path, /, *args: P.args, **kwargs: P.kwargs) -> T: + return fn(self._wrapped_cls(self), *args, **kwargs) + + return wrapper + + +def _wrap_method_path( + fn: Callable[Concatenate[pathlib.Path, P], pathlib.Path], +) -> Callable[Concatenate[PathT, P], Awaitable[PathT]]: + @_wraps_async(fn) + def wrapper(self: PathT, /, *args: P.args, **kwargs: P.kwargs) -> PathT: + return self.__class__(fn(self._wrapped_cls(self), *args, **kwargs)) + + return wrapper + + +def _wrap_method_path_iterable( + fn: Callable[Concatenate[pathlib.Path, P], Iterable[pathlib.Path]], +) -> Callable[Concatenate[PathT, P], Awaitable[Iterable[PathT]]]: + @_wraps_async(fn) + def wrapper(self: PathT, /, *args: P.args, **kwargs: P.kwargs) -> Iterable[PathT]: + return map(self.__class__, [*fn(self._wrapped_cls(self), *args, **kwargs)]) + + if wrapper.__doc__: + wrapper.__doc__ += ( + f"\n" + f"This is an async method that returns a synchronous iterator, so you\n" + f"use it like:\n" + f"\n" + f".. code:: python\n" + f"\n" + f" for subpath in await mypath.{fn.__name__}():\n" + f" ...\n" + f"\n" + f".. note::\n" + f"\n" + f" The iterator is loaded into memory immediately during the initial\n" + f" call (see `issue #501\n" + f" <https://github.com/python-trio/trio/issues/501>`__ for discussion).\n" + ) + return wrapper + + +class Path(pathlib.PurePath): + """An async :class:`pathlib.Path` that executes blocking methods in :meth:`trio.to_thread.run_sync`. + + Instantiating :class:`Path` returns a concrete platform-specific subclass, one of :class:`PosixPath` or + :class:`WindowsPath`. + """ + + __slots__ = () + + _wrapped_cls: ClassVar[type[pathlib.Path]] + + def __new__(cls, *args: str | os.PathLike[str]) -> Self: + if cls is Path: + cls = WindowsPath if os.name == "nt" else PosixPath # type: ignore[assignment] + return super().__new__(cls, *args) + + @classmethod + @_wraps_async(pathlib.Path.cwd) + def cwd(cls) -> Self: + return cls(pathlib.Path.cwd()) + + @classmethod + @_wraps_async(pathlib.Path.home) + def home(cls) -> Self: + return cls(pathlib.Path.home()) + + @overload + async def open( + self, + mode: OpenTextMode = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> AsyncIOWrapper[TextIOWrapper]: ... + + @overload + async def open( + self, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[FileIO]: ... + + @overload + async def open( + self, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[BufferedRandom]: ... + + @overload + async def open( + self, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[BufferedWriter]: ... + + @overload + async def open( + self, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[BufferedReader]: ... + + @overload + async def open( + self, + mode: OpenBinaryMode, + buffering: int = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[BinaryIO]: ... + + @overload + async def open( # type: ignore[explicit-any] # Any usage matches builtins.open(). + self, + mode: str, + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> AsyncIOWrapper[IO[Any]]: ... + + @_wraps_async(pathlib.Path.open) + def open(self, *args: Any, **kwargs: Any) -> AsyncIOWrapper[IO[Any]]: # type: ignore[misc, explicit-any] # Overload return mismatch. + return wrap_file(self._wrapped_cls(self).open(*args, **kwargs)) + + def __repr__(self) -> str: + return f"trio.Path({str(self)!r})" + + stat = _wrap_method(pathlib.Path.stat) + chmod = _wrap_method(pathlib.Path.chmod) + exists = _wrap_method(pathlib.Path.exists) + glob = _wrap_method_path_iterable(pathlib.Path.glob) + rglob = _wrap_method_path_iterable(pathlib.Path.rglob) + is_dir = _wrap_method(pathlib.Path.is_dir) + is_file = _wrap_method(pathlib.Path.is_file) + is_symlink = _wrap_method(pathlib.Path.is_symlink) + is_socket = _wrap_method(pathlib.Path.is_socket) + is_fifo = _wrap_method(pathlib.Path.is_fifo) + is_block_device = _wrap_method(pathlib.Path.is_block_device) + is_char_device = _wrap_method(pathlib.Path.is_char_device) + if sys.version_info >= (3, 12): + is_junction = _wrap_method(pathlib.Path.is_junction) + iterdir = _wrap_method_path_iterable(pathlib.Path.iterdir) + lchmod = _wrap_method(pathlib.Path.lchmod) + lstat = _wrap_method(pathlib.Path.lstat) + mkdir = _wrap_method(pathlib.Path.mkdir) + if sys.platform != "win32": + owner = _wrap_method(pathlib.Path.owner) + group = _wrap_method(pathlib.Path.group) + if sys.platform != "win32" or sys.version_info >= (3, 12): + is_mount = _wrap_method(pathlib.Path.is_mount) + readlink = _wrap_method_path(pathlib.Path.readlink) + rename = _wrap_method_path(pathlib.Path.rename) + replace = _wrap_method_path(pathlib.Path.replace) + resolve = _wrap_method_path(pathlib.Path.resolve) + rmdir = _wrap_method(pathlib.Path.rmdir) + symlink_to = _wrap_method(pathlib.Path.symlink_to) + if sys.version_info >= (3, 10): + hardlink_to = _wrap_method(pathlib.Path.hardlink_to) + touch = _wrap_method(pathlib.Path.touch) + unlink = _wrap_method(pathlib.Path.unlink) + absolute = _wrap_method_path(pathlib.Path.absolute) + expanduser = _wrap_method_path(pathlib.Path.expanduser) + read_bytes = _wrap_method(pathlib.Path.read_bytes) + read_text = _wrap_method(pathlib.Path.read_text) + samefile = _wrap_method(pathlib.Path.samefile) + write_bytes = _wrap_method(pathlib.Path.write_bytes) + write_text = _wrap_method(pathlib.Path.write_text) + if sys.version_info < (3, 12): + link_to = _wrap_method(pathlib.Path.link_to) + if sys.version_info >= (3, 13): + full_match = _wrap_method(pathlib.Path.full_match) + + def as_uri(self) -> str: + return pathlib.Path.as_uri(self) + + +if Path.relative_to.__doc__: # pragma: no branch + Path.relative_to.__doc__ = Path.relative_to.__doc__.replace(" `..` ", " ``..`` ") + + +@final +class PosixPath(Path, pathlib.PurePosixPath): + """An async :class:`pathlib.PosixPath` that executes blocking methods in :meth:`trio.to_thread.run_sync`.""" + + __slots__ = () + + _wrapped_cls: ClassVar[type[pathlib.Path]] = pathlib.PosixPath + + +@final +class WindowsPath(Path, pathlib.PureWindowsPath): + """An async :class:`pathlib.WindowsPath` that executes blocking methods in :meth:`trio.to_thread.run_sync`.""" + + __slots__ = () + + _wrapped_cls: ClassVar[type[pathlib.Path]] = pathlib.WindowsPath diff --git a/contrib/python/trio/trio/_repl.py b/contrib/python/trio/trio/_repl.py new file mode 100644 index 00000000000..5a96e687898 --- /dev/null +++ b/contrib/python/trio/trio/_repl.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import ast +import contextlib +import inspect +import sys +import warnings +from code import InteractiveConsole +from types import CodeType, FrameType, FunctionType +from typing import Callable + +import outcome + +import trio +import trio.lowlevel +from trio._util import final + + +class SuppressDecorator(contextlib.ContextDecorator, contextlib.suppress): + pass + + +@SuppressDecorator(KeyboardInterrupt) [email protected]_ki_protection +def terminal_newline() -> None: # TODO: test this line + import fcntl + import termios + + # Fake up a newline char as if user had typed it at the terminal + try: + fcntl.ioctl(sys.stdin, termios.TIOCSTI, b"\n") # type: ignore[attr-defined, unused-ignore] + except OSError as e: + print(f"\nPress enter! Newline injection failed: {e}", end="", flush=True) + + +@final +class TrioInteractiveConsole(InteractiveConsole): + def __init__(self, repl_locals: dict[str, object] | None = None) -> None: + super().__init__(locals=repl_locals) + self.token: trio.lowlevel.TrioToken | None = None + self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT + self.interrupted = False + + def runcode(self, code: CodeType) -> None: + func = FunctionType(code, self.locals) + if inspect.iscoroutinefunction(func): + result = trio.from_thread.run(outcome.acapture, func) + else: + result = trio.from_thread.run_sync(outcome.capture, func) + if isinstance(result, outcome.Error): + # If it is SystemExit, quit the repl. Otherwise, print the traceback. + # If there is a SystemExit inside a BaseExceptionGroup, it probably isn't + # the user trying to quit the repl, but rather an error in the code. So, we + # don't try to inspect groups for SystemExit. Instead, we just print and + # return to the REPL. + if isinstance(result.error, SystemExit): + raise result.error + else: + # Inline our own version of self.showtraceback that can use + # outcome.Error.error directly to print clean tracebacks. + # This also means overriding self.showtraceback does nothing. + sys.last_type, sys.last_value = type(result.error), result.error + sys.last_traceback = result.error.__traceback__ + # see https://docs.python.org/3/library/sys.html#sys.last_exc + if sys.version_info >= (3, 12): + sys.last_exc = result.error + + # We always use sys.excepthook, unlike other implementations. + # This means that overriding self.write also does nothing to tbs. + sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback) + # clear any residual KI + trio.from_thread.run(trio.lowlevel.checkpoint_if_cancelled) + # trio.from_thread.check_cancelled() has too long of a memory + + if sys.platform == "win32": # TODO: test this line + + def raw_input(self, prompt: str = "") -> str: + try: + return input(prompt) + except EOFError: + # check if trio has a pending KI + trio.from_thread.run(trio.lowlevel.checkpoint_if_cancelled) + raise + + else: + + def raw_input(self, prompt: str = "") -> str: + from signal import SIGINT, signal + + assert not self.interrupted + + def install_handler() -> ( + Callable[[int, FrameType | None], None] | int | None + ): + def handler( + sig: int, frame: FrameType | None + ) -> None: # TODO: test this line + self.interrupted = True + token.run_sync_soon(terminal_newline, idempotent=True) + + token = trio.lowlevel.current_trio_token() + + return signal(SIGINT, handler) + + prev_handler = trio.from_thread.run_sync(install_handler) + try: + return input(prompt) + finally: + trio.from_thread.run_sync(signal, SIGINT, prev_handler) + if self.interrupted: # TODO: test this line + raise KeyboardInterrupt + + def write(self, output: str) -> None: + if self.interrupted: # TODO: test this line + assert output == "\nKeyboardInterrupt\n" + sys.stderr.write(output[1:]) + self.interrupted = False + else: + sys.stderr.write(output) + + +async def run_repl(console: TrioInteractiveConsole) -> None: + banner = ( + f"trio REPL {sys.version} on {sys.platform}\n" + f'Use "await" directly instead of "trio.run()".\n' + f'Type "help", "copyright", "credits" or "license" ' + f"for more information.\n" + f'{getattr(sys, "ps1", ">>> ")}import trio' + ) + try: + await trio.to_thread.run_sync(console.interact, banner) + finally: + warnings.filterwarnings( + "ignore", + message=r"^coroutine .* was never awaited$", + category=RuntimeWarning, + ) + + +def main(original_locals: dict[str, object]) -> None: + with contextlib.suppress(ImportError): + import readline # noqa: F401 + + repl_locals: dict[str, object] = {"trio": trio} + for key in { + "__name__", + "__package__", + "__loader__", + "__spec__", + "__builtins__", + "__file__", + }: + repl_locals[key] = original_locals[key] + + console = TrioInteractiveConsole(repl_locals) + trio.run(run_repl, console) diff --git a/contrib/python/trio/trio/_signals.py b/contrib/python/trio/trio/_signals.py new file mode 100644 index 00000000000..729c48ad4e1 --- /dev/null +++ b/contrib/python/trio/trio/_signals.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import signal +from collections import OrderedDict +from contextlib import contextmanager +from typing import TYPE_CHECKING + +import trio + +from ._util import ConflictDetector, is_main_thread + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Callable, Generator, Iterable + from types import FrameType + + from typing_extensions import Self + +# Discussion of signal handling strategies: +# +# - On Windows signals barely exist. There are no options; signal handlers are +# the only available API. +# +# - On Linux signalfd is arguably the natural way. Semantics: signalfd acts as +# an *alternative* signal delivery mechanism. The way you use it is to mask +# out the relevant signals process-wide (so that they don't get delivered +# the normal way), and then when you read from signalfd that actually counts +# as delivering it (despite the mask). The problem with this is that we +# don't have any reliable way to mask out signals process-wide -- the only +# way to do that in Python is to call pthread_sigmask from the main thread +# *before starting any other threads*, and as a library we can't really +# impose that, and the failure mode is annoying (signals get delivered via +# signal handlers whether we want them to or not). +# +# - on macOS/*BSD, kqueue is the natural way. Semantics: kqueue acts as an +# *extra* signal delivery mechanism. Signals are delivered the normal +# way, *and* are delivered to kqueue. So you want to set them to SIG_IGN so +# that they don't end up pending forever (I guess?). I can't find any actual +# docs on how masking and EVFILT_SIGNAL interact. I did see someone note +# that if a signal is pending when the kqueue filter is added then you +# *don't* get notified of that, which makes sense. But still, we have to +# manipulate signal state (e.g. setting SIG_IGN) which as far as Python is +# concerned means we have to do this from the main thread. +# +# So in summary, there don't seem to be any compelling advantages to using the +# platform-native signal notification systems; they're kinda nice, but it's +# simpler to implement the naive signal-handler-based system once and be +# done. (The big advantage would be if there were a reliable way to monitor +# for SIGCHLD from outside the main thread and without interfering with other +# libraries that also want to monitor for SIGCHLD. But there isn't. I guess +# kqueue might give us that, but in kqueue we don't need it, because kqueue +# can directly monitor for child process state changes.) + + +@contextmanager +def _signal_handler( + signals: Iterable[int], + handler: Callable[[int, FrameType | None], object] | int | signal.Handlers | None, +) -> Generator[None, None, None]: + original_handlers = {} + try: + for signum in set(signals): + original_handlers[signum] = signal.signal(signum, handler) + yield + finally: + for signum, original_handler in original_handlers.items(): + signal.signal(signum, original_handler) + + +class SignalReceiver: + def __init__(self) -> None: + # {signal num: None} + self._pending: OrderedDict[int, None] = OrderedDict() + self._lot = trio.lowlevel.ParkingLot() + self._conflict_detector = ConflictDetector( + "only one task can iterate on a signal receiver at a time", + ) + self._closed = False + + def _add(self, signum: int) -> None: + if self._closed: + signal.raise_signal(signum) + else: + self._pending[signum] = None + self._lot.unpark() + + def _redeliver_remaining(self) -> None: + # First make sure that any signals still in the delivery pipeline will + # get redelivered + self._closed = True + + # And then redeliver any that are sitting in pending. This is done + # using a weird recursive construct to make sure we process everything + # even if some of the handlers raise exceptions. + def deliver_next() -> None: + if self._pending: + signum, _ = self._pending.popitem(last=False) + try: + signal.raise_signal(signum) + finally: + deliver_next() + + deliver_next() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> int: + if self._closed: + raise RuntimeError("open_signal_receiver block already exited") + # In principle it would be possible to support multiple concurrent + # calls to __anext__, but doing it without race conditions is quite + # tricky, and there doesn't seem to be any point in trying. + with self._conflict_detector: + if not self._pending: + await self._lot.park() + else: + await trio.lowlevel.checkpoint() + signum, _ = self._pending.popitem(last=False) + return signum + + +def get_pending_signal_count(rec: AsyncIterator[int]) -> int: + """Helper for tests, not public or otherwise used.""" + # open_signal_receiver() always produces SignalReceiver, this should not fail. + assert isinstance(rec, SignalReceiver) + return len(rec._pending) + + +@contextmanager +def open_signal_receiver( + *signals: signal.Signals | int, +) -> Generator[AsyncIterator[int], None, None]: + """A context manager for catching signals. + + Entering this context manager starts listening for the given signals and + returns an async iterator; exiting the context manager stops listening. + + The async iterator blocks until a signal arrives, and then yields it. + + Note that if you leave the ``with`` block while the iterator has + unextracted signals still pending inside it, then they will be + re-delivered using Python's regular signal handling logic. This avoids a + race condition when signals arrives just before we exit the ``with`` + block. + + Args: + signals: the signals to listen for. + + Raises: + TypeError: if no signals were provided. + + RuntimeError: if you try to use this anywhere except Python's main + thread. (This is a Python limitation.) + + Example: + + A common convention for Unix daemons is that they should reload their + configuration when they receive a ``SIGHUP``. Here's a sketch of what + that might look like using :func:`open_signal_receiver`:: + + with trio.open_signal_receiver(signal.SIGHUP) as signal_aiter: + async for signum in signal_aiter: + assert signum == signal.SIGHUP + reload_configuration() + + """ + if not signals: + raise TypeError("No signals were provided") + + if not is_main_thread(): + raise RuntimeError( + "Sorry, open_signal_receiver is only possible when running in " + "Python interpreter's main thread", + ) + token = trio.lowlevel.current_trio_token() + queue = SignalReceiver() + + def handler(signum: int, frame: FrameType | None) -> None: + token.run_sync_soon(queue._add, signum, idempotent=True) + + try: + with _signal_handler(signals, handler): + yield queue + finally: + queue._redeliver_remaining() diff --git a/contrib/python/trio/trio/_socket.py b/contrib/python/trio/trio/_socket.py new file mode 100644 index 00000000000..8454970693d --- /dev/null +++ b/contrib/python/trio/trio/_socket.py @@ -0,0 +1,1325 @@ +from __future__ import annotations + +import os +import select +import socket as _stdlib_socket +import sys +from operator import index +from socket import AddressFamily, SocketKind +from typing import ( + TYPE_CHECKING, + Any, + SupportsIndex, + TypeVar, + Union, + overload, +) + +import idna as _idna + +import trio +from trio._util import wraps as _wraps + +from . import _core + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Iterable + from types import TracebackType + + from typing_extensions import Buffer, Concatenate, ParamSpec, Self, TypeAlias + + from ._abc import HostnameResolver, SocketFactory + + P = ParamSpec("P") + + +T = TypeVar("T") + +# _stdlib_socket.socket supports 13 different socket families, see +# https://docs.python.org/3/library/socket.html#socket-families +# and the return type of several methods in SocketType will depend on those. Typeshed +# has ended up typing those return types as `Any` in most cases, but for users that +# know which family/families they're working in we could make SocketType a generic type, +# where you specify the return values you expect from those methods depending on the +# protocol the socket will be handling. +# But without the ability to default the value to `Any` it will be overly cumbersome for +# most users, so currently we just specify it as `Any`. Otherwise we would write: +# `AddressFormat = TypeVar("AddressFormat")` +# but instead we simply do: +AddressFormat: TypeAlias = Any # type: ignore[explicit-any] + + +# Usage: +# +# async with _try_sync(): +# return sync_call_that_might_fail_with_exception() +# # we only get here if the sync call in fact did fail with a +# # BlockingIOError +# return await do_it_properly_with_a_check_point() +# +class _try_sync: + def __init__( + self, + blocking_exc_override: Callable[[BaseException], bool] | None = None, + ) -> None: + self._blocking_exc_override = blocking_exc_override + + def _is_blocking_io_error(self, exc: BaseException) -> bool: + if self._blocking_exc_override is None: + return isinstance(exc, BlockingIOError) + else: + return self._blocking_exc_override(exc) + + async def __aenter__(self) -> None: + await trio.lowlevel.checkpoint_if_cancelled() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: + if exc_value is not None and self._is_blocking_io_error(exc_value): + # Discard the exception and fall through to the code below the + # block + return True + else: + await trio.lowlevel.cancel_shielded_checkpoint() + # Let the return or exception propagate + return False + + +################################################################ +# Overrides +################################################################ + +_resolver: _core.RunVar[HostnameResolver | None] = _core.RunVar("hostname_resolver") +_socket_factory: _core.RunVar[SocketFactory | None] = _core.RunVar("socket_factory") + + +def set_custom_hostname_resolver( + hostname_resolver: HostnameResolver | None, +) -> HostnameResolver | None: + """Set a custom hostname resolver. + + By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions + use the standard system resolver functions. This function allows you to + customize that behavior. The main intended use case is for testing, but it + might also be useful for using third-party resolvers like `c-ares + <https://c-ares.haxx.se/>`__ (though be warned that these rarely make + perfect drop-in replacements for the system resolver). See + :class:`trio.abc.HostnameResolver` for more details. + + Setting a custom hostname resolver affects all future calls to + :func:`getaddrinfo` and :func:`getnameinfo` within the enclosing call to + :func:`trio.run`. All other hostname resolution in Trio is implemented in + terms of these functions. + + Generally you should call this function just once, right at the beginning + of your program. + + Args: + hostname_resolver (trio.abc.HostnameResolver or None): The new custom + hostname resolver, or None to restore the default behavior. + + Returns: + The previous hostname resolver (which may be None). + + """ + old = _resolver.get(None) + _resolver.set(hostname_resolver) + return old + + +def set_custom_socket_factory( + socket_factory: SocketFactory | None, +) -> SocketFactory | None: + """Set a custom socket object factory. + + This function allows you to replace Trio's normal socket class with a + custom class. This is very useful for testing, and probably a bad idea in + any other circumstance. See :class:`trio.abc.HostnameResolver` for more + details. + + Setting a custom socket factory affects all future calls to :func:`socket` + within the enclosing call to :func:`trio.run`. + + Generally you should call this function just once, right at the beginning + of your program. + + Args: + socket_factory (trio.abc.SocketFactory or None): The new custom + socket factory, or None to restore the default behavior. + + Returns: + The previous socket factory (which may be None). + + """ + old = _socket_factory.get(None) + _socket_factory.set(socket_factory) + return old + + +################################################################ +# getaddrinfo and friends +################################################################ + +# AI_NUMERICSERV may be missing on some older platforms, so use it when available. +# See: https://github.com/python-trio/trio/issues/3133 +_NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST +_NUMERIC_ONLY |= getattr(_stdlib_socket, "AI_NUMERICSERV", 0) + + +# It would be possible to @overload the return value depending on Literal[AddressFamily.INET/6], but should probably be added in typeshed first +async def getaddrinfo( + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes], + ] +]: + """Look up a numeric address given a name. + + Arguments and return values are identical to :func:`socket.getaddrinfo`, + except that this version is async. + + Also, :func:`trio.socket.getaddrinfo` correctly uses IDNA 2008 to process + non-ASCII domain names. (:func:`socket.getaddrinfo` uses IDNA 2003, which + can give the wrong result in some cases and cause you to connect to a + different host than the one you intended; see `bpo-17305 + <https://bugs.python.org/issue17305>`__.) + + This function's behavior can be customized using + :func:`set_custom_hostname_resolver`. + + """ + + # If host and port are numeric, then getaddrinfo doesn't block and we can + # skip the whole thread thing, which seems worthwhile. So we try first + # with the _NUMERIC_ONLY flags set, and then only spawn a thread if that + # fails with EAI_NONAME: + def numeric_only_failure(exc: BaseException) -> bool: + return ( + isinstance(exc, _stdlib_socket.gaierror) + and exc.errno == _stdlib_socket.EAI_NONAME + ) + + async with _try_sync(numeric_only_failure): + return _stdlib_socket.getaddrinfo( + host, + port, + family, + type, + proto, + flags | _NUMERIC_ONLY, + ) + # That failed; it's a real hostname. We better use a thread. + # + # Also, it might be a unicode hostname, in which case we want to do our + # own encoding using the idna module, rather than letting Python do + # it. (Python will use the old IDNA 2003 standard, and possibly get the + # wrong answer - see bpo-17305). However, the idna module is picky, and + # will refuse to process some valid hostname strings, like "::1". So if + # it's already ascii, we pass it through; otherwise, we encode it to. + if isinstance(host, str): + try: + host = host.encode("ascii") + except UnicodeEncodeError: + # UTS-46 defines various normalizations; in particular, by default + # idna.encode will error out if the hostname has Capital Letters + # in it; with uts46=True it will lowercase them instead. + host = _idna.encode(host, uts46=True) + hr = _resolver.get(None) + if hr is not None: + return await hr.getaddrinfo(host, port, family, type, proto, flags) + else: + return await trio.to_thread.run_sync( + _stdlib_socket.getaddrinfo, + host, + port, + family, + type, + proto, + flags, + abandon_on_cancel=True, + ) + + +async def getnameinfo( + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, +) -> tuple[str, str]: + """Look up a name given a numeric address. + + Arguments and return values are identical to :func:`socket.getnameinfo`, + except that this version is async. + + This function's behavior can be customized using + :func:`set_custom_hostname_resolver`. + + """ + hr = _resolver.get(None) + if hr is not None: + return await hr.getnameinfo(sockaddr, flags) + else: + return await trio.to_thread.run_sync( + _stdlib_socket.getnameinfo, + sockaddr, + flags, + abandon_on_cancel=True, + ) + + +async def getprotobyname(name: str) -> int: + """Look up a protocol number by name. (Rarely used.) + + Like :func:`socket.getprotobyname`, but async. + + """ + return await trio.to_thread.run_sync( + _stdlib_socket.getprotobyname, + name, + abandon_on_cancel=True, + ) + + +# obsolete gethostbyname etc. intentionally omitted +# likewise for create_connection (use open_tcp_stream instead) + +################################################################ +# Socket "constructors" +################################################################ + + +def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType: + """Convert a standard library :class:`socket.socket` object into a Trio + socket object. + + """ + return _SocketType(sock) + + +@_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) +def fromfd( + fd: SupportsIndex, + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, +) -> SocketType: + """Like :func:`socket.fromfd`, but returns a Trio socket object.""" + family, type_, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) + return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type_, proto)) + + +if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket, "fromshare") +): + + @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) + def fromshare(info: bytes) -> SocketType: + """Like :func:`socket.fromshare`, but returns a Trio socket object.""" + return from_stdlib_socket(_stdlib_socket.fromshare(info)) + + +if sys.platform == "win32": + FamilyT: TypeAlias = int + TypeT: TypeAlias = int + FamilyDefault = _stdlib_socket.AF_INET +else: + FamilyDefault: None = None + FamilyT: TypeAlias = Union[int, AddressFamily, None] + TypeT: TypeAlias = Union[_stdlib_socket.socket, int] + + +@_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) +def socketpair( + family: FamilyT = FamilyDefault, + type: TypeT = SocketKind.SOCK_STREAM, + proto: int = 0, +) -> tuple[SocketType, SocketType]: + """Like :func:`socket.socketpair`, but returns a pair of Trio socket + objects. + + """ + left, right = _stdlib_socket.socketpair(family, type, proto) + return (from_stdlib_socket(left), from_stdlib_socket(right)) + + +@_wraps(_stdlib_socket.socket, assigned=(), updated=()) +def socket( + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, + fileno: int | None = None, +) -> SocketType: + """Create a new Trio socket, like :class:`socket.socket`. + + This function's behavior can be customized using + :func:`set_custom_socket_factory`. + + """ + if fileno is None: + sf = _socket_factory.get(None) + if sf is not None: + return sf.socket(family, type, proto) + else: + family, type, proto = _sniff_sockopts_for_fileno( # noqa: A001 + family, + type, + proto, + fileno, + ) + stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno) + return from_stdlib_socket(stdlib_socket) + + +def _sniff_sockopts_for_fileno( + family: AddressFamily | int, + type_: SocketKind | int, + proto: int, + fileno: int | None, +) -> tuple[AddressFamily | int, SocketKind | int, int]: + """Correct SOCKOPTS for given fileno, falling back to provided values.""" + # Wrap the raw fileno into a Python socket object + # This object might have the wrong metadata, but it lets us easily call getsockopt + # and then we'll throw it away and construct a new one with the correct metadata. + if sys.platform != "linux": + return family, type_, proto + from socket import ( # type: ignore[attr-defined,unused-ignore] + SO_DOMAIN, + SO_PROTOCOL, + SO_TYPE, + SOL_SOCKET, + ) + + sockobj = _stdlib_socket.socket(family, type_, proto, fileno=fileno) + try: + family = sockobj.getsockopt(SOL_SOCKET, SO_DOMAIN) + proto = sockobj.getsockopt(SOL_SOCKET, SO_PROTOCOL) + type_ = sockobj.getsockopt(SOL_SOCKET, SO_TYPE) + finally: + # Unwrap it again, so that sockobj.__del__ doesn't try to close our socket + sockobj.detach() + return family, type_, proto + + +################################################################ +# SocketType +################################################################ + +# sock.type gets weird stuff set in it, in particular on Linux: +# +# https://bugs.python.org/issue21327 +# +# But on other platforms (e.g. Windows) SOCK_NONBLOCK and SOCK_CLOEXEC aren't +# even defined. To recover the actual socket type (e.g. SOCK_STREAM) from a +# socket.type attribute, mask with this: +_SOCK_TYPE_MASK = ~( + getattr(_stdlib_socket, "SOCK_NONBLOCK", 0) + | getattr(_stdlib_socket, "SOCK_CLOEXEC", 0) +) + + +def _make_simple_sock_method_wrapper( + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], + maybe_avail: bool = False, +) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]: + @_wraps(fn, assigned=("__name__",), updated=()) + async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T: + return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs) + + wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async. + + """ + if maybe_avail: + wrapper.__doc__ += ( + f"Only available on platforms where :meth:`socket.socket.{fn.__name__}` is " + "available." + ) + return wrapper + + +# Helpers to work with the (hostname, port) language that Python uses for socket +# addresses everywhere. Split out into a standalone function so it can be reused by +# FakeNet. + + +# Take an address in Python's representation, and returns a new address in +# the same representation, but with names resolved to numbers, +# etc. +# +# local=True means that the address is being used with bind() or similar +# local=False means that the address is being used with connect() or sendto() or +# similar. +# + + +# Using a TypeVar to indicate we return the same type of address appears to give errors +# when passed a union of address types. +# @overload likely works, but is extremely verbose. +# NOTE: this function does not always checkpoint +async def _resolve_address_nocp( + type_: int, + family: AddressFamily, + proto: int, + *, + ipv6_v6only: bool | int, + address: AddressFormat, + local: bool, +) -> AddressFormat: + # Do some pre-checking (or exit early for non-IP sockets) + if family == _stdlib_socket.AF_INET: + if not isinstance(address, tuple) or not len(address) == 2: + raise ValueError("address should be a (host, port) tuple") + elif family == _stdlib_socket.AF_INET6: + if not isinstance(address, tuple) or not 2 <= len(address) <= 4: + raise ValueError( + "address should be a (host, port, [flowinfo, [scopeid]]) tuple", + ) + elif hasattr(_stdlib_socket, "AF_UNIX") and family == _stdlib_socket.AF_UNIX: + # unwrap path-likes + assert isinstance(address, (str, bytes, os.PathLike)) + return os.fspath(address) + else: + return address + + # -- From here on we know we have IPv4 or IPV6 -- + host: str | None + host, port, *_ = address + # Fast path for the simple case: already-resolved IP address, + # already-resolved port. This is particularly important for UDP, since + # every sendto call goes through here. + if isinstance(port, int) and host is not None: + try: + _stdlib_socket.inet_pton(family, host) + except (OSError, TypeError): + pass + else: + return address + # Special cases to match the stdlib, see gh-277 + if host == "": + host = None + if host == "<broadcast>": + host = "255.255.255.255" + flags = 0 + if local: + flags |= _stdlib_socket.AI_PASSIVE + # Since we always pass in an explicit family here, AI_ADDRCONFIG + # doesn't add any value -- if we have no ipv6 connectivity and are + # working with an ipv6 socket, then things will break soon enough! And + # if we do enable it, then it makes it impossible to even run tests + # for ipv6 address resolution on travis-ci, which as of 2017-03-07 has + # no ipv6. + # flags |= AI_ADDRCONFIG + if family == _stdlib_socket.AF_INET6 and not ipv6_v6only: + flags |= _stdlib_socket.AI_V4MAPPED + gai_res = await getaddrinfo(host, port, family, type_, proto, flags) + # AFAICT from the spec it's not possible for getaddrinfo to return an + # empty list. + assert len(gai_res) >= 1 + # Address is the last item in the first entry + (*_, normed), *_ = gai_res + # The above ignored any flowid and scopeid in the passed-in address, + # so restore them if present: + if family == _stdlib_socket.AF_INET6: + list_normed = list(normed) + assert len(normed) == 4 + if len(address) >= 3: + list_normed[2] = address[2] + if len(address) >= 4: + list_normed[3] = address[3] + return tuple(list_normed) + return normed + + +class SocketType: + def __init__(self) -> None: + # make sure this __init__ works with multiple inheritance + super().__init__() + # and only raises error if it's directly constructed + if type(self) is SocketType: + raise TypeError( + "SocketType is an abstract class; use trio.socket.socket if you " + "want to construct a socket object", + ) + + def detach(self) -> int: + raise NotImplementedError + + def fileno(self) -> int: + raise NotImplementedError + + def getpeername(self) -> AddressFormat: + raise NotImplementedError + + def getsockname(self) -> AddressFormat: + raise NotImplementedError + + @overload + def getsockopt(self, level: int, optname: int) -> int: ... + + @overload + def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: ... + + def getsockopt( + self, + level: int, + optname: int, + buflen: int | None = None, + ) -> int | bytes: + raise NotImplementedError + + @overload + def setsockopt(self, level: int, optname: int, value: int | Buffer) -> None: ... + + @overload + def setsockopt( + self, + level: int, + optname: int, + value: None, + optlen: int, + ) -> None: ... + + def setsockopt( + self, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + raise NotImplementedError + + def listen(self, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + raise NotImplementedError + + def get_inheritable(self) -> bool: + raise NotImplementedError + + def set_inheritable(self, inheritable: bool) -> None: + raise NotImplementedError + + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): + + def share(self, process_id: int) -> bytes: + raise NotImplementedError + + def __enter__(self) -> Self: + raise NotImplementedError + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + raise NotImplementedError + + @property + def family(self) -> AddressFamily: + raise NotImplementedError + + @property + def type(self) -> SocketKind: + raise NotImplementedError + + @property + def proto(self) -> int: + raise NotImplementedError + + @property + def did_shutdown_SHUT_WR(self) -> bool: + """Return True if the socket has been shut down with the SHUT_WR flag""" + raise NotImplementedError + + def __repr__(self) -> str: + raise NotImplementedError + + def dup(self) -> SocketType: + raise NotImplementedError + + def close(self) -> None: + raise NotImplementedError + + async def bind(self, address: AddressFormat) -> None: + raise NotImplementedError + + def shutdown(self, flag: int) -> None: + raise NotImplementedError + + def is_readable(self) -> bool: + """Return True if the socket is readable. This is checked with `select.select` on Windows, otherwise `select.poll`.""" + raise NotImplementedError + + async def wait_writable(self) -> None: + """Convenience method that calls trio.lowlevel.wait_writable for the object.""" + raise NotImplementedError + + async def accept(self) -> tuple[SocketType, AddressFormat]: + raise NotImplementedError + + async def connect(self, address: AddressFormat) -> None: + raise NotImplementedError + + def recv(self, buflen: int, flags: int = 0, /) -> Awaitable[bytes]: + raise NotImplementedError + + def recv_into( + self, + buffer: Buffer, + nbytes: int = 0, + flags: int = 0, + ) -> Awaitable[int]: + raise NotImplementedError + + # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] + def recvfrom( + self, + bufsize: int, + flags: int = 0, + /, + ) -> Awaitable[tuple[bytes, AddressFormat]]: + raise NotImplementedError + + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + self, + buffer: Buffer, + nbytes: int = 0, + flags: int = 0, + ) -> Awaitable[tuple[int, AddressFormat]]: + raise NotImplementedError + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg") + ): + + def recvmsg( + self, + bufsize: int, + ancbufsize: int = 0, + flags: int = 0, + /, + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, object]]: + raise NotImplementedError + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into") + ): + + def recvmsg_into( + self, + buffers: Iterable[Buffer], + ancbufsize: int = 0, + flags: int = 0, + /, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, object]]: + raise NotImplementedError + + def send(self, bytes: Buffer, flags: int = 0, /) -> Awaitable[int]: + raise NotImplementedError + + @overload + async def sendto( + self, + data: Buffer, + address: tuple[object, ...] | str | Buffer, + /, + ) -> int: ... + + @overload + async def sendto( + self, + data: Buffer, + flags: int, + address: tuple[object, ...] | str | Buffer, + /, + ) -> int: ... + + async def sendto(self, *args: object) -> int: + raise NotImplementedError + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg") + ): + + @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) + async def sendmsg( + self, + buffers: Iterable[Buffer], + ancdata: Iterable[tuple[int, int, Buffer]] = (), + flags: int = 0, + address: AddressFormat | None = None, + /, + ) -> int: + raise NotImplementedError + + +# copy docstrings from socket.SocketType / socket.socket +for name, obj in SocketType.__dict__.items(): + # skip dunders and already defined docstrings + if name.startswith("__") or obj.__doc__: + continue + # try both socket.socket and socket.SocketType + for stdlib_type in _stdlib_socket.socket, _stdlib_socket.SocketType: + stdlib_obj = getattr(stdlib_type, name, None) + if stdlib_obj and stdlib_obj.__doc__: + break + else: + continue + obj.__doc__ = stdlib_obj.__doc__ + + +class _SocketType(SocketType): + def __init__(self, sock: _stdlib_socket.socket) -> None: + if type(sock) is not _stdlib_socket.socket: + # For example, ssl.SSLSocket subclasses socket.socket, but we + # certainly don't want to blindly wrap one of those. + raise TypeError( + f"expected object of type 'socket.socket', not '{type(sock).__name__}'", + ) + self._sock = sock + self._sock.setblocking(False) + self._did_shutdown_SHUT_WR = False + + ################################################################ + # Simple + portable methods and attributes + ################################################################ + + # forwarded methods + def detach(self) -> int: + return self._sock.detach() + + def fileno(self) -> int: + return self._sock.fileno() + + def getpeername(self) -> AddressFormat: + return self._sock.getpeername() + + def getsockname(self) -> AddressFormat: + return self._sock.getsockname() + + @overload + def getsockopt(self, level: int, optname: int) -> int: ... + + @overload + def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: ... + + def getsockopt( + self, + level: int, + optname: int, + buflen: int | None = None, + ) -> int | bytes: + if buflen is None: + return self._sock.getsockopt(level, optname) + return self._sock.getsockopt(level, optname, buflen) + + @overload + def setsockopt(self, level: int, optname: int, value: int | Buffer) -> None: ... + + @overload + def setsockopt( + self, + level: int, + optname: int, + value: None, + optlen: int, + ) -> None: ... + + def setsockopt( + self, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + if optlen is None: + if value is None: + raise TypeError( + "invalid value for argument 'value', must not be None when specifying optlen", + ) + return self._sock.setsockopt(level, optname, value) + if value is not None: + raise TypeError( + f"invalid value for argument 'value': {value!r}, must be None when specifying optlen", + ) + + # Note: PyPy may crash here due to setsockopt only supporting + # four parameters. + return self._sock.setsockopt(level, optname, value, optlen) + + def listen(self, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + return self._sock.listen(backlog) + + def get_inheritable(self) -> bool: + return self._sock.get_inheritable() + + def set_inheritable(self, inheritable: bool) -> None: + return self._sock.set_inheritable(inheritable) + + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): + + def share(self, process_id: int) -> bytes: + return self._sock.share(process_id) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self._sock.__exit__(exc_type, exc_value, traceback) + + @property + def family(self) -> AddressFamily: + return self._sock.family + + @property + def type(self) -> SocketKind: + return self._sock.type + + @property + def proto(self) -> int: + return self._sock.proto + + @property + def did_shutdown_SHUT_WR(self) -> bool: + return self._did_shutdown_SHUT_WR + + def __repr__(self) -> str: + return repr(self._sock).replace("socket.socket", "trio.socket.socket") + + def dup(self) -> SocketType: + """Same as :meth:`socket.socket.dup`.""" + return _SocketType(self._sock.dup()) + + def close(self) -> None: + if self._sock.fileno() != -1: + trio.lowlevel.notify_closing(self._sock) + self._sock.close() + + async def bind(self, address: AddressFormat) -> None: + address = await self._resolve_address_nocp(address, local=True) + if ( + hasattr(_stdlib_socket, "AF_UNIX") + and self.family == _stdlib_socket.AF_UNIX + and address[0] + ): + # Use a thread for the filesystem traversal (unless it's an + # abstract domain socket) + return await trio.to_thread.run_sync(self._sock.bind, address) + else: + # POSIX actually says that bind can return EWOULDBLOCK and + # complete asynchronously, like connect. But in practice AFAICT + # there aren't yet any real systems that do this, so we'll worry + # about it when it happens. + await trio.lowlevel.checkpoint() + return self._sock.bind(address) + + def shutdown(self, flag: int) -> None: + # no need to worry about return value b/c always returns None: + self._sock.shutdown(flag) + # only do this if the call succeeded: + if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]: + self._did_shutdown_SHUT_WR = True + + def is_readable(self) -> bool: + # use select.select on Windows, and select.poll everywhere else + if sys.platform == "win32": + rready, _, _ = select.select([self._sock], [], [], 0) + return bool(rready) + p = select.poll() + p.register(self._sock, select.POLLIN) + return bool(p.poll(0)) + + async def wait_writable(self) -> None: + await _core.wait_writable(self._sock) + + async def _resolve_address_nocp( + self, + address: AddressFormat, + *, + local: bool, + ) -> AddressFormat: + if self.family == _stdlib_socket.AF_INET6: + ipv6_v6only = self._sock.getsockopt( + _stdlib_socket.IPPROTO_IPV6, + _stdlib_socket.IPV6_V6ONLY, + ) + else: + ipv6_v6only = False + return await _resolve_address_nocp( + self.type, + self.family, + self.proto, + ipv6_v6only=ipv6_v6only, + address=address, + local=local, + ) + + # args and kwargs must be starred, otherwise pyright complains: + # '"args" member of ParamSpec is valid only when used with *args parameter' + # '"kwargs" member of ParamSpec is valid only when used with **kwargs parameter' + # wait_fn and fn must also be first in the signature + # 'Keyword parameter cannot appear in signature after ParamSpec args parameter' + + async def _nonblocking_helper( + self, + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: + # We have to reconcile two conflicting goals: + # - We want to make it look like we always blocked in doing these + # operations. The obvious way is to always do an IO wait before + # calling the function. + # - But, we also want to provide the correct semantics, and part + # of that means giving correct errors. So, for example, if you + # haven't called .listen(), then .accept() raises an error + # immediately. But in this same circumstance, then on macOS, the + # socket does not register as readable. So if we block waiting + # for read *before* we call accept, then we'll be waiting + # forever instead of properly raising an error. (On Linux, + # interestingly, AFAICT a socket that can't possible read/write + # *does* count as readable/writable for select() purposes. But + # not on macOS.) + # + # So, we have to call the function once, with the appropriate + # cancellation/yielding sandwich if it succeeds, and if it gives + # BlockingIOError *then* we fall back to IO wait. + # + # XX think if this can be combined with the similar logic for IOCP + # submission... + async with _try_sync(): + return fn(self._sock, *args, **kwargs) + # First attempt raised BlockingIOError: + while True: + await wait_fn(self._sock) + try: + return fn(self._sock, *args, **kwargs) + except BlockingIOError: + pass + + ################################################################ + # accept + ################################################################ + + _accept = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.accept, + _core.wait_readable, + ) + + async def accept(self) -> tuple[SocketType, AddressFormat]: + """Like :meth:`socket.socket.accept`, but async.""" + sock, addr = await self._accept() + return from_stdlib_socket(sock), addr + + ################################################################ + # connect + ################################################################ + + async def connect(self, address: AddressFormat) -> None: + # nonblocking connect is weird -- you call it to start things + # off, then the socket becomes writable as a completion + # notification. This means it isn't really cancellable... we close the + # socket if cancelled, to avoid confusion. + try: + address = await self._resolve_address_nocp(address, local=False) + async with _try_sync(): + # An interesting puzzle: can a non-blocking connect() return EINTR + # (= raise InterruptedError)? PEP 475 specifically left this as + # the one place where it lets an InterruptedError escape instead + # of automatically retrying. This is based on the idea that EINTR + # from connect means that the connection was already started, and + # will continue in the background. For a blocking connect, this + # sort of makes sense: if it returns EINTR then the connection + # attempt is continuing in the background, and on many system you + # can't then call connect() again because there is already a + # connect happening. See: + # + # http://www.madore.org/~david/computers/connect-intr.html + # + # For a non-blocking connect, it doesn't make as much sense -- + # surely the interrupt didn't happen after we successfully + # initiated the connect and are just waiting for it to complete, + # because a non-blocking connect does not wait! And the spec + # describes the interaction between EINTR/blocking connect, but + # doesn't have anything useful to say about non-blocking connect: + # + # http://pubs.opengroup.org/onlinepubs/007904975/functions/connect.html + # + # So we have a conundrum: if EINTR means that the connect() hasn't + # happened (like it does for essentially every other syscall), + # then InterruptedError should be caught and retried. If EINTR + # means that the connect() has successfully started, then + # InterruptedError should be caught and ignored. Which should we + # do? + # + # In practice, the resolution is probably that non-blocking + # connect simply never returns EINTR, so the question of how to + # handle it is moot. Someone spelunked macOS/FreeBSD and + # confirmed this is true there: + # + # https://stackoverflow.com/questions/14134440/eintr-and-non-blocking-calls + # + # and exarkun seems to think it's true in general of non-blocking + # calls: + # + # https://twistedmatrix.com/pipermail/twisted-python/2010-September/022864.html + # (and indeed, AFAICT twisted doesn't try to handle + # InterruptedError). + # + # So we don't try to catch InterruptedError. This way if it + # happens, someone will hopefully tell us, and then hopefully we + # can investigate their system to figure out what its semantics + # are. + return self._sock.connect(address) + # It raised BlockingIOError, meaning that it's started the + # connection attempt. We wait for it to complete: + await _core.wait_writable(self._sock) + except trio.Cancelled: + # We can't really cancel a connect, and the socket is in an + # indeterminate state. Better to close it so we don't get + # confused. + self._sock.close() + raise + # Okay, the connect finished, but it might have failed: + err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR) + if err != 0: + raise OSError(err, f"Error connecting to {address!r}: {os.strerror(err)}") + + ################################################################ + # recv + ################################################################ + + # Not possible to typecheck with a Callable (due to DefaultArg), nor with a + # callback Protocol (https://github.com/python/typing/discussions/1040) + # but this seems to work. If not explicitly defined then pyright --verifytypes will + # complain about AmbiguousType + if TYPE_CHECKING: + + def recv(self, buflen: int, flags: int = 0, /) -> Awaitable[bytes]: ... + + recv = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.recv, + _core.wait_readable, + ) + + ################################################################ + # recv_into + ################################################################ + + if TYPE_CHECKING: + + def recv_into( + self, + /, + buffer: Buffer, + nbytes: int = 0, + flags: int = 0, + ) -> Awaitable[int]: ... + + recv_into = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.recv_into, + _core.wait_readable, + ) + + ################################################################ + # recvfrom + ################################################################ + + if TYPE_CHECKING: + # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] + def recvfrom( + self, + bufsize: int, + flags: int = 0, + /, + ) -> Awaitable[tuple[bytes, AddressFormat]]: ... + + recvfrom = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.recvfrom, + _core.wait_readable, + ) + + ################################################################ + # recvfrom_into + ################################################################ + + if TYPE_CHECKING: + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + self, + /, + buffer: Buffer, + nbytes: int = 0, + flags: int = 0, + ) -> Awaitable[tuple[int, AddressFormat]]: ... + + recvfrom_into = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.recvfrom_into, + _core.wait_readable, + ) + + ################################################################ + # recvmsg + ################################################################ + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg") + ): + if TYPE_CHECKING: + + def recvmsg( + self, + bufsize: int, + ancbufsize: int = 0, + flags: int = 0, + /, + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, object]]: ... + + recvmsg = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.recvmsg, + _core.wait_readable, + maybe_avail=True, + ) + + ################################################################ + # recvmsg_into + ################################################################ + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into") + ): + if TYPE_CHECKING: + + def recvmsg_into( + self, + buffers: Iterable[Buffer], + ancbufsize: int = 0, + flags: int = 0, + /, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, object]]: ... + + recvmsg_into = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.recvmsg_into, + _core.wait_readable, + maybe_avail=True, + ) + + ################################################################ + # send + ################################################################ + + if TYPE_CHECKING: + + def send(self, bytes: Buffer, flags: int = 0, /) -> Awaitable[int]: ... + + send = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.send, + _core.wait_writable, + ) + + ################################################################ + # sendto + ################################################################ + + @overload + async def sendto( + self, + data: Buffer, + address: tuple[object, ...] | str | Buffer, + /, + ) -> int: ... + + @overload + async def sendto( + self, + data: Buffer, + flags: int, + address: tuple[object, ...] | str | Buffer, + /, + ) -> int: ... + + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) + async def sendto(self, *args: object) -> int: + """Similar to :meth:`socket.socket.sendto`, but async.""" + # args is: data[, flags], address + # and kwargs are not accepted + args_list = list(args) + args_list[-1] = await self._resolve_address_nocp(args[-1], local=False) + # args_list is Any, which isn't the signature of sendto(). + # We don't care about invalid types, sendto() will do the checking. + return await self._nonblocking_helper( + _core.wait_writable, + _stdlib_socket.socket.sendto, # type: ignore[arg-type] + *args_list, + ) + + ################################################################ + # sendmsg + ################################################################ + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg") + ): + + @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) + async def sendmsg( + self, + buffers: Iterable[Buffer], + ancdata: Iterable[tuple[int, int, Buffer]] = (), + flags: int = 0, + address: AddressFormat | None = None, + /, + ) -> int: + """Similar to :meth:`socket.socket.sendmsg`, but async. + + Only available on platforms where :meth:`socket.socket.sendmsg` is + available. + + """ + if address is not None: + address = await self._resolve_address_nocp(address, local=False) + return await self._nonblocking_helper( + _core.wait_writable, + _stdlib_socket.socket.sendmsg, + buffers, + ancdata, + flags, + address, + ) + + ################################################################ + # sendfile + ################################################################ + + # Not implemented yet: + # async def sendfile(self, file, offset=0, count=None): + # XX + + # Intentionally omitted: + # sendall + # makefile + # setblocking/getblocking + # settimeout/gettimeout + # timeout diff --git a/contrib/python/trio/trio/_ssl.py b/contrib/python/trio/trio/_ssl.py new file mode 100644 index 00000000000..52c5137ea15 --- /dev/null +++ b/contrib/python/trio/trio/_ssl.py @@ -0,0 +1,964 @@ +from __future__ import annotations + +import contextlib +import operator as _operator +import ssl as _stdlib_ssl +from enum import Enum as _Enum +from typing import TYPE_CHECKING, Any, ClassVar, Final as TFinal, Generic, TypeVar + +import trio + +from . import _sync +from ._highlevel_generic import aclose_forcefully +from ._util import ConflictDetector, final +from .abc import Listener, Stream + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from typing_extensions import TypeVarTuple, Unpack + + Ts = TypeVarTuple("Ts") + +# General theory of operation: +# +# We implement an API that closely mirrors the stdlib ssl module's blocking +# API, and we do it using the stdlib ssl module's non-blocking in-memory API. +# The stdlib non-blocking in-memory API is barely documented, and acts as a +# thin wrapper around openssl, whose documentation also leaves something to be +# desired. So here's the main things you need to know to understand the code +# in this file: +# +# We use an ssl.SSLObject, which exposes the four main I/O operations: +# +# - do_handshake: performs the initial handshake. Must be called once at the +# beginning of each connection; is a no-op once it's completed once. +# +# - write: takes some unencrypted data and attempts to send it to the remote +# peer. + +# - read: attempts to decrypt and return some data from the remote peer. +# +# - unwrap: this is weirdly named; maybe it helps to realize that the thing it +# wraps is called SSL_shutdown. It sends a cryptographically signed message +# saying "I'm closing this connection now", and then waits to receive the +# same from the remote peer (unless we already received one, in which case +# it returns immediately). +# +# All of these operations read and write from some in-memory buffers called +# "BIOs", which are an opaque OpenSSL-specific object that's basically +# semantically equivalent to a Python bytearray. When they want to send some +# bytes to the remote peer, they append them to the outgoing BIO, and when +# they want to receive some bytes from the remote peer, they try to pull them +# out of the incoming BIO. "Sending" always succeeds, because the outgoing BIO +# can always be extended to hold more data. "Receiving" acts sort of like a +# non-blocking socket: it might manage to get some data immediately, or it +# might fail and need to be tried again later. We can also directly add or +# remove data from the BIOs whenever we want. +# +# Now the problem is that while these I/O operations are opaque atomic +# operations from the point of view of us calling them, under the hood they +# might require some arbitrary sequence of sends and receives from the remote +# peer. This is particularly true for do_handshake, which generally requires a +# few round trips, but it's also true for write and read, due to an evil thing +# called "renegotiation". +# +# Renegotiation is the process by which one of the peers might arbitrarily +# decide to redo the handshake at any time. Did I mention it's evil? It's +# pretty evil, and almost universally hated. The HTTP/2 spec forbids the use +# of TLS renegotiation for HTTP/2 connections. TLS 1.3 removes it from the +# protocol entirely. It's impossible to trigger a renegotiation if using +# Python's ssl module. OpenSSL's renegotiation support is pretty buggy [1]. +# Nonetheless, it does get used in real life, mostly in two cases: +# +# 1) Normally in TLS 1.2 and below, when the client side of a connection wants +# to present a certificate to prove their identity, that certificate gets sent +# in plaintext. This is bad, because it means that anyone eavesdropping can +# see who's connecting – it's like sending your username in plain text. Not as +# bad as sending your password in plain text, but still, pretty bad. However, +# renegotiations *are* encrypted. So as a workaround, it's not uncommon for +# systems that want to use client certificates to first do an anonymous +# handshake, and then to turn around and do a second handshake (= +# renegotiation) and this time ask for a client cert. Or sometimes this is +# done on a case-by-case basis, e.g. a web server might accept a connection, +# read the request, and then once it sees the page you're asking for it might +# stop and ask you for a certificate. +# +# 2) In principle the same TLS connection can be used for an arbitrarily long +# time, and might transmit arbitrarily large amounts of data. But this creates +# a cryptographic problem: an attacker who has access to arbitrarily large +# amounts of data that's all encrypted using the same key may eventually be +# able to use this to figure out the key. Is this a real practical problem? I +# have no idea, I'm not a cryptographer. In any case, some people worry that +# it's a problem, so their TLS libraries are designed to automatically trigger +# a renegotiation every once in a while on some sort of timer. +# +# The end result is that you might be going along, minding your own business, +# and then *bam*! a wild renegotiation appears! And you just have to cope. +# +# The reason that coping with renegotiations is difficult is that some +# unassuming "read" or "write" call might find itself unable to progress until +# it does a handshake, which remember is a process with multiple round +# trips. So read might have to send data, and write might have to receive +# data, and this might happen multiple times. And some of those attempts might +# fail because there isn't any data yet, and need to be retried. Managing all +# this is pretty complicated. +# +# Here's how openssl (and thus the stdlib ssl module) handle this. All of the +# I/O operations above follow the same rules. When you call one of them: +# +# - it might write some data to the outgoing BIO +# - it might read some data from the incoming BIO +# - it might raise SSLWantReadError if it can't complete without reading more +# data from the incoming BIO. This is important: the "read" in ReadError +# refers to reading from the *underlying* stream. +# - (and in principle it might raise SSLWantWriteError too, but that never +# happens when using memory BIOs, so never mind) +# +# If it doesn't raise an error, then the operation completed successfully +# (though we still need to take any outgoing data out of the memory buffer and +# put it onto the wire). If it *does* raise an error, then we need to retry +# *exactly that method call* later – in particular, if a 'write' failed, we +# need to try again later *with the same data*, because openssl might have +# already committed some of the initial parts of our data to its output even +# though it didn't tell us that, and has remembered that the next time we call +# write it needs to skip the first 1024 bytes or whatever it is. (Well, +# technically, we're actually allowed to call 'write' again with a data buffer +# which is the same as our old one PLUS some extra stuff added onto the end, +# but in Trio that never comes up so never mind.) +# +# There are some people online who claim that once you've gotten a Want*Error +# then the *very next call* you make to openssl *must* be the same as the +# previous one. I'm pretty sure those people are wrong. In particular, it's +# okay to call write, get a WantReadError, and then call read a few times; +# it's just that *the next time you call write*, it has to be with the same +# data. +# +# One final wrinkle: we want our SSLStream to support full-duplex operation, +# i.e. it should be possible for one task to be calling send_all while another +# task is calling receive_some. But renegotiation makes this a big hassle, because +# even if SSLStream's restricts themselves to one task calling send_all and one +# task calling receive_some, those two tasks might end up both wanting to call +# send_all, or both to call receive_some at the same time *on the underlying +# stream*. So we have to do some careful locking to hide this problem from our +# users. +# +# (Renegotiation is evil.) +# +# So our basic strategy is to define a single helper method called "_retry", +# which has generic logic for dealing with SSLWantReadError, pushing data from +# the outgoing BIO to the wire, reading data from the wire to the incoming +# BIO, retrying an I/O call until it works, and synchronizing with other tasks +# that might be calling _retry concurrently. Basically it takes an SSLObject +# non-blocking in-memory method and converts it into a Trio async blocking +# method. _retry is only about 30 lines of code, but all these cases +# multiplied by concurrent calls make it extremely tricky, so there are lots +# of comments down below on the details, and a really extensive test suite in +# test_ssl.py. And now you know *why* it's so tricky, and can probably +# understand how it works. +# +# [1] https://rt.openssl.org/Ticket/Display.html?id=3712 + +# XX how closely should we match the stdlib API? +# - maybe suppress_ragged_eofs=False is a better default? +# - maybe check crypto folks for advice? +# - this is also interesting: https://bugs.python.org/issue8108#msg102867 + +# Definitely keep an eye on Cory's TLS API ideas on security-sig etc. + +# XX document behavior on cancellation/error (i.e.: all is lost abandon +# stream) +# docs will need to make very clear that this is different from all the other +# cancellations in core Trio + + +T = TypeVar("T") + +################################################################ +# SSLStream +################################################################ + +# Ideally, when the user calls SSLStream.receive_some() with no argument, then +# we should do exactly one call to self.transport_stream.receive_some(), +# decrypt everything we got, and return it. Unfortunately, the way openssl's +# API works, we have to pick how much data we want to allow when we call +# read(), and then it (potentially) triggers a call to +# transport_stream.receive_some(). So at the time we pick the amount of data +# to decrypt, we don't know how much data we've read. As a simple heuristic, +# we record the max amount of data returned by previous calls to +# transport_stream.receive_some(), and we use that for future calls to read(). +# But what do we use for the very first call? That's what this constant sets. +# +# Note that the value passed to read() is a limit on the amount of +# *decrypted* data, but we can only see the size of the *encrypted* data +# returned by transport_stream.receive_some(). TLS adds a small amount of +# framing overhead, and TLS compression is rarely used these days because it's +# insecure. So the size of the encrypted data should be a slight over-estimate +# of the size of the decrypted data, which is exactly what we want. +# +# The specific value is not really based on anything; it might be worth tuning +# at some point. But, if you have an TCP connection with the typical 1500 byte +# MTU and an initial window of 10 (see RFC 6928), then the initial burst of +# data will be limited to ~15000 bytes (or a bit less due to IP-level framing +# overhead), so this is chosen to be larger than that. +STARTING_RECEIVE_SIZE: TFinal = 16384 + + +def _is_eof(exc: BaseException | None) -> bool: + # There appears to be a bug on Python 3.10, where SSLErrors + # aren't properly translated into SSLEOFErrors. + # This stringly-typed error check is borrowed from the AnyIO + # project. + return isinstance(exc, _stdlib_ssl.SSLEOFError) or ( + "UNEXPECTED_EOF_WHILE_READING" in getattr(exc, "strerror", ()) + ) + + +class NeedHandshakeError(Exception): + """Some :class:`SSLStream` methods can't return any meaningful data until + after the handshake. If you call them before the handshake, they raise + this error. + + """ + + +class _Once: + __slots__ = ("_afn", "_args", "_done", "started") + + def __init__( + self, + afn: Callable[[*Ts], Awaitable[object]], + *args: Unpack[Ts], + ) -> None: + self._afn = afn + self._args = args + self.started = False + self._done = _sync.Event() + + async def ensure(self, *, checkpoint: bool) -> None: + if not self.started: + self.started = True + await self._afn(*self._args) + self._done.set() + elif not checkpoint and self._done.is_set(): + return + else: + await self._done.wait() + + @property + def done(self) -> bool: + return bool(self._done.is_set()) + + +_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"]) + +# invariant +T_Stream = TypeVar("T_Stream", bound=Stream) + + +@final +class SSLStream(Stream, Generic[T_Stream]): + r"""Encrypted communication using SSL/TLS. + + :class:`SSLStream` wraps an arbitrary :class:`~trio.abc.Stream`, and + allows you to perform encrypted communication over it using the usual + :class:`~trio.abc.Stream` interface. You pass regular data to + :meth:`send_all`, then it encrypts it and sends the encrypted data on the + underlying :class:`~trio.abc.Stream`; :meth:`receive_some` takes encrypted + data out of the underlying :class:`~trio.abc.Stream` and decrypts it + before returning it. + + You should read the standard library's :mod:`ssl` documentation carefully + before attempting to use this class, and probably other general + documentation on SSL/TLS as well. SSL/TLS is subtle and quick to + anger. Really. I'm not kidding. + + Args: + transport_stream (~trio.abc.Stream): The stream used to transport + encrypted data. Required. + + ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` used for + this connection. Required. Usually created by calling + :func:`ssl.create_default_context`. + + server_hostname (str, bytes, or None): The name of the server being + connected to. Used for `SNI + <https://en.wikipedia.org/wiki/Server_Name_Indication>`__ and for + validating the server's certificate (if hostname checking is + enabled). This is effectively mandatory for clients, and actually + mandatory if ``ssl_context.check_hostname`` is ``True``. + + server_side (bool): Whether this stream is acting as a client or + server. Defaults to False, i.e. client mode. + + https_compatible (bool): There are two versions of SSL/TLS commonly + encountered in the wild: the standard version, and the version used + for HTTPS (HTTP-over-SSL/TLS). + + Standard-compliant SSL/TLS implementations always send a + cryptographically signed ``close_notify`` message before closing the + connection. This is important because if the underlying transport + were simply closed, then there wouldn't be any way for the other + side to know whether the connection was intentionally closed by the + peer that they negotiated a cryptographic connection to, or by some + `man-in-the-middle + <https://en.wikipedia.org/wiki/Man-in-the-middle_attack>`__ attacker + who can't manipulate the cryptographic stream, but can manipulate + the transport layer (a so-called "truncation attack"). + + However, this part of the standard is widely ignored by real-world + HTTPS implementations, which means that if you want to interoperate + with them, then you NEED to ignore it too. + + Fortunately this isn't as bad as it sounds, because the HTTP + protocol already includes its own equivalent of ``close_notify``, so + doing this again at the SSL/TLS level is redundant. But not all + protocols do! Therefore, by default Trio implements the safer + standard-compliant version (``https_compatible=False``). But if + you're speaking HTTPS or some other protocol where + ``close_notify``\s are commonly skipped, then you should set + ``https_compatible=True``; with this setting, Trio will neither + expect nor send ``close_notify`` messages. + + If you have code that was written to use :class:`ssl.SSLSocket` and + now you're porting it to Trio, then it may be useful to know that a + difference between :class:`SSLStream` and :class:`ssl.SSLSocket` is + that :class:`~ssl.SSLSocket` implements the + ``https_compatible=True`` behavior by default. + + Attributes: + transport_stream (trio.abc.Stream): The underlying transport stream + that was passed to ``__init__``. An example of when this would be + useful is if you're using :class:`SSLStream` over a + :class:`~trio.SocketStream` and want to call the + :class:`~trio.SocketStream`'s :meth:`~trio.SocketStream.setsockopt` + method. + + Internally, this class is implemented using an instance of + :class:`ssl.SSLObject`, and all of :class:`~ssl.SSLObject`'s methods and + attributes are re-exported as methods and attributes on this class. + However, there is one difference: :class:`~ssl.SSLObject` has several + methods that return information about the encrypted connection, like + :meth:`~ssl.SSLSocket.cipher` or + :meth:`~ssl.SSLSocket.selected_alpn_protocol`. If you call them before the + handshake, when they can't possibly return useful data, then + :class:`ssl.SSLObject` returns None, but :class:`trio.SSLStream` + raises :exc:`NeedHandshakeError`. + + This also means that if you register a SNI callback using + `~ssl.SSLContext.sni_callback`, then the first argument your callback + receives will be a :class:`ssl.SSLObject`. + + """ + + # Note: any new arguments here should likely also be added to + # SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers. + def __init__( + self, + transport_stream: T_Stream, + ssl_context: _stdlib_ssl.SSLContext, + *, + server_hostname: str | bytes | None = None, + server_side: bool = False, + https_compatible: bool = False, + ) -> None: + self.transport_stream: T_Stream = transport_stream + self._state = _State.OK + self._https_compatible = https_compatible + self._outgoing = _stdlib_ssl.MemoryBIO() + self._delayed_outgoing: bytes | None = None + self._incoming = _stdlib_ssl.MemoryBIO() + self._ssl_object = ssl_context.wrap_bio( + self._incoming, + self._outgoing, + server_side=server_side, + server_hostname=server_hostname, + ) + # Tracks whether we've already done the initial handshake + self._handshook = _Once(self._do_handshake) + + # These are used to synchronize access to self.transport_stream + self._inner_send_lock = _sync.StrictFIFOLock() + self._inner_recv_count = 0 + self._inner_recv_lock = _sync.Lock() + + # These are used to make sure that our caller doesn't attempt to make + # multiple concurrent calls to send_all/wait_send_all_might_not_block + # or to receive_some. + self._outer_send_conflict_detector = ConflictDetector( + "another task is currently sending data on this SSLStream", + ) + self._outer_recv_conflict_detector = ConflictDetector( + "another task is currently receiving data on this SSLStream", + ) + + self._estimated_receive_size = STARTING_RECEIVE_SIZE + + _forwarded: ClassVar = { + "context", + "server_side", + "server_hostname", + "session", + "session_reused", + "getpeercert", + "selected_npn_protocol", + "cipher", + "shared_ciphers", + "compression", + "pending", + "get_channel_binding", + "selected_alpn_protocol", + "version", + } + + _after_handshake: ClassVar = { + "session_reused", + "getpeercert", + "selected_npn_protocol", + "cipher", + "shared_ciphers", + "compression", + "get_channel_binding", + "selected_alpn_protocol", + "version", + } + + def __getattr__( # type: ignore[explicit-any] + self, + name: str, + ) -> Any: + if name in self._forwarded: + if name in self._after_handshake and not self._handshook.done: + raise NeedHandshakeError(f"call do_handshake() before calling {name!r}") + + return getattr(self._ssl_object, name) + else: + raise AttributeError(name) + + def __setattr__(self, name: str, value: object) -> None: + if name in self._forwarded: + setattr(self._ssl_object, name, value) + else: + super().__setattr__(name, value) + + def __dir__(self) -> list[str]: + return list(super().__dir__()) + list(self._forwarded) + + def _check_status(self) -> None: + if self._state is _State.OK: + return + elif self._state is _State.BROKEN: + raise trio.BrokenResourceError + elif self._state is _State.CLOSED: + raise trio.ClosedResourceError + else: # pragma: no cover + raise AssertionError() + + # This is probably the single trickiest function in Trio. It has lots of + # comments, though, just make sure to think carefully if you ever have to + # touch it. The big comment at the top of this file will help explain + # too. + async def _retry( + self, + fn: Callable[[*Ts], T], + *args: Unpack[Ts], + ignore_want_read: bool = False, + is_handshake: bool = False, + ) -> T | None: + await trio.lowlevel.checkpoint_if_cancelled() + yielded = False + finished = False + while not finished: + # WARNING: this code needs to be very careful with when it + # calls 'await'! There might be multiple tasks calling this + # function at the same time trying to do different operations, + # so we need to be careful to: + # + # 1) interact with the SSLObject, then + # 2) await on exactly one thing that lets us make forward + # progress, then + # 3) loop or exit + # + # In particular we don't want to yield while interacting with + # the SSLObject (because it's shared state, so someone else + # might come in and mess with it while we're suspended), and + # we don't want to yield *before* starting the operation that + # will help us make progress, because then someone else might + # come in and leapfrog us. + + # Call the SSLObject method, and get its result. + # + # NB: despite what the docs say, SSLWantWriteError can't + # happen – "Writes to memory BIOs will always succeed if + # memory is available: that is their size can grow + # indefinitely." + # https://wiki.openssl.org/index.php/Manual:BIO_s_mem(3) + want_read = False + ret = None + try: + ret = fn(*args) + except _stdlib_ssl.SSLWantReadError: + want_read = True + except (_stdlib_ssl.SSLError, _stdlib_ssl.CertificateError) as exc: + self._state = _State.BROKEN + raise trio.BrokenResourceError from exc + else: + finished = True + if ignore_want_read: + want_read = False + finished = True + to_send = self._outgoing.read() + + # Some versions of SSL_do_handshake have a bug in how they handle + # the TLS 1.3 handshake on the server side: after the handshake + # finishes, they automatically send session tickets, even though + # the client may not be expecting data to arrive at this point and + # sending it could cause a deadlock or lost data. This applies at + # least to OpenSSL 1.1.1c and earlier, and the OpenSSL devs + # currently have no plans to fix it: + # + # https://github.com/openssl/openssl/issues/7948 + # https://github.com/openssl/openssl/issues/7967 + # + # The correct behavior is to wait to send session tickets on the + # first call to SSL_write. (This is what BoringSSL does.) So, we + # use a heuristic to detect when OpenSSL has tried to send session + # tickets, and we manually delay sending them until the + # appropriate moment. For more discussion see: + # + # https://github.com/python-trio/trio/issues/819#issuecomment-517529763 + if ( + is_handshake + and not want_read + and self._ssl_object.server_side + and self._ssl_object.version() == "TLSv1.3" + ): + assert self._delayed_outgoing is None + self._delayed_outgoing = to_send + to_send = b"" + + # Outputs from the above code block are: + # + # - to_send: bytestring; if non-empty then we need to send + # this data to make forward progress + # + # - want_read: True if we need to receive_some some data to make + # forward progress + # + # - finished: False means that we need to retry the call to + # fn(*args) again, after having pushed things forward. True + # means we still need to do whatever was said (in particular + # send any data in to_send), but once we do then we're + # done. + # + # - ret: the operation's return value. (Meaningless unless + # finished is True.) + # + # Invariant: want_read and finished can't both be True at the + # same time. + # + # Now we need to move things forward. There are two things we + # might have to do, and any given operation might require + # either, both, or neither to proceed: + # + # - send the data in to_send + # + # - receive_some some data and put it into the incoming BIO + # + # Our strategy is: if there's data to send, send it; + # *otherwise* if there's data to receive_some, receive_some it. + # + # If both need to happen, then we only send. Why? Well, we + # know that *right now* we have to both send and receive_some + # before the operation can complete. But as soon as we yield, + # that information becomes potentially stale – e.g. while + # we're sending, some other task might go and receive_some the + # data we need and put it into the incoming BIO. And if it + # does, then we *definitely don't* want to do a receive_some – + # there might not be any more data coming, and we'd deadlock! + # We could do something tricky to keep track of whether a + # receive_some happens while we're sending, but the case where + # we have to do both is very unusual (only during a + # renegotiation), so it's better to keep things simple. So we + # do just one potentially-blocking operation, then check again + # for fresh information. + # + # And we prioritize sending over receiving because, if there + # are multiple tasks that want to receive_some, then it + # doesn't matter what order they go in. But if there are + # multiple tasks that want to send, then they each have + # different data, and the data needs to get put onto the wire + # in the same order that it was retrieved from the outgoing + # BIO. So if we have data to send, that *needs* to be the + # *very* *next* *thing* we do, to make sure no-one else sneaks + # in before us. Or if we can't send immediately because + # someone else is, then we at least need to get in line + # immediately. + if to_send: + # NOTE: This relies on the lock being strict FIFO fair! + async with self._inner_send_lock: + yielded = True + try: + if self._delayed_outgoing is not None: + to_send = self._delayed_outgoing + to_send + self._delayed_outgoing = None + await self.transport_stream.send_all(to_send) + except: + # Some unknown amount of our data got sent, and we + # don't know how much. This stream is doomed. + self._state = _State.BROKEN + raise + elif want_read: + # It's possible that someone else is already blocked in + # transport_stream.receive_some. If so then we want to + # wait for them to finish, but we don't want to call + # transport_stream.receive_some again ourselves; we just + # want to loop around and check if their contribution + # helped anything. So we make a note of how many times + # some task has been through here before taking the lock, + # and if it's changed by the time we get the lock, then we + # skip calling transport_stream.receive_some and loop + # around immediately. + recv_count = self._inner_recv_count + async with self._inner_recv_lock: + yielded = True + if recv_count == self._inner_recv_count: + data = await self.transport_stream.receive_some() + if not data: + self._incoming.write_eof() + else: + self._estimated_receive_size = max( + self._estimated_receive_size, + len(data), + ) + self._incoming.write(data) + self._inner_recv_count += 1 + if not yielded: + await trio.lowlevel.cancel_shielded_checkpoint() + return ret + + async def _do_handshake(self) -> None: + try: + await self._retry(self._ssl_object.do_handshake, is_handshake=True) + except: + self._state = _State.BROKEN + raise + + async def do_handshake(self) -> None: + """Ensure that the initial handshake has completed. + + The SSL protocol requires an initial handshake to exchange + certificates, select cryptographic keys, and so forth, before any + actual data can be sent or received. You don't have to call this + method; if you don't, then :class:`SSLStream` will automatically + perform the handshake as needed, the first time you try to send or + receive data. But if you want to trigger it manually – for example, + because you want to look at the peer's certificate before you start + talking to them – then you can call this method. + + If the initial handshake is already in progress in another task, this + waits for it to complete and then returns. + + If the initial handshake has already completed, this returns + immediately without doing anything (except executing a checkpoint). + + .. warning:: If this method is cancelled, then it may leave the + :class:`SSLStream` in an unusable state. If this happens then any + future attempt to use the object will raise + :exc:`trio.BrokenResourceError`. + + """ + self._check_status() + await self._handshook.ensure(checkpoint=True) + + # Most things work if we don't explicitly force do_handshake to be called + # before calling receive_some or send_all, because openssl will + # automatically perform the handshake on the first SSL_{read,write} + # call. BUT, allowing openssl to do this will disable Python's hostname + # checking!!! See: + # https://bugs.python.org/issue30141 + # So we *definitely* have to make sure that do_handshake is called + # before doing anything else. + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: + """Read some data from the underlying transport, decrypt it, and + return it. + + See :meth:`trio.abc.ReceiveStream.receive_some` for details. + + .. warning:: If this method is cancelled while the initial handshake + or a renegotiation are in progress, then it may leave the + :class:`SSLStream` in an unusable state. If this happens then any + future attempt to use the object will raise + :exc:`trio.BrokenResourceError`. + + """ + with self._outer_recv_conflict_detector: + self._check_status() + try: + await self._handshook.ensure(checkpoint=False) + except trio.BrokenResourceError as exc: + # For some reason, EOF before handshake sometimes raises + # SSLSyscallError instead of SSLEOFError (e.g. on my linux + # laptop, but not on appveyor). Thanks openssl. + if self._https_compatible and ( + isinstance(exc.__cause__, _stdlib_ssl.SSLSyscallError) + or _is_eof(exc.__cause__) + ): + await trio.lowlevel.checkpoint() + return b"" + else: + raise + if max_bytes is None: + # If we somehow have more data already in our pending buffer + # than the estimate receive size, bump up our size a bit for + # this read only. + max_bytes = max(self._estimated_receive_size, self._incoming.pending) + else: + max_bytes = _operator.index(max_bytes) + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") + try: + received = await self._retry(self._ssl_object.read, max_bytes) + assert received is not None + return received + except trio.BrokenResourceError as exc: + # This isn't quite equivalent to just returning b"" in the + # first place, because we still end up with self._state set to + # BROKEN. But that's actually fine, because after getting an + # EOF on TLS then the only thing you can do is close the + # stream, and closing doesn't care about the state. + + if self._https_compatible and _is_eof(exc.__cause__): + await trio.lowlevel.checkpoint() + return b"" + else: + raise + + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + """Encrypt some data and then send it on the underlying transport. + + See :meth:`trio.abc.SendStream.send_all` for details. + + .. warning:: If this method is cancelled, then it may leave the + :class:`SSLStream` in an unusable state. If this happens then any + attempt to use the object will raise + :exc:`trio.BrokenResourceError`. + + """ + with self._outer_send_conflict_detector: + self._check_status() + await self._handshook.ensure(checkpoint=False) + # SSLObject interprets write(b"") as an EOF for some reason, which + # is not what we want. + if not data: + await trio.lowlevel.checkpoint() + return + await self._retry(self._ssl_object.write, data) + + async def unwrap(self) -> tuple[Stream, bytes | bytearray]: + """Cleanly close down the SSL/TLS encryption layer, allowing the + underlying stream to be used for unencrypted communication. + + You almost certainly don't need this. + + Returns: + A pair ``(transport_stream, trailing_bytes)``, where + ``transport_stream`` is the underlying transport stream, and + ``trailing_bytes`` is a byte string. Since :class:`SSLStream` + doesn't necessarily know where the end of the encrypted data will + be, it can happen that it accidentally reads too much from the + underlying stream. ``trailing_bytes`` contains this extra data; you + should process it as if it was returned from a call to + ``transport_stream.receive_some(...)``. + + """ + with self._outer_recv_conflict_detector, self._outer_send_conflict_detector: + self._check_status() + await self._handshook.ensure(checkpoint=False) + await self._retry(self._ssl_object.unwrap) + transport_stream = self.transport_stream + self._state = _State.CLOSED + self.transport_stream = None # type: ignore[assignment] # State is CLOSED now, nothing should use + return (transport_stream, self._incoming.read()) + + async def aclose(self) -> None: + """Gracefully shut down this connection, and close the underlying + transport. + + If ``https_compatible`` is False (the default), then this attempts to + first send a ``close_notify`` and then close the underlying stream by + calling its :meth:`~trio.abc.AsyncResource.aclose` method. + + If ``https_compatible`` is set to True, then this simply closes the + underlying stream and marks this stream as closed. + + """ + if self._state is _State.CLOSED: + await trio.lowlevel.checkpoint() + return + if self._state is _State.BROKEN or self._https_compatible: + self._state = _State.CLOSED + await self.transport_stream.aclose() + return + try: + # https_compatible=False, so we're in spec-compliant mode and have + # to send close_notify so that the other side gets a cryptographic + # assurance that we've called aclose. Of course, we can't do + # anything cryptographic until after we've completed the + # handshake: + await self._handshook.ensure(checkpoint=False) + # Then, we call SSL_shutdown *once*, because we want to send a + # close_notify but *not* wait for the other side to send back a + # response. In principle it would be more polite to wait for the + # other side to reply with their own close_notify. However, if + # they aren't paying attention (e.g., if they're just sending + # data and not receiving) then we will never notice our + # close_notify and we'll be waiting forever. Eventually we'll time + # out (hopefully), but it's still kind of nasty. And we can't + # require the other side to always be receiving, because (a) + # backpressure is kind of important, and (b) I bet there are + # broken TLS implementations out there that don't receive all the + # time. (Like e.g. anyone using Python ssl in synchronous mode.) + # + # The send-then-immediately-close behavior is explicitly allowed + # by the TLS specs, so we're ok on that. + # + # Subtlety: SSLObject.unwrap will immediately call it a second + # time, and the second time will raise SSLWantReadError because + # there hasn't been time for the other side to respond + # yet. (Unless they spontaneously sent a close_notify before we + # called this, and it's either already been processed or gets + # pulled out of the buffer by Python's second call.) So the way to + # do what we want is to ignore SSLWantReadError on this call. + # + # Also, because the other side might have already sent + # close_notify and closed their connection then it's possible that + # our attempt to send close_notify will raise + # BrokenResourceError. This is totally legal, and in fact can happen + # with two well-behaved Trio programs talking to each other, so we + # don't want to raise an error. So we suppress BrokenResourceError + # here. (This is safe, because literally the only thing this call + # to _retry will do is send the close_notify alert, so that's + # surely where the error comes from.) + # + # FYI in some cases this could also raise SSLSyscallError which I + # think is because SSL_shutdown is terrible. (Check out that note + # at the bottom of the man page saying that it sometimes gets + # raised spuriously.) I haven't seen this since we switched to + # immediately closing the socket, and I don't know exactly what + # conditions cause it and how to respond, so for now we're just + # letting that happen. But if you start seeing it, then hopefully + # this will give you a little head start on tracking it down, + # because whoa did this puzzle us at the 2017 PyCon sprints. + # + # Also, if someone else is blocked in send/receive, then we aren't + # going to be able to do a clean shutdown. If that happens, we'll + # just do an unclean shutdown. + with contextlib.suppress(trio.BrokenResourceError, trio.BusyResourceError): + await self._retry(self._ssl_object.unwrap, ignore_want_read=True) + except: + # Failure! Kill the stream and move on. + await aclose_forcefully(self.transport_stream) + raise + else: + # Success! Gracefully close the underlying stream. + await self.transport_stream.aclose() + finally: + self._state = _State.CLOSED + + async def wait_send_all_might_not_block(self) -> None: + """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`.""" + # This method's implementation is deceptively simple. + # + # First, we take the outer send lock, because of Trio's standard + # semantics that wait_send_all_might_not_block and send_all + # conflict. + with self._outer_send_conflict_detector: + self._check_status() + # Then we take the inner send lock. We know that no other tasks + # are calling self.send_all or self.wait_send_all_might_not_block, + # because we have the outer_send_lock. But! There might be another + # task calling self.receive_some -> transport_stream.send_all, in + # which case if we were to call + # transport_stream.wait_send_all_might_not_block directly we'd + # have two tasks doing write-related operations on + # transport_stream simultaneously, which is not allowed. We + # *don't* want to raise this conflict to our caller, because it's + # purely an internal affair – all they did was call + # wait_send_all_might_not_block and receive_some at the same time, + # which is totally valid. And waiting for the lock is OK, because + # a call to send_all certainly wouldn't complete while the other + # task holds the lock. + async with self._inner_send_lock: + # Now we have the lock, which creates another potential + # problem: what if a call to self.receive_some attempts to do + # transport_stream.send_all now? It'll have to wait for us to + # finish! But that's OK, because we release the lock as soon + # as the underlying stream becomes writable, and the + # self.receive_some call wasn't going to make any progress + # until then anyway. + # + # Of course, this does mean we might return *before* the + # stream is logically writable, because immediately after we + # return self.receive_some might write some data and make it + # non-writable again. But that's OK too, + # wait_send_all_might_not_block only guarantees that it + # doesn't return late. + await self.transport_stream.wait_send_all_might_not_block() + + +# this is necessary for Sphinx, see also `_abc.py` +SSLStream.__module__ = SSLStream.__module__.replace("._ssl", "") + + +@final +class SSLListener(Listener[SSLStream[T_Stream]]): + """A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers. + + :class:`SSLListener` wraps around another Listener, and converts + all incoming connections to encrypted connections by wrapping them + in a :class:`SSLStream`. + + Args: + transport_listener (~trio.abc.Listener): The listener whose incoming + connections will be wrapped in :class:`SSLStream`. + + ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` that will be + used for incoming connections. + + https_compatible (bool): Passed on to :class:`SSLStream`. + + Attributes: + transport_listener (trio.abc.Listener): The underlying listener that was + passed to ``__init__``. + + """ + + def __init__( + self, + transport_listener: Listener[T_Stream], + ssl_context: _stdlib_ssl.SSLContext, + *, + https_compatible: bool = False, + ) -> None: + self.transport_listener = transport_listener + self._ssl_context = ssl_context + self._https_compatible = https_compatible + + async def accept(self) -> SSLStream[T_Stream]: + """Accept the next connection and wrap it in an :class:`SSLStream`. + + See :meth:`trio.abc.Listener.accept` for details. + + """ + transport_stream = await self.transport_listener.accept() + return SSLStream( + transport_stream, + self._ssl_context, + server_side=True, + https_compatible=self._https_compatible, + ) + + async def aclose(self) -> None: + """Close the transport listener.""" + await self.transport_listener.aclose() diff --git a/contrib/python/trio/trio/_subprocess.py b/contrib/python/trio/trio/_subprocess.py new file mode 100644 index 00000000000..d4faf317f87 --- /dev/null +++ b/contrib/python/trio/trio/_subprocess.py @@ -0,0 +1,1186 @@ +from __future__ import annotations + +import contextlib +import os +import subprocess +import sys +import warnings +from contextlib import ExitStack +from functools import partial +from typing import ( + TYPE_CHECKING, + Final, + Literal, + Protocol, + TypedDict, + Union, + overload, +) + +import trio + +from ._core import ClosedResourceError, TaskStatus +from ._highlevel_generic import StapledStream +from ._subprocess_platform import ( + create_pipe_from_child_output, + create_pipe_to_child_stdin, + wait_child_exiting, +) +from ._sync import Lock +from ._util import NoPublicConstructor, final + +if TYPE_CHECKING: + import signal + from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence + from io import TextIOWrapper + + from typing_extensions import TypeAlias, Unpack + + from ._abc import ReceiveStream, SendStream + + +# Sphinx cannot parse the stringified version +StrOrBytesPath: TypeAlias = Union[str, bytes, os.PathLike[str], os.PathLike[bytes]] + + +# Linux-specific, but has complex lifetime management stuff so we hard-code it +# here instead of hiding it behind the _subprocess_platform abstraction +can_try_pidfd_open: bool +if TYPE_CHECKING: + + def pidfd_open(fd: int, flags: int) -> int: ... + + from ._subprocess_platform import ClosableReceiveStream, ClosableSendStream + +else: + can_try_pidfd_open = True + try: + from os import pidfd_open + except ImportError: + if sys.platform == "linux": + # this workaround is needed on: + # - CPython <= 3.8 + # - non-CPython (maybe?) + # - Anaconda's interpreter (as it is built to assume an older + # than current linux kernel) + # + # The last point implies that other custom builds might not work; + # therefore, no assertion should be here. + import ctypes + + _cdll_for_pidfd_open = ctypes.CDLL(None, use_errno=True) + _cdll_for_pidfd_open.syscall.restype = ctypes.c_long + # pid and flags are actually int-sized, but the syscall() function + # always takes longs. (Except on x32 where long is 32-bits and syscall + # takes 64-bit arguments. But in the unlikely case that anyone is + # using x32, this will still work, b/c we only need to pass in 32 bits + # of data, and the C ABI doesn't distinguish between passing 32-bit vs + # 64-bit integers; our 32-bit values will get loaded into 64-bit + # registers where syscall() will find them.) + _cdll_for_pidfd_open.syscall.argtypes = [ + ctypes.c_long, # syscall number + ctypes.c_long, # pid + ctypes.c_long, # flags + ] + __NR_pidfd_open = 434 + + def pidfd_open(fd: int, flags: int) -> int: + result = _cdll_for_pidfd_open.syscall(__NR_pidfd_open, fd, flags) + if result < 0: # pragma: no cover + err = ctypes.get_errno() + raise OSError(err, os.strerror(err)) + return result + + else: + can_try_pidfd_open = False + + +class HasFileno(Protocol): + """Represents any file-like object that has a file descriptor.""" + + def fileno(self) -> int: ... + + +@final +class Process(metaclass=NoPublicConstructor): + r"""A child process. Like :class:`subprocess.Popen`, but async. + + This class has no public constructor. The most common way to get a + `Process` object is to combine `Nursery.start` with `run_process`:: + + process_object = await nursery.start(run_process, ...) + + This way, `run_process` supervises the process and makes sure that it is + cleaned up properly, while optionally checking the return value, feeding + it input, and so on. + + If you need more control – for example, because you want to spawn a child + process that outlives your program – then another option is to use + `trio.lowlevel.open_process`:: + + process_object = await trio.lowlevel.open_process(...) + + Attributes: + args (str or list): The ``command`` passed at construction time, + specifying the process to execute and its arguments. + pid (int): The process ID of the child process managed by this object. + stdin (trio.abc.SendStream or None): A stream connected to the child's + standard input stream: when you write bytes here, they become available + for the child to read. Only available if the :class:`Process` + was constructed using ``stdin=PIPE``; otherwise this will be None. + stdout (trio.abc.ReceiveStream or None): A stream connected to + the child's standard output stream: when the child writes to + standard output, the written bytes become available for you + to read here. Only available if the :class:`Process` was + constructed using ``stdout=PIPE``; otherwise this will be None. + stderr (trio.abc.ReceiveStream or None): A stream connected to + the child's standard error stream: when the child writes to + standard error, the written bytes become available for you + to read here. Only available if the :class:`Process` was + constructed using ``stderr=PIPE``; otherwise this will be None. + stdio (trio.StapledStream or None): A stream that sends data to + the child's standard input and receives from the child's standard + output. Only available if both :attr:`stdin` and :attr:`stdout` are + available; otherwise this will be None. + + """ + + # We're always in binary mode. + universal_newlines: Final = False + encoding: Final = None + errors: Final = None + + # Available for the per-platform wait_child_exiting() implementations + # to stash some state; waitid platforms use this to avoid spawning + # arbitrarily many threads if wait() keeps getting cancelled. + _wait_for_exit_data: object = None + + def __init__( + self, + popen: subprocess.Popen[bytes], + stdin: SendStream | None, + stdout: ReceiveStream | None, + stderr: ReceiveStream | None, + ) -> None: + self._proc = popen + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr + + self.stdio: StapledStream[SendStream, ReceiveStream] | None = None + if self.stdin is not None and self.stdout is not None: + self.stdio = StapledStream(self.stdin, self.stdout) + + self._wait_lock: Lock = Lock() + + self._pidfd: TextIOWrapper | None = None + if can_try_pidfd_open: + try: + fd: int = pidfd_open(self._proc.pid, 0) + except OSError: # pragma: no cover + # Well, we tried, but it didn't work (probably because we're + # running on an older kernel, or in an older sandbox, that + # hasn't been updated to support pidfd_open). We'll fall back + # on waitid instead. + pass + else: + # It worked! Wrap the raw fd up in a Python file object to + # make sure it'll get closed. + # SIM115: open-file-with-context-handler + self._pidfd = open(fd) # noqa: SIM115 + + self.args: StrOrBytesPath | Sequence[StrOrBytesPath] = self._proc.args + self.pid: int = self._proc.pid + + def __repr__(self) -> str: + returncode = self.returncode + if returncode is None: + status = f"running with PID {self.pid}" + else: + if returncode < 0: + status = f"exited with signal {-returncode}" + else: + status = f"exited with status {returncode}" + return f"<trio.Process {self.args!r}: {status}>" + + @property + def returncode(self) -> int | None: + """The exit status of the process (an integer), or ``None`` if it's + still running. + + By convention, a return code of zero indicates success. On + UNIX, negative values indicate termination due to a signal, + e.g., -11 if terminated by signal 11 (``SIGSEGV``). On + Windows, a process that exits due to a call to + :meth:`Process.terminate` will have an exit status of 1. + + Unlike the standard library `subprocess.Popen.returncode`, you don't + have to call `poll` or `wait` to update this attribute; it's + automatically updated as needed, and will always give you the latest + information. + + """ + result = self._proc.poll() + if result is not None: + self._close_pidfd() + return result + + def _close_pidfd(self) -> None: + if self._pidfd is not None: + trio.lowlevel.notify_closing(self._pidfd.fileno()) + self._pidfd.close() + self._pidfd = None + + async def wait(self) -> int: + """Block until the process exits. + + Returns: + The exit status of the process; see :attr:`returncode`. + """ + async with self._wait_lock: + if self.poll() is None: + if self._pidfd is not None: + with contextlib.suppress( + ClosedResourceError, + ): # something else (probably a call to poll) already closed the pidfd + await trio.lowlevel.wait_readable(self._pidfd.fileno()) + else: + await wait_child_exiting(self) + # We have to use .wait() here, not .poll(), because on macOS + # (and maybe other systems, who knows), there's a race + # condition inside the kernel that creates a tiny window where + # kqueue reports that the process has exited, but + # waitpid(WNOHANG) can't yet reap it. So this .wait() may + # actually block for a tiny fraction of a second. + self._proc.wait() + self._close_pidfd() + assert self._proc.returncode is not None + return self._proc.returncode + + def poll(self) -> int | None: + """Returns the exit status of the process (an integer), or ``None`` if + it's still running. + + Note that on Trio (unlike the standard library `subprocess.Popen`), + ``process.poll()`` and ``process.returncode`` always give the same + result. See `returncode` for more details. This method is only + included to make it easier to port code from `subprocess`. + + """ + return self.returncode + + def send_signal(self, sig: signal.Signals | int) -> None: + """Send signal ``sig`` to the process. + + On UNIX, ``sig`` may be any signal defined in the + :mod:`signal` module, such as ``signal.SIGINT`` or + ``signal.SIGTERM``. On Windows, it may be anything accepted by + the standard library :meth:`subprocess.Popen.send_signal`. + """ + self._proc.send_signal(sig) + + def terminate(self) -> None: + """Terminate the process, politely if possible. + + On UNIX, this is equivalent to + ``send_signal(signal.SIGTERM)``; by convention this requests + graceful termination, but a misbehaving or buggy process might + ignore it. On Windows, :meth:`terminate` forcibly terminates the + process in the same manner as :meth:`kill`. + """ + self._proc.terminate() + + def kill(self) -> None: + """Immediately terminate the process. + + On UNIX, this is equivalent to + ``send_signal(signal.SIGKILL)``. On Windows, it calls + ``TerminateProcess``. In both cases, the process cannot + prevent itself from being killed, but the termination will be + delivered asynchronously; use :meth:`wait` if you want to + ensure the process is actually dead before proceeding. + """ + self._proc.kill() + + +async def _open_process( + command: StrOrBytesPath | Sequence[StrOrBytesPath], + *, + stdin: int | HasFileno | None = None, + stdout: int | HasFileno | None = None, + stderr: int | HasFileno | None = None, + **options: object, +) -> Process: + r"""Execute a child program in a new process. + + After construction, you can interact with the child process by writing data to its + `~trio.Process.stdin` stream (a `~trio.abc.SendStream`), reading data from its + `~trio.Process.stdout` and/or `~trio.Process.stderr` streams (both + `~trio.abc.ReceiveStream`\s), sending it signals using `~trio.Process.terminate`, + `~trio.Process.kill`, or `~trio.Process.send_signal`, and waiting for it to exit + using `~trio.Process.wait`. See `trio.Process` for details. + + Each standard stream is only available if you specify that a pipe should be created + for it. For example, if you pass ``stdin=subprocess.PIPE``, you can write to the + `~trio.Process.stdin` stream, else `~trio.Process.stdin` will be ``None``. + + Unlike `trio.run_process`, this function doesn't do any kind of automatic + management of the child process. It's up to you to implement whatever semantics you + want. + + Args: + command: The command to run. Typically this is a sequence of strings or + bytes such as ``['ls', '-l', 'directory with spaces']``, where the + first element names the executable to invoke and the other elements + specify its arguments. With ``shell=True`` in the ``**options``, or on + Windows, ``command`` can be a string or bytes, which will be parsed + following platform-dependent :ref:`quoting rules + <subprocess-quoting>`. In all cases ``command`` can be a path or a + sequence of paths. + stdin: Specifies what the child process's standard input + stream should connect to: output written by the parent + (``subprocess.PIPE``), nothing (``subprocess.DEVNULL``), + or an open file (pass a file descriptor or something whose + ``fileno`` method returns one). If ``stdin`` is unspecified, + the child process will have the same standard input stream + as its parent. + stdout: Like ``stdin``, but for the child process's standard output + stream. + stderr: Like ``stdin``, but for the child process's standard error + stream. An additional value ``subprocess.STDOUT`` is supported, + which causes the child's standard output and standard error + messages to be intermixed on a single standard output stream, + attached to whatever the ``stdout`` option says to attach it to. + **options: Other :ref:`general subprocess options <subprocess-options>` + are also accepted. + + Returns: + A new `trio.Process` object. + + Raises: + OSError: if the process spawning fails, for example because the + specified command could not be found. + + """ + for key in ("universal_newlines", "text", "encoding", "errors", "bufsize"): + if options.get(key): + raise TypeError( + "trio.Process only supports communicating over " + f"unbuffered byte streams; the '{key}' option is not supported", + ) + + if os.name == "posix": + # TODO: how do paths and sequences thereof play with `shell=True`? + if isinstance(command, (str, bytes)) and not options.get("shell"): + raise TypeError( + "command must be a sequence (not a string or bytes) if " + "shell=False on UNIX systems", + ) + if not isinstance(command, (str, bytes)) and options.get("shell"): + raise TypeError( + "command must be a string or bytes (not a sequence) if " + "shell=True on UNIX systems", + ) + + trio_stdin: ClosableSendStream | None = None + trio_stdout: ClosableReceiveStream | None = None + trio_stderr: ClosableReceiveStream | None = None + # Close the parent's handle for each child side of a pipe; we want the child to + # have the only copy, so that when it exits we can read EOF on our side. The + # trio ends of pipes will be transferred to the Process object, which will be + # responsible for their lifetime. If process spawning fails, though, we still + # want to close them before letting the failure bubble out + with ExitStack() as always_cleanup, ExitStack() as cleanup_on_fail: + if stdin == subprocess.PIPE: + trio_stdin, stdin = create_pipe_to_child_stdin() + always_cleanup.callback(os.close, stdin) + cleanup_on_fail.callback(trio_stdin.close) + if stdout == subprocess.PIPE: + trio_stdout, stdout = create_pipe_from_child_output() + always_cleanup.callback(os.close, stdout) + cleanup_on_fail.callback(trio_stdout.close) + if stderr == subprocess.STDOUT: + # If we created a pipe for stdout, pass the same pipe for + # stderr. If stdout was some non-pipe thing (DEVNULL or a + # given FD), pass the same thing. If stdout was passed as + # None, keep stderr as STDOUT to allow subprocess to dup + # our stdout. Regardless of which of these is applicable, + # don't create a new Trio stream for stderr -- if stdout + # is piped, stderr will be intermixed on the stdout stream. + if stdout is not None: + stderr = stdout + elif stderr == subprocess.PIPE: + trio_stderr, stderr = create_pipe_from_child_output() + always_cleanup.callback(os.close, stderr) + cleanup_on_fail.callback(trio_stderr.close) + + popen = await trio.to_thread.run_sync( + partial( + subprocess.Popen, + command, + stdin=stdin, + stdout=stdout, + stderr=stderr, + **options, + ), + ) + # We did not fail, so dismiss the stack for the trio ends + cleanup_on_fail.pop_all() + + return Process._create(popen, trio_stdin, trio_stdout, trio_stderr) + + +# async function missing await +async def _windows_deliver_cancel(p: Process) -> None: # noqa: RUF029 + try: + p.terminate() + except OSError as exc: + warnings.warn( + RuntimeWarning(f"TerminateProcess on {p!r} failed with: {exc!r}"), + stacklevel=1, + ) + + +async def _posix_deliver_cancel(p: Process) -> None: + try: + p.terminate() + await trio.sleep(5) + warnings.warn( + RuntimeWarning( + f"process {p!r} ignored SIGTERM for 5 seconds. " + "(Maybe you should pass a custom deliver_cancel?) " + "Trying SIGKILL.", + ), + stacklevel=1, + ) + p.kill() + except OSError as exc: + warnings.warn( + RuntimeWarning(f"tried to kill process {p!r}, but failed with: {exc!r}"), + stacklevel=1, + ) + + +# Use a private name, so we can declare platform-specific stubs below. +# This is also the signature read by Sphinx +async def _run_process( + command: StrOrBytesPath | Sequence[StrOrBytesPath], + *, + stdin: bytes | bytearray | memoryview | int | HasFileno | None = b"", + capture_stdout: bool = False, + capture_stderr: bool = False, + check: bool = True, + deliver_cancel: Callable[[Process], Awaitable[object]] | None = None, + task_status: TaskStatus[Process] = trio.TASK_STATUS_IGNORED, + **options: object, +) -> subprocess.CompletedProcess[bytes]: + """Run ``command`` in a subprocess and wait for it to complete. + + This function can be called in two different ways. + + One option is a direct call, like:: + + completed_process_info = await trio.run_process(...) + + In this case, it returns a :class:`subprocess.CompletedProcess` instance + describing the results. Use this if you want to treat a process like a + function call. + + The other option is to run it as a task using `Nursery.start` – the enhanced version + of `~Nursery.start_soon` that lets a task pass back a value during startup:: + + process = await nursery.start(trio.run_process, ...) + + In this case, `~Nursery.start` returns a `Process` object that you can use + to interact with the process while it's running. Use this if you want to + treat a process like a background task. + + Either way, `run_process` makes sure that the process has exited before + returning, handles cancellation, optionally checks for errors, and + provides some convenient shorthands for dealing with the child's + input/output. + + **Input:** `run_process` supports all the same ``stdin=`` arguments as + `subprocess.Popen`. In addition, if you simply want to pass in some fixed + data, you can pass a plain `bytes` object, and `run_process` will take + care of setting up a pipe, feeding in the data you gave, and then sending + end-of-file. The default is ``b""``, which means that the child will receive + an empty stdin. If you want the child to instead read from the parent's + stdin, use ``stdin=None``. + + **Output:** By default, any output produced by the subprocess is + passed through to the standard output and error streams of the + parent Trio process. + + When calling `run_process` directly, you can capture the subprocess's output by + passing ``capture_stdout=True`` to capture the subprocess's standard output, and/or + ``capture_stderr=True`` to capture its standard error. Captured data is collected up + by Trio into an in-memory buffer, and then provided as the + :attr:`~subprocess.CompletedProcess.stdout` and/or + :attr:`~subprocess.CompletedProcess.stderr` attributes of the returned + :class:`~subprocess.CompletedProcess` object. The value for any stream that was not + captured will be ``None``. + + If you want to capture both stdout and stderr while keeping them + separate, pass ``capture_stdout=True, capture_stderr=True``. + + If you want to capture both stdout and stderr but mixed together + in the order they were printed, use: ``capture_stdout=True, stderr=subprocess.STDOUT``. + This directs the child's stderr into its stdout, so the combined + output will be available in the `~subprocess.CompletedProcess.stdout` + attribute. + + If you're using ``await nursery.start(trio.run_process, ...)`` and want to capture + the subprocess's output for further processing, then use ``stdout=subprocess.PIPE`` + and then make sure to read the data out of the `Process.stdout` stream. If you want + to capture stderr separately, use ``stderr=subprocess.PIPE``. If you want to capture + both, but mixed together in the correct order, use ``stdout=subprocess.PIPE, + stderr=subprocess.STDOUT``. + + **Error checking:** If the subprocess exits with a nonzero status + code, indicating failure, :func:`run_process` raises a + :exc:`subprocess.CalledProcessError` exception rather than + returning normally. The captured outputs are still available as + the :attr:`~subprocess.CalledProcessError.stdout` and + :attr:`~subprocess.CalledProcessError.stderr` attributes of that + exception. To disable this behavior, so that :func:`run_process` + returns normally even if the subprocess exits abnormally, pass ``check=False``. + + Note that this can make the ``capture_stdout`` and ``capture_stderr`` + arguments useful even when starting `run_process` as a task: if you only + care about the output if the process fails, then you can enable capturing + and then read the output off of the `~subprocess.CalledProcessError`. + + **Cancellation:** If cancelled, `run_process` sends a termination + request to the subprocess, then waits for it to fully exit. The + ``deliver_cancel`` argument lets you control how the process is terminated. + + .. note:: `run_process` is intentionally similar to the standard library + `subprocess.run`, but some of the defaults are different. Specifically, we + default to: + + - ``check=True``, because `"errors should never pass silently / unless + explicitly silenced" <https://www.python.org/dev/peps/pep-0020/>`__. + + - ``stdin=b""``, because it produces less-confusing results if a subprocess + unexpectedly tries to read from stdin. + + To get the `subprocess.run` semantics, use ``check=False, stdin=None``. + + Args: + command (list or str): The command to run. Typically this is a + sequence of strings such as ``['ls', '-l', 'directory with spaces']``, + where the first element names the executable to invoke and the other + elements specify its arguments. With ``shell=True`` in the + ``**options``, or on Windows, ``command`` may alternatively + be a string, which will be parsed following platform-dependent + :ref:`quoting rules <subprocess-quoting>`. + + stdin (:obj:`bytes`, subprocess.PIPE, file descriptor, or None): The + bytes to provide to the subprocess on its standard input stream, or + ``None`` if the subprocess's standard input should come from the + same place as the parent Trio process's standard input. As is the + case with the :mod:`subprocess` module, you can also pass a file + descriptor or an object with a ``fileno()`` method, in which case + the subprocess's standard input will come from that file. + + When starting `run_process` as a background task, you can also use + ``stdin=subprocess.PIPE``, in which case `Process.stdin` will be a + `~trio.abc.SendStream` that you can use to send data to the child. + + capture_stdout (bool): If true, capture the bytes that the subprocess + writes to its standard output stream and return them in the + `~subprocess.CompletedProcess.stdout` attribute of the returned + `subprocess.CompletedProcess` or `subprocess.CalledProcessError`. + + capture_stderr (bool): If true, capture the bytes that the subprocess + writes to its standard error stream and return them in the + `~subprocess.CompletedProcess.stderr` attribute of the returned + `~subprocess.CompletedProcess` or `subprocess.CalledProcessError`. + + check (bool): If false, don't validate that the subprocess exits + successfully. You should be sure to check the + ``returncode`` attribute of the returned object if you pass + ``check=False``, so that errors don't pass silently. + + deliver_cancel (async function or None): If `run_process` is cancelled, + then it needs to kill the child process. There are multiple ways to + do this, so we let you customize it. + + If you pass None (the default), then the behavior depends on the + platform: + + - On Windows, Trio calls ``TerminateProcess``, which should kill the + process immediately. + + - On Unix-likes, the default behavior is to send a ``SIGTERM``, wait + 5 seconds, and send a ``SIGKILL``. + + Alternatively, you can customize this behavior by passing in an + arbitrary async function, which will be called with the `Process` + object as an argument. For example, the default Unix behavior could + be implemented like this:: + + async def my_deliver_cancel(process): + process.send_signal(signal.SIGTERM) + await trio.sleep(5) + process.send_signal(signal.SIGKILL) + + When the process actually exits, the ``deliver_cancel`` function + will automatically be cancelled – so if the process exits after + ``SIGTERM``, then we'll never reach the ``SIGKILL``. + + In any case, `run_process` will always wait for the child process to + exit before raising `Cancelled`. + + **options: :func:`run_process` also accepts any :ref:`general subprocess + options <subprocess-options>` and passes them on to the + :class:`~trio.Process` constructor. This includes the + ``stdout`` and ``stderr`` options, which provide additional + redirection possibilities such as ``stderr=subprocess.STDOUT``, + ``stdout=subprocess.DEVNULL``, or file descriptors. + + Returns: + + When called normally – a `subprocess.CompletedProcess` instance + describing the return code and outputs. + + When called via `Nursery.start` – a `trio.Process` instance. + + Raises: + UnicodeError: if ``stdin`` is specified as a Unicode string, rather + than bytes + ValueError: if multiple redirections are specified for the same + stream, e.g., both ``capture_stdout=True`` and + ``stdout=subprocess.DEVNULL`` + subprocess.CalledProcessError: if ``check=False`` is not passed + and the process exits with a nonzero exit status + OSError: if an error is encountered starting or communicating with + the process + ExceptionGroup: if exceptions occur in ``deliver_cancel``, + or when exceptions occur when communicating with the subprocess. + If strict_exception_groups is set to false in the global context, + which is deprecated, then single exceptions will be collapsed. + + .. note:: The child process runs in the same process group as the parent + Trio process, so a Ctrl+C will be delivered simultaneously to both + parent and child. If you don't want this behavior, consult your + platform's documentation for starting child processes in a different + process group. + + """ + + if isinstance(stdin, str): + raise UnicodeError("process stdin must be bytes, not str") + if task_status is trio.TASK_STATUS_IGNORED: + if stdin is subprocess.PIPE: + raise ValueError( + "stdout=subprocess.PIPE is only valid with nursery.start, " + "since that's the only way to access the pipe; use nursery.start " + "or pass the data you want to write directly", + ) + if options.get("stdout") is subprocess.PIPE: + raise ValueError( + "stdout=subprocess.PIPE is only valid with nursery.start, " + "since that's the only way to access the pipe", + ) + if options.get("stderr") is subprocess.PIPE: + raise ValueError( + "stderr=subprocess.PIPE is only valid with nursery.start, " + "since that's the only way to access the pipe", + ) + if isinstance(stdin, (bytes, bytearray, memoryview)): + input_ = stdin + options["stdin"] = subprocess.PIPE + else: + # stdin should be something acceptable to Process + # (None, DEVNULL, a file descriptor, etc) and Process + # will raise if it's not + input_ = None + options["stdin"] = stdin + + if capture_stdout: + if "stdout" in options: + raise ValueError("can't specify both stdout and capture_stdout") + options["stdout"] = subprocess.PIPE + if capture_stderr: + if "stderr" in options: + raise ValueError("can't specify both stderr and capture_stderr") + options["stderr"] = subprocess.PIPE + + if deliver_cancel is None: + if os.name == "nt": + deliver_cancel = _windows_deliver_cancel + else: + assert os.name == "posix" + deliver_cancel = _posix_deliver_cancel + + stdout_chunks: list[bytes | bytearray] = [] + stderr_chunks: list[bytes | bytearray] = [] + + async def feed_input(stream: SendStream) -> None: + async with stream: + try: + assert input_ is not None + await stream.send_all(input_) + except trio.BrokenResourceError: + pass + + async def read_output( + stream: ReceiveStream, + chunks: list[bytes | bytearray], + ) -> None: + async with stream: + async for chunk in stream: + chunks.append(chunk) # noqa: PERF401 + + # Opening the process does not need to be inside the nursery, so we put it outside + # so any exceptions get directly seen by users. + proc = await _open_process(command, **options) # type: ignore[arg-type] + async with trio.open_nursery() as nursery: + try: + if input_ is not None: + assert proc.stdin is not None + nursery.start_soon(feed_input, proc.stdin) + proc.stdin = None + proc.stdio = None + if capture_stdout: + assert proc.stdout is not None + nursery.start_soon(read_output, proc.stdout, stdout_chunks) + proc.stdout = None + proc.stdio = None + if capture_stderr: + assert proc.stderr is not None + nursery.start_soon(read_output, proc.stderr, stderr_chunks) + proc.stderr = None + task_status.started(proc) + await proc.wait() + except BaseException: + with trio.CancelScope(shield=True): + killer_cscope = trio.CancelScope(shield=True) + + async def killer() -> None: + with killer_cscope: + await deliver_cancel(proc) + + nursery.start_soon(killer) + await proc.wait() + killer_cscope.cancel(reason="trio internal implementation detail") + raise + + stdout = b"".join(stdout_chunks) if capture_stdout else None + stderr = b"".join(stderr_chunks) if capture_stderr else None + + if proc.returncode and check: + raise subprocess.CalledProcessError( + proc.returncode, + proc.args, + output=stdout, + stderr=stderr, + ) + else: + assert proc.returncode is not None + return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr) + + +# There's a lot of duplication here because type checkers don't +# have a good way to represent overloads that differ only +# slightly. A cheat sheet: +# +# - on Windows, command is Union[str, Sequence[str]]; +# on Unix, command is str if shell=True and Sequence[str] otherwise +# +# - on Windows, there are startupinfo and creationflags options; +# on Unix, there are preexec_fn, restore_signals, start_new_session, +# pass_fds, group (3.9+), extra_groups (3.9+), user (3.9+), +# umask (3.9+), pipesize (3.10+), process_group (3.11+) +# +# - run_process() has the signature of open_process() plus arguments +# capture_stdout, capture_stderr, check, deliver_cancel, the ability +# to pass bytes as stdin, and the ability to run in `nursery.start` + + +class GeneralProcessArgs(TypedDict, total=False): + """Arguments shared between all runs.""" + + stdout: int | HasFileno | None + stderr: int | HasFileno | None + close_fds: bool + cwd: StrOrBytesPath | None + env: Mapping[str, str] | None + executable: StrOrBytesPath | None + + +if TYPE_CHECKING: + if sys.platform == "win32": + + class WindowsProcessArgs(GeneralProcessArgs, total=False): + """Arguments shared between all Windows runs.""" + + shell: bool + startupinfo: subprocess.STARTUPINFO | None + creationflags: int + + async def open_process( + command: StrOrBytesPath | Sequence[StrOrBytesPath], + *, + stdin: int | HasFileno | None = None, + **kwargs: Unpack[WindowsProcessArgs], + ) -> trio.Process: + r"""Execute a child program in a new process. + + After construction, you can interact with the child process by writing data to its + `~trio.Process.stdin` stream (a `~trio.abc.SendStream`), reading data from its + `~trio.Process.stdout` and/or `~trio.Process.stderr` streams (both + `~trio.abc.ReceiveStream`\s), sending it signals using `~trio.Process.terminate`, + `~trio.Process.kill`, or `~trio.Process.send_signal`, and waiting for it to exit + using `~trio.Process.wait`. See `trio.Process` for details. + + Each standard stream is only available if you specify that a pipe should be created + for it. For example, if you pass ``stdin=subprocess.PIPE``, you can write to the + `~trio.Process.stdin` stream, else `~trio.Process.stdin` will be ``None``. + + Unlike `trio.run_process`, this function doesn't do any kind of automatic + management of the child process. It's up to you to implement whatever semantics you + want. + + Args: + command (list or str): The command to run. Typically this is a + sequence of strings such as ``['ls', '-l', 'directory with spaces']``, + where the first element names the executable to invoke and the other + elements specify its arguments. With ``shell=True`` in the + ``**options``, or on Windows, ``command`` may alternatively + be a string, which will be parsed following platform-dependent + :ref:`quoting rules <subprocess-quoting>`. + stdin: Specifies what the child process's standard input + stream should connect to: output written by the parent + (``subprocess.PIPE``), nothing (``subprocess.DEVNULL``), + or an open file (pass a file descriptor or something whose + ``fileno`` method returns one). If ``stdin`` is unspecified, + the child process will have the same standard input stream + as its parent. + stdout: Like ``stdin``, but for the child process's standard output + stream. + stderr: Like ``stdin``, but for the child process's standard error + stream. An additional value ``subprocess.STDOUT`` is supported, + which causes the child's standard output and standard error + messages to be intermixed on a single standard output stream, + attached to whatever the ``stdout`` option says to attach it to. + **options: Other :ref:`general subprocess options <subprocess-options>` + are also accepted. + + Returns: + A new `trio.Process` object. + + Raises: + OSError: if the process spawning fails, for example because the + specified command could not be found. + + """ + ... + + async def run_process( + command: StrOrBytesPath | Sequence[StrOrBytesPath], + *, + task_status: TaskStatus[Process] = trio.TASK_STATUS_IGNORED, + stdin: bytes | bytearray | memoryview | int | HasFileno | None = None, + capture_stdout: bool = False, + capture_stderr: bool = False, + check: bool = True, + deliver_cancel: Callable[[Process], Awaitable[object]] | None = None, + **kwargs: Unpack[WindowsProcessArgs], + ) -> subprocess.CompletedProcess[bytes]: + """Run ``command`` in a subprocess and wait for it to complete. + + This function can be called in two different ways. + + One option is a direct call, like:: + + completed_process_info = await trio.run_process(...) + + In this case, it returns a :class:`subprocess.CompletedProcess` instance + describing the results. Use this if you want to treat a process like a + function call. + + The other option is to run it as a task using `Nursery.start` – the enhanced version + of `~Nursery.start_soon` that lets a task pass back a value during startup:: + + process = await nursery.start(trio.run_process, ...) + + In this case, `~Nursery.start` returns a `Process` object that you can use + to interact with the process while it's running. Use this if you want to + treat a process like a background task. + + Either way, `run_process` makes sure that the process has exited before + returning, handles cancellation, optionally checks for errors, and + provides some convenient shorthands for dealing with the child's + input/output. + + **Input:** `run_process` supports all the same ``stdin=`` arguments as + `subprocess.Popen`. In addition, if you simply want to pass in some fixed + data, you can pass a plain `bytes` object, and `run_process` will take + care of setting up a pipe, feeding in the data you gave, and then sending + end-of-file. The default is ``b""``, which means that the child will receive + an empty stdin. If you want the child to instead read from the parent's + stdin, use ``stdin=None``. + + **Output:** By default, any output produced by the subprocess is + passed through to the standard output and error streams of the + parent Trio process. + + When calling `run_process` directly, you can capture the subprocess's output by + passing ``capture_stdout=True`` to capture the subprocess's standard output, and/or + ``capture_stderr=True`` to capture its standard error. Captured data is collected up + by Trio into an in-memory buffer, and then provided as the + :attr:`~subprocess.CompletedProcess.stdout` and/or + :attr:`~subprocess.CompletedProcess.stderr` attributes of the returned + :class:`~subprocess.CompletedProcess` object. The value for any stream that was not + captured will be ``None``. + + If you want to capture both stdout and stderr while keeping them + separate, pass ``capture_stdout=True, capture_stderr=True``. + + If you want to capture both stdout and stderr but mixed together + in the order they were printed, use: ``capture_stdout=True, stderr=subprocess.STDOUT``. + This directs the child's stderr into its stdout, so the combined + output will be available in the `~subprocess.CompletedProcess.stdout` + attribute. + + If you're using ``await nursery.start(trio.run_process, ...)`` and want to capture + the subprocess's output for further processing, then use ``stdout=subprocess.PIPE`` + and then make sure to read the data out of the `Process.stdout` stream. If you want + to capture stderr separately, use ``stderr=subprocess.PIPE``. If you want to capture + both, but mixed together in the correct order, use ``stdout=subprocess.PIPE, + stderr=subprocess.STDOUT``. + + **Error checking:** If the subprocess exits with a nonzero status + code, indicating failure, :func:`run_process` raises a + :exc:`subprocess.CalledProcessError` exception rather than + returning normally. The captured outputs are still available as + the :attr:`~subprocess.CalledProcessError.stdout` and + :attr:`~subprocess.CalledProcessError.stderr` attributes of that + exception. To disable this behavior, so that :func:`run_process` + returns normally even if the subprocess exits abnormally, pass ``check=False``. + + Note that this can make the ``capture_stdout`` and ``capture_stderr`` + arguments useful even when starting `run_process` as a task: if you only + care about the output if the process fails, then you can enable capturing + and then read the output off of the `~subprocess.CalledProcessError`. + + **Cancellation:** If cancelled, `run_process` sends a termination + request to the subprocess, then waits for it to fully exit. The + ``deliver_cancel`` argument lets you control how the process is terminated. + + .. note:: `run_process` is intentionally similar to the standard library + `subprocess.run`, but some of the defaults are different. Specifically, we + default to: + + - ``check=True``, because `"errors should never pass silently / unless + explicitly silenced" <https://www.python.org/dev/peps/pep-0020/>`__. + + - ``stdin=b""``, because it produces less-confusing results if a subprocess + unexpectedly tries to read from stdin. + + To get the `subprocess.run` semantics, use ``check=False, stdin=None``. + + Args: + command (list or str): The command to run. Typically this is a + sequence of strings such as ``['ls', '-l', 'directory with spaces']``, + where the first element names the executable to invoke and the other + elements specify its arguments. With ``shell=True`` in the + ``**options``, or on Windows, ``command`` may alternatively + be a string, which will be parsed following platform-dependent + :ref:`quoting rules <subprocess-quoting>`. + + stdin (:obj:`bytes`, subprocess.PIPE, file descriptor, or None): The + bytes to provide to the subprocess on its standard input stream, or + ``None`` if the subprocess's standard input should come from the + same place as the parent Trio process's standard input. As is the + case with the :mod:`subprocess` module, you can also pass a file + descriptor or an object with a ``fileno()`` method, in which case + the subprocess's standard input will come from that file. + + When starting `run_process` as a background task, you can also use + ``stdin=subprocess.PIPE``, in which case `Process.stdin` will be a + `~trio.abc.SendStream` that you can use to send data to the child. + + capture_stdout (bool): If true, capture the bytes that the subprocess + writes to its standard output stream and return them in the + `~subprocess.CompletedProcess.stdout` attribute of the returned + `subprocess.CompletedProcess` or `subprocess.CalledProcessError`. + + capture_stderr (bool): If true, capture the bytes that the subprocess + writes to its standard error stream and return them in the + `~subprocess.CompletedProcess.stderr` attribute of the returned + `~subprocess.CompletedProcess` or `subprocess.CalledProcessError`. + + check (bool): If false, don't validate that the subprocess exits + successfully. You should be sure to check the + ``returncode`` attribute of the returned object if you pass + ``check=False``, so that errors don't pass silently. + + deliver_cancel (async function or None): If `run_process` is cancelled, + then it needs to kill the child process. There are multiple ways to + do this, so we let you customize it. + + If you pass None (the default), then the behavior depends on the + platform: + + - On Windows, Trio calls ``TerminateProcess``, which should kill the + process immediately. + + - On Unix-likes, the default behavior is to send a ``SIGTERM``, wait + 5 seconds, and send a ``SIGKILL``. + + Alternatively, you can customize this behavior by passing in an + arbitrary async function, which will be called with the `Process` + object as an argument. For example, the default Unix behavior could + be implemented like this:: + + async def my_deliver_cancel(process): + process.send_signal(signal.SIGTERM) + await trio.sleep(5) + process.send_signal(signal.SIGKILL) + + When the process actually exits, the ``deliver_cancel`` function + will automatically be cancelled – so if the process exits after + ``SIGTERM``, then we'll never reach the ``SIGKILL``. + + In any case, `run_process` will always wait for the child process to + exit before raising `Cancelled`. + + **options: :func:`run_process` also accepts any :ref:`general subprocess + options <subprocess-options>` and passes them on to the + :class:`~trio.Process` constructor. This includes the + ``stdout`` and ``stderr`` options, which provide additional + redirection possibilities such as ``stderr=subprocess.STDOUT``, + ``stdout=subprocess.DEVNULL``, or file descriptors. + + Returns: + + When called normally – a `subprocess.CompletedProcess` instance + describing the return code and outputs. + + When called via `Nursery.start` – a `trio.Process` instance. + + Raises: + UnicodeError: if ``stdin`` is specified as a Unicode string, rather + than bytes + ValueError: if multiple redirections are specified for the same + stream, e.g., both ``capture_stdout=True`` and + ``stdout=subprocess.DEVNULL`` + subprocess.CalledProcessError: if ``check=False`` is not passed + and the process exits with a nonzero exit status + OSError: if an error is encountered starting or communicating with + the process + + .. note:: The child process runs in the same process group as the parent + Trio process, so a Ctrl+C will be delivered simultaneously to both + parent and child. If you don't want this behavior, consult your + platform's documentation for starting child processes in a different + process group. + + """ + ... + + else: # Unix + # pyright doesn't give any error about overloads missing docstrings as they're + # overloads. But might still be a problem for other static analyzers / docstring + # readers (?) + + class UnixProcessArgs3_9(GeneralProcessArgs, total=False): + """Arguments shared between all Unix runs.""" + + preexec_fn: Callable[[], object] | None + restore_signals: bool + start_new_session: bool + pass_fds: Sequence[int] + + # 3.9+ + group: str | int | None + extra_groups: Iterable[str | int] | None + user: str | int | None + umask: int + + class UnixProcessArgs3_10(UnixProcessArgs3_9, total=False): + """Arguments shared between all Unix runs on 3.10+.""" + + pipesize: int + + class UnixProcessArgs3_11(UnixProcessArgs3_10, total=False): + """Arguments shared between all Unix runs on 3.11+.""" + + process_group: int | None + + class UnixRunProcessMixin(TypedDict, total=False): + """Arguments unique to run_process on Unix.""" + + task_status: TaskStatus[Process] + capture_stdout: bool + capture_stderr: bool + check: bool + deliver_cancel: Callable[[Process], Awaitable[None]] | None + + # TODO: once https://github.com/python/mypy/issues/18692 is + # fixed, move the `UnixRunProcessArgs` definition down. + if sys.version_info >= (3, 11): + UnixProcessArgs = UnixProcessArgs3_11 + + class UnixRunProcessArgs(UnixProcessArgs3_11, UnixRunProcessMixin): + """Arguments for run_process on Unix with 3.11+""" + + elif sys.version_info >= (3, 10): + UnixProcessArgs = UnixProcessArgs3_10 + + class UnixRunProcessArgs(UnixProcessArgs3_10, UnixRunProcessMixin): + """Arguments for run_process on Unix with 3.10+""" + + else: + UnixProcessArgs = UnixProcessArgs3_9 + + class UnixRunProcessArgs(UnixProcessArgs3_9, UnixRunProcessMixin): + """Arguments for run_process on Unix with 3.9+""" + + @overload # type: ignore[no-overload-impl] + async def open_process( + command: StrOrBytesPath, + *, + stdin: int | HasFileno | None = None, + shell: Literal[True], + **kwargs: Unpack[UnixProcessArgs], + ) -> trio.Process: ... + + @overload + async def open_process( + command: Sequence[StrOrBytesPath], + *, + stdin: int | HasFileno | None = None, + shell: bool = False, + **kwargs: Unpack[UnixProcessArgs], + ) -> trio.Process: ... + + @overload # type: ignore[no-overload-impl] + async def run_process( + command: StrOrBytesPath, + *, + stdin: bytes | bytearray | memoryview | int | HasFileno | None = b"", + shell: Literal[True], + **kwargs: Unpack[UnixRunProcessArgs], + ) -> subprocess.CompletedProcess[bytes]: ... + + @overload + async def run_process( + command: Sequence[StrOrBytesPath], + *, + stdin: bytes | bytearray | memoryview | int | HasFileno | None = b"", + shell: bool = False, + **kwargs: Unpack[UnixRunProcessArgs], + ) -> subprocess.CompletedProcess[bytes]: ... + +else: + # At runtime, use the actual implementations. + open_process = _open_process + open_process.__name__ = open_process.__qualname__ = "open_process" + + run_process = _run_process + run_process.__name__ = run_process.__qualname__ = "run_process" diff --git a/contrib/python/trio/trio/_subprocess_platform/__init__.py b/contrib/python/trio/trio/_subprocess_platform/__init__.py new file mode 100644 index 00000000000..daa28d8cd2d --- /dev/null +++ b/contrib/python/trio/trio/_subprocess_platform/__init__.py @@ -0,0 +1,123 @@ +# Platform-specific subprocess bits'n'pieces. +from __future__ import annotations + +import os +import sys +from typing import TYPE_CHECKING + +import trio + +from .. import _core, _subprocess +from .._abc import ReceiveStream, SendStream # noqa: TC001 + +_wait_child_exiting_error: ImportError | None = None +_create_child_pipe_error: ImportError | None = None + + +if TYPE_CHECKING: + # internal types for the pipe representations used in type checking only + class ClosableSendStream(SendStream): + def close(self) -> None: ... + + class ClosableReceiveStream(ReceiveStream): + def close(self) -> None: ... + + +# Fallback versions of the functions provided -- implementations +# per OS are imported atop these at the bottom of the module. +async def wait_child_exiting(process: _subprocess.Process) -> None: + """Block until the child process managed by ``process`` is exiting. + + It is invalid to call this function if the process has already + been waited on; that is, ``process.returncode`` must be None. + + When this function returns, it indicates that a call to + :meth:`subprocess.Popen.wait` will immediately be able to + return the process's exit status. The actual exit status is not + consumed by this call, since :class:`~subprocess.Popen` wants + to be able to do that itself. + """ + raise NotImplementedError from _wait_child_exiting_error # pragma: no cover + + +def create_pipe_to_child_stdin() -> tuple[ClosableSendStream, int]: + """Create a new pipe suitable for sending data from this + process to the standard input of a child we're about to spawn. + + Returns: + A pair ``(trio_end, subprocess_end)`` where ``trio_end`` is a + :class:`~trio.abc.SendStream` and ``subprocess_end`` is + something suitable for passing as the ``stdin`` argument of + :class:`subprocess.Popen`. + """ + raise NotImplementedError from _create_child_pipe_error # pragma: no cover + + +def create_pipe_from_child_output() -> tuple[ClosableReceiveStream, int]: + """Create a new pipe suitable for receiving data into this + process from the standard output or error stream of a child + we're about to spawn. + + Returns: + A pair ``(trio_end, subprocess_end)`` where ``trio_end`` is a + :class:`~trio.abc.ReceiveStream` and ``subprocess_end`` is + something suitable for passing as the ``stdin`` argument of + :class:`subprocess.Popen`. + """ + raise NotImplementedError from _create_child_pipe_error # pragma: no cover + + +try: + if sys.platform == "win32": + from .windows import wait_child_exiting + elif sys.platform != "linux" and (TYPE_CHECKING or hasattr(_core, "wait_kevent")): + from .kqueue import wait_child_exiting + else: + # as it's an exported symbol, noqa'd + from .waitid import wait_child_exiting # noqa: F401 +except ImportError as ex: # pragma: no cover + _wait_child_exiting_error = ex + +try: + if TYPE_CHECKING: + # Not worth type checking these definitions + pass + + elif os.name == "posix": + + def create_pipe_to_child_stdin() -> tuple[trio.lowlevel.FdStream, int]: + rfd, wfd = os.pipe() + return trio.lowlevel.FdStream(wfd), rfd + + def create_pipe_from_child_output() -> tuple[trio.lowlevel.FdStream, int]: + rfd, wfd = os.pipe() + return trio.lowlevel.FdStream(rfd), wfd + + elif os.name == "nt": + import msvcrt + + # This isn't exported or documented, but it's also not + # underscore-prefixed, and seems kosher to use. The asyncio docs + # for 3.5 included an example that imported socketpair from + # windows_utils (before socket.socketpair existed on Windows), and + # when asyncio.windows_utils.socketpair was removed in 3.7, the + # removal was mentioned in the release notes. + from asyncio.windows_utils import pipe as windows_pipe + + from .._windows_pipes import PipeReceiveStream, PipeSendStream + + def create_pipe_to_child_stdin() -> tuple[PipeSendStream, int]: + # for stdin, we want the write end (our end) to use overlapped I/O + rh, wh = windows_pipe(overlapped=(False, True)) + return PipeSendStream(wh), msvcrt.open_osfhandle(rh, os.O_RDONLY) + + def create_pipe_from_child_output() -> tuple[PipeReceiveStream, int]: + # for stdout/err, it's the read end that's overlapped + rh, wh = windows_pipe(overlapped=(True, False)) + return PipeReceiveStream(rh), msvcrt.open_osfhandle(wh, 0) + + else: # pragma: no cover + raise ImportError("pipes not implemented on this platform") + +except ImportError as ex: # pragma: no cover + _create_child_pipe_error = ex diff --git a/contrib/python/trio/trio/_subprocess_platform/kqueue.py b/contrib/python/trio/trio/_subprocess_platform/kqueue.py new file mode 100644 index 00000000000..2283bb5360c --- /dev/null +++ b/contrib/python/trio/trio/_subprocess_platform/kqueue.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import select +import sys +from typing import TYPE_CHECKING + +from .. import _core, _subprocess + +assert (sys.platform != "win32" and sys.platform != "linux") or not TYPE_CHECKING + + +async def wait_child_exiting(process: _subprocess.Process) -> None: + kqueue = _core.current_kqueue() + try: + from select import KQ_NOTE_EXIT + except ImportError: # pragma: no cover + # pypy doesn't define KQ_NOTE_EXIT: + # https://bitbucket.org/pypy/pypy/issues/2921/ + # I verified this value against both Darwin and FreeBSD + KQ_NOTE_EXIT = 0x80000000 + + def make_event(flags: int) -> select.kevent: + return select.kevent( + process.pid, + filter=select.KQ_FILTER_PROC, + flags=flags, + fflags=KQ_NOTE_EXIT, + ) + + try: + kqueue.control([make_event(select.KQ_EV_ADD | select.KQ_EV_ONESHOT)], 0) + except ProcessLookupError: # pragma: no cover + # This can supposedly happen if the process is in the process + # of exiting, and it can even be the case that kqueue says the + # process doesn't exist before waitpid(WNOHANG) says it hasn't + # exited yet. See the discussion in https://chromium.googlesource.com/ + # chromium/src/base/+/master/process/kill_mac.cc . + # We haven't actually seen this error occur since we added + # locking to prevent multiple calls to wait_child_exiting() + # for the same process simultaneously, but given the explanation + # in Chromium it seems we should still keep the check. + return + + def abort(_: _core.RaiseCancelT) -> _core.Abort: + kqueue.control([make_event(select.KQ_EV_DELETE)], 0) + return _core.Abort.SUCCEEDED + + await _core.wait_kevent(process.pid, select.KQ_FILTER_PROC, abort) diff --git a/contrib/python/trio/trio/_subprocess_platform/waitid.py b/contrib/python/trio/trio/_subprocess_platform/waitid.py new file mode 100644 index 00000000000..ebf83b48028 --- /dev/null +++ b/contrib/python/trio/trio/_subprocess_platform/waitid.py @@ -0,0 +1,113 @@ +import errno +import math +import os +import sys +from typing import TYPE_CHECKING + +from .. import _core, _subprocess +from .._sync import CapacityLimiter, Event +from .._threads import to_thread_run_sync + +assert (sys.platform != "win32" and sys.platform != "darwin") or not TYPE_CHECKING + +try: + from os import waitid + + def sync_wait_reapable(pid: int) -> None: + waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT) + +except ImportError: + # pypy doesn't define os.waitid so we need to pull it out ourselves + # using cffi: https://bitbucket.org/pypy/pypy/issues/2922/ + import cffi + + waitid_ffi = cffi.FFI() + + # Believe it or not, siginfo_t starts with fields in the + # same layout on both Linux and Darwin. The Linux structure + # is bigger so that's what we use to size `pad`; while + # there are a few extra fields in there, most of it is + # true padding which would not be written by the syscall. + waitid_ffi.cdef( + """ +typedef struct siginfo_s { + int si_signo; + int si_errno; + int si_code; + int si_pid; + int si_uid; + int si_status; + int pad[26]; +} siginfo_t; +int waitid(int idtype, int id, siginfo_t* result, int options); +""", + ) + waitid_cffi = waitid_ffi.dlopen(None).waitid # type: ignore[attr-defined] + + def sync_wait_reapable(pid: int) -> None: + P_PID = 1 + WEXITED = 0x00000004 + if sys.platform == "darwin": # pragma: no cover + # waitid() is not exposed on Python on Darwin but does + # work through CFFI; note that we typically won't get + # here since Darwin also defines kqueue + WNOWAIT = 0x00000020 + else: + WNOWAIT = 0x01000000 + result = waitid_ffi.new("siginfo_t *") + while waitid_cffi(P_PID, pid, result, WEXITED | WNOWAIT) < 0: + got_errno = waitid_ffi.errno + if got_errno == errno.EINTR: + continue + raise OSError(got_errno, os.strerror(got_errno)) + + +# adapted from +# https://github.com/python-trio/trio/issues/4#issuecomment-398967572 + +waitid_limiter = CapacityLimiter(math.inf) + + +async def _waitid_system_task(pid: int, event: Event) -> None: + """Spawn a thread that waits for ``pid`` to exit, then wake any tasks + that were waiting on it. + """ + # abandon_on_cancel=True: if this task is cancelled, then we abandon the + # thread to keep running waitpid in the background. Since this is + # always run as a system task, this will only happen if the whole + # call to trio.run is shutting down. + + try: + await to_thread_run_sync( + sync_wait_reapable, + pid, + abandon_on_cancel=True, + limiter=waitid_limiter, + ) + except OSError: + # If waitid fails, waitpid will fail too, so it still makes + # sense to wake up the callers of wait_process_exiting(). The + # most likely reason for this error in practice is a child + # exiting when wait() is not possible because SIGCHLD is + # ignored. + pass + finally: + event.set() + + +async def wait_child_exiting(process: "_subprocess.Process") -> None: + # Logic of this function: + # - The first time we get called, we create an Event and start + # an instance of _waitid_system_task that will set the Event + # when waitid() completes. If that Event is set before + # we get cancelled, we're good. + # - Otherwise, a following call after the cancellation must + # reuse the Event created during the first call, lest we + # create an arbitrary number of threads waiting on the same + # process. + + if process._wait_for_exit_data is None: + process._wait_for_exit_data = event = Event() + _core.spawn_system_task(_waitid_system_task, process.pid, event) + assert isinstance(process._wait_for_exit_data, Event) + await process._wait_for_exit_data.wait() diff --git a/contrib/python/trio/trio/_subprocess_platform/windows.py b/contrib/python/trio/trio/_subprocess_platform/windows.py new file mode 100644 index 00000000000..81fb960e4bb --- /dev/null +++ b/contrib/python/trio/trio/_subprocess_platform/windows.py @@ -0,0 +1,11 @@ +from typing import TYPE_CHECKING + +from .._wait_for_object import WaitForSingleObject + +if TYPE_CHECKING: + from .. import _subprocess + + +async def wait_child_exiting(process: "_subprocess.Process") -> None: + # _handle is not in Popen stubs, though it is present on Windows. + await WaitForSingleObject(int(process._proc._handle)) # type: ignore[attr-defined] diff --git a/contrib/python/trio/trio/_sync.py b/contrib/python/trio/trio/_sync.py new file mode 100644 index 00000000000..d026f4bc37c --- /dev/null +++ b/contrib/python/trio/trio/_sync.py @@ -0,0 +1,908 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Literal, Protocol, TypeVar + +import attrs + +import trio + +from . import _core +from ._core import ( + Abort, + ParkingLot, + RaiseCancelT, + add_parking_lot_breaker, + enable_ki_protection, + remove_parking_lot_breaker, +) +from ._deprecate import warn_deprecated +from ._util import final + +if TYPE_CHECKING: + from collections.abc import Callable + from types import TracebackType + + from typing_extensions import deprecated + + from ._core import Task + from ._core._parking_lot import ParkingLotStatistics +else: + T = TypeVar("T") + + def deprecated( + message: str, + /, + *, + category: type[Warning] | None = DeprecationWarning, + stacklevel: int = 1, + ) -> Callable[[T], T]: + def wrapper(f: T) -> T: + return f + + return wrapper + + +class EventStatistics: + """An object containing debugging information. + + Currently the following fields are defined: + + * ``tasks_waiting``: The number of tasks blocked on this event's + :meth:`trio.Event.wait` method. + + """ + + tasks_waiting: int + + +@final [email protected](repr=False, eq=False) +class Event: + """A waitable boolean value useful for inter-task synchronization, + inspired by :class:`threading.Event`. + + An event object has an internal boolean flag, representing whether + the event has happened yet. The flag is initially False, and the + :meth:`wait` method waits until the flag is True. If the flag is + already True, then :meth:`wait` returns immediately. (If the event has + already happened, there's nothing to wait for.) The :meth:`set` method + sets the flag to True, and wakes up any waiters. + + This behavior is useful because it helps avoid race conditions and + lost wakeups: it doesn't matter whether :meth:`set` gets called just + before or after :meth:`wait`. If you want a lower-level wakeup + primitive that doesn't have this protection, consider :class:`Condition` + or :class:`trio.lowlevel.ParkingLot`. + + .. note:: Unlike `threading.Event`, `trio.Event` has no + `~threading.Event.clear` method. In Trio, once an `Event` has happened, + it cannot un-happen. If you need to represent a series of events, + consider creating a new `Event` object for each one (they're cheap!), + or other synchronization methods like :ref:`channels <channels>` or + `trio.lowlevel.ParkingLot`. + + """ + + _tasks: set[Task] = attrs.field(factory=set, init=False) + _flag: bool = attrs.field(default=False, init=False) + + def is_set(self) -> bool: + """Return the current value of the internal flag.""" + return self._flag + + @enable_ki_protection + def set(self) -> None: + """Set the internal flag value to True, and wake any waiting tasks.""" + if not self._flag: + self._flag = True + for task in self._tasks: + _core.reschedule(task) + self._tasks.clear() + + async def wait(self) -> None: + """Block until the internal flag value becomes True. + + If it's already True, then this method returns immediately. + + """ + if self._flag: + await trio.lowlevel.checkpoint() + else: + task = _core.current_task() + self._tasks.add(task) + + def abort_fn(_: RaiseCancelT) -> Abort: + self._tasks.remove(task) + return _core.Abort.SUCCEEDED + + await _core.wait_task_rescheduled(abort_fn) + + def statistics(self) -> EventStatistics: + """Return an object containing debugging information. + + Currently the following fields are defined: + + * ``tasks_waiting``: The number of tasks blocked on this event's + :meth:`wait` method. + + """ + return EventStatistics(tasks_waiting=len(self._tasks)) + + @deprecated( + "trio.Event.__bool__ is deprecated since Trio 0.31.0; use trio.Event.is_set instead (https://github.com/python-trio/trio/issues/3238)", + stacklevel=2, + ) + def __bool__(self) -> Literal[True]: + """Return True and raise warning.""" + warn_deprecated( + self.__bool__, + "0.31.0", + issue=3238, + instead=self.is_set, + ) + return True + + +class _HasAcquireRelease(Protocol): + """Only classes with acquire() and release() can use the mixin's implementations.""" + + async def acquire(self) -> object: ... + + def release(self) -> object: ... + + +class AsyncContextManagerMixin: + @enable_ki_protection + async def __aenter__(self: _HasAcquireRelease) -> None: + await self.acquire() + + @enable_ki_protection + async def __aexit__( + self: _HasAcquireRelease, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.release() + + +class CapacityLimiterStatistics: + """An object containing debugging information. + + Currently the following fields are defined: + + * ``borrowed_tokens``: The number of tokens currently borrowed from + the sack. + * ``total_tokens``: The total number of tokens in the sack. Usually + this will be larger than ``borrowed_tokens``, but it's possibly for + it to be smaller if :attr:`trio.CapacityLimiter.total_tokens` was recently decreased. + * ``borrowers``: A list of all tasks or other entities that currently + hold a token. + * ``tasks_waiting``: The number of tasks blocked on this + :class:`CapacityLimiter`\'s :meth:`trio.CapacityLimiter.acquire` or + :meth:`trio.CapacityLimiter.acquire_on_behalf_of` methods. + + """ + + borrowed_tokens: int + total_tokens: int | float + borrowers: list[Task | object] + tasks_waiting: int + + +# Can be a generic type with a default of Task if/when PEP 696 is released +# and implemented in type checkers. Making it fully generic would currently +# introduce a lot of unnecessary hassle. +@final +class CapacityLimiter(AsyncContextManagerMixin): + """An object for controlling access to a resource with limited capacity. + + Sometimes you need to put a limit on how many tasks can do something at + the same time. For example, you might want to use some threads to run + multiple blocking I/O operations in parallel... but if you use too many + threads at once, then your system can become overloaded and it'll actually + make things slower. One popular solution is to impose a policy like "run + up to 40 threads at the same time, but no more". But how do you implement + a policy like this? + + That's what :class:`CapacityLimiter` is for. You can think of a + :class:`CapacityLimiter` object as a sack that starts out holding some fixed + number of tokens:: + + limit = trio.CapacityLimiter(40) + + Then tasks can come along and borrow a token out of the sack:: + + # Borrow a token: + async with limit: + # We are holding a token! + await perform_expensive_operation() + # Exiting the 'async with' block puts the token back into the sack + + And crucially, if you try to borrow a token but the sack is empty, then + you have to wait for another task to finish what it's doing and put its + token back first before you can take it and continue. + + Another way to think of it: a :class:`CapacityLimiter` is like a sofa with a + fixed number of seats, and if they're all taken then you have to wait for + someone to get up before you can sit down. + + By default, :func:`trio.to_thread.run_sync` uses a + :class:`CapacityLimiter` to limit the number of threads running at once; + see `trio.to_thread.current_default_thread_limiter` for details. + + If you're familiar with semaphores, then you can think of this as a + restricted semaphore that's specialized for one common use case, with + additional error checking. For a more traditional semaphore, see + :class:`Semaphore`. + + .. note:: + + Don't confuse this with the `"leaky bucket" + <https://en.wikipedia.org/wiki/Leaky_bucket>`__ or `"token bucket" + <https://en.wikipedia.org/wiki/Token_bucket>`__ algorithms used to + limit bandwidth usage on networks. The basic idea of using tokens to + track a resource limit is similar, but this is a very simple sack where + tokens aren't automatically created or destroyed over time; they're + just borrowed and then put back. + + """ + + # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing + def __init__(self, total_tokens: int | float) -> None: # noqa: PYI041 + self._lot = ParkingLot() + self._borrowers: set[Task | object] = set() + # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of + self._pending_borrowers: dict[Task, Task | object] = {} + # invoke the property setter for validation + self.total_tokens: int | float = total_tokens + assert self._total_tokens == total_tokens + + def __repr__(self) -> str: + return f"<trio.CapacityLimiter at {id(self):#x}, {len(self._borrowers)}/{self._total_tokens} with {len(self._lot)} waiting>" + + @property + def total_tokens(self) -> int | float: + """The total capacity available. + + You can change :attr:`total_tokens` by assigning to this attribute. If + you make it larger, then the appropriate number of waiting tasks will + be woken immediately to take the new tokens. If you decrease + total_tokens below the number of tasks that are currently using the + resource, then all current tasks will be allowed to finish as normal, + but no new tasks will be allowed in until the total number of tasks + drops below the new total_tokens. + + """ + return self._total_tokens + + @total_tokens.setter + def total_tokens(self, new_total_tokens: int | float) -> None: # noqa: PYI041 + if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf: + raise TypeError("total_tokens must be an int or math.inf") + if new_total_tokens < 1: + raise ValueError("total_tokens must be >= 1") + self._total_tokens = new_total_tokens + self._wake_waiters() + + def _wake_waiters(self) -> None: + available = self._total_tokens - len(self._borrowers) + for woken in self._lot.unpark(count=available): + self._borrowers.add(self._pending_borrowers.pop(woken)) + + @property + def borrowed_tokens(self) -> int: + """The amount of capacity that's currently in use.""" + return len(self._borrowers) + + @property + def available_tokens(self) -> int | float: + """The amount of capacity that's available to use.""" + return self.total_tokens - self.borrowed_tokens + + @enable_ki_protection + def acquire_nowait(self) -> None: + """Borrow a token from the sack, without blocking. + + Raises: + WouldBlock: if no tokens are available. + RuntimeError: if the current task already holds one of this sack's + tokens. + + """ + self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) + + @enable_ki_protection + def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None: + """Borrow a token from the sack on behalf of ``borrower``, without + blocking. + + Args: + borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object + used to record who is borrowing this token. This is used by + :func:`trio.to_thread.run_sync` to allow threads to "hold + tokens", with the intention in the future of using it to `allow + deadlock detection and other useful things + <https://github.com/python-trio/trio/issues/182>`__ + + Raises: + WouldBlock: if no tokens are available. + RuntimeError: if ``borrower`` already holds one of this sack's + tokens. + + """ + if borrower in self._borrowers: + raise RuntimeError( + "this borrower is already holding one of this CapacityLimiter's tokens", + ) + if len(self._borrowers) < self._total_tokens and not self._lot: + self._borrowers.add(borrower) + else: + raise trio.WouldBlock + + @enable_ki_protection + async def acquire(self) -> None: + """Borrow a token from the sack, blocking if necessary. + + Raises: + RuntimeError: if the current task already holds one of this sack's + tokens. + + """ + await self.acquire_on_behalf_of(trio.lowlevel.current_task()) + + @enable_ki_protection + async def acquire_on_behalf_of(self, borrower: Task | object) -> None: + """Borrow a token from the sack on behalf of ``borrower``, blocking if + necessary. + + Args: + borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object + used to record who is borrowing this token; see + :meth:`acquire_on_behalf_of_nowait` for details. + + Raises: + RuntimeError: if ``borrower`` task already holds one of this sack's + tokens. + + """ + await trio.lowlevel.checkpoint_if_cancelled() + try: + self.acquire_on_behalf_of_nowait(borrower) + except trio.WouldBlock: + task = trio.lowlevel.current_task() + self._pending_borrowers[task] = borrower + try: + await self._lot.park() + except trio.Cancelled: + self._pending_borrowers.pop(task) + raise + else: + await trio.lowlevel.cancel_shielded_checkpoint() + + @enable_ki_protection + def release(self) -> None: + """Put a token back into the sack. + + Raises: + RuntimeError: if the current task has not acquired one of this + sack's tokens. + + """ + self.release_on_behalf_of(trio.lowlevel.current_task()) + + @enable_ki_protection + def release_on_behalf_of(self, borrower: Task | object) -> None: + """Put a token back into the sack on behalf of ``borrower``. + + Raises: + RuntimeError: if the given borrower has not acquired one of this + sack's tokens. + + """ + if borrower not in self._borrowers: + raise RuntimeError( + "this borrower isn't holding any of this CapacityLimiter's tokens", + ) + self._borrowers.remove(borrower) + self._wake_waiters() + + def statistics(self) -> CapacityLimiterStatistics: + """Return an object containing debugging information. + + Currently the following fields are defined: + + * ``borrowed_tokens``: The number of tokens currently borrowed from + the sack. + * ``total_tokens``: The total number of tokens in the sack. Usually + this will be larger than ``borrowed_tokens``, but it's possibly for + it to be smaller if :attr:`total_tokens` was recently decreased. + * ``borrowers``: A list of all tasks or other entities that currently + hold a token. + * ``tasks_waiting``: The number of tasks blocked on this + :class:`CapacityLimiter`\'s :meth:`acquire` or + :meth:`acquire_on_behalf_of` methods. + + """ + return CapacityLimiterStatistics( + borrowed_tokens=len(self._borrowers), + total_tokens=self._total_tokens, + # Use a list instead of a frozenset just in case we start to allow + # one borrower to hold multiple tokens in the future + borrowers=list(self._borrowers), + tasks_waiting=len(self._lot), + ) + + +@final +class Semaphore(AsyncContextManagerMixin): + """A `semaphore <https://en.wikipedia.org/wiki/Semaphore_(programming)>`__. + + A semaphore holds an integer value, which can be incremented by + calling :meth:`release` and decremented by calling :meth:`acquire` – but + the value is never allowed to drop below zero. If the value is zero, then + :meth:`acquire` will block until someone calls :meth:`release`. + + If you're looking for a :class:`Semaphore` to limit the number of tasks + that can access some resource simultaneously, then consider using a + :class:`CapacityLimiter` instead. + + This object's interface is similar to, but different from, that of + :class:`threading.Semaphore`. + + A :class:`Semaphore` object can be used as an async context manager; it + blocks on entry but not on exit. + + Args: + initial_value (int): A non-negative integer giving semaphore's initial + value. + max_value (int or None): If given, makes this a "bounded" semaphore that + raises an error if the value is about to exceed the given + ``max_value``. + + """ + + def __init__(self, initial_value: int, *, max_value: int | None = None) -> None: + if not isinstance(initial_value, int): + raise TypeError("initial_value must be an int") + if initial_value < 0: + raise ValueError("initial value must be >= 0") + if max_value is not None: + if not isinstance(max_value, int): + raise TypeError("max_value must be None or an int") + if max_value < initial_value: + raise ValueError("max_values must be >= initial_value") + + # Invariants: + # bool(self._lot) implies self._value == 0 + # (or equivalently: self._value > 0 implies not self._lot) + self._lot = trio.lowlevel.ParkingLot() + self._value = initial_value + self._max_value = max_value + + def __repr__(self) -> str: + if self._max_value is None: + max_value_str = "" + else: + max_value_str = f", max_value={self._max_value}" + return f"<trio.Semaphore({self._value}{max_value_str}) at {id(self):#x}>" + + @property + def value(self) -> int: + """The current value of the semaphore.""" + return self._value + + @property + def max_value(self) -> int | None: + """The maximum allowed value. May be None to indicate no limit.""" + return self._max_value + + @enable_ki_protection + def acquire_nowait(self) -> None: + """Attempt to decrement the semaphore value, without blocking. + + Raises: + WouldBlock: if the value is zero. + + """ + if self._value > 0: + assert not self._lot + self._value -= 1 + else: + raise trio.WouldBlock + + @enable_ki_protection + async def acquire(self) -> None: + """Decrement the semaphore value, blocking if necessary to avoid + letting it drop below zero. + + """ + await trio.lowlevel.checkpoint_if_cancelled() + try: + self.acquire_nowait() + except trio.WouldBlock: + await self._lot.park() + else: + await trio.lowlevel.cancel_shielded_checkpoint() + + @enable_ki_protection + def release(self) -> None: + """Increment the semaphore value, possibly waking a task blocked in + :meth:`acquire`. + + Raises: + ValueError: if incrementing the value would cause it to exceed + :attr:`max_value`. + + """ + if self._lot: + assert self._value == 0 + self._lot.unpark(count=1) + else: + if self._max_value is not None and self._value == self._max_value: + raise ValueError("semaphore released too many times") + self._value += 1 + + def statistics(self) -> ParkingLotStatistics: + """Return an object containing debugging information. + + Currently the following fields are defined: + + * ``tasks_waiting``: The number of tasks blocked on this semaphore's + :meth:`acquire` method. + + """ + return self._lot.statistics() + + +class LockStatistics: + """An object containing debugging information for a Lock. + + Currently the following fields are defined: + + * ``locked`` (boolean): indicating whether the lock is held. + * ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock, + or None if the lock is not held. + * ``tasks_waiting`` (int): The number of tasks blocked on this lock's + :meth:`trio.Lock.acquire` method. + + """ + + locked: bool + owner: Task | None + tasks_waiting: int + + [email protected](eq=False, repr=False, slots=False) +class _LockImpl(AsyncContextManagerMixin): + _lot: ParkingLot = attrs.field(factory=ParkingLot, init=False) + _owner: Task | None = attrs.field(default=None, init=False) + + def __repr__(self) -> str: + if self.locked(): + s1 = "locked" + s2 = f" with {len(self._lot)} waiters" + else: + s1 = "unlocked" + s2 = "" + return f"<{s1} {self.__class__.__name__} object at {id(self):#x}{s2}>" + + def locked(self) -> bool: + """Check whether the lock is currently held. + + Returns: + bool: True if the lock is held, False otherwise. + + """ + return self._owner is not None + + @enable_ki_protection + def acquire_nowait(self) -> None: + """Attempt to acquire the lock, without blocking. + + Raises: + WouldBlock: if the lock is held. + + """ + + task = trio.lowlevel.current_task() + if self._owner is task: + raise RuntimeError("attempt to re-acquire an already held Lock") + elif self._owner is None and not self._lot: + # No-one owns it + self._owner = task + add_parking_lot_breaker(task, self._lot) + else: + raise trio.WouldBlock + + @enable_ki_protection + async def acquire(self) -> None: + """Acquire the lock, blocking if necessary. + + Raises: + BrokenResourceError: if the owner of the lock exits without releasing. + """ + await trio.lowlevel.checkpoint_if_cancelled() + try: + self.acquire_nowait() + except trio.WouldBlock: + try: + # NOTE: it's important that the contended acquire path is just + # "_lot.park()", because that's how Condition.wait() acquires the + # lock as well. + await self._lot.park() + except trio.BrokenResourceError: + raise trio.BrokenResourceError( + f"Owner of this lock exited without releasing: {self._owner}", + ) from None + else: + await trio.lowlevel.cancel_shielded_checkpoint() + + @enable_ki_protection + def release(self) -> None: + """Release the lock. + + Raises: + RuntimeError: if the calling task does not hold the lock. + + """ + task = trio.lowlevel.current_task() + if task is not self._owner: + raise RuntimeError("can't release a Lock you don't own") + remove_parking_lot_breaker(self._owner, self._lot) + if self._lot: + (self._owner,) = self._lot.unpark(count=1) + add_parking_lot_breaker(self._owner, self._lot) + else: + self._owner = None + + def statistics(self) -> LockStatistics: + """Return an object containing debugging information. + + Currently the following fields are defined: + + * ``locked``: boolean indicating whether the lock is held. + * ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock, + or None if the lock is not held. + * ``tasks_waiting``: The number of tasks blocked on this lock's + :meth:`acquire` method. + + """ + return LockStatistics( + locked=self.locked(), + owner=self._owner, + tasks_waiting=len(self._lot), + ) + + +@final +class Lock(_LockImpl): + """A classic `mutex + <https://en.wikipedia.org/wiki/Lock_(computer_science)>`__. + + This is a non-reentrant, single-owner lock. Unlike + :class:`threading.Lock`, only the owner of the lock is allowed to release + it. + + A :class:`Lock` object can be used as an async context manager; it + blocks on entry but not on exit. + + """ + + +@final +class StrictFIFOLock(_LockImpl): + r"""A variant of :class:`Lock` where tasks are guaranteed to acquire the + lock in strict first-come-first-served order. + + An example of when this is useful is if you're implementing something like + :class:`trio.SSLStream` or an HTTP/2 server using `h2 + <https://hyper-h2.readthedocs.io/>`__, where you have multiple concurrent + tasks that are interacting with a shared state machine, and at + unpredictable moments the state machine requests that a chunk of data be + sent over the network. (For example, when using h2 simply reading incoming + data can occasionally `create outgoing data to send + <https://http2.github.io/http2-spec/#PING>`__.) The challenge is to make + sure that these chunks are sent in the correct order, without being + garbled. + + One option would be to use a regular :class:`Lock`, and wrap it around + every interaction with the state machine:: + + # This approach is sometimes workable but often sub-optimal; see below + async with lock: + state_machine.do_something() + if state_machine.has_data_to_send(): + await conn.sendall(state_machine.get_data_to_send()) + + But this can be problematic. If you're using h2 then *usually* reading + incoming data doesn't create the need to send any data, so we don't want + to force every task that tries to read from the network to sit and wait + a potentially long time for ``sendall`` to finish. And in some situations + this could even potentially cause a deadlock, if the remote peer is + waiting for you to read some data before it accepts the data you're + sending. + + :class:`StrictFIFOLock` provides an alternative. We can rewrite our + example like:: + + # Note: no awaits between when we start using the state machine and + # when we block to take the lock! + state_machine.do_something() + if state_machine.has_data_to_send(): + # Notice that we fetch the data to send out of the state machine + # *before* sleeping, so that other tasks won't see it. + chunk = state_machine.get_data_to_send() + async with strict_fifo_lock: + await conn.sendall(chunk) + + First we do all our interaction with the state machine in a single + scheduling quantum (notice there are no ``await``\s in there), so it's + automatically atomic with respect to other tasks. And then if and only if + we have data to send, we get in line to send it – and + :class:`StrictFIFOLock` guarantees that each task will send its data in + the same order that the state machine generated it. + + Currently, :class:`StrictFIFOLock` is identical to :class:`Lock`, + but (a) this may not always be true in the future, especially if Trio ever + implements `more sophisticated scheduling policies + <https://github.com/python-trio/trio/issues/32>`__, and (b) the above code + is relying on a pretty subtle property of its lock. Using a + :class:`StrictFIFOLock` acts as an executable reminder that you're relying + on this property. + + """ + + +class ConditionStatistics: + r"""An object containing debugging information for a Condition. + + Currently the following fields are defined: + + * ``tasks_waiting`` (int): The number of tasks blocked on this condition's + :meth:`trio.Condition.wait` method. + * ``lock_statistics``: The result of calling the underlying + :class:`Lock`\s :meth:`~Lock.statistics` method. + + """ + + tasks_waiting: int + lock_statistics: LockStatistics + + +@final +class Condition(AsyncContextManagerMixin): + """A classic `condition variable + <https://en.wikipedia.org/wiki/Monitor_(synchronization)>`__, similar to + :class:`threading.Condition`. + + A :class:`Condition` object can be used as an async context manager to + acquire the underlying lock; it blocks on entry but not on exit. + + Args: + lock (Lock): the lock object to use. If given, must be a + :class:`trio.Lock`. If None, a new :class:`Lock` will be allocated + and used. + + """ + + def __init__(self, lock: Lock | None = None) -> None: + if lock is None: + lock = Lock() + if type(lock) is not Lock: + raise TypeError("lock must be a trio.Lock") + self._lock = lock + self._lot = trio.lowlevel.ParkingLot() + + def locked(self) -> bool: + """Check whether the underlying lock is currently held. + + Returns: + bool: True if the lock is held, False otherwise. + + """ + return self._lock.locked() + + def acquire_nowait(self) -> None: + """Attempt to acquire the underlying lock, without blocking. + + Raises: + WouldBlock: if the lock is currently held. + + """ + return self._lock.acquire_nowait() + + async def acquire(self) -> None: + """Acquire the underlying lock, blocking if necessary. + + Raises: + BrokenResourceError: if the owner of the underlying lock exits without releasing. + """ + await self._lock.acquire() + + def release(self) -> None: + """Release the underlying lock.""" + self._lock.release() + + @enable_ki_protection + async def wait(self) -> None: + """Wait for another task to call :meth:`notify` or + :meth:`notify_all`. + + When calling this method, you must hold the lock. It releases the lock + while waiting, and then re-acquires it before waking up. + + There is a subtlety with how this method interacts with cancellation: + when cancelled it will block to re-acquire the lock before raising + :exc:`Cancelled`. This may cause cancellation to be less prompt than + expected. The advantage is that it makes code like this work:: + + async with condition: + await condition.wait() + + If we didn't re-acquire the lock before waking up, and :meth:`wait` + were cancelled here, then we'd crash in ``condition.__aexit__`` when + we tried to release the lock we no longer held. + + Raises: + RuntimeError: if the calling task does not hold the lock. + BrokenResourceError: if the owner of the lock exits without releasing, when attempting to re-acquire. + + """ + if trio.lowlevel.current_task() is not self._lock._owner: + raise RuntimeError("must hold the lock to wait") + self.release() + # NOTE: we go to sleep on self._lot, but we'll wake up on + # self._lock._lot. That's all that's required to acquire a Lock. + try: + await self._lot.park() + except: + with trio.CancelScope(shield=True): + await self.acquire() + raise + + def notify(self, n: int = 1) -> None: + """Wake one or more tasks that are blocked in :meth:`wait`. + + Args: + n (int): The number of tasks to wake. + + Raises: + RuntimeError: if the calling task does not hold the lock. + + """ + if trio.lowlevel.current_task() is not self._lock._owner: + raise RuntimeError("must hold the lock to notify") + self._lot.repark(self._lock._lot, count=n) + + def notify_all(self) -> None: + """Wake all tasks that are currently blocked in :meth:`wait`. + + Raises: + RuntimeError: if the calling task does not hold the lock. + + """ + if trio.lowlevel.current_task() is not self._lock._owner: + raise RuntimeError("must hold the lock to notify") + self._lot.repark_all(self._lock._lot) + + def statistics(self) -> ConditionStatistics: + r"""Return an object containing debugging information. + + Currently the following fields are defined: + + * ``tasks_waiting``: The number of tasks blocked on this condition's + :meth:`wait` method. + * ``lock_statistics``: The result of calling the underlying + :class:`Lock`\s :meth:`~Lock.statistics` method. + + """ + return ConditionStatistics( + tasks_waiting=len(self._lot), + lock_statistics=self._lock.statistics(), + ) diff --git a/contrib/python/trio/trio/_threads.py b/contrib/python/trio/trio/_threads.py new file mode 100644 index 00000000000..4b1e54f5402 --- /dev/null +++ b/contrib/python/trio/trio/_threads.py @@ -0,0 +1,610 @@ +from __future__ import annotations + +import contextlib +import contextvars +import inspect +import queue as stdlib_queue +import threading +from itertools import count +from typing import TYPE_CHECKING, Generic, TypeVar + +import attrs +import outcome +from attrs import define +from sniffio import current_async_library_cvar + +import trio + +from ._core import ( + RunVar, + TrioToken, + checkpoint, + disable_ki_protection, + enable_ki_protection, + start_thread_soon, +) +from ._sync import CapacityLimiter, Event +from ._util import coroutine_or_error + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Generator + + from typing_extensions import TypeVarTuple, Unpack + + from trio._core._traps import RaiseCancelT + + Ts = TypeVarTuple("Ts") + +RetT = TypeVar("RetT") + + +class _ParentTaskData(threading.local): + """Global due to Threading API, thread local storage for data related to the + parent task of native Trio threads.""" + + token: TrioToken + abandon_on_cancel: bool + cancel_register: list[RaiseCancelT | None] + task_register: list[trio.lowlevel.Task | None] + + +PARENT_TASK_DATA = _ParentTaskData() + +_limiter_local: RunVar[CapacityLimiter] = RunVar("limiter") +# I pulled this number out of the air; it isn't based on anything. Probably we +# should make some kind of measurements to pick a good value. +DEFAULT_LIMIT = 40 +_thread_counter = count() + + +@define +class _ActiveThreadCount: + count: int + event: Event + + +_active_threads_local: RunVar[_ActiveThreadCount] = RunVar("active_threads") + + +def _track_active_thread() -> Generator[None, None, None]: + try: + active_threads_local = _active_threads_local.get() + except LookupError: + active_threads_local = _ActiveThreadCount(0, Event()) + _active_threads_local.set(active_threads_local) + + active_threads_local.count += 1 + try: + yield + finally: + active_threads_local.count -= 1 + if active_threads_local.count == 0: + active_threads_local.event.set() + active_threads_local.event = Event() + + +async def wait_all_threads_completed() -> None: + """Wait until no threads are still running tasks. + + This is intended to be used when testing code with trio.to_thread to + make sure no tasks are still making progress in a thread. See the + following code for a usage example:: + + async def wait_all_settled(): + while True: + await trio.testing.wait_all_threads_complete() + await trio.testing.wait_all_tasks_blocked() + if trio.testing.active_thread_count() == 0: + break + """ + + await checkpoint() + + try: + active_threads_local = _active_threads_local.get() + except LookupError: + # If there would have been active threads, the + # _active_threads_local would have been set + return + + while active_threads_local.count != 0: + await active_threads_local.event.wait() + + +def active_thread_count() -> int: + """Returns the number of threads that are currently running a task + + See `trio.testing.wait_all_threads_completed` + """ + try: + return _active_threads_local.get().count + except LookupError: + return 0 + + +def current_default_thread_limiter() -> CapacityLimiter: + """Get the default `~trio.CapacityLimiter` used by + `trio.to_thread.run_sync`. + + The most common reason to call this would be if you want to modify its + :attr:`~trio.CapacityLimiter.total_tokens` attribute. + + """ + try: + limiter = _limiter_local.get() + except LookupError: + limiter = CapacityLimiter(DEFAULT_LIMIT) + _limiter_local.set(limiter) + return limiter + + +# Eventually we might build this into a full-fledged deadlock-detection +# system; see https://github.com/python-trio/trio/issues/182 +# But for now we just need an object to stand in for the thread, so we can +# keep track of who's holding the CapacityLimiter's token. [email protected](eq=False, slots=False) +class ThreadPlaceholder: + name: str + + +# Types for the to_thread_run_sync message loop [email protected](eq=False, slots=False) +class Run(Generic[RetT]): # type: ignore[explicit-any] + afn: Callable[..., Awaitable[RetT]] # type: ignore[explicit-any] + args: tuple[object, ...] + context: contextvars.Context = attrs.field( + init=False, + factory=contextvars.copy_context, + ) + queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attrs.field( + init=False, + factory=stdlib_queue.SimpleQueue, + ) + + @disable_ki_protection + async def unprotected_afn(self) -> RetT: + coro = coroutine_or_error(self.afn, *self.args) + return await coro + + async def run(self) -> None: + # we use extra checkpoints to pick up and reset any context changes + task = trio.lowlevel.current_task() + old_context = task.context + task.context = self.context.copy() + await trio.lowlevel.cancel_shielded_checkpoint() + result = await outcome.acapture(self.unprotected_afn) + task.context = old_context + await trio.lowlevel.cancel_shielded_checkpoint() + self.queue.put_nowait(result) + + async def run_system(self) -> None: + result = await outcome.acapture(self.unprotected_afn) + self.queue.put_nowait(result) + + def run_in_host_task(self, token: TrioToken) -> None: + task_register = PARENT_TASK_DATA.task_register + + def in_trio_thread() -> None: + task = task_register[0] + assert task is not None, "guaranteed by abandon_on_cancel semantics" + trio.lowlevel.reschedule(task, outcome.Value(self)) + + token.run_sync_soon(in_trio_thread) + + def run_in_system_nursery(self, token: TrioToken) -> None: + def in_trio_thread() -> None: + try: + trio.lowlevel.spawn_system_task( + self.run_system, + name=self.afn, + context=self.context, + ) + except RuntimeError: # system nursery is closed + self.queue.put_nowait( + outcome.Error(trio.RunFinishedError("system nursery is closed")), + ) + + token.run_sync_soon(in_trio_thread) + + [email protected](eq=False, slots=False) +class RunSync(Generic[RetT]): # type: ignore[explicit-any] + fn: Callable[..., RetT] # type: ignore[explicit-any] + args: tuple[object, ...] + context: contextvars.Context = attrs.field( + init=False, + factory=contextvars.copy_context, + ) + queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attrs.field( + init=False, + factory=stdlib_queue.SimpleQueue, + ) + + @disable_ki_protection + def unprotected_fn(self) -> RetT: + ret = self.context.run(self.fn, *self.args) + + if inspect.iscoroutine(ret): + # Manually close coroutine to avoid RuntimeWarnings + ret.close() + raise TypeError( + "Trio expected a synchronous function, but {!r} appears to be " + "asynchronous".format(getattr(self.fn, "__qualname__", self.fn)), + ) + + return ret + + def run_sync(self) -> None: + result = outcome.capture(self.unprotected_fn) + self.queue.put_nowait(result) + + def run_in_host_task(self, token: TrioToken) -> None: + task_register = PARENT_TASK_DATA.task_register + + def in_trio_thread() -> None: + task = task_register[0] + assert task is not None, "guaranteed by abandon_on_cancel semantics" + trio.lowlevel.reschedule(task, outcome.Value(self)) + + token.run_sync_soon(in_trio_thread) + + def run_in_system_nursery(self, token: TrioToken) -> None: + token.run_sync_soon(self.run_sync) + + +@enable_ki_protection +async def to_thread_run_sync( + sync_fn: Callable[[Unpack[Ts]], RetT], + *args: Unpack[Ts], + thread_name: str | None = None, + abandon_on_cancel: bool = False, + limiter: CapacityLimiter | None = None, +) -> RetT: + """Convert a blocking operation into an async operation using a thread. + + These two lines are equivalent:: + + sync_fn(*args) + await trio.to_thread.run_sync(sync_fn, *args) + + except that if ``sync_fn`` takes a long time, then the first line will + block the Trio loop while it runs, while the second line allows other Trio + tasks to continue working while ``sync_fn`` runs. This is accomplished by + pushing the call to ``sync_fn(*args)`` off into a worker thread. + + From inside the worker thread, you can get back into Trio using the + functions in `trio.from_thread`. + + Args: + sync_fn: An arbitrary synchronous callable. + *args: Positional arguments to pass to sync_fn. If you need keyword + arguments, use :func:`functools.partial`. + abandon_on_cancel (bool): Whether to abandon this thread upon + cancellation of this operation. See discussion below. + thread_name (str): Optional string to set the name of the thread. + Will always set `threading.Thread.name`, but only set the os name + if pthread.h is available (i.e. most POSIX installations). + pthread names are limited to 15 characters, and can be read from + ``/proc/<PID>/task/<SPID>/comm`` or with ``ps -eT``, among others. + Defaults to ``{sync_fn.__name__|None} from {trio.lowlevel.current_task().name}``. + limiter (None, or CapacityLimiter-like object): + An object used to limit the number of simultaneous threads. Most + commonly this will be a `~trio.CapacityLimiter`, but it could be + anything providing compatible + :meth:`~trio.CapacityLimiter.acquire_on_behalf_of` and + :meth:`~trio.CapacityLimiter.release_on_behalf_of` methods. This + function will call ``acquire_on_behalf_of`` before starting the + thread, and ``release_on_behalf_of`` after the thread has finished. + + If None (the default), uses the default `~trio.CapacityLimiter`, as + returned by :func:`current_default_thread_limiter`. + + **Cancellation handling**: Cancellation is a tricky issue here, because + neither Python nor the operating systems it runs on provide any general + mechanism for cancelling an arbitrary synchronous function running in a + thread. This function will always check for cancellation on entry, before + starting the thread. But once the thread is running, there are two ways it + can handle being cancelled: + + * If ``abandon_on_cancel=False``, the function ignores the cancellation and + keeps going, just like if we had called ``sync_fn`` synchronously. This + is the default behavior. + + * If ``abandon_on_cancel=True``, then this function immediately raises + `~trio.Cancelled`. In this case **the thread keeps running in + background** – we just abandon it to do whatever it's going to do, and + silently discard any return value or errors that it raises. Only use + this if you know that the operation is safe and side-effect free. (For + example: :func:`trio.socket.getaddrinfo` uses a thread with + ``abandon_on_cancel=True``, because it doesn't really affect anything if a + stray hostname lookup keeps running in the background.) + + The ``limiter`` is only released after the thread has *actually* + finished – which in the case of cancellation may be some time after this + function has returned. If :func:`trio.run` finishes before the thread + does, then the limiter release method will never be called at all. + + .. warning:: + + You should not use this function to call long-running CPU-bound + functions! In addition to the usual GIL-related reasons why using + threads for CPU-bound work is not very effective in Python, there is an + additional problem: on CPython, `CPU-bound threads tend to "starve out" + IO-bound threads <https://bugs.python.org/issue7946>`__, so using + threads for CPU-bound work is likely to adversely affect the main + thread running Trio. If you need to do this, you're better off using a + worker process, or perhaps PyPy (which still has a GIL, but may do a + better job of fairly allocating CPU time between threads). + + Returns: + Whatever ``sync_fn(*args)`` returns. + + Raises: + Exception: Whatever ``sync_fn(*args)`` raises. + + """ + await trio.lowlevel.checkpoint_if_cancelled() + # raise early if abandon_on_cancel.__bool__ raises + # and give a new name to ensure mypy knows it's never None + abandon_bool = bool(abandon_on_cancel) + if limiter is None: + limiter = current_default_thread_limiter() + + # Holds a reference to the task that's blocked in this function waiting + # for the result – or None if this function was cancelled and we should + # discard the result. + task_register: list[trio.lowlevel.Task | None] = [trio.lowlevel.current_task()] + # Holds a reference to the raise_cancel function provided if a cancellation + # is attempted against this task - or None if no such delivery has happened. + cancel_register: list[RaiseCancelT | None] = [None] # type: ignore[assignment] + name = f"trio.to_thread.run_sync-{next(_thread_counter)}" + placeholder = ThreadPlaceholder(name) + + # This function gets scheduled into the Trio run loop to deliver the + # thread's result. + def report_back_in_trio_thread_fn(result: outcome.Outcome[RetT]) -> None: + def do_release_then_return_result() -> RetT: + # release_on_behalf_of is an arbitrary user-defined method, so it + # might raise an error. If it does, we want that error to + # replace the regular return value, and if the regular return was + # already an exception then we want them to chain. + try: + return result.unwrap() + finally: + limiter.release_on_behalf_of(placeholder) + + result = outcome.capture(do_release_then_return_result) + if task_register[0] is not None: + trio.lowlevel.reschedule(task_register[0], outcome.Value(result)) + + current_trio_token = trio.lowlevel.current_trio_token() + + if thread_name is None: + thread_name = f"{getattr(sync_fn, '__name__', None)} from {trio.lowlevel.current_task().name}" + + def worker_fn() -> RetT: + PARENT_TASK_DATA.token = current_trio_token + PARENT_TASK_DATA.abandon_on_cancel = abandon_bool + PARENT_TASK_DATA.cancel_register = cancel_register + PARENT_TASK_DATA.task_register = task_register + try: + ret = context.run(sync_fn, *args) + + if inspect.iscoroutine(ret): + # Manually close coroutine to avoid RuntimeWarnings + ret.close() + raise TypeError( + "Trio expected a sync function, but {!r} appears to be " + "asynchronous".format(getattr(sync_fn, "__qualname__", sync_fn)), + ) + + return ret + finally: + del PARENT_TASK_DATA.token + del PARENT_TASK_DATA.abandon_on_cancel + del PARENT_TASK_DATA.cancel_register + del PARENT_TASK_DATA.task_register + + context = contextvars.copy_context() + # Trio doesn't use current_async_library_cvar, but if someone + # else set it, it would now shine through since + # sniffio.thread_local isn't set in the new thread. Make sure + # the new thread sees that it's not running in async context. + context.run(current_async_library_cvar.set, None) + + def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: + # If the entire run finished, the task we're trying to contact is + # certainly long gone -- it must have been cancelled and abandoned + # us. Just ignore the error in this case. + with contextlib.suppress(trio.RunFinishedError): + current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result) + + await limiter.acquire_on_behalf_of(placeholder) + with _track_active_thread(): + try: + start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name) + except: + limiter.release_on_behalf_of(placeholder) + raise + + def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: + # fill so from_thread_check_cancelled can raise + # 'raise_cancel' will immediately delete its reason object, so we make + # a copy in each thread + cancel_register[0] = raise_cancel + if abandon_bool: + # empty so report_back_in_trio_thread_fn cannot reschedule + task_register[0] = None + return trio.lowlevel.Abort.SUCCEEDED + else: + return trio.lowlevel.Abort.FAILED + + while True: + # wait_task_rescheduled return value cannot be typed + msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[object] = ( + await trio.lowlevel.wait_task_rescheduled(abort) + ) + if isinstance(msg_from_thread, outcome.Outcome): + return msg_from_thread.unwrap() + elif isinstance(msg_from_thread, Run): + await msg_from_thread.run() + elif isinstance(msg_from_thread, RunSync): + msg_from_thread.run_sync() + else: # pragma: no cover, internal debugging guard TODO: use assert_never + raise TypeError( + f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.", + ) + del msg_from_thread + + +def from_thread_check_cancelled() -> None: + """Raise `trio.Cancelled` if the associated Trio task entered a cancelled status. + + Only applicable to threads spawned by `trio.to_thread.run_sync`. Poll to allow + ``abandon_on_cancel=False`` threads to raise :exc:`~trio.Cancelled` at a suitable + place, or to end abandoned ``abandon_on_cancel=True`` threads sooner than they may + otherwise. + + Raises: + Cancelled: If the corresponding call to `trio.to_thread.run_sync` has had a + delivery of cancellation attempted against it, regardless of the value of + ``abandon_on_cancel`` supplied as an argument to it. + RuntimeError: If this thread is not spawned from `trio.to_thread.run_sync`. + + .. note:: + + To be precise, :func:`~trio.from_thread.check_cancelled` checks whether the task + running :func:`trio.to_thread.run_sync` has ever been cancelled since the last + time it was running a :func:`trio.from_thread.run` or :func:`trio.from_thread.run_sync` + function. It may raise `trio.Cancelled` even if a cancellation occurred that was + later hidden by a modification to `trio.CancelScope.shield` between the cancelled + `~trio.CancelScope` and :func:`trio.to_thread.run_sync`. This differs from the + behavior of normal Trio checkpoints, which raise `~trio.Cancelled` only if the + cancellation is still active when the checkpoint executes. The distinction here is + *exceedingly* unlikely to be relevant to your application, but we mention it + for completeness. + """ + try: + raise_cancel = PARENT_TASK_DATA.cancel_register[0] + except AttributeError: + raise RuntimeError( + "this thread wasn't created by Trio, can't check for cancellation", + ) from None + if raise_cancel is not None: + raise_cancel() + + +def _send_message_to_trio( + trio_token: TrioToken | None, + message_to_trio: Run[RetT] | RunSync[RetT], +) -> RetT: + """Shared logic of from_thread functions""" + token_provided = trio_token is not None + + if not token_provided: + try: + trio_token = PARENT_TASK_DATA.token + except AttributeError: + raise RuntimeError( + "this thread wasn't created by Trio, pass kwarg trio_token=...", + ) from None + elif not isinstance(trio_token, TrioToken): + raise RuntimeError("Passed kwarg trio_token is not of type TrioToken") + + # Avoid deadlock by making sure we're not called from Trio thread + try: + trio.lowlevel.current_task() + except RuntimeError: + pass + else: + raise RuntimeError("this is a blocking function; call it from a thread") + + if token_provided or PARENT_TASK_DATA.abandon_on_cancel: + message_to_trio.run_in_system_nursery(trio_token) + else: + message_to_trio.run_in_host_task(trio_token) + + return message_to_trio.queue.get().unwrap() + + +def from_thread_run( + afn: Callable[[Unpack[Ts]], Awaitable[RetT]], + *args: Unpack[Ts], + trio_token: TrioToken | None = None, +) -> RetT: + """Run the given async function in the parent Trio thread, blocking until it + is complete. + + Returns: + Whatever ``afn(*args)`` returns. + + Returns or raises whatever the given function returns or raises. It + can also raise exceptions of its own: + + Raises: + RunFinishedError: if the corresponding call to :func:`trio.run` has + already completed, or if the run has started its final cleanup phase + and can no longer spawn new system tasks. + Cancelled: If the original call to :func:`trio.to_thread.run_sync` is cancelled + (if *trio_token* is None) or the call to :func:`trio.run` completes + (if *trio_token* is not None) while ``afn(*args)`` is running, + then *afn* is likely to raise :exc:`trio.Cancelled`. + RuntimeError: if you try calling this from inside the Trio thread, + which would otherwise cause a deadlock, or if no ``trio_token`` was + provided, and we can't infer one from context. + TypeError: if ``afn`` is not an asynchronous function. + + **Locating a TrioToken**: There are two ways to specify which + `trio.run` loop to reenter: + + - Spawn this thread from `trio.to_thread.run_sync`. Trio will + automatically capture the relevant Trio token and use it + to re-enter the same Trio task. + - Pass a keyword argument, ``trio_token`` specifying a specific + `trio.run` loop to re-enter. This is useful in case you have a + "foreign" thread, spawned using some other framework, and still want + to enter Trio, or if you want to use a new system task to call ``afn``, + maybe to avoid the cancellation context of a corresponding + `trio.to_thread.run_sync` task. You can get this token from + :func:`trio.lowlevel.current_trio_token`. + """ + return _send_message_to_trio(trio_token, Run(afn, args)) + + +def from_thread_run_sync( + fn: Callable[[Unpack[Ts]], RetT], + *args: Unpack[Ts], + trio_token: TrioToken | None = None, +) -> RetT: + """Run the given sync function in the parent Trio thread, blocking until it + is complete. + + Returns: + Whatever ``fn(*args)`` returns. + + Returns or raises whatever the given function returns or raises. It + can also raise exceptions of its own: + + Raises: + RunFinishedError: if the corresponding call to `trio.run` has + already completed. + RuntimeError: if you try calling this from inside the Trio thread, + which would otherwise cause a deadlock or if no ``trio_token`` was + provided, and we can't infer one from context. + TypeError: if ``fn`` is an async function. + + **Locating a TrioToken**: There are two ways to specify which + `trio.run` loop to reenter: + + - Spawn this thread from `trio.to_thread.run_sync`. Trio will + automatically capture the relevant Trio token and use it when you + want to re-enter Trio. + - Pass a keyword argument, ``trio_token`` specifying a specific + `trio.run` loop to re-enter. This is useful in case you have a + "foreign" thread, spawned using some other framework, and still want + to enter Trio, or if you want to use a new system task to call ``fn``, + maybe to avoid the cancellation context of a corresponding + `trio.to_thread.run_sync` task. + """ + return _send_message_to_trio(trio_token, RunSync(fn, args)) diff --git a/contrib/python/trio/trio/_timeouts.py b/contrib/python/trio/trio/_timeouts.py new file mode 100644 index 00000000000..d95cbe4cfc2 --- /dev/null +++ b/contrib/python/trio/trio/_timeouts.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import math +import sys +from contextlib import contextmanager +from typing import TYPE_CHECKING, NoReturn + +import trio + +if TYPE_CHECKING: + from collections.abc import Generator + + +def move_on_at(deadline: float, *, shield: bool = False) -> trio.CancelScope: + """Use as a context manager to create a cancel scope with the given + absolute deadline. + + Args: + deadline (float): The deadline. + shield (bool): Initial value for the `~trio.CancelScope.shield` attribute + of the newly created cancel scope. + + Raises: + ValueError: if deadline is NaN. + + """ + # CancelScope validates that deadline isn't math.nan + return trio.CancelScope(deadline=deadline, shield=shield) + + +def move_on_after( + seconds: float, + *, + shield: bool = False, +) -> trio.CancelScope: + """Use as a context manager to create a cancel scope whose deadline is + set to now + *seconds*. + + The deadline of the cancel scope is calculated upon entering. + + Args: + seconds (float): The timeout. + shield (bool): Initial value for the `~trio.CancelScope.shield` attribute + of the newly created cancel scope. + + Raises: + ValueError: if ``seconds`` is less than zero or NaN. + + """ + # duplicate validation logic to have the correct parameter name + if seconds < 0: + raise ValueError("`seconds` must be non-negative") + if math.isnan(seconds): + raise ValueError("`seconds` must not be NaN") + return trio.CancelScope( + shield=shield, + relative_deadline=seconds, + ) + + +async def sleep_forever() -> NoReturn: + """Pause execution of the current task forever (or until cancelled). + + Equivalent to calling ``await sleep(math.inf)``, except that if manually + rescheduled this will raise a `RuntimeError`. + + Raises: + RuntimeError: if rescheduled + + """ + await trio.lowlevel.wait_task_rescheduled(lambda _: trio.lowlevel.Abort.SUCCEEDED) + raise RuntimeError("Should never have been rescheduled!") + + +async def sleep_until(deadline: float) -> None: + """Pause execution of the current task until the given time. + + The difference between :func:`sleep` and :func:`sleep_until` is that the + former takes a relative time and the latter takes an absolute time + according to Trio's internal clock (as returned by :func:`current_time`). + + Args: + deadline (float): The time at which we should wake up again. May be in + the past, in which case this function executes a checkpoint but + does not block. + + Raises: + ValueError: if deadline is NaN. + + """ + with move_on_at(deadline): + await sleep_forever() + + +async def sleep(seconds: float) -> None: + """Pause execution of the current task for the given number of seconds. + + Args: + seconds (float): The number of seconds to sleep. May be zero to + insert a checkpoint without actually blocking. + + Raises: + ValueError: if *seconds* is negative or NaN. + + """ + if seconds < 0: + raise ValueError("`seconds` must be non-negative") + if seconds == 0: + await trio.lowlevel.checkpoint() + else: + await sleep_until(trio.current_time() + seconds) + + +class TooSlowError(Exception): + """Raised by :func:`fail_after` and :func:`fail_at` if the timeout + expires. + + """ + + +@contextmanager +def fail_at( + deadline: float, + *, + shield: bool = False, +) -> Generator[trio.CancelScope, None, None]: + """Creates a cancel scope with the given deadline, and raises an error if it + is actually cancelled. + + This function and :func:`move_on_at` are similar in that both create a + cancel scope with a given absolute deadline, and if the deadline expires + then both will cause :exc:`Cancelled` to be raised within the scope. The + difference is that when the :exc:`Cancelled` exception reaches + :func:`move_on_at`, it's caught and discarded. When it reaches + :func:`fail_at`, then it's caught and :exc:`TooSlowError` is raised in its + place. + + Args: + deadline (float): The deadline. + shield (bool): Initial value for the `~trio.CancelScope.shield` attribute + of the newly created cancel scope. + + Raises: + TooSlowError: if a :exc:`Cancelled` exception is raised in this scope + and caught by the context manager. + ValueError: if deadline is NaN. + + """ + with move_on_at(deadline, shield=shield) as scope: + yield scope + if scope.cancelled_caught: + raise TooSlowError + + +@contextmanager +def fail_after( + seconds: float, + *, + shield: bool = False, +) -> Generator[trio.CancelScope, None, None]: + """Creates a cancel scope with the given timeout, and raises an error if + it is actually cancelled. + + This function and :func:`move_on_after` are similar in that both create a + cancel scope with a given timeout, and if the timeout expires then both + will cause :exc:`Cancelled` to be raised within the scope. The difference + is that when the :exc:`Cancelled` exception reaches :func:`move_on_after`, + it's caught and discarded. When it reaches :func:`fail_after`, then it's + caught and :exc:`TooSlowError` is raised in its place. + + The deadline of the cancel scope is calculated upon entering. + + Args: + seconds (float): The timeout. + shield (bool): Initial value for the `~trio.CancelScope.shield` attribute + of the newly created cancel scope. + + Raises: + TooSlowError: if a :exc:`Cancelled` exception is raised in this scope + and caught by the context manager. + ValueError: if *seconds* is less than zero or NaN. + + """ + with move_on_after(seconds, shield=shield) as scope: + yield scope + if scope.cancelled_caught: + raise TooSlowError + + +# Users don't need to know that fail_at & fail_after wraps move_on_at and move_on_after +# and there is no functional difference. So we replace the return value when generating +# documentation. +if "sphinx.ext.autodoc" in sys.modules: + import inspect + + for c in (fail_at, fail_after): + c.__signature__ = inspect.Signature.from_callable(c).replace(return_annotation=trio.CancelScope) # type: ignore[union-attr] diff --git a/contrib/python/trio/trio/_tools/__init__.py b/contrib/python/trio/trio/_tools/__init__.py new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/contrib/python/trio/trio/_tools/__init__.py diff --git a/contrib/python/trio/trio/_tools/gen_exports.py b/contrib/python/trio/trio/_tools/gen_exports.py new file mode 100644 index 00000000000..101e0e4912d --- /dev/null +++ b/contrib/python/trio/trio/_tools/gen_exports.py @@ -0,0 +1,403 @@ +#! /usr/bin/env python3 +""" +Code generation script for class methods +to be exported as public API +""" +from __future__ import annotations + +import argparse +import ast +import os +import subprocess +import sys +from pathlib import Path +from textwrap import indent +from typing import TYPE_CHECKING + +import attrs + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + from typing_extensions import TypeGuard + +# keep these imports up to date with conditional imports in test_gen_exports +# isort: split +import astor + +PREFIX = "_generated" + +HEADER = """# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +from __future__ import annotations + +import sys + +from ._ki import enable_ki_protection +from ._run import GLOBAL_RUN_CONTEXT +""" + +TEMPLATE = """try: + return{}GLOBAL_RUN_CONTEXT.{}.{} +except AttributeError: + raise RuntimeError("must be called from async context") from None +""" + + +class File: + path: Path + modname: str + platform: str = attrs.field(default="", kw_only=True) + imports: str = attrs.field(default="", kw_only=True) + + +def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: + """Check if the AST node is either a function + or an async function + """ + return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + + +def is_public(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: + """Check if the AST node has a _public decorator""" + if is_function(node): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == "_public": + return True + return False + + +def get_public_methods( + tree: ast.AST, +) -> Iterator[ast.FunctionDef | ast.AsyncFunctionDef]: + """Return a list of methods marked as public. + The function walks the given tree and extracts + all objects that are functions which are marked + public. + """ + for node in ast.walk(tree): + if is_public(node): + yield node + + +def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> str: + """Given a function definition, create a string that represents taking all + the arguments from the function, and passing them through to another + invocation of the same function. + + Example input: ast.parse("def f(a, *, b): ...") + Example output: "(a, b=b)" + """ + call_args = [arg.arg for arg in funcdef.args.args] + if funcdef.args.vararg: + call_args.append("*" + funcdef.args.vararg.arg) + for arg in funcdef.args.kwonlyargs: + call_args.append(arg.arg + "=" + arg.arg) # noqa: PERF401 # clarity + if funcdef.args.kwarg: + call_args.append("**" + funcdef.args.kwarg.arg) + return "({})".format(", ".join(call_args)) + + +def run_black(file: File, source: str) -> tuple[bool, str]: + """Run black on the specified file. + + Returns: + Tuple of success and result string. + ex.: + (False, "Failed to run black!\nerror: cannot format ...") + (True, "<formatted source>") + + Raises: + ImportError: If black is not installed. + """ + # imported to check that `subprocess` calls will succeed + import black # noqa: F401 + + # Black has an undocumented API, but it doesn't easily allow reading configuration from + # pyproject.toml, and simultaneously pass in / receive the code as a string. + # https://github.com/psf/black/issues/779 + result = subprocess.run( + # "-" as a filename = use stdin, return on stdout. + [sys.executable, "-m", "black", "--stdin-filename", file.path, "-"], + input=source, + capture_output=True, + encoding="utf8", + ) + + if result.returncode != 0: + return False, f"Failed to run black!\n{result.stderr}" + return True, result.stdout + + +def run_ruff(file: File, source: str) -> tuple[bool, str]: + """Run ruff on the specified file. + + Returns: + Tuple of success and result string. + ex.: + (False, "Failed to run ruff!\nerror: Failed to parse ...") + (True, "<formatted source>") + + Raises: + ImportError: If ruff is not installed. + """ + # imported to check that `subprocess` calls will succeed + import ruff # noqa: F401 + + result = subprocess.run( + # "-" as a filename = use stdin, return on stdout. + [ + sys.executable, + "-m", + "ruff", + "check", + "--fix", + "--unsafe-fixes", + "--stdin-filename", + file.path, + "-", + ], + input=source, + capture_output=True, + encoding="utf8", + ) + + if result.returncode != 0: + return False, f"Failed to run ruff!\n{result.stderr}" + return True, result.stdout + + +def run_linters(file: File, source: str) -> str: + """Format the specified file using black and ruff. + + Returns: + Formatted source code. + + Raises: + ImportError: If either is not installed. + SystemExit: If either failed. + """ + + for fn in (run_black, run_ruff): + success, source = fn(file, source) + if not success: + print(source) + sys.exit(1) + + return source + + +def gen_public_wrappers_source(file: File) -> str: + """Scan the given .py file for @_public decorators, and generate wrapper + functions. + + """ + header = [HEADER] + header.append(file.imports) + if file.platform: + # Simple checks to avoid repeating imports. If this messes up, type checkers/tests will + # just give errors. + if "TYPE_CHECKING" not in file.imports: + header.append("from typing import TYPE_CHECKING\n") + if "import sys" not in file.imports: # pragma: no cover + header.append("import sys\n") + header.append( + f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n', + ) + + generated = ["".join(header)] + + source = astor.code_to_ast.parse_file(file.path) + method_names = [] + for method in get_public_methods(source): + # Remove self from arguments + assert method.args.args[0].arg == "self" + del method.args.args[0] + method_names.append(method.name) + + for dec in method.decorator_list: # pragma: no cover + if isinstance(dec, ast.Name) and dec.id == "contextmanager": + is_cm = True + break + else: + is_cm = False + + # Remove decorators + method.decorator_list = [ast.Name("enable_ki_protection")] + + # Create pass through arguments + new_args = create_passthrough_args(method) + + # Remove method body without the docstring + if ast.get_docstring(method) is None: + del method.body[:] + else: + # The first entry is always the docstring + del method.body[1:] + + # Create the function definition including the body + func = astor.to_source(method, indent_with=" " * 4) + + if is_cm: # pragma: no cover + func = func.replace("->Iterator", "->AbstractContextManager") + + # Create export function body + template = TEMPLATE.format( + " await " if isinstance(method, ast.AsyncFunctionDef) else " ", + file.modname, + method.name + new_args, + ) + + # Assemble function definition arguments and body + snippet = func + indent(template, " " * 4) + + # Append the snippet to the corresponding module + generated.append(snippet) + + method_names.sort() + # Insert after the header, before function definitions + generated.insert(1, f"__all__ = {method_names!r}") + return "\n\n".join(generated) + + +def matches_disk_files(new_files: dict[str, str]) -> bool: + for new_path, new_source in new_files.items(): + if not os.path.exists(new_path): + return False + old_source = Path(new_path).read_text(encoding="utf-8") + if old_source != new_source: + return False + return True + + +def process(files: Iterable[File], *, do_test: bool) -> None: + new_files = {} + for file in files: + print("Scanning:", file.path) + new_source = gen_public_wrappers_source(file) + new_source = run_linters(file, new_source) + dirname, basename = os.path.split(file.path) + new_path = os.path.join(dirname, PREFIX + basename) + new_files[new_path] = new_source + matches_disk = matches_disk_files(new_files) + if do_test: + if not matches_disk: + print("Generated sources are outdated. Please regenerate.") + sys.exit(1) + else: + print("Generated sources are up to date.") + else: + for new_path, new_source in new_files.items(): + with open(new_path, "w", encoding="utf-8", newline="\n") as fp: + fp.write(new_source) + print("Regenerated sources successfully.") + if not matches_disk: # TODO: test this branch + # With pre-commit integration, show that we edited files. + sys.exit(1) + + +# This is in fact run in CI, but only in the formatting check job, which +# doesn't collect coverage. +def main() -> None: # pragma: no cover + parser = argparse.ArgumentParser( + description="Generate python code for public api wrappers", + ) + parser.add_argument( + "--test", + "-t", + action="store_true", + help="test if code is still up to date", + ) + parsed_args = parser.parse_args() + + source_root = Path.cwd() + # Double-check we found the right directory + assert (source_root / "LICENSE").exists() + core = source_root / "src/trio/_core" + to_wrap = [ + File(core / "_run.py", "runner", imports=IMPORTS_RUN), + File( + core / "_instrumentation.py", + "runner.instruments", + imports=IMPORTS_INSTRUMENT, + ), + File( + core / "_io_windows.py", + "runner.io_manager", + platform="win32", + imports=IMPORTS_WINDOWS, + ), + File( + core / "_io_epoll.py", + "runner.io_manager", + platform="linux", + imports=IMPORTS_EPOLL, + ), + File( + core / "_io_kqueue.py", + "runner.io_manager", + platform="darwin", + imports=IMPORTS_KQUEUE, + ), + ] + + process(to_wrap, do_test=parsed_args.test) + + +IMPORTS_RUN = """\ +from collections.abc import Awaitable, Callable +from typing import Any, TYPE_CHECKING + +from outcome import Outcome +import contextvars + +from ._run import _NO_SEND, RunStatistics, Task +from ._entry_queue import TrioToken +from .._abc import Clock + +if TYPE_CHECKING: + from typing_extensions import Unpack + from ._run import PosArgT +""" +IMPORTS_INSTRUMENT = """\ +from ._instrumentation import Instrument +""" + +IMPORTS_EPOLL = """\ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .._file_io import _HasFileNo +""" + +IMPORTS_KQUEUE = """\ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import select + from collections.abc import Callable + from contextlib import AbstractContextManager + + from .. import _core + from .._file_io import _HasFileNo + from ._traps import Abort, RaiseCancelT +""" + +IMPORTS_WINDOWS = """\ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + from typing_extensions import Buffer + + from .._file_io import _HasFileNo + from ._unbounded_queue import UnboundedQueue + from ._windows_cffi import Handle, CData +""" + + +if __name__ == "__main__": + main() diff --git a/contrib/python/trio/trio/_tools/mypy_annotate.py b/contrib/python/trio/trio/_tools/mypy_annotate.py new file mode 100644 index 00000000000..1d625ad7aee --- /dev/null +++ b/contrib/python/trio/trio/_tools/mypy_annotate.py @@ -0,0 +1,126 @@ +"""Translates Mypy's output into GitHub's error/warning annotation syntax. + +See: https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions + +This first is run with Mypy's output piped in, to collect messages in +mypy_annotate.dat. After all platforms run, we run this again, which prints the +messages in GitHub's format but with cross-platform failures deduplicated. +""" + +from __future__ import annotations + +import argparse +import pickle +import re +import sys + +import attrs + +# Example: 'package/filename.py:42:1:46:3: error: Type error here [code]' +report_re = re.compile( + r""" + ([^:]+): # Filename (anything but ":") + ([0-9]+): # Line number (start) + (?:([0-9]+): # Optional column number + (?:([0-9]+):([0-9]+):)? # then also optionally, 2 more numbers for end columns + )? + \s*(error|warn|note): # Kind, prefixed with space + (.+) # Message + """, + re.VERBOSE, +) + +mypy_to_github = { + "error": "error", + "warn": "warning", + "note": "notice", +} + + [email protected](kw_only=True) +class Result: + """Accumulated results, used as a dict key to deduplicate.""" + + filename: str + start_line: int + kind: str + message: str + start_col: int | None = None + end_line: int | None = None + end_col: int | None = None + + +def process_line(line: str) -> Result | None: + if match := report_re.fullmatch(line.rstrip()): + filename, st_line, st_col, end_line, end_col, kind, message = match.groups() + return Result( + filename=filename, + start_line=int(st_line), + start_col=int(st_col) if st_col is not None else None, + end_line=int(end_line) if end_line is not None else None, + end_col=int(end_col) if end_col is not None else None, + kind=mypy_to_github[kind], + message=message, + ) + else: + return None + + +def export(results: dict[Result, list[str]]) -> None: + """Display the collected results.""" + for res, platforms in results.items(): + print(f"::{res.kind} file={res.filename},line={res.start_line},", end="") + if res.start_col is not None: + print(f"col={res.start_col},", end="") + if res.end_col is not None and res.end_line is not None: + print(f"endLine={res.end_line},endColumn={res.end_col},", end="") + message = f"({res.start_line}:{res.start_col} - {res.end_line}:{res.end_col}):{res.message}" + else: + message = f"({res.start_line}:{res.start_col}):{res.message}" + else: + message = f"{res.start_line}:{res.message}" + print(f"title=Mypy-{'+'.join(platforms)}::{res.filename}:{message}") + + +def main(argv: list[str]) -> None: + """Look for error messages, and convert the format.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--dumpfile", + help="File to write pickled messages to.", + required=True, + ) + parser.add_argument( + "--platform", + help="OS name, if set Mypy should be piped to stdin.", + default=None, + ) + cmd_line = parser.parse_args(argv) + + results: dict[Result, list[str]] + try: + with open(cmd_line.dumpfile, "rb") as f: + results = pickle.load(f) + except (FileNotFoundError, pickle.UnpicklingError): + # If we fail to load, assume it's an old result. + results = {} + + if cmd_line.platform is None: + # Write out the results. + export(results) + else: + platform: str = cmd_line.platform + for line in sys.stdin: + parsed = process_line(line) + if parsed is not None: + try: + results[parsed].append(platform) + except KeyError: + results[parsed] = [platform] + sys.stdout.write(line) + with open(cmd_line.dumpfile, "wb") as f: + pickle.dump(results, f) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/contrib/python/trio/trio/_tools/sync_requirements.py b/contrib/python/trio/trio/_tools/sync_requirements.py new file mode 100644 index 00000000000..43337e29dcd --- /dev/null +++ b/contrib/python/trio/trio/_tools/sync_requirements.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 + +"""Sync Requirements - Automatically upgrade test requirements pinned +versions from pre-commit config file.""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import TYPE_CHECKING + +from yaml import load as load_yaml + +if TYPE_CHECKING: + from collections.abc import Generator + + from yaml import CLoader as _CLoader, Loader as _Loader + + Loader: type[_CLoader | _Loader] + +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader + + +def yield_pre_commit_version_data( + pre_commit_text: str, +) -> Generator[tuple[str, str], None, None]: + """Yield (name, rev) tuples from pre-commit config file.""" + pre_commit_config = load_yaml(pre_commit_text, Loader) + for repo in pre_commit_config["repos"]: + if "repo" not in repo or "rev" not in repo: + continue + url = repo["repo"] + name = url.rsplit("/", 1)[-1] + rev = repo["rev"].removeprefix("v") + yield name, rev + + +def update_requirements( + requirements: Path, + version_data: dict[str, str], +) -> bool: + """Return if updated requirements file. + + Update requirements file to match versions in version_data.""" + changed = False + old_lines = requirements.read_text(encoding="utf-8").splitlines(True) + + with requirements.open("w", encoding="utf-8") as file: + for line in old_lines: + # If comment or not version mark line, ignore. + if line.startswith("#") or "==" not in line: + file.write(line) + continue + name, rest = line.split("==", 1) + # Maintain extra markers if they exist + old_version = rest.strip() + extra = "\n" + if ";" in rest: + old_version, extra = rest.split(";", 1) + old_version = old_version.strip() + extra = " ;" + extra + version = version_data.get(name) + # If does not exist, skip + if version is None: + file.write(line) + continue + # Otherwise might have changed + new_line = f"{name}=={version}{extra}" + if new_line != line: + if not changed: + changed = True + print("Changed test requirements version to match pre-commit") + print(f"{name}=={old_version} -> {name}=={version}") + file.write(new_line) + return changed + + +if __name__ == "__main__": + source_root = Path.cwd().absolute() + + # Double-check we found the right directory + assert (source_root / "LICENSE").exists() + pre_commit = source_root / ".pre-commit-config.yaml" + test_requirements = source_root / "test-requirements.txt" + + pre_commit_text = pre_commit.read_text(encoding="utf-8") + + # Get tool versions from pre-commit + # Get correct names + pre_commit_versions = { + name.removesuffix("-mirror").removesuffix("-pre-commit"): version + for name, version in yield_pre_commit_version_data(pre_commit_text) + } + changed = update_requirements(test_requirements, pre_commit_versions) + sys.exit(int(changed)) diff --git a/contrib/python/trio/trio/_tools/windows_ffi_build.py b/contrib/python/trio/trio/_tools/windows_ffi_build.py new file mode 100644 index 00000000000..a9a39410876 --- /dev/null +++ b/contrib/python/trio/trio/_tools/windows_ffi_build.py @@ -0,0 +1,220 @@ +# builder for CFFI out-of-line mode, for reduced import time. +# run this to generate `trio._core._generated_windows_ffi`. +import re + +import cffi + +LIB = """ +// https://msdn.microsoft.com/en-us/library/windows/desktop/aa383751(v=vs.85).aspx +typedef int BOOL; +typedef unsigned char BYTE; +typedef unsigned char UCHAR; +typedef BYTE BOOLEAN; +typedef void* PVOID; +typedef PVOID HANDLE; +typedef unsigned long DWORD; +typedef unsigned long ULONG; +typedef unsigned int NTSTATUS; +typedef unsigned long u_long; +typedef ULONG *PULONG; +typedef const void *LPCVOID; +typedef void *LPVOID; +typedef const wchar_t *LPCWSTR; +typedef DWORD* LPDWORD; + +typedef uintptr_t ULONG_PTR; +typedef uintptr_t UINT_PTR; + +typedef UINT_PTR SOCKET; + +typedef struct _OVERLAPPED { + ULONG_PTR Internal; + ULONG_PTR InternalHigh; + union { + struct { + DWORD Offset; + DWORD OffsetHigh; + } DUMMYSTRUCTNAME; + PVOID Pointer; + } DUMMYUNIONNAME; + + HANDLE hEvent; +} OVERLAPPED, *LPOVERLAPPED; + +typedef OVERLAPPED WSAOVERLAPPED; +typedef LPOVERLAPPED LPWSAOVERLAPPED; +typedef PVOID LPSECURITY_ATTRIBUTES; +typedef PVOID LPCSTR; + +typedef struct _OVERLAPPED_ENTRY { + ULONG_PTR lpCompletionKey; + LPOVERLAPPED lpOverlapped; + ULONG_PTR Internal; + DWORD dwNumberOfBytesTransferred; +} OVERLAPPED_ENTRY, *LPOVERLAPPED_ENTRY; + +// kernel32.dll +HANDLE WINAPI CreateIoCompletionPort( + _In_ HANDLE FileHandle, + _In_opt_ HANDLE ExistingCompletionPort, + _In_ ULONG_PTR CompletionKey, + _In_ DWORD NumberOfConcurrentThreads +); + +BOOL SetFileCompletionNotificationModes( + HANDLE FileHandle, + UCHAR Flags +); + +HANDLE CreateFileW( + LPCWSTR lpFileName, + DWORD dwDesiredAccess, + DWORD dwShareMode, + LPSECURITY_ATTRIBUTES lpSecurityAttributes, + DWORD dwCreationDisposition, + DWORD dwFlagsAndAttributes, + HANDLE hTemplateFile +); + +BOOL WINAPI CloseHandle( + _In_ HANDLE hObject +); + +BOOL WINAPI PostQueuedCompletionStatus( + _In_ HANDLE CompletionPort, + _In_ DWORD dwNumberOfBytesTransferred, + _In_ ULONG_PTR dwCompletionKey, + _In_opt_ LPOVERLAPPED lpOverlapped +); + +BOOL WINAPI GetQueuedCompletionStatusEx( + _In_ HANDLE CompletionPort, + _Out_ LPOVERLAPPED_ENTRY lpCompletionPortEntries, + _In_ ULONG ulCount, + _Out_ PULONG ulNumEntriesRemoved, + _In_ DWORD dwMilliseconds, + _In_ BOOL fAlertable +); + +BOOL WINAPI CancelIoEx( + _In_ HANDLE hFile, + _In_opt_ LPOVERLAPPED lpOverlapped +); + +BOOL WriteFile( + HANDLE hFile, + LPCVOID lpBuffer, + DWORD nNumberOfBytesToWrite, + LPDWORD lpNumberOfBytesWritten, + LPOVERLAPPED lpOverlapped +); + +BOOL ReadFile( + HANDLE hFile, + LPVOID lpBuffer, + DWORD nNumberOfBytesToRead, + LPDWORD lpNumberOfBytesRead, + LPOVERLAPPED lpOverlapped +); + +BOOL WINAPI SetConsoleCtrlHandler( + _In_opt_ void* HandlerRoutine, + _In_ BOOL Add +); + +HANDLE CreateEventA( + LPSECURITY_ATTRIBUTES lpEventAttributes, + BOOL bManualReset, + BOOL bInitialState, + LPCSTR lpName +); + +BOOL SetEvent( + HANDLE hEvent +); + +BOOL ResetEvent( + HANDLE hEvent +); + +DWORD WaitForSingleObject( + HANDLE hHandle, + DWORD dwMilliseconds +); + +DWORD WaitForMultipleObjects( + DWORD nCount, + HANDLE *lpHandles, + BOOL bWaitAll, + DWORD dwMilliseconds +); + +ULONG RtlNtStatusToDosError( + NTSTATUS Status +); + +int WSAIoctl( + SOCKET s, + DWORD dwIoControlCode, + LPVOID lpvInBuffer, + DWORD cbInBuffer, + LPVOID lpvOutBuffer, + DWORD cbOutBuffer, + LPDWORD lpcbBytesReturned, + LPWSAOVERLAPPED lpOverlapped, + // actually LPWSAOVERLAPPED_COMPLETION_ROUTINE + void* lpCompletionRoutine +); + +int WSAGetLastError(); + +BOOL DeviceIoControl( + HANDLE hDevice, + DWORD dwIoControlCode, + LPVOID lpInBuffer, + DWORD nInBufferSize, + LPVOID lpOutBuffer, + DWORD nOutBufferSize, + LPDWORD lpBytesReturned, + LPOVERLAPPED lpOverlapped +); + +// From https://github.com/piscisaureus/wepoll/blob/master/src/afd.h +typedef struct _AFD_POLL_HANDLE_INFO { + HANDLE Handle; + ULONG Events; + NTSTATUS Status; +} AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO; + +// This is really defined as a messy union to allow stuff like +// i.DUMMYSTRUCTNAME.LowPart, but we don't need those complications. +// Under all that it's just an int64. +typedef int64_t LARGE_INTEGER; + +typedef struct _AFD_POLL_INFO { + LARGE_INTEGER Timeout; + ULONG NumberOfHandles; + ULONG Exclusive; + AFD_POLL_HANDLE_INFO Handles[1]; +} AFD_POLL_INFO, *PAFD_POLL_INFO; + +""" + +# cribbed from pywincffi +# programmatically strips out those annotations MSDN likes, like _In_ +LIB = re.sub(r"\b(_In_|_Inout_|_Out_|_Outptr_|_Reserved_)(opt_)?\b", " ", LIB) + +# Other fixups: +# - get rid of FAR, cffi doesn't like it +LIB = re.sub(r"\bFAR\b", " ", LIB) +# - PASCAL is apparently an alias for __stdcall (on modern compilers - modern +# being _MSC_VER >= 800) +LIB = re.sub(r"\bPASCAL\b", "__stdcall", LIB) + +ffibuilder = cffi.FFI() +# a bit hacky but, it works +ffibuilder.set_source("trio._core._generated_windows_ffi", None) +ffibuilder.cdef(LIB) + +if __name__ == "__main__": + ffibuilder.compile("src") diff --git a/contrib/python/trio/trio/_unix_pipes.py b/contrib/python/trio/trio/_unix_pipes.py new file mode 100644 index 00000000000..dbe4358b4c5 --- /dev/null +++ b/contrib/python/trio/trio/_unix_pipes.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import errno +import os +import sys +from typing import TYPE_CHECKING + +import trio + +from ._abc import Stream +from ._util import ConflictDetector, final + +if TYPE_CHECKING: + from typing import Final as FinalType + +assert not TYPE_CHECKING or sys.platform != "win32" + +# XX TODO: is this a good number? who knows... it does match the default Linux +# pipe capacity though. +DEFAULT_RECEIVE_SIZE: FinalType = 65536 + + +class _FdHolder: + # This class holds onto a raw file descriptor, in non-blocking mode, and + # is responsible for managing its lifecycle. In particular, it's + # responsible for making sure it gets closed, and also for tracking + # whether it's been closed. + # + # The way we track closure is to set the .fd field to -1, discarding the + # original value. You might think that this is a strange idea, since it + # overloads the same field to do two different things. Wouldn't it be more + # natural to have a dedicated .closed field? But that would be more + # error-prone. Fds are represented by small integers, and once an fd is + # closed, its integer value may be reused immediately. If we accidentally + # used the old fd after being closed, we might end up doing something to + # another unrelated fd that happened to get assigned the same integer + # value. By throwing away the integer value immediately, it becomes + # impossible to make this mistake – we'll just get an EBADF. + # + # (This trick was copied from the stdlib socket module.) + fd: int + + def __init__(self, fd: int) -> None: + # make sure self.fd is always initialized to *something*, because even + # if we error out here then __del__ will run and access it. + self.fd = -1 + if not isinstance(fd, int): + raise TypeError("file descriptor must be an int") + self.fd = fd + # Store original state, and ensure non-blocking mode is enabled + self._original_is_blocking = os.get_blocking(fd) + os.set_blocking(fd, False) + + @property + def closed(self) -> bool: + return self.fd == -1 + + def _raw_close(self) -> None: + # This doesn't assume it's in a Trio context, so it can be called from + # __del__. You should never call it from Trio context, because it + # skips calling notify_fd_close. But from __del__, skipping that is + # OK, because notify_fd_close just wakes up other tasks that are + # waiting on this fd, and those tasks hold a reference to this object. + # So if __del__ is being called, we know there aren't any tasks that + # need to be woken. + if self.closed: + return + fd = self.fd + self.fd = -1 + os.set_blocking(fd, self._original_is_blocking) + os.close(fd) + + def __del__(self) -> None: + self._raw_close() + + def close(self) -> None: + if not self.closed: + trio.lowlevel.notify_closing(self.fd) + self._raw_close() + + +@final +class FdStream(Stream): + """Represents a stream given the file descriptor to a pipe, TTY, etc. + + *fd* must refer to a file that is open for reading and/or writing and + supports non-blocking I/O (pipes and TTYs will work, on-disk files probably + not). The returned stream takes ownership of the fd, so closing the stream + will close the fd too. As with `os.fdopen`, you should not directly use + an fd after you have wrapped it in a stream using this function. + + To be used as a Trio stream, an open file must be placed in non-blocking + mode. Unfortunately, this impacts all I/O that goes through the + underlying open file, including I/O that uses a different + file descriptor than the one that was passed to Trio. If other threads + or processes are using file descriptors that are related through `os.dup` + or inheritance across `os.fork` to the one that Trio is using, they are + unlikely to be prepared to have non-blocking I/O semantics suddenly + thrust upon them. For example, you can use + ``FdStream(os.dup(sys.stdin.fileno()))`` to obtain a stream for reading + from standard input, but it is only safe to do so with heavy caveats: your + stdin must not be shared by any other processes, and you must not make any + calls to synchronous methods of `sys.stdin` until the stream returned by + `FdStream` is closed. See `issue #174 + <https://github.com/python-trio/trio/issues/174>`__ for a discussion of the + challenges involved in relaxing this restriction. + + .. warning:: one specific consequence of non-blocking mode + applying to the entire open file description is that when + your program is run with multiple standard streams connected to + a TTY (as in a terminal emulator), all of the streams become + non-blocking when you construct an `FdStream` for any of them. + For example, if you construct an `FdStream` for standard input, + you might observe Python loggers begin to fail with + `BlockingIOError`. + + Args: + fd (int): The fd to be wrapped. + + Returns: + A new `FdStream` object. + """ + + def __init__(self, fd: int) -> None: + self._fd_holder = _FdHolder(fd) + self._send_conflict_detector = ConflictDetector( + "another task is using this stream for send", + ) + self._receive_conflict_detector = ConflictDetector( + "another task is using this stream for receive", + ) + + async def send_all(self, data: bytes) -> None: + with self._send_conflict_detector: + # have to check up front, because send_all(b"") on a closed pipe + # should raise + if self._fd_holder.closed: + raise trio.ClosedResourceError("file was already closed") + await trio.lowlevel.checkpoint() + length = len(data) + # adapted from the SocketStream code + with memoryview(data) as view: + sent = 0 + while sent < length: + with view[sent:] as remaining: + try: + sent += os.write(self._fd_holder.fd, remaining) + except BlockingIOError: + await trio.lowlevel.wait_writable(self._fd_holder.fd) + except OSError as e: + if e.errno == errno.EBADF: + raise trio.ClosedResourceError( + "file was already closed", + ) from None + else: + raise trio.BrokenResourceError from e + + async def wait_send_all_might_not_block(self) -> None: + with self._send_conflict_detector: + if self._fd_holder.closed: + raise trio.ClosedResourceError("file was already closed") + try: + await trio.lowlevel.wait_writable(self._fd_holder.fd) + except BrokenPipeError as e: + # kqueue: raises EPIPE on wait_writable instead + # of sending, which is annoying + raise trio.BrokenResourceError from e + + async def receive_some(self, max_bytes: int | None = None) -> bytes: + with self._receive_conflict_detector: + if max_bytes is None: + max_bytes = DEFAULT_RECEIVE_SIZE + else: + if not isinstance(max_bytes, int): + raise TypeError("max_bytes must be integer >= 1") + if max_bytes < 1: + raise ValueError("max_bytes must be integer >= 1") + + await trio.lowlevel.checkpoint() + while True: + try: + data = os.read(self._fd_holder.fd, max_bytes) + except BlockingIOError: + await trio.lowlevel.wait_readable(self._fd_holder.fd) + except OSError as exc: + if exc.errno == errno.EBADF: + raise trio.ClosedResourceError( + "file was already closed", + ) from None + else: + raise trio.BrokenResourceError from exc + else: + break + + return data + + def close(self) -> None: + self._fd_holder.close() + + async def aclose(self) -> None: + self.close() + await trio.lowlevel.checkpoint() + + def fileno(self) -> int: + return self._fd_holder.fd diff --git a/contrib/python/trio/trio/_util.py b/contrib/python/trio/trio/_util.py new file mode 100644 index 00000000000..54d324cab36 --- /dev/null +++ b/contrib/python/trio/trio/_util.py @@ -0,0 +1,421 @@ +# Little utilities we use internally +from __future__ import annotations + +import collections.abc +import inspect +import signal +from abc import ABCMeta +from collections.abc import Awaitable, Callable, Sequence +from functools import update_wrapper +from typing import ( + TYPE_CHECKING, + Any, + Generic, + NoReturn, + TypeVar, + final as std_final, +) + +from sniffio import thread_local as sniffio_loop + +import trio + +# Explicit "Any" is not allowed +CallT = TypeVar("CallT", bound=Callable[..., Any]) # type: ignore[explicit-any] +T = TypeVar("T") +RetT = TypeVar("RetT") + +if TYPE_CHECKING: + import sys + from types import AsyncGeneratorType, TracebackType + + from typing_extensions import Self, TypeVarTuple, Unpack + + if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + + PosArgsT = TypeVarTuple("PosArgsT") + + +# See: #461 as to why this is needed. +# The gist is that threading.main_thread() has the capability to lie to us +# if somebody else edits the threading ident cache to replace the main +# thread; causing threading.current_thread() to return a _DummyThread, +# causing the C-c check to fail, and so on. +# Trying to use signal out of the main thread will fail, so we can then +# reliably check if this is the main thread without relying on a +# potentially modified threading. +def is_main_thread() -> bool: + """Attempt to reliably check if we are in the main thread.""" + try: + signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT)) + return True + except (TypeError, ValueError): + return False + + +###### +# Call the function and get the coroutine object, while giving helpful +# errors for common mistakes. Returns coroutine object. +###### +def coroutine_or_error( + async_fn: Callable[[Unpack[PosArgsT]], Awaitable[RetT]], + *args: Unpack[PosArgsT], +) -> collections.abc.Coroutine[object, NoReturn, RetT]: + def _return_value_looks_like_wrong_library(value: object) -> bool: + # Returned by legacy @asyncio.coroutine functions, which includes + # a surprising proportion of asyncio builtins. + if isinstance(value, collections.abc.Generator): + return True + # The protocol for detecting an asyncio Future-like object + if getattr(value, "_asyncio_future_blocking", None) is not None: + return True + # This janky check catches tornado Futures and twisted Deferreds. + # By the time we're calling this function, we already know + # something has gone wrong, so a heuristic is pretty safe. + return value.__class__.__name__ in ("Future", "Deferred") + + # Make sure a sync-fn-that-returns-coroutine still sees itself as being + # in trio context + prev_loop, sniffio_loop.name = sniffio_loop.name, "trio" + + try: + coro = async_fn(*args) + + except TypeError: + # Give good error for: nursery.start_soon(trio.sleep(1)) + if isinstance(async_fn, collections.abc.Coroutine): + # explicitly close coroutine to avoid RuntimeWarning + async_fn.close() + + raise TypeError( + "Trio was expecting an async function, but instead it got " + f"a coroutine object {async_fn!r}\n" + "\n" + "Probably you did something like:\n" + "\n" + f" trio.run({async_fn.__name__}(...)) # incorrect!\n" + f" nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n" + "\n" + "Instead, you want (notice the parentheses!):\n" + "\n" + f" trio.run({async_fn.__name__}, ...) # correct!\n" + f" nursery.start_soon({async_fn.__name__}, ...) # correct!", + ) from None + + # Give good error for: nursery.start_soon(future) + if _return_value_looks_like_wrong_library(async_fn): + raise TypeError( + "Trio was expecting an async function, but instead it got " + f"{async_fn!r} – are you trying to use a library written for " + "asyncio/twisted/tornado or similar? That won't work " + "without some sort of compatibility shim.", + ) from None + + raise + + finally: + sniffio_loop.name = prev_loop + + # We can't check iscoroutinefunction(async_fn), because that will fail + # for things like functools.partial objects wrapping an async + # function. So we have to just call it and then check whether the + # return value is a coroutine object. + # Note: will not be necessary on python>=3.8, see https://bugs.python.org/issue34890 + # TODO: python3.7 support is now dropped, so the above can be addressed. + if not isinstance(coro, collections.abc.Coroutine): + # Give good error for: nursery.start_soon(func_returning_future) + if _return_value_looks_like_wrong_library(coro): + raise TypeError( + f"Trio got unexpected {coro!r} – are you trying to use a " + "library written for asyncio/twisted/tornado or similar? " + "That won't work without some sort of compatibility shim.", + ) + + if inspect.isasyncgen(coro): + raise TypeError( + "start_soon expected an async function but got an async " + f"generator {coro!r}", + ) + + # Give good error for: nursery.start_soon(some_sync_fn) + raise TypeError( + "Trio expected an async function, but {!r} appears to be " + "synchronous".format(getattr(async_fn, "__qualname__", async_fn)), + ) + + return coro + + +class ConflictDetector: + """Detect when two tasks are about to perform operations that would + conflict. + + Use as a synchronous context manager; if two tasks enter it at the same + time then the second one raises an error. You can use it when there are + two pieces of code that *would* collide and need a lock if they ever were + called at the same time, but that should never happen. + + We use this in particular for things like, making sure that two different + tasks don't call sendall simultaneously on the same stream. + + """ + + def __init__(self, msg: str) -> None: + self._msg = msg + self._held = False + + def __enter__(self) -> None: + if self._held: + raise trio.BusyResourceError(self._msg) + else: + self._held = True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self._held = False + + +def async_wraps( # type: ignore[explicit-any] + cls: type[object], + wrapped_cls: type[object], + attr_name: str, +) -> Callable[[CallT], CallT]: + """Similar to wraps, but for async wrappers of non-async functions.""" + + def decorator(func: CallT) -> CallT: # type: ignore[explicit-any] + func.__name__ = attr_name + func.__qualname__ = f"{cls.__qualname__}.{attr_name}" + + func.__doc__ = f"Like :meth:`~{wrapped_cls.__module__}.{wrapped_cls.__qualname__}.{attr_name}`, but async." + + return func + + return decorator + + +def fixup_module_metadata( + module_name: str, + namespace: collections.abc.Mapping[str, object], +) -> None: + seen_ids: set[int] = set() + + def fix_one(qualname: str, name: str, obj: object) -> None: + # avoid infinite recursion (relevant when using + # typing.Generic, for example) + if id(obj) in seen_ids: + return + seen_ids.add(id(obj)) + + mod = getattr(obj, "__module__", None) + if mod is not None and mod.startswith("trio."): + obj.__module__ = module_name + # Modules, unlike everything else in Python, put fully-qualified + # names into their __name__ attribute. We check for "." to avoid + # rewriting these. + if hasattr(obj, "__name__") and "." not in obj.__name__: + obj.__name__ = name + if hasattr(obj, "__qualname__"): + obj.__qualname__ = qualname + if isinstance(obj, type): + for attr_name, attr_value in obj.__dict__.items(): + fix_one(objname + "." + attr_name, attr_name, attr_value) + + for objname, obj in namespace.items(): + if not objname.startswith("_"): # ignore private attributes + fix_one(objname, objname, obj) + + +# We need ParamSpec to type this "properly", but that requires a runtime typing_extensions import +# to use as a class base. This is only used at runtime and isn't correct for type checkers anyway, +# so don't bother. +class generic_function(Generic[RetT]): + """Decorator that makes a function indexable, to communicate + non-inferable generic type parameters to a static type checker. + + If you write:: + + @generic_function + def open_memory_channel(max_buffer_size: int) -> Tuple[ + SendChannel[T], ReceiveChannel[T] + ]: ... + + it is valid at runtime to say ``open_memory_channel[bytes](5)``. + This behaves identically to ``open_memory_channel(5)`` at runtime, + and currently won't type-check without a mypy plugin or clever stubs, + but at least it becomes possible to write those. + """ + + def __init__( # type: ignore[explicit-any] + self, + fn: Callable[..., RetT], + ) -> None: + update_wrapper(self, fn) + self._fn = fn + + def __call__(self, *args: object, **kwargs: object) -> RetT: + return self._fn(*args, **kwargs) + + def __getitem__(self, subscript: object) -> Self: + return self + + +def _init_final_cls(cls: type[object]) -> NoReturn: + """Raises an exception when a final class is subclassed.""" + raise TypeError(f"{cls.__module__}.{cls.__qualname__} does not support subclassing") + + +def _final_impl(decorated: type[T]) -> type[T]: + """Decorator that enforces a class to be final (i.e., subclass not allowed). + + If a class uses this metaclass like this:: + + @final + class SomeClass: + pass + + The metaclass will ensure that no subclass can be created. + + Raises + ------ + - TypeError if a subclass is created + """ + # Override the method blindly. We're always going to raise, so it doesn't + # matter what the original did (if anything). + decorated.__init_subclass__ = classmethod(_init_final_cls) # type: ignore[assignment] + # Apply the typing decorator, in 3.11+ it adds a __final__ marker attribute. + return std_final(decorated) + + +if TYPE_CHECKING: + from typing import final +else: + final = _final_impl + + +@final # No subclassing of NoPublicConstructor itself. +class NoPublicConstructor(ABCMeta): + """Metaclass that ensures a private constructor. + + If a class uses this metaclass like this:: + + @final + class SomeClass(metaclass=NoPublicConstructor): + pass + + The metaclass will ensure that no instance can be initialized. This should always be + used with @final. + + If you try to instantiate your class (SomeClass()), a TypeError will be thrown. Use + _create() instead in the class's implementation. + + Raises + ------ + - TypeError if an instance is created. + """ + + def __call__(cls, *args: object, **kwargs: object) -> None: + raise TypeError( + f"{cls.__module__}.{cls.__qualname__} has no public constructor", + ) + + def _create(cls: type[T], *args: object, **kwargs: object) -> T: + return super().__call__(*args, **kwargs) # type: ignore + + +def name_asyncgen(agen: AsyncGeneratorType[object, NoReturn]) -> str: + """Return the fully-qualified name of the async generator function + that produced the async generator iterator *agen*. + """ + if not hasattr(agen, "ag_code"): # pragma: no cover + return repr(agen) + try: + module = agen.ag_frame.f_globals["__name__"] + except (AttributeError, KeyError): + module = f"<{agen.ag_code.co_filename}>" + try: + qualname = agen.__qualname__ + except AttributeError: + qualname = agen.ag_code.co_name + return f"{module}.{qualname}" + + +# work around a pyright error +if TYPE_CHECKING: + Fn = TypeVar("Fn", bound=Callable[..., object]) # type: ignore[explicit-any] + + def wraps( # type: ignore[explicit-any] + wrapped: Callable[..., object], + assigned: Sequence[str] = ..., + updated: Sequence[str] = ..., + ) -> Callable[[Fn], Fn]: ... + +else: + from functools import wraps # noqa: F401 # this is re-exported + + +def raise_saving_context(exc: BaseException) -> NoReturn: + """This helper allows re-raising an exception without __context__ being set.""" + # cause does not need special handling, we simply avoid using `raise .. from ..` + # __suppress_context__ also does not need handling, it's only set if modifying cause + __tracebackhide__ = True + context = exc.__context__ + try: + raise exc + finally: + exc.__context__ = context + del exc, context + + +class MultipleExceptionError(Exception): + """Raised by raise_single_exception_from_group if encountering multiple + non-cancelled exceptions.""" + + +def raise_single_exception_from_group( + eg: BaseExceptionGroup[BaseException], +) -> NoReturn: + """This function takes an exception group that is assumed to have at most + one non-cancelled exception, which it reraises as a standalone exception. + + This exception may be an exceptiongroup itself, in which case it will not be unwrapped. + + If a :exc:`KeyboardInterrupt` is encountered, a new KeyboardInterrupt is immediately + raised with the entire group as cause. + + If the group only contains :exc:`Cancelled` it reraises the first one encountered. + + It will retain context and cause of the contained exception, and entirely discard + the cause/context of the group(s). + + If multiple non-cancelled exceptions are encountered, it raises + :exc:`AssertionError`. + """ + # immediately bail out if there's any KI or SystemExit + for e in eg.exceptions: + if isinstance(e, (KeyboardInterrupt, SystemExit)): + raise type(e)(*e.args) from eg + + cancelled_exception: trio.Cancelled | None = None + noncancelled_exception: BaseException | None = None + + for e in eg.exceptions: + if isinstance(e, trio.Cancelled): + if cancelled_exception is None: + cancelled_exception = e + elif noncancelled_exception is None: + noncancelled_exception = e + else: + raise MultipleExceptionError( + "Attempted to unwrap exceptiongroup with multiple non-cancelled exceptions. This is often caused by a bug in the caller." + ) from eg + + if noncancelled_exception is not None: + raise_saving_context(noncancelled_exception) + + assert cancelled_exception is not None, "group can't be empty" + raise_saving_context(cancelled_exception) diff --git a/contrib/python/trio/trio/_version.py b/contrib/python/trio/trio/_version.py new file mode 100644 index 00000000000..03116ed1c97 --- /dev/null +++ b/contrib/python/trio/trio/_version.py @@ -0,0 +1,3 @@ +# This file is imported from __init__.py and parsed by setuptools + +__version__ = "0.31.0" diff --git a/contrib/python/trio/trio/_wait_for_object.py b/contrib/python/trio/trio/_wait_for_object.py new file mode 100644 index 00000000000..53832513a34 --- /dev/null +++ b/contrib/python/trio/trio/_wait_for_object.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import math + +import trio + +from ._core._windows_cffi import ( + CData, + ErrorCodes, + _handle, + ffi, + handle_array, + kernel32, + raise_winerror, +) + + +async def WaitForSingleObject(obj: int | CData) -> None: + """Async and cancellable variant of WaitForSingleObject. Windows only. + + Args: + handle: A Win32 handle, as a Python integer. + + Raises: + OSError: If the handle is invalid, e.g. when it is already closed. + + """ + # Allow ints or whatever we can convert to a win handle + handle = _handle(obj) + + # Quick check; we might not even need to spawn a thread. The zero + # means a zero timeout; this call never blocks. We also exit here + # if the handle is already closed for some reason. + retcode = kernel32.WaitForSingleObject(handle, 0) + if retcode == ErrorCodes.WAIT_FAILED: + raise_winerror() + elif retcode != ErrorCodes.WAIT_TIMEOUT: + return + + # Wait for a thread that waits for two handles: the handle plus a handle + # that we can use to cancel the thread. + cancel_handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + try: + await trio.to_thread.run_sync( + WaitForMultipleObjects_sync, + handle, + cancel_handle, + abandon_on_cancel=True, + limiter=trio.CapacityLimiter(math.inf), + ) + finally: + # Clean up our cancel handle. In case we get here because this task was + # cancelled, we also want to set the cancel_handle to stop the thread. + kernel32.SetEvent(cancel_handle) + kernel32.CloseHandle(cancel_handle) + + +def WaitForMultipleObjects_sync(*handles: int | CData) -> None: + """Wait for any of the given Windows handles to be signaled.""" + n = len(handles) + handle_arr = handle_array(n) + for i in range(n): + handle_arr[i] = handles[i] + timeout = 0xFFFFFFFF # INFINITE + retcode = kernel32.WaitForMultipleObjects(n, handle_arr, False, timeout) # blocking + if retcode == ErrorCodes.WAIT_FAILED: + raise_winerror() diff --git a/contrib/python/trio/trio/_windows_pipes.py b/contrib/python/trio/trio/_windows_pipes.py new file mode 100644 index 00000000000..e1eea1e72d6 --- /dev/null +++ b/contrib/python/trio/trio/_windows_pipes.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + +from . import _core +from ._abc import ReceiveStream, SendStream +from ._core._windows_cffi import _handle, kernel32, raise_winerror +from ._util import ConflictDetector, final + +assert sys.platform == "win32" or not TYPE_CHECKING + +# XX TODO: don't just make this up based on nothing. +DEFAULT_RECEIVE_SIZE = 65536 + + +# See the comments on _unix_pipes._FdHolder for discussion of why we set the +# handle to -1 when it's closed. +class _HandleHolder: + def __init__(self, handle: int) -> None: + self.handle = -1 + if not isinstance(handle, int): + raise TypeError("handle must be an int") + self.handle = handle + _core.register_with_iocp(self.handle) + + @property + def closed(self) -> bool: + return self.handle == -1 + + def close(self) -> None: + if self.closed: + return + handle = self.handle + self.handle = -1 + if not kernel32.CloseHandle(_handle(handle)): + raise_winerror() + + def __del__(self) -> None: + self.close() + + +@final +class PipeSendStream(SendStream): + """Represents a send stream over a Windows named pipe that has been + opened in OVERLAPPED mode. + """ + + def __init__(self, handle: int) -> None: + self._handle_holder = _HandleHolder(handle) + self._conflict_detector = ConflictDetector( + "another task is currently using this pipe", + ) + + async def send_all(self, data: bytes) -> None: + with self._conflict_detector: + if self._handle_holder.closed: + raise _core.ClosedResourceError("this pipe is already closed") + + if not data: + await _core.checkpoint() + return + + try: + written = await _core.write_overlapped(self._handle_holder.handle, data) + except BrokenPipeError as ex: + raise _core.BrokenResourceError from ex + # By my reading of MSDN, this assert is guaranteed to pass so long + # as the pipe isn't in nonblocking mode, but... let's just + # double-check. + assert written == len(data) + + async def wait_send_all_might_not_block(self) -> None: + with self._conflict_detector: + if self._handle_holder.closed: + raise _core.ClosedResourceError("This pipe is already closed") + + # not implemented yet, and probably not needed + await _core.checkpoint() + + def close(self) -> None: + self._handle_holder.close() + + async def aclose(self) -> None: + self.close() + await _core.checkpoint() + + +@final +class PipeReceiveStream(ReceiveStream): + """Represents a receive stream over an os.pipe object.""" + + def __init__(self, handle: int) -> None: + self._handle_holder = _HandleHolder(handle) + self._conflict_detector = ConflictDetector( + "another task is currently using this pipe", + ) + + async def receive_some(self, max_bytes: int | None = None) -> bytes: + with self._conflict_detector: + if self._handle_holder.closed: + raise _core.ClosedResourceError("this pipe is already closed") + + if max_bytes is None: + max_bytes = DEFAULT_RECEIVE_SIZE + else: + if not isinstance(max_bytes, int): + raise TypeError("max_bytes must be integer >= 1") + if max_bytes < 1: + raise ValueError("max_bytes must be integer >= 1") + + buffer = bytearray(max_bytes) + try: + size = await _core.readinto_overlapped( + self._handle_holder.handle, + buffer, + ) + except BrokenPipeError: + if self._handle_holder.closed: + raise _core.ClosedResourceError( + "another task closed this pipe", + ) from None + + # Windows raises BrokenPipeError on one end of a pipe + # whenever the other end closes, regardless of direction. + # Convert this to the Unix behavior of returning EOF to the + # reader when the writer closes. + # + # And since we're not raising an exception, we have to + # checkpoint. But readinto_overlapped did raise an exception, + # so it might not have checkpointed for us. So we have to + # checkpoint manually. + await _core.checkpoint() + return b"" + else: + del buffer[size:] + return buffer + + def close(self) -> None: + self._handle_holder.close() + + async def aclose(self) -> None: + self.close() + await _core.checkpoint() diff --git a/contrib/python/trio/trio/abc.py b/contrib/python/trio/trio/abc.py new file mode 100644 index 00000000000..439995640e7 --- /dev/null +++ b/contrib/python/trio/trio/abc.py @@ -0,0 +1,23 @@ +# This is a public namespace, so we don't want to expose any non-underscored +# attributes that aren't actually part of our public API. But it's very +# annoying to carefully always use underscored names for module-level +# temporaries, imports, etc. when implementing the module. So we put the +# implementation in an underscored module, and then re-export the public parts +# here. + +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) +from ._abc import ( + AsyncResource as AsyncResource, + Channel as Channel, + Clock as Clock, + HalfCloseableStream as HalfCloseableStream, + HostnameResolver as HostnameResolver, + Instrument as Instrument, + Listener as Listener, + ReceiveChannel as ReceiveChannel, + ReceiveStream as ReceiveStream, + SendChannel as SendChannel, + SendStream as SendStream, + SocketFactory as SocketFactory, + Stream as Stream, +) diff --git a/contrib/python/trio/trio/from_thread.py b/contrib/python/trio/trio/from_thread.py new file mode 100644 index 00000000000..50f3bac28bf --- /dev/null +++ b/contrib/python/trio/trio/from_thread.py @@ -0,0 +1,13 @@ +""" +This namespace represents special functions that can call back into Trio from +an external thread by means of a Trio Token present in Thread Local Storage +""" + +from ._threads import ( + from_thread_check_cancelled as check_cancelled, + from_thread_run as run, + from_thread_run_sync as run_sync, +) + +# need to use __all__ for pyright --verifytypes to see re-exports when renaming them +__all__ = ["check_cancelled", "run", "run_sync"] diff --git a/contrib/python/trio/trio/lowlevel.py b/contrib/python/trio/trio/lowlevel.py new file mode 100644 index 00000000000..b6621d47edd --- /dev/null +++ b/contrib/python/trio/trio/lowlevel.py @@ -0,0 +1,95 @@ +""" +This namespace represents low-level functionality not intended for daily use, +but useful for extending Trio's functionality. +""" + +# imports are renamed with leading underscores to indicate they are not part of the public API +import select as _select + +# static checkers don't understand if importing this as _sys, so it's deleted later +import sys +import typing as _t + +# Generally available symbols +from ._core import ( + Abort as Abort, + ParkingLot as ParkingLot, + ParkingLotStatistics as ParkingLotStatistics, + RaiseCancelT as RaiseCancelT, + RunStatistics as RunStatistics, + RunVar as RunVar, + RunVarToken as RunVarToken, + Task as Task, + TrioToken as TrioToken, + UnboundedQueue as UnboundedQueue, + UnboundedQueueStatistics as UnboundedQueueStatistics, + add_instrument as add_instrument, + add_parking_lot_breaker as add_parking_lot_breaker, + cancel_shielded_checkpoint as cancel_shielded_checkpoint, + checkpoint as checkpoint, + checkpoint_if_cancelled as checkpoint_if_cancelled, + current_clock as current_clock, + current_root_task as current_root_task, + current_statistics as current_statistics, + current_task as current_task, + current_trio_token as current_trio_token, + currently_ki_protected as currently_ki_protected, + disable_ki_protection as disable_ki_protection, + enable_ki_protection as enable_ki_protection, + in_trio_run as in_trio_run, + in_trio_task as in_trio_task, + notify_closing as notify_closing, + permanently_detach_coroutine_object as permanently_detach_coroutine_object, + reattach_detached_coroutine_object as reattach_detached_coroutine_object, + remove_instrument as remove_instrument, + remove_parking_lot_breaker as remove_parking_lot_breaker, + reschedule as reschedule, + spawn_system_task as spawn_system_task, + start_guest_run as start_guest_run, + start_thread_soon as start_thread_soon, + temporarily_detach_coroutine_object as temporarily_detach_coroutine_object, + wait_readable as wait_readable, + wait_task_rescheduled as wait_task_rescheduled, + wait_writable as wait_writable, +) +from ._subprocess import open_process as open_process + +# This is the union of a subset of trio/_core/ and some things from trio/*.py. +# See comments in trio/__init__.py for details. + +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) + +if sys.platform == "win32" or ( + not _t.TYPE_CHECKING and "sphinx.ext.autodoc" in sys.modules +): + # Windows symbols + from ._core import ( + current_iocp as current_iocp, + monitor_completion_key as monitor_completion_key, + readinto_overlapped as readinto_overlapped, + register_with_iocp as register_with_iocp, + wait_overlapped as wait_overlapped, + write_overlapped as write_overlapped, + ) + + # don't let documentation import the actual implementation + if sys.platform == "win32": # pragma: no branch + from ._wait_for_object import WaitForSingleObject as WaitForSingleObject + +if sys.platform != "win32" or ( + not _t.TYPE_CHECKING and "sphinx.ext.autodoc" in sys.modules +): + # Unix symbols + from ._unix_pipes import FdStream as FdStream + + # Kqueue-specific symbols + if ( + sys.platform != "linux" and (_t.TYPE_CHECKING or not hasattr(_select, "epoll")) + ) or (not _t.TYPE_CHECKING and "sphinx.ext.autodoc" in sys.modules): + from ._core import ( + current_kqueue as current_kqueue, + monitor_kevent as monitor_kevent, + wait_kevent as wait_kevent, + ) + +del sys diff --git a/contrib/python/trio/trio/py.typed b/contrib/python/trio/trio/py.typed new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/contrib/python/trio/trio/py.typed diff --git a/contrib/python/trio/trio/socket.py b/contrib/python/trio/trio/socket.py new file mode 100644 index 00000000000..cfcb9943c83 --- /dev/null +++ b/contrib/python/trio/trio/socket.py @@ -0,0 +1,602 @@ +from __future__ import annotations + +# This is a public namespace, so we don't want to expose any non-underscored +# attributes that aren't actually part of our public API. But it's very +# annoying to carefully always use underscored names for module-level +# temporaries, imports, etc. when implementing the module. So we put the +# implementation in an underscored module, and then re-export the public parts +# here. +# We still have some underscore names though but only a few. +import socket as _stdlib_socket + +# static checkers don't understand if importing this as _sys, so it's deleted later +import sys +import typing as _t + +from . import _socket + +_bad_symbols: set[str] = set() +if sys.platform == "win32": + # See https://github.com/python-trio/trio/issues/39 + # Do not import for windows platform + # (you can still get it from stdlib socket, of course, if you want it) + _bad_symbols.add("SO_REUSEADDR") + +# Dynamically re-export whatever constants this particular Python happens to +# have: +globals().update( + { + _name: getattr(_stdlib_socket, _name) + for _name in _stdlib_socket.__all__ + if _name.isupper() and _name not in _bad_symbols + }, +) + +# import the overwrites +from contextlib import suppress as _suppress + +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) +from ._socket import ( + SocketType as SocketType, + from_stdlib_socket as from_stdlib_socket, + fromfd as fromfd, + getaddrinfo as getaddrinfo, + getnameinfo as getnameinfo, + getprotobyname as getprotobyname, + set_custom_hostname_resolver as set_custom_hostname_resolver, + set_custom_socket_factory as set_custom_socket_factory, + socket as socket, + socketpair as socketpair, +) + +# not always available so expose only if +if sys.platform == "win32" or not _t.TYPE_CHECKING: + with _suppress(ImportError): + from ._socket import fromshare as fromshare + +# expose these functions to trio.socket +from socket import ( + gaierror as gaierror, + gethostname as gethostname, + herror as herror, + htonl as htonl, + htons as htons, + inet_aton as inet_aton, + inet_ntoa as inet_ntoa, + inet_ntop as inet_ntop, + inet_pton as inet_pton, + ntohs as ntohs, +) + +if sys.implementation.name == "cpython": + from socket import ( + if_indextoname as if_indextoname, + if_nametoindex as if_nametoindex, + ) + + # For android devices, if_nameindex support was introduced in API 24, + # so it doesn't exist for any version prior. + with _suppress(ImportError): + from socket import ( + if_nameindex as if_nameindex, + ) + + +# not always available so expose only if +if sys.platform != "win32" or not _t.TYPE_CHECKING: + with _suppress(ImportError): + from socket import ( + sethostname as sethostname, + ) + +if _t.TYPE_CHECKING: + IP_BIND_ADDRESS_NO_PORT: int +else: + try: + IP_BIND_ADDRESS_NO_PORT # noqa: B018 # "useless expression" + except NameError: + if sys.platform == "linux": + IP_BIND_ADDRESS_NO_PORT = 24 + +del sys + + +# The socket module exports a bunch of platform-specific constants. We want to +# re-export them. Since the exact set of constants varies depending on Python +# version, platform, the libc installed on the system where Python was built, +# etc., we figure out which constants to re-export dynamically at runtime (see +# above). But that confuses static analysis tools like jedi and mypy. So this +# import statement statically lists every constant that *could* be +# exported. There's a test in test_exports.py to make sure that the list is +# kept up to date. +if _t.TYPE_CHECKING: + from socket import ( # type: ignore[attr-defined] + AF_ALG as AF_ALG, + AF_APPLETALK as AF_APPLETALK, + AF_ASH as AF_ASH, + AF_ATMPVC as AF_ATMPVC, + AF_ATMSVC as AF_ATMSVC, + AF_AX25 as AF_AX25, + AF_BLUETOOTH as AF_BLUETOOTH, + AF_BRIDGE as AF_BRIDGE, + AF_CAN as AF_CAN, + AF_ECONET as AF_ECONET, + AF_HYPERV as AF_HYPERV, + AF_INET as AF_INET, + AF_INET6 as AF_INET6, + AF_IPX as AF_IPX, + AF_IRDA as AF_IRDA, + AF_KEY as AF_KEY, + AF_LINK as AF_LINK, + AF_LLC as AF_LLC, + AF_NETBEUI as AF_NETBEUI, + AF_NETLINK as AF_NETLINK, + AF_NETROM as AF_NETROM, + AF_PACKET as AF_PACKET, + AF_PPPOX as AF_PPPOX, + AF_QIPCRTR as AF_QIPCRTR, + AF_RDS as AF_RDS, + AF_ROSE as AF_ROSE, + AF_ROUTE as AF_ROUTE, + AF_SECURITY as AF_SECURITY, + AF_SNA as AF_SNA, + AF_SYSTEM as AF_SYSTEM, + AF_TIPC as AF_TIPC, + AF_UNIX as AF_UNIX, + AF_UNSPEC as AF_UNSPEC, + AF_VSOCK as AF_VSOCK, + AF_WANPIPE as AF_WANPIPE, + AF_X25 as AF_X25, + AI_ADDRCONFIG as AI_ADDRCONFIG, + AI_ALL as AI_ALL, + AI_CANONNAME as AI_CANONNAME, + AI_DEFAULT as AI_DEFAULT, + AI_MASK as AI_MASK, + AI_NUMERICHOST as AI_NUMERICHOST, + AI_NUMERICSERV as AI_NUMERICSERV, + AI_PASSIVE as AI_PASSIVE, + AI_V4MAPPED as AI_V4MAPPED, + AI_V4MAPPED_CFG as AI_V4MAPPED_CFG, + ALG_OP_DECRYPT as ALG_OP_DECRYPT, + ALG_OP_ENCRYPT as ALG_OP_ENCRYPT, + ALG_OP_SIGN as ALG_OP_SIGN, + ALG_OP_VERIFY as ALG_OP_VERIFY, + ALG_SET_AEAD_ASSOCLEN as ALG_SET_AEAD_ASSOCLEN, + ALG_SET_AEAD_AUTHSIZE as ALG_SET_AEAD_AUTHSIZE, + ALG_SET_IV as ALG_SET_IV, + ALG_SET_KEY as ALG_SET_KEY, + ALG_SET_OP as ALG_SET_OP, + ALG_SET_PUBKEY as ALG_SET_PUBKEY, + BDADDR_ANY as BDADDR_ANY, + BDADDR_LOCAL as BDADDR_LOCAL, + BTPROTO_HCI as BTPROTO_HCI, + BTPROTO_L2CAP as BTPROTO_L2CAP, + BTPROTO_RFCOMM as BTPROTO_RFCOMM, + BTPROTO_SCO as BTPROTO_SCO, + CAN_BCM as CAN_BCM, + CAN_BCM_CAN_FD_FRAME as CAN_BCM_CAN_FD_FRAME, + CAN_BCM_RX_ANNOUNCE_RESUME as CAN_BCM_RX_ANNOUNCE_RESUME, + CAN_BCM_RX_CHANGED as CAN_BCM_RX_CHANGED, + CAN_BCM_RX_CHECK_DLC as CAN_BCM_RX_CHECK_DLC, + CAN_BCM_RX_DELETE as CAN_BCM_RX_DELETE, + CAN_BCM_RX_FILTER_ID as CAN_BCM_RX_FILTER_ID, + CAN_BCM_RX_NO_AUTOTIMER as CAN_BCM_RX_NO_AUTOTIMER, + CAN_BCM_RX_READ as CAN_BCM_RX_READ, + CAN_BCM_RX_RTR_FRAME as CAN_BCM_RX_RTR_FRAME, + CAN_BCM_RX_SETUP as CAN_BCM_RX_SETUP, + CAN_BCM_RX_STATUS as CAN_BCM_RX_STATUS, + CAN_BCM_RX_TIMEOUT as CAN_BCM_RX_TIMEOUT, + CAN_BCM_SETTIMER as CAN_BCM_SETTIMER, + CAN_BCM_STARTTIMER as CAN_BCM_STARTTIMER, + CAN_BCM_TX_ANNOUNCE as CAN_BCM_TX_ANNOUNCE, + CAN_BCM_TX_COUNTEVT as CAN_BCM_TX_COUNTEVT, + CAN_BCM_TX_CP_CAN_ID as CAN_BCM_TX_CP_CAN_ID, + CAN_BCM_TX_DELETE as CAN_BCM_TX_DELETE, + CAN_BCM_TX_EXPIRED as CAN_BCM_TX_EXPIRED, + CAN_BCM_TX_READ as CAN_BCM_TX_READ, + CAN_BCM_TX_RESET_MULTI_IDX as CAN_BCM_TX_RESET_MULTI_IDX, + CAN_BCM_TX_SEND as CAN_BCM_TX_SEND, + CAN_BCM_TX_SETUP as CAN_BCM_TX_SETUP, + CAN_BCM_TX_STATUS as CAN_BCM_TX_STATUS, + CAN_EFF_FLAG as CAN_EFF_FLAG, + CAN_EFF_MASK as CAN_EFF_MASK, + CAN_ERR_FLAG as CAN_ERR_FLAG, + CAN_ERR_MASK as CAN_ERR_MASK, + CAN_ISOTP as CAN_ISOTP, + CAN_J1939 as CAN_J1939, + CAN_RAW as CAN_RAW, + CAN_RAW_ERR_FILTER as CAN_RAW_ERR_FILTER, + CAN_RAW_FD_FRAMES as CAN_RAW_FD_FRAMES, + CAN_RAW_FILTER as CAN_RAW_FILTER, + CAN_RAW_JOIN_FILTERS as CAN_RAW_JOIN_FILTERS, + CAN_RAW_LOOPBACK as CAN_RAW_LOOPBACK, + CAN_RAW_RECV_OWN_MSGS as CAN_RAW_RECV_OWN_MSGS, + CAN_RTR_FLAG as CAN_RTR_FLAG, + CAN_SFF_MASK as CAN_SFF_MASK, + CAPI as CAPI, + CMSG_LEN as CMSG_LEN, + CMSG_SPACE as CMSG_SPACE, + EAGAIN as EAGAIN, + EAI_ADDRFAMILY as EAI_ADDRFAMILY, + EAI_AGAIN as EAI_AGAIN, + EAI_BADFLAGS as EAI_BADFLAGS, + EAI_BADHINTS as EAI_BADHINTS, + EAI_FAIL as EAI_FAIL, + EAI_FAMILY as EAI_FAMILY, + EAI_MAX as EAI_MAX, + EAI_MEMORY as EAI_MEMORY, + EAI_NODATA as EAI_NODATA, + EAI_NONAME as EAI_NONAME, + EAI_OVERFLOW as EAI_OVERFLOW, + EAI_PROTOCOL as EAI_PROTOCOL, + EAI_SERVICE as EAI_SERVICE, + EAI_SOCKTYPE as EAI_SOCKTYPE, + EAI_SYSTEM as EAI_SYSTEM, + EBADF as EBADF, + ETH_P_ALL as ETH_P_ALL, + ETHERTYPE_ARP as ETHERTYPE_ARP, + ETHERTYPE_IP as ETHERTYPE_IP, + ETHERTYPE_IPV6 as ETHERTYPE_IPV6, + ETHERTYPE_VLAN as ETHERTYPE_VLAN, + EWOULDBLOCK as EWOULDBLOCK, + FD_ACCEPT as FD_ACCEPT, + FD_CLOSE as FD_CLOSE, + FD_CLOSE_BIT as FD_CLOSE_BIT, + FD_CONNECT as FD_CONNECT, + FD_CONNECT_BIT as FD_CONNECT_BIT, + FD_READ as FD_READ, + FD_SETSIZE as FD_SETSIZE, + FD_WRITE as FD_WRITE, + HCI_DATA_DIR as HCI_DATA_DIR, + HCI_FILTER as HCI_FILTER, + HCI_TIME_STAMP as HCI_TIME_STAMP, + HV_GUID_BROADCAST as HV_GUID_BROADCAST, + HV_GUID_CHILDREN as HV_GUID_CHILDREN, + HV_GUID_LOOPBACK as HV_GUID_LOOPBACK, + HV_GUID_PARENT as HV_GUID_PARENT, + HV_GUID_WILDCARD as HV_GUID_WILDCARD, + HV_GUID_ZERO as HV_GUID_ZERO, + HV_PROTOCOL_RAW as HV_PROTOCOL_RAW, + HVSOCKET_ADDRESS_FLAG_PASSTHRU as HVSOCKET_ADDRESS_FLAG_PASSTHRU, + HVSOCKET_CONNECT_TIMEOUT as HVSOCKET_CONNECT_TIMEOUT, + HVSOCKET_CONNECT_TIMEOUT_MAX as HVSOCKET_CONNECT_TIMEOUT_MAX, + HVSOCKET_CONNECTED_SUSPEND as HVSOCKET_CONNECTED_SUSPEND, + INADDR_ALLHOSTS_GROUP as INADDR_ALLHOSTS_GROUP, + INADDR_ANY as INADDR_ANY, + INADDR_BROADCAST as INADDR_BROADCAST, + INADDR_LOOPBACK as INADDR_LOOPBACK, + INADDR_MAX_LOCAL_GROUP as INADDR_MAX_LOCAL_GROUP, + INADDR_NONE as INADDR_NONE, + INADDR_UNSPEC_GROUP as INADDR_UNSPEC_GROUP, + INFINITE as INFINITE, + IOCTL_VM_SOCKETS_GET_LOCAL_CID as IOCTL_VM_SOCKETS_GET_LOCAL_CID, + IP_ADD_MEMBERSHIP as IP_ADD_MEMBERSHIP, + IP_ADD_SOURCE_MEMBERSHIP as IP_ADD_SOURCE_MEMBERSHIP, + IP_BLOCK_SOURCE as IP_BLOCK_SOURCE, + IP_DEFAULT_MULTICAST_LOOP as IP_DEFAULT_MULTICAST_LOOP, + IP_DEFAULT_MULTICAST_TTL as IP_DEFAULT_MULTICAST_TTL, + IP_DROP_MEMBERSHIP as IP_DROP_MEMBERSHIP, + IP_DROP_SOURCE_MEMBERSHIP as IP_DROP_SOURCE_MEMBERSHIP, + IP_FREEBIND as IP_FREEBIND, + IP_HDRINCL as IP_HDRINCL, + IP_MAX_MEMBERSHIPS as IP_MAX_MEMBERSHIPS, + IP_MULTICAST_IF as IP_MULTICAST_IF, + IP_MULTICAST_LOOP as IP_MULTICAST_LOOP, + IP_MULTICAST_TTL as IP_MULTICAST_TTL, + IP_OPTIONS as IP_OPTIONS, + IP_PKTINFO as IP_PKTINFO, + IP_RECVDSTADDR as IP_RECVDSTADDR, + IP_RECVERR as IP_RECVERR, + IP_RECVOPTS as IP_RECVOPTS, + IP_RECVORIGDSTADDR as IP_RECVORIGDSTADDR, + IP_RECVRETOPTS as IP_RECVRETOPTS, + IP_RECVTOS as IP_RECVTOS, + IP_RECVTTL as IP_RECVTTL, + IP_RETOPTS as IP_RETOPTS, + IP_TOS as IP_TOS, + IP_TRANSPARENT as IP_TRANSPARENT, + IP_TTL as IP_TTL, + IP_UNBLOCK_SOURCE as IP_UNBLOCK_SOURCE, + IPPORT_RESERVED as IPPORT_RESERVED, + IPPORT_USERRESERVED as IPPORT_USERRESERVED, + IPPROTO_AH as IPPROTO_AH, + IPPROTO_CBT as IPPROTO_CBT, + IPPROTO_DSTOPTS as IPPROTO_DSTOPTS, + IPPROTO_EGP as IPPROTO_EGP, + IPPROTO_EON as IPPROTO_EON, + IPPROTO_ESP as IPPROTO_ESP, + IPPROTO_FRAGMENT as IPPROTO_FRAGMENT, + IPPROTO_GGP as IPPROTO_GGP, + IPPROTO_GRE as IPPROTO_GRE, + IPPROTO_HELLO as IPPROTO_HELLO, + IPPROTO_HOPOPTS as IPPROTO_HOPOPTS, + IPPROTO_ICLFXBM as IPPROTO_ICLFXBM, + IPPROTO_ICMP as IPPROTO_ICMP, + IPPROTO_ICMPV6 as IPPROTO_ICMPV6, + IPPROTO_IDP as IPPROTO_IDP, + IPPROTO_IGMP as IPPROTO_IGMP, + IPPROTO_IGP as IPPROTO_IGP, + IPPROTO_IP as IPPROTO_IP, + IPPROTO_IPCOMP as IPPROTO_IPCOMP, + IPPROTO_IPIP as IPPROTO_IPIP, + IPPROTO_IPV4 as IPPROTO_IPV4, + IPPROTO_IPV6 as IPPROTO_IPV6, + IPPROTO_L2TP as IPPROTO_L2TP, + IPPROTO_MAX as IPPROTO_MAX, + IPPROTO_MOBILE as IPPROTO_MOBILE, + IPPROTO_MPTCP as IPPROTO_MPTCP, + IPPROTO_ND as IPPROTO_ND, + IPPROTO_NONE as IPPROTO_NONE, + IPPROTO_PGM as IPPROTO_PGM, + IPPROTO_PIM as IPPROTO_PIM, + IPPROTO_PUP as IPPROTO_PUP, + IPPROTO_RAW as IPPROTO_RAW, + IPPROTO_RDP as IPPROTO_RDP, + IPPROTO_ROUTING as IPPROTO_ROUTING, + IPPROTO_RSVP as IPPROTO_RSVP, + IPPROTO_SCTP as IPPROTO_SCTP, + IPPROTO_ST as IPPROTO_ST, + IPPROTO_TCP as IPPROTO_TCP, + IPPROTO_TP as IPPROTO_TP, + IPPROTO_UDP as IPPROTO_UDP, + IPPROTO_UDPLITE as IPPROTO_UDPLITE, + IPPROTO_XTP as IPPROTO_XTP, + IPV6_CHECKSUM as IPV6_CHECKSUM, + IPV6_DONTFRAG as IPV6_DONTFRAG, + IPV6_DSTOPTS as IPV6_DSTOPTS, + IPV6_HOPLIMIT as IPV6_HOPLIMIT, + IPV6_HOPOPTS as IPV6_HOPOPTS, + IPV6_JOIN_GROUP as IPV6_JOIN_GROUP, + IPV6_LEAVE_GROUP as IPV6_LEAVE_GROUP, + IPV6_MULTICAST_HOPS as IPV6_MULTICAST_HOPS, + IPV6_MULTICAST_IF as IPV6_MULTICAST_IF, + IPV6_MULTICAST_LOOP as IPV6_MULTICAST_LOOP, + IPV6_NEXTHOP as IPV6_NEXTHOP, + IPV6_PATHMTU as IPV6_PATHMTU, + IPV6_PKTINFO as IPV6_PKTINFO, + IPV6_RECVDSTOPTS as IPV6_RECVDSTOPTS, + IPV6_RECVERR as IPV6_RECVERR, + IPV6_RECVHOPLIMIT as IPV6_RECVHOPLIMIT, + IPV6_RECVHOPOPTS as IPV6_RECVHOPOPTS, + IPV6_RECVPATHMTU as IPV6_RECVPATHMTU, + IPV6_RECVPKTINFO as IPV6_RECVPKTINFO, + IPV6_RECVRTHDR as IPV6_RECVRTHDR, + IPV6_RECVTCLASS as IPV6_RECVTCLASS, + IPV6_RTHDR as IPV6_RTHDR, + IPV6_RTHDR_TYPE_0 as IPV6_RTHDR_TYPE_0, + IPV6_RTHDRDSTOPTS as IPV6_RTHDRDSTOPTS, + IPV6_TCLASS as IPV6_TCLASS, + IPV6_UNICAST_HOPS as IPV6_UNICAST_HOPS, + IPV6_USE_MIN_MTU as IPV6_USE_MIN_MTU, + IPV6_V6ONLY as IPV6_V6ONLY, + J1939_EE_INFO_NONE as J1939_EE_INFO_NONE, + J1939_EE_INFO_TX_ABORT as J1939_EE_INFO_TX_ABORT, + J1939_FILTER_MAX as J1939_FILTER_MAX, + J1939_IDLE_ADDR as J1939_IDLE_ADDR, + J1939_MAX_UNICAST_ADDR as J1939_MAX_UNICAST_ADDR, + J1939_NLA_BYTES_ACKED as J1939_NLA_BYTES_ACKED, + J1939_NLA_PAD as J1939_NLA_PAD, + J1939_NO_ADDR as J1939_NO_ADDR, + J1939_NO_NAME as J1939_NO_NAME, + J1939_NO_PGN as J1939_NO_PGN, + J1939_PGN_ADDRESS_CLAIMED as J1939_PGN_ADDRESS_CLAIMED, + J1939_PGN_ADDRESS_COMMANDED as J1939_PGN_ADDRESS_COMMANDED, + J1939_PGN_MAX as J1939_PGN_MAX, + J1939_PGN_PDU1_MAX as J1939_PGN_PDU1_MAX, + J1939_PGN_REQUEST as J1939_PGN_REQUEST, + LOCAL_PEERCRED as LOCAL_PEERCRED, + MSG_BCAST as MSG_BCAST, + MSG_CMSG_CLOEXEC as MSG_CMSG_CLOEXEC, + MSG_CONFIRM as MSG_CONFIRM, + MSG_CTRUNC as MSG_CTRUNC, + MSG_DONTROUTE as MSG_DONTROUTE, + MSG_DONTWAIT as MSG_DONTWAIT, + MSG_EOF as MSG_EOF, + MSG_EOR as MSG_EOR, + MSG_ERRQUEUE as MSG_ERRQUEUE, + MSG_FASTOPEN as MSG_FASTOPEN, + MSG_MCAST as MSG_MCAST, + MSG_MORE as MSG_MORE, + MSG_NOSIGNAL as MSG_NOSIGNAL, + MSG_NOTIFICATION as MSG_NOTIFICATION, + MSG_OOB as MSG_OOB, + MSG_PEEK as MSG_PEEK, + MSG_TRUNC as MSG_TRUNC, + MSG_WAITALL as MSG_WAITALL, + NETLINK_CRYPTO as NETLINK_CRYPTO, + NETLINK_DNRTMSG as NETLINK_DNRTMSG, + NETLINK_FIREWALL as NETLINK_FIREWALL, + NETLINK_IP6_FW as NETLINK_IP6_FW, + NETLINK_NFLOG as NETLINK_NFLOG, + NETLINK_ROUTE as NETLINK_ROUTE, + NETLINK_USERSOCK as NETLINK_USERSOCK, + NETLINK_XFRM as NETLINK_XFRM, + NI_DGRAM as NI_DGRAM, + NI_IDN as NI_IDN, + NI_MAXHOST as NI_MAXHOST, + NI_MAXSERV as NI_MAXSERV, + NI_NAMEREQD as NI_NAMEREQD, + NI_NOFQDN as NI_NOFQDN, + NI_NUMERICHOST as NI_NUMERICHOST, + NI_NUMERICSERV as NI_NUMERICSERV, + PACKET_BROADCAST as PACKET_BROADCAST, + PACKET_FASTROUTE as PACKET_FASTROUTE, + PACKET_HOST as PACKET_HOST, + PACKET_LOOPBACK as PACKET_LOOPBACK, + PACKET_MULTICAST as PACKET_MULTICAST, + PACKET_OTHERHOST as PACKET_OTHERHOST, + PACKET_OUTGOING as PACKET_OUTGOING, + PF_CAN as PF_CAN, + PF_PACKET as PF_PACKET, + PF_RDS as PF_RDS, + PF_SYSTEM as PF_SYSTEM, + POLLERR as POLLERR, + POLLHUP as POLLHUP, + POLLIN as POLLIN, + POLLMSG as POLLMSG, + POLLNVAL as POLLNVAL, + POLLOUT as POLLOUT, + POLLPRI as POLLPRI, + POLLRDBAND as POLLRDBAND, + POLLRDNORM as POLLRDNORM, + POLLWRNORM as POLLWRNORM, + RCVALL_MAX as RCVALL_MAX, + RCVALL_OFF as RCVALL_OFF, + RCVALL_ON as RCVALL_ON, + RCVALL_SOCKETLEVELONLY as RCVALL_SOCKETLEVELONLY, + SCM_CREDENTIALS as SCM_CREDENTIALS, + SCM_CREDS as SCM_CREDS, + SCM_J1939_DEST_ADDR as SCM_J1939_DEST_ADDR, + SCM_J1939_DEST_NAME as SCM_J1939_DEST_NAME, + SCM_J1939_ERRQUEUE as SCM_J1939_ERRQUEUE, + SCM_J1939_PRIO as SCM_J1939_PRIO, + SCM_RIGHTS as SCM_RIGHTS, + SHUT_RD as SHUT_RD, + SHUT_RDWR as SHUT_RDWR, + SHUT_WR as SHUT_WR, + SIO_KEEPALIVE_VALS as SIO_KEEPALIVE_VALS, + SIO_LOOPBACK_FAST_PATH as SIO_LOOPBACK_FAST_PATH, + SIO_RCVALL as SIO_RCVALL, + SIOCGIFINDEX as SIOCGIFINDEX, + SIOCGIFNAME as SIOCGIFNAME, + SO_ACCEPTCONN as SO_ACCEPTCONN, + SO_BINDTODEVICE as SO_BINDTODEVICE, + SO_BINDTOIFINDEX as SO_BINDTOIFINDEX, + SO_BROADCAST as SO_BROADCAST, + SO_BTH_ENCRYPT as SO_BTH_ENCRYPT, + SO_BTH_MTU as SO_BTH_MTU, + SO_BTH_MTU_MAX as SO_BTH_MTU_MAX, + SO_BTH_MTU_MIN as SO_BTH_MTU_MIN, + SO_DEBUG as SO_DEBUG, + SO_DOMAIN as SO_DOMAIN, + SO_DONTROUTE as SO_DONTROUTE, + SO_ERROR as SO_ERROR, + SO_EXCLUSIVEADDRUSE as SO_EXCLUSIVEADDRUSE, + SO_INCOMING_CPU as SO_INCOMING_CPU, + SO_J1939_ERRQUEUE as SO_J1939_ERRQUEUE, + SO_J1939_FILTER as SO_J1939_FILTER, + SO_J1939_PROMISC as SO_J1939_PROMISC, + SO_J1939_SEND_PRIO as SO_J1939_SEND_PRIO, + SO_KEEPALIVE as SO_KEEPALIVE, + SO_LINGER as SO_LINGER, + SO_MARK as SO_MARK, + SO_OOBINLINE as SO_OOBINLINE, + SO_ORIGINAL_DST as SO_ORIGINAL_DST, + SO_PASSCRED as SO_PASSCRED, + SO_PASSSEC as SO_PASSSEC, + SO_PEERCRED as SO_PEERCRED, + SO_PEERSEC as SO_PEERSEC, + SO_PRIORITY as SO_PRIORITY, + SO_PROTOCOL as SO_PROTOCOL, + SO_RCVBUF as SO_RCVBUF, + SO_RCVLOWAT as SO_RCVLOWAT, + SO_RCVTIMEO as SO_RCVTIMEO, + SO_REUSEADDR as SO_REUSEADDR, + SO_REUSEPORT as SO_REUSEPORT, + SO_SETFIB as SO_SETFIB, + SO_SNDBUF as SO_SNDBUF, + SO_SNDLOWAT as SO_SNDLOWAT, + SO_SNDTIMEO as SO_SNDTIMEO, + SO_TYPE as SO_TYPE, + SO_USELOOPBACK as SO_USELOOPBACK, + SO_VM_SOCKETS_BUFFER_MAX_SIZE as SO_VM_SOCKETS_BUFFER_MAX_SIZE, + SO_VM_SOCKETS_BUFFER_MIN_SIZE as SO_VM_SOCKETS_BUFFER_MIN_SIZE, + SO_VM_SOCKETS_BUFFER_SIZE as SO_VM_SOCKETS_BUFFER_SIZE, + SOCK_CLOEXEC as SOCK_CLOEXEC, + SOCK_DGRAM as SOCK_DGRAM, + SOCK_NONBLOCK as SOCK_NONBLOCK, + SOCK_RAW as SOCK_RAW, + SOCK_RDM as SOCK_RDM, + SOCK_SEQPACKET as SOCK_SEQPACKET, + SOCK_STREAM as SOCK_STREAM, + SOL_ALG as SOL_ALG, + SOL_CAN_BASE as SOL_CAN_BASE, + SOL_CAN_RAW as SOL_CAN_RAW, + SOL_HCI as SOL_HCI, + SOL_IP as SOL_IP, + SOL_RDS as SOL_RDS, + SOL_RFCOMM as SOL_RFCOMM, + SOL_SOCKET as SOL_SOCKET, + SOL_TCP as SOL_TCP, + SOL_TIPC as SOL_TIPC, + SOL_UDP as SOL_UDP, + SOMAXCONN as SOMAXCONN, + SYSPROTO_CONTROL as SYSPROTO_CONTROL, + TCP_CC_INFO as TCP_CC_INFO, + TCP_CONGESTION as TCP_CONGESTION, + TCP_CONNECTION_INFO as TCP_CONNECTION_INFO, + TCP_CORK as TCP_CORK, + TCP_DEFER_ACCEPT as TCP_DEFER_ACCEPT, + TCP_FASTOPEN as TCP_FASTOPEN, + TCP_FASTOPEN_CONNECT as TCP_FASTOPEN_CONNECT, + TCP_FASTOPEN_KEY as TCP_FASTOPEN_KEY, + TCP_FASTOPEN_NO_COOKIE as TCP_FASTOPEN_NO_COOKIE, + TCP_INFO as TCP_INFO, + TCP_INQ as TCP_INQ, + TCP_KEEPALIVE as TCP_KEEPALIVE, + TCP_KEEPCNT as TCP_KEEPCNT, + TCP_KEEPIDLE as TCP_KEEPIDLE, + TCP_KEEPINTVL as TCP_KEEPINTVL, + TCP_LINGER2 as TCP_LINGER2, + TCP_MAXSEG as TCP_MAXSEG, + TCP_MD5SIG as TCP_MD5SIG, + TCP_MD5SIG_EXT as TCP_MD5SIG_EXT, + TCP_NODELAY as TCP_NODELAY, + TCP_NOTSENT_LOWAT as TCP_NOTSENT_LOWAT, + TCP_QUEUE_SEQ as TCP_QUEUE_SEQ, + TCP_QUICKACK as TCP_QUICKACK, + TCP_REPAIR as TCP_REPAIR, + TCP_REPAIR_OPTIONS as TCP_REPAIR_OPTIONS, + TCP_REPAIR_QUEUE as TCP_REPAIR_QUEUE, + TCP_REPAIR_WINDOW as TCP_REPAIR_WINDOW, + TCP_SAVE_SYN as TCP_SAVE_SYN, + TCP_SAVED_SYN as TCP_SAVED_SYN, + TCP_SYNCNT as TCP_SYNCNT, + TCP_THIN_DUPACK as TCP_THIN_DUPACK, + TCP_THIN_LINEAR_TIMEOUTS as TCP_THIN_LINEAR_TIMEOUTS, + TCP_TIMESTAMP as TCP_TIMESTAMP, + TCP_TX_DELAY as TCP_TX_DELAY, + TCP_ULP as TCP_ULP, + TCP_USER_TIMEOUT as TCP_USER_TIMEOUT, + TCP_WINDOW_CLAMP as TCP_WINDOW_CLAMP, + TCP_ZEROCOPY_RECEIVE as TCP_ZEROCOPY_RECEIVE, + TIPC_ADDR_ID as TIPC_ADDR_ID, + TIPC_ADDR_NAME as TIPC_ADDR_NAME, + TIPC_ADDR_NAMESEQ as TIPC_ADDR_NAMESEQ, + TIPC_CFG_SRV as TIPC_CFG_SRV, + TIPC_CLUSTER_SCOPE as TIPC_CLUSTER_SCOPE, + TIPC_CONN_TIMEOUT as TIPC_CONN_TIMEOUT, + TIPC_CRITICAL_IMPORTANCE as TIPC_CRITICAL_IMPORTANCE, + TIPC_DEST_DROPPABLE as TIPC_DEST_DROPPABLE, + TIPC_HIGH_IMPORTANCE as TIPC_HIGH_IMPORTANCE, + TIPC_IMPORTANCE as TIPC_IMPORTANCE, + TIPC_LOW_IMPORTANCE as TIPC_LOW_IMPORTANCE, + TIPC_MEDIUM_IMPORTANCE as TIPC_MEDIUM_IMPORTANCE, + TIPC_NODE_SCOPE as TIPC_NODE_SCOPE, + TIPC_PUBLISHED as TIPC_PUBLISHED, + TIPC_SRC_DROPPABLE as TIPC_SRC_DROPPABLE, + TIPC_SUB_CANCEL as TIPC_SUB_CANCEL, + TIPC_SUB_PORTS as TIPC_SUB_PORTS, + TIPC_SUB_SERVICE as TIPC_SUB_SERVICE, + TIPC_SUBSCR_TIMEOUT as TIPC_SUBSCR_TIMEOUT, + TIPC_TOP_SRV as TIPC_TOP_SRV, + TIPC_WAIT_FOREVER as TIPC_WAIT_FOREVER, + TIPC_WITHDRAWN as TIPC_WITHDRAWN, + TIPC_ZONE_SCOPE as TIPC_ZONE_SCOPE, + UDPLITE_RECV_CSCOV as UDPLITE_RECV_CSCOV, + UDPLITE_SEND_CSCOV as UDPLITE_SEND_CSCOV, + VM_SOCKETS_INVALID_VERSION as VM_SOCKETS_INVALID_VERSION, + VMADDR_CID_ANY as VMADDR_CID_ANY, + VMADDR_CID_HOST as VMADDR_CID_HOST, + VMADDR_CID_LOCAL as VMADDR_CID_LOCAL, + VMADDR_PORT_ANY as VMADDR_PORT_ANY, + WSA_FLAG_OVERLAPPED as WSA_FLAG_OVERLAPPED, + WSA_INVALID_HANDLE as WSA_INVALID_HANDLE, + WSA_INVALID_PARAMETER as WSA_INVALID_PARAMETER, + WSA_IO_INCOMPLETE as WSA_IO_INCOMPLETE, + WSA_IO_PENDING as WSA_IO_PENDING, + WSA_NOT_ENOUGH_MEMORY as WSA_NOT_ENOUGH_MEMORY, + WSA_OPERATION_ABORTED as WSA_OPERATION_ABORTED, + WSA_WAIT_FAILED as WSA_WAIT_FAILED, + WSA_WAIT_TIMEOUT as WSA_WAIT_TIMEOUT, + ) diff --git a/contrib/python/trio/trio/testing/__init__.py b/contrib/python/trio/trio/testing/__init__.py new file mode 100644 index 00000000000..d93d33aab7d --- /dev/null +++ b/contrib/python/trio/trio/testing/__init__.py @@ -0,0 +1,39 @@ +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) + +from .._core import ( + MockClock as MockClock, + wait_all_tasks_blocked as wait_all_tasks_blocked, +) +from .._threads import ( + active_thread_count as active_thread_count, + wait_all_threads_completed as wait_all_threads_completed, +) +from .._util import fixup_module_metadata +from ._check_streams import ( + check_half_closeable_stream as check_half_closeable_stream, + check_one_way_stream as check_one_way_stream, + check_two_way_stream as check_two_way_stream, +) +from ._checkpoints import ( + assert_checkpoints as assert_checkpoints, + assert_no_checkpoints as assert_no_checkpoints, +) +from ._memory_streams import ( + MemoryReceiveStream as MemoryReceiveStream, + MemorySendStream as MemorySendStream, + lockstep_stream_one_way_pair as lockstep_stream_one_way_pair, + lockstep_stream_pair as lockstep_stream_pair, + memory_stream_one_way_pair as memory_stream_one_way_pair, + memory_stream_pair as memory_stream_pair, + memory_stream_pump as memory_stream_pump, +) +from ._network import open_stream_to_socket_listener as open_stream_to_socket_listener +from ._raises_group import Matcher as Matcher, RaisesGroup as RaisesGroup +from ._sequencer import Sequencer as Sequencer +from ._trio_test import trio_test as trio_test + +################################################################ + + +fixup_module_metadata(__name__, globals()) +del fixup_module_metadata diff --git a/contrib/python/trio/trio/testing/_check_streams.py b/contrib/python/trio/trio/testing/_check_streams.py new file mode 100644 index 00000000000..e58e2ddfed2 --- /dev/null +++ b/contrib/python/trio/trio/testing/_check_streams.py @@ -0,0 +1,569 @@ +# Generic stream tests +from __future__ import annotations + +import random +import sys +from collections.abc import Awaitable, Callable, Generator +from contextlib import contextmanager, suppress +from typing import ( + TYPE_CHECKING, + Generic, + TypeVar, +) + +from .. import CancelScope, _core +from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream +from .._highlevel_generic import aclose_forcefully +from ._checkpoints import assert_checkpoints + +if TYPE_CHECKING: + from types import TracebackType + + from typing_extensions import ParamSpec, TypeAlias + + ArgsT = ParamSpec("ArgsT") + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + +Res1 = TypeVar("Res1", bound=AsyncResource) +Res2 = TypeVar("Res2", bound=AsyncResource) +StreamMaker: TypeAlias = Callable[[], Awaitable[tuple[Res1, Res2]]] + + +class _ForceCloseBoth(Generic[Res1, Res2]): + def __init__(self, both: tuple[Res1, Res2]) -> None: + self._first, self._second = both + + async def __aenter__(self) -> tuple[Res1, Res2]: + return self._first, self._second + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + try: + await aclose_forcefully(self._first) + finally: + await aclose_forcefully(self._second) + + +# This is used in this file instead of pytest.raises in order to avoid a dependency +# on pytest, as the check_* functions are publicly exported. +@contextmanager +def _assert_raises( + expected_exc: type[BaseException], + wrapped: bool = False, +) -> Generator[None, None, None]: + __tracebackhide__ = True + try: + yield + except BaseExceptionGroup as exc: + assert wrapped, "caught exceptiongroup, but expected an unwrapped exception" + # assert in except block ignored below + assert len(exc.exceptions) == 1 # noqa: PT017 + assert isinstance(exc.exceptions[0], expected_exc) # noqa: PT017 + except expected_exc: + assert not wrapped, "caught exception, but expected an exceptiongroup" + else: + raise AssertionError(f"expected exception: {expected_exc}") + + +async def check_one_way_stream( + stream_maker: StreamMaker[SendStream, ReceiveStream], + clogged_stream_maker: StreamMaker[SendStream, ReceiveStream] | None, +) -> None: + """Perform a number of generic tests on a custom one-way stream + implementation. + + Args: + stream_maker: An async (!) function which returns a connected + (:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`) + pair. + clogged_stream_maker: Either None, or an async function similar to + stream_maker, but with the extra property that the returned stream + is in a state where ``send_all`` and + ``wait_send_all_might_not_block`` will block until ``receive_some`` + has been called. This allows for more thorough testing of some edge + cases, especially around ``wait_send_all_might_not_block``. + + Raises: + AssertionError: if a test fails. + + """ + async with _ForceCloseBoth(await stream_maker()) as (s, r): + assert isinstance(s, SendStream) + assert isinstance(r, ReceiveStream) + + async def do_send_all(data: bytes | bytearray | memoryview) -> None: + with assert_checkpoints(): # We're testing that it doesn't return anything. + assert await s.send_all(data) is None # type: ignore[func-returns-value] + + async def do_receive_some(max_bytes: int | None = None) -> bytes | bytearray: + with assert_checkpoints(): + return await r.receive_some(max_bytes) + + async def checked_receive_1(expected: bytes) -> None: + assert await do_receive_some(1) == expected + + async def do_aclose(resource: AsyncResource) -> None: + with assert_checkpoints(): + await resource.aclose() + + # Simple sending/receiving + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all, b"x") + nursery.start_soon(checked_receive_1, b"x") + + async def send_empty_then_y() -> None: + # Streams should tolerate sending b"" without giving it any + # special meaning. + await do_send_all(b"") + await do_send_all(b"y") + + async with _core.open_nursery() as nursery: + nursery.start_soon(send_empty_then_y) + nursery.start_soon(checked_receive_1, b"y") + + # ---- Checking various argument types ---- + + # send_all accepts bytearray and memoryview + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all, bytearray(b"1")) + nursery.start_soon(checked_receive_1, b"1") + + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all, memoryview(b"2")) + nursery.start_soon(checked_receive_1, b"2") + + # max_bytes must be a positive integer + with _assert_raises(ValueError): + await r.receive_some(-1) + with _assert_raises(ValueError): + await r.receive_some(0) + with _assert_raises(TypeError): + await r.receive_some(1.5) # type: ignore[arg-type] + # it can also be missing or None + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all, b"x") + assert await do_receive_some() == b"x" + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all, b"x") + assert await do_receive_some(None) == b"x" + + with _assert_raises(_core.BusyResourceError, wrapped=True): + async with _core.open_nursery() as nursery: + nursery.start_soon(do_receive_some, 1) + nursery.start_soon(do_receive_some, 1) + + # Method always has to exist, and an empty stream with a blocked + # receive_some should *always* allow send_all. (Technically it's legal + # for send_all to wait until receive_some is called to run, though; a + # stream doesn't *have* to have any internal buffering. That's why we + # start a concurrent receive_some call, then cancel it.) + async def simple_check_wait_send_all_might_not_block( + scope: CancelScope, + ) -> None: + with assert_checkpoints(): + await s.wait_send_all_might_not_block() + scope.cancel() + + async with _core.open_nursery() as nursery: + nursery.start_soon( + simple_check_wait_send_all_might_not_block, + nursery.cancel_scope, + ) + nursery.start_soon(do_receive_some, 1) + + # closing the r side leads to BrokenResourceError on the s side + # (eventually) + async def expect_broken_stream_on_send() -> None: + with _assert_raises(_core.BrokenResourceError): + while True: + await do_send_all(b"x" * 100) + + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_broken_stream_on_send) + nursery.start_soon(do_aclose, r) + + # once detected, the stream stays broken + with _assert_raises(_core.BrokenResourceError): + await do_send_all(b"x" * 100) + + # r closed -> ClosedResourceError on the receive side + with _assert_raises(_core.ClosedResourceError): + await do_receive_some(4096) + + # we can close the same stream repeatedly, it's fine + await do_aclose(r) + await do_aclose(r) + + # closing the sender side + await do_aclose(s) + + # now trying to send raises ClosedResourceError + with _assert_raises(_core.ClosedResourceError): + await do_send_all(b"x" * 100) + + # even if it's an empty send + with _assert_raises(_core.ClosedResourceError): + await do_send_all(b"") + + # ditto for wait_send_all_might_not_block + with _assert_raises(_core.ClosedResourceError): + with assert_checkpoints(): + await s.wait_send_all_might_not_block() + + # and again, repeated closing is fine + await do_aclose(s) + await do_aclose(s) + + async with _ForceCloseBoth(await stream_maker()) as (s, r): + # if send-then-graceful-close, receiver gets data then b"" + async def send_then_close() -> None: + await do_send_all(b"y") + await do_aclose(s) + + async def receive_send_then_close() -> None: + # We want to make sure that if the sender closes the stream before + # we read anything, then we still get all the data. But some + # streams might block on the do_send_all call. So we let the + # sender get as far as it can, then we receive. + await _core.wait_all_tasks_blocked() + await checked_receive_1(b"y") + await checked_receive_1(b"") + await do_aclose(r) + + async with _core.open_nursery() as nursery: + nursery.start_soon(send_then_close) + nursery.start_soon(receive_send_then_close) + + async with _ForceCloseBoth(await stream_maker()) as (s, r): + await aclose_forcefully(r) + + with _assert_raises(_core.BrokenResourceError): + while True: + await do_send_all(b"x" * 100) + + with _assert_raises(_core.ClosedResourceError): + await do_receive_some(4096) + + async with _ForceCloseBoth(await stream_maker()) as (s, r): + await aclose_forcefully(s) + + with _assert_raises(_core.ClosedResourceError): + await do_send_all(b"123") + + # after the sender does a forceful close, the receiver might either + # get BrokenResourceError or a clean b""; either is OK. Not OK would be + # if it freezes, or returns data. + with suppress(_core.BrokenResourceError): + await checked_receive_1(b"") + + # cancelled aclose still closes + async with _ForceCloseBoth(await stream_maker()) as (s, r): + with _core.CancelScope() as scope: + scope.cancel() + await r.aclose() + + with _core.CancelScope() as scope: + scope.cancel() + await s.aclose() + + with _assert_raises(_core.ClosedResourceError): + await do_send_all(b"123") + + with _assert_raises(_core.ClosedResourceError): + await do_receive_some(4096) + + # Check that we can still gracefully close a stream after an operation has + # been cancelled. This can be challenging if cancellation can leave the + # stream internals in an inconsistent state, e.g. for + # SSLStream. Unfortunately this test isn't very thorough; the really + # challenging case for something like SSLStream is it gets cancelled + # *while* it's sending data on the underlying, not before. But testing + # that requires some special-case handling of the particular stream setup; + # we can't do it here. Maybe we could do a bit better with + # https://github.com/python-trio/trio/issues/77 + async with _ForceCloseBoth(await stream_maker()) as (s, r): + + async def expect_cancelled( + afn: Callable[ArgsT, Awaitable[object]], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, + ) -> None: + with _assert_raises(_core.Cancelled): + await afn(*args, **kwargs) + + with _core.CancelScope() as scope: + scope.cancel() + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_cancelled, do_send_all, b"x") + nursery.start_soon(expect_cancelled, do_receive_some, 1) + + async with _core.open_nursery() as nursery: + nursery.start_soon(do_aclose, s) + nursery.start_soon(do_aclose, r) + + # Check that if a task is blocked in receive_some, then closing the + # receive stream causes it to wake up. + async with _ForceCloseBoth(await stream_maker()) as (s, r): + + async def receive_expecting_closed() -> None: + with _assert_raises(_core.ClosedResourceError): + await r.receive_some(10) + + async with _core.open_nursery() as nursery: + nursery.start_soon(receive_expecting_closed) + await _core.wait_all_tasks_blocked() + await aclose_forcefully(r) + + # check wait_send_all_might_not_block, if we can + if clogged_stream_maker is not None: + async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): + record: list[str] = [] + + async def waiter(cancel_scope: CancelScope) -> None: + record.append("waiter sleeping") + with assert_checkpoints(): + await s.wait_send_all_might_not_block() + record.append("waiter wokeup") + cancel_scope.cancel() + + async def receiver() -> None: + # give wait_send_all_might_not_block a chance to block + await _core.wait_all_tasks_blocked() + record.append("receiver starting") + while True: + await r.receive_some(16834) + + async with _core.open_nursery() as nursery: + nursery.start_soon(waiter, nursery.cancel_scope) + await _core.wait_all_tasks_blocked() + nursery.start_soon(receiver) + + assert record == [ + "waiter sleeping", + "receiver starting", + "waiter wokeup", + ] + + async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): + # simultaneous wait_send_all_might_not_block fails + with _assert_raises(_core.BusyResourceError, wrapped=True): + async with _core.open_nursery() as nursery: + nursery.start_soon(s.wait_send_all_might_not_block) + nursery.start_soon(s.wait_send_all_might_not_block) + + # and simultaneous send_all and wait_send_all_might_not_block (NB + # this test might destroy the stream b/c we end up cancelling + # send_all and e.g. SSLStream can't handle that, so we have to + # recreate afterwards) + with _assert_raises(_core.BusyResourceError, wrapped=True): + async with _core.open_nursery() as nursery: + nursery.start_soon(s.wait_send_all_might_not_block) + nursery.start_soon(s.send_all, b"123") + + async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): + # send_all and send_all blocked simultaneously should also raise + # (but again this might destroy the stream) + with _assert_raises(_core.BusyResourceError, wrapped=True): + async with _core.open_nursery() as nursery: + nursery.start_soon(s.send_all, b"123") + nursery.start_soon(s.send_all, b"123") + + # closing the receiver causes wait_send_all_might_not_block to return, + # with or without an exception + async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): + + async def sender() -> None: + try: + with assert_checkpoints(): + await s.wait_send_all_might_not_block() + except _core.BrokenResourceError: # pragma: no cover + pass + + async def receiver() -> None: + await _core.wait_all_tasks_blocked() + await aclose_forcefully(r) + + async with _core.open_nursery() as nursery: + nursery.start_soon(sender) + nursery.start_soon(receiver) + + # and again with the call starting after the close + async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): + await aclose_forcefully(r) + try: + with assert_checkpoints(): + await s.wait_send_all_might_not_block() + except _core.BrokenResourceError: # pragma: no cover + pass + + # Check that if a task is blocked in a send-side method, then closing + # the send stream causes it to wake up. + async def close_soon(s: SendStream) -> None: + await _core.wait_all_tasks_blocked() + await aclose_forcefully(s) + + async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): + async with _core.open_nursery() as nursery: + nursery.start_soon(close_soon, s) + with _assert_raises(_core.ClosedResourceError): + await s.send_all(b"xyzzy") + + async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): + async with _core.open_nursery() as nursery: + nursery.start_soon(close_soon, s) + with _assert_raises(_core.ClosedResourceError): + await s.wait_send_all_might_not_block() + + +async def check_two_way_stream( + stream_maker: StreamMaker[Stream, Stream], + clogged_stream_maker: StreamMaker[Stream, Stream] | None, +) -> None: + """Perform a number of generic tests on a custom two-way stream + implementation. + + This is similar to :func:`check_one_way_stream`, except that the maker + functions are expected to return objects implementing the + :class:`~trio.abc.Stream` interface. + + This function tests a *superset* of what :func:`check_one_way_stream` + checks – if you call this, then you don't need to also call + :func:`check_one_way_stream`. + + """ + await check_one_way_stream(stream_maker, clogged_stream_maker) + + async def flipped_stream_maker() -> tuple[Stream, Stream]: + return (await stream_maker())[::-1] + + flipped_clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None + + if clogged_stream_maker is not None: + + async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]: + return (await clogged_stream_maker())[::-1] + + else: + flipped_clogged_stream_maker = None + await check_one_way_stream(flipped_stream_maker, flipped_clogged_stream_maker) + + async with _ForceCloseBoth(await stream_maker()) as (s1, s2): + assert isinstance(s1, Stream) + assert isinstance(s2, Stream) + + # Duplex can be a bit tricky, might as well check it as well + DUPLEX_TEST_SIZE = 2**20 + CHUNK_SIZE_MAX = 2**14 + + r = random.Random(0) + i = r.getrandbits(8 * DUPLEX_TEST_SIZE) + test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little") + + async def sender( + s: Stream, + data: bytes | bytearray | memoryview, + seed: int, + ) -> None: + r = random.Random(seed) + m = memoryview(data) + while m: + chunk_size = r.randint(1, CHUNK_SIZE_MAX) + await s.send_all(m[:chunk_size]) + m = m[chunk_size:] + + async def receiver(s: Stream, data: bytes | bytearray, seed: int) -> None: + r = random.Random(seed) + got = bytearray() + while len(got) < len(data): + chunk = await s.receive_some(r.randint(1, CHUNK_SIZE_MAX)) + assert chunk + got += chunk + assert got == data + + async with _core.open_nursery() as nursery: + nursery.start_soon(sender, s1, test_data, 0) + nursery.start_soon(sender, s2, test_data[::-1], 1) + nursery.start_soon(receiver, s1, test_data[::-1], 2) + nursery.start_soon(receiver, s2, test_data, 3) + + async def expect_receive_some_empty() -> None: + assert await s2.receive_some(10) == b"" + await s2.aclose() + + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_receive_some_empty) + nursery.start_soon(s1.aclose) + + +async def check_half_closeable_stream( + stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream], + clogged_stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream] | None, +) -> None: + """Perform a number of generic tests on a custom half-closeable stream + implementation. + + This is similar to :func:`check_two_way_stream`, except that the maker + functions are expected to return objects that implement the + :class:`~trio.abc.HalfCloseableStream` interface. + + This function tests a *superset* of what :func:`check_two_way_stream` + checks – if you call this, then you don't need to also call + :func:`check_two_way_stream`. + + """ + await check_two_way_stream(stream_maker, clogged_stream_maker) + + async with _ForceCloseBoth(await stream_maker()) as (s1, s2): + assert isinstance(s1, HalfCloseableStream) + assert isinstance(s2, HalfCloseableStream) + + async def send_x_then_eof(s: HalfCloseableStream) -> None: + await s.send_all(b"x") + with assert_checkpoints(): + await s.send_eof() + + async def expect_x_then_eof(r: HalfCloseableStream) -> None: + await _core.wait_all_tasks_blocked() + assert await r.receive_some(10) == b"x" + assert await r.receive_some(10) == b"" + + async with _core.open_nursery() as nursery: + nursery.start_soon(send_x_then_eof, s1) + nursery.start_soon(expect_x_then_eof, s2) + + # now sending is disallowed + with _assert_raises(_core.ClosedResourceError): + await s1.send_all(b"y") + + # but we can do send_eof again + with assert_checkpoints(): + await s1.send_eof() + + # and we can still send stuff back the other way + async with _core.open_nursery() as nursery: + nursery.start_soon(send_x_then_eof, s2) + nursery.start_soon(expect_x_then_eof, s1) + + if clogged_stream_maker is not None: + async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2): + # send_all and send_eof simultaneously is not ok + with _assert_raises(_core.BusyResourceError, wrapped=True): + async with _core.open_nursery() as nursery: + nursery.start_soon(s1.send_all, b"x") + await _core.wait_all_tasks_blocked() + nursery.start_soon(s1.send_eof) + + async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2): + # wait_send_all_might_not_block and send_eof simultaneously is not + # ok either + with _assert_raises(_core.BusyResourceError, wrapped=True): + async with _core.open_nursery() as nursery: + nursery.start_soon(s1.wait_send_all_might_not_block) + await _core.wait_all_tasks_blocked() + nursery.start_soon(s1.send_eof) diff --git a/contrib/python/trio/trio/testing/_checkpoints.py b/contrib/python/trio/trio/testing/_checkpoints.py new file mode 100644 index 00000000000..e51463f0713 --- /dev/null +++ b/contrib/python/trio/trio/testing/_checkpoints.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from contextlib import AbstractContextManager, contextmanager +from typing import TYPE_CHECKING + +from .. import _core + +if TYPE_CHECKING: + from collections.abc import Generator + + +@contextmanager +def _assert_yields_or_not(expected: bool) -> Generator[None, None, None]: + """Check if checkpoints are executed in a block of code.""" + __tracebackhide__ = True + task = _core.current_task() + orig_cancel = task._cancel_points + orig_schedule = task._schedule_points + try: + yield + if expected and ( + task._cancel_points == orig_cancel or task._schedule_points == orig_schedule + ): + raise AssertionError("assert_checkpoints block did not yield!") + finally: + if not expected and ( + task._cancel_points != orig_cancel or task._schedule_points != orig_schedule + ): + raise AssertionError("assert_no_checkpoints block yielded!") + + +def assert_checkpoints() -> AbstractContextManager[None]: + """Use as a context manager to check that the code inside the ``with`` + block either exits with an exception or executes at least one + :ref:`checkpoint <checkpoints>`. + + Raises: + AssertionError: if no checkpoint was executed. + + Example: + Check that :func:`trio.sleep` is a checkpoint, even if it doesn't + block:: + + with trio.testing.assert_checkpoints(): + await trio.sleep(0) + + """ + __tracebackhide__ = True + return _assert_yields_or_not(True) + + +def assert_no_checkpoints() -> AbstractContextManager[None]: + """Use as a context manager to check that the code inside the ``with`` + block does not execute any :ref:`checkpoints <checkpoints>`. + + Raises: + AssertionError: if a checkpoint was executed. + + Example: + Synchronous code never contains any checkpoints, but we can double-check + that:: + + send_channel, receive_channel = trio.open_memory_channel(10) + with trio.testing.assert_no_checkpoints(): + send_channel.send_nowait(None) + + """ + __tracebackhide__ = True + return _assert_yields_or_not(False) diff --git a/contrib/python/trio/trio/testing/_fake_net.py b/contrib/python/trio/trio/testing/_fake_net.py new file mode 100644 index 00000000000..5d63112e17c --- /dev/null +++ b/contrib/python/trio/trio/testing/_fake_net.py @@ -0,0 +1,590 @@ +# This should eventually be cleaned up and become public, but for right now I'm just +# implementing enough to test DTLS. + +# TODO: +# - user-defined routers +# - TCP +# - UDP broadcast + +from __future__ import annotations + +import contextlib +import errno +import ipaddress +import os +import socket +import sys +from typing import ( + TYPE_CHECKING, + Any, + NoReturn, + Union, + overload, +) + +import attrs + +import trio +from trio._util import NoPublicConstructor, final + +if TYPE_CHECKING: + import builtins + from collections.abc import Iterable + from socket import AddressFamily, SocketKind + from types import TracebackType + + from typing_extensions import Buffer, Self, TypeAlias + + from trio._socket import AddressFormat + +IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + + +def _family_for(ip: IPAddress) -> int: + if isinstance(ip, ipaddress.IPv4Address): + return trio.socket.AF_INET + elif isinstance(ip, ipaddress.IPv6Address): + return trio.socket.AF_INET6 + raise NotImplementedError("Unhandled IPAddress instance type") # pragma: no cover + + +def _wildcard_ip_for(family: int) -> IPAddress: + if family == trio.socket.AF_INET: + return ipaddress.ip_address("0.0.0.0") + elif family == trio.socket.AF_INET6: + return ipaddress.ip_address("::") + raise NotImplementedError("Unhandled ip address family") # pragma: no cover + + +# not used anywhere +def _localhost_ip_for(family: int) -> IPAddress: # pragma: no cover + if family == trio.socket.AF_INET: + return ipaddress.ip_address("127.0.0.1") + elif family == trio.socket.AF_INET6: + return ipaddress.ip_address("::1") + raise NotImplementedError("Unhandled ip address family") + + +def _fake_err(code: int) -> NoReturn: + raise OSError(code, os.strerror(code)) + + +def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int: + written = 0 + for buf in buffers: # pragma: no branch + next_piece = data[written : written + memoryview(buf).nbytes] + with memoryview(buf) as mbuf: + mbuf[: len(next_piece)] = next_piece + written += len(next_piece) + if written == len(data): # pragma: no branch + break + return written + + +class UDPEndpoint: + ip: IPAddress + port: int + + def as_python_sockaddr(self) -> tuple[str, int] | tuple[str, int, int, int]: + sockaddr: tuple[str, int] | tuple[str, int, int, int] = ( + self.ip.compressed, + self.port, + ) + if isinstance(self.ip, ipaddress.IPv6Address): + sockaddr += (0, 0) # type: ignore[assignment] + return sockaddr + + @classmethod + def from_python_sockaddr( + cls, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + ) -> UDPEndpoint: + ip, port = sockaddr[:2] + return cls(ip=ipaddress.ip_address(ip), port=port) + + +class UDPBinding: + local: UDPEndpoint + # remote: UDPEndpoint # ?? + + +class UDPPacket: + source: UDPEndpoint + destination: UDPEndpoint + payload: bytes = attrs.field(repr=lambda p: p.hex()) + + # not used/tested anywhere + def reply(self, payload: bytes) -> UDPPacket: # pragma: no cover + return UDPPacket( + source=self.destination, + destination=self.source, + payload=payload, + ) + + +class FakeSocketFactory(trio.abc.SocketFactory): + fake_net: FakeNet + + def socket(self, family: int, type_: int, proto: int) -> FakeSocket: # type: ignore[override] + return FakeSocket._create(self.fake_net, family, type_, proto) + + +class FakeHostnameResolver(trio.abc.HostnameResolver): + fake_net: FakeNet + + async def getaddrinfo( + self, + host: bytes | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes], + ] + ]: + raise NotImplementedError("FakeNet doesn't do fake DNS yet") + + async def getnameinfo( + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, + ) -> tuple[str, str]: + raise NotImplementedError("FakeNet doesn't do fake DNS yet") + + +@final +class FakeNet: + def __init__(self) -> None: + # When we need to pick an arbitrary unique ip address/port, use these: + self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() # untested + self._auto_ipv6_iter = ipaddress.IPv6Network("1::/16").hosts() # untested + self._auto_port_iter = iter(range(50000, 65535)) + + self._bound: dict[UDPBinding, FakeSocket] = {} + + self.route_packet = None + + def _bind(self, binding: UDPBinding, socket: FakeSocket) -> None: + if binding in self._bound: + _fake_err(errno.EADDRINUSE) + self._bound[binding] = socket + + def enable(self) -> None: + trio.socket.set_custom_socket_factory(FakeSocketFactory(self)) + trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self)) + + def send_packet(self, packet: UDPPacket) -> None: + if self.route_packet is None: + self.deliver_packet(packet) + else: + self.route_packet(packet) + + def deliver_packet(self, packet: UDPPacket) -> None: + binding = UDPBinding(local=packet.destination) + if binding in self._bound: + self._bound[binding]._deliver_packet(packet) + else: + # No valid destination, so drop it + pass + + +@final +class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor): + def __init__( + self, + fake_net: FakeNet, + family: AddressFamily, + type: SocketKind, + proto: int, + ) -> None: + self._fake_net = fake_net + + if not family: # pragma: no cover + family = trio.socket.AF_INET + if not type: # pragma: no cover + type = trio.socket.SOCK_STREAM # noqa: A001 # name shadowing builtin + + if family not in (trio.socket.AF_INET, trio.socket.AF_INET6): + raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}") + if type != trio.socket.SOCK_DGRAM: + raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}") + + self._family = family + self._type = type + self._proto = proto + + self._closed = False + + self._packet_sender, self._packet_receiver = trio.open_memory_channel[ + UDPPacket + ](float("inf")) + + # This is the source-of-truth for what port etc. this socket is bound to + self._binding: UDPBinding | None = None + + @property + def type(self) -> SocketKind: + return self._type + + @property + def family(self) -> AddressFamily: + return self._family + + @property + def proto(self) -> int: + return self._proto + + def _check_closed(self) -> None: + if self._closed: + _fake_err(errno.EBADF) + + def close(self) -> None: + if self._closed: + return + self._closed = True + if self._binding is not None: + del self._fake_net._bound[self._binding] + self._packet_receiver.close() + + async def _resolve_address_nocp( + self, + address: object, + *, + local: bool, + ) -> tuple[str, int]: + return await trio._socket._resolve_address_nocp( # type: ignore[no-any-return] + self.type, + self.family, + self.proto, + address=address, + ipv6_v6only=False, + local=local, + ) + + def _deliver_packet(self, packet: UDPPacket) -> None: + # sending to a closed socket -- UDP packets get dropped + with contextlib.suppress(trio.BrokenResourceError): + self._packet_sender.send_nowait(packet) + + ################################################################ + # Actual IO operation implementations + ################################################################ + + async def bind(self, addr: object) -> None: + self._check_closed() + if self._binding is not None: + _fake_err(errno.EINVAL) + await trio.lowlevel.checkpoint() + ip_str, port, *_ = await self._resolve_address_nocp(addr, local=True) + assert _ == [], "TODO: handle other values?" + + ip = ipaddress.ip_address(ip_str) + assert _family_for(ip) == self.family + # We convert binds to INET_ANY into binds to localhost + if ip == ipaddress.ip_address("0.0.0.0"): + ip = ipaddress.ip_address("127.0.0.1") + elif ip == ipaddress.ip_address("::"): + ip = ipaddress.ip_address("::1") + if port == 0: + port = next(self._fake_net._auto_port_iter) + binding = UDPBinding(local=UDPEndpoint(ip, port)) + self._fake_net._bind(binding, self) + self._binding = binding + + async def connect(self, peer: object) -> NoReturn: + raise NotImplementedError("FakeNet does not (yet) support connected sockets") + + async def _sendmsg( + self, + buffers: Iterable[Buffer], + ancdata: Iterable[tuple[int, int, Buffer]] = (), + flags: int = 0, + address: AddressFormat | None = None, + ) -> int: + self._check_closed() + + await trio.lowlevel.checkpoint() + + if address is not None: + address = await self._resolve_address_nocp(address, local=False) + if ancdata: + raise NotImplementedError("FakeNet doesn't support ancillary data") + if flags: + raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}") + + if address is None: + _fake_err(errno.ENOTCONN) + + destination = UDPEndpoint.from_python_sockaddr(address) + + if self._binding is None: + await self.bind((_wildcard_ip_for(self.family).compressed, 0)) + + payload = b"".join(buffers) + + assert self._binding is not None + packet = UDPPacket( + source=self._binding.local, + destination=destination, + payload=payload, + ) + + self._fake_net.send_packet(packet) + + return len(payload) + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "sendmsg") + ): + sendmsg = _sendmsg + + async def _recvmsg_into( + self, + buffers: Iterable[Buffer], + ancbufsize: int = 0, + flags: int = 0, + ) -> tuple[ + int, + list[tuple[int, int, bytes]], + int, + tuple[str, int] | tuple[str, int, int, int], + ]: + if ancbufsize != 0: + raise NotImplementedError("FakeNet doesn't support ancillary data") + if flags != 0: + raise NotImplementedError("FakeNet doesn't support any recv flags") + if self._binding is None: + # I messed this up a few times when writing tests ... but it also never happens + # in any of the existing tests, so maybe it could be intentional... + raise NotImplementedError( + "The code will most likely hang if you try to receive on a fakesocket " + "without a binding. If that is not the case, or you explicitly want to " + "test that, remove this warning.", + ) + + self._check_closed() + + ancdata: list[tuple[int, int, bytes]] = [] + msg_flags = 0 + + packet = await self._packet_receiver.receive() + address = packet.source.as_python_sockaddr() + written = _scatter(packet.payload, buffers) + if written < len(packet.payload): + msg_flags |= trio.socket.MSG_TRUNC + return written, ancdata, msg_flags, address + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "sendmsg") + ): + recvmsg_into = _recvmsg_into + + ################################################################ + # Simple state query stuff + ################################################################ + + def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]: + self._check_closed() + if self._binding is not None: + return self._binding.local.as_python_sockaddr() + elif self.family == trio.socket.AF_INET: + return ("0.0.0.0", 0) + else: + assert self.family == trio.socket.AF_INET6 + return ("::", 0) + + # TODO: This method is not tested, and seems to make incorrect assumptions. It should maybe raise NotImplementedError. + def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]: + self._check_closed() + if self._binding is not None: + assert hasattr( + self._binding, + "remote", + ), "This method seems to assume that self._binding has a remote UDPEndpoint" + if self._binding.remote is not None: # pragma: no cover + assert isinstance( + self._binding.remote, + UDPEndpoint, + ), "Self._binding.remote should be a UDPEndpoint" + return self._binding.remote.as_python_sockaddr() + _fake_err(errno.ENOTCONN) + + @overload + def getsockopt(self, /, level: int, optname: int) -> int: ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... + + def getsockopt( + self, + /, + level: int, + optname: int, + buflen: int | None = None, + ) -> int | bytes: + self._check_closed() + raise OSError(f"FakeNet doesn't implement getsockopt({level}, {optname})") + + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ... + + @overload + def setsockopt( + self, + /, + level: int, + optname: int, + value: None, + optlen: int, + ) -> None: ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + self._check_closed() + + if (level, optname) == ( + trio.socket.IPPROTO_IPV6, + trio.socket.IPV6_V6ONLY, + ) and not value: + raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True") + + raise OSError(f"FakeNet doesn't implement setsockopt({level}, {optname}, ...)") + + ################################################################ + # Various boilerplate and trivial stubs + ################################################################ + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: builtins.type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + async def send(self, data: Buffer, flags: int = 0) -> int: + return await self.sendto(data, flags, None) + + # __ prefixed arguments because typeshed uses that and typechecker issues + @overload + async def sendto( + self, + __data: Buffer, # noqa: PYI063 + __address: tuple[object, ...] | str | Buffer, + ) -> int: ... + + # __ prefixed arguments because typeshed uses that and typechecker issues + @overload + async def sendto( + self, + __data: Buffer, # noqa: PYI063 + __flags: int, + __address: tuple[object, ...] | str | Buffer | None, + ) -> int: ... + + async def sendto( # type: ignore[explicit-any] + self, + *args: Any, + ) -> int: + data: Buffer + flags: int + address: tuple[object, ...] | str | Buffer + if len(args) == 2: + data, address = args + flags = 0 + elif len(args) == 3: + data, flags, address = args + else: + raise TypeError("wrong number of arguments") + return await self._sendmsg([data], [], flags, address) + + async def recv(self, bufsize: int, flags: int = 0) -> bytes: + data, _address = await self.recvfrom(bufsize, flags) + return data + + async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int: + got_bytes, _address = await self.recvfrom_into(buf, nbytes, flags) + return got_bytes + + async def recvfrom( + self, + bufsize: int, + flags: int = 0, + ) -> tuple[bytes, AddressFormat]: + data, _ancdata, _msg_flags, address = await self._recvmsg(bufsize, flags) + return data, address + + async def recvfrom_into( + self, + buf: Buffer, + nbytes: int = 0, + flags: int = 0, + ) -> tuple[int, AddressFormat]: + if nbytes != 0 and nbytes != memoryview(buf).nbytes: + raise NotImplementedError("partial recvfrom_into") + got_nbytes, _ancdata, _msg_flags, address = await self._recvmsg_into( + [buf], + 0, + flags, + ) + return got_nbytes, address + + async def _recvmsg( + self, + bufsize: int, + ancbufsize: int = 0, + flags: int = 0, + ) -> tuple[bytes, list[tuple[int, int, bytes]], int, AddressFormat]: + buf = bytearray(bufsize) + got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into( + [buf], + ancbufsize, + flags, + ) + return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address) + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "sendmsg") + ): + recvmsg = _recvmsg + + def fileno(self) -> int: + raise NotImplementedError("can't get fileno() for FakeNet sockets") + + def detach(self) -> int: + raise NotImplementedError("can't detach() a FakeNet socket") + + def get_inheritable(self) -> bool: + return False + + def set_inheritable(self, inheritable: bool) -> None: + if inheritable: + raise NotImplementedError("FakeNet can't make inheritable sockets") + + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "share") + ): + + def share(self, process_id: int) -> bytes: + raise NotImplementedError("FakeNet can't share sockets") diff --git a/contrib/python/trio/trio/testing/_memory_streams.py b/contrib/python/trio/trio/testing/_memory_streams.py new file mode 100644 index 00000000000..547d8afbe96 --- /dev/null +++ b/contrib/python/trio/trio/testing/_memory_streams.py @@ -0,0 +1,637 @@ +from __future__ import annotations + +import operator +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, TypeVar + +from .. import _core, _util +from .._highlevel_generic import StapledStream +from ..abc import ReceiveStream, SendStream + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + +AsyncHook: TypeAlias = Callable[[], Awaitable[object]] +# Would be nice to exclude awaitable here, but currently not possible. +SyncHook: TypeAlias = Callable[[], object] +SendStreamT = TypeVar("SendStreamT", bound=SendStream) +ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream) + + +################################################################ +# In-memory streams - Unbounded buffer version +################################################################ + + +class _UnboundedByteQueue: + def __init__(self) -> None: + self._data = bytearray() + self._closed = False + self._lot = _core.ParkingLot() + self._fetch_lock = _util.ConflictDetector( + "another task is already fetching data", + ) + + # This object treats "close" as being like closing the send side of a + # channel: so after close(), calling put() raises ClosedResourceError, and + # calling the get() variants drains the buffer and then returns an empty + # bytearray. + def close(self) -> None: + self._closed = True + self._lot.unpark_all() + + def close_and_wipe(self) -> None: + self._data = bytearray() + self.close() + + def put(self, data: bytes | bytearray | memoryview) -> None: + if self._closed: + raise _core.ClosedResourceError("virtual connection closed") + self._data += data + self._lot.unpark_all() + + def _check_max_bytes(self, max_bytes: int | None) -> None: + if max_bytes is None: + return + max_bytes = operator.index(max_bytes) + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") + + def _get_impl(self, max_bytes: int | None) -> bytearray: + assert self._closed or self._data + if max_bytes is None: + max_bytes = len(self._data) + if self._data: + chunk = self._data[:max_bytes] + del self._data[:max_bytes] + assert chunk + return chunk + else: + return bytearray() + + def get_nowait(self, max_bytes: int | None = None) -> bytearray: + with self._fetch_lock: + self._check_max_bytes(max_bytes) + if not self._closed and not self._data: + raise _core.WouldBlock + return self._get_impl(max_bytes) + + async def get(self, max_bytes: int | None = None) -> bytearray: + with self._fetch_lock: + self._check_max_bytes(max_bytes) + if not self._closed and not self._data: + await self._lot.park() + else: + await _core.checkpoint() + return self._get_impl(max_bytes) + + +@_util.final +class MemorySendStream(SendStream): + """An in-memory :class:`~trio.abc.SendStream`. + + Args: + send_all_hook: An async function, or None. Called from + :meth:`send_all`. Can do whatever you like. + wait_send_all_might_not_block_hook: An async function, or None. Called + from :meth:`wait_send_all_might_not_block`. Can do whatever you + like. + close_hook: A synchronous function, or None. Called from :meth:`close` + and :meth:`aclose`. Can do whatever you like. + + .. attribute:: send_all_hook + wait_send_all_might_not_block_hook + close_hook + + All of these hooks are also exposed as attributes on the object, and + you can change them at any time. + + """ + + def __init__( + self, + send_all_hook: AsyncHook | None = None, + wait_send_all_might_not_block_hook: AsyncHook | None = None, + close_hook: SyncHook | None = None, + ) -> None: + self._conflict_detector = _util.ConflictDetector( + "another task is using this stream", + ) + self._outgoing = _UnboundedByteQueue() + self.send_all_hook = send_all_hook + self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook + self.close_hook = close_hook + + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + """Places the given data into the object's internal buffer, and then + calls the :attr:`send_all_hook` (if any). + + """ + # Execute two checkpoints so we have more of a chance to detect + # buggy user code that calls this twice at the same time. + with self._conflict_detector: + await _core.checkpoint() + await _core.checkpoint() + self._outgoing.put(data) + if self.send_all_hook is not None: + await self.send_all_hook() + + async def wait_send_all_might_not_block(self) -> None: + """Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and + then returns immediately. + + """ + # Execute two checkpoints so that we have more of a chance to detect + # buggy user code that calls this twice at the same time. + with self._conflict_detector: + await _core.checkpoint() + await _core.checkpoint() + # check for being closed: + self._outgoing.put(b"") + if self.wait_send_all_might_not_block_hook is not None: + await self.wait_send_all_might_not_block_hook() + + def close(self) -> None: + """Marks this stream as closed, and then calls the :attr:`close_hook` + (if any). + + """ + # XXX should this cancel any pending calls to the send_all_hook and + # wait_send_all_might_not_block_hook? Those are the only places where + # send_all and wait_send_all_might_not_block can be blocked. + # + # The way we set things up, send_all_hook is memory_stream_pump, and + # wait_send_all_might_not_block_hook is unset. memory_stream_pump is + # synchronous. So normally, send_all and wait_send_all_might_not_block + # cannot block at all. + self._outgoing.close() + if self.close_hook is not None: + self.close_hook() + + async def aclose(self) -> None: + """Same as :meth:`close`, but async.""" + self.close() + await _core.checkpoint() + + async def get_data(self, max_bytes: int | None = None) -> bytearray: + """Retrieves data from the internal buffer, blocking if necessary. + + Args: + max_bytes (int or None): The maximum amount of data to + retrieve. None (the default) means to retrieve all the data + that's present (but still blocks until at least one byte is + available). + + Returns: + If this stream has been closed, an empty bytearray. Otherwise, the + requested data. + + """ + return await self._outgoing.get(max_bytes) + + def get_data_nowait(self, max_bytes: int | None = None) -> bytearray: + """Retrieves data from the internal buffer, but doesn't block. + + See :meth:`get_data` for details. + + Raises: + trio.WouldBlock: if no data is available to retrieve. + + """ + return self._outgoing.get_nowait(max_bytes) + + +@_util.final +class MemoryReceiveStream(ReceiveStream): + """An in-memory :class:`~trio.abc.ReceiveStream`. + + Args: + receive_some_hook: An async function, or None. Called from + :meth:`receive_some`. Can do whatever you like. + close_hook: A synchronous function, or None. Called from :meth:`close` + and :meth:`aclose`. Can do whatever you like. + + .. attribute:: receive_some_hook + close_hook + + Both hooks are also exposed as attributes on the object, and you can + change them at any time. + + """ + + def __init__( + self, + receive_some_hook: AsyncHook | None = None, + close_hook: SyncHook | None = None, + ) -> None: + self._conflict_detector = _util.ConflictDetector( + "another task is using this stream", + ) + self._incoming = _UnboundedByteQueue() + self._closed = False + self.receive_some_hook = receive_some_hook + self.close_hook = close_hook + + async def receive_some(self, max_bytes: int | None = None) -> bytearray: + """Calls the :attr:`receive_some_hook` (if any), and then retrieves + data from the internal buffer, blocking if necessary. + + """ + # Execute two checkpoints so we have more of a chance to detect + # buggy user code that calls this twice at the same time. + with self._conflict_detector: + await _core.checkpoint() + await _core.checkpoint() + if self._closed: + raise _core.ClosedResourceError + if self.receive_some_hook is not None: + await self.receive_some_hook() + # self._incoming's closure state tracks whether we got an EOF. + # self._closed tracks whether we, ourselves, are closed. + # self.close() sends an EOF to wake us up and sets self._closed, + # so after we wake up we have to check self._closed again. + data = await self._incoming.get(max_bytes) + if self._closed: + raise _core.ClosedResourceError + return data + + def close(self) -> None: + """Discards any pending data from the internal buffer, and marks this + stream as closed. + + """ + self._closed = True + self._incoming.close_and_wipe() + if self.close_hook is not None: + self.close_hook() + + async def aclose(self) -> None: + """Same as :meth:`close`, but async.""" + self.close() + await _core.checkpoint() + + def put_data(self, data: bytes | bytearray | memoryview) -> None: + """Appends the given data to the internal buffer.""" + self._incoming.put(data) + + def put_eof(self) -> None: + """Adds an end-of-file marker to the internal buffer.""" + self._incoming.close() + + +# TODO: investigate why this is necessary for the docs +MemorySendStream.__module__ = MemorySendStream.__module__.replace( + "._memory_streams", "" +) +MemoryReceiveStream.__module__ = MemoryReceiveStream.__module__.replace( + "._memory_streams", "" +) + + +def memory_stream_pump( + memory_send_stream: MemorySendStream, + memory_receive_stream: MemoryReceiveStream, + *, + max_bytes: int | None = None, +) -> bool: + """Take data out of the given :class:`MemorySendStream`'s internal buffer, + and put it into the given :class:`MemoryReceiveStream`'s internal buffer. + + Args: + memory_send_stream (MemorySendStream): The stream to get data from. + memory_receive_stream (MemoryReceiveStream): The stream to put data into. + max_bytes (int or None): The maximum amount of data to transfer in this + call, or None to transfer all available data. + + Returns: + True if it successfully transferred some data, or False if there was no + data to transfer. + + This is used to implement :func:`memory_stream_one_way_pair` and + :func:`memory_stream_pair`; see the latter's docstring for an example + of how you might use it yourself. + + """ + try: + data = memory_send_stream.get_data_nowait(max_bytes) + except _core.WouldBlock: + return False + try: + if not data: + memory_receive_stream.put_eof() + else: + memory_receive_stream.put_data(data) + except _core.ClosedResourceError: + raise _core.BrokenResourceError("MemoryReceiveStream was closed") from None + return True + + +def memory_stream_one_way_pair() -> tuple[MemorySendStream, MemoryReceiveStream]: + """Create a connected, pure-Python, unidirectional stream with infinite + buffering and flexible configuration options. + + You can think of this as being a no-operating-system-involved + Trio-streamsified version of :func:`os.pipe` (except that :func:`os.pipe` + returns the streams in the wrong order – we follow the superior convention + that data flows from left to right). + + Returns: + A tuple (:class:`MemorySendStream`, :class:`MemoryReceiveStream`), where + the :class:`MemorySendStream` has its hooks set up so that it calls + :func:`memory_stream_pump` from its + :attr:`~MemorySendStream.send_all_hook` and + :attr:`~MemorySendStream.close_hook`. + + The end result is that data automatically flows from the + :class:`MemorySendStream` to the :class:`MemoryReceiveStream`. But you're + also free to rearrange things however you like. For example, you can + temporarily set the :attr:`~MemorySendStream.send_all_hook` to None if you + want to simulate a stall in data transmission. Or see + :func:`memory_stream_pair` for a more elaborate example. + + """ + send_stream = MemorySendStream() + recv_stream = MemoryReceiveStream() + + def pump_from_send_stream_to_recv_stream() -> None: + memory_stream_pump(send_stream, recv_stream) + + # await not used + async def async_pump_from_send_stream_to_recv_stream() -> None: # noqa: RUF029 + pump_from_send_stream_to_recv_stream() + + send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream + send_stream.close_hook = pump_from_send_stream_to_recv_stream + return send_stream, recv_stream + + +def _make_stapled_pair( + one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]], +) -> tuple[ + StapledStream[SendStreamT, ReceiveStreamT], + StapledStream[SendStreamT, ReceiveStreamT], +]: + pipe1_send, pipe1_recv = one_way_pair() + pipe2_send, pipe2_recv = one_way_pair() + stream1 = StapledStream(pipe1_send, pipe2_recv) + stream2 = StapledStream(pipe2_send, pipe1_recv) + return stream1, stream2 + + +def memory_stream_pair() -> tuple[ + StapledStream[MemorySendStream, MemoryReceiveStream], + StapledStream[MemorySendStream, MemoryReceiveStream], +]: + """Create a connected, pure-Python, bidirectional stream with infinite + buffering and flexible configuration options. + + This is a convenience function that creates two one-way streams using + :func:`memory_stream_one_way_pair`, and then uses + :class:`~trio.StapledStream` to combine them into a single bidirectional + stream. + + This is like a no-operating-system-involved, Trio-streamsified version of + :func:`socket.socketpair`. + + Returns: + A pair of :class:`~trio.StapledStream` objects that are connected so + that data automatically flows from one to the other in both directions. + + After creating a stream pair, you can send data back and forth, which is + enough for simple tests:: + + left, right = memory_stream_pair() + await left.send_all(b"123") + assert await right.receive_some() == b"123" + await right.send_all(b"456") + assert await left.receive_some() == b"456" + + But if you read the docs for :class:`~trio.StapledStream` and + :func:`memory_stream_one_way_pair`, you'll see that all the pieces + involved in wiring this up are public APIs, so you can adjust to suit the + requirements of your tests. For example, here's how to tweak a stream so + that data flowing from left to right trickles in one byte at a time (but + data flowing from right to left proceeds at full speed):: + + left, right = memory_stream_pair() + async def trickle(): + # left is a StapledStream, and left.send_stream is a MemorySendStream + # right is a StapledStream, and right.recv_stream is a MemoryReceiveStream + while memory_stream_pump(left.send_stream, right.recv_stream, max_bytes=1): + # Pause between each byte + await trio.sleep(1) + # Normally this send_all_hook calls memory_stream_pump directly without + # passing in a max_bytes. We replace it with our custom version: + left.send_stream.send_all_hook = trickle + + And here's a simple test using our modified stream objects:: + + async def sender(): + await left.send_all(b"12345") + await left.send_eof() + + async def receiver(): + async for data in right: + print(data) + + async with trio.open_nursery() as nursery: + nursery.start_soon(sender) + nursery.start_soon(receiver) + + By default, this will print ``b"12345"`` and then immediately exit; with + our trickle stream it instead sleeps 1 second, then prints ``b"1"``, then + sleeps 1 second, then prints ``b"2"``, etc. + + Pro-tip: you can insert sleep calls (like in our example above) to + manipulate the flow of data across tasks... and then use + :class:`MockClock` and its :attr:`~MockClock.autojump_threshold` + functionality to keep your test suite running quickly. + + If you want to stress test a protocol implementation, one nice trick is to + use the :mod:`random` module (preferably with a fixed seed) to move random + numbers of bytes at a time, and insert random sleeps in between them. You + can also set up a custom :attr:`~MemoryReceiveStream.receive_some_hook` if + you want to manipulate things on the receiving side, and not just the + sending side. + + """ + return _make_stapled_pair(memory_stream_one_way_pair) + + +################################################################ +# In-memory streams - Lockstep version +################################################################ + + +class _LockstepByteQueue: + def __init__(self) -> None: + self._data = bytearray() + self._sender_closed = False + self._receiver_closed = False + self._receiver_waiting = False + self._waiters = _core.ParkingLot() + self._send_conflict_detector = _util.ConflictDetector( + "another task is already sending", + ) + self._receive_conflict_detector = _util.ConflictDetector( + "another task is already receiving", + ) + + def _something_happened(self) -> None: + self._waiters.unpark_all() + + # Always wakes up when one side is closed, because everyone always reacts + # to that. + async def _wait_for(self, fn: Callable[[], bool]) -> None: + while True: + if fn(): + break + if self._sender_closed or self._receiver_closed: + break + await self._waiters.park() + await _core.checkpoint() + + def close_sender(self) -> None: + self._sender_closed = True + self._something_happened() + + def close_receiver(self) -> None: + self._receiver_closed = True + self._something_happened() + + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + with self._send_conflict_detector: + if self._sender_closed: + raise _core.ClosedResourceError + if self._receiver_closed: + raise _core.BrokenResourceError + assert not self._data + self._data += data + self._something_happened() + await self._wait_for(lambda: self._data == b"") + if self._sender_closed: + raise _core.ClosedResourceError + if self._data and self._receiver_closed: + raise _core.BrokenResourceError + + async def wait_send_all_might_not_block(self) -> None: + with self._send_conflict_detector: + if self._sender_closed: + raise _core.ClosedResourceError + if self._receiver_closed: + await _core.checkpoint() + return + await self._wait_for(lambda: self._receiver_waiting) + if self._sender_closed: + raise _core.ClosedResourceError + + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: + with self._receive_conflict_detector: + # Argument validation + if max_bytes is not None: + max_bytes = operator.index(max_bytes) + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") + # State validation + if self._receiver_closed: + raise _core.ClosedResourceError + # Wake wait_send_all_might_not_block and wait for data + self._receiver_waiting = True + self._something_happened() + try: + await self._wait_for(lambda: self._data != b"") + finally: + self._receiver_waiting = False + if self._receiver_closed: + raise _core.ClosedResourceError + # Get data, possibly waking send_all + if self._data: + # Neat trick: if max_bytes is None, then obj[:max_bytes] is + # the same as obj[:]. + got = self._data[:max_bytes] + del self._data[:max_bytes] + self._something_happened() + return got + else: + assert self._sender_closed + return b"" + + +class _LockstepSendStream(SendStream): + def __init__(self, lbq: _LockstepByteQueue) -> None: + self._lbq = lbq + + def close(self) -> None: + self._lbq.close_sender() + + async def aclose(self) -> None: + self.close() + await _core.checkpoint() + + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + await self._lbq.send_all(data) + + async def wait_send_all_might_not_block(self) -> None: + await self._lbq.wait_send_all_might_not_block() + + +class _LockstepReceiveStream(ReceiveStream): + def __init__(self, lbq: _LockstepByteQueue) -> None: + self._lbq = lbq + + def close(self) -> None: + self._lbq.close_receiver() + + async def aclose(self) -> None: + self.close() + await _core.checkpoint() + + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: + return await self._lbq.receive_some(max_bytes) + + +def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]: + """Create a connected, pure Python, unidirectional stream where data flows + in lockstep. + + Returns: + A tuple + (:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`). + + This stream has *absolutely no* buffering. Each call to + :meth:`~trio.abc.SendStream.send_all` will block until all the given data + has been returned by a call to + :meth:`~trio.abc.ReceiveStream.receive_some`. + + This can be useful for testing flow control mechanisms in an extreme case, + or for setting up "clogged" streams to use with + :func:`check_one_way_stream` and friends. + + In addition to fulfilling the :class:`~trio.abc.SendStream` and + :class:`~trio.abc.ReceiveStream` interfaces, the return objects + also have a synchronous ``close`` method. + + """ + + lbq = _LockstepByteQueue() + return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq) + + +def lockstep_stream_pair() -> tuple[ + StapledStream[SendStream, ReceiveStream], + StapledStream[SendStream, ReceiveStream], +]: + """Create a connected, pure-Python, bidirectional stream where data flows + in lockstep. + + Returns: + A tuple (:class:`~trio.StapledStream`, :class:`~trio.StapledStream`). + + This is a convenience function that creates two one-way streams using + :func:`lockstep_stream_one_way_pair`, and then uses + :class:`~trio.StapledStream` to combine them into a single bidirectional + stream. + + """ + return _make_stapled_pair(lockstep_stream_one_way_pair) diff --git a/contrib/python/trio/trio/testing/_network.py b/contrib/python/trio/trio/testing/_network.py new file mode 100644 index 00000000000..fddbbf0fdc7 --- /dev/null +++ b/contrib/python/trio/trio/testing/_network.py @@ -0,0 +1,36 @@ +from .. import socket as tsocket +from .._highlevel_socket import SocketListener, SocketStream + + +async def open_stream_to_socket_listener( + socket_listener: SocketListener, +) -> SocketStream: + """Connect to the given :class:`~trio.SocketListener`. + + This is particularly useful in tests when you want to let a server pick + its own port, and then connect to it:: + + listeners = await trio.open_tcp_listeners(0) + client = await trio.testing.open_stream_to_socket_listener(listeners[0]) + + Args: + socket_listener (~trio.SocketListener): The + :class:`~trio.SocketListener` to connect to. + + Returns: + SocketStream: a stream connected to the given listener. + + """ + family = socket_listener.socket.family + sockaddr = socket_listener.socket.getsockname() + if family in (tsocket.AF_INET, tsocket.AF_INET6): + sockaddr = list(sockaddr) + if sockaddr[0] == "0.0.0.0": + sockaddr[0] = "127.0.0.1" + if sockaddr[0] == "::": + sockaddr[0] = "::1" + sockaddr = tuple(sockaddr) + + sock = tsocket.socket(family=family) + await sock.connect(sockaddr) + return SocketStream(sock) diff --git a/contrib/python/trio/trio/testing/_raises_group.py b/contrib/python/trio/trio/testing/_raises_group.py new file mode 100644 index 00000000000..6001e4dab10 --- /dev/null +++ b/contrib/python/trio/trio/testing/_raises_group.py @@ -0,0 +1,1021 @@ +from __future__ import annotations + +import re +import sys +from abc import ABC, abstractmethod +from re import Pattern +from textwrap import indent +from typing import ( + TYPE_CHECKING, + Generic, + Literal, + cast, + overload, +) + +from trio._util import final + +if TYPE_CHECKING: + import builtins + + # sphinx will *only* work if we use types.TracebackType, and import + # *inside* TYPE_CHECKING. No other combination works..... + import types + from collections.abc import Callable, Sequence + + from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback + from typing_extensions import TypeGuard, TypeVar + + # this conditional definition is because we want to allow a TypeVar default + MatchE = TypeVar( + "MatchE", + bound=BaseException, + default=BaseException, + covariant=True, + ) +else: + from typing import TypeVar + + MatchE = TypeVar("MatchE", bound=BaseException, covariant=True) + +# RaisesGroup doesn't work with a default. +BaseExcT_co = TypeVar("BaseExcT_co", bound=BaseException, covariant=True) +BaseExcT_1 = TypeVar("BaseExcT_1", bound=BaseException) +BaseExcT_2 = TypeVar("BaseExcT_2", bound=BaseException) +ExcT_1 = TypeVar("ExcT_1", bound=Exception) +ExcT_2 = TypeVar("ExcT_2", bound=Exception) + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup, ExceptionGroup + + +@final +class _ExceptionInfo(Generic[MatchE]): + """Minimal re-implementation of pytest.ExceptionInfo, only used if pytest is not available. Supports a subset of its features necessary for functionality of :class:`trio.testing.RaisesGroup` and :class:`trio.testing.Matcher`.""" + + _excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None + + def __init__( + self, + excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None, + ) -> None: + self._excinfo = excinfo + + def fill_unfilled( + self, + exc_info: tuple[type[MatchE], MatchE, types.TracebackType], + ) -> None: + """Fill an unfilled ExceptionInfo created with ``for_later()``.""" + assert self._excinfo is None, "ExceptionInfo was already filled" + self._excinfo = exc_info + + @classmethod + def for_later(cls) -> _ExceptionInfo[MatchE]: + """Return an unfilled ExceptionInfo.""" + return cls(None) + + # Note, special cased in sphinx config, since "type" conflicts. + @property + def type(self) -> type[MatchE]: + """The exception class.""" + assert ( + self._excinfo is not None + ), ".type can only be used after the context manager exits" + return self._excinfo[0] + + @property + def value(self) -> MatchE: + """The exception value.""" + assert ( + self._excinfo is not None + ), ".value can only be used after the context manager exits" + return self._excinfo[1] + + @property + def tb(self) -> types.TracebackType: + """The exception raw traceback.""" + assert ( + self._excinfo is not None + ), ".tb can only be used after the context manager exits" + return self._excinfo[2] + + def exconly(self, tryshort: bool = False) -> str: + raise NotImplementedError( + "This is a helper method only available if you use RaisesGroup with the pytest package installed", + ) + + def errisinstance( + self, + exc: builtins.type[BaseException] | tuple[builtins.type[BaseException], ...], + ) -> bool: + raise NotImplementedError( + "This is a helper method only available if you use RaisesGroup with the pytest package installed", + ) + + def getrepr( + self, + showlocals: bool = False, + style: str = "long", + abspath: bool = False, + tbfilter: bool | Callable[[_ExceptionInfo], Traceback] = True, + funcargs: bool = False, + truncate_locals: bool = True, + chain: bool = True, + ) -> ReprExceptionInfo | ExceptionChainRepr: + raise NotImplementedError( + "This is a helper method only available if you use RaisesGroup with the pytest package installed", + ) + + +# Type checkers are not able to do conditional types depending on installed packages, so +# we've added signatures for all helpers to _ExceptionInfo, and then always use that. +# If this ends up leading to problems, we can resort to always using _ExceptionInfo and +# users that want to use getrepr/errisinstance/exconly can write helpers on their own, or +# we reimplement them ourselves...or get this merged in upstream pytest. +if TYPE_CHECKING: + ExceptionInfo = _ExceptionInfo + +else: + try: + from pytest import ExceptionInfo # noqa: PT013 + except ImportError: # pragma: no cover + ExceptionInfo = _ExceptionInfo + + +# copied from pytest.ExceptionInfo +def _stringify_exception(exc: BaseException) -> str: + return "\n".join( + [ + getattr(exc, "message", str(exc)), + *getattr(exc, "__notes__", []), + ], + ) + + +# String patterns default to including the unicode flag. +_REGEX_NO_FLAGS = re.compile(r"").flags + + +def _match_pattern(match: Pattern[str]) -> str | Pattern[str]: + """helper function to remove redundant `re.compile` calls when printing regex""" + return match.pattern if match.flags == _REGEX_NO_FLAGS else match + + +def repr_callable(fun: Callable[[BaseExcT_1], bool]) -> str: + """Get the repr of a ``check`` parameter. + + Split out so it can be monkeypatched (e.g. by our hypothesis plugin) + """ + return repr(fun) + + +def _exception_type_name(e: type[BaseException]) -> str: + return repr(e.__name__) + + +def _check_raw_type( + expected_type: type[BaseException] | None, + exception: BaseException, +) -> str | None: + if expected_type is None: + return None + + if not isinstance( + exception, + expected_type, + ): + actual_type_str = _exception_type_name(type(exception)) + expected_type_str = _exception_type_name(expected_type) + if isinstance(exception, BaseExceptionGroup) and not issubclass( + expected_type, BaseExceptionGroup + ): + return f"Unexpected nested {actual_type_str}, expected {expected_type_str}" + return f"{actual_type_str} is not of type {expected_type_str}" + return None + + +class AbstractMatcher(ABC, Generic[BaseExcT_co]): + """ABC with common functionality shared between Matcher and RaisesGroup""" + + def __init__( + self, + match: str | Pattern[str] | None, + check: Callable[[BaseExcT_co], bool] | None, + ) -> None: + if isinstance(match, str): + self.match: Pattern[str] | None = re.compile(match) + else: + self.match = match + self.check = check + self._fail_reason: str | None = None + + # used to suppress repeated printing of `repr(self.check)` + self._nested: bool = False + + @property + def fail_reason(self) -> str | None: + """Set after a call to `matches` to give a human-readable + reason for why the match failed. + When used as a context manager the string will be given as the text of an + `AssertionError`""" + return self._fail_reason + + def _check_check( + self: AbstractMatcher[BaseExcT_1], + exception: BaseExcT_1, + ) -> bool: + if self.check is None: + return True + + if self.check(exception): + return True + + check_repr = "" if self._nested else " " + repr_callable(self.check) + self._fail_reason = f"check{check_repr} did not return True" + return False + + def _check_match(self, e: BaseException) -> bool: + if self.match is None or re.search( + self.match, + stringified_exception := _stringify_exception(e), + ): + return True + + maybe_specify_type = ( + f" of {_exception_type_name(type(e))}" + if isinstance(e, BaseExceptionGroup) + else "" + ) + self._fail_reason = f"Regex pattern {_match_pattern(self.match)!r} did not match {stringified_exception!r}{maybe_specify_type}" + if _match_pattern(self.match) == stringified_exception: + self._fail_reason += "\n Did you mean to `re.escape()` the regex?" + return False + + # TODO: when transitioning to pytest, harmonize Matcher and RaisesGroup + # signatures. One names the parameter `exc_val` and the other `exception` + @abstractmethod + def matches( + self: AbstractMatcher[BaseExcT_1], exc_val: BaseException + ) -> TypeGuard[BaseExcT_1]: + """Check if an exception matches the requirements of this AbstractMatcher. + If it fails, `AbstractMatcher.fail_reason` should be set. + """ + + +@final +class Matcher(AbstractMatcher[MatchE]): + """Helper class to be used together with RaisesGroups when you want to specify requirements on sub-exceptions. Only specifying the type is redundant, and it's also unnecessary when the type is a nested `RaisesGroup` since it supports the same arguments. + The type is checked with `isinstance`, and does not need to be an exact match. If that is wanted you can use the ``check`` parameter. + :meth:`Matcher.matches` can also be used standalone to check individual exceptions. + + Examples:: + + with RaisesGroups(Matcher(ValueError, match="string")): + ... + with RaisesGroups(Matcher(check=lambda x: x.args == (3, "hello"))): + ... + with RaisesGroups(Matcher(check=lambda x: type(x) is ValueError)): + ... + + Tip: if you install ``hypothesis`` and import it in ``conftest.py`` you will get + readable ``repr``s of ``check`` callables in the output. + """ + + # At least one of the three parameters must be passed. + @overload + def __init__( + self, + exception_type: type[MatchE], + match: str | Pattern[str] = ..., + check: Callable[[MatchE], bool] = ..., + ) -> None: ... + + @overload + def __init__( + self: Matcher[BaseException], # Give E a value. + *, + match: str | Pattern[str], + # If exception_type is not provided, check() must do any typechecks itself. + check: Callable[[BaseException], bool] = ..., + ) -> None: ... + + @overload + def __init__(self, *, check: Callable[[BaseException], bool]) -> None: ... + + def __init__( + self, + exception_type: type[MatchE] | None = None, + match: str | Pattern[str] | None = None, + check: Callable[[MatchE], bool] | None = None, + ): + super().__init__(match, check) + if exception_type is None and match is None and check is None: + raise ValueError("You must specify at least one parameter to match on.") + if exception_type is not None and not issubclass(exception_type, BaseException): + raise ValueError( + f"exception_type {exception_type} must be a subclass of BaseException", + ) + self.exception_type = exception_type + + def matches( + self, + exception: BaseException, + ) -> TypeGuard[MatchE]: + """Check if an exception matches the requirements of this Matcher. + If it fails, `Matcher.fail_reason` will be set. + + Examples:: + + assert Matcher(ValueError).matches(my_exception) + # is equivalent to + assert isinstance(my_exception, ValueError) + + # this can be useful when checking e.g. the ``__cause__`` of an exception. + with pytest.raises(ValueError) as excinfo: + ... + assert Matcher(SyntaxError, match="foo").matches(excinfo.value.__cause__) + # above line is equivalent to + assert isinstance(excinfo.value.__cause__, SyntaxError) + assert re.search("foo", str(excinfo.value.__cause__)) + + """ + if not self._check_type(exception): + return False + + if not self._check_match(exception): + return False + + return self._check_check(exception) + + def __repr__(self) -> str: + parameters = [] + if self.exception_type is not None: + parameters.append(self.exception_type.__name__) + if self.match is not None: + # If no flags were specified, discard the redundant re.compile() here. + parameters.append( + f"match={_match_pattern(self.match)!r}", + ) + if self.check is not None: + parameters.append(f"check={repr_callable(self.check)}") + return f'Matcher({", ".join(parameters)})' + + def _check_type(self, exception: BaseException) -> TypeGuard[MatchE]: + self._fail_reason = _check_raw_type(self.exception_type, exception) + return self._fail_reason is None + + +@final +class RaisesGroup(AbstractMatcher[BaseExceptionGroup[BaseExcT_co]]): + """Contextmanager for checking for an expected `ExceptionGroup`. + This works similar to ``pytest.raises``, and a version of it will hopefully be added upstream, after which this can be deprecated and removed. See https://github.com/pytest-dev/pytest/issues/11538 + + + The catching behaviour differs from :ref:`except* <except_star>` in multiple different ways, being much stricter by default. By using ``allow_unwrapped=True`` and ``flatten_subgroups=True`` you can match ``except*`` fully when expecting a single exception. + + #. All specified exceptions must be present, *and no others*. + + * If you expect a variable number of exceptions you need to use ``pytest.raises(ExceptionGroup)`` and manually check the contained exceptions. Consider making use of :func:`Matcher.matches`. + + #. It will only catch exceptions wrapped in an exceptiongroup by default. + + * With ``allow_unwrapped=True`` you can specify a single expected exception or `Matcher` and it will match the exception even if it is not inside an `ExceptionGroup`. If you expect one of several different exception types you need to use a `Matcher` object. + + #. By default it cares about the full structure with nested `ExceptionGroup`'s. You can specify nested `ExceptionGroup`'s by passing `RaisesGroup` objects as expected exceptions. + + * With ``flatten_subgroups=True`` it will "flatten" the raised `ExceptionGroup`, extracting all exceptions inside any nested :class:`ExceptionGroup`, before matching. + + It does not care about the order of the exceptions, so ``RaisesGroups(ValueError, TypeError)`` is equivalent to ``RaisesGroups(TypeError, ValueError)``. + + Examples:: + + with RaisesGroups(ValueError): + raise ExceptionGroup("", (ValueError(),)) + with RaisesGroups(ValueError, ValueError, Matcher(TypeError, match="expected int")): + ... + with RaisesGroups(KeyboardInterrupt, match="hello", check=lambda x: type(x) is BaseExceptionGroup): + ... + with RaisesGroups(RaisesGroups(ValueError)): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + # flatten_subgroups + with RaisesGroups(ValueError, flatten_subgroups=True): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + # allow_unwrapped + with RaisesGroups(ValueError, allow_unwrapped=True): + raise ValueError + + + `RaisesGroup.matches` can also be used directly to check a standalone exception group. + + + The matching algorithm is greedy, which means cases such as this may fail:: + + with RaisesGroups(ValueError, Matcher(ValueError, match="hello")): + raise ExceptionGroup("", (ValueError("hello"), ValueError("goodbye"))) + + even though it generally does not care about the order of the exceptions in the group. + To avoid the above you should specify the first ValueError with a Matcher as well. + + Tip: if you install ``hypothesis`` and import it in ``conftest.py`` you will get + readable ``repr``s of ``check`` callables in the output. + """ + + # allow_unwrapped=True requires: singular exception, exception not being + # RaisesGroup instance, match is None, check is None + @overload + def __init__( + self, + exception: type[BaseExcT_co] | Matcher[BaseExcT_co], + *, + allow_unwrapped: Literal[True], + flatten_subgroups: bool = False, + ) -> None: ... + + # flatten_subgroups = True also requires no nested RaisesGroup + @overload + def __init__( + self, + exception: type[BaseExcT_co] | Matcher[BaseExcT_co], + *other_exceptions: type[BaseExcT_co] | Matcher[BaseExcT_co], + flatten_subgroups: Literal[True], + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[BaseExcT_co]], bool] | None = None, + ) -> None: ... + + # simplify the typevars if possible (the following 3 are equivalent but go simpler->complicated) + # ... the first handles RaisesGroup[ValueError], the second RaisesGroup[ExceptionGroup[ValueError]], + # the third RaisesGroup[ValueError | ExceptionGroup[ValueError]]. + # ... otherwise, we will get results like RaisesGroup[ValueError | ExceptionGroup[Never]] (I think) + # (technically correct but misleading) + @overload + def __init__( + self: RaisesGroup[ExcT_1], + exception: type[ExcT_1] | Matcher[ExcT_1], + *other_exceptions: type[ExcT_1] | Matcher[ExcT_1], + match: str | Pattern[str] | None = None, + check: Callable[[ExceptionGroup[ExcT_1]], bool] | None = None, + ) -> None: ... + + @overload + def __init__( + self: RaisesGroup[ExceptionGroup[ExcT_2]], + exception: RaisesGroup[ExcT_2], + *other_exceptions: RaisesGroup[ExcT_2], + match: str | Pattern[str] | None = None, + check: Callable[[ExceptionGroup[ExceptionGroup[ExcT_2]]], bool] | None = None, + ) -> None: ... + + @overload + def __init__( + self: RaisesGroup[ExcT_1 | ExceptionGroup[ExcT_2]], + exception: type[ExcT_1] | Matcher[ExcT_1] | RaisesGroup[ExcT_2], + *other_exceptions: type[ExcT_1] | Matcher[ExcT_1] | RaisesGroup[ExcT_2], + match: str | Pattern[str] | None = None, + check: ( + Callable[[ExceptionGroup[ExcT_1 | ExceptionGroup[ExcT_2]]], bool] | None + ) = None, + ) -> None: ... + + # same as the above 3 but handling BaseException + @overload + def __init__( + self: RaisesGroup[BaseExcT_1], + exception: type[BaseExcT_1] | Matcher[BaseExcT_1], + *other_exceptions: type[BaseExcT_1] | Matcher[BaseExcT_1], + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[BaseExcT_1]], bool] | None = None, + ) -> None: ... + + @overload + def __init__( + self: RaisesGroup[BaseExceptionGroup[BaseExcT_2]], + exception: RaisesGroup[BaseExcT_2], + *other_exceptions: RaisesGroup[BaseExcT_2], + match: str | Pattern[str] | None = None, + check: ( + Callable[[BaseExceptionGroup[BaseExceptionGroup[BaseExcT_2]]], bool] | None + ) = None, + ) -> None: ... + + @overload + def __init__( + self: RaisesGroup[BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]], + exception: type[BaseExcT_1] | Matcher[BaseExcT_1] | RaisesGroup[BaseExcT_2], + *other_exceptions: type[BaseExcT_1] + | Matcher[BaseExcT_1] + | RaisesGroup[BaseExcT_2], + match: str | Pattern[str] | None = None, + check: ( + Callable[ + [BaseExceptionGroup[BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]]], + bool, + ] + | None + ) = None, + ) -> None: ... + + def __init__( + self: RaisesGroup[ExcT_1 | BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]], + exception: type[BaseExcT_1] | Matcher[BaseExcT_1] | RaisesGroup[BaseExcT_2], + *other_exceptions: type[BaseExcT_1] + | Matcher[BaseExcT_1] + | RaisesGroup[BaseExcT_2], + allow_unwrapped: bool = False, + flatten_subgroups: bool = False, + match: str | Pattern[str] | None = None, + check: ( + Callable[[BaseExceptionGroup[BaseExcT_1]], bool] + | Callable[[ExceptionGroup[ExcT_1]], bool] + | None + ) = None, + ): + # The type hint on the `self` and `check` parameters uses different formats + # that are *very* hard to reconcile while adhering to the overloads, so we cast + # it to avoid an error when passing it to super().__init__ + check = cast( + "Callable[[" + "BaseExceptionGroup[ExcT_1|BaseExcT_1|BaseExceptionGroup[BaseExcT_2]]" + "], bool]", + check, + ) + super().__init__(match, check) + self.expected_exceptions: tuple[ + type[BaseExcT_co] | Matcher[BaseExcT_co] | RaisesGroup[BaseException], ... + ] = ( + exception, + *other_exceptions, + ) + self.allow_unwrapped = allow_unwrapped + self.flatten_subgroups: bool = flatten_subgroups + self.is_baseexceptiongroup: bool = False + + if allow_unwrapped and other_exceptions: + raise ValueError( + "You cannot specify multiple exceptions with `allow_unwrapped=True.`" + " If you want to match one of multiple possible exceptions you should" + " use a `Matcher`." + " E.g. `Matcher(check=lambda e: isinstance(e, (...)))`", + ) + if allow_unwrapped and isinstance(exception, RaisesGroup): + raise ValueError( + "`allow_unwrapped=True` has no effect when expecting a `RaisesGroup`." + " You might want it in the expected `RaisesGroup`, or" + " `flatten_subgroups=True` if you don't care about the structure.", + ) + if allow_unwrapped and (match is not None or check is not None): + raise ValueError( + "`allow_unwrapped=True` bypasses the `match` and `check` parameters" + " if the exception is unwrapped. If you intended to match/check the" + " exception you should use a `Matcher` object. If you want to match/check" + " the exceptiongroup when the exception *is* wrapped you need to" + " do e.g. `if isinstance(exc.value, ExceptionGroup):" + " assert RaisesGroup(...).matches(exc.value)` afterwards.", + ) + + # verify `expected_exceptions` and set `self.is_baseexceptiongroup` + for exc in self.expected_exceptions: + if isinstance(exc, RaisesGroup): + if self.flatten_subgroups: + raise ValueError( + "You cannot specify a nested structure inside a RaisesGroup with" + " `flatten_subgroups=True`. The parameter will flatten subgroups" + " in the raised exceptiongroup before matching, which would never" + " match a nested structure.", + ) + self.is_baseexceptiongroup |= exc.is_baseexceptiongroup + exc._nested = True + elif isinstance(exc, Matcher): + if exc.exception_type is not None: + # Matcher __init__ assures it's a subclass of BaseException + self.is_baseexceptiongroup |= not issubclass( + exc.exception_type, + Exception, + ) + exc._nested = True + elif isinstance(exc, type) and issubclass(exc, BaseException): + self.is_baseexceptiongroup |= not issubclass(exc, Exception) + else: + raise ValueError( + f'Invalid argument "{exc!r}" must be exception type, Matcher, or' + " RaisesGroup.", + ) + + @overload + def __enter__( + self: RaisesGroup[ExcT_1], + ) -> ExceptionInfo[ExceptionGroup[ExcT_1]]: ... + @overload + def __enter__( + self: RaisesGroup[BaseExcT_1], + ) -> ExceptionInfo[BaseExceptionGroup[BaseExcT_1]]: ... + + def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[BaseException]]: + self.excinfo: ExceptionInfo[BaseExceptionGroup[BaseExcT_co]] = ( + ExceptionInfo.for_later() + ) + return self.excinfo + + def __repr__(self) -> str: + parameters = [ + e.__name__ if isinstance(e, type) else repr(e) + for e in self.expected_exceptions + ] + if self.allow_unwrapped: + parameters.append(f"allow_unwrapped={self.allow_unwrapped}") + if self.flatten_subgroups: + parameters.append(f"flatten_subgroups={self.flatten_subgroups}") + if self.match is not None: + # If no flags were specified, discard the redundant re.compile() here. + parameters.append(f"match={_match_pattern(self.match)!r}") + if self.check is not None: + parameters.append(f"check={repr_callable(self.check)}") + return f"RaisesGroup({', '.join(parameters)})" + + def _unroll_exceptions( + self, + exceptions: Sequence[BaseException], + ) -> Sequence[BaseException]: + """Used if `flatten_subgroups=True`.""" + res: list[BaseException] = [] + for exc in exceptions: + if isinstance(exc, BaseExceptionGroup): + res.extend(self._unroll_exceptions(exc.exceptions)) + + else: + res.append(exc) + return res + + @overload + def matches( + self: RaisesGroup[ExcT_1], + exc_val: BaseException | None, + ) -> TypeGuard[ExceptionGroup[ExcT_1]]: ... + @overload + def matches( + self: RaisesGroup[BaseExcT_1], + exc_val: BaseException | None, + ) -> TypeGuard[BaseExceptionGroup[BaseExcT_1]]: ... + + def matches( + self, + exc_val: BaseException | None, + ) -> TypeGuard[BaseExceptionGroup[BaseExcT_co]]: + """Check if an exception matches the requirements of this RaisesGroup. + If it fails, `RaisesGroup.fail_reason` will be set. + + Example:: + + with pytest.raises(TypeError) as excinfo: + ... + assert RaisesGroups(ValueError).matches(excinfo.value.__cause__) + # the above line is equivalent to + myexc = excinfo.value.__cause + assert isinstance(myexc, BaseExceptionGroup) + assert len(myexc.exceptions) == 1 + assert isinstance(myexc.exceptions[0], ValueError) + """ + self._fail_reason = None + if exc_val is None: + self._fail_reason = "exception is None" + return False + if not isinstance(exc_val, BaseExceptionGroup): + # we opt to only print type of the exception here, as the repr would + # likely be quite long + not_group_msg = f"{type(exc_val).__name__!r} is not an exception group" + if len(self.expected_exceptions) > 1: + self._fail_reason = not_group_msg + return False + # if we have 1 expected exception, check if it would work even if + # allow_unwrapped is not set + res = self._check_expected(self.expected_exceptions[0], exc_val) + if res is None and self.allow_unwrapped: + return True + + if res is None: + self._fail_reason = ( + f"{not_group_msg}, but would match with `allow_unwrapped=True`" + ) + elif self.allow_unwrapped: + self._fail_reason = res + else: + self._fail_reason = not_group_msg + return False + + actual_exceptions: Sequence[BaseException] = exc_val.exceptions + if self.flatten_subgroups: + actual_exceptions = self._unroll_exceptions(actual_exceptions) + + if not self._check_match(exc_val): + old_reason = self._fail_reason + if ( + len(actual_exceptions) == len(self.expected_exceptions) == 1 + and isinstance(expected := self.expected_exceptions[0], type) + and isinstance(actual := actual_exceptions[0], expected) + and self._check_match(actual) + ): + assert self.match is not None, "can't be None if _check_match failed" + assert self._fail_reason is old_reason is not None + self._fail_reason += f", but matched the expected {self._repr_expected(expected)}. You might want RaisesGroup(Matcher({expected.__name__}, match={_match_pattern(self.match)!r}))" + else: + self._fail_reason = old_reason + return False + + # do the full check on expected exceptions + if not self._check_exceptions( + exc_val, + actual_exceptions, + ): + assert self._fail_reason is not None + old_reason = self._fail_reason + # if we're not expecting a nested structure, and there is one, do a second + # pass where we try flattening it + if ( + not self.flatten_subgroups + and not any( + isinstance(e, RaisesGroup) for e in self.expected_exceptions + ) + and any(isinstance(e, BaseExceptionGroup) for e in actual_exceptions) + and self._check_exceptions( + exc_val, + self._unroll_exceptions(exc_val.exceptions), + ) + ): + # only indent if it's a single-line reason. In a multi-line there's already + # indented lines that this does not belong to. + indent = " " if "\n" not in self._fail_reason else "" + self._fail_reason = ( + old_reason + + f"\n{indent}Did you mean to use `flatten_subgroups=True`?" + ) + else: + self._fail_reason = old_reason + return False + + # Only run `self.check` once we know `exc_val` is of the correct type. + # TODO: if this fails, we should say the *group* did not match + return self._check_check(exc_val) + + @staticmethod + def _check_expected( + expected_type: ( + type[BaseException] | Matcher[BaseException] | RaisesGroup[BaseException] + ), + exception: BaseException, + ) -> str | None: + """Helper method for `RaisesGroup.matches` and `RaisesGroup._check_exceptions` + to check one of potentially several expected exceptions.""" + if isinstance(expected_type, type): + return _check_raw_type(expected_type, exception) + res = expected_type.matches(exception) + if res: + return None + assert expected_type.fail_reason is not None + if expected_type.fail_reason.startswith("\n"): + return f"\n{expected_type!r}: {indent(expected_type.fail_reason, ' ')}" + return f"{expected_type!r}: {expected_type.fail_reason}" + + @staticmethod + def _repr_expected(e: type[BaseException] | AbstractMatcher[BaseException]) -> str: + """Get the repr of an expected type/Matcher/RaisesGroup, but we only want + the name if it's a type""" + if isinstance(e, type): + return _exception_type_name(e) + return repr(e) + + @overload + def _check_exceptions( + self: RaisesGroup[ExcT_1], + _exc_val: Exception, + actual_exceptions: Sequence[Exception], + ) -> TypeGuard[ExceptionGroup[ExcT_1]]: ... + @overload + def _check_exceptions( + self: RaisesGroup[BaseExcT_1], + _exc_val: BaseException, + actual_exceptions: Sequence[BaseException], + ) -> TypeGuard[BaseExceptionGroup[BaseExcT_1]]: ... + + def _check_exceptions( + self, + _exc_val: BaseException, + actual_exceptions: Sequence[BaseException], + ) -> TypeGuard[BaseExceptionGroup[BaseExcT_co]]: + """helper method for RaisesGroup.matches that attempts to pair up expected and actual exceptions""" + # full table with all results + results = ResultHolder(self.expected_exceptions, actual_exceptions) + + # (indexes of) raised exceptions that haven't (yet) found an expected + remaining_actual = list(range(len(actual_exceptions))) + # (indexes of) expected exceptions that haven't found a matching raised + failed_expected: list[int] = [] + # successful greedy matches + matches: dict[int, int] = {} + + # loop over expected exceptions first to get a more predictable result + for i_exp, expected in enumerate(self.expected_exceptions): + for i_rem in remaining_actual: + res = self._check_expected(expected, actual_exceptions[i_rem]) + results.set_result(i_exp, i_rem, res) + if res is None: + remaining_actual.remove(i_rem) + matches[i_exp] = i_rem + break + else: + failed_expected.append(i_exp) + + # All exceptions matched up successfully + if not remaining_actual and not failed_expected: + return True + + # in case of a single expected and single raised we simplify the output + if 1 == len(actual_exceptions) == len(self.expected_exceptions): + assert not matches + self._fail_reason = res + return False + + # The test case is failing, so we can do a slow and exhaustive check to find + # duplicate matches etc that will be helpful in debugging + for i_exp, expected in enumerate(self.expected_exceptions): + for i_actual, actual in enumerate(actual_exceptions): + if results.has_result(i_exp, i_actual): + continue + results.set_result( + i_exp, i_actual, self._check_expected(expected, actual) + ) + + successful_str = ( + f"{len(matches)} matched exception{'s' if len(matches) > 1 else ''}. " + if matches + else "" + ) + + # all expected were found + if not failed_expected and results.no_match_for_actual(remaining_actual): + self._fail_reason = f"{successful_str}Unexpected exception(s): {[actual_exceptions[i] for i in remaining_actual]!r}" + return False + # all raised exceptions were expected + if not remaining_actual and results.no_match_for_expected(failed_expected): + self._fail_reason = f"{successful_str}Too few exceptions raised, found no match for: [{', '.join(self._repr_expected(self.expected_exceptions[i]) for i in failed_expected)}]" + return False + + # if there's only one remaining and one failed, and the unmatched didn't match anything else, + # we elect to only print why the remaining and the failed didn't match. + if ( + 1 == len(remaining_actual) == len(failed_expected) + and results.no_match_for_actual(remaining_actual) + and results.no_match_for_expected(failed_expected) + ): + self._fail_reason = f"{successful_str}{results.get_result(failed_expected[0], remaining_actual[0])}" + return False + + # there's both expected and raised exceptions without matches + s = "" + if matches: + s += f"\n{successful_str}" + indent_1 = " " * 2 + indent_2 = " " * 4 + + if not remaining_actual: + s += "\nToo few exceptions raised!" + elif not failed_expected: + s += "\nUnexpected exception(s)!" + + if failed_expected: + s += "\nThe following expected exceptions did not find a match:" + rev_matches = {v: k for k, v in matches.items()} + for i_failed in failed_expected: + s += ( + f"\n{indent_1}{self._repr_expected(self.expected_exceptions[i_failed])}" + ) + for i_actual, actual in enumerate(actual_exceptions): + if results.get_result(i_exp, i_actual) is None: + # we print full repr of match target + s += f"\n{indent_2}It matches {actual!r} which was paired with {self._repr_expected(self.expected_exceptions[rev_matches[i_actual]])}" + + if remaining_actual: + s += "\nThe following raised exceptions did not find a match" + for i_actual in remaining_actual: + s += f"\n{indent_1}{actual_exceptions[i_actual]!r}:" + for i_exp, expected in enumerate(self.expected_exceptions): + res = results.get_result(i_exp, i_actual) + if i_exp in failed_expected: + assert res is not None + if res[0] != "\n": + s += "\n" + s += indent(res, indent_2) + if res is None: + # we print full repr of match target + s += f"\n{indent_2}It matches {self._repr_expected(expected)} which was paired with {actual_exceptions[matches[i_exp]]!r}" + + if len(self.expected_exceptions) == len(actual_exceptions) and possible_match( + results + ): + s += "\nThere exist a possible match when attempting an exhaustive check, but RaisesGroup uses a greedy algorithm. Please make your expected exceptions more stringent with `Matcher` etc so the greedy algorithm can function." + self._fail_reason = s + return False + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> bool: + __tracebackhide__ = True + assert ( + exc_type is not None + ), f"DID NOT RAISE any exception, expected {self.expected_type()}" + assert ( + self.excinfo is not None + ), "Internal error - should have been constructed in __enter__" + + group_str = ( + "(group)" + if self.allow_unwrapped and not issubclass(exc_type, BaseExceptionGroup) + else "group" + ) + + assert self.matches( + exc_val, + ), f"Raised exception {group_str} did not match: {self._fail_reason}" + + # Cast to narrow the exception type now that it's verified. + exc_info = cast( + "tuple[type[BaseExceptionGroup[BaseExcT_co]], BaseExceptionGroup[BaseExcT_co], types.TracebackType]", + (exc_type, exc_val, exc_tb), + ) + self.excinfo.fill_unfilled(exc_info) + return True + + def expected_type(self) -> str: + subexcs = [] + for e in self.expected_exceptions: + if isinstance(e, Matcher): + subexcs.append(str(e)) + elif isinstance(e, RaisesGroup): + subexcs.append(e.expected_type()) + elif isinstance(e, type): + subexcs.append(e.__name__) + else: # pragma: no cover + raise AssertionError("unknown type") + group_type = "Base" if self.is_baseexceptiongroup else "" + return f"{group_type}ExceptionGroup({', '.join(subexcs)})" + + +@final +class NotChecked: ... + + +class ResultHolder: + def __init__( + self, + expected_exceptions: tuple[ + type[BaseException] | AbstractMatcher[BaseException], ... + ], + actual_exceptions: Sequence[BaseException], + ) -> None: + self.results: list[list[str | type[NotChecked] | None]] = [ + [NotChecked for _ in expected_exceptions] for _ in actual_exceptions + ] + + def set_result(self, expected: int, actual: int, result: str | None) -> None: + self.results[actual][expected] = result + + def get_result(self, expected: int, actual: int) -> str | None: + res = self.results[actual][expected] + # mypy doesn't support `assert res is not NotChecked` + assert not isinstance(res, type) + return res + + def has_result(self, expected: int, actual: int) -> bool: + return self.results[actual][expected] is not NotChecked + + def no_match_for_expected(self, expected: list[int]) -> bool: + for i in expected: + for actual_results in self.results: + assert actual_results[i] is not NotChecked + if actual_results[i] is None: + return False + return True + + def no_match_for_actual(self, actual: list[int]) -> bool: + for i in actual: + for res in self.results[i]: + assert res is not NotChecked + if res is None: + return False + return True + + +def possible_match(results: ResultHolder, used: set[int] | None = None) -> bool: + if used is None: + used = set() + curr_row = len(used) + if curr_row == len(results.results): + return True + + for i, val in enumerate(results.results[curr_row]): + if val is None and i not in used and possible_match(results, used | {i}): + return True + return False diff --git a/contrib/python/trio/trio/testing/_sequencer.py b/contrib/python/trio/trio/testing/_sequencer.py new file mode 100644 index 00000000000..32171cb2a27 --- /dev/null +++ b/contrib/python/trio/trio/testing/_sequencer.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from collections import defaultdict +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +import attrs + +from .. import Event, _core, _util + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + +@_util.final [email protected](eq=False, slots=False) +class Sequencer: + """A convenience class for forcing code in different tasks to run in an + explicit linear order. + + Instances of this class implement a ``__call__`` method which returns an + async context manager. The idea is that you pass a sequence number to + ``__call__`` to say where this block of code should go in the linear + sequence. Block 0 starts immediately, and then block N doesn't start until + block N-1 has finished. + + Example: + An extremely elaborate way to print the numbers 0-5, in order:: + + async def worker1(seq): + async with seq(0): + print(0) + async with seq(4): + print(4) + + async def worker2(seq): + async with seq(2): + print(2) + async with seq(5): + print(5) + + async def worker3(seq): + async with seq(1): + print(1) + async with seq(3): + print(3) + + async def main(): + seq = trio.testing.Sequencer() + async with trio.open_nursery() as nursery: + nursery.start_soon(worker1, seq) + nursery.start_soon(worker2, seq) + nursery.start_soon(worker3, seq) + + """ + + _sequence_points: defaultdict[int, Event] = attrs.field( + factory=lambda: defaultdict(Event), + init=False, + ) + _claimed: set[int] = attrs.field(factory=set, init=False) + _broken: bool = attrs.field(default=False, init=False) + + @asynccontextmanager + async def __call__(self, position: int) -> AsyncIterator[None]: + if position in self._claimed: + raise RuntimeError(f"Attempted to reuse sequence point {position}") + if self._broken: + raise RuntimeError("sequence broken!") + self._claimed.add(position) + if position != 0: + try: + await self._sequence_points[position].wait() + except _core.Cancelled: + self._broken = True + for event in self._sequence_points.values(): + event.set() + raise RuntimeError( + "Sequencer wait cancelled -- sequence broken", + ) from None + else: + if self._broken: + raise RuntimeError("sequence broken!") + try: + yield + finally: + self._sequence_points[position + 1].set() diff --git a/contrib/python/trio/trio/testing/_trio_test.py b/contrib/python/trio/trio/testing/_trio_test.py new file mode 100644 index 00000000000..226e5591966 --- /dev/null +++ b/contrib/python/trio/trio/testing/_trio_test.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from functools import partial, wraps +from typing import TYPE_CHECKING, TypeVar + +from .. import _core +from ..abc import Clock, Instrument + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from typing_extensions import ParamSpec + + ArgsT = ParamSpec("ArgsT") + + +RetT = TypeVar("RetT") + + +def trio_test(fn: Callable[ArgsT, Awaitable[RetT]]) -> Callable[ArgsT, RetT]: + """Converts an async test function to be synchronous, running via Trio. + + Usage:: + + @trio_test + async def test_whatever(): + await ... + + If a pytest fixture is passed in that subclasses the :class:`~trio.abc.Clock` or + :class:`~trio.abc.Instrument` ABCs, then those are passed to :meth:`trio.run()`. + """ + + @wraps(fn) + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: + __tracebackhide__ = True + clocks = [c for c in kwargs.values() if isinstance(c, Clock)] + if not clocks: + clock = None + elif len(clocks) == 1: + clock = clocks[0] + else: + raise ValueError("too many clocks spoil the broth!") + instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] + return _core.run( + partial(fn, *args, **kwargs), + clock=clock, + instruments=instruments, + ) + + return wrapper diff --git a/contrib/python/trio/trio/to_thread.py b/contrib/python/trio/trio/to_thread.py new file mode 100644 index 00000000000..45ea5b480bf --- /dev/null +++ b/contrib/python/trio/trio/to_thread.py @@ -0,0 +1,4 @@ +from ._threads import current_default_thread_limiter, to_thread_run_sync as run_sync + +# need to use __all__ for pyright --verifytypes to see re-exports when renaming them +__all__ = ["current_default_thread_limiter", "run_sync"] diff --git a/contrib/python/trio/ya.make b/contrib/python/trio/ya.make new file mode 100644 index 00000000000..6349037343b --- /dev/null +++ b/contrib/python/trio/ya.make @@ -0,0 +1,122 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(0.31.0) + +LICENSE(Apache-2.0 AND BSD-3-Clause AND LicenseRef-scancode-unknown-license-reference AND MIT AND "MIT OR Apache-2.0") + +PEERDIR( + contrib/python/attrs + contrib/python/idna + contrib/python/outcome + contrib/python/sniffio + contrib/python/sortedcontainers +) + +NO_LINT() + +NO_CHECK_IMPORTS( + trio._core._generated_windows_ffi + trio._core._io_kqueue + trio._core._io_windows + trio._core._windows_cffi + trio._subprocess_platform.windows + trio._tools.gen_exports + trio._tools.sync_requirements + trio._tools.windows_ffi_build + trio._wait_for_object + trio._windows_pipes +) + +PY_SRCS( + TOP_LEVEL + trio/__init__.py + trio/__main__.py + trio/_abc.py + trio/_channel.py + trio/_core/__init__.py + trio/_core/_asyncgens.py + trio/_core/_concat_tb.py + trio/_core/_entry_queue.py + trio/_core/_exceptions.py + trio/_core/_generated_instrumentation.py + trio/_core/_generated_io_epoll.py + trio/_core/_generated_io_kqueue.py + trio/_core/_generated_io_windows.py + trio/_core/_generated_run.py + trio/_core/_generated_windows_ffi.py + trio/_core/_instrumentation.py + trio/_core/_io_common.py + trio/_core/_io_epoll.py + trio/_core/_io_kqueue.py + trio/_core/_io_windows.py + trio/_core/_ki.py + trio/_core/_local.py + trio/_core/_mock_clock.py + trio/_core/_parking_lot.py + trio/_core/_run.py + trio/_core/_run_context.py + trio/_core/_thread_cache.py + trio/_core/_traps.py + trio/_core/_unbounded_queue.py + trio/_core/_wakeup_socketpair.py + trio/_core/_windows_cffi.py + trio/_deprecate.py + trio/_dtls.py + trio/_file_io.py + trio/_highlevel_generic.py + trio/_highlevel_open_tcp_listeners.py + trio/_highlevel_open_tcp_stream.py + trio/_highlevel_open_unix_stream.py + trio/_highlevel_serve_listeners.py + trio/_highlevel_socket.py + trio/_highlevel_ssl_helpers.py + trio/_path.py + trio/_repl.py + trio/_signals.py + trio/_socket.py + trio/_ssl.py + trio/_subprocess.py + trio/_subprocess_platform/__init__.py + trio/_subprocess_platform/kqueue.py + trio/_subprocess_platform/waitid.py + trio/_subprocess_platform/windows.py + trio/_sync.py + trio/_threads.py + trio/_timeouts.py + trio/_tools/__init__.py + trio/_tools/gen_exports.py + trio/_tools/mypy_annotate.py + trio/_tools/sync_requirements.py + trio/_tools/windows_ffi_build.py + trio/_unix_pipes.py + trio/_util.py + trio/_version.py + trio/_wait_for_object.py + trio/_windows_pipes.py + trio/abc.py + trio/from_thread.py + trio/lowlevel.py + trio/socket.py + trio/testing/__init__.py + trio/testing/_check_streams.py + trio/testing/_checkpoints.py + trio/testing/_fake_net.py + trio/testing/_memory_streams.py + trio/testing/_network.py + trio/testing/_raises_group.py + trio/testing/_sequencer.py + trio/testing/_trio_test.py + trio/to_thread.py +) + +RESOURCE_FILES( + PREFIX contrib/python/trio/ + .dist-info/METADATA + .dist-info/entry_points.txt + .dist-info/top_level.txt + trio/py.typed +) + +END() |
