diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2024-06-04 13:38:02 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2024-06-04 13:45:53 +0300 |
commit | 4106eea04f837a5213898cf991d9db841a841870 (patch) | |
tree | ac3f4ba4526c6a0e893d9930fc148a5c313fbe3d | |
parent | 75de3354ad96635806c975cd9b0dbddd4036ff80 (diff) | |
download | ydb-4106eea04f837a5213898cf991d9db841a841870.tar.gz |
Intermediate changes
-rw-r--r-- | contrib/python/anyio/.dist-info/METADATA | 2 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/_backends/_asyncio.py | 79 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/_backends/_trio.py | 42 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/_core/_eventloop.py | 15 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/_core/_fileio.py | 24 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/_core/_sockets.py | 15 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/_core/_testing.py | 12 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/_core/_typedattr.py | 10 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/abc/_eventloop.py | 8 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/abc/_resources.py | 2 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/abc/_tasks.py | 6 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/abc/_testing.py | 3 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/from_thread.py | 70 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/lowlevel.py | 6 | ||||
-rw-r--r-- | contrib/python/anyio/anyio/streams/memory.py | 52 | ||||
-rw-r--r-- | contrib/python/anyio/ya.make | 2 |
16 files changed, 230 insertions, 118 deletions
diff --git a/contrib/python/anyio/.dist-info/METADATA b/contrib/python/anyio/.dist-info/METADATA index e02715ca28..be13c8aa0f 100644 --- a/contrib/python/anyio/.dist-info/METADATA +++ b/contrib/python/anyio/.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: anyio -Version: 4.3.0 +Version: 4.4.0 Summary: High level compatibility layer for multiple asynchronous event loop implementations Author-email: Alex Grönholm <alex.gronholm@nextday.fi> License: MIT diff --git a/contrib/python/anyio/anyio/_backends/_asyncio.py b/contrib/python/anyio/anyio/_backends/_asyncio.py index 2699bf8146..43b7cb0e0c 100644 --- a/contrib/python/anyio/anyio/_backends/_asyncio.py +++ b/contrib/python/anyio/anyio/_backends/_asyncio.py @@ -7,6 +7,7 @@ import math import socket import sys import threading +import weakref from asyncio import ( AbstractEventLoop, CancelledError, @@ -488,7 +489,7 @@ class CancelScope(BaseCancelScope): if task is not current and (task is self._host_task or _task_started(task)): waiter = task._fut_waiter # type: ignore[attr-defined] if not isinstance(waiter, asyncio.Future) or not waiter.done(): - self._cancel_calls += 1 + origin._cancel_calls += 1 if sys.version_info >= (3, 9): task.cancel(f"Cancelled by cancel scope {id(origin):x}") else: @@ -596,14 +597,14 @@ class TaskState: itself because there are no guarantees about its implementation. """ - __slots__ = "parent_id", "cancel_scope" + __slots__ = "parent_id", "cancel_scope", "__weakref__" def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): self.parent_id = parent_id self.cancel_scope = cancel_scope -_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState] +_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary() # @@ -620,9 +621,10 @@ class _AsyncioTaskStatus(abc.TaskStatus): try: self._future.set_result(value) except asyncio.InvalidStateError: - raise RuntimeError( - "called 'started' twice on the same task status" - ) from None + if not self._future.cancelled(): + raise RuntimeError( + "called 'started' twice on the same task status" + ) from None task = cast(asyncio.Task, current_task()) _task_states[task].parent_id = self._parent_id @@ -713,6 +715,12 @@ class TaskGroup(abc.TaskGroup): exc = e if exc is not None: + # The future can only be in the cancelled state if the host task was + # cancelled, so return immediately instead of adding one more + # CancelledError to the exceptions list + if task_status_future is not None and task_status_future.cancelled(): + return + if task_status_future is None or task_status_future.done(): if not isinstance(exc, CancelledError): self._exceptions.append(exc) @@ -1047,6 +1055,7 @@ class StreamProtocol(asyncio.Protocol): read_event: asyncio.Event write_event: asyncio.Event exception: Exception | None = None + is_at_eof: bool = False def connection_made(self, transport: asyncio.BaseTransport) -> None: self.read_queue = deque() @@ -1068,6 +1077,7 @@ class StreamProtocol(asyncio.Protocol): self.read_event.set() def eof_received(self) -> bool | None: + self.is_at_eof = True self.read_event.set() return True @@ -1123,15 +1133,16 @@ class SocketStream(abc.SocketStream): async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: - await AsyncIOBackend.checkpoint() - if ( not self._protocol.read_event.is_set() and not self._transport.is_closing() + and not self._protocol.is_at_eof ): self._transport.resume_reading() await self._protocol.read_event.wait() self._transport.pause_reading() + else: + await AsyncIOBackend.checkpoint() try: chunk = self._protocol.read_queue.popleft() @@ -1651,7 +1662,7 @@ class Event(BaseEvent): await self._event.wait() def statistics(self) -> EventStatistics: - return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined] + return EventStatistics(len(self._event._waiters)) class CapacityLimiter(BaseCapacityLimiter): @@ -1751,7 +1762,7 @@ class CapacityLimiter(BaseCapacityLimiter): self._borrowers.remove(borrower) except KeyError: raise RuntimeError( - "this borrower isn't holding any of this CapacityLimiter's " "tokens" + "this borrower isn't holding any of this CapacityLimiter's tokens" ) from None # Notify the next task in line if this limiter has free capacity now @@ -1823,14 +1834,36 @@ class _SignalReceiver: # -def _create_task_info(task: asyncio.Task) -> TaskInfo: - task_state = _task_states.get(task) - if task_state is None: - parent_id = None - else: - parent_id = task_state.parent_id +class AsyncIOTaskInfo(TaskInfo): + def __init__(self, task: asyncio.Task): + task_state = _task_states.get(task) + if task_state is None: + parent_id = None + else: + parent_id = task_state.parent_id - return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro()) + super().__init__(id(task), parent_id, task.get_name(), task.get_coro()) + self._task = weakref.ref(task) + + def has_pending_cancellation(self) -> bool: + if not (task := self._task()): + # If the task isn't around anymore, it won't have a pending cancellation + return False + + if sys.version_info >= (3, 11): + if task.cancelling(): + return True + elif ( + isinstance(task._fut_waiter, asyncio.Future) + and task._fut_waiter.cancelled() + ): + return True + + if task_state := _task_states.get(task): + if cancel_scope := task_state.cancel_scope: + return cancel_scope.cancel_called or cancel_scope._parent_cancelled() + + return False class TestRunner(abc.TestRunner): @@ -1887,13 +1920,13 @@ class TestRunner(abc.TestRunner): "Multiple exceptions occurred in asynchronous callbacks", exceptions ) - @staticmethod async def _run_tests_and_fixtures( + self, receive_stream: MemoryObjectReceiveStream[ tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]] ], ) -> None: - with receive_stream: + with receive_stream, self._send_stream: async for coro, future in receive_stream: try: retval = await coro @@ -1990,7 +2023,7 @@ class AsyncIOBackend(AsyncBackend): finally: del _task_states[task] - debug = options.get("debug", False) + debug = options.get("debug", None) loop_factory = options.get("loop_factory", None) if loop_factory is None and options.get("use_uvloop", False): import uvloop @@ -2448,11 +2481,11 @@ class AsyncIOBackend(AsyncBackend): @classmethod def get_current_task(cls) -> TaskInfo: - return _create_task_info(current_task()) # type: ignore[arg-type] + return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type] @classmethod - def get_running_tasks(cls) -> list[TaskInfo]: - return [_create_task_info(task) for task in all_tasks() if not task.done()] + def get_running_tasks(cls) -> Sequence[TaskInfo]: + return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()] @classmethod async def wait_all_tasks_blocked(cls) -> None: diff --git a/contrib/python/anyio/anyio/_backends/_trio.py b/contrib/python/anyio/anyio/_backends/_trio.py index 1a47192e30..cf6f3db789 100644 --- a/contrib/python/anyio/anyio/_backends/_trio.py +++ b/contrib/python/anyio/anyio/_backends/_trio.py @@ -5,6 +5,7 @@ import math import socket import sys import types +import weakref from collections.abc import AsyncIterator, Iterable from concurrent.futures import Future from dataclasses import dataclass @@ -839,6 +840,24 @@ class TestRunner(abc.TestRunner): self._call_in_runner_task(test_func, **kwargs) +class TrioTaskInfo(TaskInfo): + def __init__(self, task: trio.lowlevel.Task): + parent_id = None + if task.parent_nursery and task.parent_nursery.parent_task: + parent_id = id(task.parent_nursery.parent_task) + + super().__init__(id(task), parent_id, task.name, task.coro) + self._task = weakref.proxy(task) + + def has_pending_cancellation(self) -> bool: + try: + return self._task._cancel_status.effectively_cancelled + except ReferenceError: + # If the task is no longer around, it surely doesn't have a cancellation + # pending + return False + + class TrioBackend(AsyncBackend): @classmethod def run( @@ -1040,15 +1059,13 @@ class TrioBackend(AsyncBackend): @overload async def create_unix_datagram_socket( cls, raw_socket: socket.socket, remote_path: None - ) -> abc.UNIXDatagramSocket: - ... + ) -> abc.UNIXDatagramSocket: ... @classmethod @overload async def create_unix_datagram_socket( cls, raw_socket: socket.socket, remote_path: str | bytes - ) -> abc.ConnectedUNIXDatagramSocket: - ... + ) -> abc.ConnectedUNIXDatagramSocket: ... @classmethod async def create_unix_datagram_socket( @@ -1127,28 +1144,19 @@ class TrioBackend(AsyncBackend): @classmethod def get_current_task(cls) -> TaskInfo: task = current_task() - - parent_id = None - if task.parent_nursery and task.parent_nursery.parent_task: - parent_id = id(task.parent_nursery.parent_task) - - return TaskInfo(id(task), parent_id, task.name, task.coro) + return TrioTaskInfo(task) @classmethod - def get_running_tasks(cls) -> list[TaskInfo]: + def get_running_tasks(cls) -> Sequence[TaskInfo]: root_task = current_root_task() assert root_task - task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)] + task_infos = [TrioTaskInfo(root_task)] nurseries = root_task.child_nurseries while nurseries: new_nurseries: list[trio.Nursery] = [] for nursery in nurseries: for task in nursery.child_tasks: - task_infos.append( - TaskInfo( - id(task), id(nursery.parent_task), task.name, task.coro - ) - ) + task_infos.append(TrioTaskInfo(task)) new_nurseries.extend(task.child_nurseries) nurseries = new_nurseries diff --git a/contrib/python/anyio/anyio/_core/_eventloop.py b/contrib/python/anyio/anyio/_core/_eventloop.py index a9c6e82585..6dcb458981 100644 --- a/contrib/python/anyio/anyio/_core/_eventloop.py +++ b/contrib/python/anyio/anyio/_core/_eventloop.py @@ -25,6 +25,7 @@ T_Retval = TypeVar("T_Retval") PosArgsT = TypeVarTuple("PosArgsT") threadlocals = threading.local() +loaded_backends: dict[str, type[AsyncBackend]] = {} def run( @@ -150,14 +151,16 @@ def claim_worker_thread( del threadlocals.current_token -def get_async_backend(asynclib_name: str | None = None) -> AsyncBackend: +def get_async_backend(asynclib_name: str | None = None) -> type[AsyncBackend]: if asynclib_name is None: asynclib_name = sniffio.current_async_library() - modulename = "anyio._backends._" + asynclib_name + # We use our own dict instead of sys.modules to get the already imported back-end + # class because the appropriate modules in sys.modules could potentially be only + # partially initialized try: - module = sys.modules[modulename] + return loaded_backends[asynclib_name] except KeyError: - module = import_module(modulename) - - return getattr(module, "backend_class") + module = import_module(f"anyio._backends._{asynclib_name}") + loaded_backends[asynclib_name] = module.backend_class + return module.backend_class diff --git a/contrib/python/anyio/anyio/_core/_fileio.py b/contrib/python/anyio/anyio/_core/_fileio.py index d054be693d..df2057fe34 100644 --- a/contrib/python/anyio/anyio/_core/_fileio.py +++ b/contrib/python/anyio/anyio/_core/_fileio.py @@ -100,12 +100,10 @@ class AsyncFile(AsyncResource, Generic[AnyStr]): return await to_thread.run_sync(self._fp.readinto1, b) @overload - async def write(self: AsyncFile[bytes], b: ReadableBuffer) -> int: - ... + async def write(self: AsyncFile[bytes], b: ReadableBuffer) -> int: ... @overload - async def write(self: AsyncFile[str], b: str) -> int: - ... + async def write(self: AsyncFile[str], b: str) -> int: ... async def write(self, b: ReadableBuffer | str) -> int: return await to_thread.run_sync(self._fp.write, b) @@ -113,12 +111,10 @@ class AsyncFile(AsyncResource, Generic[AnyStr]): @overload async def writelines( self: AsyncFile[bytes], lines: Iterable[ReadableBuffer] - ) -> None: - ... + ) -> None: ... @overload - async def writelines(self: AsyncFile[str], lines: Iterable[str]) -> None: - ... + async def writelines(self: AsyncFile[str], lines: Iterable[str]) -> None: ... async def writelines(self, lines: Iterable[ReadableBuffer] | Iterable[str]) -> None: return await to_thread.run_sync(self._fp.writelines, lines) @@ -146,8 +142,7 @@ async def open_file( newline: str | None = ..., closefd: bool = ..., opener: Callable[[str, int], int] | None = ..., -) -> AsyncFile[bytes]: - ... +) -> AsyncFile[bytes]: ... @overload @@ -160,8 +155,7 @@ async def open_file( newline: str | None = ..., closefd: bool = ..., opener: Callable[[str, int], int] | None = ..., -) -> AsyncFile[str]: - ... +) -> AsyncFile[str]: ... async def open_file( @@ -476,8 +470,7 @@ class Path: encoding: str | None = ..., errors: str | None = ..., newline: str | None = ..., - ) -> AsyncFile[bytes]: - ... + ) -> AsyncFile[bytes]: ... @overload async def open( @@ -487,8 +480,7 @@ class Path: encoding: str | None = ..., errors: str | None = ..., newline: str | None = ..., - ) -> AsyncFile[str]: - ... + ) -> AsyncFile[str]: ... async def open( self, diff --git a/contrib/python/anyio/anyio/_core/_sockets.py b/contrib/python/anyio/anyio/_core/_sockets.py index 0f0a3142fb..5e09cdbf0f 100644 --- a/contrib/python/anyio/anyio/_core/_sockets.py +++ b/contrib/python/anyio/anyio/_core/_sockets.py @@ -53,8 +53,7 @@ async def connect_tcp( tls_standard_compatible: bool = ..., tls_hostname: str, happy_eyeballs_delay: float = ..., -) -> TLSStream: - ... +) -> TLSStream: ... # ssl_context given @@ -68,8 +67,7 @@ async def connect_tcp( tls_standard_compatible: bool = ..., tls_hostname: str | None = ..., happy_eyeballs_delay: float = ..., -) -> TLSStream: - ... +) -> TLSStream: ... # tls=True @@ -84,8 +82,7 @@ async def connect_tcp( tls_standard_compatible: bool = ..., tls_hostname: str | None = ..., happy_eyeballs_delay: float = ..., -) -> TLSStream: - ... +) -> TLSStream: ... # tls=False @@ -100,8 +97,7 @@ async def connect_tcp( tls_standard_compatible: bool = ..., tls_hostname: str | None = ..., happy_eyeballs_delay: float = ..., -) -> SocketStream: - ... +) -> SocketStream: ... # No TLS arguments @@ -112,8 +108,7 @@ async def connect_tcp( *, local_host: IPAddressType | None = ..., happy_eyeballs_delay: float = ..., -) -> SocketStream: - ... +) -> SocketStream: ... async def connect_tcp( diff --git a/contrib/python/anyio/anyio/_core/_testing.py b/contrib/python/anyio/anyio/_core/_testing.py index 1dae3b193a..9e28b22766 100644 --- a/contrib/python/anyio/anyio/_core/_testing.py +++ b/contrib/python/anyio/anyio/_core/_testing.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Awaitable, Generator -from typing import Any +from typing import Any, cast from ._eventloop import get_async_backend @@ -45,8 +45,12 @@ class TaskInfo: def __repr__(self) -> str: return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})" - def _unwrap(self) -> TaskInfo: - return self + def has_pending_cancellation(self) -> bool: + """ + Return ``True`` if the task has a cancellation pending, ``False`` otherwise. + + """ + return False def get_current_task() -> TaskInfo: @@ -66,7 +70,7 @@ def get_running_tasks() -> list[TaskInfo]: :return: a list of task info objects """ - return get_async_backend().get_running_tasks() + return cast("list[TaskInfo]", get_async_backend().get_running_tasks()) async def wait_all_tasks_blocked() -> None: diff --git a/contrib/python/anyio/anyio/_core/_typedattr.py b/contrib/python/anyio/anyio/_core/_typedattr.py index 74c6b8fdcb..f358a448cb 100644 --- a/contrib/python/anyio/anyio/_core/_typedattr.py +++ b/contrib/python/anyio/anyio/_core/_typedattr.py @@ -50,12 +50,10 @@ class TypedAttributeProvider: return {} @overload - def extra(self, attribute: T_Attr) -> T_Attr: - ... + def extra(self, attribute: T_Attr) -> T_Attr: ... @overload - def extra(self, attribute: T_Attr, default: T_Default) -> T_Attr | T_Default: - ... + def extra(self, attribute: T_Attr, default: T_Default) -> T_Attr | T_Default: ... @final def extra(self, attribute: Any, default: object = undefined) -> object: @@ -73,9 +71,11 @@ class TypedAttributeProvider: """ try: - return self.extra_attributes[attribute]() + getter = self.extra_attributes[attribute] except KeyError: if default is undefined: raise TypedAttributeLookupError("Attribute not found") from None else: return default + + return getter() diff --git a/contrib/python/anyio/anyio/abc/_eventloop.py b/contrib/python/anyio/anyio/abc/_eventloop.py index 4470d83d24..a50afefaa0 100644 --- a/contrib/python/anyio/anyio/abc/_eventloop.py +++ b/contrib/python/anyio/anyio/abc/_eventloop.py @@ -303,15 +303,13 @@ class AsyncBackend(metaclass=ABCMeta): @overload async def create_unix_datagram_socket( cls, raw_socket: socket, remote_path: None - ) -> UNIXDatagramSocket: - ... + ) -> UNIXDatagramSocket: ... @classmethod @overload async def create_unix_datagram_socket( cls, raw_socket: socket, remote_path: str | bytes - ) -> ConnectedUNIXDatagramSocket: - ... + ) -> ConnectedUNIXDatagramSocket: ... @classmethod @abstractmethod @@ -378,7 +376,7 @@ class AsyncBackend(metaclass=ABCMeta): @classmethod @abstractmethod - def get_running_tasks(cls) -> list[TaskInfo]: + def get_running_tasks(cls) -> Sequence[TaskInfo]: pass @classmethod diff --git a/contrib/python/anyio/anyio/abc/_resources.py b/contrib/python/anyio/anyio/abc/_resources.py index 9693835bae..10df115a7b 100644 --- a/contrib/python/anyio/anyio/abc/_resources.py +++ b/contrib/python/anyio/anyio/abc/_resources.py @@ -15,6 +15,8 @@ class AsyncResource(metaclass=ABCMeta): and calls :meth:`aclose` on exit. """ + __slots__ = () + async def __aenter__(self: T) -> T: return self diff --git a/contrib/python/anyio/anyio/abc/_tasks.py b/contrib/python/anyio/anyio/abc/_tasks.py index 7ad4938cb4..88aecf3833 100644 --- a/contrib/python/anyio/anyio/abc/_tasks.py +++ b/contrib/python/anyio/anyio/abc/_tasks.py @@ -21,12 +21,10 @@ PosArgsT = TypeVarTuple("PosArgsT") class TaskStatus(Protocol[T_contra]): @overload - def started(self: TaskStatus[None]) -> None: - ... + def started(self: TaskStatus[None]) -> None: ... @overload - def started(self, value: T_contra) -> None: - ... + def started(self, value: T_contra) -> None: ... def started(self, value: T_contra | None = None) -> None: """ diff --git a/contrib/python/anyio/anyio/abc/_testing.py b/contrib/python/anyio/anyio/abc/_testing.py index 4d70b9ec6b..7c50ed76dc 100644 --- a/contrib/python/anyio/anyio/abc/_testing.py +++ b/contrib/python/anyio/anyio/abc/_testing.py @@ -23,8 +23,7 @@ class TestRunner(metaclass=ABCMeta): exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, - ) -> bool | None: - ... + ) -> bool | None: ... @abstractmethod def run_asyncgen_fixture( diff --git a/contrib/python/anyio/anyio/from_thread.py b/contrib/python/anyio/anyio/from_thread.py index 4a987031fe..88a854bb91 100644 --- a/contrib/python/anyio/anyio/from_thread.py +++ b/contrib/python/anyio/anyio/from_thread.py @@ -5,6 +5,7 @@ import threading from collections.abc import Awaitable, Callable, Generator from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from contextlib import AbstractContextManager, contextmanager +from dataclasses import dataclass, field from inspect import isawaitable from types import TracebackType from typing import ( @@ -261,14 +262,12 @@ class BlockingPortal: self, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT], - ) -> T_Retval: - ... + ) -> T_Retval: ... @overload def call( self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] - ) -> T_Retval: - ... + ) -> T_Retval: ... def call( self, @@ -293,8 +292,7 @@ class BlockingPortal: func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT], name: object = None, - ) -> Future[T_Retval]: - ... + ) -> Future[T_Retval]: ... @overload def start_task_soon( @@ -302,8 +300,7 @@ class BlockingPortal: func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT], name: object = None, - ) -> Future[T_Retval]: - ... + ) -> Future[T_Retval]: ... def start_task_soon( self, @@ -395,6 +392,63 @@ class BlockingPortal: return _BlockingAsyncContextManager(cm, self) +@dataclass +class BlockingPortalProvider: + """ + A manager for a blocking portal. Used as a context manager. The first thread to + enter this context manager causes a blocking portal to be started with the specific + parameters, and the last thread to exit causes the portal to be shut down. Thus, + there will be exactly one blocking portal running in this context as long as at + least one thread has entered this context manager. + + The parameters are the same as for :func:`~anyio.run`. + + :param backend: name of the backend + :param backend_options: backend options + + .. versionadded:: 4.4 + """ + + backend: str = "asyncio" + backend_options: dict[str, Any] | None = None + _lock: threading.Lock = field(init=False, default_factory=threading.Lock) + _leases: int = field(init=False, default=0) + _portal: BlockingPortal = field(init=False) + _portal_cm: AbstractContextManager[BlockingPortal] | None = field( + init=False, default=None + ) + + def __enter__(self) -> BlockingPortal: + with self._lock: + if self._portal_cm is None: + self._portal_cm = start_blocking_portal( + self.backend, self.backend_options + ) + self._portal = self._portal_cm.__enter__() + + self._leases += 1 + return self._portal + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + portal_cm: AbstractContextManager[BlockingPortal] | None = None + with self._lock: + assert self._portal_cm + assert self._leases > 0 + self._leases -= 1 + if not self._leases: + portal_cm = self._portal_cm + self._portal_cm = None + del self._portal + + if portal_cm: + portal_cm.__exit__(None, None, None) + + @contextmanager def start_blocking_portal( backend: str = "asyncio", backend_options: dict[str, Any] | None = None diff --git a/contrib/python/anyio/anyio/lowlevel.py b/contrib/python/anyio/anyio/lowlevel.py index a9e10f430a..14c7668cb3 100644 --- a/contrib/python/anyio/anyio/lowlevel.py +++ b/contrib/python/anyio/anyio/lowlevel.py @@ -114,12 +114,10 @@ class RunVar(Generic[T]): return run_vars @overload - def get(self, default: D) -> T | D: - ... + def get(self, default: D) -> T | D: ... @overload - def get(self) -> T: - ... + def get(self) -> T: ... def get( self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET diff --git a/contrib/python/anyio/anyio/streams/memory.py b/contrib/python/anyio/anyio/streams/memory.py index bc2425b76f..6840e6242f 100644 --- a/contrib/python/anyio/anyio/streams/memory.py +++ b/contrib/python/anyio/anyio/streams/memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from collections import OrderedDict, deque from dataclasses import dataclass, field from types import TracebackType @@ -11,6 +12,7 @@ from .. import ( EndOfStream, WouldBlock, ) +from .._core._testing import TaskInfo, get_current_task from ..abc import Event, ObjectReceiveStream, ObjectSendStream from ..lowlevel import checkpoint @@ -32,12 +34,18 @@ class MemoryObjectStreamStatistics(NamedTuple): @dataclass(eq=False) +class MemoryObjectItemReceiver(Generic[T_Item]): + task_info: TaskInfo = field(init=False, default_factory=get_current_task) + item: T_Item = field(init=False) + + +@dataclass(eq=False) class MemoryObjectStreamState(Generic[T_Item]): max_buffer_size: float = field() buffer: deque[T_Item] = field(init=False, default_factory=deque) open_send_channels: int = field(init=False, default=0) open_receive_channels: int = field(init=False, default=0) - waiting_receivers: OrderedDict[Event, list[T_Item]] = field( + waiting_receivers: OrderedDict[Event, MemoryObjectItemReceiver[T_Item]] = field( init=False, default_factory=OrderedDict ) waiting_senders: OrderedDict[Event, T_Item] = field( @@ -98,17 +106,17 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]): except WouldBlock: # Add ourselves in the queue receive_event = Event() - container: list[T_co] = [] - self._state.waiting_receivers[receive_event] = container + receiver = MemoryObjectItemReceiver[T_co]() + self._state.waiting_receivers[receive_event] = receiver try: await receive_event.wait() finally: self._state.waiting_receivers.pop(receive_event, None) - if container: - return container[0] - else: + try: + return receiver.item + except AttributeError: raise EndOfStream def clone(self) -> MemoryObjectReceiveStream[T_co]: @@ -164,6 +172,14 @@ class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]): ) -> None: self.close() + def __del__(self) -> None: + if not self._closed: + warnings.warn( + f"Unclosed <{self.__class__.__name__}>", + ResourceWarning, + source=self, + ) + @dataclass(eq=False) class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): @@ -190,11 +206,14 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): if not self._state.open_receive_channels: raise BrokenResourceError - if self._state.waiting_receivers: - receive_event, container = self._state.waiting_receivers.popitem(last=False) - container.append(item) - receive_event.set() - elif len(self._state.buffer) < self._state.max_buffer_size: + while self._state.waiting_receivers: + receive_event, receiver = self._state.waiting_receivers.popitem(last=False) + if not receiver.task_info.has_pending_cancellation(): + receiver.item = item + receive_event.set() + return + + if len(self._state.buffer) < self._state.max_buffer_size: self._state.buffer.append(item) else: raise WouldBlock @@ -225,7 +244,8 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): self._state.waiting_senders.pop(send_event, None) raise - if self._state.waiting_senders.pop(send_event, None): + if send_event in self._state.waiting_senders: + del self._state.waiting_senders[send_event] raise BrokenResourceError from None def clone(self) -> MemoryObjectSendStream[T_contra]: @@ -281,3 +301,11 @@ class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): exc_tb: TracebackType | None, ) -> None: self.close() + + def __del__(self) -> None: + if not self._closed: + warnings.warn( + f"Unclosed <{self.__class__.__name__}>", + ResourceWarning, + source=self, + ) diff --git a/contrib/python/anyio/ya.make b/contrib/python/anyio/ya.make index 7d21184567..9062121337 100644 --- a/contrib/python/anyio/ya.make +++ b/contrib/python/anyio/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(4.3.0) +VERSION(4.4.0) LICENSE(MIT) |