aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorrobot-piglet <robot-piglet@yandex-team.com>2024-05-25 22:39:06 +0300
committerrobot-piglet <robot-piglet@yandex-team.com>2024-05-25 22:48:55 +0300
commit7ddcd63286aa8b7a10c462f4d29507790855d43b (patch)
treeff8411bc6137aa568453f070adb7a508072841fd
parent93dedde5f347af8bf358ce6a131662ed784220f4 (diff)
downloadydb-7ddcd63286aa8b7a10c462f4d29507790855d43b.tar.gz
Intermediate changes
-rw-r--r--contrib/python/anyio/.dist-info/METADATA37
-rw-r--r--contrib/python/anyio/anyio/__init__.py215
-rw-r--r--contrib/python/anyio/anyio/_backends/_asyncio.py1903
-rw-r--r--contrib/python/anyio/anyio/_backends/_trio.py841
-rw-r--r--contrib/python/anyio/anyio/_core/_compat.py217
-rw-r--r--contrib/python/anyio/anyio/_core/_eventloop.py74
-rw-r--r--contrib/python/anyio/anyio/_core/_exceptions.py51
-rw-r--r--contrib/python/anyio/anyio/_core/_fileio.py148
-rw-r--r--contrib/python/anyio/anyio/_core/_signals.py25
-rw-r--r--contrib/python/anyio/anyio/_core/_sockets.py265
-rw-r--r--contrib/python/anyio/anyio/_core/_streams.py59
-rw-r--r--contrib/python/anyio/anyio/_core/_subprocesses.py113
-rw-r--r--contrib/python/anyio/anyio/_core/_synchronization.py249
-rw-r--r--contrib/python/anyio/anyio/_core/_tasks.py112
-rw-r--r--contrib/python/anyio/anyio/_core/_testing.py22
-rw-r--r--contrib/python/anyio/anyio/_core/_typedattr.py28
-rw-r--r--contrib/python/anyio/anyio/abc/__init__.py119
-rw-r--r--contrib/python/anyio/anyio/abc/_eventloop.py392
-rw-r--r--contrib/python/anyio/anyio/abc/_resources.py4
-rw-r--r--contrib/python/anyio/anyio/abc/_sockets.py70
-rw-r--r--contrib/python/anyio/anyio/abc/_streams.py54
-rw-r--r--contrib/python/anyio/anyio/abc/_subprocesses.py4
-rw-r--r--contrib/python/anyio/anyio/abc/_tasks.py44
-rw-r--r--contrib/python/anyio/anyio/abc/_testing.py16
-rw-r--r--contrib/python/anyio/anyio/from_thread.py240
-rw-r--r--contrib/python/anyio/anyio/lowlevel.py45
-rw-r--r--contrib/python/anyio/anyio/pytest_plugin.py53
-rw-r--r--contrib/python/anyio/anyio/streams/buffered.py11
-rw-r--r--contrib/python/anyio/anyio/streams/file.py7
-rw-r--r--contrib/python/anyio/anyio/streams/memory.py48
-rw-r--r--contrib/python/anyio/anyio/streams/stapled.py19
-rw-r--r--contrib/python/anyio/anyio/streams/text.py44
-rw-r--r--contrib/python/anyio/anyio/streams/tls.py42
-rw-r--r--contrib/python/anyio/anyio/to_process.py60
-rw-r--r--contrib/python/anyio/anyio/to_thread.py72
-rw-r--r--contrib/python/anyio/ya.make4
36 files changed, 3240 insertions, 2467 deletions
diff --git a/contrib/python/anyio/.dist-info/METADATA b/contrib/python/anyio/.dist-info/METADATA
index 5e46476e02..e02715ca28 100644
--- a/contrib/python/anyio/.dist-info/METADATA
+++ b/contrib/python/anyio/.dist-info/METADATA
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: anyio
-Version: 3.7.1
+Version: 4.3.0
Summary: High level compatibility layer for multiple asynchronous event loop implementations
Author-email: Alex Grönholm <alex.gronholm@nextday.fi>
License: MIT
@@ -15,36 +15,35 @@ Classifier: Framework :: AnyIO
Classifier: Typing :: Typed
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
-Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
-Requires-Python: >=3.7
+Classifier: Programming Language :: Python :: 3.12
+Requires-Python: >=3.8
Description-Content-Type: text/x-rst
License-File: LICENSE
-Requires-Dist: idna (>=2.8)
-Requires-Dist: sniffio (>=1.1)
-Requires-Dist: exceptiongroup ; python_version < "3.11"
-Requires-Dist: typing-extensions ; python_version < "3.8"
+Requires-Dist: idna >=2.8
+Requires-Dist: sniffio >=1.1
+Requires-Dist: exceptiongroup >=1.0.2 ; python_version < "3.11"
+Requires-Dist: typing-extensions >=4.1 ; python_version < "3.11"
Provides-Extra: doc
Requires-Dist: packaging ; extra == 'doc'
-Requires-Dist: Sphinx ; extra == 'doc'
-Requires-Dist: sphinx-rtd-theme (>=1.2.2) ; extra == 'doc'
-Requires-Dist: sphinxcontrib-jquery ; extra == 'doc'
-Requires-Dist: sphinx-autodoc-typehints (>=1.2.0) ; extra == 'doc'
+Requires-Dist: Sphinx >=7 ; extra == 'doc'
+Requires-Dist: sphinx-rtd-theme ; extra == 'doc'
+Requires-Dist: sphinx-autodoc-typehints >=1.2.0 ; extra == 'doc'
Provides-Extra: test
Requires-Dist: anyio[trio] ; extra == 'test'
-Requires-Dist: coverage[toml] (>=4.5) ; extra == 'test'
-Requires-Dist: hypothesis (>=4.0) ; extra == 'test'
-Requires-Dist: psutil (>=5.9) ; extra == 'test'
-Requires-Dist: pytest (>=7.0) ; extra == 'test'
-Requires-Dist: pytest-mock (>=3.6.1) ; extra == 'test'
+Requires-Dist: coverage[toml] >=7 ; extra == 'test'
+Requires-Dist: exceptiongroup >=1.2.0 ; extra == 'test'
+Requires-Dist: hypothesis >=4.0 ; extra == 'test'
+Requires-Dist: psutil >=5.9 ; extra == 'test'
+Requires-Dist: pytest >=7.0 ; extra == 'test'
+Requires-Dist: pytest-mock >=3.6.1 ; extra == 'test'
Requires-Dist: trustme ; extra == 'test'
-Requires-Dist: uvloop (>=0.17) ; (python_version < "3.12" and platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test'
-Requires-Dist: mock (>=4) ; (python_version < "3.8") and extra == 'test'
+Requires-Dist: uvloop >=0.17 ; (platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test'
Provides-Extra: trio
-Requires-Dist: trio (<0.22) ; extra == 'trio'
+Requires-Dist: trio >=0.23 ; extra == 'trio'
.. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg
:target: https://github.com/agronholm/anyio/actions/workflows/test.yml
diff --git a/contrib/python/anyio/anyio/__init__.py b/contrib/python/anyio/anyio/__init__.py
index 29fb3561e4..7bfe231645 100644
--- a/contrib/python/anyio/anyio/__init__.py
+++ b/contrib/python/anyio/anyio/__init__.py
@@ -1,165 +1,72 @@
from __future__ import annotations
-__all__ = (
- "maybe_async",
- "maybe_async_cm",
- "run",
- "sleep",
- "sleep_forever",
- "sleep_until",
- "current_time",
- "get_all_backends",
- "get_cancelled_exc_class",
- "BrokenResourceError",
- "BrokenWorkerProcess",
- "BusyResourceError",
- "ClosedResourceError",
- "DelimiterNotFound",
- "EndOfStream",
- "ExceptionGroup",
- "IncompleteRead",
- "TypedAttributeLookupError",
- "WouldBlock",
- "AsyncFile",
- "Path",
- "open_file",
- "wrap_file",
- "aclose_forcefully",
- "open_signal_receiver",
- "connect_tcp",
- "connect_unix",
- "create_tcp_listener",
- "create_unix_listener",
- "create_udp_socket",
- "create_connected_udp_socket",
- "getaddrinfo",
- "getnameinfo",
- "wait_socket_readable",
- "wait_socket_writable",
- "create_memory_object_stream",
- "run_process",
- "open_process",
- "create_lock",
- "CapacityLimiter",
- "CapacityLimiterStatistics",
- "Condition",
- "ConditionStatistics",
- "Event",
- "EventStatistics",
- "Lock",
- "LockStatistics",
- "Semaphore",
- "SemaphoreStatistics",
- "create_condition",
- "create_event",
- "create_semaphore",
- "create_capacity_limiter",
- "open_cancel_scope",
- "fail_after",
- "move_on_after",
- "current_effective_deadline",
- "TASK_STATUS_IGNORED",
- "CancelScope",
- "create_task_group",
- "TaskInfo",
- "get_current_task",
- "get_running_tasks",
- "wait_all_tasks_blocked",
- "run_sync_in_worker_thread",
- "run_async_from_thread",
- "run_sync_from_thread",
- "current_default_worker_thread_limiter",
- "create_blocking_portal",
- "start_blocking_portal",
- "typed_attribute",
- "TypedAttributeSet",
- "TypedAttributeProvider",
-)
-
from typing import Any
-from ._core._compat import maybe_async, maybe_async_cm
-from ._core._eventloop import (
- current_time,
- get_all_backends,
- get_cancelled_exc_class,
- run,
- sleep,
- sleep_forever,
- sleep_until,
-)
-from ._core._exceptions import (
- BrokenResourceError,
- BrokenWorkerProcess,
- BusyResourceError,
- ClosedResourceError,
- DelimiterNotFound,
- EndOfStream,
- ExceptionGroup,
- IncompleteRead,
- TypedAttributeLookupError,
- WouldBlock,
-)
-from ._core._fileio import AsyncFile, Path, open_file, wrap_file
-from ._core._resources import aclose_forcefully
-from ._core._signals import open_signal_receiver
+from ._core._eventloop import current_time as current_time
+from ._core._eventloop import get_all_backends as get_all_backends
+from ._core._eventloop import get_cancelled_exc_class as get_cancelled_exc_class
+from ._core._eventloop import run as run
+from ._core._eventloop import sleep as sleep
+from ._core._eventloop import sleep_forever as sleep_forever
+from ._core._eventloop import sleep_until as sleep_until
+from ._core._exceptions import BrokenResourceError as BrokenResourceError
+from ._core._exceptions import BrokenWorkerProcess as BrokenWorkerProcess
+from ._core._exceptions import BusyResourceError as BusyResourceError
+from ._core._exceptions import ClosedResourceError as ClosedResourceError
+from ._core._exceptions import DelimiterNotFound as DelimiterNotFound
+from ._core._exceptions import EndOfStream as EndOfStream
+from ._core._exceptions import IncompleteRead as IncompleteRead
+from ._core._exceptions import TypedAttributeLookupError as TypedAttributeLookupError
+from ._core._exceptions import WouldBlock as WouldBlock
+from ._core._fileio import AsyncFile as AsyncFile
+from ._core._fileio import Path as Path
+from ._core._fileio import open_file as open_file
+from ._core._fileio import wrap_file as wrap_file
+from ._core._resources import aclose_forcefully as aclose_forcefully
+from ._core._signals import open_signal_receiver as open_signal_receiver
+from ._core._sockets import connect_tcp as connect_tcp
+from ._core._sockets import connect_unix as connect_unix
+from ._core._sockets import create_connected_udp_socket as create_connected_udp_socket
from ._core._sockets import (
- connect_tcp,
- connect_unix,
- create_connected_udp_socket,
- create_tcp_listener,
- create_udp_socket,
- create_unix_listener,
- getaddrinfo,
- getnameinfo,
- wait_socket_readable,
- wait_socket_writable,
+ create_connected_unix_datagram_socket as create_connected_unix_datagram_socket,
)
-from ._core._streams import create_memory_object_stream
-from ._core._subprocesses import open_process, run_process
+from ._core._sockets import create_tcp_listener as create_tcp_listener
+from ._core._sockets import create_udp_socket as create_udp_socket
+from ._core._sockets import create_unix_datagram_socket as create_unix_datagram_socket
+from ._core._sockets import create_unix_listener as create_unix_listener
+from ._core._sockets import getaddrinfo as getaddrinfo
+from ._core._sockets import getnameinfo as getnameinfo
+from ._core._sockets import wait_socket_readable as wait_socket_readable
+from ._core._sockets import wait_socket_writable as wait_socket_writable
+from ._core._streams import create_memory_object_stream as create_memory_object_stream
+from ._core._subprocesses import open_process as open_process
+from ._core._subprocesses import run_process as run_process
+from ._core._synchronization import CapacityLimiter as CapacityLimiter
from ._core._synchronization import (
- CapacityLimiter,
- CapacityLimiterStatistics,
- Condition,
- ConditionStatistics,
- Event,
- EventStatistics,
- Lock,
- LockStatistics,
- Semaphore,
- SemaphoreStatistics,
- create_capacity_limiter,
- create_condition,
- create_event,
- create_lock,
- create_semaphore,
-)
-from ._core._tasks import (
- TASK_STATUS_IGNORED,
- CancelScope,
- create_task_group,
- current_effective_deadline,
- fail_after,
- move_on_after,
- open_cancel_scope,
-)
-from ._core._testing import (
- TaskInfo,
- get_current_task,
- get_running_tasks,
- wait_all_tasks_blocked,
-)
-from ._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute
-
-# Re-exported here, for backwards compatibility
-# isort: off
-from .to_thread import current_default_worker_thread_limiter, run_sync_in_worker_thread
-from .from_thread import (
- create_blocking_portal,
- run_async_from_thread,
- run_sync_from_thread,
- start_blocking_portal,
+ CapacityLimiterStatistics as CapacityLimiterStatistics,
)
+from ._core._synchronization import Condition as Condition
+from ._core._synchronization import ConditionStatistics as ConditionStatistics
+from ._core._synchronization import Event as Event
+from ._core._synchronization import EventStatistics as EventStatistics
+from ._core._synchronization import Lock as Lock
+from ._core._synchronization import LockStatistics as LockStatistics
+from ._core._synchronization import ResourceGuard as ResourceGuard
+from ._core._synchronization import Semaphore as Semaphore
+from ._core._synchronization import SemaphoreStatistics as SemaphoreStatistics
+from ._core._tasks import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED
+from ._core._tasks import CancelScope as CancelScope
+from ._core._tasks import create_task_group as create_task_group
+from ._core._tasks import current_effective_deadline as current_effective_deadline
+from ._core._tasks import fail_after as fail_after
+from ._core._tasks import move_on_after as move_on_after
+from ._core._testing import TaskInfo as TaskInfo
+from ._core._testing import get_current_task as get_current_task
+from ._core._testing import get_running_tasks as get_running_tasks
+from ._core._testing import wait_all_tasks_blocked as wait_all_tasks_blocked
+from ._core._typedattr import TypedAttributeProvider as TypedAttributeProvider
+from ._core._typedattr import TypedAttributeSet as TypedAttributeSet
+from ._core._typedattr import typed_attribute as typed_attribute
# Re-export imports so they look like they live directly in this package
key: str
diff --git a/contrib/python/anyio/anyio/_backends/_asyncio.py b/contrib/python/anyio/anyio/_backends/_asyncio.py
index bfdb4ea7e1..2699bf8146 100644
--- a/contrib/python/anyio/anyio/_backends/_asyncio.py
+++ b/contrib/python/anyio/anyio/_backends/_asyncio.py
@@ -6,23 +6,34 @@ import concurrent.futures
import math
import socket
import sys
+import threading
+from asyncio import (
+ AbstractEventLoop,
+ CancelledError,
+ all_tasks,
+ create_task,
+ current_task,
+ get_running_loop,
+ sleep,
+)
from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
from collections import OrderedDict, deque
+from collections.abc import AsyncIterator, Generator, Iterable
from concurrent.futures import Future
+from contextlib import suppress
from contextvars import Context, copy_context
from dataclasses import dataclass
from functools import partial, wraps
from inspect import (
CORO_RUNNING,
CORO_SUSPENDED,
- GEN_RUNNING,
- GEN_SUSPENDED,
getcoroutinestate,
- getgeneratorstate,
+ iscoroutine,
)
from io import IOBase
from os import PathLike
from queue import Queue
+from signal import Signals
from socket import AddressFamily, SocketKind
from threading import Thread
from types import TracebackType
@@ -33,15 +44,13 @@ from typing import (
Awaitable,
Callable,
Collection,
+ ContextManager,
Coroutine,
- Generator,
- Iterable,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
- Union,
cast,
)
from weakref import WeakKeyDictionary
@@ -49,7 +58,6 @@ from weakref import WeakKeyDictionary
import sniffio
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
-from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable
from .._core._eventloop import claim_worker_thread, threadlocals
from .._core._exceptions import (
BrokenResourceError,
@@ -58,40 +66,220 @@ from .._core._exceptions import (
EndOfStream,
WouldBlock,
)
-from .._core._exceptions import ExceptionGroup as BaseExceptionGroup
-from .._core._sockets import GetAddrInfoReturnType, convert_ipv6_sockaddr
+from .._core._sockets import convert_ipv6_sockaddr
+from .._core._streams import create_memory_object_stream
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
from .._core._synchronization import Event as BaseEvent
from .._core._synchronization import ResourceGuard
from .._core._tasks import CancelScope as BaseCancelScope
-from ..abc import IPSockAddrType, UDPPacketType
+from ..abc import (
+ AsyncBackend,
+ IPSockAddrType,
+ SocketListener,
+ UDPPacketType,
+ UNIXDatagramPacketType,
+)
from ..lowlevel import RunVar
+from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
-if sys.version_info >= (3, 8):
-
- def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
- return task.get_coro()
+if sys.version_info >= (3, 10):
+ from typing import ParamSpec
+else:
+ from typing_extensions import ParamSpec
+if sys.version_info >= (3, 11):
+ from asyncio import Runner
+ from typing import TypeVarTuple, Unpack
else:
+ import contextvars
+ import enum
+ import signal
+ from asyncio import coroutines, events, exceptions, tasks
+
+ from exceptiongroup import BaseExceptionGroup
+ from typing_extensions import TypeVarTuple, Unpack
+
+ class _State(enum.Enum):
+ CREATED = "created"
+ INITIALIZED = "initialized"
+ CLOSED = "closed"
+
+ class Runner:
+ # Copied from CPython 3.11
+ def __init__(
+ self,
+ *,
+ debug: bool | None = None,
+ loop_factory: Callable[[], AbstractEventLoop] | None = None,
+ ):
+ self._state = _State.CREATED
+ self._debug = debug
+ self._loop_factory = loop_factory
+ self._loop: AbstractEventLoop | None = None
+ self._context = None
+ self._interrupt_count = 0
+ self._set_event_loop = False
+
+ def __enter__(self) -> Runner:
+ self._lazy_init()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException],
+ exc_val: BaseException,
+ exc_tb: TracebackType,
+ ) -> None:
+ self.close()
+
+ def close(self) -> None:
+ """Shutdown and close event loop."""
+ if self._state is not _State.INITIALIZED:
+ return
+ try:
+ loop = self._loop
+ _cancel_all_tasks(loop)
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ if hasattr(loop, "shutdown_default_executor"):
+ loop.run_until_complete(loop.shutdown_default_executor())
+ else:
+ loop.run_until_complete(_shutdown_default_executor(loop))
+ finally:
+ if self._set_event_loop:
+ events.set_event_loop(None)
+ loop.close()
+ self._loop = None
+ self._state = _State.CLOSED
+
+ def get_loop(self) -> AbstractEventLoop:
+ """Return embedded event loop."""
+ self._lazy_init()
+ return self._loop
+
+ def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
+ """Run a coroutine inside the embedded event loop."""
+ if not coroutines.iscoroutine(coro):
+ raise ValueError(f"a coroutine was expected, got {coro!r}")
+
+ if events._get_running_loop() is not None:
+ # fail fast with short traceback
+ raise RuntimeError(
+ "Runner.run() cannot be called from a running event loop"
+ )
+
+ self._lazy_init()
- def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
- return task._coro
+ if context is None:
+ context = self._context
+ task = context.run(self._loop.create_task, coro)
+ if (
+ threading.current_thread() is threading.main_thread()
+ and signal.getsignal(signal.SIGINT) is signal.default_int_handler
+ ):
+ sigint_handler = partial(self._on_sigint, main_task=task)
+ try:
+ signal.signal(signal.SIGINT, sigint_handler)
+ except ValueError:
+ # `signal.signal` may throw if `threading.main_thread` does
+ # not support signals (e.g. embedded interpreter with signals
+ # not registered - see gh-91880)
+ sigint_handler = None
+ else:
+ sigint_handler = None
-from asyncio import all_tasks, create_task, current_task, get_running_loop
-from asyncio import run as native_run
+ self._interrupt_count = 0
+ try:
+ return self._loop.run_until_complete(task)
+ except exceptions.CancelledError:
+ if self._interrupt_count > 0:
+ uncancel = getattr(task, "uncancel", None)
+ if uncancel is not None and uncancel() == 0:
+ raise KeyboardInterrupt()
+ raise # CancelledError
+ finally:
+ if (
+ sigint_handler is not None
+ and signal.getsignal(signal.SIGINT) is sigint_handler
+ ):
+ signal.signal(signal.SIGINT, signal.default_int_handler)
+ def _lazy_init(self) -> None:
+ if self._state is _State.CLOSED:
+ raise RuntimeError("Runner is closed")
+ if self._state is _State.INITIALIZED:
+ return
+ if self._loop_factory is None:
+ self._loop = events.new_event_loop()
+ if not self._set_event_loop:
+ # Call set_event_loop only once to avoid calling
+ # attach_loop multiple times on child watchers
+ events.set_event_loop(self._loop)
+ self._set_event_loop = True
+ else:
+ self._loop = self._loop_factory()
+ if self._debug is not None:
+ self._loop.set_debug(self._debug)
+ self._context = contextvars.copy_context()
+ self._state = _State.INITIALIZED
+
+ def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
+ self._interrupt_count += 1
+ if self._interrupt_count == 1 and not main_task.done():
+ main_task.cancel()
+ # wakeup loop if it is blocked by select() with long timeout
+ self._loop.call_soon_threadsafe(lambda: None)
+ return
+ raise KeyboardInterrupt()
-def _get_task_callbacks(task: asyncio.Task) -> Iterable[Callable]:
- return [cb for cb, context in task._callbacks]
+ def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
+ to_cancel = tasks.all_tasks(loop)
+ if not to_cancel:
+ return
+ for task in to_cancel:
+ task.cancel()
-T_Retval = TypeVar("T_Retval")
-T_contra = TypeVar("T_contra", contravariant=True)
+ loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
-# Check whether there is native support for task names in asyncio (3.8+)
-_native_task_names = hasattr(asyncio.Task, "get_name")
+ for task in to_cancel:
+ if task.cancelled():
+ continue
+ if task.exception() is not None:
+ loop.call_exception_handler(
+ {
+ "message": "unhandled exception during asyncio.run() shutdown",
+ "exception": task.exception(),
+ "task": task,
+ }
+ )
+ async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
+ """Schedule the shutdown of the default executor."""
+
+ def _do_shutdown(future: asyncio.futures.Future) -> None:
+ try:
+ loop._default_executor.shutdown(wait=True) # type: ignore[attr-defined]
+ loop.call_soon_threadsafe(future.set_result, None)
+ except Exception as ex:
+ loop.call_soon_threadsafe(future.set_exception, ex)
+
+ loop._executor_shutdown_called = True
+ if loop._default_executor is None:
+ return
+ future = loop.create_future()
+ thread = threading.Thread(target=_do_shutdown, args=(future,))
+ thread.start()
+ try:
+ await future
+ finally:
+ thread.join()
+
+
+T_Retval = TypeVar("T_Retval")
+T_contra = TypeVar("T_contra", contravariant=True)
+PosArgsT = TypeVarTuple("PosArgsT")
+P = ParamSpec("P")
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
@@ -104,7 +292,8 @@ def find_root_task() -> asyncio.Task:
# Look for a task that has been started via run_until_complete()
for task in all_tasks():
if task._callbacks and not task.done():
- for cb in _get_task_callbacks(task):
+ callbacks = [cb for cb, context in task._callbacks]
+ for cb in callbacks:
if (
cb is _run_until_complete_cb
or getattr(cb, "__module__", None) == "uvloop.loop"
@@ -136,87 +325,22 @@ def get_callable_name(func: Callable) -> str:
# Event loop
#
-_run_vars = (
- WeakKeyDictionary()
-) # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any]
-
-current_token = get_running_loop
+_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
def _task_started(task: asyncio.Task) -> bool:
"""Return ``True`` if the task has been started and has not finished."""
- coro = cast(Coroutine[Any, Any, Any], get_coro(task))
try:
- return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
+ return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
except AttributeError:
- try:
- return getgeneratorstate(cast(Generator, coro)) in (
- GEN_RUNNING,
- GEN_SUSPENDED,
- )
- except AttributeError:
- # task coro is async_genenerator_asend https://bugs.python.org/issue37771
- raise Exception(f"Cannot determine if task {task} has started or not")
-
-
-def _maybe_set_event_loop_policy(
- policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool
-) -> None:
- # On CPython, use uvloop when possible if no other policy has been given and if not
- # explicitly disabled
- if policy is None and use_uvloop and sys.implementation.name == "cpython":
- try:
- import uvloop
- except ImportError:
- pass
- else:
- # Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier)
- if not hasattr(
- asyncio.AbstractEventLoop, "shutdown_default_executor"
- ) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"):
- policy = uvloop.EventLoopPolicy()
-
- if policy is not None:
- asyncio.set_event_loop_policy(policy)
-
-
-def run(
- func: Callable[..., Awaitable[T_Retval]],
- *args: object,
- debug: bool = False,
- use_uvloop: bool = False,
- policy: asyncio.AbstractEventLoopPolicy | None = None,
-) -> T_Retval:
- @wraps(func)
- async def wrapper() -> T_Retval:
- task = cast(asyncio.Task, current_task())
- task_state = TaskState(None, get_callable_name(func), None)
- _task_states[task] = task_state
- if _native_task_names:
- task.set_name(task_state.name)
-
- try:
- return await func(*args)
- finally:
- del _task_states[task]
-
- _maybe_set_event_loop_policy(policy, use_uvloop)
- return native_run(wrapper(), debug=debug)
-
-
-#
-# Miscellaneous
-#
-
-sleep = asyncio.sleep
+ # task coro is async_genenerator_asend https://bugs.python.org/issue37771
+ raise Exception(f"Cannot determine if task {task} has started or not") from None
#
# Timeouts and cancellation
#
-CancelledError = asyncio.CancelledError
-
class CancelScope(BaseCancelScope):
def __new__(
@@ -228,14 +352,16 @@ class CancelScope(BaseCancelScope):
self._deadline = deadline
self._shield = shield
self._parent_scope: CancelScope | None = None
+ self._child_scopes: set[CancelScope] = set()
self._cancel_called = False
+ self._cancelled_caught = False
self._active = False
self._timeout_handle: asyncio.TimerHandle | None = None
self._cancel_handle: asyncio.Handle | None = None
self._tasks: set[asyncio.Task] = set()
self._host_task: asyncio.Task | None = None
- self._timeout_expired = False
self._cancel_calls: int = 0
+ self._cancelling: int | None = None
def __enter__(self) -> CancelScope:
if self._active:
@@ -248,19 +374,23 @@ class CancelScope(BaseCancelScope):
try:
task_state = _task_states[host_task]
except KeyError:
- task_name = host_task.get_name() if _native_task_names else None
- task_state = TaskState(None, task_name, self)
+ task_state = TaskState(None, self)
_task_states[host_task] = task_state
else:
self._parent_scope = task_state.cancel_scope
task_state.cancel_scope = self
+ if self._parent_scope is not None:
+ self._parent_scope._child_scopes.add(self)
+ self._parent_scope._tasks.remove(host_task)
self._timeout()
self._active = True
+ if sys.version_info >= (3, 11):
+ self._cancelling = self._host_task.cancelling()
# Start cancelling the host task if the scope was cancelled before entering
if self._cancel_called:
- self._deliver_cancellation()
+ self._deliver_cancellation(self)
return self
@@ -292,56 +422,60 @@ class CancelScope(BaseCancelScope):
self._timeout_handle = None
self._tasks.remove(self._host_task)
+ if self._parent_scope is not None:
+ self._parent_scope._child_scopes.remove(self)
+ self._parent_scope._tasks.add(self._host_task)
host_task_state.cancel_scope = self._parent_scope
- # Restart the cancellation effort in the farthest directly cancelled parent scope if this
- # one was shielded
- if self._shield:
- self._deliver_cancellation_to_parent()
+ # Restart the cancellation effort in the closest directly cancelled parent
+ # scope if this one was shielded
+ self._restart_cancellation_in_parent()
- if exc_val is not None:
- exceptions = (
- exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
- )
- if all(isinstance(exc, CancelledError) for exc in exceptions):
- if self._timeout_expired:
- return self._uncancel()
- elif not self._cancel_called:
- # Task was cancelled natively
- return None
- elif not self._parent_cancelled():
- # This scope was directly cancelled
- return self._uncancel()
+ if self._cancel_called and exc_val is not None:
+ for exc in iterate_exceptions(exc_val):
+ if isinstance(exc, CancelledError):
+ self._cancelled_caught = self._uncancel(exc)
+ if self._cancelled_caught:
+ break
+
+ return self._cancelled_caught
return None
- def _uncancel(self) -> bool:
- if sys.version_info < (3, 11) or self._host_task is None:
+ def _uncancel(self, cancelled_exc: CancelledError) -> bool:
+ if sys.version_info < (3, 9) or self._host_task is None:
self._cancel_calls = 0
return True
- # Uncancel all AnyIO cancellations
- for i in range(self._cancel_calls):
- self._host_task.uncancel()
+ # Undo all cancellations done by this scope
+ if self._cancelling is not None:
+ while self._cancel_calls:
+ self._cancel_calls -= 1
+ if self._host_task.uncancel() <= self._cancelling:
+ return True
self._cancel_calls = 0
- return not self._host_task.cancelling()
+ return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args
def _timeout(self) -> None:
if self._deadline != math.inf:
loop = get_running_loop()
if loop.time() >= self._deadline:
- self._timeout_expired = True
self.cancel()
else:
self._timeout_handle = loop.call_at(self._deadline, self._timeout)
- def _deliver_cancellation(self) -> None:
+ def _deliver_cancellation(self, origin: CancelScope) -> bool:
"""
Deliver cancellation to directly contained tasks and nested cancel scopes.
- Schedule another run at the end if we still have tasks eligible for cancellation.
+ Schedule another run at the end if we still have tasks eligible for
+ cancellation.
+
+ :param origin: the cancel scope that originated the cancellation
+ :return: ``True`` if the delivery needs to be retried on the next cycle
+
"""
should_retry = False
current = current_task()
@@ -349,37 +483,46 @@ class CancelScope(BaseCancelScope):
if task._must_cancel: # type: ignore[attr-defined]
continue
- # The task is eligible for cancellation if it has started and is not in a cancel
- # scope shielded from this one
- cancel_scope = _task_states[task].cancel_scope
- while cancel_scope is not self:
- if cancel_scope is None or cancel_scope._shield:
- break
- else:
- cancel_scope = cancel_scope._parent_scope
- else:
- should_retry = True
- if task is not current and (
- task is self._host_task or _task_started(task)
- ):
+ # The task is eligible for cancellation if it has started
+ should_retry = True
+ if task is not current and (task is self._host_task or _task_started(task)):
+ waiter = task._fut_waiter # type: ignore[attr-defined]
+ if not isinstance(waiter, asyncio.Future) or not waiter.done():
self._cancel_calls += 1
- task.cancel()
+ if sys.version_info >= (3, 9):
+ task.cancel(f"Cancelled by cancel scope {id(origin):x}")
+ else:
+ task.cancel()
+
+ # Deliver cancellation to child scopes that aren't shielded or running their own
+ # cancellation callbacks
+ for scope in self._child_scopes:
+ if not scope._shield and not scope.cancel_called:
+ should_retry = scope._deliver_cancellation(origin) or should_retry
# Schedule another callback if there are still tasks left
- if should_retry:
- self._cancel_handle = get_running_loop().call_soon(
- self._deliver_cancellation
- )
- else:
- self._cancel_handle = None
+ if origin is self:
+ if should_retry:
+ self._cancel_handle = get_running_loop().call_soon(
+ self._deliver_cancellation, origin
+ )
+ else:
+ self._cancel_handle = None
- def _deliver_cancellation_to_parent(self) -> None:
- """Start cancellation effort in the farthest directly cancelled parent scope"""
+ return should_retry
+
+ def _restart_cancellation_in_parent(self) -> None:
+ """
+ Restart the cancellation effort in the closest directly cancelled parent scope.
+
+ """
scope = self._parent_scope
- scope_to_cancel: CancelScope | None = None
while scope is not None:
- if scope._cancel_called and scope._cancel_handle is None:
- scope_to_cancel = scope
+ if scope._cancel_called:
+ if scope._cancel_handle is None:
+ scope._deliver_cancellation(scope)
+
+ break
# No point in looking beyond any shielded scope
if scope._shield:
@@ -387,9 +530,6 @@ class CancelScope(BaseCancelScope):
scope = scope._parent_scope
- if scope_to_cancel is not None:
- scope_to_cancel._deliver_cancellation()
-
def _parent_cancelled(self) -> bool:
# Check whether any parent has been cancelled
cancel_scope = self._parent_scope
@@ -401,7 +541,7 @@ class CancelScope(BaseCancelScope):
return False
- def cancel(self) -> DeprecatedAwaitable:
+ def cancel(self) -> None:
if not self._cancel_called:
if self._timeout_handle:
self._timeout_handle.cancel()
@@ -409,9 +549,7 @@ class CancelScope(BaseCancelScope):
self._cancel_called = True
if self._host_task is not None:
- self._deliver_cancellation()
-
- return DeprecatedAwaitable(self.cancel)
+ self._deliver_cancellation(self)
@property
def deadline(self) -> float:
@@ -432,6 +570,10 @@ class CancelScope(BaseCancelScope):
return self._cancel_called
@property
+ def cancelled_caught(self) -> bool:
+ return self._cancelled_caught
+
+ @property
def shield(self) -> bool:
return self._shield
@@ -440,59 +582,7 @@ class CancelScope(BaseCancelScope):
if self._shield != value:
self._shield = value
if not value:
- self._deliver_cancellation_to_parent()
-
-
-async def checkpoint() -> None:
- await sleep(0)
-
-
-async def checkpoint_if_cancelled() -> None:
- task = current_task()
- if task is None:
- return
-
- try:
- cancel_scope = _task_states[task].cancel_scope
- except KeyError:
- return
-
- while cancel_scope:
- if cancel_scope.cancel_called:
- await sleep(0)
- elif cancel_scope.shield:
- break
- else:
- cancel_scope = cancel_scope._parent_scope
-
-
-async def cancel_shielded_checkpoint() -> None:
- with CancelScope(shield=True):
- await sleep(0)
-
-
-def current_effective_deadline() -> float:
- try:
- cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index]
- except KeyError:
- return math.inf
-
- deadline = math.inf
- while cancel_scope:
- deadline = min(deadline, cancel_scope.deadline)
- if cancel_scope._cancel_called:
- deadline = -math.inf
- break
- elif cancel_scope.shield:
- break
- else:
- cancel_scope = cancel_scope._parent_scope
-
- return deadline
-
-
-def current_time() -> float:
- return get_running_loop().time()
+ self._restart_cancellation_in_parent()
#
@@ -502,20 +592,14 @@ def current_time() -> float:
class TaskState:
"""
- Encapsulates auxiliary task information that cannot be added to the Task instance itself
- because there are no guarantees about its implementation.
+ Encapsulates auxiliary task information that cannot be added to the Task instance
+ itself because there are no guarantees about its implementation.
"""
- __slots__ = "parent_id", "name", "cancel_scope"
+ __slots__ = "parent_id", "cancel_scope"
- def __init__(
- self,
- parent_id: int | None,
- name: str | None,
- cancel_scope: CancelScope | None,
- ):
+ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.parent_id = parent_id
- self.name = name
self.cancel_scope = cancel_scope
@@ -527,12 +611,6 @@ _task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, Task
#
-class ExceptionGroup(BaseExceptionGroup):
- def __init__(self, exceptions: list[BaseException]):
- super().__init__()
- self.exceptions = exceptions
-
-
class _AsyncioTaskStatus(abc.TaskStatus):
def __init__(self, future: asyncio.Future, parent_id: int):
self._future = future
@@ -550,11 +628,22 @@ class _AsyncioTaskStatus(abc.TaskStatus):
_task_states[task].parent_id = self._parent_id
+def iterate_exceptions(
+ exception: BaseException,
+) -> Generator[BaseException, None, None]:
+ if isinstance(exception, BaseExceptionGroup):
+ for exc in exception.exceptions:
+ yield from iterate_exceptions(exc)
+ else:
+ yield exception
+
+
class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
self._active = False
self._exceptions: list[BaseException] = []
+ self._tasks: set[asyncio.Task] = set()
async def __aenter__(self) -> TaskGroup:
self.cancel_scope.__enter__()
@@ -570,98 +659,49 @@ class TaskGroup(abc.TaskGroup):
ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
if exc_val is not None:
self.cancel_scope.cancel()
- self._exceptions.append(exc_val)
+ if not isinstance(exc_val, CancelledError):
+ self._exceptions.append(exc_val)
- while self.cancel_scope._tasks:
+ cancelled_exc_while_waiting_tasks: CancelledError | None = None
+ while self._tasks:
try:
- await asyncio.wait(self.cancel_scope._tasks)
- except asyncio.CancelledError:
+ await asyncio.wait(self._tasks)
+ except CancelledError as exc:
+ # This task was cancelled natively; reraise the CancelledError later
+ # unless this task was already interrupted by another exception
self.cancel_scope.cancel()
+ if cancelled_exc_while_waiting_tasks is None:
+ cancelled_exc_while_waiting_tasks = exc
self._active = False
- if not self.cancel_scope._parent_cancelled():
- exceptions = self._filter_cancellation_errors(self._exceptions)
- else:
- exceptions = self._exceptions
+ if self._exceptions:
+ raise BaseExceptionGroup(
+ "unhandled errors in a TaskGroup", self._exceptions
+ )
- try:
- if len(exceptions) > 1:
- if all(
- isinstance(e, CancelledError) and not e.args for e in exceptions
- ):
- # Tasks were cancelled natively, without a cancellation message
- raise CancelledError
- else:
- raise ExceptionGroup(exceptions)
- elif exceptions and exceptions[0] is not exc_val:
- raise exceptions[0]
- except BaseException as exc:
- # Clear the context here, as it can only be done in-flight.
- # If the context is not cleared, it can result in recursive tracebacks (see #145).
- exc.__context__ = None
- raise
+ # Raise the CancelledError received while waiting for child tasks to exit,
+ # unless the context manager itself was previously exited with another
+ # exception, or if any of the child tasks raised an exception other than
+ # CancelledError
+ if cancelled_exc_while_waiting_tasks:
+ if exc_val is None or ignore_exception:
+ raise cancelled_exc_while_waiting_tasks
return ignore_exception
- @staticmethod
- def _filter_cancellation_errors(
- exceptions: Sequence[BaseException],
- ) -> list[BaseException]:
- filtered_exceptions: list[BaseException] = []
- for exc in exceptions:
- if isinstance(exc, ExceptionGroup):
- new_exceptions = TaskGroup._filter_cancellation_errors(exc.exceptions)
- if len(new_exceptions) > 1:
- filtered_exceptions.append(exc)
- elif len(new_exceptions) == 1:
- filtered_exceptions.append(new_exceptions[0])
- elif new_exceptions:
- new_exc = ExceptionGroup(new_exceptions)
- new_exc.__cause__ = exc.__cause__
- new_exc.__context__ = exc.__context__
- new_exc.__traceback__ = exc.__traceback__
- filtered_exceptions.append(new_exc)
- elif not isinstance(exc, CancelledError) or exc.args:
- filtered_exceptions.append(exc)
-
- return filtered_exceptions
-
- async def _run_wrapped_task(
- self, coro: Coroutine, task_status_future: asyncio.Future | None
- ) -> None:
- # This is the code path for Python 3.7 on which asyncio freaks out if a task
- # raises a BaseException.
- __traceback_hide__ = __tracebackhide__ = True # noqa: F841
- task = cast(asyncio.Task, current_task())
- try:
- await coro
- except BaseException as exc:
- if task_status_future is None or task_status_future.done():
- self._exceptions.append(exc)
- self.cancel_scope.cancel()
- else:
- task_status_future.set_exception(exc)
- else:
- if task_status_future is not None and not task_status_future.done():
- task_status_future.set_exception(
- RuntimeError("Child exited without calling task_status.started()")
- )
- finally:
- if task in self.cancel_scope._tasks:
- self.cancel_scope._tasks.remove(task)
- del _task_states[task]
-
def _spawn(
self,
- func: Callable[..., Awaitable[Any]],
- args: tuple,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
+ args: tuple[Unpack[PosArgsT]],
name: object,
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
- # This is the code path for Python 3.8+
- assert _task in self.cancel_scope._tasks
- self.cancel_scope._tasks.remove(_task)
+ task_state = _task_states[_task]
+ assert task_state.cancel_scope is not None
+ assert _task in task_state.cancel_scope._tasks
+ task_state.cancel_scope._tasks.remove(_task)
+ self._tasks.remove(task)
del _task_states[_task]
try:
@@ -674,8 +714,11 @@ class TaskGroup(abc.TaskGroup):
if exc is not None:
if task_status_future is None or task_status_future.done():
- self._exceptions.append(exc)
- self.cancel_scope.cancel()
+ if not isinstance(exc, CancelledError):
+ self._exceptions.append(exc)
+
+ if not self.cancel_scope._parent_cancelled():
+ self.cancel_scope.cancel()
else:
task_status_future.set_exception(exc)
elif task_status_future is not None and not task_status_future.done():
@@ -688,11 +731,6 @@ class TaskGroup(abc.TaskGroup):
"This task group is not active; no new tasks can be started."
)
- options: dict[str, Any] = {}
- name = get_callable_name(func) if name is None else str(name)
- if _native_task_names:
- options["name"] = name
-
kwargs = {}
if task_status_future:
parent_id = id(current_task())
@@ -703,46 +741,52 @@ class TaskGroup(abc.TaskGroup):
parent_id = id(self.cancel_scope._host_task)
coro = func(*args, **kwargs)
- if not asyncio.iscoroutine(coro):
+ if not iscoroutine(coro):
+ prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
raise TypeError(
- f"Expected an async function, but {func} appears to be synchronous"
+ f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
+ f"the return value ({coro!r}) is not a coroutine object"
)
- foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame")
- if foreign_coro or sys.version_info < (3, 8):
- coro = self._run_wrapped_task(coro, task_status_future)
-
- task = create_task(coro, **options)
- if not foreign_coro and sys.version_info >= (3, 8):
- task.add_done_callback(task_done)
+ name = get_callable_name(func) if name is None else str(name)
+ task = create_task(coro, name=name)
+ task.add_done_callback(task_done)
# Make the spawned task inherit the task group's cancel scope
_task_states[task] = TaskState(
- parent_id=parent_id, name=name, cancel_scope=self.cancel_scope
+ parent_id=parent_id, cancel_scope=self.cancel_scope
)
self.cancel_scope._tasks.add(task)
+ self._tasks.add(task)
return task
def start_soon(
- self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
+ self,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
+ *args: Unpack[PosArgsT],
+ name: object = None,
) -> None:
self._spawn(func, args, name)
async def start(
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
- ) -> None:
+ ) -> Any:
future: asyncio.Future = asyncio.Future()
task = self._spawn(func, args, name, future)
- # If the task raises an exception after sending a start value without a switch point
- # between, the task group is cancelled and this method never proceeds to process the
- # completed future. That's why we have to have a shielded cancel scope here.
- with CancelScope(shield=True):
- try:
- return await future
- except CancelledError:
- task.cancel()
- raise
+ # If the task raises an exception after sending a start value without a switch
+ # point between, the task group is cancelled and this method never proceeds to
+ # process the completed future. That's why we have to have a shielded cancel
+ # scope here.
+ try:
+ return await future
+ except CancelledError:
+ # Cancel the task and wait for it to exit before returning
+ task.cancel()
+ with CancelScope(shield=True), suppress(CancelledError):
+ await task
+
+ raise
#
@@ -767,15 +811,15 @@ class WorkerThread(Thread):
self.idle_workers = idle_workers
self.loop = root_task._loop
self.queue: Queue[
- tuple[Context, Callable, tuple, asyncio.Future] | None
+ tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
] = Queue(2)
- self.idle_since = current_time()
+ self.idle_since = AsyncIOBackend.current_time()
self.stopping = False
def _report_result(
self, future: asyncio.Future, result: Any, exc: BaseException | None
) -> None:
- self.idle_since = current_time()
+ self.idle_since = AsyncIOBackend.current_time()
if not self.stopping:
self.idle_workers.append(self)
@@ -791,22 +835,24 @@ class WorkerThread(Thread):
future.set_result(result)
def run(self) -> None:
- with claim_worker_thread("asyncio"):
- threadlocals.loop = self.loop
+ with claim_worker_thread(AsyncIOBackend, self.loop):
while True:
item = self.queue.get()
if item is None:
# Shutdown command received
return
- context, func, args, future = item
+ context, func, args, future, cancel_scope = item
if not future.cancelled():
result = None
exception: BaseException | None = None
+ threadlocals.current_cancel_scope = cancel_scope
try:
result = context.run(func, *args)
except BaseException as exc:
exception = exc
+ finally:
+ del threadlocals.current_cancel_scope
if not self.loop.is_closed():
self.loop.call_soon_threadsafe(
@@ -831,81 +877,6 @@ _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
-async def run_sync_in_worker_thread(
- func: Callable[..., T_Retval],
- *args: object,
- cancellable: bool = False,
- limiter: CapacityLimiter | None = None,
-) -> T_Retval:
- await checkpoint()
-
- # If this is the first run in this event loop thread, set up the necessary variables
- try:
- idle_workers = _threadpool_idle_workers.get()
- workers = _threadpool_workers.get()
- except LookupError:
- idle_workers = deque()
- workers = set()
- _threadpool_idle_workers.set(idle_workers)
- _threadpool_workers.set(workers)
-
- async with (limiter or current_default_thread_limiter()):
- with CancelScope(shield=not cancellable):
- future: asyncio.Future = asyncio.Future()
- root_task = find_root_task()
- if not idle_workers:
- worker = WorkerThread(root_task, workers, idle_workers)
- worker.start()
- workers.add(worker)
- root_task.add_done_callback(worker.stop)
- else:
- worker = idle_workers.pop()
-
- # Prune any other workers that have been idle for MAX_IDLE_TIME seconds or longer
- now = current_time()
- while idle_workers:
- if now - idle_workers[0].idle_since < WorkerThread.MAX_IDLE_TIME:
- break
-
- expired_worker = idle_workers.popleft()
- expired_worker.root_task.remove_done_callback(expired_worker.stop)
- expired_worker.stop()
-
- context = copy_context()
- context.run(sniffio.current_async_library_cvar.set, None)
- worker.queue.put_nowait((context, func, args, future))
- return await future
-
-
-def run_sync_from_thread(
- func: Callable[..., T_Retval],
- *args: object,
- loop: asyncio.AbstractEventLoop | None = None,
-) -> T_Retval:
- @wraps(func)
- def wrapper() -> None:
- try:
- f.set_result(func(*args))
- except BaseException as exc:
- f.set_exception(exc)
- if not isinstance(exc, Exception):
- raise
-
- f: concurrent.futures.Future[T_Retval] = Future()
- loop = loop or threadlocals.loop
- loop.call_soon_threadsafe(wrapper)
- return f.result()
-
-
-def run_async_from_thread(
- func: Callable[..., Awaitable[T_Retval]], *args: object
-) -> T_Retval:
- f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe(
- func(*args), threadlocals.loop
- )
- return f.result()
-
-
class BlockingPortal(abc.BlockingPortal):
def __new__(cls) -> BlockingPortal:
return object.__new__(cls)
@@ -916,20 +887,16 @@ class BlockingPortal(abc.BlockingPortal):
def _spawn_task_from_thread(
self,
- func: Callable,
- args: tuple,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
+ args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
name: object,
- future: Future,
+ future: Future[T_Retval],
) -> None:
- run_sync_from_thread(
+ AsyncIOBackend.run_sync_from_thread(
partial(self._task_group.start_soon, name=name),
- self._call_func,
- func,
- args,
- kwargs,
- future,
- loop=self._loop,
+ (self._call_func, func, args, kwargs, future),
+ self._loop,
)
@@ -951,6 +918,7 @@ class StreamReaderWrapper(abc.ByteReceiveStream):
async def aclose(self) -> None:
self._stream.feed_eof()
+ await AsyncIOBackend.checkpoint()
@dataclass(eq=False)
@@ -963,6 +931,7 @@ class StreamWriterWrapper(abc.ByteSendStream):
async def aclose(self) -> None:
self._stream.close()
+ await AsyncIOBackend.checkpoint()
@dataclass(eq=False)
@@ -973,14 +942,22 @@ class Process(abc.Process):
_stderr: StreamReaderWrapper | None
async def aclose(self) -> None:
- if self._stdin:
- await self._stdin.aclose()
- if self._stdout:
- await self._stdout.aclose()
- if self._stderr:
- await self._stderr.aclose()
+ with CancelScope(shield=True):
+ if self._stdin:
+ await self._stdin.aclose()
+ if self._stdout:
+ await self._stdout.aclose()
+ if self._stderr:
+ await self._stderr.aclose()
- await self.wait()
+ try:
+ await self.wait()
+ except BaseException:
+ self.kill()
+ with CancelScope(shield=True):
+ await self.wait()
+
+ raise
async def wait(self) -> int:
return await self._process.wait()
@@ -1015,55 +992,17 @@ class Process(abc.Process):
return self._stderr
-async def open_process(
- command: str | bytes | Sequence[str | bytes],
- *,
- shell: bool,
- stdin: int | IO[Any] | None,
- stdout: int | IO[Any] | None,
- stderr: int | IO[Any] | None,
- cwd: str | bytes | PathLike | None = None,
- env: Mapping[str, str] | None = None,
- start_new_session: bool = False,
-) -> Process:
- await checkpoint()
- if shell:
- process = await asyncio.create_subprocess_shell(
- cast(Union[str, bytes], command),
- stdin=stdin,
- stdout=stdout,
- stderr=stderr,
- cwd=cwd,
- env=env,
- start_new_session=start_new_session,
- )
- else:
- process = await asyncio.create_subprocess_exec(
- *command,
- stdin=stdin,
- stdout=stdout,
- stderr=stderr,
- cwd=cwd,
- env=env,
- start_new_session=start_new_session,
- )
-
- stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
- stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
- stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
- return Process(process, stdin_stream, stdout_stream, stderr_stream)
-
-
def _forcibly_shutdown_process_pool_on_exit(
workers: set[Process], _task: object
) -> None:
"""
Forcibly shuts down worker processes belonging to this event loop."""
- child_watcher: asyncio.AbstractChildWatcher | None
- try:
- child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
- except NotImplementedError:
- child_watcher = None
+ child_watcher: asyncio.AbstractChildWatcher | None = None
+ if sys.version_info < (3, 12):
+ try:
+ child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
+ except NotImplementedError:
+ pass
# Close as much as possible (w/o async/await) to avoid warnings
for process in workers:
@@ -1078,14 +1017,15 @@ def _forcibly_shutdown_process_pool_on_exit(
child_watcher.remove_child_handler(process.pid)
-async def _shutdown_process_pool_on_exit(workers: set[Process]) -> None:
+async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
"""
Shuts down worker processes belonging to this event loop.
- NOTE: this only works when the event loop was started using asyncio.run() or anyio.run().
+ NOTE: this only works when the event loop was started using asyncio.run() or
+ anyio.run().
"""
- process: Process
+ process: abc.Process
try:
await sleep(math.inf)
except asyncio.CancelledError:
@@ -1097,16 +1037,6 @@ async def _shutdown_process_pool_on_exit(workers: set[Process]) -> None:
await process.aclose()
-def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None:
- kwargs: dict[str, Any] = (
- {"name": "AnyIO process pool shutdown task"} if _native_task_names else {}
- )
- create_task(_shutdown_process_pool_on_exit(workers), **kwargs)
- find_root_task().add_done_callback(
- partial(_forcibly_shutdown_process_pool_on_exit, workers)
- )
-
-
#
# Sockets and networking
#
@@ -1193,7 +1123,7 @@ class SocketStream(abc.SocketStream):
async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
if (
not self._protocol.read_event.is_set()
@@ -1209,7 +1139,7 @@ class SocketStream(abc.SocketStream):
if self._closed:
raise ClosedResourceError from None
elif self._protocol.exception:
- raise self._protocol.exception
+ raise self._protocol.exception from None
else:
raise EndOfStream from None
@@ -1218,8 +1148,8 @@ class SocketStream(abc.SocketStream):
chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
self._protocol.read_queue.appendleft(leftover)
- # If the read queue is empty, clear the flag so that the next call will block until
- # data is available
+ # If the read queue is empty, clear the flag so that the next call will
+ # block until data is available
if not self._protocol.read_queue:
self._protocol.read_event.clear()
@@ -1227,7 +1157,7 @@ class SocketStream(abc.SocketStream):
async def send(self, item: bytes) -> None:
with self._send_guard:
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
if self._closed:
raise ClosedResourceError
@@ -1263,14 +1193,13 @@ class SocketStream(abc.SocketStream):
self._transport.abort()
-class UNIXSocketStream(abc.SocketStream):
+class _RawSocketMixin:
_receive_future: asyncio.Future | None = None
_send_future: asyncio.Future | None = None
_closing = False
def __init__(self, raw_socket: socket.socket):
self.__raw_socket = raw_socket
- self._loop = get_running_loop()
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")
@@ -1284,7 +1213,7 @@ class UNIXSocketStream(abc.SocketStream):
loop.remove_reader(self.__raw_socket)
f = self._receive_future = asyncio.Future()
- self._loop.add_reader(self.__raw_socket, f.set_result, None)
+ loop.add_reader(self.__raw_socket, f.set_result, None)
f.add_done_callback(callback)
return f
@@ -1294,21 +1223,34 @@ class UNIXSocketStream(abc.SocketStream):
loop.remove_writer(self.__raw_socket)
f = self._send_future = asyncio.Future()
- self._loop.add_writer(self.__raw_socket, f.set_result, None)
+ loop.add_writer(self.__raw_socket, f.set_result, None)
f.add_done_callback(callback)
return f
+ async def aclose(self) -> None:
+ if not self._closing:
+ self._closing = True
+ if self.__raw_socket.fileno() != -1:
+ self.__raw_socket.close()
+
+ if self._receive_future:
+ self._receive_future.set_result(None)
+ if self._send_future:
+ self._send_future.set_result(None)
+
+
+class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
async def send_eof(self) -> None:
with self._send_guard:
self._raw_socket.shutdown(socket.SHUT_WR)
async def receive(self, max_bytes: int = 65536) -> bytes:
loop = get_running_loop()
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
with self._receive_guard:
while True:
try:
- data = self.__raw_socket.recv(max_bytes)
+ data = self._raw_socket.recv(max_bytes)
except BlockingIOError:
await self._wait_until_readable(loop)
except OSError as exc:
@@ -1324,12 +1266,12 @@ class UNIXSocketStream(abc.SocketStream):
async def send(self, item: bytes) -> None:
loop = get_running_loop()
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
with self._send_guard:
view = memoryview(item)
while view:
try:
- bytes_sent = self.__raw_socket.send(view)
+ bytes_sent = self._raw_socket.send(view)
except BlockingIOError:
await self._wait_until_writable(loop)
except OSError as exc:
@@ -1348,11 +1290,11 @@ class UNIXSocketStream(abc.SocketStream):
loop = get_running_loop()
fds = array.array("i")
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
with self._receive_guard:
while True:
try:
- message, ancdata, flags, addr = self.__raw_socket.recvmsg(
+ message, ancdata, flags, addr = self._raw_socket.recvmsg(
msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
)
except BlockingIOError:
@@ -1394,13 +1336,13 @@ class UNIXSocketStream(abc.SocketStream):
filenos.append(fd.fileno())
fdarray = array.array("i", filenos)
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
with self._send_guard:
while True:
try:
# The ignore can be removed after mypy picks up
# https://github.com/python/typeshed/pull/5545
- self.__raw_socket.sendmsg(
+ self._raw_socket.sendmsg(
[message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
)
break
@@ -1412,17 +1354,6 @@ class UNIXSocketStream(abc.SocketStream):
else:
raise BrokenResourceError from exc
- async def aclose(self) -> None:
- if not self._closing:
- self._closing = True
- if self.__raw_socket.fileno() != -1:
- self.__raw_socket.close()
-
- if self._receive_future:
- self._receive_future.set_result(None)
- if self._send_future:
- self._send_future.set_result(None)
-
class TCPSocketListener(abc.SocketListener):
_accept_scope: CancelScope | None = None
@@ -1442,7 +1373,7 @@ class TCPSocketListener(abc.SocketListener):
raise ClosedResourceError
with self._accept_guard:
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
with CancelScope() as self._accept_scope:
try:
client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
@@ -1492,7 +1423,7 @@ class UNIXSocketListener(abc.SocketListener):
self._closed = False
async def accept(self) -> abc.SocketStream:
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
with self._accept_guard:
while True:
try:
@@ -1542,7 +1473,7 @@ class UDPSocket(abc.UDPSocket):
async def receive(self) -> tuple[bytes, IPSockAddrType]:
with self._receive_guard:
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
# If the buffer is empty, ask for more data
if not self._protocol.read_queue and not self._transport.is_closing():
@@ -1559,7 +1490,7 @@ class UDPSocket(abc.UDPSocket):
async def send(self, item: UDPPacketType) -> None:
with self._send_guard:
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
await self._protocol.write_event.wait()
if self._closed:
raise ClosedResourceError
@@ -1590,7 +1521,7 @@ class ConnectedUDPSocket(abc.ConnectedUDPSocket):
async def receive(self) -> bytes:
with self._receive_guard:
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
# If the buffer is empty, ask for more data
if not self._protocol.read_queue and not self._transport.is_closing():
@@ -1609,7 +1540,7 @@ class ConnectedUDPSocket(abc.ConnectedUDPSocket):
async def send(self, item: bytes) -> None:
with self._send_guard:
- await checkpoint()
+ await AsyncIOBackend.checkpoint()
await self._protocol.write_event.wait()
if self._closed:
raise ClosedResourceError
@@ -1619,142 +1550,82 @@ class ConnectedUDPSocket(abc.ConnectedUDPSocket):
self._transport.sendto(item)
-async def connect_tcp(
- host: str, port: int, local_addr: tuple[str, int] | None = None
-) -> SocketStream:
- transport, protocol = cast(
- Tuple[asyncio.Transport, StreamProtocol],
- await get_running_loop().create_connection(
- StreamProtocol, host, port, local_addr=local_addr
- ),
- )
- transport.pause_reading()
- return SocketStream(transport, protocol)
-
-
-async def connect_unix(path: str) -> UNIXSocketStream:
- await checkpoint()
- loop = get_running_loop()
- raw_socket = socket.socket(socket.AF_UNIX)
- raw_socket.setblocking(False)
- while True:
- try:
- raw_socket.connect(path)
- except BlockingIOError:
- f: asyncio.Future = asyncio.Future()
- loop.add_writer(raw_socket, f.set_result, None)
- f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
- await f
- except BaseException:
- raw_socket.close()
- raise
- else:
- return UNIXSocketStream(raw_socket)
-
-
-async def create_udp_socket(
- family: socket.AddressFamily,
- local_address: IPSockAddrType | None,
- remote_address: IPSockAddrType | None,
- reuse_port: bool,
-) -> UDPSocket | ConnectedUDPSocket:
- result = await get_running_loop().create_datagram_endpoint(
- DatagramProtocol,
- local_addr=local_address,
- remote_addr=remote_address,
- family=family,
- reuse_port=reuse_port,
- )
- transport = result[0]
- protocol = result[1]
- if protocol.exception:
- transport.close()
- raise protocol.exception
-
- if not remote_address:
- return UDPSocket(transport, protocol)
- else:
- return ConnectedUDPSocket(transport, protocol)
+class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
+ async def receive(self) -> UNIXDatagramPacketType:
+ loop = get_running_loop()
+ await AsyncIOBackend.checkpoint()
+ with self._receive_guard:
+ while True:
+ try:
+ data = self._raw_socket.recvfrom(65536)
+ except BlockingIOError:
+ await self._wait_until_readable(loop)
+ except OSError as exc:
+ if self._closing:
+ raise ClosedResourceError from None
+ else:
+ raise BrokenResourceError from exc
+ else:
+ return data
+ async def send(self, item: UNIXDatagramPacketType) -> None:
+ loop = get_running_loop()
+ await AsyncIOBackend.checkpoint()
+ with self._send_guard:
+ while True:
+ try:
+ self._raw_socket.sendto(*item)
+ except BlockingIOError:
+ await self._wait_until_writable(loop)
+ except OSError as exc:
+ if self._closing:
+ raise ClosedResourceError from None
+ else:
+ raise BrokenResourceError from exc
+ else:
+ return
-async def getaddrinfo(
- host: bytes | str,
- port: str | int | None,
- *,
- family: int | AddressFamily = 0,
- type: int | SocketKind = 0,
- proto: int = 0,
- flags: int = 0,
-) -> GetAddrInfoReturnType:
- # https://github.com/python/typeshed/pull/4304
- result = await get_running_loop().getaddrinfo(
- host, port, family=family, type=type, proto=proto, flags=flags
- )
- return cast(GetAddrInfoReturnType, result)
+class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
+ async def receive(self) -> bytes:
+ loop = get_running_loop()
+ await AsyncIOBackend.checkpoint()
+ with self._receive_guard:
+ while True:
+ try:
+ data = self._raw_socket.recv(65536)
+ except BlockingIOError:
+ await self._wait_until_readable(loop)
+ except OSError as exc:
+ if self._closing:
+ raise ClosedResourceError from None
+ else:
+ raise BrokenResourceError from exc
+ else:
+ return data
-async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> tuple[str, str]:
- return await get_running_loop().getnameinfo(sockaddr, flags)
+ async def send(self, item: bytes) -> None:
+ loop = get_running_loop()
+ await AsyncIOBackend.checkpoint()
+ with self._send_guard:
+ while True:
+ try:
+ self._raw_socket.send(item)
+ except BlockingIOError:
+ await self._wait_until_writable(loop)
+ except OSError as exc:
+ if self._closing:
+ raise ClosedResourceError from None
+ else:
+ raise BrokenResourceError from exc
+ else:
+ return
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
-async def wait_socket_readable(sock: socket.socket) -> None:
- await checkpoint()
- try:
- read_events = _read_events.get()
- except LookupError:
- read_events = {}
- _read_events.set(read_events)
-
- if read_events.get(sock):
- raise BusyResourceError("reading from") from None
-
- loop = get_running_loop()
- event = read_events[sock] = asyncio.Event()
- loop.add_reader(sock, event.set)
- try:
- await event.wait()
- finally:
- if read_events.pop(sock, None) is not None:
- loop.remove_reader(sock)
- readable = True
- else:
- readable = False
-
- if not readable:
- raise ClosedResourceError
-
-
-async def wait_socket_writable(sock: socket.socket) -> None:
- await checkpoint()
- try:
- write_events = _write_events.get()
- except LookupError:
- write_events = {}
- _write_events.set(write_events)
-
- if write_events.get(sock):
- raise BusyResourceError("writing to") from None
-
- loop = get_running_loop()
- event = write_events[sock] = asyncio.Event()
- loop.add_writer(sock.fileno(), event.set)
- try:
- await event.wait()
- finally:
- if write_events.pop(sock, None) is not None:
- loop.remove_writer(sock)
- writable = True
- else:
- writable = False
-
- if not writable:
- raise ClosedResourceError
-
-
#
# Synchronization
#
@@ -1767,16 +1638,17 @@ class Event(BaseEvent):
def __init__(self) -> None:
self._event = asyncio.Event()
- def set(self) -> DeprecatedAwaitable:
+ def set(self) -> None:
self._event.set()
- return DeprecatedAwaitable(self.set)
def is_set(self) -> bool:
return self._event.is_set()
async def wait(self) -> None:
- if await self._event.wait():
- await checkpoint()
+ if self.is_set():
+ await AsyncIOBackend.checkpoint()
+ else:
+ await self._event.wait()
def statistics(self) -> EventStatistics:
return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined]
@@ -1815,19 +1687,14 @@ class CapacityLimiter(BaseCapacityLimiter):
if value < 1:
raise ValueError("total_tokens must be >= 1")
- old_value = self._total_tokens
+ waiters_to_notify = max(value - self._total_tokens, 0)
self._total_tokens = value
- events = []
- for event in self._wait_queue.values():
- if value <= old_value:
- break
-
- if not event.is_set():
- events.append(event)
- old_value += 1
- for event in events:
+ # Notify waiting tasks that they have acquired the limiter
+ while self._wait_queue and waiters_to_notify:
+ event = self._wait_queue.popitem(last=False)[1]
event.set()
+ waiters_to_notify -= 1
@property
def borrowed_tokens(self) -> int:
@@ -1837,11 +1704,10 @@ class CapacityLimiter(BaseCapacityLimiter):
def available_tokens(self) -> float:
return self._total_tokens - len(self._borrowers)
- def acquire_nowait(self) -> DeprecatedAwaitable:
+ def acquire_nowait(self) -> None:
self.acquire_on_behalf_of_nowait(current_task())
- return DeprecatedAwaitable(self.acquire_nowait)
- def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
if borrower in self._borrowers:
raise RuntimeError(
"this borrower is already holding one of this CapacityLimiter's "
@@ -1852,13 +1718,12 @@ class CapacityLimiter(BaseCapacityLimiter):
raise WouldBlock
self._borrowers.add(borrower)
- return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait)
async def acquire(self) -> None:
return await self.acquire_on_behalf_of(current_task())
async def acquire_on_behalf_of(self, borrower: object) -> None:
- await checkpoint_if_cancelled()
+ await AsyncIOBackend.checkpoint_if_cancelled()
try:
self.acquire_on_behalf_of_nowait(borrower)
except WouldBlock:
@@ -1873,7 +1738,7 @@ class CapacityLimiter(BaseCapacityLimiter):
self._borrowers.add(borrower)
else:
try:
- await cancel_shielded_checkpoint()
+ await AsyncIOBackend.cancel_shielded_checkpoint()
except BaseException:
self.release()
raise
@@ -1906,29 +1771,20 @@ class CapacityLimiter(BaseCapacityLimiter):
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
-def current_default_thread_limiter() -> CapacityLimiter:
- try:
- return _default_thread_limiter.get()
- except LookupError:
- limiter = CapacityLimiter(40)
- _default_thread_limiter.set(limiter)
- return limiter
-
-
#
# Operating system signals
#
-class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]):
- def __init__(self, signals: tuple[int, ...]):
+class _SignalReceiver:
+ def __init__(self, signals: tuple[Signals, ...]):
self._signals = signals
self._loop = get_running_loop()
- self._signal_queue: deque[int] = deque()
+ self._signal_queue: deque[Signals] = deque()
self._future: asyncio.Future = asyncio.Future()
- self._handled_signals: set[int] = set()
+ self._handled_signals: set[Signals] = set()
- def _deliver(self, signum: int) -> None:
+ def _deliver(self, signum: Signals) -> None:
self._signal_queue.append(signum)
if not self._future.done():
self._future.set_result(None)
@@ -1953,8 +1809,8 @@ class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]):
def __aiter__(self) -> _SignalReceiver:
return self
- async def __anext__(self) -> int:
- await checkpoint()
+ async def __anext__(self) -> Signals:
+ await AsyncIOBackend.checkpoint()
if not self._signal_queue:
self._future = asyncio.Future()
await self._future
@@ -1962,10 +1818,6 @@ class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]):
return self._signal_queue.popleft()
-def open_signal_receiver(*signals: int) -> _SignalReceiver:
- return _SignalReceiver(signals)
-
-
#
# Testing and debugging
#
@@ -1974,69 +1826,47 @@ def open_signal_receiver(*signals: int) -> _SignalReceiver:
def _create_task_info(task: asyncio.Task) -> TaskInfo:
task_state = _task_states.get(task)
if task_state is None:
- name = task.get_name() if _native_task_names else None
parent_id = None
else:
- name = task_state.name
parent_id = task_state.parent_id
- return TaskInfo(id(task), parent_id, name, get_coro(task))
-
-
-def get_current_task() -> TaskInfo:
- return _create_task_info(current_task()) # type: ignore[arg-type]
-
-
-def get_running_tasks() -> list[TaskInfo]:
- return [_create_task_info(task) for task in all_tasks() if not task.done()]
-
-
-async def wait_all_tasks_blocked() -> None:
- await checkpoint()
- this_task = current_task()
- while True:
- for task in all_tasks():
- if task is this_task:
- continue
-
- if task._fut_waiter is None or task._fut_waiter.done(): # type: ignore[attr-defined]
- await sleep(0.1)
- break
- else:
- return
+ return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
class TestRunner(abc.TestRunner):
+ _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
+
def __init__(
self,
- debug: bool = False,
+ *,
+ debug: bool | None = None,
use_uvloop: bool = False,
- policy: asyncio.AbstractEventLoopPolicy | None = None,
- ):
+ loop_factory: Callable[[], AbstractEventLoop] | None = None,
+ ) -> None:
+ if use_uvloop and loop_factory is None:
+ import uvloop
+
+ loop_factory = uvloop.new_event_loop
+
+ self._runner = Runner(debug=debug, loop_factory=loop_factory)
self._exceptions: list[BaseException] = []
- _maybe_set_event_loop_policy(policy, use_uvloop)
- self._loop = asyncio.new_event_loop()
- self._loop.set_debug(debug)
- self._loop.set_exception_handler(self._exception_handler)
- asyncio.set_event_loop(self._loop)
-
- def _cancel_all_tasks(self) -> None:
- to_cancel = all_tasks(self._loop)
- if not to_cancel:
- return
+ self._runner_task: asyncio.Task | None = None
- for task in to_cancel:
- task.cancel()
+ def __enter__(self) -> TestRunner:
+ self._runner.__enter__()
+ self.get_loop().set_exception_handler(self._exception_handler)
+ return self
- self._loop.run_until_complete(
- asyncio.gather(*to_cancel, return_exceptions=True)
- )
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ self._runner.__exit__(exc_type, exc_val, exc_tb)
- for task in to_cancel:
- if task.cancelled():
- continue
- if task.exception() is not None:
- raise cast(BaseException, task.exception())
+ def get_loop(self) -> AbstractEventLoop:
+ return self._runner.get_loop()
def _exception_handler(
self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
@@ -2053,56 +1883,77 @@ class TestRunner(abc.TestRunner):
if len(exceptions) == 1:
raise exceptions[0]
elif exceptions:
- raise ExceptionGroup(exceptions)
+ raise BaseExceptionGroup(
+ "Multiple exceptions occurred in asynchronous callbacks", exceptions
+ )
- def close(self) -> None:
- try:
- self._cancel_all_tasks()
- self._loop.run_until_complete(self._loop.shutdown_asyncgens())
- finally:
- asyncio.set_event_loop(None)
- self._loop.close()
+ @staticmethod
+ async def _run_tests_and_fixtures(
+ receive_stream: MemoryObjectReceiveStream[
+ tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
+ ],
+ ) -> None:
+ with receive_stream:
+ async for coro, future in receive_stream:
+ try:
+ retval = await coro
+ except BaseException as exc:
+ if not future.cancelled():
+ future.set_exception(exc)
+ else:
+ if not future.cancelled():
+ future.set_result(retval)
+
+ async def _call_in_runner_task(
+ self,
+ func: Callable[P, Awaitable[T_Retval]],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> T_Retval:
+ if not self._runner_task:
+ self._send_stream, receive_stream = create_memory_object_stream[
+ Tuple[Awaitable[Any], asyncio.Future]
+ ](1)
+ self._runner_task = self.get_loop().create_task(
+ self._run_tests_and_fixtures(receive_stream)
+ )
+
+ coro = func(*args, **kwargs)
+ future: asyncio.Future[T_Retval] = self.get_loop().create_future()
+ self._send_stream.send_nowait((coro, future))
+ return await future
def run_asyncgen_fixture(
self,
fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
kwargs: dict[str, Any],
) -> Iterable[T_Retval]:
- async def fixture_runner() -> None:
- agen = fixture_func(**kwargs)
- try:
- retval = await agen.asend(None)
- self._raise_async_exceptions()
- except BaseException as exc:
- f.set_exception(exc)
- return
- else:
- f.set_result(retval)
-
- await event.wait()
- try:
- await agen.asend(None)
- except StopAsyncIteration:
- pass
- else:
- await agen.aclose()
- raise RuntimeError("Async generator fixture did not stop")
-
- f = self._loop.create_future()
- event = asyncio.Event()
- fixture_task = self._loop.create_task(fixture_runner())
- self._loop.run_until_complete(f)
- yield f.result()
- event.set()
- self._loop.run_until_complete(fixture_task)
+ asyncgen = fixture_func(**kwargs)
+ fixturevalue: T_Retval = self.get_loop().run_until_complete(
+ self._call_in_runner_task(asyncgen.asend, None)
+ )
self._raise_async_exceptions()
+ yield fixturevalue
+
+ try:
+ self.get_loop().run_until_complete(
+ self._call_in_runner_task(asyncgen.asend, None)
+ )
+ except StopAsyncIteration:
+ self._raise_async_exceptions()
+ else:
+ self.get_loop().run_until_complete(asyncgen.aclose())
+ raise RuntimeError("Async generator fixture did not stop")
+
def run_fixture(
self,
fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
kwargs: dict[str, Any],
) -> T_Retval:
- retval = self._loop.run_until_complete(fixture_func(**kwargs))
+ retval = self.get_loop().run_until_complete(
+ self._call_in_runner_task(fixture_func, **kwargs)
+ )
self._raise_async_exceptions()
return retval
@@ -2110,8 +1961,518 @@ class TestRunner(abc.TestRunner):
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
) -> None:
try:
- self._loop.run_until_complete(test_func(**kwargs))
+ self.get_loop().run_until_complete(
+ self._call_in_runner_task(test_func, **kwargs)
+ )
except Exception as exc:
self._exceptions.append(exc)
self._raise_async_exceptions()
+
+
+class AsyncIOBackend(AsyncBackend):
+ @classmethod
+ def run(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
+ args: tuple[Unpack[PosArgsT]],
+ kwargs: dict[str, Any],
+ options: dict[str, Any],
+ ) -> T_Retval:
+ @wraps(func)
+ async def wrapper() -> T_Retval:
+ task = cast(asyncio.Task, current_task())
+ task.set_name(get_callable_name(func))
+ _task_states[task] = TaskState(None, None)
+
+ try:
+ return await func(*args)
+ finally:
+ del _task_states[task]
+
+ debug = options.get("debug", False)
+ loop_factory = options.get("loop_factory", None)
+ if loop_factory is None and options.get("use_uvloop", False):
+ import uvloop
+
+ loop_factory = uvloop.new_event_loop
+
+ with Runner(debug=debug, loop_factory=loop_factory) as runner:
+ return runner.run(wrapper())
+
+ @classmethod
+ def current_token(cls) -> object:
+ return get_running_loop()
+
+ @classmethod
+ def current_time(cls) -> float:
+ return get_running_loop().time()
+
+ @classmethod
+ def cancelled_exception_class(cls) -> type[BaseException]:
+ return CancelledError
+
+ @classmethod
+ async def checkpoint(cls) -> None:
+ await sleep(0)
+
+ @classmethod
+ async def checkpoint_if_cancelled(cls) -> None:
+ task = current_task()
+ if task is None:
+ return
+
+ try:
+ cancel_scope = _task_states[task].cancel_scope
+ except KeyError:
+ return
+
+ while cancel_scope:
+ if cancel_scope.cancel_called:
+ await sleep(0)
+ elif cancel_scope.shield:
+ break
+ else:
+ cancel_scope = cancel_scope._parent_scope
+
+ @classmethod
+ async def cancel_shielded_checkpoint(cls) -> None:
+ with CancelScope(shield=True):
+ await sleep(0)
+
+ @classmethod
+ async def sleep(cls, delay: float) -> None:
+ await sleep(delay)
+
+ @classmethod
+ def create_cancel_scope(
+ cls, *, deadline: float = math.inf, shield: bool = False
+ ) -> CancelScope:
+ return CancelScope(deadline=deadline, shield=shield)
+
+ @classmethod
+ def current_effective_deadline(cls) -> float:
+ try:
+ cancel_scope = _task_states[
+ current_task() # type: ignore[index]
+ ].cancel_scope
+ except KeyError:
+ return math.inf
+
+ deadline = math.inf
+ while cancel_scope:
+ deadline = min(deadline, cancel_scope.deadline)
+ if cancel_scope._cancel_called:
+ deadline = -math.inf
+ break
+ elif cancel_scope.shield:
+ break
+ else:
+ cancel_scope = cancel_scope._parent_scope
+
+ return deadline
+
+ @classmethod
+ def create_task_group(cls) -> abc.TaskGroup:
+ return TaskGroup()
+
+ @classmethod
+ def create_event(cls) -> abc.Event:
+ return Event()
+
+ @classmethod
+ def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
+ return CapacityLimiter(total_tokens)
+
+ @classmethod
+ async def run_sync_in_worker_thread(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
+ args: tuple[Unpack[PosArgsT]],
+ abandon_on_cancel: bool = False,
+ limiter: abc.CapacityLimiter | None = None,
+ ) -> T_Retval:
+ await cls.checkpoint()
+
+ # If this is the first run in this event loop thread, set up the necessary
+ # variables
+ try:
+ idle_workers = _threadpool_idle_workers.get()
+ workers = _threadpool_workers.get()
+ except LookupError:
+ idle_workers = deque()
+ workers = set()
+ _threadpool_idle_workers.set(idle_workers)
+ _threadpool_workers.set(workers)
+
+ async with limiter or cls.current_default_thread_limiter():
+ with CancelScope(shield=not abandon_on_cancel) as scope:
+ future: asyncio.Future = asyncio.Future()
+ root_task = find_root_task()
+ if not idle_workers:
+ worker = WorkerThread(root_task, workers, idle_workers)
+ worker.start()
+ workers.add(worker)
+ root_task.add_done_callback(worker.stop)
+ else:
+ worker = idle_workers.pop()
+
+ # Prune any other workers that have been idle for MAX_IDLE_TIME
+ # seconds or longer
+ now = cls.current_time()
+ while idle_workers:
+ if (
+ now - idle_workers[0].idle_since
+ < WorkerThread.MAX_IDLE_TIME
+ ):
+ break
+
+ expired_worker = idle_workers.popleft()
+ expired_worker.root_task.remove_done_callback(
+ expired_worker.stop
+ )
+ expired_worker.stop()
+
+ context = copy_context()
+ context.run(sniffio.current_async_library_cvar.set, None)
+ if abandon_on_cancel or scope._parent_scope is None:
+ worker_scope = scope
+ else:
+ worker_scope = scope._parent_scope
+
+ worker.queue.put_nowait((context, func, args, future, worker_scope))
+ return await future
+
+ @classmethod
+ def check_cancelled(cls) -> None:
+ scope: CancelScope | None = threadlocals.current_cancel_scope
+ while scope is not None:
+ if scope.cancel_called:
+ raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
+
+ if scope.shield:
+ return
+
+ scope = scope._parent_scope
+
+ @classmethod
+ def run_async_from_thread(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
+ args: tuple[Unpack[PosArgsT]],
+ token: object,
+ ) -> T_Retval:
+ async def task_wrapper(scope: CancelScope) -> T_Retval:
+ __tracebackhide__ = True
+ task = cast(asyncio.Task, current_task())
+ _task_states[task] = TaskState(None, scope)
+ scope._tasks.add(task)
+ try:
+ return await func(*args)
+ except CancelledError as exc:
+ raise concurrent.futures.CancelledError(str(exc)) from None
+ finally:
+ scope._tasks.discard(task)
+
+ loop = cast(AbstractEventLoop, token)
+ context = copy_context()
+ context.run(sniffio.current_async_library_cvar.set, "asyncio")
+ wrapper = task_wrapper(threadlocals.current_cancel_scope)
+ f: concurrent.futures.Future[T_Retval] = context.run(
+ asyncio.run_coroutine_threadsafe, wrapper, loop
+ )
+ return f.result()
+
+ @classmethod
+ def run_sync_from_thread(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
+ args: tuple[Unpack[PosArgsT]],
+ token: object,
+ ) -> T_Retval:
+ @wraps(func)
+ def wrapper() -> None:
+ try:
+ sniffio.current_async_library_cvar.set("asyncio")
+ f.set_result(func(*args))
+ except BaseException as exc:
+ f.set_exception(exc)
+ if not isinstance(exc, Exception):
+ raise
+
+ f: concurrent.futures.Future[T_Retval] = Future()
+ loop = cast(AbstractEventLoop, token)
+ loop.call_soon_threadsafe(wrapper)
+ return f.result()
+
+ @classmethod
+ def create_blocking_portal(cls) -> abc.BlockingPortal:
+ return BlockingPortal()
+
+ @classmethod
+ async def open_process(
+ cls,
+ command: str | bytes | Sequence[str | bytes],
+ *,
+ shell: bool,
+ stdin: int | IO[Any] | None,
+ stdout: int | IO[Any] | None,
+ stderr: int | IO[Any] | None,
+ cwd: str | bytes | PathLike | None = None,
+ env: Mapping[str, str] | None = None,
+ start_new_session: bool = False,
+ ) -> Process:
+ await cls.checkpoint()
+ if shell:
+ process = await asyncio.create_subprocess_shell(
+ cast("str | bytes", command),
+ stdin=stdin,
+ stdout=stdout,
+ stderr=stderr,
+ cwd=cwd,
+ env=env,
+ start_new_session=start_new_session,
+ )
+ else:
+ process = await asyncio.create_subprocess_exec(
+ *command,
+ stdin=stdin,
+ stdout=stdout,
+ stderr=stderr,
+ cwd=cwd,
+ env=env,
+ start_new_session=start_new_session,
+ )
+
+ stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
+ stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
+ stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
+ return Process(process, stdin_stream, stdout_stream, stderr_stream)
+
+ @classmethod
+ def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
+ create_task(
+ _shutdown_process_pool_on_exit(workers),
+ name="AnyIO process pool shutdown task",
+ )
+ find_root_task().add_done_callback(
+ partial(_forcibly_shutdown_process_pool_on_exit, workers)
+ )
+
+ @classmethod
+ async def connect_tcp(
+ cls, host: str, port: int, local_address: IPSockAddrType | None = None
+ ) -> abc.SocketStream:
+ transport, protocol = cast(
+ Tuple[asyncio.Transport, StreamProtocol],
+ await get_running_loop().create_connection(
+ StreamProtocol, host, port, local_addr=local_address
+ ),
+ )
+ transport.pause_reading()
+ return SocketStream(transport, protocol)
+
+ @classmethod
+ async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
+ await cls.checkpoint()
+ loop = get_running_loop()
+ raw_socket = socket.socket(socket.AF_UNIX)
+ raw_socket.setblocking(False)
+ while True:
+ try:
+ raw_socket.connect(path)
+ except BlockingIOError:
+ f: asyncio.Future = asyncio.Future()
+ loop.add_writer(raw_socket, f.set_result, None)
+ f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
+ await f
+ except BaseException:
+ raw_socket.close()
+ raise
+ else:
+ return UNIXSocketStream(raw_socket)
+
+ @classmethod
+ def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
+ return TCPSocketListener(sock)
+
+ @classmethod
+ def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
+ return UNIXSocketListener(sock)
+
+ @classmethod
+ async def create_udp_socket(
+ cls,
+ family: AddressFamily,
+ local_address: IPSockAddrType | None,
+ remote_address: IPSockAddrType | None,
+ reuse_port: bool,
+ ) -> UDPSocket | ConnectedUDPSocket:
+ transport, protocol = await get_running_loop().create_datagram_endpoint(
+ DatagramProtocol,
+ local_addr=local_address,
+ remote_addr=remote_address,
+ family=family,
+ reuse_port=reuse_port,
+ )
+ if protocol.exception:
+ transport.close()
+ raise protocol.exception
+
+ if not remote_address:
+ return UDPSocket(transport, protocol)
+ else:
+ return ConnectedUDPSocket(transport, protocol)
+
+ @classmethod
+ async def create_unix_datagram_socket( # type: ignore[override]
+ cls, raw_socket: socket.socket, remote_path: str | bytes | None
+ ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
+ await cls.checkpoint()
+ loop = get_running_loop()
+
+ if remote_path:
+ while True:
+ try:
+ raw_socket.connect(remote_path)
+ except BlockingIOError:
+ f: asyncio.Future = asyncio.Future()
+ loop.add_writer(raw_socket, f.set_result, None)
+ f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
+ await f
+ except BaseException:
+ raw_socket.close()
+ raise
+ else:
+ return ConnectedUNIXDatagramSocket(raw_socket)
+ else:
+ return UNIXDatagramSocket(raw_socket)
+
+ @classmethod
+ async def getaddrinfo(
+ cls,
+ host: bytes | str | None,
+ port: str | int | None,
+ *,
+ family: int | AddressFamily = 0,
+ type: int | SocketKind = 0,
+ proto: int = 0,
+ flags: int = 0,
+ ) -> list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ tuple[str, int] | tuple[str, int, int, int],
+ ]
+ ]:
+ return await get_running_loop().getaddrinfo(
+ host, port, family=family, type=type, proto=proto, flags=flags
+ )
+
+ @classmethod
+ async def getnameinfo(
+ cls, sockaddr: IPSockAddrType, flags: int = 0
+ ) -> tuple[str, str]:
+ return await get_running_loop().getnameinfo(sockaddr, flags)
+
+ @classmethod
+ async def wait_socket_readable(cls, sock: socket.socket) -> None:
+ await cls.checkpoint()
+ try:
+ read_events = _read_events.get()
+ except LookupError:
+ read_events = {}
+ _read_events.set(read_events)
+
+ if read_events.get(sock):
+ raise BusyResourceError("reading from") from None
+
+ loop = get_running_loop()
+ event = read_events[sock] = asyncio.Event()
+ loop.add_reader(sock, event.set)
+ try:
+ await event.wait()
+ finally:
+ if read_events.pop(sock, None) is not None:
+ loop.remove_reader(sock)
+ readable = True
+ else:
+ readable = False
+
+ if not readable:
+ raise ClosedResourceError
+
+ @classmethod
+ async def wait_socket_writable(cls, sock: socket.socket) -> None:
+ await cls.checkpoint()
+ try:
+ write_events = _write_events.get()
+ except LookupError:
+ write_events = {}
+ _write_events.set(write_events)
+
+ if write_events.get(sock):
+ raise BusyResourceError("writing to") from None
+
+ loop = get_running_loop()
+ event = write_events[sock] = asyncio.Event()
+ loop.add_writer(sock.fileno(), event.set)
+ try:
+ await event.wait()
+ finally:
+ if write_events.pop(sock, None) is not None:
+ loop.remove_writer(sock)
+ writable = True
+ else:
+ writable = False
+
+ if not writable:
+ raise ClosedResourceError
+
+ @classmethod
+ def current_default_thread_limiter(cls) -> CapacityLimiter:
+ try:
+ return _default_thread_limiter.get()
+ except LookupError:
+ limiter = CapacityLimiter(40)
+ _default_thread_limiter.set(limiter)
+ return limiter
+
+ @classmethod
+ def open_signal_receiver(
+ cls, *signals: Signals
+ ) -> ContextManager[AsyncIterator[Signals]]:
+ return _SignalReceiver(signals)
+
+ @classmethod
+ def get_current_task(cls) -> TaskInfo:
+ return _create_task_info(current_task()) # type: ignore[arg-type]
+
+ @classmethod
+ def get_running_tasks(cls) -> list[TaskInfo]:
+ return [_create_task_info(task) for task in all_tasks() if not task.done()]
+
+ @classmethod
+ async def wait_all_tasks_blocked(cls) -> None:
+ await cls.checkpoint()
+ this_task = current_task()
+ while True:
+ for task in all_tasks():
+ if task is this_task:
+ continue
+
+ waiter = task._fut_waiter # type: ignore[attr-defined]
+ if waiter is None or waiter.done():
+ await sleep(0.1)
+ break
+ else:
+ return
+
+ @classmethod
+ def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
+ return TestRunner(**options)
+
+
+backend_class = AsyncIOBackend
diff --git a/contrib/python/anyio/anyio/_backends/_trio.py b/contrib/python/anyio/anyio/_backends/_trio.py
index cf28943509..1a47192e30 100644
--- a/contrib/python/anyio/anyio/_backends/_trio.py
+++ b/contrib/python/anyio/anyio/_backends/_trio.py
@@ -3,41 +3,48 @@ from __future__ import annotations
import array
import math
import socket
+import sys
+import types
+from collections.abc import AsyncIterator, Iterable
from concurrent.futures import Future
-from contextvars import copy_context
from dataclasses import dataclass
from functools import partial
from io import IOBase
from os import PathLike
from signal import Signals
+from socket import AddressFamily, SocketKind
from types import TracebackType
from typing import (
IO,
- TYPE_CHECKING,
Any,
AsyncGenerator,
- AsyncIterator,
Awaitable,
Callable,
Collection,
+ ContextManager,
Coroutine,
Generic,
- Iterable,
Mapping,
NoReturn,
Sequence,
TypeVar,
cast,
+ overload,
)
-import sniffio
import trio.from_thread
+import trio.lowlevel
from outcome import Error, Outcome, Value
+from trio.lowlevel import (
+ current_root_task,
+ current_task,
+ wait_readable,
+ wait_writable,
+)
from trio.socket import SocketType as TrioSocketType
from trio.to_thread import run_sync
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
-from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable
from .._core._eventloop import claim_worker_thread
from .._core._exceptions import (
BrokenResourceError,
@@ -45,54 +52,42 @@ from .._core._exceptions import (
ClosedResourceError,
EndOfStream,
)
-from .._core._exceptions import ExceptionGroup as BaseExceptionGroup
from .._core._sockets import convert_ipv6_sockaddr
+from .._core._streams import create_memory_object_stream
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
from .._core._synchronization import Event as BaseEvent
from .._core._synchronization import ResourceGuard
from .._core._tasks import CancelScope as BaseCancelScope
-from ..abc import IPSockAddrType, UDPPacketType
-
-if TYPE_CHECKING:
- from trio_typing import TaskStatus
+from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType
+from ..abc._eventloop import AsyncBackend
+from ..streams.memory import MemoryObjectSendStream
-try:
- from trio import lowlevel as trio_lowlevel
-except ImportError:
- from trio import hazmat as trio_lowlevel # type: ignore[no-redef]
- from trio.hazmat import wait_readable, wait_writable
+if sys.version_info >= (3, 10):
+ from typing import ParamSpec
else:
- from trio.lowlevel import wait_readable, wait_writable
+ from typing_extensions import ParamSpec
-try:
- trio_open_process = trio_lowlevel.open_process
-except AttributeError:
- # isort: off
- from trio import ( # type: ignore[attr-defined, no-redef]
- open_process as trio_open_process,
- )
+if sys.version_info >= (3, 11):
+ from typing import TypeVarTuple, Unpack
+else:
+ from exceptiongroup import BaseExceptionGroup
+ from typing_extensions import TypeVarTuple, Unpack
+T = TypeVar("T")
T_Retval = TypeVar("T_Retval")
T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType)
+PosArgsT = TypeVarTuple("PosArgsT")
+P = ParamSpec("P")
#
# Event loop
#
-run = trio.run
-current_token = trio.lowlevel.current_trio_token
RunVar = trio.lowlevel.RunVar
#
-# Miscellaneous
-#
-
-sleep = trio.sleep
-
-
-#
# Timeouts and cancellation
#
@@ -117,13 +112,10 @@ class CancelScope(BaseCancelScope):
exc_tb: TracebackType | None,
) -> bool | None:
# https://github.com/python-trio/trio-typing/pull/79
- return self.__original.__exit__( # type: ignore[func-returns-value]
- exc_type, exc_val, exc_tb
- )
+ return self.__original.__exit__(exc_type, exc_val, exc_tb)
- def cancel(self) -> DeprecatedAwaitable:
+ def cancel(self) -> None:
self.__original.cancel()
- return DeprecatedAwaitable(self.cancel)
@property
def deadline(self) -> float:
@@ -138,6 +130,10 @@ class CancelScope(BaseCancelScope):
return self.__original.cancel_called
@property
+ def cancelled_caught(self) -> bool:
+ return self.__original.cancelled_caught
+
+ @property
def shield(self) -> bool:
return self.__original.shield
@@ -146,27 +142,15 @@ class CancelScope(BaseCancelScope):
self.__original.shield = value
-CancelledError = trio.Cancelled
-checkpoint = trio.lowlevel.checkpoint
-checkpoint_if_cancelled = trio.lowlevel.checkpoint_if_cancelled
-cancel_shielded_checkpoint = trio.lowlevel.cancel_shielded_checkpoint
-current_effective_deadline = trio.current_effective_deadline
-current_time = trio.current_time
-
-
#
# Task groups
#
-class ExceptionGroup(BaseExceptionGroup, trio.MultiError):
- pass
-
-
class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self._active = False
- self._nursery_manager = trio.open_nursery()
+ self._nursery_manager = trio.open_nursery(strict_exception_groups=True)
self.cancel_scope = None # type: ignore[assignment]
async def __aenter__(self) -> TaskGroup:
@@ -183,13 +167,21 @@ class TaskGroup(abc.TaskGroup):
) -> bool | None:
try:
return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb)
- except trio.MultiError as exc:
- raise ExceptionGroup(exc.exceptions) from None
+ except BaseExceptionGroup as exc:
+ _, rest = exc.split(trio.Cancelled)
+ if not rest:
+ cancelled_exc = trio.Cancelled._create()
+ raise cancelled_exc from exc
+
+ raise
finally:
self._active = False
def start_soon(
- self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
+ self,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
+ *args: Unpack[PosArgsT],
+ name: object = None,
) -> None:
if not self._active:
raise RuntimeError(
@@ -200,7 +192,7 @@ class TaskGroup(abc.TaskGroup):
async def start(
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
- ) -> object:
+ ) -> Any:
if not self._active:
raise RuntimeError(
"This task group is not active; no new tasks can be started."
@@ -214,53 +206,6 @@ class TaskGroup(abc.TaskGroup):
#
-async def run_sync_in_worker_thread(
- func: Callable[..., T_Retval],
- *args: object,
- cancellable: bool = False,
- limiter: trio.CapacityLimiter | None = None,
-) -> T_Retval:
- def wrapper() -> T_Retval:
- with claim_worker_thread("trio"):
- return func(*args)
-
- # TODO: remove explicit context copying when trio 0.20 is the minimum requirement
- context = copy_context()
- context.run(sniffio.current_async_library_cvar.set, None)
- return await run_sync(
- context.run, wrapper, cancellable=cancellable, limiter=limiter
- )
-
-
-# TODO: remove this workaround when trio 0.20 is the minimum requirement
-def run_async_from_thread(
- fn: Callable[..., Awaitable[T_Retval]], *args: Any
-) -> T_Retval:
- async def wrapper() -> T_Retval:
- retval: T_Retval
-
- async def inner() -> None:
- nonlocal retval
- __tracebackhide__ = True
- retval = await fn(*args)
-
- async with trio.open_nursery() as n:
- context.run(n.start_soon, inner)
-
- __tracebackhide__ = True
- return retval # noqa: F821
-
- context = copy_context()
- context.run(sniffio.current_async_library_cvar.set, "trio")
- return trio.from_thread.run(wrapper)
-
-
-def run_sync_from_thread(fn: Callable[..., T_Retval], *args: Any) -> T_Retval:
- # TODO: remove explicit context copying when trio 0.20 is the minimum requirement
- retval = trio.from_thread.run_sync(copy_context().run, fn, *args)
- return cast(T_Retval, retval)
-
-
class BlockingPortal(abc.BlockingPortal):
def __new__(cls) -> BlockingPortal:
return object.__new__(cls)
@@ -271,16 +216,13 @@ class BlockingPortal(abc.BlockingPortal):
def _spawn_task_from_thread(
self,
- func: Callable,
- args: tuple,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
+ args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
name: object,
- future: Future,
+ future: Future[T_Retval],
) -> None:
- context = copy_context()
- context.run(sniffio.current_async_library_cvar.set, "trio")
trio.from_thread.run_sync(
- context.run,
partial(self._task_group.start_soon, name=name),
self._call_func,
func,
@@ -341,14 +283,21 @@ class Process(abc.Process):
_stderr: abc.ByteReceiveStream | None
async def aclose(self) -> None:
- if self._stdin:
- await self._stdin.aclose()
- if self._stdout:
- await self._stdout.aclose()
- if self._stderr:
- await self._stderr.aclose()
+ with CancelScope(shield=True):
+ if self._stdin:
+ await self._stdin.aclose()
+ if self._stdout:
+ await self._stdout.aclose()
+ if self._stderr:
+ await self._stderr.aclose()
- await self.wait()
+ try:
+ await self.wait()
+ except BaseException:
+ self.kill()
+ with CancelScope(shield=True):
+ await self.wait()
+ raise
async def wait(self) -> int:
return await self._process.wait()
@@ -383,47 +332,19 @@ class Process(abc.Process):
return self._stderr
-async def open_process(
- command: str | bytes | Sequence[str | bytes],
- *,
- shell: bool,
- stdin: int | IO[Any] | None,
- stdout: int | IO[Any] | None,
- stderr: int | IO[Any] | None,
- cwd: str | bytes | PathLike | None = None,
- env: Mapping[str, str] | None = None,
- start_new_session: bool = False,
-) -> Process:
- process = await trio_open_process( # type: ignore[misc]
- command, # type: ignore[arg-type]
- stdin=stdin,
- stdout=stdout,
- stderr=stderr,
- shell=shell,
- cwd=cwd,
- env=env,
- start_new_session=start_new_session,
- )
- stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None
- stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None
- stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None
- return Process(process, stdin_stream, stdout_stream, stderr_stream)
-
-
class _ProcessPoolShutdownInstrument(trio.abc.Instrument):
def after_run(self) -> None:
super().after_run()
-current_default_worker_process_limiter: RunVar = RunVar(
+current_default_worker_process_limiter: trio.lowlevel.RunVar = RunVar(
"current_default_worker_process_limiter"
)
-async def _shutdown_process_pool(workers: set[Process]) -> None:
- process: Process
+async def _shutdown_process_pool(workers: set[abc.Process]) -> None:
try:
- await sleep(math.inf)
+ await trio.sleep(math.inf)
except trio.Cancelled:
for process in workers:
if process.returncode is None:
@@ -434,10 +355,6 @@ async def _shutdown_process_pool(workers: set[Process]) -> None:
await process.aclose()
-def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None:
- trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers)
-
-
#
# Sockets and networking
#
@@ -515,7 +432,7 @@ class UNIXSocketStream(SocketStream, abc.UNIXSocketStream):
raise ValueError("maxfds must be a positive integer")
fds = array.array("i")
- await checkpoint()
+ await trio.lowlevel.checkpoint()
with self._receive_guard:
while True:
try:
@@ -555,7 +472,7 @@ class UNIXSocketStream(SocketStream, abc.UNIXSocketStream):
filenos.append(fd.fileno())
fdarray = array.array("i", filenos)
- await checkpoint()
+ await trio.lowlevel.checkpoint()
with self._send_guard:
while True:
try:
@@ -564,7 +481,7 @@ class UNIXSocketStream(SocketStream, abc.UNIXSocketStream):
[
(
socket.SOL_SOCKET,
- socket.SCM_RIGHTS, # type: ignore[list-item]
+ socket.SCM_RIGHTS,
fdarray,
)
],
@@ -648,76 +565,49 @@ class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocke
self._convert_socket_error(exc)
-async def connect_tcp(
- host: str, port: int, local_address: IPSockAddrType | None = None
-) -> SocketStream:
- family = socket.AF_INET6 if ":" in host else socket.AF_INET
- trio_socket = trio.socket.socket(family)
- trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
- if local_address:
- await trio_socket.bind(local_address)
-
- try:
- await trio_socket.connect((host, port))
- except BaseException:
- trio_socket.close()
- raise
-
- return SocketStream(trio_socket)
-
-
-async def connect_unix(path: str) -> UNIXSocketStream:
- trio_socket = trio.socket.socket(socket.AF_UNIX)
- try:
- await trio_socket.connect(path)
- except BaseException:
- trio_socket.close()
- raise
-
- return UNIXSocketStream(trio_socket)
-
-
-async def create_udp_socket(
- family: socket.AddressFamily,
- local_address: IPSockAddrType | None,
- remote_address: IPSockAddrType | None,
- reuse_port: bool,
-) -> UDPSocket | ConnectedUDPSocket:
- trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM)
-
- if reuse_port:
- trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-
- if local_address:
- await trio_socket.bind(local_address)
-
- if remote_address:
- await trio_socket.connect(remote_address)
- return ConnectedUDPSocket(trio_socket)
- else:
- return UDPSocket(trio_socket)
+class UNIXDatagramSocket(_TrioSocketMixin[str], abc.UNIXDatagramSocket):
+ def __init__(self, trio_socket: TrioSocketType) -> None:
+ super().__init__(trio_socket)
+ self._receive_guard = ResourceGuard("reading from")
+ self._send_guard = ResourceGuard("writing to")
+ async def receive(self) -> UNIXDatagramPacketType:
+ with self._receive_guard:
+ try:
+ data, addr = await self._trio_socket.recvfrom(65536)
+ return data, addr
+ except BaseException as exc:
+ self._convert_socket_error(exc)
-getaddrinfo = trio.socket.getaddrinfo
-getnameinfo = trio.socket.getnameinfo
+ async def send(self, item: UNIXDatagramPacketType) -> None:
+ with self._send_guard:
+ try:
+ await self._trio_socket.sendto(*item)
+ except BaseException as exc:
+ self._convert_socket_error(exc)
-async def wait_socket_readable(sock: socket.socket) -> None:
- try:
- await wait_readable(sock)
- except trio.ClosedResourceError as exc:
- raise ClosedResourceError().with_traceback(exc.__traceback__) from None
- except trio.BusyResourceError:
- raise BusyResourceError("reading from") from None
+class ConnectedUNIXDatagramSocket(
+ _TrioSocketMixin[str], abc.ConnectedUNIXDatagramSocket
+):
+ def __init__(self, trio_socket: TrioSocketType) -> None:
+ super().__init__(trio_socket)
+ self._receive_guard = ResourceGuard("reading from")
+ self._send_guard = ResourceGuard("writing to")
+ async def receive(self) -> bytes:
+ with self._receive_guard:
+ try:
+ return await self._trio_socket.recv(65536)
+ except BaseException as exc:
+ self._convert_socket_error(exc)
-async def wait_socket_writable(sock: socket.socket) -> None:
- try:
- await wait_writable(sock)
- except trio.ClosedResourceError as exc:
- raise ClosedResourceError().with_traceback(exc.__traceback__) from None
- except trio.BusyResourceError:
- raise BusyResourceError("writing to") from None
+ async def send(self, item: bytes) -> None:
+ with self._send_guard:
+ try:
+ await self._trio_socket.send(item)
+ except BaseException as exc:
+ self._convert_socket_error(exc)
#
@@ -742,19 +632,30 @@ class Event(BaseEvent):
orig_statistics = self.__original.statistics()
return EventStatistics(tasks_waiting=orig_statistics.tasks_waiting)
- def set(self) -> DeprecatedAwaitable:
+ def set(self) -> None:
self.__original.set()
- return DeprecatedAwaitable(self.set)
class CapacityLimiter(BaseCapacityLimiter):
- def __new__(cls, *args: object, **kwargs: object) -> CapacityLimiter:
+ def __new__(
+ cls,
+ total_tokens: float | None = None,
+ *,
+ original: trio.CapacityLimiter | None = None,
+ ) -> CapacityLimiter:
return object.__new__(cls)
def __init__(
- self, *args: Any, original: trio.CapacityLimiter | None = None
+ self,
+ total_tokens: float | None = None,
+ *,
+ original: trio.CapacityLimiter | None = None,
) -> None:
- self.__original = original or trio.CapacityLimiter(*args)
+ if original is not None:
+ self.__original = original
+ else:
+ assert total_tokens is not None
+ self.__original = trio.CapacityLimiter(total_tokens)
async def __aenter__(self) -> None:
return await self.__original.__aenter__()
@@ -783,13 +684,11 @@ class CapacityLimiter(BaseCapacityLimiter):
def available_tokens(self) -> float:
return self.__original.available_tokens
- def acquire_nowait(self) -> DeprecatedAwaitable:
+ def acquire_nowait(self) -> None:
self.__original.acquire_nowait()
- return DeprecatedAwaitable(self.acquire_nowait)
- def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
self.__original.acquire_on_behalf_of_nowait(borrower)
- return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait)
async def acquire(self) -> None:
await self.__original.acquire()
@@ -808,23 +707,12 @@ class CapacityLimiter(BaseCapacityLimiter):
return CapacityLimiterStatistics(
borrowed_tokens=orig.borrowed_tokens,
total_tokens=orig.total_tokens,
- borrowers=orig.borrowers,
+ borrowers=tuple(orig.borrowers),
tasks_waiting=orig.tasks_waiting,
)
-_capacity_limiter_wrapper: RunVar = RunVar("_capacity_limiter_wrapper")
-
-
-def current_default_thread_limiter() -> CapacityLimiter:
- try:
- return _capacity_limiter_wrapper.get()
- except LookupError:
- limiter = CapacityLimiter(
- original=trio.to_thread.current_default_thread_limiter()
- )
- _capacity_limiter_wrapper.set(limiter)
- return limiter
+_capacity_limiter_wrapper: trio.lowlevel.RunVar = RunVar("_capacity_limiter_wrapper")
#
@@ -832,7 +720,7 @@ def current_default_thread_limiter() -> CapacityLimiter:
#
-class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]):
+class _SignalReceiver:
_iterator: AsyncIterator[int]
def __init__(self, signals: tuple[Signals, ...]):
@@ -859,138 +747,423 @@ class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]):
return Signals(signum)
-def open_signal_receiver(*signals: Signals) -> _SignalReceiver:
- return _SignalReceiver(signals)
-
-
#
# Testing and debugging
#
-def get_current_task() -> TaskInfo:
- task = trio_lowlevel.current_task()
-
- parent_id = None
- if task.parent_nursery and task.parent_nursery.parent_task:
- parent_id = id(task.parent_nursery.parent_task)
-
- return TaskInfo(id(task), parent_id, task.name, task.coro)
-
-
-def get_running_tasks() -> list[TaskInfo]:
- root_task = trio_lowlevel.current_root_task()
- task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)]
- nurseries = root_task.child_nurseries
- while nurseries:
- new_nurseries: list[trio.Nursery] = []
- for nursery in nurseries:
- for task in nursery.child_tasks:
- task_infos.append(
- TaskInfo(id(task), id(nursery.parent_task), task.name, task.coro)
- )
- new_nurseries.extend(task.child_nurseries)
-
- nurseries = new_nurseries
-
- return task_infos
-
-
-def wait_all_tasks_blocked() -> Awaitable[None]:
- import trio.testing
-
- return trio.testing.wait_all_tasks_blocked()
-
-
class TestRunner(abc.TestRunner):
def __init__(self, **options: Any) -> None:
- from collections import deque
from queue import Queue
- self._call_queue: Queue[Callable[..., object]] = Queue()
- self._result_queue: deque[Outcome] = deque()
- self._stop_event: trio.Event | None = None
- self._nursery: trio.Nursery | None = None
+ self._call_queue: Queue[Callable[[], object]] = Queue()
+ self._send_stream: MemoryObjectSendStream | None = None
self._options = options
- async def _trio_main(self) -> None:
- self._stop_event = trio.Event()
- async with trio.open_nursery() as self._nursery:
- await self._stop_event.wait()
-
- async def _call_func(
- self, func: Callable[..., Awaitable[object]], args: tuple, kwargs: dict
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: types.TracebackType | None,
) -> None:
- try:
- retval = await func(*args, **kwargs)
- except BaseException as exc:
- self._result_queue.append(Error(exc))
- else:
- self._result_queue.append(Value(retval))
+ if self._send_stream:
+ self._send_stream.close()
+ while self._send_stream is not None:
+ self._call_queue.get()()
+
+ async def _run_tests_and_fixtures(self) -> None:
+ self._send_stream, receive_stream = create_memory_object_stream(1)
+ with receive_stream:
+ async for coro, outcome_holder in receive_stream:
+ try:
+ retval = await coro
+ except BaseException as exc:
+ outcome_holder.append(Error(exc))
+ else:
+ outcome_holder.append(Value(retval))
def _main_task_finished(self, outcome: object) -> None:
- self._nursery = None
+ self._send_stream = None
- def _get_nursery(self) -> trio.Nursery:
- if self._nursery is None:
+ def _call_in_runner_task(
+ self,
+ func: Callable[P, Awaitable[T_Retval]],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> T_Retval:
+ if self._send_stream is None:
trio.lowlevel.start_guest_run(
- self._trio_main,
+ self._run_tests_and_fixtures,
run_sync_soon_threadsafe=self._call_queue.put,
done_callback=self._main_task_finished,
**self._options,
)
- while self._nursery is None:
+ while self._send_stream is None:
self._call_queue.get()()
- return self._nursery
-
- def _call(
- self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
- ) -> T_Retval:
- self._get_nursery().start_soon(self._call_func, func, args, kwargs)
- while not self._result_queue:
+ outcome_holder: list[Outcome] = []
+ self._send_stream.send_nowait((func(*args, **kwargs), outcome_holder))
+ while not outcome_holder:
self._call_queue.get()()
- outcome = self._result_queue.pop()
- return outcome.unwrap()
-
- def close(self) -> None:
- if self._stop_event:
- self._stop_event.set()
- while self._nursery is not None:
- self._call_queue.get()()
+ return outcome_holder[0].unwrap()
def run_asyncgen_fixture(
self,
fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
kwargs: dict[str, Any],
) -> Iterable[T_Retval]:
- async def fixture_runner(*, task_status: TaskStatus[T_Retval]) -> None:
- agen = fixture_func(**kwargs)
- retval = await agen.asend(None)
- task_status.started(retval)
- await teardown_event.wait()
- try:
- await agen.asend(None)
- except StopAsyncIteration:
- pass
- else:
- await agen.aclose()
- raise RuntimeError("Async generator fixture did not stop")
+ asyncgen = fixture_func(**kwargs)
+ fixturevalue: T_Retval = self._call_in_runner_task(asyncgen.asend, None)
- teardown_event = trio.Event()
- fixture_value = self._call(lambda: self._get_nursery().start(fixture_runner))
- yield fixture_value
- teardown_event.set()
+ yield fixturevalue
+
+ try:
+ self._call_in_runner_task(asyncgen.asend, None)
+ except StopAsyncIteration:
+ pass
+ else:
+ self._call_in_runner_task(asyncgen.aclose)
+ raise RuntimeError("Async generator fixture did not stop")
def run_fixture(
self,
fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
kwargs: dict[str, Any],
) -> T_Retval:
- return self._call(fixture_func, **kwargs)
+ return self._call_in_runner_task(fixture_func, **kwargs)
def run_test(
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
) -> None:
- self._call(test_func, **kwargs)
+ self._call_in_runner_task(test_func, **kwargs)
+
+
+class TrioBackend(AsyncBackend):
+ @classmethod
+ def run(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
+ args: tuple[Unpack[PosArgsT]],
+ kwargs: dict[str, Any],
+ options: dict[str, Any],
+ ) -> T_Retval:
+ return trio.run(func, *args)
+
+ @classmethod
+ def current_token(cls) -> object:
+ return trio.lowlevel.current_trio_token()
+
+ @classmethod
+ def current_time(cls) -> float:
+ return trio.current_time()
+
+ @classmethod
+ def cancelled_exception_class(cls) -> type[BaseException]:
+ return trio.Cancelled
+
+ @classmethod
+ async def checkpoint(cls) -> None:
+ await trio.lowlevel.checkpoint()
+
+ @classmethod
+ async def checkpoint_if_cancelled(cls) -> None:
+ await trio.lowlevel.checkpoint_if_cancelled()
+
+ @classmethod
+ async def cancel_shielded_checkpoint(cls) -> None:
+ await trio.lowlevel.cancel_shielded_checkpoint()
+
+ @classmethod
+ async def sleep(cls, delay: float) -> None:
+ await trio.sleep(delay)
+
+ @classmethod
+ def create_cancel_scope(
+ cls, *, deadline: float = math.inf, shield: bool = False
+ ) -> abc.CancelScope:
+ return CancelScope(deadline=deadline, shield=shield)
+
+ @classmethod
+ def current_effective_deadline(cls) -> float:
+ return trio.current_effective_deadline()
+
+ @classmethod
+ def create_task_group(cls) -> abc.TaskGroup:
+ return TaskGroup()
+
+ @classmethod
+ def create_event(cls) -> abc.Event:
+ return Event()
+
+ @classmethod
+ def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
+ return CapacityLimiter(total_tokens)
+
+ @classmethod
+ async def run_sync_in_worker_thread(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
+ args: tuple[Unpack[PosArgsT]],
+ abandon_on_cancel: bool = False,
+ limiter: abc.CapacityLimiter | None = None,
+ ) -> T_Retval:
+ def wrapper() -> T_Retval:
+ with claim_worker_thread(TrioBackend, token):
+ return func(*args)
+
+ token = TrioBackend.current_token()
+ return await run_sync(
+ wrapper,
+ abandon_on_cancel=abandon_on_cancel,
+ limiter=cast(trio.CapacityLimiter, limiter),
+ )
+
+ @classmethod
+ def check_cancelled(cls) -> None:
+ trio.from_thread.check_cancelled()
+
+ @classmethod
+ def run_async_from_thread(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
+ args: tuple[Unpack[PosArgsT]],
+ token: object,
+ ) -> T_Retval:
+ return trio.from_thread.run(func, *args)
+
+ @classmethod
+ def run_sync_from_thread(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
+ args: tuple[Unpack[PosArgsT]],
+ token: object,
+ ) -> T_Retval:
+ return trio.from_thread.run_sync(func, *args)
+
+ @classmethod
+ def create_blocking_portal(cls) -> abc.BlockingPortal:
+ return BlockingPortal()
+
+ @classmethod
+ async def open_process(
+ cls,
+ command: str | bytes | Sequence[str | bytes],
+ *,
+ shell: bool,
+ stdin: int | IO[Any] | None,
+ stdout: int | IO[Any] | None,
+ stderr: int | IO[Any] | None,
+ cwd: str | bytes | PathLike | None = None,
+ env: Mapping[str, str] | None = None,
+ start_new_session: bool = False,
+ ) -> Process:
+ process = await trio.lowlevel.open_process( # type: ignore[misc]
+ command, # type: ignore[arg-type]
+ stdin=stdin,
+ stdout=stdout,
+ stderr=stderr,
+ shell=shell,
+ cwd=cwd,
+ env=env,
+ start_new_session=start_new_session,
+ )
+ stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None
+ stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None
+ stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None
+ return Process(process, stdin_stream, stdout_stream, stderr_stream)
+
+ @classmethod
+ def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
+ trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers)
+
+ @classmethod
+ async def connect_tcp(
+ cls, host: str, port: int, local_address: IPSockAddrType | None = None
+ ) -> SocketStream:
+ family = socket.AF_INET6 if ":" in host else socket.AF_INET
+ trio_socket = trio.socket.socket(family)
+ trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ if local_address:
+ await trio_socket.bind(local_address)
+
+ try:
+ await trio_socket.connect((host, port))
+ except BaseException:
+ trio_socket.close()
+ raise
+
+ return SocketStream(trio_socket)
+
+ @classmethod
+ async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
+ trio_socket = trio.socket.socket(socket.AF_UNIX)
+ try:
+ await trio_socket.connect(path)
+ except BaseException:
+ trio_socket.close()
+ raise
+
+ return UNIXSocketStream(trio_socket)
+
+ @classmethod
+ def create_tcp_listener(cls, sock: socket.socket) -> abc.SocketListener:
+ return TCPSocketListener(sock)
+
+ @classmethod
+ def create_unix_listener(cls, sock: socket.socket) -> abc.SocketListener:
+ return UNIXSocketListener(sock)
+
+ @classmethod
+ async def create_udp_socket(
+ cls,
+ family: socket.AddressFamily,
+ local_address: IPSockAddrType | None,
+ remote_address: IPSockAddrType | None,
+ reuse_port: bool,
+ ) -> UDPSocket | ConnectedUDPSocket:
+ trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM)
+
+ if reuse_port:
+ trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+
+ if local_address:
+ await trio_socket.bind(local_address)
+
+ if remote_address:
+ await trio_socket.connect(remote_address)
+ return ConnectedUDPSocket(trio_socket)
+ else:
+ return UDPSocket(trio_socket)
+
+ @classmethod
+ @overload
+ async def create_unix_datagram_socket(
+ cls, raw_socket: socket.socket, remote_path: None
+ ) -> abc.UNIXDatagramSocket:
+ ...
+
+ @classmethod
+ @overload
+ async def create_unix_datagram_socket(
+ cls, raw_socket: socket.socket, remote_path: str | bytes
+ ) -> abc.ConnectedUNIXDatagramSocket:
+ ...
+
+ @classmethod
+ async def create_unix_datagram_socket(
+ cls, raw_socket: socket.socket, remote_path: str | bytes | None
+ ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
+ trio_socket = trio.socket.from_stdlib_socket(raw_socket)
+
+ if remote_path:
+ await trio_socket.connect(remote_path)
+ return ConnectedUNIXDatagramSocket(trio_socket)
+ else:
+ return UNIXDatagramSocket(trio_socket)
+
+ @classmethod
+ async def getaddrinfo(
+ cls,
+ host: bytes | str | None,
+ port: str | int | None,
+ *,
+ family: int | AddressFamily = 0,
+ type: int | SocketKind = 0,
+ proto: int = 0,
+ flags: int = 0,
+ ) -> list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ tuple[str, int] | tuple[str, int, int, int],
+ ]
+ ]:
+ return await trio.socket.getaddrinfo(host, port, family, type, proto, flags)
+
+ @classmethod
+ async def getnameinfo(
+ cls, sockaddr: IPSockAddrType, flags: int = 0
+ ) -> tuple[str, str]:
+ return await trio.socket.getnameinfo(sockaddr, flags)
+
+ @classmethod
+ async def wait_socket_readable(cls, sock: socket.socket) -> None:
+ try:
+ await wait_readable(sock)
+ except trio.ClosedResourceError as exc:
+ raise ClosedResourceError().with_traceback(exc.__traceback__) from None
+ except trio.BusyResourceError:
+ raise BusyResourceError("reading from") from None
+
+ @classmethod
+ async def wait_socket_writable(cls, sock: socket.socket) -> None:
+ try:
+ await wait_writable(sock)
+ except trio.ClosedResourceError as exc:
+ raise ClosedResourceError().with_traceback(exc.__traceback__) from None
+ except trio.BusyResourceError:
+ raise BusyResourceError("writing to") from None
+
+ @classmethod
+ def current_default_thread_limiter(cls) -> CapacityLimiter:
+ try:
+ return _capacity_limiter_wrapper.get()
+ except LookupError:
+ limiter = CapacityLimiter(
+ original=trio.to_thread.current_default_thread_limiter()
+ )
+ _capacity_limiter_wrapper.set(limiter)
+ return limiter
+
+ @classmethod
+ def open_signal_receiver(
+ cls, *signals: Signals
+ ) -> ContextManager[AsyncIterator[Signals]]:
+ return _SignalReceiver(signals)
+
+ @classmethod
+ def get_current_task(cls) -> TaskInfo:
+ task = current_task()
+
+ parent_id = None
+ if task.parent_nursery and task.parent_nursery.parent_task:
+ parent_id = id(task.parent_nursery.parent_task)
+
+ return TaskInfo(id(task), parent_id, task.name, task.coro)
+
+ @classmethod
+ def get_running_tasks(cls) -> list[TaskInfo]:
+ root_task = current_root_task()
+ assert root_task
+ task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)]
+ nurseries = root_task.child_nurseries
+ while nurseries:
+ new_nurseries: list[trio.Nursery] = []
+ for nursery in nurseries:
+ for task in nursery.child_tasks:
+ task_infos.append(
+ TaskInfo(
+ id(task), id(nursery.parent_task), task.name, task.coro
+ )
+ )
+ new_nurseries.extend(task.child_nurseries)
+
+ nurseries = new_nurseries
+
+ return task_infos
+
+ @classmethod
+ async def wait_all_tasks_blocked(cls) -> None:
+ from trio.testing import wait_all_tasks_blocked
+
+ await wait_all_tasks_blocked()
+
+ @classmethod
+ def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
+ return TestRunner(**options)
+
+
+backend_class = TrioBackend
diff --git a/contrib/python/anyio/anyio/_core/_compat.py b/contrib/python/anyio/anyio/_core/_compat.py
deleted file mode 100644
index 22d29ab8ac..0000000000
--- a/contrib/python/anyio/anyio/_core/_compat.py
+++ /dev/null
@@ -1,217 +0,0 @@
-from __future__ import annotations
-
-from abc import ABCMeta, abstractmethod
-from contextlib import AbstractContextManager
-from types import TracebackType
-from typing import (
- TYPE_CHECKING,
- Any,
- AsyncContextManager,
- Callable,
- ContextManager,
- Generator,
- Generic,
- Iterable,
- List,
- TypeVar,
- Union,
- overload,
-)
-from warnings import warn
-
-if TYPE_CHECKING:
- from ._testing import TaskInfo
-else:
- TaskInfo = object
-
-T = TypeVar("T")
-AnyDeprecatedAwaitable = Union[
- "DeprecatedAwaitable",
- "DeprecatedAwaitableFloat",
- "DeprecatedAwaitableList[T]",
- TaskInfo,
-]
-
-
-@overload
-async def maybe_async(__obj: TaskInfo) -> TaskInfo:
- ...
-
-
-@overload
-async def maybe_async(__obj: DeprecatedAwaitableFloat) -> float:
- ...
-
-
-@overload
-async def maybe_async(__obj: DeprecatedAwaitableList[T]) -> list[T]:
- ...
-
-
-@overload
-async def maybe_async(__obj: DeprecatedAwaitable) -> None:
- ...
-
-
-async def maybe_async(
- __obj: AnyDeprecatedAwaitable[T],
-) -> TaskInfo | float | list[T] | None:
- """
- Await on the given object if necessary.
-
- This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and
- methods were converted from coroutine functions into regular functions.
-
- Do **not** try to use this for any other purpose!
-
- :return: the result of awaiting on the object if coroutine, or the object itself otherwise
-
- .. versionadded:: 2.2
-
- """
- return __obj._unwrap()
-
-
-class _ContextManagerWrapper:
- def __init__(self, cm: ContextManager[T]):
- self._cm = cm
-
- async def __aenter__(self) -> T:
- return self._cm.__enter__()
-
- async def __aexit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> bool | None:
- return self._cm.__exit__(exc_type, exc_val, exc_tb)
-
-
-def maybe_async_cm(
- cm: ContextManager[T] | AsyncContextManager[T],
-) -> AsyncContextManager[T]:
- """
- Wrap a regular context manager as an async one if necessary.
-
- This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and
- methods were changed to return regular context managers instead of async ones.
-
- :param cm: a regular or async context manager
- :return: an async context manager
-
- .. versionadded:: 2.2
-
- """
- if not isinstance(cm, AbstractContextManager):
- raise TypeError("Given object is not an context manager")
-
- return _ContextManagerWrapper(cm)
-
-
-def _warn_deprecation(
- awaitable: AnyDeprecatedAwaitable[Any], stacklevel: int = 1
-) -> None:
- warn(
- f'Awaiting on {awaitable._name}() is deprecated. Use "await '
- f"anyio.maybe_async({awaitable._name}(...)) if you have to support both AnyIO 2.x "
- f'and 3.x, or just remove the "await" if you are completely migrating to AnyIO 3+.',
- DeprecationWarning,
- stacklevel=stacklevel + 1,
- )
-
-
-class DeprecatedAwaitable:
- def __init__(self, func: Callable[..., DeprecatedAwaitable]):
- self._name = f"{func.__module__}.{func.__qualname__}"
-
- def __await__(self) -> Generator[None, None, None]:
- _warn_deprecation(self)
- if False:
- yield
-
- def __reduce__(self) -> tuple[type[None], tuple[()]]:
- return type(None), ()
-
- def _unwrap(self) -> None:
- return None
-
-
-class DeprecatedAwaitableFloat(float):
- def __new__(
- cls, x: float, func: Callable[..., DeprecatedAwaitableFloat]
- ) -> DeprecatedAwaitableFloat:
- return super().__new__(cls, x)
-
- def __init__(self, x: float, func: Callable[..., DeprecatedAwaitableFloat]):
- self._name = f"{func.__module__}.{func.__qualname__}"
-
- def __await__(self) -> Generator[None, None, float]:
- _warn_deprecation(self)
- if False:
- yield
-
- return float(self)
-
- def __reduce__(self) -> tuple[type[float], tuple[float]]:
- return float, (float(self),)
-
- def _unwrap(self) -> float:
- return float(self)
-
-
-class DeprecatedAwaitableList(List[T]):
- def __init__(
- self,
- iterable: Iterable[T] = (),
- *,
- func: Callable[..., DeprecatedAwaitableList[T]],
- ):
- super().__init__(iterable)
- self._name = f"{func.__module__}.{func.__qualname__}"
-
- def __await__(self) -> Generator[None, None, list[T]]:
- _warn_deprecation(self)
- if False:
- yield
-
- return list(self)
-
- def __reduce__(self) -> tuple[type[list[T]], tuple[list[T]]]:
- return list, (list(self),)
-
- def _unwrap(self) -> list[T]:
- return list(self)
-
-
-class DeprecatedAsyncContextManager(Generic[T], metaclass=ABCMeta):
- @abstractmethod
- def __enter__(self) -> T:
- pass
-
- @abstractmethod
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> bool | None:
- pass
-
- async def __aenter__(self) -> T:
- warn(
- f"Using {self.__class__.__name__} as an async context manager has been deprecated. "
- f'Use "async with anyio.maybe_async_cm(yourcontextmanager) as foo:" if you have to '
- f'support both AnyIO 2.x and 3.x, or just remove the "async" from "async with" if '
- f"you are completely migrating to AnyIO 3+.",
- DeprecationWarning,
- )
- return self.__enter__()
-
- async def __aexit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> bool | None:
- return self.__exit__(exc_type, exc_val, exc_tb)
diff --git a/contrib/python/anyio/anyio/_core/_eventloop.py b/contrib/python/anyio/anyio/_core/_eventloop.py
index ae9864851b..a9c6e82585 100644
--- a/contrib/python/anyio/anyio/_core/_eventloop.py
+++ b/contrib/python/anyio/anyio/_core/_eventloop.py
@@ -3,30 +3,33 @@ from __future__ import annotations
import math
import sys
import threading
+from collections.abc import Awaitable, Callable, Generator
from contextlib import contextmanager
from importlib import import_module
-from typing import (
- Any,
- Awaitable,
- Callable,
- Generator,
- TypeVar,
-)
+from typing import TYPE_CHECKING, Any, TypeVar
import sniffio
-# This must be updated when new backends are introduced
-from ._compat import DeprecatedAwaitableFloat
+if sys.version_info >= (3, 11):
+ from typing import TypeVarTuple, Unpack
+else:
+ from typing_extensions import TypeVarTuple, Unpack
+
+if TYPE_CHECKING:
+ from ..abc import AsyncBackend
+# This must be updated when new backends are introduced
BACKENDS = "asyncio", "trio"
T_Retval = TypeVar("T_Retval")
+PosArgsT = TypeVarTuple("PosArgsT")
+
threadlocals = threading.local()
def run(
- func: Callable[..., Awaitable[T_Retval]],
- *args: object,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
+ *args: Unpack[PosArgsT],
backend: str = "asyncio",
backend_options: dict[str, Any] | None = None,
) -> T_Retval:
@@ -37,12 +40,13 @@ def run(
:param func: a coroutine function
:param args: positional arguments to ``func``
- :param backend: name of the asynchronous event loop implementation – currently either
- ``asyncio`` or ``trio``
- :param backend_options: keyword arguments to call the backend ``run()`` implementation with
- (documented :ref:`here <backend options>`)
+ :param backend: name of the asynchronous event loop implementation – currently
+ either ``asyncio`` or ``trio``
+ :param backend_options: keyword arguments to call the backend ``run()``
+ implementation with (documented :ref:`here <backend options>`)
:return: the return value of the coroutine function
- :raises RuntimeError: if an asynchronous event loop is already running in this thread
+ :raises RuntimeError: if an asynchronous event loop is already running in this
+ thread
:raises LookupError: if the named backend is not found
"""
@@ -54,18 +58,19 @@ def run(
raise RuntimeError(f"Already running {asynclib_name} in this thread")
try:
- asynclib = import_module(f"..._backends._{backend}", package=__name__)
+ async_backend = get_async_backend(backend)
except ImportError as exc:
raise LookupError(f"No such backend: {backend}") from exc
token = None
if sniffio.current_async_library_cvar.get(None) is None:
- # Since we're in control of the event loop, we can cache the name of the async library
+ # Since we're in control of the event loop, we can cache the name of the async
+ # library
token = sniffio.current_async_library_cvar.set(backend)
try:
backend_options = backend_options or {}
- return asynclib.run(func, *args, **backend_options)
+ return async_backend.run(func, args, {}, backend_options)
finally:
if token:
sniffio.current_async_library_cvar.reset(token)
@@ -78,7 +83,7 @@ async def sleep(delay: float) -> None:
:param delay: the duration, in seconds
"""
- return await get_asynclib().sleep(delay)
+ return await get_async_backend().sleep(delay)
async def sleep_forever() -> None:
@@ -97,8 +102,8 @@ async def sleep_until(deadline: float) -> None:
"""
Pause the current task until the given time.
- :param deadline: the absolute time to wake up at (according to the internal monotonic clock of
- the event loop)
+ :param deadline: the absolute time to wake up at (according to the internal
+ monotonic clock of the event loop)
.. versionadded:: 3.1
@@ -107,14 +112,14 @@ async def sleep_until(deadline: float) -> None:
await sleep(max(deadline - now, 0))
-def current_time() -> DeprecatedAwaitableFloat:
+def current_time() -> float:
"""
Return the current value of the event loop's internal clock.
:return: the clock value (seconds)
"""
- return DeprecatedAwaitableFloat(get_asynclib().current_time(), current_time)
+ return get_async_backend().current_time()
def get_all_backends() -> tuple[str, ...]:
@@ -124,7 +129,7 @@ def get_all_backends() -> tuple[str, ...]:
def get_cancelled_exc_class() -> type[BaseException]:
"""Return the current async library's cancellation exception class."""
- return get_asynclib().CancelledError
+ return get_async_backend().cancelled_exception_class()
#
@@ -133,21 +138,26 @@ def get_cancelled_exc_class() -> type[BaseException]:
@contextmanager
-def claim_worker_thread(backend: str) -> Generator[Any, None, None]:
- module = sys.modules["anyio._backends._" + backend]
- threadlocals.current_async_module = module
+def claim_worker_thread(
+ backend_class: type[AsyncBackend], token: object
+) -> Generator[Any, None, None]:
+ threadlocals.current_async_backend = backend_class
+ threadlocals.current_token = token
try:
yield
finally:
- del threadlocals.current_async_module
+ del threadlocals.current_async_backend
+ del threadlocals.current_token
-def get_asynclib(asynclib_name: str | None = None) -> Any:
+def get_async_backend(asynclib_name: str | None = None) -> AsyncBackend:
if asynclib_name is None:
asynclib_name = sniffio.current_async_library()
modulename = "anyio._backends._" + asynclib_name
try:
- return sys.modules[modulename]
+ module = sys.modules[modulename]
except KeyError:
- return import_module(modulename)
+ module = import_module(modulename)
+
+ return getattr(module, "backend_class")
diff --git a/contrib/python/anyio/anyio/_core/_exceptions.py b/contrib/python/anyio/anyio/_core/_exceptions.py
index 92ccd77a2d..571c3b8531 100644
--- a/contrib/python/anyio/anyio/_core/_exceptions.py
+++ b/contrib/python/anyio/anyio/_core/_exceptions.py
@@ -1,24 +1,25 @@
from __future__ import annotations
-from traceback import format_exception
-
class BrokenResourceError(Exception):
"""
- Raised when trying to use a resource that has been rendered unusable due to external causes
- (e.g. a send stream whose peer has disconnected).
+ Raised when trying to use a resource that has been rendered unusable due to external
+ causes (e.g. a send stream whose peer has disconnected).
"""
class BrokenWorkerProcess(Exception):
"""
- Raised by :func:`run_sync_in_process` if the worker process terminates abruptly or otherwise
- misbehaves.
+ Raised by :func:`run_sync_in_process` if the worker process terminates abruptly or
+ otherwise misbehaves.
"""
class BusyResourceError(Exception):
- """Raised when two tasks are trying to read from or write to the same resource concurrently."""
+ """
+ Raised when two tasks are trying to read from or write to the same resource
+ concurrently.
+ """
def __init__(self, action: str):
super().__init__(f"Another task is already {action} this resource")
@@ -30,7 +31,8 @@ class ClosedResourceError(Exception):
class DelimiterNotFound(Exception):
"""
- Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the
+ Raised during
+ :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the
maximum number of bytes has been read without the delimiter being found.
"""
@@ -41,38 +43,15 @@ class DelimiterNotFound(Exception):
class EndOfStream(Exception):
- """Raised when trying to read from a stream that has been closed from the other end."""
-
-
-class ExceptionGroup(BaseException):
"""
- Raised when multiple exceptions have been raised in a task group.
-
- :var ~typing.Sequence[BaseException] exceptions: the sequence of exceptions raised together
+ Raised when trying to read from a stream that has been closed from the other end.
"""
- SEPARATOR = "----------------------------\n"
-
- exceptions: list[BaseException]
-
- def __str__(self) -> str:
- tracebacks = [
- "".join(format_exception(type(exc), exc, exc.__traceback__))
- for exc in self.exceptions
- ]
- return (
- f"{len(self.exceptions)} exceptions were raised in the task group:\n"
- f"{self.SEPARATOR}{self.SEPARATOR.join(tracebacks)}"
- )
-
- def __repr__(self) -> str:
- exception_reprs = ", ".join(repr(exc) for exc in self.exceptions)
- return f"<{self.__class__.__name__}: {exception_reprs}>"
-
class IncompleteRead(Exception):
"""
- Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_exactly` or
+ Raised during
+ :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_exactly` or
:meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the
connection is closed before the requested amount of bytes has been read.
"""
@@ -85,8 +64,8 @@ class IncompleteRead(Exception):
class TypedAttributeLookupError(LookupError):
"""
- Raised by :meth:`~anyio.TypedAttributeProvider.extra` when the given typed attribute is not
- found and no default value has been given.
+ Raised by :meth:`~anyio.TypedAttributeProvider.extra` when the given typed attribute
+ is not found and no default value has been given.
"""
diff --git a/contrib/python/anyio/anyio/_core/_fileio.py b/contrib/python/anyio/anyio/_core/_fileio.py
index 35e8e8af6c..d054be693d 100644
--- a/contrib/python/anyio/anyio/_core/_fileio.py
+++ b/contrib/python/anyio/anyio/_core/_fileio.py
@@ -3,6 +3,7 @@ from __future__ import annotations
import os
import pathlib
import sys
+from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass
from functools import partial
from os import PathLike
@@ -12,23 +13,14 @@ from typing import (
Any,
AnyStr,
AsyncIterator,
- Callable,
+ Final,
Generic,
- Iterable,
- Iterator,
- Sequence,
- cast,
overload,
)
from .. import to_thread
from ..abc import AsyncResource
-if sys.version_info >= (3, 8):
- from typing import Final
-else:
- from typing_extensions import Final
-
if TYPE_CHECKING:
from _typeshed import OpenBinaryMode, OpenTextMode, ReadableBuffer, WriteableBuffer
else:
@@ -39,8 +31,8 @@ class AsyncFile(AsyncResource, Generic[AnyStr]):
"""
An asynchronous file object.
- This class wraps a standard file object and provides async friendly versions of the following
- blocking methods (where available on the original file object):
+ This class wraps a standard file object and provides async friendly versions of the
+ following blocking methods (where available on the original file object):
* read
* read1
@@ -57,8 +49,8 @@ class AsyncFile(AsyncResource, Generic[AnyStr]):
All other methods are directly passed through.
- This class supports the asynchronous context manager protocol which closes the underlying file
- at the end of the context block.
+ This class supports the asynchronous context manager protocol which closes the
+ underlying file at the end of the context block.
This class also supports asynchronous iteration::
@@ -212,22 +204,25 @@ class _PathIterator(AsyncIterator["Path"]):
iterator: Iterator[PathLike[str]]
async def __anext__(self) -> Path:
- nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True)
+ nextval = await to_thread.run_sync(
+ next, self.iterator, None, abandon_on_cancel=True
+ )
if nextval is None:
raise StopAsyncIteration from None
- return Path(cast("PathLike[str]", nextval))
+ return Path(nextval)
class Path:
"""
An asynchronous version of :class:`pathlib.Path`.
- This class cannot be substituted for :class:`pathlib.Path` or :class:`pathlib.PurePath`, but
- it is compatible with the :class:`os.PathLike` interface.
+ This class cannot be substituted for :class:`pathlib.Path` or
+ :class:`pathlib.PurePath`, but it is compatible with the :class:`os.PathLike`
+ interface.
- It implements the Python 3.10 version of :class:`pathlib.Path` interface, except for the
- deprecated :meth:`~pathlib.Path.link_to` method.
+ It implements the Python 3.10 version of :class:`pathlib.Path` interface, except for
+ the deprecated :meth:`~pathlib.Path.link_to` method.
Any methods that do disk I/O need to be awaited on. These methods are:
@@ -263,7 +258,8 @@ class Path:
* :meth:`~pathlib.Path.write_bytes`
* :meth:`~pathlib.Path.write_text`
- Additionally, the following methods return an async iterator yielding :class:`~.Path` objects:
+ Additionally, the following methods return an async iterator yielding
+ :class:`~.Path` objects:
* :meth:`~pathlib.Path.glob`
* :meth:`~pathlib.Path.iterdir`
@@ -296,26 +292,26 @@ class Path:
target = other._path if isinstance(other, Path) else other
return self._path.__eq__(target)
- def __lt__(self, other: Path) -> bool:
+ def __lt__(self, other: pathlib.PurePath | Path) -> bool:
target = other._path if isinstance(other, Path) else other
return self._path.__lt__(target)
- def __le__(self, other: Path) -> bool:
+ def __le__(self, other: pathlib.PurePath | Path) -> bool:
target = other._path if isinstance(other, Path) else other
return self._path.__le__(target)
- def __gt__(self, other: Path) -> bool:
+ def __gt__(self, other: pathlib.PurePath | Path) -> bool:
target = other._path if isinstance(other, Path) else other
return self._path.__gt__(target)
- def __ge__(self, other: Path) -> bool:
+ def __ge__(self, other: pathlib.PurePath | Path) -> bool:
target = other._path if isinstance(other, Path) else other
return self._path.__ge__(target)
- def __truediv__(self, other: Any) -> Path:
+ def __truediv__(self, other: str | PathLike[str]) -> Path:
return Path(self._path / other)
- def __rtruediv__(self, other: Any) -> Path:
+ def __rtruediv__(self, other: str | PathLike[str]) -> Path:
return Path(other) / self
@property
@@ -371,13 +367,16 @@ class Path:
def match(self, path_pattern: str) -> bool:
return self._path.match(path_pattern)
- def is_relative_to(self, *other: str | PathLike[str]) -> bool:
+ def is_relative_to(self, other: str | PathLike[str]) -> bool:
try:
- self.relative_to(*other)
+ self.relative_to(other)
return True
except ValueError:
return False
+ async def is_junction(self) -> bool:
+ return await to_thread.run_sync(self._path.is_junction)
+
async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None:
func = partial(os.chmod, follow_symlinks=follow_symlinks)
return await to_thread.run_sync(func, self._path, mode)
@@ -388,19 +387,23 @@ class Path:
return cls(path)
async def exists(self) -> bool:
- return await to_thread.run_sync(self._path.exists, cancellable=True)
+ return await to_thread.run_sync(self._path.exists, abandon_on_cancel=True)
async def expanduser(self) -> Path:
- return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True))
+ return Path(
+ await to_thread.run_sync(self._path.expanduser, abandon_on_cancel=True)
+ )
def glob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.glob(pattern)
return _PathIterator(gen)
async def group(self) -> str:
- return await to_thread.run_sync(self._path.group, cancellable=True)
+ return await to_thread.run_sync(self._path.group, abandon_on_cancel=True)
- async def hardlink_to(self, target: str | pathlib.Path | Path) -> None:
+ async def hardlink_to(
+ self, target: str | bytes | PathLike[str] | PathLike[bytes]
+ ) -> None:
if isinstance(target, Path):
target = target._path
@@ -415,31 +418,37 @@ class Path:
return self._path.is_absolute()
async def is_block_device(self) -> bool:
- return await to_thread.run_sync(self._path.is_block_device, cancellable=True)
+ return await to_thread.run_sync(
+ self._path.is_block_device, abandon_on_cancel=True
+ )
async def is_char_device(self) -> bool:
- return await to_thread.run_sync(self._path.is_char_device, cancellable=True)
+ return await to_thread.run_sync(
+ self._path.is_char_device, abandon_on_cancel=True
+ )
async def is_dir(self) -> bool:
- return await to_thread.run_sync(self._path.is_dir, cancellable=True)
+ return await to_thread.run_sync(self._path.is_dir, abandon_on_cancel=True)
async def is_fifo(self) -> bool:
- return await to_thread.run_sync(self._path.is_fifo, cancellable=True)
+ return await to_thread.run_sync(self._path.is_fifo, abandon_on_cancel=True)
async def is_file(self) -> bool:
- return await to_thread.run_sync(self._path.is_file, cancellable=True)
+ return await to_thread.run_sync(self._path.is_file, abandon_on_cancel=True)
async def is_mount(self) -> bool:
- return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True)
+ return await to_thread.run_sync(
+ os.path.ismount, self._path, abandon_on_cancel=True
+ )
def is_reserved(self) -> bool:
return self._path.is_reserved()
async def is_socket(self) -> bool:
- return await to_thread.run_sync(self._path.is_socket, cancellable=True)
+ return await to_thread.run_sync(self._path.is_socket, abandon_on_cancel=True)
async def is_symlink(self) -> bool:
- return await to_thread.run_sync(self._path.is_symlink, cancellable=True)
+ return await to_thread.run_sync(self._path.is_symlink, abandon_on_cancel=True)
def iterdir(self) -> AsyncIterator[Path]:
gen = self._path.iterdir()
@@ -452,7 +461,7 @@ class Path:
await to_thread.run_sync(self._path.lchmod, mode)
async def lstat(self) -> os.stat_result:
- return await to_thread.run_sync(self._path.lstat, cancellable=True)
+ return await to_thread.run_sync(self._path.lstat, abandon_on_cancel=True)
async def mkdir(
self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False
@@ -495,7 +504,7 @@ class Path:
return AsyncFile(fp)
async def owner(self) -> str:
- return await to_thread.run_sync(self._path.owner, cancellable=True)
+ return await to_thread.run_sync(self._path.owner, abandon_on_cancel=True)
async def read_bytes(self) -> bytes:
return await to_thread.run_sync(self._path.read_bytes)
@@ -505,12 +514,21 @@ class Path:
) -> str:
return await to_thread.run_sync(self._path.read_text, encoding, errors)
- def relative_to(self, *other: str | PathLike[str]) -> Path:
- return Path(self._path.relative_to(*other))
+ if sys.version_info >= (3, 12):
+
+ def relative_to(
+ self, *other: str | PathLike[str], walk_up: bool = False
+ ) -> Path:
+ return Path(self._path.relative_to(*other, walk_up=walk_up))
+
+ else:
+
+ def relative_to(self, *other: str | PathLike[str]) -> Path:
+ return Path(self._path.relative_to(*other))
async def readlink(self) -> Path:
target = await to_thread.run_sync(os.readlink, self._path)
- return Path(cast(str, target))
+ return Path(target)
async def rename(self, target: str | pathlib.PurePath | Path) -> Path:
if isinstance(target, Path):
@@ -528,7 +546,7 @@ class Path:
async def resolve(self, strict: bool = False) -> Path:
func = partial(self._path.resolve, strict=strict)
- return Path(await to_thread.run_sync(func, cancellable=True))
+ return Path(await to_thread.run_sync(func, abandon_on_cancel=True))
def rglob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.rglob(pattern)
@@ -537,23 +555,21 @@ class Path:
async def rmdir(self) -> None:
await to_thread.run_sync(self._path.rmdir)
- async def samefile(
- self, other_path: str | bytes | int | pathlib.Path | Path
- ) -> bool:
+ async def samefile(self, other_path: str | PathLike[str]) -> bool:
if isinstance(other_path, Path):
other_path = other_path._path
return await to_thread.run_sync(
- self._path.samefile, other_path, cancellable=True
+ self._path.samefile, other_path, abandon_on_cancel=True
)
async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result:
func = partial(os.stat, follow_symlinks=follow_symlinks)
- return await to_thread.run_sync(func, self._path, cancellable=True)
+ return await to_thread.run_sync(func, self._path, abandon_on_cancel=True)
async def symlink_to(
self,
- target: str | pathlib.Path | Path,
+ target: str | bytes | PathLike[str] | PathLike[bytes],
target_is_directory: bool = False,
) -> None:
if isinstance(target, Path):
@@ -571,6 +587,29 @@ class Path:
if not missing_ok:
raise
+ if sys.version_info >= (3, 12):
+
+ async def walk(
+ self,
+ top_down: bool = True,
+ on_error: Callable[[OSError], object] | None = None,
+ follow_symlinks: bool = False,
+ ) -> AsyncIterator[tuple[Path, list[str], list[str]]]:
+ def get_next_value() -> tuple[pathlib.Path, list[str], list[str]] | None:
+ try:
+ return next(gen)
+ except StopIteration:
+ return None
+
+ gen = self._path.walk(top_down, on_error, follow_symlinks)
+ while True:
+ value = await to_thread.run_sync(get_next_value)
+ if value is None:
+ return
+
+ root, dirs, paths = value
+ yield Path(root), dirs, paths
+
def with_name(self, name: str) -> Path:
return Path(self._path.with_name(name))
@@ -580,6 +619,9 @@ class Path:
def with_suffix(self, suffix: str) -> Path:
return Path(self._path.with_suffix(suffix))
+ def with_segments(self, *pathsegments: str | PathLike[str]) -> Path:
+ return Path(*pathsegments)
+
async def write_bytes(self, data: bytes) -> int:
return await to_thread.run_sync(self._path.write_bytes, data)
diff --git a/contrib/python/anyio/anyio/_core/_signals.py b/contrib/python/anyio/anyio/_core/_signals.py
index 8ea54af86c..115c749bd9 100644
--- a/contrib/python/anyio/anyio/_core/_signals.py
+++ b/contrib/python/anyio/anyio/_core/_signals.py
@@ -1,26 +1,25 @@
from __future__ import annotations
-from typing import AsyncIterator
+from collections.abc import AsyncIterator
+from signal import Signals
+from typing import ContextManager
-from ._compat import DeprecatedAsyncContextManager
-from ._eventloop import get_asynclib
+from ._eventloop import get_async_backend
-def open_signal_receiver(
- *signals: int,
-) -> DeprecatedAsyncContextManager[AsyncIterator[int]]:
+def open_signal_receiver(*signals: Signals) -> ContextManager[AsyncIterator[Signals]]:
"""
Start receiving operating system signals.
:param signals: signals to receive (e.g. ``signal.SIGINT``)
- :return: an asynchronous context manager for an asynchronous iterator which yields signal
- numbers
+ :return: an asynchronous context manager for an asynchronous iterator which yields
+ signal numbers
- .. warning:: Windows does not support signals natively so it is best to avoid relying on this
- in cross-platform applications.
+ .. warning:: Windows does not support signals natively so it is best to avoid
+ relying on this in cross-platform applications.
- .. warning:: On asyncio, this permanently replaces any previous signal handler for the given
- signals, as set via :meth:`~asyncio.loop.add_signal_handler`.
+ .. warning:: On asyncio, this permanently replaces any previous signal handler for
+ the given signals, as set via :meth:`~asyncio.loop.add_signal_handler`.
"""
- return get_asynclib().open_signal_receiver(*signals)
+ return get_async_backend().open_signal_receiver(*signals)
diff --git a/contrib/python/anyio/anyio/_core/_sockets.py b/contrib/python/anyio/anyio/_core/_sockets.py
index e6970bee27..0f0a3142fb 100644
--- a/contrib/python/anyio/anyio/_core/_sockets.py
+++ b/contrib/python/anyio/anyio/_core/_sockets.py
@@ -1,41 +1,41 @@
from __future__ import annotations
+import errno
+import os
import socket
import ssl
+import stat
import sys
+from collections.abc import Awaitable
from ipaddress import IPv6Address, ip_address
from os import PathLike, chmod
-from pathlib import Path
from socket import AddressFamily, SocketKind
-from typing import Awaitable, List, Tuple, cast, overload
+from typing import Any, Literal, cast, overload
from .. import to_thread
from ..abc import (
ConnectedUDPSocket,
+ ConnectedUNIXDatagramSocket,
IPAddressType,
IPSockAddrType,
SocketListener,
SocketStream,
UDPSocket,
+ UNIXDatagramSocket,
UNIXSocketStream,
)
from ..streams.stapled import MultiListener
from ..streams.tls import TLSStream
-from ._eventloop import get_asynclib
+from ._eventloop import get_async_backend
from ._resources import aclose_forcefully
from ._synchronization import Event
from ._tasks import create_task_group, move_on_after
-if sys.version_info >= (3, 8):
- from typing import Literal
-else:
- from typing_extensions import Literal
+if sys.version_info < (3, 11):
+ from exceptiongroup import ExceptionGroup
IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515
-GetAddrInfoReturnType = List[
- Tuple[AddressFamily, SocketKind, int, str, Tuple[str, int]]
-]
AnyIPAddressFamily = Literal[
AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
]
@@ -142,18 +142,21 @@ async def connect_tcp(
:param remote_host: the IP address or host name to connect to
:param remote_port: port on the target host to connect to
- :param local_host: the interface address or name to bind the socket to before connecting
+ :param local_host: the interface address or name to bind the socket to before
+ connecting
:param tls: ``True`` to do a TLS handshake with the connected stream and return a
:class:`~anyio.streams.tls.TLSStream` instead
- :param ssl_context: the SSL context object to use (if omitted, a default context is created)
- :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake before closing
- the stream and requires that the server does this as well. Otherwise,
- :exc:`~ssl.SSLEOFError` may be raised during reads from the stream.
+ :param ssl_context: the SSL context object to use (if omitted, a default context is
+ created)
+ :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake
+ before closing the stream and requires that the server does this as well.
+ Otherwise, :exc:`~ssl.SSLEOFError` may be raised during reads from the stream.
Some protocols, such as HTTP, require this option to be ``False``.
See :meth:`~ssl.SSLContext.wrap_socket` for details.
- :param tls_hostname: host name to check the server certificate against (defaults to the value
- of ``remote_host``)
- :param happy_eyeballs_delay: delay (in seconds) before starting the next connection attempt
+ :param tls_hostname: host name to check the server certificate against (defaults to
+ the value of ``remote_host``)
+ :param happy_eyeballs_delay: delay (in seconds) before starting the next connection
+ attempt
:return: a socket stream object if no TLS handshake was done, otherwise a TLS stream
:raises OSError: if the connection attempt fails
@@ -177,7 +180,7 @@ async def connect_tcp(
finally:
event.set()
- asynclib = get_asynclib()
+ asynclib = get_async_backend()
local_address: IPSockAddrType | None = None
family = socket.AF_UNSPEC
if local_host:
@@ -193,8 +196,8 @@ async def connect_tcp(
target_host, remote_port, family=family, type=socket.SOCK_STREAM
)
- # Organize the list so that the first address is an IPv6 address (if available) and the
- # second one is an IPv4 addresses. The rest can be in whatever order.
+ # Organize the list so that the first address is an IPv6 address (if available)
+ # and the second one is an IPv4 addresses. The rest can be in whatever order.
v6_found = v4_found = False
target_addrs: list[tuple[socket.AddressFamily, str]] = []
for af, *rest, sa in gai_res:
@@ -221,7 +224,11 @@ async def connect_tcp(
await event.wait()
if connected_stream is None:
- cause = oserrors[0] if len(oserrors) == 1 else asynclib.ExceptionGroup(oserrors)
+ cause = (
+ oserrors[0]
+ if len(oserrors) == 1
+ else ExceptionGroup("multiple connection attempts failed", oserrors)
+ )
raise OSError("All connection attempts failed") from cause
if tls or tls_hostname or ssl_context:
@@ -240,7 +247,7 @@ async def connect_tcp(
return connected_stream
-async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream:
+async def connect_unix(path: str | bytes | PathLike[Any]) -> UNIXSocketStream:
"""
Connect to the given UNIX socket.
@@ -250,8 +257,8 @@ async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream:
:return: a socket stream object
"""
- path = str(Path(path))
- return await get_asynclib().connect_unix(path)
+ path = os.fspath(path)
+ return await get_async_backend().connect_unix(path)
async def create_tcp_listener(
@@ -277,11 +284,11 @@ async def create_tcp_listener(
:return: a list of listener objects
"""
- asynclib = get_asynclib()
+ asynclib = get_async_backend()
backlog = min(backlog, 65536)
local_host = str(local_host) if local_host is not None else None
gai_res = await getaddrinfo(
- local_host, # type: ignore[arg-type]
+ local_host,
local_port,
family=family,
type=socket.SocketKind.SOCK_STREAM if sys.platform == "win32" else 0,
@@ -302,7 +309,8 @@ async def create_tcp_listener(
raw_socket = socket.socket(fam)
raw_socket.setblocking(False)
- # For Windows, enable exclusive address use. For others, enable address reuse.
+ # For Windows, enable exclusive address use. For others, enable address
+ # reuse.
if sys.platform == "win32":
raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
else:
@@ -322,7 +330,7 @@ async def create_tcp_listener(
raw_socket.bind(sockaddr)
raw_socket.listen(backlog)
- listener = asynclib.TCPSocketListener(raw_socket)
+ listener = asynclib.create_tcp_listener(raw_socket)
listeners.append(listener)
except BaseException:
for listener in listeners:
@@ -334,7 +342,7 @@ async def create_tcp_listener(
async def create_unix_listener(
- path: str | PathLike[str],
+ path: str | bytes | PathLike[Any],
*,
mode: int | None = None,
backlog: int = 65536,
@@ -346,29 +354,20 @@ async def create_unix_listener(
:param path: path of the socket
:param mode: permissions to set on the socket
- :param backlog: maximum number of queued incoming connections (up to a maximum of 2**16, or
- 65536)
+ :param backlog: maximum number of queued incoming connections (up to a maximum of
+ 2**16, or 65536)
:return: a listener object
.. versionchanged:: 3.0
- If a socket already exists on the file system in the given path, it will be removed first.
+ If a socket already exists on the file system in the given path, it will be
+ removed first.
"""
- path_str = str(path)
- path = Path(path)
- if path.is_socket():
- path.unlink()
-
backlog = min(backlog, 65536)
- raw_socket = socket.socket(socket.AF_UNIX)
- raw_socket.setblocking(False)
+ raw_socket = await setup_unix_local_socket(path, mode, socket.SOCK_STREAM)
try:
- await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True)
- if mode is not None:
- await to_thread.run_sync(chmod, path_str, mode, cancellable=True)
-
raw_socket.listen(backlog)
- return get_asynclib().UNIXSocketListener(raw_socket)
+ return get_async_backend().create_unix_listener(raw_socket)
except BaseException:
raw_socket.close()
raise
@@ -384,15 +383,15 @@ async def create_udp_socket(
"""
Create a UDP socket.
- If ``local_port`` has been given, the socket will be bound to this port on the local
+ If ``port`` has been given, the socket will be bound to this port on the local
machine, making this socket suitable for providing UDP based services.
- :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from
- ``local_host`` if omitted
+ :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically
+ determined from ``local_host`` if omitted
:param local_host: IP address or host name of the local interface to bind to
:param local_port: local port to bind to
- :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port
- (not supported on Windows)
+ :param reuse_port: ``True`` to allow multiple sockets to bind to the same
+ address/port (not supported on Windows)
:return: a UDP socket
"""
@@ -414,9 +413,10 @@ async def create_udp_socket(
else:
local_address = ("0.0.0.0", 0)
- return await get_asynclib().create_udp_socket(
+ sock = await get_async_backend().create_udp_socket(
family, local_address, None, reuse_port
)
+ return cast(UDPSocket, sock)
async def create_connected_udp_socket(
@@ -431,17 +431,17 @@ async def create_connected_udp_socket(
"""
Create a connected UDP socket.
- Connected UDP sockets can only communicate with the specified remote host/port, and any packets
- sent from other sources are dropped.
+ Connected UDP sockets can only communicate with the specified remote host/port, an
+ any packets sent from other sources are dropped.
:param remote_host: remote host to set as the default target
:param remote_port: port on the remote host to set as the default target
- :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from
- ``local_host`` or ``remote_host`` if omitted
+ :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically
+ determined from ``local_host`` or ``remote_host`` if omitted
:param local_host: IP address or host name of the local interface to bind to
:param local_port: local port to bind to
- :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port
- (not supported on Windows)
+ :param reuse_port: ``True`` to allow multiple sockets to bind to the same
+ address/port (not supported on Windows)
:return: a connected UDP socket
"""
@@ -463,25 +463,87 @@ async def create_connected_udp_socket(
family = cast(AnyIPAddressFamily, gai_res[0][0])
remote_address = gai_res[0][-1]
- return await get_asynclib().create_udp_socket(
+ sock = await get_async_backend().create_udp_socket(
family, local_address, remote_address, reuse_port
)
+ return cast(ConnectedUDPSocket, sock)
+
+
+async def create_unix_datagram_socket(
+ *,
+ local_path: None | str | bytes | PathLike[Any] = None,
+ local_mode: int | None = None,
+) -> UNIXDatagramSocket:
+ """
+ Create a UNIX datagram socket.
+
+ Not available on Windows.
+
+ If ``local_path`` has been given, the socket will be bound to this path, making this
+ socket suitable for receiving datagrams from other processes. Other processes can
+ send datagrams to this socket only if ``local_path`` is set.
+
+ If a socket already exists on the file system in the ``local_path``, it will be
+ removed first.
+
+ :param local_path: the path on which to bind to
+ :param local_mode: permissions to set on the local socket
+ :return: a UNIX datagram socket
+
+ """
+ raw_socket = await setup_unix_local_socket(
+ local_path, local_mode, socket.SOCK_DGRAM
+ )
+ return await get_async_backend().create_unix_datagram_socket(raw_socket, None)
+
+
+async def create_connected_unix_datagram_socket(
+ remote_path: str | bytes | PathLike[Any],
+ *,
+ local_path: None | str | bytes | PathLike[Any] = None,
+ local_mode: int | None = None,
+) -> ConnectedUNIXDatagramSocket:
+ """
+ Create a connected UNIX datagram socket.
+
+ Connected datagram sockets can only communicate with the specified remote path.
+
+ If ``local_path`` has been given, the socket will be bound to this path, making
+ this socket suitable for receiving datagrams from other processes. Other processes
+ can send datagrams to this socket only if ``local_path`` is set.
+
+ If a socket already exists on the file system in the ``local_path``, it will be
+ removed first.
+
+ :param remote_path: the path to set as the default target
+ :param local_path: the path on which to bind to
+ :param local_mode: permissions to set on the local socket
+ :return: a connected UNIX datagram socket
+
+ """
+ remote_path = os.fspath(remote_path)
+ raw_socket = await setup_unix_local_socket(
+ local_path, local_mode, socket.SOCK_DGRAM
+ )
+ return await get_async_backend().create_unix_datagram_socket(
+ raw_socket, remote_path
+ )
async def getaddrinfo(
- host: bytearray | bytes | str,
+ host: bytes | str | None,
port: str | int | None,
*,
family: int | AddressFamily = 0,
type: int | SocketKind = 0,
proto: int = 0,
flags: int = 0,
-) -> GetAddrInfoReturnType:
+) -> list[tuple[AddressFamily, SocketKind, int, str, tuple[str, int]]]:
"""
Look up a numeric IP address given a host name.
- Internationalized domain names are translated according to the (non-transitional) IDNA 2008
- standard.
+ Internationalized domain names are translated according to the (non-transitional)
+ IDNA 2008 standard.
.. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of
(host, port), unlike what :func:`socket.getaddrinfo` does.
@@ -500,7 +562,7 @@ async def getaddrinfo(
# Handle unicode hostnames
if isinstance(host, str):
try:
- encoded_host = host.encode("ascii")
+ encoded_host: bytes | None = host.encode("ascii")
except UnicodeEncodeError:
import idna
@@ -508,7 +570,7 @@ async def getaddrinfo(
else:
encoded_host = host
- gai_res = await get_asynclib().getaddrinfo(
+ gai_res = await get_async_backend().getaddrinfo(
encoded_host, port, family=family, type=type, proto=proto, flags=flags
)
return [
@@ -528,18 +590,18 @@ def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str
.. seealso:: :func:`socket.getnameinfo`
"""
- return get_asynclib().getnameinfo(sockaddr, flags)
+ return get_async_backend().getnameinfo(sockaddr, flags)
def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
"""
Wait until the given socket has data to be read.
- This does **NOT** work on Windows when using the asyncio backend with a proactor event loop
- (default on py3.8+).
+ This does **NOT** work on Windows when using the asyncio backend with a proactor
+ event loop (default on py3.8+).
- .. warning:: Only use this on raw sockets that have not been wrapped by any higher level
- constructs like socket streams!
+ .. warning:: Only use this on raw sockets that have not been wrapped by any higher
+ level constructs like socket streams!
:param sock: a socket object
:raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
@@ -548,18 +610,18 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
to become readable
"""
- return get_asynclib().wait_socket_readable(sock)
+ return get_async_backend().wait_socket_readable(sock)
def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
"""
Wait until the given socket can be written to.
- This does **NOT** work on Windows when using the asyncio backend with a proactor event loop
- (default on py3.8+).
+ This does **NOT** work on Windows when using the asyncio backend with a proactor
+ event loop (default on py3.8+).
- .. warning:: Only use this on raw sockets that have not been wrapped by any higher level
- constructs like socket streams!
+ .. warning:: Only use this on raw sockets that have not been wrapped by any higher
+ level constructs like socket streams!
:param sock: a socket object
:raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
@@ -568,7 +630,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
to become writable
"""
- return get_asynclib().wait_socket_writable(sock)
+ return get_async_backend().wait_socket_writable(sock)
#
@@ -577,7 +639,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
def convert_ipv6_sockaddr(
- sockaddr: tuple[str, int, int, int] | tuple[str, int]
+ sockaddr: tuple[str, int, int, int] | tuple[str, int],
) -> tuple[str, int]:
"""
Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format.
@@ -592,7 +654,7 @@ def convert_ipv6_sockaddr(
"""
# This is more complicated than it should be because of MyPy
if isinstance(sockaddr, tuple) and len(sockaddr) == 4:
- host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr)
+ host, port, flowinfo, scope_id = sockaddr
if scope_id:
# PyPy (as of v7.3.11) leaves the interface name in the result, so
# we discard it and only get the scope ID from the end
@@ -604,4 +666,51 @@ def convert_ipv6_sockaddr(
else:
return host, port
else:
- return cast(Tuple[str, int], sockaddr)
+ return sockaddr
+
+
+async def setup_unix_local_socket(
+ path: None | str | bytes | PathLike[Any],
+ mode: int | None,
+ socktype: int,
+) -> socket.socket:
+ """
+ Create a UNIX local socket object, deleting the socket at the given path if it
+ exists.
+
+ Not available on Windows.
+
+ :param path: path of the socket
+ :param mode: permissions to set on the socket
+ :param socktype: socket.SOCK_STREAM or socket.SOCK_DGRAM
+
+ """
+ path_str: str | bytes | None
+ if path is not None:
+ path_str = os.fspath(path)
+
+ # Copied from pathlib...
+ try:
+ stat_result = os.stat(path)
+ except OSError as e:
+ if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP):
+ raise
+ else:
+ if stat.S_ISSOCK(stat_result.st_mode):
+ os.unlink(path)
+ else:
+ path_str = None
+
+ raw_socket = socket.socket(socket.AF_UNIX, socktype)
+ raw_socket.setblocking(False)
+
+ if path_str is not None:
+ try:
+ await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True)
+ if mode is not None:
+ await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True)
+ except BaseException:
+ raw_socket.close()
+ raise
+
+ return raw_socket
diff --git a/contrib/python/anyio/anyio/_core/_streams.py b/contrib/python/anyio/anyio/_core/_streams.py
index 54ea2b2baf..aa6b0c222a 100644
--- a/contrib/python/anyio/anyio/_core/_streams.py
+++ b/contrib/python/anyio/anyio/_core/_streams.py
@@ -1,7 +1,8 @@
from __future__ import annotations
import math
-from typing import Any, TypeVar, overload
+from typing import Tuple, TypeVar
+from warnings import warn
from ..streams.memory import (
MemoryObjectReceiveStream,
@@ -12,36 +13,40 @@ from ..streams.memory import (
T_Item = TypeVar("T_Item")
-@overload
-def create_memory_object_stream(
- max_buffer_size: float = ...,
-) -> tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]:
- ...
-
-
-@overload
-def create_memory_object_stream(
- max_buffer_size: float = ..., item_type: type[T_Item] = ...
-) -> tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]]:
- ...
-
-
-def create_memory_object_stream(
- max_buffer_size: float = 0, item_type: type[T_Item] | None = None
-) -> tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]:
+class create_memory_object_stream(
+ Tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]],
+):
"""
Create a memory object stream.
- :param max_buffer_size: number of items held in the buffer until ``send()`` starts blocking
- :param item_type: type of item, for marking the streams with the right generic type for
- static typing (not used at run time)
+ The stream's item type can be annotated like
+ :func:`create_memory_object_stream[T_Item]`.
+
+ :param max_buffer_size: number of items held in the buffer until ``send()`` starts
+ blocking
+ :param item_type: old way of marking the streams with the right generic type for
+ static typing (does nothing on AnyIO 4)
+
+ .. deprecated:: 4.0
+ Use ``create_memory_object_stream[YourItemType](...)`` instead.
:return: a tuple of (send stream, receive stream)
"""
- if max_buffer_size != math.inf and not isinstance(max_buffer_size, int):
- raise ValueError("max_buffer_size must be either an integer or math.inf")
- if max_buffer_size < 0:
- raise ValueError("max_buffer_size cannot be negative")
- state: MemoryObjectStreamState = MemoryObjectStreamState(max_buffer_size)
- return MemoryObjectSendStream(state), MemoryObjectReceiveStream(state)
+ def __new__( # type: ignore[misc]
+ cls, max_buffer_size: float = 0, item_type: object = None
+ ) -> tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]]:
+ if max_buffer_size != math.inf and not isinstance(max_buffer_size, int):
+ raise ValueError("max_buffer_size must be either an integer or math.inf")
+ if max_buffer_size < 0:
+ raise ValueError("max_buffer_size cannot be negative")
+ if item_type is not None:
+ warn(
+ "The item_type argument has been deprecated in AnyIO 4.0. "
+ "Use create_memory_object_stream[YourItemType](...) instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ state = MemoryObjectStreamState[T_Item](max_buffer_size)
+ return (MemoryObjectSendStream(state), MemoryObjectReceiveStream(state))
diff --git a/contrib/python/anyio/anyio/_core/_subprocesses.py b/contrib/python/anyio/anyio/_core/_subprocesses.py
index 1a26ac8c7f..5d5d7b768a 100644
--- a/contrib/python/anyio/anyio/_core/_subprocesses.py
+++ b/contrib/python/anyio/anyio/_core/_subprocesses.py
@@ -1,19 +1,13 @@
from __future__ import annotations
+from collections.abc import AsyncIterable, Mapping, Sequence
from io import BytesIO
from os import PathLike
from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess
-from typing import (
- IO,
- Any,
- AsyncIterable,
- Mapping,
- Sequence,
- cast,
-)
+from typing import IO, Any, cast
from ..abc import Process
-from ._eventloop import get_asynclib
+from ._eventloop import get_async_backend
from ._tasks import create_task_group
@@ -33,22 +27,24 @@ async def run_process(
.. seealso:: :func:`subprocess.run`
- :param command: either a string to pass to the shell, or an iterable of strings containing the
- executable name or path and its arguments
+ :param command: either a string to pass to the shell, or an iterable of strings
+ containing the executable name or path and its arguments
:param input: bytes passed to the standard input of the subprocess
- :param stdout: either :data:`subprocess.PIPE` or :data:`subprocess.DEVNULL`
- :param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL` or
- :data:`subprocess.STDOUT`
- :param check: if ``True``, raise :exc:`~subprocess.CalledProcessError` if the process
- terminates with a return code other than 0
- :param cwd: If not ``None``, change the working directory to this before running the command
- :param env: if not ``None``, this mapping replaces the inherited environment variables from the
- parent process
- :param start_new_session: if ``true`` the setsid() system call will be made in the child
- process prior to the execution of the subprocess. (POSIX only)
+ :param stdout: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
+ a file-like object, or `None`
+ :param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
+ :data:`subprocess.STDOUT`, a file-like object, or `None`
+ :param check: if ``True``, raise :exc:`~subprocess.CalledProcessError` if the
+ process terminates with a return code other than 0
+ :param cwd: If not ``None``, change the working directory to this before running the
+ command
+ :param env: if not ``None``, this mapping replaces the inherited environment
+ variables from the parent process
+ :param start_new_session: if ``true`` the setsid() system call will be made in the
+ child process prior to the execution of the subprocess. (POSIX only)
:return: an object representing the completed process
- :raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process exits with a
- nonzero return code
+ :raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process
+ exits with a nonzero return code
"""
@@ -69,20 +65,18 @@ async def run_process(
start_new_session=start_new_session,
) as process:
stream_contents: list[bytes | None] = [None, None]
- try:
- async with create_task_group() as tg:
- if process.stdout:
- tg.start_soon(drain_stream, process.stdout, 0)
- if process.stderr:
- tg.start_soon(drain_stream, process.stderr, 1)
- if process.stdin and input:
- await process.stdin.send(input)
- await process.stdin.aclose()
-
- await process.wait()
- except BaseException:
- process.kill()
- raise
+ async with create_task_group() as tg:
+ if process.stdout:
+ tg.start_soon(drain_stream, process.stdout, 0)
+
+ if process.stderr:
+ tg.start_soon(drain_stream, process.stderr, 1)
+
+ if process.stdin and input:
+ await process.stdin.send(input)
+ await process.stdin.aclose()
+
+ await process.wait()
output, errors = stream_contents
if check and process.returncode != 0:
@@ -106,8 +100,8 @@ async def open_process(
.. seealso:: :class:`subprocess.Popen`
- :param command: either a string to pass to the shell, or an iterable of strings containing the
- executable name or path and its arguments
+ :param command: either a string to pass to the shell, or an iterable of strings
+ containing the executable name or path and its arguments
:param stdin: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, a
file-like object, or ``None``
:param stdout: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
@@ -115,21 +109,32 @@ async def open_process(
:param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
:data:`subprocess.STDOUT`, a file-like object, or ``None``
:param cwd: If not ``None``, the working directory is changed before executing
- :param env: If env is not ``None``, it must be a mapping that defines the environment
- variables for the new process
- :param start_new_session: if ``true`` the setsid() system call will be made in the child
- process prior to the execution of the subprocess. (POSIX only)
+ :param env: If env is not ``None``, it must be a mapping that defines the
+ environment variables for the new process
+ :param start_new_session: if ``true`` the setsid() system call will be made in the
+ child process prior to the execution of the subprocess. (POSIX only)
:return: an asynchronous process object
"""
- shell = isinstance(command, str)
- return await get_asynclib().open_process(
- command,
- shell=shell,
- stdin=stdin,
- stdout=stdout,
- stderr=stderr,
- cwd=cwd,
- env=env,
- start_new_session=start_new_session,
- )
+ if isinstance(command, (str, bytes)):
+ return await get_async_backend().open_process(
+ command,
+ shell=True,
+ stdin=stdin,
+ stdout=stdout,
+ stderr=stderr,
+ cwd=cwd,
+ env=env,
+ start_new_session=start_new_session,
+ )
+ else:
+ return await get_async_backend().open_process(
+ command,
+ shell=False,
+ stdin=stdin,
+ stdout=stdout,
+ stderr=stderr,
+ cwd=cwd,
+ env=env,
+ start_new_session=start_new_session,
+ )
diff --git a/contrib/python/anyio/anyio/_core/_synchronization.py b/contrib/python/anyio/anyio/_core/_synchronization.py
index 783570c7ac..b274a31ea2 100644
--- a/contrib/python/anyio/anyio/_core/_synchronization.py
+++ b/contrib/python/anyio/anyio/_core/_synchronization.py
@@ -1,13 +1,14 @@
from __future__ import annotations
+import math
from collections import deque
from dataclasses import dataclass
from types import TracebackType
-from warnings import warn
+
+from sniffio import AsyncLibraryNotFoundError
from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled
-from ._compat import DeprecatedAwaitable
-from ._eventloop import get_asynclib
+from ._eventloop import get_async_backend
from ._exceptions import BusyResourceError, WouldBlock
from ._tasks import CancelScope
from ._testing import TaskInfo, get_current_task
@@ -27,9 +28,10 @@ class CapacityLimiterStatistics:
"""
:ivar int borrowed_tokens: number of tokens currently borrowed by tasks
:ivar float total_tokens: total number of available tokens
- :ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from this
- limiter
- :ivar int tasks_waiting: number of tasks waiting on :meth:`~.CapacityLimiter.acquire` or
+ :ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from
+ this limiter
+ :ivar int tasks_waiting: number of tasks waiting on
+ :meth:`~.CapacityLimiter.acquire` or
:meth:`~.CapacityLimiter.acquire_on_behalf_of`
"""
@@ -43,8 +45,8 @@ class CapacityLimiterStatistics:
class LockStatistics:
"""
:ivar bool locked: flag indicating if this lock is locked or not
- :ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the lock is not
- held by any task)
+ :ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the
+ lock is not held by any task)
:ivar int tasks_waiting: number of tasks waiting on :meth:`~.Lock.acquire`
"""
@@ -57,7 +59,8 @@ class LockStatistics:
class ConditionStatistics:
"""
:ivar int tasks_waiting: number of tasks blocked on :meth:`~.Condition.wait`
- :ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying :class:`~.Lock`
+ :ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying
+ :class:`~.Lock`
"""
tasks_waiting: int
@@ -76,9 +79,12 @@ class SemaphoreStatistics:
class Event:
def __new__(cls) -> Event:
- return get_asynclib().Event()
+ try:
+ return get_async_backend().create_event()
+ except AsyncLibraryNotFoundError:
+ return EventAdapter()
- def set(self) -> DeprecatedAwaitable:
+ def set(self) -> None:
"""Set the flag, notifying all listeners."""
raise NotImplementedError
@@ -90,7 +96,8 @@ class Event:
"""
Wait until the flag has been set.
- If the flag has already been set when this method is called, it returns immediately.
+ If the flag has already been set when this method is called, it returns
+ immediately.
"""
raise NotImplementedError
@@ -100,6 +107,35 @@ class Event:
raise NotImplementedError
+class EventAdapter(Event):
+ _internal_event: Event | None = None
+
+ def __new__(cls) -> EventAdapter:
+ return object.__new__(cls)
+
+ @property
+ def _event(self) -> Event:
+ if self._internal_event is None:
+ self._internal_event = get_async_backend().create_event()
+
+ return self._internal_event
+
+ def set(self) -> None:
+ self._event.set()
+
+ def is_set(self) -> bool:
+ return self._internal_event is not None and self._internal_event.is_set()
+
+ async def wait(self) -> None:
+ await self._event.wait()
+
+ def statistics(self) -> EventStatistics:
+ if self._internal_event is None:
+ return EventStatistics(tasks_waiting=0)
+
+ return self._internal_event.statistics()
+
+
class Lock:
_owner_task: TaskInfo | None = None
@@ -161,7 +197,7 @@ class Lock:
self._owner_task = task
- def release(self) -> DeprecatedAwaitable:
+ def release(self) -> None:
"""Release the lock."""
if self._owner_task != get_current_task():
raise RuntimeError("The current task is not holding this lock")
@@ -172,8 +208,6 @@ class Lock:
else:
del self._owner_task
- return DeprecatedAwaitable(self.release)
-
def locked(self) -> bool:
"""Return True if the lock is currently held."""
return self._owner_task is not None
@@ -224,10 +258,9 @@ class Condition:
self._lock.acquire_nowait()
self._owner_task = get_current_task()
- def release(self) -> DeprecatedAwaitable:
+ def release(self) -> None:
"""Release the underlying lock."""
self._lock.release()
- return DeprecatedAwaitable(self.release)
def locked(self) -> bool:
"""Return True if the lock is set."""
@@ -344,7 +377,7 @@ class Semaphore:
self._value -= 1
- def release(self) -> DeprecatedAwaitable:
+ def release(self) -> None:
"""Increment the semaphore value."""
if self._max_value is not None and self._value == self._max_value:
raise ValueError("semaphore released too many times")
@@ -354,8 +387,6 @@ class Semaphore:
else:
self._value += 1
- return DeprecatedAwaitable(self.release)
-
@property
def value(self) -> int:
"""The current value of the semaphore."""
@@ -377,7 +408,10 @@ class Semaphore:
class CapacityLimiter:
def __new__(cls, total_tokens: float) -> CapacityLimiter:
- return get_asynclib().CapacityLimiter(total_tokens)
+ try:
+ return get_async_backend().create_capacity_limiter(total_tokens)
+ except AsyncLibraryNotFoundError:
+ return CapacityLimiterAdapter(total_tokens)
async def __aenter__(self) -> None:
raise NotImplementedError
@@ -396,7 +430,8 @@ class CapacityLimiter:
The total number of tokens available for borrowing.
This is a read-write property. If the total number of tokens is increased, the
- proportionate number of tasks waiting on this limiter will be granted their tokens.
+ proportionate number of tasks waiting on this limiter will be granted their
+ tokens.
.. versionchanged:: 3.0
The property is now writable.
@@ -408,14 +443,6 @@ class CapacityLimiter:
def total_tokens(self, value: float) -> None:
raise NotImplementedError
- async def set_total_tokens(self, value: float) -> None:
- warn(
- "CapacityLimiter.set_total_tokens has been deprecated. Set the value of the"
- '"total_tokens" attribute directly.',
- DeprecationWarning,
- )
- self.total_tokens = value
-
@property
def borrowed_tokens(self) -> int:
"""The number of tokens that have currently been borrowed."""
@@ -426,16 +453,17 @@ class CapacityLimiter:
"""The number of tokens currently available to be borrowed"""
raise NotImplementedError
- def acquire_nowait(self) -> DeprecatedAwaitable:
+ def acquire_nowait(self) -> None:
"""
- Acquire a token for the current task without waiting for one to become available.
+ Acquire a token for the current task without waiting for one to become
+ available.
:raises ~anyio.WouldBlock: if there are no tokens available for borrowing
"""
raise NotImplementedError
- def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
"""
Acquire a token without waiting for one to become available.
@@ -447,7 +475,8 @@ class CapacityLimiter:
async def acquire(self) -> None:
"""
- Acquire a token for the current task, waiting if necessary for one to become available.
+ Acquire a token for the current task, waiting if necessary for one to become
+ available.
"""
raise NotImplementedError
@@ -464,7 +493,9 @@ class CapacityLimiter:
def release(self) -> None:
"""
Release the token held by the current task.
- :raises RuntimeError: if the current task has not borrowed a token from this limiter.
+
+ :raises RuntimeError: if the current task has not borrowed a token from this
+ limiter.
"""
raise NotImplementedError
@@ -473,7 +504,8 @@ class CapacityLimiter:
"""
Release the token held by the given borrower.
- :raises RuntimeError: if the borrower has not borrowed a token from this limiter.
+ :raises RuntimeError: if the borrower has not borrowed a token from this
+ limiter.
"""
raise NotImplementedError
@@ -488,96 +520,117 @@ class CapacityLimiter:
raise NotImplementedError
-def create_lock() -> Lock:
- """
- Create an asynchronous lock.
+class CapacityLimiterAdapter(CapacityLimiter):
+ _internal_limiter: CapacityLimiter | None = None
- :return: a lock object
+ def __new__(cls, total_tokens: float) -> CapacityLimiterAdapter:
+ return object.__new__(cls)
- .. deprecated:: 3.0
- Use :class:`~Lock` directly.
+ def __init__(self, total_tokens: float) -> None:
+ self.total_tokens = total_tokens
- """
- warn("create_lock() is deprecated -- use Lock() directly", DeprecationWarning)
- return Lock()
+ @property
+ def _limiter(self) -> CapacityLimiter:
+ if self._internal_limiter is None:
+ self._internal_limiter = get_async_backend().create_capacity_limiter(
+ self._total_tokens
+ )
+ return self._internal_limiter
-def create_condition(lock: Lock | None = None) -> Condition:
- """
- Create an asynchronous condition.
+ async def __aenter__(self) -> None:
+ await self._limiter.__aenter__()
- :param lock: the lock to base the condition object on
- :return: a condition object
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> bool | None:
+ return await self._limiter.__aexit__(exc_type, exc_val, exc_tb)
- .. deprecated:: 3.0
- Use :class:`~Condition` directly.
+ @property
+ def total_tokens(self) -> float:
+ if self._internal_limiter is None:
+ return self._total_tokens
- """
- warn(
- "create_condition() is deprecated -- use Condition() directly",
- DeprecationWarning,
- )
- return Condition(lock=lock)
+ return self._internal_limiter.total_tokens
+ @total_tokens.setter
+ def total_tokens(self, value: float) -> None:
+ if not isinstance(value, int) and value is not math.inf:
+ raise TypeError("total_tokens must be an int or math.inf")
+ elif value < 1:
+ raise ValueError("total_tokens must be >= 1")
-def create_event() -> Event:
- """
- Create an asynchronous event object.
+ if self._internal_limiter is None:
+ self._total_tokens = value
+ return
- :return: an event object
+ self._limiter.total_tokens = value
- .. deprecated:: 3.0
- Use :class:`~Event` directly.
+ @property
+ def borrowed_tokens(self) -> int:
+ if self._internal_limiter is None:
+ return 0
- """
- warn("create_event() is deprecated -- use Event() directly", DeprecationWarning)
- return get_asynclib().Event()
+ return self._internal_limiter.borrowed_tokens
+ @property
+ def available_tokens(self) -> float:
+ if self._internal_limiter is None:
+ return self._total_tokens
-def create_semaphore(value: int, *, max_value: int | None = None) -> Semaphore:
- """
- Create an asynchronous semaphore.
+ return self._internal_limiter.available_tokens
- :param value: the semaphore's initial value
- :param max_value: if set, makes this a "bounded" semaphore that raises :exc:`ValueError` if the
- semaphore's value would exceed this number
- :return: a semaphore object
+ def acquire_nowait(self) -> None:
+ self._limiter.acquire_nowait()
- .. deprecated:: 3.0
- Use :class:`~Semaphore` directly.
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
+ self._limiter.acquire_on_behalf_of_nowait(borrower)
- """
- warn(
- "create_semaphore() is deprecated -- use Semaphore() directly",
- DeprecationWarning,
- )
- return Semaphore(value, max_value=max_value)
+ async def acquire(self) -> None:
+ await self._limiter.acquire()
+ async def acquire_on_behalf_of(self, borrower: object) -> None:
+ await self._limiter.acquire_on_behalf_of(borrower)
-def create_capacity_limiter(total_tokens: float) -> CapacityLimiter:
- """
- Create a capacity limiter.
+ def release(self) -> None:
+ self._limiter.release()
- :param total_tokens: the total number of tokens available for borrowing (can be an integer or
- :data:`math.inf`)
- :return: a capacity limiter object
+ def release_on_behalf_of(self, borrower: object) -> None:
+ self._limiter.release_on_behalf_of(borrower)
- .. deprecated:: 3.0
- Use :class:`~CapacityLimiter` directly.
+ def statistics(self) -> CapacityLimiterStatistics:
+ if self._internal_limiter is None:
+ return CapacityLimiterStatistics(
+ borrowed_tokens=0,
+ total_tokens=self.total_tokens,
+ borrowers=(),
+ tasks_waiting=0,
+ )
- """
- warn(
- "create_capacity_limiter() is deprecated -- use CapacityLimiter() directly",
- DeprecationWarning,
- )
- return get_asynclib().CapacityLimiter(total_tokens)
+ return self._internal_limiter.statistics()
class ResourceGuard:
+ """
+ A context manager for ensuring that a resource is only used by a single task at a
+ time.
+
+ Entering this context manager while the previous has not exited it yet will trigger
+ :exc:`BusyResourceError`.
+
+ :param action: the action to guard against (visible in the :exc:`BusyResourceError`
+ when triggered, e.g. "Another task is already {action} this resource")
+
+ .. versionadded:: 4.1
+ """
+
__slots__ = "action", "_guarded"
- def __init__(self, action: str):
- self.action = action
+ def __init__(self, action: str = "using"):
+ self.action: str = action
self._guarded = False
def __enter__(self) -> None:
diff --git a/contrib/python/anyio/anyio/_core/_tasks.py b/contrib/python/anyio/anyio/_core/_tasks.py
index e9d9c2bd67..2f21ea20b1 100644
--- a/contrib/python/anyio/anyio/_core/_tasks.py
+++ b/contrib/python/anyio/anyio/_core/_tasks.py
@@ -1,16 +1,12 @@
from __future__ import annotations
import math
+from collections.abc import Generator
+from contextlib import contextmanager
from types import TracebackType
-from warnings import warn
from ..abc._tasks import TaskGroup, TaskStatus
-from ._compat import (
- DeprecatedAsyncContextManager,
- DeprecatedAwaitable,
- DeprecatedAwaitableFloat,
-)
-from ._eventloop import get_asynclib
+from ._eventloop import get_async_backend
class _IgnoredTaskStatus(TaskStatus[object]):
@@ -21,7 +17,7 @@ class _IgnoredTaskStatus(TaskStatus[object]):
TASK_STATUS_IGNORED = _IgnoredTaskStatus()
-class CancelScope(DeprecatedAsyncContextManager["CancelScope"]):
+class CancelScope:
"""
Wraps a unit of work that can be made separately cancellable.
@@ -32,9 +28,9 @@ class CancelScope(DeprecatedAsyncContextManager["CancelScope"]):
def __new__(
cls, *, deadline: float = math.inf, shield: bool = False
) -> CancelScope:
- return get_asynclib().CancelScope(shield=shield, deadline=deadline)
+ return get_async_backend().create_cancel_scope(shield=shield, deadline=deadline)
- def cancel(self) -> DeprecatedAwaitable:
+ def cancel(self) -> None:
"""Cancel this scope immediately."""
raise NotImplementedError
@@ -58,6 +54,19 @@ class CancelScope(DeprecatedAsyncContextManager["CancelScope"]):
raise NotImplementedError
@property
+ def cancelled_caught(self) -> bool:
+ """
+ ``True`` if this scope suppressed a cancellation exception it itself raised.
+
+ This is typically used to check if any work was interrupted, or to see if the
+ scope was cancelled due to its deadline being reached. The value will, however,
+ only be ``True`` if the cancellation was triggered by the scope itself (and not
+ an outer scope).
+
+ """
+ raise NotImplementedError
+
+ @property
def shield(self) -> bool:
"""
``True`` if this scope is shielded from external cancellation.
@@ -83,81 +92,52 @@ class CancelScope(DeprecatedAsyncContextManager["CancelScope"]):
raise NotImplementedError
-def open_cancel_scope(*, shield: bool = False) -> CancelScope:
+@contextmanager
+def fail_after(
+ delay: float | None, shield: bool = False
+) -> Generator[CancelScope, None, None]:
"""
- Open a cancel scope.
+ Create a context manager which raises a :class:`TimeoutError` if does not finish in
+ time.
- :param shield: ``True`` to shield the cancel scope from external cancellation
- :return: a cancel scope
-
- .. deprecated:: 3.0
- Use :class:`~CancelScope` directly.
-
- """
- warn(
- "open_cancel_scope() is deprecated -- use CancelScope() directly",
- DeprecationWarning,
- )
- return get_asynclib().CancelScope(shield=shield)
-
-
-class FailAfterContextManager(DeprecatedAsyncContextManager[CancelScope]):
- def __init__(self, cancel_scope: CancelScope):
- self._cancel_scope = cancel_scope
-
- def __enter__(self) -> CancelScope:
- return self._cancel_scope.__enter__()
-
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> bool | None:
- retval = self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
- if self._cancel_scope.cancel_called:
- raise TimeoutError
-
- return retval
-
-
-def fail_after(delay: float | None, shield: bool = False) -> FailAfterContextManager:
- """
- Create a context manager which raises a :class:`TimeoutError` if does not finish in time.
-
- :param delay: maximum allowed time (in seconds) before raising the exception, or ``None`` to
- disable the timeout
+ :param delay: maximum allowed time (in seconds) before raising the exception, or
+ ``None`` to disable the timeout
:param shield: ``True`` to shield the cancel scope from external cancellation
:return: a context manager that yields a cancel scope
:rtype: :class:`~typing.ContextManager`\\[:class:`~anyio.CancelScope`\\]
"""
- deadline = (
- (get_asynclib().current_time() + delay) if delay is not None else math.inf
- )
- cancel_scope = get_asynclib().CancelScope(deadline=deadline, shield=shield)
- return FailAfterContextManager(cancel_scope)
+ current_time = get_async_backend().current_time
+ deadline = (current_time() + delay) if delay is not None else math.inf
+ with get_async_backend().create_cancel_scope(
+ deadline=deadline, shield=shield
+ ) as cancel_scope:
+ yield cancel_scope
+
+ if cancel_scope.cancelled_caught and current_time() >= cancel_scope.deadline:
+ raise TimeoutError
def move_on_after(delay: float | None, shield: bool = False) -> CancelScope:
"""
Create a cancel scope with a deadline that expires after the given delay.
- :param delay: maximum allowed time (in seconds) before exiting the context block, or ``None``
- to disable the timeout
+ :param delay: maximum allowed time (in seconds) before exiting the context block, or
+ ``None`` to disable the timeout
:param shield: ``True`` to shield the cancel scope from external cancellation
:return: a cancel scope
"""
deadline = (
- (get_asynclib().current_time() + delay) if delay is not None else math.inf
+ (get_async_backend().current_time() + delay) if delay is not None else math.inf
)
- return get_asynclib().CancelScope(deadline=deadline, shield=shield)
+ return get_async_backend().create_cancel_scope(deadline=deadline, shield=shield)
-def current_effective_deadline() -> DeprecatedAwaitableFloat:
+def current_effective_deadline() -> float:
"""
- Return the nearest deadline among all the cancel scopes effective for the current task.
+ Return the nearest deadline among all the cancel scopes effective for the current
+ task.
:return: a clock value from the event loop's internal clock (or ``float('inf')`` if
there is no deadline in effect, or ``float('-inf')`` if the current scope has
@@ -165,9 +145,7 @@ def current_effective_deadline() -> DeprecatedAwaitableFloat:
:rtype: float
"""
- return DeprecatedAwaitableFloat(
- get_asynclib().current_effective_deadline(), current_effective_deadline
- )
+ return get_async_backend().current_effective_deadline()
def create_task_group() -> TaskGroup:
@@ -177,4 +155,4 @@ def create_task_group() -> TaskGroup:
:return: a task group
"""
- return get_asynclib().TaskGroup()
+ return get_async_backend().create_task_group()
diff --git a/contrib/python/anyio/anyio/_core/_testing.py b/contrib/python/anyio/anyio/_core/_testing.py
index c8191b3866..1dae3b193a 100644
--- a/contrib/python/anyio/anyio/_core/_testing.py
+++ b/contrib/python/anyio/anyio/_core/_testing.py
@@ -1,9 +1,9 @@
from __future__ import annotations
-from typing import Any, Awaitable, Generator
+from collections.abc import Awaitable, Generator
+from typing import Any
-from ._compat import DeprecatedAwaitableList, _warn_deprecation
-from ._eventloop import get_asynclib
+from ._eventloop import get_async_backend
class TaskInfo:
@@ -45,13 +45,6 @@ class TaskInfo:
def __repr__(self) -> str:
return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})"
- def __await__(self) -> Generator[None, None, TaskInfo]:
- _warn_deprecation(self)
- if False:
- yield
-
- return self
-
def _unwrap(self) -> TaskInfo:
return self
@@ -63,20 +56,19 @@ def get_current_task() -> TaskInfo:
:return: a representation of the current task
"""
- return get_asynclib().get_current_task()
+ return get_async_backend().get_current_task()
-def get_running_tasks() -> DeprecatedAwaitableList[TaskInfo]:
+def get_running_tasks() -> list[TaskInfo]:
"""
Return a list of running tasks in the current event loop.
:return: a list of task info objects
"""
- tasks = get_asynclib().get_running_tasks()
- return DeprecatedAwaitableList(tasks, func=get_running_tasks)
+ return get_async_backend().get_running_tasks()
async def wait_all_tasks_blocked() -> None:
"""Wait until all other tasks are waiting for something."""
- await get_asynclib().wait_all_tasks_blocked()
+ await get_async_backend().wait_all_tasks_blocked()
diff --git a/contrib/python/anyio/anyio/_core/_typedattr.py b/contrib/python/anyio/anyio/_core/_typedattr.py
index bf9202eeab..74c6b8fdcb 100644
--- a/contrib/python/anyio/anyio/_core/_typedattr.py
+++ b/contrib/python/anyio/anyio/_core/_typedattr.py
@@ -1,15 +1,10 @@
from __future__ import annotations
-import sys
-from typing import Any, Callable, Mapping, TypeVar, overload
+from collections.abc import Callable, Mapping
+from typing import Any, TypeVar, final, overload
from ._exceptions import TypedAttributeLookupError
-if sys.version_info >= (3, 8):
- from typing import final
-else:
- from typing_extensions import final
-
T_Attr = TypeVar("T_Attr")
T_Default = TypeVar("T_Default")
undefined = object()
@@ -44,11 +39,12 @@ class TypedAttributeProvider:
@property
def extra_attributes(self) -> Mapping[T_Attr, Callable[[], T_Attr]]:
"""
- A mapping of the extra attributes to callables that return the corresponding values.
+ A mapping of the extra attributes to callables that return the corresponding
+ values.
- If the provider wraps another provider, the attributes from that wrapper should also be
- included in the returned mapping (but the wrapper may override the callables from the
- wrapped instance).
+ If the provider wraps another provider, the attributes from that wrapper should
+ also be included in the returned mapping (but the wrapper may override the
+ callables from the wrapped instance).
"""
return {}
@@ -68,10 +64,12 @@ class TypedAttributeProvider:
Return the value of the given typed extra attribute.
- :param attribute: the attribute (member of a :class:`~TypedAttributeSet`) to look for
- :param default: the value that should be returned if no value is found for the attribute
- :raises ~anyio.TypedAttributeLookupError: if the search failed and no default value was
- given
+ :param attribute: the attribute (member of a :class:`~TypedAttributeSet`) to
+ look for
+ :param default: the value that should be returned if no value is found for the
+ attribute
+ :raises ~anyio.TypedAttributeLookupError: if the search failed and no default
+ value was given
"""
try:
diff --git a/contrib/python/anyio/anyio/abc/__init__.py b/contrib/python/anyio/anyio/abc/__init__.py
index 72c34e544e..1ca0fcf746 100644
--- a/contrib/python/anyio/anyio/abc/__init__.py
+++ b/contrib/python/anyio/anyio/abc/__init__.py
@@ -1,86 +1,53 @@
from __future__ import annotations
-__all__ = (
- "AsyncResource",
- "IPAddressType",
- "IPSockAddrType",
- "SocketAttribute",
- "SocketStream",
- "SocketListener",
- "UDPSocket",
- "UNIXSocketStream",
- "UDPPacketType",
- "ConnectedUDPSocket",
- "UnreliableObjectReceiveStream",
- "UnreliableObjectSendStream",
- "UnreliableObjectStream",
- "ObjectReceiveStream",
- "ObjectSendStream",
- "ObjectStream",
- "ByteReceiveStream",
- "ByteSendStream",
- "ByteStream",
- "AnyUnreliableByteReceiveStream",
- "AnyUnreliableByteSendStream",
- "AnyUnreliableByteStream",
- "AnyByteReceiveStream",
- "AnyByteSendStream",
- "AnyByteStream",
- "Listener",
- "Process",
- "Event",
- "Condition",
- "Lock",
- "Semaphore",
- "CapacityLimiter",
- "CancelScope",
- "TaskGroup",
- "TaskStatus",
- "TestRunner",
- "BlockingPortal",
-)
-
from typing import Any
-from ._resources import AsyncResource
-from ._sockets import (
- ConnectedUDPSocket,
- IPAddressType,
- IPSockAddrType,
- SocketAttribute,
- SocketListener,
- SocketStream,
- UDPPacketType,
- UDPSocket,
- UNIXSocketStream,
-)
-from ._streams import (
- AnyByteReceiveStream,
- AnyByteSendStream,
- AnyByteStream,
- AnyUnreliableByteReceiveStream,
- AnyUnreliableByteSendStream,
- AnyUnreliableByteStream,
- ByteReceiveStream,
- ByteSendStream,
- ByteStream,
- Listener,
- ObjectReceiveStream,
- ObjectSendStream,
- ObjectStream,
- UnreliableObjectReceiveStream,
- UnreliableObjectSendStream,
- UnreliableObjectStream,
-)
-from ._subprocesses import Process
-from ._tasks import TaskGroup, TaskStatus
-from ._testing import TestRunner
+from ._eventloop import AsyncBackend as AsyncBackend
+from ._resources import AsyncResource as AsyncResource
+from ._sockets import ConnectedUDPSocket as ConnectedUDPSocket
+from ._sockets import ConnectedUNIXDatagramSocket as ConnectedUNIXDatagramSocket
+from ._sockets import IPAddressType as IPAddressType
+from ._sockets import IPSockAddrType as IPSockAddrType
+from ._sockets import SocketAttribute as SocketAttribute
+from ._sockets import SocketListener as SocketListener
+from ._sockets import SocketStream as SocketStream
+from ._sockets import UDPPacketType as UDPPacketType
+from ._sockets import UDPSocket as UDPSocket
+from ._sockets import UNIXDatagramPacketType as UNIXDatagramPacketType
+from ._sockets import UNIXDatagramSocket as UNIXDatagramSocket
+from ._sockets import UNIXSocketStream as UNIXSocketStream
+from ._streams import AnyByteReceiveStream as AnyByteReceiveStream
+from ._streams import AnyByteSendStream as AnyByteSendStream
+from ._streams import AnyByteStream as AnyByteStream
+from ._streams import AnyUnreliableByteReceiveStream as AnyUnreliableByteReceiveStream
+from ._streams import AnyUnreliableByteSendStream as AnyUnreliableByteSendStream
+from ._streams import AnyUnreliableByteStream as AnyUnreliableByteStream
+from ._streams import ByteReceiveStream as ByteReceiveStream
+from ._streams import ByteSendStream as ByteSendStream
+from ._streams import ByteStream as ByteStream
+from ._streams import Listener as Listener
+from ._streams import ObjectReceiveStream as ObjectReceiveStream
+from ._streams import ObjectSendStream as ObjectSendStream
+from ._streams import ObjectStream as ObjectStream
+from ._streams import UnreliableObjectReceiveStream as UnreliableObjectReceiveStream
+from ._streams import UnreliableObjectSendStream as UnreliableObjectSendStream
+from ._streams import UnreliableObjectStream as UnreliableObjectStream
+from ._subprocesses import Process as Process
+from ._tasks import TaskGroup as TaskGroup
+from ._tasks import TaskStatus as TaskStatus
+from ._testing import TestRunner as TestRunner
# Re-exported here, for backwards compatibility
# isort: off
-from .._core._synchronization import CapacityLimiter, Condition, Event, Lock, Semaphore
-from .._core._tasks import CancelScope
-from ..from_thread import BlockingPortal
+from .._core._synchronization import (
+ CapacityLimiter as CapacityLimiter,
+ Condition as Condition,
+ Event as Event,
+ Lock as Lock,
+ Semaphore as Semaphore,
+)
+from .._core._tasks import CancelScope as CancelScope
+from ..from_thread import BlockingPortal as BlockingPortal
# Re-export imports so they look like they live directly in this package
key: str
diff --git a/contrib/python/anyio/anyio/abc/_eventloop.py b/contrib/python/anyio/anyio/abc/_eventloop.py
new file mode 100644
index 0000000000..4470d83d24
--- /dev/null
+++ b/contrib/python/anyio/anyio/abc/_eventloop.py
@@ -0,0 +1,392 @@
+from __future__ import annotations
+
+import math
+import sys
+from abc import ABCMeta, abstractmethod
+from collections.abc import AsyncIterator, Awaitable, Mapping
+from os import PathLike
+from signal import Signals
+from socket import AddressFamily, SocketKind, socket
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ ContextManager,
+ Sequence,
+ TypeVar,
+ overload,
+)
+
+if sys.version_info >= (3, 11):
+ from typing import TypeVarTuple, Unpack
+else:
+ from typing_extensions import TypeVarTuple, Unpack
+
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from .._core._synchronization import CapacityLimiter, Event
+ from .._core._tasks import CancelScope
+ from .._core._testing import TaskInfo
+ from ..from_thread import BlockingPortal
+ from ._sockets import (
+ ConnectedUDPSocket,
+ ConnectedUNIXDatagramSocket,
+ IPSockAddrType,
+ SocketListener,
+ SocketStream,
+ UDPSocket,
+ UNIXDatagramSocket,
+ UNIXSocketStream,
+ )
+ from ._subprocesses import Process
+ from ._tasks import TaskGroup
+ from ._testing import TestRunner
+
+T_Retval = TypeVar("T_Retval")
+PosArgsT = TypeVarTuple("PosArgsT")
+
+
+class AsyncBackend(metaclass=ABCMeta):
+ @classmethod
+ @abstractmethod
+ def run(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
+ args: tuple[Unpack[PosArgsT]],
+ kwargs: dict[str, Any],
+ options: dict[str, Any],
+ ) -> T_Retval:
+ """
+ Run the given coroutine function in an asynchronous event loop.
+
+ The current thread must not be already running an event loop.
+
+ :param func: a coroutine function
+ :param args: positional arguments to ``func``
+ :param kwargs: positional arguments to ``func``
+ :param options: keyword arguments to call the backend ``run()`` implementation
+ with
+ :return: the return value of the coroutine function
+ """
+
+ @classmethod
+ @abstractmethod
+ def current_token(cls) -> object:
+ """
+
+ :return:
+ """
+
+ @classmethod
+ @abstractmethod
+ def current_time(cls) -> float:
+ """
+ Return the current value of the event loop's internal clock.
+
+ :return: the clock value (seconds)
+ """
+
+ @classmethod
+ @abstractmethod
+ def cancelled_exception_class(cls) -> type[BaseException]:
+ """Return the exception class that is raised in a task if it's cancelled."""
+
+ @classmethod
+ @abstractmethod
+ async def checkpoint(cls) -> None:
+ """
+ Check if the task has been cancelled, and allow rescheduling of other tasks.
+
+ This is effectively the same as running :meth:`checkpoint_if_cancelled` and then
+ :meth:`cancel_shielded_checkpoint`.
+ """
+
+ @classmethod
+ async def checkpoint_if_cancelled(cls) -> None:
+ """
+ Check if the current task group has been cancelled.
+
+ This will check if the task has been cancelled, but will not allow other tasks
+ to be scheduled if not.
+
+ """
+ if cls.current_effective_deadline() == -math.inf:
+ await cls.checkpoint()
+
+ @classmethod
+ async def cancel_shielded_checkpoint(cls) -> None:
+ """
+ Allow the rescheduling of other tasks.
+
+ This will give other tasks the opportunity to run, but without checking if the
+ current task group has been cancelled, unlike with :meth:`checkpoint`.
+
+ """
+ with cls.create_cancel_scope(shield=True):
+ await cls.sleep(0)
+
+ @classmethod
+ @abstractmethod
+ async def sleep(cls, delay: float) -> None:
+ """
+ Pause the current task for the specified duration.
+
+ :param delay: the duration, in seconds
+ """
+
+ @classmethod
+ @abstractmethod
+ def create_cancel_scope(
+ cls, *, deadline: float = math.inf, shield: bool = False
+ ) -> CancelScope:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def current_effective_deadline(cls) -> float:
+ """
+ Return the nearest deadline among all the cancel scopes effective for the
+ current task.
+
+ :return:
+ - a clock value from the event loop's internal clock
+ - ``inf`` if there is no deadline in effect
+ - ``-inf`` if the current scope has been cancelled
+ :rtype: float
+ """
+
+ @classmethod
+ @abstractmethod
+ def create_task_group(cls) -> TaskGroup:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def create_event(cls) -> Event:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def run_sync_in_worker_thread(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
+ args: tuple[Unpack[PosArgsT]],
+ abandon_on_cancel: bool = False,
+ limiter: CapacityLimiter | None = None,
+ ) -> T_Retval:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def check_cancelled(cls) -> None:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def run_async_from_thread(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
+ args: tuple[Unpack[PosArgsT]],
+ token: object,
+ ) -> T_Retval:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def run_sync_from_thread(
+ cls,
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
+ args: tuple[Unpack[PosArgsT]],
+ token: object,
+ ) -> T_Retval:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def create_blocking_portal(cls) -> BlockingPortal:
+ pass
+
+ @classmethod
+ @overload
+ async def open_process(
+ cls,
+ command: str | bytes,
+ *,
+ shell: Literal[True],
+ stdin: int | IO[Any] | None,
+ stdout: int | IO[Any] | None,
+ stderr: int | IO[Any] | None,
+ cwd: str | bytes | PathLike[str] | None = None,
+ env: Mapping[str, str] | None = None,
+ start_new_session: bool = False,
+ ) -> Process:
+ pass
+
+ @classmethod
+ @overload
+ async def open_process(
+ cls,
+ command: Sequence[str | bytes],
+ *,
+ shell: Literal[False],
+ stdin: int | IO[Any] | None,
+ stdout: int | IO[Any] | None,
+ stderr: int | IO[Any] | None,
+ cwd: str | bytes | PathLike[str] | None = None,
+ env: Mapping[str, str] | None = None,
+ start_new_session: bool = False,
+ ) -> Process:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def open_process(
+ cls,
+ command: str | bytes | Sequence[str | bytes],
+ *,
+ shell: bool,
+ stdin: int | IO[Any] | None,
+ stdout: int | IO[Any] | None,
+ stderr: int | IO[Any] | None,
+ cwd: str | bytes | PathLike[str] | None = None,
+ env: Mapping[str, str] | None = None,
+ start_new_session: bool = False,
+ ) -> Process:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def setup_process_pool_exit_at_shutdown(cls, workers: set[Process]) -> None:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def connect_tcp(
+ cls, host: str, port: int, local_address: IPSockAddrType | None = None
+ ) -> SocketStream:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def connect_unix(cls, path: str | bytes) -> UNIXSocketStream:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def create_tcp_listener(cls, sock: socket) -> SocketListener:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def create_unix_listener(cls, sock: socket) -> SocketListener:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def create_udp_socket(
+ cls,
+ family: AddressFamily,
+ local_address: IPSockAddrType | None,
+ remote_address: IPSockAddrType | None,
+ reuse_port: bool,
+ ) -> UDPSocket | ConnectedUDPSocket:
+ pass
+
+ @classmethod
+ @overload
+ async def create_unix_datagram_socket(
+ cls, raw_socket: socket, remote_path: None
+ ) -> UNIXDatagramSocket:
+ ...
+
+ @classmethod
+ @overload
+ async def create_unix_datagram_socket(
+ cls, raw_socket: socket, remote_path: str | bytes
+ ) -> ConnectedUNIXDatagramSocket:
+ ...
+
+ @classmethod
+ @abstractmethod
+ async def create_unix_datagram_socket(
+ cls, raw_socket: socket, remote_path: str | bytes | None
+ ) -> UNIXDatagramSocket | ConnectedUNIXDatagramSocket:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def getaddrinfo(
+ cls,
+ host: bytes | str | None,
+ port: str | int | None,
+ *,
+ family: int | AddressFamily = 0,
+ type: int | SocketKind = 0,
+ proto: int = 0,
+ flags: int = 0,
+ ) -> list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ tuple[str, int] | tuple[str, int, int, int],
+ ]
+ ]:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def getnameinfo(
+ cls, sockaddr: IPSockAddrType, flags: int = 0
+ ) -> tuple[str, str]:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def wait_socket_readable(cls, sock: socket) -> None:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def wait_socket_writable(cls, sock: socket) -> None:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def current_default_thread_limiter(cls) -> CapacityLimiter:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def open_signal_receiver(
+ cls, *signals: Signals
+ ) -> ContextManager[AsyncIterator[Signals]]:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def get_current_task(cls) -> TaskInfo:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def get_running_tasks(cls) -> list[TaskInfo]:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def wait_all_tasks_blocked(cls) -> None:
+ pass
+
+ @classmethod
+ @abstractmethod
+ def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
+ pass
diff --git a/contrib/python/anyio/anyio/abc/_resources.py b/contrib/python/anyio/anyio/abc/_resources.py
index e0a283fc98..9693835bae 100644
--- a/contrib/python/anyio/anyio/abc/_resources.py
+++ b/contrib/python/anyio/anyio/abc/_resources.py
@@ -11,8 +11,8 @@ class AsyncResource(metaclass=ABCMeta):
"""
Abstract base class for all closeable asynchronous resources.
- Works as an asynchronous context manager which returns the instance itself on enter, and calls
- :meth:`aclose` on exit.
+ Works as an asynchronous context manager which returns the instance itself on enter,
+ and calls :meth:`aclose` on exit.
"""
async def __aenter__(self: T) -> T:
diff --git a/contrib/python/anyio/anyio/abc/_sockets.py b/contrib/python/anyio/anyio/abc/_sockets.py
index 6aac5f7c22..b321225a7b 100644
--- a/contrib/python/anyio/anyio/abc/_sockets.py
+++ b/contrib/python/anyio/anyio/abc/_sockets.py
@@ -2,21 +2,14 @@ from __future__ import annotations
import socket
from abc import abstractmethod
+from collections.abc import Callable, Collection, Mapping
from contextlib import AsyncExitStack
from io import IOBase
from ipaddress import IPv4Address, IPv6Address
from socket import AddressFamily
-from typing import (
- Any,
- Callable,
- Collection,
- Mapping,
- Tuple,
- TypeVar,
- Union,
-)
+from types import TracebackType
+from typing import Any, Tuple, TypeVar, Union
-from .._core._tasks import create_task_group
from .._core._typedattr import (
TypedAttributeProvider,
TypedAttributeSet,
@@ -29,9 +22,23 @@ IPAddressType = Union[str, IPv4Address, IPv6Address]
IPSockAddrType = Tuple[str, int]
SockAddrType = Union[IPSockAddrType, str]
UDPPacketType = Tuple[bytes, IPSockAddrType]
+UNIXDatagramPacketType = Tuple[bytes, str]
T_Retval = TypeVar("T_Retval")
+class _NullAsyncContextManager:
+ async def __aenter__(self) -> None:
+ pass
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> bool | None:
+ return None
+
+
class SocketAttribute(TypedAttributeSet):
#: the address family of the underlying socket
family: AddressFamily = typed_attribute()
@@ -70,9 +77,9 @@ class _SocketProvider(TypedAttributeProvider):
# Provide local and remote ports for IP based sockets
if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
- attributes[
- SocketAttribute.local_port
- ] = lambda: self._raw_socket.getsockname()[1]
+ attributes[SocketAttribute.local_port] = (
+ lambda: self._raw_socket.getsockname()[1]
+ )
if peername is not None:
remote_port = peername[1]
attributes[SocketAttribute.remote_port] = lambda: remote_port
@@ -100,8 +107,8 @@ class UNIXSocketStream(SocketStream):
Send file descriptors along with a message to the peer.
:param message: a non-empty bytestring
- :param fds: a collection of files (either numeric file descriptors or open file or socket
- objects)
+ :param fds: a collection of files (either numeric file descriptors or open file
+ or socket objects)
"""
@abstractmethod
@@ -131,9 +138,11 @@ class SocketListener(Listener[SocketStream], _SocketProvider):
handler: Callable[[SocketStream], Any],
task_group: TaskGroup | None = None,
) -> None:
- async with AsyncExitStack() as exit_stack:
+ from .. import create_task_group
+
+ async with AsyncExitStack() as stack:
if task_group is None:
- task_group = await exit_stack.enter_async_context(create_task_group())
+ task_group = await stack.enter_async_context(create_task_group())
while True:
stream = await self.accept()
@@ -148,7 +157,10 @@ class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
"""
async def sendto(self, data: bytes, host: str, port: int) -> None:
- """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port)))."""
+ """
+ Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).
+
+ """
return await self.send((data, (host, port)))
@@ -158,3 +170,25 @@ class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
Supports all relevant extra attributes from :class:`~SocketAttribute`.
"""
+
+
+class UNIXDatagramSocket(
+ UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider
+):
+ """
+ Represents an unconnected Unix datagram socket.
+
+ Supports all relevant extra attributes from :class:`~SocketAttribute`.
+ """
+
+ async def sendto(self, data: bytes, path: str) -> None:
+ """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path))."""
+ return await self.send((data, path))
+
+
+class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider):
+ """
+ Represents a connected Unix datagram socket.
+
+ Supports all relevant extra attributes from :class:`~SocketAttribute`.
+ """
diff --git a/contrib/python/anyio/anyio/abc/_streams.py b/contrib/python/anyio/anyio/abc/_streams.py
index 4fa7ccc9ff..8c638683a4 100644
--- a/contrib/python/anyio/anyio/abc/_streams.py
+++ b/contrib/python/anyio/anyio/abc/_streams.py
@@ -1,7 +1,8 @@
from __future__ import annotations
from abc import abstractmethod
-from typing import Any, Callable, Generic, TypeVar, Union
+from collections.abc import Callable
+from typing import Any, Generic, TypeVar, Union
from .._core._exceptions import EndOfStream
from .._core._typedattr import TypedAttributeProvider
@@ -19,11 +20,11 @@ class UnreliableObjectReceiveStream(
"""
An interface for receiving objects.
- This interface makes no guarantees that the received messages arrive in the order in which they
- were sent, or that no messages are missed.
+ This interface makes no guarantees that the received messages arrive in the order in
+ which they were sent, or that no messages are missed.
- Asynchronously iterating over objects of this type will yield objects matching the given type
- parameter.
+ Asynchronously iterating over objects of this type will yield objects matching the
+ given type parameter.
"""
def __aiter__(self) -> UnreliableObjectReceiveStream[T_co]:
@@ -54,8 +55,8 @@ class UnreliableObjectSendStream(
"""
An interface for sending objects.
- This interface makes no guarantees that the messages sent will reach the recipient(s) in the
- same order in which they were sent, or at all.
+ This interface makes no guarantees that the messages sent will reach the
+ recipient(s) in the same order in which they were sent, or at all.
"""
@abstractmethod
@@ -75,22 +76,22 @@ class UnreliableObjectStream(
UnreliableObjectReceiveStream[T_Item], UnreliableObjectSendStream[T_Item]
):
"""
- A bidirectional message stream which does not guarantee the order or reliability of message
- delivery.
+ A bidirectional message stream which does not guarantee the order or reliability of
+ message delivery.
"""
class ObjectReceiveStream(UnreliableObjectReceiveStream[T_co]):
"""
- A receive message stream which guarantees that messages are received in the same order in
- which they were sent, and that no messages are missed.
+ A receive message stream which guarantees that messages are received in the same
+ order in which they were sent, and that no messages are missed.
"""
class ObjectSendStream(UnreliableObjectSendStream[T_contra]):
"""
- A send message stream which guarantees that messages are delivered in the same order in which
- they were sent, without missing any messages in the middle.
+ A send message stream which guarantees that messages are delivered in the same order
+ in which they were sent, without missing any messages in the middle.
"""
@@ -100,7 +101,8 @@ class ObjectStream(
UnreliableObjectStream[T_Item],
):
"""
- A bidirectional message stream which guarantees the order and reliability of message delivery.
+ A bidirectional message stream which guarantees the order and reliability of message
+ delivery.
"""
@abstractmethod
@@ -108,8 +110,8 @@ class ObjectStream(
"""
Send an end-of-file indication to the peer.
- You should not try to send any further data to this stream after calling this method.
- This method is idempotent (does nothing on successive calls).
+ You should not try to send any further data to this stream after calling this
+ method. This method is idempotent (does nothing on successive calls).
"""
@@ -117,8 +119,8 @@ class ByteReceiveStream(AsyncResource, TypedAttributeProvider):
"""
An interface for receiving bytes from a single peer.
- Iterating this byte stream will yield a byte string of arbitrary length, but no more than
- 65536 bytes.
+ Iterating this byte stream will yield a byte string of arbitrary length, but no more
+ than 65536 bytes.
"""
def __aiter__(self) -> ByteReceiveStream:
@@ -135,8 +137,8 @@ class ByteReceiveStream(AsyncResource, TypedAttributeProvider):
"""
Receive at most ``max_bytes`` bytes from the peer.
- .. note:: Implementors of this interface should not return an empty :class:`bytes` object,
- and users should ignore them.
+ .. note:: Implementors of this interface should not return an empty
+ :class:`bytes` object, and users should ignore them.
:param max_bytes: maximum number of bytes to receive
:return: the received bytes
@@ -164,8 +166,8 @@ class ByteStream(ByteReceiveStream, ByteSendStream):
"""
Send an end-of-file indication to the peer.
- You should not try to send any further data to this stream after calling this method.
- This method is idempotent (does nothing on successive calls).
+ You should not try to send any further data to this stream after calling this
+ method. This method is idempotent (does nothing on successive calls).
"""
@@ -190,14 +192,12 @@ class Listener(Generic[T_co], AsyncResource, TypedAttributeProvider):
@abstractmethod
async def serve(
- self,
- handler: Callable[[T_co], Any],
- task_group: TaskGroup | None = None,
+ self, handler: Callable[[T_co], Any], task_group: TaskGroup | None = None
) -> None:
"""
Accept incoming connections as they come in and start tasks to handle them.
:param handler: a callable that will be used to handle each accepted connection
- :param task_group: the task group that will be used to start tasks for handling each
- accepted connection (if omitted, an ad-hoc task group will be created)
+ :param task_group: the task group that will be used to start tasks for handling
+ each accepted connection (if omitted, an ad-hoc task group will be created)
"""
diff --git a/contrib/python/anyio/anyio/abc/_subprocesses.py b/contrib/python/anyio/anyio/abc/_subprocesses.py
index 704b44a2dd..ce0564ceac 100644
--- a/contrib/python/anyio/anyio/abc/_subprocesses.py
+++ b/contrib/python/anyio/anyio/abc/_subprocesses.py
@@ -59,8 +59,8 @@ class Process(AsyncResource):
@abstractmethod
def returncode(self) -> int | None:
"""
- The return code of the process. If the process has not yet terminated, this will be
- ``None``.
+ The return code of the process. If the process has not yet terminated, this will
+ be ``None``.
"""
@property
diff --git a/contrib/python/anyio/anyio/abc/_tasks.py b/contrib/python/anyio/anyio/abc/_tasks.py
index e48d3c1e97..7ad4938cb4 100644
--- a/contrib/python/anyio/anyio/abc/_tasks.py
+++ b/contrib/python/anyio/anyio/abc/_tasks.py
@@ -2,20 +2,21 @@ from __future__ import annotations
import sys
from abc import ABCMeta, abstractmethod
+from collections.abc import Awaitable, Callable
from types import TracebackType
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, overload
-from warnings import warn
+from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload
-if sys.version_info >= (3, 8):
- from typing import Protocol
+if sys.version_info >= (3, 11):
+ from typing import TypeVarTuple, Unpack
else:
- from typing_extensions import Protocol
+ from typing_extensions import TypeVarTuple, Unpack
if TYPE_CHECKING:
- from anyio._core._tasks import CancelScope
+ from .._core._tasks import CancelScope
T_Retval = TypeVar("T_Retval")
T_contra = TypeVar("T_contra", contravariant=True)
+PosArgsT = TypeVarTuple("PosArgsT")
class TaskStatus(Protocol[T_contra]):
@@ -45,35 +46,11 @@ class TaskGroup(metaclass=ABCMeta):
cancel_scope: CancelScope
- async def spawn(
- self,
- func: Callable[..., Awaitable[Any]],
- *args: object,
- name: object = None,
- ) -> None:
- """
- Start a new task in this task group.
-
- :param func: a coroutine function
- :param args: positional arguments to call the function with
- :param name: name of the task, for the purposes of introspection and debugging
-
- .. deprecated:: 3.0
- Use :meth:`start_soon` instead. If your code needs AnyIO 2 compatibility, you
- can keep using this until AnyIO 4.
-
- """
- warn(
- 'spawn() is deprecated -- use start_soon() (without the "await") instead',
- DeprecationWarning,
- )
- self.start_soon(func, *args, name=name)
-
@abstractmethod
def start_soon(
self,
- func: Callable[..., Awaitable[Any]],
- *args: object,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
+ *args: Unpack[PosArgsT],
name: object = None,
) -> None:
"""
@@ -100,7 +77,8 @@ class TaskGroup(metaclass=ABCMeta):
:param args: positional arguments to call the function with
:param name: name of the task, for the purposes of introspection and debugging
:return: the value passed to ``task_status.started()``
- :raises RuntimeError: if the task finishes without calling ``task_status.started()``
+ :raises RuntimeError: if the task finishes without calling
+ ``task_status.started()``
.. versionadded:: 3.0
"""
diff --git a/contrib/python/anyio/anyio/abc/_testing.py b/contrib/python/anyio/anyio/abc/_testing.py
index ee2cff5cc3..4d70b9ec6b 100644
--- a/contrib/python/anyio/anyio/abc/_testing.py
+++ b/contrib/python/anyio/anyio/abc/_testing.py
@@ -2,33 +2,29 @@ from __future__ import annotations
import types
from abc import ABCMeta, abstractmethod
-from collections.abc import AsyncGenerator, Iterable
-from typing import Any, Callable, Coroutine, TypeVar
+from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable
+from typing import Any, TypeVar
_T = TypeVar("_T")
class TestRunner(metaclass=ABCMeta):
"""
- Encapsulates a running event loop. Every call made through this object will use the same event
- loop.
+ Encapsulates a running event loop. Every call made through this object will use the
+ same event loop.
"""
def __enter__(self) -> TestRunner:
return self
+ @abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> bool | None:
- self.close()
- return None
-
- @abstractmethod
- def close(self) -> None:
- """Close the event loop."""
+ ...
@abstractmethod
def run_asyncgen_fixture(
diff --git a/contrib/python/anyio/anyio/from_thread.py b/contrib/python/anyio/anyio/from_thread.py
index 6b76861c70..4a987031fe 100644
--- a/contrib/python/anyio/anyio/from_thread.py
+++ b/contrib/python/anyio/anyio/from_thread.py
@@ -1,36 +1,43 @@
from __future__ import annotations
+import sys
import threading
-from asyncio import iscoroutine
+from collections.abc import Awaitable, Callable, Generator
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from contextlib import AbstractContextManager, contextmanager
+from inspect import isawaitable
from types import TracebackType
from typing import (
Any,
AsyncContextManager,
- Awaitable,
- Callable,
ContextManager,
- Generator,
Generic,
Iterable,
TypeVar,
cast,
overload,
)
-from warnings import warn
from ._core import _eventloop
-from ._core._eventloop import get_asynclib, get_cancelled_exc_class, threadlocals
+from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals
from ._core._synchronization import Event
from ._core._tasks import CancelScope, create_task_group
+from .abc import AsyncBackend
from .abc._tasks import TaskStatus
+if sys.version_info >= (3, 11):
+ from typing import TypeVarTuple, Unpack
+else:
+ from typing_extensions import TypeVarTuple, Unpack
+
T_Retval = TypeVar("T_Retval")
-T_co = TypeVar("T_co")
+T_co = TypeVar("T_co", covariant=True)
+PosArgsT = TypeVarTuple("PosArgsT")
-def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval:
+def run(
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT]
+) -> T_Retval:
"""
Call a coroutine function from a worker thread.
@@ -40,24 +47,19 @@ def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval:
"""
try:
- asynclib = threadlocals.current_async_module
+ async_backend = threadlocals.current_async_backend
+ token = threadlocals.current_token
except AttributeError:
- raise RuntimeError("This function can only be run from an AnyIO worker thread")
+ raise RuntimeError(
+ "This function can only be run from an AnyIO worker thread"
+ ) from None
- return asynclib.run_async_from_thread(func, *args)
+ return async_backend.run_async_from_thread(func, args, token=token)
-def run_async_from_thread(
- func: Callable[..., Awaitable[T_Retval]], *args: object
+def run_sync(
+ func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
) -> T_Retval:
- warn(
- "run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead",
- DeprecationWarning,
- )
- return run(func, *args)
-
-
-def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval:
"""
Call a function in the event loop thread from a worker thread.
@@ -67,24 +69,19 @@ def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval:
"""
try:
- asynclib = threadlocals.current_async_module
+ async_backend = threadlocals.current_async_backend
+ token = threadlocals.current_token
except AttributeError:
- raise RuntimeError("This function can only be run from an AnyIO worker thread")
-
- return asynclib.run_sync_from_thread(func, *args)
+ raise RuntimeError(
+ "This function can only be run from an AnyIO worker thread"
+ ) from None
-
-def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval:
- warn(
- "run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead",
- DeprecationWarning,
- )
- return run_sync(func, *args)
+ return async_backend.run_sync_from_thread(func, args, token=token)
class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
- _enter_future: Future
- _exit_future: Future
+ _enter_future: Future[T_co]
+ _exit_future: Future[bool | None]
_exit_event: Event
_exit_exc_info: tuple[
type[BaseException] | None, BaseException | None, TracebackType | None
@@ -120,8 +117,7 @@ class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
def __enter__(self) -> T_co:
self._enter_future = Future()
self._exit_future = self._portal.start_task_soon(self.run_async_cm)
- cm = self._enter_future.result()
- return cast(T_co, cm)
+ return self._enter_future.result()
def __exit__(
self,
@@ -146,7 +142,7 @@ class BlockingPortal:
"""An object that lets external threads run code in an asynchronous event loop."""
def __new__(cls) -> BlockingPortal:
- return get_asynclib().BlockingPortal()
+ return get_async_backend().create_blocking_portal()
def __init__(self) -> None:
self._event_loop_thread_id: int | None = threading.get_ident()
@@ -186,8 +182,8 @@ class BlockingPortal:
This marks the portal as no longer accepting new calls and exits from
:meth:`sleep_until_stopped`.
- :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` to let them
- finish before returning
+ :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
+ to let them finish before returning
"""
self._event_loop_thread_id = None
@@ -196,9 +192,13 @@ class BlockingPortal:
self._task_group.cancel_scope.cancel()
async def _call_func(
- self, func: Callable, args: tuple, kwargs: dict[str, Any], future: Future
+ self,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
+ args: tuple[Unpack[PosArgsT]],
+ kwargs: dict[str, Any],
+ future: Future[T_Retval],
) -> None:
- def callback(f: Future) -> None:
+ def callback(f: Future[T_Retval]) -> None:
if f.cancelled() and self._event_loop_thread_id not in (
None,
threading.get_ident(),
@@ -206,17 +206,20 @@ class BlockingPortal:
self.call(scope.cancel)
try:
- retval = func(*args, **kwargs)
- if iscoroutine(retval):
+ retval_or_awaitable = func(*args, **kwargs)
+ if isawaitable(retval_or_awaitable):
with CancelScope() as scope:
if future.cancelled():
scope.cancel()
else:
future.add_done_callback(callback)
- retval = await retval
+ retval = await retval_or_awaitable
+ else:
+ retval = retval_or_awaitable
except self._cancelled_exc_class:
future.cancel()
+ future.set_running_or_notify_cancel()
except BaseException as exc:
if not future.cancelled():
future.set_exception(exc)
@@ -232,11 +235,11 @@ class BlockingPortal:
def _spawn_task_from_thread(
self,
- func: Callable,
- args: tuple,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
+ args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
name: object,
- future: Future,
+ future: Future[T_Retval],
) -> None:
"""
Spawn a new task using the given callable.
@@ -247,22 +250,30 @@ class BlockingPortal:
:param args: positional arguments to be passed to the callable
:param kwargs: keyword arguments to be passed to the callable
:param name: name of the task (will be coerced to a string if not ``None``)
- :param future: a future that will resolve to the return value of the callable, or the
- exception raised during its execution
+ :param future: a future that will resolve to the return value of the callable,
+ or the exception raised during its execution
"""
raise NotImplementedError
@overload
- def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval:
+ def call(
+ self,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
+ *args: Unpack[PosArgsT],
+ ) -> T_Retval:
...
@overload
- def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval:
+ def call(
+ self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
+ ) -> T_Retval:
...
def call(
- self, func: Callable[..., Awaitable[T_Retval] | T_Retval], *args: object
+ self,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
+ *args: Unpack[PosArgsT],
) -> T_Retval:
"""
Call the given function in the event loop thread.
@@ -270,82 +281,41 @@ class BlockingPortal:
If the callable returns a coroutine object, it is awaited on.
:param func: any callable
- :raises RuntimeError: if the portal is not running or if this method is called from within
- the event loop thread
+ :raises RuntimeError: if the portal is not running or if this method is called
+ from within the event loop thread
"""
return cast(T_Retval, self.start_task_soon(func, *args).result())
@overload
- def spawn_task(
+ def start_task_soon(
self,
- func: Callable[..., Awaitable[T_Retval]],
- *args: object,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
+ *args: Unpack[PosArgsT],
name: object = None,
) -> Future[T_Retval]:
...
@overload
- def spawn_task(
- self, func: Callable[..., T_Retval], *args: object, name: object = None
- ) -> Future[T_Retval]:
- ...
-
- def spawn_task(
- self,
- func: Callable[..., Awaitable[T_Retval] | T_Retval],
- *args: object,
- name: object = None,
- ) -> Future[T_Retval]:
- """
- Start a task in the portal's task group.
-
- :param func: the target coroutine function
- :param args: positional arguments passed to ``func``
- :param name: name of the task (will be coerced to a string if not ``None``)
- :return: a future that resolves with the return value of the callable if the task completes
- successfully, or with the exception raised in the task
- :raises RuntimeError: if the portal is not running or if this method is called from within
- the event loop thread
-
- .. versionadded:: 2.1
- .. deprecated:: 3.0
- Use :meth:`start_task_soon` instead. If your code needs AnyIO 2 compatibility, you
- can keep using this until AnyIO 4.
-
- """
- warn(
- "spawn_task() is deprecated -- use start_task_soon() instead",
- DeprecationWarning,
- )
- return self.start_task_soon(func, *args, name=name) # type: ignore[arg-type]
-
- @overload
def start_task_soon(
self,
- func: Callable[..., Awaitable[T_Retval]],
- *args: object,
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
+ *args: Unpack[PosArgsT],
name: object = None,
) -> Future[T_Retval]:
...
- @overload
- def start_task_soon(
- self, func: Callable[..., T_Retval], *args: object, name: object = None
- ) -> Future[T_Retval]:
- ...
-
def start_task_soon(
self,
- func: Callable[..., Awaitable[T_Retval] | T_Retval],
- *args: object,
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
+ *args: Unpack[PosArgsT],
name: object = None,
) -> Future[T_Retval]:
"""
Start a task in the portal's task group.
- The task will be run inside a cancel scope which can be cancelled by cancelling the
- returned future.
+ The task will be run inside a cancel scope which can be cancelled by cancelling
+ the returned future.
:param func: the target function
:param args: positional arguments passed to ``func``
@@ -360,13 +330,16 @@ class BlockingPortal:
"""
self._check_running()
- f: Future = Future()
+ f: Future[T_Retval] = Future()
self._spawn_task_from_thread(func, args, {}, name, f)
return f
def start_task(
- self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
- ) -> tuple[Future[Any], Any]:
+ self,
+ func: Callable[..., Awaitable[T_Retval]],
+ *args: object,
+ name: object = None,
+ ) -> tuple[Future[T_Retval], Any]:
"""
Start a task in the portal's task group and wait until it signals for readiness.
@@ -378,13 +351,13 @@ class BlockingPortal:
:return: a tuple of (future, task_status_value) where the ``task_status_value``
is the value passed to ``task_status.started()`` from within the target
function
- :rtype: tuple[concurrent.futures.Future[Any], Any]
+ :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
.. versionadded:: 3.0
"""
- def task_done(future: Future) -> None:
+ def task_done(future: Future[T_Retval]) -> None:
if not task_status_future.done():
if future.cancelled():
task_status_future.cancel()
@@ -410,8 +383,8 @@ class BlockingPortal:
"""
Wrap an async context manager as a synchronous context manager via this portal.
- Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping in the
- middle until the synchronous context manager exits.
+ Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
+ in the middle until the synchronous context manager exits.
:param cm: an asynchronous context manager
:return: a synchronous context manager
@@ -422,25 +395,6 @@ class BlockingPortal:
return _BlockingAsyncContextManager(cm, self)
-def create_blocking_portal() -> BlockingPortal:
- """
- Create a portal for running functions in the event loop thread from external threads.
-
- Use this function in asynchronous code when you need to allow external threads access to the
- event loop where your asynchronous code is currently running.
-
- .. deprecated:: 3.0
- Use :class:`.BlockingPortal` directly.
-
- """
- warn(
- "create_blocking_portal() has been deprecated -- use anyio.from_thread.BlockingPortal() "
- "directly",
- DeprecationWarning,
- )
- return BlockingPortal()
-
-
@contextmanager
def start_blocking_portal(
backend: str = "asyncio", backend_options: dict[str, Any] | None = None
@@ -468,8 +422,8 @@ def start_blocking_portal(
future: Future[BlockingPortal] = Future()
with ThreadPoolExecutor(1) as executor:
run_future = executor.submit(
- _eventloop.run,
- run_portal, # type: ignore[arg-type]
+ _eventloop.run, # type: ignore[arg-type]
+ run_portal,
backend=backend,
backend_options=backend_options,
)
@@ -498,3 +452,25 @@ def start_blocking_portal(
pass
run_future.result()
+
+
+def check_cancelled() -> None:
+ """
+ Check if the cancel scope of the host task's running the current worker thread has
+ been cancelled.
+
+ If the host task's current cancel scope has indeed been cancelled, the
+ backend-specific cancellation exception will be raised.
+
+ :raises RuntimeError: if the current thread was not spawned by
+ :func:`.to_thread.run_sync`
+
+ """
+ try:
+ async_backend: AsyncBackend = threadlocals.current_async_backend
+ except AttributeError:
+ raise RuntimeError(
+ "This function can only be run from an AnyIO worker thread"
+ ) from None
+
+ async_backend.check_cancelled()
diff --git a/contrib/python/anyio/anyio/lowlevel.py b/contrib/python/anyio/anyio/lowlevel.py
index 0e908c6547..a9e10f430a 100644
--- a/contrib/python/anyio/anyio/lowlevel.py
+++ b/contrib/python/anyio/anyio/lowlevel.py
@@ -1,17 +1,11 @@
from __future__ import annotations
import enum
-import sys
from dataclasses import dataclass
-from typing import Any, Generic, TypeVar, overload
+from typing import Any, Generic, Literal, TypeVar, overload
from weakref import WeakKeyDictionary
-from ._core._eventloop import get_asynclib
-
-if sys.version_info >= (3, 8):
- from typing import Literal
-else:
- from typing_extensions import Literal
+from ._core._eventloop import get_async_backend
T = TypeVar("T")
D = TypeVar("D")
@@ -30,7 +24,7 @@ async def checkpoint() -> None:
.. versionadded:: 3.0
"""
- await get_asynclib().checkpoint()
+ await get_async_backend().checkpoint()
async def checkpoint_if_cancelled() -> None:
@@ -42,7 +36,7 @@ async def checkpoint_if_cancelled() -> None:
.. versionadded:: 3.0
"""
- await get_asynclib().checkpoint_if_cancelled()
+ await get_async_backend().checkpoint_if_cancelled()
async def cancel_shielded_checkpoint() -> None:
@@ -58,12 +52,16 @@ async def cancel_shielded_checkpoint() -> None:
.. versionadded:: 3.0
"""
- await get_asynclib().cancel_shielded_checkpoint()
+ await get_async_backend().cancel_shielded_checkpoint()
def current_token() -> object:
- """Return a backend specific token object that can be used to get back to the event loop."""
- return get_asynclib().current_token()
+ """
+ Return a backend specific token object that can be used to get back to the event
+ loop.
+
+ """
+ return get_async_backend().current_token()
_run_vars: WeakKeyDictionary[Any, dict[str, Any]] = WeakKeyDictionary()
@@ -101,9 +99,7 @@ class RunVar(Generic[T]):
_token_wrappers: set[_TokenWrapper] = set()
def __init__(
- self,
- name: str,
- default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET,
+ self, name: str, default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
):
self._name = name
self._default = default
@@ -111,18 +107,11 @@ class RunVar(Generic[T]):
@property
def _current_vars(self) -> dict[str, T]:
token = current_token()
- while True:
- try:
- return _run_vars[token]
- except TypeError:
- # Happens when token isn't weak referable (TrioToken).
- # This workaround does mean that some memory will leak on Trio until the problem
- # is fixed on their end.
- token = _TokenWrapper(token)
- self._token_wrappers.add(token)
- except KeyError:
- run_vars = _run_vars[token] = {}
- return run_vars
+ try:
+ return _run_vars[token]
+ except KeyError:
+ run_vars = _run_vars[token] = {}
+ return run_vars
@overload
def get(self, default: D) -> T | D:
diff --git a/contrib/python/anyio/anyio/pytest_plugin.py b/contrib/python/anyio/anyio/pytest_plugin.py
index 044ce6914d..a8dd6f3e3f 100644
--- a/contrib/python/anyio/anyio/pytest_plugin.py
+++ b/contrib/python/anyio/anyio/pytest_plugin.py
@@ -1,16 +1,19 @@
from __future__ import annotations
-from contextlib import contextmanager
+from collections.abc import Iterator
+from contextlib import ExitStack, contextmanager
from inspect import isasyncgenfunction, iscoroutinefunction
-from typing import Any, Dict, Generator, Tuple, cast
+from typing import Any, Dict, Tuple, cast
import pytest
import sniffio
-from ._core._eventloop import get_all_backends, get_asynclib
+from ._core._eventloop import get_all_backends, get_async_backend
from .abc import TestRunner
_current_runner: TestRunner | None = None
+_runner_stack: ExitStack | None = None
+_runner_leases = 0
def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]:
@@ -26,27 +29,31 @@ def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]:
@contextmanager
def get_runner(
backend_name: str, backend_options: dict[str, Any]
-) -> Generator[TestRunner, object, None]:
- global _current_runner
- if _current_runner:
- yield _current_runner
- return
+) -> Iterator[TestRunner]:
+ global _current_runner, _runner_leases, _runner_stack
+ if _current_runner is None:
+ asynclib = get_async_backend(backend_name)
+ _runner_stack = ExitStack()
+ if sniffio.current_async_library_cvar.get(None) is None:
+ # Since we're in control of the event loop, we can cache the name of the
+ # async library
+ token = sniffio.current_async_library_cvar.set(backend_name)
+ _runner_stack.callback(sniffio.current_async_library_cvar.reset, token)
- asynclib = get_asynclib(backend_name)
- token = None
- if sniffio.current_async_library_cvar.get(None) is None:
- # Since we're in control of the event loop, we can cache the name of the async library
- token = sniffio.current_async_library_cvar.set(backend_name)
+ backend_options = backend_options or {}
+ _current_runner = _runner_stack.enter_context(
+ asynclib.create_test_runner(backend_options)
+ )
+ _runner_leases += 1
try:
- backend_options = backend_options or {}
- with asynclib.TestRunner(**backend_options) as runner:
- _current_runner = runner
- yield runner
+ yield _current_runner
finally:
- _current_runner = None
- if token:
- sniffio.current_async_library_cvar.reset(token)
+ _runner_leases -= 1
+ if not _runner_leases:
+ assert _runner_stack is not None
+ _runner_stack.close()
+ _runner_stack = _current_runner = None
def pytest_configure(config: Any) -> None:
@@ -69,8 +76,8 @@ def pytest_fixture_setup(fixturedef: Any, request: Any) -> None:
else:
yield runner.run_fixture(func, kwargs)
- # Only apply this to coroutine functions and async generator functions in requests that involve
- # the anyio_backend fixture
+ # Only apply this to coroutine functions and async generator functions in requests
+ # that involve the anyio_backend fixture
func = fixturedef.func
if isasyncgenfunction(func) or iscoroutinefunction(func):
if "anyio_backend" in request.fixturenames:
@@ -121,7 +128,7 @@ def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
return None
-@pytest.fixture(params=get_all_backends())
+@pytest.fixture(scope="module", params=get_all_backends())
def anyio_backend(request: Any) -> Any:
return request.param
diff --git a/contrib/python/anyio/anyio/streams/buffered.py b/contrib/python/anyio/anyio/streams/buffered.py
index 11474c16a9..f5d5e836dd 100644
--- a/contrib/python/anyio/anyio/streams/buffered.py
+++ b/contrib/python/anyio/anyio/streams/buffered.py
@@ -1,7 +1,8 @@
from __future__ import annotations
+from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
-from typing import Any, Callable, Mapping
+from typing import Any
from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead
from ..abc import AnyByteReceiveStream, ByteReceiveStream
@@ -10,8 +11,8 @@ from ..abc import AnyByteReceiveStream, ByteReceiveStream
@dataclass(eq=False)
class BufferedByteReceiveStream(ByteReceiveStream):
"""
- Wraps any bytes-based receive stream and uses a buffer to provide sophisticated receiving
- capabilities in the form of a byte stream.
+ Wraps any bytes-based receive stream and uses a buffer to provide sophisticated
+ receiving capabilities in the form of a byte stream.
"""
receive_stream: AnyByteReceiveStream
@@ -42,8 +43,8 @@ class BufferedByteReceiveStream(ByteReceiveStream):
elif isinstance(self.receive_stream, ByteReceiveStream):
return await self.receive_stream.receive(max_bytes)
else:
- # With a bytes-oriented object stream, we need to handle any surplus bytes we get from
- # the receive() call
+ # With a bytes-oriented object stream, we need to handle any surplus bytes
+ # we get from the receive() call
chunk = await self.receive_stream.receive()
if len(chunk) > max_bytes:
# Save the surplus bytes in the buffer
diff --git a/contrib/python/anyio/anyio/streams/file.py b/contrib/python/anyio/anyio/streams/file.py
index 2840d40ab6..f492464267 100644
--- a/contrib/python/anyio/anyio/streams/file.py
+++ b/contrib/python/anyio/anyio/streams/file.py
@@ -1,9 +1,10 @@
from __future__ import annotations
+from collections.abc import Callable, Mapping
from io import SEEK_SET, UnsupportedOperation
from os import PathLike
from pathlib import Path
-from typing import Any, BinaryIO, Callable, Mapping, cast
+from typing import Any, BinaryIO, cast
from .. import (
BrokenResourceError,
@@ -130,8 +131,8 @@ class FileWriteStream(_BaseFileStream, ByteSendStream):
Create a file write stream by opening the given file for writing.
:param path: path of the file to write to
- :param append: if ``True``, open the file for appending; if ``False``, any existing file
- at the given path will be truncated
+ :param append: if ``True``, open the file for appending; if ``False``, any
+ existing file at the given path will be truncated
"""
mode = "ab" if append else "wb"
diff --git a/contrib/python/anyio/anyio/streams/memory.py b/contrib/python/anyio/anyio/streams/memory.py
index a6499c13ff..bc2425b76f 100644
--- a/contrib/python/anyio/anyio/streams/memory.py
+++ b/contrib/python/anyio/anyio/streams/memory.py
@@ -10,9 +10,7 @@ from .. import (
ClosedResourceError,
EndOfStream,
WouldBlock,
- get_cancelled_exc_class,
)
-from .._core._compat import DeprecatedAwaitable
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
from ..lowlevel import checkpoint
@@ -27,7 +25,8 @@ class MemoryObjectStreamStatistics(NamedTuple):
max_buffer_size: float
open_send_streams: int #: number of unclosed clones of the send stream
open_receive_streams: int #: number of unclosed clones of the receive stream
- tasks_waiting_send: int #: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
+ #: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
+ tasks_waiting_send: int
#: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive`
tasks_waiting_receive: int
@@ -104,11 +103,6 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
try:
await receive_event.wait()
- except get_cancelled_exc_class():
- # Ignore the immediate cancellation if we already received an item, so as not to
- # lose it
- if not container:
- raise
finally:
self._state.waiting_receivers.pop(receive_event, None)
@@ -121,8 +115,8 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
"""
Create a clone of this receive stream.
- Each clone can be closed separately. Only when all clones have been closed will the
- receiving end of the memory stream be considered closed by the sending ends.
+ Each clone can be closed separately. Only when all clones have been closed will
+ the receiving end of the memory stream be considered closed by the sending ends.
:return: the cloned stream
@@ -136,8 +130,8 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
"""
Close the stream.
- This works the exact same way as :meth:`aclose`, but is provided as a special case for the
- benefit of synchronous callbacks.
+ This works the exact same way as :meth:`aclose`, but is provided as a special
+ case for the benefit of synchronous callbacks.
"""
if not self._closed:
@@ -179,7 +173,7 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
def __post_init__(self) -> None:
self._state.open_send_channels += 1
- def send_nowait(self, item: T_contra) -> DeprecatedAwaitable:
+ def send_nowait(self, item: T_contra) -> None:
"""
Send an item immediately if it can be done without waiting.
@@ -205,9 +199,19 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
else:
raise WouldBlock
- return DeprecatedAwaitable(self.send_nowait)
-
async def send(self, item: T_contra) -> None:
+ """
+ Send an item to the stream.
+
+ If the buffer is full, this method blocks until there is again room in the
+ buffer or the item can be sent directly to a receiver.
+
+ :param item: the item to send
+ :raises ~anyio.ClosedResourceError: if this send stream has been closed
+ :raises ~anyio.BrokenResourceError: if the stream has been closed from the
+ receiving end
+
+ """
await checkpoint()
try:
self.send_nowait(item)
@@ -218,18 +222,18 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
try:
await send_event.wait()
except BaseException:
- self._state.waiting_senders.pop(send_event, None) # type: ignore[arg-type]
+ self._state.waiting_senders.pop(send_event, None)
raise
- if self._state.waiting_senders.pop(send_event, None): # type: ignore[arg-type]
- raise BrokenResourceError
+ if self._state.waiting_senders.pop(send_event, None):
+ raise BrokenResourceError from None
def clone(self) -> MemoryObjectSendStream[T_contra]:
"""
Create a clone of this send stream.
- Each clone can be closed separately. Only when all clones have been closed will the
- sending end of the memory stream be considered closed by the receiving ends.
+ Each clone can be closed separately. Only when all clones have been closed will
+ the sending end of the memory stream be considered closed by the receiving ends.
:return: the cloned stream
@@ -243,8 +247,8 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
"""
Close the stream.
- This works the exact same way as :meth:`aclose`, but is provided as a special case for the
- benefit of synchronous callbacks.
+ This works the exact same way as :meth:`aclose`, but is provided as a special
+ case for the benefit of synchronous callbacks.
"""
if not self._closed:
diff --git a/contrib/python/anyio/anyio/streams/stapled.py b/contrib/python/anyio/anyio/streams/stapled.py
index 1b2862e3ea..80f64a2e8e 100644
--- a/contrib/python/anyio/anyio/streams/stapled.py
+++ b/contrib/python/anyio/anyio/streams/stapled.py
@@ -1,7 +1,8 @@
from __future__ import annotations
+from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
-from typing import Any, Callable, Generic, Mapping, Sequence, TypeVar
+from typing import Any, Generic, TypeVar
from ..abc import (
ByteReceiveStream,
@@ -23,8 +24,8 @@ class StapledByteStream(ByteStream):
"""
Combines two byte streams into a single, bidirectional byte stream.
- Extra attributes will be provided from both streams, with the receive stream providing the
- values in case of a conflict.
+ Extra attributes will be provided from both streams, with the receive stream
+ providing the values in case of a conflict.
:param ByteSendStream send_stream: the sending byte stream
:param ByteReceiveStream receive_stream: the receiving byte stream
@@ -59,8 +60,8 @@ class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]):
"""
Combines two object streams into a single, bidirectional object stream.
- Extra attributes will be provided from both streams, with the receive stream providing the
- values in case of a conflict.
+ Extra attributes will be provided from both streams, with the receive stream
+ providing the values in case of a conflict.
:param ObjectSendStream send_stream: the sending object stream
:param ObjectReceiveStream receive_stream: the receiving object stream
@@ -95,11 +96,11 @@ class MultiListener(Generic[T_Stream], Listener[T_Stream]):
"""
Combines multiple listeners into one, serving connections from all of them at once.
- Any MultiListeners in the given collection of listeners will have their listeners moved into
- this one.
+ Any MultiListeners in the given collection of listeners will have their listeners
+ moved into this one.
- Extra attributes are provided from each listener, with each successive listener overriding any
- conflicting attributes from the previous one.
+ Extra attributes are provided from each listener, with each successive listener
+ overriding any conflicting attributes from the previous one.
:param listeners: listeners to serve
:type listeners: Sequence[Listener[T_Stream]]
diff --git a/contrib/python/anyio/anyio/streams/text.py b/contrib/python/anyio/anyio/streams/text.py
index bba2d3f7df..f1a11278e3 100644
--- a/contrib/python/anyio/anyio/streams/text.py
+++ b/contrib/python/anyio/anyio/streams/text.py
@@ -1,8 +1,9 @@
from __future__ import annotations
import codecs
+from collections.abc import Callable, Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Callable, Mapping
+from typing import Any
from ..abc import (
AnyByteReceiveStream,
@@ -19,16 +20,17 @@ class TextReceiveStream(ObjectReceiveStream[str]):
"""
Stream wrapper that decodes bytes to strings using the given encoding.
- Decoding is done using :class:`~codecs.IncrementalDecoder` which returns any completely
- received unicode characters as soon as they come in.
+ Decoding is done using :class:`~codecs.IncrementalDecoder` which returns any
+ completely received unicode characters as soon as they come in.
:param transport_stream: any bytes-based receive stream
- :param encoding: character encoding to use for decoding bytes to strings (defaults to
- ``utf-8``)
+ :param encoding: character encoding to use for decoding bytes to strings (defaults
+ to ``utf-8``)
:param errors: handling scheme for decoding errors (defaults to ``strict``; see the
`codecs module documentation`_ for a comprehensive list of options)
- .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects
+ .. _codecs module documentation:
+ https://docs.python.org/3/library/codecs.html#codec-objects
"""
transport_stream: AnyByteReceiveStream
@@ -62,12 +64,13 @@ class TextSendStream(ObjectSendStream[str]):
Sends strings to the wrapped stream as bytes using the given encoding.
:param AnyByteSendStream transport_stream: any bytes-based send stream
- :param str encoding: character encoding to use for encoding strings to bytes (defaults to
- ``utf-8``)
- :param str errors: handling scheme for encoding errors (defaults to ``strict``; see the
- `codecs module documentation`_ for a comprehensive list of options)
+ :param str encoding: character encoding to use for encoding strings to bytes
+ (defaults to ``utf-8``)
+ :param str errors: handling scheme for encoding errors (defaults to ``strict``; see
+ the `codecs module documentation`_ for a comprehensive list of options)
- .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects
+ .. _codecs module documentation:
+ https://docs.python.org/3/library/codecs.html#codec-objects
"""
transport_stream: AnyByteSendStream
@@ -93,19 +96,20 @@ class TextSendStream(ObjectSendStream[str]):
@dataclass(eq=False)
class TextStream(ObjectStream[str]):
"""
- A bidirectional stream that decodes bytes to strings on receive and encodes strings to bytes on
- send.
+ A bidirectional stream that decodes bytes to strings on receive and encodes strings
+ to bytes on send.
- Extra attributes will be provided from both streams, with the receive stream providing the
- values in case of a conflict.
+ Extra attributes will be provided from both streams, with the receive stream
+ providing the values in case of a conflict.
:param AnyByteStream transport_stream: any bytes-based stream
- :param str encoding: character encoding to use for encoding/decoding strings to/from bytes
- (defaults to ``utf-8``)
- :param str errors: handling scheme for encoding errors (defaults to ``strict``; see the
- `codecs module documentation`_ for a comprehensive list of options)
+ :param str encoding: character encoding to use for encoding/decoding strings to/from
+ bytes (defaults to ``utf-8``)
+ :param str errors: handling scheme for encoding errors (defaults to ``strict``; see
+ the `codecs module documentation`_ for a comprehensive list of options)
- .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects
+ .. _codecs module documentation:
+ https://docs.python.org/3/library/codecs.html#codec-objects
"""
transport_stream: AnyByteStream
diff --git a/contrib/python/anyio/anyio/streams/tls.py b/contrib/python/anyio/anyio/streams/tls.py
index 9f9e9fd89c..e913eedbbf 100644
--- a/contrib/python/anyio/anyio/streams/tls.py
+++ b/contrib/python/anyio/anyio/streams/tls.py
@@ -3,9 +3,11 @@ from __future__ import annotations
import logging
import re
import ssl
+import sys
+from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import wraps
-from typing import Any, Callable, Mapping, Tuple, TypeVar
+from typing import Any, Tuple, TypeVar
from .. import (
BrokenResourceError,
@@ -16,7 +18,13 @@ from .. import (
from .._core._typedattr import TypedAttributeSet, typed_attribute
from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
+if sys.version_info >= (3, 11):
+ from typing import TypeVarTuple, Unpack
+else:
+ from typing_extensions import TypeVarTuple, Unpack
+
T_Retval = TypeVar("T_Retval")
+PosArgsT = TypeVarTuple("PosArgsT")
_PCTRTT = Tuple[Tuple[str, str], ...]
_PCTRTTT = Tuple[_PCTRTT, ...]
@@ -31,8 +39,8 @@ class TLSAttribute(TypedAttributeSet):
#: the selected cipher
cipher: tuple[str, str, int] = typed_attribute()
#: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
- #: for more information)
- peer_certificate: dict[str, str | _PCTRTTT | _PCTRTT] | None = typed_attribute()
+ # for more information)
+ peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
#: the peer certificate in binary form
peer_certificate_binary: bytes | None = typed_attribute()
#: ``True`` if this is the server side of the connection
@@ -90,8 +98,9 @@ class TLSStream(ByteStream):
:param hostname: host name of the peer (if host name checking is desired)
:param ssl_context: the SSLContext object to use (if not provided, a secure
default will be created)
- :param standard_compatible: if ``False``, skip the closing handshake when closing the
- connection, and don't raise an exception if the peer does the same
+ :param standard_compatible: if ``False``, skip the closing handshake when
+ closing the connection, and don't raise an exception if the peer does the
+ same
:raises ~ssl.SSLError: if the TLS handshake fails
"""
@@ -124,7 +133,7 @@ class TLSStream(ByteStream):
return wrapper
async def _call_sslobject_method(
- self, func: Callable[..., T_Retval], *args: object
+ self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
) -> T_Retval:
while True:
try:
@@ -222,7 +231,9 @@ class TLSStream(ByteStream):
return {
**self.transport_stream.extra_attributes,
TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
- TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding,
+ TLSAttribute.channel_binding_tls_unique: (
+ self._ssl_object.get_channel_binding
+ ),
TLSAttribute.cipher: self._ssl_object.cipher,
TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
@@ -241,11 +252,12 @@ class TLSStream(ByteStream):
@dataclass(eq=False)
class TLSListener(Listener[TLSStream]):
"""
- A convenience listener that wraps another listener and auto-negotiates a TLS session on every
- accepted connection.
+ A convenience listener that wraps another listener and auto-negotiates a TLS session
+ on every accepted connection.
- If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is
- called to do whatever post-mortem processing is deemed necessary.
+ If the TLS handshake times out or raises an exception,
+ :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
+ deemed necessary.
Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
@@ -281,7 +293,13 @@ class TLSListener(Listener[TLSStream]):
# Log all except cancellation exceptions
if not isinstance(exc, get_cancelled_exc_class()):
- logging.getLogger(__name__).exception("Error during TLS handshake")
+ # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
+ # any asyncio implementation, so we explicitly pass the exception to log
+ # (https://github.com/python/cpython/issues/108668). Trio does not have this
+ # issue because it works around the CPython bug.
+ logging.getLogger(__name__).exception(
+ "Error during TLS handshake", exc_info=exc
+ )
# Only reraise base exceptions and cancellation exceptions
if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
diff --git a/contrib/python/anyio/anyio/to_process.py b/contrib/python/anyio/anyio/to_process.py
index 7ba9d44198..1ff06f0b25 100644
--- a/contrib/python/anyio/anyio/to_process.py
+++ b/contrib/python/anyio/anyio/to_process.py
@@ -5,10 +5,11 @@ import pickle
import subprocess
import sys
from collections import deque
+from collections.abc import Callable
from importlib.util import module_from_spec, spec_from_file_location
-from typing import Callable, TypeVar, cast
+from typing import TypeVar, cast
-from ._core._eventloop import current_time, get_asynclib, get_cancelled_exc_class
+from ._core._eventloop import current_time, get_async_backend, get_cancelled_exc_class
from ._core._exceptions import BrokenWorkerProcess
from ._core._subprocesses import open_process
from ._core._synchronization import CapacityLimiter
@@ -17,9 +18,16 @@ from .abc import ByteReceiveStream, ByteSendStream, Process
from .lowlevel import RunVar, checkpoint_if_cancelled
from .streams.buffered import BufferedByteReceiveStream
+if sys.version_info >= (3, 11):
+ from typing import TypeVarTuple, Unpack
+else:
+ from typing_extensions import TypeVarTuple, Unpack
+
WORKER_MAX_IDLE_TIME = 300 # 5 minutes
T_Retval = TypeVar("T_Retval")
+PosArgsT = TypeVarTuple("PosArgsT")
+
_process_pool_workers: RunVar[set[Process]] = RunVar("_process_pool_workers")
_process_pool_idle_workers: RunVar[deque[tuple[Process, float]]] = RunVar(
"_process_pool_idle_workers"
@@ -28,23 +36,24 @@ _default_process_limiter: RunVar[CapacityLimiter] = RunVar("_default_process_lim
async def run_sync(
- func: Callable[..., T_Retval],
- *args: object,
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
+ *args: Unpack[PosArgsT],
cancellable: bool = False,
limiter: CapacityLimiter | None = None,
) -> T_Retval:
"""
Call the given function with the given arguments in a worker process.
- If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled,
- the worker process running it will be abruptly terminated using SIGKILL (or
- ``terminateProcess()`` on Windows).
+ If the ``cancellable`` option is enabled and the task waiting for its completion is
+ cancelled, the worker process running it will be abruptly terminated using SIGKILL
+ (or ``terminateProcess()`` on Windows).
:param func: a callable
:param args: positional arguments for the callable
- :param cancellable: ``True`` to allow cancellation of the operation while it's running
- :param limiter: capacity limiter to use to limit the total amount of processes running
- (if omitted, the default limiter is used)
+ :param cancellable: ``True`` to allow cancellation of the operation while it's
+ running
+ :param limiter: capacity limiter to use to limit the total amount of processes
+ running (if omitted, the default limiter is used)
:return: an awaitable that yields the return value of the function.
"""
@@ -94,11 +103,11 @@ async def run_sync(
idle_workers = deque()
_process_pool_workers.set(workers)
_process_pool_idle_workers.set(idle_workers)
- get_asynclib().setup_process_pool_exit_at_shutdown(workers)
+ get_async_backend().setup_process_pool_exit_at_shutdown(workers)
- async with (limiter or current_default_process_limiter()):
- # Pop processes from the pool (starting from the most recently used) until we find one that
- # hasn't exited yet
+ async with limiter or current_default_process_limiter():
+ # Pop processes from the pool (starting from the most recently used) until we
+ # find one that hasn't exited yet
process: Process
while idle_workers:
process, idle_since = idle_workers.pop()
@@ -108,22 +117,22 @@ async def run_sync(
cast(ByteReceiveStream, process.stdout)
)
- # Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME seconds or
- # longer
+ # Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME
+ # seconds or longer
now = current_time()
killed_processes: list[Process] = []
while idle_workers:
if now - idle_workers[0][1] < WORKER_MAX_IDLE_TIME:
break
- process, idle_since = idle_workers.popleft()
- process.kill()
- workers.remove(process)
- killed_processes.append(process)
+ process_to_kill, idle_since = idle_workers.popleft()
+ process_to_kill.kill()
+ workers.remove(process_to_kill)
+ killed_processes.append(process_to_kill)
with CancelScope(shield=True):
- for process in killed_processes:
- await process.aclose()
+ for killed_process in killed_processes:
+ await killed_process.aclose()
break
@@ -172,7 +181,8 @@ async def run_sync(
def current_default_process_limiter() -> CapacityLimiter:
"""
- Return the capacity limiter that is used by default to limit the number of worker processes.
+ Return the capacity limiter that is used by default to limit the number of worker
+ processes.
:return: a capacity limiter object
@@ -214,8 +224,8 @@ def process_worker() -> None:
sys.path, main_module_path = args
del sys.modules["__main__"]
if main_module_path:
- # Load the parent's main module but as __mp_main__ instead of __main__
- # (like multiprocessing does) to avoid infinite recursion
+ # Load the parent's main module but as __mp_main__ instead of
+ # __main__ (like multiprocessing does) to avoid infinite recursion
try:
spec = spec_from_file_location("__mp_main__", main_module_path)
if spec and spec.loader:
diff --git a/contrib/python/anyio/anyio/to_thread.py b/contrib/python/anyio/anyio/to_thread.py
index 9315d1ecf1..5070516eb5 100644
--- a/contrib/python/anyio/anyio/to_thread.py
+++ b/contrib/python/anyio/anyio/to_thread.py
@@ -1,67 +1,69 @@
from __future__ import annotations
-from typing import Callable, TypeVar
+import sys
+from collections.abc import Callable
+from typing import TypeVar
from warnings import warn
-from ._core._eventloop import get_asynclib
+from ._core._eventloop import get_async_backend
from .abc import CapacityLimiter
+if sys.version_info >= (3, 11):
+ from typing import TypeVarTuple, Unpack
+else:
+ from typing_extensions import TypeVarTuple, Unpack
+
T_Retval = TypeVar("T_Retval")
+PosArgsT = TypeVarTuple("PosArgsT")
async def run_sync(
- func: Callable[..., T_Retval],
- *args: object,
- cancellable: bool = False,
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
+ *args: Unpack[PosArgsT],
+ abandon_on_cancel: bool = False,
+ cancellable: bool | None = None,
limiter: CapacityLimiter | None = None,
) -> T_Retval:
"""
Call the given function with the given arguments in a worker thread.
- If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled,
- the thread will still run its course but its return value (or any raised exception) will be
- ignored.
+ If the ``cancellable`` option is enabled and the task waiting for its completion is
+ cancelled, the thread will still run its course but its return value (or any raised
+ exception) will be ignored.
:param func: a callable
:param args: positional arguments for the callable
- :param cancellable: ``True`` to allow cancellation of the operation
+ :param abandon_on_cancel: ``True`` to abandon the thread (leaving it to run
+ unchecked on own) if the host task is cancelled, ``False`` to ignore
+ cancellations in the host task until the operation has completed in the worker
+ thread
+ :param cancellable: deprecated alias of ``abandon_on_cancel``; will override
+ ``abandon_on_cancel`` if both parameters are passed
:param limiter: capacity limiter to use to limit the total amount of threads running
(if omitted, the default limiter is used)
:return: an awaitable that yields the return value of the function.
"""
- return await get_asynclib().run_sync_in_worker_thread(
- func, *args, cancellable=cancellable, limiter=limiter
+ if cancellable is not None:
+ abandon_on_cancel = cancellable
+ warn(
+ "The `cancellable=` keyword argument to `anyio.to_thread.run_sync` is "
+ "deprecated since AnyIO 4.1.0; use `abandon_on_cancel=` instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return await get_async_backend().run_sync_in_worker_thread(
+ func, args, abandon_on_cancel=abandon_on_cancel, limiter=limiter
)
-async def run_sync_in_worker_thread(
- func: Callable[..., T_Retval],
- *args: object,
- cancellable: bool = False,
- limiter: CapacityLimiter | None = None,
-) -> T_Retval:
- warn(
- "run_sync_in_worker_thread() has been deprecated, use anyio.to_thread.run_sync() instead",
- DeprecationWarning,
- )
- return await run_sync(func, *args, cancellable=cancellable, limiter=limiter)
-
-
def current_default_thread_limiter() -> CapacityLimiter:
"""
- Return the capacity limiter that is used by default to limit the number of concurrent threads.
+ Return the capacity limiter that is used by default to limit the number of
+ concurrent threads.
:return: a capacity limiter object
"""
- return get_asynclib().current_default_thread_limiter()
-
-
-def current_default_worker_thread_limiter() -> CapacityLimiter:
- warn(
- "current_default_worker_thread_limiter() has been deprecated, "
- "use anyio.to_thread.current_default_thread_limiter() instead",
- DeprecationWarning,
- )
- return current_default_thread_limiter()
+ return get_async_backend().current_default_thread_limiter()
diff --git a/contrib/python/anyio/ya.make b/contrib/python/anyio/ya.make
index f8534a7d6c..7d21184567 100644
--- a/contrib/python/anyio/ya.make
+++ b/contrib/python/anyio/ya.make
@@ -2,7 +2,7 @@
PY3_LIBRARY()
-VERSION(3.7.1)
+VERSION(4.3.0)
LICENSE(MIT)
@@ -25,7 +25,6 @@ PY_SRCS(
anyio/_backends/_asyncio.py
anyio/_backends/_trio.py
anyio/_core/__init__.py
- anyio/_core/_compat.py
anyio/_core/_eventloop.py
anyio/_core/_exceptions.py
anyio/_core/_fileio.py
@@ -39,6 +38,7 @@ PY_SRCS(
anyio/_core/_testing.py
anyio/_core/_typedattr.py
anyio/abc/__init__.py
+ anyio/abc/_eventloop.py
anyio/abc/_resources.py
anyio/abc/_sockets.py
anyio/abc/_streams.py