diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2024-05-25 22:39:06 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2024-05-25 22:48:55 +0300 |
commit | 7ddcd63286aa8b7a10c462f4d29507790855d43b (patch) | |
tree | ff8411bc6137aa568453f070adb7a508072841fd | |
parent | 93dedde5f347af8bf358ce6a131662ed784220f4 (diff) | |
download | ydb-7ddcd63286aa8b7a10c462f4d29507790855d43b.tar.gz |
Intermediate changes
36 files changed, 3240 insertions, 2467 deletions
diff --git a/contrib/python/anyio/.dist-info/METADATA b/contrib/python/anyio/.dist-info/METADATA index 5e46476e02..e02715ca28 100644 --- a/contrib/python/anyio/.dist-info/METADATA +++ b/contrib/python/anyio/.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: anyio -Version: 3.7.1 +Version: 4.3.0 Summary: High level compatibility layer for multiple asynchronous event loop implementations Author-email: Alex Grönholm <alex.gronholm@nextday.fi> License: MIT @@ -15,36 +15,35 @@ Classifier: Framework :: AnyIO Classifier: Typing :: Typed Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 -Requires-Python: >=3.7 +Classifier: Programming Language :: Python :: 3.12 +Requires-Python: >=3.8 Description-Content-Type: text/x-rst License-File: LICENSE -Requires-Dist: idna (>=2.8) -Requires-Dist: sniffio (>=1.1) -Requires-Dist: exceptiongroup ; python_version < "3.11" -Requires-Dist: typing-extensions ; python_version < "3.8" +Requires-Dist: idna >=2.8 +Requires-Dist: sniffio >=1.1 +Requires-Dist: exceptiongroup >=1.0.2 ; python_version < "3.11" +Requires-Dist: typing-extensions >=4.1 ; python_version < "3.11" Provides-Extra: doc Requires-Dist: packaging ; extra == 'doc' -Requires-Dist: Sphinx ; extra == 'doc' -Requires-Dist: sphinx-rtd-theme (>=1.2.2) ; extra == 'doc' -Requires-Dist: sphinxcontrib-jquery ; extra == 'doc' -Requires-Dist: sphinx-autodoc-typehints (>=1.2.0) ; extra == 'doc' +Requires-Dist: Sphinx >=7 ; extra == 'doc' +Requires-Dist: sphinx-rtd-theme ; extra == 'doc' +Requires-Dist: sphinx-autodoc-typehints >=1.2.0 ; extra == 'doc' Provides-Extra: test Requires-Dist: anyio[trio] ; extra == 'test' -Requires-Dist: coverage[toml] (>=4.5) ; extra == 'test' -Requires-Dist: hypothesis (>=4.0) ; extra == 'test' -Requires-Dist: psutil (>=5.9) ; extra == 'test' -Requires-Dist: pytest (>=7.0) ; extra == 'test' -Requires-Dist: pytest-mock (>=3.6.1) ; extra == 'test' +Requires-Dist: coverage[toml] >=7 ; extra == 'test' +Requires-Dist: exceptiongroup >=1.2.0 ; extra == 'test' +Requires-Dist: hypothesis >=4.0 ; extra == 'test' +Requires-Dist: psutil >=5.9 ; extra == 'test' +Requires-Dist: pytest >=7.0 ; extra == 'test' +Requires-Dist: pytest-mock >=3.6.1 ; extra == 'test' Requires-Dist: trustme ; extra == 'test' -Requires-Dist: uvloop (>=0.17) ; (python_version < "3.12" and platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test' -Requires-Dist: mock (>=4) ; (python_version < "3.8") and extra == 'test' +Requires-Dist: uvloop >=0.17 ; (platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test' Provides-Extra: trio -Requires-Dist: trio (<0.22) ; extra == 'trio' +Requires-Dist: trio >=0.23 ; extra == 'trio' .. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg :target: https://github.com/agronholm/anyio/actions/workflows/test.yml diff --git a/contrib/python/anyio/anyio/__init__.py b/contrib/python/anyio/anyio/__init__.py index 29fb3561e4..7bfe231645 100644 --- a/contrib/python/anyio/anyio/__init__.py +++ b/contrib/python/anyio/anyio/__init__.py @@ -1,165 +1,72 @@ from __future__ import annotations -__all__ = ( - "maybe_async", - "maybe_async_cm", - "run", - "sleep", - "sleep_forever", - "sleep_until", - "current_time", - "get_all_backends", - "get_cancelled_exc_class", - "BrokenResourceError", - "BrokenWorkerProcess", - "BusyResourceError", - "ClosedResourceError", - "DelimiterNotFound", - "EndOfStream", - "ExceptionGroup", - "IncompleteRead", - "TypedAttributeLookupError", - "WouldBlock", - "AsyncFile", - "Path", - "open_file", - "wrap_file", - "aclose_forcefully", - "open_signal_receiver", - "connect_tcp", - "connect_unix", - "create_tcp_listener", - "create_unix_listener", - "create_udp_socket", - "create_connected_udp_socket", - "getaddrinfo", - "getnameinfo", - "wait_socket_readable", - "wait_socket_writable", - "create_memory_object_stream", - "run_process", - "open_process", - "create_lock", - "CapacityLimiter", - "CapacityLimiterStatistics", - "Condition", - "ConditionStatistics", - "Event", - "EventStatistics", - "Lock", - "LockStatistics", - "Semaphore", - "SemaphoreStatistics", - "create_condition", - "create_event", - "create_semaphore", - "create_capacity_limiter", - "open_cancel_scope", - "fail_after", - "move_on_after", - "current_effective_deadline", - "TASK_STATUS_IGNORED", - "CancelScope", - "create_task_group", - "TaskInfo", - "get_current_task", - "get_running_tasks", - "wait_all_tasks_blocked", - "run_sync_in_worker_thread", - "run_async_from_thread", - "run_sync_from_thread", - "current_default_worker_thread_limiter", - "create_blocking_portal", - "start_blocking_portal", - "typed_attribute", - "TypedAttributeSet", - "TypedAttributeProvider", -) - from typing import Any -from ._core._compat import maybe_async, maybe_async_cm -from ._core._eventloop import ( - current_time, - get_all_backends, - get_cancelled_exc_class, - run, - sleep, - sleep_forever, - sleep_until, -) -from ._core._exceptions import ( - BrokenResourceError, - BrokenWorkerProcess, - BusyResourceError, - ClosedResourceError, - DelimiterNotFound, - EndOfStream, - ExceptionGroup, - IncompleteRead, - TypedAttributeLookupError, - WouldBlock, -) -from ._core._fileio import AsyncFile, Path, open_file, wrap_file -from ._core._resources import aclose_forcefully -from ._core._signals import open_signal_receiver +from ._core._eventloop import current_time as current_time +from ._core._eventloop import get_all_backends as get_all_backends +from ._core._eventloop import get_cancelled_exc_class as get_cancelled_exc_class +from ._core._eventloop import run as run +from ._core._eventloop import sleep as sleep +from ._core._eventloop import sleep_forever as sleep_forever +from ._core._eventloop import sleep_until as sleep_until +from ._core._exceptions import BrokenResourceError as BrokenResourceError +from ._core._exceptions import BrokenWorkerProcess as BrokenWorkerProcess +from ._core._exceptions import BusyResourceError as BusyResourceError +from ._core._exceptions import ClosedResourceError as ClosedResourceError +from ._core._exceptions import DelimiterNotFound as DelimiterNotFound +from ._core._exceptions import EndOfStream as EndOfStream +from ._core._exceptions import IncompleteRead as IncompleteRead +from ._core._exceptions import TypedAttributeLookupError as TypedAttributeLookupError +from ._core._exceptions import WouldBlock as WouldBlock +from ._core._fileio import AsyncFile as AsyncFile +from ._core._fileio import Path as Path +from ._core._fileio import open_file as open_file +from ._core._fileio import wrap_file as wrap_file +from ._core._resources import aclose_forcefully as aclose_forcefully +from ._core._signals import open_signal_receiver as open_signal_receiver +from ._core._sockets import connect_tcp as connect_tcp +from ._core._sockets import connect_unix as connect_unix +from ._core._sockets import create_connected_udp_socket as create_connected_udp_socket from ._core._sockets import ( - connect_tcp, - connect_unix, - create_connected_udp_socket, - create_tcp_listener, - create_udp_socket, - create_unix_listener, - getaddrinfo, - getnameinfo, - wait_socket_readable, - wait_socket_writable, + create_connected_unix_datagram_socket as create_connected_unix_datagram_socket, ) -from ._core._streams import create_memory_object_stream -from ._core._subprocesses import open_process, run_process +from ._core._sockets import create_tcp_listener as create_tcp_listener +from ._core._sockets import create_udp_socket as create_udp_socket +from ._core._sockets import create_unix_datagram_socket as create_unix_datagram_socket +from ._core._sockets import create_unix_listener as create_unix_listener +from ._core._sockets import getaddrinfo as getaddrinfo +from ._core._sockets import getnameinfo as getnameinfo +from ._core._sockets import wait_socket_readable as wait_socket_readable +from ._core._sockets import wait_socket_writable as wait_socket_writable +from ._core._streams import create_memory_object_stream as create_memory_object_stream +from ._core._subprocesses import open_process as open_process +from ._core._subprocesses import run_process as run_process +from ._core._synchronization import CapacityLimiter as CapacityLimiter from ._core._synchronization import ( - CapacityLimiter, - CapacityLimiterStatistics, - Condition, - ConditionStatistics, - Event, - EventStatistics, - Lock, - LockStatistics, - Semaphore, - SemaphoreStatistics, - create_capacity_limiter, - create_condition, - create_event, - create_lock, - create_semaphore, -) -from ._core._tasks import ( - TASK_STATUS_IGNORED, - CancelScope, - create_task_group, - current_effective_deadline, - fail_after, - move_on_after, - open_cancel_scope, -) -from ._core._testing import ( - TaskInfo, - get_current_task, - get_running_tasks, - wait_all_tasks_blocked, -) -from ._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute - -# Re-exported here, for backwards compatibility -# isort: off -from .to_thread import current_default_worker_thread_limiter, run_sync_in_worker_thread -from .from_thread import ( - create_blocking_portal, - run_async_from_thread, - run_sync_from_thread, - start_blocking_portal, + CapacityLimiterStatistics as CapacityLimiterStatistics, ) +from ._core._synchronization import Condition as Condition +from ._core._synchronization import ConditionStatistics as ConditionStatistics +from ._core._synchronization import Event as Event +from ._core._synchronization import EventStatistics as EventStatistics +from ._core._synchronization import Lock as Lock +from ._core._synchronization import LockStatistics as LockStatistics +from ._core._synchronization import ResourceGuard as ResourceGuard +from ._core._synchronization import Semaphore as Semaphore +from ._core._synchronization import SemaphoreStatistics as SemaphoreStatistics +from ._core._tasks import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED +from ._core._tasks import CancelScope as CancelScope +from ._core._tasks import create_task_group as create_task_group +from ._core._tasks import current_effective_deadline as current_effective_deadline +from ._core._tasks import fail_after as fail_after +from ._core._tasks import move_on_after as move_on_after +from ._core._testing import TaskInfo as TaskInfo +from ._core._testing import get_current_task as get_current_task +from ._core._testing import get_running_tasks as get_running_tasks +from ._core._testing import wait_all_tasks_blocked as wait_all_tasks_blocked +from ._core._typedattr import TypedAttributeProvider as TypedAttributeProvider +from ._core._typedattr import TypedAttributeSet as TypedAttributeSet +from ._core._typedattr import typed_attribute as typed_attribute # Re-export imports so they look like they live directly in this package key: str diff --git a/contrib/python/anyio/anyio/_backends/_asyncio.py b/contrib/python/anyio/anyio/_backends/_asyncio.py index bfdb4ea7e1..2699bf8146 100644 --- a/contrib/python/anyio/anyio/_backends/_asyncio.py +++ b/contrib/python/anyio/anyio/_backends/_asyncio.py @@ -6,23 +6,34 @@ import concurrent.futures import math import socket import sys +import threading +from asyncio import ( + AbstractEventLoop, + CancelledError, + all_tasks, + create_task, + current_task, + get_running_loop, + sleep, +) from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined] from collections import OrderedDict, deque +from collections.abc import AsyncIterator, Generator, Iterable from concurrent.futures import Future +from contextlib import suppress from contextvars import Context, copy_context from dataclasses import dataclass from functools import partial, wraps from inspect import ( CORO_RUNNING, CORO_SUSPENDED, - GEN_RUNNING, - GEN_SUSPENDED, getcoroutinestate, - getgeneratorstate, + iscoroutine, ) from io import IOBase from os import PathLike from queue import Queue +from signal import Signals from socket import AddressFamily, SocketKind from threading import Thread from types import TracebackType @@ -33,15 +44,13 @@ from typing import ( Awaitable, Callable, Collection, + ContextManager, Coroutine, - Generator, - Iterable, Mapping, Optional, Sequence, Tuple, TypeVar, - Union, cast, ) from weakref import WeakKeyDictionary @@ -49,7 +58,6 @@ from weakref import WeakKeyDictionary import sniffio from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc -from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable from .._core._eventloop import claim_worker_thread, threadlocals from .._core._exceptions import ( BrokenResourceError, @@ -58,40 +66,220 @@ from .._core._exceptions import ( EndOfStream, WouldBlock, ) -from .._core._exceptions import ExceptionGroup as BaseExceptionGroup -from .._core._sockets import GetAddrInfoReturnType, convert_ipv6_sockaddr +from .._core._sockets import convert_ipv6_sockaddr +from .._core._streams import create_memory_object_stream from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter from .._core._synchronization import Event as BaseEvent from .._core._synchronization import ResourceGuard from .._core._tasks import CancelScope as BaseCancelScope -from ..abc import IPSockAddrType, UDPPacketType +from ..abc import ( + AsyncBackend, + IPSockAddrType, + SocketListener, + UDPPacketType, + UNIXDatagramPacketType, +) from ..lowlevel import RunVar +from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -if sys.version_info >= (3, 8): - - def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: - return task.get_coro() +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec +if sys.version_info >= (3, 11): + from asyncio import Runner + from typing import TypeVarTuple, Unpack else: + import contextvars + import enum + import signal + from asyncio import coroutines, events, exceptions, tasks + + from exceptiongroup import BaseExceptionGroup + from typing_extensions import TypeVarTuple, Unpack + + class _State(enum.Enum): + CREATED = "created" + INITIALIZED = "initialized" + CLOSED = "closed" + + class Runner: + # Copied from CPython 3.11 + def __init__( + self, + *, + debug: bool | None = None, + loop_factory: Callable[[], AbstractEventLoop] | None = None, + ): + self._state = _State.CREATED + self._debug = debug + self._loop_factory = loop_factory + self._loop: AbstractEventLoop | None = None + self._context = None + self._interrupt_count = 0 + self._set_event_loop = False + + def __enter__(self) -> Runner: + self._lazy_init() + return self + + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + self.close() + + def close(self) -> None: + """Shutdown and close event loop.""" + if self._state is not _State.INITIALIZED: + return + try: + loop = self._loop + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + if hasattr(loop, "shutdown_default_executor"): + loop.run_until_complete(loop.shutdown_default_executor()) + else: + loop.run_until_complete(_shutdown_default_executor(loop)) + finally: + if self._set_event_loop: + events.set_event_loop(None) + loop.close() + self._loop = None + self._state = _State.CLOSED + + def get_loop(self) -> AbstractEventLoop: + """Return embedded event loop.""" + self._lazy_init() + return self._loop + + def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval: + """Run a coroutine inside the embedded event loop.""" + if not coroutines.iscoroutine(coro): + raise ValueError(f"a coroutine was expected, got {coro!r}") + + if events._get_running_loop() is not None: + # fail fast with short traceback + raise RuntimeError( + "Runner.run() cannot be called from a running event loop" + ) + + self._lazy_init() - def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: - return task._coro + if context is None: + context = self._context + task = context.run(self._loop.create_task, coro) + if ( + threading.current_thread() is threading.main_thread() + and signal.getsignal(signal.SIGINT) is signal.default_int_handler + ): + sigint_handler = partial(self._on_sigint, main_task=task) + try: + signal.signal(signal.SIGINT, sigint_handler) + except ValueError: + # `signal.signal` may throw if `threading.main_thread` does + # not support signals (e.g. embedded interpreter with signals + # not registered - see gh-91880) + sigint_handler = None + else: + sigint_handler = None -from asyncio import all_tasks, create_task, current_task, get_running_loop -from asyncio import run as native_run + self._interrupt_count = 0 + try: + return self._loop.run_until_complete(task) + except exceptions.CancelledError: + if self._interrupt_count > 0: + uncancel = getattr(task, "uncancel", None) + if uncancel is not None and uncancel() == 0: + raise KeyboardInterrupt() + raise # CancelledError + finally: + if ( + sigint_handler is not None + and signal.getsignal(signal.SIGINT) is sigint_handler + ): + signal.signal(signal.SIGINT, signal.default_int_handler) + def _lazy_init(self) -> None: + if self._state is _State.CLOSED: + raise RuntimeError("Runner is closed") + if self._state is _State.INITIALIZED: + return + if self._loop_factory is None: + self._loop = events.new_event_loop() + if not self._set_event_loop: + # Call set_event_loop only once to avoid calling + # attach_loop multiple times on child watchers + events.set_event_loop(self._loop) + self._set_event_loop = True + else: + self._loop = self._loop_factory() + if self._debug is not None: + self._loop.set_debug(self._debug) + self._context = contextvars.copy_context() + self._state = _State.INITIALIZED + + def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None: + self._interrupt_count += 1 + if self._interrupt_count == 1 and not main_task.done(): + main_task.cancel() + # wakeup loop if it is blocked by select() with long timeout + self._loop.call_soon_threadsafe(lambda: None) + return + raise KeyboardInterrupt() -def _get_task_callbacks(task: asyncio.Task) -> Iterable[Callable]: - return [cb for cb, context in task._callbacks] + def _cancel_all_tasks(loop: AbstractEventLoop) -> None: + to_cancel = tasks.all_tasks(loop) + if not to_cancel: + return + for task in to_cancel: + task.cancel() -T_Retval = TypeVar("T_Retval") -T_contra = TypeVar("T_contra", contravariant=True) + loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True)) -# Check whether there is native support for task names in asyncio (3.8+) -_native_task_names = hasattr(asyncio.Task, "get_name") + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + async def _shutdown_default_executor(loop: AbstractEventLoop) -> None: + """Schedule the shutdown of the default executor.""" + + def _do_shutdown(future: asyncio.futures.Future) -> None: + try: + loop._default_executor.shutdown(wait=True) # type: ignore[attr-defined] + loop.call_soon_threadsafe(future.set_result, None) + except Exception as ex: + loop.call_soon_threadsafe(future.set_exception, ex) + + loop._executor_shutdown_called = True + if loop._default_executor is None: + return + future = loop.create_future() + thread = threading.Thread(target=_do_shutdown, args=(future,)) + thread.start() + try: + await future + finally: + thread.join() + + +T_Retval = TypeVar("T_Retval") +T_contra = TypeVar("T_contra", contravariant=True) +PosArgsT = TypeVarTuple("PosArgsT") +P = ParamSpec("P") _root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") @@ -104,7 +292,8 @@ def find_root_task() -> asyncio.Task: # Look for a task that has been started via run_until_complete() for task in all_tasks(): if task._callbacks and not task.done(): - for cb in _get_task_callbacks(task): + callbacks = [cb for cb, context in task._callbacks] + for cb in callbacks: if ( cb is _run_until_complete_cb or getattr(cb, "__module__", None) == "uvloop.loop" @@ -136,87 +325,22 @@ def get_callable_name(func: Callable) -> str: # Event loop # -_run_vars = ( - WeakKeyDictionary() -) # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] - -current_token = get_running_loop +_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary() def _task_started(task: asyncio.Task) -> bool: """Return ``True`` if the task has been started and has not finished.""" - coro = cast(Coroutine[Any, Any, Any], get_coro(task)) try: - return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) + return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED) except AttributeError: - try: - return getgeneratorstate(cast(Generator, coro)) in ( - GEN_RUNNING, - GEN_SUSPENDED, - ) - except AttributeError: - # task coro is async_genenerator_asend https://bugs.python.org/issue37771 - raise Exception(f"Cannot determine if task {task} has started or not") - - -def _maybe_set_event_loop_policy( - policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool -) -> None: - # On CPython, use uvloop when possible if no other policy has been given and if not - # explicitly disabled - if policy is None and use_uvloop and sys.implementation.name == "cpython": - try: - import uvloop - except ImportError: - pass - else: - # Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier) - if not hasattr( - asyncio.AbstractEventLoop, "shutdown_default_executor" - ) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"): - policy = uvloop.EventLoopPolicy() - - if policy is not None: - asyncio.set_event_loop_policy(policy) - - -def run( - func: Callable[..., Awaitable[T_Retval]], - *args: object, - debug: bool = False, - use_uvloop: bool = False, - policy: asyncio.AbstractEventLoopPolicy | None = None, -) -> T_Retval: - @wraps(func) - async def wrapper() -> T_Retval: - task = cast(asyncio.Task, current_task()) - task_state = TaskState(None, get_callable_name(func), None) - _task_states[task] = task_state - if _native_task_names: - task.set_name(task_state.name) - - try: - return await func(*args) - finally: - del _task_states[task] - - _maybe_set_event_loop_policy(policy, use_uvloop) - return native_run(wrapper(), debug=debug) - - -# -# Miscellaneous -# - -sleep = asyncio.sleep + # task coro is async_genenerator_asend https://bugs.python.org/issue37771 + raise Exception(f"Cannot determine if task {task} has started or not") from None # # Timeouts and cancellation # -CancelledError = asyncio.CancelledError - class CancelScope(BaseCancelScope): def __new__( @@ -228,14 +352,16 @@ class CancelScope(BaseCancelScope): self._deadline = deadline self._shield = shield self._parent_scope: CancelScope | None = None + self._child_scopes: set[CancelScope] = set() self._cancel_called = False + self._cancelled_caught = False self._active = False self._timeout_handle: asyncio.TimerHandle | None = None self._cancel_handle: asyncio.Handle | None = None self._tasks: set[asyncio.Task] = set() self._host_task: asyncio.Task | None = None - self._timeout_expired = False self._cancel_calls: int = 0 + self._cancelling: int | None = None def __enter__(self) -> CancelScope: if self._active: @@ -248,19 +374,23 @@ class CancelScope(BaseCancelScope): try: task_state = _task_states[host_task] except KeyError: - task_name = host_task.get_name() if _native_task_names else None - task_state = TaskState(None, task_name, self) + task_state = TaskState(None, self) _task_states[host_task] = task_state else: self._parent_scope = task_state.cancel_scope task_state.cancel_scope = self + if self._parent_scope is not None: + self._parent_scope._child_scopes.add(self) + self._parent_scope._tasks.remove(host_task) self._timeout() self._active = True + if sys.version_info >= (3, 11): + self._cancelling = self._host_task.cancelling() # Start cancelling the host task if the scope was cancelled before entering if self._cancel_called: - self._deliver_cancellation() + self._deliver_cancellation(self) return self @@ -292,56 +422,60 @@ class CancelScope(BaseCancelScope): self._timeout_handle = None self._tasks.remove(self._host_task) + if self._parent_scope is not None: + self._parent_scope._child_scopes.remove(self) + self._parent_scope._tasks.add(self._host_task) host_task_state.cancel_scope = self._parent_scope - # Restart the cancellation effort in the farthest directly cancelled parent scope if this - # one was shielded - if self._shield: - self._deliver_cancellation_to_parent() + # Restart the cancellation effort in the closest directly cancelled parent + # scope if this one was shielded + self._restart_cancellation_in_parent() - if exc_val is not None: - exceptions = ( - exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val] - ) - if all(isinstance(exc, CancelledError) for exc in exceptions): - if self._timeout_expired: - return self._uncancel() - elif not self._cancel_called: - # Task was cancelled natively - return None - elif not self._parent_cancelled(): - # This scope was directly cancelled - return self._uncancel() + if self._cancel_called and exc_val is not None: + for exc in iterate_exceptions(exc_val): + if isinstance(exc, CancelledError): + self._cancelled_caught = self._uncancel(exc) + if self._cancelled_caught: + break + + return self._cancelled_caught return None - def _uncancel(self) -> bool: - if sys.version_info < (3, 11) or self._host_task is None: + def _uncancel(self, cancelled_exc: CancelledError) -> bool: + if sys.version_info < (3, 9) or self._host_task is None: self._cancel_calls = 0 return True - # Uncancel all AnyIO cancellations - for i in range(self._cancel_calls): - self._host_task.uncancel() + # Undo all cancellations done by this scope + if self._cancelling is not None: + while self._cancel_calls: + self._cancel_calls -= 1 + if self._host_task.uncancel() <= self._cancelling: + return True self._cancel_calls = 0 - return not self._host_task.cancelling() + return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args def _timeout(self) -> None: if self._deadline != math.inf: loop = get_running_loop() if loop.time() >= self._deadline: - self._timeout_expired = True self.cancel() else: self._timeout_handle = loop.call_at(self._deadline, self._timeout) - def _deliver_cancellation(self) -> None: + def _deliver_cancellation(self, origin: CancelScope) -> bool: """ Deliver cancellation to directly contained tasks and nested cancel scopes. - Schedule another run at the end if we still have tasks eligible for cancellation. + Schedule another run at the end if we still have tasks eligible for + cancellation. + + :param origin: the cancel scope that originated the cancellation + :return: ``True`` if the delivery needs to be retried on the next cycle + """ should_retry = False current = current_task() @@ -349,37 +483,46 @@ class CancelScope(BaseCancelScope): if task._must_cancel: # type: ignore[attr-defined] continue - # The task is eligible for cancellation if it has started and is not in a cancel - # scope shielded from this one - cancel_scope = _task_states[task].cancel_scope - while cancel_scope is not self: - if cancel_scope is None or cancel_scope._shield: - break - else: - cancel_scope = cancel_scope._parent_scope - else: - should_retry = True - if task is not current and ( - task is self._host_task or _task_started(task) - ): + # The task is eligible for cancellation if it has started + should_retry = True + if task is not current and (task is self._host_task or _task_started(task)): + waiter = task._fut_waiter # type: ignore[attr-defined] + if not isinstance(waiter, asyncio.Future) or not waiter.done(): self._cancel_calls += 1 - task.cancel() + if sys.version_info >= (3, 9): + task.cancel(f"Cancelled by cancel scope {id(origin):x}") + else: + task.cancel() + + # Deliver cancellation to child scopes that aren't shielded or running their own + # cancellation callbacks + for scope in self._child_scopes: + if not scope._shield and not scope.cancel_called: + should_retry = scope._deliver_cancellation(origin) or should_retry # Schedule another callback if there are still tasks left - if should_retry: - self._cancel_handle = get_running_loop().call_soon( - self._deliver_cancellation - ) - else: - self._cancel_handle = None + if origin is self: + if should_retry: + self._cancel_handle = get_running_loop().call_soon( + self._deliver_cancellation, origin + ) + else: + self._cancel_handle = None - def _deliver_cancellation_to_parent(self) -> None: - """Start cancellation effort in the farthest directly cancelled parent scope""" + return should_retry + + def _restart_cancellation_in_parent(self) -> None: + """ + Restart the cancellation effort in the closest directly cancelled parent scope. + + """ scope = self._parent_scope - scope_to_cancel: CancelScope | None = None while scope is not None: - if scope._cancel_called and scope._cancel_handle is None: - scope_to_cancel = scope + if scope._cancel_called: + if scope._cancel_handle is None: + scope._deliver_cancellation(scope) + + break # No point in looking beyond any shielded scope if scope._shield: @@ -387,9 +530,6 @@ class CancelScope(BaseCancelScope): scope = scope._parent_scope - if scope_to_cancel is not None: - scope_to_cancel._deliver_cancellation() - def _parent_cancelled(self) -> bool: # Check whether any parent has been cancelled cancel_scope = self._parent_scope @@ -401,7 +541,7 @@ class CancelScope(BaseCancelScope): return False - def cancel(self) -> DeprecatedAwaitable: + def cancel(self) -> None: if not self._cancel_called: if self._timeout_handle: self._timeout_handle.cancel() @@ -409,9 +549,7 @@ class CancelScope(BaseCancelScope): self._cancel_called = True if self._host_task is not None: - self._deliver_cancellation() - - return DeprecatedAwaitable(self.cancel) + self._deliver_cancellation(self) @property def deadline(self) -> float: @@ -432,6 +570,10 @@ class CancelScope(BaseCancelScope): return self._cancel_called @property + def cancelled_caught(self) -> bool: + return self._cancelled_caught + + @property def shield(self) -> bool: return self._shield @@ -440,59 +582,7 @@ class CancelScope(BaseCancelScope): if self._shield != value: self._shield = value if not value: - self._deliver_cancellation_to_parent() - - -async def checkpoint() -> None: - await sleep(0) - - -async def checkpoint_if_cancelled() -> None: - task = current_task() - if task is None: - return - - try: - cancel_scope = _task_states[task].cancel_scope - except KeyError: - return - - while cancel_scope: - if cancel_scope.cancel_called: - await sleep(0) - elif cancel_scope.shield: - break - else: - cancel_scope = cancel_scope._parent_scope - - -async def cancel_shielded_checkpoint() -> None: - with CancelScope(shield=True): - await sleep(0) - - -def current_effective_deadline() -> float: - try: - cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index] - except KeyError: - return math.inf - - deadline = math.inf - while cancel_scope: - deadline = min(deadline, cancel_scope.deadline) - if cancel_scope._cancel_called: - deadline = -math.inf - break - elif cancel_scope.shield: - break - else: - cancel_scope = cancel_scope._parent_scope - - return deadline - - -def current_time() -> float: - return get_running_loop().time() + self._restart_cancellation_in_parent() # @@ -502,20 +592,14 @@ def current_time() -> float: class TaskState: """ - Encapsulates auxiliary task information that cannot be added to the Task instance itself - because there are no guarantees about its implementation. + Encapsulates auxiliary task information that cannot be added to the Task instance + itself because there are no guarantees about its implementation. """ - __slots__ = "parent_id", "name", "cancel_scope" + __slots__ = "parent_id", "cancel_scope" - def __init__( - self, - parent_id: int | None, - name: str | None, - cancel_scope: CancelScope | None, - ): + def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): self.parent_id = parent_id - self.name = name self.cancel_scope = cancel_scope @@ -527,12 +611,6 @@ _task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, Task # -class ExceptionGroup(BaseExceptionGroup): - def __init__(self, exceptions: list[BaseException]): - super().__init__() - self.exceptions = exceptions - - class _AsyncioTaskStatus(abc.TaskStatus): def __init__(self, future: asyncio.Future, parent_id: int): self._future = future @@ -550,11 +628,22 @@ class _AsyncioTaskStatus(abc.TaskStatus): _task_states[task].parent_id = self._parent_id +def iterate_exceptions( + exception: BaseException, +) -> Generator[BaseException, None, None]: + if isinstance(exception, BaseExceptionGroup): + for exc in exception.exceptions: + yield from iterate_exceptions(exc) + else: + yield exception + + class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self.cancel_scope: CancelScope = CancelScope() self._active = False self._exceptions: list[BaseException] = [] + self._tasks: set[asyncio.Task] = set() async def __aenter__(self) -> TaskGroup: self.cancel_scope.__enter__() @@ -570,98 +659,49 @@ class TaskGroup(abc.TaskGroup): ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) if exc_val is not None: self.cancel_scope.cancel() - self._exceptions.append(exc_val) + if not isinstance(exc_val, CancelledError): + self._exceptions.append(exc_val) - while self.cancel_scope._tasks: + cancelled_exc_while_waiting_tasks: CancelledError | None = None + while self._tasks: try: - await asyncio.wait(self.cancel_scope._tasks) - except asyncio.CancelledError: + await asyncio.wait(self._tasks) + except CancelledError as exc: + # This task was cancelled natively; reraise the CancelledError later + # unless this task was already interrupted by another exception self.cancel_scope.cancel() + if cancelled_exc_while_waiting_tasks is None: + cancelled_exc_while_waiting_tasks = exc self._active = False - if not self.cancel_scope._parent_cancelled(): - exceptions = self._filter_cancellation_errors(self._exceptions) - else: - exceptions = self._exceptions + if self._exceptions: + raise BaseExceptionGroup( + "unhandled errors in a TaskGroup", self._exceptions + ) - try: - if len(exceptions) > 1: - if all( - isinstance(e, CancelledError) and not e.args for e in exceptions - ): - # Tasks were cancelled natively, without a cancellation message - raise CancelledError - else: - raise ExceptionGroup(exceptions) - elif exceptions and exceptions[0] is not exc_val: - raise exceptions[0] - except BaseException as exc: - # Clear the context here, as it can only be done in-flight. - # If the context is not cleared, it can result in recursive tracebacks (see #145). - exc.__context__ = None - raise + # Raise the CancelledError received while waiting for child tasks to exit, + # unless the context manager itself was previously exited with another + # exception, or if any of the child tasks raised an exception other than + # CancelledError + if cancelled_exc_while_waiting_tasks: + if exc_val is None or ignore_exception: + raise cancelled_exc_while_waiting_tasks return ignore_exception - @staticmethod - def _filter_cancellation_errors( - exceptions: Sequence[BaseException], - ) -> list[BaseException]: - filtered_exceptions: list[BaseException] = [] - for exc in exceptions: - if isinstance(exc, ExceptionGroup): - new_exceptions = TaskGroup._filter_cancellation_errors(exc.exceptions) - if len(new_exceptions) > 1: - filtered_exceptions.append(exc) - elif len(new_exceptions) == 1: - filtered_exceptions.append(new_exceptions[0]) - elif new_exceptions: - new_exc = ExceptionGroup(new_exceptions) - new_exc.__cause__ = exc.__cause__ - new_exc.__context__ = exc.__context__ - new_exc.__traceback__ = exc.__traceback__ - filtered_exceptions.append(new_exc) - elif not isinstance(exc, CancelledError) or exc.args: - filtered_exceptions.append(exc) - - return filtered_exceptions - - async def _run_wrapped_task( - self, coro: Coroutine, task_status_future: asyncio.Future | None - ) -> None: - # This is the code path for Python 3.7 on which asyncio freaks out if a task - # raises a BaseException. - __traceback_hide__ = __tracebackhide__ = True # noqa: F841 - task = cast(asyncio.Task, current_task()) - try: - await coro - except BaseException as exc: - if task_status_future is None or task_status_future.done(): - self._exceptions.append(exc) - self.cancel_scope.cancel() - else: - task_status_future.set_exception(exc) - else: - if task_status_future is not None and not task_status_future.done(): - task_status_future.set_exception( - RuntimeError("Child exited without calling task_status.started()") - ) - finally: - if task in self.cancel_scope._tasks: - self.cancel_scope._tasks.remove(task) - del _task_states[task] - def _spawn( self, - func: Callable[..., Awaitable[Any]], - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + args: tuple[Unpack[PosArgsT]], name: object, task_status_future: asyncio.Future | None = None, ) -> asyncio.Task: def task_done(_task: asyncio.Task) -> None: - # This is the code path for Python 3.8+ - assert _task in self.cancel_scope._tasks - self.cancel_scope._tasks.remove(_task) + task_state = _task_states[_task] + assert task_state.cancel_scope is not None + assert _task in task_state.cancel_scope._tasks + task_state.cancel_scope._tasks.remove(_task) + self._tasks.remove(task) del _task_states[_task] try: @@ -674,8 +714,11 @@ class TaskGroup(abc.TaskGroup): if exc is not None: if task_status_future is None or task_status_future.done(): - self._exceptions.append(exc) - self.cancel_scope.cancel() + if not isinstance(exc, CancelledError): + self._exceptions.append(exc) + + if not self.cancel_scope._parent_cancelled(): + self.cancel_scope.cancel() else: task_status_future.set_exception(exc) elif task_status_future is not None and not task_status_future.done(): @@ -688,11 +731,6 @@ class TaskGroup(abc.TaskGroup): "This task group is not active; no new tasks can be started." ) - options: dict[str, Any] = {} - name = get_callable_name(func) if name is None else str(name) - if _native_task_names: - options["name"] = name - kwargs = {} if task_status_future: parent_id = id(current_task()) @@ -703,46 +741,52 @@ class TaskGroup(abc.TaskGroup): parent_id = id(self.cancel_scope._host_task) coro = func(*args, **kwargs) - if not asyncio.iscoroutine(coro): + if not iscoroutine(coro): + prefix = f"{func.__module__}." if hasattr(func, "__module__") else "" raise TypeError( - f"Expected an async function, but {func} appears to be synchronous" + f"Expected {prefix}{func.__qualname__}() to return a coroutine, but " + f"the return value ({coro!r}) is not a coroutine object" ) - foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame") - if foreign_coro or sys.version_info < (3, 8): - coro = self._run_wrapped_task(coro, task_status_future) - - task = create_task(coro, **options) - if not foreign_coro and sys.version_info >= (3, 8): - task.add_done_callback(task_done) + name = get_callable_name(func) if name is None else str(name) + task = create_task(coro, name=name) + task.add_done_callback(task_done) # Make the spawned task inherit the task group's cancel scope _task_states[task] = TaskState( - parent_id=parent_id, name=name, cancel_scope=self.cancel_scope + parent_id=parent_id, cancel_scope=self.cancel_scope ) self.cancel_scope._tasks.add(task) + self._tasks.add(task) return task def start_soon( - self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + name: object = None, ) -> None: self._spawn(func, args, name) async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None - ) -> None: + ) -> Any: future: asyncio.Future = asyncio.Future() task = self._spawn(func, args, name, future) - # If the task raises an exception after sending a start value without a switch point - # between, the task group is cancelled and this method never proceeds to process the - # completed future. That's why we have to have a shielded cancel scope here. - with CancelScope(shield=True): - try: - return await future - except CancelledError: - task.cancel() - raise + # If the task raises an exception after sending a start value without a switch + # point between, the task group is cancelled and this method never proceeds to + # process the completed future. That's why we have to have a shielded cancel + # scope here. + try: + return await future + except CancelledError: + # Cancel the task and wait for it to exit before returning + task.cancel() + with CancelScope(shield=True), suppress(CancelledError): + await task + + raise # @@ -767,15 +811,15 @@ class WorkerThread(Thread): self.idle_workers = idle_workers self.loop = root_task._loop self.queue: Queue[ - tuple[Context, Callable, tuple, asyncio.Future] | None + tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None ] = Queue(2) - self.idle_since = current_time() + self.idle_since = AsyncIOBackend.current_time() self.stopping = False def _report_result( self, future: asyncio.Future, result: Any, exc: BaseException | None ) -> None: - self.idle_since = current_time() + self.idle_since = AsyncIOBackend.current_time() if not self.stopping: self.idle_workers.append(self) @@ -791,22 +835,24 @@ class WorkerThread(Thread): future.set_result(result) def run(self) -> None: - with claim_worker_thread("asyncio"): - threadlocals.loop = self.loop + with claim_worker_thread(AsyncIOBackend, self.loop): while True: item = self.queue.get() if item is None: # Shutdown command received return - context, func, args, future = item + context, func, args, future, cancel_scope = item if not future.cancelled(): result = None exception: BaseException | None = None + threadlocals.current_cancel_scope = cancel_scope try: result = context.run(func, *args) except BaseException as exc: exception = exc + finally: + del threadlocals.current_cancel_scope if not self.loop.is_closed(): self.loop.call_soon_threadsafe( @@ -831,81 +877,6 @@ _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar( _threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers") -async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], - *args: object, - cancellable: bool = False, - limiter: CapacityLimiter | None = None, -) -> T_Retval: - await checkpoint() - - # If this is the first run in this event loop thread, set up the necessary variables - try: - idle_workers = _threadpool_idle_workers.get() - workers = _threadpool_workers.get() - except LookupError: - idle_workers = deque() - workers = set() - _threadpool_idle_workers.set(idle_workers) - _threadpool_workers.set(workers) - - async with (limiter or current_default_thread_limiter()): - with CancelScope(shield=not cancellable): - future: asyncio.Future = asyncio.Future() - root_task = find_root_task() - if not idle_workers: - worker = WorkerThread(root_task, workers, idle_workers) - worker.start() - workers.add(worker) - root_task.add_done_callback(worker.stop) - else: - worker = idle_workers.pop() - - # Prune any other workers that have been idle for MAX_IDLE_TIME seconds or longer - now = current_time() - while idle_workers: - if now - idle_workers[0].idle_since < WorkerThread.MAX_IDLE_TIME: - break - - expired_worker = idle_workers.popleft() - expired_worker.root_task.remove_done_callback(expired_worker.stop) - expired_worker.stop() - - context = copy_context() - context.run(sniffio.current_async_library_cvar.set, None) - worker.queue.put_nowait((context, func, args, future)) - return await future - - -def run_sync_from_thread( - func: Callable[..., T_Retval], - *args: object, - loop: asyncio.AbstractEventLoop | None = None, -) -> T_Retval: - @wraps(func) - def wrapper() -> None: - try: - f.set_result(func(*args)) - except BaseException as exc: - f.set_exception(exc) - if not isinstance(exc, Exception): - raise - - f: concurrent.futures.Future[T_Retval] = Future() - loop = loop or threadlocals.loop - loop.call_soon_threadsafe(wrapper) - return f.result() - - -def run_async_from_thread( - func: Callable[..., Awaitable[T_Retval]], *args: object -) -> T_Retval: - f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( - func(*args), threadlocals.loop - ) - return f.result() - - class BlockingPortal(abc.BlockingPortal): def __new__(cls) -> BlockingPortal: return object.__new__(cls) @@ -916,20 +887,16 @@ class BlockingPortal(abc.BlockingPortal): def _spawn_task_from_thread( self, - func: Callable, - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, - future: Future, + future: Future[T_Retval], ) -> None: - run_sync_from_thread( + AsyncIOBackend.run_sync_from_thread( partial(self._task_group.start_soon, name=name), - self._call_func, - func, - args, - kwargs, - future, - loop=self._loop, + (self._call_func, func, args, kwargs, future), + self._loop, ) @@ -951,6 +918,7 @@ class StreamReaderWrapper(abc.ByteReceiveStream): async def aclose(self) -> None: self._stream.feed_eof() + await AsyncIOBackend.checkpoint() @dataclass(eq=False) @@ -963,6 +931,7 @@ class StreamWriterWrapper(abc.ByteSendStream): async def aclose(self) -> None: self._stream.close() + await AsyncIOBackend.checkpoint() @dataclass(eq=False) @@ -973,14 +942,22 @@ class Process(abc.Process): _stderr: StreamReaderWrapper | None async def aclose(self) -> None: - if self._stdin: - await self._stdin.aclose() - if self._stdout: - await self._stdout.aclose() - if self._stderr: - await self._stderr.aclose() + with CancelScope(shield=True): + if self._stdin: + await self._stdin.aclose() + if self._stdout: + await self._stdout.aclose() + if self._stderr: + await self._stderr.aclose() - await self.wait() + try: + await self.wait() + except BaseException: + self.kill() + with CancelScope(shield=True): + await self.wait() + + raise async def wait(self) -> int: return await self._process.wait() @@ -1015,55 +992,17 @@ class Process(abc.Process): return self._stderr -async def open_process( - command: str | bytes | Sequence[str | bytes], - *, - shell: bool, - stdin: int | IO[Any] | None, - stdout: int | IO[Any] | None, - stderr: int | IO[Any] | None, - cwd: str | bytes | PathLike | None = None, - env: Mapping[str, str] | None = None, - start_new_session: bool = False, -) -> Process: - await checkpoint() - if shell: - process = await asyncio.create_subprocess_shell( - cast(Union[str, bytes], command), - stdin=stdin, - stdout=stdout, - stderr=stderr, - cwd=cwd, - env=env, - start_new_session=start_new_session, - ) - else: - process = await asyncio.create_subprocess_exec( - *command, - stdin=stdin, - stdout=stdout, - stderr=stderr, - cwd=cwd, - env=env, - start_new_session=start_new_session, - ) - - stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None - stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None - stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None - return Process(process, stdin_stream, stdout_stream, stderr_stream) - - def _forcibly_shutdown_process_pool_on_exit( workers: set[Process], _task: object ) -> None: """ Forcibly shuts down worker processes belonging to this event loop.""" - child_watcher: asyncio.AbstractChildWatcher | None - try: - child_watcher = asyncio.get_event_loop_policy().get_child_watcher() - except NotImplementedError: - child_watcher = None + child_watcher: asyncio.AbstractChildWatcher | None = None + if sys.version_info < (3, 12): + try: + child_watcher = asyncio.get_event_loop_policy().get_child_watcher() + except NotImplementedError: + pass # Close as much as possible (w/o async/await) to avoid warnings for process in workers: @@ -1078,14 +1017,15 @@ def _forcibly_shutdown_process_pool_on_exit( child_watcher.remove_child_handler(process.pid) -async def _shutdown_process_pool_on_exit(workers: set[Process]) -> None: +async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None: """ Shuts down worker processes belonging to this event loop. - NOTE: this only works when the event loop was started using asyncio.run() or anyio.run(). + NOTE: this only works when the event loop was started using asyncio.run() or + anyio.run(). """ - process: Process + process: abc.Process try: await sleep(math.inf) except asyncio.CancelledError: @@ -1097,16 +1037,6 @@ async def _shutdown_process_pool_on_exit(workers: set[Process]) -> None: await process.aclose() -def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None: - kwargs: dict[str, Any] = ( - {"name": "AnyIO process pool shutdown task"} if _native_task_names else {} - ) - create_task(_shutdown_process_pool_on_exit(workers), **kwargs) - find_root_task().add_done_callback( - partial(_forcibly_shutdown_process_pool_on_exit, workers) - ) - - # # Sockets and networking # @@ -1193,7 +1123,7 @@ class SocketStream(abc.SocketStream): async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: - await checkpoint() + await AsyncIOBackend.checkpoint() if ( not self._protocol.read_event.is_set() @@ -1209,7 +1139,7 @@ class SocketStream(abc.SocketStream): if self._closed: raise ClosedResourceError from None elif self._protocol.exception: - raise self._protocol.exception + raise self._protocol.exception from None else: raise EndOfStream from None @@ -1218,8 +1148,8 @@ class SocketStream(abc.SocketStream): chunk, leftover = chunk[:max_bytes], chunk[max_bytes:] self._protocol.read_queue.appendleft(leftover) - # If the read queue is empty, clear the flag so that the next call will block until - # data is available + # If the read queue is empty, clear the flag so that the next call will + # block until data is available if not self._protocol.read_queue: self._protocol.read_event.clear() @@ -1227,7 +1157,7 @@ class SocketStream(abc.SocketStream): async def send(self, item: bytes) -> None: with self._send_guard: - await checkpoint() + await AsyncIOBackend.checkpoint() if self._closed: raise ClosedResourceError @@ -1263,14 +1193,13 @@ class SocketStream(abc.SocketStream): self._transport.abort() -class UNIXSocketStream(abc.SocketStream): +class _RawSocketMixin: _receive_future: asyncio.Future | None = None _send_future: asyncio.Future | None = None _closing = False def __init__(self, raw_socket: socket.socket): self.__raw_socket = raw_socket - self._loop = get_running_loop() self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") @@ -1284,7 +1213,7 @@ class UNIXSocketStream(abc.SocketStream): loop.remove_reader(self.__raw_socket) f = self._receive_future = asyncio.Future() - self._loop.add_reader(self.__raw_socket, f.set_result, None) + loop.add_reader(self.__raw_socket, f.set_result, None) f.add_done_callback(callback) return f @@ -1294,21 +1223,34 @@ class UNIXSocketStream(abc.SocketStream): loop.remove_writer(self.__raw_socket) f = self._send_future = asyncio.Future() - self._loop.add_writer(self.__raw_socket, f.set_result, None) + loop.add_writer(self.__raw_socket, f.set_result, None) f.add_done_callback(callback) return f + async def aclose(self) -> None: + if not self._closing: + self._closing = True + if self.__raw_socket.fileno() != -1: + self.__raw_socket.close() + + if self._receive_future: + self._receive_future.set_result(None) + if self._send_future: + self._send_future.set_result(None) + + +class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream): async def send_eof(self) -> None: with self._send_guard: self._raw_socket.shutdown(socket.SHUT_WR) async def receive(self, max_bytes: int = 65536) -> bytes: loop = get_running_loop() - await checkpoint() + await AsyncIOBackend.checkpoint() with self._receive_guard: while True: try: - data = self.__raw_socket.recv(max_bytes) + data = self._raw_socket.recv(max_bytes) except BlockingIOError: await self._wait_until_readable(loop) except OSError as exc: @@ -1324,12 +1266,12 @@ class UNIXSocketStream(abc.SocketStream): async def send(self, item: bytes) -> None: loop = get_running_loop() - await checkpoint() + await AsyncIOBackend.checkpoint() with self._send_guard: view = memoryview(item) while view: try: - bytes_sent = self.__raw_socket.send(view) + bytes_sent = self._raw_socket.send(view) except BlockingIOError: await self._wait_until_writable(loop) except OSError as exc: @@ -1348,11 +1290,11 @@ class UNIXSocketStream(abc.SocketStream): loop = get_running_loop() fds = array.array("i") - await checkpoint() + await AsyncIOBackend.checkpoint() with self._receive_guard: while True: try: - message, ancdata, flags, addr = self.__raw_socket.recvmsg( + message, ancdata, flags, addr = self._raw_socket.recvmsg( msglen, socket.CMSG_LEN(maxfds * fds.itemsize) ) except BlockingIOError: @@ -1394,13 +1336,13 @@ class UNIXSocketStream(abc.SocketStream): filenos.append(fd.fileno()) fdarray = array.array("i", filenos) - await checkpoint() + await AsyncIOBackend.checkpoint() with self._send_guard: while True: try: # The ignore can be removed after mypy picks up # https://github.com/python/typeshed/pull/5545 - self.__raw_socket.sendmsg( + self._raw_socket.sendmsg( [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)] ) break @@ -1412,17 +1354,6 @@ class UNIXSocketStream(abc.SocketStream): else: raise BrokenResourceError from exc - async def aclose(self) -> None: - if not self._closing: - self._closing = True - if self.__raw_socket.fileno() != -1: - self.__raw_socket.close() - - if self._receive_future: - self._receive_future.set_result(None) - if self._send_future: - self._send_future.set_result(None) - class TCPSocketListener(abc.SocketListener): _accept_scope: CancelScope | None = None @@ -1442,7 +1373,7 @@ class TCPSocketListener(abc.SocketListener): raise ClosedResourceError with self._accept_guard: - await checkpoint() + await AsyncIOBackend.checkpoint() with CancelScope() as self._accept_scope: try: client_sock, _addr = await self._loop.sock_accept(self._raw_socket) @@ -1492,7 +1423,7 @@ class UNIXSocketListener(abc.SocketListener): self._closed = False async def accept(self) -> abc.SocketStream: - await checkpoint() + await AsyncIOBackend.checkpoint() with self._accept_guard: while True: try: @@ -1542,7 +1473,7 @@ class UDPSocket(abc.UDPSocket): async def receive(self) -> tuple[bytes, IPSockAddrType]: with self._receive_guard: - await checkpoint() + await AsyncIOBackend.checkpoint() # If the buffer is empty, ask for more data if not self._protocol.read_queue and not self._transport.is_closing(): @@ -1559,7 +1490,7 @@ class UDPSocket(abc.UDPSocket): async def send(self, item: UDPPacketType) -> None: with self._send_guard: - await checkpoint() + await AsyncIOBackend.checkpoint() await self._protocol.write_event.wait() if self._closed: raise ClosedResourceError @@ -1590,7 +1521,7 @@ class ConnectedUDPSocket(abc.ConnectedUDPSocket): async def receive(self) -> bytes: with self._receive_guard: - await checkpoint() + await AsyncIOBackend.checkpoint() # If the buffer is empty, ask for more data if not self._protocol.read_queue and not self._transport.is_closing(): @@ -1609,7 +1540,7 @@ class ConnectedUDPSocket(abc.ConnectedUDPSocket): async def send(self, item: bytes) -> None: with self._send_guard: - await checkpoint() + await AsyncIOBackend.checkpoint() await self._protocol.write_event.wait() if self._closed: raise ClosedResourceError @@ -1619,142 +1550,82 @@ class ConnectedUDPSocket(abc.ConnectedUDPSocket): self._transport.sendto(item) -async def connect_tcp( - host: str, port: int, local_addr: tuple[str, int] | None = None -) -> SocketStream: - transport, protocol = cast( - Tuple[asyncio.Transport, StreamProtocol], - await get_running_loop().create_connection( - StreamProtocol, host, port, local_addr=local_addr - ), - ) - transport.pause_reading() - return SocketStream(transport, protocol) - - -async def connect_unix(path: str) -> UNIXSocketStream: - await checkpoint() - loop = get_running_loop() - raw_socket = socket.socket(socket.AF_UNIX) - raw_socket.setblocking(False) - while True: - try: - raw_socket.connect(path) - except BlockingIOError: - f: asyncio.Future = asyncio.Future() - loop.add_writer(raw_socket, f.set_result, None) - f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) - await f - except BaseException: - raw_socket.close() - raise - else: - return UNIXSocketStream(raw_socket) - - -async def create_udp_socket( - family: socket.AddressFamily, - local_address: IPSockAddrType | None, - remote_address: IPSockAddrType | None, - reuse_port: bool, -) -> UDPSocket | ConnectedUDPSocket: - result = await get_running_loop().create_datagram_endpoint( - DatagramProtocol, - local_addr=local_address, - remote_addr=remote_address, - family=family, - reuse_port=reuse_port, - ) - transport = result[0] - protocol = result[1] - if protocol.exception: - transport.close() - raise protocol.exception - - if not remote_address: - return UDPSocket(transport, protocol) - else: - return ConnectedUDPSocket(transport, protocol) +class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket): + async def receive(self) -> UNIXDatagramPacketType: + loop = get_running_loop() + await AsyncIOBackend.checkpoint() + with self._receive_guard: + while True: + try: + data = self._raw_socket.recvfrom(65536) + except BlockingIOError: + await self._wait_until_readable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + return data + async def send(self, item: UNIXDatagramPacketType) -> None: + loop = get_running_loop() + await AsyncIOBackend.checkpoint() + with self._send_guard: + while True: + try: + self._raw_socket.sendto(*item) + except BlockingIOError: + await self._wait_until_writable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + return -async def getaddrinfo( - host: bytes | str, - port: str | int | None, - *, - family: int | AddressFamily = 0, - type: int | SocketKind = 0, - proto: int = 0, - flags: int = 0, -) -> GetAddrInfoReturnType: - # https://github.com/python/typeshed/pull/4304 - result = await get_running_loop().getaddrinfo( - host, port, family=family, type=type, proto=proto, flags=flags - ) - return cast(GetAddrInfoReturnType, result) +class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket): + async def receive(self) -> bytes: + loop = get_running_loop() + await AsyncIOBackend.checkpoint() + with self._receive_guard: + while True: + try: + data = self._raw_socket.recv(65536) + except BlockingIOError: + await self._wait_until_readable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + return data -async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> tuple[str, str]: - return await get_running_loop().getnameinfo(sockaddr, flags) + async def send(self, item: bytes) -> None: + loop = get_running_loop() + await AsyncIOBackend.checkpoint() + with self._send_guard: + while True: + try: + self._raw_socket.send(item) + except BlockingIOError: + await self._wait_until_writable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + return _read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events") _write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events") -async def wait_socket_readable(sock: socket.socket) -> None: - await checkpoint() - try: - read_events = _read_events.get() - except LookupError: - read_events = {} - _read_events.set(read_events) - - if read_events.get(sock): - raise BusyResourceError("reading from") from None - - loop = get_running_loop() - event = read_events[sock] = asyncio.Event() - loop.add_reader(sock, event.set) - try: - await event.wait() - finally: - if read_events.pop(sock, None) is not None: - loop.remove_reader(sock) - readable = True - else: - readable = False - - if not readable: - raise ClosedResourceError - - -async def wait_socket_writable(sock: socket.socket) -> None: - await checkpoint() - try: - write_events = _write_events.get() - except LookupError: - write_events = {} - _write_events.set(write_events) - - if write_events.get(sock): - raise BusyResourceError("writing to") from None - - loop = get_running_loop() - event = write_events[sock] = asyncio.Event() - loop.add_writer(sock.fileno(), event.set) - try: - await event.wait() - finally: - if write_events.pop(sock, None) is not None: - loop.remove_writer(sock) - writable = True - else: - writable = False - - if not writable: - raise ClosedResourceError - - # # Synchronization # @@ -1767,16 +1638,17 @@ class Event(BaseEvent): def __init__(self) -> None: self._event = asyncio.Event() - def set(self) -> DeprecatedAwaitable: + def set(self) -> None: self._event.set() - return DeprecatedAwaitable(self.set) def is_set(self) -> bool: return self._event.is_set() async def wait(self) -> None: - if await self._event.wait(): - await checkpoint() + if self.is_set(): + await AsyncIOBackend.checkpoint() + else: + await self._event.wait() def statistics(self) -> EventStatistics: return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined] @@ -1815,19 +1687,14 @@ class CapacityLimiter(BaseCapacityLimiter): if value < 1: raise ValueError("total_tokens must be >= 1") - old_value = self._total_tokens + waiters_to_notify = max(value - self._total_tokens, 0) self._total_tokens = value - events = [] - for event in self._wait_queue.values(): - if value <= old_value: - break - - if not event.is_set(): - events.append(event) - old_value += 1 - for event in events: + # Notify waiting tasks that they have acquired the limiter + while self._wait_queue and waiters_to_notify: + event = self._wait_queue.popitem(last=False)[1] event.set() + waiters_to_notify -= 1 @property def borrowed_tokens(self) -> int: @@ -1837,11 +1704,10 @@ class CapacityLimiter(BaseCapacityLimiter): def available_tokens(self) -> float: return self._total_tokens - len(self._borrowers) - def acquire_nowait(self) -> DeprecatedAwaitable: + def acquire_nowait(self) -> None: self.acquire_on_behalf_of_nowait(current_task()) - return DeprecatedAwaitable(self.acquire_nowait) - def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: + def acquire_on_behalf_of_nowait(self, borrower: object) -> None: if borrower in self._borrowers: raise RuntimeError( "this borrower is already holding one of this CapacityLimiter's " @@ -1852,13 +1718,12 @@ class CapacityLimiter(BaseCapacityLimiter): raise WouldBlock self._borrowers.add(borrower) - return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) async def acquire(self) -> None: return await self.acquire_on_behalf_of(current_task()) async def acquire_on_behalf_of(self, borrower: object) -> None: - await checkpoint_if_cancelled() + await AsyncIOBackend.checkpoint_if_cancelled() try: self.acquire_on_behalf_of_nowait(borrower) except WouldBlock: @@ -1873,7 +1738,7 @@ class CapacityLimiter(BaseCapacityLimiter): self._borrowers.add(borrower) else: try: - await cancel_shielded_checkpoint() + await AsyncIOBackend.cancel_shielded_checkpoint() except BaseException: self.release() raise @@ -1906,29 +1771,20 @@ class CapacityLimiter(BaseCapacityLimiter): _default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter") -def current_default_thread_limiter() -> CapacityLimiter: - try: - return _default_thread_limiter.get() - except LookupError: - limiter = CapacityLimiter(40) - _default_thread_limiter.set(limiter) - return limiter - - # # Operating system signals # -class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): - def __init__(self, signals: tuple[int, ...]): +class _SignalReceiver: + def __init__(self, signals: tuple[Signals, ...]): self._signals = signals self._loop = get_running_loop() - self._signal_queue: deque[int] = deque() + self._signal_queue: deque[Signals] = deque() self._future: asyncio.Future = asyncio.Future() - self._handled_signals: set[int] = set() + self._handled_signals: set[Signals] = set() - def _deliver(self, signum: int) -> None: + def _deliver(self, signum: Signals) -> None: self._signal_queue.append(signum) if not self._future.done(): self._future.set_result(None) @@ -1953,8 +1809,8 @@ class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): def __aiter__(self) -> _SignalReceiver: return self - async def __anext__(self) -> int: - await checkpoint() + async def __anext__(self) -> Signals: + await AsyncIOBackend.checkpoint() if not self._signal_queue: self._future = asyncio.Future() await self._future @@ -1962,10 +1818,6 @@ class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): return self._signal_queue.popleft() -def open_signal_receiver(*signals: int) -> _SignalReceiver: - return _SignalReceiver(signals) - - # # Testing and debugging # @@ -1974,69 +1826,47 @@ def open_signal_receiver(*signals: int) -> _SignalReceiver: def _create_task_info(task: asyncio.Task) -> TaskInfo: task_state = _task_states.get(task) if task_state is None: - name = task.get_name() if _native_task_names else None parent_id = None else: - name = task_state.name parent_id = task_state.parent_id - return TaskInfo(id(task), parent_id, name, get_coro(task)) - - -def get_current_task() -> TaskInfo: - return _create_task_info(current_task()) # type: ignore[arg-type] - - -def get_running_tasks() -> list[TaskInfo]: - return [_create_task_info(task) for task in all_tasks() if not task.done()] - - -async def wait_all_tasks_blocked() -> None: - await checkpoint() - this_task = current_task() - while True: - for task in all_tasks(): - if task is this_task: - continue - - if task._fut_waiter is None or task._fut_waiter.done(): # type: ignore[attr-defined] - await sleep(0.1) - break - else: - return + return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro()) class TestRunner(abc.TestRunner): + _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]] + def __init__( self, - debug: bool = False, + *, + debug: bool | None = None, use_uvloop: bool = False, - policy: asyncio.AbstractEventLoopPolicy | None = None, - ): + loop_factory: Callable[[], AbstractEventLoop] | None = None, + ) -> None: + if use_uvloop and loop_factory is None: + import uvloop + + loop_factory = uvloop.new_event_loop + + self._runner = Runner(debug=debug, loop_factory=loop_factory) self._exceptions: list[BaseException] = [] - _maybe_set_event_loop_policy(policy, use_uvloop) - self._loop = asyncio.new_event_loop() - self._loop.set_debug(debug) - self._loop.set_exception_handler(self._exception_handler) - asyncio.set_event_loop(self._loop) - - def _cancel_all_tasks(self) -> None: - to_cancel = all_tasks(self._loop) - if not to_cancel: - return + self._runner_task: asyncio.Task | None = None - for task in to_cancel: - task.cancel() + def __enter__(self) -> TestRunner: + self._runner.__enter__() + self.get_loop().set_exception_handler(self._exception_handler) + return self - self._loop.run_until_complete( - asyncio.gather(*to_cancel, return_exceptions=True) - ) + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self._runner.__exit__(exc_type, exc_val, exc_tb) - for task in to_cancel: - if task.cancelled(): - continue - if task.exception() is not None: - raise cast(BaseException, task.exception()) + def get_loop(self) -> AbstractEventLoop: + return self._runner.get_loop() def _exception_handler( self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] @@ -2053,56 +1883,77 @@ class TestRunner(abc.TestRunner): if len(exceptions) == 1: raise exceptions[0] elif exceptions: - raise ExceptionGroup(exceptions) + raise BaseExceptionGroup( + "Multiple exceptions occurred in asynchronous callbacks", exceptions + ) - def close(self) -> None: - try: - self._cancel_all_tasks() - self._loop.run_until_complete(self._loop.shutdown_asyncgens()) - finally: - asyncio.set_event_loop(None) - self._loop.close() + @staticmethod + async def _run_tests_and_fixtures( + receive_stream: MemoryObjectReceiveStream[ + tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]] + ], + ) -> None: + with receive_stream: + async for coro, future in receive_stream: + try: + retval = await coro + except BaseException as exc: + if not future.cancelled(): + future.set_exception(exc) + else: + if not future.cancelled(): + future.set_result(retval) + + async def _call_in_runner_task( + self, + func: Callable[P, Awaitable[T_Retval]], + *args: P.args, + **kwargs: P.kwargs, + ) -> T_Retval: + if not self._runner_task: + self._send_stream, receive_stream = create_memory_object_stream[ + Tuple[Awaitable[Any], asyncio.Future] + ](1) + self._runner_task = self.get_loop().create_task( + self._run_tests_and_fixtures(receive_stream) + ) + + coro = func(*args, **kwargs) + future: asyncio.Future[T_Retval] = self.get_loop().create_future() + self._send_stream.send_nowait((coro, future)) + return await future def run_asyncgen_fixture( self, fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], kwargs: dict[str, Any], ) -> Iterable[T_Retval]: - async def fixture_runner() -> None: - agen = fixture_func(**kwargs) - try: - retval = await agen.asend(None) - self._raise_async_exceptions() - except BaseException as exc: - f.set_exception(exc) - return - else: - f.set_result(retval) - - await event.wait() - try: - await agen.asend(None) - except StopAsyncIteration: - pass - else: - await agen.aclose() - raise RuntimeError("Async generator fixture did not stop") - - f = self._loop.create_future() - event = asyncio.Event() - fixture_task = self._loop.create_task(fixture_runner()) - self._loop.run_until_complete(f) - yield f.result() - event.set() - self._loop.run_until_complete(fixture_task) + asyncgen = fixture_func(**kwargs) + fixturevalue: T_Retval = self.get_loop().run_until_complete( + self._call_in_runner_task(asyncgen.asend, None) + ) self._raise_async_exceptions() + yield fixturevalue + + try: + self.get_loop().run_until_complete( + self._call_in_runner_task(asyncgen.asend, None) + ) + except StopAsyncIteration: + self._raise_async_exceptions() + else: + self.get_loop().run_until_complete(asyncgen.aclose()) + raise RuntimeError("Async generator fixture did not stop") + def run_fixture( self, fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], kwargs: dict[str, Any], ) -> T_Retval: - retval = self._loop.run_until_complete(fixture_func(**kwargs)) + retval = self.get_loop().run_until_complete( + self._call_in_runner_task(fixture_func, **kwargs) + ) self._raise_async_exceptions() return retval @@ -2110,8 +1961,518 @@ class TestRunner(abc.TestRunner): self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] ) -> None: try: - self._loop.run_until_complete(test_func(**kwargs)) + self.get_loop().run_until_complete( + self._call_in_runner_task(test_func, **kwargs) + ) except Exception as exc: self._exceptions.append(exc) self._raise_async_exceptions() + + +class AsyncIOBackend(AsyncBackend): + @classmethod + def run( + cls, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], + kwargs: dict[str, Any], + options: dict[str, Any], + ) -> T_Retval: + @wraps(func) + async def wrapper() -> T_Retval: + task = cast(asyncio.Task, current_task()) + task.set_name(get_callable_name(func)) + _task_states[task] = TaskState(None, None) + + try: + return await func(*args) + finally: + del _task_states[task] + + debug = options.get("debug", False) + loop_factory = options.get("loop_factory", None) + if loop_factory is None and options.get("use_uvloop", False): + import uvloop + + loop_factory = uvloop.new_event_loop + + with Runner(debug=debug, loop_factory=loop_factory) as runner: + return runner.run(wrapper()) + + @classmethod + def current_token(cls) -> object: + return get_running_loop() + + @classmethod + def current_time(cls) -> float: + return get_running_loop().time() + + @classmethod + def cancelled_exception_class(cls) -> type[BaseException]: + return CancelledError + + @classmethod + async def checkpoint(cls) -> None: + await sleep(0) + + @classmethod + async def checkpoint_if_cancelled(cls) -> None: + task = current_task() + if task is None: + return + + try: + cancel_scope = _task_states[task].cancel_scope + except KeyError: + return + + while cancel_scope: + if cancel_scope.cancel_called: + await sleep(0) + elif cancel_scope.shield: + break + else: + cancel_scope = cancel_scope._parent_scope + + @classmethod + async def cancel_shielded_checkpoint(cls) -> None: + with CancelScope(shield=True): + await sleep(0) + + @classmethod + async def sleep(cls, delay: float) -> None: + await sleep(delay) + + @classmethod + def create_cancel_scope( + cls, *, deadline: float = math.inf, shield: bool = False + ) -> CancelScope: + return CancelScope(deadline=deadline, shield=shield) + + @classmethod + def current_effective_deadline(cls) -> float: + try: + cancel_scope = _task_states[ + current_task() # type: ignore[index] + ].cancel_scope + except KeyError: + return math.inf + + deadline = math.inf + while cancel_scope: + deadline = min(deadline, cancel_scope.deadline) + if cancel_scope._cancel_called: + deadline = -math.inf + break + elif cancel_scope.shield: + break + else: + cancel_scope = cancel_scope._parent_scope + + return deadline + + @classmethod + def create_task_group(cls) -> abc.TaskGroup: + return TaskGroup() + + @classmethod + def create_event(cls) -> abc.Event: + return Event() + + @classmethod + def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter: + return CapacityLimiter(total_tokens) + + @classmethod + async def run_sync_in_worker_thread( + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + abandon_on_cancel: bool = False, + limiter: abc.CapacityLimiter | None = None, + ) -> T_Retval: + await cls.checkpoint() + + # If this is the first run in this event loop thread, set up the necessary + # variables + try: + idle_workers = _threadpool_idle_workers.get() + workers = _threadpool_workers.get() + except LookupError: + idle_workers = deque() + workers = set() + _threadpool_idle_workers.set(idle_workers) + _threadpool_workers.set(workers) + + async with limiter or cls.current_default_thread_limiter(): + with CancelScope(shield=not abandon_on_cancel) as scope: + future: asyncio.Future = asyncio.Future() + root_task = find_root_task() + if not idle_workers: + worker = WorkerThread(root_task, workers, idle_workers) + worker.start() + workers.add(worker) + root_task.add_done_callback(worker.stop) + else: + worker = idle_workers.pop() + + # Prune any other workers that have been idle for MAX_IDLE_TIME + # seconds or longer + now = cls.current_time() + while idle_workers: + if ( + now - idle_workers[0].idle_since + < WorkerThread.MAX_IDLE_TIME + ): + break + + expired_worker = idle_workers.popleft() + expired_worker.root_task.remove_done_callback( + expired_worker.stop + ) + expired_worker.stop() + + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, None) + if abandon_on_cancel or scope._parent_scope is None: + worker_scope = scope + else: + worker_scope = scope._parent_scope + + worker.queue.put_nowait((context, func, args, future, worker_scope)) + return await future + + @classmethod + def check_cancelled(cls) -> None: + scope: CancelScope | None = threadlocals.current_cancel_scope + while scope is not None: + if scope.cancel_called: + raise CancelledError(f"Cancelled by cancel scope {id(scope):x}") + + if scope.shield: + return + + scope = scope._parent_scope + + @classmethod + def run_async_from_thread( + cls, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], + token: object, + ) -> T_Retval: + async def task_wrapper(scope: CancelScope) -> T_Retval: + __tracebackhide__ = True + task = cast(asyncio.Task, current_task()) + _task_states[task] = TaskState(None, scope) + scope._tasks.add(task) + try: + return await func(*args) + except CancelledError as exc: + raise concurrent.futures.CancelledError(str(exc)) from None + finally: + scope._tasks.discard(task) + + loop = cast(AbstractEventLoop, token) + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, "asyncio") + wrapper = task_wrapper(threadlocals.current_cancel_scope) + f: concurrent.futures.Future[T_Retval] = context.run( + asyncio.run_coroutine_threadsafe, wrapper, loop + ) + return f.result() + + @classmethod + def run_sync_from_thread( + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + token: object, + ) -> T_Retval: + @wraps(func) + def wrapper() -> None: + try: + sniffio.current_async_library_cvar.set("asyncio") + f.set_result(func(*args)) + except BaseException as exc: + f.set_exception(exc) + if not isinstance(exc, Exception): + raise + + f: concurrent.futures.Future[T_Retval] = Future() + loop = cast(AbstractEventLoop, token) + loop.call_soon_threadsafe(wrapper) + return f.result() + + @classmethod + def create_blocking_portal(cls) -> abc.BlockingPortal: + return BlockingPortal() + + @classmethod + async def open_process( + cls, + command: str | bytes | Sequence[str | bytes], + *, + shell: bool, + stdin: int | IO[Any] | None, + stdout: int | IO[Any] | None, + stderr: int | IO[Any] | None, + cwd: str | bytes | PathLike | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, + ) -> Process: + await cls.checkpoint() + if shell: + process = await asyncio.create_subprocess_shell( + cast("str | bytes", command), + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + else: + process = await asyncio.create_subprocess_exec( + *command, + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + + stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None + stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None + stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None + return Process(process, stdin_stream, stdout_stream, stderr_stream) + + @classmethod + def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None: + create_task( + _shutdown_process_pool_on_exit(workers), + name="AnyIO process pool shutdown task", + ) + find_root_task().add_done_callback( + partial(_forcibly_shutdown_process_pool_on_exit, workers) + ) + + @classmethod + async def connect_tcp( + cls, host: str, port: int, local_address: IPSockAddrType | None = None + ) -> abc.SocketStream: + transport, protocol = cast( + Tuple[asyncio.Transport, StreamProtocol], + await get_running_loop().create_connection( + StreamProtocol, host, port, local_addr=local_address + ), + ) + transport.pause_reading() + return SocketStream(transport, protocol) + + @classmethod + async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream: + await cls.checkpoint() + loop = get_running_loop() + raw_socket = socket.socket(socket.AF_UNIX) + raw_socket.setblocking(False) + while True: + try: + raw_socket.connect(path) + except BlockingIOError: + f: asyncio.Future = asyncio.Future() + loop.add_writer(raw_socket, f.set_result, None) + f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) + await f + except BaseException: + raw_socket.close() + raise + else: + return UNIXSocketStream(raw_socket) + + @classmethod + def create_tcp_listener(cls, sock: socket.socket) -> SocketListener: + return TCPSocketListener(sock) + + @classmethod + def create_unix_listener(cls, sock: socket.socket) -> SocketListener: + return UNIXSocketListener(sock) + + @classmethod + async def create_udp_socket( + cls, + family: AddressFamily, + local_address: IPSockAddrType | None, + remote_address: IPSockAddrType | None, + reuse_port: bool, + ) -> UDPSocket | ConnectedUDPSocket: + transport, protocol = await get_running_loop().create_datagram_endpoint( + DatagramProtocol, + local_addr=local_address, + remote_addr=remote_address, + family=family, + reuse_port=reuse_port, + ) + if protocol.exception: + transport.close() + raise protocol.exception + + if not remote_address: + return UDPSocket(transport, protocol) + else: + return ConnectedUDPSocket(transport, protocol) + + @classmethod + async def create_unix_datagram_socket( # type: ignore[override] + cls, raw_socket: socket.socket, remote_path: str | bytes | None + ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket: + await cls.checkpoint() + loop = get_running_loop() + + if remote_path: + while True: + try: + raw_socket.connect(remote_path) + except BlockingIOError: + f: asyncio.Future = asyncio.Future() + loop.add_writer(raw_socket, f.set_result, None) + f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) + await f + except BaseException: + raw_socket.close() + raise + else: + return ConnectedUNIXDatagramSocket(raw_socket) + else: + return UNIXDatagramSocket(raw_socket) + + @classmethod + async def getaddrinfo( + cls, + host: bytes | str | None, + port: str | int | None, + *, + family: int | AddressFamily = 0, + type: int | SocketKind = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: + return await get_running_loop().getaddrinfo( + host, port, family=family, type=type, proto=proto, flags=flags + ) + + @classmethod + async def getnameinfo( + cls, sockaddr: IPSockAddrType, flags: int = 0 + ) -> tuple[str, str]: + return await get_running_loop().getnameinfo(sockaddr, flags) + + @classmethod + async def wait_socket_readable(cls, sock: socket.socket) -> None: + await cls.checkpoint() + try: + read_events = _read_events.get() + except LookupError: + read_events = {} + _read_events.set(read_events) + + if read_events.get(sock): + raise BusyResourceError("reading from") from None + + loop = get_running_loop() + event = read_events[sock] = asyncio.Event() + loop.add_reader(sock, event.set) + try: + await event.wait() + finally: + if read_events.pop(sock, None) is not None: + loop.remove_reader(sock) + readable = True + else: + readable = False + + if not readable: + raise ClosedResourceError + + @classmethod + async def wait_socket_writable(cls, sock: socket.socket) -> None: + await cls.checkpoint() + try: + write_events = _write_events.get() + except LookupError: + write_events = {} + _write_events.set(write_events) + + if write_events.get(sock): + raise BusyResourceError("writing to") from None + + loop = get_running_loop() + event = write_events[sock] = asyncio.Event() + loop.add_writer(sock.fileno(), event.set) + try: + await event.wait() + finally: + if write_events.pop(sock, None) is not None: + loop.remove_writer(sock) + writable = True + else: + writable = False + + if not writable: + raise ClosedResourceError + + @classmethod + def current_default_thread_limiter(cls) -> CapacityLimiter: + try: + return _default_thread_limiter.get() + except LookupError: + limiter = CapacityLimiter(40) + _default_thread_limiter.set(limiter) + return limiter + + @classmethod + def open_signal_receiver( + cls, *signals: Signals + ) -> ContextManager[AsyncIterator[Signals]]: + return _SignalReceiver(signals) + + @classmethod + def get_current_task(cls) -> TaskInfo: + return _create_task_info(current_task()) # type: ignore[arg-type] + + @classmethod + def get_running_tasks(cls) -> list[TaskInfo]: + return [_create_task_info(task) for task in all_tasks() if not task.done()] + + @classmethod + async def wait_all_tasks_blocked(cls) -> None: + await cls.checkpoint() + this_task = current_task() + while True: + for task in all_tasks(): + if task is this_task: + continue + + waiter = task._fut_waiter # type: ignore[attr-defined] + if waiter is None or waiter.done(): + await sleep(0.1) + break + else: + return + + @classmethod + def create_test_runner(cls, options: dict[str, Any]) -> TestRunner: + return TestRunner(**options) + + +backend_class = AsyncIOBackend diff --git a/contrib/python/anyio/anyio/_backends/_trio.py b/contrib/python/anyio/anyio/_backends/_trio.py index cf28943509..1a47192e30 100644 --- a/contrib/python/anyio/anyio/_backends/_trio.py +++ b/contrib/python/anyio/anyio/_backends/_trio.py @@ -3,41 +3,48 @@ from __future__ import annotations import array import math import socket +import sys +import types +from collections.abc import AsyncIterator, Iterable from concurrent.futures import Future -from contextvars import copy_context from dataclasses import dataclass from functools import partial from io import IOBase from os import PathLike from signal import Signals +from socket import AddressFamily, SocketKind from types import TracebackType from typing import ( IO, - TYPE_CHECKING, Any, AsyncGenerator, - AsyncIterator, Awaitable, Callable, Collection, + ContextManager, Coroutine, Generic, - Iterable, Mapping, NoReturn, Sequence, TypeVar, cast, + overload, ) -import sniffio import trio.from_thread +import trio.lowlevel from outcome import Error, Outcome, Value +from trio.lowlevel import ( + current_root_task, + current_task, + wait_readable, + wait_writable, +) from trio.socket import SocketType as TrioSocketType from trio.to_thread import run_sync from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc -from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable from .._core._eventloop import claim_worker_thread from .._core._exceptions import ( BrokenResourceError, @@ -45,54 +52,42 @@ from .._core._exceptions import ( ClosedResourceError, EndOfStream, ) -from .._core._exceptions import ExceptionGroup as BaseExceptionGroup from .._core._sockets import convert_ipv6_sockaddr +from .._core._streams import create_memory_object_stream from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter from .._core._synchronization import Event as BaseEvent from .._core._synchronization import ResourceGuard from .._core._tasks import CancelScope as BaseCancelScope -from ..abc import IPSockAddrType, UDPPacketType - -if TYPE_CHECKING: - from trio_typing import TaskStatus +from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType +from ..abc._eventloop import AsyncBackend +from ..streams.memory import MemoryObjectSendStream -try: - from trio import lowlevel as trio_lowlevel -except ImportError: - from trio import hazmat as trio_lowlevel # type: ignore[no-redef] - from trio.hazmat import wait_readable, wait_writable +if sys.version_info >= (3, 10): + from typing import ParamSpec else: - from trio.lowlevel import wait_readable, wait_writable + from typing_extensions import ParamSpec -try: - trio_open_process = trio_lowlevel.open_process -except AttributeError: - # isort: off - from trio import ( # type: ignore[attr-defined, no-redef] - open_process as trio_open_process, - ) +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from exceptiongroup import BaseExceptionGroup + from typing_extensions import TypeVarTuple, Unpack +T = TypeVar("T") T_Retval = TypeVar("T_Retval") T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType) +PosArgsT = TypeVarTuple("PosArgsT") +P = ParamSpec("P") # # Event loop # -run = trio.run -current_token = trio.lowlevel.current_trio_token RunVar = trio.lowlevel.RunVar # -# Miscellaneous -# - -sleep = trio.sleep - - -# # Timeouts and cancellation # @@ -117,13 +112,10 @@ class CancelScope(BaseCancelScope): exc_tb: TracebackType | None, ) -> bool | None: # https://github.com/python-trio/trio-typing/pull/79 - return self.__original.__exit__( # type: ignore[func-returns-value] - exc_type, exc_val, exc_tb - ) + return self.__original.__exit__(exc_type, exc_val, exc_tb) - def cancel(self) -> DeprecatedAwaitable: + def cancel(self) -> None: self.__original.cancel() - return DeprecatedAwaitable(self.cancel) @property def deadline(self) -> float: @@ -138,6 +130,10 @@ class CancelScope(BaseCancelScope): return self.__original.cancel_called @property + def cancelled_caught(self) -> bool: + return self.__original.cancelled_caught + + @property def shield(self) -> bool: return self.__original.shield @@ -146,27 +142,15 @@ class CancelScope(BaseCancelScope): self.__original.shield = value -CancelledError = trio.Cancelled -checkpoint = trio.lowlevel.checkpoint -checkpoint_if_cancelled = trio.lowlevel.checkpoint_if_cancelled -cancel_shielded_checkpoint = trio.lowlevel.cancel_shielded_checkpoint -current_effective_deadline = trio.current_effective_deadline -current_time = trio.current_time - - # # Task groups # -class ExceptionGroup(BaseExceptionGroup, trio.MultiError): - pass - - class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self._active = False - self._nursery_manager = trio.open_nursery() + self._nursery_manager = trio.open_nursery(strict_exception_groups=True) self.cancel_scope = None # type: ignore[assignment] async def __aenter__(self) -> TaskGroup: @@ -183,13 +167,21 @@ class TaskGroup(abc.TaskGroup): ) -> bool | None: try: return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) - except trio.MultiError as exc: - raise ExceptionGroup(exc.exceptions) from None + except BaseExceptionGroup as exc: + _, rest = exc.split(trio.Cancelled) + if not rest: + cancelled_exc = trio.Cancelled._create() + raise cancelled_exc from exc + + raise finally: self._active = False def start_soon( - self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + name: object = None, ) -> None: if not self._active: raise RuntimeError( @@ -200,7 +192,7 @@ class TaskGroup(abc.TaskGroup): async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None - ) -> object: + ) -> Any: if not self._active: raise RuntimeError( "This task group is not active; no new tasks can be started." @@ -214,53 +206,6 @@ class TaskGroup(abc.TaskGroup): # -async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], - *args: object, - cancellable: bool = False, - limiter: trio.CapacityLimiter | None = None, -) -> T_Retval: - def wrapper() -> T_Retval: - with claim_worker_thread("trio"): - return func(*args) - - # TODO: remove explicit context copying when trio 0.20 is the minimum requirement - context = copy_context() - context.run(sniffio.current_async_library_cvar.set, None) - return await run_sync( - context.run, wrapper, cancellable=cancellable, limiter=limiter - ) - - -# TODO: remove this workaround when trio 0.20 is the minimum requirement -def run_async_from_thread( - fn: Callable[..., Awaitable[T_Retval]], *args: Any -) -> T_Retval: - async def wrapper() -> T_Retval: - retval: T_Retval - - async def inner() -> None: - nonlocal retval - __tracebackhide__ = True - retval = await fn(*args) - - async with trio.open_nursery() as n: - context.run(n.start_soon, inner) - - __tracebackhide__ = True - return retval # noqa: F821 - - context = copy_context() - context.run(sniffio.current_async_library_cvar.set, "trio") - return trio.from_thread.run(wrapper) - - -def run_sync_from_thread(fn: Callable[..., T_Retval], *args: Any) -> T_Retval: - # TODO: remove explicit context copying when trio 0.20 is the minimum requirement - retval = trio.from_thread.run_sync(copy_context().run, fn, *args) - return cast(T_Retval, retval) - - class BlockingPortal(abc.BlockingPortal): def __new__(cls) -> BlockingPortal: return object.__new__(cls) @@ -271,16 +216,13 @@ class BlockingPortal(abc.BlockingPortal): def _spawn_task_from_thread( self, - func: Callable, - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, - future: Future, + future: Future[T_Retval], ) -> None: - context = copy_context() - context.run(sniffio.current_async_library_cvar.set, "trio") trio.from_thread.run_sync( - context.run, partial(self._task_group.start_soon, name=name), self._call_func, func, @@ -341,14 +283,21 @@ class Process(abc.Process): _stderr: abc.ByteReceiveStream | None async def aclose(self) -> None: - if self._stdin: - await self._stdin.aclose() - if self._stdout: - await self._stdout.aclose() - if self._stderr: - await self._stderr.aclose() + with CancelScope(shield=True): + if self._stdin: + await self._stdin.aclose() + if self._stdout: + await self._stdout.aclose() + if self._stderr: + await self._stderr.aclose() - await self.wait() + try: + await self.wait() + except BaseException: + self.kill() + with CancelScope(shield=True): + await self.wait() + raise async def wait(self) -> int: return await self._process.wait() @@ -383,47 +332,19 @@ class Process(abc.Process): return self._stderr -async def open_process( - command: str | bytes | Sequence[str | bytes], - *, - shell: bool, - stdin: int | IO[Any] | None, - stdout: int | IO[Any] | None, - stderr: int | IO[Any] | None, - cwd: str | bytes | PathLike | None = None, - env: Mapping[str, str] | None = None, - start_new_session: bool = False, -) -> Process: - process = await trio_open_process( # type: ignore[misc] - command, # type: ignore[arg-type] - stdin=stdin, - stdout=stdout, - stderr=stderr, - shell=shell, - cwd=cwd, - env=env, - start_new_session=start_new_session, - ) - stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None - stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None - stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None - return Process(process, stdin_stream, stdout_stream, stderr_stream) - - class _ProcessPoolShutdownInstrument(trio.abc.Instrument): def after_run(self) -> None: super().after_run() -current_default_worker_process_limiter: RunVar = RunVar( +current_default_worker_process_limiter: trio.lowlevel.RunVar = RunVar( "current_default_worker_process_limiter" ) -async def _shutdown_process_pool(workers: set[Process]) -> None: - process: Process +async def _shutdown_process_pool(workers: set[abc.Process]) -> None: try: - await sleep(math.inf) + await trio.sleep(math.inf) except trio.Cancelled: for process in workers: if process.returncode is None: @@ -434,10 +355,6 @@ async def _shutdown_process_pool(workers: set[Process]) -> None: await process.aclose() -def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None: - trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers) - - # # Sockets and networking # @@ -515,7 +432,7 @@ class UNIXSocketStream(SocketStream, abc.UNIXSocketStream): raise ValueError("maxfds must be a positive integer") fds = array.array("i") - await checkpoint() + await trio.lowlevel.checkpoint() with self._receive_guard: while True: try: @@ -555,7 +472,7 @@ class UNIXSocketStream(SocketStream, abc.UNIXSocketStream): filenos.append(fd.fileno()) fdarray = array.array("i", filenos) - await checkpoint() + await trio.lowlevel.checkpoint() with self._send_guard: while True: try: @@ -564,7 +481,7 @@ class UNIXSocketStream(SocketStream, abc.UNIXSocketStream): [ ( socket.SOL_SOCKET, - socket.SCM_RIGHTS, # type: ignore[list-item] + socket.SCM_RIGHTS, fdarray, ) ], @@ -648,76 +565,49 @@ class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocke self._convert_socket_error(exc) -async def connect_tcp( - host: str, port: int, local_address: IPSockAddrType | None = None -) -> SocketStream: - family = socket.AF_INET6 if ":" in host else socket.AF_INET - trio_socket = trio.socket.socket(family) - trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - if local_address: - await trio_socket.bind(local_address) - - try: - await trio_socket.connect((host, port)) - except BaseException: - trio_socket.close() - raise - - return SocketStream(trio_socket) - - -async def connect_unix(path: str) -> UNIXSocketStream: - trio_socket = trio.socket.socket(socket.AF_UNIX) - try: - await trio_socket.connect(path) - except BaseException: - trio_socket.close() - raise - - return UNIXSocketStream(trio_socket) - - -async def create_udp_socket( - family: socket.AddressFamily, - local_address: IPSockAddrType | None, - remote_address: IPSockAddrType | None, - reuse_port: bool, -) -> UDPSocket | ConnectedUDPSocket: - trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) - - if reuse_port: - trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - - if local_address: - await trio_socket.bind(local_address) - - if remote_address: - await trio_socket.connect(remote_address) - return ConnectedUDPSocket(trio_socket) - else: - return UDPSocket(trio_socket) +class UNIXDatagramSocket(_TrioSocketMixin[str], abc.UNIXDatagramSocket): + def __init__(self, trio_socket: TrioSocketType) -> None: + super().__init__(trio_socket) + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + async def receive(self) -> UNIXDatagramPacketType: + with self._receive_guard: + try: + data, addr = await self._trio_socket.recvfrom(65536) + return data, addr + except BaseException as exc: + self._convert_socket_error(exc) -getaddrinfo = trio.socket.getaddrinfo -getnameinfo = trio.socket.getnameinfo + async def send(self, item: UNIXDatagramPacketType) -> None: + with self._send_guard: + try: + await self._trio_socket.sendto(*item) + except BaseException as exc: + self._convert_socket_error(exc) -async def wait_socket_readable(sock: socket.socket) -> None: - try: - await wait_readable(sock) - except trio.ClosedResourceError as exc: - raise ClosedResourceError().with_traceback(exc.__traceback__) from None - except trio.BusyResourceError: - raise BusyResourceError("reading from") from None +class ConnectedUNIXDatagramSocket( + _TrioSocketMixin[str], abc.ConnectedUNIXDatagramSocket +): + def __init__(self, trio_socket: TrioSocketType) -> None: + super().__init__(trio_socket) + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + async def receive(self) -> bytes: + with self._receive_guard: + try: + return await self._trio_socket.recv(65536) + except BaseException as exc: + self._convert_socket_error(exc) -async def wait_socket_writable(sock: socket.socket) -> None: - try: - await wait_writable(sock) - except trio.ClosedResourceError as exc: - raise ClosedResourceError().with_traceback(exc.__traceback__) from None - except trio.BusyResourceError: - raise BusyResourceError("writing to") from None + async def send(self, item: bytes) -> None: + with self._send_guard: + try: + await self._trio_socket.send(item) + except BaseException as exc: + self._convert_socket_error(exc) # @@ -742,19 +632,30 @@ class Event(BaseEvent): orig_statistics = self.__original.statistics() return EventStatistics(tasks_waiting=orig_statistics.tasks_waiting) - def set(self) -> DeprecatedAwaitable: + def set(self) -> None: self.__original.set() - return DeprecatedAwaitable(self.set) class CapacityLimiter(BaseCapacityLimiter): - def __new__(cls, *args: object, **kwargs: object) -> CapacityLimiter: + def __new__( + cls, + total_tokens: float | None = None, + *, + original: trio.CapacityLimiter | None = None, + ) -> CapacityLimiter: return object.__new__(cls) def __init__( - self, *args: Any, original: trio.CapacityLimiter | None = None + self, + total_tokens: float | None = None, + *, + original: trio.CapacityLimiter | None = None, ) -> None: - self.__original = original or trio.CapacityLimiter(*args) + if original is not None: + self.__original = original + else: + assert total_tokens is not None + self.__original = trio.CapacityLimiter(total_tokens) async def __aenter__(self) -> None: return await self.__original.__aenter__() @@ -783,13 +684,11 @@ class CapacityLimiter(BaseCapacityLimiter): def available_tokens(self) -> float: return self.__original.available_tokens - def acquire_nowait(self) -> DeprecatedAwaitable: + def acquire_nowait(self) -> None: self.__original.acquire_nowait() - return DeprecatedAwaitable(self.acquire_nowait) - def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: + def acquire_on_behalf_of_nowait(self, borrower: object) -> None: self.__original.acquire_on_behalf_of_nowait(borrower) - return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) async def acquire(self) -> None: await self.__original.acquire() @@ -808,23 +707,12 @@ class CapacityLimiter(BaseCapacityLimiter): return CapacityLimiterStatistics( borrowed_tokens=orig.borrowed_tokens, total_tokens=orig.total_tokens, - borrowers=orig.borrowers, + borrowers=tuple(orig.borrowers), tasks_waiting=orig.tasks_waiting, ) -_capacity_limiter_wrapper: RunVar = RunVar("_capacity_limiter_wrapper") - - -def current_default_thread_limiter() -> CapacityLimiter: - try: - return _capacity_limiter_wrapper.get() - except LookupError: - limiter = CapacityLimiter( - original=trio.to_thread.current_default_thread_limiter() - ) - _capacity_limiter_wrapper.set(limiter) - return limiter +_capacity_limiter_wrapper: trio.lowlevel.RunVar = RunVar("_capacity_limiter_wrapper") # @@ -832,7 +720,7 @@ def current_default_thread_limiter() -> CapacityLimiter: # -class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): +class _SignalReceiver: _iterator: AsyncIterator[int] def __init__(self, signals: tuple[Signals, ...]): @@ -859,138 +747,423 @@ class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): return Signals(signum) -def open_signal_receiver(*signals: Signals) -> _SignalReceiver: - return _SignalReceiver(signals) - - # # Testing and debugging # -def get_current_task() -> TaskInfo: - task = trio_lowlevel.current_task() - - parent_id = None - if task.parent_nursery and task.parent_nursery.parent_task: - parent_id = id(task.parent_nursery.parent_task) - - return TaskInfo(id(task), parent_id, task.name, task.coro) - - -def get_running_tasks() -> list[TaskInfo]: - root_task = trio_lowlevel.current_root_task() - task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)] - nurseries = root_task.child_nurseries - while nurseries: - new_nurseries: list[trio.Nursery] = [] - for nursery in nurseries: - for task in nursery.child_tasks: - task_infos.append( - TaskInfo(id(task), id(nursery.parent_task), task.name, task.coro) - ) - new_nurseries.extend(task.child_nurseries) - - nurseries = new_nurseries - - return task_infos - - -def wait_all_tasks_blocked() -> Awaitable[None]: - import trio.testing - - return trio.testing.wait_all_tasks_blocked() - - class TestRunner(abc.TestRunner): def __init__(self, **options: Any) -> None: - from collections import deque from queue import Queue - self._call_queue: Queue[Callable[..., object]] = Queue() - self._result_queue: deque[Outcome] = deque() - self._stop_event: trio.Event | None = None - self._nursery: trio.Nursery | None = None + self._call_queue: Queue[Callable[[], object]] = Queue() + self._send_stream: MemoryObjectSendStream | None = None self._options = options - async def _trio_main(self) -> None: - self._stop_event = trio.Event() - async with trio.open_nursery() as self._nursery: - await self._stop_event.wait() - - async def _call_func( - self, func: Callable[..., Awaitable[object]], args: tuple, kwargs: dict + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, ) -> None: - try: - retval = await func(*args, **kwargs) - except BaseException as exc: - self._result_queue.append(Error(exc)) - else: - self._result_queue.append(Value(retval)) + if self._send_stream: + self._send_stream.close() + while self._send_stream is not None: + self._call_queue.get()() + + async def _run_tests_and_fixtures(self) -> None: + self._send_stream, receive_stream = create_memory_object_stream(1) + with receive_stream: + async for coro, outcome_holder in receive_stream: + try: + retval = await coro + except BaseException as exc: + outcome_holder.append(Error(exc)) + else: + outcome_holder.append(Value(retval)) def _main_task_finished(self, outcome: object) -> None: - self._nursery = None + self._send_stream = None - def _get_nursery(self) -> trio.Nursery: - if self._nursery is None: + def _call_in_runner_task( + self, + func: Callable[P, Awaitable[T_Retval]], + *args: P.args, + **kwargs: P.kwargs, + ) -> T_Retval: + if self._send_stream is None: trio.lowlevel.start_guest_run( - self._trio_main, + self._run_tests_and_fixtures, run_sync_soon_threadsafe=self._call_queue.put, done_callback=self._main_task_finished, **self._options, ) - while self._nursery is None: + while self._send_stream is None: self._call_queue.get()() - return self._nursery - - def _call( - self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object - ) -> T_Retval: - self._get_nursery().start_soon(self._call_func, func, args, kwargs) - while not self._result_queue: + outcome_holder: list[Outcome] = [] + self._send_stream.send_nowait((func(*args, **kwargs), outcome_holder)) + while not outcome_holder: self._call_queue.get()() - outcome = self._result_queue.pop() - return outcome.unwrap() - - def close(self) -> None: - if self._stop_event: - self._stop_event.set() - while self._nursery is not None: - self._call_queue.get()() + return outcome_holder[0].unwrap() def run_asyncgen_fixture( self, fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], kwargs: dict[str, Any], ) -> Iterable[T_Retval]: - async def fixture_runner(*, task_status: TaskStatus[T_Retval]) -> None: - agen = fixture_func(**kwargs) - retval = await agen.asend(None) - task_status.started(retval) - await teardown_event.wait() - try: - await agen.asend(None) - except StopAsyncIteration: - pass - else: - await agen.aclose() - raise RuntimeError("Async generator fixture did not stop") + asyncgen = fixture_func(**kwargs) + fixturevalue: T_Retval = self._call_in_runner_task(asyncgen.asend, None) - teardown_event = trio.Event() - fixture_value = self._call(lambda: self._get_nursery().start(fixture_runner)) - yield fixture_value - teardown_event.set() + yield fixturevalue + + try: + self._call_in_runner_task(asyncgen.asend, None) + except StopAsyncIteration: + pass + else: + self._call_in_runner_task(asyncgen.aclose) + raise RuntimeError("Async generator fixture did not stop") def run_fixture( self, fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], kwargs: dict[str, Any], ) -> T_Retval: - return self._call(fixture_func, **kwargs) + return self._call_in_runner_task(fixture_func, **kwargs) def run_test( self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] ) -> None: - self._call(test_func, **kwargs) + self._call_in_runner_task(test_func, **kwargs) + + +class TrioBackend(AsyncBackend): + @classmethod + def run( + cls, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], + kwargs: dict[str, Any], + options: dict[str, Any], + ) -> T_Retval: + return trio.run(func, *args) + + @classmethod + def current_token(cls) -> object: + return trio.lowlevel.current_trio_token() + + @classmethod + def current_time(cls) -> float: + return trio.current_time() + + @classmethod + def cancelled_exception_class(cls) -> type[BaseException]: + return trio.Cancelled + + @classmethod + async def checkpoint(cls) -> None: + await trio.lowlevel.checkpoint() + + @classmethod + async def checkpoint_if_cancelled(cls) -> None: + await trio.lowlevel.checkpoint_if_cancelled() + + @classmethod + async def cancel_shielded_checkpoint(cls) -> None: + await trio.lowlevel.cancel_shielded_checkpoint() + + @classmethod + async def sleep(cls, delay: float) -> None: + await trio.sleep(delay) + + @classmethod + def create_cancel_scope( + cls, *, deadline: float = math.inf, shield: bool = False + ) -> abc.CancelScope: + return CancelScope(deadline=deadline, shield=shield) + + @classmethod + def current_effective_deadline(cls) -> float: + return trio.current_effective_deadline() + + @classmethod + def create_task_group(cls) -> abc.TaskGroup: + return TaskGroup() + + @classmethod + def create_event(cls) -> abc.Event: + return Event() + + @classmethod + def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter: + return CapacityLimiter(total_tokens) + + @classmethod + async def run_sync_in_worker_thread( + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + abandon_on_cancel: bool = False, + limiter: abc.CapacityLimiter | None = None, + ) -> T_Retval: + def wrapper() -> T_Retval: + with claim_worker_thread(TrioBackend, token): + return func(*args) + + token = TrioBackend.current_token() + return await run_sync( + wrapper, + abandon_on_cancel=abandon_on_cancel, + limiter=cast(trio.CapacityLimiter, limiter), + ) + + @classmethod + def check_cancelled(cls) -> None: + trio.from_thread.check_cancelled() + + @classmethod + def run_async_from_thread( + cls, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], + token: object, + ) -> T_Retval: + return trio.from_thread.run(func, *args) + + @classmethod + def run_sync_from_thread( + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + token: object, + ) -> T_Retval: + return trio.from_thread.run_sync(func, *args) + + @classmethod + def create_blocking_portal(cls) -> abc.BlockingPortal: + return BlockingPortal() + + @classmethod + async def open_process( + cls, + command: str | bytes | Sequence[str | bytes], + *, + shell: bool, + stdin: int | IO[Any] | None, + stdout: int | IO[Any] | None, + stderr: int | IO[Any] | None, + cwd: str | bytes | PathLike | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, + ) -> Process: + process = await trio.lowlevel.open_process( # type: ignore[misc] + command, # type: ignore[arg-type] + stdin=stdin, + stdout=stdout, + stderr=stderr, + shell=shell, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None + stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None + stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None + return Process(process, stdin_stream, stdout_stream, stderr_stream) + + @classmethod + def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None: + trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers) + + @classmethod + async def connect_tcp( + cls, host: str, port: int, local_address: IPSockAddrType | None = None + ) -> SocketStream: + family = socket.AF_INET6 if ":" in host else socket.AF_INET + trio_socket = trio.socket.socket(family) + trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if local_address: + await trio_socket.bind(local_address) + + try: + await trio_socket.connect((host, port)) + except BaseException: + trio_socket.close() + raise + + return SocketStream(trio_socket) + + @classmethod + async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream: + trio_socket = trio.socket.socket(socket.AF_UNIX) + try: + await trio_socket.connect(path) + except BaseException: + trio_socket.close() + raise + + return UNIXSocketStream(trio_socket) + + @classmethod + def create_tcp_listener(cls, sock: socket.socket) -> abc.SocketListener: + return TCPSocketListener(sock) + + @classmethod + def create_unix_listener(cls, sock: socket.socket) -> abc.SocketListener: + return UNIXSocketListener(sock) + + @classmethod + async def create_udp_socket( + cls, + family: socket.AddressFamily, + local_address: IPSockAddrType | None, + remote_address: IPSockAddrType | None, + reuse_port: bool, + ) -> UDPSocket | ConnectedUDPSocket: + trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) + + if reuse_port: + trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + if local_address: + await trio_socket.bind(local_address) + + if remote_address: + await trio_socket.connect(remote_address) + return ConnectedUDPSocket(trio_socket) + else: + return UDPSocket(trio_socket) + + @classmethod + @overload + async def create_unix_datagram_socket( + cls, raw_socket: socket.socket, remote_path: None + ) -> abc.UNIXDatagramSocket: + ... + + @classmethod + @overload + async def create_unix_datagram_socket( + cls, raw_socket: socket.socket, remote_path: str | bytes + ) -> abc.ConnectedUNIXDatagramSocket: + ... + + @classmethod + async def create_unix_datagram_socket( + cls, raw_socket: socket.socket, remote_path: str | bytes | None + ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket: + trio_socket = trio.socket.from_stdlib_socket(raw_socket) + + if remote_path: + await trio_socket.connect(remote_path) + return ConnectedUNIXDatagramSocket(trio_socket) + else: + return UNIXDatagramSocket(trio_socket) + + @classmethod + async def getaddrinfo( + cls, + host: bytes | str | None, + port: str | int | None, + *, + family: int | AddressFamily = 0, + type: int | SocketKind = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: + return await trio.socket.getaddrinfo(host, port, family, type, proto, flags) + + @classmethod + async def getnameinfo( + cls, sockaddr: IPSockAddrType, flags: int = 0 + ) -> tuple[str, str]: + return await trio.socket.getnameinfo(sockaddr, flags) + + @classmethod + async def wait_socket_readable(cls, sock: socket.socket) -> None: + try: + await wait_readable(sock) + except trio.ClosedResourceError as exc: + raise ClosedResourceError().with_traceback(exc.__traceback__) from None + except trio.BusyResourceError: + raise BusyResourceError("reading from") from None + + @classmethod + async def wait_socket_writable(cls, sock: socket.socket) -> None: + try: + await wait_writable(sock) + except trio.ClosedResourceError as exc: + raise ClosedResourceError().with_traceback(exc.__traceback__) from None + except trio.BusyResourceError: + raise BusyResourceError("writing to") from None + + @classmethod + def current_default_thread_limiter(cls) -> CapacityLimiter: + try: + return _capacity_limiter_wrapper.get() + except LookupError: + limiter = CapacityLimiter( + original=trio.to_thread.current_default_thread_limiter() + ) + _capacity_limiter_wrapper.set(limiter) + return limiter + + @classmethod + def open_signal_receiver( + cls, *signals: Signals + ) -> ContextManager[AsyncIterator[Signals]]: + return _SignalReceiver(signals) + + @classmethod + def get_current_task(cls) -> TaskInfo: + task = current_task() + + parent_id = None + if task.parent_nursery and task.parent_nursery.parent_task: + parent_id = id(task.parent_nursery.parent_task) + + return TaskInfo(id(task), parent_id, task.name, task.coro) + + @classmethod + def get_running_tasks(cls) -> list[TaskInfo]: + root_task = current_root_task() + assert root_task + task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)] + nurseries = root_task.child_nurseries + while nurseries: + new_nurseries: list[trio.Nursery] = [] + for nursery in nurseries: + for task in nursery.child_tasks: + task_infos.append( + TaskInfo( + id(task), id(nursery.parent_task), task.name, task.coro + ) + ) + new_nurseries.extend(task.child_nurseries) + + nurseries = new_nurseries + + return task_infos + + @classmethod + async def wait_all_tasks_blocked(cls) -> None: + from trio.testing import wait_all_tasks_blocked + + await wait_all_tasks_blocked() + + @classmethod + def create_test_runner(cls, options: dict[str, Any]) -> TestRunner: + return TestRunner(**options) + + +backend_class = TrioBackend diff --git a/contrib/python/anyio/anyio/_core/_compat.py b/contrib/python/anyio/anyio/_core/_compat.py deleted file mode 100644 index 22d29ab8ac..0000000000 --- a/contrib/python/anyio/anyio/_core/_compat.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -from abc import ABCMeta, abstractmethod -from contextlib import AbstractContextManager -from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - AsyncContextManager, - Callable, - ContextManager, - Generator, - Generic, - Iterable, - List, - TypeVar, - Union, - overload, -) -from warnings import warn - -if TYPE_CHECKING: - from ._testing import TaskInfo -else: - TaskInfo = object - -T = TypeVar("T") -AnyDeprecatedAwaitable = Union[ - "DeprecatedAwaitable", - "DeprecatedAwaitableFloat", - "DeprecatedAwaitableList[T]", - TaskInfo, -] - - -@overload -async def maybe_async(__obj: TaskInfo) -> TaskInfo: - ... - - -@overload -async def maybe_async(__obj: DeprecatedAwaitableFloat) -> float: - ... - - -@overload -async def maybe_async(__obj: DeprecatedAwaitableList[T]) -> list[T]: - ... - - -@overload -async def maybe_async(__obj: DeprecatedAwaitable) -> None: - ... - - -async def maybe_async( - __obj: AnyDeprecatedAwaitable[T], -) -> TaskInfo | float | list[T] | None: - """ - Await on the given object if necessary. - - This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and - methods were converted from coroutine functions into regular functions. - - Do **not** try to use this for any other purpose! - - :return: the result of awaiting on the object if coroutine, or the object itself otherwise - - .. versionadded:: 2.2 - - """ - return __obj._unwrap() - - -class _ContextManagerWrapper: - def __init__(self, cm: ContextManager[T]): - self._cm = cm - - async def __aenter__(self) -> T: - return self._cm.__enter__() - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - return self._cm.__exit__(exc_type, exc_val, exc_tb) - - -def maybe_async_cm( - cm: ContextManager[T] | AsyncContextManager[T], -) -> AsyncContextManager[T]: - """ - Wrap a regular context manager as an async one if necessary. - - This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and - methods were changed to return regular context managers instead of async ones. - - :param cm: a regular or async context manager - :return: an async context manager - - .. versionadded:: 2.2 - - """ - if not isinstance(cm, AbstractContextManager): - raise TypeError("Given object is not an context manager") - - return _ContextManagerWrapper(cm) - - -def _warn_deprecation( - awaitable: AnyDeprecatedAwaitable[Any], stacklevel: int = 1 -) -> None: - warn( - f'Awaiting on {awaitable._name}() is deprecated. Use "await ' - f"anyio.maybe_async({awaitable._name}(...)) if you have to support both AnyIO 2.x " - f'and 3.x, or just remove the "await" if you are completely migrating to AnyIO 3+.', - DeprecationWarning, - stacklevel=stacklevel + 1, - ) - - -class DeprecatedAwaitable: - def __init__(self, func: Callable[..., DeprecatedAwaitable]): - self._name = f"{func.__module__}.{func.__qualname__}" - - def __await__(self) -> Generator[None, None, None]: - _warn_deprecation(self) - if False: - yield - - def __reduce__(self) -> tuple[type[None], tuple[()]]: - return type(None), () - - def _unwrap(self) -> None: - return None - - -class DeprecatedAwaitableFloat(float): - def __new__( - cls, x: float, func: Callable[..., DeprecatedAwaitableFloat] - ) -> DeprecatedAwaitableFloat: - return super().__new__(cls, x) - - def __init__(self, x: float, func: Callable[..., DeprecatedAwaitableFloat]): - self._name = f"{func.__module__}.{func.__qualname__}" - - def __await__(self) -> Generator[None, None, float]: - _warn_deprecation(self) - if False: - yield - - return float(self) - - def __reduce__(self) -> tuple[type[float], tuple[float]]: - return float, (float(self),) - - def _unwrap(self) -> float: - return float(self) - - -class DeprecatedAwaitableList(List[T]): - def __init__( - self, - iterable: Iterable[T] = (), - *, - func: Callable[..., DeprecatedAwaitableList[T]], - ): - super().__init__(iterable) - self._name = f"{func.__module__}.{func.__qualname__}" - - def __await__(self) -> Generator[None, None, list[T]]: - _warn_deprecation(self) - if False: - yield - - return list(self) - - def __reduce__(self) -> tuple[type[list[T]], tuple[list[T]]]: - return list, (list(self),) - - def _unwrap(self) -> list[T]: - return list(self) - - -class DeprecatedAsyncContextManager(Generic[T], metaclass=ABCMeta): - @abstractmethod - def __enter__(self) -> T: - pass - - @abstractmethod - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - pass - - async def __aenter__(self) -> T: - warn( - f"Using {self.__class__.__name__} as an async context manager has been deprecated. " - f'Use "async with anyio.maybe_async_cm(yourcontextmanager) as foo:" if you have to ' - f'support both AnyIO 2.x and 3.x, or just remove the "async" from "async with" if ' - f"you are completely migrating to AnyIO 3+.", - DeprecationWarning, - ) - return self.__enter__() - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - return self.__exit__(exc_type, exc_val, exc_tb) diff --git a/contrib/python/anyio/anyio/_core/_eventloop.py b/contrib/python/anyio/anyio/_core/_eventloop.py index ae9864851b..a9c6e82585 100644 --- a/contrib/python/anyio/anyio/_core/_eventloop.py +++ b/contrib/python/anyio/anyio/_core/_eventloop.py @@ -3,30 +3,33 @@ from __future__ import annotations import math import sys import threading +from collections.abc import Awaitable, Callable, Generator from contextlib import contextmanager from importlib import import_module -from typing import ( - Any, - Awaitable, - Callable, - Generator, - TypeVar, -) +from typing import TYPE_CHECKING, Any, TypeVar import sniffio -# This must be updated when new backends are introduced -from ._compat import DeprecatedAwaitableFloat +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + +if TYPE_CHECKING: + from ..abc import AsyncBackend +# This must be updated when new backends are introduced BACKENDS = "asyncio", "trio" T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + threadlocals = threading.local() def run( - func: Callable[..., Awaitable[T_Retval]], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], backend: str = "asyncio", backend_options: dict[str, Any] | None = None, ) -> T_Retval: @@ -37,12 +40,13 @@ def run( :param func: a coroutine function :param args: positional arguments to ``func`` - :param backend: name of the asynchronous event loop implementation – currently either - ``asyncio`` or ``trio`` - :param backend_options: keyword arguments to call the backend ``run()`` implementation with - (documented :ref:`here <backend options>`) + :param backend: name of the asynchronous event loop implementation – currently + either ``asyncio`` or ``trio`` + :param backend_options: keyword arguments to call the backend ``run()`` + implementation with (documented :ref:`here <backend options>`) :return: the return value of the coroutine function - :raises RuntimeError: if an asynchronous event loop is already running in this thread + :raises RuntimeError: if an asynchronous event loop is already running in this + thread :raises LookupError: if the named backend is not found """ @@ -54,18 +58,19 @@ def run( raise RuntimeError(f"Already running {asynclib_name} in this thread") try: - asynclib = import_module(f"..._backends._{backend}", package=__name__) + async_backend = get_async_backend(backend) except ImportError as exc: raise LookupError(f"No such backend: {backend}") from exc token = None if sniffio.current_async_library_cvar.get(None) is None: - # Since we're in control of the event loop, we can cache the name of the async library + # Since we're in control of the event loop, we can cache the name of the async + # library token = sniffio.current_async_library_cvar.set(backend) try: backend_options = backend_options or {} - return asynclib.run(func, *args, **backend_options) + return async_backend.run(func, args, {}, backend_options) finally: if token: sniffio.current_async_library_cvar.reset(token) @@ -78,7 +83,7 @@ async def sleep(delay: float) -> None: :param delay: the duration, in seconds """ - return await get_asynclib().sleep(delay) + return await get_async_backend().sleep(delay) async def sleep_forever() -> None: @@ -97,8 +102,8 @@ async def sleep_until(deadline: float) -> None: """ Pause the current task until the given time. - :param deadline: the absolute time to wake up at (according to the internal monotonic clock of - the event loop) + :param deadline: the absolute time to wake up at (according to the internal + monotonic clock of the event loop) .. versionadded:: 3.1 @@ -107,14 +112,14 @@ async def sleep_until(deadline: float) -> None: await sleep(max(deadline - now, 0)) -def current_time() -> DeprecatedAwaitableFloat: +def current_time() -> float: """ Return the current value of the event loop's internal clock. :return: the clock value (seconds) """ - return DeprecatedAwaitableFloat(get_asynclib().current_time(), current_time) + return get_async_backend().current_time() def get_all_backends() -> tuple[str, ...]: @@ -124,7 +129,7 @@ def get_all_backends() -> tuple[str, ...]: def get_cancelled_exc_class() -> type[BaseException]: """Return the current async library's cancellation exception class.""" - return get_asynclib().CancelledError + return get_async_backend().cancelled_exception_class() # @@ -133,21 +138,26 @@ def get_cancelled_exc_class() -> type[BaseException]: @contextmanager -def claim_worker_thread(backend: str) -> Generator[Any, None, None]: - module = sys.modules["anyio._backends._" + backend] - threadlocals.current_async_module = module +def claim_worker_thread( + backend_class: type[AsyncBackend], token: object +) -> Generator[Any, None, None]: + threadlocals.current_async_backend = backend_class + threadlocals.current_token = token try: yield finally: - del threadlocals.current_async_module + del threadlocals.current_async_backend + del threadlocals.current_token -def get_asynclib(asynclib_name: str | None = None) -> Any: +def get_async_backend(asynclib_name: str | None = None) -> AsyncBackend: if asynclib_name is None: asynclib_name = sniffio.current_async_library() modulename = "anyio._backends._" + asynclib_name try: - return sys.modules[modulename] + module = sys.modules[modulename] except KeyError: - return import_module(modulename) + module = import_module(modulename) + + return getattr(module, "backend_class") diff --git a/contrib/python/anyio/anyio/_core/_exceptions.py b/contrib/python/anyio/anyio/_core/_exceptions.py index 92ccd77a2d..571c3b8531 100644 --- a/contrib/python/anyio/anyio/_core/_exceptions.py +++ b/contrib/python/anyio/anyio/_core/_exceptions.py @@ -1,24 +1,25 @@ from __future__ import annotations -from traceback import format_exception - class BrokenResourceError(Exception): """ - Raised when trying to use a resource that has been rendered unusable due to external causes - (e.g. a send stream whose peer has disconnected). + Raised when trying to use a resource that has been rendered unusable due to external + causes (e.g. a send stream whose peer has disconnected). """ class BrokenWorkerProcess(Exception): """ - Raised by :func:`run_sync_in_process` if the worker process terminates abruptly or otherwise - misbehaves. + Raised by :func:`run_sync_in_process` if the worker process terminates abruptly or + otherwise misbehaves. """ class BusyResourceError(Exception): - """Raised when two tasks are trying to read from or write to the same resource concurrently.""" + """ + Raised when two tasks are trying to read from or write to the same resource + concurrently. + """ def __init__(self, action: str): super().__init__(f"Another task is already {action} this resource") @@ -30,7 +31,8 @@ class ClosedResourceError(Exception): class DelimiterNotFound(Exception): """ - Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the + Raised during + :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the maximum number of bytes has been read without the delimiter being found. """ @@ -41,38 +43,15 @@ class DelimiterNotFound(Exception): class EndOfStream(Exception): - """Raised when trying to read from a stream that has been closed from the other end.""" - - -class ExceptionGroup(BaseException): """ - Raised when multiple exceptions have been raised in a task group. - - :var ~typing.Sequence[BaseException] exceptions: the sequence of exceptions raised together + Raised when trying to read from a stream that has been closed from the other end. """ - SEPARATOR = "----------------------------\n" - - exceptions: list[BaseException] - - def __str__(self) -> str: - tracebacks = [ - "".join(format_exception(type(exc), exc, exc.__traceback__)) - for exc in self.exceptions - ] - return ( - f"{len(self.exceptions)} exceptions were raised in the task group:\n" - f"{self.SEPARATOR}{self.SEPARATOR.join(tracebacks)}" - ) - - def __repr__(self) -> str: - exception_reprs = ", ".join(repr(exc) for exc in self.exceptions) - return f"<{self.__class__.__name__}: {exception_reprs}>" - class IncompleteRead(Exception): """ - Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_exactly` or + Raised during + :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_exactly` or :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the connection is closed before the requested amount of bytes has been read. """ @@ -85,8 +64,8 @@ class IncompleteRead(Exception): class TypedAttributeLookupError(LookupError): """ - Raised by :meth:`~anyio.TypedAttributeProvider.extra` when the given typed attribute is not - found and no default value has been given. + Raised by :meth:`~anyio.TypedAttributeProvider.extra` when the given typed attribute + is not found and no default value has been given. """ diff --git a/contrib/python/anyio/anyio/_core/_fileio.py b/contrib/python/anyio/anyio/_core/_fileio.py index 35e8e8af6c..d054be693d 100644 --- a/contrib/python/anyio/anyio/_core/_fileio.py +++ b/contrib/python/anyio/anyio/_core/_fileio.py @@ -3,6 +3,7 @@ from __future__ import annotations import os import pathlib import sys +from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass from functools import partial from os import PathLike @@ -12,23 +13,14 @@ from typing import ( Any, AnyStr, AsyncIterator, - Callable, + Final, Generic, - Iterable, - Iterator, - Sequence, - cast, overload, ) from .. import to_thread from ..abc import AsyncResource -if sys.version_info >= (3, 8): - from typing import Final -else: - from typing_extensions import Final - if TYPE_CHECKING: from _typeshed import OpenBinaryMode, OpenTextMode, ReadableBuffer, WriteableBuffer else: @@ -39,8 +31,8 @@ class AsyncFile(AsyncResource, Generic[AnyStr]): """ An asynchronous file object. - This class wraps a standard file object and provides async friendly versions of the following - blocking methods (where available on the original file object): + This class wraps a standard file object and provides async friendly versions of the + following blocking methods (where available on the original file object): * read * read1 @@ -57,8 +49,8 @@ class AsyncFile(AsyncResource, Generic[AnyStr]): All other methods are directly passed through. - This class supports the asynchronous context manager protocol which closes the underlying file - at the end of the context block. + This class supports the asynchronous context manager protocol which closes the + underlying file at the end of the context block. This class also supports asynchronous iteration:: @@ -212,22 +204,25 @@ class _PathIterator(AsyncIterator["Path"]): iterator: Iterator[PathLike[str]] async def __anext__(self) -> Path: - nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True) + nextval = await to_thread.run_sync( + next, self.iterator, None, abandon_on_cancel=True + ) if nextval is None: raise StopAsyncIteration from None - return Path(cast("PathLike[str]", nextval)) + return Path(nextval) class Path: """ An asynchronous version of :class:`pathlib.Path`. - This class cannot be substituted for :class:`pathlib.Path` or :class:`pathlib.PurePath`, but - it is compatible with the :class:`os.PathLike` interface. + This class cannot be substituted for :class:`pathlib.Path` or + :class:`pathlib.PurePath`, but it is compatible with the :class:`os.PathLike` + interface. - It implements the Python 3.10 version of :class:`pathlib.Path` interface, except for the - deprecated :meth:`~pathlib.Path.link_to` method. + It implements the Python 3.10 version of :class:`pathlib.Path` interface, except for + the deprecated :meth:`~pathlib.Path.link_to` method. Any methods that do disk I/O need to be awaited on. These methods are: @@ -263,7 +258,8 @@ class Path: * :meth:`~pathlib.Path.write_bytes` * :meth:`~pathlib.Path.write_text` - Additionally, the following methods return an async iterator yielding :class:`~.Path` objects: + Additionally, the following methods return an async iterator yielding + :class:`~.Path` objects: * :meth:`~pathlib.Path.glob` * :meth:`~pathlib.Path.iterdir` @@ -296,26 +292,26 @@ class Path: target = other._path if isinstance(other, Path) else other return self._path.__eq__(target) - def __lt__(self, other: Path) -> bool: + def __lt__(self, other: pathlib.PurePath | Path) -> bool: target = other._path if isinstance(other, Path) else other return self._path.__lt__(target) - def __le__(self, other: Path) -> bool: + def __le__(self, other: pathlib.PurePath | Path) -> bool: target = other._path if isinstance(other, Path) else other return self._path.__le__(target) - def __gt__(self, other: Path) -> bool: + def __gt__(self, other: pathlib.PurePath | Path) -> bool: target = other._path if isinstance(other, Path) else other return self._path.__gt__(target) - def __ge__(self, other: Path) -> bool: + def __ge__(self, other: pathlib.PurePath | Path) -> bool: target = other._path if isinstance(other, Path) else other return self._path.__ge__(target) - def __truediv__(self, other: Any) -> Path: + def __truediv__(self, other: str | PathLike[str]) -> Path: return Path(self._path / other) - def __rtruediv__(self, other: Any) -> Path: + def __rtruediv__(self, other: str | PathLike[str]) -> Path: return Path(other) / self @property @@ -371,13 +367,16 @@ class Path: def match(self, path_pattern: str) -> bool: return self._path.match(path_pattern) - def is_relative_to(self, *other: str | PathLike[str]) -> bool: + def is_relative_to(self, other: str | PathLike[str]) -> bool: try: - self.relative_to(*other) + self.relative_to(other) return True except ValueError: return False + async def is_junction(self) -> bool: + return await to_thread.run_sync(self._path.is_junction) + async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None: func = partial(os.chmod, follow_symlinks=follow_symlinks) return await to_thread.run_sync(func, self._path, mode) @@ -388,19 +387,23 @@ class Path: return cls(path) async def exists(self) -> bool: - return await to_thread.run_sync(self._path.exists, cancellable=True) + return await to_thread.run_sync(self._path.exists, abandon_on_cancel=True) async def expanduser(self) -> Path: - return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True)) + return Path( + await to_thread.run_sync(self._path.expanduser, abandon_on_cancel=True) + ) def glob(self, pattern: str) -> AsyncIterator[Path]: gen = self._path.glob(pattern) return _PathIterator(gen) async def group(self) -> str: - return await to_thread.run_sync(self._path.group, cancellable=True) + return await to_thread.run_sync(self._path.group, abandon_on_cancel=True) - async def hardlink_to(self, target: str | pathlib.Path | Path) -> None: + async def hardlink_to( + self, target: str | bytes | PathLike[str] | PathLike[bytes] + ) -> None: if isinstance(target, Path): target = target._path @@ -415,31 +418,37 @@ class Path: return self._path.is_absolute() async def is_block_device(self) -> bool: - return await to_thread.run_sync(self._path.is_block_device, cancellable=True) + return await to_thread.run_sync( + self._path.is_block_device, abandon_on_cancel=True + ) async def is_char_device(self) -> bool: - return await to_thread.run_sync(self._path.is_char_device, cancellable=True) + return await to_thread.run_sync( + self._path.is_char_device, abandon_on_cancel=True + ) async def is_dir(self) -> bool: - return await to_thread.run_sync(self._path.is_dir, cancellable=True) + return await to_thread.run_sync(self._path.is_dir, abandon_on_cancel=True) async def is_fifo(self) -> bool: - return await to_thread.run_sync(self._path.is_fifo, cancellable=True) + return await to_thread.run_sync(self._path.is_fifo, abandon_on_cancel=True) async def is_file(self) -> bool: - return await to_thread.run_sync(self._path.is_file, cancellable=True) + return await to_thread.run_sync(self._path.is_file, abandon_on_cancel=True) async def is_mount(self) -> bool: - return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True) + return await to_thread.run_sync( + os.path.ismount, self._path, abandon_on_cancel=True + ) def is_reserved(self) -> bool: return self._path.is_reserved() async def is_socket(self) -> bool: - return await to_thread.run_sync(self._path.is_socket, cancellable=True) + return await to_thread.run_sync(self._path.is_socket, abandon_on_cancel=True) async def is_symlink(self) -> bool: - return await to_thread.run_sync(self._path.is_symlink, cancellable=True) + return await to_thread.run_sync(self._path.is_symlink, abandon_on_cancel=True) def iterdir(self) -> AsyncIterator[Path]: gen = self._path.iterdir() @@ -452,7 +461,7 @@ class Path: await to_thread.run_sync(self._path.lchmod, mode) async def lstat(self) -> os.stat_result: - return await to_thread.run_sync(self._path.lstat, cancellable=True) + return await to_thread.run_sync(self._path.lstat, abandon_on_cancel=True) async def mkdir( self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False @@ -495,7 +504,7 @@ class Path: return AsyncFile(fp) async def owner(self) -> str: - return await to_thread.run_sync(self._path.owner, cancellable=True) + return await to_thread.run_sync(self._path.owner, abandon_on_cancel=True) async def read_bytes(self) -> bytes: return await to_thread.run_sync(self._path.read_bytes) @@ -505,12 +514,21 @@ class Path: ) -> str: return await to_thread.run_sync(self._path.read_text, encoding, errors) - def relative_to(self, *other: str | PathLike[str]) -> Path: - return Path(self._path.relative_to(*other)) + if sys.version_info >= (3, 12): + + def relative_to( + self, *other: str | PathLike[str], walk_up: bool = False + ) -> Path: + return Path(self._path.relative_to(*other, walk_up=walk_up)) + + else: + + def relative_to(self, *other: str | PathLike[str]) -> Path: + return Path(self._path.relative_to(*other)) async def readlink(self) -> Path: target = await to_thread.run_sync(os.readlink, self._path) - return Path(cast(str, target)) + return Path(target) async def rename(self, target: str | pathlib.PurePath | Path) -> Path: if isinstance(target, Path): @@ -528,7 +546,7 @@ class Path: async def resolve(self, strict: bool = False) -> Path: func = partial(self._path.resolve, strict=strict) - return Path(await to_thread.run_sync(func, cancellable=True)) + return Path(await to_thread.run_sync(func, abandon_on_cancel=True)) def rglob(self, pattern: str) -> AsyncIterator[Path]: gen = self._path.rglob(pattern) @@ -537,23 +555,21 @@ class Path: async def rmdir(self) -> None: await to_thread.run_sync(self._path.rmdir) - async def samefile( - self, other_path: str | bytes | int | pathlib.Path | Path - ) -> bool: + async def samefile(self, other_path: str | PathLike[str]) -> bool: if isinstance(other_path, Path): other_path = other_path._path return await to_thread.run_sync( - self._path.samefile, other_path, cancellable=True + self._path.samefile, other_path, abandon_on_cancel=True ) async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result: func = partial(os.stat, follow_symlinks=follow_symlinks) - return await to_thread.run_sync(func, self._path, cancellable=True) + return await to_thread.run_sync(func, self._path, abandon_on_cancel=True) async def symlink_to( self, - target: str | pathlib.Path | Path, + target: str | bytes | PathLike[str] | PathLike[bytes], target_is_directory: bool = False, ) -> None: if isinstance(target, Path): @@ -571,6 +587,29 @@ class Path: if not missing_ok: raise + if sys.version_info >= (3, 12): + + async def walk( + self, + top_down: bool = True, + on_error: Callable[[OSError], object] | None = None, + follow_symlinks: bool = False, + ) -> AsyncIterator[tuple[Path, list[str], list[str]]]: + def get_next_value() -> tuple[pathlib.Path, list[str], list[str]] | None: + try: + return next(gen) + except StopIteration: + return None + + gen = self._path.walk(top_down, on_error, follow_symlinks) + while True: + value = await to_thread.run_sync(get_next_value) + if value is None: + return + + root, dirs, paths = value + yield Path(root), dirs, paths + def with_name(self, name: str) -> Path: return Path(self._path.with_name(name)) @@ -580,6 +619,9 @@ class Path: def with_suffix(self, suffix: str) -> Path: return Path(self._path.with_suffix(suffix)) + def with_segments(self, *pathsegments: str | PathLike[str]) -> Path: + return Path(*pathsegments) + async def write_bytes(self, data: bytes) -> int: return await to_thread.run_sync(self._path.write_bytes, data) diff --git a/contrib/python/anyio/anyio/_core/_signals.py b/contrib/python/anyio/anyio/_core/_signals.py index 8ea54af86c..115c749bd9 100644 --- a/contrib/python/anyio/anyio/_core/_signals.py +++ b/contrib/python/anyio/anyio/_core/_signals.py @@ -1,26 +1,25 @@ from __future__ import annotations -from typing import AsyncIterator +from collections.abc import AsyncIterator +from signal import Signals +from typing import ContextManager -from ._compat import DeprecatedAsyncContextManager -from ._eventloop import get_asynclib +from ._eventloop import get_async_backend -def open_signal_receiver( - *signals: int, -) -> DeprecatedAsyncContextManager[AsyncIterator[int]]: +def open_signal_receiver(*signals: Signals) -> ContextManager[AsyncIterator[Signals]]: """ Start receiving operating system signals. :param signals: signals to receive (e.g. ``signal.SIGINT``) - :return: an asynchronous context manager for an asynchronous iterator which yields signal - numbers + :return: an asynchronous context manager for an asynchronous iterator which yields + signal numbers - .. warning:: Windows does not support signals natively so it is best to avoid relying on this - in cross-platform applications. + .. warning:: Windows does not support signals natively so it is best to avoid + relying on this in cross-platform applications. - .. warning:: On asyncio, this permanently replaces any previous signal handler for the given - signals, as set via :meth:`~asyncio.loop.add_signal_handler`. + .. warning:: On asyncio, this permanently replaces any previous signal handler for + the given signals, as set via :meth:`~asyncio.loop.add_signal_handler`. """ - return get_asynclib().open_signal_receiver(*signals) + return get_async_backend().open_signal_receiver(*signals) diff --git a/contrib/python/anyio/anyio/_core/_sockets.py b/contrib/python/anyio/anyio/_core/_sockets.py index e6970bee27..0f0a3142fb 100644 --- a/contrib/python/anyio/anyio/_core/_sockets.py +++ b/contrib/python/anyio/anyio/_core/_sockets.py @@ -1,41 +1,41 @@ from __future__ import annotations +import errno +import os import socket import ssl +import stat import sys +from collections.abc import Awaitable from ipaddress import IPv6Address, ip_address from os import PathLike, chmod -from pathlib import Path from socket import AddressFamily, SocketKind -from typing import Awaitable, List, Tuple, cast, overload +from typing import Any, Literal, cast, overload from .. import to_thread from ..abc import ( ConnectedUDPSocket, + ConnectedUNIXDatagramSocket, IPAddressType, IPSockAddrType, SocketListener, SocketStream, UDPSocket, + UNIXDatagramSocket, UNIXSocketStream, ) from ..streams.stapled import MultiListener from ..streams.tls import TLSStream -from ._eventloop import get_asynclib +from ._eventloop import get_async_backend from ._resources import aclose_forcefully from ._synchronization import Event from ._tasks import create_task_group, move_on_after -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515 -GetAddrInfoReturnType = List[ - Tuple[AddressFamily, SocketKind, int, str, Tuple[str, int]] -] AnyIPAddressFamily = Literal[ AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6 ] @@ -142,18 +142,21 @@ async def connect_tcp( :param remote_host: the IP address or host name to connect to :param remote_port: port on the target host to connect to - :param local_host: the interface address or name to bind the socket to before connecting + :param local_host: the interface address or name to bind the socket to before + connecting :param tls: ``True`` to do a TLS handshake with the connected stream and return a :class:`~anyio.streams.tls.TLSStream` instead - :param ssl_context: the SSL context object to use (if omitted, a default context is created) - :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake before closing - the stream and requires that the server does this as well. Otherwise, - :exc:`~ssl.SSLEOFError` may be raised during reads from the stream. + :param ssl_context: the SSL context object to use (if omitted, a default context is + created) + :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake + before closing the stream and requires that the server does this as well. + Otherwise, :exc:`~ssl.SSLEOFError` may be raised during reads from the stream. Some protocols, such as HTTP, require this option to be ``False``. See :meth:`~ssl.SSLContext.wrap_socket` for details. - :param tls_hostname: host name to check the server certificate against (defaults to the value - of ``remote_host``) - :param happy_eyeballs_delay: delay (in seconds) before starting the next connection attempt + :param tls_hostname: host name to check the server certificate against (defaults to + the value of ``remote_host``) + :param happy_eyeballs_delay: delay (in seconds) before starting the next connection + attempt :return: a socket stream object if no TLS handshake was done, otherwise a TLS stream :raises OSError: if the connection attempt fails @@ -177,7 +180,7 @@ async def connect_tcp( finally: event.set() - asynclib = get_asynclib() + asynclib = get_async_backend() local_address: IPSockAddrType | None = None family = socket.AF_UNSPEC if local_host: @@ -193,8 +196,8 @@ async def connect_tcp( target_host, remote_port, family=family, type=socket.SOCK_STREAM ) - # Organize the list so that the first address is an IPv6 address (if available) and the - # second one is an IPv4 addresses. The rest can be in whatever order. + # Organize the list so that the first address is an IPv6 address (if available) + # and the second one is an IPv4 addresses. The rest can be in whatever order. v6_found = v4_found = False target_addrs: list[tuple[socket.AddressFamily, str]] = [] for af, *rest, sa in gai_res: @@ -221,7 +224,11 @@ async def connect_tcp( await event.wait() if connected_stream is None: - cause = oserrors[0] if len(oserrors) == 1 else asynclib.ExceptionGroup(oserrors) + cause = ( + oserrors[0] + if len(oserrors) == 1 + else ExceptionGroup("multiple connection attempts failed", oserrors) + ) raise OSError("All connection attempts failed") from cause if tls or tls_hostname or ssl_context: @@ -240,7 +247,7 @@ async def connect_tcp( return connected_stream -async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream: +async def connect_unix(path: str | bytes | PathLike[Any]) -> UNIXSocketStream: """ Connect to the given UNIX socket. @@ -250,8 +257,8 @@ async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream: :return: a socket stream object """ - path = str(Path(path)) - return await get_asynclib().connect_unix(path) + path = os.fspath(path) + return await get_async_backend().connect_unix(path) async def create_tcp_listener( @@ -277,11 +284,11 @@ async def create_tcp_listener( :return: a list of listener objects """ - asynclib = get_asynclib() + asynclib = get_async_backend() backlog = min(backlog, 65536) local_host = str(local_host) if local_host is not None else None gai_res = await getaddrinfo( - local_host, # type: ignore[arg-type] + local_host, local_port, family=family, type=socket.SocketKind.SOCK_STREAM if sys.platform == "win32" else 0, @@ -302,7 +309,8 @@ async def create_tcp_listener( raw_socket = socket.socket(fam) raw_socket.setblocking(False) - # For Windows, enable exclusive address use. For others, enable address reuse. + # For Windows, enable exclusive address use. For others, enable address + # reuse. if sys.platform == "win32": raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) else: @@ -322,7 +330,7 @@ async def create_tcp_listener( raw_socket.bind(sockaddr) raw_socket.listen(backlog) - listener = asynclib.TCPSocketListener(raw_socket) + listener = asynclib.create_tcp_listener(raw_socket) listeners.append(listener) except BaseException: for listener in listeners: @@ -334,7 +342,7 @@ async def create_tcp_listener( async def create_unix_listener( - path: str | PathLike[str], + path: str | bytes | PathLike[Any], *, mode: int | None = None, backlog: int = 65536, @@ -346,29 +354,20 @@ async def create_unix_listener( :param path: path of the socket :param mode: permissions to set on the socket - :param backlog: maximum number of queued incoming connections (up to a maximum of 2**16, or - 65536) + :param backlog: maximum number of queued incoming connections (up to a maximum of + 2**16, or 65536) :return: a listener object .. versionchanged:: 3.0 - If a socket already exists on the file system in the given path, it will be removed first. + If a socket already exists on the file system in the given path, it will be + removed first. """ - path_str = str(path) - path = Path(path) - if path.is_socket(): - path.unlink() - backlog = min(backlog, 65536) - raw_socket = socket.socket(socket.AF_UNIX) - raw_socket.setblocking(False) + raw_socket = await setup_unix_local_socket(path, mode, socket.SOCK_STREAM) try: - await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True) - if mode is not None: - await to_thread.run_sync(chmod, path_str, mode, cancellable=True) - raw_socket.listen(backlog) - return get_asynclib().UNIXSocketListener(raw_socket) + return get_async_backend().create_unix_listener(raw_socket) except BaseException: raw_socket.close() raise @@ -384,15 +383,15 @@ async def create_udp_socket( """ Create a UDP socket. - If ``local_port`` has been given, the socket will be bound to this port on the local + If ``port`` has been given, the socket will be bound to this port on the local machine, making this socket suitable for providing UDP based services. - :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from - ``local_host`` if omitted + :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically + determined from ``local_host`` if omitted :param local_host: IP address or host name of the local interface to bind to :param local_port: local port to bind to - :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port - (not supported on Windows) + :param reuse_port: ``True`` to allow multiple sockets to bind to the same + address/port (not supported on Windows) :return: a UDP socket """ @@ -414,9 +413,10 @@ async def create_udp_socket( else: local_address = ("0.0.0.0", 0) - return await get_asynclib().create_udp_socket( + sock = await get_async_backend().create_udp_socket( family, local_address, None, reuse_port ) + return cast(UDPSocket, sock) async def create_connected_udp_socket( @@ -431,17 +431,17 @@ async def create_connected_udp_socket( """ Create a connected UDP socket. - Connected UDP sockets can only communicate with the specified remote host/port, and any packets - sent from other sources are dropped. + Connected UDP sockets can only communicate with the specified remote host/port, an + any packets sent from other sources are dropped. :param remote_host: remote host to set as the default target :param remote_port: port on the remote host to set as the default target - :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from - ``local_host`` or ``remote_host`` if omitted + :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically + determined from ``local_host`` or ``remote_host`` if omitted :param local_host: IP address or host name of the local interface to bind to :param local_port: local port to bind to - :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port - (not supported on Windows) + :param reuse_port: ``True`` to allow multiple sockets to bind to the same + address/port (not supported on Windows) :return: a connected UDP socket """ @@ -463,25 +463,87 @@ async def create_connected_udp_socket( family = cast(AnyIPAddressFamily, gai_res[0][0]) remote_address = gai_res[0][-1] - return await get_asynclib().create_udp_socket( + sock = await get_async_backend().create_udp_socket( family, local_address, remote_address, reuse_port ) + return cast(ConnectedUDPSocket, sock) + + +async def create_unix_datagram_socket( + *, + local_path: None | str | bytes | PathLike[Any] = None, + local_mode: int | None = None, +) -> UNIXDatagramSocket: + """ + Create a UNIX datagram socket. + + Not available on Windows. + + If ``local_path`` has been given, the socket will be bound to this path, making this + socket suitable for receiving datagrams from other processes. Other processes can + send datagrams to this socket only if ``local_path`` is set. + + If a socket already exists on the file system in the ``local_path``, it will be + removed first. + + :param local_path: the path on which to bind to + :param local_mode: permissions to set on the local socket + :return: a UNIX datagram socket + + """ + raw_socket = await setup_unix_local_socket( + local_path, local_mode, socket.SOCK_DGRAM + ) + return await get_async_backend().create_unix_datagram_socket(raw_socket, None) + + +async def create_connected_unix_datagram_socket( + remote_path: str | bytes | PathLike[Any], + *, + local_path: None | str | bytes | PathLike[Any] = None, + local_mode: int | None = None, +) -> ConnectedUNIXDatagramSocket: + """ + Create a connected UNIX datagram socket. + + Connected datagram sockets can only communicate with the specified remote path. + + If ``local_path`` has been given, the socket will be bound to this path, making + this socket suitable for receiving datagrams from other processes. Other processes + can send datagrams to this socket only if ``local_path`` is set. + + If a socket already exists on the file system in the ``local_path``, it will be + removed first. + + :param remote_path: the path to set as the default target + :param local_path: the path on which to bind to + :param local_mode: permissions to set on the local socket + :return: a connected UNIX datagram socket + + """ + remote_path = os.fspath(remote_path) + raw_socket = await setup_unix_local_socket( + local_path, local_mode, socket.SOCK_DGRAM + ) + return await get_async_backend().create_unix_datagram_socket( + raw_socket, remote_path + ) async def getaddrinfo( - host: bytearray | bytes | str, + host: bytes | str | None, port: str | int | None, *, family: int | AddressFamily = 0, type: int | SocketKind = 0, proto: int = 0, flags: int = 0, -) -> GetAddrInfoReturnType: +) -> list[tuple[AddressFamily, SocketKind, int, str, tuple[str, int]]]: """ Look up a numeric IP address given a host name. - Internationalized domain names are translated according to the (non-transitional) IDNA 2008 - standard. + Internationalized domain names are translated according to the (non-transitional) + IDNA 2008 standard. .. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of (host, port), unlike what :func:`socket.getaddrinfo` does. @@ -500,7 +562,7 @@ async def getaddrinfo( # Handle unicode hostnames if isinstance(host, str): try: - encoded_host = host.encode("ascii") + encoded_host: bytes | None = host.encode("ascii") except UnicodeEncodeError: import idna @@ -508,7 +570,7 @@ async def getaddrinfo( else: encoded_host = host - gai_res = await get_asynclib().getaddrinfo( + gai_res = await get_async_backend().getaddrinfo( encoded_host, port, family=family, type=type, proto=proto, flags=flags ) return [ @@ -528,18 +590,18 @@ def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str .. seealso:: :func:`socket.getnameinfo` """ - return get_asynclib().getnameinfo(sockaddr, flags) + return get_async_backend().getnameinfo(sockaddr, flags) def wait_socket_readable(sock: socket.socket) -> Awaitable[None]: """ Wait until the given socket has data to be read. - This does **NOT** work on Windows when using the asyncio backend with a proactor event loop - (default on py3.8+). + This does **NOT** work on Windows when using the asyncio backend with a proactor + event loop (default on py3.8+). - .. warning:: Only use this on raw sockets that have not been wrapped by any higher level - constructs like socket streams! + .. warning:: Only use this on raw sockets that have not been wrapped by any higher + level constructs like socket streams! :param sock: a socket object :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the @@ -548,18 +610,18 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]: to become readable """ - return get_asynclib().wait_socket_readable(sock) + return get_async_backend().wait_socket_readable(sock) def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: """ Wait until the given socket can be written to. - This does **NOT** work on Windows when using the asyncio backend with a proactor event loop - (default on py3.8+). + This does **NOT** work on Windows when using the asyncio backend with a proactor + event loop (default on py3.8+). - .. warning:: Only use this on raw sockets that have not been wrapped by any higher level - constructs like socket streams! + .. warning:: Only use this on raw sockets that have not been wrapped by any higher + level constructs like socket streams! :param sock: a socket object :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the @@ -568,7 +630,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: to become writable """ - return get_asynclib().wait_socket_writable(sock) + return get_async_backend().wait_socket_writable(sock) # @@ -577,7 +639,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: def convert_ipv6_sockaddr( - sockaddr: tuple[str, int, int, int] | tuple[str, int] + sockaddr: tuple[str, int, int, int] | tuple[str, int], ) -> tuple[str, int]: """ Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format. @@ -592,7 +654,7 @@ def convert_ipv6_sockaddr( """ # This is more complicated than it should be because of MyPy if isinstance(sockaddr, tuple) and len(sockaddr) == 4: - host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr) + host, port, flowinfo, scope_id = sockaddr if scope_id: # PyPy (as of v7.3.11) leaves the interface name in the result, so # we discard it and only get the scope ID from the end @@ -604,4 +666,51 @@ def convert_ipv6_sockaddr( else: return host, port else: - return cast(Tuple[str, int], sockaddr) + return sockaddr + + +async def setup_unix_local_socket( + path: None | str | bytes | PathLike[Any], + mode: int | None, + socktype: int, +) -> socket.socket: + """ + Create a UNIX local socket object, deleting the socket at the given path if it + exists. + + Not available on Windows. + + :param path: path of the socket + :param mode: permissions to set on the socket + :param socktype: socket.SOCK_STREAM or socket.SOCK_DGRAM + + """ + path_str: str | bytes | None + if path is not None: + path_str = os.fspath(path) + + # Copied from pathlib... + try: + stat_result = os.stat(path) + except OSError as e: + if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP): + raise + else: + if stat.S_ISSOCK(stat_result.st_mode): + os.unlink(path) + else: + path_str = None + + raw_socket = socket.socket(socket.AF_UNIX, socktype) + raw_socket.setblocking(False) + + if path_str is not None: + try: + await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True) + if mode is not None: + await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True) + except BaseException: + raw_socket.close() + raise + + return raw_socket diff --git a/contrib/python/anyio/anyio/_core/_streams.py b/contrib/python/anyio/anyio/_core/_streams.py index 54ea2b2baf..aa6b0c222a 100644 --- a/contrib/python/anyio/anyio/_core/_streams.py +++ b/contrib/python/anyio/anyio/_core/_streams.py @@ -1,7 +1,8 @@ from __future__ import annotations import math -from typing import Any, TypeVar, overload +from typing import Tuple, TypeVar +from warnings import warn from ..streams.memory import ( MemoryObjectReceiveStream, @@ -12,36 +13,40 @@ from ..streams.memory import ( T_Item = TypeVar("T_Item") -@overload -def create_memory_object_stream( - max_buffer_size: float = ..., -) -> tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: - ... - - -@overload -def create_memory_object_stream( - max_buffer_size: float = ..., item_type: type[T_Item] = ... -) -> tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]]: - ... - - -def create_memory_object_stream( - max_buffer_size: float = 0, item_type: type[T_Item] | None = None -) -> tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: +class create_memory_object_stream( + Tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]], +): """ Create a memory object stream. - :param max_buffer_size: number of items held in the buffer until ``send()`` starts blocking - :param item_type: type of item, for marking the streams with the right generic type for - static typing (not used at run time) + The stream's item type can be annotated like + :func:`create_memory_object_stream[T_Item]`. + + :param max_buffer_size: number of items held in the buffer until ``send()`` starts + blocking + :param item_type: old way of marking the streams with the right generic type for + static typing (does nothing on AnyIO 4) + + .. deprecated:: 4.0 + Use ``create_memory_object_stream[YourItemType](...)`` instead. :return: a tuple of (send stream, receive stream) """ - if max_buffer_size != math.inf and not isinstance(max_buffer_size, int): - raise ValueError("max_buffer_size must be either an integer or math.inf") - if max_buffer_size < 0: - raise ValueError("max_buffer_size cannot be negative") - state: MemoryObjectStreamState = MemoryObjectStreamState(max_buffer_size) - return MemoryObjectSendStream(state), MemoryObjectReceiveStream(state) + def __new__( # type: ignore[misc] + cls, max_buffer_size: float = 0, item_type: object = None + ) -> tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]]: + if max_buffer_size != math.inf and not isinstance(max_buffer_size, int): + raise ValueError("max_buffer_size must be either an integer or math.inf") + if max_buffer_size < 0: + raise ValueError("max_buffer_size cannot be negative") + if item_type is not None: + warn( + "The item_type argument has been deprecated in AnyIO 4.0. " + "Use create_memory_object_stream[YourItemType](...) instead.", + DeprecationWarning, + stacklevel=2, + ) + + state = MemoryObjectStreamState[T_Item](max_buffer_size) + return (MemoryObjectSendStream(state), MemoryObjectReceiveStream(state)) diff --git a/contrib/python/anyio/anyio/_core/_subprocesses.py b/contrib/python/anyio/anyio/_core/_subprocesses.py index 1a26ac8c7f..5d5d7b768a 100644 --- a/contrib/python/anyio/anyio/_core/_subprocesses.py +++ b/contrib/python/anyio/anyio/_core/_subprocesses.py @@ -1,19 +1,13 @@ from __future__ import annotations +from collections.abc import AsyncIterable, Mapping, Sequence from io import BytesIO from os import PathLike from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess -from typing import ( - IO, - Any, - AsyncIterable, - Mapping, - Sequence, - cast, -) +from typing import IO, Any, cast from ..abc import Process -from ._eventloop import get_asynclib +from ._eventloop import get_async_backend from ._tasks import create_task_group @@ -33,22 +27,24 @@ async def run_process( .. seealso:: :func:`subprocess.run` - :param command: either a string to pass to the shell, or an iterable of strings containing the - executable name or path and its arguments + :param command: either a string to pass to the shell, or an iterable of strings + containing the executable name or path and its arguments :param input: bytes passed to the standard input of the subprocess - :param stdout: either :data:`subprocess.PIPE` or :data:`subprocess.DEVNULL` - :param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL` or - :data:`subprocess.STDOUT` - :param check: if ``True``, raise :exc:`~subprocess.CalledProcessError` if the process - terminates with a return code other than 0 - :param cwd: If not ``None``, change the working directory to this before running the command - :param env: if not ``None``, this mapping replaces the inherited environment variables from the - parent process - :param start_new_session: if ``true`` the setsid() system call will be made in the child - process prior to the execution of the subprocess. (POSIX only) + :param stdout: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, + a file-like object, or `None` + :param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, + :data:`subprocess.STDOUT`, a file-like object, or `None` + :param check: if ``True``, raise :exc:`~subprocess.CalledProcessError` if the + process terminates with a return code other than 0 + :param cwd: If not ``None``, change the working directory to this before running the + command + :param env: if not ``None``, this mapping replaces the inherited environment + variables from the parent process + :param start_new_session: if ``true`` the setsid() system call will be made in the + child process prior to the execution of the subprocess. (POSIX only) :return: an object representing the completed process - :raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process exits with a - nonzero return code + :raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process + exits with a nonzero return code """ @@ -69,20 +65,18 @@ async def run_process( start_new_session=start_new_session, ) as process: stream_contents: list[bytes | None] = [None, None] - try: - async with create_task_group() as tg: - if process.stdout: - tg.start_soon(drain_stream, process.stdout, 0) - if process.stderr: - tg.start_soon(drain_stream, process.stderr, 1) - if process.stdin and input: - await process.stdin.send(input) - await process.stdin.aclose() - - await process.wait() - except BaseException: - process.kill() - raise + async with create_task_group() as tg: + if process.stdout: + tg.start_soon(drain_stream, process.stdout, 0) + + if process.stderr: + tg.start_soon(drain_stream, process.stderr, 1) + + if process.stdin and input: + await process.stdin.send(input) + await process.stdin.aclose() + + await process.wait() output, errors = stream_contents if check and process.returncode != 0: @@ -106,8 +100,8 @@ async def open_process( .. seealso:: :class:`subprocess.Popen` - :param command: either a string to pass to the shell, or an iterable of strings containing the - executable name or path and its arguments + :param command: either a string to pass to the shell, or an iterable of strings + containing the executable name or path and its arguments :param stdin: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, a file-like object, or ``None`` :param stdout: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, @@ -115,21 +109,32 @@ async def open_process( :param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, :data:`subprocess.STDOUT`, a file-like object, or ``None`` :param cwd: If not ``None``, the working directory is changed before executing - :param env: If env is not ``None``, it must be a mapping that defines the environment - variables for the new process - :param start_new_session: if ``true`` the setsid() system call will be made in the child - process prior to the execution of the subprocess. (POSIX only) + :param env: If env is not ``None``, it must be a mapping that defines the + environment variables for the new process + :param start_new_session: if ``true`` the setsid() system call will be made in the + child process prior to the execution of the subprocess. (POSIX only) :return: an asynchronous process object """ - shell = isinstance(command, str) - return await get_asynclib().open_process( - command, - shell=shell, - stdin=stdin, - stdout=stdout, - stderr=stderr, - cwd=cwd, - env=env, - start_new_session=start_new_session, - ) + if isinstance(command, (str, bytes)): + return await get_async_backend().open_process( + command, + shell=True, + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + else: + return await get_async_backend().open_process( + command, + shell=False, + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) diff --git a/contrib/python/anyio/anyio/_core/_synchronization.py b/contrib/python/anyio/anyio/_core/_synchronization.py index 783570c7ac..b274a31ea2 100644 --- a/contrib/python/anyio/anyio/_core/_synchronization.py +++ b/contrib/python/anyio/anyio/_core/_synchronization.py @@ -1,13 +1,14 @@ from __future__ import annotations +import math from collections import deque from dataclasses import dataclass from types import TracebackType -from warnings import warn + +from sniffio import AsyncLibraryNotFoundError from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled -from ._compat import DeprecatedAwaitable -from ._eventloop import get_asynclib +from ._eventloop import get_async_backend from ._exceptions import BusyResourceError, WouldBlock from ._tasks import CancelScope from ._testing import TaskInfo, get_current_task @@ -27,9 +28,10 @@ class CapacityLimiterStatistics: """ :ivar int borrowed_tokens: number of tokens currently borrowed by tasks :ivar float total_tokens: total number of available tokens - :ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from this - limiter - :ivar int tasks_waiting: number of tasks waiting on :meth:`~.CapacityLimiter.acquire` or + :ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from + this limiter + :ivar int tasks_waiting: number of tasks waiting on + :meth:`~.CapacityLimiter.acquire` or :meth:`~.CapacityLimiter.acquire_on_behalf_of` """ @@ -43,8 +45,8 @@ class CapacityLimiterStatistics: class LockStatistics: """ :ivar bool locked: flag indicating if this lock is locked or not - :ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the lock is not - held by any task) + :ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the + lock is not held by any task) :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Lock.acquire` """ @@ -57,7 +59,8 @@ class LockStatistics: class ConditionStatistics: """ :ivar int tasks_waiting: number of tasks blocked on :meth:`~.Condition.wait` - :ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying :class:`~.Lock` + :ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying + :class:`~.Lock` """ tasks_waiting: int @@ -76,9 +79,12 @@ class SemaphoreStatistics: class Event: def __new__(cls) -> Event: - return get_asynclib().Event() + try: + return get_async_backend().create_event() + except AsyncLibraryNotFoundError: + return EventAdapter() - def set(self) -> DeprecatedAwaitable: + def set(self) -> None: """Set the flag, notifying all listeners.""" raise NotImplementedError @@ -90,7 +96,8 @@ class Event: """ Wait until the flag has been set. - If the flag has already been set when this method is called, it returns immediately. + If the flag has already been set when this method is called, it returns + immediately. """ raise NotImplementedError @@ -100,6 +107,35 @@ class Event: raise NotImplementedError +class EventAdapter(Event): + _internal_event: Event | None = None + + def __new__(cls) -> EventAdapter: + return object.__new__(cls) + + @property + def _event(self) -> Event: + if self._internal_event is None: + self._internal_event = get_async_backend().create_event() + + return self._internal_event + + def set(self) -> None: + self._event.set() + + def is_set(self) -> bool: + return self._internal_event is not None and self._internal_event.is_set() + + async def wait(self) -> None: + await self._event.wait() + + def statistics(self) -> EventStatistics: + if self._internal_event is None: + return EventStatistics(tasks_waiting=0) + + return self._internal_event.statistics() + + class Lock: _owner_task: TaskInfo | None = None @@ -161,7 +197,7 @@ class Lock: self._owner_task = task - def release(self) -> DeprecatedAwaitable: + def release(self) -> None: """Release the lock.""" if self._owner_task != get_current_task(): raise RuntimeError("The current task is not holding this lock") @@ -172,8 +208,6 @@ class Lock: else: del self._owner_task - return DeprecatedAwaitable(self.release) - def locked(self) -> bool: """Return True if the lock is currently held.""" return self._owner_task is not None @@ -224,10 +258,9 @@ class Condition: self._lock.acquire_nowait() self._owner_task = get_current_task() - def release(self) -> DeprecatedAwaitable: + def release(self) -> None: """Release the underlying lock.""" self._lock.release() - return DeprecatedAwaitable(self.release) def locked(self) -> bool: """Return True if the lock is set.""" @@ -344,7 +377,7 @@ class Semaphore: self._value -= 1 - def release(self) -> DeprecatedAwaitable: + def release(self) -> None: """Increment the semaphore value.""" if self._max_value is not None and self._value == self._max_value: raise ValueError("semaphore released too many times") @@ -354,8 +387,6 @@ class Semaphore: else: self._value += 1 - return DeprecatedAwaitable(self.release) - @property def value(self) -> int: """The current value of the semaphore.""" @@ -377,7 +408,10 @@ class Semaphore: class CapacityLimiter: def __new__(cls, total_tokens: float) -> CapacityLimiter: - return get_asynclib().CapacityLimiter(total_tokens) + try: + return get_async_backend().create_capacity_limiter(total_tokens) + except AsyncLibraryNotFoundError: + return CapacityLimiterAdapter(total_tokens) async def __aenter__(self) -> None: raise NotImplementedError @@ -396,7 +430,8 @@ class CapacityLimiter: The total number of tokens available for borrowing. This is a read-write property. If the total number of tokens is increased, the - proportionate number of tasks waiting on this limiter will be granted their tokens. + proportionate number of tasks waiting on this limiter will be granted their + tokens. .. versionchanged:: 3.0 The property is now writable. @@ -408,14 +443,6 @@ class CapacityLimiter: def total_tokens(self, value: float) -> None: raise NotImplementedError - async def set_total_tokens(self, value: float) -> None: - warn( - "CapacityLimiter.set_total_tokens has been deprecated. Set the value of the" - '"total_tokens" attribute directly.', - DeprecationWarning, - ) - self.total_tokens = value - @property def borrowed_tokens(self) -> int: """The number of tokens that have currently been borrowed.""" @@ -426,16 +453,17 @@ class CapacityLimiter: """The number of tokens currently available to be borrowed""" raise NotImplementedError - def acquire_nowait(self) -> DeprecatedAwaitable: + def acquire_nowait(self) -> None: """ - Acquire a token for the current task without waiting for one to become available. + Acquire a token for the current task without waiting for one to become + available. :raises ~anyio.WouldBlock: if there are no tokens available for borrowing """ raise NotImplementedError - def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: + def acquire_on_behalf_of_nowait(self, borrower: object) -> None: """ Acquire a token without waiting for one to become available. @@ -447,7 +475,8 @@ class CapacityLimiter: async def acquire(self) -> None: """ - Acquire a token for the current task, waiting if necessary for one to become available. + Acquire a token for the current task, waiting if necessary for one to become + available. """ raise NotImplementedError @@ -464,7 +493,9 @@ class CapacityLimiter: def release(self) -> None: """ Release the token held by the current task. - :raises RuntimeError: if the current task has not borrowed a token from this limiter. + + :raises RuntimeError: if the current task has not borrowed a token from this + limiter. """ raise NotImplementedError @@ -473,7 +504,8 @@ class CapacityLimiter: """ Release the token held by the given borrower. - :raises RuntimeError: if the borrower has not borrowed a token from this limiter. + :raises RuntimeError: if the borrower has not borrowed a token from this + limiter. """ raise NotImplementedError @@ -488,96 +520,117 @@ class CapacityLimiter: raise NotImplementedError -def create_lock() -> Lock: - """ - Create an asynchronous lock. +class CapacityLimiterAdapter(CapacityLimiter): + _internal_limiter: CapacityLimiter | None = None - :return: a lock object + def __new__(cls, total_tokens: float) -> CapacityLimiterAdapter: + return object.__new__(cls) - .. deprecated:: 3.0 - Use :class:`~Lock` directly. + def __init__(self, total_tokens: float) -> None: + self.total_tokens = total_tokens - """ - warn("create_lock() is deprecated -- use Lock() directly", DeprecationWarning) - return Lock() + @property + def _limiter(self) -> CapacityLimiter: + if self._internal_limiter is None: + self._internal_limiter = get_async_backend().create_capacity_limiter( + self._total_tokens + ) + return self._internal_limiter -def create_condition(lock: Lock | None = None) -> Condition: - """ - Create an asynchronous condition. + async def __aenter__(self) -> None: + await self._limiter.__aenter__() - :param lock: the lock to base the condition object on - :return: a condition object + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return await self._limiter.__aexit__(exc_type, exc_val, exc_tb) - .. deprecated:: 3.0 - Use :class:`~Condition` directly. + @property + def total_tokens(self) -> float: + if self._internal_limiter is None: + return self._total_tokens - """ - warn( - "create_condition() is deprecated -- use Condition() directly", - DeprecationWarning, - ) - return Condition(lock=lock) + return self._internal_limiter.total_tokens + @total_tokens.setter + def total_tokens(self, value: float) -> None: + if not isinstance(value, int) and value is not math.inf: + raise TypeError("total_tokens must be an int or math.inf") + elif value < 1: + raise ValueError("total_tokens must be >= 1") -def create_event() -> Event: - """ - Create an asynchronous event object. + if self._internal_limiter is None: + self._total_tokens = value + return - :return: an event object + self._limiter.total_tokens = value - .. deprecated:: 3.0 - Use :class:`~Event` directly. + @property + def borrowed_tokens(self) -> int: + if self._internal_limiter is None: + return 0 - """ - warn("create_event() is deprecated -- use Event() directly", DeprecationWarning) - return get_asynclib().Event() + return self._internal_limiter.borrowed_tokens + @property + def available_tokens(self) -> float: + if self._internal_limiter is None: + return self._total_tokens -def create_semaphore(value: int, *, max_value: int | None = None) -> Semaphore: - """ - Create an asynchronous semaphore. + return self._internal_limiter.available_tokens - :param value: the semaphore's initial value - :param max_value: if set, makes this a "bounded" semaphore that raises :exc:`ValueError` if the - semaphore's value would exceed this number - :return: a semaphore object + def acquire_nowait(self) -> None: + self._limiter.acquire_nowait() - .. deprecated:: 3.0 - Use :class:`~Semaphore` directly. + def acquire_on_behalf_of_nowait(self, borrower: object) -> None: + self._limiter.acquire_on_behalf_of_nowait(borrower) - """ - warn( - "create_semaphore() is deprecated -- use Semaphore() directly", - DeprecationWarning, - ) - return Semaphore(value, max_value=max_value) + async def acquire(self) -> None: + await self._limiter.acquire() + async def acquire_on_behalf_of(self, borrower: object) -> None: + await self._limiter.acquire_on_behalf_of(borrower) -def create_capacity_limiter(total_tokens: float) -> CapacityLimiter: - """ - Create a capacity limiter. + def release(self) -> None: + self._limiter.release() - :param total_tokens: the total number of tokens available for borrowing (can be an integer or - :data:`math.inf`) - :return: a capacity limiter object + def release_on_behalf_of(self, borrower: object) -> None: + self._limiter.release_on_behalf_of(borrower) - .. deprecated:: 3.0 - Use :class:`~CapacityLimiter` directly. + def statistics(self) -> CapacityLimiterStatistics: + if self._internal_limiter is None: + return CapacityLimiterStatistics( + borrowed_tokens=0, + total_tokens=self.total_tokens, + borrowers=(), + tasks_waiting=0, + ) - """ - warn( - "create_capacity_limiter() is deprecated -- use CapacityLimiter() directly", - DeprecationWarning, - ) - return get_asynclib().CapacityLimiter(total_tokens) + return self._internal_limiter.statistics() class ResourceGuard: + """ + A context manager for ensuring that a resource is only used by a single task at a + time. + + Entering this context manager while the previous has not exited it yet will trigger + :exc:`BusyResourceError`. + + :param action: the action to guard against (visible in the :exc:`BusyResourceError` + when triggered, e.g. "Another task is already {action} this resource") + + .. versionadded:: 4.1 + """ + __slots__ = "action", "_guarded" - def __init__(self, action: str): - self.action = action + def __init__(self, action: str = "using"): + self.action: str = action self._guarded = False def __enter__(self) -> None: diff --git a/contrib/python/anyio/anyio/_core/_tasks.py b/contrib/python/anyio/anyio/_core/_tasks.py index e9d9c2bd67..2f21ea20b1 100644 --- a/contrib/python/anyio/anyio/_core/_tasks.py +++ b/contrib/python/anyio/anyio/_core/_tasks.py @@ -1,16 +1,12 @@ from __future__ import annotations import math +from collections.abc import Generator +from contextlib import contextmanager from types import TracebackType -from warnings import warn from ..abc._tasks import TaskGroup, TaskStatus -from ._compat import ( - DeprecatedAsyncContextManager, - DeprecatedAwaitable, - DeprecatedAwaitableFloat, -) -from ._eventloop import get_asynclib +from ._eventloop import get_async_backend class _IgnoredTaskStatus(TaskStatus[object]): @@ -21,7 +17,7 @@ class _IgnoredTaskStatus(TaskStatus[object]): TASK_STATUS_IGNORED = _IgnoredTaskStatus() -class CancelScope(DeprecatedAsyncContextManager["CancelScope"]): +class CancelScope: """ Wraps a unit of work that can be made separately cancellable. @@ -32,9 +28,9 @@ class CancelScope(DeprecatedAsyncContextManager["CancelScope"]): def __new__( cls, *, deadline: float = math.inf, shield: bool = False ) -> CancelScope: - return get_asynclib().CancelScope(shield=shield, deadline=deadline) + return get_async_backend().create_cancel_scope(shield=shield, deadline=deadline) - def cancel(self) -> DeprecatedAwaitable: + def cancel(self) -> None: """Cancel this scope immediately.""" raise NotImplementedError @@ -58,6 +54,19 @@ class CancelScope(DeprecatedAsyncContextManager["CancelScope"]): raise NotImplementedError @property + def cancelled_caught(self) -> bool: + """ + ``True`` if this scope suppressed a cancellation exception it itself raised. + + This is typically used to check if any work was interrupted, or to see if the + scope was cancelled due to its deadline being reached. The value will, however, + only be ``True`` if the cancellation was triggered by the scope itself (and not + an outer scope). + + """ + raise NotImplementedError + + @property def shield(self) -> bool: """ ``True`` if this scope is shielded from external cancellation. @@ -83,81 +92,52 @@ class CancelScope(DeprecatedAsyncContextManager["CancelScope"]): raise NotImplementedError -def open_cancel_scope(*, shield: bool = False) -> CancelScope: +@contextmanager +def fail_after( + delay: float | None, shield: bool = False +) -> Generator[CancelScope, None, None]: """ - Open a cancel scope. + Create a context manager which raises a :class:`TimeoutError` if does not finish in + time. - :param shield: ``True`` to shield the cancel scope from external cancellation - :return: a cancel scope - - .. deprecated:: 3.0 - Use :class:`~CancelScope` directly. - - """ - warn( - "open_cancel_scope() is deprecated -- use CancelScope() directly", - DeprecationWarning, - ) - return get_asynclib().CancelScope(shield=shield) - - -class FailAfterContextManager(DeprecatedAsyncContextManager[CancelScope]): - def __init__(self, cancel_scope: CancelScope): - self._cancel_scope = cancel_scope - - def __enter__(self) -> CancelScope: - return self._cancel_scope.__enter__() - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - retval = self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) - if self._cancel_scope.cancel_called: - raise TimeoutError - - return retval - - -def fail_after(delay: float | None, shield: bool = False) -> FailAfterContextManager: - """ - Create a context manager which raises a :class:`TimeoutError` if does not finish in time. - - :param delay: maximum allowed time (in seconds) before raising the exception, or ``None`` to - disable the timeout + :param delay: maximum allowed time (in seconds) before raising the exception, or + ``None`` to disable the timeout :param shield: ``True`` to shield the cancel scope from external cancellation :return: a context manager that yields a cancel scope :rtype: :class:`~typing.ContextManager`\\[:class:`~anyio.CancelScope`\\] """ - deadline = ( - (get_asynclib().current_time() + delay) if delay is not None else math.inf - ) - cancel_scope = get_asynclib().CancelScope(deadline=deadline, shield=shield) - return FailAfterContextManager(cancel_scope) + current_time = get_async_backend().current_time + deadline = (current_time() + delay) if delay is not None else math.inf + with get_async_backend().create_cancel_scope( + deadline=deadline, shield=shield + ) as cancel_scope: + yield cancel_scope + + if cancel_scope.cancelled_caught and current_time() >= cancel_scope.deadline: + raise TimeoutError def move_on_after(delay: float | None, shield: bool = False) -> CancelScope: """ Create a cancel scope with a deadline that expires after the given delay. - :param delay: maximum allowed time (in seconds) before exiting the context block, or ``None`` - to disable the timeout + :param delay: maximum allowed time (in seconds) before exiting the context block, or + ``None`` to disable the timeout :param shield: ``True`` to shield the cancel scope from external cancellation :return: a cancel scope """ deadline = ( - (get_asynclib().current_time() + delay) if delay is not None else math.inf + (get_async_backend().current_time() + delay) if delay is not None else math.inf ) - return get_asynclib().CancelScope(deadline=deadline, shield=shield) + return get_async_backend().create_cancel_scope(deadline=deadline, shield=shield) -def current_effective_deadline() -> DeprecatedAwaitableFloat: +def current_effective_deadline() -> float: """ - Return the nearest deadline among all the cancel scopes effective for the current task. + Return the nearest deadline among all the cancel scopes effective for the current + task. :return: a clock value from the event loop's internal clock (or ``float('inf')`` if there is no deadline in effect, or ``float('-inf')`` if the current scope has @@ -165,9 +145,7 @@ def current_effective_deadline() -> DeprecatedAwaitableFloat: :rtype: float """ - return DeprecatedAwaitableFloat( - get_asynclib().current_effective_deadline(), current_effective_deadline - ) + return get_async_backend().current_effective_deadline() def create_task_group() -> TaskGroup: @@ -177,4 +155,4 @@ def create_task_group() -> TaskGroup: :return: a task group """ - return get_asynclib().TaskGroup() + return get_async_backend().create_task_group() diff --git a/contrib/python/anyio/anyio/_core/_testing.py b/contrib/python/anyio/anyio/_core/_testing.py index c8191b3866..1dae3b193a 100644 --- a/contrib/python/anyio/anyio/_core/_testing.py +++ b/contrib/python/anyio/anyio/_core/_testing.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Any, Awaitable, Generator +from collections.abc import Awaitable, Generator +from typing import Any -from ._compat import DeprecatedAwaitableList, _warn_deprecation -from ._eventloop import get_asynclib +from ._eventloop import get_async_backend class TaskInfo: @@ -45,13 +45,6 @@ class TaskInfo: def __repr__(self) -> str: return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})" - def __await__(self) -> Generator[None, None, TaskInfo]: - _warn_deprecation(self) - if False: - yield - - return self - def _unwrap(self) -> TaskInfo: return self @@ -63,20 +56,19 @@ def get_current_task() -> TaskInfo: :return: a representation of the current task """ - return get_asynclib().get_current_task() + return get_async_backend().get_current_task() -def get_running_tasks() -> DeprecatedAwaitableList[TaskInfo]: +def get_running_tasks() -> list[TaskInfo]: """ Return a list of running tasks in the current event loop. :return: a list of task info objects """ - tasks = get_asynclib().get_running_tasks() - return DeprecatedAwaitableList(tasks, func=get_running_tasks) + return get_async_backend().get_running_tasks() async def wait_all_tasks_blocked() -> None: """Wait until all other tasks are waiting for something.""" - await get_asynclib().wait_all_tasks_blocked() + await get_async_backend().wait_all_tasks_blocked() diff --git a/contrib/python/anyio/anyio/_core/_typedattr.py b/contrib/python/anyio/anyio/_core/_typedattr.py index bf9202eeab..74c6b8fdcb 100644 --- a/contrib/python/anyio/anyio/_core/_typedattr.py +++ b/contrib/python/anyio/anyio/_core/_typedattr.py @@ -1,15 +1,10 @@ from __future__ import annotations -import sys -from typing import Any, Callable, Mapping, TypeVar, overload +from collections.abc import Callable, Mapping +from typing import Any, TypeVar, final, overload from ._exceptions import TypedAttributeLookupError -if sys.version_info >= (3, 8): - from typing import final -else: - from typing_extensions import final - T_Attr = TypeVar("T_Attr") T_Default = TypeVar("T_Default") undefined = object() @@ -44,11 +39,12 @@ class TypedAttributeProvider: @property def extra_attributes(self) -> Mapping[T_Attr, Callable[[], T_Attr]]: """ - A mapping of the extra attributes to callables that return the corresponding values. + A mapping of the extra attributes to callables that return the corresponding + values. - If the provider wraps another provider, the attributes from that wrapper should also be - included in the returned mapping (but the wrapper may override the callables from the - wrapped instance). + If the provider wraps another provider, the attributes from that wrapper should + also be included in the returned mapping (but the wrapper may override the + callables from the wrapped instance). """ return {} @@ -68,10 +64,12 @@ class TypedAttributeProvider: Return the value of the given typed extra attribute. - :param attribute: the attribute (member of a :class:`~TypedAttributeSet`) to look for - :param default: the value that should be returned if no value is found for the attribute - :raises ~anyio.TypedAttributeLookupError: if the search failed and no default value was - given + :param attribute: the attribute (member of a :class:`~TypedAttributeSet`) to + look for + :param default: the value that should be returned if no value is found for the + attribute + :raises ~anyio.TypedAttributeLookupError: if the search failed and no default + value was given """ try: diff --git a/contrib/python/anyio/anyio/abc/__init__.py b/contrib/python/anyio/anyio/abc/__init__.py index 72c34e544e..1ca0fcf746 100644 --- a/contrib/python/anyio/anyio/abc/__init__.py +++ b/contrib/python/anyio/anyio/abc/__init__.py @@ -1,86 +1,53 @@ from __future__ import annotations -__all__ = ( - "AsyncResource", - "IPAddressType", - "IPSockAddrType", - "SocketAttribute", - "SocketStream", - "SocketListener", - "UDPSocket", - "UNIXSocketStream", - "UDPPacketType", - "ConnectedUDPSocket", - "UnreliableObjectReceiveStream", - "UnreliableObjectSendStream", - "UnreliableObjectStream", - "ObjectReceiveStream", - "ObjectSendStream", - "ObjectStream", - "ByteReceiveStream", - "ByteSendStream", - "ByteStream", - "AnyUnreliableByteReceiveStream", - "AnyUnreliableByteSendStream", - "AnyUnreliableByteStream", - "AnyByteReceiveStream", - "AnyByteSendStream", - "AnyByteStream", - "Listener", - "Process", - "Event", - "Condition", - "Lock", - "Semaphore", - "CapacityLimiter", - "CancelScope", - "TaskGroup", - "TaskStatus", - "TestRunner", - "BlockingPortal", -) - from typing import Any -from ._resources import AsyncResource -from ._sockets import ( - ConnectedUDPSocket, - IPAddressType, - IPSockAddrType, - SocketAttribute, - SocketListener, - SocketStream, - UDPPacketType, - UDPSocket, - UNIXSocketStream, -) -from ._streams import ( - AnyByteReceiveStream, - AnyByteSendStream, - AnyByteStream, - AnyUnreliableByteReceiveStream, - AnyUnreliableByteSendStream, - AnyUnreliableByteStream, - ByteReceiveStream, - ByteSendStream, - ByteStream, - Listener, - ObjectReceiveStream, - ObjectSendStream, - ObjectStream, - UnreliableObjectReceiveStream, - UnreliableObjectSendStream, - UnreliableObjectStream, -) -from ._subprocesses import Process -from ._tasks import TaskGroup, TaskStatus -from ._testing import TestRunner +from ._eventloop import AsyncBackend as AsyncBackend +from ._resources import AsyncResource as AsyncResource +from ._sockets import ConnectedUDPSocket as ConnectedUDPSocket +from ._sockets import ConnectedUNIXDatagramSocket as ConnectedUNIXDatagramSocket +from ._sockets import IPAddressType as IPAddressType +from ._sockets import IPSockAddrType as IPSockAddrType +from ._sockets import SocketAttribute as SocketAttribute +from ._sockets import SocketListener as SocketListener +from ._sockets import SocketStream as SocketStream +from ._sockets import UDPPacketType as UDPPacketType +from ._sockets import UDPSocket as UDPSocket +from ._sockets import UNIXDatagramPacketType as UNIXDatagramPacketType +from ._sockets import UNIXDatagramSocket as UNIXDatagramSocket +from ._sockets import UNIXSocketStream as UNIXSocketStream +from ._streams import AnyByteReceiveStream as AnyByteReceiveStream +from ._streams import AnyByteSendStream as AnyByteSendStream +from ._streams import AnyByteStream as AnyByteStream +from ._streams import AnyUnreliableByteReceiveStream as AnyUnreliableByteReceiveStream +from ._streams import AnyUnreliableByteSendStream as AnyUnreliableByteSendStream +from ._streams import AnyUnreliableByteStream as AnyUnreliableByteStream +from ._streams import ByteReceiveStream as ByteReceiveStream +from ._streams import ByteSendStream as ByteSendStream +from ._streams import ByteStream as ByteStream +from ._streams import Listener as Listener +from ._streams import ObjectReceiveStream as ObjectReceiveStream +from ._streams import ObjectSendStream as ObjectSendStream +from ._streams import ObjectStream as ObjectStream +from ._streams import UnreliableObjectReceiveStream as UnreliableObjectReceiveStream +from ._streams import UnreliableObjectSendStream as UnreliableObjectSendStream +from ._streams import UnreliableObjectStream as UnreliableObjectStream +from ._subprocesses import Process as Process +from ._tasks import TaskGroup as TaskGroup +from ._tasks import TaskStatus as TaskStatus +from ._testing import TestRunner as TestRunner # Re-exported here, for backwards compatibility # isort: off -from .._core._synchronization import CapacityLimiter, Condition, Event, Lock, Semaphore -from .._core._tasks import CancelScope -from ..from_thread import BlockingPortal +from .._core._synchronization import ( + CapacityLimiter as CapacityLimiter, + Condition as Condition, + Event as Event, + Lock as Lock, + Semaphore as Semaphore, +) +from .._core._tasks import CancelScope as CancelScope +from ..from_thread import BlockingPortal as BlockingPortal # Re-export imports so they look like they live directly in this package key: str diff --git a/contrib/python/anyio/anyio/abc/_eventloop.py b/contrib/python/anyio/anyio/abc/_eventloop.py new file mode 100644 index 0000000000..4470d83d24 --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_eventloop.py @@ -0,0 +1,392 @@ +from __future__ import annotations + +import math +import sys +from abc import ABCMeta, abstractmethod +from collections.abc import AsyncIterator, Awaitable, Mapping +from os import PathLike +from signal import Signals +from socket import AddressFamily, SocketKind, socket +from typing import ( + IO, + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Sequence, + TypeVar, + overload, +) + +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + +if TYPE_CHECKING: + from typing import Literal + + from .._core._synchronization import CapacityLimiter, Event + from .._core._tasks import CancelScope + from .._core._testing import TaskInfo + from ..from_thread import BlockingPortal + from ._sockets import ( + ConnectedUDPSocket, + ConnectedUNIXDatagramSocket, + IPSockAddrType, + SocketListener, + SocketStream, + UDPSocket, + UNIXDatagramSocket, + UNIXSocketStream, + ) + from ._subprocesses import Process + from ._tasks import TaskGroup + from ._testing import TestRunner + +T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + + +class AsyncBackend(metaclass=ABCMeta): + @classmethod + @abstractmethod + def run( + cls, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], + kwargs: dict[str, Any], + options: dict[str, Any], + ) -> T_Retval: + """ + Run the given coroutine function in an asynchronous event loop. + + The current thread must not be already running an event loop. + + :param func: a coroutine function + :param args: positional arguments to ``func`` + :param kwargs: positional arguments to ``func`` + :param options: keyword arguments to call the backend ``run()`` implementation + with + :return: the return value of the coroutine function + """ + + @classmethod + @abstractmethod + def current_token(cls) -> object: + """ + + :return: + """ + + @classmethod + @abstractmethod + def current_time(cls) -> float: + """ + Return the current value of the event loop's internal clock. + + :return: the clock value (seconds) + """ + + @classmethod + @abstractmethod + def cancelled_exception_class(cls) -> type[BaseException]: + """Return the exception class that is raised in a task if it's cancelled.""" + + @classmethod + @abstractmethod + async def checkpoint(cls) -> None: + """ + Check if the task has been cancelled, and allow rescheduling of other tasks. + + This is effectively the same as running :meth:`checkpoint_if_cancelled` and then + :meth:`cancel_shielded_checkpoint`. + """ + + @classmethod + async def checkpoint_if_cancelled(cls) -> None: + """ + Check if the current task group has been cancelled. + + This will check if the task has been cancelled, but will not allow other tasks + to be scheduled if not. + + """ + if cls.current_effective_deadline() == -math.inf: + await cls.checkpoint() + + @classmethod + async def cancel_shielded_checkpoint(cls) -> None: + """ + Allow the rescheduling of other tasks. + + This will give other tasks the opportunity to run, but without checking if the + current task group has been cancelled, unlike with :meth:`checkpoint`. + + """ + with cls.create_cancel_scope(shield=True): + await cls.sleep(0) + + @classmethod + @abstractmethod + async def sleep(cls, delay: float) -> None: + """ + Pause the current task for the specified duration. + + :param delay: the duration, in seconds + """ + + @classmethod + @abstractmethod + def create_cancel_scope( + cls, *, deadline: float = math.inf, shield: bool = False + ) -> CancelScope: + pass + + @classmethod + @abstractmethod + def current_effective_deadline(cls) -> float: + """ + Return the nearest deadline among all the cancel scopes effective for the + current task. + + :return: + - a clock value from the event loop's internal clock + - ``inf`` if there is no deadline in effect + - ``-inf`` if the current scope has been cancelled + :rtype: float + """ + + @classmethod + @abstractmethod + def create_task_group(cls) -> TaskGroup: + pass + + @classmethod + @abstractmethod + def create_event(cls) -> Event: + pass + + @classmethod + @abstractmethod + def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter: + pass + + @classmethod + @abstractmethod + async def run_sync_in_worker_thread( + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + abandon_on_cancel: bool = False, + limiter: CapacityLimiter | None = None, + ) -> T_Retval: + pass + + @classmethod + @abstractmethod + def check_cancelled(cls) -> None: + pass + + @classmethod + @abstractmethod + def run_async_from_thread( + cls, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + args: tuple[Unpack[PosArgsT]], + token: object, + ) -> T_Retval: + pass + + @classmethod + @abstractmethod + def run_sync_from_thread( + cls, + func: Callable[[Unpack[PosArgsT]], T_Retval], + args: tuple[Unpack[PosArgsT]], + token: object, + ) -> T_Retval: + pass + + @classmethod + @abstractmethod + def create_blocking_portal(cls) -> BlockingPortal: + pass + + @classmethod + @overload + async def open_process( + cls, + command: str | bytes, + *, + shell: Literal[True], + stdin: int | IO[Any] | None, + stdout: int | IO[Any] | None, + stderr: int | IO[Any] | None, + cwd: str | bytes | PathLike[str] | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, + ) -> Process: + pass + + @classmethod + @overload + async def open_process( + cls, + command: Sequence[str | bytes], + *, + shell: Literal[False], + stdin: int | IO[Any] | None, + stdout: int | IO[Any] | None, + stderr: int | IO[Any] | None, + cwd: str | bytes | PathLike[str] | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, + ) -> Process: + pass + + @classmethod + @abstractmethod + async def open_process( + cls, + command: str | bytes | Sequence[str | bytes], + *, + shell: bool, + stdin: int | IO[Any] | None, + stdout: int | IO[Any] | None, + stderr: int | IO[Any] | None, + cwd: str | bytes | PathLike[str] | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, + ) -> Process: + pass + + @classmethod + @abstractmethod + def setup_process_pool_exit_at_shutdown(cls, workers: set[Process]) -> None: + pass + + @classmethod + @abstractmethod + async def connect_tcp( + cls, host: str, port: int, local_address: IPSockAddrType | None = None + ) -> SocketStream: + pass + + @classmethod + @abstractmethod + async def connect_unix(cls, path: str | bytes) -> UNIXSocketStream: + pass + + @classmethod + @abstractmethod + def create_tcp_listener(cls, sock: socket) -> SocketListener: + pass + + @classmethod + @abstractmethod + def create_unix_listener(cls, sock: socket) -> SocketListener: + pass + + @classmethod + @abstractmethod + async def create_udp_socket( + cls, + family: AddressFamily, + local_address: IPSockAddrType | None, + remote_address: IPSockAddrType | None, + reuse_port: bool, + ) -> UDPSocket | ConnectedUDPSocket: + pass + + @classmethod + @overload + async def create_unix_datagram_socket( + cls, raw_socket: socket, remote_path: None + ) -> UNIXDatagramSocket: + ... + + @classmethod + @overload + async def create_unix_datagram_socket( + cls, raw_socket: socket, remote_path: str | bytes + ) -> ConnectedUNIXDatagramSocket: + ... + + @classmethod + @abstractmethod + async def create_unix_datagram_socket( + cls, raw_socket: socket, remote_path: str | bytes | None + ) -> UNIXDatagramSocket | ConnectedUNIXDatagramSocket: + pass + + @classmethod + @abstractmethod + async def getaddrinfo( + cls, + host: bytes | str | None, + port: str | int | None, + *, + family: int | AddressFamily = 0, + type: int | SocketKind = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: + pass + + @classmethod + @abstractmethod + async def getnameinfo( + cls, sockaddr: IPSockAddrType, flags: int = 0 + ) -> tuple[str, str]: + pass + + @classmethod + @abstractmethod + async def wait_socket_readable(cls, sock: socket) -> None: + pass + + @classmethod + @abstractmethod + async def wait_socket_writable(cls, sock: socket) -> None: + pass + + @classmethod + @abstractmethod + def current_default_thread_limiter(cls) -> CapacityLimiter: + pass + + @classmethod + @abstractmethod + def open_signal_receiver( + cls, *signals: Signals + ) -> ContextManager[AsyncIterator[Signals]]: + pass + + @classmethod + @abstractmethod + def get_current_task(cls) -> TaskInfo: + pass + + @classmethod + @abstractmethod + def get_running_tasks(cls) -> list[TaskInfo]: + pass + + @classmethod + @abstractmethod + async def wait_all_tasks_blocked(cls) -> None: + pass + + @classmethod + @abstractmethod + def create_test_runner(cls, options: dict[str, Any]) -> TestRunner: + pass diff --git a/contrib/python/anyio/anyio/abc/_resources.py b/contrib/python/anyio/anyio/abc/_resources.py index e0a283fc98..9693835bae 100644 --- a/contrib/python/anyio/anyio/abc/_resources.py +++ b/contrib/python/anyio/anyio/abc/_resources.py @@ -11,8 +11,8 @@ class AsyncResource(metaclass=ABCMeta): """ Abstract base class for all closeable asynchronous resources. - Works as an asynchronous context manager which returns the instance itself on enter, and calls - :meth:`aclose` on exit. + Works as an asynchronous context manager which returns the instance itself on enter, + and calls :meth:`aclose` on exit. """ async def __aenter__(self: T) -> T: diff --git a/contrib/python/anyio/anyio/abc/_sockets.py b/contrib/python/anyio/anyio/abc/_sockets.py index 6aac5f7c22..b321225a7b 100644 --- a/contrib/python/anyio/anyio/abc/_sockets.py +++ b/contrib/python/anyio/anyio/abc/_sockets.py @@ -2,21 +2,14 @@ from __future__ import annotations import socket from abc import abstractmethod +from collections.abc import Callable, Collection, Mapping from contextlib import AsyncExitStack from io import IOBase from ipaddress import IPv4Address, IPv6Address from socket import AddressFamily -from typing import ( - Any, - Callable, - Collection, - Mapping, - Tuple, - TypeVar, - Union, -) +from types import TracebackType +from typing import Any, Tuple, TypeVar, Union -from .._core._tasks import create_task_group from .._core._typedattr import ( TypedAttributeProvider, TypedAttributeSet, @@ -29,9 +22,23 @@ IPAddressType = Union[str, IPv4Address, IPv6Address] IPSockAddrType = Tuple[str, int] SockAddrType = Union[IPSockAddrType, str] UDPPacketType = Tuple[bytes, IPSockAddrType] +UNIXDatagramPacketType = Tuple[bytes, str] T_Retval = TypeVar("T_Retval") +class _NullAsyncContextManager: + async def __aenter__(self) -> None: + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + class SocketAttribute(TypedAttributeSet): #: the address family of the underlying socket family: AddressFamily = typed_attribute() @@ -70,9 +77,9 @@ class _SocketProvider(TypedAttributeProvider): # Provide local and remote ports for IP based sockets if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): - attributes[ - SocketAttribute.local_port - ] = lambda: self._raw_socket.getsockname()[1] + attributes[SocketAttribute.local_port] = ( + lambda: self._raw_socket.getsockname()[1] + ) if peername is not None: remote_port = peername[1] attributes[SocketAttribute.remote_port] = lambda: remote_port @@ -100,8 +107,8 @@ class UNIXSocketStream(SocketStream): Send file descriptors along with a message to the peer. :param message: a non-empty bytestring - :param fds: a collection of files (either numeric file descriptors or open file or socket - objects) + :param fds: a collection of files (either numeric file descriptors or open file + or socket objects) """ @abstractmethod @@ -131,9 +138,11 @@ class SocketListener(Listener[SocketStream], _SocketProvider): handler: Callable[[SocketStream], Any], task_group: TaskGroup | None = None, ) -> None: - async with AsyncExitStack() as exit_stack: + from .. import create_task_group + + async with AsyncExitStack() as stack: if task_group is None: - task_group = await exit_stack.enter_async_context(create_task_group()) + task_group = await stack.enter_async_context(create_task_group()) while True: stream = await self.accept() @@ -148,7 +157,10 @@ class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider): """ async def sendto(self, data: bytes, host: str, port: int) -> None: - """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).""" + """ + Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))). + + """ return await self.send((data, (host, port))) @@ -158,3 +170,25 @@ class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider): Supports all relevant extra attributes from :class:`~SocketAttribute`. """ + + +class UNIXDatagramSocket( + UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider +): + """ + Represents an unconnected Unix datagram socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + async def sendto(self, data: bytes, path: str) -> None: + """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path)).""" + return await self.send((data, path)) + + +class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider): + """ + Represents a connected Unix datagram socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ diff --git a/contrib/python/anyio/anyio/abc/_streams.py b/contrib/python/anyio/anyio/abc/_streams.py index 4fa7ccc9ff..8c638683a4 100644 --- a/contrib/python/anyio/anyio/abc/_streams.py +++ b/contrib/python/anyio/anyio/abc/_streams.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Callable, Generic, TypeVar, Union +from collections.abc import Callable +from typing import Any, Generic, TypeVar, Union from .._core._exceptions import EndOfStream from .._core._typedattr import TypedAttributeProvider @@ -19,11 +20,11 @@ class UnreliableObjectReceiveStream( """ An interface for receiving objects. - This interface makes no guarantees that the received messages arrive in the order in which they - were sent, or that no messages are missed. + This interface makes no guarantees that the received messages arrive in the order in + which they were sent, or that no messages are missed. - Asynchronously iterating over objects of this type will yield objects matching the given type - parameter. + Asynchronously iterating over objects of this type will yield objects matching the + given type parameter. """ def __aiter__(self) -> UnreliableObjectReceiveStream[T_co]: @@ -54,8 +55,8 @@ class UnreliableObjectSendStream( """ An interface for sending objects. - This interface makes no guarantees that the messages sent will reach the recipient(s) in the - same order in which they were sent, or at all. + This interface makes no guarantees that the messages sent will reach the + recipient(s) in the same order in which they were sent, or at all. """ @abstractmethod @@ -75,22 +76,22 @@ class UnreliableObjectStream( UnreliableObjectReceiveStream[T_Item], UnreliableObjectSendStream[T_Item] ): """ - A bidirectional message stream which does not guarantee the order or reliability of message - delivery. + A bidirectional message stream which does not guarantee the order or reliability of + message delivery. """ class ObjectReceiveStream(UnreliableObjectReceiveStream[T_co]): """ - A receive message stream which guarantees that messages are received in the same order in - which they were sent, and that no messages are missed. + A receive message stream which guarantees that messages are received in the same + order in which they were sent, and that no messages are missed. """ class ObjectSendStream(UnreliableObjectSendStream[T_contra]): """ - A send message stream which guarantees that messages are delivered in the same order in which - they were sent, without missing any messages in the middle. + A send message stream which guarantees that messages are delivered in the same order + in which they were sent, without missing any messages in the middle. """ @@ -100,7 +101,8 @@ class ObjectStream( UnreliableObjectStream[T_Item], ): """ - A bidirectional message stream which guarantees the order and reliability of message delivery. + A bidirectional message stream which guarantees the order and reliability of message + delivery. """ @abstractmethod @@ -108,8 +110,8 @@ class ObjectStream( """ Send an end-of-file indication to the peer. - You should not try to send any further data to this stream after calling this method. - This method is idempotent (does nothing on successive calls). + You should not try to send any further data to this stream after calling this + method. This method is idempotent (does nothing on successive calls). """ @@ -117,8 +119,8 @@ class ByteReceiveStream(AsyncResource, TypedAttributeProvider): """ An interface for receiving bytes from a single peer. - Iterating this byte stream will yield a byte string of arbitrary length, but no more than - 65536 bytes. + Iterating this byte stream will yield a byte string of arbitrary length, but no more + than 65536 bytes. """ def __aiter__(self) -> ByteReceiveStream: @@ -135,8 +137,8 @@ class ByteReceiveStream(AsyncResource, TypedAttributeProvider): """ Receive at most ``max_bytes`` bytes from the peer. - .. note:: Implementors of this interface should not return an empty :class:`bytes` object, - and users should ignore them. + .. note:: Implementors of this interface should not return an empty + :class:`bytes` object, and users should ignore them. :param max_bytes: maximum number of bytes to receive :return: the received bytes @@ -164,8 +166,8 @@ class ByteStream(ByteReceiveStream, ByteSendStream): """ Send an end-of-file indication to the peer. - You should not try to send any further data to this stream after calling this method. - This method is idempotent (does nothing on successive calls). + You should not try to send any further data to this stream after calling this + method. This method is idempotent (does nothing on successive calls). """ @@ -190,14 +192,12 @@ class Listener(Generic[T_co], AsyncResource, TypedAttributeProvider): @abstractmethod async def serve( - self, - handler: Callable[[T_co], Any], - task_group: TaskGroup | None = None, + self, handler: Callable[[T_co], Any], task_group: TaskGroup | None = None ) -> None: """ Accept incoming connections as they come in and start tasks to handle them. :param handler: a callable that will be used to handle each accepted connection - :param task_group: the task group that will be used to start tasks for handling each - accepted connection (if omitted, an ad-hoc task group will be created) + :param task_group: the task group that will be used to start tasks for handling + each accepted connection (if omitted, an ad-hoc task group will be created) """ diff --git a/contrib/python/anyio/anyio/abc/_subprocesses.py b/contrib/python/anyio/anyio/abc/_subprocesses.py index 704b44a2dd..ce0564ceac 100644 --- a/contrib/python/anyio/anyio/abc/_subprocesses.py +++ b/contrib/python/anyio/anyio/abc/_subprocesses.py @@ -59,8 +59,8 @@ class Process(AsyncResource): @abstractmethod def returncode(self) -> int | None: """ - The return code of the process. If the process has not yet terminated, this will be - ``None``. + The return code of the process. If the process has not yet terminated, this will + be ``None``. """ @property diff --git a/contrib/python/anyio/anyio/abc/_tasks.py b/contrib/python/anyio/anyio/abc/_tasks.py index e48d3c1e97..7ad4938cb4 100644 --- a/contrib/python/anyio/anyio/abc/_tasks.py +++ b/contrib/python/anyio/anyio/abc/_tasks.py @@ -2,20 +2,21 @@ from __future__ import annotations import sys from abc import ABCMeta, abstractmethod +from collections.abc import Awaitable, Callable from types import TracebackType -from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, overload -from warnings import warn +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload -if sys.version_info >= (3, 8): - from typing import Protocol +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack else: - from typing_extensions import Protocol + from typing_extensions import TypeVarTuple, Unpack if TYPE_CHECKING: - from anyio._core._tasks import CancelScope + from .._core._tasks import CancelScope T_Retval = TypeVar("T_Retval") T_contra = TypeVar("T_contra", contravariant=True) +PosArgsT = TypeVarTuple("PosArgsT") class TaskStatus(Protocol[T_contra]): @@ -45,35 +46,11 @@ class TaskGroup(metaclass=ABCMeta): cancel_scope: CancelScope - async def spawn( - self, - func: Callable[..., Awaitable[Any]], - *args: object, - name: object = None, - ) -> None: - """ - Start a new task in this task group. - - :param func: a coroutine function - :param args: positional arguments to call the function with - :param name: name of the task, for the purposes of introspection and debugging - - .. deprecated:: 3.0 - Use :meth:`start_soon` instead. If your code needs AnyIO 2 compatibility, you - can keep using this until AnyIO 4. - - """ - warn( - 'spawn() is deprecated -- use start_soon() (without the "await") instead', - DeprecationWarning, - ) - self.start_soon(func, *args, name=name) - @abstractmethod def start_soon( self, - func: Callable[..., Awaitable[Any]], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], name: object = None, ) -> None: """ @@ -100,7 +77,8 @@ class TaskGroup(metaclass=ABCMeta): :param args: positional arguments to call the function with :param name: name of the task, for the purposes of introspection and debugging :return: the value passed to ``task_status.started()`` - :raises RuntimeError: if the task finishes without calling ``task_status.started()`` + :raises RuntimeError: if the task finishes without calling + ``task_status.started()`` .. versionadded:: 3.0 """ diff --git a/contrib/python/anyio/anyio/abc/_testing.py b/contrib/python/anyio/anyio/abc/_testing.py index ee2cff5cc3..4d70b9ec6b 100644 --- a/contrib/python/anyio/anyio/abc/_testing.py +++ b/contrib/python/anyio/anyio/abc/_testing.py @@ -2,33 +2,29 @@ from __future__ import annotations import types from abc import ABCMeta, abstractmethod -from collections.abc import AsyncGenerator, Iterable -from typing import Any, Callable, Coroutine, TypeVar +from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable +from typing import Any, TypeVar _T = TypeVar("_T") class TestRunner(metaclass=ABCMeta): """ - Encapsulates a running event loop. Every call made through this object will use the same event - loop. + Encapsulates a running event loop. Every call made through this object will use the + same event loop. """ def __enter__(self) -> TestRunner: return self + @abstractmethod def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> bool | None: - self.close() - return None - - @abstractmethod - def close(self) -> None: - """Close the event loop.""" + ... @abstractmethod def run_asyncgen_fixture( diff --git a/contrib/python/anyio/anyio/from_thread.py b/contrib/python/anyio/anyio/from_thread.py index 6b76861c70..4a987031fe 100644 --- a/contrib/python/anyio/anyio/from_thread.py +++ b/contrib/python/anyio/anyio/from_thread.py @@ -1,36 +1,43 @@ from __future__ import annotations +import sys import threading -from asyncio import iscoroutine +from collections.abc import Awaitable, Callable, Generator from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from contextlib import AbstractContextManager, contextmanager +from inspect import isawaitable from types import TracebackType from typing import ( Any, AsyncContextManager, - Awaitable, - Callable, ContextManager, - Generator, Generic, Iterable, TypeVar, cast, overload, ) -from warnings import warn from ._core import _eventloop -from ._core._eventloop import get_asynclib, get_cancelled_exc_class, threadlocals +from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals from ._core._synchronization import Event from ._core._tasks import CancelScope, create_task_group +from .abc import AsyncBackend from .abc._tasks import TaskStatus +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + T_Retval = TypeVar("T_Retval") -T_co = TypeVar("T_co") +T_co = TypeVar("T_co", covariant=True) +PosArgsT = TypeVarTuple("PosArgsT") -def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: +def run( + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT] +) -> T_Retval: """ Call a coroutine function from a worker thread. @@ -40,24 +47,19 @@ def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: """ try: - asynclib = threadlocals.current_async_module + async_backend = threadlocals.current_async_backend + token = threadlocals.current_token except AttributeError: - raise RuntimeError("This function can only be run from an AnyIO worker thread") + raise RuntimeError( + "This function can only be run from an AnyIO worker thread" + ) from None - return asynclib.run_async_from_thread(func, *args) + return async_backend.run_async_from_thread(func, args, token=token) -def run_async_from_thread( - func: Callable[..., Awaitable[T_Retval]], *args: object +def run_sync( + func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] ) -> T_Retval: - warn( - "run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead", - DeprecationWarning, - ) - return run(func, *args) - - -def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: """ Call a function in the event loop thread from a worker thread. @@ -67,24 +69,19 @@ def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: """ try: - asynclib = threadlocals.current_async_module + async_backend = threadlocals.current_async_backend + token = threadlocals.current_token except AttributeError: - raise RuntimeError("This function can only be run from an AnyIO worker thread") - - return asynclib.run_sync_from_thread(func, *args) + raise RuntimeError( + "This function can only be run from an AnyIO worker thread" + ) from None - -def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval: - warn( - "run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead", - DeprecationWarning, - ) - return run_sync(func, *args) + return async_backend.run_sync_from_thread(func, args, token=token) class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): - _enter_future: Future - _exit_future: Future + _enter_future: Future[T_co] + _exit_future: Future[bool | None] _exit_event: Event _exit_exc_info: tuple[ type[BaseException] | None, BaseException | None, TracebackType | None @@ -120,8 +117,7 @@ class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): def __enter__(self) -> T_co: self._enter_future = Future() self._exit_future = self._portal.start_task_soon(self.run_async_cm) - cm = self._enter_future.result() - return cast(T_co, cm) + return self._enter_future.result() def __exit__( self, @@ -146,7 +142,7 @@ class BlockingPortal: """An object that lets external threads run code in an asynchronous event loop.""" def __new__(cls) -> BlockingPortal: - return get_asynclib().BlockingPortal() + return get_async_backend().create_blocking_portal() def __init__(self) -> None: self._event_loop_thread_id: int | None = threading.get_ident() @@ -186,8 +182,8 @@ class BlockingPortal: This marks the portal as no longer accepting new calls and exits from :meth:`sleep_until_stopped`. - :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` to let them - finish before returning + :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` + to let them finish before returning """ self._event_loop_thread_id = None @@ -196,9 +192,13 @@ class BlockingPortal: self._task_group.cancel_scope.cancel() async def _call_func( - self, func: Callable, args: tuple, kwargs: dict[str, Any], future: Future + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], + kwargs: dict[str, Any], + future: Future[T_Retval], ) -> None: - def callback(f: Future) -> None: + def callback(f: Future[T_Retval]) -> None: if f.cancelled() and self._event_loop_thread_id not in ( None, threading.get_ident(), @@ -206,17 +206,20 @@ class BlockingPortal: self.call(scope.cancel) try: - retval = func(*args, **kwargs) - if iscoroutine(retval): + retval_or_awaitable = func(*args, **kwargs) + if isawaitable(retval_or_awaitable): with CancelScope() as scope: if future.cancelled(): scope.cancel() else: future.add_done_callback(callback) - retval = await retval + retval = await retval_or_awaitable + else: + retval = retval_or_awaitable except self._cancelled_exc_class: future.cancel() + future.set_running_or_notify_cancel() except BaseException as exc: if not future.cancelled(): future.set_exception(exc) @@ -232,11 +235,11 @@ class BlockingPortal: def _spawn_task_from_thread( self, - func: Callable, - args: tuple, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, - future: Future, + future: Future[T_Retval], ) -> None: """ Spawn a new task using the given callable. @@ -247,22 +250,30 @@ class BlockingPortal: :param args: positional arguments to be passed to the callable :param kwargs: keyword arguments to be passed to the callable :param name: name of the task (will be coerced to a string if not ``None``) - :param future: a future that will resolve to the return value of the callable, or the - exception raised during its execution + :param future: a future that will resolve to the return value of the callable, + or the exception raised during its execution """ raise NotImplementedError @overload - def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: + def call( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], + ) -> T_Retval: ... @overload - def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: + def call( + self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] + ) -> T_Retval: ... def call( - self, func: Callable[..., Awaitable[T_Retval] | T_Retval], *args: object + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + *args: Unpack[PosArgsT], ) -> T_Retval: """ Call the given function in the event loop thread. @@ -270,82 +281,41 @@ class BlockingPortal: If the callable returns a coroutine object, it is awaited on. :param func: any callable - :raises RuntimeError: if the portal is not running or if this method is called from within - the event loop thread + :raises RuntimeError: if the portal is not running or if this method is called + from within the event loop thread """ return cast(T_Retval, self.start_task_soon(func, *args).result()) @overload - def spawn_task( + def start_task_soon( self, - func: Callable[..., Awaitable[T_Retval]], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: ... @overload - def spawn_task( - self, func: Callable[..., T_Retval], *args: object, name: object = None - ) -> Future[T_Retval]: - ... - - def spawn_task( - self, - func: Callable[..., Awaitable[T_Retval] | T_Retval], - *args: object, - name: object = None, - ) -> Future[T_Retval]: - """ - Start a task in the portal's task group. - - :param func: the target coroutine function - :param args: positional arguments passed to ``func`` - :param name: name of the task (will be coerced to a string if not ``None``) - :return: a future that resolves with the return value of the callable if the task completes - successfully, or with the exception raised in the task - :raises RuntimeError: if the portal is not running or if this method is called from within - the event loop thread - - .. versionadded:: 2.1 - .. deprecated:: 3.0 - Use :meth:`start_task_soon` instead. If your code needs AnyIO 2 compatibility, you - can keep using this until AnyIO 4. - - """ - warn( - "spawn_task() is deprecated -- use start_task_soon() instead", - DeprecationWarning, - ) - return self.start_task_soon(func, *args, name=name) # type: ignore[arg-type] - - @overload def start_task_soon( self, - func: Callable[..., Awaitable[T_Retval]], - *args: object, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: ... - @overload - def start_task_soon( - self, func: Callable[..., T_Retval], *args: object, name: object = None - ) -> Future[T_Retval]: - ... - def start_task_soon( self, - func: Callable[..., Awaitable[T_Retval] | T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: """ Start a task in the portal's task group. - The task will be run inside a cancel scope which can be cancelled by cancelling the - returned future. + The task will be run inside a cancel scope which can be cancelled by cancelling + the returned future. :param func: the target function :param args: positional arguments passed to ``func`` @@ -360,13 +330,16 @@ class BlockingPortal: """ self._check_running() - f: Future = Future() + f: Future[T_Retval] = Future() self._spawn_task_from_thread(func, args, {}, name, f) return f def start_task( - self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None - ) -> tuple[Future[Any], Any]: + self, + func: Callable[..., Awaitable[T_Retval]], + *args: object, + name: object = None, + ) -> tuple[Future[T_Retval], Any]: """ Start a task in the portal's task group and wait until it signals for readiness. @@ -378,13 +351,13 @@ class BlockingPortal: :return: a tuple of (future, task_status_value) where the ``task_status_value`` is the value passed to ``task_status.started()`` from within the target function - :rtype: tuple[concurrent.futures.Future[Any], Any] + :rtype: tuple[concurrent.futures.Future[T_Retval], Any] .. versionadded:: 3.0 """ - def task_done(future: Future) -> None: + def task_done(future: Future[T_Retval]) -> None: if not task_status_future.done(): if future.cancelled(): task_status_future.cancel() @@ -410,8 +383,8 @@ class BlockingPortal: """ Wrap an async context manager as a synchronous context manager via this portal. - Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping in the - middle until the synchronous context manager exits. + Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping + in the middle until the synchronous context manager exits. :param cm: an asynchronous context manager :return: a synchronous context manager @@ -422,25 +395,6 @@ class BlockingPortal: return _BlockingAsyncContextManager(cm, self) -def create_blocking_portal() -> BlockingPortal: - """ - Create a portal for running functions in the event loop thread from external threads. - - Use this function in asynchronous code when you need to allow external threads access to the - event loop where your asynchronous code is currently running. - - .. deprecated:: 3.0 - Use :class:`.BlockingPortal` directly. - - """ - warn( - "create_blocking_portal() has been deprecated -- use anyio.from_thread.BlockingPortal() " - "directly", - DeprecationWarning, - ) - return BlockingPortal() - - @contextmanager def start_blocking_portal( backend: str = "asyncio", backend_options: dict[str, Any] | None = None @@ -468,8 +422,8 @@ def start_blocking_portal( future: Future[BlockingPortal] = Future() with ThreadPoolExecutor(1) as executor: run_future = executor.submit( - _eventloop.run, - run_portal, # type: ignore[arg-type] + _eventloop.run, # type: ignore[arg-type] + run_portal, backend=backend, backend_options=backend_options, ) @@ -498,3 +452,25 @@ def start_blocking_portal( pass run_future.result() + + +def check_cancelled() -> None: + """ + Check if the cancel scope of the host task's running the current worker thread has + been cancelled. + + If the host task's current cancel scope has indeed been cancelled, the + backend-specific cancellation exception will be raised. + + :raises RuntimeError: if the current thread was not spawned by + :func:`.to_thread.run_sync` + + """ + try: + async_backend: AsyncBackend = threadlocals.current_async_backend + except AttributeError: + raise RuntimeError( + "This function can only be run from an AnyIO worker thread" + ) from None + + async_backend.check_cancelled() diff --git a/contrib/python/anyio/anyio/lowlevel.py b/contrib/python/anyio/anyio/lowlevel.py index 0e908c6547..a9e10f430a 100644 --- a/contrib/python/anyio/anyio/lowlevel.py +++ b/contrib/python/anyio/anyio/lowlevel.py @@ -1,17 +1,11 @@ from __future__ import annotations import enum -import sys from dataclasses import dataclass -from typing import Any, Generic, TypeVar, overload +from typing import Any, Generic, Literal, TypeVar, overload from weakref import WeakKeyDictionary -from ._core._eventloop import get_asynclib - -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal +from ._core._eventloop import get_async_backend T = TypeVar("T") D = TypeVar("D") @@ -30,7 +24,7 @@ async def checkpoint() -> None: .. versionadded:: 3.0 """ - await get_asynclib().checkpoint() + await get_async_backend().checkpoint() async def checkpoint_if_cancelled() -> None: @@ -42,7 +36,7 @@ async def checkpoint_if_cancelled() -> None: .. versionadded:: 3.0 """ - await get_asynclib().checkpoint_if_cancelled() + await get_async_backend().checkpoint_if_cancelled() async def cancel_shielded_checkpoint() -> None: @@ -58,12 +52,16 @@ async def cancel_shielded_checkpoint() -> None: .. versionadded:: 3.0 """ - await get_asynclib().cancel_shielded_checkpoint() + await get_async_backend().cancel_shielded_checkpoint() def current_token() -> object: - """Return a backend specific token object that can be used to get back to the event loop.""" - return get_asynclib().current_token() + """ + Return a backend specific token object that can be used to get back to the event + loop. + + """ + return get_async_backend().current_token() _run_vars: WeakKeyDictionary[Any, dict[str, Any]] = WeakKeyDictionary() @@ -101,9 +99,7 @@ class RunVar(Generic[T]): _token_wrappers: set[_TokenWrapper] = set() def __init__( - self, - name: str, - default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET, + self, name: str, default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET ): self._name = name self._default = default @@ -111,18 +107,11 @@ class RunVar(Generic[T]): @property def _current_vars(self) -> dict[str, T]: token = current_token() - while True: - try: - return _run_vars[token] - except TypeError: - # Happens when token isn't weak referable (TrioToken). - # This workaround does mean that some memory will leak on Trio until the problem - # is fixed on their end. - token = _TokenWrapper(token) - self._token_wrappers.add(token) - except KeyError: - run_vars = _run_vars[token] = {} - return run_vars + try: + return _run_vars[token] + except KeyError: + run_vars = _run_vars[token] = {} + return run_vars @overload def get(self, default: D) -> T | D: diff --git a/contrib/python/anyio/anyio/pytest_plugin.py b/contrib/python/anyio/anyio/pytest_plugin.py index 044ce6914d..a8dd6f3e3f 100644 --- a/contrib/python/anyio/anyio/pytest_plugin.py +++ b/contrib/python/anyio/anyio/pytest_plugin.py @@ -1,16 +1,19 @@ from __future__ import annotations -from contextlib import contextmanager +from collections.abc import Iterator +from contextlib import ExitStack, contextmanager from inspect import isasyncgenfunction, iscoroutinefunction -from typing import Any, Dict, Generator, Tuple, cast +from typing import Any, Dict, Tuple, cast import pytest import sniffio -from ._core._eventloop import get_all_backends, get_asynclib +from ._core._eventloop import get_all_backends, get_async_backend from .abc import TestRunner _current_runner: TestRunner | None = None +_runner_stack: ExitStack | None = None +_runner_leases = 0 def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]: @@ -26,27 +29,31 @@ def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]: @contextmanager def get_runner( backend_name: str, backend_options: dict[str, Any] -) -> Generator[TestRunner, object, None]: - global _current_runner - if _current_runner: - yield _current_runner - return +) -> Iterator[TestRunner]: + global _current_runner, _runner_leases, _runner_stack + if _current_runner is None: + asynclib = get_async_backend(backend_name) + _runner_stack = ExitStack() + if sniffio.current_async_library_cvar.get(None) is None: + # Since we're in control of the event loop, we can cache the name of the + # async library + token = sniffio.current_async_library_cvar.set(backend_name) + _runner_stack.callback(sniffio.current_async_library_cvar.reset, token) - asynclib = get_asynclib(backend_name) - token = None - if sniffio.current_async_library_cvar.get(None) is None: - # Since we're in control of the event loop, we can cache the name of the async library - token = sniffio.current_async_library_cvar.set(backend_name) + backend_options = backend_options or {} + _current_runner = _runner_stack.enter_context( + asynclib.create_test_runner(backend_options) + ) + _runner_leases += 1 try: - backend_options = backend_options or {} - with asynclib.TestRunner(**backend_options) as runner: - _current_runner = runner - yield runner + yield _current_runner finally: - _current_runner = None - if token: - sniffio.current_async_library_cvar.reset(token) + _runner_leases -= 1 + if not _runner_leases: + assert _runner_stack is not None + _runner_stack.close() + _runner_stack = _current_runner = None def pytest_configure(config: Any) -> None: @@ -69,8 +76,8 @@ def pytest_fixture_setup(fixturedef: Any, request: Any) -> None: else: yield runner.run_fixture(func, kwargs) - # Only apply this to coroutine functions and async generator functions in requests that involve - # the anyio_backend fixture + # Only apply this to coroutine functions and async generator functions in requests + # that involve the anyio_backend fixture func = fixturedef.func if isasyncgenfunction(func) or iscoroutinefunction(func): if "anyio_backend" in request.fixturenames: @@ -121,7 +128,7 @@ def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None: return None -@pytest.fixture(params=get_all_backends()) +@pytest.fixture(scope="module", params=get_all_backends()) def anyio_backend(request: Any) -> Any: return request.param diff --git a/contrib/python/anyio/anyio/streams/buffered.py b/contrib/python/anyio/anyio/streams/buffered.py index 11474c16a9..f5d5e836dd 100644 --- a/contrib/python/anyio/anyio/streams/buffered.py +++ b/contrib/python/anyio/anyio/streams/buffered.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable, Mapping from dataclasses import dataclass, field -from typing import Any, Callable, Mapping +from typing import Any from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead from ..abc import AnyByteReceiveStream, ByteReceiveStream @@ -10,8 +11,8 @@ from ..abc import AnyByteReceiveStream, ByteReceiveStream @dataclass(eq=False) class BufferedByteReceiveStream(ByteReceiveStream): """ - Wraps any bytes-based receive stream and uses a buffer to provide sophisticated receiving - capabilities in the form of a byte stream. + Wraps any bytes-based receive stream and uses a buffer to provide sophisticated + receiving capabilities in the form of a byte stream. """ receive_stream: AnyByteReceiveStream @@ -42,8 +43,8 @@ class BufferedByteReceiveStream(ByteReceiveStream): elif isinstance(self.receive_stream, ByteReceiveStream): return await self.receive_stream.receive(max_bytes) else: - # With a bytes-oriented object stream, we need to handle any surplus bytes we get from - # the receive() call + # With a bytes-oriented object stream, we need to handle any surplus bytes + # we get from the receive() call chunk = await self.receive_stream.receive() if len(chunk) > max_bytes: # Save the surplus bytes in the buffer diff --git a/contrib/python/anyio/anyio/streams/file.py b/contrib/python/anyio/anyio/streams/file.py index 2840d40ab6..f492464267 100644 --- a/contrib/python/anyio/anyio/streams/file.py +++ b/contrib/python/anyio/anyio/streams/file.py @@ -1,9 +1,10 @@ from __future__ import annotations +from collections.abc import Callable, Mapping from io import SEEK_SET, UnsupportedOperation from os import PathLike from pathlib import Path -from typing import Any, BinaryIO, Callable, Mapping, cast +from typing import Any, BinaryIO, cast from .. import ( BrokenResourceError, @@ -130,8 +131,8 @@ class FileWriteStream(_BaseFileStream, ByteSendStream): Create a file write stream by opening the given file for writing. :param path: path of the file to write to - :param append: if ``True``, open the file for appending; if ``False``, any existing file - at the given path will be truncated + :param append: if ``True``, open the file for appending; if ``False``, any + existing file at the given path will be truncated """ mode = "ab" if append else "wb" diff --git a/contrib/python/anyio/anyio/streams/memory.py b/contrib/python/anyio/anyio/streams/memory.py index a6499c13ff..bc2425b76f 100644 --- a/contrib/python/anyio/anyio/streams/memory.py +++ b/contrib/python/anyio/anyio/streams/memory.py @@ -10,9 +10,7 @@ from .. import ( ClosedResourceError, EndOfStream, WouldBlock, - get_cancelled_exc_class, ) -from .._core._compat import DeprecatedAwaitable from ..abc import Event, ObjectReceiveStream, ObjectSendStream from ..lowlevel import checkpoint @@ -27,7 +25,8 @@ class MemoryObjectStreamStatistics(NamedTuple): max_buffer_size: float open_send_streams: int #: number of unclosed clones of the send stream open_receive_streams: int #: number of unclosed clones of the receive stream - tasks_waiting_send: int #: number of tasks blocked on :meth:`MemoryObjectSendStream.send` + #: number of tasks blocked on :meth:`MemoryObjectSendStream.send` + tasks_waiting_send: int #: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive` tasks_waiting_receive: int @@ -104,11 +103,6 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]): try: await receive_event.wait() - except get_cancelled_exc_class(): - # Ignore the immediate cancellation if we already received an item, so as not to - # lose it - if not container: - raise finally: self._state.waiting_receivers.pop(receive_event, None) @@ -121,8 +115,8 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]): """ Create a clone of this receive stream. - Each clone can be closed separately. Only when all clones have been closed will the - receiving end of the memory stream be considered closed by the sending ends. + Each clone can be closed separately. Only when all clones have been closed will + the receiving end of the memory stream be considered closed by the sending ends. :return: the cloned stream @@ -136,8 +130,8 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]): """ Close the stream. - This works the exact same way as :meth:`aclose`, but is provided as a special case for the - benefit of synchronous callbacks. + This works the exact same way as :meth:`aclose`, but is provided as a special + case for the benefit of synchronous callbacks. """ if not self._closed: @@ -179,7 +173,7 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): def __post_init__(self) -> None: self._state.open_send_channels += 1 - def send_nowait(self, item: T_contra) -> DeprecatedAwaitable: + def send_nowait(self, item: T_contra) -> None: """ Send an item immediately if it can be done without waiting. @@ -205,9 +199,19 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): else: raise WouldBlock - return DeprecatedAwaitable(self.send_nowait) - async def send(self, item: T_contra) -> None: + """ + Send an item to the stream. + + If the buffer is full, this method blocks until there is again room in the + buffer or the item can be sent directly to a receiver. + + :param item: the item to send + :raises ~anyio.ClosedResourceError: if this send stream has been closed + :raises ~anyio.BrokenResourceError: if the stream has been closed from the + receiving end + + """ await checkpoint() try: self.send_nowait(item) @@ -218,18 +222,18 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): try: await send_event.wait() except BaseException: - self._state.waiting_senders.pop(send_event, None) # type: ignore[arg-type] + self._state.waiting_senders.pop(send_event, None) raise - if self._state.waiting_senders.pop(send_event, None): # type: ignore[arg-type] - raise BrokenResourceError + if self._state.waiting_senders.pop(send_event, None): + raise BrokenResourceError from None def clone(self) -> MemoryObjectSendStream[T_contra]: """ Create a clone of this send stream. - Each clone can be closed separately. Only when all clones have been closed will the - sending end of the memory stream be considered closed by the receiving ends. + Each clone can be closed separately. Only when all clones have been closed will + the sending end of the memory stream be considered closed by the receiving ends. :return: the cloned stream @@ -243,8 +247,8 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): """ Close the stream. - This works the exact same way as :meth:`aclose`, but is provided as a special case for the - benefit of synchronous callbacks. + This works the exact same way as :meth:`aclose`, but is provided as a special + case for the benefit of synchronous callbacks. """ if not self._closed: diff --git a/contrib/python/anyio/anyio/streams/stapled.py b/contrib/python/anyio/anyio/streams/stapled.py index 1b2862e3ea..80f64a2e8e 100644 --- a/contrib/python/anyio/anyio/streams/stapled.py +++ b/contrib/python/anyio/anyio/streams/stapled.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass -from typing import Any, Callable, Generic, Mapping, Sequence, TypeVar +from typing import Any, Generic, TypeVar from ..abc import ( ByteReceiveStream, @@ -23,8 +24,8 @@ class StapledByteStream(ByteStream): """ Combines two byte streams into a single, bidirectional byte stream. - Extra attributes will be provided from both streams, with the receive stream providing the - values in case of a conflict. + Extra attributes will be provided from both streams, with the receive stream + providing the values in case of a conflict. :param ByteSendStream send_stream: the sending byte stream :param ByteReceiveStream receive_stream: the receiving byte stream @@ -59,8 +60,8 @@ class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]): """ Combines two object streams into a single, bidirectional object stream. - Extra attributes will be provided from both streams, with the receive stream providing the - values in case of a conflict. + Extra attributes will be provided from both streams, with the receive stream + providing the values in case of a conflict. :param ObjectSendStream send_stream: the sending object stream :param ObjectReceiveStream receive_stream: the receiving object stream @@ -95,11 +96,11 @@ class MultiListener(Generic[T_Stream], Listener[T_Stream]): """ Combines multiple listeners into one, serving connections from all of them at once. - Any MultiListeners in the given collection of listeners will have their listeners moved into - this one. + Any MultiListeners in the given collection of listeners will have their listeners + moved into this one. - Extra attributes are provided from each listener, with each successive listener overriding any - conflicting attributes from the previous one. + Extra attributes are provided from each listener, with each successive listener + overriding any conflicting attributes from the previous one. :param listeners: listeners to serve :type listeners: Sequence[Listener[T_Stream]] diff --git a/contrib/python/anyio/anyio/streams/text.py b/contrib/python/anyio/anyio/streams/text.py index bba2d3f7df..f1a11278e3 100644 --- a/contrib/python/anyio/anyio/streams/text.py +++ b/contrib/python/anyio/anyio/streams/text.py @@ -1,8 +1,9 @@ from __future__ import annotations import codecs +from collections.abc import Callable, Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Callable, Mapping +from typing import Any from ..abc import ( AnyByteReceiveStream, @@ -19,16 +20,17 @@ class TextReceiveStream(ObjectReceiveStream[str]): """ Stream wrapper that decodes bytes to strings using the given encoding. - Decoding is done using :class:`~codecs.IncrementalDecoder` which returns any completely - received unicode characters as soon as they come in. + Decoding is done using :class:`~codecs.IncrementalDecoder` which returns any + completely received unicode characters as soon as they come in. :param transport_stream: any bytes-based receive stream - :param encoding: character encoding to use for decoding bytes to strings (defaults to - ``utf-8``) + :param encoding: character encoding to use for decoding bytes to strings (defaults + to ``utf-8``) :param errors: handling scheme for decoding errors (defaults to ``strict``; see the `codecs module documentation`_ for a comprehensive list of options) - .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects + .. _codecs module documentation: + https://docs.python.org/3/library/codecs.html#codec-objects """ transport_stream: AnyByteReceiveStream @@ -62,12 +64,13 @@ class TextSendStream(ObjectSendStream[str]): Sends strings to the wrapped stream as bytes using the given encoding. :param AnyByteSendStream transport_stream: any bytes-based send stream - :param str encoding: character encoding to use for encoding strings to bytes (defaults to - ``utf-8``) - :param str errors: handling scheme for encoding errors (defaults to ``strict``; see the - `codecs module documentation`_ for a comprehensive list of options) + :param str encoding: character encoding to use for encoding strings to bytes + (defaults to ``utf-8``) + :param str errors: handling scheme for encoding errors (defaults to ``strict``; see + the `codecs module documentation`_ for a comprehensive list of options) - .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects + .. _codecs module documentation: + https://docs.python.org/3/library/codecs.html#codec-objects """ transport_stream: AnyByteSendStream @@ -93,19 +96,20 @@ class TextSendStream(ObjectSendStream[str]): @dataclass(eq=False) class TextStream(ObjectStream[str]): """ - A bidirectional stream that decodes bytes to strings on receive and encodes strings to bytes on - send. + A bidirectional stream that decodes bytes to strings on receive and encodes strings + to bytes on send. - Extra attributes will be provided from both streams, with the receive stream providing the - values in case of a conflict. + Extra attributes will be provided from both streams, with the receive stream + providing the values in case of a conflict. :param AnyByteStream transport_stream: any bytes-based stream - :param str encoding: character encoding to use for encoding/decoding strings to/from bytes - (defaults to ``utf-8``) - :param str errors: handling scheme for encoding errors (defaults to ``strict``; see the - `codecs module documentation`_ for a comprehensive list of options) + :param str encoding: character encoding to use for encoding/decoding strings to/from + bytes (defaults to ``utf-8``) + :param str errors: handling scheme for encoding errors (defaults to ``strict``; see + the `codecs module documentation`_ for a comprehensive list of options) - .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects + .. _codecs module documentation: + https://docs.python.org/3/library/codecs.html#codec-objects """ transport_stream: AnyByteStream diff --git a/contrib/python/anyio/anyio/streams/tls.py b/contrib/python/anyio/anyio/streams/tls.py index 9f9e9fd89c..e913eedbbf 100644 --- a/contrib/python/anyio/anyio/streams/tls.py +++ b/contrib/python/anyio/anyio/streams/tls.py @@ -3,9 +3,11 @@ from __future__ import annotations import logging import re import ssl +import sys +from collections.abc import Callable, Mapping from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Mapping, Tuple, TypeVar +from typing import Any, Tuple, TypeVar from .. import ( BrokenResourceError, @@ -16,7 +18,13 @@ from .. import ( from .._core._typedattr import TypedAttributeSet, typed_attribute from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") _PCTRTT = Tuple[Tuple[str, str], ...] _PCTRTTT = Tuple[_PCTRTT, ...] @@ -31,8 +39,8 @@ class TLSAttribute(TypedAttributeSet): #: the selected cipher cipher: tuple[str, str, int] = typed_attribute() #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` - #: for more information) - peer_certificate: dict[str, str | _PCTRTTT | _PCTRTT] | None = typed_attribute() + # for more information) + peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute() #: the peer certificate in binary form peer_certificate_binary: bytes | None = typed_attribute() #: ``True`` if this is the server side of the connection @@ -90,8 +98,9 @@ class TLSStream(ByteStream): :param hostname: host name of the peer (if host name checking is desired) :param ssl_context: the SSLContext object to use (if not provided, a secure default will be created) - :param standard_compatible: if ``False``, skip the closing handshake when closing the - connection, and don't raise an exception if the peer does the same + :param standard_compatible: if ``False``, skip the closing handshake when + closing the connection, and don't raise an exception if the peer does the + same :raises ~ssl.SSLError: if the TLS handshake fails """ @@ -124,7 +133,7 @@ class TLSStream(ByteStream): return wrapper async def _call_sslobject_method( - self, func: Callable[..., T_Retval], *args: object + self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] ) -> T_Retval: while True: try: @@ -222,7 +231,9 @@ class TLSStream(ByteStream): return { **self.transport_stream.extra_attributes, TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, - TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding, + TLSAttribute.channel_binding_tls_unique: ( + self._ssl_object.get_channel_binding + ), TLSAttribute.cipher: self._ssl_object.cipher, TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( @@ -241,11 +252,12 @@ class TLSStream(ByteStream): @dataclass(eq=False) class TLSListener(Listener[TLSStream]): """ - A convenience listener that wraps another listener and auto-negotiates a TLS session on every - accepted connection. + A convenience listener that wraps another listener and auto-negotiates a TLS session + on every accepted connection. - If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is - called to do whatever post-mortem processing is deemed necessary. + If the TLS handshake times out or raises an exception, + :meth:`handle_handshake_error` is called to do whatever post-mortem processing is + deemed necessary. Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. @@ -281,7 +293,13 @@ class TLSListener(Listener[TLSStream]): # Log all except cancellation exceptions if not isinstance(exc, get_cancelled_exc_class()): - logging.getLogger(__name__).exception("Error during TLS handshake") + # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using + # any asyncio implementation, so we explicitly pass the exception to log + # (https://github.com/python/cpython/issues/108668). Trio does not have this + # issue because it works around the CPython bug. + logging.getLogger(__name__).exception( + "Error during TLS handshake", exc_info=exc + ) # Only reraise base exceptions and cancellation exceptions if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): diff --git a/contrib/python/anyio/anyio/to_process.py b/contrib/python/anyio/anyio/to_process.py index 7ba9d44198..1ff06f0b25 100644 --- a/contrib/python/anyio/anyio/to_process.py +++ b/contrib/python/anyio/anyio/to_process.py @@ -5,10 +5,11 @@ import pickle import subprocess import sys from collections import deque +from collections.abc import Callable from importlib.util import module_from_spec, spec_from_file_location -from typing import Callable, TypeVar, cast +from typing import TypeVar, cast -from ._core._eventloop import current_time, get_asynclib, get_cancelled_exc_class +from ._core._eventloop import current_time, get_async_backend, get_cancelled_exc_class from ._core._exceptions import BrokenWorkerProcess from ._core._subprocesses import open_process from ._core._synchronization import CapacityLimiter @@ -17,9 +18,16 @@ from .abc import ByteReceiveStream, ByteSendStream, Process from .lowlevel import RunVar, checkpoint_if_cancelled from .streams.buffered import BufferedByteReceiveStream +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + WORKER_MAX_IDLE_TIME = 300 # 5 minutes T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + _process_pool_workers: RunVar[set[Process]] = RunVar("_process_pool_workers") _process_pool_idle_workers: RunVar[deque[tuple[Process, float]]] = RunVar( "_process_pool_idle_workers" @@ -28,23 +36,24 @@ _default_process_limiter: RunVar[CapacityLimiter] = RunVar("_default_process_lim async def run_sync( - func: Callable[..., T_Retval], - *args: object, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], cancellable: bool = False, limiter: CapacityLimiter | None = None, ) -> T_Retval: """ Call the given function with the given arguments in a worker process. - If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled, - the worker process running it will be abruptly terminated using SIGKILL (or - ``terminateProcess()`` on Windows). + If the ``cancellable`` option is enabled and the task waiting for its completion is + cancelled, the worker process running it will be abruptly terminated using SIGKILL + (or ``terminateProcess()`` on Windows). :param func: a callable :param args: positional arguments for the callable - :param cancellable: ``True`` to allow cancellation of the operation while it's running - :param limiter: capacity limiter to use to limit the total amount of processes running - (if omitted, the default limiter is used) + :param cancellable: ``True`` to allow cancellation of the operation while it's + running + :param limiter: capacity limiter to use to limit the total amount of processes + running (if omitted, the default limiter is used) :return: an awaitable that yields the return value of the function. """ @@ -94,11 +103,11 @@ async def run_sync( idle_workers = deque() _process_pool_workers.set(workers) _process_pool_idle_workers.set(idle_workers) - get_asynclib().setup_process_pool_exit_at_shutdown(workers) + get_async_backend().setup_process_pool_exit_at_shutdown(workers) - async with (limiter or current_default_process_limiter()): - # Pop processes from the pool (starting from the most recently used) until we find one that - # hasn't exited yet + async with limiter or current_default_process_limiter(): + # Pop processes from the pool (starting from the most recently used) until we + # find one that hasn't exited yet process: Process while idle_workers: process, idle_since = idle_workers.pop() @@ -108,22 +117,22 @@ async def run_sync( cast(ByteReceiveStream, process.stdout) ) - # Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME seconds or - # longer + # Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME + # seconds or longer now = current_time() killed_processes: list[Process] = [] while idle_workers: if now - idle_workers[0][1] < WORKER_MAX_IDLE_TIME: break - process, idle_since = idle_workers.popleft() - process.kill() - workers.remove(process) - killed_processes.append(process) + process_to_kill, idle_since = idle_workers.popleft() + process_to_kill.kill() + workers.remove(process_to_kill) + killed_processes.append(process_to_kill) with CancelScope(shield=True): - for process in killed_processes: - await process.aclose() + for killed_process in killed_processes: + await killed_process.aclose() break @@ -172,7 +181,8 @@ async def run_sync( def current_default_process_limiter() -> CapacityLimiter: """ - Return the capacity limiter that is used by default to limit the number of worker processes. + Return the capacity limiter that is used by default to limit the number of worker + processes. :return: a capacity limiter object @@ -214,8 +224,8 @@ def process_worker() -> None: sys.path, main_module_path = args del sys.modules["__main__"] if main_module_path: - # Load the parent's main module but as __mp_main__ instead of __main__ - # (like multiprocessing does) to avoid infinite recursion + # Load the parent's main module but as __mp_main__ instead of + # __main__ (like multiprocessing does) to avoid infinite recursion try: spec = spec_from_file_location("__mp_main__", main_module_path) if spec and spec.loader: diff --git a/contrib/python/anyio/anyio/to_thread.py b/contrib/python/anyio/anyio/to_thread.py index 9315d1ecf1..5070516eb5 100644 --- a/contrib/python/anyio/anyio/to_thread.py +++ b/contrib/python/anyio/anyio/to_thread.py @@ -1,67 +1,69 @@ from __future__ import annotations -from typing import Callable, TypeVar +import sys +from collections.abc import Callable +from typing import TypeVar from warnings import warn -from ._core._eventloop import get_asynclib +from ._core._eventloop import get_async_backend from .abc import CapacityLimiter +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") async def run_sync( - func: Callable[..., T_Retval], - *args: object, - cancellable: bool = False, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], + abandon_on_cancel: bool = False, + cancellable: bool | None = None, limiter: CapacityLimiter | None = None, ) -> T_Retval: """ Call the given function with the given arguments in a worker thread. - If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled, - the thread will still run its course but its return value (or any raised exception) will be - ignored. + If the ``cancellable`` option is enabled and the task waiting for its completion is + cancelled, the thread will still run its course but its return value (or any raised + exception) will be ignored. :param func: a callable :param args: positional arguments for the callable - :param cancellable: ``True`` to allow cancellation of the operation + :param abandon_on_cancel: ``True`` to abandon the thread (leaving it to run + unchecked on own) if the host task is cancelled, ``False`` to ignore + cancellations in the host task until the operation has completed in the worker + thread + :param cancellable: deprecated alias of ``abandon_on_cancel``; will override + ``abandon_on_cancel`` if both parameters are passed :param limiter: capacity limiter to use to limit the total amount of threads running (if omitted, the default limiter is used) :return: an awaitable that yields the return value of the function. """ - return await get_asynclib().run_sync_in_worker_thread( - func, *args, cancellable=cancellable, limiter=limiter + if cancellable is not None: + abandon_on_cancel = cancellable + warn( + "The `cancellable=` keyword argument to `anyio.to_thread.run_sync` is " + "deprecated since AnyIO 4.1.0; use `abandon_on_cancel=` instead", + DeprecationWarning, + stacklevel=2, + ) + + return await get_async_backend().run_sync_in_worker_thread( + func, args, abandon_on_cancel=abandon_on_cancel, limiter=limiter ) -async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], - *args: object, - cancellable: bool = False, - limiter: CapacityLimiter | None = None, -) -> T_Retval: - warn( - "run_sync_in_worker_thread() has been deprecated, use anyio.to_thread.run_sync() instead", - DeprecationWarning, - ) - return await run_sync(func, *args, cancellable=cancellable, limiter=limiter) - - def current_default_thread_limiter() -> CapacityLimiter: """ - Return the capacity limiter that is used by default to limit the number of concurrent threads. + Return the capacity limiter that is used by default to limit the number of + concurrent threads. :return: a capacity limiter object """ - return get_asynclib().current_default_thread_limiter() - - -def current_default_worker_thread_limiter() -> CapacityLimiter: - warn( - "current_default_worker_thread_limiter() has been deprecated, " - "use anyio.to_thread.current_default_thread_limiter() instead", - DeprecationWarning, - ) - return current_default_thread_limiter() + return get_async_backend().current_default_thread_limiter() diff --git a/contrib/python/anyio/ya.make b/contrib/python/anyio/ya.make index f8534a7d6c..7d21184567 100644 --- a/contrib/python/anyio/ya.make +++ b/contrib/python/anyio/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(3.7.1) +VERSION(4.3.0) LICENSE(MIT) @@ -25,7 +25,6 @@ PY_SRCS( anyio/_backends/_asyncio.py anyio/_backends/_trio.py anyio/_core/__init__.py - anyio/_core/_compat.py anyio/_core/_eventloop.py anyio/_core/_exceptions.py anyio/_core/_fileio.py @@ -39,6 +38,7 @@ PY_SRCS( anyio/_core/_testing.py anyio/_core/_typedattr.py anyio/abc/__init__.py + anyio/abc/_eventloop.py anyio/abc/_resources.py anyio/abc/_sockets.py anyio/abc/_streams.py |