diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2023-11-12 21:25:31 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2023-11-12 21:39:54 +0300 |
commit | d28c55ab25cc8cedab8a5f4736c0d66e88b3da95 (patch) | |
tree | 73d373709b74fa2baaa4fe02a40a77c0a5baf6b7 /contrib/python | |
parent | 35b17f4f3b6e0ed855e7e47d3f1eb57470388a2c (diff) | |
download | ydb-d28c55ab25cc8cedab8a5f4736c0d66e88b3da95.tar.gz |
Intermediate changes
Diffstat (limited to 'contrib/python')
140 files changed, 29884 insertions, 0 deletions
diff --git a/contrib/python/anyio/.dist-info/METADATA b/contrib/python/anyio/.dist-info/METADATA new file mode 100644 index 0000000000..5e46476e02 --- /dev/null +++ b/contrib/python/anyio/.dist-info/METADATA @@ -0,0 +1,105 @@ +Metadata-Version: 2.1 +Name: anyio +Version: 3.7.1 +Summary: High level compatibility layer for multiple asynchronous event loop implementations +Author-email: Alex Grönholm <alex.gronholm@nextday.fi> +License: MIT +Project-URL: Documentation, https://anyio.readthedocs.io/en/latest/ +Project-URL: Changelog, https://anyio.readthedocs.io/en/stable/versionhistory.html +Project-URL: Source code, https://github.com/agronholm/anyio +Project-URL: Issue tracker, https://github.com/agronholm/anyio/issues +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Framework :: AnyIO +Classifier: Typing :: Typed +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: idna (>=2.8) +Requires-Dist: sniffio (>=1.1) +Requires-Dist: exceptiongroup ; python_version < "3.11" +Requires-Dist: typing-extensions ; python_version < "3.8" +Provides-Extra: doc +Requires-Dist: packaging ; extra == 'doc' +Requires-Dist: Sphinx ; extra == 'doc' +Requires-Dist: sphinx-rtd-theme (>=1.2.2) ; extra == 'doc' +Requires-Dist: sphinxcontrib-jquery ; extra == 'doc' +Requires-Dist: sphinx-autodoc-typehints (>=1.2.0) ; extra == 'doc' +Provides-Extra: test +Requires-Dist: anyio[trio] ; extra == 'test' +Requires-Dist: coverage[toml] (>=4.5) ; extra == 'test' +Requires-Dist: hypothesis (>=4.0) ; extra == 'test' +Requires-Dist: psutil (>=5.9) ; extra == 'test' +Requires-Dist: pytest (>=7.0) ; extra == 'test' +Requires-Dist: pytest-mock (>=3.6.1) ; extra == 'test' +Requires-Dist: trustme ; extra == 'test' +Requires-Dist: uvloop (>=0.17) ; (python_version < "3.12" and platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test' +Requires-Dist: mock (>=4) ; (python_version < "3.8") and extra == 'test' +Provides-Extra: trio +Requires-Dist: trio (<0.22) ; extra == 'trio' + +.. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg + :target: https://github.com/agronholm/anyio/actions/workflows/test.yml + :alt: Build Status +.. image:: https://coveralls.io/repos/github/agronholm/anyio/badge.svg?branch=master + :target: https://coveralls.io/github/agronholm/anyio?branch=master + :alt: Code Coverage +.. image:: https://readthedocs.org/projects/anyio/badge/?version=latest + :target: https://anyio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation +.. image:: https://badges.gitter.im/gitterHQ/gitter.svg + :target: https://gitter.im/python-trio/AnyIO + :alt: Gitter chat + +AnyIO is an asynchronous networking and concurrency library that works on top of either asyncio_ or +trio_. It implements trio-like `structured concurrency`_ (SC) on top of asyncio and works in harmony +with the native SC of trio itself. + +Applications and libraries written against AnyIO's API will run unmodified on either asyncio_ or +trio_. AnyIO can also be adopted into a library or application incrementally – bit by bit, no full +refactoring necessary. It will blend in with the native libraries of your chosen backend. + +Documentation +------------- + +View full documentation at: https://anyio.readthedocs.io/ + +Features +-------- + +AnyIO offers the following functionality: + +* Task groups (nurseries_ in trio terminology) +* High-level networking (TCP, UDP and UNIX sockets) + + * `Happy eyeballs`_ algorithm for TCP connections (more robust than that of asyncio on Python + 3.8) + * async/await style UDP sockets (unlike asyncio where you still have to use Transports and + Protocols) + +* A versatile API for byte streams and object streams +* Inter-task synchronization and communication (locks, conditions, events, semaphores, object + streams) +* Worker threads +* Subprocesses +* Asynchronous file I/O (using worker threads) +* Signal handling + +AnyIO also comes with its own pytest_ plugin which also supports asynchronous fixtures. +It even works with the popular Hypothesis_ library. + +.. _asyncio: https://docs.python.org/3/library/asyncio.html +.. _trio: https://github.com/python-trio/trio +.. _structured concurrency: https://en.wikipedia.org/wiki/Structured_concurrency +.. _nurseries: https://trio.readthedocs.io/en/stable/reference-core.html#nurseries-and-spawning +.. _Happy eyeballs: https://en.wikipedia.org/wiki/Happy_Eyeballs +.. _pytest: https://docs.pytest.org/en/latest/ +.. _Hypothesis: https://hypothesis.works/ diff --git a/contrib/python/anyio/.dist-info/entry_points.txt b/contrib/python/anyio/.dist-info/entry_points.txt new file mode 100644 index 0000000000..44dd9bdc30 --- /dev/null +++ b/contrib/python/anyio/.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[pytest11] +anyio = anyio.pytest_plugin diff --git a/contrib/python/anyio/.dist-info/top_level.txt b/contrib/python/anyio/.dist-info/top_level.txt new file mode 100644 index 0000000000..c77c069ecc --- /dev/null +++ b/contrib/python/anyio/.dist-info/top_level.txt @@ -0,0 +1 @@ +anyio diff --git a/contrib/python/anyio/LICENSE b/contrib/python/anyio/LICENSE new file mode 100644 index 0000000000..104eebf5a3 --- /dev/null +++ b/contrib/python/anyio/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2018 Alex Grönholm + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/contrib/python/anyio/README.rst b/contrib/python/anyio/README.rst new file mode 100644 index 0000000000..35afc7e312 --- /dev/null +++ b/contrib/python/anyio/README.rst @@ -0,0 +1,57 @@ +.. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg + :target: https://github.com/agronholm/anyio/actions/workflows/test.yml + :alt: Build Status +.. image:: https://coveralls.io/repos/github/agronholm/anyio/badge.svg?branch=master + :target: https://coveralls.io/github/agronholm/anyio?branch=master + :alt: Code Coverage +.. image:: https://readthedocs.org/projects/anyio/badge/?version=latest + :target: https://anyio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation +.. image:: https://badges.gitter.im/gitterHQ/gitter.svg + :target: https://gitter.im/python-trio/AnyIO + :alt: Gitter chat + +AnyIO is an asynchronous networking and concurrency library that works on top of either asyncio_ or +trio_. It implements trio-like `structured concurrency`_ (SC) on top of asyncio and works in harmony +with the native SC of trio itself. + +Applications and libraries written against AnyIO's API will run unmodified on either asyncio_ or +trio_. AnyIO can also be adopted into a library or application incrementally – bit by bit, no full +refactoring necessary. It will blend in with the native libraries of your chosen backend. + +Documentation +------------- + +View full documentation at: https://anyio.readthedocs.io/ + +Features +-------- + +AnyIO offers the following functionality: + +* Task groups (nurseries_ in trio terminology) +* High-level networking (TCP, UDP and UNIX sockets) + + * `Happy eyeballs`_ algorithm for TCP connections (more robust than that of asyncio on Python + 3.8) + * async/await style UDP sockets (unlike asyncio where you still have to use Transports and + Protocols) + +* A versatile API for byte streams and object streams +* Inter-task synchronization and communication (locks, conditions, events, semaphores, object + streams) +* Worker threads +* Subprocesses +* Asynchronous file I/O (using worker threads) +* Signal handling + +AnyIO also comes with its own pytest_ plugin which also supports asynchronous fixtures. +It even works with the popular Hypothesis_ library. + +.. _asyncio: https://docs.python.org/3/library/asyncio.html +.. _trio: https://github.com/python-trio/trio +.. _structured concurrency: https://en.wikipedia.org/wiki/Structured_concurrency +.. _nurseries: https://trio.readthedocs.io/en/stable/reference-core.html#nurseries-and-spawning +.. _Happy eyeballs: https://en.wikipedia.org/wiki/Happy_Eyeballs +.. _pytest: https://docs.pytest.org/en/latest/ +.. _Hypothesis: https://hypothesis.works/ diff --git a/contrib/python/anyio/anyio/__init__.py b/contrib/python/anyio/anyio/__init__.py new file mode 100644 index 0000000000..29fb3561e4 --- /dev/null +++ b/contrib/python/anyio/anyio/__init__.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +__all__ = ( + "maybe_async", + "maybe_async_cm", + "run", + "sleep", + "sleep_forever", + "sleep_until", + "current_time", + "get_all_backends", + "get_cancelled_exc_class", + "BrokenResourceError", + "BrokenWorkerProcess", + "BusyResourceError", + "ClosedResourceError", + "DelimiterNotFound", + "EndOfStream", + "ExceptionGroup", + "IncompleteRead", + "TypedAttributeLookupError", + "WouldBlock", + "AsyncFile", + "Path", + "open_file", + "wrap_file", + "aclose_forcefully", + "open_signal_receiver", + "connect_tcp", + "connect_unix", + "create_tcp_listener", + "create_unix_listener", + "create_udp_socket", + "create_connected_udp_socket", + "getaddrinfo", + "getnameinfo", + "wait_socket_readable", + "wait_socket_writable", + "create_memory_object_stream", + "run_process", + "open_process", + "create_lock", + "CapacityLimiter", + "CapacityLimiterStatistics", + "Condition", + "ConditionStatistics", + "Event", + "EventStatistics", + "Lock", + "LockStatistics", + "Semaphore", + "SemaphoreStatistics", + "create_condition", + "create_event", + "create_semaphore", + "create_capacity_limiter", + "open_cancel_scope", + "fail_after", + "move_on_after", + "current_effective_deadline", + "TASK_STATUS_IGNORED", + "CancelScope", + "create_task_group", + "TaskInfo", + "get_current_task", + "get_running_tasks", + "wait_all_tasks_blocked", + "run_sync_in_worker_thread", + "run_async_from_thread", + "run_sync_from_thread", + "current_default_worker_thread_limiter", + "create_blocking_portal", + "start_blocking_portal", + "typed_attribute", + "TypedAttributeSet", + "TypedAttributeProvider", +) + +from typing import Any + +from ._core._compat import maybe_async, maybe_async_cm +from ._core._eventloop import ( + current_time, + get_all_backends, + get_cancelled_exc_class, + run, + sleep, + sleep_forever, + sleep_until, +) +from ._core._exceptions import ( + BrokenResourceError, + BrokenWorkerProcess, + BusyResourceError, + ClosedResourceError, + DelimiterNotFound, + EndOfStream, + ExceptionGroup, + IncompleteRead, + TypedAttributeLookupError, + WouldBlock, +) +from ._core._fileio import AsyncFile, Path, open_file, wrap_file +from ._core._resources import aclose_forcefully +from ._core._signals import open_signal_receiver +from ._core._sockets import ( + connect_tcp, + connect_unix, + create_connected_udp_socket, + create_tcp_listener, + create_udp_socket, + create_unix_listener, + getaddrinfo, + getnameinfo, + wait_socket_readable, + wait_socket_writable, +) +from ._core._streams import create_memory_object_stream +from ._core._subprocesses import open_process, run_process +from ._core._synchronization import ( + CapacityLimiter, + CapacityLimiterStatistics, + Condition, + ConditionStatistics, + Event, + EventStatistics, + Lock, + LockStatistics, + Semaphore, + SemaphoreStatistics, + create_capacity_limiter, + create_condition, + create_event, + create_lock, + create_semaphore, +) +from ._core._tasks import ( + TASK_STATUS_IGNORED, + CancelScope, + create_task_group, + current_effective_deadline, + fail_after, + move_on_after, + open_cancel_scope, +) +from ._core._testing import ( + TaskInfo, + get_current_task, + get_running_tasks, + wait_all_tasks_blocked, +) +from ._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute + +# Re-exported here, for backwards compatibility +# isort: off +from .to_thread import current_default_worker_thread_limiter, run_sync_in_worker_thread +from .from_thread import ( + create_blocking_portal, + run_async_from_thread, + run_sync_from_thread, + start_blocking_portal, +) + +# Re-export imports so they look like they live directly in this package +key: str +value: Any +for key, value in list(locals().items()): + if getattr(value, "__module__", "").startswith("anyio."): + value.__module__ = __name__ diff --git a/contrib/python/anyio/anyio/_backends/__init__.py b/contrib/python/anyio/anyio/_backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/anyio/anyio/_backends/__init__.py diff --git a/contrib/python/anyio/anyio/_backends/_asyncio.py b/contrib/python/anyio/anyio/_backends/_asyncio.py new file mode 100644 index 0000000000..bfdb4ea7e1 --- /dev/null +++ b/contrib/python/anyio/anyio/_backends/_asyncio.py @@ -0,0 +1,2117 @@ +from __future__ import annotations + +import array +import asyncio +import concurrent.futures +import math +import socket +import sys +from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined] +from collections import OrderedDict, deque +from concurrent.futures import Future +from contextvars import Context, copy_context +from dataclasses import dataclass +from functools import partial, wraps +from inspect import ( + CORO_RUNNING, + CORO_SUSPENDED, + GEN_RUNNING, + GEN_SUSPENDED, + getcoroutinestate, + getgeneratorstate, +) +from io import IOBase +from os import PathLike +from queue import Queue +from socket import AddressFamily, SocketKind +from threading import Thread +from types import TracebackType +from typing import ( + IO, + Any, + AsyncGenerator, + Awaitable, + Callable, + Collection, + Coroutine, + Generator, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) +from weakref import WeakKeyDictionary + +import sniffio + +from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc +from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable +from .._core._eventloop import claim_worker_thread, threadlocals +from .._core._exceptions import ( + BrokenResourceError, + BusyResourceError, + ClosedResourceError, + EndOfStream, + WouldBlock, +) +from .._core._exceptions import ExceptionGroup as BaseExceptionGroup +from .._core._sockets import GetAddrInfoReturnType, convert_ipv6_sockaddr +from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter +from .._core._synchronization import Event as BaseEvent +from .._core._synchronization import ResourceGuard +from .._core._tasks import CancelScope as BaseCancelScope +from ..abc import IPSockAddrType, UDPPacketType +from ..lowlevel import RunVar + +if sys.version_info >= (3, 8): + + def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: + return task.get_coro() + +else: + + def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: + return task._coro + + +from asyncio import all_tasks, create_task, current_task, get_running_loop +from asyncio import run as native_run + + +def _get_task_callbacks(task: asyncio.Task) -> Iterable[Callable]: + return [cb for cb, context in task._callbacks] + + +T_Retval = TypeVar("T_Retval") +T_contra = TypeVar("T_contra", contravariant=True) + +# Check whether there is native support for task names in asyncio (3.8+) +_native_task_names = hasattr(asyncio.Task, "get_name") + + +_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") + + +def find_root_task() -> asyncio.Task: + root_task = _root_task.get(None) + if root_task is not None and not root_task.done(): + return root_task + + # Look for a task that has been started via run_until_complete() + for task in all_tasks(): + if task._callbacks and not task.done(): + for cb in _get_task_callbacks(task): + if ( + cb is _run_until_complete_cb + or getattr(cb, "__module__", None) == "uvloop.loop" + ): + _root_task.set(task) + return task + + # Look up the topmost task in the AnyIO task tree, if possible + task = cast(asyncio.Task, current_task()) + state = _task_states.get(task) + if state: + cancel_scope = state.cancel_scope + while cancel_scope and cancel_scope._parent_scope is not None: + cancel_scope = cancel_scope._parent_scope + + if cancel_scope is not None: + return cast(asyncio.Task, cancel_scope._host_task) + + return task + + +def get_callable_name(func: Callable) -> str: + module = getattr(func, "__module__", None) + qualname = getattr(func, "__qualname__", None) + return ".".join([x for x in (module, qualname) if x]) + + +# +# Event loop +# + +_run_vars = ( + WeakKeyDictionary() +) # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] + +current_token = get_running_loop + + +def _task_started(task: asyncio.Task) -> bool: + """Return ``True`` if the task has been started and has not finished.""" + coro = cast(Coroutine[Any, Any, Any], get_coro(task)) + try: + return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) + except AttributeError: + try: + return getgeneratorstate(cast(Generator, coro)) in ( + GEN_RUNNING, + GEN_SUSPENDED, + ) + except AttributeError: + # task coro is async_genenerator_asend https://bugs.python.org/issue37771 + raise Exception(f"Cannot determine if task {task} has started or not") + + +def _maybe_set_event_loop_policy( + policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool +) -> None: + # On CPython, use uvloop when possible if no other policy has been given and if not + # explicitly disabled + if policy is None and use_uvloop and sys.implementation.name == "cpython": + try: + import uvloop + except ImportError: + pass + else: + # Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier) + if not hasattr( + asyncio.AbstractEventLoop, "shutdown_default_executor" + ) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"): + policy = uvloop.EventLoopPolicy() + + if policy is not None: + asyncio.set_event_loop_policy(policy) + + +def run( + func: Callable[..., Awaitable[T_Retval]], + *args: object, + debug: bool = False, + use_uvloop: bool = False, + policy: asyncio.AbstractEventLoopPolicy | None = None, +) -> T_Retval: + @wraps(func) + async def wrapper() -> T_Retval: + task = cast(asyncio.Task, current_task()) + task_state = TaskState(None, get_callable_name(func), None) + _task_states[task] = task_state + if _native_task_names: + task.set_name(task_state.name) + + try: + return await func(*args) + finally: + del _task_states[task] + + _maybe_set_event_loop_policy(policy, use_uvloop) + return native_run(wrapper(), debug=debug) + + +# +# Miscellaneous +# + +sleep = asyncio.sleep + + +# +# Timeouts and cancellation +# + +CancelledError = asyncio.CancelledError + + +class CancelScope(BaseCancelScope): + def __new__( + cls, *, deadline: float = math.inf, shield: bool = False + ) -> CancelScope: + return object.__new__(cls) + + def __init__(self, deadline: float = math.inf, shield: bool = False): + self._deadline = deadline + self._shield = shield + self._parent_scope: CancelScope | None = None + self._cancel_called = False + self._active = False + self._timeout_handle: asyncio.TimerHandle | None = None + self._cancel_handle: asyncio.Handle | None = None + self._tasks: set[asyncio.Task] = set() + self._host_task: asyncio.Task | None = None + self._timeout_expired = False + self._cancel_calls: int = 0 + + def __enter__(self) -> CancelScope: + if self._active: + raise RuntimeError( + "Each CancelScope may only be used for a single 'with' block" + ) + + self._host_task = host_task = cast(asyncio.Task, current_task()) + self._tasks.add(host_task) + try: + task_state = _task_states[host_task] + except KeyError: + task_name = host_task.get_name() if _native_task_names else None + task_state = TaskState(None, task_name, self) + _task_states[host_task] = task_state + else: + self._parent_scope = task_state.cancel_scope + task_state.cancel_scope = self + + self._timeout() + self._active = True + + # Start cancelling the host task if the scope was cancelled before entering + if self._cancel_called: + self._deliver_cancellation() + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + if not self._active: + raise RuntimeError("This cancel scope is not active") + if current_task() is not self._host_task: + raise RuntimeError( + "Attempted to exit cancel scope in a different task than it was " + "entered in" + ) + + assert self._host_task is not None + host_task_state = _task_states.get(self._host_task) + if host_task_state is None or host_task_state.cancel_scope is not self: + raise RuntimeError( + "Attempted to exit a cancel scope that isn't the current tasks's " + "current cancel scope" + ) + + self._active = False + if self._timeout_handle: + self._timeout_handle.cancel() + self._timeout_handle = None + + self._tasks.remove(self._host_task) + + host_task_state.cancel_scope = self._parent_scope + + # Restart the cancellation effort in the farthest directly cancelled parent scope if this + # one was shielded + if self._shield: + self._deliver_cancellation_to_parent() + + if exc_val is not None: + exceptions = ( + exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val] + ) + if all(isinstance(exc, CancelledError) for exc in exceptions): + if self._timeout_expired: + return self._uncancel() + elif not self._cancel_called: + # Task was cancelled natively + return None + elif not self._parent_cancelled(): + # This scope was directly cancelled + return self._uncancel() + + return None + + def _uncancel(self) -> bool: + if sys.version_info < (3, 11) or self._host_task is None: + self._cancel_calls = 0 + return True + + # Uncancel all AnyIO cancellations + for i in range(self._cancel_calls): + self._host_task.uncancel() + + self._cancel_calls = 0 + return not self._host_task.cancelling() + + def _timeout(self) -> None: + if self._deadline != math.inf: + loop = get_running_loop() + if loop.time() >= self._deadline: + self._timeout_expired = True + self.cancel() + else: + self._timeout_handle = loop.call_at(self._deadline, self._timeout) + + def _deliver_cancellation(self) -> None: + """ + Deliver cancellation to directly contained tasks and nested cancel scopes. + + Schedule another run at the end if we still have tasks eligible for cancellation. + """ + should_retry = False + current = current_task() + for task in self._tasks: + if task._must_cancel: # type: ignore[attr-defined] + continue + + # The task is eligible for cancellation if it has started and is not in a cancel + # scope shielded from this one + cancel_scope = _task_states[task].cancel_scope + while cancel_scope is not self: + if cancel_scope is None or cancel_scope._shield: + break + else: + cancel_scope = cancel_scope._parent_scope + else: + should_retry = True + if task is not current and ( + task is self._host_task or _task_started(task) + ): + self._cancel_calls += 1 + task.cancel() + + # Schedule another callback if there are still tasks left + if should_retry: + self._cancel_handle = get_running_loop().call_soon( + self._deliver_cancellation + ) + else: + self._cancel_handle = None + + def _deliver_cancellation_to_parent(self) -> None: + """Start cancellation effort in the farthest directly cancelled parent scope""" + scope = self._parent_scope + scope_to_cancel: CancelScope | None = None + while scope is not None: + if scope._cancel_called and scope._cancel_handle is None: + scope_to_cancel = scope + + # No point in looking beyond any shielded scope + if scope._shield: + break + + scope = scope._parent_scope + + if scope_to_cancel is not None: + scope_to_cancel._deliver_cancellation() + + def _parent_cancelled(self) -> bool: + # Check whether any parent has been cancelled + cancel_scope = self._parent_scope + while cancel_scope is not None and not cancel_scope._shield: + if cancel_scope._cancel_called: + return True + else: + cancel_scope = cancel_scope._parent_scope + + return False + + def cancel(self) -> DeprecatedAwaitable: + if not self._cancel_called: + if self._timeout_handle: + self._timeout_handle.cancel() + self._timeout_handle = None + + self._cancel_called = True + if self._host_task is not None: + self._deliver_cancellation() + + return DeprecatedAwaitable(self.cancel) + + @property + def deadline(self) -> float: + return self._deadline + + @deadline.setter + def deadline(self, value: float) -> None: + self._deadline = float(value) + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + + if self._active and not self._cancel_called: + self._timeout() + + @property + def cancel_called(self) -> bool: + return self._cancel_called + + @property + def shield(self) -> bool: + return self._shield + + @shield.setter + def shield(self, value: bool) -> None: + if self._shield != value: + self._shield = value + if not value: + self._deliver_cancellation_to_parent() + + +async def checkpoint() -> None: + await sleep(0) + + +async def checkpoint_if_cancelled() -> None: + task = current_task() + if task is None: + return + + try: + cancel_scope = _task_states[task].cancel_scope + except KeyError: + return + + while cancel_scope: + if cancel_scope.cancel_called: + await sleep(0) + elif cancel_scope.shield: + break + else: + cancel_scope = cancel_scope._parent_scope + + +async def cancel_shielded_checkpoint() -> None: + with CancelScope(shield=True): + await sleep(0) + + +def current_effective_deadline() -> float: + try: + cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index] + except KeyError: + return math.inf + + deadline = math.inf + while cancel_scope: + deadline = min(deadline, cancel_scope.deadline) + if cancel_scope._cancel_called: + deadline = -math.inf + break + elif cancel_scope.shield: + break + else: + cancel_scope = cancel_scope._parent_scope + + return deadline + + +def current_time() -> float: + return get_running_loop().time() + + +# +# Task states +# + + +class TaskState: + """ + Encapsulates auxiliary task information that cannot be added to the Task instance itself + because there are no guarantees about its implementation. + """ + + __slots__ = "parent_id", "name", "cancel_scope" + + def __init__( + self, + parent_id: int | None, + name: str | None, + cancel_scope: CancelScope | None, + ): + self.parent_id = parent_id + self.name = name + self.cancel_scope = cancel_scope + + +_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState] + + +# +# Task groups +# + + +class ExceptionGroup(BaseExceptionGroup): + def __init__(self, exceptions: list[BaseException]): + super().__init__() + self.exceptions = exceptions + + +class _AsyncioTaskStatus(abc.TaskStatus): + def __init__(self, future: asyncio.Future, parent_id: int): + self._future = future + self._parent_id = parent_id + + def started(self, value: T_contra | None = None) -> None: + try: + self._future.set_result(value) + except asyncio.InvalidStateError: + raise RuntimeError( + "called 'started' twice on the same task status" + ) from None + + task = cast(asyncio.Task, current_task()) + _task_states[task].parent_id = self._parent_id + + +class TaskGroup(abc.TaskGroup): + def __init__(self) -> None: + self.cancel_scope: CancelScope = CancelScope() + self._active = False + self._exceptions: list[BaseException] = [] + + async def __aenter__(self) -> TaskGroup: + self.cancel_scope.__enter__() + self._active = True + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) + if exc_val is not None: + self.cancel_scope.cancel() + self._exceptions.append(exc_val) + + while self.cancel_scope._tasks: + try: + await asyncio.wait(self.cancel_scope._tasks) + except asyncio.CancelledError: + self.cancel_scope.cancel() + + self._active = False + if not self.cancel_scope._parent_cancelled(): + exceptions = self._filter_cancellation_errors(self._exceptions) + else: + exceptions = self._exceptions + + try: + if len(exceptions) > 1: + if all( + isinstance(e, CancelledError) and not e.args for e in exceptions + ): + # Tasks were cancelled natively, without a cancellation message + raise CancelledError + else: + raise ExceptionGroup(exceptions) + elif exceptions and exceptions[0] is not exc_val: + raise exceptions[0] + except BaseException as exc: + # Clear the context here, as it can only be done in-flight. + # If the context is not cleared, it can result in recursive tracebacks (see #145). + exc.__context__ = None + raise + + return ignore_exception + + @staticmethod + def _filter_cancellation_errors( + exceptions: Sequence[BaseException], + ) -> list[BaseException]: + filtered_exceptions: list[BaseException] = [] + for exc in exceptions: + if isinstance(exc, ExceptionGroup): + new_exceptions = TaskGroup._filter_cancellation_errors(exc.exceptions) + if len(new_exceptions) > 1: + filtered_exceptions.append(exc) + elif len(new_exceptions) == 1: + filtered_exceptions.append(new_exceptions[0]) + elif new_exceptions: + new_exc = ExceptionGroup(new_exceptions) + new_exc.__cause__ = exc.__cause__ + new_exc.__context__ = exc.__context__ + new_exc.__traceback__ = exc.__traceback__ + filtered_exceptions.append(new_exc) + elif not isinstance(exc, CancelledError) or exc.args: + filtered_exceptions.append(exc) + + return filtered_exceptions + + async def _run_wrapped_task( + self, coro: Coroutine, task_status_future: asyncio.Future | None + ) -> None: + # This is the code path for Python 3.7 on which asyncio freaks out if a task + # raises a BaseException. + __traceback_hide__ = __tracebackhide__ = True # noqa: F841 + task = cast(asyncio.Task, current_task()) + try: + await coro + except BaseException as exc: + if task_status_future is None or task_status_future.done(): + self._exceptions.append(exc) + self.cancel_scope.cancel() + else: + task_status_future.set_exception(exc) + else: + if task_status_future is not None and not task_status_future.done(): + task_status_future.set_exception( + RuntimeError("Child exited without calling task_status.started()") + ) + finally: + if task in self.cancel_scope._tasks: + self.cancel_scope._tasks.remove(task) + del _task_states[task] + + def _spawn( + self, + func: Callable[..., Awaitable[Any]], + args: tuple, + name: object, + task_status_future: asyncio.Future | None = None, + ) -> asyncio.Task: + def task_done(_task: asyncio.Task) -> None: + # This is the code path for Python 3.8+ + assert _task in self.cancel_scope._tasks + self.cancel_scope._tasks.remove(_task) + del _task_states[_task] + + try: + exc = _task.exception() + except CancelledError as e: + while isinstance(e.__context__, CancelledError): + e = e.__context__ + + exc = e + + if exc is not None: + if task_status_future is None or task_status_future.done(): + self._exceptions.append(exc) + self.cancel_scope.cancel() + else: + task_status_future.set_exception(exc) + elif task_status_future is not None and not task_status_future.done(): + task_status_future.set_exception( + RuntimeError("Child exited without calling task_status.started()") + ) + + if not self._active: + raise RuntimeError( + "This task group is not active; no new tasks can be started." + ) + + options: dict[str, Any] = {} + name = get_callable_name(func) if name is None else str(name) + if _native_task_names: + options["name"] = name + + kwargs = {} + if task_status_future: + parent_id = id(current_task()) + kwargs["task_status"] = _AsyncioTaskStatus( + task_status_future, id(self.cancel_scope._host_task) + ) + else: + parent_id = id(self.cancel_scope._host_task) + + coro = func(*args, **kwargs) + if not asyncio.iscoroutine(coro): + raise TypeError( + f"Expected an async function, but {func} appears to be synchronous" + ) + + foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame") + if foreign_coro or sys.version_info < (3, 8): + coro = self._run_wrapped_task(coro, task_status_future) + + task = create_task(coro, **options) + if not foreign_coro and sys.version_info >= (3, 8): + task.add_done_callback(task_done) + + # Make the spawned task inherit the task group's cancel scope + _task_states[task] = TaskState( + parent_id=parent_id, name=name, cancel_scope=self.cancel_scope + ) + self.cancel_scope._tasks.add(task) + return task + + def start_soon( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> None: + self._spawn(func, args, name) + + async def start( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> None: + future: asyncio.Future = asyncio.Future() + task = self._spawn(func, args, name, future) + + # If the task raises an exception after sending a start value without a switch point + # between, the task group is cancelled and this method never proceeds to process the + # completed future. That's why we have to have a shielded cancel scope here. + with CancelScope(shield=True): + try: + return await future + except CancelledError: + task.cancel() + raise + + +# +# Threads +# + +_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]] + + +class WorkerThread(Thread): + MAX_IDLE_TIME = 10 # seconds + + def __init__( + self, + root_task: asyncio.Task, + workers: set[WorkerThread], + idle_workers: deque[WorkerThread], + ): + super().__init__(name="AnyIO worker thread") + self.root_task = root_task + self.workers = workers + self.idle_workers = idle_workers + self.loop = root_task._loop + self.queue: Queue[ + tuple[Context, Callable, tuple, asyncio.Future] | None + ] = Queue(2) + self.idle_since = current_time() + self.stopping = False + + def _report_result( + self, future: asyncio.Future, result: Any, exc: BaseException | None + ) -> None: + self.idle_since = current_time() + if not self.stopping: + self.idle_workers.append(self) + + if not future.cancelled(): + if exc is not None: + if isinstance(exc, StopIteration): + new_exc = RuntimeError("coroutine raised StopIteration") + new_exc.__cause__ = exc + exc = new_exc + + future.set_exception(exc) + else: + future.set_result(result) + + def run(self) -> None: + with claim_worker_thread("asyncio"): + threadlocals.loop = self.loop + while True: + item = self.queue.get() + if item is None: + # Shutdown command received + return + + context, func, args, future = item + if not future.cancelled(): + result = None + exception: BaseException | None = None + try: + result = context.run(func, *args) + except BaseException as exc: + exception = exc + + if not self.loop.is_closed(): + self.loop.call_soon_threadsafe( + self._report_result, future, result, exception + ) + + self.queue.task_done() + + def stop(self, f: asyncio.Task | None = None) -> None: + self.stopping = True + self.queue.put_nowait(None) + self.workers.discard(self) + try: + self.idle_workers.remove(self) + except ValueError: + pass + + +_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar( + "_threadpool_idle_workers" +) +_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers") + + +async def run_sync_in_worker_thread( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + await checkpoint() + + # If this is the first run in this event loop thread, set up the necessary variables + try: + idle_workers = _threadpool_idle_workers.get() + workers = _threadpool_workers.get() + except LookupError: + idle_workers = deque() + workers = set() + _threadpool_idle_workers.set(idle_workers) + _threadpool_workers.set(workers) + + async with (limiter or current_default_thread_limiter()): + with CancelScope(shield=not cancellable): + future: asyncio.Future = asyncio.Future() + root_task = find_root_task() + if not idle_workers: + worker = WorkerThread(root_task, workers, idle_workers) + worker.start() + workers.add(worker) + root_task.add_done_callback(worker.stop) + else: + worker = idle_workers.pop() + + # Prune any other workers that have been idle for MAX_IDLE_TIME seconds or longer + now = current_time() + while idle_workers: + if now - idle_workers[0].idle_since < WorkerThread.MAX_IDLE_TIME: + break + + expired_worker = idle_workers.popleft() + expired_worker.root_task.remove_done_callback(expired_worker.stop) + expired_worker.stop() + + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, None) + worker.queue.put_nowait((context, func, args, future)) + return await future + + +def run_sync_from_thread( + func: Callable[..., T_Retval], + *args: object, + loop: asyncio.AbstractEventLoop | None = None, +) -> T_Retval: + @wraps(func) + def wrapper() -> None: + try: + f.set_result(func(*args)) + except BaseException as exc: + f.set_exception(exc) + if not isinstance(exc, Exception): + raise + + f: concurrent.futures.Future[T_Retval] = Future() + loop = loop or threadlocals.loop + loop.call_soon_threadsafe(wrapper) + return f.result() + + +def run_async_from_thread( + func: Callable[..., Awaitable[T_Retval]], *args: object +) -> T_Retval: + f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( + func(*args), threadlocals.loop + ) + return f.result() + + +class BlockingPortal(abc.BlockingPortal): + def __new__(cls) -> BlockingPortal: + return object.__new__(cls) + + def __init__(self) -> None: + super().__init__() + self._loop = get_running_loop() + + def _spawn_task_from_thread( + self, + func: Callable, + args: tuple, + kwargs: dict[str, Any], + name: object, + future: Future, + ) -> None: + run_sync_from_thread( + partial(self._task_group.start_soon, name=name), + self._call_func, + func, + args, + kwargs, + future, + loop=self._loop, + ) + + +# +# Subprocesses +# + + +@dataclass(eq=False) +class StreamReaderWrapper(abc.ByteReceiveStream): + _stream: asyncio.StreamReader + + async def receive(self, max_bytes: int = 65536) -> bytes: + data = await self._stream.read(max_bytes) + if data: + return data + else: + raise EndOfStream + + async def aclose(self) -> None: + self._stream.feed_eof() + + +@dataclass(eq=False) +class StreamWriterWrapper(abc.ByteSendStream): + _stream: asyncio.StreamWriter + + async def send(self, item: bytes) -> None: + self._stream.write(item) + await self._stream.drain() + + async def aclose(self) -> None: + self._stream.close() + + +@dataclass(eq=False) +class Process(abc.Process): + _process: asyncio.subprocess.Process + _stdin: StreamWriterWrapper | None + _stdout: StreamReaderWrapper | None + _stderr: StreamReaderWrapper | None + + async def aclose(self) -> None: + if self._stdin: + await self._stdin.aclose() + if self._stdout: + await self._stdout.aclose() + if self._stderr: + await self._stderr.aclose() + + await self.wait() + + async def wait(self) -> int: + return await self._process.wait() + + def terminate(self) -> None: + self._process.terminate() + + def kill(self) -> None: + self._process.kill() + + def send_signal(self, signal: int) -> None: + self._process.send_signal(signal) + + @property + def pid(self) -> int: + return self._process.pid + + @property + def returncode(self) -> int | None: + return self._process.returncode + + @property + def stdin(self) -> abc.ByteSendStream | None: + return self._stdin + + @property + def stdout(self) -> abc.ByteReceiveStream | None: + return self._stdout + + @property + def stderr(self) -> abc.ByteReceiveStream | None: + return self._stderr + + +async def open_process( + command: str | bytes | Sequence[str | bytes], + *, + shell: bool, + stdin: int | IO[Any] | None, + stdout: int | IO[Any] | None, + stderr: int | IO[Any] | None, + cwd: str | bytes | PathLike | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, +) -> Process: + await checkpoint() + if shell: + process = await asyncio.create_subprocess_shell( + cast(Union[str, bytes], command), + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + else: + process = await asyncio.create_subprocess_exec( + *command, + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + + stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None + stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None + stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None + return Process(process, stdin_stream, stdout_stream, stderr_stream) + + +def _forcibly_shutdown_process_pool_on_exit( + workers: set[Process], _task: object +) -> None: + """ + Forcibly shuts down worker processes belonging to this event loop.""" + child_watcher: asyncio.AbstractChildWatcher | None + try: + child_watcher = asyncio.get_event_loop_policy().get_child_watcher() + except NotImplementedError: + child_watcher = None + + # Close as much as possible (w/o async/await) to avoid warnings + for process in workers: + if process.returncode is None: + continue + + process._stdin._stream._transport.close() # type: ignore[union-attr] + process._stdout._stream._transport.close() # type: ignore[union-attr] + process._stderr._stream._transport.close() # type: ignore[union-attr] + process.kill() + if child_watcher: + child_watcher.remove_child_handler(process.pid) + + +async def _shutdown_process_pool_on_exit(workers: set[Process]) -> None: + """ + Shuts down worker processes belonging to this event loop. + + NOTE: this only works when the event loop was started using asyncio.run() or anyio.run(). + + """ + process: Process + try: + await sleep(math.inf) + except asyncio.CancelledError: + for process in workers: + if process.returncode is None: + process.kill() + + for process in workers: + await process.aclose() + + +def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None: + kwargs: dict[str, Any] = ( + {"name": "AnyIO process pool shutdown task"} if _native_task_names else {} + ) + create_task(_shutdown_process_pool_on_exit(workers), **kwargs) + find_root_task().add_done_callback( + partial(_forcibly_shutdown_process_pool_on_exit, workers) + ) + + +# +# Sockets and networking +# + + +class StreamProtocol(asyncio.Protocol): + read_queue: deque[bytes] + read_event: asyncio.Event + write_event: asyncio.Event + exception: Exception | None = None + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self.read_queue = deque() + self.read_event = asyncio.Event() + self.write_event = asyncio.Event() + self.write_event.set() + cast(asyncio.Transport, transport).set_write_buffer_limits(0) + + def connection_lost(self, exc: Exception | None) -> None: + if exc: + self.exception = BrokenResourceError() + self.exception.__cause__ = exc + + self.read_event.set() + self.write_event.set() + + def data_received(self, data: bytes) -> None: + self.read_queue.append(data) + self.read_event.set() + + def eof_received(self) -> bool | None: + self.read_event.set() + return True + + def pause_writing(self) -> None: + self.write_event = asyncio.Event() + + def resume_writing(self) -> None: + self.write_event.set() + + +class DatagramProtocol(asyncio.DatagramProtocol): + read_queue: deque[tuple[bytes, IPSockAddrType]] + read_event: asyncio.Event + write_event: asyncio.Event + exception: Exception | None = None + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self.read_queue = deque(maxlen=100) # arbitrary value + self.read_event = asyncio.Event() + self.write_event = asyncio.Event() + self.write_event.set() + + def connection_lost(self, exc: Exception | None) -> None: + self.read_event.set() + self.write_event.set() + + def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None: + addr = convert_ipv6_sockaddr(addr) + self.read_queue.append((data, addr)) + self.read_event.set() + + def error_received(self, exc: Exception) -> None: + self.exception = exc + + def pause_writing(self) -> None: + self.write_event.clear() + + def resume_writing(self) -> None: + self.write_event.set() + + +class SocketStream(abc.SocketStream): + def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): + self._transport = transport + self._protocol = protocol + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + self._closed = False + + @property + def _raw_socket(self) -> socket.socket: + return self._transport.get_extra_info("socket") + + async def receive(self, max_bytes: int = 65536) -> bytes: + with self._receive_guard: + await checkpoint() + + if ( + not self._protocol.read_event.is_set() + and not self._transport.is_closing() + ): + self._transport.resume_reading() + await self._protocol.read_event.wait() + self._transport.pause_reading() + + try: + chunk = self._protocol.read_queue.popleft() + except IndexError: + if self._closed: + raise ClosedResourceError from None + elif self._protocol.exception: + raise self._protocol.exception + else: + raise EndOfStream from None + + if len(chunk) > max_bytes: + # Split the oversized chunk + chunk, leftover = chunk[:max_bytes], chunk[max_bytes:] + self._protocol.read_queue.appendleft(leftover) + + # If the read queue is empty, clear the flag so that the next call will block until + # data is available + if not self._protocol.read_queue: + self._protocol.read_event.clear() + + return chunk + + async def send(self, item: bytes) -> None: + with self._send_guard: + await checkpoint() + + if self._closed: + raise ClosedResourceError + elif self._protocol.exception is not None: + raise self._protocol.exception + + try: + self._transport.write(item) + except RuntimeError as exc: + if self._transport.is_closing(): + raise BrokenResourceError from exc + else: + raise + + await self._protocol.write_event.wait() + + async def send_eof(self) -> None: + try: + self._transport.write_eof() + except OSError: + pass + + async def aclose(self) -> None: + if not self._transport.is_closing(): + self._closed = True + try: + self._transport.write_eof() + except OSError: + pass + + self._transport.close() + await sleep(0) + self._transport.abort() + + +class UNIXSocketStream(abc.SocketStream): + _receive_future: asyncio.Future | None = None + _send_future: asyncio.Future | None = None + _closing = False + + def __init__(self, raw_socket: socket.socket): + self.__raw_socket = raw_socket + self._loop = get_running_loop() + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + + @property + def _raw_socket(self) -> socket.socket: + return self.__raw_socket + + def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: + def callback(f: object) -> None: + del self._receive_future + loop.remove_reader(self.__raw_socket) + + f = self._receive_future = asyncio.Future() + self._loop.add_reader(self.__raw_socket, f.set_result, None) + f.add_done_callback(callback) + return f + + def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: + def callback(f: object) -> None: + del self._send_future + loop.remove_writer(self.__raw_socket) + + f = self._send_future = asyncio.Future() + self._loop.add_writer(self.__raw_socket, f.set_result, None) + f.add_done_callback(callback) + return f + + async def send_eof(self) -> None: + with self._send_guard: + self._raw_socket.shutdown(socket.SHUT_WR) + + async def receive(self, max_bytes: int = 65536) -> bytes: + loop = get_running_loop() + await checkpoint() + with self._receive_guard: + while True: + try: + data = self.__raw_socket.recv(max_bytes) + except BlockingIOError: + await self._wait_until_readable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + if not data: + raise EndOfStream + + return data + + async def send(self, item: bytes) -> None: + loop = get_running_loop() + await checkpoint() + with self._send_guard: + view = memoryview(item) + while view: + try: + bytes_sent = self.__raw_socket.send(view) + except BlockingIOError: + await self._wait_until_writable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + view = view[bytes_sent:] + + async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: + if not isinstance(msglen, int) or msglen < 0: + raise ValueError("msglen must be a non-negative integer") + if not isinstance(maxfds, int) or maxfds < 1: + raise ValueError("maxfds must be a positive integer") + + loop = get_running_loop() + fds = array.array("i") + await checkpoint() + with self._receive_guard: + while True: + try: + message, ancdata, flags, addr = self.__raw_socket.recvmsg( + msglen, socket.CMSG_LEN(maxfds * fds.itemsize) + ) + except BlockingIOError: + await self._wait_until_readable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + if not message and not ancdata: + raise EndOfStream + + break + + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: + raise RuntimeError( + f"Received unexpected ancillary data; message = {message!r}, " + f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" + ) + + fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + + return message, list(fds) + + async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: + if not message: + raise ValueError("message must not be empty") + if not fds: + raise ValueError("fds must not be empty") + + loop = get_running_loop() + filenos: list[int] = [] + for fd in fds: + if isinstance(fd, int): + filenos.append(fd) + elif isinstance(fd, IOBase): + filenos.append(fd.fileno()) + + fdarray = array.array("i", filenos) + await checkpoint() + with self._send_guard: + while True: + try: + # The ignore can be removed after mypy picks up + # https://github.com/python/typeshed/pull/5545 + self.__raw_socket.sendmsg( + [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)] + ) + break + except BlockingIOError: + await self._wait_until_writable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + + async def aclose(self) -> None: + if not self._closing: + self._closing = True + if self.__raw_socket.fileno() != -1: + self.__raw_socket.close() + + if self._receive_future: + self._receive_future.set_result(None) + if self._send_future: + self._send_future.set_result(None) + + +class TCPSocketListener(abc.SocketListener): + _accept_scope: CancelScope | None = None + _closed = False + + def __init__(self, raw_socket: socket.socket): + self.__raw_socket = raw_socket + self._loop = cast(asyncio.BaseEventLoop, get_running_loop()) + self._accept_guard = ResourceGuard("accepting connections from") + + @property + def _raw_socket(self) -> socket.socket: + return self.__raw_socket + + async def accept(self) -> abc.SocketStream: + if self._closed: + raise ClosedResourceError + + with self._accept_guard: + await checkpoint() + with CancelScope() as self._accept_scope: + try: + client_sock, _addr = await self._loop.sock_accept(self._raw_socket) + except asyncio.CancelledError: + # Workaround for https://bugs.python.org/issue41317 + try: + self._loop.remove_reader(self._raw_socket) + except (ValueError, NotImplementedError): + pass + + if self._closed: + raise ClosedResourceError from None + + raise + finally: + self._accept_scope = None + + client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + transport, protocol = await self._loop.connect_accepted_socket( + StreamProtocol, client_sock + ) + return SocketStream(transport, protocol) + + async def aclose(self) -> None: + if self._closed: + return + + self._closed = True + if self._accept_scope: + # Workaround for https://bugs.python.org/issue41317 + try: + self._loop.remove_reader(self._raw_socket) + except (ValueError, NotImplementedError): + pass + + self._accept_scope.cancel() + await sleep(0) + + self._raw_socket.close() + + +class UNIXSocketListener(abc.SocketListener): + def __init__(self, raw_socket: socket.socket): + self.__raw_socket = raw_socket + self._loop = get_running_loop() + self._accept_guard = ResourceGuard("accepting connections from") + self._closed = False + + async def accept(self) -> abc.SocketStream: + await checkpoint() + with self._accept_guard: + while True: + try: + client_sock, _ = self.__raw_socket.accept() + client_sock.setblocking(False) + return UNIXSocketStream(client_sock) + except BlockingIOError: + f: asyncio.Future = asyncio.Future() + self._loop.add_reader(self.__raw_socket, f.set_result, None) + f.add_done_callback( + lambda _: self._loop.remove_reader(self.__raw_socket) + ) + await f + except OSError as exc: + if self._closed: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + + async def aclose(self) -> None: + self._closed = True + self.__raw_socket.close() + + @property + def _raw_socket(self) -> socket.socket: + return self.__raw_socket + + +class UDPSocket(abc.UDPSocket): + def __init__( + self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol + ): + self._transport = transport + self._protocol = protocol + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + self._closed = False + + @property + def _raw_socket(self) -> socket.socket: + return self._transport.get_extra_info("socket") + + async def aclose(self) -> None: + if not self._transport.is_closing(): + self._closed = True + self._transport.close() + + async def receive(self) -> tuple[bytes, IPSockAddrType]: + with self._receive_guard: + await checkpoint() + + # If the buffer is empty, ask for more data + if not self._protocol.read_queue and not self._transport.is_closing(): + self._protocol.read_event.clear() + await self._protocol.read_event.wait() + + try: + return self._protocol.read_queue.popleft() + except IndexError: + if self._closed: + raise ClosedResourceError from None + else: + raise BrokenResourceError from None + + async def send(self, item: UDPPacketType) -> None: + with self._send_guard: + await checkpoint() + await self._protocol.write_event.wait() + if self._closed: + raise ClosedResourceError + elif self._transport.is_closing(): + raise BrokenResourceError + else: + self._transport.sendto(*item) + + +class ConnectedUDPSocket(abc.ConnectedUDPSocket): + def __init__( + self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol + ): + self._transport = transport + self._protocol = protocol + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + self._closed = False + + @property + def _raw_socket(self) -> socket.socket: + return self._transport.get_extra_info("socket") + + async def aclose(self) -> None: + if not self._transport.is_closing(): + self._closed = True + self._transport.close() + + async def receive(self) -> bytes: + with self._receive_guard: + await checkpoint() + + # If the buffer is empty, ask for more data + if not self._protocol.read_queue and not self._transport.is_closing(): + self._protocol.read_event.clear() + await self._protocol.read_event.wait() + + try: + packet = self._protocol.read_queue.popleft() + except IndexError: + if self._closed: + raise ClosedResourceError from None + else: + raise BrokenResourceError from None + + return packet[0] + + async def send(self, item: bytes) -> None: + with self._send_guard: + await checkpoint() + await self._protocol.write_event.wait() + if self._closed: + raise ClosedResourceError + elif self._transport.is_closing(): + raise BrokenResourceError + else: + self._transport.sendto(item) + + +async def connect_tcp( + host: str, port: int, local_addr: tuple[str, int] | None = None +) -> SocketStream: + transport, protocol = cast( + Tuple[asyncio.Transport, StreamProtocol], + await get_running_loop().create_connection( + StreamProtocol, host, port, local_addr=local_addr + ), + ) + transport.pause_reading() + return SocketStream(transport, protocol) + + +async def connect_unix(path: str) -> UNIXSocketStream: + await checkpoint() + loop = get_running_loop() + raw_socket = socket.socket(socket.AF_UNIX) + raw_socket.setblocking(False) + while True: + try: + raw_socket.connect(path) + except BlockingIOError: + f: asyncio.Future = asyncio.Future() + loop.add_writer(raw_socket, f.set_result, None) + f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) + await f + except BaseException: + raw_socket.close() + raise + else: + return UNIXSocketStream(raw_socket) + + +async def create_udp_socket( + family: socket.AddressFamily, + local_address: IPSockAddrType | None, + remote_address: IPSockAddrType | None, + reuse_port: bool, +) -> UDPSocket | ConnectedUDPSocket: + result = await get_running_loop().create_datagram_endpoint( + DatagramProtocol, + local_addr=local_address, + remote_addr=remote_address, + family=family, + reuse_port=reuse_port, + ) + transport = result[0] + protocol = result[1] + if protocol.exception: + transport.close() + raise protocol.exception + + if not remote_address: + return UDPSocket(transport, protocol) + else: + return ConnectedUDPSocket(transport, protocol) + + +async def getaddrinfo( + host: bytes | str, + port: str | int | None, + *, + family: int | AddressFamily = 0, + type: int | SocketKind = 0, + proto: int = 0, + flags: int = 0, +) -> GetAddrInfoReturnType: + # https://github.com/python/typeshed/pull/4304 + result = await get_running_loop().getaddrinfo( + host, port, family=family, type=type, proto=proto, flags=flags + ) + return cast(GetAddrInfoReturnType, result) + + +async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> tuple[str, str]: + return await get_running_loop().getnameinfo(sockaddr, flags) + + +_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events") +_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events") + + +async def wait_socket_readable(sock: socket.socket) -> None: + await checkpoint() + try: + read_events = _read_events.get() + except LookupError: + read_events = {} + _read_events.set(read_events) + + if read_events.get(sock): + raise BusyResourceError("reading from") from None + + loop = get_running_loop() + event = read_events[sock] = asyncio.Event() + loop.add_reader(sock, event.set) + try: + await event.wait() + finally: + if read_events.pop(sock, None) is not None: + loop.remove_reader(sock) + readable = True + else: + readable = False + + if not readable: + raise ClosedResourceError + + +async def wait_socket_writable(sock: socket.socket) -> None: + await checkpoint() + try: + write_events = _write_events.get() + except LookupError: + write_events = {} + _write_events.set(write_events) + + if write_events.get(sock): + raise BusyResourceError("writing to") from None + + loop = get_running_loop() + event = write_events[sock] = asyncio.Event() + loop.add_writer(sock.fileno(), event.set) + try: + await event.wait() + finally: + if write_events.pop(sock, None) is not None: + loop.remove_writer(sock) + writable = True + else: + writable = False + + if not writable: + raise ClosedResourceError + + +# +# Synchronization +# + + +class Event(BaseEvent): + def __new__(cls) -> Event: + return object.__new__(cls) + + def __init__(self) -> None: + self._event = asyncio.Event() + + def set(self) -> DeprecatedAwaitable: + self._event.set() + return DeprecatedAwaitable(self.set) + + def is_set(self) -> bool: + return self._event.is_set() + + async def wait(self) -> None: + if await self._event.wait(): + await checkpoint() + + def statistics(self) -> EventStatistics: + return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined] + + +class CapacityLimiter(BaseCapacityLimiter): + _total_tokens: float = 0 + + def __new__(cls, total_tokens: float) -> CapacityLimiter: + return object.__new__(cls) + + def __init__(self, total_tokens: float): + self._borrowers: set[Any] = set() + self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict() + self.total_tokens = total_tokens + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.release() + + @property + def total_tokens(self) -> float: + return self._total_tokens + + @total_tokens.setter + def total_tokens(self, value: float) -> None: + if not isinstance(value, int) and not math.isinf(value): + raise TypeError("total_tokens must be an int or math.inf") + if value < 1: + raise ValueError("total_tokens must be >= 1") + + old_value = self._total_tokens + self._total_tokens = value + events = [] + for event in self._wait_queue.values(): + if value <= old_value: + break + + if not event.is_set(): + events.append(event) + old_value += 1 + + for event in events: + event.set() + + @property + def borrowed_tokens(self) -> int: + return len(self._borrowers) + + @property + def available_tokens(self) -> float: + return self._total_tokens - len(self._borrowers) + + def acquire_nowait(self) -> DeprecatedAwaitable: + self.acquire_on_behalf_of_nowait(current_task()) + return DeprecatedAwaitable(self.acquire_nowait) + + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: + if borrower in self._borrowers: + raise RuntimeError( + "this borrower is already holding one of this CapacityLimiter's " + "tokens" + ) + + if self._wait_queue or len(self._borrowers) >= self._total_tokens: + raise WouldBlock + + self._borrowers.add(borrower) + return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) + + async def acquire(self) -> None: + return await self.acquire_on_behalf_of(current_task()) + + async def acquire_on_behalf_of(self, borrower: object) -> None: + await checkpoint_if_cancelled() + try: + self.acquire_on_behalf_of_nowait(borrower) + except WouldBlock: + event = asyncio.Event() + self._wait_queue[borrower] = event + try: + await event.wait() + except BaseException: + self._wait_queue.pop(borrower, None) + raise + + self._borrowers.add(borrower) + else: + try: + await cancel_shielded_checkpoint() + except BaseException: + self.release() + raise + + def release(self) -> None: + self.release_on_behalf_of(current_task()) + + def release_on_behalf_of(self, borrower: object) -> None: + try: + self._borrowers.remove(borrower) + except KeyError: + raise RuntimeError( + "this borrower isn't holding any of this CapacityLimiter's " "tokens" + ) from None + + # Notify the next task in line if this limiter has free capacity now + if self._wait_queue and len(self._borrowers) < self._total_tokens: + event = self._wait_queue.popitem(last=False)[1] + event.set() + + def statistics(self) -> CapacityLimiterStatistics: + return CapacityLimiterStatistics( + self.borrowed_tokens, + self.total_tokens, + tuple(self._borrowers), + len(self._wait_queue), + ) + + +_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter") + + +def current_default_thread_limiter() -> CapacityLimiter: + try: + return _default_thread_limiter.get() + except LookupError: + limiter = CapacityLimiter(40) + _default_thread_limiter.set(limiter) + return limiter + + +# +# Operating system signals +# + + +class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): + def __init__(self, signals: tuple[int, ...]): + self._signals = signals + self._loop = get_running_loop() + self._signal_queue: deque[int] = deque() + self._future: asyncio.Future = asyncio.Future() + self._handled_signals: set[int] = set() + + def _deliver(self, signum: int) -> None: + self._signal_queue.append(signum) + if not self._future.done(): + self._future.set_result(None) + + def __enter__(self) -> _SignalReceiver: + for sig in set(self._signals): + self._loop.add_signal_handler(sig, self._deliver, sig) + self._handled_signals.add(sig) + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + for sig in self._handled_signals: + self._loop.remove_signal_handler(sig) + return None + + def __aiter__(self) -> _SignalReceiver: + return self + + async def __anext__(self) -> int: + await checkpoint() + if not self._signal_queue: + self._future = asyncio.Future() + await self._future + + return self._signal_queue.popleft() + + +def open_signal_receiver(*signals: int) -> _SignalReceiver: + return _SignalReceiver(signals) + + +# +# Testing and debugging +# + + +def _create_task_info(task: asyncio.Task) -> TaskInfo: + task_state = _task_states.get(task) + if task_state is None: + name = task.get_name() if _native_task_names else None + parent_id = None + else: + name = task_state.name + parent_id = task_state.parent_id + + return TaskInfo(id(task), parent_id, name, get_coro(task)) + + +def get_current_task() -> TaskInfo: + return _create_task_info(current_task()) # type: ignore[arg-type] + + +def get_running_tasks() -> list[TaskInfo]: + return [_create_task_info(task) for task in all_tasks() if not task.done()] + + +async def wait_all_tasks_blocked() -> None: + await checkpoint() + this_task = current_task() + while True: + for task in all_tasks(): + if task is this_task: + continue + + if task._fut_waiter is None or task._fut_waiter.done(): # type: ignore[attr-defined] + await sleep(0.1) + break + else: + return + + +class TestRunner(abc.TestRunner): + def __init__( + self, + debug: bool = False, + use_uvloop: bool = False, + policy: asyncio.AbstractEventLoopPolicy | None = None, + ): + self._exceptions: list[BaseException] = [] + _maybe_set_event_loop_policy(policy, use_uvloop) + self._loop = asyncio.new_event_loop() + self._loop.set_debug(debug) + self._loop.set_exception_handler(self._exception_handler) + asyncio.set_event_loop(self._loop) + + def _cancel_all_tasks(self) -> None: + to_cancel = all_tasks(self._loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + self._loop.run_until_complete( + asyncio.gather(*to_cancel, return_exceptions=True) + ) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + raise cast(BaseException, task.exception()) + + def _exception_handler( + self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] + ) -> None: + if isinstance(context.get("exception"), Exception): + self._exceptions.append(context["exception"]) + else: + loop.default_exception_handler(context) + + def _raise_async_exceptions(self) -> None: + # Re-raise any exceptions raised in asynchronous callbacks + if self._exceptions: + exceptions, self._exceptions = self._exceptions, [] + if len(exceptions) == 1: + raise exceptions[0] + elif exceptions: + raise ExceptionGroup(exceptions) + + def close(self) -> None: + try: + self._cancel_all_tasks() + self._loop.run_until_complete(self._loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + self._loop.close() + + def run_asyncgen_fixture( + self, + fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], + kwargs: dict[str, Any], + ) -> Iterable[T_Retval]: + async def fixture_runner() -> None: + agen = fixture_func(**kwargs) + try: + retval = await agen.asend(None) + self._raise_async_exceptions() + except BaseException as exc: + f.set_exception(exc) + return + else: + f.set_result(retval) + + await event.wait() + try: + await agen.asend(None) + except StopAsyncIteration: + pass + else: + await agen.aclose() + raise RuntimeError("Async generator fixture did not stop") + + f = self._loop.create_future() + event = asyncio.Event() + fixture_task = self._loop.create_task(fixture_runner()) + self._loop.run_until_complete(f) + yield f.result() + event.set() + self._loop.run_until_complete(fixture_task) + self._raise_async_exceptions() + + def run_fixture( + self, + fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], + kwargs: dict[str, Any], + ) -> T_Retval: + retval = self._loop.run_until_complete(fixture_func(**kwargs)) + self._raise_async_exceptions() + return retval + + def run_test( + self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] + ) -> None: + try: + self._loop.run_until_complete(test_func(**kwargs)) + except Exception as exc: + self._exceptions.append(exc) + + self._raise_async_exceptions() diff --git a/contrib/python/anyio/anyio/_backends/_trio.py b/contrib/python/anyio/anyio/_backends/_trio.py new file mode 100644 index 0000000000..cf28943509 --- /dev/null +++ b/contrib/python/anyio/anyio/_backends/_trio.py @@ -0,0 +1,996 @@ +from __future__ import annotations + +import array +import math +import socket +from concurrent.futures import Future +from contextvars import copy_context +from dataclasses import dataclass +from functools import partial +from io import IOBase +from os import PathLike +from signal import Signals +from types import TracebackType +from typing import ( + IO, + TYPE_CHECKING, + Any, + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Collection, + Coroutine, + Generic, + Iterable, + Mapping, + NoReturn, + Sequence, + TypeVar, + cast, +) + +import sniffio +import trio.from_thread +from outcome import Error, Outcome, Value +from trio.socket import SocketType as TrioSocketType +from trio.to_thread import run_sync + +from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc +from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable +from .._core._eventloop import claim_worker_thread +from .._core._exceptions import ( + BrokenResourceError, + BusyResourceError, + ClosedResourceError, + EndOfStream, +) +from .._core._exceptions import ExceptionGroup as BaseExceptionGroup +from .._core._sockets import convert_ipv6_sockaddr +from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter +from .._core._synchronization import Event as BaseEvent +from .._core._synchronization import ResourceGuard +from .._core._tasks import CancelScope as BaseCancelScope +from ..abc import IPSockAddrType, UDPPacketType + +if TYPE_CHECKING: + from trio_typing import TaskStatus + +try: + from trio import lowlevel as trio_lowlevel +except ImportError: + from trio import hazmat as trio_lowlevel # type: ignore[no-redef] + from trio.hazmat import wait_readable, wait_writable +else: + from trio.lowlevel import wait_readable, wait_writable + +try: + trio_open_process = trio_lowlevel.open_process +except AttributeError: + # isort: off + from trio import ( # type: ignore[attr-defined, no-redef] + open_process as trio_open_process, + ) + +T_Retval = TypeVar("T_Retval") +T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType) + + +# +# Event loop +# + +run = trio.run +current_token = trio.lowlevel.current_trio_token +RunVar = trio.lowlevel.RunVar + + +# +# Miscellaneous +# + +sleep = trio.sleep + + +# +# Timeouts and cancellation +# + + +class CancelScope(BaseCancelScope): + def __new__( + cls, original: trio.CancelScope | None = None, **kwargs: object + ) -> CancelScope: + return object.__new__(cls) + + def __init__(self, original: trio.CancelScope | None = None, **kwargs: Any) -> None: + self.__original = original or trio.CancelScope(**kwargs) + + def __enter__(self) -> CancelScope: + self.__original.__enter__() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + # https://github.com/python-trio/trio-typing/pull/79 + return self.__original.__exit__( # type: ignore[func-returns-value] + exc_type, exc_val, exc_tb + ) + + def cancel(self) -> DeprecatedAwaitable: + self.__original.cancel() + return DeprecatedAwaitable(self.cancel) + + @property + def deadline(self) -> float: + return self.__original.deadline + + @deadline.setter + def deadline(self, value: float) -> None: + self.__original.deadline = value + + @property + def cancel_called(self) -> bool: + return self.__original.cancel_called + + @property + def shield(self) -> bool: + return self.__original.shield + + @shield.setter + def shield(self, value: bool) -> None: + self.__original.shield = value + + +CancelledError = trio.Cancelled +checkpoint = trio.lowlevel.checkpoint +checkpoint_if_cancelled = trio.lowlevel.checkpoint_if_cancelled +cancel_shielded_checkpoint = trio.lowlevel.cancel_shielded_checkpoint +current_effective_deadline = trio.current_effective_deadline +current_time = trio.current_time + + +# +# Task groups +# + + +class ExceptionGroup(BaseExceptionGroup, trio.MultiError): + pass + + +class TaskGroup(abc.TaskGroup): + def __init__(self) -> None: + self._active = False + self._nursery_manager = trio.open_nursery() + self.cancel_scope = None # type: ignore[assignment] + + async def __aenter__(self) -> TaskGroup: + self._active = True + self._nursery = await self._nursery_manager.__aenter__() + self.cancel_scope = CancelScope(self._nursery.cancel_scope) + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + try: + return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) + except trio.MultiError as exc: + raise ExceptionGroup(exc.exceptions) from None + finally: + self._active = False + + def start_soon( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> None: + if not self._active: + raise RuntimeError( + "This task group is not active; no new tasks can be started." + ) + + self._nursery.start_soon(func, *args, name=name) + + async def start( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> object: + if not self._active: + raise RuntimeError( + "This task group is not active; no new tasks can be started." + ) + + return await self._nursery.start(func, *args, name=name) + + +# +# Threads +# + + +async def run_sync_in_worker_thread( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: trio.CapacityLimiter | None = None, +) -> T_Retval: + def wrapper() -> T_Retval: + with claim_worker_thread("trio"): + return func(*args) + + # TODO: remove explicit context copying when trio 0.20 is the minimum requirement + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, None) + return await run_sync( + context.run, wrapper, cancellable=cancellable, limiter=limiter + ) + + +# TODO: remove this workaround when trio 0.20 is the minimum requirement +def run_async_from_thread( + fn: Callable[..., Awaitable[T_Retval]], *args: Any +) -> T_Retval: + async def wrapper() -> T_Retval: + retval: T_Retval + + async def inner() -> None: + nonlocal retval + __tracebackhide__ = True + retval = await fn(*args) + + async with trio.open_nursery() as n: + context.run(n.start_soon, inner) + + __tracebackhide__ = True + return retval # noqa: F821 + + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, "trio") + return trio.from_thread.run(wrapper) + + +def run_sync_from_thread(fn: Callable[..., T_Retval], *args: Any) -> T_Retval: + # TODO: remove explicit context copying when trio 0.20 is the minimum requirement + retval = trio.from_thread.run_sync(copy_context().run, fn, *args) + return cast(T_Retval, retval) + + +class BlockingPortal(abc.BlockingPortal): + def __new__(cls) -> BlockingPortal: + return object.__new__(cls) + + def __init__(self) -> None: + super().__init__() + self._token = trio.lowlevel.current_trio_token() + + def _spawn_task_from_thread( + self, + func: Callable, + args: tuple, + kwargs: dict[str, Any], + name: object, + future: Future, + ) -> None: + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, "trio") + trio.from_thread.run_sync( + context.run, + partial(self._task_group.start_soon, name=name), + self._call_func, + func, + args, + kwargs, + future, + trio_token=self._token, + ) + + +# +# Subprocesses +# + + +@dataclass(eq=False) +class ReceiveStreamWrapper(abc.ByteReceiveStream): + _stream: trio.abc.ReceiveStream + + async def receive(self, max_bytes: int | None = None) -> bytes: + try: + data = await self._stream.receive_some(max_bytes) + except trio.ClosedResourceError as exc: + raise ClosedResourceError from exc.__cause__ + except trio.BrokenResourceError as exc: + raise BrokenResourceError from exc.__cause__ + + if data: + return data + else: + raise EndOfStream + + async def aclose(self) -> None: + await self._stream.aclose() + + +@dataclass(eq=False) +class SendStreamWrapper(abc.ByteSendStream): + _stream: trio.abc.SendStream + + async def send(self, item: bytes) -> None: + try: + await self._stream.send_all(item) + except trio.ClosedResourceError as exc: + raise ClosedResourceError from exc.__cause__ + except trio.BrokenResourceError as exc: + raise BrokenResourceError from exc.__cause__ + + async def aclose(self) -> None: + await self._stream.aclose() + + +@dataclass(eq=False) +class Process(abc.Process): + _process: trio.Process + _stdin: abc.ByteSendStream | None + _stdout: abc.ByteReceiveStream | None + _stderr: abc.ByteReceiveStream | None + + async def aclose(self) -> None: + if self._stdin: + await self._stdin.aclose() + if self._stdout: + await self._stdout.aclose() + if self._stderr: + await self._stderr.aclose() + + await self.wait() + + async def wait(self) -> int: + return await self._process.wait() + + def terminate(self) -> None: + self._process.terminate() + + def kill(self) -> None: + self._process.kill() + + def send_signal(self, signal: Signals) -> None: + self._process.send_signal(signal) + + @property + def pid(self) -> int: + return self._process.pid + + @property + def returncode(self) -> int | None: + return self._process.returncode + + @property + def stdin(self) -> abc.ByteSendStream | None: + return self._stdin + + @property + def stdout(self) -> abc.ByteReceiveStream | None: + return self._stdout + + @property + def stderr(self) -> abc.ByteReceiveStream | None: + return self._stderr + + +async def open_process( + command: str | bytes | Sequence[str | bytes], + *, + shell: bool, + stdin: int | IO[Any] | None, + stdout: int | IO[Any] | None, + stderr: int | IO[Any] | None, + cwd: str | bytes | PathLike | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, +) -> Process: + process = await trio_open_process( # type: ignore[misc] + command, # type: ignore[arg-type] + stdin=stdin, + stdout=stdout, + stderr=stderr, + shell=shell, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None + stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None + stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None + return Process(process, stdin_stream, stdout_stream, stderr_stream) + + +class _ProcessPoolShutdownInstrument(trio.abc.Instrument): + def after_run(self) -> None: + super().after_run() + + +current_default_worker_process_limiter: RunVar = RunVar( + "current_default_worker_process_limiter" +) + + +async def _shutdown_process_pool(workers: set[Process]) -> None: + process: Process + try: + await sleep(math.inf) + except trio.Cancelled: + for process in workers: + if process.returncode is None: + process.kill() + + with CancelScope(shield=True): + for process in workers: + await process.aclose() + + +def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None: + trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers) + + +# +# Sockets and networking +# + + +class _TrioSocketMixin(Generic[T_SockAddr]): + def __init__(self, trio_socket: TrioSocketType) -> None: + self._trio_socket = trio_socket + self._closed = False + + def _check_closed(self) -> None: + if self._closed: + raise ClosedResourceError + if self._trio_socket.fileno() < 0: + raise BrokenResourceError + + @property + def _raw_socket(self) -> socket.socket: + return self._trio_socket._sock # type: ignore[attr-defined] + + async def aclose(self) -> None: + if self._trio_socket.fileno() >= 0: + self._closed = True + self._trio_socket.close() + + def _convert_socket_error(self, exc: BaseException) -> NoReturn: + if isinstance(exc, trio.ClosedResourceError): + raise ClosedResourceError from exc + elif self._trio_socket.fileno() < 0 and self._closed: + raise ClosedResourceError from None + elif isinstance(exc, OSError): + raise BrokenResourceError from exc + else: + raise exc + + +class SocketStream(_TrioSocketMixin, abc.SocketStream): + def __init__(self, trio_socket: TrioSocketType) -> None: + super().__init__(trio_socket) + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + + async def receive(self, max_bytes: int = 65536) -> bytes: + with self._receive_guard: + try: + data = await self._trio_socket.recv(max_bytes) + except BaseException as exc: + self._convert_socket_error(exc) + + if data: + return data + else: + raise EndOfStream + + async def send(self, item: bytes) -> None: + with self._send_guard: + view = memoryview(item) + while view: + try: + bytes_sent = await self._trio_socket.send(view) + except BaseException as exc: + self._convert_socket_error(exc) + + view = view[bytes_sent:] + + async def send_eof(self) -> None: + self._trio_socket.shutdown(socket.SHUT_WR) + + +class UNIXSocketStream(SocketStream, abc.UNIXSocketStream): + async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: + if not isinstance(msglen, int) or msglen < 0: + raise ValueError("msglen must be a non-negative integer") + if not isinstance(maxfds, int) or maxfds < 1: + raise ValueError("maxfds must be a positive integer") + + fds = array.array("i") + await checkpoint() + with self._receive_guard: + while True: + try: + message, ancdata, flags, addr = await self._trio_socket.recvmsg( + msglen, socket.CMSG_LEN(maxfds * fds.itemsize) + ) + except BaseException as exc: + self._convert_socket_error(exc) + else: + if not message and not ancdata: + raise EndOfStream + + break + + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: + raise RuntimeError( + f"Received unexpected ancillary data; message = {message!r}, " + f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" + ) + + fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + + return message, list(fds) + + async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: + if not message: + raise ValueError("message must not be empty") + if not fds: + raise ValueError("fds must not be empty") + + filenos: list[int] = [] + for fd in fds: + if isinstance(fd, int): + filenos.append(fd) + elif isinstance(fd, IOBase): + filenos.append(fd.fileno()) + + fdarray = array.array("i", filenos) + await checkpoint() + with self._send_guard: + while True: + try: + await self._trio_socket.sendmsg( + [message], + [ + ( + socket.SOL_SOCKET, + socket.SCM_RIGHTS, # type: ignore[list-item] + fdarray, + ) + ], + ) + break + except BaseException as exc: + self._convert_socket_error(exc) + + +class TCPSocketListener(_TrioSocketMixin, abc.SocketListener): + def __init__(self, raw_socket: socket.socket): + super().__init__(trio.socket.from_stdlib_socket(raw_socket)) + self._accept_guard = ResourceGuard("accepting connections from") + + async def accept(self) -> SocketStream: + with self._accept_guard: + try: + trio_socket, _addr = await self._trio_socket.accept() + except BaseException as exc: + self._convert_socket_error(exc) + + trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return SocketStream(trio_socket) + + +class UNIXSocketListener(_TrioSocketMixin, abc.SocketListener): + def __init__(self, raw_socket: socket.socket): + super().__init__(trio.socket.from_stdlib_socket(raw_socket)) + self._accept_guard = ResourceGuard("accepting connections from") + + async def accept(self) -> UNIXSocketStream: + with self._accept_guard: + try: + trio_socket, _addr = await self._trio_socket.accept() + except BaseException as exc: + self._convert_socket_error(exc) + + return UNIXSocketStream(trio_socket) + + +class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket): + def __init__(self, trio_socket: TrioSocketType) -> None: + super().__init__(trio_socket) + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + + async def receive(self) -> tuple[bytes, IPSockAddrType]: + with self._receive_guard: + try: + data, addr = await self._trio_socket.recvfrom(65536) + return data, convert_ipv6_sockaddr(addr) + except BaseException as exc: + self._convert_socket_error(exc) + + async def send(self, item: UDPPacketType) -> None: + with self._send_guard: + try: + await self._trio_socket.sendto(*item) + except BaseException as exc: + self._convert_socket_error(exc) + + +class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket): + def __init__(self, trio_socket: TrioSocketType) -> None: + super().__init__(trio_socket) + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + + async def receive(self) -> bytes: + with self._receive_guard: + try: + return await self._trio_socket.recv(65536) + except BaseException as exc: + self._convert_socket_error(exc) + + async def send(self, item: bytes) -> None: + with self._send_guard: + try: + await self._trio_socket.send(item) + except BaseException as exc: + self._convert_socket_error(exc) + + +async def connect_tcp( + host: str, port: int, local_address: IPSockAddrType | None = None +) -> SocketStream: + family = socket.AF_INET6 if ":" in host else socket.AF_INET + trio_socket = trio.socket.socket(family) + trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if local_address: + await trio_socket.bind(local_address) + + try: + await trio_socket.connect((host, port)) + except BaseException: + trio_socket.close() + raise + + return SocketStream(trio_socket) + + +async def connect_unix(path: str) -> UNIXSocketStream: + trio_socket = trio.socket.socket(socket.AF_UNIX) + try: + await trio_socket.connect(path) + except BaseException: + trio_socket.close() + raise + + return UNIXSocketStream(trio_socket) + + +async def create_udp_socket( + family: socket.AddressFamily, + local_address: IPSockAddrType | None, + remote_address: IPSockAddrType | None, + reuse_port: bool, +) -> UDPSocket | ConnectedUDPSocket: + trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) + + if reuse_port: + trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + if local_address: + await trio_socket.bind(local_address) + + if remote_address: + await trio_socket.connect(remote_address) + return ConnectedUDPSocket(trio_socket) + else: + return UDPSocket(trio_socket) + + +getaddrinfo = trio.socket.getaddrinfo +getnameinfo = trio.socket.getnameinfo + + +async def wait_socket_readable(sock: socket.socket) -> None: + try: + await wait_readable(sock) + except trio.ClosedResourceError as exc: + raise ClosedResourceError().with_traceback(exc.__traceback__) from None + except trio.BusyResourceError: + raise BusyResourceError("reading from") from None + + +async def wait_socket_writable(sock: socket.socket) -> None: + try: + await wait_writable(sock) + except trio.ClosedResourceError as exc: + raise ClosedResourceError().with_traceback(exc.__traceback__) from None + except trio.BusyResourceError: + raise BusyResourceError("writing to") from None + + +# +# Synchronization +# + + +class Event(BaseEvent): + def __new__(cls) -> Event: + return object.__new__(cls) + + def __init__(self) -> None: + self.__original = trio.Event() + + def is_set(self) -> bool: + return self.__original.is_set() + + async def wait(self) -> None: + return await self.__original.wait() + + def statistics(self) -> EventStatistics: + orig_statistics = self.__original.statistics() + return EventStatistics(tasks_waiting=orig_statistics.tasks_waiting) + + def set(self) -> DeprecatedAwaitable: + self.__original.set() + return DeprecatedAwaitable(self.set) + + +class CapacityLimiter(BaseCapacityLimiter): + def __new__(cls, *args: object, **kwargs: object) -> CapacityLimiter: + return object.__new__(cls) + + def __init__( + self, *args: Any, original: trio.CapacityLimiter | None = None + ) -> None: + self.__original = original or trio.CapacityLimiter(*args) + + async def __aenter__(self) -> None: + return await self.__original.__aenter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.__original.__aexit__(exc_type, exc_val, exc_tb) + + @property + def total_tokens(self) -> float: + return self.__original.total_tokens + + @total_tokens.setter + def total_tokens(self, value: float) -> None: + self.__original.total_tokens = value + + @property + def borrowed_tokens(self) -> int: + return self.__original.borrowed_tokens + + @property + def available_tokens(self) -> float: + return self.__original.available_tokens + + def acquire_nowait(self) -> DeprecatedAwaitable: + self.__original.acquire_nowait() + return DeprecatedAwaitable(self.acquire_nowait) + + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: + self.__original.acquire_on_behalf_of_nowait(borrower) + return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) + + async def acquire(self) -> None: + await self.__original.acquire() + + async def acquire_on_behalf_of(self, borrower: object) -> None: + await self.__original.acquire_on_behalf_of(borrower) + + def release(self) -> None: + return self.__original.release() + + def release_on_behalf_of(self, borrower: object) -> None: + return self.__original.release_on_behalf_of(borrower) + + def statistics(self) -> CapacityLimiterStatistics: + orig = self.__original.statistics() + return CapacityLimiterStatistics( + borrowed_tokens=orig.borrowed_tokens, + total_tokens=orig.total_tokens, + borrowers=orig.borrowers, + tasks_waiting=orig.tasks_waiting, + ) + + +_capacity_limiter_wrapper: RunVar = RunVar("_capacity_limiter_wrapper") + + +def current_default_thread_limiter() -> CapacityLimiter: + try: + return _capacity_limiter_wrapper.get() + except LookupError: + limiter = CapacityLimiter( + original=trio.to_thread.current_default_thread_limiter() + ) + _capacity_limiter_wrapper.set(limiter) + return limiter + + +# +# Signal handling +# + + +class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): + _iterator: AsyncIterator[int] + + def __init__(self, signals: tuple[Signals, ...]): + self._signals = signals + + def __enter__(self) -> _SignalReceiver: + self._cm = trio.open_signal_receiver(*self._signals) + self._iterator = self._cm.__enter__() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return self._cm.__exit__(exc_type, exc_val, exc_tb) + + def __aiter__(self) -> _SignalReceiver: + return self + + async def __anext__(self) -> Signals: + signum = await self._iterator.__anext__() + return Signals(signum) + + +def open_signal_receiver(*signals: Signals) -> _SignalReceiver: + return _SignalReceiver(signals) + + +# +# Testing and debugging +# + + +def get_current_task() -> TaskInfo: + task = trio_lowlevel.current_task() + + parent_id = None + if task.parent_nursery and task.parent_nursery.parent_task: + parent_id = id(task.parent_nursery.parent_task) + + return TaskInfo(id(task), parent_id, task.name, task.coro) + + +def get_running_tasks() -> list[TaskInfo]: + root_task = trio_lowlevel.current_root_task() + task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)] + nurseries = root_task.child_nurseries + while nurseries: + new_nurseries: list[trio.Nursery] = [] + for nursery in nurseries: + for task in nursery.child_tasks: + task_infos.append( + TaskInfo(id(task), id(nursery.parent_task), task.name, task.coro) + ) + new_nurseries.extend(task.child_nurseries) + + nurseries = new_nurseries + + return task_infos + + +def wait_all_tasks_blocked() -> Awaitable[None]: + import trio.testing + + return trio.testing.wait_all_tasks_blocked() + + +class TestRunner(abc.TestRunner): + def __init__(self, **options: Any) -> None: + from collections import deque + from queue import Queue + + self._call_queue: Queue[Callable[..., object]] = Queue() + self._result_queue: deque[Outcome] = deque() + self._stop_event: trio.Event | None = None + self._nursery: trio.Nursery | None = None + self._options = options + + async def _trio_main(self) -> None: + self._stop_event = trio.Event() + async with trio.open_nursery() as self._nursery: + await self._stop_event.wait() + + async def _call_func( + self, func: Callable[..., Awaitable[object]], args: tuple, kwargs: dict + ) -> None: + try: + retval = await func(*args, **kwargs) + except BaseException as exc: + self._result_queue.append(Error(exc)) + else: + self._result_queue.append(Value(retval)) + + def _main_task_finished(self, outcome: object) -> None: + self._nursery = None + + def _get_nursery(self) -> trio.Nursery: + if self._nursery is None: + trio.lowlevel.start_guest_run( + self._trio_main, + run_sync_soon_threadsafe=self._call_queue.put, + done_callback=self._main_task_finished, + **self._options, + ) + while self._nursery is None: + self._call_queue.get()() + + return self._nursery + + def _call( + self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object + ) -> T_Retval: + self._get_nursery().start_soon(self._call_func, func, args, kwargs) + while not self._result_queue: + self._call_queue.get()() + + outcome = self._result_queue.pop() + return outcome.unwrap() + + def close(self) -> None: + if self._stop_event: + self._stop_event.set() + while self._nursery is not None: + self._call_queue.get()() + + def run_asyncgen_fixture( + self, + fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], + kwargs: dict[str, Any], + ) -> Iterable[T_Retval]: + async def fixture_runner(*, task_status: TaskStatus[T_Retval]) -> None: + agen = fixture_func(**kwargs) + retval = await agen.asend(None) + task_status.started(retval) + await teardown_event.wait() + try: + await agen.asend(None) + except StopAsyncIteration: + pass + else: + await agen.aclose() + raise RuntimeError("Async generator fixture did not stop") + + teardown_event = trio.Event() + fixture_value = self._call(lambda: self._get_nursery().start(fixture_runner)) + yield fixture_value + teardown_event.set() + + def run_fixture( + self, + fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], + kwargs: dict[str, Any], + ) -> T_Retval: + return self._call(fixture_func, **kwargs) + + def run_test( + self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] + ) -> None: + self._call(test_func, **kwargs) diff --git a/contrib/python/anyio/anyio/_core/__init__.py b/contrib/python/anyio/anyio/_core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/anyio/anyio/_core/__init__.py diff --git a/contrib/python/anyio/anyio/_core/_compat.py b/contrib/python/anyio/anyio/_core/_compat.py new file mode 100644 index 0000000000..22d29ab8ac --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_compat.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from contextlib import AbstractContextManager +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + Callable, + ContextManager, + Generator, + Generic, + Iterable, + List, + TypeVar, + Union, + overload, +) +from warnings import warn + +if TYPE_CHECKING: + from ._testing import TaskInfo +else: + TaskInfo = object + +T = TypeVar("T") +AnyDeprecatedAwaitable = Union[ + "DeprecatedAwaitable", + "DeprecatedAwaitableFloat", + "DeprecatedAwaitableList[T]", + TaskInfo, +] + + +@overload +async def maybe_async(__obj: TaskInfo) -> TaskInfo: + ... + + +@overload +async def maybe_async(__obj: DeprecatedAwaitableFloat) -> float: + ... + + +@overload +async def maybe_async(__obj: DeprecatedAwaitableList[T]) -> list[T]: + ... + + +@overload +async def maybe_async(__obj: DeprecatedAwaitable) -> None: + ... + + +async def maybe_async( + __obj: AnyDeprecatedAwaitable[T], +) -> TaskInfo | float | list[T] | None: + """ + Await on the given object if necessary. + + This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and + methods were converted from coroutine functions into regular functions. + + Do **not** try to use this for any other purpose! + + :return: the result of awaiting on the object if coroutine, or the object itself otherwise + + .. versionadded:: 2.2 + + """ + return __obj._unwrap() + + +class _ContextManagerWrapper: + def __init__(self, cm: ContextManager[T]): + self._cm = cm + + async def __aenter__(self) -> T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return self._cm.__exit__(exc_type, exc_val, exc_tb) + + +def maybe_async_cm( + cm: ContextManager[T] | AsyncContextManager[T], +) -> AsyncContextManager[T]: + """ + Wrap a regular context manager as an async one if necessary. + + This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and + methods were changed to return regular context managers instead of async ones. + + :param cm: a regular or async context manager + :return: an async context manager + + .. versionadded:: 2.2 + + """ + if not isinstance(cm, AbstractContextManager): + raise TypeError("Given object is not an context manager") + + return _ContextManagerWrapper(cm) + + +def _warn_deprecation( + awaitable: AnyDeprecatedAwaitable[Any], stacklevel: int = 1 +) -> None: + warn( + f'Awaiting on {awaitable._name}() is deprecated. Use "await ' + f"anyio.maybe_async({awaitable._name}(...)) if you have to support both AnyIO 2.x " + f'and 3.x, or just remove the "await" if you are completely migrating to AnyIO 3+.', + DeprecationWarning, + stacklevel=stacklevel + 1, + ) + + +class DeprecatedAwaitable: + def __init__(self, func: Callable[..., DeprecatedAwaitable]): + self._name = f"{func.__module__}.{func.__qualname__}" + + def __await__(self) -> Generator[None, None, None]: + _warn_deprecation(self) + if False: + yield + + def __reduce__(self) -> tuple[type[None], tuple[()]]: + return type(None), () + + def _unwrap(self) -> None: + return None + + +class DeprecatedAwaitableFloat(float): + def __new__( + cls, x: float, func: Callable[..., DeprecatedAwaitableFloat] + ) -> DeprecatedAwaitableFloat: + return super().__new__(cls, x) + + def __init__(self, x: float, func: Callable[..., DeprecatedAwaitableFloat]): + self._name = f"{func.__module__}.{func.__qualname__}" + + def __await__(self) -> Generator[None, None, float]: + _warn_deprecation(self) + if False: + yield + + return float(self) + + def __reduce__(self) -> tuple[type[float], tuple[float]]: + return float, (float(self),) + + def _unwrap(self) -> float: + return float(self) + + +class DeprecatedAwaitableList(List[T]): + def __init__( + self, + iterable: Iterable[T] = (), + *, + func: Callable[..., DeprecatedAwaitableList[T]], + ): + super().__init__(iterable) + self._name = f"{func.__module__}.{func.__qualname__}" + + def __await__(self) -> Generator[None, None, list[T]]: + _warn_deprecation(self) + if False: + yield + + return list(self) + + def __reduce__(self) -> tuple[type[list[T]], tuple[list[T]]]: + return list, (list(self),) + + def _unwrap(self) -> list[T]: + return list(self) + + +class DeprecatedAsyncContextManager(Generic[T], metaclass=ABCMeta): + @abstractmethod + def __enter__(self) -> T: + pass + + @abstractmethod + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + pass + + async def __aenter__(self) -> T: + warn( + f"Using {self.__class__.__name__} as an async context manager has been deprecated. " + f'Use "async with anyio.maybe_async_cm(yourcontextmanager) as foo:" if you have to ' + f'support both AnyIO 2.x and 3.x, or just remove the "async" from "async with" if ' + f"you are completely migrating to AnyIO 3+.", + DeprecationWarning, + ) + return self.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return self.__exit__(exc_type, exc_val, exc_tb) diff --git a/contrib/python/anyio/anyio/_core/_eventloop.py b/contrib/python/anyio/anyio/_core/_eventloop.py new file mode 100644 index 0000000000..ae9864851b --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_eventloop.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import math +import sys +import threading +from contextlib import contextmanager +from importlib import import_module +from typing import ( + Any, + Awaitable, + Callable, + Generator, + TypeVar, +) + +import sniffio + +# This must be updated when new backends are introduced +from ._compat import DeprecatedAwaitableFloat + +BACKENDS = "asyncio", "trio" + +T_Retval = TypeVar("T_Retval") +threadlocals = threading.local() + + +def run( + func: Callable[..., Awaitable[T_Retval]], + *args: object, + backend: str = "asyncio", + backend_options: dict[str, Any] | None = None, +) -> T_Retval: + """ + Run the given coroutine function in an asynchronous event loop. + + The current thread must not be already running an event loop. + + :param func: a coroutine function + :param args: positional arguments to ``func`` + :param backend: name of the asynchronous event loop implementation – currently either + ``asyncio`` or ``trio`` + :param backend_options: keyword arguments to call the backend ``run()`` implementation with + (documented :ref:`here <backend options>`) + :return: the return value of the coroutine function + :raises RuntimeError: if an asynchronous event loop is already running in this thread + :raises LookupError: if the named backend is not found + + """ + try: + asynclib_name = sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + pass + else: + raise RuntimeError(f"Already running {asynclib_name} in this thread") + + try: + asynclib = import_module(f"..._backends._{backend}", package=__name__) + except ImportError as exc: + raise LookupError(f"No such backend: {backend}") from exc + + token = None + if sniffio.current_async_library_cvar.get(None) is None: + # Since we're in control of the event loop, we can cache the name of the async library + token = sniffio.current_async_library_cvar.set(backend) + + try: + backend_options = backend_options or {} + return asynclib.run(func, *args, **backend_options) + finally: + if token: + sniffio.current_async_library_cvar.reset(token) + + +async def sleep(delay: float) -> None: + """ + Pause the current task for the specified duration. + + :param delay: the duration, in seconds + + """ + return await get_asynclib().sleep(delay) + + +async def sleep_forever() -> None: + """ + Pause the current task until it's cancelled. + + This is a shortcut for ``sleep(math.inf)``. + + .. versionadded:: 3.1 + + """ + await sleep(math.inf) + + +async def sleep_until(deadline: float) -> None: + """ + Pause the current task until the given time. + + :param deadline: the absolute time to wake up at (according to the internal monotonic clock of + the event loop) + + .. versionadded:: 3.1 + + """ + now = current_time() + await sleep(max(deadline - now, 0)) + + +def current_time() -> DeprecatedAwaitableFloat: + """ + Return the current value of the event loop's internal clock. + + :return: the clock value (seconds) + + """ + return DeprecatedAwaitableFloat(get_asynclib().current_time(), current_time) + + +def get_all_backends() -> tuple[str, ...]: + """Return a tuple of the names of all built-in backends.""" + return BACKENDS + + +def get_cancelled_exc_class() -> type[BaseException]: + """Return the current async library's cancellation exception class.""" + return get_asynclib().CancelledError + + +# +# Private API +# + + +@contextmanager +def claim_worker_thread(backend: str) -> Generator[Any, None, None]: + module = sys.modules["anyio._backends._" + backend] + threadlocals.current_async_module = module + try: + yield + finally: + del threadlocals.current_async_module + + +def get_asynclib(asynclib_name: str | None = None) -> Any: + if asynclib_name is None: + asynclib_name = sniffio.current_async_library() + + modulename = "anyio._backends._" + asynclib_name + try: + return sys.modules[modulename] + except KeyError: + return import_module(modulename) diff --git a/contrib/python/anyio/anyio/_core/_exceptions.py b/contrib/python/anyio/anyio/_core/_exceptions.py new file mode 100644 index 0000000000..92ccd77a2d --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_exceptions.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from traceback import format_exception + + +class BrokenResourceError(Exception): + """ + Raised when trying to use a resource that has been rendered unusable due to external causes + (e.g. a send stream whose peer has disconnected). + """ + + +class BrokenWorkerProcess(Exception): + """ + Raised by :func:`run_sync_in_process` if the worker process terminates abruptly or otherwise + misbehaves. + """ + + +class BusyResourceError(Exception): + """Raised when two tasks are trying to read from or write to the same resource concurrently.""" + + def __init__(self, action: str): + super().__init__(f"Another task is already {action} this resource") + + +class ClosedResourceError(Exception): + """Raised when trying to use a resource that has been closed.""" + + +class DelimiterNotFound(Exception): + """ + Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the + maximum number of bytes has been read without the delimiter being found. + """ + + def __init__(self, max_bytes: int) -> None: + super().__init__( + f"The delimiter was not found among the first {max_bytes} bytes" + ) + + +class EndOfStream(Exception): + """Raised when trying to read from a stream that has been closed from the other end.""" + + +class ExceptionGroup(BaseException): + """ + Raised when multiple exceptions have been raised in a task group. + + :var ~typing.Sequence[BaseException] exceptions: the sequence of exceptions raised together + """ + + SEPARATOR = "----------------------------\n" + + exceptions: list[BaseException] + + def __str__(self) -> str: + tracebacks = [ + "".join(format_exception(type(exc), exc, exc.__traceback__)) + for exc in self.exceptions + ] + return ( + f"{len(self.exceptions)} exceptions were raised in the task group:\n" + f"{self.SEPARATOR}{self.SEPARATOR.join(tracebacks)}" + ) + + def __repr__(self) -> str: + exception_reprs = ", ".join(repr(exc) for exc in self.exceptions) + return f"<{self.__class__.__name__}: {exception_reprs}>" + + +class IncompleteRead(Exception): + """ + Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_exactly` or + :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the + connection is closed before the requested amount of bytes has been read. + """ + + def __init__(self) -> None: + super().__init__( + "The stream was closed before the read operation could be completed" + ) + + +class TypedAttributeLookupError(LookupError): + """ + Raised by :meth:`~anyio.TypedAttributeProvider.extra` when the given typed attribute is not + found and no default value has been given. + """ + + +class WouldBlock(Exception): + """Raised by ``X_nowait`` functions if ``X()`` would block.""" diff --git a/contrib/python/anyio/anyio/_core/_fileio.py b/contrib/python/anyio/anyio/_core/_fileio.py new file mode 100644 index 0000000000..35e8e8af6c --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_fileio.py @@ -0,0 +1,603 @@ +from __future__ import annotations + +import os +import pathlib +import sys +from dataclasses import dataclass +from functools import partial +from os import PathLike +from typing import ( + IO, + TYPE_CHECKING, + Any, + AnyStr, + AsyncIterator, + Callable, + Generic, + Iterable, + Iterator, + Sequence, + cast, + overload, +) + +from .. import to_thread +from ..abc import AsyncResource + +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + +if TYPE_CHECKING: + from _typeshed import OpenBinaryMode, OpenTextMode, ReadableBuffer, WriteableBuffer +else: + ReadableBuffer = OpenBinaryMode = OpenTextMode = WriteableBuffer = object + + +class AsyncFile(AsyncResource, Generic[AnyStr]): + """ + An asynchronous file object. + + This class wraps a standard file object and provides async friendly versions of the following + blocking methods (where available on the original file object): + + * read + * read1 + * readline + * readlines + * readinto + * readinto1 + * write + * writelines + * truncate + * seek + * tell + * flush + + All other methods are directly passed through. + + This class supports the asynchronous context manager protocol which closes the underlying file + at the end of the context block. + + This class also supports asynchronous iteration:: + + async with await open_file(...) as f: + async for line in f: + print(line) + """ + + def __init__(self, fp: IO[AnyStr]) -> None: + self._fp: Any = fp + + def __getattr__(self, name: str) -> object: + return getattr(self._fp, name) + + @property + def wrapped(self) -> IO[AnyStr]: + """The wrapped file object.""" + return self._fp + + async def __aiter__(self) -> AsyncIterator[AnyStr]: + while True: + line = await self.readline() + if line: + yield line + else: + break + + async def aclose(self) -> None: + return await to_thread.run_sync(self._fp.close) + + async def read(self, size: int = -1) -> AnyStr: + return await to_thread.run_sync(self._fp.read, size) + + async def read1(self: AsyncFile[bytes], size: int = -1) -> bytes: + return await to_thread.run_sync(self._fp.read1, size) + + async def readline(self) -> AnyStr: + return await to_thread.run_sync(self._fp.readline) + + async def readlines(self) -> list[AnyStr]: + return await to_thread.run_sync(self._fp.readlines) + + async def readinto(self: AsyncFile[bytes], b: WriteableBuffer) -> bytes: + return await to_thread.run_sync(self._fp.readinto, b) + + async def readinto1(self: AsyncFile[bytes], b: WriteableBuffer) -> bytes: + return await to_thread.run_sync(self._fp.readinto1, b) + + @overload + async def write(self: AsyncFile[bytes], b: ReadableBuffer) -> int: + ... + + @overload + async def write(self: AsyncFile[str], b: str) -> int: + ... + + async def write(self, b: ReadableBuffer | str) -> int: + return await to_thread.run_sync(self._fp.write, b) + + @overload + async def writelines( + self: AsyncFile[bytes], lines: Iterable[ReadableBuffer] + ) -> None: + ... + + @overload + async def writelines(self: AsyncFile[str], lines: Iterable[str]) -> None: + ... + + async def writelines(self, lines: Iterable[ReadableBuffer] | Iterable[str]) -> None: + return await to_thread.run_sync(self._fp.writelines, lines) + + async def truncate(self, size: int | None = None) -> int: + return await to_thread.run_sync(self._fp.truncate, size) + + async def seek(self, offset: int, whence: int | None = os.SEEK_SET) -> int: + return await to_thread.run_sync(self._fp.seek, offset, whence) + + async def tell(self) -> int: + return await to_thread.run_sync(self._fp.tell) + + async def flush(self) -> None: + return await to_thread.run_sync(self._fp.flush) + + +@overload +async def open_file( + file: str | PathLike[str] | int, + mode: OpenBinaryMode, + buffering: int = ..., + encoding: str | None = ..., + errors: str | None = ..., + newline: str | None = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] | None = ..., +) -> AsyncFile[bytes]: + ... + + +@overload +async def open_file( + file: str | PathLike[str] | int, + mode: OpenTextMode = ..., + buffering: int = ..., + encoding: str | None = ..., + errors: str | None = ..., + newline: str | None = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] | None = ..., +) -> AsyncFile[str]: + ... + + +async def open_file( + file: str | PathLike[str] | int, + mode: str = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: Callable[[str, int], int] | None = None, +) -> AsyncFile[Any]: + """ + Open a file asynchronously. + + The arguments are exactly the same as for the builtin :func:`open`. + + :return: an asynchronous file object + + """ + fp = await to_thread.run_sync( + open, file, mode, buffering, encoding, errors, newline, closefd, opener + ) + return AsyncFile(fp) + + +def wrap_file(file: IO[AnyStr]) -> AsyncFile[AnyStr]: + """ + Wrap an existing file as an asynchronous file. + + :param file: an existing file-like object + :return: an asynchronous file object + + """ + return AsyncFile(file) + + +@dataclass(eq=False) +class _PathIterator(AsyncIterator["Path"]): + iterator: Iterator[PathLike[str]] + + async def __anext__(self) -> Path: + nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True) + if nextval is None: + raise StopAsyncIteration from None + + return Path(cast("PathLike[str]", nextval)) + + +class Path: + """ + An asynchronous version of :class:`pathlib.Path`. + + This class cannot be substituted for :class:`pathlib.Path` or :class:`pathlib.PurePath`, but + it is compatible with the :class:`os.PathLike` interface. + + It implements the Python 3.10 version of :class:`pathlib.Path` interface, except for the + deprecated :meth:`~pathlib.Path.link_to` method. + + Any methods that do disk I/O need to be awaited on. These methods are: + + * :meth:`~pathlib.Path.absolute` + * :meth:`~pathlib.Path.chmod` + * :meth:`~pathlib.Path.cwd` + * :meth:`~pathlib.Path.exists` + * :meth:`~pathlib.Path.expanduser` + * :meth:`~pathlib.Path.group` + * :meth:`~pathlib.Path.hardlink_to` + * :meth:`~pathlib.Path.home` + * :meth:`~pathlib.Path.is_block_device` + * :meth:`~pathlib.Path.is_char_device` + * :meth:`~pathlib.Path.is_dir` + * :meth:`~pathlib.Path.is_fifo` + * :meth:`~pathlib.Path.is_file` + * :meth:`~pathlib.Path.is_mount` + * :meth:`~pathlib.Path.lchmod` + * :meth:`~pathlib.Path.lstat` + * :meth:`~pathlib.Path.mkdir` + * :meth:`~pathlib.Path.open` + * :meth:`~pathlib.Path.owner` + * :meth:`~pathlib.Path.read_bytes` + * :meth:`~pathlib.Path.read_text` + * :meth:`~pathlib.Path.readlink` + * :meth:`~pathlib.Path.rename` + * :meth:`~pathlib.Path.replace` + * :meth:`~pathlib.Path.rmdir` + * :meth:`~pathlib.Path.samefile` + * :meth:`~pathlib.Path.stat` + * :meth:`~pathlib.Path.touch` + * :meth:`~pathlib.Path.unlink` + * :meth:`~pathlib.Path.write_bytes` + * :meth:`~pathlib.Path.write_text` + + Additionally, the following methods return an async iterator yielding :class:`~.Path` objects: + + * :meth:`~pathlib.Path.glob` + * :meth:`~pathlib.Path.iterdir` + * :meth:`~pathlib.Path.rglob` + """ + + __slots__ = "_path", "__weakref__" + + __weakref__: Any + + def __init__(self, *args: str | PathLike[str]) -> None: + self._path: Final[pathlib.Path] = pathlib.Path(*args) + + def __fspath__(self) -> str: + return self._path.__fspath__() + + def __str__(self) -> str: + return self._path.__str__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.as_posix()!r})" + + def __bytes__(self) -> bytes: + return self._path.__bytes__() + + def __hash__(self) -> int: + return self._path.__hash__() + + def __eq__(self, other: object) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__eq__(target) + + def __lt__(self, other: Path) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__lt__(target) + + def __le__(self, other: Path) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__le__(target) + + def __gt__(self, other: Path) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__gt__(target) + + def __ge__(self, other: Path) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__ge__(target) + + def __truediv__(self, other: Any) -> Path: + return Path(self._path / other) + + def __rtruediv__(self, other: Any) -> Path: + return Path(other) / self + + @property + def parts(self) -> tuple[str, ...]: + return self._path.parts + + @property + def drive(self) -> str: + return self._path.drive + + @property + def root(self) -> str: + return self._path.root + + @property + def anchor(self) -> str: + return self._path.anchor + + @property + def parents(self) -> Sequence[Path]: + return tuple(Path(p) for p in self._path.parents) + + @property + def parent(self) -> Path: + return Path(self._path.parent) + + @property + def name(self) -> str: + return self._path.name + + @property + def suffix(self) -> str: + return self._path.suffix + + @property + def suffixes(self) -> list[str]: + return self._path.suffixes + + @property + def stem(self) -> str: + return self._path.stem + + async def absolute(self) -> Path: + path = await to_thread.run_sync(self._path.absolute) + return Path(path) + + def as_posix(self) -> str: + return self._path.as_posix() + + def as_uri(self) -> str: + return self._path.as_uri() + + def match(self, path_pattern: str) -> bool: + return self._path.match(path_pattern) + + def is_relative_to(self, *other: str | PathLike[str]) -> bool: + try: + self.relative_to(*other) + return True + except ValueError: + return False + + async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None: + func = partial(os.chmod, follow_symlinks=follow_symlinks) + return await to_thread.run_sync(func, self._path, mode) + + @classmethod + async def cwd(cls) -> Path: + path = await to_thread.run_sync(pathlib.Path.cwd) + return cls(path) + + async def exists(self) -> bool: + return await to_thread.run_sync(self._path.exists, cancellable=True) + + async def expanduser(self) -> Path: + return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True)) + + def glob(self, pattern: str) -> AsyncIterator[Path]: + gen = self._path.glob(pattern) + return _PathIterator(gen) + + async def group(self) -> str: + return await to_thread.run_sync(self._path.group, cancellable=True) + + async def hardlink_to(self, target: str | pathlib.Path | Path) -> None: + if isinstance(target, Path): + target = target._path + + await to_thread.run_sync(os.link, target, self) + + @classmethod + async def home(cls) -> Path: + home_path = await to_thread.run_sync(pathlib.Path.home) + return cls(home_path) + + def is_absolute(self) -> bool: + return self._path.is_absolute() + + async def is_block_device(self) -> bool: + return await to_thread.run_sync(self._path.is_block_device, cancellable=True) + + async def is_char_device(self) -> bool: + return await to_thread.run_sync(self._path.is_char_device, cancellable=True) + + async def is_dir(self) -> bool: + return await to_thread.run_sync(self._path.is_dir, cancellable=True) + + async def is_fifo(self) -> bool: + return await to_thread.run_sync(self._path.is_fifo, cancellable=True) + + async def is_file(self) -> bool: + return await to_thread.run_sync(self._path.is_file, cancellable=True) + + async def is_mount(self) -> bool: + return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True) + + def is_reserved(self) -> bool: + return self._path.is_reserved() + + async def is_socket(self) -> bool: + return await to_thread.run_sync(self._path.is_socket, cancellable=True) + + async def is_symlink(self) -> bool: + return await to_thread.run_sync(self._path.is_symlink, cancellable=True) + + def iterdir(self) -> AsyncIterator[Path]: + gen = self._path.iterdir() + return _PathIterator(gen) + + def joinpath(self, *args: str | PathLike[str]) -> Path: + return Path(self._path.joinpath(*args)) + + async def lchmod(self, mode: int) -> None: + await to_thread.run_sync(self._path.lchmod, mode) + + async def lstat(self) -> os.stat_result: + return await to_thread.run_sync(self._path.lstat, cancellable=True) + + async def mkdir( + self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False + ) -> None: + await to_thread.run_sync(self._path.mkdir, mode, parents, exist_ok) + + @overload + async def open( + self, + mode: OpenBinaryMode, + buffering: int = ..., + encoding: str | None = ..., + errors: str | None = ..., + newline: str | None = ..., + ) -> AsyncFile[bytes]: + ... + + @overload + async def open( + self, + mode: OpenTextMode = ..., + buffering: int = ..., + encoding: str | None = ..., + errors: str | None = ..., + newline: str | None = ..., + ) -> AsyncFile[str]: + ... + + async def open( + self, + mode: str = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> AsyncFile[Any]: + fp = await to_thread.run_sync( + self._path.open, mode, buffering, encoding, errors, newline + ) + return AsyncFile(fp) + + async def owner(self) -> str: + return await to_thread.run_sync(self._path.owner, cancellable=True) + + async def read_bytes(self) -> bytes: + return await to_thread.run_sync(self._path.read_bytes) + + async def read_text( + self, encoding: str | None = None, errors: str | None = None + ) -> str: + return await to_thread.run_sync(self._path.read_text, encoding, errors) + + def relative_to(self, *other: str | PathLike[str]) -> Path: + return Path(self._path.relative_to(*other)) + + async def readlink(self) -> Path: + target = await to_thread.run_sync(os.readlink, self._path) + return Path(cast(str, target)) + + async def rename(self, target: str | pathlib.PurePath | Path) -> Path: + if isinstance(target, Path): + target = target._path + + await to_thread.run_sync(self._path.rename, target) + return Path(target) + + async def replace(self, target: str | pathlib.PurePath | Path) -> Path: + if isinstance(target, Path): + target = target._path + + await to_thread.run_sync(self._path.replace, target) + return Path(target) + + async def resolve(self, strict: bool = False) -> Path: + func = partial(self._path.resolve, strict=strict) + return Path(await to_thread.run_sync(func, cancellable=True)) + + def rglob(self, pattern: str) -> AsyncIterator[Path]: + gen = self._path.rglob(pattern) + return _PathIterator(gen) + + async def rmdir(self) -> None: + await to_thread.run_sync(self._path.rmdir) + + async def samefile( + self, other_path: str | bytes | int | pathlib.Path | Path + ) -> bool: + if isinstance(other_path, Path): + other_path = other_path._path + + return await to_thread.run_sync( + self._path.samefile, other_path, cancellable=True + ) + + async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result: + func = partial(os.stat, follow_symlinks=follow_symlinks) + return await to_thread.run_sync(func, self._path, cancellable=True) + + async def symlink_to( + self, + target: str | pathlib.Path | Path, + target_is_directory: bool = False, + ) -> None: + if isinstance(target, Path): + target = target._path + + await to_thread.run_sync(self._path.symlink_to, target, target_is_directory) + + async def touch(self, mode: int = 0o666, exist_ok: bool = True) -> None: + await to_thread.run_sync(self._path.touch, mode, exist_ok) + + async def unlink(self, missing_ok: bool = False) -> None: + try: + await to_thread.run_sync(self._path.unlink) + except FileNotFoundError: + if not missing_ok: + raise + + def with_name(self, name: str) -> Path: + return Path(self._path.with_name(name)) + + def with_stem(self, stem: str) -> Path: + return Path(self._path.with_name(stem + self._path.suffix)) + + def with_suffix(self, suffix: str) -> Path: + return Path(self._path.with_suffix(suffix)) + + async def write_bytes(self, data: bytes) -> int: + return await to_thread.run_sync(self._path.write_bytes, data) + + async def write_text( + self, + data: str, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> int: + # Path.write_text() does not support the "newline" parameter before Python 3.10 + def sync_write_text() -> int: + with self._path.open( + "w", encoding=encoding, errors=errors, newline=newline + ) as fp: + return fp.write(data) + + return await to_thread.run_sync(sync_write_text) + + +PathLike.register(Path) diff --git a/contrib/python/anyio/anyio/_core/_resources.py b/contrib/python/anyio/anyio/_core/_resources.py new file mode 100644 index 0000000000..b9a5344aef --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_resources.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ..abc import AsyncResource +from ._tasks import CancelScope + + +async def aclose_forcefully(resource: AsyncResource) -> None: + """ + Close an asynchronous resource in a cancelled scope. + + Doing this closes the resource without waiting on anything. + + :param resource: the resource to close + + """ + with CancelScope() as scope: + scope.cancel() + await resource.aclose() diff --git a/contrib/python/anyio/anyio/_core/_signals.py b/contrib/python/anyio/anyio/_core/_signals.py new file mode 100644 index 0000000000..8ea54af86c --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_signals.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import AsyncIterator + +from ._compat import DeprecatedAsyncContextManager +from ._eventloop import get_asynclib + + +def open_signal_receiver( + *signals: int, +) -> DeprecatedAsyncContextManager[AsyncIterator[int]]: + """ + Start receiving operating system signals. + + :param signals: signals to receive (e.g. ``signal.SIGINT``) + :return: an asynchronous context manager for an asynchronous iterator which yields signal + numbers + + .. warning:: Windows does not support signals natively so it is best to avoid relying on this + in cross-platform applications. + + .. warning:: On asyncio, this permanently replaces any previous signal handler for the given + signals, as set via :meth:`~asyncio.loop.add_signal_handler`. + + """ + return get_asynclib().open_signal_receiver(*signals) diff --git a/contrib/python/anyio/anyio/_core/_sockets.py b/contrib/python/anyio/anyio/_core/_sockets.py new file mode 100644 index 0000000000..e6970bee27 --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_sockets.py @@ -0,0 +1,607 @@ +from __future__ import annotations + +import socket +import ssl +import sys +from ipaddress import IPv6Address, ip_address +from os import PathLike, chmod +from pathlib import Path +from socket import AddressFamily, SocketKind +from typing import Awaitable, List, Tuple, cast, overload + +from .. import to_thread +from ..abc import ( + ConnectedUDPSocket, + IPAddressType, + IPSockAddrType, + SocketListener, + SocketStream, + UDPSocket, + UNIXSocketStream, +) +from ..streams.stapled import MultiListener +from ..streams.tls import TLSStream +from ._eventloop import get_asynclib +from ._resources import aclose_forcefully +from ._synchronization import Event +from ._tasks import create_task_group, move_on_after + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515 + +GetAddrInfoReturnType = List[ + Tuple[AddressFamily, SocketKind, int, str, Tuple[str, int]] +] +AnyIPAddressFamily = Literal[ + AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6 +] +IPAddressFamily = Literal[AddressFamily.AF_INET, AddressFamily.AF_INET6] + + +# tls_hostname given +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + ssl_context: ssl.SSLContext | None = ..., + tls_standard_compatible: bool = ..., + tls_hostname: str, + happy_eyeballs_delay: float = ..., +) -> TLSStream: + ... + + +# ssl_context given +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + ssl_context: ssl.SSLContext, + tls_standard_compatible: bool = ..., + tls_hostname: str | None = ..., + happy_eyeballs_delay: float = ..., +) -> TLSStream: + ... + + +# tls=True +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + tls: Literal[True], + ssl_context: ssl.SSLContext | None = ..., + tls_standard_compatible: bool = ..., + tls_hostname: str | None = ..., + happy_eyeballs_delay: float = ..., +) -> TLSStream: + ... + + +# tls=False +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + tls: Literal[False], + ssl_context: ssl.SSLContext | None = ..., + tls_standard_compatible: bool = ..., + tls_hostname: str | None = ..., + happy_eyeballs_delay: float = ..., +) -> SocketStream: + ... + + +# No TLS arguments +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + happy_eyeballs_delay: float = ..., +) -> SocketStream: + ... + + +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = None, + tls: bool = False, + ssl_context: ssl.SSLContext | None = None, + tls_standard_compatible: bool = True, + tls_hostname: str | None = None, + happy_eyeballs_delay: float = 0.25, +) -> SocketStream | TLSStream: + """ + Connect to a host using the TCP protocol. + + This function implements the stateless version of the Happy Eyeballs algorithm (RFC + 6555). If ``remote_host`` is a host name that resolves to multiple IP addresses, + each one is tried until one connection attempt succeeds. If the first attempt does + not connected within 250 milliseconds, a second attempt is started using the next + address in the list, and so on. On IPv6 enabled systems, an IPv6 address (if + available) is tried first. + + When the connection has been established, a TLS handshake will be done if either + ``ssl_context`` or ``tls_hostname`` is not ``None``, or if ``tls`` is ``True``. + + :param remote_host: the IP address or host name to connect to + :param remote_port: port on the target host to connect to + :param local_host: the interface address or name to bind the socket to before connecting + :param tls: ``True`` to do a TLS handshake with the connected stream and return a + :class:`~anyio.streams.tls.TLSStream` instead + :param ssl_context: the SSL context object to use (if omitted, a default context is created) + :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake before closing + the stream and requires that the server does this as well. Otherwise, + :exc:`~ssl.SSLEOFError` may be raised during reads from the stream. + Some protocols, such as HTTP, require this option to be ``False``. + See :meth:`~ssl.SSLContext.wrap_socket` for details. + :param tls_hostname: host name to check the server certificate against (defaults to the value + of ``remote_host``) + :param happy_eyeballs_delay: delay (in seconds) before starting the next connection attempt + :return: a socket stream object if no TLS handshake was done, otherwise a TLS stream + :raises OSError: if the connection attempt fails + + """ + # Placed here due to https://github.com/python/mypy/issues/7057 + connected_stream: SocketStream | None = None + + async def try_connect(remote_host: str, event: Event) -> None: + nonlocal connected_stream + try: + stream = await asynclib.connect_tcp(remote_host, remote_port, local_address) + except OSError as exc: + oserrors.append(exc) + return + else: + if connected_stream is None: + connected_stream = stream + tg.cancel_scope.cancel() + else: + await stream.aclose() + finally: + event.set() + + asynclib = get_asynclib() + local_address: IPSockAddrType | None = None + family = socket.AF_UNSPEC + if local_host: + gai_res = await getaddrinfo(str(local_host), None) + family, *_, local_address = gai_res[0] + + target_host = str(remote_host) + try: + addr_obj = ip_address(remote_host) + except ValueError: + # getaddrinfo() will raise an exception if name resolution fails + gai_res = await getaddrinfo( + target_host, remote_port, family=family, type=socket.SOCK_STREAM + ) + + # Organize the list so that the first address is an IPv6 address (if available) and the + # second one is an IPv4 addresses. The rest can be in whatever order. + v6_found = v4_found = False + target_addrs: list[tuple[socket.AddressFamily, str]] = [] + for af, *rest, sa in gai_res: + if af == socket.AF_INET6 and not v6_found: + v6_found = True + target_addrs.insert(0, (af, sa[0])) + elif af == socket.AF_INET and not v4_found and v6_found: + v4_found = True + target_addrs.insert(1, (af, sa[0])) + else: + target_addrs.append((af, sa[0])) + else: + if isinstance(addr_obj, IPv6Address): + target_addrs = [(socket.AF_INET6, addr_obj.compressed)] + else: + target_addrs = [(socket.AF_INET, addr_obj.compressed)] + + oserrors: list[OSError] = [] + async with create_task_group() as tg: + for i, (af, addr) in enumerate(target_addrs): + event = Event() + tg.start_soon(try_connect, addr, event) + with move_on_after(happy_eyeballs_delay): + await event.wait() + + if connected_stream is None: + cause = oserrors[0] if len(oserrors) == 1 else asynclib.ExceptionGroup(oserrors) + raise OSError("All connection attempts failed") from cause + + if tls or tls_hostname or ssl_context: + try: + return await TLSStream.wrap( + connected_stream, + server_side=False, + hostname=tls_hostname or str(remote_host), + ssl_context=ssl_context, + standard_compatible=tls_standard_compatible, + ) + except BaseException: + await aclose_forcefully(connected_stream) + raise + + return connected_stream + + +async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream: + """ + Connect to the given UNIX socket. + + Not available on Windows. + + :param path: path to the socket + :return: a socket stream object + + """ + path = str(Path(path)) + return await get_asynclib().connect_unix(path) + + +async def create_tcp_listener( + *, + local_host: IPAddressType | None = None, + local_port: int = 0, + family: AnyIPAddressFamily = socket.AddressFamily.AF_UNSPEC, + backlog: int = 65536, + reuse_port: bool = False, +) -> MultiListener[SocketStream]: + """ + Create a TCP socket listener. + + :param local_port: port number to listen on + :param local_host: IP address of the interface to listen on. If omitted, listen on + all IPv4 and IPv6 interfaces. To listen on all interfaces on a specific address + family, use ``0.0.0.0`` for IPv4 or ``::`` for IPv6. + :param family: address family (used if ``local_host`` was omitted) + :param backlog: maximum number of queued incoming connections (up to a maximum of + 2**16, or 65536) + :param reuse_port: ``True`` to allow multiple sockets to bind to the same + address/port (not supported on Windows) + :return: a list of listener objects + + """ + asynclib = get_asynclib() + backlog = min(backlog, 65536) + local_host = str(local_host) if local_host is not None else None + gai_res = await getaddrinfo( + local_host, # type: ignore[arg-type] + local_port, + family=family, + type=socket.SocketKind.SOCK_STREAM if sys.platform == "win32" else 0, + flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG, + ) + listeners: list[SocketListener] = [] + try: + # The set() is here to work around a glibc bug: + # https://sourceware.org/bugzilla/show_bug.cgi?id=14969 + sockaddr: tuple[str, int] | tuple[str, int, int, int] + for fam, kind, *_, sockaddr in sorted(set(gai_res)): + # Workaround for an uvloop bug where we don't get the correct scope ID for + # IPv6 link-local addresses when passing type=socket.SOCK_STREAM to + # getaddrinfo(): https://github.com/MagicStack/uvloop/issues/539 + if sys.platform != "win32" and kind is not SocketKind.SOCK_STREAM: + continue + + raw_socket = socket.socket(fam) + raw_socket.setblocking(False) + + # For Windows, enable exclusive address use. For others, enable address reuse. + if sys.platform == "win32": + raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + else: + raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + if reuse_port: + raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + # If only IPv6 was requested, disable dual stack operation + if fam == socket.AF_INET6: + raw_socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) + + # Workaround for #554 + if "%" in sockaddr[0]: + addr, scope_id = sockaddr[0].split("%", 1) + sockaddr = (addr, sockaddr[1], 0, int(scope_id)) + + raw_socket.bind(sockaddr) + raw_socket.listen(backlog) + listener = asynclib.TCPSocketListener(raw_socket) + listeners.append(listener) + except BaseException: + for listener in listeners: + await listener.aclose() + + raise + + return MultiListener(listeners) + + +async def create_unix_listener( + path: str | PathLike[str], + *, + mode: int | None = None, + backlog: int = 65536, +) -> SocketListener: + """ + Create a UNIX socket listener. + + Not available on Windows. + + :param path: path of the socket + :param mode: permissions to set on the socket + :param backlog: maximum number of queued incoming connections (up to a maximum of 2**16, or + 65536) + :return: a listener object + + .. versionchanged:: 3.0 + If a socket already exists on the file system in the given path, it will be removed first. + + """ + path_str = str(path) + path = Path(path) + if path.is_socket(): + path.unlink() + + backlog = min(backlog, 65536) + raw_socket = socket.socket(socket.AF_UNIX) + raw_socket.setblocking(False) + try: + await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True) + if mode is not None: + await to_thread.run_sync(chmod, path_str, mode, cancellable=True) + + raw_socket.listen(backlog) + return get_asynclib().UNIXSocketListener(raw_socket) + except BaseException: + raw_socket.close() + raise + + +async def create_udp_socket( + family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC, + *, + local_host: IPAddressType | None = None, + local_port: int = 0, + reuse_port: bool = False, +) -> UDPSocket: + """ + Create a UDP socket. + + If ``local_port`` has been given, the socket will be bound to this port on the local + machine, making this socket suitable for providing UDP based services. + + :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from + ``local_host`` if omitted + :param local_host: IP address or host name of the local interface to bind to + :param local_port: local port to bind to + :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port + (not supported on Windows) + :return: a UDP socket + + """ + if family is AddressFamily.AF_UNSPEC and not local_host: + raise ValueError('Either "family" or "local_host" must be given') + + if local_host: + gai_res = await getaddrinfo( + str(local_host), + local_port, + family=family, + type=socket.SOCK_DGRAM, + flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG, + ) + family = cast(AnyIPAddressFamily, gai_res[0][0]) + local_address = gai_res[0][-1] + elif family is AddressFamily.AF_INET6: + local_address = ("::", 0) + else: + local_address = ("0.0.0.0", 0) + + return await get_asynclib().create_udp_socket( + family, local_address, None, reuse_port + ) + + +async def create_connected_udp_socket( + remote_host: IPAddressType, + remote_port: int, + *, + family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC, + local_host: IPAddressType | None = None, + local_port: int = 0, + reuse_port: bool = False, +) -> ConnectedUDPSocket: + """ + Create a connected UDP socket. + + Connected UDP sockets can only communicate with the specified remote host/port, and any packets + sent from other sources are dropped. + + :param remote_host: remote host to set as the default target + :param remote_port: port on the remote host to set as the default target + :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from + ``local_host`` or ``remote_host`` if omitted + :param local_host: IP address or host name of the local interface to bind to + :param local_port: local port to bind to + :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port + (not supported on Windows) + :return: a connected UDP socket + + """ + local_address = None + if local_host: + gai_res = await getaddrinfo( + str(local_host), + local_port, + family=family, + type=socket.SOCK_DGRAM, + flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG, + ) + family = cast(AnyIPAddressFamily, gai_res[0][0]) + local_address = gai_res[0][-1] + + gai_res = await getaddrinfo( + str(remote_host), remote_port, family=family, type=socket.SOCK_DGRAM + ) + family = cast(AnyIPAddressFamily, gai_res[0][0]) + remote_address = gai_res[0][-1] + + return await get_asynclib().create_udp_socket( + family, local_address, remote_address, reuse_port + ) + + +async def getaddrinfo( + host: bytearray | bytes | str, + port: str | int | None, + *, + family: int | AddressFamily = 0, + type: int | SocketKind = 0, + proto: int = 0, + flags: int = 0, +) -> GetAddrInfoReturnType: + """ + Look up a numeric IP address given a host name. + + Internationalized domain names are translated according to the (non-transitional) IDNA 2008 + standard. + + .. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of + (host, port), unlike what :func:`socket.getaddrinfo` does. + + :param host: host name + :param port: port number + :param family: socket family (`'AF_INET``, ...) + :param type: socket type (``SOCK_STREAM``, ...) + :param proto: protocol number + :param flags: flags to pass to upstream ``getaddrinfo()`` + :return: list of tuples containing (family, type, proto, canonname, sockaddr) + + .. seealso:: :func:`socket.getaddrinfo` + + """ + # Handle unicode hostnames + if isinstance(host, str): + try: + encoded_host = host.encode("ascii") + except UnicodeEncodeError: + import idna + + encoded_host = idna.encode(host, uts46=True) + else: + encoded_host = host + + gai_res = await get_asynclib().getaddrinfo( + encoded_host, port, family=family, type=type, proto=proto, flags=flags + ) + return [ + (family, type, proto, canonname, convert_ipv6_sockaddr(sockaddr)) + for family, type, proto, canonname, sockaddr in gai_res + ] + + +def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str, str]]: + """ + Look up the host name of an IP address. + + :param sockaddr: socket address (e.g. (ipaddress, port) for IPv4) + :param flags: flags to pass to upstream ``getnameinfo()`` + :return: a tuple of (host name, service name) + + .. seealso:: :func:`socket.getnameinfo` + + """ + return get_asynclib().getnameinfo(sockaddr, flags) + + +def wait_socket_readable(sock: socket.socket) -> Awaitable[None]: + """ + Wait until the given socket has data to be read. + + This does **NOT** work on Windows when using the asyncio backend with a proactor event loop + (default on py3.8+). + + .. warning:: Only use this on raw sockets that have not been wrapped by any higher level + constructs like socket streams! + + :param sock: a socket object + :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the + socket to become readable + :raises ~anyio.BusyResourceError: if another task is already waiting for the socket + to become readable + + """ + return get_asynclib().wait_socket_readable(sock) + + +def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: + """ + Wait until the given socket can be written to. + + This does **NOT** work on Windows when using the asyncio backend with a proactor event loop + (default on py3.8+). + + .. warning:: Only use this on raw sockets that have not been wrapped by any higher level + constructs like socket streams! + + :param sock: a socket object + :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the + socket to become writable + :raises ~anyio.BusyResourceError: if another task is already waiting for the socket + to become writable + + """ + return get_asynclib().wait_socket_writable(sock) + + +# +# Private API +# + + +def convert_ipv6_sockaddr( + sockaddr: tuple[str, int, int, int] | tuple[str, int] +) -> tuple[str, int]: + """ + Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format. + + If the scope ID is nonzero, it is added to the address, separated with ``%``. + Otherwise the flow id and scope id are simply cut off from the tuple. + Any other kinds of socket addresses are returned as-is. + + :param sockaddr: the result of :meth:`~socket.socket.getsockname` + :return: the converted socket address + + """ + # This is more complicated than it should be because of MyPy + if isinstance(sockaddr, tuple) and len(sockaddr) == 4: + host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr) + if scope_id: + # PyPy (as of v7.3.11) leaves the interface name in the result, so + # we discard it and only get the scope ID from the end + # (https://foss.heptapod.net/pypy/pypy/-/issues/3938) + host = host.split("%")[0] + + # Add scope_id to the address + return f"{host}%{scope_id}", port + else: + return host, port + else: + return cast(Tuple[str, int], sockaddr) diff --git a/contrib/python/anyio/anyio/_core/_streams.py b/contrib/python/anyio/anyio/_core/_streams.py new file mode 100644 index 0000000000..54ea2b2baf --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_streams.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import math +from typing import Any, TypeVar, overload + +from ..streams.memory import ( + MemoryObjectReceiveStream, + MemoryObjectSendStream, + MemoryObjectStreamState, +) + +T_Item = TypeVar("T_Item") + + +@overload +def create_memory_object_stream( + max_buffer_size: float = ..., +) -> tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: + ... + + +@overload +def create_memory_object_stream( + max_buffer_size: float = ..., item_type: type[T_Item] = ... +) -> tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]]: + ... + + +def create_memory_object_stream( + max_buffer_size: float = 0, item_type: type[T_Item] | None = None +) -> tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: + """ + Create a memory object stream. + + :param max_buffer_size: number of items held in the buffer until ``send()`` starts blocking + :param item_type: type of item, for marking the streams with the right generic type for + static typing (not used at run time) + :return: a tuple of (send stream, receive stream) + + """ + if max_buffer_size != math.inf and not isinstance(max_buffer_size, int): + raise ValueError("max_buffer_size must be either an integer or math.inf") + if max_buffer_size < 0: + raise ValueError("max_buffer_size cannot be negative") + + state: MemoryObjectStreamState = MemoryObjectStreamState(max_buffer_size) + return MemoryObjectSendStream(state), MemoryObjectReceiveStream(state) diff --git a/contrib/python/anyio/anyio/_core/_subprocesses.py b/contrib/python/anyio/anyio/_core/_subprocesses.py new file mode 100644 index 0000000000..1a26ac8c7f --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_subprocesses.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from io import BytesIO +from os import PathLike +from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess +from typing import ( + IO, + Any, + AsyncIterable, + Mapping, + Sequence, + cast, +) + +from ..abc import Process +from ._eventloop import get_asynclib +from ._tasks import create_task_group + + +async def run_process( + command: str | bytes | Sequence[str | bytes], + *, + input: bytes | None = None, + stdout: int | IO[Any] | None = PIPE, + stderr: int | IO[Any] | None = PIPE, + check: bool = True, + cwd: str | bytes | PathLike[str] | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, +) -> CompletedProcess[bytes]: + """ + Run an external command in a subprocess and wait until it completes. + + .. seealso:: :func:`subprocess.run` + + :param command: either a string to pass to the shell, or an iterable of strings containing the + executable name or path and its arguments + :param input: bytes passed to the standard input of the subprocess + :param stdout: either :data:`subprocess.PIPE` or :data:`subprocess.DEVNULL` + :param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL` or + :data:`subprocess.STDOUT` + :param check: if ``True``, raise :exc:`~subprocess.CalledProcessError` if the process + terminates with a return code other than 0 + :param cwd: If not ``None``, change the working directory to this before running the command + :param env: if not ``None``, this mapping replaces the inherited environment variables from the + parent process + :param start_new_session: if ``true`` the setsid() system call will be made in the child + process prior to the execution of the subprocess. (POSIX only) + :return: an object representing the completed process + :raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process exits with a + nonzero return code + + """ + + async def drain_stream(stream: AsyncIterable[bytes], index: int) -> None: + buffer = BytesIO() + async for chunk in stream: + buffer.write(chunk) + + stream_contents[index] = buffer.getvalue() + + async with await open_process( + command, + stdin=PIPE if input else DEVNULL, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) as process: + stream_contents: list[bytes | None] = [None, None] + try: + async with create_task_group() as tg: + if process.stdout: + tg.start_soon(drain_stream, process.stdout, 0) + if process.stderr: + tg.start_soon(drain_stream, process.stderr, 1) + if process.stdin and input: + await process.stdin.send(input) + await process.stdin.aclose() + + await process.wait() + except BaseException: + process.kill() + raise + + output, errors = stream_contents + if check and process.returncode != 0: + raise CalledProcessError(cast(int, process.returncode), command, output, errors) + + return CompletedProcess(command, cast(int, process.returncode), output, errors) + + +async def open_process( + command: str | bytes | Sequence[str | bytes], + *, + stdin: int | IO[Any] | None = PIPE, + stdout: int | IO[Any] | None = PIPE, + stderr: int | IO[Any] | None = PIPE, + cwd: str | bytes | PathLike[str] | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, +) -> Process: + """ + Start an external command in a subprocess. + + .. seealso:: :class:`subprocess.Popen` + + :param command: either a string to pass to the shell, or an iterable of strings containing the + executable name or path and its arguments + :param stdin: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, a + file-like object, or ``None`` + :param stdout: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, + a file-like object, or ``None`` + :param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, + :data:`subprocess.STDOUT`, a file-like object, or ``None`` + :param cwd: If not ``None``, the working directory is changed before executing + :param env: If env is not ``None``, it must be a mapping that defines the environment + variables for the new process + :param start_new_session: if ``true`` the setsid() system call will be made in the child + process prior to the execution of the subprocess. (POSIX only) + :return: an asynchronous process object + + """ + shell = isinstance(command, str) + return await get_asynclib().open_process( + command, + shell=shell, + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) diff --git a/contrib/python/anyio/anyio/_core/_synchronization.py b/contrib/python/anyio/anyio/_core/_synchronization.py new file mode 100644 index 0000000000..783570c7ac --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_synchronization.py @@ -0,0 +1,596 @@ +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from types import TracebackType +from warnings import warn + +from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled +from ._compat import DeprecatedAwaitable +from ._eventloop import get_asynclib +from ._exceptions import BusyResourceError, WouldBlock +from ._tasks import CancelScope +from ._testing import TaskInfo, get_current_task + + +@dataclass(frozen=True) +class EventStatistics: + """ + :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Event.wait` + """ + + tasks_waiting: int + + +@dataclass(frozen=True) +class CapacityLimiterStatistics: + """ + :ivar int borrowed_tokens: number of tokens currently borrowed by tasks + :ivar float total_tokens: total number of available tokens + :ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from this + limiter + :ivar int tasks_waiting: number of tasks waiting on :meth:`~.CapacityLimiter.acquire` or + :meth:`~.CapacityLimiter.acquire_on_behalf_of` + """ + + borrowed_tokens: int + total_tokens: float + borrowers: tuple[object, ...] + tasks_waiting: int + + +@dataclass(frozen=True) +class LockStatistics: + """ + :ivar bool locked: flag indicating if this lock is locked or not + :ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the lock is not + held by any task) + :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Lock.acquire` + """ + + locked: bool + owner: TaskInfo | None + tasks_waiting: int + + +@dataclass(frozen=True) +class ConditionStatistics: + """ + :ivar int tasks_waiting: number of tasks blocked on :meth:`~.Condition.wait` + :ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying :class:`~.Lock` + """ + + tasks_waiting: int + lock_statistics: LockStatistics + + +@dataclass(frozen=True) +class SemaphoreStatistics: + """ + :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Semaphore.acquire` + + """ + + tasks_waiting: int + + +class Event: + def __new__(cls) -> Event: + return get_asynclib().Event() + + def set(self) -> DeprecatedAwaitable: + """Set the flag, notifying all listeners.""" + raise NotImplementedError + + def is_set(self) -> bool: + """Return ``True`` if the flag is set, ``False`` if not.""" + raise NotImplementedError + + async def wait(self) -> None: + """ + Wait until the flag has been set. + + If the flag has already been set when this method is called, it returns immediately. + + """ + raise NotImplementedError + + def statistics(self) -> EventStatistics: + """Return statistics about the current state of this event.""" + raise NotImplementedError + + +class Lock: + _owner_task: TaskInfo | None = None + + def __init__(self) -> None: + self._waiters: deque[tuple[TaskInfo, Event]] = deque() + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.release() + + async def acquire(self) -> None: + """Acquire the lock.""" + await checkpoint_if_cancelled() + try: + self.acquire_nowait() + except WouldBlock: + task = get_current_task() + event = Event() + token = task, event + self._waiters.append(token) + try: + await event.wait() + except BaseException: + if not event.is_set(): + self._waiters.remove(token) + elif self._owner_task == task: + self.release() + + raise + + assert self._owner_task == task + else: + try: + await cancel_shielded_checkpoint() + except BaseException: + self.release() + raise + + def acquire_nowait(self) -> None: + """ + Acquire the lock, without blocking. + + :raises ~anyio.WouldBlock: if the operation would block + + """ + task = get_current_task() + if self._owner_task == task: + raise RuntimeError("Attempted to acquire an already held Lock") + + if self._owner_task is not None: + raise WouldBlock + + self._owner_task = task + + def release(self) -> DeprecatedAwaitable: + """Release the lock.""" + if self._owner_task != get_current_task(): + raise RuntimeError("The current task is not holding this lock") + + if self._waiters: + self._owner_task, event = self._waiters.popleft() + event.set() + else: + del self._owner_task + + return DeprecatedAwaitable(self.release) + + def locked(self) -> bool: + """Return True if the lock is currently held.""" + return self._owner_task is not None + + def statistics(self) -> LockStatistics: + """ + Return statistics about the current state of this lock. + + .. versionadded:: 3.0 + """ + return LockStatistics(self.locked(), self._owner_task, len(self._waiters)) + + +class Condition: + _owner_task: TaskInfo | None = None + + def __init__(self, lock: Lock | None = None): + self._lock = lock or Lock() + self._waiters: deque[Event] = deque() + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.release() + + def _check_acquired(self) -> None: + if self._owner_task != get_current_task(): + raise RuntimeError("The current task is not holding the underlying lock") + + async def acquire(self) -> None: + """Acquire the underlying lock.""" + await self._lock.acquire() + self._owner_task = get_current_task() + + def acquire_nowait(self) -> None: + """ + Acquire the underlying lock, without blocking. + + :raises ~anyio.WouldBlock: if the operation would block + + """ + self._lock.acquire_nowait() + self._owner_task = get_current_task() + + def release(self) -> DeprecatedAwaitable: + """Release the underlying lock.""" + self._lock.release() + return DeprecatedAwaitable(self.release) + + def locked(self) -> bool: + """Return True if the lock is set.""" + return self._lock.locked() + + def notify(self, n: int = 1) -> None: + """Notify exactly n listeners.""" + self._check_acquired() + for _ in range(n): + try: + event = self._waiters.popleft() + except IndexError: + break + + event.set() + + def notify_all(self) -> None: + """Notify all the listeners.""" + self._check_acquired() + for event in self._waiters: + event.set() + + self._waiters.clear() + + async def wait(self) -> None: + """Wait for a notification.""" + await checkpoint() + event = Event() + self._waiters.append(event) + self.release() + try: + await event.wait() + except BaseException: + if not event.is_set(): + self._waiters.remove(event) + + raise + finally: + with CancelScope(shield=True): + await self.acquire() + + def statistics(self) -> ConditionStatistics: + """ + Return statistics about the current state of this condition. + + .. versionadded:: 3.0 + """ + return ConditionStatistics(len(self._waiters), self._lock.statistics()) + + +class Semaphore: + def __init__(self, initial_value: int, *, max_value: int | None = None): + if not isinstance(initial_value, int): + raise TypeError("initial_value must be an integer") + if initial_value < 0: + raise ValueError("initial_value must be >= 0") + if max_value is not None: + if not isinstance(max_value, int): + raise TypeError("max_value must be an integer or None") + if max_value < initial_value: + raise ValueError( + "max_value must be equal to or higher than initial_value" + ) + + self._value = initial_value + self._max_value = max_value + self._waiters: deque[Event] = deque() + + async def __aenter__(self) -> Semaphore: + await self.acquire() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.release() + + async def acquire(self) -> None: + """Decrement the semaphore value, blocking if necessary.""" + await checkpoint_if_cancelled() + try: + self.acquire_nowait() + except WouldBlock: + event = Event() + self._waiters.append(event) + try: + await event.wait() + except BaseException: + if not event.is_set(): + self._waiters.remove(event) + else: + self.release() + + raise + else: + try: + await cancel_shielded_checkpoint() + except BaseException: + self.release() + raise + + def acquire_nowait(self) -> None: + """ + Acquire the underlying lock, without blocking. + + :raises ~anyio.WouldBlock: if the operation would block + + """ + if self._value == 0: + raise WouldBlock + + self._value -= 1 + + def release(self) -> DeprecatedAwaitable: + """Increment the semaphore value.""" + if self._max_value is not None and self._value == self._max_value: + raise ValueError("semaphore released too many times") + + if self._waiters: + self._waiters.popleft().set() + else: + self._value += 1 + + return DeprecatedAwaitable(self.release) + + @property + def value(self) -> int: + """The current value of the semaphore.""" + return self._value + + @property + def max_value(self) -> int | None: + """The maximum value of the semaphore.""" + return self._max_value + + def statistics(self) -> SemaphoreStatistics: + """ + Return statistics about the current state of this semaphore. + + .. versionadded:: 3.0 + """ + return SemaphoreStatistics(len(self._waiters)) + + +class CapacityLimiter: + def __new__(cls, total_tokens: float) -> CapacityLimiter: + return get_asynclib().CapacityLimiter(total_tokens) + + async def __aenter__(self) -> None: + raise NotImplementedError + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + raise NotImplementedError + + @property + def total_tokens(self) -> float: + """ + The total number of tokens available for borrowing. + + This is a read-write property. If the total number of tokens is increased, the + proportionate number of tasks waiting on this limiter will be granted their tokens. + + .. versionchanged:: 3.0 + The property is now writable. + + """ + raise NotImplementedError + + @total_tokens.setter + def total_tokens(self, value: float) -> None: + raise NotImplementedError + + async def set_total_tokens(self, value: float) -> None: + warn( + "CapacityLimiter.set_total_tokens has been deprecated. Set the value of the" + '"total_tokens" attribute directly.', + DeprecationWarning, + ) + self.total_tokens = value + + @property + def borrowed_tokens(self) -> int: + """The number of tokens that have currently been borrowed.""" + raise NotImplementedError + + @property + def available_tokens(self) -> float: + """The number of tokens currently available to be borrowed""" + raise NotImplementedError + + def acquire_nowait(self) -> DeprecatedAwaitable: + """ + Acquire a token for the current task without waiting for one to become available. + + :raises ~anyio.WouldBlock: if there are no tokens available for borrowing + + """ + raise NotImplementedError + + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: + """ + Acquire a token without waiting for one to become available. + + :param borrower: the entity borrowing a token + :raises ~anyio.WouldBlock: if there are no tokens available for borrowing + + """ + raise NotImplementedError + + async def acquire(self) -> None: + """ + Acquire a token for the current task, waiting if necessary for one to become available. + + """ + raise NotImplementedError + + async def acquire_on_behalf_of(self, borrower: object) -> None: + """ + Acquire a token, waiting if necessary for one to become available. + + :param borrower: the entity borrowing a token + + """ + raise NotImplementedError + + def release(self) -> None: + """ + Release the token held by the current task. + :raises RuntimeError: if the current task has not borrowed a token from this limiter. + + """ + raise NotImplementedError + + def release_on_behalf_of(self, borrower: object) -> None: + """ + Release the token held by the given borrower. + + :raises RuntimeError: if the borrower has not borrowed a token from this limiter. + + """ + raise NotImplementedError + + def statistics(self) -> CapacityLimiterStatistics: + """ + Return statistics about the current state of this limiter. + + .. versionadded:: 3.0 + + """ + raise NotImplementedError + + +def create_lock() -> Lock: + """ + Create an asynchronous lock. + + :return: a lock object + + .. deprecated:: 3.0 + Use :class:`~Lock` directly. + + """ + warn("create_lock() is deprecated -- use Lock() directly", DeprecationWarning) + return Lock() + + +def create_condition(lock: Lock | None = None) -> Condition: + """ + Create an asynchronous condition. + + :param lock: the lock to base the condition object on + :return: a condition object + + .. deprecated:: 3.0 + Use :class:`~Condition` directly. + + """ + warn( + "create_condition() is deprecated -- use Condition() directly", + DeprecationWarning, + ) + return Condition(lock=lock) + + +def create_event() -> Event: + """ + Create an asynchronous event object. + + :return: an event object + + .. deprecated:: 3.0 + Use :class:`~Event` directly. + + """ + warn("create_event() is deprecated -- use Event() directly", DeprecationWarning) + return get_asynclib().Event() + + +def create_semaphore(value: int, *, max_value: int | None = None) -> Semaphore: + """ + Create an asynchronous semaphore. + + :param value: the semaphore's initial value + :param max_value: if set, makes this a "bounded" semaphore that raises :exc:`ValueError` if the + semaphore's value would exceed this number + :return: a semaphore object + + .. deprecated:: 3.0 + Use :class:`~Semaphore` directly. + + """ + warn( + "create_semaphore() is deprecated -- use Semaphore() directly", + DeprecationWarning, + ) + return Semaphore(value, max_value=max_value) + + +def create_capacity_limiter(total_tokens: float) -> CapacityLimiter: + """ + Create a capacity limiter. + + :param total_tokens: the total number of tokens available for borrowing (can be an integer or + :data:`math.inf`) + :return: a capacity limiter object + + .. deprecated:: 3.0 + Use :class:`~CapacityLimiter` directly. + + """ + warn( + "create_capacity_limiter() is deprecated -- use CapacityLimiter() directly", + DeprecationWarning, + ) + return get_asynclib().CapacityLimiter(total_tokens) + + +class ResourceGuard: + __slots__ = "action", "_guarded" + + def __init__(self, action: str): + self.action = action + self._guarded = False + + def __enter__(self) -> None: + if self._guarded: + raise BusyResourceError(self.action) + + self._guarded = True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + self._guarded = False + return None diff --git a/contrib/python/anyio/anyio/_core/_tasks.py b/contrib/python/anyio/anyio/_core/_tasks.py new file mode 100644 index 0000000000..e9d9c2bd67 --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_tasks.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import math +from types import TracebackType +from warnings import warn + +from ..abc._tasks import TaskGroup, TaskStatus +from ._compat import ( + DeprecatedAsyncContextManager, + DeprecatedAwaitable, + DeprecatedAwaitableFloat, +) +from ._eventloop import get_asynclib + + +class _IgnoredTaskStatus(TaskStatus[object]): + def started(self, value: object = None) -> None: + pass + + +TASK_STATUS_IGNORED = _IgnoredTaskStatus() + + +class CancelScope(DeprecatedAsyncContextManager["CancelScope"]): + """ + Wraps a unit of work that can be made separately cancellable. + + :param deadline: The time (clock value) when this scope is cancelled automatically + :param shield: ``True`` to shield the cancel scope from external cancellation + """ + + def __new__( + cls, *, deadline: float = math.inf, shield: bool = False + ) -> CancelScope: + return get_asynclib().CancelScope(shield=shield, deadline=deadline) + + def cancel(self) -> DeprecatedAwaitable: + """Cancel this scope immediately.""" + raise NotImplementedError + + @property + def deadline(self) -> float: + """ + The time (clock value) when this scope is cancelled automatically. + + Will be ``float('inf')`` if no timeout has been set. + + """ + raise NotImplementedError + + @deadline.setter + def deadline(self, value: float) -> None: + raise NotImplementedError + + @property + def cancel_called(self) -> bool: + """``True`` if :meth:`cancel` has been called.""" + raise NotImplementedError + + @property + def shield(self) -> bool: + """ + ``True`` if this scope is shielded from external cancellation. + + While a scope is shielded, it will not receive cancellations from outside. + + """ + raise NotImplementedError + + @shield.setter + def shield(self, value: bool) -> None: + raise NotImplementedError + + def __enter__(self) -> CancelScope: + raise NotImplementedError + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + raise NotImplementedError + + +def open_cancel_scope(*, shield: bool = False) -> CancelScope: + """ + Open a cancel scope. + + :param shield: ``True`` to shield the cancel scope from external cancellation + :return: a cancel scope + + .. deprecated:: 3.0 + Use :class:`~CancelScope` directly. + + """ + warn( + "open_cancel_scope() is deprecated -- use CancelScope() directly", + DeprecationWarning, + ) + return get_asynclib().CancelScope(shield=shield) + + +class FailAfterContextManager(DeprecatedAsyncContextManager[CancelScope]): + def __init__(self, cancel_scope: CancelScope): + self._cancel_scope = cancel_scope + + def __enter__(self) -> CancelScope: + return self._cancel_scope.__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + retval = self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) + if self._cancel_scope.cancel_called: + raise TimeoutError + + return retval + + +def fail_after(delay: float | None, shield: bool = False) -> FailAfterContextManager: + """ + Create a context manager which raises a :class:`TimeoutError` if does not finish in time. + + :param delay: maximum allowed time (in seconds) before raising the exception, or ``None`` to + disable the timeout + :param shield: ``True`` to shield the cancel scope from external cancellation + :return: a context manager that yields a cancel scope + :rtype: :class:`~typing.ContextManager`\\[:class:`~anyio.CancelScope`\\] + + """ + deadline = ( + (get_asynclib().current_time() + delay) if delay is not None else math.inf + ) + cancel_scope = get_asynclib().CancelScope(deadline=deadline, shield=shield) + return FailAfterContextManager(cancel_scope) + + +def move_on_after(delay: float | None, shield: bool = False) -> CancelScope: + """ + Create a cancel scope with a deadline that expires after the given delay. + + :param delay: maximum allowed time (in seconds) before exiting the context block, or ``None`` + to disable the timeout + :param shield: ``True`` to shield the cancel scope from external cancellation + :return: a cancel scope + + """ + deadline = ( + (get_asynclib().current_time() + delay) if delay is not None else math.inf + ) + return get_asynclib().CancelScope(deadline=deadline, shield=shield) + + +def current_effective_deadline() -> DeprecatedAwaitableFloat: + """ + Return the nearest deadline among all the cancel scopes effective for the current task. + + :return: a clock value from the event loop's internal clock (or ``float('inf')`` if + there is no deadline in effect, or ``float('-inf')`` if the current scope has + been cancelled) + :rtype: float + + """ + return DeprecatedAwaitableFloat( + get_asynclib().current_effective_deadline(), current_effective_deadline + ) + + +def create_task_group() -> TaskGroup: + """ + Create a task group. + + :return: a task group + + """ + return get_asynclib().TaskGroup() diff --git a/contrib/python/anyio/anyio/_core/_testing.py b/contrib/python/anyio/anyio/_core/_testing.py new file mode 100644 index 0000000000..c8191b3866 --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_testing.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import Any, Awaitable, Generator + +from ._compat import DeprecatedAwaitableList, _warn_deprecation +from ._eventloop import get_asynclib + + +class TaskInfo: + """ + Represents an asynchronous task. + + :ivar int id: the unique identifier of the task + :ivar parent_id: the identifier of the parent task, if any + :vartype parent_id: Optional[int] + :ivar str name: the description of the task (if any) + :ivar ~collections.abc.Coroutine coro: the coroutine object of the task + """ + + __slots__ = "_name", "id", "parent_id", "name", "coro" + + def __init__( + self, + id: int, + parent_id: int | None, + name: str | None, + coro: Generator[Any, Any, Any] | Awaitable[Any], + ): + func = get_current_task + self._name = f"{func.__module__}.{func.__qualname__}" + self.id: int = id + self.parent_id: int | None = parent_id + self.name: str | None = name + self.coro: Generator[Any, Any, Any] | Awaitable[Any] = coro + + def __eq__(self, other: object) -> bool: + if isinstance(other, TaskInfo): + return self.id == other.id + + return NotImplemented + + def __hash__(self) -> int: + return hash(self.id) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})" + + def __await__(self) -> Generator[None, None, TaskInfo]: + _warn_deprecation(self) + if False: + yield + + return self + + def _unwrap(self) -> TaskInfo: + return self + + +def get_current_task() -> TaskInfo: + """ + Return the current task. + + :return: a representation of the current task + + """ + return get_asynclib().get_current_task() + + +def get_running_tasks() -> DeprecatedAwaitableList[TaskInfo]: + """ + Return a list of running tasks in the current event loop. + + :return: a list of task info objects + + """ + tasks = get_asynclib().get_running_tasks() + return DeprecatedAwaitableList(tasks, func=get_running_tasks) + + +async def wait_all_tasks_blocked() -> None: + """Wait until all other tasks are waiting for something.""" + await get_asynclib().wait_all_tasks_blocked() diff --git a/contrib/python/anyio/anyio/_core/_typedattr.py b/contrib/python/anyio/anyio/_core/_typedattr.py new file mode 100644 index 0000000000..bf9202eeab --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_typedattr.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import sys +from typing import Any, Callable, Mapping, TypeVar, overload + +from ._exceptions import TypedAttributeLookupError + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + +T_Attr = TypeVar("T_Attr") +T_Default = TypeVar("T_Default") +undefined = object() + + +def typed_attribute() -> Any: + """Return a unique object, used to mark typed attributes.""" + return object() + + +class TypedAttributeSet: + """ + Superclass for typed attribute collections. + + Checks that every public attribute of every subclass has a type annotation. + """ + + def __init_subclass__(cls) -> None: + annotations: dict[str, Any] = getattr(cls, "__annotations__", {}) + for attrname in dir(cls): + if not attrname.startswith("_") and attrname not in annotations: + raise TypeError( + f"Attribute {attrname!r} is missing its type annotation" + ) + + super().__init_subclass__() + + +class TypedAttributeProvider: + """Base class for classes that wish to provide typed extra attributes.""" + + @property + def extra_attributes(self) -> Mapping[T_Attr, Callable[[], T_Attr]]: + """ + A mapping of the extra attributes to callables that return the corresponding values. + + If the provider wraps another provider, the attributes from that wrapper should also be + included in the returned mapping (but the wrapper may override the callables from the + wrapped instance). + + """ + return {} + + @overload + def extra(self, attribute: T_Attr) -> T_Attr: + ... + + @overload + def extra(self, attribute: T_Attr, default: T_Default) -> T_Attr | T_Default: + ... + + @final + def extra(self, attribute: Any, default: object = undefined) -> object: + """ + extra(attribute, default=undefined) + + Return the value of the given typed extra attribute. + + :param attribute: the attribute (member of a :class:`~TypedAttributeSet`) to look for + :param default: the value that should be returned if no value is found for the attribute + :raises ~anyio.TypedAttributeLookupError: if the search failed and no default value was + given + + """ + try: + return self.extra_attributes[attribute]() + except KeyError: + if default is undefined: + raise TypedAttributeLookupError("Attribute not found") from None + else: + return default diff --git a/contrib/python/anyio/anyio/abc/__init__.py b/contrib/python/anyio/anyio/abc/__init__.py new file mode 100644 index 0000000000..72c34e544e --- /dev/null +++ b/contrib/python/anyio/anyio/abc/__init__.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +__all__ = ( + "AsyncResource", + "IPAddressType", + "IPSockAddrType", + "SocketAttribute", + "SocketStream", + "SocketListener", + "UDPSocket", + "UNIXSocketStream", + "UDPPacketType", + "ConnectedUDPSocket", + "UnreliableObjectReceiveStream", + "UnreliableObjectSendStream", + "UnreliableObjectStream", + "ObjectReceiveStream", + "ObjectSendStream", + "ObjectStream", + "ByteReceiveStream", + "ByteSendStream", + "ByteStream", + "AnyUnreliableByteReceiveStream", + "AnyUnreliableByteSendStream", + "AnyUnreliableByteStream", + "AnyByteReceiveStream", + "AnyByteSendStream", + "AnyByteStream", + "Listener", + "Process", + "Event", + "Condition", + "Lock", + "Semaphore", + "CapacityLimiter", + "CancelScope", + "TaskGroup", + "TaskStatus", + "TestRunner", + "BlockingPortal", +) + +from typing import Any + +from ._resources import AsyncResource +from ._sockets import ( + ConnectedUDPSocket, + IPAddressType, + IPSockAddrType, + SocketAttribute, + SocketListener, + SocketStream, + UDPPacketType, + UDPSocket, + UNIXSocketStream, +) +from ._streams import ( + AnyByteReceiveStream, + AnyByteSendStream, + AnyByteStream, + AnyUnreliableByteReceiveStream, + AnyUnreliableByteSendStream, + AnyUnreliableByteStream, + ByteReceiveStream, + ByteSendStream, + ByteStream, + Listener, + ObjectReceiveStream, + ObjectSendStream, + ObjectStream, + UnreliableObjectReceiveStream, + UnreliableObjectSendStream, + UnreliableObjectStream, +) +from ._subprocesses import Process +from ._tasks import TaskGroup, TaskStatus +from ._testing import TestRunner + +# Re-exported here, for backwards compatibility +# isort: off +from .._core._synchronization import CapacityLimiter, Condition, Event, Lock, Semaphore +from .._core._tasks import CancelScope +from ..from_thread import BlockingPortal + +# Re-export imports so they look like they live directly in this package +key: str +value: Any +for key, value in list(locals().items()): + if getattr(value, "__module__", "").startswith("anyio.abc."): + value.__module__ = __name__ diff --git a/contrib/python/anyio/anyio/abc/_resources.py b/contrib/python/anyio/anyio/abc/_resources.py new file mode 100644 index 0000000000..e0a283fc98 --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_resources.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from types import TracebackType +from typing import TypeVar + +T = TypeVar("T") + + +class AsyncResource(metaclass=ABCMeta): + """ + Abstract base class for all closeable asynchronous resources. + + Works as an asynchronous context manager which returns the instance itself on enter, and calls + :meth:`aclose` on exit. + """ + + async def __aenter__(self: T) -> T: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() + + @abstractmethod + async def aclose(self) -> None: + """Close the resource.""" diff --git a/contrib/python/anyio/anyio/abc/_sockets.py b/contrib/python/anyio/anyio/abc/_sockets.py new file mode 100644 index 0000000000..6aac5f7c22 --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_sockets.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import socket +from abc import abstractmethod +from contextlib import AsyncExitStack +from io import IOBase +from ipaddress import IPv4Address, IPv6Address +from socket import AddressFamily +from typing import ( + Any, + Callable, + Collection, + Mapping, + Tuple, + TypeVar, + Union, +) + +from .._core._tasks import create_task_group +from .._core._typedattr import ( + TypedAttributeProvider, + TypedAttributeSet, + typed_attribute, +) +from ._streams import ByteStream, Listener, UnreliableObjectStream +from ._tasks import TaskGroup + +IPAddressType = Union[str, IPv4Address, IPv6Address] +IPSockAddrType = Tuple[str, int] +SockAddrType = Union[IPSockAddrType, str] +UDPPacketType = Tuple[bytes, IPSockAddrType] +T_Retval = TypeVar("T_Retval") + + +class SocketAttribute(TypedAttributeSet): + #: the address family of the underlying socket + family: AddressFamily = typed_attribute() + #: the local socket address of the underlying socket + local_address: SockAddrType = typed_attribute() + #: for IP addresses, the local port the underlying socket is bound to + local_port: int = typed_attribute() + #: the underlying stdlib socket object + raw_socket: socket.socket = typed_attribute() + #: the remote address the underlying socket is connected to + remote_address: SockAddrType = typed_attribute() + #: for IP addresses, the remote port the underlying socket is connected to + remote_port: int = typed_attribute() + + +class _SocketProvider(TypedAttributeProvider): + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + from .._core._sockets import convert_ipv6_sockaddr as convert + + attributes: dict[Any, Callable[[], Any]] = { + SocketAttribute.family: lambda: self._raw_socket.family, + SocketAttribute.local_address: lambda: convert( + self._raw_socket.getsockname() + ), + SocketAttribute.raw_socket: lambda: self._raw_socket, + } + try: + peername: tuple[str, int] | None = convert(self._raw_socket.getpeername()) + except OSError: + peername = None + + # Provide the remote address for connected sockets + if peername is not None: + attributes[SocketAttribute.remote_address] = lambda: peername + + # Provide local and remote ports for IP based sockets + if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): + attributes[ + SocketAttribute.local_port + ] = lambda: self._raw_socket.getsockname()[1] + if peername is not None: + remote_port = peername[1] + attributes[SocketAttribute.remote_port] = lambda: remote_port + + return attributes + + @property + @abstractmethod + def _raw_socket(self) -> socket.socket: + pass + + +class SocketStream(ByteStream, _SocketProvider): + """ + Transports bytes over a socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + +class UNIXSocketStream(SocketStream): + @abstractmethod + async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: + """ + Send file descriptors along with a message to the peer. + + :param message: a non-empty bytestring + :param fds: a collection of files (either numeric file descriptors or open file or socket + objects) + """ + + @abstractmethod + async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: + """ + Receive file descriptors along with a message from the peer. + + :param msglen: length of the message to expect from the peer + :param maxfds: maximum number of file descriptors to expect from the peer + :return: a tuple of (message, file descriptors) + """ + + +class SocketListener(Listener[SocketStream], _SocketProvider): + """ + Listens to incoming socket connections. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + @abstractmethod + async def accept(self) -> SocketStream: + """Accept an incoming connection.""" + + async def serve( + self, + handler: Callable[[SocketStream], Any], + task_group: TaskGroup | None = None, + ) -> None: + async with AsyncExitStack() as exit_stack: + if task_group is None: + task_group = await exit_stack.enter_async_context(create_task_group()) + + while True: + stream = await self.accept() + task_group.start_soon(handler, stream) + + +class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider): + """ + Represents an unconnected UDP socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + async def sendto(self, data: bytes, host: str, port: int) -> None: + """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).""" + return await self.send((data, (host, port))) + + +class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider): + """ + Represents an connected UDP socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ diff --git a/contrib/python/anyio/anyio/abc/_streams.py b/contrib/python/anyio/anyio/abc/_streams.py new file mode 100644 index 0000000000..4fa7ccc9ff --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_streams.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Callable, Generic, TypeVar, Union + +from .._core._exceptions import EndOfStream +from .._core._typedattr import TypedAttributeProvider +from ._resources import AsyncResource +from ._tasks import TaskGroup + +T_Item = TypeVar("T_Item") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +class UnreliableObjectReceiveStream( + Generic[T_co], AsyncResource, TypedAttributeProvider +): + """ + An interface for receiving objects. + + This interface makes no guarantees that the received messages arrive in the order in which they + were sent, or that no messages are missed. + + Asynchronously iterating over objects of this type will yield objects matching the given type + parameter. + """ + + def __aiter__(self) -> UnreliableObjectReceiveStream[T_co]: + return self + + async def __anext__(self) -> T_co: + try: + return await self.receive() + except EndOfStream: + raise StopAsyncIteration + + @abstractmethod + async def receive(self) -> T_co: + """ + Receive the next item. + + :raises ~anyio.ClosedResourceError: if the receive stream has been explicitly + closed + :raises ~anyio.EndOfStream: if this stream has been closed from the other end + :raises ~anyio.BrokenResourceError: if this stream has been rendered unusable + due to external causes + """ + + +class UnreliableObjectSendStream( + Generic[T_contra], AsyncResource, TypedAttributeProvider +): + """ + An interface for sending objects. + + This interface makes no guarantees that the messages sent will reach the recipient(s) in the + same order in which they were sent, or at all. + """ + + @abstractmethod + async def send(self, item: T_contra) -> None: + """ + Send an item to the peer(s). + + :param item: the item to send + :raises ~anyio.ClosedResourceError: if the send stream has been explicitly + closed + :raises ~anyio.BrokenResourceError: if this stream has been rendered unusable + due to external causes + """ + + +class UnreliableObjectStream( + UnreliableObjectReceiveStream[T_Item], UnreliableObjectSendStream[T_Item] +): + """ + A bidirectional message stream which does not guarantee the order or reliability of message + delivery. + """ + + +class ObjectReceiveStream(UnreliableObjectReceiveStream[T_co]): + """ + A receive message stream which guarantees that messages are received in the same order in + which they were sent, and that no messages are missed. + """ + + +class ObjectSendStream(UnreliableObjectSendStream[T_contra]): + """ + A send message stream which guarantees that messages are delivered in the same order in which + they were sent, without missing any messages in the middle. + """ + + +class ObjectStream( + ObjectReceiveStream[T_Item], + ObjectSendStream[T_Item], + UnreliableObjectStream[T_Item], +): + """ + A bidirectional message stream which guarantees the order and reliability of message delivery. + """ + + @abstractmethod + async def send_eof(self) -> None: + """ + Send an end-of-file indication to the peer. + + You should not try to send any further data to this stream after calling this method. + This method is idempotent (does nothing on successive calls). + """ + + +class ByteReceiveStream(AsyncResource, TypedAttributeProvider): + """ + An interface for receiving bytes from a single peer. + + Iterating this byte stream will yield a byte string of arbitrary length, but no more than + 65536 bytes. + """ + + def __aiter__(self) -> ByteReceiveStream: + return self + + async def __anext__(self) -> bytes: + try: + return await self.receive() + except EndOfStream: + raise StopAsyncIteration + + @abstractmethod + async def receive(self, max_bytes: int = 65536) -> bytes: + """ + Receive at most ``max_bytes`` bytes from the peer. + + .. note:: Implementors of this interface should not return an empty :class:`bytes` object, + and users should ignore them. + + :param max_bytes: maximum number of bytes to receive + :return: the received bytes + :raises ~anyio.EndOfStream: if this stream has been closed from the other end + """ + + +class ByteSendStream(AsyncResource, TypedAttributeProvider): + """An interface for sending bytes to a single peer.""" + + @abstractmethod + async def send(self, item: bytes) -> None: + """ + Send the given bytes to the peer. + + :param item: the bytes to send + """ + + +class ByteStream(ByteReceiveStream, ByteSendStream): + """A bidirectional byte stream.""" + + @abstractmethod + async def send_eof(self) -> None: + """ + Send an end-of-file indication to the peer. + + You should not try to send any further data to this stream after calling this method. + This method is idempotent (does nothing on successive calls). + """ + + +#: Type alias for all unreliable bytes-oriented receive streams. +AnyUnreliableByteReceiveStream = Union[ + UnreliableObjectReceiveStream[bytes], ByteReceiveStream +] +#: Type alias for all unreliable bytes-oriented send streams. +AnyUnreliableByteSendStream = Union[UnreliableObjectSendStream[bytes], ByteSendStream] +#: Type alias for all unreliable bytes-oriented streams. +AnyUnreliableByteStream = Union[UnreliableObjectStream[bytes], ByteStream] +#: Type alias for all bytes-oriented receive streams. +AnyByteReceiveStream = Union[ObjectReceiveStream[bytes], ByteReceiveStream] +#: Type alias for all bytes-oriented send streams. +AnyByteSendStream = Union[ObjectSendStream[bytes], ByteSendStream] +#: Type alias for all bytes-oriented streams. +AnyByteStream = Union[ObjectStream[bytes], ByteStream] + + +class Listener(Generic[T_co], AsyncResource, TypedAttributeProvider): + """An interface for objects that let you accept incoming connections.""" + + @abstractmethod + async def serve( + self, + handler: Callable[[T_co], Any], + task_group: TaskGroup | None = None, + ) -> None: + """ + Accept incoming connections as they come in and start tasks to handle them. + + :param handler: a callable that will be used to handle each accepted connection + :param task_group: the task group that will be used to start tasks for handling each + accepted connection (if omitted, an ad-hoc task group will be created) + """ diff --git a/contrib/python/anyio/anyio/abc/_subprocesses.py b/contrib/python/anyio/anyio/abc/_subprocesses.py new file mode 100644 index 0000000000..704b44a2dd --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_subprocesses.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from abc import abstractmethod +from signal import Signals + +from ._resources import AsyncResource +from ._streams import ByteReceiveStream, ByteSendStream + + +class Process(AsyncResource): + """An asynchronous version of :class:`subprocess.Popen`.""" + + @abstractmethod + async def wait(self) -> int: + """ + Wait until the process exits. + + :return: the exit code of the process + """ + + @abstractmethod + def terminate(self) -> None: + """ + Terminates the process, gracefully if possible. + + On Windows, this calls ``TerminateProcess()``. + On POSIX systems, this sends ``SIGTERM`` to the process. + + .. seealso:: :meth:`subprocess.Popen.terminate` + """ + + @abstractmethod + def kill(self) -> None: + """ + Kills the process. + + On Windows, this calls ``TerminateProcess()``. + On POSIX systems, this sends ``SIGKILL`` to the process. + + .. seealso:: :meth:`subprocess.Popen.kill` + """ + + @abstractmethod + def send_signal(self, signal: Signals) -> None: + """ + Send a signal to the subprocess. + + .. seealso:: :meth:`subprocess.Popen.send_signal` + + :param signal: the signal number (e.g. :data:`signal.SIGHUP`) + """ + + @property + @abstractmethod + def pid(self) -> int: + """The process ID of the process.""" + + @property + @abstractmethod + def returncode(self) -> int | None: + """ + The return code of the process. If the process has not yet terminated, this will be + ``None``. + """ + + @property + @abstractmethod + def stdin(self) -> ByteSendStream | None: + """The stream for the standard input of the process.""" + + @property + @abstractmethod + def stdout(self) -> ByteReceiveStream | None: + """The stream for the standard output of the process.""" + + @property + @abstractmethod + def stderr(self) -> ByteReceiveStream | None: + """The stream for the standard error output of the process.""" diff --git a/contrib/python/anyio/anyio/abc/_tasks.py b/contrib/python/anyio/anyio/abc/_tasks.py new file mode 100644 index 0000000000..e48d3c1e97 --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_tasks.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import sys +from abc import ABCMeta, abstractmethod +from types import TracebackType +from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, overload +from warnings import warn + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + +if TYPE_CHECKING: + from anyio._core._tasks import CancelScope + +T_Retval = TypeVar("T_Retval") +T_contra = TypeVar("T_contra", contravariant=True) + + +class TaskStatus(Protocol[T_contra]): + @overload + def started(self: TaskStatus[None]) -> None: + ... + + @overload + def started(self, value: T_contra) -> None: + ... + + def started(self, value: T_contra | None = None) -> None: + """ + Signal that the task has started. + + :param value: object passed back to the starter of the task + """ + + +class TaskGroup(metaclass=ABCMeta): + """ + Groups several asynchronous tasks together. + + :ivar cancel_scope: the cancel scope inherited by all child tasks + :vartype cancel_scope: CancelScope + """ + + cancel_scope: CancelScope + + async def spawn( + self, + func: Callable[..., Awaitable[Any]], + *args: object, + name: object = None, + ) -> None: + """ + Start a new task in this task group. + + :param func: a coroutine function + :param args: positional arguments to call the function with + :param name: name of the task, for the purposes of introspection and debugging + + .. deprecated:: 3.0 + Use :meth:`start_soon` instead. If your code needs AnyIO 2 compatibility, you + can keep using this until AnyIO 4. + + """ + warn( + 'spawn() is deprecated -- use start_soon() (without the "await") instead', + DeprecationWarning, + ) + self.start_soon(func, *args, name=name) + + @abstractmethod + def start_soon( + self, + func: Callable[..., Awaitable[Any]], + *args: object, + name: object = None, + ) -> None: + """ + Start a new task in this task group. + + :param func: a coroutine function + :param args: positional arguments to call the function with + :param name: name of the task, for the purposes of introspection and debugging + + .. versionadded:: 3.0 + """ + + @abstractmethod + async def start( + self, + func: Callable[..., Awaitable[Any]], + *args: object, + name: object = None, + ) -> Any: + """ + Start a new task and wait until it signals for readiness. + + :param func: a coroutine function + :param args: positional arguments to call the function with + :param name: name of the task, for the purposes of introspection and debugging + :return: the value passed to ``task_status.started()`` + :raises RuntimeError: if the task finishes without calling ``task_status.started()`` + + .. versionadded:: 3.0 + """ + + @abstractmethod + async def __aenter__(self) -> TaskGroup: + """Enter the task group context and allow starting new tasks.""" + + @abstractmethod + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + """Exit the task group context waiting for all tasks to finish.""" diff --git a/contrib/python/anyio/anyio/abc/_testing.py b/contrib/python/anyio/anyio/abc/_testing.py new file mode 100644 index 0000000000..ee2cff5cc3 --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_testing.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import types +from abc import ABCMeta, abstractmethod +from collections.abc import AsyncGenerator, Iterable +from typing import Any, Callable, Coroutine, TypeVar + +_T = TypeVar("_T") + + +class TestRunner(metaclass=ABCMeta): + """ + Encapsulates a running event loop. Every call made through this object will use the same event + loop. + """ + + def __enter__(self) -> TestRunner: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> bool | None: + self.close() + return None + + @abstractmethod + def close(self) -> None: + """Close the event loop.""" + + @abstractmethod + def run_asyncgen_fixture( + self, + fixture_func: Callable[..., AsyncGenerator[_T, Any]], + kwargs: dict[str, Any], + ) -> Iterable[_T]: + """ + Run an async generator fixture. + + :param fixture_func: the fixture function + :param kwargs: keyword arguments to call the fixture function with + :return: an iterator yielding the value yielded from the async generator + """ + + @abstractmethod + def run_fixture( + self, + fixture_func: Callable[..., Coroutine[Any, Any, _T]], + kwargs: dict[str, Any], + ) -> _T: + """ + Run an async fixture. + + :param fixture_func: the fixture function + :param kwargs: keyword arguments to call the fixture function with + :return: the return value of the fixture function + """ + + @abstractmethod + def run_test( + self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] + ) -> None: + """ + Run an async test function. + + :param test_func: the test function + :param kwargs: keyword arguments to call the test function with + """ diff --git a/contrib/python/anyio/anyio/from_thread.py b/contrib/python/anyio/anyio/from_thread.py new file mode 100644 index 0000000000..6b76861c70 --- /dev/null +++ b/contrib/python/anyio/anyio/from_thread.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +import threading +from asyncio import iscoroutine +from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from contextlib import AbstractContextManager, contextmanager +from types import TracebackType +from typing import ( + Any, + AsyncContextManager, + Awaitable, + Callable, + ContextManager, + Generator, + Generic, + Iterable, + TypeVar, + cast, + overload, +) +from warnings import warn + +from ._core import _eventloop +from ._core._eventloop import get_asynclib, get_cancelled_exc_class, threadlocals +from ._core._synchronization import Event +from ._core._tasks import CancelScope, create_task_group +from .abc._tasks import TaskStatus + +T_Retval = TypeVar("T_Retval") +T_co = TypeVar("T_co") + + +def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: + """ + Call a coroutine function from a worker thread. + + :param func: a coroutine function + :param args: positional arguments for the callable + :return: the return value of the coroutine function + + """ + try: + asynclib = threadlocals.current_async_module + except AttributeError: + raise RuntimeError("This function can only be run from an AnyIO worker thread") + + return asynclib.run_async_from_thread(func, *args) + + +def run_async_from_thread( + func: Callable[..., Awaitable[T_Retval]], *args: object +) -> T_Retval: + warn( + "run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead", + DeprecationWarning, + ) + return run(func, *args) + + +def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: + """ + Call a function in the event loop thread from a worker thread. + + :param func: a callable + :param args: positional arguments for the callable + :return: the return value of the callable + + """ + try: + asynclib = threadlocals.current_async_module + except AttributeError: + raise RuntimeError("This function can only be run from an AnyIO worker thread") + + return asynclib.run_sync_from_thread(func, *args) + + +def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval: + warn( + "run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead", + DeprecationWarning, + ) + return run_sync(func, *args) + + +class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): + _enter_future: Future + _exit_future: Future + _exit_event: Event + _exit_exc_info: tuple[ + type[BaseException] | None, BaseException | None, TracebackType | None + ] = (None, None, None) + + def __init__(self, async_cm: AsyncContextManager[T_co], portal: BlockingPortal): + self._async_cm = async_cm + self._portal = portal + + async def run_async_cm(self) -> bool | None: + try: + self._exit_event = Event() + value = await self._async_cm.__aenter__() + except BaseException as exc: + self._enter_future.set_exception(exc) + raise + else: + self._enter_future.set_result(value) + + try: + # Wait for the sync context manager to exit. + # This next statement can raise `get_cancelled_exc_class()` if + # something went wrong in a task group in this async context + # manager. + await self._exit_event.wait() + finally: + # In case of cancellation, it could be that we end up here before + # `_BlockingAsyncContextManager.__exit__` is called, and an + # `_exit_exc_info` has been set. + result = await self._async_cm.__aexit__(*self._exit_exc_info) + return result + + def __enter__(self) -> T_co: + self._enter_future = Future() + self._exit_future = self._portal.start_task_soon(self.run_async_cm) + cm = self._enter_future.result() + return cast(T_co, cm) + + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: + self._exit_exc_info = __exc_type, __exc_value, __traceback + self._portal.call(self._exit_event.set) + return self._exit_future.result() + + +class _BlockingPortalTaskStatus(TaskStatus): + def __init__(self, future: Future): + self._future = future + + def started(self, value: object = None) -> None: + self._future.set_result(value) + + +class BlockingPortal: + """An object that lets external threads run code in an asynchronous event loop.""" + + def __new__(cls) -> BlockingPortal: + return get_asynclib().BlockingPortal() + + def __init__(self) -> None: + self._event_loop_thread_id: int | None = threading.get_ident() + self._stop_event = Event() + self._task_group = create_task_group() + self._cancelled_exc_class = get_cancelled_exc_class() + + async def __aenter__(self) -> BlockingPortal: + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.stop() + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + def _check_running(self) -> None: + if self._event_loop_thread_id is None: + raise RuntimeError("This portal is not running") + if self._event_loop_thread_id == threading.get_ident(): + raise RuntimeError( + "This method cannot be called from the event loop thread" + ) + + async def sleep_until_stopped(self) -> None: + """Sleep until :meth:`stop` is called.""" + await self._stop_event.wait() + + async def stop(self, cancel_remaining: bool = False) -> None: + """ + Signal the portal to shut down. + + This marks the portal as no longer accepting new calls and exits from + :meth:`sleep_until_stopped`. + + :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` to let them + finish before returning + + """ + self._event_loop_thread_id = None + self._stop_event.set() + if cancel_remaining: + self._task_group.cancel_scope.cancel() + + async def _call_func( + self, func: Callable, args: tuple, kwargs: dict[str, Any], future: Future + ) -> None: + def callback(f: Future) -> None: + if f.cancelled() and self._event_loop_thread_id not in ( + None, + threading.get_ident(), + ): + self.call(scope.cancel) + + try: + retval = func(*args, **kwargs) + if iscoroutine(retval): + with CancelScope() as scope: + if future.cancelled(): + scope.cancel() + else: + future.add_done_callback(callback) + + retval = await retval + except self._cancelled_exc_class: + future.cancel() + except BaseException as exc: + if not future.cancelled(): + future.set_exception(exc) + + # Let base exceptions fall through + if not isinstance(exc, Exception): + raise + else: + if not future.cancelled(): + future.set_result(retval) + finally: + scope = None # type: ignore[assignment] + + def _spawn_task_from_thread( + self, + func: Callable, + args: tuple, + kwargs: dict[str, Any], + name: object, + future: Future, + ) -> None: + """ + Spawn a new task using the given callable. + + Implementors must ensure that the future is resolved when the task finishes. + + :param func: a callable + :param args: positional arguments to be passed to the callable + :param kwargs: keyword arguments to be passed to the callable + :param name: name of the task (will be coerced to a string if not ``None``) + :param future: a future that will resolve to the return value of the callable, or the + exception raised during its execution + + """ + raise NotImplementedError + + @overload + def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: + ... + + @overload + def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: + ... + + def call( + self, func: Callable[..., Awaitable[T_Retval] | T_Retval], *args: object + ) -> T_Retval: + """ + Call the given function in the event loop thread. + + If the callable returns a coroutine object, it is awaited on. + + :param func: any callable + :raises RuntimeError: if the portal is not running or if this method is called from within + the event loop thread + + """ + return cast(T_Retval, self.start_task_soon(func, *args).result()) + + @overload + def spawn_task( + self, + func: Callable[..., Awaitable[T_Retval]], + *args: object, + name: object = None, + ) -> Future[T_Retval]: + ... + + @overload + def spawn_task( + self, func: Callable[..., T_Retval], *args: object, name: object = None + ) -> Future[T_Retval]: + ... + + def spawn_task( + self, + func: Callable[..., Awaitable[T_Retval] | T_Retval], + *args: object, + name: object = None, + ) -> Future[T_Retval]: + """ + Start a task in the portal's task group. + + :param func: the target coroutine function + :param args: positional arguments passed to ``func`` + :param name: name of the task (will be coerced to a string if not ``None``) + :return: a future that resolves with the return value of the callable if the task completes + successfully, or with the exception raised in the task + :raises RuntimeError: if the portal is not running or if this method is called from within + the event loop thread + + .. versionadded:: 2.1 + .. deprecated:: 3.0 + Use :meth:`start_task_soon` instead. If your code needs AnyIO 2 compatibility, you + can keep using this until AnyIO 4. + + """ + warn( + "spawn_task() is deprecated -- use start_task_soon() instead", + DeprecationWarning, + ) + return self.start_task_soon(func, *args, name=name) # type: ignore[arg-type] + + @overload + def start_task_soon( + self, + func: Callable[..., Awaitable[T_Retval]], + *args: object, + name: object = None, + ) -> Future[T_Retval]: + ... + + @overload + def start_task_soon( + self, func: Callable[..., T_Retval], *args: object, name: object = None + ) -> Future[T_Retval]: + ... + + def start_task_soon( + self, + func: Callable[..., Awaitable[T_Retval] | T_Retval], + *args: object, + name: object = None, + ) -> Future[T_Retval]: + """ + Start a task in the portal's task group. + + The task will be run inside a cancel scope which can be cancelled by cancelling the + returned future. + + :param func: the target function + :param args: positional arguments passed to ``func`` + :param name: name of the task (will be coerced to a string if not ``None``) + :return: a future that resolves with the return value of the callable if the + task completes successfully, or with the exception raised in the task + :raises RuntimeError: if the portal is not running or if this method is called + from within the event loop thread + :rtype: concurrent.futures.Future[T_Retval] + + .. versionadded:: 3.0 + + """ + self._check_running() + f: Future = Future() + self._spawn_task_from_thread(func, args, {}, name, f) + return f + + def start_task( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> tuple[Future[Any], Any]: + """ + Start a task in the portal's task group and wait until it signals for readiness. + + This method works the same way as :meth:`.abc.TaskGroup.start`. + + :param func: the target function + :param args: positional arguments passed to ``func`` + :param name: name of the task (will be coerced to a string if not ``None``) + :return: a tuple of (future, task_status_value) where the ``task_status_value`` + is the value passed to ``task_status.started()`` from within the target + function + :rtype: tuple[concurrent.futures.Future[Any], Any] + + .. versionadded:: 3.0 + + """ + + def task_done(future: Future) -> None: + if not task_status_future.done(): + if future.cancelled(): + task_status_future.cancel() + elif future.exception(): + task_status_future.set_exception(future.exception()) + else: + exc = RuntimeError( + "Task exited without calling task_status.started()" + ) + task_status_future.set_exception(exc) + + self._check_running() + task_status_future: Future = Future() + task_status = _BlockingPortalTaskStatus(task_status_future) + f: Future = Future() + f.add_done_callback(task_done) + self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f) + return f, task_status_future.result() + + def wrap_async_context_manager( + self, cm: AsyncContextManager[T_co] + ) -> ContextManager[T_co]: + """ + Wrap an async context manager as a synchronous context manager via this portal. + + Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping in the + middle until the synchronous context manager exits. + + :param cm: an asynchronous context manager + :return: a synchronous context manager + + .. versionadded:: 2.1 + + """ + return _BlockingAsyncContextManager(cm, self) + + +def create_blocking_portal() -> BlockingPortal: + """ + Create a portal for running functions in the event loop thread from external threads. + + Use this function in asynchronous code when you need to allow external threads access to the + event loop where your asynchronous code is currently running. + + .. deprecated:: 3.0 + Use :class:`.BlockingPortal` directly. + + """ + warn( + "create_blocking_portal() has been deprecated -- use anyio.from_thread.BlockingPortal() " + "directly", + DeprecationWarning, + ) + return BlockingPortal() + + +@contextmanager +def start_blocking_portal( + backend: str = "asyncio", backend_options: dict[str, Any] | None = None +) -> Generator[BlockingPortal, Any, None]: + """ + Start a new event loop in a new thread and run a blocking portal in its main task. + + The parameters are the same as for :func:`~anyio.run`. + + :param backend: name of the backend + :param backend_options: backend options + :return: a context manager that yields a blocking portal + + .. versionchanged:: 3.0 + Usage as a context manager is now required. + + """ + + async def run_portal() -> None: + async with BlockingPortal() as portal_: + if future.set_running_or_notify_cancel(): + future.set_result(portal_) + await portal_.sleep_until_stopped() + + future: Future[BlockingPortal] = Future() + with ThreadPoolExecutor(1) as executor: + run_future = executor.submit( + _eventloop.run, + run_portal, # type: ignore[arg-type] + backend=backend, + backend_options=backend_options, + ) + try: + wait( + cast(Iterable[Future], [run_future, future]), + return_when=FIRST_COMPLETED, + ) + except BaseException: + future.cancel() + run_future.cancel() + raise + + if future.done(): + portal = future.result() + cancel_remaining_tasks = False + try: + yield portal + except BaseException: + cancel_remaining_tasks = True + raise + finally: + try: + portal.call(portal.stop, cancel_remaining_tasks) + except RuntimeError: + pass + + run_future.result() diff --git a/contrib/python/anyio/anyio/lowlevel.py b/contrib/python/anyio/anyio/lowlevel.py new file mode 100644 index 0000000000..0e908c6547 --- /dev/null +++ b/contrib/python/anyio/anyio/lowlevel.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import enum +import sys +from dataclasses import dataclass +from typing import Any, Generic, TypeVar, overload +from weakref import WeakKeyDictionary + +from ._core._eventloop import get_asynclib + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +T = TypeVar("T") +D = TypeVar("D") + + +async def checkpoint() -> None: + """ + Check for cancellation and allow the scheduler to switch to another task. + + Equivalent to (but more efficient than):: + + await checkpoint_if_cancelled() + await cancel_shielded_checkpoint() + + + .. versionadded:: 3.0 + + """ + await get_asynclib().checkpoint() + + +async def checkpoint_if_cancelled() -> None: + """ + Enter a checkpoint if the enclosing cancel scope has been cancelled. + + This does not allow the scheduler to switch to a different task. + + .. versionadded:: 3.0 + + """ + await get_asynclib().checkpoint_if_cancelled() + + +async def cancel_shielded_checkpoint() -> None: + """ + Allow the scheduler to switch to another task but without checking for cancellation. + + Equivalent to (but potentially more efficient than):: + + with CancelScope(shield=True): + await checkpoint() + + + .. versionadded:: 3.0 + + """ + await get_asynclib().cancel_shielded_checkpoint() + + +def current_token() -> object: + """Return a backend specific token object that can be used to get back to the event loop.""" + return get_asynclib().current_token() + + +_run_vars: WeakKeyDictionary[Any, dict[str, Any]] = WeakKeyDictionary() +_token_wrappers: dict[Any, _TokenWrapper] = {} + + +@dataclass(frozen=True) +class _TokenWrapper: + __slots__ = "_token", "__weakref__" + _token: object + + +class _NoValueSet(enum.Enum): + NO_VALUE_SET = enum.auto() + + +class RunvarToken(Generic[T]): + __slots__ = "_var", "_value", "_redeemed" + + def __init__(self, var: RunVar[T], value: T | Literal[_NoValueSet.NO_VALUE_SET]): + self._var = var + self._value: T | Literal[_NoValueSet.NO_VALUE_SET] = value + self._redeemed = False + + +class RunVar(Generic[T]): + """ + Like a :class:`~contextvars.ContextVar`, except scoped to the running event loop. + """ + + __slots__ = "_name", "_default" + + NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET + + _token_wrappers: set[_TokenWrapper] = set() + + def __init__( + self, + name: str, + default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET, + ): + self._name = name + self._default = default + + @property + def _current_vars(self) -> dict[str, T]: + token = current_token() + while True: + try: + return _run_vars[token] + except TypeError: + # Happens when token isn't weak referable (TrioToken). + # This workaround does mean that some memory will leak on Trio until the problem + # is fixed on their end. + token = _TokenWrapper(token) + self._token_wrappers.add(token) + except KeyError: + run_vars = _run_vars[token] = {} + return run_vars + + @overload + def get(self, default: D) -> T | D: + ... + + @overload + def get(self) -> T: + ... + + def get( + self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET + ) -> T | D: + try: + return self._current_vars[self._name] + except KeyError: + if default is not RunVar.NO_VALUE_SET: + return default + elif self._default is not RunVar.NO_VALUE_SET: + return self._default + + raise LookupError( + f'Run variable "{self._name}" has no value and no default set' + ) + + def set(self, value: T) -> RunvarToken[T]: + current_vars = self._current_vars + token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET)) + current_vars[self._name] = value + return token + + def reset(self, token: RunvarToken[T]) -> None: + if token._var is not self: + raise ValueError("This token does not belong to this RunVar") + + if token._redeemed: + raise ValueError("This token has already been used") + + if token._value is _NoValueSet.NO_VALUE_SET: + try: + del self._current_vars[self._name] + except KeyError: + pass + else: + self._current_vars[self._name] = token._value + + token._redeemed = True + + def __repr__(self) -> str: + return f"<RunVar name={self._name!r}>" diff --git a/contrib/python/anyio/anyio/py.typed b/contrib/python/anyio/anyio/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/anyio/anyio/py.typed diff --git a/contrib/python/anyio/anyio/pytest_plugin.py b/contrib/python/anyio/anyio/pytest_plugin.py new file mode 100644 index 0000000000..044ce6914d --- /dev/null +++ b/contrib/python/anyio/anyio/pytest_plugin.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from contextlib import contextmanager +from inspect import isasyncgenfunction, iscoroutinefunction +from typing import Any, Dict, Generator, Tuple, cast + +import pytest +import sniffio + +from ._core._eventloop import get_all_backends, get_asynclib +from .abc import TestRunner + +_current_runner: TestRunner | None = None + + +def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]: + if isinstance(backend, str): + return backend, {} + elif isinstance(backend, tuple) and len(backend) == 2: + if isinstance(backend[0], str) and isinstance(backend[1], dict): + return cast(Tuple[str, Dict[str, Any]], backend) + + raise TypeError("anyio_backend must be either a string or tuple of (string, dict)") + + +@contextmanager +def get_runner( + backend_name: str, backend_options: dict[str, Any] +) -> Generator[TestRunner, object, None]: + global _current_runner + if _current_runner: + yield _current_runner + return + + asynclib = get_asynclib(backend_name) + token = None + if sniffio.current_async_library_cvar.get(None) is None: + # Since we're in control of the event loop, we can cache the name of the async library + token = sniffio.current_async_library_cvar.set(backend_name) + + try: + backend_options = backend_options or {} + with asynclib.TestRunner(**backend_options) as runner: + _current_runner = runner + yield runner + finally: + _current_runner = None + if token: + sniffio.current_async_library_cvar.reset(token) + + +def pytest_configure(config: Any) -> None: + config.addinivalue_line( + "markers", + "anyio: mark the (coroutine function) test to be run " + "asynchronously via anyio.", + ) + + +def pytest_fixture_setup(fixturedef: Any, request: Any) -> None: + def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def] + backend_name, backend_options = extract_backend_and_options(anyio_backend) + if has_backend_arg: + kwargs["anyio_backend"] = anyio_backend + + with get_runner(backend_name, backend_options) as runner: + if isasyncgenfunction(func): + yield from runner.run_asyncgen_fixture(func, kwargs) + else: + yield runner.run_fixture(func, kwargs) + + # Only apply this to coroutine functions and async generator functions in requests that involve + # the anyio_backend fixture + func = fixturedef.func + if isasyncgenfunction(func) or iscoroutinefunction(func): + if "anyio_backend" in request.fixturenames: + has_backend_arg = "anyio_backend" in fixturedef.argnames + fixturedef.func = wrapper + if not has_backend_arg: + fixturedef.argnames += ("anyio_backend",) + + +@pytest.hookimpl(tryfirst=True) +def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None: + if collector.istestfunction(obj, name): + inner_func = obj.hypothesis.inner_test if hasattr(obj, "hypothesis") else obj + if iscoroutinefunction(inner_func): + marker = collector.get_closest_marker("anyio") + own_markers = getattr(obj, "pytestmark", ()) + if marker or any(marker.name == "anyio" for marker in own_markers): + pytest.mark.usefixtures("anyio_backend")(obj) + + +@pytest.hookimpl(tryfirst=True) +def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None: + def run_with_hypothesis(**kwargs: Any) -> None: + with get_runner(backend_name, backend_options) as runner: + runner.run_test(original_func, kwargs) + + backend = pyfuncitem.funcargs.get("anyio_backend") + if backend: + backend_name, backend_options = extract_backend_and_options(backend) + + if hasattr(pyfuncitem.obj, "hypothesis"): + # Wrap the inner test function unless it's already wrapped + original_func = pyfuncitem.obj.hypothesis.inner_test + if original_func.__qualname__ != run_with_hypothesis.__qualname__: + if iscoroutinefunction(original_func): + pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis + + return None + + if iscoroutinefunction(pyfuncitem.obj): + funcargs = pyfuncitem.funcargs + testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames} + with get_runner(backend_name, backend_options) as runner: + runner.run_test(pyfuncitem.obj, testargs) + + return True + + return None + + +@pytest.fixture(params=get_all_backends()) +def anyio_backend(request: Any) -> Any: + return request.param + + +@pytest.fixture +def anyio_backend_name(anyio_backend: Any) -> str: + if isinstance(anyio_backend, str): + return anyio_backend + else: + return anyio_backend[0] + + +@pytest.fixture +def anyio_backend_options(anyio_backend: Any) -> dict[str, Any]: + if isinstance(anyio_backend, str): + return {} + else: + return anyio_backend[1] diff --git a/contrib/python/anyio/anyio/streams/__init__.py b/contrib/python/anyio/anyio/streams/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/anyio/anyio/streams/__init__.py diff --git a/contrib/python/anyio/anyio/streams/buffered.py b/contrib/python/anyio/anyio/streams/buffered.py new file mode 100644 index 0000000000..11474c16a9 --- /dev/null +++ b/contrib/python/anyio/anyio/streams/buffered.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Mapping + +from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead +from ..abc import AnyByteReceiveStream, ByteReceiveStream + + +@dataclass(eq=False) +class BufferedByteReceiveStream(ByteReceiveStream): + """ + Wraps any bytes-based receive stream and uses a buffer to provide sophisticated receiving + capabilities in the form of a byte stream. + """ + + receive_stream: AnyByteReceiveStream + _buffer: bytearray = field(init=False, default_factory=bytearray) + _closed: bool = field(init=False, default=False) + + async def aclose(self) -> None: + await self.receive_stream.aclose() + self._closed = True + + @property + def buffer(self) -> bytes: + """The bytes currently in the buffer.""" + return bytes(self._buffer) + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return self.receive_stream.extra_attributes + + async def receive(self, max_bytes: int = 65536) -> bytes: + if self._closed: + raise ClosedResourceError + + if self._buffer: + chunk = bytes(self._buffer[:max_bytes]) + del self._buffer[:max_bytes] + return chunk + elif isinstance(self.receive_stream, ByteReceiveStream): + return await self.receive_stream.receive(max_bytes) + else: + # With a bytes-oriented object stream, we need to handle any surplus bytes we get from + # the receive() call + chunk = await self.receive_stream.receive() + if len(chunk) > max_bytes: + # Save the surplus bytes in the buffer + self._buffer.extend(chunk[max_bytes:]) + return chunk[:max_bytes] + else: + return chunk + + async def receive_exactly(self, nbytes: int) -> bytes: + """ + Read exactly the given amount of bytes from the stream. + + :param nbytes: the number of bytes to read + :return: the bytes read + :raises ~anyio.IncompleteRead: if the stream was closed before the requested + amount of bytes could be read from the stream + + """ + while True: + remaining = nbytes - len(self._buffer) + if remaining <= 0: + retval = self._buffer[:nbytes] + del self._buffer[:nbytes] + return bytes(retval) + + try: + if isinstance(self.receive_stream, ByteReceiveStream): + chunk = await self.receive_stream.receive(remaining) + else: + chunk = await self.receive_stream.receive() + except EndOfStream as exc: + raise IncompleteRead from exc + + self._buffer.extend(chunk) + + async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes: + """ + Read from the stream until the delimiter is found or max_bytes have been read. + + :param delimiter: the marker to look for in the stream + :param max_bytes: maximum number of bytes that will be read before raising + :exc:`~anyio.DelimiterNotFound` + :return: the bytes read (not including the delimiter) + :raises ~anyio.IncompleteRead: if the stream was closed before the delimiter + was found + :raises ~anyio.DelimiterNotFound: if the delimiter is not found within the + bytes read up to the maximum allowed + + """ + delimiter_size = len(delimiter) + offset = 0 + while True: + # Check if the delimiter can be found in the current buffer + index = self._buffer.find(delimiter, offset) + if index >= 0: + found = self._buffer[:index] + del self._buffer[: index + len(delimiter) :] + return bytes(found) + + # Check if the buffer is already at or over the limit + if len(self._buffer) >= max_bytes: + raise DelimiterNotFound(max_bytes) + + # Read more data into the buffer from the socket + try: + data = await self.receive_stream.receive() + except EndOfStream as exc: + raise IncompleteRead from exc + + # Move the offset forward and add the new data to the buffer + offset = max(len(self._buffer) - delimiter_size + 1, 0) + self._buffer.extend(data) diff --git a/contrib/python/anyio/anyio/streams/file.py b/contrib/python/anyio/anyio/streams/file.py new file mode 100644 index 0000000000..2840d40ab6 --- /dev/null +++ b/contrib/python/anyio/anyio/streams/file.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from io import SEEK_SET, UnsupportedOperation +from os import PathLike +from pathlib import Path +from typing import Any, BinaryIO, Callable, Mapping, cast + +from .. import ( + BrokenResourceError, + ClosedResourceError, + EndOfStream, + TypedAttributeSet, + to_thread, + typed_attribute, +) +from ..abc import ByteReceiveStream, ByteSendStream + + +class FileStreamAttribute(TypedAttributeSet): + #: the open file descriptor + file: BinaryIO = typed_attribute() + #: the path of the file on the file system, if available (file must be a real file) + path: Path = typed_attribute() + #: the file number, if available (file must be a real file or a TTY) + fileno: int = typed_attribute() + + +class _BaseFileStream: + def __init__(self, file: BinaryIO): + self._file = file + + async def aclose(self) -> None: + await to_thread.run_sync(self._file.close) + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + attributes: dict[Any, Callable[[], Any]] = { + FileStreamAttribute.file: lambda: self._file, + } + + if hasattr(self._file, "name"): + attributes[FileStreamAttribute.path] = lambda: Path(self._file.name) + + try: + self._file.fileno() + except UnsupportedOperation: + pass + else: + attributes[FileStreamAttribute.fileno] = lambda: self._file.fileno() + + return attributes + + +class FileReadStream(_BaseFileStream, ByteReceiveStream): + """ + A byte stream that reads from a file in the file system. + + :param file: a file that has been opened for reading in binary mode + + .. versionadded:: 3.0 + """ + + @classmethod + async def from_path(cls, path: str | PathLike[str]) -> FileReadStream: + """ + Create a file read stream by opening the given file. + + :param path: path of the file to read from + + """ + file = await to_thread.run_sync(Path(path).open, "rb") + return cls(cast(BinaryIO, file)) + + async def receive(self, max_bytes: int = 65536) -> bytes: + try: + data = await to_thread.run_sync(self._file.read, max_bytes) + except ValueError: + raise ClosedResourceError from None + except OSError as exc: + raise BrokenResourceError from exc + + if data: + return data + else: + raise EndOfStream + + async def seek(self, position: int, whence: int = SEEK_SET) -> int: + """ + Seek the file to the given position. + + .. seealso:: :meth:`io.IOBase.seek` + + .. note:: Not all file descriptors are seekable. + + :param position: position to seek the file to + :param whence: controls how ``position`` is interpreted + :return: the new absolute position + :raises OSError: if the file is not seekable + + """ + return await to_thread.run_sync(self._file.seek, position, whence) + + async def tell(self) -> int: + """ + Return the current stream position. + + .. note:: Not all file descriptors are seekable. + + :return: the current absolute position + :raises OSError: if the file is not seekable + + """ + return await to_thread.run_sync(self._file.tell) + + +class FileWriteStream(_BaseFileStream, ByteSendStream): + """ + A byte stream that writes to a file in the file system. + + :param file: a file that has been opened for writing in binary mode + + .. versionadded:: 3.0 + """ + + @classmethod + async def from_path( + cls, path: str | PathLike[str], append: bool = False + ) -> FileWriteStream: + """ + Create a file write stream by opening the given file for writing. + + :param path: path of the file to write to + :param append: if ``True``, open the file for appending; if ``False``, any existing file + at the given path will be truncated + + """ + mode = "ab" if append else "wb" + file = await to_thread.run_sync(Path(path).open, mode) + return cls(cast(BinaryIO, file)) + + async def send(self, item: bytes) -> None: + try: + await to_thread.run_sync(self._file.write, item) + except ValueError: + raise ClosedResourceError from None + except OSError as exc: + raise BrokenResourceError from exc diff --git a/contrib/python/anyio/anyio/streams/memory.py b/contrib/python/anyio/anyio/streams/memory.py new file mode 100644 index 0000000000..a6499c13ff --- /dev/null +++ b/contrib/python/anyio/anyio/streams/memory.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from collections import OrderedDict, deque +from dataclasses import dataclass, field +from types import TracebackType +from typing import Generic, NamedTuple, TypeVar + +from .. import ( + BrokenResourceError, + ClosedResourceError, + EndOfStream, + WouldBlock, + get_cancelled_exc_class, +) +from .._core._compat import DeprecatedAwaitable +from ..abc import Event, ObjectReceiveStream, ObjectSendStream +from ..lowlevel import checkpoint + +T_Item = TypeVar("T_Item") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +class MemoryObjectStreamStatistics(NamedTuple): + current_buffer_used: int #: number of items stored in the buffer + #: maximum number of items that can be stored on this stream (or :data:`math.inf`) + max_buffer_size: float + open_send_streams: int #: number of unclosed clones of the send stream + open_receive_streams: int #: number of unclosed clones of the receive stream + tasks_waiting_send: int #: number of tasks blocked on :meth:`MemoryObjectSendStream.send` + #: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive` + tasks_waiting_receive: int + + +@dataclass(eq=False) +class MemoryObjectStreamState(Generic[T_Item]): + max_buffer_size: float = field() + buffer: deque[T_Item] = field(init=False, default_factory=deque) + open_send_channels: int = field(init=False, default=0) + open_receive_channels: int = field(init=False, default=0) + waiting_receivers: OrderedDict[Event, list[T_Item]] = field( + init=False, default_factory=OrderedDict + ) + waiting_senders: OrderedDict[Event, T_Item] = field( + init=False, default_factory=OrderedDict + ) + + def statistics(self) -> MemoryObjectStreamStatistics: + return MemoryObjectStreamStatistics( + len(self.buffer), + self.max_buffer_size, + self.open_send_channels, + self.open_receive_channels, + len(self.waiting_senders), + len(self.waiting_receivers), + ) + + +@dataclass(eq=False) +class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]): + _state: MemoryObjectStreamState[T_co] + _closed: bool = field(init=False, default=False) + + def __post_init__(self) -> None: + self._state.open_receive_channels += 1 + + def receive_nowait(self) -> T_co: + """ + Receive the next item if it can be done without waiting. + + :return: the received item + :raises ~anyio.ClosedResourceError: if this send stream has been closed + :raises ~anyio.EndOfStream: if the buffer is empty and this stream has been + closed from the sending end + :raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks + waiting to send + + """ + if self._closed: + raise ClosedResourceError + + if self._state.waiting_senders: + # Get the item from the next sender + send_event, item = self._state.waiting_senders.popitem(last=False) + self._state.buffer.append(item) + send_event.set() + + if self._state.buffer: + return self._state.buffer.popleft() + elif not self._state.open_send_channels: + raise EndOfStream + + raise WouldBlock + + async def receive(self) -> T_co: + await checkpoint() + try: + return self.receive_nowait() + except WouldBlock: + # Add ourselves in the queue + receive_event = Event() + container: list[T_co] = [] + self._state.waiting_receivers[receive_event] = container + + try: + await receive_event.wait() + except get_cancelled_exc_class(): + # Ignore the immediate cancellation if we already received an item, so as not to + # lose it + if not container: + raise + finally: + self._state.waiting_receivers.pop(receive_event, None) + + if container: + return container[0] + else: + raise EndOfStream + + def clone(self) -> MemoryObjectReceiveStream[T_co]: + """ + Create a clone of this receive stream. + + Each clone can be closed separately. Only when all clones have been closed will the + receiving end of the memory stream be considered closed by the sending ends. + + :return: the cloned stream + + """ + if self._closed: + raise ClosedResourceError + + return MemoryObjectReceiveStream(_state=self._state) + + def close(self) -> None: + """ + Close the stream. + + This works the exact same way as :meth:`aclose`, but is provided as a special case for the + benefit of synchronous callbacks. + + """ + if not self._closed: + self._closed = True + self._state.open_receive_channels -= 1 + if self._state.open_receive_channels == 0: + send_events = list(self._state.waiting_senders.keys()) + for event in send_events: + event.set() + + async def aclose(self) -> None: + self.close() + + def statistics(self) -> MemoryObjectStreamStatistics: + """ + Return statistics about the current state of this stream. + + .. versionadded:: 3.0 + """ + return self._state.statistics() + + def __enter__(self) -> MemoryObjectReceiveStream[T_co]: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + +@dataclass(eq=False) +class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): + _state: MemoryObjectStreamState[T_contra] + _closed: bool = field(init=False, default=False) + + def __post_init__(self) -> None: + self._state.open_send_channels += 1 + + def send_nowait(self, item: T_contra) -> DeprecatedAwaitable: + """ + Send an item immediately if it can be done without waiting. + + :param item: the item to send + :raises ~anyio.ClosedResourceError: if this send stream has been closed + :raises ~anyio.BrokenResourceError: if the stream has been closed from the + receiving end + :raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting + to receive + + """ + if self._closed: + raise ClosedResourceError + if not self._state.open_receive_channels: + raise BrokenResourceError + + if self._state.waiting_receivers: + receive_event, container = self._state.waiting_receivers.popitem(last=False) + container.append(item) + receive_event.set() + elif len(self._state.buffer) < self._state.max_buffer_size: + self._state.buffer.append(item) + else: + raise WouldBlock + + return DeprecatedAwaitable(self.send_nowait) + + async def send(self, item: T_contra) -> None: + await checkpoint() + try: + self.send_nowait(item) + except WouldBlock: + # Wait until there's someone on the receiving end + send_event = Event() + self._state.waiting_senders[send_event] = item + try: + await send_event.wait() + except BaseException: + self._state.waiting_senders.pop(send_event, None) # type: ignore[arg-type] + raise + + if self._state.waiting_senders.pop(send_event, None): # type: ignore[arg-type] + raise BrokenResourceError + + def clone(self) -> MemoryObjectSendStream[T_contra]: + """ + Create a clone of this send stream. + + Each clone can be closed separately. Only when all clones have been closed will the + sending end of the memory stream be considered closed by the receiving ends. + + :return: the cloned stream + + """ + if self._closed: + raise ClosedResourceError + + return MemoryObjectSendStream(_state=self._state) + + def close(self) -> None: + """ + Close the stream. + + This works the exact same way as :meth:`aclose`, but is provided as a special case for the + benefit of synchronous callbacks. + + """ + if not self._closed: + self._closed = True + self._state.open_send_channels -= 1 + if self._state.open_send_channels == 0: + receive_events = list(self._state.waiting_receivers.keys()) + self._state.waiting_receivers.clear() + for event in receive_events: + event.set() + + async def aclose(self) -> None: + self.close() + + def statistics(self) -> MemoryObjectStreamStatistics: + """ + Return statistics about the current state of this stream. + + .. versionadded:: 3.0 + """ + return self._state.statistics() + + def __enter__(self) -> MemoryObjectSendStream[T_contra]: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() diff --git a/contrib/python/anyio/anyio/streams/stapled.py b/contrib/python/anyio/anyio/streams/stapled.py new file mode 100644 index 0000000000..1b2862e3ea --- /dev/null +++ b/contrib/python/anyio/anyio/streams/stapled.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Generic, Mapping, Sequence, TypeVar + +from ..abc import ( + ByteReceiveStream, + ByteSendStream, + ByteStream, + Listener, + ObjectReceiveStream, + ObjectSendStream, + ObjectStream, + TaskGroup, +) + +T_Item = TypeVar("T_Item") +T_Stream = TypeVar("T_Stream") + + +@dataclass(eq=False) +class StapledByteStream(ByteStream): + """ + Combines two byte streams into a single, bidirectional byte stream. + + Extra attributes will be provided from both streams, with the receive stream providing the + values in case of a conflict. + + :param ByteSendStream send_stream: the sending byte stream + :param ByteReceiveStream receive_stream: the receiving byte stream + """ + + send_stream: ByteSendStream + receive_stream: ByteReceiveStream + + async def receive(self, max_bytes: int = 65536) -> bytes: + return await self.receive_stream.receive(max_bytes) + + async def send(self, item: bytes) -> None: + await self.send_stream.send(item) + + async def send_eof(self) -> None: + await self.send_stream.aclose() + + async def aclose(self) -> None: + await self.send_stream.aclose() + await self.receive_stream.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.send_stream.extra_attributes, + **self.receive_stream.extra_attributes, + } + + +@dataclass(eq=False) +class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]): + """ + Combines two object streams into a single, bidirectional object stream. + + Extra attributes will be provided from both streams, with the receive stream providing the + values in case of a conflict. + + :param ObjectSendStream send_stream: the sending object stream + :param ObjectReceiveStream receive_stream: the receiving object stream + """ + + send_stream: ObjectSendStream[T_Item] + receive_stream: ObjectReceiveStream[T_Item] + + async def receive(self) -> T_Item: + return await self.receive_stream.receive() + + async def send(self, item: T_Item) -> None: + await self.send_stream.send(item) + + async def send_eof(self) -> None: + await self.send_stream.aclose() + + async def aclose(self) -> None: + await self.send_stream.aclose() + await self.receive_stream.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.send_stream.extra_attributes, + **self.receive_stream.extra_attributes, + } + + +@dataclass(eq=False) +class MultiListener(Generic[T_Stream], Listener[T_Stream]): + """ + Combines multiple listeners into one, serving connections from all of them at once. + + Any MultiListeners in the given collection of listeners will have their listeners moved into + this one. + + Extra attributes are provided from each listener, with each successive listener overriding any + conflicting attributes from the previous one. + + :param listeners: listeners to serve + :type listeners: Sequence[Listener[T_Stream]] + """ + + listeners: Sequence[Listener[T_Stream]] + + def __post_init__(self) -> None: + listeners: list[Listener[T_Stream]] = [] + for listener in self.listeners: + if isinstance(listener, MultiListener): + listeners.extend(listener.listeners) + del listener.listeners[:] # type: ignore[attr-defined] + else: + listeners.append(listener) + + self.listeners = listeners + + async def serve( + self, handler: Callable[[T_Stream], Any], task_group: TaskGroup | None = None + ) -> None: + from .. import create_task_group + + async with create_task_group() as tg: + for listener in self.listeners: + tg.start_soon(listener.serve, handler, task_group) + + async def aclose(self) -> None: + for listener in self.listeners: + await listener.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + attributes: dict = {} + for listener in self.listeners: + attributes.update(listener.extra_attributes) + + return attributes diff --git a/contrib/python/anyio/anyio/streams/text.py b/contrib/python/anyio/anyio/streams/text.py new file mode 100644 index 0000000000..bba2d3f7df --- /dev/null +++ b/contrib/python/anyio/anyio/streams/text.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import codecs +from dataclasses import InitVar, dataclass, field +from typing import Any, Callable, Mapping + +from ..abc import ( + AnyByteReceiveStream, + AnyByteSendStream, + AnyByteStream, + ObjectReceiveStream, + ObjectSendStream, + ObjectStream, +) + + +@dataclass(eq=False) +class TextReceiveStream(ObjectReceiveStream[str]): + """ + Stream wrapper that decodes bytes to strings using the given encoding. + + Decoding is done using :class:`~codecs.IncrementalDecoder` which returns any completely + received unicode characters as soon as they come in. + + :param transport_stream: any bytes-based receive stream + :param encoding: character encoding to use for decoding bytes to strings (defaults to + ``utf-8``) + :param errors: handling scheme for decoding errors (defaults to ``strict``; see the + `codecs module documentation`_ for a comprehensive list of options) + + .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects + """ + + transport_stream: AnyByteReceiveStream + encoding: InitVar[str] = "utf-8" + errors: InitVar[str] = "strict" + _decoder: codecs.IncrementalDecoder = field(init=False) + + def __post_init__(self, encoding: str, errors: str) -> None: + decoder_class = codecs.getincrementaldecoder(encoding) + self._decoder = decoder_class(errors=errors) + + async def receive(self) -> str: + while True: + chunk = await self.transport_stream.receive() + decoded = self._decoder.decode(chunk) + if decoded: + return decoded + + async def aclose(self) -> None: + await self.transport_stream.aclose() + self._decoder.reset() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return self.transport_stream.extra_attributes + + +@dataclass(eq=False) +class TextSendStream(ObjectSendStream[str]): + """ + Sends strings to the wrapped stream as bytes using the given encoding. + + :param AnyByteSendStream transport_stream: any bytes-based send stream + :param str encoding: character encoding to use for encoding strings to bytes (defaults to + ``utf-8``) + :param str errors: handling scheme for encoding errors (defaults to ``strict``; see the + `codecs module documentation`_ for a comprehensive list of options) + + .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects + """ + + transport_stream: AnyByteSendStream + encoding: InitVar[str] = "utf-8" + errors: str = "strict" + _encoder: Callable[..., tuple[bytes, int]] = field(init=False) + + def __post_init__(self, encoding: str) -> None: + self._encoder = codecs.getencoder(encoding) + + async def send(self, item: str) -> None: + encoded = self._encoder(item, self.errors)[0] + await self.transport_stream.send(encoded) + + async def aclose(self) -> None: + await self.transport_stream.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return self.transport_stream.extra_attributes + + +@dataclass(eq=False) +class TextStream(ObjectStream[str]): + """ + A bidirectional stream that decodes bytes to strings on receive and encodes strings to bytes on + send. + + Extra attributes will be provided from both streams, with the receive stream providing the + values in case of a conflict. + + :param AnyByteStream transport_stream: any bytes-based stream + :param str encoding: character encoding to use for encoding/decoding strings to/from bytes + (defaults to ``utf-8``) + :param str errors: handling scheme for encoding errors (defaults to ``strict``; see the + `codecs module documentation`_ for a comprehensive list of options) + + .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects + """ + + transport_stream: AnyByteStream + encoding: InitVar[str] = "utf-8" + errors: InitVar[str] = "strict" + _receive_stream: TextReceiveStream = field(init=False) + _send_stream: TextSendStream = field(init=False) + + def __post_init__(self, encoding: str, errors: str) -> None: + self._receive_stream = TextReceiveStream( + self.transport_stream, encoding=encoding, errors=errors + ) + self._send_stream = TextSendStream( + self.transport_stream, encoding=encoding, errors=errors + ) + + async def receive(self) -> str: + return await self._receive_stream.receive() + + async def send(self, item: str) -> None: + await self._send_stream.send(item) + + async def send_eof(self) -> None: + await self.transport_stream.send_eof() + + async def aclose(self) -> None: + await self._send_stream.aclose() + await self._receive_stream.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self._send_stream.extra_attributes, + **self._receive_stream.extra_attributes, + } diff --git a/contrib/python/anyio/anyio/streams/tls.py b/contrib/python/anyio/anyio/streams/tls.py new file mode 100644 index 0000000000..9f9e9fd89c --- /dev/null +++ b/contrib/python/anyio/anyio/streams/tls.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import logging +import re +import ssl +from dataclasses import dataclass +from functools import wraps +from typing import Any, Callable, Mapping, Tuple, TypeVar + +from .. import ( + BrokenResourceError, + EndOfStream, + aclose_forcefully, + get_cancelled_exc_class, +) +from .._core._typedattr import TypedAttributeSet, typed_attribute +from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup + +T_Retval = TypeVar("T_Retval") +_PCTRTT = Tuple[Tuple[str, str], ...] +_PCTRTTT = Tuple[_PCTRTT, ...] + + +class TLSAttribute(TypedAttributeSet): + """Contains Transport Layer Security related attributes.""" + + #: the selected ALPN protocol + alpn_protocol: str | None = typed_attribute() + #: the channel binding for type ``tls-unique`` + channel_binding_tls_unique: bytes = typed_attribute() + #: the selected cipher + cipher: tuple[str, str, int] = typed_attribute() + #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` + #: for more information) + peer_certificate: dict[str, str | _PCTRTTT | _PCTRTT] | None = typed_attribute() + #: the peer certificate in binary form + peer_certificate_binary: bytes | None = typed_attribute() + #: ``True`` if this is the server side of the connection + server_side: bool = typed_attribute() + #: ciphers shared by the client during the TLS handshake (``None`` if this is the + #: client side) + shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() + #: the :class:`~ssl.SSLObject` used for encryption + ssl_object: ssl.SSLObject = typed_attribute() + #: ``True`` if this stream does (and expects) a closing TLS handshake when the + #: stream is being closed + standard_compatible: bool = typed_attribute() + #: the TLS protocol version (e.g. ``TLSv1.2``) + tls_version: str = typed_attribute() + + +@dataclass(eq=False) +class TLSStream(ByteStream): + """ + A stream wrapper that encrypts all sent data and decrypts received data. + + This class has no public initializer; use :meth:`wrap` instead. + All extra attributes from :class:`~TLSAttribute` are supported. + + :var AnyByteStream transport_stream: the wrapped stream + + """ + + transport_stream: AnyByteStream + standard_compatible: bool + _ssl_object: ssl.SSLObject + _read_bio: ssl.MemoryBIO + _write_bio: ssl.MemoryBIO + + @classmethod + async def wrap( + cls, + transport_stream: AnyByteStream, + *, + server_side: bool | None = None, + hostname: str | None = None, + ssl_context: ssl.SSLContext | None = None, + standard_compatible: bool = True, + ) -> TLSStream: + """ + Wrap an existing stream with Transport Layer Security. + + This performs a TLS handshake with the peer. + + :param transport_stream: a bytes-transporting stream to wrap + :param server_side: ``True`` if this is the server side of the connection, + ``False`` if this is the client side (if omitted, will be set to ``False`` + if ``hostname`` has been provided, ``False`` otherwise). Used only to create + a default context when an explicit context has not been provided. + :param hostname: host name of the peer (if host name checking is desired) + :param ssl_context: the SSLContext object to use (if not provided, a secure + default will be created) + :param standard_compatible: if ``False``, skip the closing handshake when closing the + connection, and don't raise an exception if the peer does the same + :raises ~ssl.SSLError: if the TLS handshake fails + + """ + if server_side is None: + server_side = not hostname + + if not ssl_context: + purpose = ( + ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH + ) + ssl_context = ssl.create_default_context(purpose) + + # Re-enable detection of unexpected EOFs if it was disabled by Python + if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): + ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF + + bio_in = ssl.MemoryBIO() + bio_out = ssl.MemoryBIO() + ssl_object = ssl_context.wrap_bio( + bio_in, bio_out, server_side=server_side, server_hostname=hostname + ) + wrapper = cls( + transport_stream=transport_stream, + standard_compatible=standard_compatible, + _ssl_object=ssl_object, + _read_bio=bio_in, + _write_bio=bio_out, + ) + await wrapper._call_sslobject_method(ssl_object.do_handshake) + return wrapper + + async def _call_sslobject_method( + self, func: Callable[..., T_Retval], *args: object + ) -> T_Retval: + while True: + try: + result = func(*args) + except ssl.SSLWantReadError: + try: + # Flush any pending writes first + if self._write_bio.pending: + await self.transport_stream.send(self._write_bio.read()) + + data = await self.transport_stream.receive() + except EndOfStream: + self._read_bio.write_eof() + except OSError as exc: + self._read_bio.write_eof() + self._write_bio.write_eof() + raise BrokenResourceError from exc + else: + self._read_bio.write(data) + except ssl.SSLWantWriteError: + await self.transport_stream.send(self._write_bio.read()) + except ssl.SSLSyscallError as exc: + self._read_bio.write_eof() + self._write_bio.write_eof() + raise BrokenResourceError from exc + except ssl.SSLError as exc: + self._read_bio.write_eof() + self._write_bio.write_eof() + if ( + isinstance(exc, ssl.SSLEOFError) + or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror + ): + if self.standard_compatible: + raise BrokenResourceError from exc + else: + raise EndOfStream from None + + raise + else: + # Flush any pending writes first + if self._write_bio.pending: + await self.transport_stream.send(self._write_bio.read()) + + return result + + async def unwrap(self) -> tuple[AnyByteStream, bytes]: + """ + Does the TLS closing handshake. + + :return: a tuple of (wrapped byte stream, bytes left in the read buffer) + + """ + await self._call_sslobject_method(self._ssl_object.unwrap) + self._read_bio.write_eof() + self._write_bio.write_eof() + return self.transport_stream, self._read_bio.read() + + async def aclose(self) -> None: + if self.standard_compatible: + try: + await self.unwrap() + except BaseException: + await aclose_forcefully(self.transport_stream) + raise + + await self.transport_stream.aclose() + + async def receive(self, max_bytes: int = 65536) -> bytes: + data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) + if not data: + raise EndOfStream + + return data + + async def send(self, item: bytes) -> None: + await self._call_sslobject_method(self._ssl_object.write, item) + + async def send_eof(self) -> None: + tls_version = self.extra(TLSAttribute.tls_version) + match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) + if match: + major, minor = int(match.group(1)), int(match.group(2) or 0) + if (major, minor) < (1, 3): + raise NotImplementedError( + f"send_eof() requires at least TLSv1.3; current " + f"session uses {tls_version}" + ) + + raise NotImplementedError( + "send_eof() has not yet been implemented for TLS streams" + ) + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.transport_stream.extra_attributes, + TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, + TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding, + TLSAttribute.cipher: self._ssl_object.cipher, + TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), + TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( + True + ), + TLSAttribute.server_side: lambda: self._ssl_object.server_side, + TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() + if self._ssl_object.server_side + else None, + TLSAttribute.standard_compatible: lambda: self.standard_compatible, + TLSAttribute.ssl_object: lambda: self._ssl_object, + TLSAttribute.tls_version: self._ssl_object.version, + } + + +@dataclass(eq=False) +class TLSListener(Listener[TLSStream]): + """ + A convenience listener that wraps another listener and auto-negotiates a TLS session on every + accepted connection. + + If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is + called to do whatever post-mortem processing is deemed necessary. + + Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. + + :param Listener listener: the listener to wrap + :param ssl_context: the SSL context object + :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` + :param handshake_timeout: time limit for the TLS handshake + (passed to :func:`~anyio.fail_after`) + """ + + listener: Listener[Any] + ssl_context: ssl.SSLContext + standard_compatible: bool = True + handshake_timeout: float = 30 + + @staticmethod + async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: + """ + Handle an exception raised during the TLS handshake. + + This method does 3 things: + + #. Forcefully closes the original stream + #. Logs the exception (unless it was a cancellation exception) using the + ``anyio.streams.tls`` logger + #. Reraises the exception if it was a base exception or a cancellation exception + + :param exc: the exception + :param stream: the original stream + + """ + await aclose_forcefully(stream) + + # Log all except cancellation exceptions + if not isinstance(exc, get_cancelled_exc_class()): + logging.getLogger(__name__).exception("Error during TLS handshake") + + # Only reraise base exceptions and cancellation exceptions + if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): + raise + + async def serve( + self, + handler: Callable[[TLSStream], Any], + task_group: TaskGroup | None = None, + ) -> None: + @wraps(handler) + async def handler_wrapper(stream: AnyByteStream) -> None: + from .. import fail_after + + try: + with fail_after(self.handshake_timeout): + wrapped_stream = await TLSStream.wrap( + stream, + ssl_context=self.ssl_context, + standard_compatible=self.standard_compatible, + ) + except BaseException as exc: + await self.handle_handshake_error(exc, stream) + else: + await handler(wrapped_stream) + + await self.listener.serve(handler_wrapper, task_group) + + async def aclose(self) -> None: + await self.listener.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + TLSAttribute.standard_compatible: lambda: self.standard_compatible, + } diff --git a/contrib/python/anyio/anyio/to_process.py b/contrib/python/anyio/anyio/to_process.py new file mode 100644 index 0000000000..7ba9d44198 --- /dev/null +++ b/contrib/python/anyio/anyio/to_process.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import os +import pickle +import subprocess +import sys +from collections import deque +from importlib.util import module_from_spec, spec_from_file_location +from typing import Callable, TypeVar, cast + +from ._core._eventloop import current_time, get_asynclib, get_cancelled_exc_class +from ._core._exceptions import BrokenWorkerProcess +from ._core._subprocesses import open_process +from ._core._synchronization import CapacityLimiter +from ._core._tasks import CancelScope, fail_after +from .abc import ByteReceiveStream, ByteSendStream, Process +from .lowlevel import RunVar, checkpoint_if_cancelled +from .streams.buffered import BufferedByteReceiveStream + +WORKER_MAX_IDLE_TIME = 300 # 5 minutes + +T_Retval = TypeVar("T_Retval") +_process_pool_workers: RunVar[set[Process]] = RunVar("_process_pool_workers") +_process_pool_idle_workers: RunVar[deque[tuple[Process, float]]] = RunVar( + "_process_pool_idle_workers" +) +_default_process_limiter: RunVar[CapacityLimiter] = RunVar("_default_process_limiter") + + +async def run_sync( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + """ + Call the given function with the given arguments in a worker process. + + If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled, + the worker process running it will be abruptly terminated using SIGKILL (or + ``terminateProcess()`` on Windows). + + :param func: a callable + :param args: positional arguments for the callable + :param cancellable: ``True`` to allow cancellation of the operation while it's running + :param limiter: capacity limiter to use to limit the total amount of processes running + (if omitted, the default limiter is used) + :return: an awaitable that yields the return value of the function. + + """ + + async def send_raw_command(pickled_cmd: bytes) -> object: + try: + await stdin.send(pickled_cmd) + response = await buffered.receive_until(b"\n", 50) + status, length = response.split(b" ") + if status not in (b"RETURN", b"EXCEPTION"): + raise RuntimeError( + f"Worker process returned unexpected response: {response!r}" + ) + + pickled_response = await buffered.receive_exactly(int(length)) + except BaseException as exc: + workers.discard(process) + try: + process.kill() + with CancelScope(shield=True): + await process.aclose() + except ProcessLookupError: + pass + + if isinstance(exc, get_cancelled_exc_class()): + raise + else: + raise BrokenWorkerProcess from exc + + retval = pickle.loads(pickled_response) + if status == b"EXCEPTION": + assert isinstance(retval, BaseException) + raise retval + else: + return retval + + # First pickle the request before trying to reserve a worker process + await checkpoint_if_cancelled() + request = pickle.dumps(("run", func, args), protocol=pickle.HIGHEST_PROTOCOL) + + # If this is the first run in this event loop thread, set up the necessary variables + try: + workers = _process_pool_workers.get() + idle_workers = _process_pool_idle_workers.get() + except LookupError: + workers = set() + idle_workers = deque() + _process_pool_workers.set(workers) + _process_pool_idle_workers.set(idle_workers) + get_asynclib().setup_process_pool_exit_at_shutdown(workers) + + async with (limiter or current_default_process_limiter()): + # Pop processes from the pool (starting from the most recently used) until we find one that + # hasn't exited yet + process: Process + while idle_workers: + process, idle_since = idle_workers.pop() + if process.returncode is None: + stdin = cast(ByteSendStream, process.stdin) + buffered = BufferedByteReceiveStream( + cast(ByteReceiveStream, process.stdout) + ) + + # Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME seconds or + # longer + now = current_time() + killed_processes: list[Process] = [] + while idle_workers: + if now - idle_workers[0][1] < WORKER_MAX_IDLE_TIME: + break + + process, idle_since = idle_workers.popleft() + process.kill() + workers.remove(process) + killed_processes.append(process) + + with CancelScope(shield=True): + for process in killed_processes: + await process.aclose() + + break + + workers.remove(process) + else: + command = [sys.executable, "-u", "-m", __name__] + process = await open_process( + command, stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) + try: + stdin = cast(ByteSendStream, process.stdin) + buffered = BufferedByteReceiveStream( + cast(ByteReceiveStream, process.stdout) + ) + with fail_after(20): + message = await buffered.receive(6) + + if message != b"READY\n": + raise BrokenWorkerProcess( + f"Worker process returned unexpected response: {message!r}" + ) + + main_module_path = getattr(sys.modules["__main__"], "__file__", None) + pickled = pickle.dumps( + ("init", sys.path, main_module_path), + protocol=pickle.HIGHEST_PROTOCOL, + ) + await send_raw_command(pickled) + except (BrokenWorkerProcess, get_cancelled_exc_class()): + raise + except BaseException as exc: + process.kill() + raise BrokenWorkerProcess( + "Error during worker process initialization" + ) from exc + + workers.add(process) + + with CancelScope(shield=not cancellable): + try: + return cast(T_Retval, await send_raw_command(request)) + finally: + if process in workers: + idle_workers.append((process, current_time())) + + +def current_default_process_limiter() -> CapacityLimiter: + """ + Return the capacity limiter that is used by default to limit the number of worker processes. + + :return: a capacity limiter object + + """ + try: + return _default_process_limiter.get() + except LookupError: + limiter = CapacityLimiter(os.cpu_count() or 2) + _default_process_limiter.set(limiter) + return limiter + + +def process_worker() -> None: + # Redirect standard streams to os.devnull so that user code won't interfere with the + # parent-worker communication + stdin = sys.stdin + stdout = sys.stdout + sys.stdin = open(os.devnull) + sys.stdout = open(os.devnull, "w") + + stdout.buffer.write(b"READY\n") + while True: + retval = exception = None + try: + command, *args = pickle.load(stdin.buffer) + except EOFError: + return + except BaseException as exc: + exception = exc + else: + if command == "run": + func, args = args + try: + retval = func(*args) + except BaseException as exc: + exception = exc + elif command == "init": + main_module_path: str | None + sys.path, main_module_path = args + del sys.modules["__main__"] + if main_module_path: + # Load the parent's main module but as __mp_main__ instead of __main__ + # (like multiprocessing does) to avoid infinite recursion + try: + spec = spec_from_file_location("__mp_main__", main_module_path) + if spec and spec.loader: + main = module_from_spec(spec) + spec.loader.exec_module(main) + sys.modules["__main__"] = main + except BaseException as exc: + exception = exc + + try: + if exception is not None: + status = b"EXCEPTION" + pickled = pickle.dumps(exception, pickle.HIGHEST_PROTOCOL) + else: + status = b"RETURN" + pickled = pickle.dumps(retval, pickle.HIGHEST_PROTOCOL) + except BaseException as exc: + exception = exc + status = b"EXCEPTION" + pickled = pickle.dumps(exc, pickle.HIGHEST_PROTOCOL) + + stdout.buffer.write(b"%s %d\n" % (status, len(pickled))) + stdout.buffer.write(pickled) + + # Respect SIGTERM + if isinstance(exception, SystemExit): + raise exception + + +if __name__ == "__main__": + process_worker() diff --git a/contrib/python/anyio/anyio/to_thread.py b/contrib/python/anyio/anyio/to_thread.py new file mode 100644 index 0000000000..9315d1ecf1 --- /dev/null +++ b/contrib/python/anyio/anyio/to_thread.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Callable, TypeVar +from warnings import warn + +from ._core._eventloop import get_asynclib +from .abc import CapacityLimiter + +T_Retval = TypeVar("T_Retval") + + +async def run_sync( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + """ + Call the given function with the given arguments in a worker thread. + + If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled, + the thread will still run its course but its return value (or any raised exception) will be + ignored. + + :param func: a callable + :param args: positional arguments for the callable + :param cancellable: ``True`` to allow cancellation of the operation + :param limiter: capacity limiter to use to limit the total amount of threads running + (if omitted, the default limiter is used) + :return: an awaitable that yields the return value of the function. + + """ + return await get_asynclib().run_sync_in_worker_thread( + func, *args, cancellable=cancellable, limiter=limiter + ) + + +async def run_sync_in_worker_thread( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + warn( + "run_sync_in_worker_thread() has been deprecated, use anyio.to_thread.run_sync() instead", + DeprecationWarning, + ) + return await run_sync(func, *args, cancellable=cancellable, limiter=limiter) + + +def current_default_thread_limiter() -> CapacityLimiter: + """ + Return the capacity limiter that is used by default to limit the number of concurrent threads. + + :return: a capacity limiter object + + """ + return get_asynclib().current_default_thread_limiter() + + +def current_default_worker_thread_limiter() -> CapacityLimiter: + warn( + "current_default_worker_thread_limiter() has been deprecated, " + "use anyio.to_thread.current_default_thread_limiter() instead", + DeprecationWarning, + ) + return current_default_thread_limiter() diff --git a/contrib/python/anyio/ya.make b/contrib/python/anyio/ya.make new file mode 100644 index 0000000000..f8534a7d6c --- /dev/null +++ b/contrib/python/anyio/ya.make @@ -0,0 +1,70 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(3.7.1) + +LICENSE(MIT) + +PEERDIR( + contrib/python/idna + contrib/python/sniffio +) + +NO_LINT() + +NO_CHECK_IMPORTS( + anyio._backends._trio + anyio.pytest_plugin +) + +PY_SRCS( + TOP_LEVEL + anyio/__init__.py + anyio/_backends/__init__.py + anyio/_backends/_asyncio.py + anyio/_backends/_trio.py + anyio/_core/__init__.py + anyio/_core/_compat.py + anyio/_core/_eventloop.py + anyio/_core/_exceptions.py + anyio/_core/_fileio.py + anyio/_core/_resources.py + anyio/_core/_signals.py + anyio/_core/_sockets.py + anyio/_core/_streams.py + anyio/_core/_subprocesses.py + anyio/_core/_synchronization.py + anyio/_core/_tasks.py + anyio/_core/_testing.py + anyio/_core/_typedattr.py + anyio/abc/__init__.py + anyio/abc/_resources.py + anyio/abc/_sockets.py + anyio/abc/_streams.py + anyio/abc/_subprocesses.py + anyio/abc/_tasks.py + anyio/abc/_testing.py + anyio/from_thread.py + anyio/lowlevel.py + anyio/pytest_plugin.py + anyio/streams/__init__.py + anyio/streams/buffered.py + anyio/streams/file.py + anyio/streams/memory.py + anyio/streams/stapled.py + anyio/streams/text.py + anyio/streams/tls.py + anyio/to_process.py + anyio/to_thread.py +) + +RESOURCE_FILES( + PREFIX contrib/python/anyio/ + .dist-info/METADATA + .dist-info/entry_points.txt + .dist-info/top_level.txt + anyio/py.typed +) + +END() diff --git a/contrib/python/h11/.dist-info/METADATA b/contrib/python/h11/.dist-info/METADATA new file mode 100644 index 0000000000..cf12a82f19 --- /dev/null +++ b/contrib/python/h11/.dist-info/METADATA @@ -0,0 +1,193 @@ +Metadata-Version: 2.1 +Name: h11 +Version: 0.14.0 +Summary: A pure-Python, bring-your-own-I/O implementation of HTTP/1.1 +Home-page: https://github.com/python-hyper/h11 +Author: Nathaniel J. Smith +Author-email: njs@pobox.com +License: MIT +Classifier: Development Status :: 3 - Alpha +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Topic :: Internet :: WWW/HTTP +Classifier: Topic :: System :: Networking +Requires-Python: >=3.7 +License-File: LICENSE.txt +Requires-Dist: typing-extensions ; python_version < "3.8" + +h11 +=== + +.. image:: https://travis-ci.org/python-hyper/h11.svg?branch=master + :target: https://travis-ci.org/python-hyper/h11 + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-hyper/h11/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-hyper/h11 + :alt: Test coverage + +.. image:: https://readthedocs.org/projects/h11/badge/?version=latest + :target: http://h11.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +This is a little HTTP/1.1 library written from scratch in Python, +heavily inspired by `hyper-h2 <https://hyper-h2.readthedocs.io/>`_. + +It's a "bring-your-own-I/O" library; h11 contains no IO code +whatsoever. This means you can hook h11 up to your favorite network +API, and that could be anything you want: synchronous, threaded, +asynchronous, or your own implementation of `RFC 6214 +<https://tools.ietf.org/html/rfc6214>`_ -- h11 won't judge you. +(Compare this to the current state of the art, where every time a `new +network API <https://trio.readthedocs.io/>`_ comes along then someone +gets to start over reimplementing the entire HTTP protocol from +scratch.) Cory Benfield made an `excellent blog post describing the +benefits of this approach +<https://lukasa.co.uk/2015/10/The_New_Hyper/>`_, or if you like video +then here's his `PyCon 2016 talk on the same theme +<https://www.youtube.com/watch?v=7cC3_jGwl_U>`_. + +This also means that h11 is not immediately useful out of the box: +it's a toolkit for building programs that speak HTTP, not something +that could directly replace ``requests`` or ``twisted.web`` or +whatever. But h11 makes it much easier to implement something like +``requests`` or ``twisted.web``. + +At a high level, working with h11 goes like this: + +1) First, create an ``h11.Connection`` object to track the state of a + single HTTP/1.1 connection. + +2) When you read data off the network, pass it to + ``conn.receive_data(...)``; you'll get back a list of objects + representing high-level HTTP "events". + +3) When you want to send a high-level HTTP event, create the + corresponding "event" object and pass it to ``conn.send(...)``; + this will give you back some bytes that you can then push out + through the network. + +For example, a client might instantiate and then send a +``h11.Request`` object, then zero or more ``h11.Data`` objects for the +request body (e.g., if this is a POST), and then a +``h11.EndOfMessage`` to indicate the end of the message. Then the +server would then send back a ``h11.Response``, some ``h11.Data``, and +its own ``h11.EndOfMessage``. If either side violates the protocol, +you'll get a ``h11.ProtocolError`` exception. + +h11 is suitable for implementing both servers and clients, and has a +pleasantly symmetric API: the events you send as a client are exactly +the ones that you receive as a server and vice-versa. + +`Here's an example of a tiny HTTP client +<https://github.com/python-hyper/h11/blob/master/examples/basic-client.py>`_ + +It also has `a fine manual <https://h11.readthedocs.io/>`_. + +FAQ +--- + +*Whyyyyy?* + +I wanted to play with HTTP in `Curio +<https://curio.readthedocs.io/en/latest/tutorial.html>`__ and `Trio +<https://trio.readthedocs.io>`__, which at the time didn't have any +HTTP libraries. So I thought, no big deal, Python has, like, a dozen +different implementations of HTTP, surely I can find one that's +reusable. I didn't find one, but I did find Cory's call-to-arms +blog-post. So I figured, well, fine, if I have to implement HTTP from +scratch, at least I can make sure no-one *else* has to ever again. + +*Should I use it?* + +Maybe. You should be aware that it's a very young project. But, it's +feature complete and has an exhaustive test-suite and complete docs, +so the next step is for people to try using it and see how it goes +:-). If you do then please let us know -- if nothing else we'll want +to talk to you before making any incompatible changes! + +*What are the features/limitations?* + +Roughly speaking, it's trying to be a robust, complete, and non-hacky +implementation of the first "chapter" of the HTTP/1.1 spec: `RFC 7230: +HTTP/1.1 Message Syntax and Routing +<https://tools.ietf.org/html/rfc7230>`_. That is, it mostly focuses on +implementing HTTP at the level of taking bytes on and off the wire, +and the headers related to that, and tries to be anal about spec +conformance. It doesn't know about higher-level concerns like URL +routing, conditional GETs, cross-origin cookie policies, or content +negotiation. But it does know how to take care of framing, +cross-version differences in keep-alive handling, and the "obsolete +line folding" rule, so you can focus your energies on the hard / +interesting parts for your application, and it tries to support the +full specification in the sense that any useful HTTP/1.1 conformant +application should be able to use h11. + +It's pure Python, and has no dependencies outside of the standard +library. + +It has a test suite with 100.0% coverage for both statements and +branches. + +Currently it supports Python 3 (testing on 3.7-3.10) and PyPy 3. +The last Python 2-compatible version was h11 0.11.x. +(Originally it had a Cython wrapper for `http-parser +<https://github.com/nodejs/http-parser>`_ and a beautiful nested state +machine implemented with ``yield from`` to postprocess the output. But +I had to take these out -- the new *parser* needs fewer lines-of-code +than the old *parser wrapper*, is written in pure Python, uses no +exotic language syntax, and has more features. It's sad, really; that +old state machine was really slick. I just need a few sentences here +to mourn that.) + +I don't know how fast it is. I haven't benchmarked or profiled it yet, +so it's probably got a few pointless hot spots, and I've been trying +to err on the side of simplicity and robustness instead of +micro-optimization. But at the architectural level I tried hard to +avoid fundamentally bad decisions, e.g., I believe that all the +parsing algorithms remain linear-time even in the face of pathological +input like slowloris, and there are no byte-by-byte loops. (I also +believe that it maintains bounded memory usage in the face of +arbitrary/pathological input.) + +The whole library is ~800 lines-of-code. You can read and understand +the whole thing in less than an hour. Most of the energy invested in +this so far has been spent on trying to keep things simple by +minimizing special-cases and ad hoc state manipulation; even though it +is now quite small and simple, I'm still annoyed that I haven't +figured out how to make it even smaller and simpler. (Unfortunately, +HTTP does not lend itself to simplicity.) + +The API is ~feature complete and I don't expect the general outlines +to change much, but you can't judge an API's ergonomics until you +actually document and use it, so I'd expect some changes in the +details. + +*How do I try it?* + +.. code-block:: sh + + $ pip install h11 + $ git clone git@github.com:python-hyper/h11 + $ cd h11/examples + $ python basic-client.py + +and go from there. + +*License?* + +MIT + +*Code of conduct?* + +Contributors are requested to follow our `code of conduct +<https://github.com/python-hyper/h11/blob/master/CODE_OF_CONDUCT.md>`_ in +all project spaces. diff --git a/contrib/python/h11/.dist-info/top_level.txt b/contrib/python/h11/.dist-info/top_level.txt new file mode 100644 index 0000000000..0d24def711 --- /dev/null +++ b/contrib/python/h11/.dist-info/top_level.txt @@ -0,0 +1 @@ +h11 diff --git a/contrib/python/h11/LICENSE.txt b/contrib/python/h11/LICENSE.txt new file mode 100644 index 0000000000..8f080eae84 --- /dev/null +++ b/contrib/python/h11/LICENSE.txt @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2016 Nathaniel J. Smith <njs@pobox.com> and other contributors + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/contrib/python/h11/README.rst b/contrib/python/h11/README.rst new file mode 100644 index 0000000000..56e277e3d1 --- /dev/null +++ b/contrib/python/h11/README.rst @@ -0,0 +1,168 @@ +h11 +=== + +.. image:: https://travis-ci.org/python-hyper/h11.svg?branch=master + :target: https://travis-ci.org/python-hyper/h11 + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-hyper/h11/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-hyper/h11 + :alt: Test coverage + +.. image:: https://readthedocs.org/projects/h11/badge/?version=latest + :target: http://h11.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +This is a little HTTP/1.1 library written from scratch in Python, +heavily inspired by `hyper-h2 <https://hyper-h2.readthedocs.io/>`_. + +It's a "bring-your-own-I/O" library; h11 contains no IO code +whatsoever. This means you can hook h11 up to your favorite network +API, and that could be anything you want: synchronous, threaded, +asynchronous, or your own implementation of `RFC 6214 +<https://tools.ietf.org/html/rfc6214>`_ -- h11 won't judge you. +(Compare this to the current state of the art, where every time a `new +network API <https://trio.readthedocs.io/>`_ comes along then someone +gets to start over reimplementing the entire HTTP protocol from +scratch.) Cory Benfield made an `excellent blog post describing the +benefits of this approach +<https://lukasa.co.uk/2015/10/The_New_Hyper/>`_, or if you like video +then here's his `PyCon 2016 talk on the same theme +<https://www.youtube.com/watch?v=7cC3_jGwl_U>`_. + +This also means that h11 is not immediately useful out of the box: +it's a toolkit for building programs that speak HTTP, not something +that could directly replace ``requests`` or ``twisted.web`` or +whatever. But h11 makes it much easier to implement something like +``requests`` or ``twisted.web``. + +At a high level, working with h11 goes like this: + +1) First, create an ``h11.Connection`` object to track the state of a + single HTTP/1.1 connection. + +2) When you read data off the network, pass it to + ``conn.receive_data(...)``; you'll get back a list of objects + representing high-level HTTP "events". + +3) When you want to send a high-level HTTP event, create the + corresponding "event" object and pass it to ``conn.send(...)``; + this will give you back some bytes that you can then push out + through the network. + +For example, a client might instantiate and then send a +``h11.Request`` object, then zero or more ``h11.Data`` objects for the +request body (e.g., if this is a POST), and then a +``h11.EndOfMessage`` to indicate the end of the message. Then the +server would then send back a ``h11.Response``, some ``h11.Data``, and +its own ``h11.EndOfMessage``. If either side violates the protocol, +you'll get a ``h11.ProtocolError`` exception. + +h11 is suitable for implementing both servers and clients, and has a +pleasantly symmetric API: the events you send as a client are exactly +the ones that you receive as a server and vice-versa. + +`Here's an example of a tiny HTTP client +<https://github.com/python-hyper/h11/blob/master/examples/basic-client.py>`_ + +It also has `a fine manual <https://h11.readthedocs.io/>`_. + +FAQ +--- + +*Whyyyyy?* + +I wanted to play with HTTP in `Curio +<https://curio.readthedocs.io/en/latest/tutorial.html>`__ and `Trio +<https://trio.readthedocs.io>`__, which at the time didn't have any +HTTP libraries. So I thought, no big deal, Python has, like, a dozen +different implementations of HTTP, surely I can find one that's +reusable. I didn't find one, but I did find Cory's call-to-arms +blog-post. So I figured, well, fine, if I have to implement HTTP from +scratch, at least I can make sure no-one *else* has to ever again. + +*Should I use it?* + +Maybe. You should be aware that it's a very young project. But, it's +feature complete and has an exhaustive test-suite and complete docs, +so the next step is for people to try using it and see how it goes +:-). If you do then please let us know -- if nothing else we'll want +to talk to you before making any incompatible changes! + +*What are the features/limitations?* + +Roughly speaking, it's trying to be a robust, complete, and non-hacky +implementation of the first "chapter" of the HTTP/1.1 spec: `RFC 7230: +HTTP/1.1 Message Syntax and Routing +<https://tools.ietf.org/html/rfc7230>`_. That is, it mostly focuses on +implementing HTTP at the level of taking bytes on and off the wire, +and the headers related to that, and tries to be anal about spec +conformance. It doesn't know about higher-level concerns like URL +routing, conditional GETs, cross-origin cookie policies, or content +negotiation. But it does know how to take care of framing, +cross-version differences in keep-alive handling, and the "obsolete +line folding" rule, so you can focus your energies on the hard / +interesting parts for your application, and it tries to support the +full specification in the sense that any useful HTTP/1.1 conformant +application should be able to use h11. + +It's pure Python, and has no dependencies outside of the standard +library. + +It has a test suite with 100.0% coverage for both statements and +branches. + +Currently it supports Python 3 (testing on 3.7-3.10) and PyPy 3. +The last Python 2-compatible version was h11 0.11.x. +(Originally it had a Cython wrapper for `http-parser +<https://github.com/nodejs/http-parser>`_ and a beautiful nested state +machine implemented with ``yield from`` to postprocess the output. But +I had to take these out -- the new *parser* needs fewer lines-of-code +than the old *parser wrapper*, is written in pure Python, uses no +exotic language syntax, and has more features. It's sad, really; that +old state machine was really slick. I just need a few sentences here +to mourn that.) + +I don't know how fast it is. I haven't benchmarked or profiled it yet, +so it's probably got a few pointless hot spots, and I've been trying +to err on the side of simplicity and robustness instead of +micro-optimization. But at the architectural level I tried hard to +avoid fundamentally bad decisions, e.g., I believe that all the +parsing algorithms remain linear-time even in the face of pathological +input like slowloris, and there are no byte-by-byte loops. (I also +believe that it maintains bounded memory usage in the face of +arbitrary/pathological input.) + +The whole library is ~800 lines-of-code. You can read and understand +the whole thing in less than an hour. Most of the energy invested in +this so far has been spent on trying to keep things simple by +minimizing special-cases and ad hoc state manipulation; even though it +is now quite small and simple, I'm still annoyed that I haven't +figured out how to make it even smaller and simpler. (Unfortunately, +HTTP does not lend itself to simplicity.) + +The API is ~feature complete and I don't expect the general outlines +to change much, but you can't judge an API's ergonomics until you +actually document and use it, so I'd expect some changes in the +details. + +*How do I try it?* + +.. code-block:: sh + + $ pip install h11 + $ git clone git@github.com:python-hyper/h11 + $ cd h11/examples + $ python basic-client.py + +and go from there. + +*License?* + +MIT + +*Code of conduct?* + +Contributors are requested to follow our `code of conduct +<https://github.com/python-hyper/h11/blob/master/CODE_OF_CONDUCT.md>`_ in +all project spaces. diff --git a/contrib/python/h11/h11/__init__.py b/contrib/python/h11/h11/__init__.py new file mode 100644 index 0000000000..989e92c345 --- /dev/null +++ b/contrib/python/h11/h11/__init__.py @@ -0,0 +1,62 @@ +# A highish-level implementation of the HTTP/1.1 wire protocol (RFC 7230), +# containing no networking code at all, loosely modelled on hyper-h2's generic +# implementation of HTTP/2 (and in particular the h2.connection.H2Connection +# class). There's still a bunch of subtle details you need to get right if you +# want to make this actually useful, because it doesn't implement all the +# semantics to check that what you're asking to write to the wire is sensible, +# but at least it gets you out of dealing with the wire itself. + +from h11._connection import Connection, NEED_DATA, PAUSED +from h11._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from h11._state import ( + CLIENT, + CLOSED, + DONE, + ERROR, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) +from h11._util import LocalProtocolError, ProtocolError, RemoteProtocolError +from h11._version import __version__ + +PRODUCT_ID = "python-h11/" + __version__ + + +__all__ = ( + "Connection", + "NEED_DATA", + "PAUSED", + "ConnectionClosed", + "Data", + "EndOfMessage", + "Event", + "InformationalResponse", + "Request", + "Response", + "CLIENT", + "CLOSED", + "DONE", + "ERROR", + "IDLE", + "MUST_CLOSE", + "SEND_BODY", + "SEND_RESPONSE", + "SERVER", + "SWITCHED_PROTOCOL", + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", +) diff --git a/contrib/python/h11/h11/_abnf.py b/contrib/python/h11/h11/_abnf.py new file mode 100644 index 0000000000..933587fba2 --- /dev/null +++ b/contrib/python/h11/h11/_abnf.py @@ -0,0 +1,132 @@ +# We use native strings for all the re patterns, to take advantage of string +# formatting, and then convert to bytestrings when compiling the final re +# objects. + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#whitespace +# OWS = *( SP / HTAB ) +# ; optional whitespace +OWS = r"[ \t]*" + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#rule.token.separators +# token = 1*tchar +# +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" +# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" +# / DIGIT / ALPHA +# ; any VCHAR, except delimiters +token = r"[-!#$%&'*+.^_`|~0-9a-zA-Z]+" + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#header.fields +# field-name = token +field_name = token + +# The standard says: +# +# field-value = *( field-content / obs-fold ) +# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +# field-vchar = VCHAR / obs-text +# obs-fold = CRLF 1*( SP / HTAB ) +# ; obsolete line folding +# ; see Section 3.2.4 +# +# https://tools.ietf.org/html/rfc5234#appendix-B.1 +# +# VCHAR = %x21-7E +# ; visible (printing) characters +# +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#rule.quoted-string +# obs-text = %x80-FF +# +# However, the standard definition of field-content is WRONG! It disallows +# fields containing a single visible character surrounded by whitespace, +# e.g. "foo a bar". +# +# See: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 +# +# So our definition of field_content attempts to fix it up... +# +# Also, we allow lots of control characters, because apparently people assume +# that they're legal in practice (e.g., google analytics makes cookies with +# \x01 in them!): +# https://github.com/python-hyper/h11/issues/57 +# We still don't allow NUL or whitespace, because those are often treated as +# meta-characters and letting them through can lead to nasty issues like SSRF. +vchar = r"[\x21-\x7e]" +vchar_or_obs_text = r"[^\x00\s]" +field_vchar = vchar_or_obs_text +field_content = r"{field_vchar}+(?:[ \t]+{field_vchar}+)*".format(**globals()) + +# We handle obs-fold at a different level, and our fixed-up field_content +# already grows to swallow the whole value, so ? instead of * +field_value = r"({field_content})?".format(**globals()) + +# header-field = field-name ":" OWS field-value OWS +header_field = ( + r"(?P<field_name>{field_name})" + r":" + r"{OWS}" + r"(?P<field_value>{field_value})" + r"{OWS}".format(**globals()) +) + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#request.line +# +# request-line = method SP request-target SP HTTP-version CRLF +# method = token +# HTTP-version = HTTP-name "/" DIGIT "." DIGIT +# HTTP-name = %x48.54.54.50 ; "HTTP", case-sensitive +# +# request-target is complicated (see RFC 7230 sec 5.3) -- could be path, full +# URL, host+port (for connect), or even "*", but in any case we are guaranteed +# that it contists of the visible printing characters. +method = token +request_target = r"{vchar}+".format(**globals()) +http_version = r"HTTP/(?P<http_version>[0-9]\.[0-9])" +request_line = ( + r"(?P<method>{method})" + r" " + r"(?P<target>{request_target})" + r" " + r"{http_version}".format(**globals()) +) + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#status.line +# +# status-line = HTTP-version SP status-code SP reason-phrase CRLF +# status-code = 3DIGIT +# reason-phrase = *( HTAB / SP / VCHAR / obs-text ) +status_code = r"[0-9]{3}" +reason_phrase = r"([ \t]|{vchar_or_obs_text})*".format(**globals()) +status_line = ( + r"{http_version}" + r" " + r"(?P<status_code>{status_code})" + # However, there are apparently a few too many servers out there that just + # leave out the reason phrase: + # https://github.com/scrapy/scrapy/issues/345#issuecomment-281756036 + # https://github.com/seanmonstar/httparse/issues/29 + # so make it optional. ?: is a non-capturing group. + r"(?: (?P<reason>{reason_phrase}))?".format(**globals()) +) + +HEXDIG = r"[0-9A-Fa-f]" +# Actually +# +# chunk-size = 1*HEXDIG +# +# but we impose an upper-limit to avoid ridiculosity. len(str(2**64)) == 20 +chunk_size = r"({HEXDIG}){{1,20}}".format(**globals()) +# Actually +# +# chunk-ext = *( ";" chunk-ext-name [ "=" chunk-ext-val ] ) +# +# but we aren't parsing the things so we don't really care. +chunk_ext = r";.*" +chunk_header = ( + r"(?P<chunk_size>{chunk_size})" + r"(?P<chunk_ext>{chunk_ext})?" + r"{OWS}\r\n".format( + **globals() + ) # Even though the specification does not allow for extra whitespaces, + # we are lenient with trailing whitespaces because some servers on the wild use it. +) diff --git a/contrib/python/h11/h11/_connection.py b/contrib/python/h11/h11/_connection.py new file mode 100644 index 0000000000..d175270759 --- /dev/null +++ b/contrib/python/h11/h11/_connection.py @@ -0,0 +1,633 @@ +# This contains the main Connection class. Everything in h11 revolves around +# this. +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union + +from ._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from ._headers import get_comma_header, has_expect_100_continue, set_comma_header +from ._readers import READERS, ReadersType +from ._receivebuffer import ReceiveBuffer +from ._state import ( + _SWITCH_CONNECT, + _SWITCH_UPGRADE, + CLIENT, + ConnectionState, + DONE, + ERROR, + MIGHT_SWITCH_PROTOCOL, + SEND_BODY, + SERVER, + SWITCHED_PROTOCOL, +) +from ._util import ( # Import the internal things we need + LocalProtocolError, + RemoteProtocolError, + Sentinel, +) +from ._writers import WRITERS, WritersType + +# Everything in __all__ gets re-exported as part of the h11 public API. +__all__ = ["Connection", "NEED_DATA", "PAUSED"] + + +class NEED_DATA(Sentinel, metaclass=Sentinel): + pass + + +class PAUSED(Sentinel, metaclass=Sentinel): + pass + + +# If we ever have this much buffered without it making a complete parseable +# event, we error out. The only time we really buffer is when reading the +# request/response line + headers together, so this is effectively the limit on +# the size of that. +# +# Some precedents for defaults: +# - node.js: 80 * 1024 +# - tomcat: 8 * 1024 +# - IIS: 16 * 1024 +# - Apache: <8 KiB per line> +DEFAULT_MAX_INCOMPLETE_EVENT_SIZE = 16 * 1024 + +# RFC 7230's rules for connection lifecycles: +# - If either side says they want to close the connection, then the connection +# must close. +# - HTTP/1.1 defaults to keep-alive unless someone says Connection: close +# - HTTP/1.0 defaults to close unless both sides say Connection: keep-alive +# (and even this is a mess -- e.g. if you're implementing a proxy then +# sending Connection: keep-alive is forbidden). +# +# We simplify life by simply not supporting keep-alive with HTTP/1.0 peers. So +# our rule is: +# - If someone says Connection: close, we will close +# - If someone uses HTTP/1.0, we will close. +def _keep_alive(event: Union[Request, Response]) -> bool: + connection = get_comma_header(event.headers, b"connection") + if b"close" in connection: + return False + if getattr(event, "http_version", b"1.1") < b"1.1": + return False + return True + + +def _body_framing( + request_method: bytes, event: Union[Request, Response] +) -> Tuple[str, Union[Tuple[()], Tuple[int]]]: + # Called when we enter SEND_BODY to figure out framing information for + # this body. + # + # These are the only two events that can trigger a SEND_BODY state: + assert type(event) in (Request, Response) + # Returns one of: + # + # ("content-length", count) + # ("chunked", ()) + # ("http/1.0", ()) + # + # which are (lookup key, *args) for constructing body reader/writer + # objects. + # + # Reference: https://tools.ietf.org/html/rfc7230#section-3.3.3 + # + # Step 1: some responses always have an empty body, regardless of what the + # headers say. + if type(event) is Response: + if ( + event.status_code in (204, 304) + or request_method == b"HEAD" + or (request_method == b"CONNECT" and 200 <= event.status_code < 300) + ): + return ("content-length", (0,)) + # Section 3.3.3 also lists another case -- responses with status_code + # < 200. For us these are InformationalResponses, not Responses, so + # they can't get into this function in the first place. + assert event.status_code >= 200 + + # Step 2: check for Transfer-Encoding (T-E beats C-L): + transfer_encodings = get_comma_header(event.headers, b"transfer-encoding") + if transfer_encodings: + assert transfer_encodings == [b"chunked"] + return ("chunked", ()) + + # Step 3: check for Content-Length + content_lengths = get_comma_header(event.headers, b"content-length") + if content_lengths: + return ("content-length", (int(content_lengths[0]),)) + + # Step 4: no applicable headers; fallback/default depends on type + if type(event) is Request: + return ("content-length", (0,)) + else: + return ("http/1.0", ()) + + +################################################################ +# +# The main Connection class +# +################################################################ + + +class Connection: + """An object encapsulating the state of an HTTP connection. + + Args: + our_role: If you're implementing a client, pass :data:`h11.CLIENT`. If + you're implementing a server, pass :data:`h11.SERVER`. + + max_incomplete_event_size (int): + The maximum number of bytes we're willing to buffer of an + incomplete event. In practice this mostly sets a limit on the + maximum size of the request/response line + headers. If this is + exceeded, then :meth:`next_event` will raise + :exc:`RemoteProtocolError`. + + """ + + def __init__( + self, + our_role: Type[Sentinel], + max_incomplete_event_size: int = DEFAULT_MAX_INCOMPLETE_EVENT_SIZE, + ) -> None: + self._max_incomplete_event_size = max_incomplete_event_size + # State and role tracking + if our_role not in (CLIENT, SERVER): + raise ValueError("expected CLIENT or SERVER, not {!r}".format(our_role)) + self.our_role = our_role + self.their_role: Type[Sentinel] + if our_role is CLIENT: + self.their_role = SERVER + else: + self.their_role = CLIENT + self._cstate = ConnectionState() + + # Callables for converting data->events or vice-versa given the + # current state + self._writer = self._get_io_object(self.our_role, None, WRITERS) + self._reader = self._get_io_object(self.their_role, None, READERS) + + # Holds any unprocessed received data + self._receive_buffer = ReceiveBuffer() + # If this is true, then it indicates that the incoming connection was + # closed *after* the end of whatever's in self._receive_buffer: + self._receive_buffer_closed = False + + # Extra bits of state that don't fit into the state machine. + # + # These two are only used to interpret framing headers for figuring + # out how to read/write response bodies. their_http_version is also + # made available as a convenient public API. + self.their_http_version: Optional[bytes] = None + self._request_method: Optional[bytes] = None + # This is pure flow-control and doesn't at all affect the set of legal + # transitions, so no need to bother ConnectionState with it: + self.client_is_waiting_for_100_continue = False + + @property + def states(self) -> Dict[Type[Sentinel], Type[Sentinel]]: + """A dictionary like:: + + {CLIENT: <client state>, SERVER: <server state>} + + See :ref:`state-machine` for details. + + """ + return dict(self._cstate.states) + + @property + def our_state(self) -> Type[Sentinel]: + """The current state of whichever role we are playing. See + :ref:`state-machine` for details. + """ + return self._cstate.states[self.our_role] + + @property + def their_state(self) -> Type[Sentinel]: + """The current state of whichever role we are NOT playing. See + :ref:`state-machine` for details. + """ + return self._cstate.states[self.their_role] + + @property + def they_are_waiting_for_100_continue(self) -> bool: + return self.their_role is CLIENT and self.client_is_waiting_for_100_continue + + def start_next_cycle(self) -> None: + """Attempt to reset our connection state for a new request/response + cycle. + + If both client and server are in :data:`DONE` state, then resets them + both to :data:`IDLE` state in preparation for a new request/response + cycle on this same connection. Otherwise, raises a + :exc:`LocalProtocolError`. + + See :ref:`keepalive-and-pipelining`. + + """ + old_states = dict(self._cstate.states) + self._cstate.start_next_cycle() + self._request_method = None + # self.their_http_version gets left alone, since it presumably lasts + # beyond a single request/response cycle + assert not self.client_is_waiting_for_100_continue + self._respond_to_state_changes(old_states) + + def _process_error(self, role: Type[Sentinel]) -> None: + old_states = dict(self._cstate.states) + self._cstate.process_error(role) + self._respond_to_state_changes(old_states) + + def _server_switch_event(self, event: Event) -> Optional[Type[Sentinel]]: + if type(event) is InformationalResponse and event.status_code == 101: + return _SWITCH_UPGRADE + if type(event) is Response: + if ( + _SWITCH_CONNECT in self._cstate.pending_switch_proposals + and 200 <= event.status_code < 300 + ): + return _SWITCH_CONNECT + return None + + # All events go through here + def _process_event(self, role: Type[Sentinel], event: Event) -> None: + # First, pass the event through the state machine to make sure it + # succeeds. + old_states = dict(self._cstate.states) + if role is CLIENT and type(event) is Request: + if event.method == b"CONNECT": + self._cstate.process_client_switch_proposal(_SWITCH_CONNECT) + if get_comma_header(event.headers, b"upgrade"): + self._cstate.process_client_switch_proposal(_SWITCH_UPGRADE) + server_switch_event = None + if role is SERVER: + server_switch_event = self._server_switch_event(event) + self._cstate.process_event(role, type(event), server_switch_event) + + # Then perform the updates triggered by it. + + if type(event) is Request: + self._request_method = event.method + + if role is self.their_role and type(event) in ( + Request, + Response, + InformationalResponse, + ): + event = cast(Union[Request, Response, InformationalResponse], event) + self.their_http_version = event.http_version + + # Keep alive handling + # + # RFC 7230 doesn't really say what one should do if Connection: close + # shows up on a 1xx InformationalResponse. I think the idea is that + # this is not supposed to happen. In any case, if it does happen, we + # ignore it. + if type(event) in (Request, Response) and not _keep_alive( + cast(Union[Request, Response], event) + ): + self._cstate.process_keep_alive_disabled() + + # 100-continue + if type(event) is Request and has_expect_100_continue(event): + self.client_is_waiting_for_100_continue = True + if type(event) in (InformationalResponse, Response): + self.client_is_waiting_for_100_continue = False + if role is CLIENT and type(event) in (Data, EndOfMessage): + self.client_is_waiting_for_100_continue = False + + self._respond_to_state_changes(old_states, event) + + def _get_io_object( + self, + role: Type[Sentinel], + event: Optional[Event], + io_dict: Union[ReadersType, WritersType], + ) -> Optional[Callable[..., Any]]: + # event may be None; it's only used when entering SEND_BODY + state = self._cstate.states[role] + if state is SEND_BODY: + # Special case: the io_dict has a dict of reader/writer factories + # that depend on the request/response framing. + framing_type, args = _body_framing( + cast(bytes, self._request_method), cast(Union[Request, Response], event) + ) + return io_dict[SEND_BODY][framing_type](*args) # type: ignore[index] + else: + # General case: the io_dict just has the appropriate reader/writer + # for this state + return io_dict.get((role, state)) # type: ignore[return-value] + + # This must be called after any action that might have caused + # self._cstate.states to change. + def _respond_to_state_changes( + self, + old_states: Dict[Type[Sentinel], Type[Sentinel]], + event: Optional[Event] = None, + ) -> None: + # Update reader/writer + if self.our_state != old_states[self.our_role]: + self._writer = self._get_io_object(self.our_role, event, WRITERS) + if self.their_state != old_states[self.their_role]: + self._reader = self._get_io_object(self.their_role, event, READERS) + + @property + def trailing_data(self) -> Tuple[bytes, bool]: + """Data that has been received, but not yet processed, represented as + a tuple with two elements, where the first is a byte-string containing + the unprocessed data itself, and the second is a bool that is True if + the receive connection was closed. + + See :ref:`switching-protocols` for discussion of why you'd want this. + """ + return (bytes(self._receive_buffer), self._receive_buffer_closed) + + def receive_data(self, data: bytes) -> None: + """Add data to our internal receive buffer. + + This does not actually do any processing on the data, just stores + it. To trigger processing, you have to call :meth:`next_event`. + + Args: + data (:term:`bytes-like object`): + The new data that was just received. + + Special case: If *data* is an empty byte-string like ``b""``, + then this indicates that the remote side has closed the + connection (end of file). Normally this is convenient, because + standard Python APIs like :meth:`file.read` or + :meth:`socket.recv` use ``b""`` to indicate end-of-file, while + other failures to read are indicated using other mechanisms + like raising :exc:`TimeoutError`. When using such an API you + can just blindly pass through whatever you get from ``read`` + to :meth:`receive_data`, and everything will work. + + But, if you have an API where reading an empty string is a + valid non-EOF condition, then you need to be aware of this and + make sure to check for such strings and avoid passing them to + :meth:`receive_data`. + + Returns: + Nothing, but after calling this you should call :meth:`next_event` + to parse the newly received data. + + Raises: + RuntimeError: + Raised if you pass an empty *data*, indicating EOF, and then + pass a non-empty *data*, indicating more data that somehow + arrived after the EOF. + + (Calling ``receive_data(b"")`` multiple times is fine, + and equivalent to calling it once.) + + """ + if data: + if self._receive_buffer_closed: + raise RuntimeError("received close, then received more data?") + self._receive_buffer += data + else: + self._receive_buffer_closed = True + + def _extract_next_receive_event( + self, + ) -> Union[Event, Type[NEED_DATA], Type[PAUSED]]: + state = self.their_state + # We don't pause immediately when they enter DONE, because even in + # DONE state we can still process a ConnectionClosed() event. But + # if we have data in our buffer, then we definitely aren't getting + # a ConnectionClosed() immediately and we need to pause. + if state is DONE and self._receive_buffer: + return PAUSED + if state is MIGHT_SWITCH_PROTOCOL or state is SWITCHED_PROTOCOL: + return PAUSED + assert self._reader is not None + event = self._reader(self._receive_buffer) + if event is None: + if not self._receive_buffer and self._receive_buffer_closed: + # In some unusual cases (basically just HTTP/1.0 bodies), EOF + # triggers an actual protocol event; in that case, we want to + # return that event, and then the state will change and we'll + # get called again to generate the actual ConnectionClosed(). + if hasattr(self._reader, "read_eof"): + event = self._reader.read_eof() # type: ignore[attr-defined] + else: + event = ConnectionClosed() + if event is None: + event = NEED_DATA + return event # type: ignore[no-any-return] + + def next_event(self) -> Union[Event, Type[NEED_DATA], Type[PAUSED]]: + """Parse the next event out of our receive buffer, update our internal + state, and return it. + + This is a mutating operation -- think of it like calling :func:`next` + on an iterator. + + Returns: + : One of three things: + + 1) An event object -- see :ref:`events`. + + 2) The special constant :data:`NEED_DATA`, which indicates that + you need to read more data from your socket and pass it to + :meth:`receive_data` before this method will be able to return + any more events. + + 3) The special constant :data:`PAUSED`, which indicates that we + are not in a state where we can process incoming data (usually + because the peer has finished their part of the current + request/response cycle, and you have not yet called + :meth:`start_next_cycle`). See :ref:`flow-control` for details. + + Raises: + RemoteProtocolError: + The peer has misbehaved. You should close the connection + (possibly after sending some kind of 4xx response). + + Once this method returns :class:`ConnectionClosed` once, then all + subsequent calls will also return :class:`ConnectionClosed`. + + If this method raises any exception besides :exc:`RemoteProtocolError` + then that's a bug -- if it happens please file a bug report! + + If this method raises any exception then it also sets + :attr:`Connection.their_state` to :data:`ERROR` -- see + :ref:`error-handling` for discussion. + + """ + + if self.their_state is ERROR: + raise RemoteProtocolError("Can't receive data when peer state is ERROR") + try: + event = self._extract_next_receive_event() + if event not in [NEED_DATA, PAUSED]: + self._process_event(self.their_role, cast(Event, event)) + if event is NEED_DATA: + if len(self._receive_buffer) > self._max_incomplete_event_size: + # 431 is "Request header fields too large" which is pretty + # much the only situation where we can get here + raise RemoteProtocolError( + "Receive buffer too long", error_status_hint=431 + ) + if self._receive_buffer_closed: + # We're still trying to complete some event, but that's + # never going to happen because no more data is coming + raise RemoteProtocolError("peer unexpectedly closed connection") + return event + except BaseException as exc: + self._process_error(self.their_role) + if isinstance(exc, LocalProtocolError): + exc._reraise_as_remote_protocol_error() + else: + raise + + def send(self, event: Event) -> Optional[bytes]: + """Convert a high-level event into bytes that can be sent to the peer, + while updating our internal state machine. + + Args: + event: The :ref:`event <events>` to send. + + Returns: + If ``type(event) is ConnectionClosed``, then returns + ``None``. Otherwise, returns a :term:`bytes-like object`. + + Raises: + LocalProtocolError: + Sending this event at this time would violate our + understanding of the HTTP/1.1 protocol. + + If this method raises any exception then it also sets + :attr:`Connection.our_state` to :data:`ERROR` -- see + :ref:`error-handling` for discussion. + + """ + data_list = self.send_with_data_passthrough(event) + if data_list is None: + return None + else: + return b"".join(data_list) + + def send_with_data_passthrough(self, event: Event) -> Optional[List[bytes]]: + """Identical to :meth:`send`, except that in situations where + :meth:`send` returns a single :term:`bytes-like object`, this instead + returns a list of them -- and when sending a :class:`Data` event, this + list is guaranteed to contain the exact object you passed in as + :attr:`Data.data`. See :ref:`sendfile` for discussion. + + """ + if self.our_state is ERROR: + raise LocalProtocolError("Can't send data when our state is ERROR") + try: + if type(event) is Response: + event = self._clean_up_response_headers_for_sending(event) + # We want to call _process_event before calling the writer, + # because if someone tries to do something invalid then this will + # give a sensible error message, while our writers all just assume + # they will only receive valid events. But, _process_event might + # change self._writer. So we have to do a little dance: + writer = self._writer + self._process_event(self.our_role, event) + if type(event) is ConnectionClosed: + return None + else: + # In any situation where writer is None, process_event should + # have raised ProtocolError + assert writer is not None + data_list: List[bytes] = [] + writer(event, data_list.append) + return data_list + except: + self._process_error(self.our_role) + raise + + def send_failed(self) -> None: + """Notify the state machine that we failed to send the data it gave + us. + + This causes :attr:`Connection.our_state` to immediately become + :data:`ERROR` -- see :ref:`error-handling` for discussion. + + """ + self._process_error(self.our_role) + + # When sending a Response, we take responsibility for a few things: + # + # - Sometimes you MUST set Connection: close. We take care of those + # times. (You can also set it yourself if you want, and if you do then + # we'll respect that and close the connection at the right time. But you + # don't have to worry about that unless you want to.) + # + # - The user has to set Content-Length if they want it. Otherwise, for + # responses that have bodies (e.g. not HEAD), then we will automatically + # select the right mechanism for streaming a body of unknown length, + # which depends on depending on the peer's HTTP version. + # + # This function's *only* responsibility is making sure headers are set up + # right -- everything downstream just looks at the headers. There are no + # side channels. + def _clean_up_response_headers_for_sending(self, response: Response) -> Response: + assert type(response) is Response + + headers = response.headers + need_close = False + + # HEAD requests need some special handling: they always act like they + # have Content-Length: 0, and that's how _body_framing treats + # them. But their headers are supposed to match what we would send if + # the request was a GET. (Technically there is one deviation allowed: + # we're allowed to leave out the framing headers -- see + # https://tools.ietf.org/html/rfc7231#section-4.3.2 . But it's just as + # easy to get them right.) + method_for_choosing_headers = cast(bytes, self._request_method) + if method_for_choosing_headers == b"HEAD": + method_for_choosing_headers = b"GET" + framing_type, _ = _body_framing(method_for_choosing_headers, response) + if framing_type in ("chunked", "http/1.0"): + # This response has a body of unknown length. + # If our peer is HTTP/1.1, we use Transfer-Encoding: chunked + # If our peer is HTTP/1.0, we use no framing headers, and close the + # connection afterwards. + # + # Make sure to clear Content-Length (in principle user could have + # set both and then we ignored Content-Length b/c + # Transfer-Encoding overwrote it -- this would be naughty of them, + # but the HTTP spec says that if our peer does this then we have + # to fix it instead of erroring out, so we'll accord the user the + # same respect). + headers = set_comma_header(headers, b"content-length", []) + if self.their_http_version is None or self.their_http_version < b"1.1": + # Either we never got a valid request and are sending back an + # error (their_http_version is None), so we assume the worst; + # or else we did get a valid HTTP/1.0 request, so we know that + # they don't understand chunked encoding. + headers = set_comma_header(headers, b"transfer-encoding", []) + # This is actually redundant ATM, since currently we + # unconditionally disable keep-alive when talking to HTTP/1.0 + # peers. But let's be defensive just in case we add + # Connection: keep-alive support later: + if self._request_method != b"HEAD": + need_close = True + else: + headers = set_comma_header(headers, b"transfer-encoding", [b"chunked"]) + + if not self._cstate.keep_alive or need_close: + # Make sure Connection: close is set + connection = set(get_comma_header(headers, b"connection")) + connection.discard(b"keep-alive") + connection.add(b"close") + headers = set_comma_header(headers, b"connection", sorted(connection)) + + return Response( + headers=headers, + status_code=response.status_code, + http_version=response.http_version, + reason=response.reason, + ) diff --git a/contrib/python/h11/h11/_events.py b/contrib/python/h11/h11/_events.py new file mode 100644 index 0000000000..075bf8a469 --- /dev/null +++ b/contrib/python/h11/h11/_events.py @@ -0,0 +1,369 @@ +# High level events that make up HTTP/1.1 conversations. Loosely inspired by +# the corresponding events in hyper-h2: +# +# http://python-hyper.org/h2/en/stable/api.html#events +# +# Don't subclass these. Stuff will break. + +import re +from abc import ABC +from dataclasses import dataclass, field +from typing import Any, cast, Dict, List, Tuple, Union + +from ._abnf import method, request_target +from ._headers import Headers, normalize_and_validate +from ._util import bytesify, LocalProtocolError, validate + +# Everything in __all__ gets re-exported as part of the h11 public API. +__all__ = [ + "Event", + "Request", + "InformationalResponse", + "Response", + "Data", + "EndOfMessage", + "ConnectionClosed", +] + +method_re = re.compile(method.encode("ascii")) +request_target_re = re.compile(request_target.encode("ascii")) + + +class Event(ABC): + """ + Base class for h11 events. + """ + + __slots__ = () + + +@dataclass(init=False, frozen=True) +class Request(Event): + """The beginning of an HTTP request. + + Fields: + + .. attribute:: method + + An HTTP method, e.g. ``b"GET"`` or ``b"POST"``. Always a byte + string. :term:`Bytes-like objects <bytes-like object>` and native + strings containing only ascii characters will be automatically + converted to byte strings. + + .. attribute:: target + + The target of an HTTP request, e.g. ``b"/index.html"``, or one of the + more exotic formats described in `RFC 7320, section 5.3 + <https://tools.ietf.org/html/rfc7230#section-5.3>`_. Always a byte + string. :term:`Bytes-like objects <bytes-like object>` and native + strings containing only ascii characters will be automatically + converted to byte strings. + + .. attribute:: headers + + Request headers, represented as a list of (name, value) pairs. See + :ref:`the header normalization rules <headers-format>` for details. + + .. attribute:: http_version + + The HTTP protocol version, represented as a byte string like + ``b"1.1"``. See :ref:`the HTTP version normalization rules + <http_version-format>` for details. + + """ + + __slots__ = ("method", "headers", "target", "http_version") + + method: bytes + headers: Headers + target: bytes + http_version: bytes + + def __init__( + self, + *, + method: Union[bytes, str], + headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]], + target: Union[bytes, str], + http_version: Union[bytes, str] = b"1.1", + _parsed: bool = False, + ) -> None: + super().__init__() + if isinstance(headers, Headers): + object.__setattr__(self, "headers", headers) + else: + object.__setattr__( + self, "headers", normalize_and_validate(headers, _parsed=_parsed) + ) + if not _parsed: + object.__setattr__(self, "method", bytesify(method)) + object.__setattr__(self, "target", bytesify(target)) + object.__setattr__(self, "http_version", bytesify(http_version)) + else: + object.__setattr__(self, "method", method) + object.__setattr__(self, "target", target) + object.__setattr__(self, "http_version", http_version) + + # "A server MUST respond with a 400 (Bad Request) status code to any + # HTTP/1.1 request message that lacks a Host header field and to any + # request message that contains more than one Host header field or a + # Host header field with an invalid field-value." + # -- https://tools.ietf.org/html/rfc7230#section-5.4 + host_count = 0 + for name, value in self.headers: + if name == b"host": + host_count += 1 + if self.http_version == b"1.1" and host_count == 0: + raise LocalProtocolError("Missing mandatory Host: header") + if host_count > 1: + raise LocalProtocolError("Found multiple Host: headers") + + validate(method_re, self.method, "Illegal method characters") + validate(request_target_re, self.target, "Illegal target characters") + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class _ResponseBase(Event): + __slots__ = ("headers", "http_version", "reason", "status_code") + + headers: Headers + http_version: bytes + reason: bytes + status_code: int + + def __init__( + self, + *, + headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]], + status_code: int, + http_version: Union[bytes, str] = b"1.1", + reason: Union[bytes, str] = b"", + _parsed: bool = False, + ) -> None: + super().__init__() + if isinstance(headers, Headers): + object.__setattr__(self, "headers", headers) + else: + object.__setattr__( + self, "headers", normalize_and_validate(headers, _parsed=_parsed) + ) + if not _parsed: + object.__setattr__(self, "reason", bytesify(reason)) + object.__setattr__(self, "http_version", bytesify(http_version)) + if not isinstance(status_code, int): + raise LocalProtocolError("status code must be integer") + # Because IntEnum objects are instances of int, but aren't + # duck-compatible (sigh), see gh-72. + object.__setattr__(self, "status_code", int(status_code)) + else: + object.__setattr__(self, "reason", reason) + object.__setattr__(self, "http_version", http_version) + object.__setattr__(self, "status_code", status_code) + + self.__post_init__() + + def __post_init__(self) -> None: + pass + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class InformationalResponse(_ResponseBase): + """An HTTP informational response. + + Fields: + + .. attribute:: status_code + + The status code of this response, as an integer. For an + :class:`InformationalResponse`, this is always in the range [100, + 200). + + .. attribute:: headers + + Request headers, represented as a list of (name, value) pairs. See + :ref:`the header normalization rules <headers-format>` for + details. + + .. attribute:: http_version + + The HTTP protocol version, represented as a byte string like + ``b"1.1"``. See :ref:`the HTTP version normalization rules + <http_version-format>` for details. + + .. attribute:: reason + + The reason phrase of this response, as a byte string. For example: + ``b"OK"``, or ``b"Not Found"``. + + """ + + def __post_init__(self) -> None: + if not (100 <= self.status_code < 200): + raise LocalProtocolError( + "InformationalResponse status_code should be in range " + "[100, 200), not {}".format(self.status_code) + ) + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class Response(_ResponseBase): + """The beginning of an HTTP response. + + Fields: + + .. attribute:: status_code + + The status code of this response, as an integer. For an + :class:`Response`, this is always in the range [200, + 1000). + + .. attribute:: headers + + Request headers, represented as a list of (name, value) pairs. See + :ref:`the header normalization rules <headers-format>` for details. + + .. attribute:: http_version + + The HTTP protocol version, represented as a byte string like + ``b"1.1"``. See :ref:`the HTTP version normalization rules + <http_version-format>` for details. + + .. attribute:: reason + + The reason phrase of this response, as a byte string. For example: + ``b"OK"``, or ``b"Not Found"``. + + """ + + def __post_init__(self) -> None: + if not (200 <= self.status_code < 1000): + raise LocalProtocolError( + "Response status_code should be in range [200, 1000), not {}".format( + self.status_code + ) + ) + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class Data(Event): + """Part of an HTTP message body. + + Fields: + + .. attribute:: data + + A :term:`bytes-like object` containing part of a message body. Or, if + using the ``combine=False`` argument to :meth:`Connection.send`, then + any object that your socket writing code knows what to do with, and for + which calling :func:`len` returns the number of bytes that will be + written -- see :ref:`sendfile` for details. + + .. attribute:: chunk_start + + A marker that indicates whether this data object is from the start of a + chunked transfer encoding chunk. This field is ignored when when a Data + event is provided to :meth:`Connection.send`: it is only valid on + events emitted from :meth:`Connection.next_event`. You probably + shouldn't use this attribute at all; see + :ref:`chunk-delimiters-are-bad` for details. + + .. attribute:: chunk_end + + A marker that indicates whether this data object is the last for a + given chunked transfer encoding chunk. This field is ignored when when + a Data event is provided to :meth:`Connection.send`: it is only valid + on events emitted from :meth:`Connection.next_event`. You probably + shouldn't use this attribute at all; see + :ref:`chunk-delimiters-are-bad` for details. + + """ + + __slots__ = ("data", "chunk_start", "chunk_end") + + data: bytes + chunk_start: bool + chunk_end: bool + + def __init__( + self, data: bytes, chunk_start: bool = False, chunk_end: bool = False + ) -> None: + object.__setattr__(self, "data", data) + object.__setattr__(self, "chunk_start", chunk_start) + object.__setattr__(self, "chunk_end", chunk_end) + + # This is an unhashable type. + __hash__ = None # type: ignore + + +# XX FIXME: "A recipient MUST ignore (or consider as an error) any fields that +# are forbidden to be sent in a trailer, since processing them as if they were +# present in the header section might bypass external security filters." +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#chunked.trailer.part +# Unfortunately, the list of forbidden fields is long and vague :-/ +@dataclass(init=False, frozen=True) +class EndOfMessage(Event): + """The end of an HTTP message. + + Fields: + + .. attribute:: headers + + Default value: ``[]`` + + Any trailing headers attached to this message, represented as a list of + (name, value) pairs. See :ref:`the header normalization rules + <headers-format>` for details. + + Must be empty unless ``Transfer-Encoding: chunked`` is in use. + + """ + + __slots__ = ("headers",) + + headers: Headers + + def __init__( + self, + *, + headers: Union[ + Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]], None + ] = None, + _parsed: bool = False, + ) -> None: + super().__init__() + if headers is None: + headers = Headers([]) + elif not isinstance(headers, Headers): + headers = normalize_and_validate(headers, _parsed=_parsed) + + object.__setattr__(self, "headers", headers) + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(frozen=True) +class ConnectionClosed(Event): + """This event indicates that the sender has closed their outgoing + connection. + + Note that this does not necessarily mean that they can't *receive* further + data, because TCP connections are composed to two one-way channels which + can be closed independently. See :ref:`closing` for details. + + No fields. + """ + + pass diff --git a/contrib/python/h11/h11/_headers.py b/contrib/python/h11/h11/_headers.py new file mode 100644 index 0000000000..b97d020b63 --- /dev/null +++ b/contrib/python/h11/h11/_headers.py @@ -0,0 +1,278 @@ +import re +from typing import AnyStr, cast, List, overload, Sequence, Tuple, TYPE_CHECKING, Union + +from ._abnf import field_name, field_value +from ._util import bytesify, LocalProtocolError, validate + +if TYPE_CHECKING: + from ._events import Request + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + + +# Facts +# ----- +# +# Headers are: +# keys: case-insensitive ascii +# values: mixture of ascii and raw bytes +# +# "Historically, HTTP has allowed field content with text in the ISO-8859-1 +# charset [ISO-8859-1], supporting other charsets only through use of +# [RFC2047] encoding. In practice, most HTTP header field values use only a +# subset of the US-ASCII charset [USASCII]. Newly defined header fields SHOULD +# limit their field values to US-ASCII octets. A recipient SHOULD treat other +# octets in field content (obs-text) as opaque data." +# And it deprecates all non-ascii values +# +# Leading/trailing whitespace in header names is forbidden +# +# Values get leading/trailing whitespace stripped +# +# Content-Disposition actually needs to contain unicode semantically; to +# accomplish this it has a terrifically weird way of encoding the filename +# itself as ascii (and even this still has lots of cross-browser +# incompatibilities) +# +# Order is important: +# "a proxy MUST NOT change the order of these field values when forwarding a +# message" +# (and there are several headers where the order indicates a preference) +# +# Multiple occurences of the same header: +# "A sender MUST NOT generate multiple header fields with the same field name +# in a message unless either the entire field value for that header field is +# defined as a comma-separated list [or the header is Set-Cookie which gets a +# special exception]" - RFC 7230. (cookies are in RFC 6265) +# +# So every header aside from Set-Cookie can be merged by b", ".join if it +# occurs repeatedly. But, of course, they can't necessarily be split by +# .split(b","), because quoting. +# +# Given all this mess (case insensitive, duplicates allowed, order is +# important, ...), there doesn't appear to be any standard way to handle +# headers in Python -- they're almost like dicts, but... actually just +# aren't. For now we punt and just use a super simple representation: headers +# are a list of pairs +# +# [(name1, value1), (name2, value2), ...] +# +# where all entries are bytestrings, names are lowercase and have no +# leading/trailing whitespace, and values are bytestrings with no +# leading/trailing whitespace. Searching and updating are done via naive O(n) +# methods. +# +# Maybe a dict-of-lists would be better? + +_content_length_re = re.compile(rb"[0-9]+") +_field_name_re = re.compile(field_name.encode("ascii")) +_field_value_re = re.compile(field_value.encode("ascii")) + + +class Headers(Sequence[Tuple[bytes, bytes]]): + """ + A list-like interface that allows iterating over headers as byte-pairs + of (lowercased-name, value). + + Internally we actually store the representation as three-tuples, + including both the raw original casing, in order to preserve casing + over-the-wire, and the lowercased name, for case-insensitive comparisions. + + r = Request( + method="GET", + target="/", + headers=[("Host", "example.org"), ("Connection", "keep-alive")], + http_version="1.1", + ) + assert r.headers == [ + (b"host", b"example.org"), + (b"connection", b"keep-alive") + ] + assert r.headers.raw_items() == [ + (b"Host", b"example.org"), + (b"Connection", b"keep-alive") + ] + """ + + __slots__ = "_full_items" + + def __init__(self, full_items: List[Tuple[bytes, bytes, bytes]]) -> None: + self._full_items = full_items + + def __bool__(self) -> bool: + return bool(self._full_items) + + def __eq__(self, other: object) -> bool: + return list(self) == list(other) # type: ignore + + def __len__(self) -> int: + return len(self._full_items) + + def __repr__(self) -> str: + return "<Headers(%s)>" % repr(list(self)) + + def __getitem__(self, idx: int) -> Tuple[bytes, bytes]: # type: ignore[override] + _, name, value = self._full_items[idx] + return (name, value) + + def raw_items(self) -> List[Tuple[bytes, bytes]]: + return [(raw_name, value) for raw_name, _, value in self._full_items] + + +HeaderTypes = Union[ + List[Tuple[bytes, bytes]], + List[Tuple[bytes, str]], + List[Tuple[str, bytes]], + List[Tuple[str, str]], +] + + +@overload +def normalize_and_validate(headers: Headers, _parsed: Literal[True]) -> Headers: + ... + + +@overload +def normalize_and_validate(headers: HeaderTypes, _parsed: Literal[False]) -> Headers: + ... + + +@overload +def normalize_and_validate( + headers: Union[Headers, HeaderTypes], _parsed: bool = False +) -> Headers: + ... + + +def normalize_and_validate( + headers: Union[Headers, HeaderTypes], _parsed: bool = False +) -> Headers: + new_headers = [] + seen_content_length = None + saw_transfer_encoding = False + for name, value in headers: + # For headers coming out of the parser, we can safely skip some steps, + # because it always returns bytes and has already run these regexes + # over the data: + if not _parsed: + name = bytesify(name) + value = bytesify(value) + validate(_field_name_re, name, "Illegal header name {!r}", name) + validate(_field_value_re, value, "Illegal header value {!r}", value) + assert isinstance(name, bytes) + assert isinstance(value, bytes) + + raw_name = name + name = name.lower() + if name == b"content-length": + lengths = {length.strip() for length in value.split(b",")} + if len(lengths) != 1: + raise LocalProtocolError("conflicting Content-Length headers") + value = lengths.pop() + validate(_content_length_re, value, "bad Content-Length") + if seen_content_length is None: + seen_content_length = value + new_headers.append((raw_name, name, value)) + elif seen_content_length != value: + raise LocalProtocolError("conflicting Content-Length headers") + elif name == b"transfer-encoding": + # "A server that receives a request message with a transfer coding + # it does not understand SHOULD respond with 501 (Not + # Implemented)." + # https://tools.ietf.org/html/rfc7230#section-3.3.1 + if saw_transfer_encoding: + raise LocalProtocolError( + "multiple Transfer-Encoding headers", error_status_hint=501 + ) + # "All transfer-coding names are case-insensitive" + # -- https://tools.ietf.org/html/rfc7230#section-4 + value = value.lower() + if value != b"chunked": + raise LocalProtocolError( + "Only Transfer-Encoding: chunked is supported", + error_status_hint=501, + ) + saw_transfer_encoding = True + new_headers.append((raw_name, name, value)) + else: + new_headers.append((raw_name, name, value)) + return Headers(new_headers) + + +def get_comma_header(headers: Headers, name: bytes) -> List[bytes]: + # Should only be used for headers whose value is a list of + # comma-separated, case-insensitive values. + # + # The header name `name` is expected to be lower-case bytes. + # + # Connection: meets these criteria (including cast insensitivity). + # + # Content-Length: technically is just a single value (1*DIGIT), but the + # standard makes reference to implementations that do multiple values, and + # using this doesn't hurt. Ditto, case insensitivity doesn't things either + # way. + # + # Transfer-Encoding: is more complex (allows for quoted strings), so + # splitting on , is actually wrong. For example, this is legal: + # + # Transfer-Encoding: foo; options="1,2", chunked + # + # and should be parsed as + # + # foo; options="1,2" + # chunked + # + # but this naive function will parse it as + # + # foo; options="1 + # 2" + # chunked + # + # However, this is okay because the only thing we are going to do with + # any Transfer-Encoding is reject ones that aren't just "chunked", so + # both of these will be treated the same anyway. + # + # Expect: the only legal value is the literal string + # "100-continue". Splitting on commas is harmless. Case insensitive. + # + out: List[bytes] = [] + for _, found_name, found_raw_value in headers._full_items: + if found_name == name: + found_raw_value = found_raw_value.lower() + for found_split_value in found_raw_value.split(b","): + found_split_value = found_split_value.strip() + if found_split_value: + out.append(found_split_value) + return out + + +def set_comma_header(headers: Headers, name: bytes, new_values: List[bytes]) -> Headers: + # The header name `name` is expected to be lower-case bytes. + # + # Note that when we store the header we use title casing for the header + # names, in order to match the conventional HTTP header style. + # + # Simply calling `.title()` is a blunt approach, but it's correct + # here given the cases where we're using `set_comma_header`... + # + # Connection, Content-Length, Transfer-Encoding. + new_headers: List[Tuple[bytes, bytes]] = [] + for found_raw_name, found_name, found_raw_value in headers._full_items: + if found_name != name: + new_headers.append((found_raw_name, found_raw_value)) + for new_value in new_values: + new_headers.append((name.title(), new_value)) + return normalize_and_validate(new_headers) + + +def has_expect_100_continue(request: "Request") -> bool: + # https://tools.ietf.org/html/rfc7231#section-5.1.1 + # "A server that receives a 100-continue expectation in an HTTP/1.0 request + # MUST ignore that expectation." + if request.http_version < b"1.1": + return False + expect = get_comma_header(request.headers, b"expect") + return b"100-continue" in expect diff --git a/contrib/python/h11/h11/_readers.py b/contrib/python/h11/h11/_readers.py new file mode 100644 index 0000000000..08a9574da4 --- /dev/null +++ b/contrib/python/h11/h11/_readers.py @@ -0,0 +1,247 @@ +# Code to read HTTP data +# +# Strategy: each reader is a callable which takes a ReceiveBuffer object, and +# either: +# 1) consumes some of it and returns an Event +# 2) raises a LocalProtocolError (for consistency -- e.g. we call validate() +# and it might raise a LocalProtocolError, so simpler just to always use +# this) +# 3) returns None, meaning "I need more data" +# +# If they have a .read_eof attribute, then this will be called if an EOF is +# received -- but this is optional. Either way, the actual ConnectionClosed +# event will be generated afterwards. +# +# READERS is a dict describing how to pick a reader. It maps states to either: +# - a reader +# - or, for body readers, a dict of per-framing reader factories + +import re +from typing import Any, Callable, Dict, Iterable, NoReturn, Optional, Tuple, Type, Union + +from ._abnf import chunk_header, header_field, request_line, status_line +from ._events import Data, EndOfMessage, InformationalResponse, Request, Response +from ._receivebuffer import ReceiveBuffer +from ._state import ( + CLIENT, + CLOSED, + DONE, + IDLE, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, +) +from ._util import LocalProtocolError, RemoteProtocolError, Sentinel, validate + +__all__ = ["READERS"] + +header_field_re = re.compile(header_field.encode("ascii")) +obs_fold_re = re.compile(rb"[ \t]+") + + +def _obsolete_line_fold(lines: Iterable[bytes]) -> Iterable[bytes]: + it = iter(lines) + last: Optional[bytes] = None + for line in it: + match = obs_fold_re.match(line) + if match: + if last is None: + raise LocalProtocolError("continuation line at start of headers") + if not isinstance(last, bytearray): + # Cast to a mutable type, avoiding copy on append to ensure O(n) time + last = bytearray(last) + last += b" " + last += line[match.end() :] + else: + if last is not None: + yield last + last = line + if last is not None: + yield last + + +def _decode_header_lines( + lines: Iterable[bytes], +) -> Iterable[Tuple[bytes, bytes]]: + for line in _obsolete_line_fold(lines): + matches = validate(header_field_re, line, "illegal header line: {!r}", line) + yield (matches["field_name"], matches["field_value"]) + + +request_line_re = re.compile(request_line.encode("ascii")) + + +def maybe_read_from_IDLE_client(buf: ReceiveBuffer) -> Optional[Request]: + lines = buf.maybe_extract_lines() + if lines is None: + if buf.is_next_line_obviously_invalid_request_line(): + raise LocalProtocolError("illegal request line") + return None + if not lines: + raise LocalProtocolError("no request line received") + matches = validate( + request_line_re, lines[0], "illegal request line: {!r}", lines[0] + ) + return Request( + headers=list(_decode_header_lines(lines[1:])), _parsed=True, **matches + ) + + +status_line_re = re.compile(status_line.encode("ascii")) + + +def maybe_read_from_SEND_RESPONSE_server( + buf: ReceiveBuffer, +) -> Union[InformationalResponse, Response, None]: + lines = buf.maybe_extract_lines() + if lines is None: + if buf.is_next_line_obviously_invalid_request_line(): + raise LocalProtocolError("illegal request line") + return None + if not lines: + raise LocalProtocolError("no response line received") + matches = validate(status_line_re, lines[0], "illegal status line: {!r}", lines[0]) + http_version = ( + b"1.1" if matches["http_version"] is None else matches["http_version"] + ) + reason = b"" if matches["reason"] is None else matches["reason"] + status_code = int(matches["status_code"]) + class_: Union[Type[InformationalResponse], Type[Response]] = ( + InformationalResponse if status_code < 200 else Response + ) + return class_( + headers=list(_decode_header_lines(lines[1:])), + _parsed=True, + status_code=status_code, + reason=reason, + http_version=http_version, + ) + + +class ContentLengthReader: + def __init__(self, length: int) -> None: + self._length = length + self._remaining = length + + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: + if self._remaining == 0: + return EndOfMessage() + data = buf.maybe_extract_at_most(self._remaining) + if data is None: + return None + self._remaining -= len(data) + return Data(data=data) + + def read_eof(self) -> NoReturn: + raise RemoteProtocolError( + "peer closed connection without sending complete message body " + "(received {} bytes, expected {})".format( + self._length - self._remaining, self._length + ) + ) + + +chunk_header_re = re.compile(chunk_header.encode("ascii")) + + +class ChunkedReader: + def __init__(self) -> None: + self._bytes_in_chunk = 0 + # After reading a chunk, we have to throw away the trailing \r\n; if + # this is >0 then we discard that many bytes before resuming regular + # de-chunkification. + self._bytes_to_discard = 0 + self._reading_trailer = False + + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: + if self._reading_trailer: + lines = buf.maybe_extract_lines() + if lines is None: + return None + return EndOfMessage(headers=list(_decode_header_lines(lines))) + if self._bytes_to_discard > 0: + data = buf.maybe_extract_at_most(self._bytes_to_discard) + if data is None: + return None + self._bytes_to_discard -= len(data) + if self._bytes_to_discard > 0: + return None + # else, fall through and read some more + assert self._bytes_to_discard == 0 + if self._bytes_in_chunk == 0: + # We need to refill our chunk count + chunk_header = buf.maybe_extract_next_line() + if chunk_header is None: + return None + matches = validate( + chunk_header_re, + chunk_header, + "illegal chunk header: {!r}", + chunk_header, + ) + # XX FIXME: we discard chunk extensions. Does anyone care? + self._bytes_in_chunk = int(matches["chunk_size"], base=16) + if self._bytes_in_chunk == 0: + self._reading_trailer = True + return self(buf) + chunk_start = True + else: + chunk_start = False + assert self._bytes_in_chunk > 0 + data = buf.maybe_extract_at_most(self._bytes_in_chunk) + if data is None: + return None + self._bytes_in_chunk -= len(data) + if self._bytes_in_chunk == 0: + self._bytes_to_discard = 2 + chunk_end = True + else: + chunk_end = False + return Data(data=data, chunk_start=chunk_start, chunk_end=chunk_end) + + def read_eof(self) -> NoReturn: + raise RemoteProtocolError( + "peer closed connection without sending complete message body " + "(incomplete chunked read)" + ) + + +class Http10Reader: + def __call__(self, buf: ReceiveBuffer) -> Optional[Data]: + data = buf.maybe_extract_at_most(999999999) + if data is None: + return None + return Data(data=data) + + def read_eof(self) -> EndOfMessage: + return EndOfMessage() + + +def expect_nothing(buf: ReceiveBuffer) -> None: + if buf: + raise LocalProtocolError("Got data when expecting EOF") + return None + + +ReadersType = Dict[ + Union[Type[Sentinel], Tuple[Type[Sentinel], Type[Sentinel]]], + Union[Callable[..., Any], Dict[str, Callable[..., Any]]], +] + +READERS: ReadersType = { + (CLIENT, IDLE): maybe_read_from_IDLE_client, + (SERVER, IDLE): maybe_read_from_SEND_RESPONSE_server, + (SERVER, SEND_RESPONSE): maybe_read_from_SEND_RESPONSE_server, + (CLIENT, DONE): expect_nothing, + (CLIENT, MUST_CLOSE): expect_nothing, + (CLIENT, CLOSED): expect_nothing, + (SERVER, DONE): expect_nothing, + (SERVER, MUST_CLOSE): expect_nothing, + (SERVER, CLOSED): expect_nothing, + SEND_BODY: { + "chunked": ChunkedReader, + "content-length": ContentLengthReader, + "http/1.0": Http10Reader, + }, +} diff --git a/contrib/python/h11/h11/_receivebuffer.py b/contrib/python/h11/h11/_receivebuffer.py new file mode 100644 index 0000000000..e5c4e08a56 --- /dev/null +++ b/contrib/python/h11/h11/_receivebuffer.py @@ -0,0 +1,153 @@ +import re +import sys +from typing import List, Optional, Union + +__all__ = ["ReceiveBuffer"] + + +# Operations we want to support: +# - find next \r\n or \r\n\r\n (\n or \n\n are also acceptable), +# or wait until there is one +# - read at-most-N bytes +# Goals: +# - on average, do this fast +# - worst case, do this in O(n) where n is the number of bytes processed +# Plan: +# - store bytearray, offset, how far we've searched for a separator token +# - use the how-far-we've-searched data to avoid rescanning +# - while doing a stream of uninterrupted processing, advance offset instead +# of constantly copying +# WARNING: +# - I haven't benchmarked or profiled any of this yet. +# +# Note that starting in Python 3.4, deleting the initial n bytes from a +# bytearray is amortized O(n), thanks to some excellent work by Antoine +# Martin: +# +# https://bugs.python.org/issue19087 +# +# This means that if we only supported 3.4+, we could get rid of the code here +# involving self._start and self.compress, because it's doing exactly the same +# thing that bytearray now does internally. +# +# BUT unfortunately, we still support 2.7, and reading short segments out of a +# long buffer MUST be O(bytes read) to avoid DoS issues, so we can't actually +# delete this code. Yet: +# +# https://pythonclock.org/ +# +# (Two things to double-check first though: make sure PyPy also has the +# optimization, and benchmark to make sure it's a win, since we do have a +# slightly clever thing where we delay calling compress() until we've +# processed a whole event, which could in theory be slightly more efficient +# than the internal bytearray support.) +blank_line_regex = re.compile(b"\n\r?\n", re.MULTILINE) + + +class ReceiveBuffer: + def __init__(self) -> None: + self._data = bytearray() + self._next_line_search = 0 + self._multiple_lines_search = 0 + + def __iadd__(self, byteslike: Union[bytes, bytearray]) -> "ReceiveBuffer": + self._data += byteslike + return self + + def __bool__(self) -> bool: + return bool(len(self)) + + def __len__(self) -> int: + return len(self._data) + + # for @property unprocessed_data + def __bytes__(self) -> bytes: + return bytes(self._data) + + def _extract(self, count: int) -> bytearray: + # extracting an initial slice of the data buffer and return it + out = self._data[:count] + del self._data[:count] + + self._next_line_search = 0 + self._multiple_lines_search = 0 + + return out + + def maybe_extract_at_most(self, count: int) -> Optional[bytearray]: + """ + Extract a fixed number of bytes from the buffer. + """ + out = self._data[:count] + if not out: + return None + + return self._extract(count) + + def maybe_extract_next_line(self) -> Optional[bytearray]: + """ + Extract the first line, if it is completed in the buffer. + """ + # Only search in buffer space that we've not already looked at. + search_start_index = max(0, self._next_line_search - 1) + partial_idx = self._data.find(b"\r\n", search_start_index) + + if partial_idx == -1: + self._next_line_search = len(self._data) + return None + + # + 2 is to compensate len(b"\r\n") + idx = partial_idx + 2 + + return self._extract(idx) + + def maybe_extract_lines(self) -> Optional[List[bytearray]]: + """ + Extract everything up to the first blank line, and return a list of lines. + """ + # Handle the case where we have an immediate empty line. + if self._data[:1] == b"\n": + self._extract(1) + return [] + + if self._data[:2] == b"\r\n": + self._extract(2) + return [] + + # Only search in buffer space that we've not already looked at. + match = blank_line_regex.search(self._data, self._multiple_lines_search) + if match is None: + self._multiple_lines_search = max(0, len(self._data) - 2) + return None + + # Truncate the buffer and return it. + idx = match.span(0)[-1] + out = self._extract(idx) + lines = out.split(b"\n") + + for line in lines: + if line.endswith(b"\r"): + del line[-1] + + assert lines[-2] == lines[-1] == b"" + + del lines[-2:] + + return lines + + # In theory we should wait until `\r\n` before starting to validate + # incoming data. However it's interesting to detect (very) invalid data + # early given they might not even contain `\r\n` at all (hence only + # timeout will get rid of them). + # This is not a 100% effective detection but more of a cheap sanity check + # allowing for early abort in some useful cases. + # This is especially interesting when peer is messing up with HTTPS and + # sent us a TLS stream where we were expecting plain HTTP given all + # versions of TLS so far start handshake with a 0x16 message type code. + def is_next_line_obviously_invalid_request_line(self) -> bool: + try: + # HTTP header line must not contain non-printable characters + # and should not start with a space + return self._data[0] < 0x21 + except IndexError: + return False diff --git a/contrib/python/h11/h11/_state.py b/contrib/python/h11/h11/_state.py new file mode 100644 index 0000000000..3593430a74 --- /dev/null +++ b/contrib/python/h11/h11/_state.py @@ -0,0 +1,367 @@ +################################################################ +# The core state machine +################################################################ +# +# Rule 1: everything that affects the state machine and state transitions must +# live here in this file. As much as possible goes into the table-based +# representation, but for the bits that don't quite fit, the actual code and +# state must nonetheless live here. +# +# Rule 2: this file does not know about what role we're playing; it only knows +# about HTTP request/response cycles in the abstract. This ensures that we +# don't cheat and apply different rules to local and remote parties. +# +# +# Theory of operation +# =================== +# +# Possibly the simplest way to think about this is that we actually have 5 +# different state machines here. Yes, 5. These are: +# +# 1) The client state, with its complicated automaton (see the docs) +# 2) The server state, with its complicated automaton (see the docs) +# 3) The keep-alive state, with possible states {True, False} +# 4) The SWITCH_CONNECT state, with possible states {False, True} +# 5) The SWITCH_UPGRADE state, with possible states {False, True} +# +# For (3)-(5), the first state listed is the initial state. +# +# (1)-(3) are stored explicitly in member variables. The last +# two are stored implicitly in the pending_switch_proposals set as: +# (state of 4) == (_SWITCH_CONNECT in pending_switch_proposals) +# (state of 5) == (_SWITCH_UPGRADE in pending_switch_proposals) +# +# And each of these machines has two different kinds of transitions: +# +# a) Event-triggered +# b) State-triggered +# +# Event triggered is the obvious thing that you'd think it is: some event +# happens, and if it's the right event at the right time then a transition +# happens. But there are somewhat complicated rules for which machines can +# "see" which events. (As a rule of thumb, if a machine "sees" an event, this +# means two things: the event can affect the machine, and if the machine is +# not in a state where it expects that event then it's an error.) These rules +# are: +# +# 1) The client machine sees all h11.events objects emitted by the client. +# +# 2) The server machine sees all h11.events objects emitted by the server. +# +# It also sees the client's Request event. +# +# And sometimes, server events are annotated with a _SWITCH_* event. For +# example, we can have a (Response, _SWITCH_CONNECT) event, which is +# different from a regular Response event. +# +# 3) The keep-alive machine sees the process_keep_alive_disabled() event +# (which is derived from Request/Response events), and this event +# transitions it from True -> False, or from False -> False. There's no way +# to transition back. +# +# 4&5) The _SWITCH_* machines transition from False->True when we get a +# Request that proposes the relevant type of switch (via +# process_client_switch_proposals), and they go from True->False when we +# get a Response that has no _SWITCH_* annotation. +# +# So that's event-triggered transitions. +# +# State-triggered transitions are less standard. What they do here is couple +# the machines together. The way this works is, when certain *joint* +# configurations of states are achieved, then we automatically transition to a +# new *joint* state. So, for example, if we're ever in a joint state with +# +# client: DONE +# keep-alive: False +# +# then the client state immediately transitions to: +# +# client: MUST_CLOSE +# +# This is fundamentally different from an event-based transition, because it +# doesn't matter how we arrived at the {client: DONE, keep-alive: False} state +# -- maybe the client transitioned SEND_BODY -> DONE, or keep-alive +# transitioned True -> False. Either way, once this precondition is satisfied, +# this transition is immediately triggered. +# +# What if two conflicting state-based transitions get enabled at the same +# time? In practice there's only one case where this arises (client DONE -> +# MIGHT_SWITCH_PROTOCOL versus DONE -> MUST_CLOSE), and we resolve it by +# explicitly prioritizing the DONE -> MIGHT_SWITCH_PROTOCOL transition. +# +# Implementation +# -------------- +# +# The event-triggered transitions for the server and client machines are all +# stored explicitly in a table. Ditto for the state-triggered transitions that +# involve just the server and client state. +# +# The transitions for the other machines, and the state-triggered transitions +# that involve the other machines, are written out as explicit Python code. +# +# It'd be nice if there were some cleaner way to do all this. This isn't +# *too* terrible, but I feel like it could probably be better. +# +# WARNING +# ------- +# +# The script that generates the state machine diagrams for the docs knows how +# to read out the EVENT_TRIGGERED_TRANSITIONS and STATE_TRIGGERED_TRANSITIONS +# tables. But it can't automatically read the transitions that are written +# directly in Python code. So if you touch those, you need to also update the +# script to keep it in sync! +from typing import cast, Dict, Optional, Set, Tuple, Type, Union + +from ._events import * +from ._util import LocalProtocolError, Sentinel + +# Everything in __all__ gets re-exported as part of the h11 public API. +__all__ = [ + "CLIENT", + "SERVER", + "IDLE", + "SEND_RESPONSE", + "SEND_BODY", + "DONE", + "MUST_CLOSE", + "CLOSED", + "MIGHT_SWITCH_PROTOCOL", + "SWITCHED_PROTOCOL", + "ERROR", +] + + +class CLIENT(Sentinel, metaclass=Sentinel): + pass + + +class SERVER(Sentinel, metaclass=Sentinel): + pass + + +# States +class IDLE(Sentinel, metaclass=Sentinel): + pass + + +class SEND_RESPONSE(Sentinel, metaclass=Sentinel): + pass + + +class SEND_BODY(Sentinel, metaclass=Sentinel): + pass + + +class DONE(Sentinel, metaclass=Sentinel): + pass + + +class MUST_CLOSE(Sentinel, metaclass=Sentinel): + pass + + +class CLOSED(Sentinel, metaclass=Sentinel): + pass + + +class ERROR(Sentinel, metaclass=Sentinel): + pass + + +# Switch types +class MIGHT_SWITCH_PROTOCOL(Sentinel, metaclass=Sentinel): + pass + + +class SWITCHED_PROTOCOL(Sentinel, metaclass=Sentinel): + pass + + +class _SWITCH_UPGRADE(Sentinel, metaclass=Sentinel): + pass + + +class _SWITCH_CONNECT(Sentinel, metaclass=Sentinel): + pass + + +EventTransitionType = Dict[ + Type[Sentinel], + Dict[ + Type[Sentinel], + Dict[Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]], Type[Sentinel]], + ], +] + +EVENT_TRIGGERED_TRANSITIONS: EventTransitionType = { + CLIENT: { + IDLE: {Request: SEND_BODY, ConnectionClosed: CLOSED}, + SEND_BODY: {Data: SEND_BODY, EndOfMessage: DONE}, + DONE: {ConnectionClosed: CLOSED}, + MUST_CLOSE: {ConnectionClosed: CLOSED}, + CLOSED: {ConnectionClosed: CLOSED}, + MIGHT_SWITCH_PROTOCOL: {}, + SWITCHED_PROTOCOL: {}, + ERROR: {}, + }, + SERVER: { + IDLE: { + ConnectionClosed: CLOSED, + Response: SEND_BODY, + # Special case: server sees client Request events, in this form + (Request, CLIENT): SEND_RESPONSE, + }, + SEND_RESPONSE: { + InformationalResponse: SEND_RESPONSE, + Response: SEND_BODY, + (InformationalResponse, _SWITCH_UPGRADE): SWITCHED_PROTOCOL, + (Response, _SWITCH_CONNECT): SWITCHED_PROTOCOL, + }, + SEND_BODY: {Data: SEND_BODY, EndOfMessage: DONE}, + DONE: {ConnectionClosed: CLOSED}, + MUST_CLOSE: {ConnectionClosed: CLOSED}, + CLOSED: {ConnectionClosed: CLOSED}, + SWITCHED_PROTOCOL: {}, + ERROR: {}, + }, +} + +StateTransitionType = Dict[ + Tuple[Type[Sentinel], Type[Sentinel]], Dict[Type[Sentinel], Type[Sentinel]] +] + +# NB: there are also some special-case state-triggered transitions hard-coded +# into _fire_state_triggered_transitions below. +STATE_TRIGGERED_TRANSITIONS: StateTransitionType = { + # (Client state, Server state) -> new states + # Protocol negotiation + (MIGHT_SWITCH_PROTOCOL, SWITCHED_PROTOCOL): {CLIENT: SWITCHED_PROTOCOL}, + # Socket shutdown + (CLOSED, DONE): {SERVER: MUST_CLOSE}, + (CLOSED, IDLE): {SERVER: MUST_CLOSE}, + (ERROR, DONE): {SERVER: MUST_CLOSE}, + (DONE, CLOSED): {CLIENT: MUST_CLOSE}, + (IDLE, CLOSED): {CLIENT: MUST_CLOSE}, + (DONE, ERROR): {CLIENT: MUST_CLOSE}, +} + + +class ConnectionState: + def __init__(self) -> None: + # Extra bits of state that don't quite fit into the state model. + + # If this is False then it enables the automatic DONE -> MUST_CLOSE + # transition. Don't set this directly; call .keep_alive_disabled() + self.keep_alive = True + + # This is a subset of {UPGRADE, CONNECT}, containing the proposals + # made by the client for switching protocols. + self.pending_switch_proposals: Set[Type[Sentinel]] = set() + + self.states: Dict[Type[Sentinel], Type[Sentinel]] = {CLIENT: IDLE, SERVER: IDLE} + + def process_error(self, role: Type[Sentinel]) -> None: + self.states[role] = ERROR + self._fire_state_triggered_transitions() + + def process_keep_alive_disabled(self) -> None: + self.keep_alive = False + self._fire_state_triggered_transitions() + + def process_client_switch_proposal(self, switch_event: Type[Sentinel]) -> None: + self.pending_switch_proposals.add(switch_event) + self._fire_state_triggered_transitions() + + def process_event( + self, + role: Type[Sentinel], + event_type: Type[Event], + server_switch_event: Optional[Type[Sentinel]] = None, + ) -> None: + _event_type: Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]] = event_type + if server_switch_event is not None: + assert role is SERVER + if server_switch_event not in self.pending_switch_proposals: + raise LocalProtocolError( + "Received server {} event without a pending proposal".format( + server_switch_event + ) + ) + _event_type = (event_type, server_switch_event) + if server_switch_event is None and _event_type is Response: + self.pending_switch_proposals = set() + self._fire_event_triggered_transitions(role, _event_type) + # Special case: the server state does get to see Request + # events. + if _event_type is Request: + assert role is CLIENT + self._fire_event_triggered_transitions(SERVER, (Request, CLIENT)) + self._fire_state_triggered_transitions() + + def _fire_event_triggered_transitions( + self, + role: Type[Sentinel], + event_type: Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]], + ) -> None: + state = self.states[role] + try: + new_state = EVENT_TRIGGERED_TRANSITIONS[role][state][event_type] + except KeyError: + event_type = cast(Type[Event], event_type) + raise LocalProtocolError( + "can't handle event type {} when role={} and state={}".format( + event_type.__name__, role, self.states[role] + ) + ) from None + self.states[role] = new_state + + def _fire_state_triggered_transitions(self) -> None: + # We apply these rules repeatedly until converging on a fixed point + while True: + start_states = dict(self.states) + + # It could happen that both these special-case transitions are + # enabled at the same time: + # + # DONE -> MIGHT_SWITCH_PROTOCOL + # DONE -> MUST_CLOSE + # + # For example, this will always be true of a HTTP/1.0 client + # requesting CONNECT. If this happens, the protocol switch takes + # priority. From there the client will either go to + # SWITCHED_PROTOCOL, in which case it's none of our business when + # they close the connection, or else the server will deny the + # request, in which case the client will go back to DONE and then + # from there to MUST_CLOSE. + if self.pending_switch_proposals: + if self.states[CLIENT] is DONE: + self.states[CLIENT] = MIGHT_SWITCH_PROTOCOL + + if not self.pending_switch_proposals: + if self.states[CLIENT] is MIGHT_SWITCH_PROTOCOL: + self.states[CLIENT] = DONE + + if not self.keep_alive: + for role in (CLIENT, SERVER): + if self.states[role] is DONE: + self.states[role] = MUST_CLOSE + + # Tabular state-triggered transitions + joint_state = (self.states[CLIENT], self.states[SERVER]) + changes = STATE_TRIGGERED_TRANSITIONS.get(joint_state, {}) + self.states.update(changes) + + if self.states == start_states: + # Fixed point reached + return + + def start_next_cycle(self) -> None: + if self.states != {CLIENT: DONE, SERVER: DONE}: + raise LocalProtocolError( + "not in a reusable state. self.states={}".format(self.states) + ) + # Can't reach DONE/DONE with any of these active, but still, let's be + # sure. + assert self.keep_alive + assert not self.pending_switch_proposals + self.states = {CLIENT: IDLE, SERVER: IDLE} diff --git a/contrib/python/h11/h11/_util.py b/contrib/python/h11/h11/_util.py new file mode 100644 index 0000000000..6718445290 --- /dev/null +++ b/contrib/python/h11/h11/_util.py @@ -0,0 +1,135 @@ +from typing import Any, Dict, NoReturn, Pattern, Tuple, Type, TypeVar, Union + +__all__ = [ + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", + "validate", + "bytesify", +] + + +class ProtocolError(Exception): + """Exception indicating a violation of the HTTP/1.1 protocol. + + This as an abstract base class, with two concrete base classes: + :exc:`LocalProtocolError`, which indicates that you tried to do something + that HTTP/1.1 says is illegal, and :exc:`RemoteProtocolError`, which + indicates that the remote peer tried to do something that HTTP/1.1 says is + illegal. See :ref:`error-handling` for details. + + In addition to the normal :exc:`Exception` features, it has one attribute: + + .. attribute:: error_status_hint + + This gives a suggestion as to what status code a server might use if + this error occurred as part of a request. + + For a :exc:`RemoteProtocolError`, this is useful as a suggestion for + how you might want to respond to a misbehaving peer, if you're + implementing a server. + + For a :exc:`LocalProtocolError`, this can be taken as a suggestion for + how your peer might have responded to *you* if h11 had allowed you to + continue. + + The default is 400 Bad Request, a generic catch-all for protocol + violations. + + """ + + def __init__(self, msg: str, error_status_hint: int = 400) -> None: + if type(self) is ProtocolError: + raise TypeError("tried to directly instantiate ProtocolError") + Exception.__init__(self, msg) + self.error_status_hint = error_status_hint + + +# Strategy: there are a number of public APIs where a LocalProtocolError can +# be raised (send(), all the different event constructors, ...), and only one +# public API where RemoteProtocolError can be raised +# (receive_data()). Therefore we always raise LocalProtocolError internally, +# and then receive_data will translate this into a RemoteProtocolError. +# +# Internally: +# LocalProtocolError is the generic "ProtocolError". +# Externally: +# LocalProtocolError is for local errors and RemoteProtocolError is for +# remote errors. +class LocalProtocolError(ProtocolError): + def _reraise_as_remote_protocol_error(self) -> NoReturn: + # After catching a LocalProtocolError, use this method to re-raise it + # as a RemoteProtocolError. This method must be called from inside an + # except: block. + # + # An easy way to get an equivalent RemoteProtocolError is just to + # modify 'self' in place. + self.__class__ = RemoteProtocolError # type: ignore + # But the re-raising is somewhat non-trivial -- you might think that + # now that we've modified the in-flight exception object, that just + # doing 'raise' to re-raise it would be enough. But it turns out that + # this doesn't work, because Python tracks the exception type + # (exc_info[0]) separately from the exception object (exc_info[1]), + # and we only modified the latter. So we really do need to re-raise + # the new type explicitly. + # On py3, the traceback is part of the exception object, so our + # in-place modification preserved it and we can just re-raise: + raise self + + +class RemoteProtocolError(ProtocolError): + pass + + +def validate( + regex: Pattern[bytes], data: bytes, msg: str = "malformed data", *format_args: Any +) -> Dict[str, bytes]: + match = regex.fullmatch(data) + if not match: + if format_args: + msg = msg.format(*format_args) + raise LocalProtocolError(msg) + return match.groupdict() + + +# Sentinel values +# +# - Inherit identity-based comparison and hashing from object +# - Have a nice repr +# - Have a *bonus property*: type(sentinel) is sentinel +# +# The bonus property is useful if you want to take the return value from +# next_event() and do some sort of dispatch based on type(event). + +_T_Sentinel = TypeVar("_T_Sentinel", bound="Sentinel") + + +class Sentinel(type): + def __new__( + cls: Type[_T_Sentinel], + name: str, + bases: Tuple[type, ...], + namespace: Dict[str, Any], + **kwds: Any + ) -> _T_Sentinel: + assert bases == (Sentinel,) + v = super().__new__(cls, name, bases, namespace, **kwds) + v.__class__ = v # type: ignore + return v + + def __repr__(self) -> str: + return self.__name__ + + +# Used for methods, request targets, HTTP versions, header names, and header +# values. Accepts ascii-strings, or bytes/bytearray/memoryview/..., and always +# returns bytes. +def bytesify(s: Union[bytes, bytearray, memoryview, int, str]) -> bytes: + # Fast-path: + if type(s) is bytes: + return s + if isinstance(s, str): + s = s.encode("ascii") + if isinstance(s, int): + raise TypeError("expected bytes-like object, not int") + return bytes(s) diff --git a/contrib/python/h11/h11/_version.py b/contrib/python/h11/h11/_version.py new file mode 100644 index 0000000000..4c89113056 --- /dev/null +++ b/contrib/python/h11/h11/_version.py @@ -0,0 +1,16 @@ +# This file must be kept very simple, because it is consumed from several +# places -- it is imported by h11/__init__.py, execfile'd by setup.py, etc. + +# We use a simple scheme: +# 1.0.0 -> 1.0.0+dev -> 1.1.0 -> 1.1.0+dev +# where the +dev versions are never released into the wild, they're just what +# we stick into the VCS in between releases. +# +# This is compatible with PEP 440: +# http://legacy.python.org/dev/peps/pep-0440/ +# via the use of the "local suffix" "+dev", which is disallowed on index +# servers and causes 1.0.0+dev to sort after plain 1.0.0, which is what we +# want. (Contrast with the special suffix 1.0.0.dev, which sorts *before* +# 1.0.0.) + +__version__ = "0.14.0" diff --git a/contrib/python/h11/h11/_writers.py b/contrib/python/h11/h11/_writers.py new file mode 100644 index 0000000000..939cdb912a --- /dev/null +++ b/contrib/python/h11/h11/_writers.py @@ -0,0 +1,145 @@ +# Code to read HTTP data +# +# Strategy: each writer takes an event + a write-some-bytes function, which is +# calls. +# +# WRITERS is a dict describing how to pick a reader. It maps states to either: +# - a writer +# - or, for body writers, a dict of framin-dependent writer factories + +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +from ._events import Data, EndOfMessage, Event, InformationalResponse, Request, Response +from ._headers import Headers +from ._state import CLIENT, IDLE, SEND_BODY, SEND_RESPONSE, SERVER +from ._util import LocalProtocolError, Sentinel + +__all__ = ["WRITERS"] + +Writer = Callable[[bytes], Any] + + +def write_headers(headers: Headers, write: Writer) -> None: + # "Since the Host field-value is critical information for handling a + # request, a user agent SHOULD generate Host as the first header field + # following the request-line." - RFC 7230 + raw_items = headers._full_items + for raw_name, name, value in raw_items: + if name == b"host": + write(b"%s: %s\r\n" % (raw_name, value)) + for raw_name, name, value in raw_items: + if name != b"host": + write(b"%s: %s\r\n" % (raw_name, value)) + write(b"\r\n") + + +def write_request(request: Request, write: Writer) -> None: + if request.http_version != b"1.1": + raise LocalProtocolError("I only send HTTP/1.1") + write(b"%s %s HTTP/1.1\r\n" % (request.method, request.target)) + write_headers(request.headers, write) + + +# Shared between InformationalResponse and Response +def write_any_response( + response: Union[InformationalResponse, Response], write: Writer +) -> None: + if response.http_version != b"1.1": + raise LocalProtocolError("I only send HTTP/1.1") + status_bytes = str(response.status_code).encode("ascii") + # We don't bother sending ascii status messages like "OK"; they're + # optional and ignored by the protocol. (But the space after the numeric + # status code is mandatory.) + # + # XX FIXME: could at least make an effort to pull out the status message + # from stdlib's http.HTTPStatus table. Or maybe just steal their enums + # (either by import or copy/paste). We already accept them as status codes + # since they're of type IntEnum < int. + write(b"HTTP/1.1 %s %s\r\n" % (status_bytes, response.reason)) + write_headers(response.headers, write) + + +class BodyWriter: + def __call__(self, event: Event, write: Writer) -> None: + if type(event) is Data: + self.send_data(event.data, write) + elif type(event) is EndOfMessage: + self.send_eom(event.headers, write) + else: # pragma: no cover + assert False + + def send_data(self, data: bytes, write: Writer) -> None: + pass + + def send_eom(self, headers: Headers, write: Writer) -> None: + pass + + +# +# These are all careful not to do anything to 'data' except call len(data) and +# write(data). This allows us to transparently pass-through funny objects, +# like placeholder objects referring to files on disk that will be sent via +# sendfile(2). +# +class ContentLengthWriter(BodyWriter): + def __init__(self, length: int) -> None: + self._length = length + + def send_data(self, data: bytes, write: Writer) -> None: + self._length -= len(data) + if self._length < 0: + raise LocalProtocolError("Too much data for declared Content-Length") + write(data) + + def send_eom(self, headers: Headers, write: Writer) -> None: + if self._length != 0: + raise LocalProtocolError("Too little data for declared Content-Length") + if headers: + raise LocalProtocolError("Content-Length and trailers don't mix") + + +class ChunkedWriter(BodyWriter): + def send_data(self, data: bytes, write: Writer) -> None: + # if we encoded 0-length data in the naive way, it would look like an + # end-of-message. + if not data: + return + write(b"%x\r\n" % len(data)) + write(data) + write(b"\r\n") + + def send_eom(self, headers: Headers, write: Writer) -> None: + write(b"0\r\n") + write_headers(headers, write) + + +class Http10Writer(BodyWriter): + def send_data(self, data: bytes, write: Writer) -> None: + write(data) + + def send_eom(self, headers: Headers, write: Writer) -> None: + if headers: + raise LocalProtocolError("can't send trailers to HTTP/1.0 client") + # no need to close the socket ourselves, that will be taken care of by + # Connection: close machinery + + +WritersType = Dict[ + Union[Tuple[Type[Sentinel], Type[Sentinel]], Type[Sentinel]], + Union[ + Dict[str, Type[BodyWriter]], + Callable[[Union[InformationalResponse, Response], Writer], None], + Callable[[Request, Writer], None], + ], +] + +WRITERS: WritersType = { + (CLIENT, IDLE): write_request, + (SERVER, IDLE): write_any_response, + (SERVER, SEND_RESPONSE): write_any_response, + SEND_BODY: { + "chunked": ChunkedWriter, + "content-length": ContentLengthWriter, + "http/1.0": Http10Writer, + }, +} diff --git a/contrib/python/h11/h11/py.typed b/contrib/python/h11/h11/py.typed new file mode 100644 index 0000000000..f5642f79f2 --- /dev/null +++ b/contrib/python/h11/h11/py.typed @@ -0,0 +1 @@ +Marker diff --git a/contrib/python/h11/ya.make b/contrib/python/h11/ya.make new file mode 100644 index 0000000000..48fcc1a654 --- /dev/null +++ b/contrib/python/h11/ya.make @@ -0,0 +1,33 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(0.14.0) + +LICENSE(MIT) + +NO_LINT() + +PY_SRCS( + TOP_LEVEL + h11/__init__.py + h11/_abnf.py + h11/_connection.py + h11/_events.py + h11/_headers.py + h11/_readers.py + h11/_receivebuffer.py + h11/_state.py + h11/_util.py + h11/_version.py + h11/_writers.py +) + +RESOURCE_FILES( + PREFIX contrib/python/h11/ + .dist-info/METADATA + .dist-info/top_level.txt + h11/py.typed +) + +END() diff --git a/contrib/python/httpcore/.dist-info/METADATA b/contrib/python/httpcore/.dist-info/METADATA new file mode 100644 index 0000000000..3776738caf --- /dev/null +++ b/contrib/python/httpcore/.dist-info/METADATA @@ -0,0 +1,547 @@ +Metadata-Version: 2.1 +Name: httpcore +Version: 0.18.0 +Summary: A minimal low-level HTTP client. +Project-URL: Documentation, https://www.encode.io/httpcore +Project-URL: Homepage, https://www.encode.io/httpcore/ +Project-URL: Source, https://github.com/encode/httpcore +Author-email: Tom Christie <tom@tomchristie.com> +License-Expression: BSD-3-Clause +License-File: LICENSE.md +Classifier: Development Status :: 3 - Alpha +Classifier: Environment :: Web Environment +Classifier: Framework :: AsyncIO +Classifier: Framework :: Trio +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Topic :: Internet :: WWW/HTTP +Requires-Python: >=3.8 +Requires-Dist: anyio<5.0,>=3.0 +Requires-Dist: certifi +Requires-Dist: h11<0.15,>=0.13 +Requires-Dist: sniffio==1.* +Provides-Extra: http2 +Requires-Dist: h2<5,>=3; extra == 'http2' +Provides-Extra: socks +Requires-Dist: socksio==1.*; extra == 'socks' +Description-Content-Type: text/markdown + +# HTTP Core + +[![Test Suite](https://github.com/encode/httpcore/workflows/Test%20Suite/badge.svg)](https://github.com/encode/httpcore/actions) +[![Package version](https://badge.fury.io/py/httpcore.svg)](https://pypi.org/project/httpcore/) + +> *Do one thing, and do it well.* + +The HTTP Core package provides a minimal low-level HTTP client, which does +one thing only. Sending HTTP requests. + +It does not provide any high level model abstractions over the API, +does not handle redirects, multipart uploads, building authentication headers, +transparent HTTP caching, URL parsing, session cookie handling, +content or charset decoding, handling JSON, environment based configuration +defaults, or any of that Jazz. + +Some things HTTP Core does do: + +* Sending HTTP requests. +* Thread-safe / task-safe connection pooling. +* HTTP(S) proxy & SOCKS proxy support. +* Supports HTTP/1.1 and HTTP/2. +* Provides both sync and async interfaces. +* Async backend support for `asyncio` and `trio`. + +## Requirements + +Python 3.8+ + +## Installation + +For HTTP/1.1 only support, install with: + +```shell +$ pip install httpcore +``` + +For HTTP/1.1 and HTTP/2 support, install with: + +```shell +$ pip install httpcore[http2] +``` + +For SOCKS proxy support, install with: + +```shell +$ pip install httpcore[socks] +``` + +# Sending requests + +Send an HTTP request: + +```python +import httpcore + +response = httpcore.request("GET", "https://www.example.com/") + +print(response) +# <Response [200]> +print(response.status) +# 200 +print(response.headers) +# [(b'Accept-Ranges', b'bytes'), (b'Age', b'557328'), (b'Cache-Control', b'max-age=604800'), ...] +print(response.content) +# b'<!doctype html>\n<html>\n<head>\n<title>Example Domain</title>\n\n<meta charset="utf-8"/>\n ...' +``` + +The top-level `httpcore.request()` function is provided for convenience. In practice whenever you're working with `httpcore` you'll want to use the connection pooling functionality that it provides. + +```python +import httpcore + +http = httpcore.ConnectionPool() +response = http.request("GET", "https://www.example.com/") +``` + +Once you're ready to get going, [head over to the documentation](https://www.encode.io/httpcore/). + +## Motivation + +You *probably* don't want to be using HTTP Core directly. It might make sense if +you're writing something like a proxy service in Python, and you just want +something at the lowest possible level, but more typically you'll want to use +a higher level client library, such as `httpx`. + +The motivation for `httpcore` is: + +* To provide a reusable low-level client library, that other packages can then build on top of. +* To provide a *really clear interface split* between the networking code and client logic, + so that each is easier to understand and reason about in isolation. +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). + +## 0.18.0 (September 8th, 2023) + +- Add support for HTTPS proxies. (#745, #786) +- Drop Python 3.7 support. (#727) +- Handle `sni_hostname` extension with SOCKS proxy. (#774) +- Handle HTTP/1.1 half-closed connections gracefully. (#641) +- Change the type of `Extensions` from `Mapping[Str, Any]` to `MutableMapping[Str, Any]`. (#762) + +## 0.17.3 (July 5th, 2023) + +- Support async cancellations, ensuring that the connection pool is left in a clean state when cancellations occur. (#726) +- The networking backend interface has [been added to the public API](https://www.encode.io/httpcore/network-backends). Some classes which were previously private implementation detail are now part of the top-level public API. (#699) +- Graceful handling of HTTP/2 GoAway frames, with requests being transparently retried on a new connection. (#730) +- Add exceptions when a synchronous `trace callback` is passed to an asynchronous request or an asynchronous `trace callback` is passed to a synchronous request. (#717) +- Drop Python 3.7 support. (#727) + +## 0.17.2 (May 23th, 2023) + +- Add `socket_options` argument to `ConnectionPool` and `HTTProxy` classes. (#668) +- Improve logging with per-module logger names. (#690) +- Add `sni_hostname` request extension. (#696) +- Resolve race condition during import of `anyio` package. (#692) +- Enable TCP_NODELAY for all synchronous sockets. (#651) + +## 0.17.1 (May 17th, 2023) + +- If 'retries' is set, then allow retries if an SSL handshake error occurs. (#669) +- Improve correctness of tracebacks on network exceptions, by raising properly chained exceptions. (#678) +- Prevent connection-hanging behaviour when HTTP/2 connections are closed by a server-sent 'GoAway' frame. (#679) +- Fix edge-case exception when removing requests from the connection pool. (#680) +- Fix pool timeout edge-case. (#688) + +## 0.17.0 (March 16th, 2023) + +- Add DEBUG level logging. (#648) +- Respect HTTP/2 max concurrent streams when settings updates are sent by server. (#652) +- Increase the allowable HTTP header size to 100kB. (#647) +- Add `retries` option to SOCKS proxy classes. (#643) + +## 0.16.3 (December 20th, 2022) + +- Allow `ws` and `wss` schemes. Allows us to properly support websocket upgrade connections. (#625) +- Forwarding HTTP proxies use a connection-per-remote-host. Required by some proxy implementations. (#637) +- Don't raise `RuntimeError` when closing a connection pool with active connections. Removes some error cases when cancellations are used. (#631) +- Lazy import `anyio`, so that it's no longer a hard dependancy, and isn't imported if unused. (#639) + +## 0.16.2 (November 25th, 2022) + +- Revert 'Fix async cancellation behaviour', which introduced race conditions. (#627) +- Raise `RuntimeError` if attempting to us UNIX domain sockets on Windows. (#619) + +## 0.16.1 (November 17th, 2022) + +- Fix HTTP/1.1 interim informational responses, such as "100 Continue". (#605) + +## 0.16.0 (October 11th, 2022) + +- Support HTTP/1.1 informational responses. (#581) +- Fix async cancellation behaviour. (#580) +- Support `h11` 0.14. (#579) + +## 0.15.0 (May 17th, 2022) + +- Drop Python 3.6 support (#535) +- Ensure HTTP proxy CONNECT requests include `timeout` configuration. (#506) +- Switch to explicit `typing.Optional` for type hints. (#513) +- For `trio` map OSError exceptions to `ConnectError`. (#543) + +## 0.14.7 (February 4th, 2022) + +- Requests which raise a PoolTimeout need to be removed from the pool queue. (#502) +- Fix AttributeError that happened when Socks5Connection were terminated. (#501) + +## 0.14.6 (February 1st, 2022) + +- Fix SOCKS support for `http://` URLs. (#492) +- Resolve race condition around exceptions during streaming a response. (#491) + +## 0.14.5 (January 18th, 2022) + +- SOCKS proxy support. (#478) +- Add proxy_auth argument to HTTPProxy. (#481) +- Improve error message on 'RemoteProtocolError' exception when server disconnects without sending a response. (#479) + +## 0.14.4 (January 5th, 2022) + +- Support HTTP/2 on HTTPS tunnelling proxies. (#468) +- Fix proxy headers missing on HTTP forwarding. (#456) +- Only instantiate SSL context if required. (#457) +- More robust HTTP/2 handling. (#253, #439, #440, #441) + +## 0.14.3 (November 17th, 2021) + +- Fix race condition when removing closed connections from the pool. (#437) + +## 0.14.2 (November 16th, 2021) + +- Failed connections no longer remain in the pool. (Pull #433) + +## 0.14.1 (November 12th, 2021) + +- `max_connections` becomes optional. (Pull #429) +- `certifi` is now included in the install dependancies. (Pull #428) +- `h2` is now strictly optional. (Pull #428) + +## 0.14.0 (November 11th, 2021) + +The 0.14 release is a complete reworking of `httpcore`, comprehensively addressing some underlying issues in the connection pooling, as well as substantially redesigning the API to be more user friendly. + +Some of the lower-level API design also makes the components more easily testable in isolation, and the package now has 100% test coverage. + +See [discussion #419](https://github.com/encode/httpcore/discussions/419) for a little more background. + +There's some other neat bits in there too, such as the "trace" extension, which gives a hook into inspecting the internal events that occur during the request/response cycle. This extension is needed for the HTTPX cli, in order to... + +* Log the point at which the connection is established, and the IP/port on which it is made. +* Determine if the outgoing request should log as HTTP/1.1 or HTTP/2, rather than having to assume it's HTTP/2 if the --http2 flag was passed. (Which may not actually be true.) +* Log SSL version info / certificate info. + +Note that `curio` support is not currently available in 0.14.0. If you're using `httpcore` with `curio` please get in touch, so we can assess if we ought to prioritize it as a feature or not. + +## 0.13.7 (September 13th, 2021) + +- Fix broken error messaging when URL scheme is missing, or a non HTTP(S) scheme is used. (Pull #403) + +## 0.13.6 (June 15th, 2021) + +### Fixed + +- Close sockets when read or write timeouts occur. (Pull #365) + +## 0.13.5 (June 14th, 2021) + +### Fixed + +- Resolved niggles with AnyIO EOF behaviours. (Pull #358, #362) + +## 0.13.4 (June 9th, 2021) + +### Added + +- Improved error messaging when URL scheme is missing, or a non HTTP(S) scheme is used. (Pull #354) + +### Fixed + +- Switched to `anyio` as the default backend implementation when running with `asyncio`. Resolves some awkward [TLS timeout issues](https://github.com/encode/httpx/discussions/1511). + +## 0.13.3 (May 6th, 2021) + +### Added + +- Support HTTP/2 prior knowledge, using `httpcore.SyncConnectionPool(http1=False)`. (Pull #333) + +### Fixed + +- Handle cases where environment does not provide `select.poll` support. (Pull #331) + +## 0.13.2 (April 29th, 2021) + +### Added + +- Improve error message for specific case of `RemoteProtocolError` where server disconnects without sending a response. (Pull #313) + +## 0.13.1 (April 28th, 2021) + +### Fixed + +- More resiliant testing for closed connections. (Pull #311) +- Don't raise exceptions on ungraceful connection closes. (Pull #310) + +## 0.13.0 (April 21st, 2021) + +The 0.13 release updates the core API in order to match the HTTPX Transport API, +introduced in HTTPX 0.18 onwards. + +An example of making requests with the new interface is: + +```python +with httpcore.SyncConnectionPool() as http: + status_code, headers, stream, extensions = http.handle_request( + method=b'GET', + url=(b'https', b'example.org', 443, b'/'), + headers=[(b'host', b'example.org'), (b'user-agent', b'httpcore')] + stream=httpcore.ByteStream(b''), + extensions={} + ) + body = stream.read() + print(status_code, body) +``` + +### Changed + +- The `.request()` method is now `handle_request()`. (Pull #296) +- The `.arequest()` method is now `.handle_async_request()`. (Pull #296) +- The `headers` argument is no longer optional. (Pull #296) +- The `stream` argument is no longer optional. (Pull #296) +- The `ext` argument is now named `extensions`, and is no longer optional. (Pull #296) +- The `"reason"` extension keyword is now named `"reason_phrase"`. (Pull #296) +- The `"reason_phrase"` and `"http_version"` extensions now use byte strings for their values. (Pull #296) +- The `httpcore.PlainByteStream()` class becomes `httpcore.ByteStream()`. (Pull #296) + +### Added + +- Streams now support a `.read()` interface. (Pull #296) + +### Fixed + +- Task cancellation no longer leaks connections from the connection pool. (Pull #305) + +## 0.12.3 (December 7th, 2020) + +### Fixed + +- Abort SSL connections on close rather than waiting for remote EOF when using `asyncio`. (Pull #167) +- Fix exception raised in case of connect timeouts when using the `anyio` backend. (Pull #236) +- Fix `Host` header precedence for `:authority` in HTTP/2. (Pull #241, #243) +- Handle extra edge case when detecting for socket readability when using `asyncio`. (Pull #242, #244) +- Fix `asyncio` SSL warning when using proxy tunneling. (Pull #249) + +## 0.12.2 (November 20th, 2020) + +### Fixed + +- Properly wrap connect errors on the asyncio backend. (Pull #235) +- Fix `ImportError` occurring on Python 3.9 when using the HTTP/1.1 sync client in a multithreaded context. (Pull #237) + +## 0.12.1 (November 7th, 2020) + +### Added + +- Add connect retries. (Pull #221) + +### Fixed + +- Tweak detection of dropped connections, resolving an issue with open files limits on Linux. (Pull #185) +- Avoid leaking connections when establishing an HTTP tunnel to a proxy has failed. (Pull #223) +- Properly wrap OS errors when using `trio`. (Pull #225) + +## 0.12.0 (October 6th, 2020) + +### Changed + +- HTTP header casing is now preserved, rather than always sent in lowercase. (#216 and python-hyper/h11#104) + +### Added + +- Add Python 3.9 to officially supported versions. + +### Fixed + +- Gracefully handle a stdlib asyncio bug when a connection is closed while it is in a paused-for-reading state. (#201) + +## 0.11.1 (September 28nd, 2020) + +### Fixed + +- Add await to async semaphore release() coroutine (#197) +- Drop incorrect curio classifier (#192) + +## 0.11.0 (September 22nd, 2020) + +The Transport API with 0.11.0 has a couple of significant changes. + +Firstly we've moved changed the request interface in order to allow extensions, which will later enable us to support features +such as trailing headers, HTTP/2 server push, and CONNECT/Upgrade connections. + +The interface changes from: + +```python +def request(method, url, headers, stream, timeout): + return (http_version, status_code, reason, headers, stream) +``` + +To instead including an optional dictionary of extensions on the request and response: + +```python +def request(method, url, headers, stream, ext): + return (status_code, headers, stream, ext) +``` + +Having an open-ended extensions point will allow us to add later support for various optional features, that wouldn't otherwise be supported without these API changes. + +In particular: + +* Trailing headers support. +* HTTP/2 Server Push +* sendfile. +* Exposing raw connection on CONNECT, Upgrade, HTTP/2 bi-di streaming. +* Exposing debug information out of the API, including template name, template context. + +Currently extensions are limited to: + +* request: `timeout` - Optional. Timeout dictionary. +* response: `http_version` - Optional. Include the HTTP version used on the response. +* response: `reason` - Optional. Include the reason phrase used on the response. Only valid with HTTP/1.*. + +See https://github.com/encode/httpx/issues/1274#issuecomment-694884553 for the history behind this. + +Secondly, the async version of `request` is now namespaced as `arequest`. + +This allows concrete transports to support both sync and async implementations on the same class. + +### Added + +- Add curio support. (Pull #168) +- Add anyio support, with `backend="anyio"`. (Pull #169) + +### Changed + +- Update the Transport API to use 'ext' for optional extensions. (Pull #190) +- Update the Transport API to use `.request` and `.arequest` so implementations can support both sync and async. (Pull #189) + +## 0.10.2 (August 20th, 2020) + +### Added + +- Added Unix Domain Socket support. (Pull #139) + +### Fixed + +- Always include the port on proxy CONNECT requests. (Pull #154) +- Fix `max_keepalive_connections` configuration. (Pull #153) +- Fixes behaviour in HTTP/1.1 where server disconnects can be used to signal the end of the response body. (Pull #164) + +## 0.10.1 (August 7th, 2020) + +- Include `max_keepalive_connections` on `AsyncHTTPProxy`/`SyncHTTPProxy` classes. + +## 0.10.0 (August 7th, 2020) + +The most notable change in the 0.10.0 release is that HTTP/2 support is now fully optional. + +Use either `pip install httpcore` for HTTP/1.1 support only, or `pip install httpcore[http2]` for HTTP/1.1 and HTTP/2 support. + +### Added + +- HTTP/2 support becomes optional. (Pull #121, #130) +- Add `local_address=...` support. (Pull #100, #134) +- Add `PlainByteStream`, `IteratorByteStream`, `AsyncIteratorByteStream`. The `AsyncByteSteam` and `SyncByteStream` classes are now pure interface classes. (#133) +- Add `LocalProtocolError`, `RemoteProtocolError` exceptions. (Pull #129) +- Add `UnsupportedProtocol` exception. (Pull #128) +- Add `.get_connection_info()` method. (Pull #102, #137) +- Add better TRACE logs. (Pull #101) + +### Changed + +- `max_keepalive` is deprecated in favour of `max_keepalive_connections`. (Pull #140) + +### Fixed + +- Improve handling of server disconnects. (Pull #112) + +## 0.9.1 (May 27th, 2020) + +### Fixed + +- Proper host resolution for sync case, including IPv6 support. (Pull #97) +- Close outstanding connections when connection pool is closed. (Pull #98) + +## 0.9.0 (May 21th, 2020) + +### Changed + +- URL port becomes an `Optional[int]` instead of `int`. (Pull #92) + +### Fixed + +- Honor HTTP/2 max concurrent streams settings. (Pull #89, #90) +- Remove incorrect debug log. (Pull #83) + +## 0.8.4 (May 11th, 2020) + +### Added + +- Logging via HTTPCORE_LOG_LEVEL and HTTPX_LOG_LEVEL environment variables +and TRACE level logging. (Pull #79) + +### Fixed + +- Reuse of connections on HTTP/2 in close concurrency situations. (Pull #81) + +## 0.8.3 (May 6rd, 2020) + +### Fixed + +- Include `Host` and `Accept` headers on proxy "CONNECT" requests. +- De-duplicate any headers also contained in proxy_headers. +- HTTP/2 flag not being passed down to proxy connections. + +## 0.8.2 (May 3rd, 2020) + +### Fixed + +- Fix connections using proxy forwarding requests not being added to the +connection pool properly. (Pull #70) + +## 0.8.1 (April 30th, 2020) + +### Changed + +- Allow inherintance of both `httpcore.AsyncByteStream`, `httpcore.SyncByteStream` without type conflicts. + +## 0.8.0 (April 30th, 2020) + +### Fixed + +- Fixed tunnel proxy support. + +### Added + +- New `TimeoutException` base class. + +## 0.7.0 (March 5th, 2020) + +- First integration with HTTPX. diff --git a/contrib/python/httpcore/.dist-info/top_level.txt b/contrib/python/httpcore/.dist-info/top_level.txt new file mode 100644 index 0000000000..613e43507b --- /dev/null +++ b/contrib/python/httpcore/.dist-info/top_level.txt @@ -0,0 +1,4 @@ +httpcore +httpcore/_async +httpcore/_backends +httpcore/_sync diff --git a/contrib/python/httpcore/LICENSE.md b/contrib/python/httpcore/LICENSE.md new file mode 100644 index 0000000000..311b2b56c5 --- /dev/null +++ b/contrib/python/httpcore/LICENSE.md @@ -0,0 +1,27 @@ +Copyright © 2020, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/contrib/python/httpcore/README.md b/contrib/python/httpcore/README.md new file mode 100644 index 0000000000..66a2150016 --- /dev/null +++ b/contrib/python/httpcore/README.md @@ -0,0 +1,91 @@ +# HTTP Core + +[![Test Suite](https://github.com/encode/httpcore/workflows/Test%20Suite/badge.svg)](https://github.com/encode/httpcore/actions) +[![Package version](https://badge.fury.io/py/httpcore.svg)](https://pypi.org/project/httpcore/) + +> *Do one thing, and do it well.* + +The HTTP Core package provides a minimal low-level HTTP client, which does +one thing only. Sending HTTP requests. + +It does not provide any high level model abstractions over the API, +does not handle redirects, multipart uploads, building authentication headers, +transparent HTTP caching, URL parsing, session cookie handling, +content or charset decoding, handling JSON, environment based configuration +defaults, or any of that Jazz. + +Some things HTTP Core does do: + +* Sending HTTP requests. +* Thread-safe / task-safe connection pooling. +* HTTP(S) proxy & SOCKS proxy support. +* Supports HTTP/1.1 and HTTP/2. +* Provides both sync and async interfaces. +* Async backend support for `asyncio` and `trio`. + +## Requirements + +Python 3.8+ + +## Installation + +For HTTP/1.1 only support, install with: + +```shell +$ pip install httpcore +``` + +For HTTP/1.1 and HTTP/2 support, install with: + +```shell +$ pip install httpcore[http2] +``` + +For SOCKS proxy support, install with: + +```shell +$ pip install httpcore[socks] +``` + +# Sending requests + +Send an HTTP request: + +```python +import httpcore + +response = httpcore.request("GET", "https://www.example.com/") + +print(response) +# <Response [200]> +print(response.status) +# 200 +print(response.headers) +# [(b'Accept-Ranges', b'bytes'), (b'Age', b'557328'), (b'Cache-Control', b'max-age=604800'), ...] +print(response.content) +# b'<!doctype html>\n<html>\n<head>\n<title>Example Domain</title>\n\n<meta charset="utf-8"/>\n ...' +``` + +The top-level `httpcore.request()` function is provided for convenience. In practice whenever you're working with `httpcore` you'll want to use the connection pooling functionality that it provides. + +```python +import httpcore + +http = httpcore.ConnectionPool() +response = http.request("GET", "https://www.example.com/") +``` + +Once you're ready to get going, [head over to the documentation](https://www.encode.io/httpcore/). + +## Motivation + +You *probably* don't want to be using HTTP Core directly. It might make sense if +you're writing something like a proxy service in Python, and you just want +something at the lowest possible level, but more typically you'll want to use +a higher level client library, such as `httpx`. + +The motivation for `httpcore` is: + +* To provide a reusable low-level client library, that other packages can then build on top of. +* To provide a *really clear interface split* between the networking code and client logic, + so that each is easier to understand and reason about in isolation. diff --git a/contrib/python/httpcore/httpcore/__init__.py b/contrib/python/httpcore/httpcore/__init__.py new file mode 100644 index 0000000000..65abe9716a --- /dev/null +++ b/contrib/python/httpcore/httpcore/__init__.py @@ -0,0 +1,139 @@ +from ._api import request, stream +from ._async import ( + AsyncConnectionInterface, + AsyncConnectionPool, + AsyncHTTP2Connection, + AsyncHTTP11Connection, + AsyncHTTPConnection, + AsyncHTTPProxy, + AsyncSOCKSProxy, +) +from ._backends.base import ( + SOCKET_OPTION, + AsyncNetworkBackend, + AsyncNetworkStream, + NetworkBackend, + NetworkStream, +) +from ._backends.mock import AsyncMockBackend, AsyncMockStream, MockBackend, MockStream +from ._backends.sync import SyncBackend +from ._exceptions import ( + ConnectError, + ConnectionNotAvailable, + ConnectTimeout, + LocalProtocolError, + NetworkError, + PoolTimeout, + ProtocolError, + ProxyError, + ReadError, + ReadTimeout, + RemoteProtocolError, + TimeoutException, + UnsupportedProtocol, + WriteError, + WriteTimeout, +) +from ._models import URL, Origin, Request, Response +from ._ssl import default_ssl_context +from ._sync import ( + ConnectionInterface, + ConnectionPool, + HTTP2Connection, + HTTP11Connection, + HTTPConnection, + HTTPProxy, + SOCKSProxy, +) + +# The 'httpcore.AnyIOBackend' class is conditional on 'anyio' being installed. +try: + from ._backends.anyio import AnyIOBackend +except ImportError: # pragma: nocover + + class AnyIOBackend: # type: ignore + def __init__(self, *args, **kwargs): # type: ignore + msg = ( + "Attempted to use 'httpcore.AnyIOBackend' but 'anyio' is not installed." + ) + raise RuntimeError(msg) + + +# The 'httpcore.TrioBackend' class is conditional on 'trio' being installed. +try: + from ._backends.trio import TrioBackend +except ImportError: # pragma: nocover + + class TrioBackend: # type: ignore + def __init__(self, *args, **kwargs): # type: ignore + msg = "Attempted to use 'httpcore.TrioBackend' but 'trio' is not installed." + raise RuntimeError(msg) + + +__all__ = [ + # top-level requests + "request", + "stream", + # models + "Origin", + "URL", + "Request", + "Response", + # async + "AsyncHTTPConnection", + "AsyncConnectionPool", + "AsyncHTTPProxy", + "AsyncHTTP11Connection", + "AsyncHTTP2Connection", + "AsyncConnectionInterface", + "AsyncSOCKSProxy", + # sync + "HTTPConnection", + "ConnectionPool", + "HTTPProxy", + "HTTP11Connection", + "HTTP2Connection", + "ConnectionInterface", + "SOCKSProxy", + # network backends, implementations + "SyncBackend", + "AnyIOBackend", + "TrioBackend", + # network backends, mock implementations + "AsyncMockBackend", + "AsyncMockStream", + "MockBackend", + "MockStream", + # network backends, interface + "AsyncNetworkStream", + "AsyncNetworkBackend", + "NetworkStream", + "NetworkBackend", + # util + "default_ssl_context", + "SOCKET_OPTION", + # exceptions + "ConnectionNotAvailable", + "ProxyError", + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", + "UnsupportedProtocol", + "TimeoutException", + "PoolTimeout", + "ConnectTimeout", + "ReadTimeout", + "WriteTimeout", + "NetworkError", + "ConnectError", + "ReadError", + "WriteError", +] + +__version__ = "0.18.0" + + +__locals = locals() +for __name in __all__: + if not __name.startswith("__"): + setattr(__locals[__name], "__module__", "httpcore") # noqa diff --git a/contrib/python/httpcore/httpcore/_api.py b/contrib/python/httpcore/httpcore/_api.py new file mode 100644 index 0000000000..854235f5f6 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_api.py @@ -0,0 +1,92 @@ +from contextlib import contextmanager +from typing import Iterator, Optional, Union + +from ._models import URL, Extensions, HeaderTypes, Response +from ._sync.connection_pool import ConnectionPool + + +def request( + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterator[bytes], None] = None, + extensions: Optional[Extensions] = None, +) -> Response: + """ + Sends an HTTP request, returning the response. + + ``` + response = httpcore.request("GET", "https://www.example.com/") + ``` + + Arguments: + method: The HTTP method for the request. Typically one of `"GET"`, + `"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`. + url: The URL of the HTTP request. Either as an instance of `httpcore.URL`, + or as str/bytes. + headers: The HTTP request headers. Either as a dictionary of str/bytes, + or as a list of two-tuples of str/bytes. + content: The content of the request body. Either as bytes, + or as a bytes iterator. + extensions: A dictionary of optional extra information included on the request. + Possible keys include `"timeout"`. + + Returns: + An instance of `httpcore.Response`. + """ + with ConnectionPool() as pool: + return pool.request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + + +@contextmanager +def stream( + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterator[bytes], None] = None, + extensions: Optional[Extensions] = None, +) -> Iterator[Response]: + """ + Sends an HTTP request, returning the response within a content manager. + + ``` + with httpcore.stream("GET", "https://www.example.com/") as response: + ... + ``` + + When using the `stream()` function, the body of the response will not be + automatically read. If you want to access the response body you should + either use `content = response.read()`, or `for chunk in response.iter_content()`. + + Arguments: + method: The HTTP method for the request. Typically one of `"GET"`, + `"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`. + url: The URL of the HTTP request. Either as an instance of `httpcore.URL`, + or as str/bytes. + headers: The HTTP request headers. Either as a dictionary of str/bytes, + or as a list of two-tuples of str/bytes. + content: The content of the request body. Either as bytes, + or as a bytes iterator. + extensions: A dictionary of optional extra information included on the request. + Possible keys include `"timeout"`. + + Returns: + An instance of `httpcore.Response`. + """ + with ConnectionPool() as pool: + with pool.stream( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) as response: + yield response diff --git a/contrib/python/httpcore/httpcore/_async/__init__.py b/contrib/python/httpcore/httpcore/_async/__init__.py new file mode 100644 index 0000000000..88dc7f01e1 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/__init__.py @@ -0,0 +1,39 @@ +from .connection import AsyncHTTPConnection +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .http_proxy import AsyncHTTPProxy +from .interfaces import AsyncConnectionInterface + +try: + from .http2 import AsyncHTTP2Connection +except ImportError: # pragma: nocover + + class AsyncHTTP2Connection: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use http2 support, but the `h2` package is not " + "installed. Use 'pip install httpcore[http2]'." + ) + + +try: + from .socks_proxy import AsyncSOCKSProxy +except ImportError: # pragma: nocover + + class AsyncSOCKSProxy: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use SOCKS support, but the `socksio` package is not " + "installed. Use 'pip install httpcore[socks]'." + ) + + +__all__ = [ + "AsyncHTTPConnection", + "AsyncConnectionPool", + "AsyncHTTPProxy", + "AsyncHTTP11Connection", + "AsyncHTTP2Connection", + "AsyncConnectionInterface", + "AsyncSOCKSProxy", +] diff --git a/contrib/python/httpcore/httpcore/_async/connection.py b/contrib/python/httpcore/httpcore/_async/connection.py new file mode 100644 index 0000000000..45ee22a63d --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/connection.py @@ -0,0 +1,222 @@ +import itertools +import logging +import ssl +from types import TracebackType +from typing import Iterable, Iterator, Optional, Type + +from .._backends.auto import AutoBackend +from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream +from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface + +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. + + +logger = logging.getLogger("httpcore.connection") + + +def exponential_backoff(factor: float) -> Iterator[float]: + """ + Generate a geometric sequence that has a ratio of 2 and starts with 0. + + For example: + - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...` + - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...` + """ + yield 0 + for n in itertools.count(): + yield factor * 2**n + + +class AsyncHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + origin: Origin, + ssl_context: Optional[ssl.SSLContext] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + self._origin = origin + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._network_backend: AsyncNetworkBackend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._connection: Optional[AsyncConnectionInterface] = None + self._connect_failed: bool = False + self._request_lock = AsyncLock() + self._socket_options = socket_options + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection to {self._origin}" + ) + + async with self._request_lock: + if self._connection is None: + try: + stream = await self._connect(request) + + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): + raise ConnectionNotAvailable() + + return await self._connection.handle_async_request(request) + + async def _connect(self, request: Request) -> AsyncNetworkStream: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + retries_left = self._retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + if self._uds is None: + kwargs = { + "host": self._origin.host.decode("ascii"), + "port": self._origin.port, + "local_address": self._local_address, + "timeout": timeout, + "socket_options": self._socket_options, + } + async with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = await self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + else: + kwargs = { + "path": self._uds, + "timeout": timeout, + "socket_options": self._socket_options, + } + async with Trace( + "connect_unix_socket", logger, request, kwargs + ) as trace: + stream = await self._network_backend.connect_unix_socket( + **kwargs + ) + trace.return_value = stream + + if self._origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + return stream + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + async with Trace("retry", logger, request, kwargs) as trace: + await self._network_backend.sleep(delay) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + async def aclose(self) -> None: + if self._connection is not None: + async with Trace("close", logger, None, {}): + await self._connection.aclose() + + def is_available(self) -> bool: + if self._connection is None: + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> "AsyncHTTPConnection": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.aclose() diff --git a/contrib/python/httpcore/httpcore/_async/connection_pool.py b/contrib/python/httpcore/httpcore/_async/connection_pool.py new file mode 100644 index 0000000000..ddc0510e60 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/connection_pool.py @@ -0,0 +1,356 @@ +import ssl +import sys +from types import TracebackType +from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type + +from .._backends.auto import AutoBackend +from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Request, Response +from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation +from .connection import AsyncHTTPConnection +from .interfaces import AsyncConnectionInterface, AsyncRequestInterface + + +class RequestStatus: + def __init__(self, request: Request): + self.request = request + self.connection: Optional[AsyncConnectionInterface] = None + self._connection_acquired = AsyncEvent() + + def set_connection(self, connection: AsyncConnectionInterface) -> None: + assert self.connection is None + self.connection = connection + self._connection_acquired.set() + + def unset_connection(self) -> None: + assert self.connection is not None + self.connection = None + self._connection_acquired = AsyncEvent() + + async def wait_for_connection( + self, timeout: Optional[float] = None + ) -> AsyncConnectionInterface: + if self.connection is None: + await self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection + + +class AsyncConnectionPool(AsyncRequestInterface): + """ + A connection pool for making HTTP requests. + """ + + def __init__( + self, + ssl_context: Optional[ssl.SSLContext] = None, + max_connections: Optional[int] = 10, + max_keepalive_connections: Optional[int] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish a + connection. + local_address: Local address to connect from. Can also be used to connect + using a particular address family. Using `local_address="0.0.0.0"` + will connect using an `AF_INET` address (IPv4), while using + `local_address="::"` will connect using an `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + socket_options: Socket options that have to be included + in the TCP socket when the connection was established. + """ + self._ssl_context = ssl_context + + self._max_connections = ( + sys.maxsize if max_connections is None else max_connections + ) + self._max_keepalive_connections = ( + sys.maxsize + if max_keepalive_connections is None + else max_keepalive_connections + ) + self._max_keepalive_connections = min( + self._max_connections, self._max_keepalive_connections + ) + + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._pool: List[AsyncConnectionInterface] = [] + self._requests: List[RequestStatus] = [] + self._pool_lock = AsyncLock() + self._network_backend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._socket_options = socket_options + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + return AsyncHTTPConnection( + origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + retries=self._retries, + local_address=self._local_address, + uds=self._uds, + network_backend=self._network_backend, + socket_options=self._socket_options, + ) + + @property + def connections(self) -> List[AsyncConnectionInterface]: + """ + Return a list of the connections currently in the pool. + + For example: + + ```python + >>> pool.connections + [ + <AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>, + <AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> , + <AsyncHTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>, + ] + ``` + """ + return list(self._pool) + + async def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool: + """ + Attempt to provide a connection that can handle the given origin. + """ + origin = status.request.url.origin + + # If there are queued requests in front of us, then don't acquire a + # connection. We handle requests strictly in order. + waiting = [s for s in self._requests if s.connection is None] + if waiting and waiting[0] is not status: + return False + + # Reuse an existing connection if one is currently available. + for idx, connection in enumerate(self._pool): + if connection.can_handle_request(origin) and connection.is_available(): + self._pool.pop(idx) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + # If the pool is currently full, attempt to close one idle connection. + if len(self._pool) >= self._max_connections: + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle(): + await connection.aclose() + self._pool.pop(idx) + break + + # If the pool is still full, then we cannot acquire a connection. + if len(self._pool) >= self._max_connections: + return False + + # Otherwise create a new connection. + connection = self.create_connection(origin) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + async def _close_expired_connections(self) -> None: + """ + Clean up the connection pool by closing off any connections that have expired. + """ + # Close any connections that have expired their keep-alive time. + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.has_expired(): + await connection.aclose() + self._pool.pop(idx) + + # If the pool size exceeds the maximum number of allowed keep-alive connections, + # then close off idle connections as required. + pool_size = len(self._pool) + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle() and pool_size > self._max_keepalive_connections: + await connection.aclose() + self._pool.pop(idx) + pool_size -= 1 + + async def handle_async_request(self, request: Request) -> Response: + """ + Send an HTTP request, and return an HTTP response. + + This is the core implementation that is called into by `.request()` or `.stream()`. + """ + scheme = request.url.scheme.decode() + if scheme == "": + raise UnsupportedProtocol( + "Request URL is missing an 'http://' or 'https://' protocol." + ) + if scheme not in ("http", "https", "ws", "wss"): + raise UnsupportedProtocol( + f"Request URL has an unsupported protocol '{scheme}://'." + ) + + status = RequestStatus(request) + + async with self._pool_lock: + self._requests.append(status) + await self._close_expired_connections() + await self._attempt_to_acquire_connection(status) + + while True: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + try: + connection = await status.wait_for_connection(timeout=timeout) + except BaseException as exc: + # If we timeout here, or if the task is cancelled, then make + # sure to remove the request from the queue before bubbling + # up the exception. + async with self._pool_lock: + # Ensure only remove when task exists. + if status in self._requests: + self._requests.remove(status) + raise exc + + try: + response = await connection.handle_async_request(request) + except ConnectionNotAvailable: + # The ConnectionNotAvailable exception is a special case, that + # indicates we need to retry the request on a new connection. + # + # The most common case where this can occur is when multiple + # requests are queued waiting for a single connection, which + # might end up as an HTTP/2 connection, but which actually ends + # up as HTTP/1.1. + async with self._pool_lock: + # Maintain our position in the request queue, but reset the + # status so that the request becomes queued again. + status.unset_connection() + await self._attempt_to_acquire_connection(status) + except BaseException as exc: + with AsyncShieldCancellation(): + await self.response_closed(status) + raise exc + else: + break + + # When we return the response, we wrap the stream in a special class + # that handles notifying the connection pool once the response + # has been released. + assert isinstance(response.stream, AsyncIterable) + return Response( + status=response.status, + headers=response.headers, + content=ConnectionPoolByteStream(response.stream, self, status), + extensions=response.extensions, + ) + + async def response_closed(self, status: RequestStatus) -> None: + """ + This method acts as a callback once the request/response cycle is complete. + + It is called into from the `ConnectionPoolByteStream.aclose()` method. + """ + assert status.connection is not None + connection = status.connection + + async with self._pool_lock: + # Update the state of the connection pool. + if status in self._requests: + self._requests.remove(status) + + if connection.is_closed() and connection in self._pool: + self._pool.remove(connection) + + # Since we've had a response closed, it's possible we'll now be able + # to service one or more requests that are currently pending. + for status in self._requests: + if status.connection is None: + acquired = await self._attempt_to_acquire_connection(status) + # If we could not acquire a connection for a queued request + # then we don't need to check anymore requests that are + # queued later behind it. + if not acquired: + break + + # Housekeeping. + await self._close_expired_connections() + + async def aclose(self) -> None: + """ + Close any connections in the pool. + """ + async with self._pool_lock: + for connection in self._pool: + await connection.aclose() + self._pool = [] + self._requests = [] + + async def __aenter__(self) -> "AsyncConnectionPool": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.aclose() + + +class ConnectionPoolByteStream: + """ + A wrapper around the response byte stream, that additionally handles + notifying the connection pool when the response has been closed. + """ + + def __init__( + self, + stream: AsyncIterable[bytes], + pool: AsyncConnectionPool, + status: RequestStatus, + ) -> None: + self._stream = stream + self._pool = pool + self._status = status + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for part in self._stream: + yield part + + async def aclose(self) -> None: + try: + if hasattr(self._stream, "aclose"): + await self._stream.aclose() + finally: + with AsyncShieldCancellation(): + await self._pool.response_closed(self._status) diff --git a/contrib/python/httpcore/httpcore/_async/http11.py b/contrib/python/httpcore/httpcore/_async/http11.py new file mode 100644 index 0000000000..32fa3a6f23 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/http11.py @@ -0,0 +1,343 @@ +import enum +import logging +import time +from types import TracebackType +from typing import ( + AsyncIterable, + AsyncIterator, + List, + Optional, + Tuple, + Type, + Union, + cast, +) + +import h11 + +from .._backends.base import AsyncNetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, + WriteError, + map_exceptions, +) +from .._models import Origin, Request, Response +from .._synchronization import AsyncLock, AsyncShieldCancellation +from .._trace import Trace +from .interfaces import AsyncConnectionInterface + +logger = logging.getLogger("httpcore.http11") + + +# A subset of `h11.Event` types supported by `_send_event` +H11SendEvent = Union[ + h11.Request, + h11.Data, + h11.EndOfMessage, +] + + +class HTTPConnectionState(enum.IntEnum): + NEW = 0 + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class AsyncHTTP11Connection(AsyncConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024 + + def __init__( + self, + origin: Origin, + stream: AsyncNetworkStream, + keepalive_expiry: Optional[float] = None, + ) -> None: + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: Optional[float] = keepalive_expiry + self._expire_at: Optional[float] = None + self._state = HTTPConnectionState.NEW + self._state_lock = AsyncLock() + self._request_count = 0 + self._h11_state = h11.Connection( + our_role=h11.CLIENT, + max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, + ) + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + async with self._state_lock: + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): + self._request_count += 1 + self._state = HTTPConnectionState.ACTIVE + self._expire_at = None + else: + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request} + try: + async with Trace( + "send_request_headers", logger, request, kwargs + ) as trace: + await self._send_request_headers(**kwargs) + async with Trace("send_request_body", logger, request, kwargs) as trace: + await self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + + async with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + ) = await self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": self._network_stream, + }, + ) + except BaseException as exc: + with AsyncShieldCancellation(): + async with Trace("response_closed", logger, request) as trace: + await self._response_closed() + raise exc + + # Sending the request... + + async def _send_request_headers(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): + event = h11.Request( + method=request.method, + target=request.url.target, + headers=request.headers, + ) + await self._send_event(event, timeout=timeout) + + async def _send_request_body(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + assert isinstance(request.stream, AsyncIterable) + async for chunk in request.stream: + event = h11.Data(data=chunk) + await self._send_event(event, timeout=timeout) + + await self._send_event(h11.EndOfMessage(), timeout=timeout) + + async def _send_event( + self, event: h11.Event, timeout: Optional[float] = None + ) -> None: + bytes_to_send = self._h11_state.send(event) + if bytes_to_send is not None: + await self._network_stream.write(bytes_to_send, timeout=timeout) + + # Receiving the response... + + async def _receive_response_headers( + self, request: Request + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = await self._receive_event(timeout=timeout) + if isinstance(event, h11.Response): + break + if ( + isinstance(event, h11.InformationalResponse) + and event.status_code == 101 + ): + break + + http_version = b"HTTP/" + event.http_version + + # h11 version 0.11+ supports a `raw_items` interface to get the + # raw header casing, rather than the enforced lowercase headers. + headers = event.headers.raw_items() + + return http_version, event.status_code, event.reason, headers + + async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = await self._receive_event(timeout=timeout) + if isinstance(event, h11.Data): + yield bytes(event.data) + elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): + break + + async def _receive_event( + self, timeout: Optional[float] = None + ) -> Union[h11.Event, Type[h11.PAUSED]]: + while True: + with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): + event = self._h11_state.next_event() + + if event is h11.NEED_DATA: + data = await self._network_stream.read( + self.READ_NUM_BYTES, timeout=timeout + ) + + # If we feed this case through h11 we'll raise an exception like: + # + # httpcore.RemoteProtocolError: can't handle event type + # ConnectionClosed when role=SERVER and state=SEND_RESPONSE + # + # Which is accurate, but not very informative from an end-user + # perspective. Instead we handle this case distinctly and treat + # it as a ConnectError. + if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: + msg = "Server disconnected without sending a response." + raise RemoteProtocolError(msg) + + self._h11_state.receive_data(data) + else: + # mypy fails to narrow the type in the above if statement above + return cast(Union[h11.Event, Type[h11.PAUSED]], event) + + async def _response_closed(self) -> None: + async with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + await self.aclose() + + # Once the connection is no longer required... + + async def aclose(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._state = HTTPConnectionState.CLOSED + await self._network_stream.aclose() + + # The AsyncConnectionInterface methods provide information about the state of + # the connection, allowing for a connection pooling implementation to + # determine when to reuse and when to close the connection... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + # Note that HTTP/1.1 connections in the "NEW" state are not treated as + # being "available". The control flow which created the connection will + # be able to send an outgoing request, but the connection will not be + # acquired from the connection pool for any other request. + return self._state == HTTPConnectionState.IDLE + + def has_expired(self) -> bool: + now = time.monotonic() + keepalive_expired = self._expire_at is not None and now > self._expire_at + + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + + return keepalive_expired or server_disconnected + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/1.1, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> "AsyncHTTP11Connection": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.aclose() + + +class HTTP11ConnectionByteStream: + def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + self._closed = False + + async def __aiter__(self) -> AsyncIterator[bytes]: + kwargs = {"request": self._request} + try: + async with Trace("receive_response_body", logger, self._request, kwargs): + async for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + async with Trace("response_closed", logger, self._request): + await self._connection._response_closed() diff --git a/contrib/python/httpcore/httpcore/_async/http2.py b/contrib/python/httpcore/httpcore/_async/http2.py new file mode 100644 index 0000000000..8dc776ffa0 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/http2.py @@ -0,0 +1,589 @@ +import enum +import logging +import time +import types +import typing + +import h2.config +import h2.connection +import h2.events +import h2.exceptions +import h2.settings + +from .._backends.base import AsyncNetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, +) +from .._models import Origin, Request, Response +from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation +from .._trace import Trace +from .interfaces import AsyncConnectionInterface + +logger = logging.getLogger("httpcore.http2") + + +def has_body_headers(request: Request) -> bool: + return any( + k.lower() == b"content-length" or k.lower() == b"transfer-encoding" + for k, v in request.headers + ) + + +class HTTPConnectionState(enum.IntEnum): + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class AsyncHTTP2Connection(AsyncConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) + + def __init__( + self, + origin: Origin, + stream: AsyncNetworkStream, + keepalive_expiry: typing.Optional[float] = None, + ): + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: typing.Optional[float] = keepalive_expiry + self._h2_state = h2.connection.H2Connection(config=self.CONFIG) + self._state = HTTPConnectionState.IDLE + self._expire_at: typing.Optional[float] = None + self._request_count = 0 + self._init_lock = AsyncLock() + self._state_lock = AsyncLock() + self._read_lock = AsyncLock() + self._write_lock = AsyncLock() + self._sent_connection_init = False + self._used_all_stream_ids = False + self._connection_error = False + + # Mapping from stream ID to response stream events. + self._events: typing.Dict[ + int, + typing.Union[ + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ], + ] = {} + + # Connection terminated events are stored as state since + # we need to handle them for all streams. + self._connection_terminated: typing.Optional[ + h2.events.ConnectionTerminated + ] = None + + self._read_exception: typing.Optional[Exception] = None + self._write_exception: typing.Optional[Exception] = None + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + # This cannot occur in normal operation, since the connection pool + # will only send requests on connections that handle them. + # It's in place simply for resilience as a guard against incorrect + # usage, for anyone working directly with httpcore connections. + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + async with self._state_lock: + if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + self._request_count += 1 + self._expire_at = None + self._state = HTTPConnectionState.ACTIVE + else: + raise ConnectionNotAvailable() + + async with self._init_lock: + if not self._sent_connection_init: + try: + kwargs = {"request": request} + async with Trace("send_connection_init", logger, request, kwargs): + await self._send_connection_init(**kwargs) + except BaseException as exc: + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + self._sent_connection_init = True + + # Initially start with just 1 until the remote server provides + # its max_concurrent_streams value + self._max_streams = 1 + + local_settings_max_streams = ( + self._h2_state.local_settings.max_concurrent_streams + ) + self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams) + + for _ in range(local_settings_max_streams - self._max_streams): + await self._max_streams_semaphore.acquire() + + await self._max_streams_semaphore.acquire() + + try: + stream_id = self._h2_state.get_next_available_stream_id() + self._events[stream_id] = [] + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + self._request_count -= 1 + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request, "stream_id": stream_id} + async with Trace("send_request_headers", logger, request, kwargs): + await self._send_request_headers(request=request, stream_id=stream_id) + async with Trace("send_request_body", logger, request, kwargs): + await self._send_request_body(request=request, stream_id=stream_id) + async with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + status, headers = await self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + return Response( + status=status, + headers=headers, + content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), + extensions={ + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + }, + ) + except BaseException as exc: # noqa: PIE786 + with AsyncShieldCancellation(): + kwargs = {"stream_id": stream_id} + async with Trace("response_closed", logger, request, kwargs): + await self._response_closed(stream_id=stream_id) + + if isinstance(exc, h2.exceptions.ProtocolError): + # One case where h2 can raise a protocol error is when a + # closed frame has been seen by the state machine. + # + # This happens when one stream is reading, and encounters + # a GOAWAY event. Other flows of control may then raise + # a protocol error at any point they interact with the 'h2_state'. + # + # In this case we'll have stored the event, and should raise + # it as a RemoteProtocolError. + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) + # If h2 raises a protocol error in some other state then we + # must somehow have made a protocol violation. + raise LocalProtocolError(exc) # pragma: nocover + + raise exc + + async def _send_connection_init(self, request: Request) -> None: + """ + The HTTP/2 connection requires some initial setup before we can start + using individual request/response streams on it. + """ + # Need to set these manually here instead of manipulating via + # __setitem__() otherwise the H2Connection will emit SettingsUpdate + # frames in addition to sending the undesired defaults. + self._h2_state.local_settings = h2.settings.Settings( + client=True, + initial_values={ + # Disable PUSH_PROMISE frames from the server since we don't do anything + # with them for now. Maybe when we support caching? + h2.settings.SettingCodes.ENABLE_PUSH: 0, + # These two are taken from h2 for safe defaults + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, + h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, + }, + ) + + # Some websites (*cough* Yahoo *cough*) balk at this setting being + # present in the initial handshake since it's not defined in the original + # RFC despite the RFC mandating ignoring settings you don't know about. + del self._h2_state.local_settings[ + h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL + ] + + self._h2_state.initiate_connection() + self._h2_state.increment_flow_control_window(2**24) + await self._write_outgoing_data(request) + + # Sending the request... + + async def _send_request_headers(self, request: Request, stream_id: int) -> None: + """ + Send the request headers to a given stream ID. + """ + end_stream = not has_body_headers(request) + + # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. + # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require + # HTTP/1.1 style headers, and map them appropriately if we end up on + # an HTTP/2 connection. + authority = [v for k, v in request.headers if k.lower() == b"host"][0] + + headers = [ + (b":method", request.method), + (b":authority", authority), + (b":scheme", request.url.scheme), + (b":path", request.url.target), + ] + [ + (k.lower(), v) + for k, v in request.headers + if k.lower() + not in ( + b"host", + b"transfer-encoding", + ) + ] + + self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) + self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id) + await self._write_outgoing_data(request) + + async def _send_request_body(self, request: Request, stream_id: int) -> None: + """ + Iterate over the request body sending it to a given stream ID. + """ + if not has_body_headers(request): + return + + assert isinstance(request.stream, typing.AsyncIterable) + async for data in request.stream: + await self._send_stream_data(request, stream_id, data) + await self._send_end_stream(request, stream_id) + + async def _send_stream_data( + self, request: Request, stream_id: int, data: bytes + ) -> None: + """ + Send a single chunk of data in one or more data frames. + """ + while data: + max_flow = await self._wait_for_outgoing_flow(request, stream_id) + chunk_size = min(len(data), max_flow) + chunk, data = data[:chunk_size], data[chunk_size:] + self._h2_state.send_data(stream_id, chunk) + await self._write_outgoing_data(request) + + async def _send_end_stream(self, request: Request, stream_id: int) -> None: + """ + Send an empty data frame on on a given stream ID with the END_STREAM flag set. + """ + self._h2_state.end_stream(stream_id) + await self._write_outgoing_data(request) + + # Receiving the response... + + async def _receive_response( + self, request: Request, stream_id: int + ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]: + """ + Return the response status code and headers for a given stream ID. + """ + while True: + event = await self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.ResponseReceived): + break + + status_code = 200 + headers = [] + for k, v in event.headers: + if k == b":status": + status_code = int(v.decode("ascii", errors="ignore")) + elif not k.startswith(b":"): + headers.append((k, v)) + + return (status_code, headers) + + async def _receive_response_body( + self, request: Request, stream_id: int + ) -> typing.AsyncIterator[bytes]: + """ + Iterator that returns the bytes of the response body for a given stream ID. + """ + while True: + event = await self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.DataReceived): + amount = event.flow_controlled_length + self._h2_state.acknowledge_received_data(amount, stream_id) + await self._write_outgoing_data(request) + yield event.data + elif isinstance(event, h2.events.StreamEnded): + break + + async def _receive_stream_event( + self, request: Request, stream_id: int + ) -> typing.Union[ + h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded + ]: + """ + Return the next available event for a given stream ID. + + Will read more data from the network if required. + """ + while not self._events.get(stream_id): + await self._receive_events(request, stream_id) + event = self._events[stream_id].pop(0) + if isinstance(event, h2.events.StreamReset): + raise RemoteProtocolError(event) + return event + + async def _receive_events( + self, request: Request, stream_id: typing.Optional[int] = None + ) -> None: + """ + Read some data from the network until we see one or more events + for a given stream ID. + """ + async with self._read_lock: + if self._connection_terminated is not None: + last_stream_id = self._connection_terminated.last_stream_id + if stream_id and last_stream_id and stream_id > last_stream_id: + self._request_count -= 1 + raise ConnectionNotAvailable() + raise RemoteProtocolError(self._connection_terminated) + + # This conditional is a bit icky. We don't want to block reading if we've + # actually got an event to return for a given stream. We need to do that + # check *within* the atomic read lock. Though it also need to be optional, + # because when we call it from `_wait_for_outgoing_flow` we *do* want to + # block until we've available flow control, event when we have events + # pending for the stream ID we're attempting to send on. + if stream_id is None or not self._events.get(stream_id): + events = await self._read_incoming_data(request) + for event in events: + if isinstance(event, h2.events.RemoteSettingsChanged): + async with Trace( + "receive_remote_settings", logger, request + ) as trace: + await self._receive_remote_settings_change(event) + trace.return_value = event + + elif isinstance( + event, + ( + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ), + ): + if event.stream_id in self._events: + self._events[event.stream_id].append(event) + + elif isinstance(event, h2.events.ConnectionTerminated): + self._connection_terminated = event + + await self._write_outgoing_data(request) + + async def _receive_remote_settings_change(self, event: h2.events.Event) -> None: + max_concurrent_streams = event.changed_settings.get( + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS + ) + if max_concurrent_streams: + new_max_streams = min( + max_concurrent_streams.new_value, + self._h2_state.local_settings.max_concurrent_streams, + ) + if new_max_streams and new_max_streams != self._max_streams: + while new_max_streams > self._max_streams: + await self._max_streams_semaphore.release() + self._max_streams += 1 + while new_max_streams < self._max_streams: + await self._max_streams_semaphore.acquire() + self._max_streams -= 1 + + async def _response_closed(self, stream_id: int) -> None: + await self._max_streams_semaphore.release() + del self._events[stream_id] + async with self._state_lock: + if self._connection_terminated and not self._events: + await self.aclose() + + elif self._state == HTTPConnectionState.ACTIVE and not self._events: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + if self._used_all_stream_ids: # pragma: nocover + await self.aclose() + + async def aclose(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._h2_state.close_connection() + self._state = HTTPConnectionState.CLOSED + await self._network_stream.aclose() + + # Wrappers around network read/write operations... + + async def _read_incoming_data( + self, request: Request + ) -> typing.List[h2.events.Event]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + if self._read_exception is not None: + raise self._read_exception # pragma: nocover + + try: + data = await self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + except Exception as exc: + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future reads. + # (For example, this means that a single read timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._read_exception = exc + self._connection_error = True + raise exc + + events: typing.List[h2.events.Event] = self._h2_state.receive_data(data) + + return events + + async def _write_outgoing_data(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + async with self._write_lock: + data_to_send = self._h2_state.data_to_send() + + if self._write_exception is not None: + raise self._write_exception # pragma: nocover + + try: + await self._network_stream.write(data_to_send, timeout) + except Exception as exc: # pragma: nocover + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future write. + # (For example, this means that a single write timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._write_exception = exc + self._connection_error = True + raise exc + + # Flow control... + + async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: + """ + Returns the maximum allowable outgoing flow for a given stream. + + If the allowable flow is zero, then waits on the network until + WindowUpdated frames have increased the flow rate. + https://tools.ietf.org/html/rfc7540#section-6.9 + """ + local_flow: int = self._h2_state.local_flow_control_window(stream_id) + max_frame_size: int = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + while flow == 0: + await self._receive_events(request) + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + return flow + + # Interface for connection pooling... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + return ( + self._state != HTTPConnectionState.CLOSED + and not self._connection_error + and not self._used_all_stream_ids + and not ( + self._h2_state.state_machine.state + == h2.connection.ConnectionState.CLOSED + ) + ) + + def has_expired(self) -> bool: + now = time.monotonic() + return self._expire_at is not None and now > self._expire_at + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/2, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> "AsyncHTTP2Connection": + return self + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[types.TracebackType] = None, + ) -> None: + await self.aclose() + + +class HTTP2ConnectionByteStream: + def __init__( + self, connection: AsyncHTTP2Connection, request: Request, stream_id: int + ) -> None: + self._connection = connection + self._request = request + self._stream_id = stream_id + self._closed = False + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + kwargs = {"request": self._request, "stream_id": self._stream_id} + try: + async with Trace("receive_response_body", logger, self._request, kwargs): + async for chunk in self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + kwargs = {"stream_id": self._stream_id} + async with Trace("response_closed", logger, self._request, kwargs): + await self._connection._response_closed(stream_id=self._stream_id) diff --git a/contrib/python/httpcore/httpcore/_async/http_proxy.py b/contrib/python/httpcore/httpcore/_async/http_proxy.py new file mode 100644 index 0000000000..4aa7d8741a --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/http_proxy.py @@ -0,0 +1,368 @@ +import logging +import ssl +from base64 import b64encode +from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union + +from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend +from .._exceptions import ProxyError +from .._models import ( + URL, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, +) +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from .connection import AsyncHTTPConnection +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface + +HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]] +HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]] + + +logger = logging.getLogger("httpcore.proxy") + + +def merge_headers( + default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, + override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, +) -> List[Tuple[bytes, bytes]]: + """ + Append default_headers and override_headers, de-duplicating if a key exists + in both cases. + """ + default_headers = [] if default_headers is None else list(default_headers) + override_headers = [] if override_headers is None else list(override_headers) + has_override = set(key.lower() for key, value in override_headers) + default_headers = [ + (key, value) + for key, value in default_headers + if key.lower() not in has_override + ] + return default_headers + override_headers + + +def build_auth_header(username: bytes, password: bytes) -> bytes: + userpass = username + b":" + password + return b"Basic " + b64encode(userpass) + + +class AsyncHTTPProxy(AsyncConnectionPool): + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: Union[URL, bytes, str], + proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None, + proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, + ssl_context: Optional[ssl.SSLContext] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + max_connections: Optional[int] = 10, + max_keepalive_connections: Optional[int] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + proxy_auth: Any proxy authentication as a two-tuple of + (username, password). May be either bytes or ascii-only str. + proxy_headers: Any HTTP headers to use for the proxy requests. + For example `{"Proxy-Authorization": "Basic <username>:<password>"}`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + local_address=local_address, + uds=uds, + socket_options=socket_options, + ) + + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if ( + self._proxy_url.scheme == b"http" and proxy_ssl_context is not None + ): # pragma: no cover + raise RuntimeError( + "The `proxy_ssl_context` argument is not allowed for the http scheme" + ) + + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + if proxy_auth is not None: + username = enforce_bytes(proxy_auth[0], name="proxy_auth") + password = enforce_bytes(proxy_auth[1], name="proxy_auth") + authorization = build_auth_header(username, password) + self._proxy_headers = [ + (b"Proxy-Authorization", authorization) + ] + self._proxy_headers + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + if origin.scheme == b"http": + return AsyncForwardHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + proxy_ssl_context=self._proxy_ssl_context, + ) + return AsyncTunnelHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + ssl_context=self._ssl_context, + proxy_ssl_context=self._proxy_ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class AsyncForwardHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, + keepalive_expiry: Optional[float] = None, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + ) -> None: + self._connection = AsyncHTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._remote_origin = remote_origin + + async def handle_async_request(self, request: Request) -> Response: + headers = merge_headers(self._proxy_headers, request.headers) + url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=bytes(request.url), + ) + proxy_request = Request( + method=request.method, + url=url, + headers=headers, + content=request.stream, + extensions=request.extensions, + ) + return await self._connection.handle_async_request(proxy_request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + await self._connection.aclose() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + +class AsyncTunnelHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + ssl_context: Optional[ssl.SSLContext] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + self._connection: AsyncConnectionInterface = AsyncHTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._connect_lock = AsyncLock() + self._connected = False + + async def handle_async_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) + + async with self._connect_lock: + if not self._connected: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, + ) + connect_headers = merge_headers( + [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers + ) + connect_request = Request( + method=b"CONNECT", + url=connect_url, + headers=connect_headers, + extensions=request.extensions, + ) + connect_response = await self._connection.handle_async_request( + connect_request + ) + + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + await self._connection.aclose() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + return await self._connection.handle_async_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + await self._connection.aclose() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/contrib/python/httpcore/httpcore/_async/interfaces.py b/contrib/python/httpcore/httpcore/_async/interfaces.py new file mode 100644 index 0000000000..c998dd2763 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/interfaces.py @@ -0,0 +1,135 @@ +from contextlib import asynccontextmanager +from typing import AsyncIterator, Optional, Union + +from .._models import ( + URL, + Extensions, + HeaderTypes, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, + include_request_headers, +) + + +class AsyncRequestInterface: + async def request( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, AsyncIterator[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> Response: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = await self.handle_async_request(request) + try: + await response.aread() + finally: + await response.aclose() + return response + + @asynccontextmanager + async def stream( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, AsyncIterator[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> AsyncIterator[Response]: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = await self.handle_async_request(request) + try: + yield response + finally: + await response.aclose() + + async def handle_async_request(self, request: Request) -> Response: + raise NotImplementedError() # pragma: nocover + + +class AsyncConnectionInterface(AsyncRequestInterface): + async def aclose(self) -> None: + raise NotImplementedError() # pragma: nocover + + def info(self) -> str: + raise NotImplementedError() # pragma: nocover + + def can_handle_request(self, origin: Origin) -> bool: + raise NotImplementedError() # pragma: nocover + + def is_available(self) -> bool: + """ + Return `True` if the connection is currently able to accept an + outgoing request. + + An HTTP/1.1 connection will only be available if it is currently idle. + + An HTTP/2 connection will be available so long as the stream ID space is + not yet exhausted, and the connection is not in an error state. + + While the connection is being established we may not yet know if it is going + to result in an HTTP/1.1 or HTTP/2 connection. The connection should be + treated as being available, but might ultimately raise `NewConnectionRequired` + required exceptions if multiple requests are attempted over a connection + that ends up being established as HTTP/1.1. + """ + raise NotImplementedError() # pragma: nocover + + def has_expired(self) -> bool: + """ + Return `True` if the connection is in a state where it should be closed. + + This either means that the connection is idle and it has passed the + expiry time on its keep-alive, or that server has sent an EOF. + """ + raise NotImplementedError() # pragma: nocover + + def is_idle(self) -> bool: + """ + Return `True` if the connection is currently idle. + """ + raise NotImplementedError() # pragma: nocover + + def is_closed(self) -> bool: + """ + Return `True` if the connection has been closed. + + Used when a response is closed to determine if the connection may be + returned to the connection pool or not. + """ + raise NotImplementedError() # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_async/socks_proxy.py b/contrib/python/httpcore/httpcore/_async/socks_proxy.py new file mode 100644 index 0000000000..08a065d6d1 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/socks_proxy.py @@ -0,0 +1,342 @@ +import logging +import ssl +import typing + +from socksio import socks5 + +from .._backends.auto import AutoBackend +from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream +from .._exceptions import ConnectionNotAvailable, ProxyError +from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface + +logger = logging.getLogger("httpcore.socks") + + +AUTH_METHODS = { + b"\x00": "NO AUTHENTICATION REQUIRED", + b"\x01": "GSSAPI", + b"\x02": "USERNAME/PASSWORD", + b"\xff": "NO ACCEPTABLE METHODS", +} + +REPLY_CODES = { + b"\x00": "Succeeded", + b"\x01": "General SOCKS server failure", + b"\x02": "Connection not allowed by ruleset", + b"\x03": "Network unreachable", + b"\x04": "Host unreachable", + b"\x05": "Connection refused", + b"\x06": "TTL expired", + b"\x07": "Command not supported", + b"\x08": "Address type not supported", +} + + +async def _init_socks5_connection( + stream: AsyncNetworkStream, + *, + host: bytes, + port: int, + auth: typing.Optional[typing.Tuple[bytes, bytes]] = None, +) -> None: + conn = socks5.SOCKS5Connection() + + # Auth method request + auth_method = ( + socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED + if auth is None + else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD + ) + conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method])) + outgoing_bytes = conn.data_to_send() + await stream.write(outgoing_bytes) + + # Auth method response + incoming_bytes = await stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5AuthReply) + if response.method != auth_method: + requested = AUTH_METHODS.get(auth_method, "UNKNOWN") + responded = AUTH_METHODS.get(response.method, "UNKNOWN") + raise ProxyError( + f"Requested {requested} from proxy server, but got {responded}." + ) + + if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD: + # Username/password request + assert auth is not None + username, password = auth + conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password)) + outgoing_bytes = conn.data_to_send() + await stream.write(outgoing_bytes) + + # Username/password response + incoming_bytes = await stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5UsernamePasswordReply) + if not response.success: + raise ProxyError("Invalid username/password") + + # Connect request + conn.send( + socks5.SOCKS5CommandRequest.from_address( + socks5.SOCKS5Command.CONNECT, (host, port) + ) + ) + outgoing_bytes = conn.data_to_send() + await stream.write(outgoing_bytes) + + # Connect response + incoming_bytes = await stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5Reply) + if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED: + reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN") + raise ProxyError(f"Proxy Server could not connect: {reply_code}.") + + +class AsyncSOCKSProxy(AsyncConnectionPool): + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: typing.Union[URL, bytes, str], + proxy_auth: typing.Optional[ + typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]] + ] = None, + ssl_context: typing.Optional[ssl.SSLContext] = None, + max_connections: typing.Optional[int] = 10, + max_keepalive_connections: typing.Optional[int] = None, + keepalive_expiry: typing.Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + network_backend: typing.Optional[AsyncNetworkBackend] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + ) + self._ssl_context = ssl_context + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if proxy_auth is not None: + username, password = proxy_auth + username_bytes = enforce_bytes(username, name="proxy_auth") + password_bytes = enforce_bytes(password, name="proxy_auth") + self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = ( + username_bytes, + password_bytes, + ) + else: + self._proxy_auth = None + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + return AsyncSocks5Connection( + proxy_origin=self._proxy_url.origin, + remote_origin=origin, + proxy_auth=self._proxy_auth, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class AsyncSocks5Connection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None, + ssl_context: typing.Optional[ssl.SSLContext] = None, + keepalive_expiry: typing.Optional[float] = None, + http1: bool = True, + http2: bool = False, + network_backend: typing.Optional[AsyncNetworkBackend] = None, + ) -> None: + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._proxy_auth = proxy_auth + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + + self._network_backend: AsyncNetworkBackend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._connect_lock = AsyncLock() + self._connection: typing.Optional[AsyncConnectionInterface] = None + self._connect_failed = False + + async def handle_async_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + async with self._connect_lock: + if self._connection is None: + try: + # Connect to the proxy + kwargs = { + "host": self._proxy_origin.host.decode("ascii"), + "port": self._proxy_origin.port, + "timeout": timeout, + } + with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = await self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + + # Connect to the remote host using socks5 + kwargs = { + "stream": stream, + "host": self._remote_origin.host.decode("ascii"), + "port": self._remote_origin.port, + "auth": self._proxy_auth, + } + with Trace( + "setup_socks5_connection", logger, request, kwargs + ) as trace: + await _init_socks5_connection(**kwargs) + trace.return_value = stream + + # Upgrade the stream to SSL + if self._remote_origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ( + ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ) + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or ( + self._http2 and not self._http1 + ): # pragma: nocover + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): # pragma: nocover + raise ConnectionNotAvailable() + + return await self._connection.handle_async_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + if self._connection is not None: + await self._connection.aclose() + + def is_available(self) -> bool: + if self._connection is None: # pragma: nocover + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._remote_origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: # pragma: nocover + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/contrib/python/httpcore/httpcore/_backends/__init__.py b/contrib/python/httpcore/httpcore/_backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/__init__.py diff --git a/contrib/python/httpcore/httpcore/_backends/anyio.py b/contrib/python/httpcore/httpcore/_backends/anyio.py new file mode 100644 index 0000000000..1ed5228dbd --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/anyio.py @@ -0,0 +1,145 @@ +import ssl +import typing + +import anyio + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class AnyIOStream(AsyncNetworkStream): + def __init__(self, stream: anyio.abc.ByteStream) -> None: + self._stream = stream + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + exc_map = { + TimeoutError: ReadTimeout, + anyio.BrokenResourceError: ReadError, + anyio.ClosedResourceError: ReadError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + try: + return await self._stream.receive(max_bytes=max_bytes) + except anyio.EndOfStream: # pragma: nocover + return b"" + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + if not buffer: + return + + exc_map = { + TimeoutError: WriteTimeout, + anyio.BrokenResourceError: WriteError, + anyio.ClosedResourceError: WriteError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + await self._stream.send(item=buffer) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> AsyncNetworkStream: + exc_map = { + TimeoutError: ConnectTimeout, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + try: + with anyio.fail_after(timeout): + ssl_stream = await anyio.streams.tls.TLSStream.wrap( + self._stream, + ssl_context=ssl_context, + hostname=server_hostname, + standard_compatible=False, + server_side=False, + ) + except Exception as exc: # pragma: nocover + await self.aclose() + raise exc + return AnyIOStream(ssl_stream) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object": + return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None) + if info == "client_addr": + return self._stream.extra(anyio.abc.SocketAttribute.local_address, None) + if info == "server_addr": + return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None) + if info == "socket": + return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + if info == "is_readable": + sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + return is_socket_readable(sock) + return None + + +class AnyIOBackend(AsyncNetworkBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + if socket_options is None: + socket_options = [] # pragma: no cover + exc_map = { + TimeoutError: ConnectTimeout, + OSError: ConnectError, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + stream: anyio.abc.ByteStream = await anyio.connect_tcp( + remote_host=host, + remote_port=port, + local_host=local_address, + ) + # By default TCP sockets opened in `asyncio` include TCP_NODELAY. + for option in socket_options: + stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return AnyIOStream(stream) + + async def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: # pragma: nocover + if socket_options is None: + socket_options = [] + exc_map = { + TimeoutError: ConnectTimeout, + OSError: ConnectError, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + stream: anyio.abc.ByteStream = await anyio.connect_unix(path) + for option in socket_options: + stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return AnyIOStream(stream) + + async def sleep(self, seconds: float) -> None: + await anyio.sleep(seconds) # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_backends/auto.py b/contrib/python/httpcore/httpcore/_backends/auto.py new file mode 100644 index 0000000000..b612ba071c --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/auto.py @@ -0,0 +1,52 @@ +import typing +from typing import Optional + +import sniffio + +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class AutoBackend(AsyncNetworkBackend): + async def _init_backend(self) -> None: + if not (hasattr(self, "_backend")): + backend = sniffio.current_async_library() + if backend == "trio": + from .trio import TrioBackend + + self._backend: AsyncNetworkBackend = TrioBackend() + else: + from .anyio import AnyIOBackend + + self._backend = AnyIOBackend() + + async def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + await self._init_backend() + return await self._backend.connect_tcp( + host, + port, + timeout=timeout, + local_address=local_address, + socket_options=socket_options, + ) + + async def connect_unix_socket( + self, + path: str, + timeout: Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: # pragma: nocover + await self._init_backend() + return await self._backend.connect_unix_socket( + path, timeout=timeout, socket_options=socket_options + ) + + async def sleep(self, seconds: float) -> None: # pragma: nocover + await self._init_backend() + return await self._backend.sleep(seconds) diff --git a/contrib/python/httpcore/httpcore/_backends/base.py b/contrib/python/httpcore/httpcore/_backends/base.py new file mode 100644 index 0000000000..6cadedb5f9 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/base.py @@ -0,0 +1,103 @@ +import ssl +import time +import typing + +SOCKET_OPTION = typing.Union[ + typing.Tuple[int, int, int], + typing.Tuple[int, int, typing.Union[bytes, bytearray]], + typing.Tuple[int, int, None, int], +] + + +class NetworkStream: + def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: + raise NotImplementedError() # pragma: nocover + + def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: + raise NotImplementedError() # pragma: nocover + + def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> "NetworkStream": + raise NotImplementedError() # pragma: nocover + + def get_extra_info(self, info: str) -> typing.Any: + return None # pragma: nocover + + +class NetworkBackend: + def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + raise NotImplementedError() # pragma: nocover + + def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + raise NotImplementedError() # pragma: nocover + + def sleep(self, seconds: float) -> None: + time.sleep(seconds) # pragma: nocover + + +class AsyncNetworkStream: + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + raise NotImplementedError() # pragma: nocover + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + raise NotImplementedError() # pragma: nocover + + async def aclose(self) -> None: + raise NotImplementedError() # pragma: nocover + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> "AsyncNetworkStream": + raise NotImplementedError() # pragma: nocover + + def get_extra_info(self, info: str) -> typing.Any: + return None # pragma: nocover + + +class AsyncNetworkBackend: + async def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + raise NotImplementedError() # pragma: nocover + + async def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + raise NotImplementedError() # pragma: nocover + + async def sleep(self, seconds: float) -> None: + raise NotImplementedError() # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_backends/mock.py b/contrib/python/httpcore/httpcore/_backends/mock.py new file mode 100644 index 0000000000..f7aefebf51 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/mock.py @@ -0,0 +1,142 @@ +import ssl +import typing +from typing import Optional + +from .._exceptions import ReadError +from .base import ( + SOCKET_OPTION, + AsyncNetworkBackend, + AsyncNetworkStream, + NetworkBackend, + NetworkStream, +) + + +class MockSSLObject: + def __init__(self, http2: bool): + self._http2 = http2 + + def selected_alpn_protocol(self) -> str: + return "h2" if self._http2 else "http/1.1" + + +class MockStream(NetworkStream): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + self._closed = False + + def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + if self._closed: + raise ReadError("Connection closed") + if not self._buffer: + return b"" + return self._buffer.pop(0) + + def write(self, buffer: bytes, timeout: Optional[float] = None) -> None: + pass + + def close(self) -> None: + self._closed = True + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: Optional[str] = None, + timeout: Optional[float] = None, + ) -> NetworkStream: + return self + + def get_extra_info(self, info: str) -> typing.Any: + return MockSSLObject(http2=self._http2) if info == "ssl_object" else None + + def __repr__(self) -> str: + return "<httpcore.MockStream>" + + +class MockBackend(NetworkBackend): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + + def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + return MockStream(list(self._buffer), http2=self._http2) + + def connect_unix_socket( + self, + path: str, + timeout: Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + return MockStream(list(self._buffer), http2=self._http2) + + def sleep(self, seconds: float) -> None: + pass + + +class AsyncMockStream(AsyncNetworkStream): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + self._closed = False + + async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + if self._closed: + raise ReadError("Connection closed") + if not self._buffer: + return b"" + return self._buffer.pop(0) + + async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None: + pass + + async def aclose(self) -> None: + self._closed = True + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: Optional[str] = None, + timeout: Optional[float] = None, + ) -> AsyncNetworkStream: + return self + + def get_extra_info(self, info: str) -> typing.Any: + return MockSSLObject(http2=self._http2) if info == "ssl_object" else None + + def __repr__(self) -> str: + return "<httpcore.AsyncMockStream>" + + +class AsyncMockBackend(AsyncNetworkBackend): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + + async def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + return AsyncMockStream(list(self._buffer), http2=self._http2) + + async def connect_unix_socket( + self, + path: str, + timeout: Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + return AsyncMockStream(list(self._buffer), http2=self._http2) + + async def sleep(self, seconds: float) -> None: + pass diff --git a/contrib/python/httpcore/httpcore/_backends/sync.py b/contrib/python/httpcore/httpcore/_backends/sync.py new file mode 100644 index 0000000000..f2dbd32afa --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/sync.py @@ -0,0 +1,245 @@ +import socket +import ssl +import sys +import typing +from functools import partial + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ExceptionMapping, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import SOCKET_OPTION, NetworkBackend, NetworkStream + + +class TLSinTLSStream(NetworkStream): # pragma: no cover + """ + Because the standard `SSLContext.wrap_socket` method does + not work for `SSLSocket` objects, we need this class + to implement TLS stream using an underlying `SSLObject` + instance in order to support TLS on top of TLS. + """ + + # Defined in RFC 8449 + TLS_RECORD_SIZE = 16384 + + def __init__( + self, + sock: socket.socket, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ): + self._sock = sock + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + + self.ssl_obj = ssl_context.wrap_bio( + incoming=self._incoming, + outgoing=self._outgoing, + server_hostname=server_hostname, + ) + + self._sock.settimeout(timeout) + self._perform_io(self.ssl_obj.do_handshake) + + def _perform_io( + self, + func: typing.Callable[..., typing.Any], + ) -> typing.Any: + ret = None + + while True: + errno = None + try: + ret = func() + except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: + errno = e.errno + + self._sock.sendall(self._outgoing.read()) + + if errno == ssl.SSL_ERROR_WANT_READ: + buf = self._sock.recv(self.TLS_RECORD_SIZE) + + if buf: + self._incoming.write(buf) + else: + self._incoming.write_eof() + if errno is None: + return ret + + def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: + exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + return typing.cast( + bytes, self._perform_io(partial(self.ssl_obj.read, max_bytes)) + ) + + def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: + exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + while buffer: + nsent = self._perform_io(partial(self.ssl_obj.write, buffer)) + buffer = buffer[nsent:] + + def close(self) -> None: + self._sock.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> "NetworkStream": + raise NotImplementedError() + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object": + return self.ssl_obj + if info == "client_addr": + return self._sock.getsockname() + if info == "server_addr": + return self._sock.getpeername() + if info == "socket": + return self._sock + if info == "is_readable": + return is_socket_readable(self._sock) + return None + + +class SyncStream(NetworkStream): + def __init__(self, sock: socket.socket) -> None: + self._sock = sock + + def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: + exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + return self._sock.recv(max_bytes) + + def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: + if not buffer: + return + + exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} + with map_exceptions(exc_map): + while buffer: + self._sock.settimeout(timeout) + n = self._sock.send(buffer) + buffer = buffer[n:] + + def close(self) -> None: + self._sock.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> NetworkStream: + if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover + raise RuntimeError( + "Attempted to add a TLS layer on top of the existing " + "TLS stream, which is not supported by httpcore package" + ) + + exc_map: ExceptionMapping = { + socket.timeout: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + try: + if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover + # If the underlying socket has already been upgraded + # to the TLS layer (i.e. is an instance of SSLSocket), + # we need some additional smarts to support TLS-in-TLS. + return TLSinTLSStream( + self._sock, ssl_context, server_hostname, timeout + ) + else: + self._sock.settimeout(timeout) + sock = ssl_context.wrap_socket( + self._sock, server_hostname=server_hostname + ) + except Exception as exc: # pragma: nocover + self.close() + raise exc + return SyncStream(sock) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket): + return self._sock._sslobj # type: ignore + if info == "client_addr": + return self._sock.getsockname() + if info == "server_addr": + return self._sock.getpeername() + if info == "socket": + return self._sock + if info == "is_readable": + return is_socket_readable(self._sock) + return None + + +class SyncBackend(NetworkBackend): + def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + # Note that we automatically include `TCP_NODELAY` + # in addition to any other custom socket options. + if socket_options is None: + socket_options = [] # pragma: no cover + address = (host, port) + source_address = None if local_address is None else (local_address, 0) + exc_map: ExceptionMapping = { + socket.timeout: ConnectTimeout, + OSError: ConnectError, + } + + with map_exceptions(exc_map): + sock = socket.create_connection( + address, + timeout, + source_address=source_address, + ) + for option in socket_options: + sock.setsockopt(*option) # pragma: no cover + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return SyncStream(sock) + + def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: # pragma: nocover + if sys.platform == "win32": + raise RuntimeError( + "Attempted to connect to a UNIX socket on a Windows system." + ) + if socket_options is None: + socket_options = [] + + exc_map: ExceptionMapping = { + socket.timeout: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + for option in socket_options: + sock.setsockopt(*option) + sock.settimeout(timeout) + sock.connect(path) + return SyncStream(sock) diff --git a/contrib/python/httpcore/httpcore/_backends/trio.py b/contrib/python/httpcore/httpcore/_backends/trio.py new file mode 100644 index 0000000000..b1626d28e2 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/trio.py @@ -0,0 +1,161 @@ +import ssl +import typing + +import trio + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ExceptionMapping, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class TrioStream(AsyncNetworkStream): + def __init__(self, stream: trio.abc.Stream) -> None: + self._stream = stream + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ReadTimeout, + trio.BrokenResourceError: ReadError, + trio.ClosedResourceError: ReadError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + data: bytes = await self._stream.receive_some(max_bytes=max_bytes) + return data + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + if not buffer: + return + + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: WriteTimeout, + trio.BrokenResourceError: WriteError, + trio.ClosedResourceError: WriteError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + await self._stream.send_all(data=buffer) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> AsyncNetworkStream: + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + } + ssl_stream = trio.SSLStream( + self._stream, + ssl_context=ssl_context, + server_hostname=server_hostname, + https_compatible=True, + server_side=False, + ) + with map_exceptions(exc_map): + try: + with trio.fail_after(timeout_or_inf): + await ssl_stream.do_handshake() + except Exception as exc: # pragma: nocover + await self.aclose() + raise exc + return TrioStream(ssl_stream) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object" and isinstance(self._stream, trio.SSLStream): + # Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__. + # Tracked at https://github.com/python-trio/trio/issues/542 + return self._stream._ssl_object # type: ignore[attr-defined] + if info == "client_addr": + return self._get_socket_stream().socket.getsockname() + if info == "server_addr": + return self._get_socket_stream().socket.getpeername() + if info == "socket": + stream = self._stream + while isinstance(stream, trio.SSLStream): + stream = stream.transport_stream + assert isinstance(stream, trio.SocketStream) + return stream.socket + if info == "is_readable": + socket = self.get_extra_info("socket") + return socket.is_readable() + return None + + def _get_socket_stream(self) -> trio.SocketStream: + stream = self._stream + while isinstance(stream, trio.SSLStream): + stream = stream.transport_stream + assert isinstance(stream, trio.SocketStream) + return stream + + +class TrioBackend(AsyncNetworkBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + # By default for TCP sockets, trio enables TCP_NODELAY. + # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream + if socket_options is None: + socket_options = [] # pragma: no cover + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + OSError: ConnectError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + stream: trio.abc.Stream = await trio.open_tcp_stream( + host=host, port=port, local_address=local_address + ) + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return TrioStream(stream) + + async def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: # pragma: nocover + if socket_options is None: + socket_options = [] + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + OSError: ConnectError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + stream: trio.abc.Stream = await trio.open_unix_socket(path) + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return TrioStream(stream) + + async def sleep(self, seconds: float) -> None: + await trio.sleep(seconds) # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_exceptions.py b/contrib/python/httpcore/httpcore/_exceptions.py new file mode 100644 index 0000000000..81e7fc61dd --- /dev/null +++ b/contrib/python/httpcore/httpcore/_exceptions.py @@ -0,0 +1,81 @@ +import contextlib +from typing import Iterator, Mapping, Type + +ExceptionMapping = Mapping[Type[Exception], Type[Exception]] + + +@contextlib.contextmanager +def map_exceptions(map: ExceptionMapping) -> Iterator[None]: + try: + yield + except Exception as exc: # noqa: PIE786 + for from_exc, to_exc in map.items(): + if isinstance(exc, from_exc): + raise to_exc(exc) from exc + raise # pragma: nocover + + +class ConnectionNotAvailable(Exception): + pass + + +class ProxyError(Exception): + pass + + +class UnsupportedProtocol(Exception): + pass + + +class ProtocolError(Exception): + pass + + +class RemoteProtocolError(ProtocolError): + pass + + +class LocalProtocolError(ProtocolError): + pass + + +# Timeout errors + + +class TimeoutException(Exception): + pass + + +class PoolTimeout(TimeoutException): + pass + + +class ConnectTimeout(TimeoutException): + pass + + +class ReadTimeout(TimeoutException): + pass + + +class WriteTimeout(TimeoutException): + pass + + +# Network errors + + +class NetworkError(Exception): + pass + + +class ConnectError(NetworkError): + pass + + +class ReadError(NetworkError): + pass + + +class WriteError(NetworkError): + pass diff --git a/contrib/python/httpcore/httpcore/_models.py b/contrib/python/httpcore/httpcore/_models.py new file mode 100644 index 0000000000..11bfcd84f0 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_models.py @@ -0,0 +1,484 @@ +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, + Union, +) +from urllib.parse import urlparse + +# Functions for typechecking... + + +HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]] +HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]] +HeaderTypes = Union[HeadersAsSequence, HeadersAsMapping, None] + +Extensions = MutableMapping[str, Any] + + +def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes: + """ + Any arguments that are ultimately represented as bytes can be specified + either as bytes or as strings. + + However we enforce that any string arguments must only contain characters in + the plain ASCII range. chr(0)...chr(127). If you need to use characters + outside that range then be precise, and use a byte-wise argument. + """ + if isinstance(value, str): + try: + return value.encode("ascii") + except UnicodeEncodeError: + raise TypeError(f"{name} strings may not include unicode characters.") + elif isinstance(value, bytes): + return value + + seen_type = type(value).__name__ + raise TypeError(f"{name} must be bytes or str, but got {seen_type}.") + + +def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL": + """ + Type check for URL parameters. + """ + if isinstance(value, (bytes, str)): + return URL(value) + elif isinstance(value, URL): + return value + + seen_type = type(value).__name__ + raise TypeError(f"{name} must be a URL, bytes, or str, but got {seen_type}.") + + +def enforce_headers( + value: Union[HeadersAsMapping, HeadersAsSequence, None] = None, *, name: str +) -> List[Tuple[bytes, bytes]]: + """ + Convienence function that ensure all items in request or response headers + are either bytes or strings in the plain ASCII range. + """ + if value is None: + return [] + elif isinstance(value, Mapping): + return [ + ( + enforce_bytes(k, name="header name"), + enforce_bytes(v, name="header value"), + ) + for k, v in value.items() + ] + elif isinstance(value, Sequence): + return [ + ( + enforce_bytes(k, name="header name"), + enforce_bytes(v, name="header value"), + ) + for k, v in value + ] + + seen_type = type(value).__name__ + raise TypeError( + f"{name} must be a mapping or sequence of two-tuples, but got {seen_type}." + ) + + +def enforce_stream( + value: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None], *, name: str +) -> Union[Iterable[bytes], AsyncIterable[bytes]]: + if value is None: + return ByteStream(b"") + elif isinstance(value, bytes): + return ByteStream(value) + return value + + +# * https://tools.ietf.org/html/rfc3986#section-3.2.3 +# * https://url.spec.whatwg.org/#url-miscellaneous +# * https://url.spec.whatwg.org/#scheme-state +DEFAULT_PORTS = { + b"ftp": 21, + b"http": 80, + b"https": 443, + b"ws": 80, + b"wss": 443, +} + + +def include_request_headers( + headers: List[Tuple[bytes, bytes]], + *, + url: "URL", + content: Union[None, bytes, Iterable[bytes], AsyncIterable[bytes]], +) -> List[Tuple[bytes, bytes]]: + headers_set = set(k.lower() for k, v in headers) + + if b"host" not in headers_set: + default_port = DEFAULT_PORTS.get(url.scheme) + if url.port is None or url.port == default_port: + header_value = url.host + else: + header_value = b"%b:%d" % (url.host, url.port) + headers = [(b"Host", header_value)] + headers + + if ( + content is not None + and b"content-length" not in headers_set + and b"transfer-encoding" not in headers_set + ): + if isinstance(content, bytes): + content_length = str(len(content)).encode("ascii") + headers += [(b"Content-Length", content_length)] + else: + headers += [(b"Transfer-Encoding", b"chunked")] # pragma: nocover + + return headers + + +# Interfaces for byte streams... + + +class ByteStream: + """ + A container for non-streaming content, and that supports both sync and async + stream iteration. + """ + + def __init__(self, content: bytes) -> None: + self._content = content + + def __iter__(self) -> Iterator[bytes]: + yield self._content + + async def __aiter__(self) -> AsyncIterator[bytes]: + yield self._content + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{len(self._content)} bytes]>" + + +class Origin: + def __init__(self, scheme: bytes, host: bytes, port: int) -> None: + self.scheme = scheme + self.host = host + self.port = port + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, Origin) + and self.scheme == other.scheme + and self.host == other.host + and self.port == other.port + ) + + def __str__(self) -> str: + scheme = self.scheme.decode("ascii") + host = self.host.decode("ascii") + port = str(self.port) + return f"{scheme}://{host}:{port}" + + +class URL: + """ + Represents the URL against which an HTTP request may be made. + + The URL may either be specified as a plain string, for convienence: + + ```python + url = httpcore.URL("https://www.example.com/") + ``` + + Or be constructed with explicitily pre-parsed components: + + ```python + url = httpcore.URL(scheme=b'https', host=b'www.example.com', port=None, target=b'/') + ``` + + Using this second more explicit style allows integrations that are using + `httpcore` to pass through URLs that have already been parsed in order to use + libraries such as `rfc-3986` rather than relying on the stdlib. It also ensures + that URL parsing is treated identically at both the networking level and at any + higher layers of abstraction. + + The four components are important here, as they allow the URL to be precisely + specified in a pre-parsed format. They also allow certain types of request to + be created that could not otherwise be expressed. + + For example, an HTTP request to `http://www.example.com/` forwarded via a proxy + at `http://localhost:8080`... + + ```python + # Constructs an HTTP request with a complete URL as the target: + # GET https://www.example.com/ HTTP/1.1 + url = httpcore.URL( + scheme=b'http', + host=b'localhost', + port=8080, + target=b'https://www.example.com/' + ) + request = httpcore.Request( + method="GET", + url=url + ) + ``` + + Another example is constructing an `OPTIONS *` request... + + ```python + # Constructs an 'OPTIONS *' HTTP request: + # OPTIONS * HTTP/1.1 + url = httpcore.URL(scheme=b'https', host=b'www.example.com', target=b'*') + request = httpcore.Request(method="OPTIONS", url=url) + ``` + + This kind of request is not possible to formulate with a URL string, + because the `/` delimiter is always used to demark the target from the + host/port portion of the URL. + + For convenience, string-like arguments may be specified either as strings or + as bytes. However, once a request is being issue over-the-wire, the URL + components are always ultimately required to be a bytewise representation. + + In order to avoid any ambiguity over character encodings, when strings are used + as arguments, they must be strictly limited to the ASCII range `chr(0)`-`chr(127)`. + If you require a bytewise representation that is outside this range you must + handle the character encoding directly, and pass a bytes instance. + """ + + def __init__( + self, + url: Union[bytes, str] = "", + *, + scheme: Union[bytes, str] = b"", + host: Union[bytes, str] = b"", + port: Optional[int] = None, + target: Union[bytes, str] = b"", + ) -> None: + """ + Parameters: + url: The complete URL as a string or bytes. + scheme: The URL scheme as a string or bytes. + Typically either `"http"` or `"https"`. + host: The URL host as a string or bytes. Such as `"www.example.com"`. + port: The port to connect to. Either an integer or `None`. + target: The target of the HTTP request. Such as `"/items?search=red"`. + """ + if url: + parsed = urlparse(enforce_bytes(url, name="url")) + self.scheme = parsed.scheme + self.host = parsed.hostname or b"" + self.port = parsed.port + self.target = (parsed.path or b"/") + ( + b"?" + parsed.query if parsed.query else b"" + ) + else: + self.scheme = enforce_bytes(scheme, name="scheme") + self.host = enforce_bytes(host, name="host") + self.port = port + self.target = enforce_bytes(target, name="target") + + @property + def origin(self) -> Origin: + default_port = { + b"http": 80, + b"https": 443, + b"ws": 80, + b"wss": 443, + b"socks5": 1080, + }[self.scheme] + return Origin( + scheme=self.scheme, host=self.host, port=self.port or default_port + ) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, URL) + and other.scheme == self.scheme + and other.host == self.host + and other.port == self.port + and other.target == self.target + ) + + def __bytes__(self) -> bytes: + if self.port is None: + return b"%b://%b%b" % (self.scheme, self.host, self.target) + return b"%b://%b:%d%b" % (self.scheme, self.host, self.port, self.target) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(scheme={self.scheme!r}, " + f"host={self.host!r}, port={self.port!r}, target={self.target!r})" + ) + + +class Request: + """ + An HTTP request. + """ + + def __init__( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> None: + """ + Parameters: + method: The HTTP request method, either as a string or bytes. + For example: `GET`. + url: The request URL, either as a `URL` instance, or as a string or bytes. + For example: `"https://www.example.com".` + headers: The HTTP request headers. + content: The content of the response body. + extensions: A dictionary of optional extra information included on + the request. Possible keys include `"timeout"`, and `"trace"`. + """ + self.method: bytes = enforce_bytes(method, name="method") + self.url: URL = enforce_url(url, name="url") + self.headers: List[Tuple[bytes, bytes]] = enforce_headers( + headers, name="headers" + ) + self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream( + content, name="content" + ) + self.extensions = {} if extensions is None else extensions + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.method!r}]>" + + +class Response: + """ + An HTTP response. + """ + + def __init__( + self, + status: int, + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> None: + """ + Parameters: + status: The HTTP status code of the response. For example `200`. + headers: The HTTP response headers. + content: The content of the response body. + extensions: A dictionary of optional extra information included on + the responseself.Possible keys include `"http_version"`, + `"reason_phrase"`, and `"network_stream"`. + """ + self.status: int = status + self.headers: List[Tuple[bytes, bytes]] = enforce_headers( + headers, name="headers" + ) + self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream( + content, name="content" + ) + self.extensions = {} if extensions is None else extensions + + self._stream_consumed = False + + @property + def content(self) -> bytes: + if not hasattr(self, "_content"): + if isinstance(self.stream, Iterable): + raise RuntimeError( + "Attempted to access 'response.content' on a streaming response. " + "Call 'response.read()' first." + ) + else: + raise RuntimeError( + "Attempted to access 'response.content' on a streaming response. " + "Call 'await response.aread()' first." + ) + return self._content + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.status}]>" + + # Sync interface... + + def read(self) -> bytes: + if not isinstance(self.stream, Iterable): # pragma: nocover + raise RuntimeError( + "Attempted to read an asynchronous response using 'response.read()'. " + "You should use 'await response.aread()' instead." + ) + if not hasattr(self, "_content"): + self._content = b"".join([part for part in self.iter_stream()]) + return self._content + + def iter_stream(self) -> Iterator[bytes]: + if not isinstance(self.stream, Iterable): # pragma: nocover + raise RuntimeError( + "Attempted to stream an asynchronous response using 'for ... in " + "response.iter_stream()'. " + "You should use 'async for ... in response.aiter_stream()' instead." + ) + if self._stream_consumed: + raise RuntimeError( + "Attempted to call 'for ... in response.iter_stream()' more than once." + ) + self._stream_consumed = True + for chunk in self.stream: + yield chunk + + def close(self) -> None: + if not isinstance(self.stream, Iterable): # pragma: nocover + raise RuntimeError( + "Attempted to close an asynchronous response using 'response.close()'. " + "You should use 'await response.aclose()' instead." + ) + if hasattr(self.stream, "close"): + self.stream.close() + + # Async interface... + + async def aread(self) -> bytes: + if not isinstance(self.stream, AsyncIterable): # pragma: nocover + raise RuntimeError( + "Attempted to read an synchronous response using " + "'await response.aread()'. " + "You should use 'response.read()' instead." + ) + if not hasattr(self, "_content"): + self._content = b"".join([part async for part in self.aiter_stream()]) + return self._content + + async def aiter_stream(self) -> AsyncIterator[bytes]: + if not isinstance(self.stream, AsyncIterable): # pragma: nocover + raise RuntimeError( + "Attempted to stream an synchronous response using 'async for ... in " + "response.aiter_stream()'. " + "You should use 'for ... in response.iter_stream()' instead." + ) + if self._stream_consumed: + raise RuntimeError( + "Attempted to call 'async for ... in response.aiter_stream()' " + "more than once." + ) + self._stream_consumed = True + async for chunk in self.stream: + yield chunk + + async def aclose(self) -> None: + if not isinstance(self.stream, AsyncIterable): # pragma: nocover + raise RuntimeError( + "Attempted to close a synchronous response using " + "'await response.aclose()'. " + "You should use 'response.close()' instead." + ) + if hasattr(self.stream, "aclose"): + await self.stream.aclose() diff --git a/contrib/python/httpcore/httpcore/_ssl.py b/contrib/python/httpcore/httpcore/_ssl.py new file mode 100644 index 0000000000..c99c5a6794 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_ssl.py @@ -0,0 +1,9 @@ +import ssl + +import certifi + + +def default_ssl_context() -> ssl.SSLContext: + context = ssl.create_default_context() + context.load_verify_locations(certifi.where()) + return context diff --git a/contrib/python/httpcore/httpcore/_sync/__init__.py b/contrib/python/httpcore/httpcore/_sync/__init__.py new file mode 100644 index 0000000000..b476d76d9a --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/__init__.py @@ -0,0 +1,39 @@ +from .connection import HTTPConnection +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .http_proxy import HTTPProxy +from .interfaces import ConnectionInterface + +try: + from .http2 import HTTP2Connection +except ImportError: # pragma: nocover + + class HTTP2Connection: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use http2 support, but the `h2` package is not " + "installed. Use 'pip install httpcore[http2]'." + ) + + +try: + from .socks_proxy import SOCKSProxy +except ImportError: # pragma: nocover + + class SOCKSProxy: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use SOCKS support, but the `socksio` package is not " + "installed. Use 'pip install httpcore[socks]'." + ) + + +__all__ = [ + "HTTPConnection", + "ConnectionPool", + "HTTPProxy", + "HTTP11Connection", + "HTTP2Connection", + "ConnectionInterface", + "SOCKSProxy", +] diff --git a/contrib/python/httpcore/httpcore/_sync/connection.py b/contrib/python/httpcore/httpcore/_sync/connection.py new file mode 100644 index 0000000000..81e4172a21 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/connection.py @@ -0,0 +1,222 @@ +import itertools +import logging +import ssl +from types import TracebackType +from typing import Iterable, Iterator, Optional, Type + +from .._backends.sync import SyncBackend +from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream +from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface + +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. + + +logger = logging.getLogger("httpcore.connection") + + +def exponential_backoff(factor: float) -> Iterator[float]: + """ + Generate a geometric sequence that has a ratio of 2 and starts with 0. + + For example: + - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...` + - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...` + """ + yield 0 + for n in itertools.count(): + yield factor * 2**n + + +class HTTPConnection(ConnectionInterface): + def __init__( + self, + origin: Origin, + ssl_context: Optional[ssl.SSLContext] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + self._origin = origin + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._network_backend: NetworkBackend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._connection: Optional[ConnectionInterface] = None + self._connect_failed: bool = False + self._request_lock = Lock() + self._socket_options = socket_options + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection to {self._origin}" + ) + + with self._request_lock: + if self._connection is None: + try: + stream = self._connect(request) + + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): + raise ConnectionNotAvailable() + + return self._connection.handle_request(request) + + def _connect(self, request: Request) -> NetworkStream: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + retries_left = self._retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + if self._uds is None: + kwargs = { + "host": self._origin.host.decode("ascii"), + "port": self._origin.port, + "local_address": self._local_address, + "timeout": timeout, + "socket_options": self._socket_options, + } + with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + else: + kwargs = { + "path": self._uds, + "timeout": timeout, + "socket_options": self._socket_options, + } + with Trace( + "connect_unix_socket", logger, request, kwargs + ) as trace: + stream = self._network_backend.connect_unix_socket( + **kwargs + ) + trace.return_value = stream + + if self._origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + return stream + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + with Trace("retry", logger, request, kwargs) as trace: + self._network_backend.sleep(delay) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def close(self) -> None: + if self._connection is not None: + with Trace("close", logger, None, {}): + self._connection.close() + + def is_available(self) -> bool: + if self._connection is None: + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> "HTTPConnection": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + self.close() diff --git a/contrib/python/httpcore/httpcore/_sync/connection_pool.py b/contrib/python/httpcore/httpcore/_sync/connection_pool.py new file mode 100644 index 0000000000..dbcaff1fcf --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/connection_pool.py @@ -0,0 +1,356 @@ +import ssl +import sys +from types import TracebackType +from typing import Iterable, Iterator, Iterable, List, Optional, Type + +from .._backends.sync import SyncBackend +from .._backends.base import SOCKET_OPTION, NetworkBackend +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Request, Response +from .._synchronization import Event, Lock, ShieldCancellation +from .connection import HTTPConnection +from .interfaces import ConnectionInterface, RequestInterface + + +class RequestStatus: + def __init__(self, request: Request): + self.request = request + self.connection: Optional[ConnectionInterface] = None + self._connection_acquired = Event() + + def set_connection(self, connection: ConnectionInterface) -> None: + assert self.connection is None + self.connection = connection + self._connection_acquired.set() + + def unset_connection(self) -> None: + assert self.connection is not None + self.connection = None + self._connection_acquired = Event() + + def wait_for_connection( + self, timeout: Optional[float] = None + ) -> ConnectionInterface: + if self.connection is None: + self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection + + +class ConnectionPool(RequestInterface): + """ + A connection pool for making HTTP requests. + """ + + def __init__( + self, + ssl_context: Optional[ssl.SSLContext] = None, + max_connections: Optional[int] = 10, + max_keepalive_connections: Optional[int] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish a + connection. + local_address: Local address to connect from. Can also be used to connect + using a particular address family. Using `local_address="0.0.0.0"` + will connect using an `AF_INET` address (IPv4), while using + `local_address="::"` will connect using an `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + socket_options: Socket options that have to be included + in the TCP socket when the connection was established. + """ + self._ssl_context = ssl_context + + self._max_connections = ( + sys.maxsize if max_connections is None else max_connections + ) + self._max_keepalive_connections = ( + sys.maxsize + if max_keepalive_connections is None + else max_keepalive_connections + ) + self._max_keepalive_connections = min( + self._max_connections, self._max_keepalive_connections + ) + + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._pool: List[ConnectionInterface] = [] + self._requests: List[RequestStatus] = [] + self._pool_lock = Lock() + self._network_backend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._socket_options = socket_options + + def create_connection(self, origin: Origin) -> ConnectionInterface: + return HTTPConnection( + origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + retries=self._retries, + local_address=self._local_address, + uds=self._uds, + network_backend=self._network_backend, + socket_options=self._socket_options, + ) + + @property + def connections(self) -> List[ConnectionInterface]: + """ + Return a list of the connections currently in the pool. + + For example: + + ```python + >>> pool.connections + [ + <HTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>, + <HTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> , + <HTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>, + ] + ``` + """ + return list(self._pool) + + def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool: + """ + Attempt to provide a connection that can handle the given origin. + """ + origin = status.request.url.origin + + # If there are queued requests in front of us, then don't acquire a + # connection. We handle requests strictly in order. + waiting = [s for s in self._requests if s.connection is None] + if waiting and waiting[0] is not status: + return False + + # Reuse an existing connection if one is currently available. + for idx, connection in enumerate(self._pool): + if connection.can_handle_request(origin) and connection.is_available(): + self._pool.pop(idx) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + # If the pool is currently full, attempt to close one idle connection. + if len(self._pool) >= self._max_connections: + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle(): + connection.close() + self._pool.pop(idx) + break + + # If the pool is still full, then we cannot acquire a connection. + if len(self._pool) >= self._max_connections: + return False + + # Otherwise create a new connection. + connection = self.create_connection(origin) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + def _close_expired_connections(self) -> None: + """ + Clean up the connection pool by closing off any connections that have expired. + """ + # Close any connections that have expired their keep-alive time. + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.has_expired(): + connection.close() + self._pool.pop(idx) + + # If the pool size exceeds the maximum number of allowed keep-alive connections, + # then close off idle connections as required. + pool_size = len(self._pool) + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle() and pool_size > self._max_keepalive_connections: + connection.close() + self._pool.pop(idx) + pool_size -= 1 + + def handle_request(self, request: Request) -> Response: + """ + Send an HTTP request, and return an HTTP response. + + This is the core implementation that is called into by `.request()` or `.stream()`. + """ + scheme = request.url.scheme.decode() + if scheme == "": + raise UnsupportedProtocol( + "Request URL is missing an 'http://' or 'https://' protocol." + ) + if scheme not in ("http", "https", "ws", "wss"): + raise UnsupportedProtocol( + f"Request URL has an unsupported protocol '{scheme}://'." + ) + + status = RequestStatus(request) + + with self._pool_lock: + self._requests.append(status) + self._close_expired_connections() + self._attempt_to_acquire_connection(status) + + while True: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + try: + connection = status.wait_for_connection(timeout=timeout) + except BaseException as exc: + # If we timeout here, or if the task is cancelled, then make + # sure to remove the request from the queue before bubbling + # up the exception. + with self._pool_lock: + # Ensure only remove when task exists. + if status in self._requests: + self._requests.remove(status) + raise exc + + try: + response = connection.handle_request(request) + except ConnectionNotAvailable: + # The ConnectionNotAvailable exception is a special case, that + # indicates we need to retry the request on a new connection. + # + # The most common case where this can occur is when multiple + # requests are queued waiting for a single connection, which + # might end up as an HTTP/2 connection, but which actually ends + # up as HTTP/1.1. + with self._pool_lock: + # Maintain our position in the request queue, but reset the + # status so that the request becomes queued again. + status.unset_connection() + self._attempt_to_acquire_connection(status) + except BaseException as exc: + with ShieldCancellation(): + self.response_closed(status) + raise exc + else: + break + + # When we return the response, we wrap the stream in a special class + # that handles notifying the connection pool once the response + # has been released. + assert isinstance(response.stream, Iterable) + return Response( + status=response.status, + headers=response.headers, + content=ConnectionPoolByteStream(response.stream, self, status), + extensions=response.extensions, + ) + + def response_closed(self, status: RequestStatus) -> None: + """ + This method acts as a callback once the request/response cycle is complete. + + It is called into from the `ConnectionPoolByteStream.close()` method. + """ + assert status.connection is not None + connection = status.connection + + with self._pool_lock: + # Update the state of the connection pool. + if status in self._requests: + self._requests.remove(status) + + if connection.is_closed() and connection in self._pool: + self._pool.remove(connection) + + # Since we've had a response closed, it's possible we'll now be able + # to service one or more requests that are currently pending. + for status in self._requests: + if status.connection is None: + acquired = self._attempt_to_acquire_connection(status) + # If we could not acquire a connection for a queued request + # then we don't need to check anymore requests that are + # queued later behind it. + if not acquired: + break + + # Housekeeping. + self._close_expired_connections() + + def close(self) -> None: + """ + Close any connections in the pool. + """ + with self._pool_lock: + for connection in self._pool: + connection.close() + self._pool = [] + self._requests = [] + + def __enter__(self) -> "ConnectionPool": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + self.close() + + +class ConnectionPoolByteStream: + """ + A wrapper around the response byte stream, that additionally handles + notifying the connection pool when the response has been closed. + """ + + def __init__( + self, + stream: Iterable[bytes], + pool: ConnectionPool, + status: RequestStatus, + ) -> None: + self._stream = stream + self._pool = pool + self._status = status + + def __iter__(self) -> Iterator[bytes]: + for part in self._stream: + yield part + + def close(self) -> None: + try: + if hasattr(self._stream, "close"): + self._stream.close() + finally: + with ShieldCancellation(): + self._pool.response_closed(self._status) diff --git a/contrib/python/httpcore/httpcore/_sync/http11.py b/contrib/python/httpcore/httpcore/_sync/http11.py new file mode 100644 index 0000000000..0cc100e3ff --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/http11.py @@ -0,0 +1,343 @@ +import enum +import logging +import time +from types import TracebackType +from typing import ( + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + Union, + cast, +) + +import h11 + +from .._backends.base import NetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, + WriteError, + map_exceptions, +) +from .._models import Origin, Request, Response +from .._synchronization import Lock, ShieldCancellation +from .._trace import Trace +from .interfaces import ConnectionInterface + +logger = logging.getLogger("httpcore.http11") + + +# A subset of `h11.Event` types supported by `_send_event` +H11SendEvent = Union[ + h11.Request, + h11.Data, + h11.EndOfMessage, +] + + +class HTTPConnectionState(enum.IntEnum): + NEW = 0 + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class HTTP11Connection(ConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024 + + def __init__( + self, + origin: Origin, + stream: NetworkStream, + keepalive_expiry: Optional[float] = None, + ) -> None: + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: Optional[float] = keepalive_expiry + self._expire_at: Optional[float] = None + self._state = HTTPConnectionState.NEW + self._state_lock = Lock() + self._request_count = 0 + self._h11_state = h11.Connection( + our_role=h11.CLIENT, + max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, + ) + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + with self._state_lock: + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): + self._request_count += 1 + self._state = HTTPConnectionState.ACTIVE + self._expire_at = None + else: + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request} + try: + with Trace( + "send_request_headers", logger, request, kwargs + ) as trace: + self._send_request_headers(**kwargs) + with Trace("send_request_body", logger, request, kwargs) as trace: + self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + + with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + ) = self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": self._network_stream, + }, + ) + except BaseException as exc: + with ShieldCancellation(): + with Trace("response_closed", logger, request) as trace: + self._response_closed() + raise exc + + # Sending the request... + + def _send_request_headers(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): + event = h11.Request( + method=request.method, + target=request.url.target, + headers=request.headers, + ) + self._send_event(event, timeout=timeout) + + def _send_request_body(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + assert isinstance(request.stream, Iterable) + for chunk in request.stream: + event = h11.Data(data=chunk) + self._send_event(event, timeout=timeout) + + self._send_event(h11.EndOfMessage(), timeout=timeout) + + def _send_event( + self, event: h11.Event, timeout: Optional[float] = None + ) -> None: + bytes_to_send = self._h11_state.send(event) + if bytes_to_send is not None: + self._network_stream.write(bytes_to_send, timeout=timeout) + + # Receiving the response... + + def _receive_response_headers( + self, request: Request + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = self._receive_event(timeout=timeout) + if isinstance(event, h11.Response): + break + if ( + isinstance(event, h11.InformationalResponse) + and event.status_code == 101 + ): + break + + http_version = b"HTTP/" + event.http_version + + # h11 version 0.11+ supports a `raw_items` interface to get the + # raw header casing, rather than the enforced lowercase headers. + headers = event.headers.raw_items() + + return http_version, event.status_code, event.reason, headers + + def _receive_response_body(self, request: Request) -> Iterator[bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = self._receive_event(timeout=timeout) + if isinstance(event, h11.Data): + yield bytes(event.data) + elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): + break + + def _receive_event( + self, timeout: Optional[float] = None + ) -> Union[h11.Event, Type[h11.PAUSED]]: + while True: + with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): + event = self._h11_state.next_event() + + if event is h11.NEED_DATA: + data = self._network_stream.read( + self.READ_NUM_BYTES, timeout=timeout + ) + + # If we feed this case through h11 we'll raise an exception like: + # + # httpcore.RemoteProtocolError: can't handle event type + # ConnectionClosed when role=SERVER and state=SEND_RESPONSE + # + # Which is accurate, but not very informative from an end-user + # perspective. Instead we handle this case distinctly and treat + # it as a ConnectError. + if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: + msg = "Server disconnected without sending a response." + raise RemoteProtocolError(msg) + + self._h11_state.receive_data(data) + else: + # mypy fails to narrow the type in the above if statement above + return cast(Union[h11.Event, Type[h11.PAUSED]], event) + + def _response_closed(self) -> None: + with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + self.close() + + # Once the connection is no longer required... + + def close(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._state = HTTPConnectionState.CLOSED + self._network_stream.close() + + # The ConnectionInterface methods provide information about the state of + # the connection, allowing for a connection pooling implementation to + # determine when to reuse and when to close the connection... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + # Note that HTTP/1.1 connections in the "NEW" state are not treated as + # being "available". The control flow which created the connection will + # be able to send an outgoing request, but the connection will not be + # acquired from the connection pool for any other request. + return self._state == HTTPConnectionState.IDLE + + def has_expired(self) -> bool: + now = time.monotonic() + keepalive_expired = self._expire_at is not None and now > self._expire_at + + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + + return keepalive_expired or server_disconnected + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/1.1, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> "HTTP11Connection": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + self.close() + + +class HTTP11ConnectionByteStream: + def __init__(self, connection: HTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + self._closed = False + + def __iter__(self) -> Iterator[bytes]: + kwargs = {"request": self._request} + try: + with Trace("receive_response_body", logger, self._request, kwargs): + for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with ShieldCancellation(): + self.close() + raise exc + + def close(self) -> None: + if not self._closed: + self._closed = True + with Trace("response_closed", logger, self._request): + self._connection._response_closed() diff --git a/contrib/python/httpcore/httpcore/_sync/http2.py b/contrib/python/httpcore/httpcore/_sync/http2.py new file mode 100644 index 0000000000..d141d459a5 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/http2.py @@ -0,0 +1,589 @@ +import enum +import logging +import time +import types +import typing + +import h2.config +import h2.connection +import h2.events +import h2.exceptions +import h2.settings + +from .._backends.base import NetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, +) +from .._models import Origin, Request, Response +from .._synchronization import Lock, Semaphore, ShieldCancellation +from .._trace import Trace +from .interfaces import ConnectionInterface + +logger = logging.getLogger("httpcore.http2") + + +def has_body_headers(request: Request) -> bool: + return any( + k.lower() == b"content-length" or k.lower() == b"transfer-encoding" + for k, v in request.headers + ) + + +class HTTPConnectionState(enum.IntEnum): + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class HTTP2Connection(ConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) + + def __init__( + self, + origin: Origin, + stream: NetworkStream, + keepalive_expiry: typing.Optional[float] = None, + ): + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: typing.Optional[float] = keepalive_expiry + self._h2_state = h2.connection.H2Connection(config=self.CONFIG) + self._state = HTTPConnectionState.IDLE + self._expire_at: typing.Optional[float] = None + self._request_count = 0 + self._init_lock = Lock() + self._state_lock = Lock() + self._read_lock = Lock() + self._write_lock = Lock() + self._sent_connection_init = False + self._used_all_stream_ids = False + self._connection_error = False + + # Mapping from stream ID to response stream events. + self._events: typing.Dict[ + int, + typing.Union[ + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ], + ] = {} + + # Connection terminated events are stored as state since + # we need to handle them for all streams. + self._connection_terminated: typing.Optional[ + h2.events.ConnectionTerminated + ] = None + + self._read_exception: typing.Optional[Exception] = None + self._write_exception: typing.Optional[Exception] = None + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + # This cannot occur in normal operation, since the connection pool + # will only send requests on connections that handle them. + # It's in place simply for resilience as a guard against incorrect + # usage, for anyone working directly with httpcore connections. + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + with self._state_lock: + if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + self._request_count += 1 + self._expire_at = None + self._state = HTTPConnectionState.ACTIVE + else: + raise ConnectionNotAvailable() + + with self._init_lock: + if not self._sent_connection_init: + try: + kwargs = {"request": request} + with Trace("send_connection_init", logger, request, kwargs): + self._send_connection_init(**kwargs) + except BaseException as exc: + with ShieldCancellation(): + self.close() + raise exc + + self._sent_connection_init = True + + # Initially start with just 1 until the remote server provides + # its max_concurrent_streams value + self._max_streams = 1 + + local_settings_max_streams = ( + self._h2_state.local_settings.max_concurrent_streams + ) + self._max_streams_semaphore = Semaphore(local_settings_max_streams) + + for _ in range(local_settings_max_streams - self._max_streams): + self._max_streams_semaphore.acquire() + + self._max_streams_semaphore.acquire() + + try: + stream_id = self._h2_state.get_next_available_stream_id() + self._events[stream_id] = [] + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + self._request_count -= 1 + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request, "stream_id": stream_id} + with Trace("send_request_headers", logger, request, kwargs): + self._send_request_headers(request=request, stream_id=stream_id) + with Trace("send_request_body", logger, request, kwargs): + self._send_request_body(request=request, stream_id=stream_id) + with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + status, headers = self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + return Response( + status=status, + headers=headers, + content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), + extensions={ + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + }, + ) + except BaseException as exc: # noqa: PIE786 + with ShieldCancellation(): + kwargs = {"stream_id": stream_id} + with Trace("response_closed", logger, request, kwargs): + self._response_closed(stream_id=stream_id) + + if isinstance(exc, h2.exceptions.ProtocolError): + # One case where h2 can raise a protocol error is when a + # closed frame has been seen by the state machine. + # + # This happens when one stream is reading, and encounters + # a GOAWAY event. Other flows of control may then raise + # a protocol error at any point they interact with the 'h2_state'. + # + # In this case we'll have stored the event, and should raise + # it as a RemoteProtocolError. + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) + # If h2 raises a protocol error in some other state then we + # must somehow have made a protocol violation. + raise LocalProtocolError(exc) # pragma: nocover + + raise exc + + def _send_connection_init(self, request: Request) -> None: + """ + The HTTP/2 connection requires some initial setup before we can start + using individual request/response streams on it. + """ + # Need to set these manually here instead of manipulating via + # __setitem__() otherwise the H2Connection will emit SettingsUpdate + # frames in addition to sending the undesired defaults. + self._h2_state.local_settings = h2.settings.Settings( + client=True, + initial_values={ + # Disable PUSH_PROMISE frames from the server since we don't do anything + # with them for now. Maybe when we support caching? + h2.settings.SettingCodes.ENABLE_PUSH: 0, + # These two are taken from h2 for safe defaults + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, + h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, + }, + ) + + # Some websites (*cough* Yahoo *cough*) balk at this setting being + # present in the initial handshake since it's not defined in the original + # RFC despite the RFC mandating ignoring settings you don't know about. + del self._h2_state.local_settings[ + h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL + ] + + self._h2_state.initiate_connection() + self._h2_state.increment_flow_control_window(2**24) + self._write_outgoing_data(request) + + # Sending the request... + + def _send_request_headers(self, request: Request, stream_id: int) -> None: + """ + Send the request headers to a given stream ID. + """ + end_stream = not has_body_headers(request) + + # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. + # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require + # HTTP/1.1 style headers, and map them appropriately if we end up on + # an HTTP/2 connection. + authority = [v for k, v in request.headers if k.lower() == b"host"][0] + + headers = [ + (b":method", request.method), + (b":authority", authority), + (b":scheme", request.url.scheme), + (b":path", request.url.target), + ] + [ + (k.lower(), v) + for k, v in request.headers + if k.lower() + not in ( + b"host", + b"transfer-encoding", + ) + ] + + self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) + self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id) + self._write_outgoing_data(request) + + def _send_request_body(self, request: Request, stream_id: int) -> None: + """ + Iterate over the request body sending it to a given stream ID. + """ + if not has_body_headers(request): + return + + assert isinstance(request.stream, typing.Iterable) + for data in request.stream: + self._send_stream_data(request, stream_id, data) + self._send_end_stream(request, stream_id) + + def _send_stream_data( + self, request: Request, stream_id: int, data: bytes + ) -> None: + """ + Send a single chunk of data in one or more data frames. + """ + while data: + max_flow = self._wait_for_outgoing_flow(request, stream_id) + chunk_size = min(len(data), max_flow) + chunk, data = data[:chunk_size], data[chunk_size:] + self._h2_state.send_data(stream_id, chunk) + self._write_outgoing_data(request) + + def _send_end_stream(self, request: Request, stream_id: int) -> None: + """ + Send an empty data frame on on a given stream ID with the END_STREAM flag set. + """ + self._h2_state.end_stream(stream_id) + self._write_outgoing_data(request) + + # Receiving the response... + + def _receive_response( + self, request: Request, stream_id: int + ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]: + """ + Return the response status code and headers for a given stream ID. + """ + while True: + event = self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.ResponseReceived): + break + + status_code = 200 + headers = [] + for k, v in event.headers: + if k == b":status": + status_code = int(v.decode("ascii", errors="ignore")) + elif not k.startswith(b":"): + headers.append((k, v)) + + return (status_code, headers) + + def _receive_response_body( + self, request: Request, stream_id: int + ) -> typing.Iterator[bytes]: + """ + Iterator that returns the bytes of the response body for a given stream ID. + """ + while True: + event = self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.DataReceived): + amount = event.flow_controlled_length + self._h2_state.acknowledge_received_data(amount, stream_id) + self._write_outgoing_data(request) + yield event.data + elif isinstance(event, h2.events.StreamEnded): + break + + def _receive_stream_event( + self, request: Request, stream_id: int + ) -> typing.Union[ + h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded + ]: + """ + Return the next available event for a given stream ID. + + Will read more data from the network if required. + """ + while not self._events.get(stream_id): + self._receive_events(request, stream_id) + event = self._events[stream_id].pop(0) + if isinstance(event, h2.events.StreamReset): + raise RemoteProtocolError(event) + return event + + def _receive_events( + self, request: Request, stream_id: typing.Optional[int] = None + ) -> None: + """ + Read some data from the network until we see one or more events + for a given stream ID. + """ + with self._read_lock: + if self._connection_terminated is not None: + last_stream_id = self._connection_terminated.last_stream_id + if stream_id and last_stream_id and stream_id > last_stream_id: + self._request_count -= 1 + raise ConnectionNotAvailable() + raise RemoteProtocolError(self._connection_terminated) + + # This conditional is a bit icky. We don't want to block reading if we've + # actually got an event to return for a given stream. We need to do that + # check *within* the atomic read lock. Though it also need to be optional, + # because when we call it from `_wait_for_outgoing_flow` we *do* want to + # block until we've available flow control, event when we have events + # pending for the stream ID we're attempting to send on. + if stream_id is None or not self._events.get(stream_id): + events = self._read_incoming_data(request) + for event in events: + if isinstance(event, h2.events.RemoteSettingsChanged): + with Trace( + "receive_remote_settings", logger, request + ) as trace: + self._receive_remote_settings_change(event) + trace.return_value = event + + elif isinstance( + event, + ( + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ), + ): + if event.stream_id in self._events: + self._events[event.stream_id].append(event) + + elif isinstance(event, h2.events.ConnectionTerminated): + self._connection_terminated = event + + self._write_outgoing_data(request) + + def _receive_remote_settings_change(self, event: h2.events.Event) -> None: + max_concurrent_streams = event.changed_settings.get( + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS + ) + if max_concurrent_streams: + new_max_streams = min( + max_concurrent_streams.new_value, + self._h2_state.local_settings.max_concurrent_streams, + ) + if new_max_streams and new_max_streams != self._max_streams: + while new_max_streams > self._max_streams: + self._max_streams_semaphore.release() + self._max_streams += 1 + while new_max_streams < self._max_streams: + self._max_streams_semaphore.acquire() + self._max_streams -= 1 + + def _response_closed(self, stream_id: int) -> None: + self._max_streams_semaphore.release() + del self._events[stream_id] + with self._state_lock: + if self._connection_terminated and not self._events: + self.close() + + elif self._state == HTTPConnectionState.ACTIVE and not self._events: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + if self._used_all_stream_ids: # pragma: nocover + self.close() + + def close(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._h2_state.close_connection() + self._state = HTTPConnectionState.CLOSED + self._network_stream.close() + + # Wrappers around network read/write operations... + + def _read_incoming_data( + self, request: Request + ) -> typing.List[h2.events.Event]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + if self._read_exception is not None: + raise self._read_exception # pragma: nocover + + try: + data = self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + except Exception as exc: + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future reads. + # (For example, this means that a single read timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._read_exception = exc + self._connection_error = True + raise exc + + events: typing.List[h2.events.Event] = self._h2_state.receive_data(data) + + return events + + def _write_outgoing_data(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with self._write_lock: + data_to_send = self._h2_state.data_to_send() + + if self._write_exception is not None: + raise self._write_exception # pragma: nocover + + try: + self._network_stream.write(data_to_send, timeout) + except Exception as exc: # pragma: nocover + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future write. + # (For example, this means that a single write timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._write_exception = exc + self._connection_error = True + raise exc + + # Flow control... + + def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: + """ + Returns the maximum allowable outgoing flow for a given stream. + + If the allowable flow is zero, then waits on the network until + WindowUpdated frames have increased the flow rate. + https://tools.ietf.org/html/rfc7540#section-6.9 + """ + local_flow: int = self._h2_state.local_flow_control_window(stream_id) + max_frame_size: int = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + while flow == 0: + self._receive_events(request) + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + return flow + + # Interface for connection pooling... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + return ( + self._state != HTTPConnectionState.CLOSED + and not self._connection_error + and not self._used_all_stream_ids + and not ( + self._h2_state.state_machine.state + == h2.connection.ConnectionState.CLOSED + ) + ) + + def has_expired(self) -> bool: + now = time.monotonic() + return self._expire_at is not None and now > self._expire_at + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/2, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> "HTTP2Connection": + return self + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[types.TracebackType] = None, + ) -> None: + self.close() + + +class HTTP2ConnectionByteStream: + def __init__( + self, connection: HTTP2Connection, request: Request, stream_id: int + ) -> None: + self._connection = connection + self._request = request + self._stream_id = stream_id + self._closed = False + + def __iter__(self) -> typing.Iterator[bytes]: + kwargs = {"request": self._request, "stream_id": self._stream_id} + try: + with Trace("receive_response_body", logger, self._request, kwargs): + for chunk in self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with ShieldCancellation(): + self.close() + raise exc + + def close(self) -> None: + if not self._closed: + self._closed = True + kwargs = {"stream_id": self._stream_id} + with Trace("response_closed", logger, self._request, kwargs): + self._connection._response_closed(stream_id=self._stream_id) diff --git a/contrib/python/httpcore/httpcore/_sync/http_proxy.py b/contrib/python/httpcore/httpcore/_sync/http_proxy.py new file mode 100644 index 0000000000..6acac9a7cd --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/http_proxy.py @@ -0,0 +1,368 @@ +import logging +import ssl +from base64 import b64encode +from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union + +from .._backends.base import SOCKET_OPTION, NetworkBackend +from .._exceptions import ProxyError +from .._models import ( + URL, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, +) +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from .connection import HTTPConnection +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface + +HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]] +HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]] + + +logger = logging.getLogger("httpcore.proxy") + + +def merge_headers( + default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, + override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, +) -> List[Tuple[bytes, bytes]]: + """ + Append default_headers and override_headers, de-duplicating if a key exists + in both cases. + """ + default_headers = [] if default_headers is None else list(default_headers) + override_headers = [] if override_headers is None else list(override_headers) + has_override = set(key.lower() for key, value in override_headers) + default_headers = [ + (key, value) + for key, value in default_headers + if key.lower() not in has_override + ] + return default_headers + override_headers + + +def build_auth_header(username: bytes, password: bytes) -> bytes: + userpass = username + b":" + password + return b"Basic " + b64encode(userpass) + + +class HTTPProxy(ConnectionPool): + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: Union[URL, bytes, str], + proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None, + proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, + ssl_context: Optional[ssl.SSLContext] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + max_connections: Optional[int] = 10, + max_keepalive_connections: Optional[int] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + proxy_auth: Any proxy authentication as a two-tuple of + (username, password). May be either bytes or ascii-only str. + proxy_headers: Any HTTP headers to use for the proxy requests. + For example `{"Proxy-Authorization": "Basic <username>:<password>"}`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + local_address=local_address, + uds=uds, + socket_options=socket_options, + ) + + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if ( + self._proxy_url.scheme == b"http" and proxy_ssl_context is not None + ): # pragma: no cover + raise RuntimeError( + "The `proxy_ssl_context` argument is not allowed for the http scheme" + ) + + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + if proxy_auth is not None: + username = enforce_bytes(proxy_auth[0], name="proxy_auth") + password = enforce_bytes(proxy_auth[1], name="proxy_auth") + authorization = build_auth_header(username, password) + self._proxy_headers = [ + (b"Proxy-Authorization", authorization) + ] + self._proxy_headers + + def create_connection(self, origin: Origin) -> ConnectionInterface: + if origin.scheme == b"http": + return ForwardHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + proxy_ssl_context=self._proxy_ssl_context, + ) + return TunnelHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + ssl_context=self._ssl_context, + proxy_ssl_context=self._proxy_ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class ForwardHTTPConnection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, + keepalive_expiry: Optional[float] = None, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + ) -> None: + self._connection = HTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._remote_origin = remote_origin + + def handle_request(self, request: Request) -> Response: + headers = merge_headers(self._proxy_headers, request.headers) + url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=bytes(request.url), + ) + proxy_request = Request( + method=request.method, + url=url, + headers=headers, + content=request.stream, + extensions=request.extensions, + ) + return self._connection.handle_request(proxy_request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + self._connection.close() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + +class TunnelHTTPConnection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + ssl_context: Optional[ssl.SSLContext] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + self._connection: ConnectionInterface = HTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._connect_lock = Lock() + self._connected = False + + def handle_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) + + with self._connect_lock: + if not self._connected: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, + ) + connect_headers = merge_headers( + [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers + ) + connect_request = Request( + method=b"CONNECT", + url=connect_url, + headers=connect_headers, + extensions=request.extensions, + ) + connect_response = self._connection.handle_request( + connect_request + ) + + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + self._connection.close() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + return self._connection.handle_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + self._connection.close() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/contrib/python/httpcore/httpcore/_sync/interfaces.py b/contrib/python/httpcore/httpcore/_sync/interfaces.py new file mode 100644 index 0000000000..5e95be1ec7 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/interfaces.py @@ -0,0 +1,135 @@ +from contextlib import contextmanager +from typing import Iterator, Optional, Union + +from .._models import ( + URL, + Extensions, + HeaderTypes, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, + include_request_headers, +) + + +class RequestInterface: + def request( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterator[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> Response: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = self.handle_request(request) + try: + response.read() + finally: + response.close() + return response + + @contextmanager + def stream( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterator[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> Iterator[Response]: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = self.handle_request(request) + try: + yield response + finally: + response.close() + + def handle_request(self, request: Request) -> Response: + raise NotImplementedError() # pragma: nocover + + +class ConnectionInterface(RequestInterface): + def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + def info(self) -> str: + raise NotImplementedError() # pragma: nocover + + def can_handle_request(self, origin: Origin) -> bool: + raise NotImplementedError() # pragma: nocover + + def is_available(self) -> bool: + """ + Return `True` if the connection is currently able to accept an + outgoing request. + + An HTTP/1.1 connection will only be available if it is currently idle. + + An HTTP/2 connection will be available so long as the stream ID space is + not yet exhausted, and the connection is not in an error state. + + While the connection is being established we may not yet know if it is going + to result in an HTTP/1.1 or HTTP/2 connection. The connection should be + treated as being available, but might ultimately raise `NewConnectionRequired` + required exceptions if multiple requests are attempted over a connection + that ends up being established as HTTP/1.1. + """ + raise NotImplementedError() # pragma: nocover + + def has_expired(self) -> bool: + """ + Return `True` if the connection is in a state where it should be closed. + + This either means that the connection is idle and it has passed the + expiry time on its keep-alive, or that server has sent an EOF. + """ + raise NotImplementedError() # pragma: nocover + + def is_idle(self) -> bool: + """ + Return `True` if the connection is currently idle. + """ + raise NotImplementedError() # pragma: nocover + + def is_closed(self) -> bool: + """ + Return `True` if the connection has been closed. + + Used when a response is closed to determine if the connection may be + returned to the connection pool or not. + """ + raise NotImplementedError() # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_sync/socks_proxy.py b/contrib/python/httpcore/httpcore/_sync/socks_proxy.py new file mode 100644 index 0000000000..502e4d7fef --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/socks_proxy.py @@ -0,0 +1,342 @@ +import logging +import ssl +import typing + +from socksio import socks5 + +from .._backends.sync import SyncBackend +from .._backends.base import NetworkBackend, NetworkStream +from .._exceptions import ConnectionNotAvailable, ProxyError +from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface + +logger = logging.getLogger("httpcore.socks") + + +AUTH_METHODS = { + b"\x00": "NO AUTHENTICATION REQUIRED", + b"\x01": "GSSAPI", + b"\x02": "USERNAME/PASSWORD", + b"\xff": "NO ACCEPTABLE METHODS", +} + +REPLY_CODES = { + b"\x00": "Succeeded", + b"\x01": "General SOCKS server failure", + b"\x02": "Connection not allowed by ruleset", + b"\x03": "Network unreachable", + b"\x04": "Host unreachable", + b"\x05": "Connection refused", + b"\x06": "TTL expired", + b"\x07": "Command not supported", + b"\x08": "Address type not supported", +} + + +def _init_socks5_connection( + stream: NetworkStream, + *, + host: bytes, + port: int, + auth: typing.Optional[typing.Tuple[bytes, bytes]] = None, +) -> None: + conn = socks5.SOCKS5Connection() + + # Auth method request + auth_method = ( + socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED + if auth is None + else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD + ) + conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method])) + outgoing_bytes = conn.data_to_send() + stream.write(outgoing_bytes) + + # Auth method response + incoming_bytes = stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5AuthReply) + if response.method != auth_method: + requested = AUTH_METHODS.get(auth_method, "UNKNOWN") + responded = AUTH_METHODS.get(response.method, "UNKNOWN") + raise ProxyError( + f"Requested {requested} from proxy server, but got {responded}." + ) + + if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD: + # Username/password request + assert auth is not None + username, password = auth + conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password)) + outgoing_bytes = conn.data_to_send() + stream.write(outgoing_bytes) + + # Username/password response + incoming_bytes = stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5UsernamePasswordReply) + if not response.success: + raise ProxyError("Invalid username/password") + + # Connect request + conn.send( + socks5.SOCKS5CommandRequest.from_address( + socks5.SOCKS5Command.CONNECT, (host, port) + ) + ) + outgoing_bytes = conn.data_to_send() + stream.write(outgoing_bytes) + + # Connect response + incoming_bytes = stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5Reply) + if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED: + reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN") + raise ProxyError(f"Proxy Server could not connect: {reply_code}.") + + +class SOCKSProxy(ConnectionPool): + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: typing.Union[URL, bytes, str], + proxy_auth: typing.Optional[ + typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]] + ] = None, + ssl_context: typing.Optional[ssl.SSLContext] = None, + max_connections: typing.Optional[int] = 10, + max_keepalive_connections: typing.Optional[int] = None, + keepalive_expiry: typing.Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + network_backend: typing.Optional[NetworkBackend] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + ) + self._ssl_context = ssl_context + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if proxy_auth is not None: + username, password = proxy_auth + username_bytes = enforce_bytes(username, name="proxy_auth") + password_bytes = enforce_bytes(password, name="proxy_auth") + self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = ( + username_bytes, + password_bytes, + ) + else: + self._proxy_auth = None + + def create_connection(self, origin: Origin) -> ConnectionInterface: + return Socks5Connection( + proxy_origin=self._proxy_url.origin, + remote_origin=origin, + proxy_auth=self._proxy_auth, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class Socks5Connection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None, + ssl_context: typing.Optional[ssl.SSLContext] = None, + keepalive_expiry: typing.Optional[float] = None, + http1: bool = True, + http2: bool = False, + network_backend: typing.Optional[NetworkBackend] = None, + ) -> None: + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._proxy_auth = proxy_auth + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + + self._network_backend: NetworkBackend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._connect_lock = Lock() + self._connection: typing.Optional[ConnectionInterface] = None + self._connect_failed = False + + def handle_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + with self._connect_lock: + if self._connection is None: + try: + # Connect to the proxy + kwargs = { + "host": self._proxy_origin.host.decode("ascii"), + "port": self._proxy_origin.port, + "timeout": timeout, + } + with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + + # Connect to the remote host using socks5 + kwargs = { + "stream": stream, + "host": self._remote_origin.host.decode("ascii"), + "port": self._remote_origin.port, + "auth": self._proxy_auth, + } + with Trace( + "setup_socks5_connection", logger, request, kwargs + ) as trace: + _init_socks5_connection(**kwargs) + trace.return_value = stream + + # Upgrade the stream to SSL + if self._remote_origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ( + ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ) + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or ( + self._http2 and not self._http1 + ): # pragma: nocover + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): # pragma: nocover + raise ConnectionNotAvailable() + + return self._connection.handle_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + if self._connection is not None: + self._connection.close() + + def is_available(self) -> bool: + if self._connection is None: # pragma: nocover + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._remote_origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: # pragma: nocover + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/contrib/python/httpcore/httpcore/_synchronization.py b/contrib/python/httpcore/httpcore/_synchronization.py new file mode 100644 index 0000000000..bae27c1b11 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_synchronization.py @@ -0,0 +1,279 @@ +import threading +from types import TracebackType +from typing import Optional, Type + +import sniffio + +from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions + +# Our async synchronization primatives use either 'anyio' or 'trio' depending +# on if they're running under asyncio or trio. + +try: + import trio +except ImportError: # pragma: nocover + trio = None # type: ignore + +try: + import anyio +except ImportError: # pragma: nocover + anyio = None # type: ignore + + +class AsyncLock: + def __init__(self) -> None: + self._backend = "" + + def setup(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a lock with the correct implementation. + """ + self._backend = sniffio.current_async_library() + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio, requires the 'trio' package to be installed." + ) + self._trio_lock = trio.Lock() + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + self._anyio_lock = anyio.Lock() + + async def __aenter__(self) -> "AsyncLock": + if not self._backend: + self.setup() + + if self._backend == "trio": + await self._trio_lock.acquire() + else: + await self._anyio_lock.acquire() + + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + if self._backend == "trio": + self._trio_lock.release() + else: + self._anyio_lock.release() + + +class AsyncEvent: + def __init__(self) -> None: + self._backend = "" + + def setup(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a lock with the correct implementation. + """ + self._backend = sniffio.current_async_library() + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio requires the 'trio' package to be installed." + ) + self._trio_event = trio.Event() + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + self._anyio_event = anyio.Event() + + def set(self) -> None: + if not self._backend: + self.setup() + + if self._backend == "trio": + self._trio_event.set() + else: + self._anyio_event.set() + + async def wait(self, timeout: Optional[float] = None) -> None: + if not self._backend: + self.setup() + + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio requires the 'trio' package to be installed." + ) + + trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout} + timeout_or_inf = float("inf") if timeout is None else timeout + with map_exceptions(trio_exc_map): + with trio.fail_after(timeout_or_inf): + await self._trio_event.wait() + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + + anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} + with map_exceptions(anyio_exc_map): + with anyio.fail_after(timeout): + await self._anyio_event.wait() + + +class AsyncSemaphore: + def __init__(self, bound: int) -> None: + self._bound = bound + self._backend = "" + + def setup(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a semaphore with the correct implementation. + """ + self._backend = sniffio.current_async_library() + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio requires the 'trio' package to be installed." + ) + + self._trio_semaphore = trio.Semaphore( + initial_value=self._bound, max_value=self._bound + ) + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + + self._anyio_semaphore = anyio.Semaphore( + initial_value=self._bound, max_value=self._bound + ) + + async def acquire(self) -> None: + if not self._backend: + self.setup() + + if self._backend == "trio": + await self._trio_semaphore.acquire() + else: + await self._anyio_semaphore.acquire() + + async def release(self) -> None: + if self._backend == "trio": + self._trio_semaphore.release() + else: + self._anyio_semaphore.release() + + +class AsyncShieldCancellation: + # For certain portions of our codebase where we're dealing with + # closing connections during exception handling we want to shield + # the operation from being cancelled. + # + # with AsyncShieldCancellation(): + # ... # clean-up operations, shielded from cancellation. + + def __init__(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a shielded scope with the correct implementation. + """ + self._backend = sniffio.current_async_library() + + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio requires the 'trio' package to be installed." + ) + + self._trio_shield = trio.CancelScope(shield=True) + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + + self._anyio_shield = anyio.CancelScope(shield=True) + + def __enter__(self) -> "AsyncShieldCancellation": + if self._backend == "trio": + self._trio_shield.__enter__() + else: + self._anyio_shield.__enter__() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + if self._backend == "trio": + self._trio_shield.__exit__(exc_type, exc_value, traceback) + else: + self._anyio_shield.__exit__(exc_type, exc_value, traceback) + + +# Our thread-based synchronization primitives... + + +class Lock: + def __init__(self) -> None: + self._lock = threading.Lock() + + def __enter__(self) -> "Lock": + self._lock.acquire() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + self._lock.release() + + +class Event: + def __init__(self) -> None: + self._event = threading.Event() + + def set(self) -> None: + self._event.set() + + def wait(self, timeout: Optional[float] = None) -> None: + if not self._event.wait(timeout=timeout): + raise PoolTimeout() # pragma: nocover + + +class Semaphore: + def __init__(self, bound: int) -> None: + self._semaphore = threading.Semaphore(value=bound) + + def acquire(self) -> None: + self._semaphore.acquire() + + def release(self) -> None: + self._semaphore.release() + + +class ShieldCancellation: + # Thread-synchronous codebases don't support cancellation semantics. + # We have this class because we need to mirror the async and sync + # cases within our package, but it's just a no-op. + def __enter__(self) -> "ShieldCancellation": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + pass diff --git a/contrib/python/httpcore/httpcore/_trace.py b/contrib/python/httpcore/httpcore/_trace.py new file mode 100644 index 0000000000..b122a53e88 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_trace.py @@ -0,0 +1,105 @@ +import inspect +import logging +from types import TracebackType +from typing import Any, Dict, Optional, Type + +from ._models import Request + + +class Trace: + def __init__( + self, + name: str, + logger: logging.Logger, + request: Optional[Request] = None, + kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.name = name + self.logger = logger + self.trace_extension = ( + None if request is None else request.extensions.get("trace") + ) + self.debug = self.logger.isEnabledFor(logging.DEBUG) + self.kwargs = kwargs or {} + self.return_value: Any = None + self.should_trace = self.debug or self.trace_extension is not None + self.prefix = self.logger.name.split(".")[-1] + + def trace(self, name: str, info: Dict[str, Any]) -> None: + if self.trace_extension is not None: + prefix_and_name = f"{self.prefix}.{name}" + ret = self.trace_extension(prefix_and_name, info) + if inspect.iscoroutine(ret): # pragma: no cover + raise TypeError( + "If you are using a synchronous interface, " + "the callback of the `trace` extension should " + "be a normal function instead of an asynchronous function." + ) + + if self.debug: + if not info or "return_value" in info and info["return_value"] is None: + message = name + else: + args = " ".join([f"{key}={value!r}" for key, value in info.items()]) + message = f"{name} {args}" + self.logger.debug(message) + + def __enter__(self) -> "Trace": + if self.should_trace: + info = self.kwargs + self.trace(f"{self.name}.started", info) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + if self.should_trace: + if exc_value is None: + info = {"return_value": self.return_value} + self.trace(f"{self.name}.complete", info) + else: + info = {"exception": exc_value} + self.trace(f"{self.name}.failed", info) + + async def atrace(self, name: str, info: Dict[str, Any]) -> None: + if self.trace_extension is not None: + prefix_and_name = f"{self.prefix}.{name}" + coro = self.trace_extension(prefix_and_name, info) + if not inspect.iscoroutine(coro): # pragma: no cover + raise TypeError( + "If you're using an asynchronous interface, " + "the callback of the `trace` extension should " + "be an asynchronous function rather than a normal function." + ) + await coro + + if self.debug: + if not info or "return_value" in info and info["return_value"] is None: + message = name + else: + args = " ".join([f"{key}={value!r}" for key, value in info.items()]) + message = f"{name} {args}" + self.logger.debug(message) + + async def __aenter__(self) -> "Trace": + if self.should_trace: + info = self.kwargs + await self.atrace(f"{self.name}.started", info) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + if self.should_trace: + if exc_value is None: + info = {"return_value": self.return_value} + await self.atrace(f"{self.name}.complete", info) + else: + info = {"exception": exc_value} + await self.atrace(f"{self.name}.failed", info) diff --git a/contrib/python/httpcore/httpcore/_utils.py b/contrib/python/httpcore/httpcore/_utils.py new file mode 100644 index 0000000000..df5dea8fe4 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_utils.py @@ -0,0 +1,36 @@ +import select +import socket +import sys +import typing + + +def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool: + """ + Return whether a socket, as identifed by its file descriptor, is readable. + "A socket is readable" means that the read buffer isn't empty, i.e. that calling + .recv() on it would immediately return some data. + """ + # NOTE: we want check for readability without actually attempting to read, because + # we don't want to block forever if it's not readable. + + # In the case that the socket no longer exists, or cannot return a file + # descriptor, we treat it as being readable, as if it the next read operation + # on it is ready to return the terminating `b""`. + sock_fd = None if sock is None else sock.fileno() + if sock_fd is None or sock_fd < 0: # pragma: nocover + return True + + # The implementation below was stolen from: + # https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471-L478 + # See also: https://github.com/encode/httpcore/pull/193#issuecomment-703129316 + + # Use select.select on Windows, and when poll is unavailable and select.poll + # everywhere else. (E.g. When eventlet is in use. See #327) + if ( + sys.platform == "win32" or getattr(select, "poll", None) is None + ): # pragma: nocover + rready, _, _ = select.select([sock_fd], [], [], 0) + return bool(rready) + p = select.poll() + p.register(sock_fd, select.POLLIN) + return bool(p.poll(0)) diff --git a/contrib/python/httpcore/httpcore/py.typed b/contrib/python/httpcore/httpcore/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/httpcore/httpcore/py.typed diff --git a/contrib/python/httpcore/ya.make b/contrib/python/httpcore/ya.make new file mode 100644 index 0000000000..e8516afe10 --- /dev/null +++ b/contrib/python/httpcore/ya.make @@ -0,0 +1,68 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(0.18.0) + +LICENSE(BSD-3-Clause) + +PEERDIR( + contrib/python/anyio + contrib/python/certifi + contrib/python/h11 + contrib/python/sniffio +) + +NO_LINT() + +NO_CHECK_IMPORTS( + httpcore._async.http2 + httpcore._async.socks_proxy + httpcore._backends.trio + httpcore._sync.http2 + httpcore._sync.socks_proxy +) + +PY_SRCS( + TOP_LEVEL + httpcore/__init__.py + httpcore/_api.py + httpcore/_async/__init__.py + httpcore/_async/connection.py + httpcore/_async/connection_pool.py + httpcore/_async/http11.py + httpcore/_async/http2.py + httpcore/_async/http_proxy.py + httpcore/_async/interfaces.py + httpcore/_async/socks_proxy.py + httpcore/_backends/__init__.py + httpcore/_backends/anyio.py + httpcore/_backends/auto.py + httpcore/_backends/base.py + httpcore/_backends/mock.py + httpcore/_backends/sync.py + httpcore/_backends/trio.py + httpcore/_exceptions.py + httpcore/_models.py + httpcore/_ssl.py + httpcore/_sync/__init__.py + httpcore/_sync/connection.py + httpcore/_sync/connection_pool.py + httpcore/_sync/http11.py + httpcore/_sync/http2.py + httpcore/_sync/http_proxy.py + httpcore/_sync/interfaces.py + httpcore/_sync/socks_proxy.py + httpcore/_synchronization.py + httpcore/_trace.py + httpcore/_utils.py +) + +RESOURCE_FILES( + PREFIX contrib/python/httpcore/ + .dist-info/METADATA + .dist-info/top_level.txt + httpcore/py.typed +) + +END() diff --git a/contrib/python/httpx/.dist-info/METADATA b/contrib/python/httpx/.dist-info/METADATA new file mode 100644 index 0000000000..f3a6d509cd --- /dev/null +++ b/contrib/python/httpx/.dist-info/METADATA @@ -0,0 +1,216 @@ +Metadata-Version: 2.1 +Name: httpx +Version: 0.25.0 +Summary: The next generation HTTP client. +Project-URL: Changelog, https://github.com/encode/httpx/blob/master/CHANGELOG.md +Project-URL: Documentation, https://www.python-httpx.org +Project-URL: Homepage, https://github.com/encode/httpx +Project-URL: Source, https://github.com/encode/httpx +Author-email: Tom Christie <tom@tomchristie.com> +License-Expression: BSD-3-Clause +License-File: LICENSE.md +Classifier: Development Status :: 4 - Beta +Classifier: Environment :: Web Environment +Classifier: Framework :: AsyncIO +Classifier: Framework :: Trio +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Topic :: Internet :: WWW/HTTP +Requires-Python: >=3.8 +Requires-Dist: certifi +Requires-Dist: httpcore<0.19.0,>=0.18.0 +Requires-Dist: idna +Requires-Dist: sniffio +Provides-Extra: brotli +Requires-Dist: brotli; platform_python_implementation == 'CPython' and extra == 'brotli' +Requires-Dist: brotlicffi; platform_python_implementation != 'CPython' and extra == 'brotli' +Provides-Extra: cli +Requires-Dist: click==8.*; extra == 'cli' +Requires-Dist: pygments==2.*; extra == 'cli' +Requires-Dist: rich<14,>=10; extra == 'cli' +Provides-Extra: http2 +Requires-Dist: h2<5,>=3; extra == 'http2' +Provides-Extra: socks +Requires-Dist: socksio==1.*; extra == 'socks' +Description-Content-Type: text/markdown + +<p align="center"> + <a href="https://www.python-httpx.org/"><img width="350" height="208" src="https://raw.githubusercontent.com/encode/httpx/master/docs/img/butterfly.png" alt='HTTPX'></a> +</p> + +<p align="center"><strong>HTTPX</strong> <em>- A next-generation HTTP client for Python.</em></p> + +<p align="center"> +<a href="https://github.com/encode/httpx/actions"> + <img src="https://github.com/encode/httpx/workflows/Test%20Suite/badge.svg" alt="Test Suite"> +</a> +<a href="https://pypi.org/project/httpx/"> + <img src="https://badge.fury.io/py/httpx.svg" alt="Package version"> +</a> +</p> + +HTTPX is a fully featured HTTP client library for Python 3. It includes **an integrated +command line client**, has support for both **HTTP/1.1 and HTTP/2**, and provides both **sync +and async APIs**. + +--- + +Install HTTPX using pip: + +```shell +$ pip install httpx +``` + +Now, let's get started: + +```pycon +>>> import httpx +>>> r = httpx.get('https://www.example.org/') +>>> r +<Response [200 OK]> +>>> r.status_code +200 +>>> r.headers['content-type'] +'text/html; charset=UTF-8' +>>> r.text +'<!doctype html>\n<html>\n<head>\n<title>Example Domain</title>...' +``` + +Or, using the command-line client. + +```shell +$ pip install 'httpx[cli]' # The command line client is an optional dependency. +``` + +Which now allows us to use HTTPX directly from the command-line... + +<p align="center"> + <img width="700" src="https://raw.githubusercontent.com/encode/httpx/master/docs/img/httpx-help.png" alt='httpx --help'> +</p> + +Sending a request... + +<p align="center"> + <img width="700" src="https://raw.githubusercontent.com/encode/httpx/master/docs/img/httpx-request.png" alt='httpx http://httpbin.org/json'> +</p> + +## Features + +HTTPX builds on the well-established usability of `requests`, and gives you: + +* A broadly [requests-compatible API](https://www.python-httpx.org/compatibility/). +* An integrated command-line client. +* HTTP/1.1 [and HTTP/2 support](https://www.python-httpx.org/http2/). +* Standard synchronous interface, but with [async support if you need it](https://www.python-httpx.org/async/). +* Ability to make requests directly to [WSGI applications](https://www.python-httpx.org/advanced/#calling-into-python-web-apps) or [ASGI applications](https://www.python-httpx.org/async/#calling-into-python-web-apps). +* Strict timeouts everywhere. +* Fully type annotated. +* 100% test coverage. + +Plus all the standard features of `requests`... + +* International Domains and URLs +* Keep-Alive & Connection Pooling +* Sessions with Cookie Persistence +* Browser-style SSL Verification +* Basic/Digest Authentication +* Elegant Key/Value Cookies +* Automatic Decompression +* Automatic Content Decoding +* Unicode Response Bodies +* Multipart File Uploads +* HTTP(S) Proxy Support +* Connection Timeouts +* Streaming Downloads +* .netrc Support +* Chunked Requests + +## Installation + +Install with pip: + +```shell +$ pip install httpx +``` + +Or, to include the optional HTTP/2 support, use: + +```shell +$ pip install httpx[http2] +``` + +HTTPX requires Python 3.8+. + +## Documentation + +Project documentation is available at [https://www.python-httpx.org/](https://www.python-httpx.org/). + +For a run-through of all the basics, head over to the [QuickStart](https://www.python-httpx.org/quickstart/). + +For more advanced topics, see the [Advanced Usage](https://www.python-httpx.org/advanced/) section, the [async support](https://www.python-httpx.org/async/) section, or the [HTTP/2](https://www.python-httpx.org/http2/) section. + +The [Developer Interface](https://www.python-httpx.org/api/) provides a comprehensive API reference. + +To find out about tools that integrate with HTTPX, see [Third Party Packages](https://www.python-httpx.org/third_party_packages/). + +## Contribute + +If you want to contribute with HTTPX check out the [Contributing Guide](https://www.python-httpx.org/contributing/) to learn how to start. + +## Dependencies + +The HTTPX project relies on these excellent libraries: + +* `httpcore` - The underlying transport implementation for `httpx`. + * `h11` - HTTP/1.1 support. +* `certifi` - SSL certificates. +* `idna` - Internationalized domain name support. +* `sniffio` - Async library autodetection. + +As well as these optional installs: + +* `h2` - HTTP/2 support. *(Optional, with `httpx[http2]`)* +* `socksio` - SOCKS proxy support. *(Optional, with `httpx[socks]`)* +* `rich` - Rich terminal support. *(Optional, with `httpx[cli]`)* +* `click` - Command line client support. *(Optional, with `httpx[cli]`)* +* `brotli` or `brotlicffi` - Decoding for "brotli" compressed responses. *(Optional, with `httpx[brotli]`)* + +A huge amount of credit is due to `requests` for the API layout that +much of this work follows, as well as to `urllib3` for plenty of design +inspiration around the lower-level networking details. + +--- + +<p align="center"><i>HTTPX is <a href="https://github.com/encode/httpx/blob/master/LICENSE.md">BSD licensed</a> code.<br/>Designed & crafted with care.</i><br/>— 🦋 —</p> + +## Release Information + +### Removed + +* Drop support for Python 3.7. (#2813) + +### Added + +* Support HTTPS proxies. (#2845) +* Change the type of `Extensions` from `Mapping[Str, Any]` to `MutableMapping[Str, Any]`. (#2803) +* Add `socket_options` argument to `httpx.HTTPTransport` and `httpx.AsyncHTTPTransport` classes. (#2716) +* The `Response.raise_for_status()` method now returns the response instance. For example: `data = httpx.get('...').raise_for_status().json()`. (#2776) + +### Fixed + +* Return `500` error response instead of exceptions when `raise_app_exceptions=False` is set on `ASGITransport`. (#2669) +* Ensure all `WSGITransport` environs have a `SERVER_PROTOCOL`. (#2708) +* Always encode forward slashes as `%2F` in query parameters (#2723) +* Use Mozilla documentation instead of `httpstatuses.com` for HTTP error reference (#2768) + + +--- + +[Full changelog](https://github.com/encode/httpx/blob/master/CHANGELOG.md) diff --git a/contrib/python/httpx/.dist-info/entry_points.txt b/contrib/python/httpx/.dist-info/entry_points.txt new file mode 100644 index 0000000000..8ae96007f7 --- /dev/null +++ b/contrib/python/httpx/.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +httpx = httpx:main diff --git a/contrib/python/httpx/.dist-info/top_level.txt b/contrib/python/httpx/.dist-info/top_level.txt new file mode 100644 index 0000000000..c180eb2f8e --- /dev/null +++ b/contrib/python/httpx/.dist-info/top_level.txt @@ -0,0 +1,2 @@ +httpx +httpx/_transports diff --git a/contrib/python/httpx/LICENSE.md b/contrib/python/httpx/LICENSE.md new file mode 100644 index 0000000000..ab79d16a3f --- /dev/null +++ b/contrib/python/httpx/LICENSE.md @@ -0,0 +1,12 @@ +Copyright © 2019, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/contrib/python/httpx/README.md b/contrib/python/httpx/README.md new file mode 100644 index 0000000000..62fb295d17 --- /dev/null +++ b/contrib/python/httpx/README.md @@ -0,0 +1,148 @@ +<p align="center"> + <a href="https://www.python-httpx.org/"><img width="350" height="208" src="https://raw.githubusercontent.com/encode/httpx/master/docs/img/butterfly.png" alt='HTTPX'></a> +</p> + +<p align="center"><strong>HTTPX</strong> <em>- A next-generation HTTP client for Python.</em></p> + +<p align="center"> +<a href="https://github.com/encode/httpx/actions"> + <img src="https://github.com/encode/httpx/workflows/Test%20Suite/badge.svg" alt="Test Suite"> +</a> +<a href="https://pypi.org/project/httpx/"> + <img src="https://badge.fury.io/py/httpx.svg" alt="Package version"> +</a> +</p> + +HTTPX is a fully featured HTTP client library for Python 3. It includes **an integrated +command line client**, has support for both **HTTP/1.1 and HTTP/2**, and provides both **sync +and async APIs**. + +--- + +Install HTTPX using pip: + +```shell +$ pip install httpx +``` + +Now, let's get started: + +```pycon +>>> import httpx +>>> r = httpx.get('https://www.example.org/') +>>> r +<Response [200 OK]> +>>> r.status_code +200 +>>> r.headers['content-type'] +'text/html; charset=UTF-8' +>>> r.text +'<!doctype html>\n<html>\n<head>\n<title>Example Domain</title>...' +``` + +Or, using the command-line client. + +```shell +$ pip install 'httpx[cli]' # The command line client is an optional dependency. +``` + +Which now allows us to use HTTPX directly from the command-line... + +<p align="center"> + <img width="700" src="docs/img/httpx-help.png" alt='httpx --help'> +</p> + +Sending a request... + +<p align="center"> + <img width="700" src="docs/img/httpx-request.png" alt='httpx http://httpbin.org/json'> +</p> + +## Features + +HTTPX builds on the well-established usability of `requests`, and gives you: + +* A broadly [requests-compatible API](https://www.python-httpx.org/compatibility/). +* An integrated command-line client. +* HTTP/1.1 [and HTTP/2 support](https://www.python-httpx.org/http2/). +* Standard synchronous interface, but with [async support if you need it](https://www.python-httpx.org/async/). +* Ability to make requests directly to [WSGI applications](https://www.python-httpx.org/advanced/#calling-into-python-web-apps) or [ASGI applications](https://www.python-httpx.org/async/#calling-into-python-web-apps). +* Strict timeouts everywhere. +* Fully type annotated. +* 100% test coverage. + +Plus all the standard features of `requests`... + +* International Domains and URLs +* Keep-Alive & Connection Pooling +* Sessions with Cookie Persistence +* Browser-style SSL Verification +* Basic/Digest Authentication +* Elegant Key/Value Cookies +* Automatic Decompression +* Automatic Content Decoding +* Unicode Response Bodies +* Multipart File Uploads +* HTTP(S) Proxy Support +* Connection Timeouts +* Streaming Downloads +* .netrc Support +* Chunked Requests + +## Installation + +Install with pip: + +```shell +$ pip install httpx +``` + +Or, to include the optional HTTP/2 support, use: + +```shell +$ pip install httpx[http2] +``` + +HTTPX requires Python 3.8+. + +## Documentation + +Project documentation is available at [https://www.python-httpx.org/](https://www.python-httpx.org/). + +For a run-through of all the basics, head over to the [QuickStart](https://www.python-httpx.org/quickstart/). + +For more advanced topics, see the [Advanced Usage](https://www.python-httpx.org/advanced/) section, the [async support](https://www.python-httpx.org/async/) section, or the [HTTP/2](https://www.python-httpx.org/http2/) section. + +The [Developer Interface](https://www.python-httpx.org/api/) provides a comprehensive API reference. + +To find out about tools that integrate with HTTPX, see [Third Party Packages](https://www.python-httpx.org/third_party_packages/). + +## Contribute + +If you want to contribute with HTTPX check out the [Contributing Guide](https://www.python-httpx.org/contributing/) to learn how to start. + +## Dependencies + +The HTTPX project relies on these excellent libraries: + +* `httpcore` - The underlying transport implementation for `httpx`. + * `h11` - HTTP/1.1 support. +* `certifi` - SSL certificates. +* `idna` - Internationalized domain name support. +* `sniffio` - Async library autodetection. + +As well as these optional installs: + +* `h2` - HTTP/2 support. *(Optional, with `httpx[http2]`)* +* `socksio` - SOCKS proxy support. *(Optional, with `httpx[socks]`)* +* `rich` - Rich terminal support. *(Optional, with `httpx[cli]`)* +* `click` - Command line client support. *(Optional, with `httpx[cli]`)* +* `brotli` or `brotlicffi` - Decoding for "brotli" compressed responses. *(Optional, with `httpx[brotli]`)* + +A huge amount of credit is due to `requests` for the API layout that +much of this work follows, as well as to `urllib3` for plenty of design +inspiration around the lower-level networking details. + +--- + +<p align="center"><i>HTTPX is <a href="https://github.com/encode/httpx/blob/master/LICENSE.md">BSD licensed</a> code.<br/>Designed & crafted with care.</i><br/>— 🦋 —</p> diff --git a/contrib/python/httpx/httpx/__init__.py b/contrib/python/httpx/httpx/__init__.py new file mode 100644 index 0000000000..f61112f8b2 --- /dev/null +++ b/contrib/python/httpx/httpx/__init__.py @@ -0,0 +1,138 @@ +from .__version__ import __description__, __title__, __version__ +from ._api import delete, get, head, options, patch, post, put, request, stream +from ._auth import Auth, BasicAuth, DigestAuth, NetRCAuth +from ._client import USE_CLIENT_DEFAULT, AsyncClient, Client +from ._config import Limits, Proxy, Timeout, create_ssl_context +from ._content import ByteStream +from ._exceptions import ( + CloseError, + ConnectError, + ConnectTimeout, + CookieConflict, + DecodingError, + HTTPError, + HTTPStatusError, + InvalidURL, + LocalProtocolError, + NetworkError, + PoolTimeout, + ProtocolError, + ProxyError, + ReadError, + ReadTimeout, + RemoteProtocolError, + RequestError, + RequestNotRead, + ResponseNotRead, + StreamClosed, + StreamConsumed, + StreamError, + TimeoutException, + TooManyRedirects, + TransportError, + UnsupportedProtocol, + WriteError, + WriteTimeout, +) +from ._models import Cookies, Headers, Request, Response +from ._status_codes import codes +from ._transports.asgi import ASGITransport +from ._transports.base import AsyncBaseTransport, BaseTransport +from ._transports.default import AsyncHTTPTransport, HTTPTransport +from ._transports.mock import MockTransport +from ._transports.wsgi import WSGITransport +from ._types import AsyncByteStream, SyncByteStream +from ._urls import URL, QueryParams + +try: + from ._main import main +except ImportError: # pragma: no cover + + def main() -> None: # type: ignore + import sys + + print( + "The httpx command line client could not run because the required " + "dependencies were not installed.\nMake sure you've installed " + "everything with: pip install 'httpx[cli]'" + ) + sys.exit(1) + + +__all__ = [ + "__description__", + "__title__", + "__version__", + "ASGITransport", + "AsyncBaseTransport", + "AsyncByteStream", + "AsyncClient", + "AsyncHTTPTransport", + "Auth", + "BaseTransport", + "BasicAuth", + "ByteStream", + "Client", + "CloseError", + "codes", + "ConnectError", + "ConnectTimeout", + "CookieConflict", + "Cookies", + "create_ssl_context", + "DecodingError", + "delete", + "DigestAuth", + "get", + "head", + "Headers", + "HTTPError", + "HTTPStatusError", + "HTTPTransport", + "InvalidURL", + "Limits", + "LocalProtocolError", + "main", + "MockTransport", + "NetRCAuth", + "NetworkError", + "options", + "patch", + "PoolTimeout", + "post", + "ProtocolError", + "Proxy", + "ProxyError", + "put", + "QueryParams", + "ReadError", + "ReadTimeout", + "RemoteProtocolError", + "request", + "Request", + "RequestError", + "RequestNotRead", + "Response", + "ResponseNotRead", + "stream", + "StreamClosed", + "StreamConsumed", + "StreamError", + "SyncByteStream", + "Timeout", + "TimeoutException", + "TooManyRedirects", + "TransportError", + "UnsupportedProtocol", + "URL", + "USE_CLIENT_DEFAULT", + "WriteError", + "WriteTimeout", + "WSGITransport", +] + + +__locals = locals() +for __name in __all__: + if not __name.startswith("__"): + setattr(__locals[__name], "__module__", "httpx") # noqa diff --git a/contrib/python/httpx/httpx/__version__.py b/contrib/python/httpx/httpx/__version__.py new file mode 100644 index 0000000000..bfa421ad60 --- /dev/null +++ b/contrib/python/httpx/httpx/__version__.py @@ -0,0 +1,3 @@ +__title__ = "httpx" +__description__ = "A next generation HTTP client, for Python 3." +__version__ = "0.25.0" diff --git a/contrib/python/httpx/httpx/_api.py b/contrib/python/httpx/httpx/_api.py new file mode 100644 index 0000000000..571289cf2b --- /dev/null +++ b/contrib/python/httpx/httpx/_api.py @@ -0,0 +1,445 @@ +import typing +from contextlib import contextmanager + +from ._client import Client +from ._config import DEFAULT_TIMEOUT_CONFIG +from ._models import Response +from ._types import ( + AuthTypes, + CertTypes, + CookieTypes, + HeaderTypes, + ProxiesTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, + URLTypes, + VerifyTypes, +) + + +def request( + method: str, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + trust_env: bool = True, +) -> Response: + """ + Sends an HTTP request. + + **Parameters:** + + * **method** - HTTP method for the new `Request` object: `GET`, `OPTIONS`, + `HEAD`, `POST`, `PUT`, `PATCH`, or `DELETE`. + * **url** - URL for the new `Request` object. + * **params** - *(optional)* Query parameters to include in the URL, as a + string, dictionary, or sequence of two-tuples. + * **content** - *(optional)* Binary content to include in the body of the + request, as bytes or a byte iterator. + * **data** - *(optional)* Form data to include in the body of the request, + as a dictionary. + * **files** - *(optional)* A dictionary of upload files to include in the + body of the request. + * **json** - *(optional)* A JSON serializable object to include in the body + of the request. + * **headers** - *(optional)* Dictionary of HTTP headers to include in the + request. + * **cookies** - *(optional)* Dictionary of Cookie items to include in the + request. + * **auth** - *(optional)* An authentication class to use when sending the + request. + * **proxies** - *(optional)* A dictionary mapping proxy keys to proxy URLs. + * **timeout** - *(optional)* The timeout configuration to use when sending + the request. + * **follow_redirects** - *(optional)* Enables or disables HTTP redirects. + * **verify** - *(optional)* SSL certificates (a.k.a CA bundle) used to + verify the identity of requested hosts. Either `True` (default CA bundle), + a path to an SSL certificate file, an `ssl.SSLContext`, or `False` + (which will disable verification). + * **cert** - *(optional)* An SSL certificate used by the requested host + to authenticate the client. Either a path to an SSL certificate file, or + two-tuple of (certificate file, key file), or a three-tuple of (certificate + file, key file, password). + * **trust_env** - *(optional)* Enables or disables usage of environment + variables for configuration. + + **Returns:** `Response` + + Usage: + + ``` + >>> import httpx + >>> response = httpx.request('GET', 'https://httpbin.org/get') + >>> response + <Response [200 OK]> + ``` + """ + with Client( + cookies=cookies, + proxies=proxies, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + return client.request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + ) + + +@contextmanager +def stream( + method: str, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + trust_env: bool = True, +) -> typing.Iterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + with Client( + cookies=cookies, + proxies=proxies, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + with client.stream( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + ) as response: + yield response + + +def get( + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `GET` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `GET` requests should not include a request body. + """ + return request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def options( + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends an `OPTIONS` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `OPTIONS` requests should not include a request body. + """ + return request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def head( + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `HEAD` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `HEAD` requests should not include a request body. + """ + return request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def post( + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `POST` request. + + **Parameters**: See `httpx.request`. + """ + return request( + "POST", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def put( + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `PUT` request. + + **Parameters**: See `httpx.request`. + """ + return request( + "PUT", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def patch( + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `PATCH` request. + + **Parameters**: See `httpx.request`. + """ + return request( + "PATCH", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def delete( + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `DELETE` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `DELETE` requests should not include a request body. + """ + return request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) diff --git a/contrib/python/httpx/httpx/_auth.py b/contrib/python/httpx/httpx/_auth.py new file mode 100644 index 0000000000..1d7385d573 --- /dev/null +++ b/contrib/python/httpx/httpx/_auth.py @@ -0,0 +1,347 @@ +import hashlib +import netrc +import os +import re +import time +import typing +from base64 import b64encode +from urllib.request import parse_http_list + +from ._exceptions import ProtocolError +from ._models import Request, Response +from ._utils import to_bytes, to_str, unquote + +if typing.TYPE_CHECKING: # pragma: no cover + from hashlib import _Hash + + +class Auth: + """ + Base class for all authentication schemes. + + To implement a custom authentication scheme, subclass `Auth` and override + the `.auth_flow()` method. + + If the authentication scheme does I/O such as disk access or network calls, or uses + synchronization primitives such as locks, you should override `.sync_auth_flow()` + and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized + implementations that will be used by `Client` and `AsyncClient` respectively. + """ + + requires_request_body = False + requires_response_body = False + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + """ + Execute the authentication flow. + + To dispatch a request, `yield` it: + + ``` + yield request + ``` + + The client will `.send()` the response back into the flow generator. You can + access it like so: + + ``` + response = yield request + ``` + + A `return` (or reaching the end of the generator) will result in the + client returning the last response obtained from the server. + + You can dispatch as many requests as is necessary. + """ + yield request + + def sync_auth_flow( + self, request: Request + ) -> typing.Generator[Request, Response, None]: + """ + Execute the authentication flow synchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O and/or uses concurrency primitives. + """ + if self.requires_request_body: + request.read() + + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + response.read() + + try: + request = flow.send(response) + except StopIteration: + break + + async def async_auth_flow( + self, request: Request + ) -> typing.AsyncGenerator[Request, Response]: + """ + Execute the authentication flow asynchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O and/or uses concurrency primitives. + """ + if self.requires_request_body: + await request.aread() + + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + await response.aread() + + try: + request = flow.send(response) + except StopIteration: + break + + +class FunctionAuth(Auth): + """ + Allows the 'auth' argument to be passed as a simple callable function, + that takes the request, and returns a new, modified request. + """ + + def __init__(self, func: typing.Callable[[Request], Request]) -> None: + self._func = func + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + yield self._func(request) + + +class BasicAuth(Auth): + """ + Allows the 'auth' argument to be passed as a (username, password) pair, + and uses HTTP Basic authentication. + """ + + def __init__( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ): + self._auth_header = self._build_auth_header(username, password) + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + request.headers["Authorization"] = self._auth_header + yield request + + def _build_auth_header( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ) -> str: + userpass = b":".join((to_bytes(username), to_bytes(password))) + token = b64encode(userpass).decode() + return f"Basic {token}" + + +class NetRCAuth(Auth): + """ + Use a 'netrc' file to lookup basic auth credentials based on the url host. + """ + + def __init__(self, file: typing.Optional[str] = None): + self._netrc_info = netrc.netrc(file) + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + auth_info = self._netrc_info.authenticators(request.url.host) + if auth_info is None or not auth_info[2]: + # The netrc file did not have authentication credentials for this host. + yield request + else: + # Build a basic auth header with credentials from the netrc file. + request.headers["Authorization"] = self._build_auth_header( + username=auth_info[0], password=auth_info[2] + ) + yield request + + def _build_auth_header( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ) -> str: + userpass = b":".join((to_bytes(username), to_bytes(password))) + token = b64encode(userpass).decode() + return f"Basic {token}" + + +class DigestAuth(Auth): + _ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable[[bytes], "_Hash"]] = { + "MD5": hashlib.md5, + "MD5-SESS": hashlib.md5, + "SHA": hashlib.sha1, + "SHA-SESS": hashlib.sha1, + "SHA-256": hashlib.sha256, + "SHA-256-SESS": hashlib.sha256, + "SHA-512": hashlib.sha512, + "SHA-512-SESS": hashlib.sha512, + } + + def __init__( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ) -> None: + self._username = to_bytes(username) + self._password = to_bytes(password) + self._last_challenge: typing.Optional[_DigestAuthChallenge] = None + self._nonce_count = 1 + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + if self._last_challenge: + request.headers["Authorization"] = self._build_auth_header( + request, self._last_challenge + ) + + response = yield request + + if response.status_code != 401 or "www-authenticate" not in response.headers: + # If the response is not a 401 then we don't + # need to build an authenticated request. + return + + for auth_header in response.headers.get_list("www-authenticate"): + if auth_header.lower().startswith("digest "): + break + else: + # If the response does not include a 'WWW-Authenticate: Digest ...' + # header, then we don't need to build an authenticated request. + return + + self._last_challenge = self._parse_challenge(request, response, auth_header) + self._nonce_count = 1 + + request.headers["Authorization"] = self._build_auth_header( + request, self._last_challenge + ) + yield request + + def _parse_challenge( + self, request: Request, response: Response, auth_header: str + ) -> "_DigestAuthChallenge": + """ + Returns a challenge from a Digest WWW-Authenticate header. + These take the form of: + `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"` + """ + scheme, _, fields = auth_header.partition(" ") + + # This method should only ever have been called with a Digest auth header. + assert scheme.lower() == "digest" + + header_dict: typing.Dict[str, str] = {} + for field in parse_http_list(fields): + key, value = field.strip().split("=", 1) + header_dict[key] = unquote(value) + + try: + realm = header_dict["realm"].encode() + nonce = header_dict["nonce"].encode() + algorithm = header_dict.get("algorithm", "MD5") + opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None + qop = header_dict["qop"].encode() if "qop" in header_dict else None + return _DigestAuthChallenge( + realm=realm, nonce=nonce, algorithm=algorithm, opaque=opaque, qop=qop + ) + except KeyError as exc: + message = "Malformed Digest WWW-Authenticate header" + raise ProtocolError(message, request=request) from exc + + def _build_auth_header( + self, request: Request, challenge: "_DigestAuthChallenge" + ) -> str: + hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()] + + def digest(data: bytes) -> bytes: + return hash_func(data).hexdigest().encode() + + A1 = b":".join((self._username, challenge.realm, self._password)) + + path = request.url.raw_path + A2 = b":".join((request.method.encode(), path)) + # TODO: implement auth-int + HA2 = digest(A2) + + nc_value = b"%08x" % self._nonce_count + cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce) + self._nonce_count += 1 + + HA1 = digest(A1) + if challenge.algorithm.lower().endswith("-sess"): + HA1 = digest(b":".join((HA1, challenge.nonce, cnonce))) + + qop = self._resolve_qop(challenge.qop, request=request) + if qop is None: + digest_data = [HA1, challenge.nonce, HA2] + else: + digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2] + key_digest = b":".join(digest_data) + + format_args = { + "username": self._username, + "realm": challenge.realm, + "nonce": challenge.nonce, + "uri": path, + "response": digest(b":".join((HA1, key_digest))), + "algorithm": challenge.algorithm.encode(), + } + if challenge.opaque: + format_args["opaque"] = challenge.opaque + if qop: + format_args["qop"] = b"auth" + format_args["nc"] = nc_value + format_args["cnonce"] = cnonce + + return "Digest " + self._get_header_value(format_args) + + def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes: + s = str(nonce_count).encode() + s += nonce + s += time.ctime().encode() + s += os.urandom(8) + + return hashlib.sha1(s).hexdigest()[:16].encode() + + def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str: + NON_QUOTED_FIELDS = ("algorithm", "qop", "nc") + QUOTED_TEMPLATE = '{}="{}"' + NON_QUOTED_TEMPLATE = "{}={}" + + header_value = "" + for i, (field, value) in enumerate(header_fields.items()): + if i > 0: + header_value += ", " + template = ( + QUOTED_TEMPLATE + if field not in NON_QUOTED_FIELDS + else NON_QUOTED_TEMPLATE + ) + header_value += template.format(field, to_str(value)) + + return header_value + + def _resolve_qop( + self, qop: typing.Optional[bytes], request: Request + ) -> typing.Optional[bytes]: + if qop is None: + return None + qops = re.split(b", ?", qop) + if b"auth" in qops: + return b"auth" + + if qops == [b"auth-int"]: + raise NotImplementedError("Digest auth-int support is not yet implemented") + + message = f'Unexpected qop value "{qop!r}" in digest auth' + raise ProtocolError(message, request=request) + + +class _DigestAuthChallenge(typing.NamedTuple): + realm: bytes + nonce: bytes + algorithm: str + opaque: typing.Optional[bytes] + qop: typing.Optional[bytes] diff --git a/contrib/python/httpx/httpx/_client.py b/contrib/python/httpx/httpx/_client.py new file mode 100644 index 0000000000..cb475e0204 --- /dev/null +++ b/contrib/python/httpx/httpx/_client.py @@ -0,0 +1,2006 @@ +import datetime +import enum +import logging +import typing +import warnings +from contextlib import asynccontextmanager, contextmanager +from types import TracebackType + +from .__version__ import __version__ +from ._auth import Auth, BasicAuth, FunctionAuth +from ._config import ( + DEFAULT_LIMITS, + DEFAULT_MAX_REDIRECTS, + DEFAULT_TIMEOUT_CONFIG, + Limits, + Proxy, + Timeout, +) +from ._decoders import SUPPORTED_DECODERS +from ._exceptions import ( + InvalidURL, + RemoteProtocolError, + TooManyRedirects, + request_context, +) +from ._models import Cookies, Headers, Request, Response +from ._status_codes import codes +from ._transports.asgi import ASGITransport +from ._transports.base import AsyncBaseTransport, BaseTransport +from ._transports.default import AsyncHTTPTransport, HTTPTransport +from ._transports.wsgi import WSGITransport +from ._types import ( + AsyncByteStream, + AuthTypes, + CertTypes, + CookieTypes, + HeaderTypes, + ProxiesTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestExtensions, + RequestFiles, + SyncByteStream, + TimeoutTypes, + URLTypes, + VerifyTypes, +) +from ._urls import URL, QueryParams +from ._utils import ( + Timer, + URLPattern, + get_environment_proxies, + is_https_redirect, + same_origin, +) + +# The type annotation for @classmethod and context managers here follows PEP 484 +# https://www.python.org/dev/peps/pep-0484/#annotating-instance-and-class-methods +T = typing.TypeVar("T", bound="Client") +U = typing.TypeVar("U", bound="AsyncClient") + + +class UseClientDefault: + """ + For some parameters such as `auth=...` and `timeout=...` we need to be able + to indicate the default "unset" state, in a way that is distinctly different + to using `None`. + + The default "unset" state indicates that whatever default is set on the + client should be used. This is different to setting `None`, which + explicitly disables the parameter, possibly overriding a client default. + + For example we use `timeout=USE_CLIENT_DEFAULT` in the `request()` signature. + Omitting the `timeout` parameter will send a request using whatever default + timeout has been configured on the client. Including `timeout=None` will + ensure no timeout is used. + + Note that user code shouldn't need to use the `USE_CLIENT_DEFAULT` constant, + but it is used internally when a parameter is not included. + """ + + +USE_CLIENT_DEFAULT = UseClientDefault() + + +logger = logging.getLogger("httpx") + +USER_AGENT = f"python-httpx/{__version__}" +ACCEPT_ENCODING = ", ".join( + [key for key in SUPPORTED_DECODERS.keys() if key != "identity"] +) + + +class ClientState(enum.Enum): + # UNOPENED: + # The client has been instantiated, but has not been used to send a request, + # or been opened by entering the context of a `with` block. + UNOPENED = 1 + # OPENED: + # The client has either sent a request, or is within a `with` block. + OPENED = 2 + # CLOSED: + # The client has either exited the `with` block, or `close()` has + # been called explicitly. + CLOSED = 3 + + +class BoundSyncStream(SyncByteStream): + """ + A byte stream that is bound to a given response instance, and that + ensures the `response.elapsed` is set once the response is closed. + """ + + def __init__( + self, stream: SyncByteStream, response: Response, timer: Timer + ) -> None: + self._stream = stream + self._response = response + self._timer = timer + + def __iter__(self) -> typing.Iterator[bytes]: + for chunk in self._stream: + yield chunk + + def close(self) -> None: + seconds = self._timer.sync_elapsed() + self._response.elapsed = datetime.timedelta(seconds=seconds) + self._stream.close() + + +class BoundAsyncStream(AsyncByteStream): + """ + An async byte stream that is bound to a given response instance, and that + ensures the `response.elapsed` is set once the response is closed. + """ + + def __init__( + self, stream: AsyncByteStream, response: Response, timer: Timer + ) -> None: + self._stream = stream + self._response = response + self._timer = timer + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async for chunk in self._stream: + yield chunk + + async def aclose(self) -> None: + seconds = await self._timer.async_elapsed() + self._response.elapsed = datetime.timedelta(seconds=seconds) + await self._stream.aclose() + + +EventHook = typing.Callable[..., typing.Any] + + +class BaseClient: + def __init__( + self, + *, + auth: typing.Optional[AuthTypes] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + event_hooks: typing.Optional[ + typing.Mapping[str, typing.List[EventHook]] + ] = None, + base_url: URLTypes = "", + trust_env: bool = True, + default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8", + ): + event_hooks = {} if event_hooks is None else event_hooks + + self._base_url = self._enforce_trailing_slash(URL(base_url)) + + self._auth = self._build_auth(auth) + self._params = QueryParams(params) + self.headers = Headers(headers) + self._cookies = Cookies(cookies) + self._timeout = Timeout(timeout) + self.follow_redirects = follow_redirects + self.max_redirects = max_redirects + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + self._trust_env = trust_env + self._default_encoding = default_encoding + self._state = ClientState.UNOPENED + + @property + def is_closed(self) -> bool: + """ + Check if the client being closed + """ + return self._state == ClientState.CLOSED + + @property + def trust_env(self) -> bool: + return self._trust_env + + def _enforce_trailing_slash(self, url: URL) -> URL: + if url.raw_path.endswith(b"/"): + return url + return url.copy_with(raw_path=url.raw_path + b"/") + + def _get_proxy_map( + self, proxies: typing.Optional[ProxiesTypes], allow_env_proxies: bool + ) -> typing.Dict[str, typing.Optional[Proxy]]: + if proxies is None: + if allow_env_proxies: + return { + key: None if url is None else Proxy(url=url) + for key, url in get_environment_proxies().items() + } + return {} + if isinstance(proxies, dict): + new_proxies = {} + for key, value in proxies.items(): + proxy = Proxy(url=value) if isinstance(value, (str, URL)) else value + new_proxies[str(key)] = proxy + return new_proxies + else: + proxy = Proxy(url=proxies) if isinstance(proxies, (str, URL)) else proxies + return {"all://": proxy} + + @property + def timeout(self) -> Timeout: + return self._timeout + + @timeout.setter + def timeout(self, timeout: TimeoutTypes) -> None: + self._timeout = Timeout(timeout) + + @property + def event_hooks(self) -> typing.Dict[str, typing.List[EventHook]]: + return self._event_hooks + + @event_hooks.setter + def event_hooks( + self, event_hooks: typing.Dict[str, typing.List[EventHook]] + ) -> None: + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + + @property + def auth(self) -> typing.Optional[Auth]: + """ + Authentication class used when none is passed at the request-level. + + See also [Authentication][0]. + + [0]: /quickstart/#authentication + """ + return self._auth + + @auth.setter + def auth(self, auth: AuthTypes) -> None: + self._auth = self._build_auth(auth) + + @property + def base_url(self) -> URL: + """ + Base URL to use when sending requests with relative URLs. + """ + return self._base_url + + @base_url.setter + def base_url(self, url: URLTypes) -> None: + self._base_url = self._enforce_trailing_slash(URL(url)) + + @property + def headers(self) -> Headers: + """ + HTTP headers to include when sending requests. + """ + return self._headers + + @headers.setter + def headers(self, headers: HeaderTypes) -> None: + client_headers = Headers( + { + b"Accept": b"*/*", + b"Accept-Encoding": ACCEPT_ENCODING.encode("ascii"), + b"Connection": b"keep-alive", + b"User-Agent": USER_AGENT.encode("ascii"), + } + ) + client_headers.update(headers) + self._headers = client_headers + + @property + def cookies(self) -> Cookies: + """ + Cookie values to include when sending requests. + """ + return self._cookies + + @cookies.setter + def cookies(self, cookies: CookieTypes) -> None: + self._cookies = Cookies(cookies) + + @property + def params(self) -> QueryParams: + """ + Query parameters to include in the URL when sending requests. + """ + return self._params + + @params.setter + def params(self, params: QueryParamTypes) -> None: + self._params = QueryParams(params) + + def build_request( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Request: + """ + Build and return a request instance. + + * The `params`, `headers` and `cookies` arguments + are merged with any values set on the client. + * The `url` argument is merged with any `base_url` set on the client. + + See also: [Request instances][0] + + [0]: /advanced/#request-instances + """ + url = self._merge_url(url) + headers = self._merge_headers(headers) + cookies = self._merge_cookies(cookies) + params = self._merge_queryparams(params) + extensions = {} if extensions is None else extensions + if "timeout" not in extensions: + timeout = ( + self.timeout + if isinstance(timeout, UseClientDefault) + else Timeout(timeout) + ) + extensions = dict(**extensions, timeout=timeout.as_dict()) + return Request( + method, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + extensions=extensions, + ) + + def _merge_url(self, url: URLTypes) -> URL: + """ + Merge a URL argument together with any 'base_url' on the client, + to create the URL used for the outgoing request. + """ + merge_url = URL(url) + if merge_url.is_relative_url: + # To merge URLs we always append to the base URL. To get this + # behaviour correct we always ensure the base URL ends in a '/' + # separator, and strip any leading '/' from the merge URL. + # + # So, eg... + # + # >>> client = Client(base_url="https://www.example.com/subpath") + # >>> client.base_url + # URL('https://www.example.com/subpath/') + # >>> client.build_request("GET", "/path").url + # URL('https://www.example.com/subpath/path') + merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/") + return self.base_url.copy_with(raw_path=merge_raw_path) + return merge_url + + def _merge_cookies( + self, cookies: typing.Optional[CookieTypes] = None + ) -> typing.Optional[CookieTypes]: + """ + Merge a cookies argument together with any cookies on the client, + to create the cookies used for the outgoing request. + """ + if cookies or self.cookies: + merged_cookies = Cookies(self.cookies) + merged_cookies.update(cookies) + return merged_cookies + return cookies + + def _merge_headers( + self, headers: typing.Optional[HeaderTypes] = None + ) -> typing.Optional[HeaderTypes]: + """ + Merge a headers argument together with any headers on the client, + to create the headers used for the outgoing request. + """ + merged_headers = Headers(self.headers) + merged_headers.update(headers) + return merged_headers + + def _merge_queryparams( + self, params: typing.Optional[QueryParamTypes] = None + ) -> typing.Optional[QueryParamTypes]: + """ + Merge a queryparams argument together with any queryparams on the client, + to create the queryparams used for the outgoing request. + """ + if params or self.params: + merged_queryparams = QueryParams(self.params) + return merged_queryparams.merge(params) + return params + + def _build_auth(self, auth: typing.Optional[AuthTypes]) -> typing.Optional[Auth]: + if auth is None: + return None + elif isinstance(auth, tuple): + return BasicAuth(username=auth[0], password=auth[1]) + elif isinstance(auth, Auth): + return auth + elif callable(auth): + return FunctionAuth(func=auth) + else: + raise TypeError(f'Invalid "auth" argument: {auth!r}') + + def _build_request_auth( + self, + request: Request, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + ) -> Auth: + auth = ( + self._auth if isinstance(auth, UseClientDefault) else self._build_auth(auth) + ) + + if auth is not None: + return auth + + username, password = request.url.username, request.url.password + if username or password: + return BasicAuth(username=username, password=password) + + return Auth() + + def _build_redirect_request(self, request: Request, response: Response) -> Request: + """ + Given a request and a redirect response, return a new request that + should be used to effect the redirect. + """ + method = self._redirect_method(request, response) + url = self._redirect_url(request, response) + headers = self._redirect_headers(request, url, method) + stream = self._redirect_stream(request, method) + cookies = Cookies(self.cookies) + return Request( + method=method, + url=url, + headers=headers, + cookies=cookies, + stream=stream, + extensions=request.extensions, + ) + + def _redirect_method(self, request: Request, response: Response) -> str: + """ + When being redirected we may want to change the method of the request + based on certain specs or browser behavior. + """ + method = request.method + + # https://tools.ietf.org/html/rfc7231#section-6.4.4 + if response.status_code == codes.SEE_OTHER and method != "HEAD": + method = "GET" + + # Do what the browsers do, despite standards... + # Turn 302s into GETs. + if response.status_code == codes.FOUND and method != "HEAD": + method = "GET" + + # If a POST is responded to with a 301, turn it into a GET. + # This bizarre behaviour is explained in 'requests' issue 1704. + if response.status_code == codes.MOVED_PERMANENTLY and method == "POST": + method = "GET" + + return method + + def _redirect_url(self, request: Request, response: Response) -> URL: + """ + Return the URL for the redirect to follow. + """ + location = response.headers["Location"] + + try: + url = URL(location) + except InvalidURL as exc: + raise RemoteProtocolError( + f"Invalid URL in location header: {exc}.", request=request + ) from None + + # Handle malformed 'Location' headers that are "absolute" form, have no host. + # See: https://github.com/encode/httpx/issues/771 + if url.scheme and not url.host: + url = url.copy_with(host=request.url.host) + + # Facilitate relative 'Location' headers, as allowed by RFC 7231. + # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + if url.is_relative_url: + url = request.url.join(url) + + # Attach previous fragment if needed (RFC 7231 7.1.2) + if request.url.fragment and not url.fragment: + url = url.copy_with(fragment=request.url.fragment) + + return url + + def _redirect_headers(self, request: Request, url: URL, method: str) -> Headers: + """ + Return the headers that should be used for the redirect request. + """ + headers = Headers(request.headers) + + if not same_origin(url, request.url): + if not is_https_redirect(request.url, url): + # Strip Authorization headers when responses are redirected + # away from the origin. (Except for direct HTTP to HTTPS redirects.) + headers.pop("Authorization", None) + + # Update the Host header. + headers["Host"] = url.netloc.decode("ascii") + + if method != request.method and method == "GET": + # If we've switch to a 'GET' request, then strip any headers which + # are only relevant to the request body. + headers.pop("Content-Length", None) + headers.pop("Transfer-Encoding", None) + + # We should use the client cookie store to determine any cookie header, + # rather than whatever was on the original outgoing request. + headers.pop("Cookie", None) + + return headers + + def _redirect_stream( + self, request: Request, method: str + ) -> typing.Optional[typing.Union[SyncByteStream, AsyncByteStream]]: + """ + Return the body that should be used for the redirect request. + """ + if method != request.method and method == "GET": + return None + + return request.stream + + +class Client(BaseClient): + """ + An HTTP client, with connection pooling, HTTP/2, redirects, cookie persistence, etc. + + It can be shared between threads. + + Usage: + + ```python + >>> client = httpx.Client() + >>> response = client.get('https://example.org') + ``` + + **Parameters:** + + * **auth** - *(optional)* An authentication class to use when sending + requests. + * **params** - *(optional)* Query parameters to include in request URLs, as + a string, dictionary, or sequence of two-tuples. + * **headers** - *(optional)* Dictionary of HTTP headers to include when + sending requests. + * **cookies** - *(optional)* Dictionary of Cookie items to include when + sending requests. + * **verify** - *(optional)* SSL certificates (a.k.a CA bundle) used to + verify the identity of requested hosts. Either `True` (default CA bundle), + a path to an SSL certificate file, an `ssl.SSLContext`, or `False` + (which will disable verification). + * **cert** - *(optional)* An SSL certificate used by the requested host + to authenticate the client. Either a path to an SSL certificate file, or + two-tuple of (certificate file, key file), or a three-tuple of (certificate + file, key file, password). + * **proxies** - *(optional)* A dictionary mapping proxy keys to proxy + URLs. + * **timeout** - *(optional)* The timeout configuration to use when sending + requests. + * **limits** - *(optional)* The limits configuration to use. + * **max_redirects** - *(optional)* The maximum number of redirect responses + that should be followed. + * **base_url** - *(optional)* A URL to use as the base when building + request URLs. + * **transport** - *(optional)* A transport class to use for sending requests + over the network. + * **app** - *(optional)* An WSGI application to send requests to, + rather than sending actual network requests. + * **trust_env** - *(optional)* Enables or disables usage of environment + variables for configuration. + * **default_encoding** - *(optional)* The default encoding to use for decoding + response text, if no charset information is included in a response Content-Type + header. Set to a callable for automatic character set detection. Default: "utf-8". + """ + + def __init__( + self, + *, + auth: typing.Optional[AuthTypes] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + proxies: typing.Optional[ProxiesTypes] = None, + mounts: typing.Optional[typing.Mapping[str, BaseTransport]] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + limits: Limits = DEFAULT_LIMITS, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + event_hooks: typing.Optional[ + typing.Mapping[str, typing.List[EventHook]] + ] = None, + base_url: URLTypes = "", + transport: typing.Optional[BaseTransport] = None, + app: typing.Optional[typing.Callable[..., typing.Any]] = None, + trust_env: bool = True, + default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8", + ): + super().__init__( + auth=auth, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + follow_redirects=follow_redirects, + max_redirects=max_redirects, + event_hooks=event_hooks, + base_url=base_url, + trust_env=trust_env, + default_encoding=default_encoding, + ) + + if http2: + try: + import h2 # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using http2=True, but the 'h2' package is not installed. " + "Make sure to install httpx using `pip install httpx[http2]`." + ) from None + + allow_env_proxies = trust_env and app is None and transport is None + proxy_map = self._get_proxy_map(proxies, allow_env_proxies) + + self._transport = self._init_transport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + transport=transport, + app=app, + trust_env=trust_env, + ) + self._mounts: typing.Dict[URLPattern, typing.Optional[BaseTransport]] = { + URLPattern(key): None + if proxy is None + else self._init_proxy_transport( + proxy, + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + ) + for key, proxy in proxy_map.items() + } + if mounts is not None: + self._mounts.update( + {URLPattern(key): transport for key, transport in mounts.items()} + ) + + self._mounts = dict(sorted(self._mounts.items())) + + def _init_transport( + self, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + transport: typing.Optional[BaseTransport] = None, + app: typing.Optional[typing.Callable[..., typing.Any]] = None, + trust_env: bool = True, + ) -> BaseTransport: + if transport is not None: + return transport + + if app is not None: + return WSGITransport(app=app) + + return HTTPTransport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + ) + + def _init_proxy_transport( + self, + proxy: Proxy, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + trust_env: bool = True, + ) -> BaseTransport: + return HTTPTransport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + proxy=proxy, + ) + + def _transport_for_url(self, url: URL) -> BaseTransport: + """ + Returns the transport instance that should be used for a given URL. + This will either be the standard connection pool, or a proxy. + """ + for pattern, transport in self._mounts.items(): + if pattern.matches(url): + return self._transport if transport is None else transport + + return self._transport + + def request( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Build and send a request. + + Equivalent to: + + ```python + request = client.build_request(...) + response = client.send(request, ...) + ``` + + See `Client.build_request()`, `Client.send()` and + [Merging of configuration][0] for how the various parameters + are merged with client-level configuration. + + [0]: /advanced/#merging-of-configuration + """ + if cookies is not None: + message = ( + "Setting per-request cookies=<...> is being deprecated, because " + "the expected behaviour on cookie persistence is ambiguous. Set " + "cookies directly on the client instance instead." + ) + warnings.warn(message, DeprecationWarning) + + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + return self.send(request, auth=auth, follow_redirects=follow_redirects) + + @contextmanager + def stream( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> typing.Iterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + response = self.send( + request=request, + auth=auth, + follow_redirects=follow_redirects, + stream=True, + ) + try: + yield response + finally: + response.close() + + def send( + self, + request: Request, + *, + stream: bool = False, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + ) -> Response: + """ + Send a request. + + The request is sent as-is, unmodified. + + Typically you'll want to build one with `Client.build_request()` + so that any client-level configuration is merged into the request, + but passing an explicit `httpx.Request()` is supported as well. + + See also: [Request instances][0] + + [0]: /advanced/#request-instances + """ + if self._state == ClientState.CLOSED: + raise RuntimeError("Cannot send a request, as the client has been closed.") + + self._state = ClientState.OPENED + follow_redirects = ( + self.follow_redirects + if isinstance(follow_redirects, UseClientDefault) + else follow_redirects + ) + + auth = self._build_request_auth(request, auth) + + response = self._send_handling_auth( + request, + auth=auth, + follow_redirects=follow_redirects, + history=[], + ) + try: + if not stream: + response.read() + + return response + + except BaseException as exc: + response.close() + raise exc + + def _send_handling_auth( + self, + request: Request, + auth: Auth, + follow_redirects: bool, + history: typing.List[Response], + ) -> Response: + auth_flow = auth.sync_auth_flow(request) + try: + request = next(auth_flow) + + while True: + response = self._send_handling_redirects( + request, + follow_redirects=follow_redirects, + history=history, + ) + try: + try: + next_request = auth_flow.send(response) + except StopIteration: + return response + + response.history = list(history) + response.read() + request = next_request + history.append(response) + + except BaseException as exc: + response.close() + raise exc + finally: + auth_flow.close() + + def _send_handling_redirects( + self, + request: Request, + follow_redirects: bool, + history: typing.List[Response], + ) -> Response: + while True: + if len(history) > self.max_redirects: + raise TooManyRedirects( + "Exceeded maximum allowed redirects.", request=request + ) + + for hook in self._event_hooks["request"]: + hook(request) + + response = self._send_single_request(request) + try: + for hook in self._event_hooks["response"]: + hook(response) + response.history = list(history) + + if not response.has_redirect_location: + return response + + request = self._build_redirect_request(request, response) + history = history + [response] + + if follow_redirects: + response.read() + else: + response.next_request = request + return response + + except BaseException as exc: + response.close() + raise exc + + def _send_single_request(self, request: Request) -> Response: + """ + Sends a single request, without handling any redirections. + """ + transport = self._transport_for_url(request.url) + timer = Timer() + timer.sync_start() + + if not isinstance(request.stream, SyncByteStream): + raise RuntimeError( + "Attempted to send an async request with a sync Client instance." + ) + + with request_context(request=request): + response = transport.handle_request(request) + + assert isinstance(response.stream, SyncByteStream) + + response.request = request + response.stream = BoundSyncStream( + response.stream, response=response, timer=timer + ) + self.cookies.extract_cookies(response) + response.default_encoding = self._default_encoding + + logger.info( + 'HTTP Request: %s %s "%s %d %s"', + request.method, + request.url, + response.http_version, + response.status_code, + response.reason_phrase, + ) + + return response + + def get( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `GET` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def options( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send an `OPTIONS` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def head( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `HEAD` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def post( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `POST` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "POST", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def put( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `PUT` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "PUT", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def patch( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `PATCH` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "PATCH", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def delete( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `DELETE` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def close(self) -> None: + """ + Close transport and proxies. + """ + if self._state != ClientState.CLOSED: + self._state = ClientState.CLOSED + + self._transport.close() + for transport in self._mounts.values(): + if transport is not None: + transport.close() + + def __enter__(self: T) -> T: + if self._state != ClientState.UNOPENED: + msg = { + ClientState.OPENED: "Cannot open a client instance more than once.", + ClientState.CLOSED: "Cannot reopen a client instance, once it has been closed.", + }[self._state] + raise RuntimeError(msg) + + self._state = ClientState.OPENED + + self._transport.__enter__() + for transport in self._mounts.values(): + if transport is not None: + transport.__enter__() + return self + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + self._state = ClientState.CLOSED + + self._transport.__exit__(exc_type, exc_value, traceback) + for transport in self._mounts.values(): + if transport is not None: + transport.__exit__(exc_type, exc_value, traceback) + + +class AsyncClient(BaseClient): + """ + An asynchronous HTTP client, with connection pooling, HTTP/2, redirects, + cookie persistence, etc. + + Usage: + + ```python + >>> async with httpx.AsyncClient() as client: + >>> response = await client.get('https://example.org') + ``` + + **Parameters:** + + * **auth** - *(optional)* An authentication class to use when sending + requests. + * **params** - *(optional)* Query parameters to include in request URLs, as + a string, dictionary, or sequence of two-tuples. + * **headers** - *(optional)* Dictionary of HTTP headers to include when + sending requests. + * **cookies** - *(optional)* Dictionary of Cookie items to include when + sending requests. + * **verify** - *(optional)* SSL certificates (a.k.a CA bundle) used to + verify the identity of requested hosts. Either `True` (default CA bundle), + a path to an SSL certificate file, an `ssl.SSLContext`, or `False` + (which will disable verification). + * **cert** - *(optional)* An SSL certificate used by the requested host + to authenticate the client. Either a path to an SSL certificate file, or + two-tuple of (certificate file, key file), or a three-tuple of (certificate + file, key file, password). + * **http2** - *(optional)* A boolean indicating if HTTP/2 support should be + enabled. Defaults to `False`. + * **proxies** - *(optional)* A dictionary mapping HTTP protocols to proxy + URLs. + * **timeout** - *(optional)* The timeout configuration to use when sending + requests. + * **limits** - *(optional)* The limits configuration to use. + * **max_redirects** - *(optional)* The maximum number of redirect responses + that should be followed. + * **base_url** - *(optional)* A URL to use as the base when building + request URLs. + * **transport** - *(optional)* A transport class to use for sending requests + over the network. + * **app** - *(optional)* An ASGI application to send requests to, + rather than sending actual network requests. + * **trust_env** - *(optional)* Enables or disables usage of environment + variables for configuration. + * **default_encoding** - *(optional)* The default encoding to use for decoding + response text, if no charset information is included in a response Content-Type + header. Set to a callable for automatic character set detection. Default: "utf-8". + """ + + def __init__( + self, + *, + auth: typing.Optional[AuthTypes] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + proxies: typing.Optional[ProxiesTypes] = None, + mounts: typing.Optional[typing.Mapping[str, AsyncBaseTransport]] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + limits: Limits = DEFAULT_LIMITS, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + event_hooks: typing.Optional[ + typing.Mapping[str, typing.List[typing.Callable[..., typing.Any]]] + ] = None, + base_url: URLTypes = "", + transport: typing.Optional[AsyncBaseTransport] = None, + app: typing.Optional[typing.Callable[..., typing.Any]] = None, + trust_env: bool = True, + default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8", + ): + super().__init__( + auth=auth, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + follow_redirects=follow_redirects, + max_redirects=max_redirects, + event_hooks=event_hooks, + base_url=base_url, + trust_env=trust_env, + default_encoding=default_encoding, + ) + + if http2: + try: + import h2 # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using http2=True, but the 'h2' package is not installed. " + "Make sure to install httpx using `pip install httpx[http2]`." + ) from None + + allow_env_proxies = trust_env and app is None and transport is None + proxy_map = self._get_proxy_map(proxies, allow_env_proxies) + + self._transport = self._init_transport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + transport=transport, + app=app, + trust_env=trust_env, + ) + + self._mounts: typing.Dict[URLPattern, typing.Optional[AsyncBaseTransport]] = { + URLPattern(key): None + if proxy is None + else self._init_proxy_transport( + proxy, + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + ) + for key, proxy in proxy_map.items() + } + if mounts is not None: + self._mounts.update( + {URLPattern(key): transport for key, transport in mounts.items()} + ) + self._mounts = dict(sorted(self._mounts.items())) + + def _init_transport( + self, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + transport: typing.Optional[AsyncBaseTransport] = None, + app: typing.Optional[typing.Callable[..., typing.Any]] = None, + trust_env: bool = True, + ) -> AsyncBaseTransport: + if transport is not None: + return transport + + if app is not None: + return ASGITransport(app=app) + + return AsyncHTTPTransport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + ) + + def _init_proxy_transport( + self, + proxy: Proxy, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + trust_env: bool = True, + ) -> AsyncBaseTransport: + return AsyncHTTPTransport( + verify=verify, + cert=cert, + http2=http2, + limits=limits, + trust_env=trust_env, + proxy=proxy, + ) + + def _transport_for_url(self, url: URL) -> AsyncBaseTransport: + """ + Returns the transport instance that should be used for a given URL. + This will either be the standard connection pool, or a proxy. + """ + for pattern, transport in self._mounts.items(): + if pattern.matches(url): + return self._transport if transport is None else transport + + return self._transport + + async def request( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Build and send a request. + + Equivalent to: + + ```python + request = client.build_request(...) + response = await client.send(request, ...) + ``` + + See `AsyncClient.build_request()`, `AsyncClient.send()` + and [Merging of configuration][0] for how the various parameters + are merged with client-level configuration. + + [0]: /advanced/#merging-of-configuration + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + return await self.send(request, auth=auth, follow_redirects=follow_redirects) + + @asynccontextmanager + async def stream( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> typing.AsyncIterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + response = await self.send( + request=request, + auth=auth, + follow_redirects=follow_redirects, + stream=True, + ) + try: + yield response + finally: + await response.aclose() + + async def send( + self, + request: Request, + *, + stream: bool = False, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + ) -> Response: + """ + Send a request. + + The request is sent as-is, unmodified. + + Typically you'll want to build one with `AsyncClient.build_request()` + so that any client-level configuration is merged into the request, + but passing an explicit `httpx.Request()` is supported as well. + + See also: [Request instances][0] + + [0]: /advanced/#request-instances + """ + if self._state == ClientState.CLOSED: + raise RuntimeError("Cannot send a request, as the client has been closed.") + + self._state = ClientState.OPENED + follow_redirects = ( + self.follow_redirects + if isinstance(follow_redirects, UseClientDefault) + else follow_redirects + ) + + auth = self._build_request_auth(request, auth) + + response = await self._send_handling_auth( + request, + auth=auth, + follow_redirects=follow_redirects, + history=[], + ) + try: + if not stream: + await response.aread() + + return response + + except BaseException as exc: # pragma: no cover + await response.aclose() + raise exc + + async def _send_handling_auth( + self, + request: Request, + auth: Auth, + follow_redirects: bool, + history: typing.List[Response], + ) -> Response: + auth_flow = auth.async_auth_flow(request) + try: + request = await auth_flow.__anext__() + + while True: + response = await self._send_handling_redirects( + request, + follow_redirects=follow_redirects, + history=history, + ) + try: + try: + next_request = await auth_flow.asend(response) + except StopAsyncIteration: + return response + + response.history = list(history) + await response.aread() + request = next_request + history.append(response) + + except BaseException as exc: + await response.aclose() + raise exc + finally: + await auth_flow.aclose() + + async def _send_handling_redirects( + self, + request: Request, + follow_redirects: bool, + history: typing.List[Response], + ) -> Response: + while True: + if len(history) > self.max_redirects: + raise TooManyRedirects( + "Exceeded maximum allowed redirects.", request=request + ) + + for hook in self._event_hooks["request"]: + await hook(request) + + response = await self._send_single_request(request) + try: + for hook in self._event_hooks["response"]: + await hook(response) + + response.history = list(history) + + if not response.has_redirect_location: + return response + + request = self._build_redirect_request(request, response) + history = history + [response] + + if follow_redirects: + await response.aread() + else: + response.next_request = request + return response + + except BaseException as exc: + await response.aclose() + raise exc + + async def _send_single_request(self, request: Request) -> Response: + """ + Sends a single request, without handling any redirections. + """ + transport = self._transport_for_url(request.url) + timer = Timer() + await timer.async_start() + + if not isinstance(request.stream, AsyncByteStream): + raise RuntimeError( + "Attempted to send an sync request with an AsyncClient instance." + ) + + with request_context(request=request): + response = await transport.handle_async_request(request) + + assert isinstance(response.stream, AsyncByteStream) + response.request = request + response.stream = BoundAsyncStream( + response.stream, response=response, timer=timer + ) + self.cookies.extract_cookies(response) + response.default_encoding = self._default_encoding + + logger.info( + 'HTTP Request: %s %s "%s %d %s"', + request.method, + request.url, + response.http_version, + response.status_code, + response.reason_phrase, + ) + + return response + + async def get( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `GET` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def options( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send an `OPTIONS` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def head( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `HEAD` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def post( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `POST` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "POST", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def put( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `PUT` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "PUT", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def patch( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `PATCH` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "PATCH", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def delete( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `DELETE` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def aclose(self) -> None: + """ + Close transport and proxies. + """ + if self._state != ClientState.CLOSED: + self._state = ClientState.CLOSED + + await self._transport.aclose() + for proxy in self._mounts.values(): + if proxy is not None: + await proxy.aclose() + + async def __aenter__(self: U) -> U: + if self._state != ClientState.UNOPENED: + msg = { + ClientState.OPENED: "Cannot open a client instance more than once.", + ClientState.CLOSED: "Cannot reopen a client instance, once it has been closed.", + }[self._state] + raise RuntimeError(msg) + + self._state = ClientState.OPENED + + await self._transport.__aenter__() + for proxy in self._mounts.values(): + if proxy is not None: + await proxy.__aenter__() + return self + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + self._state = ClientState.CLOSED + + await self._transport.__aexit__(exc_type, exc_value, traceback) + for proxy in self._mounts.values(): + if proxy is not None: + await proxy.__aexit__(exc_type, exc_value, traceback) diff --git a/contrib/python/httpx/httpx/_compat.py b/contrib/python/httpx/httpx/_compat.py new file mode 100644 index 0000000000..a271c6b800 --- /dev/null +++ b/contrib/python/httpx/httpx/_compat.py @@ -0,0 +1,43 @@ +""" +The _compat module is used for code which requires branching between different +Python environments. It is excluded from the code coverage checks. +""" +import ssl +import sys + +# Brotli support is optional +# The C bindings in `brotli` are recommended for CPython. +# The CFFI bindings in `brotlicffi` are recommended for PyPy and everything else. +try: + import brotlicffi as brotli +except ImportError: # pragma: no cover + try: + import brotli + except ImportError: + brotli = None + +if sys.version_info >= (3, 10) or ( + sys.version_info >= (3, 8) and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0, 7) +): + + def set_minimum_tls_version_1_2(context: ssl.SSLContext) -> None: + # The OP_NO_SSL* and OP_NO_TLS* become deprecated in favor of + # 'SSLContext.minimum_version' from Python 3.7 onwards, however + # this attribute is not available unless the ssl module is compiled + # with OpenSSL 1.1.0g or newer. + # https://docs.python.org/3.10/library/ssl.html#ssl.SSLContext.minimum_version + # https://docs.python.org/3.7/library/ssl.html#ssl.SSLContext.minimum_version + context.minimum_version = ssl.TLSVersion.TLSv1_2 + +else: + + def set_minimum_tls_version_1_2(context: ssl.SSLContext) -> None: + # If 'minimum_version' isn't available, we configure these options with + # the older deprecated variants. + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_TLSv1 + context.options |= ssl.OP_NO_TLSv1_1 + + +__all__ = ["brotli", "set_minimum_tls_version_1_2"] diff --git a/contrib/python/httpx/httpx/_config.py b/contrib/python/httpx/httpx/_config.py new file mode 100644 index 0000000000..8d4e03add5 --- /dev/null +++ b/contrib/python/httpx/httpx/_config.py @@ -0,0 +1,378 @@ +import logging +import os +import ssl +import sys +import typing +from pathlib import Path + +import certifi + +from ._compat import set_minimum_tls_version_1_2 +from ._models import Headers +from ._types import CertTypes, HeaderTypes, TimeoutTypes, URLTypes, VerifyTypes +from ._urls import URL +from ._utils import get_ca_bundle_from_env + +DEFAULT_CIPHERS = ":".join( + [ + "ECDHE+AESGCM", + "ECDHE+CHACHA20", + "DHE+AESGCM", + "DHE+CHACHA20", + "ECDH+AESGCM", + "DH+AESGCM", + "ECDH+AES", + "DH+AES", + "RSA+AESGCM", + "RSA+AES", + "!aNULL", + "!eNULL", + "!MD5", + "!DSS", + ] +) + + +logger = logging.getLogger("httpx") + + +class UnsetType: + pass # pragma: no cover + + +UNSET = UnsetType() + + +def create_ssl_context( + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + trust_env: bool = True, + http2: bool = False, +) -> ssl.SSLContext: + return SSLConfig( + cert=cert, verify=verify, trust_env=trust_env, http2=http2 + ).ssl_context + + +class SSLConfig: + """ + SSL Configuration. + """ + + DEFAULT_CA_BUNDLE_PATH = certifi.where() + if callable(DEFAULT_CA_BUNDLE_PATH): + DEFAULT_CA_BUNDLE_PATH = staticmethod(DEFAULT_CA_BUNDLE_PATH) + else: + DEFAULT_CA_BUNDLE_PATH = Path(DEFAULT_CA_BUNDLE_PATH) + + def __init__( + self, + *, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + trust_env: bool = True, + http2: bool = False, + ): + self.cert = cert + self.verify = verify + self.trust_env = trust_env + self.http2 = http2 + self.ssl_context = self.load_ssl_context() + + def load_ssl_context(self) -> ssl.SSLContext: + logger.debug( + "load_ssl_context verify=%r cert=%r trust_env=%r http2=%r", + self.verify, + self.cert, + self.trust_env, + self.http2, + ) + + if self.verify: + return self.load_ssl_context_verify() + return self.load_ssl_context_no_verify() + + def load_ssl_context_no_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for unverified connections. + """ + context = self._create_default_ssl_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + self._load_client_certs(context) + return context + + def load_ssl_context_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for verified connections. + """ + if self.trust_env and self.verify is True: + ca_bundle = get_ca_bundle_from_env() + if ca_bundle is not None: + self.verify = ca_bundle + + if isinstance(self.verify, ssl.SSLContext): + # Allow passing in our own SSLContext object that's pre-configured. + context = self.verify + self._load_client_certs(context) + return context + elif isinstance(self.verify, bool): + ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH + elif Path(self.verify).exists(): + ca_bundle_path = Path(self.verify) + else: + raise IOError( + "Could not find a suitable TLS CA certificate bundle, " + "invalid path: {}".format(self.verify) + ) + + context = self._create_default_ssl_context() + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + + # Signal to server support for PHA in TLS 1.3. Raises an + # AttributeError if only read-only access is implemented. + if sys.version_info >= (3, 8): # pragma: no cover + try: + context.post_handshake_auth = True + except AttributeError: # pragma: no cover + pass + + # Disable using 'commonName' for SSLContext.check_hostname + # when the 'subjectAltName' extension isn't available. + try: + context.hostname_checks_common_name = False + except AttributeError: # pragma: no cover + pass + + if callable(ca_bundle_path): + logger.debug("load_verify_locations cafile=%r", ca_bundle_path) + context.load_verify_locations(cafile=ca_bundle_path) + elif ca_bundle_path.is_file(): + cafile = str(ca_bundle_path) + logger.debug("load_verify_locations cafile=%r", cafile) + context.load_verify_locations(cafile=cafile) + elif ca_bundle_path.is_dir(): + capath = str(ca_bundle_path) + logger.debug("load_verify_locations capath=%r", capath) + context.load_verify_locations(capath=capath) + + self._load_client_certs(context) + + return context + + def _create_default_ssl_context(self) -> ssl.SSLContext: + """ + Creates the default SSLContext object that's used for both verified + and unverified connections. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + set_minimum_tls_version_1_2(context) + context.options |= ssl.OP_NO_COMPRESSION + context.set_ciphers(DEFAULT_CIPHERS) + + if ssl.HAS_ALPN: + alpn_idents = ["http/1.1", "h2"] if self.http2 else ["http/1.1"] + context.set_alpn_protocols(alpn_idents) + + if sys.version_info >= (3, 8): # pragma: no cover + keylogfile = os.environ.get("SSLKEYLOGFILE") + if keylogfile and self.trust_env: + context.keylog_filename = keylogfile + + return context + + def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None: + """ + Loads client certificates into our SSLContext object + """ + if self.cert is not None: + if isinstance(self.cert, str): + ssl_context.load_cert_chain(certfile=self.cert) + elif isinstance(self.cert, tuple) and len(self.cert) == 2: + ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1]) + elif isinstance(self.cert, tuple) and len(self.cert) == 3: + ssl_context.load_cert_chain( + certfile=self.cert[0], + keyfile=self.cert[1], + password=self.cert[2], # type: ignore + ) + + +class Timeout: + """ + Timeout configuration. + + **Usage**: + + Timeout(None) # No timeouts. + Timeout(5.0) # 5s timeout on all operations. + Timeout(None, connect=5.0) # 5s timeout on connect, no other timeouts. + Timeout(5.0, connect=10.0) # 10s timeout on connect. 5s timeout elsewhere. + Timeout(5.0, pool=None) # No timeout on acquiring connection from pool. + # 5s timeout elsewhere. + """ + + def __init__( + self, + timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, + *, + connect: typing.Union[None, float, UnsetType] = UNSET, + read: typing.Union[None, float, UnsetType] = UNSET, + write: typing.Union[None, float, UnsetType] = UNSET, + pool: typing.Union[None, float, UnsetType] = UNSET, + ): + if isinstance(timeout, Timeout): + # Passed as a single explicit Timeout. + assert connect is UNSET + assert read is UNSET + assert write is UNSET + assert pool is UNSET + self.connect = timeout.connect # type: typing.Optional[float] + self.read = timeout.read # type: typing.Optional[float] + self.write = timeout.write # type: typing.Optional[float] + self.pool = timeout.pool # type: typing.Optional[float] + elif isinstance(timeout, tuple): + # Passed as a tuple. + self.connect = timeout[0] + self.read = timeout[1] + self.write = None if len(timeout) < 3 else timeout[2] + self.pool = None if len(timeout) < 4 else timeout[3] + elif not ( + isinstance(connect, UnsetType) + or isinstance(read, UnsetType) + or isinstance(write, UnsetType) + or isinstance(pool, UnsetType) + ): + self.connect = connect + self.read = read + self.write = write + self.pool = pool + else: + if isinstance(timeout, UnsetType): + raise ValueError( + "httpx.Timeout must either include a default, or set all " + "four parameters explicitly." + ) + self.connect = timeout if isinstance(connect, UnsetType) else connect + self.read = timeout if isinstance(read, UnsetType) else read + self.write = timeout if isinstance(write, UnsetType) else write + self.pool = timeout if isinstance(pool, UnsetType) else pool + + def as_dict(self) -> typing.Dict[str, typing.Optional[float]]: + return { + "connect": self.connect, + "read": self.read, + "write": self.write, + "pool": self.pool, + } + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.connect == other.connect + and self.read == other.read + and self.write == other.write + and self.pool == other.pool + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + if len({self.connect, self.read, self.write, self.pool}) == 1: + return f"{class_name}(timeout={self.connect})" + return ( + f"{class_name}(connect={self.connect}, " + f"read={self.read}, write={self.write}, pool={self.pool})" + ) + + +class Limits: + """ + Configuration for limits to various client behaviors. + + **Parameters:** + + * **max_connections** - The maximum number of concurrent connections that may be + established. + * **max_keepalive_connections** - Allow the connection pool to maintain + keep-alive connections below this point. Should be less than or equal + to `max_connections`. + * **keepalive_expiry** - Time limit on idle keep-alive connections in seconds. + """ + + def __init__( + self, + *, + max_connections: typing.Optional[int] = None, + max_keepalive_connections: typing.Optional[int] = None, + keepalive_expiry: typing.Optional[float] = 5.0, + ): + self.max_connections = max_connections + self.max_keepalive_connections = max_keepalive_connections + self.keepalive_expiry = keepalive_expiry + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.max_connections == other.max_connections + and self.max_keepalive_connections == other.max_keepalive_connections + and self.keepalive_expiry == other.keepalive_expiry + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return ( + f"{class_name}(max_connections={self.max_connections}, " + f"max_keepalive_connections={self.max_keepalive_connections}, " + f"keepalive_expiry={self.keepalive_expiry})" + ) + + +class Proxy: + def __init__( + self, + url: URLTypes, + *, + ssl_context: typing.Optional[ssl.SSLContext] = None, + auth: typing.Optional[typing.Tuple[str, str]] = None, + headers: typing.Optional[HeaderTypes] = None, + ): + url = URL(url) + headers = Headers(headers) + + if url.scheme not in ("http", "https", "socks5"): + raise ValueError(f"Unknown scheme for proxy URL {url!r}") + + if url.username or url.password: + # Remove any auth credentials from the URL. + auth = (url.username, url.password) + url = url.copy_with(username=None, password=None) + + self.url = url + self.auth = auth + self.headers = headers + self.ssl_context = ssl_context + + @property + def raw_auth(self) -> typing.Optional[typing.Tuple[bytes, bytes]]: + # The proxy authentication as raw bytes. + return ( + None + if self.auth is None + else (self.auth[0].encode("utf-8"), self.auth[1].encode("utf-8")) + ) + + def __repr__(self) -> str: + # The authentication is represented with the password component masked. + auth = (self.auth[0], "********") if self.auth else None + + # Build a nice concise representation. + url_str = f"{str(self.url)!r}" + auth_str = f", auth={auth!r}" if auth else "" + headers_str = f", headers={dict(self.headers)!r}" if self.headers else "" + return f"Proxy({url_str}{auth_str}{headers_str})" + + +DEFAULT_TIMEOUT_CONFIG = Timeout(timeout=5.0) +DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20) +DEFAULT_MAX_REDIRECTS = 20 diff --git a/contrib/python/httpx/httpx/_content.py b/contrib/python/httpx/httpx/_content.py new file mode 100644 index 0000000000..b16e12d954 --- /dev/null +++ b/contrib/python/httpx/httpx/_content.py @@ -0,0 +1,238 @@ +import inspect +import warnings +from json import dumps as json_dumps +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + Tuple, + Union, +) +from urllib.parse import urlencode + +from ._exceptions import StreamClosed, StreamConsumed +from ._multipart import MultipartStream +from ._types import ( + AsyncByteStream, + RequestContent, + RequestData, + RequestFiles, + ResponseContent, + SyncByteStream, +) +from ._utils import peek_filelike_length, primitive_value_to_str + + +class ByteStream(AsyncByteStream, SyncByteStream): + def __init__(self, stream: bytes) -> None: + self._stream = stream + + def __iter__(self) -> Iterator[bytes]: + yield self._stream + + async def __aiter__(self) -> AsyncIterator[bytes]: + yield self._stream + + +class IteratorByteStream(SyncByteStream): + CHUNK_SIZE = 65_536 + + def __init__(self, stream: Iterable[bytes]): + self._stream = stream + self._is_stream_consumed = False + self._is_generator = inspect.isgenerator(stream) + + def __iter__(self) -> Iterator[bytes]: + if self._is_stream_consumed and self._is_generator: + raise StreamConsumed() + + self._is_stream_consumed = True + if hasattr(self._stream, "read"): + # File-like interfaces should use 'read' directly. + chunk = self._stream.read(self.CHUNK_SIZE) + while chunk: + yield chunk + chunk = self._stream.read(self.CHUNK_SIZE) + else: + # Otherwise iterate. + for part in self._stream: + yield part + + +class AsyncIteratorByteStream(AsyncByteStream): + CHUNK_SIZE = 65_536 + + def __init__(self, stream: AsyncIterable[bytes]): + self._stream = stream + self._is_stream_consumed = False + self._is_generator = inspect.isasyncgen(stream) + + async def __aiter__(self) -> AsyncIterator[bytes]: + if self._is_stream_consumed and self._is_generator: + raise StreamConsumed() + + self._is_stream_consumed = True + if hasattr(self._stream, "aread"): + # File-like interfaces should use 'aread' directly. + chunk = await self._stream.aread(self.CHUNK_SIZE) + while chunk: + yield chunk + chunk = await self._stream.aread(self.CHUNK_SIZE) + else: + # Otherwise iterate. + async for part in self._stream: + yield part + + +class UnattachedStream(AsyncByteStream, SyncByteStream): + """ + If a request or response is serialized using pickle, then it is no longer + attached to a stream for I/O purposes. Any stream operations should result + in `httpx.StreamClosed`. + """ + + def __iter__(self) -> Iterator[bytes]: + raise StreamClosed() + + async def __aiter__(self) -> AsyncIterator[bytes]: + raise StreamClosed() + yield b"" # pragma: no cover + + +def encode_content( + content: Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] +) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]: + if isinstance(content, (bytes, str)): + body = content.encode("utf-8") if isinstance(content, str) else content + content_length = len(body) + headers = {"Content-Length": str(content_length)} if body else {} + return headers, ByteStream(body) + + elif isinstance(content, Iterable) and not isinstance(content, dict): + # `not isinstance(content, dict)` is a bit oddly specific, but it + # catches a case that's easy for users to make in error, and would + # otherwise pass through here, like any other bytes-iterable, + # because `dict` happens to be iterable. See issue #2491. + content_length_or_none = peek_filelike_length(content) + + if content_length_or_none is None: + headers = {"Transfer-Encoding": "chunked"} + else: + headers = {"Content-Length": str(content_length_or_none)} + return headers, IteratorByteStream(content) # type: ignore + + elif isinstance(content, AsyncIterable): + headers = {"Transfer-Encoding": "chunked"} + return headers, AsyncIteratorByteStream(content) + + raise TypeError(f"Unexpected type for 'content', {type(content)!r}") + + +def encode_urlencoded_data( + data: RequestData, +) -> Tuple[Dict[str, str], ByteStream]: + plain_data = [] + for key, value in data.items(): + if isinstance(value, (list, tuple)): + plain_data.extend([(key, primitive_value_to_str(item)) for item in value]) + else: + plain_data.append((key, primitive_value_to_str(value))) + body = urlencode(plain_data, doseq=True).encode("utf-8") + content_length = str(len(body)) + content_type = "application/x-www-form-urlencoded" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_multipart_data( + data: RequestData, files: RequestFiles, boundary: Optional[bytes] +) -> Tuple[Dict[str, str], MultipartStream]: + multipart = MultipartStream(data=data, files=files, boundary=boundary) + headers = multipart.get_headers() + return headers, multipart + + +def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]: + body = text.encode("utf-8") + content_length = str(len(body)) + content_type = "text/plain; charset=utf-8" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]: + body = html.encode("utf-8") + content_length = str(len(body)) + content_type = "text/html; charset=utf-8" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]: + body = json_dumps(json).encode("utf-8") + content_length = str(len(body)) + content_type = "application/json" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_request( + content: Optional[RequestContent] = None, + data: Optional[RequestData] = None, + files: Optional[RequestFiles] = None, + json: Optional[Any] = None, + boundary: Optional[bytes] = None, +) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]: + """ + Handles encoding the given `content`, `data`, `files`, and `json`, + returning a two-tuple of (<headers>, <stream>). + """ + if data is not None and not isinstance(data, Mapping): + # We prefer to separate `content=<bytes|str|byte iterator|bytes aiterator>` + # for raw request content, and `data=<form data>` for url encoded or + # multipart form content. + # + # However for compat with requests, we *do* still support + # `data=<bytes...>` usages. We deal with that case here, treating it + # as if `content=<...>` had been supplied instead. + message = "Use 'content=<...>' to upload raw bytes/text content." + warnings.warn(message, DeprecationWarning) + return encode_content(data) + + if content is not None: + return encode_content(content) + elif files: + return encode_multipart_data(data or {}, files, boundary) + elif data: + return encode_urlencoded_data(data) + elif json is not None: + return encode_json(json) + + return {}, ByteStream(b"") + + +def encode_response( + content: Optional[ResponseContent] = None, + text: Optional[str] = None, + html: Optional[str] = None, + json: Optional[Any] = None, +) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]: + """ + Handles encoding the given `content`, returning a two-tuple of + (<headers>, <stream>). + """ + if content is not None: + return encode_content(content) + elif text is not None: + return encode_text(text) + elif html is not None: + return encode_html(html) + elif json is not None: + return encode_json(json) + + return {}, ByteStream(b"") diff --git a/contrib/python/httpx/httpx/_decoders.py b/contrib/python/httpx/httpx/_decoders.py new file mode 100644 index 0000000000..500ce7ffc3 --- /dev/null +++ b/contrib/python/httpx/httpx/_decoders.py @@ -0,0 +1,324 @@ +""" +Handlers for Content-Encoding. + +See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding +""" +import codecs +import io +import typing +import zlib + +from ._compat import brotli +from ._exceptions import DecodingError + + +class ContentDecoder: + def decode(self, data: bytes) -> bytes: + raise NotImplementedError() # pragma: no cover + + def flush(self) -> bytes: + raise NotImplementedError() # pragma: no cover + + +class IdentityDecoder(ContentDecoder): + """ + Handle unencoded data. + """ + + def decode(self, data: bytes) -> bytes: + return data + + def flush(self) -> bytes: + return b"" + + +class DeflateDecoder(ContentDecoder): + """ + Handle 'deflate' decoding. + + See: https://stackoverflow.com/questions/1838699 + """ + + def __init__(self) -> None: + self.first_attempt = True + self.decompressor = zlib.decompressobj() + + def decode(self, data: bytes) -> bytes: + was_first_attempt = self.first_attempt + self.first_attempt = False + try: + return self.decompressor.decompress(data) + except zlib.error as exc: + if was_first_attempt: + self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS) + return self.decode(data) + raise DecodingError(str(exc)) from exc + + def flush(self) -> bytes: + try: + return self.decompressor.flush() + except zlib.error as exc: # pragma: no cover + raise DecodingError(str(exc)) from exc + + +class GZipDecoder(ContentDecoder): + """ + Handle 'gzip' decoding. + + See: https://stackoverflow.com/questions/1838699 + """ + + def __init__(self) -> None: + self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16) + + def decode(self, data: bytes) -> bytes: + try: + return self.decompressor.decompress(data) + except zlib.error as exc: + raise DecodingError(str(exc)) from exc + + def flush(self) -> bytes: + try: + return self.decompressor.flush() + except zlib.error as exc: # pragma: no cover + raise DecodingError(str(exc)) from exc + + +class BrotliDecoder(ContentDecoder): + """ + Handle 'brotli' decoding. + + Requires `pip install brotlipy`. See: https://brotlipy.readthedocs.io/ + or `pip install brotli`. See https://github.com/google/brotli + Supports both 'brotlipy' and 'Brotli' packages since they share an import + name. The top branches are for 'brotlipy' and bottom branches for 'Brotli' + """ + + def __init__(self) -> None: + if brotli is None: # pragma: no cover + raise ImportError( + "Using 'BrotliDecoder', but neither of the 'brotlicffi' or 'brotli' " + "packages have been installed. " + "Make sure to install httpx using `pip install httpx[brotli]`." + ) from None + + self.decompressor = brotli.Decompressor() + self.seen_data = False + self._decompress: typing.Callable[[bytes], bytes] + if hasattr(self.decompressor, "decompress"): + # The 'brotlicffi' package. + self._decompress = self.decompressor.decompress # pragma: no cover + else: + # The 'brotli' package. + self._decompress = self.decompressor.process # pragma: no cover + + def decode(self, data: bytes) -> bytes: + if not data: + return b"" + self.seen_data = True + try: + return self._decompress(data) + except brotli.error as exc: + raise DecodingError(str(exc)) from exc + + def flush(self) -> bytes: + if not self.seen_data: + return b"" + try: + if hasattr(self.decompressor, "finish"): + # Only available in the 'brotlicffi' package. + + # As the decompressor decompresses eagerly, this + # will never actually emit any data. However, it will potentially throw + # errors if a truncated or damaged data stream has been used. + self.decompressor.finish() # pragma: no cover + return b"" + except brotli.error as exc: # pragma: no cover + raise DecodingError(str(exc)) from exc + + +class MultiDecoder(ContentDecoder): + """ + Handle the case where multiple encodings have been applied. + """ + + def __init__(self, children: typing.Sequence[ContentDecoder]) -> None: + """ + 'children' should be a sequence of decoders in the order in which + each was applied. + """ + # Note that we reverse the order for decoding. + self.children = list(reversed(children)) + + def decode(self, data: bytes) -> bytes: + for child in self.children: + data = child.decode(data) + return data + + def flush(self) -> bytes: + data = b"" + for child in self.children: + data = child.decode(data) + child.flush() + return data + + +class ByteChunker: + """ + Handles returning byte content in fixed-size chunks. + """ + + def __init__(self, chunk_size: typing.Optional[int] = None) -> None: + self._buffer = io.BytesIO() + self._chunk_size = chunk_size + + def decode(self, content: bytes) -> typing.List[bytes]: + if self._chunk_size is None: + return [content] if content else [] + + self._buffer.write(content) + if self._buffer.tell() >= self._chunk_size: + value = self._buffer.getvalue() + chunks = [ + value[i : i + self._chunk_size] + for i in range(0, len(value), self._chunk_size) + ] + if len(chunks[-1]) == self._chunk_size: + self._buffer.seek(0) + self._buffer.truncate() + return chunks + else: + self._buffer.seek(0) + self._buffer.write(chunks[-1]) + self._buffer.truncate() + return chunks[:-1] + else: + return [] + + def flush(self) -> typing.List[bytes]: + value = self._buffer.getvalue() + self._buffer.seek(0) + self._buffer.truncate() + return [value] if value else [] + + +class TextChunker: + """ + Handles returning text content in fixed-size chunks. + """ + + def __init__(self, chunk_size: typing.Optional[int] = None) -> None: + self._buffer = io.StringIO() + self._chunk_size = chunk_size + + def decode(self, content: str) -> typing.List[str]: + if self._chunk_size is None: + return [content] + + self._buffer.write(content) + if self._buffer.tell() >= self._chunk_size: + value = self._buffer.getvalue() + chunks = [ + value[i : i + self._chunk_size] + for i in range(0, len(value), self._chunk_size) + ] + if len(chunks[-1]) == self._chunk_size: + self._buffer.seek(0) + self._buffer.truncate() + return chunks + else: + self._buffer.seek(0) + self._buffer.write(chunks[-1]) + self._buffer.truncate() + return chunks[:-1] + else: + return [] + + def flush(self) -> typing.List[str]: + value = self._buffer.getvalue() + self._buffer.seek(0) + self._buffer.truncate() + return [value] if value else [] + + +class TextDecoder: + """ + Handles incrementally decoding bytes into text + """ + + def __init__(self, encoding: str = "utf-8"): + self.decoder = codecs.getincrementaldecoder(encoding)(errors="replace") + + def decode(self, data: bytes) -> str: + return self.decoder.decode(data) + + def flush(self) -> str: + return self.decoder.decode(b"", True) + + +class LineDecoder: + """ + Handles incrementally reading lines from text. + + Has the same behaviour as the stdllib splitlines, but handling the input iteratively. + """ + + def __init__(self) -> None: + self.buffer: typing.List[str] = [] + self.trailing_cr: bool = False + + def decode(self, text: str) -> typing.List[str]: + # See https://docs.python.org/3/library/stdtypes.html#str.splitlines + NEWLINE_CHARS = "\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029" + + # We always push a trailing `\r` into the next decode iteration. + if self.trailing_cr: + text = "\r" + text + self.trailing_cr = False + if text.endswith("\r"): + self.trailing_cr = True + text = text[:-1] + + if not text: + return [] + + trailing_newline = text[-1] in NEWLINE_CHARS + lines = text.splitlines() + + if len(lines) == 1 and not trailing_newline: + # No new lines, buffer the input and continue. + self.buffer.append(lines[0]) + return [] + + if self.buffer: + # Include any existing buffer in the first portion of the + # splitlines result. + lines = ["".join(self.buffer) + lines[0]] + lines[1:] + self.buffer = [] + + if not trailing_newline: + # If the last segment of splitlines is not newline terminated, + # then drop it from our output and start a new buffer. + self.buffer = [lines.pop()] + + return lines + + def flush(self) -> typing.List[str]: + if not self.buffer and not self.trailing_cr: + return [] + + lines = ["".join(self.buffer)] + self.buffer = [] + self.trailing_cr = False + return lines + + +SUPPORTED_DECODERS = { + "identity": IdentityDecoder, + "gzip": GZipDecoder, + "deflate": DeflateDecoder, + "br": BrotliDecoder, +} + + +if brotli is None: + SUPPORTED_DECODERS.pop("br") # pragma: no cover diff --git a/contrib/python/httpx/httpx/_exceptions.py b/contrib/python/httpx/httpx/_exceptions.py new file mode 100644 index 0000000000..24a4f8aba3 --- /dev/null +++ b/contrib/python/httpx/httpx/_exceptions.py @@ -0,0 +1,343 @@ +""" +Our exception hierarchy: + +* HTTPError + x RequestError + + TransportError + - TimeoutException + · ConnectTimeout + · ReadTimeout + · WriteTimeout + · PoolTimeout + - NetworkError + · ConnectError + · ReadError + · WriteError + · CloseError + - ProtocolError + · LocalProtocolError + · RemoteProtocolError + - ProxyError + - UnsupportedProtocol + + DecodingError + + TooManyRedirects + x HTTPStatusError +* InvalidURL +* CookieConflict +* StreamError + x StreamConsumed + x StreamClosed + x ResponseNotRead + x RequestNotRead +""" +import contextlib +import typing + +if typing.TYPE_CHECKING: + from ._models import Request, Response # pragma: no cover + + +class HTTPError(Exception): + """ + Base class for `RequestError` and `HTTPStatusError`. + + Useful for `try...except` blocks when issuing a request, + and then calling `.raise_for_status()`. + + For example: + + ``` + try: + response = httpx.get("https://www.example.com") + response.raise_for_status() + except httpx.HTTPError as exc: + print(f"HTTP Exception for {exc.request.url} - {exc}") + ``` + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self._request: typing.Optional["Request"] = None + + @property + def request(self) -> "Request": + if self._request is None: + raise RuntimeError("The .request property has not been set.") + return self._request + + @request.setter + def request(self, request: "Request") -> None: + self._request = request + + +class RequestError(HTTPError): + """ + Base class for all exceptions that may occur when issuing a `.request()`. + """ + + def __init__( + self, message: str, *, request: typing.Optional["Request"] = None + ) -> None: + super().__init__(message) + # At the point an exception is raised we won't typically have a request + # instance to associate it with. + # + # The 'request_context' context manager is used within the Client and + # Response methods in order to ensure that any raised exceptions + # have a `.request` property set on them. + self._request = request + + +class TransportError(RequestError): + """ + Base class for all exceptions that occur at the level of the Transport API. + """ + + +# Timeout exceptions... + + +class TimeoutException(TransportError): + """ + The base class for timeout errors. + + An operation has timed out. + """ + + +class ConnectTimeout(TimeoutException): + """ + Timed out while connecting to the host. + """ + + +class ReadTimeout(TimeoutException): + """ + Timed out while receiving data from the host. + """ + + +class WriteTimeout(TimeoutException): + """ + Timed out while sending data to the host. + """ + + +class PoolTimeout(TimeoutException): + """ + Timed out waiting to acquire a connection from the pool. + """ + + +# Core networking exceptions... + + +class NetworkError(TransportError): + """ + The base class for network-related errors. + + An error occurred while interacting with the network. + """ + + +class ReadError(NetworkError): + """ + Failed to receive data from the network. + """ + + +class WriteError(NetworkError): + """ + Failed to send data through the network. + """ + + +class ConnectError(NetworkError): + """ + Failed to establish a connection. + """ + + +class CloseError(NetworkError): + """ + Failed to close a connection. + """ + + +# Other transport exceptions... + + +class ProxyError(TransportError): + """ + An error occurred while establishing a proxy connection. + """ + + +class UnsupportedProtocol(TransportError): + """ + Attempted to make a request to an unsupported protocol. + + For example issuing a request to `ftp://www.example.com`. + """ + + +class ProtocolError(TransportError): + """ + The protocol was violated. + """ + + +class LocalProtocolError(ProtocolError): + """ + A protocol was violated by the client. + + For example if the user instantiated a `Request` instance explicitly, + failed to include the mandatory `Host:` header, and then issued it directly + using `client.send()`. + """ + + +class RemoteProtocolError(ProtocolError): + """ + The protocol was violated by the server. + + For example, returning malformed HTTP. + """ + + +# Other request exceptions... + + +class DecodingError(RequestError): + """ + Decoding of the response failed, due to a malformed encoding. + """ + + +class TooManyRedirects(RequestError): + """ + Too many redirects. + """ + + +# Client errors + + +class HTTPStatusError(HTTPError): + """ + The response had an error HTTP status of 4xx or 5xx. + + May be raised when calling `response.raise_for_status()` + """ + + def __init__( + self, message: str, *, request: "Request", response: "Response" + ) -> None: + super().__init__(message) + self.request = request + self.response = response + + +class InvalidURL(Exception): + """ + URL is improperly formed or cannot be parsed. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class CookieConflict(Exception): + """ + Attempted to lookup a cookie by name, but multiple cookies existed. + + Can occur when calling `response.cookies.get(...)`. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +# Stream exceptions... + +# These may occur as the result of a programming error, by accessing +# the request/response stream in an invalid manner. + + +class StreamError(RuntimeError): + """ + The base class for stream exceptions. + + The developer made an error in accessing the request stream in + an invalid way. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class StreamConsumed(StreamError): + """ + Attempted to read or stream content, but the content has already + been streamed. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream some content, but the content has " + "already been streamed. For requests, this could be due to passing " + "a generator as request content, and then receiving a redirect " + "response or a secondary request as part of an authentication flow." + "For responses, this could be due to attempting to stream the response " + "content more than once." + ) + super().__init__(message) + + +class StreamClosed(StreamError): + """ + Attempted to read or stream response content, but the request has been + closed. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream content, but the stream has " "been closed." + ) + super().__init__(message) + + +class ResponseNotRead(StreamError): + """ + Attempted to access streaming response content, without having called `read()`. + """ + + def __init__(self) -> None: + message = "Attempted to access streaming response content, without having called `read()`." + super().__init__(message) + + +class RequestNotRead(StreamError): + """ + Attempted to access streaming request content, without having called `read()`. + """ + + def __init__(self) -> None: + message = "Attempted to access streaming request content, without having called `read()`." + super().__init__(message) + + +@contextlib.contextmanager +def request_context( + request: typing.Optional["Request"] = None, +) -> typing.Iterator[None]: + """ + A context manager that can be used to attach the given request context + to any `RequestError` exceptions that are raised within the block. + """ + try: + yield + except RequestError as exc: + if request is not None: + exc.request = request + raise exc diff --git a/contrib/python/httpx/httpx/_main.py b/contrib/python/httpx/httpx/_main.py new file mode 100644 index 0000000000..7c12ce841d --- /dev/null +++ b/contrib/python/httpx/httpx/_main.py @@ -0,0 +1,506 @@ +import functools +import json +import sys +import typing + +import click +import httpcore +import pygments.lexers +import pygments.util +import rich.console +import rich.markup +import rich.progress +import rich.syntax +import rich.table + +from ._client import Client +from ._exceptions import RequestError +from ._models import Response +from ._status_codes import codes + + +def print_help() -> None: + console = rich.console.Console() + + console.print("[bold]HTTPX :butterfly:", justify="center") + console.print() + console.print("A next generation HTTP client.", justify="center") + console.print() + console.print( + "Usage: [bold]httpx[/bold] [cyan]<URL> [OPTIONS][/cyan] ", justify="left" + ) + console.print() + + table = rich.table.Table.grid(padding=1, pad_edge=True) + table.add_column("Parameter", no_wrap=True, justify="left", style="bold") + table.add_column("Description") + table.add_row( + "-m, --method [cyan]METHOD", + "Request method, such as GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD.\n" + "[Default: GET, or POST if a request body is included]", + ) + table.add_row( + "-p, --params [cyan]<NAME VALUE> ...", + "Query parameters to include in the request URL.", + ) + table.add_row( + "-c, --content [cyan]TEXT", "Byte content to include in the request body." + ) + table.add_row( + "-d, --data [cyan]<NAME VALUE> ...", "Form data to include in the request body." + ) + table.add_row( + "-f, --files [cyan]<NAME FILENAME> ...", + "Form files to include in the request body.", + ) + table.add_row("-j, --json [cyan]TEXT", "JSON data to include in the request body.") + table.add_row( + "-h, --headers [cyan]<NAME VALUE> ...", + "Include additional HTTP headers in the request.", + ) + table.add_row( + "--cookies [cyan]<NAME VALUE> ...", "Cookies to include in the request." + ) + table.add_row( + "--auth [cyan]<USER PASS>", + "Username and password to include in the request. Specify '-' for the password to use " + "a password prompt. Note that using --verbose/-v will expose the Authorization " + "header, including the password encoding in a trivially reversible format.", + ) + + table.add_row( + "--proxies [cyan]URL", + "Send the request via a proxy. Should be the URL giving the proxy address.", + ) + + table.add_row( + "--timeout [cyan]FLOAT", + "Timeout value to use for network operations, such as establishing the connection, " + "reading some data, etc... [Default: 5.0]", + ) + + table.add_row("--follow-redirects", "Automatically follow redirects.") + table.add_row("--no-verify", "Disable SSL verification.") + table.add_row( + "--http2", "Send the request using HTTP/2, if the remote server supports it." + ) + + table.add_row( + "--download [cyan]FILE", + "Save the response content as a file, rather than displaying it.", + ) + + table.add_row("-v, --verbose", "Verbose output. Show request as well as response.") + table.add_row("--help", "Show this message and exit.") + console.print(table) + + +def get_lexer_for_response(response: Response) -> str: + content_type = response.headers.get("Content-Type") + if content_type is not None: + mime_type, _, _ = content_type.partition(";") + try: + return typing.cast( + str, pygments.lexers.get_lexer_for_mimetype(mime_type.strip()).name + ) + except pygments.util.ClassNotFound: # pragma: no cover + pass + return "" # pragma: no cover + + +def format_request_headers(request: httpcore.Request, http2: bool = False) -> str: + version = "HTTP/2" if http2 else "HTTP/1.1" + headers = [ + (name.lower() if http2 else name, value) for name, value in request.headers + ] + method = request.method.decode("ascii") + target = request.url.target.decode("ascii") + lines = [f"{method} {target} {version}"] + [ + f"{name.decode('ascii')}: {value.decode('ascii')}" for name, value in headers + ] + return "\n".join(lines) + + +def format_response_headers( + http_version: bytes, + status: int, + reason_phrase: typing.Optional[bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], +) -> str: + version = http_version.decode("ascii") + reason = ( + codes.get_reason_phrase(status) + if reason_phrase is None + else reason_phrase.decode("ascii") + ) + lines = [f"{version} {status} {reason}"] + [ + f"{name.decode('ascii')}: {value.decode('ascii')}" for name, value in headers + ] + return "\n".join(lines) + + +def print_request_headers(request: httpcore.Request, http2: bool = False) -> None: + console = rich.console.Console() + http_text = format_request_headers(request, http2=http2) + syntax = rich.syntax.Syntax(http_text, "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + syntax = rich.syntax.Syntax("", "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + + +def print_response_headers( + http_version: bytes, + status: int, + reason_phrase: typing.Optional[bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], +) -> None: + console = rich.console.Console() + http_text = format_response_headers(http_version, status, reason_phrase, headers) + syntax = rich.syntax.Syntax(http_text, "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + syntax = rich.syntax.Syntax("", "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + + +def print_response(response: Response) -> None: + console = rich.console.Console() + lexer_name = get_lexer_for_response(response) + if lexer_name: + if lexer_name.lower() == "json": + try: + data = response.json() + text = json.dumps(data, indent=4) + except ValueError: # pragma: no cover + text = response.text + else: + text = response.text + + syntax = rich.syntax.Syntax(text, lexer_name, theme="ansi_dark", word_wrap=True) + console.print(syntax) + else: + console.print(f"<{len(response.content)} bytes of binary data>") + + +_PCTRTT = typing.Tuple[typing.Tuple[str, str], ...] +_PCTRTTT = typing.Tuple[_PCTRTT, ...] +_PeerCertRetDictType = typing.Dict[str, typing.Union[str, _PCTRTTT, _PCTRTT]] + + +def format_certificate(cert: _PeerCertRetDictType) -> str: # pragma: no cover + lines = [] + for key, value in cert.items(): + if isinstance(value, (list, tuple)): + lines.append(f"* {key}:") + for item in value: + if key in ("subject", "issuer"): + for sub_item in item: + lines.append(f"* {sub_item[0]}: {sub_item[1]!r}") + elif isinstance(item, tuple) and len(item) == 2: + lines.append(f"* {item[0]}: {item[1]!r}") + else: + lines.append(f"* {item!r}") + else: + lines.append(f"* {key}: {value!r}") + return "\n".join(lines) + + +def trace( + name: str, info: typing.Mapping[str, typing.Any], verbose: bool = False +) -> None: + console = rich.console.Console() + if name == "connection.connect_tcp.started" and verbose: + host = info["host"] + console.print(f"* Connecting to {host!r}") + elif name == "connection.connect_tcp.complete" and verbose: + stream = info["return_value"] + server_addr = stream.get_extra_info("server_addr") + console.print(f"* Connected to {server_addr[0]!r} on port {server_addr[1]}") + elif name == "connection.start_tls.complete" and verbose: # pragma: no cover + stream = info["return_value"] + ssl_object = stream.get_extra_info("ssl_object") + version = ssl_object.version() + cipher = ssl_object.cipher() + server_cert = ssl_object.getpeercert() + alpn = ssl_object.selected_alpn_protocol() + console.print(f"* SSL established using {version!r} / {cipher[0]!r}") + console.print(f"* Selected ALPN protocol: {alpn!r}") + if server_cert: + console.print("* Server certificate:") + console.print(format_certificate(server_cert)) + elif name == "http11.send_request_headers.started" and verbose: + request = info["request"] + print_request_headers(request, http2=False) + elif name == "http2.send_request_headers.started" and verbose: # pragma: no cover + request = info["request"] + print_request_headers(request, http2=True) + elif name == "http11.receive_response_headers.complete": + http_version, status, reason_phrase, headers = info["return_value"] + print_response_headers(http_version, status, reason_phrase, headers) + elif name == "http2.receive_response_headers.complete": # pragma: no cover + status, headers = info["return_value"] + http_version = b"HTTP/2" + reason_phrase = None + print_response_headers(http_version, status, reason_phrase, headers) + + +def download_response(response: Response, download: typing.BinaryIO) -> None: + console = rich.console.Console() + console.print() + content_length = response.headers.get("Content-Length") + with rich.progress.Progress( + "[progress.description]{task.description}", + "[progress.percentage]{task.percentage:>3.0f}%", + rich.progress.BarColumn(bar_width=None), + rich.progress.DownloadColumn(), + rich.progress.TransferSpeedColumn(), + ) as progress: + description = f"Downloading [bold]{rich.markup.escape(download.name)}" + download_task = progress.add_task( + description, + total=int(content_length or 0), + start=content_length is not None, + ) + for chunk in response.iter_bytes(): + download.write(chunk) + progress.update(download_task, completed=response.num_bytes_downloaded) + + +def validate_json( + ctx: click.Context, + param: typing.Union[click.Option, click.Parameter], + value: typing.Any, +) -> typing.Any: + if value is None: + return None + + try: + return json.loads(value) + except json.JSONDecodeError: # pragma: no cover + raise click.BadParameter("Not valid JSON") + + +def validate_auth( + ctx: click.Context, + param: typing.Union[click.Option, click.Parameter], + value: typing.Any, +) -> typing.Any: + if value == (None, None): + return None + + username, password = value + if password == "-": # pragma: no cover + password = click.prompt("Password", hide_input=True) + return (username, password) + + +def handle_help( + ctx: click.Context, + param: typing.Union[click.Option, click.Parameter], + value: typing.Any, +) -> None: + if not value or ctx.resilient_parsing: + return + + print_help() + ctx.exit() + + +@click.command(add_help_option=False) +@click.argument("url", type=str) +@click.option( + "--method", + "-m", + "method", + type=str, + help=( + "Request method, such as GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD. " + "[Default: GET, or POST if a request body is included]" + ), +) +@click.option( + "--params", + "-p", + "params", + type=(str, str), + multiple=True, + help="Query parameters to include in the request URL.", +) +@click.option( + "--content", + "-c", + "content", + type=str, + help="Byte content to include in the request body.", +) +@click.option( + "--data", + "-d", + "data", + type=(str, str), + multiple=True, + help="Form data to include in the request body.", +) +@click.option( + "--files", + "-f", + "files", + type=(str, click.File(mode="rb")), + multiple=True, + help="Form files to include in the request body.", +) +@click.option( + "--json", + "-j", + "json", + type=str, + callback=validate_json, + help="JSON data to include in the request body.", +) +@click.option( + "--headers", + "-h", + "headers", + type=(str, str), + multiple=True, + help="Include additional HTTP headers in the request.", +) +@click.option( + "--cookies", + "cookies", + type=(str, str), + multiple=True, + help="Cookies to include in the request.", +) +@click.option( + "--auth", + "auth", + type=(str, str), + default=(None, None), + callback=validate_auth, + help=( + "Username and password to include in the request. " + "Specify '-' for the password to use a password prompt. " + "Note that using --verbose/-v will expose the Authorization header, " + "including the password encoding in a trivially reversible format." + ), +) +@click.option( + "--proxies", + "proxies", + type=str, + default=None, + help="Send the request via a proxy. Should be the URL giving the proxy address.", +) +@click.option( + "--timeout", + "timeout", + type=float, + default=5.0, + help=( + "Timeout value to use for network operations, such as establishing the " + "connection, reading some data, etc... [Default: 5.0]" + ), +) +@click.option( + "--follow-redirects", + "follow_redirects", + is_flag=True, + default=False, + help="Automatically follow redirects.", +) +@click.option( + "--no-verify", + "verify", + is_flag=True, + default=True, + help="Disable SSL verification.", +) +@click.option( + "--http2", + "http2", + type=bool, + is_flag=True, + default=False, + help="Send the request using HTTP/2, if the remote server supports it.", +) +@click.option( + "--download", + type=click.File("wb"), + help="Save the response content as a file, rather than displaying it.", +) +@click.option( + "--verbose", + "-v", + type=bool, + is_flag=True, + default=False, + help="Verbose. Show request as well as response.", +) +@click.option( + "--help", + is_flag=True, + is_eager=True, + expose_value=False, + callback=handle_help, + help="Show this message and exit.", +) +def main( + url: str, + method: str, + params: typing.List[typing.Tuple[str, str]], + content: str, + data: typing.List[typing.Tuple[str, str]], + files: typing.List[typing.Tuple[str, click.File]], + json: str, + headers: typing.List[typing.Tuple[str, str]], + cookies: typing.List[typing.Tuple[str, str]], + auth: typing.Optional[typing.Tuple[str, str]], + proxies: str, + timeout: float, + follow_redirects: bool, + verify: bool, + http2: bool, + download: typing.Optional[typing.BinaryIO], + verbose: bool, +) -> None: + """ + An HTTP command line client. + Sends a request and displays the response. + """ + if not method: + method = "POST" if content or data or files or json else "GET" + + try: + with Client( + proxies=proxies, + timeout=timeout, + verify=verify, + http2=http2, + ) as client: + with client.stream( + method, + url, + params=list(params), + content=content, + data=dict(data), + files=files, # type: ignore + json=json, + headers=headers, + cookies=dict(cookies), + auth=auth, + follow_redirects=follow_redirects, + extensions={"trace": functools.partial(trace, verbose=verbose)}, + ) as response: + if download is not None: + download_response(response, download) + else: + response.read() + if response.content: + print_response(response) + + except RequestError as exc: + console = rich.console.Console() + console.print(f"[red]{type(exc).__name__}[/red]: {exc}") + sys.exit(1) + + sys.exit(0 if response.is_success else 1) diff --git a/contrib/python/httpx/httpx/_models.py b/contrib/python/httpx/httpx/_models.py new file mode 100644 index 0000000000..e1e45cf06b --- /dev/null +++ b/contrib/python/httpx/httpx/_models.py @@ -0,0 +1,1209 @@ +import datetime +import email.message +import json as jsonlib +import typing +import urllib.request +from collections.abc import Mapping +from http.cookiejar import Cookie, CookieJar + +from ._content import ByteStream, UnattachedStream, encode_request, encode_response +from ._decoders import ( + SUPPORTED_DECODERS, + ByteChunker, + ContentDecoder, + IdentityDecoder, + LineDecoder, + MultiDecoder, + TextChunker, + TextDecoder, +) +from ._exceptions import ( + CookieConflict, + HTTPStatusError, + RequestNotRead, + ResponseNotRead, + StreamClosed, + StreamConsumed, + request_context, +) +from ._multipart import get_multipart_boundary_from_content_type +from ._status_codes import codes +from ._types import ( + AsyncByteStream, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestExtensions, + RequestFiles, + ResponseContent, + ResponseExtensions, + SyncByteStream, +) +from ._urls import URL +from ._utils import ( + guess_json_utf, + is_known_encoding, + normalize_header_key, + normalize_header_value, + obfuscate_sensitive_headers, + parse_content_type_charset, + parse_header_links, +) + + +class Headers(typing.MutableMapping[str, str]): + """ + HTTP headers, as a case-insensitive multi-dict. + """ + + def __init__( + self, + headers: typing.Optional[HeaderTypes] = None, + encoding: typing.Optional[str] = None, + ) -> None: + if headers is None: + self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]] + elif isinstance(headers, Headers): + self._list = list(headers._list) + elif isinstance(headers, Mapping): + self._list = [ + ( + normalize_header_key(k, lower=False, encoding=encoding), + normalize_header_key(k, lower=True, encoding=encoding), + normalize_header_value(v, encoding), + ) + for k, v in headers.items() + ] + else: + self._list = [ + ( + normalize_header_key(k, lower=False, encoding=encoding), + normalize_header_key(k, lower=True, encoding=encoding), + normalize_header_value(v, encoding), + ) + for k, v in headers + ] + + self._encoding = encoding + + @property + def encoding(self) -> str: + """ + Header encoding is mandated as ascii, but we allow fallbacks to utf-8 + or iso-8859-1. + """ + if self._encoding is None: + for encoding in ["ascii", "utf-8"]: + for key, value in self.raw: + try: + key.decode(encoding) + value.decode(encoding) + except UnicodeDecodeError: + break + else: + # The else block runs if 'break' did not occur, meaning + # all values fitted the encoding. + self._encoding = encoding + break + else: + # The ISO-8859-1 encoding covers all 256 code points in a byte, + # so will never raise decode errors. + self._encoding = "iso-8859-1" + return self._encoding + + @encoding.setter + def encoding(self, value: str) -> None: + self._encoding = value + + @property + def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: + """ + Returns a list of the raw header items, as byte pairs. + """ + return [(raw_key, value) for raw_key, _, value in self._list] + + def keys(self) -> typing.KeysView[str]: + return {key.decode(self.encoding): None for _, key, value in self._list}.keys() + + def values(self) -> typing.ValuesView[str]: + values_dict: typing.Dict[str, str] = {} + for _, key, value in self._list: + str_key = key.decode(self.encoding) + str_value = value.decode(self.encoding) + if str_key in values_dict: + values_dict[str_key] += f", {str_value}" + else: + values_dict[str_key] = str_value + return values_dict.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return `(key, value)` items of headers. Concatenate headers + into a single comma separated value when a key occurs multiple times. + """ + values_dict: typing.Dict[str, str] = {} + for _, key, value in self._list: + str_key = key.decode(self.encoding) + str_value = value.decode(self.encoding) + if str_key in values_dict: + values_dict[str_key] += f", {str_value}" + else: + values_dict[str_key] = str_value + return values_dict.items() + + def multi_items(self) -> typing.List[typing.Tuple[str, str]]: + """ + Return a list of `(key, value)` pairs of headers. Allow multiple + occurrences of the same key without concatenating into a single + comma separated value. + """ + return [ + (key.decode(self.encoding), value.decode(self.encoding)) + for _, key, value in self._list + ] + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Return a header value. If multiple occurrences of the header occur + then concatenate them together with commas. + """ + try: + return self[key] + except KeyError: + return default + + def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]: + """ + Return a list of all header values for a given key. + If `split_commas=True` is passed, then any comma separated header + values are split into multiple return strings. + """ + get_header_key = key.lower().encode(self.encoding) + + values = [ + item_value.decode(self.encoding) + for _, item_key, item_value in self._list + if item_key.lower() == get_header_key + ] + + if not split_commas: + return values + + split_values = [] + for value in values: + split_values.extend([item.strip() for item in value.split(",")]) + return split_values + + def update(self, headers: typing.Optional[HeaderTypes] = None) -> None: # type: ignore + headers = Headers(headers) + for key in headers.keys(): + if key in self: + self.pop(key) + self._list.extend(headers._list) + + def copy(self) -> "Headers": + return Headers(self, encoding=self.encoding) + + def __getitem__(self, key: str) -> str: + """ + Return a single header value. + + If there are multiple headers with the same key, then we concatenate + them with commas. See: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + normalized_key = key.lower().encode(self.encoding) + + items = [ + header_value.decode(self.encoding) + for _, header_key, header_value in self._list + if header_key == normalized_key + ] + + if items: + return ", ".join(items) + + raise KeyError(key) + + def __setitem__(self, key: str, value: str) -> None: + """ + Set the header `key` to `value`, removing any duplicate entries. + Retains insertion order. + """ + set_key = key.encode(self._encoding or "utf-8") + set_value = value.encode(self._encoding or "utf-8") + lookup_key = set_key.lower() + + found_indexes = [ + idx + for idx, (_, item_key, _) in enumerate(self._list) + if item_key == lookup_key + ] + + for idx in reversed(found_indexes[1:]): + del self._list[idx] + + if found_indexes: + idx = found_indexes[0] + self._list[idx] = (set_key, lookup_key, set_value) + else: + self._list.append((set_key, lookup_key, set_value)) + + def __delitem__(self, key: str) -> None: + """ + Remove the header `key`. + """ + del_key = key.lower().encode(self.encoding) + + pop_indexes = [ + idx + for idx, (_, item_key, _) in enumerate(self._list) + if item_key.lower() == del_key + ] + + if not pop_indexes: + raise KeyError(key) + + for idx in reversed(pop_indexes): + del self._list[idx] + + def __contains__(self, key: typing.Any) -> bool: + header_key = key.lower().encode(self.encoding) + return header_key in [key for _, key, _ in self._list] + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._list) + + def __eq__(self, other: typing.Any) -> bool: + try: + other_headers = Headers(other) + except ValueError: + return False + + self_list = [(key, value) for _, key, value in self._list] + other_list = [(key, value) for _, key, value in other_headers._list] + return sorted(self_list) == sorted(other_list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + + encoding_str = "" + if self.encoding != "ascii": + encoding_str = f", encoding={self.encoding!r}" + + as_list = list(obfuscate_sensitive_headers(self.multi_items())) + as_dict = dict(as_list) + + no_duplicate_keys = len(as_dict) == len(as_list) + if no_duplicate_keys: + return f"{class_name}({as_dict!r}{encoding_str})" + return f"{class_name}({as_list!r}{encoding_str})" + + +class Request: + def __init__( + self, + method: typing.Union[str, bytes], + url: typing.Union["URL", str], + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + stream: typing.Union[SyncByteStream, AsyncByteStream, None] = None, + extensions: typing.Optional[RequestExtensions] = None, + ): + self.method = ( + method.decode("ascii").upper() + if isinstance(method, bytes) + else method.upper() + ) + self.url = URL(url) + if params is not None: + self.url = self.url.copy_merge_params(params=params) + self.headers = Headers(headers) + self.extensions = {} if extensions is None else extensions + + if cookies: + Cookies(cookies).set_cookie_header(self) + + if stream is None: + content_type: typing.Optional[str] = self.headers.get("content-type") + headers, stream = encode_request( + content=content, + data=data, + files=files, + json=json, + boundary=get_multipart_boundary_from_content_type( + content_type=content_type.encode(self.headers.encoding) + if content_type + else None + ), + ) + self._prepare(headers) + self.stream = stream + # Load the request body, except for streaming content. + if isinstance(stream, ByteStream): + self.read() + else: + # There's an important distinction between `Request(content=...)`, + # and `Request(stream=...)`. + # + # Using `content=...` implies automatically populated `Host` and content + # headers, of either `Content-Length: ...` or `Transfer-Encoding: chunked`. + # + # Using `stream=...` will not automatically include *any* auto-populated headers. + # + # As an end-user you don't really need `stream=...`. It's only + # useful when: + # + # * Preserving the request stream when copying requests, eg for redirects. + # * Creating request instances on the *server-side* of the transport API. + self.stream = stream + + def _prepare(self, default_headers: typing.Dict[str, str]) -> None: + for key, value in default_headers.items(): + # Ignore Transfer-Encoding if the Content-Length has been set explicitly. + if key.lower() == "transfer-encoding" and "Content-Length" in self.headers: + continue + self.headers.setdefault(key, value) + + auto_headers: typing.List[typing.Tuple[bytes, bytes]] = [] + + has_host = "Host" in self.headers + has_content_length = ( + "Content-Length" in self.headers or "Transfer-Encoding" in self.headers + ) + + if not has_host and self.url.host: + auto_headers.append((b"Host", self.url.netloc)) + if not has_content_length and self.method in ("POST", "PUT", "PATCH"): + auto_headers.append((b"Content-Length", b"0")) + + self.headers = Headers(auto_headers + self.headers.raw) + + @property + def content(self) -> bytes: + if not hasattr(self, "_content"): + raise RequestNotRead() + return self._content + + def read(self) -> bytes: + """ + Read and return the request content. + """ + if not hasattr(self, "_content"): + assert isinstance(self.stream, typing.Iterable) + self._content = b"".join(self.stream) + if not isinstance(self.stream, ByteStream): + # If a streaming request has been read entirely into memory, then + # we can replace the stream with a raw bytes implementation, + # to ensure that any non-replayable streams can still be used. + self.stream = ByteStream(self._content) + return self._content + + async def aread(self) -> bytes: + """ + Read and return the request content. + """ + if not hasattr(self, "_content"): + assert isinstance(self.stream, typing.AsyncIterable) + self._content = b"".join([part async for part in self.stream]) + if not isinstance(self.stream, ByteStream): + # If a streaming request has been read entirely into memory, then + # we can replace the stream with a raw bytes implementation, + # to ensure that any non-replayable streams can still be used. + self.stream = ByteStream(self._content) + return self._content + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + url = str(self.url) + return f"<{class_name}({self.method!r}, {url!r})>" + + def __getstate__(self) -> typing.Dict[str, typing.Any]: + return { + name: value + for name, value in self.__dict__.items() + if name not in ["extensions", "stream"] + } + + def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None: + for name, value in state.items(): + setattr(self, name, value) + self.extensions = {} + self.stream = UnattachedStream() + + +class Response: + def __init__( + self, + status_code: int, + *, + headers: typing.Optional[HeaderTypes] = None, + content: typing.Optional[ResponseContent] = None, + text: typing.Optional[str] = None, + html: typing.Optional[str] = None, + json: typing.Any = None, + stream: typing.Union[SyncByteStream, AsyncByteStream, None] = None, + request: typing.Optional[Request] = None, + extensions: typing.Optional[ResponseExtensions] = None, + history: typing.Optional[typing.List["Response"]] = None, + default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8", + ): + self.status_code = status_code + self.headers = Headers(headers) + + self._request: typing.Optional[Request] = request + + # When follow_redirects=False and a redirect is received, + # the client will set `response.next_request`. + self.next_request: typing.Optional[Request] = None + + self.extensions: ResponseExtensions = {} if extensions is None else extensions + self.history = [] if history is None else list(history) + + self.is_closed = False + self.is_stream_consumed = False + + self.default_encoding = default_encoding + + if stream is None: + headers, stream = encode_response(content, text, html, json) + self._prepare(headers) + self.stream = stream + if isinstance(stream, ByteStream): + # Load the response body, except for streaming content. + self.read() + else: + # There's an important distinction between `Response(content=...)`, + # and `Response(stream=...)`. + # + # Using `content=...` implies automatically populated content headers, + # of either `Content-Length: ...` or `Transfer-Encoding: chunked`. + # + # Using `stream=...` will not automatically include any content headers. + # + # As an end-user you don't really need `stream=...`. It's only + # useful when creating response instances having received a stream + # from the transport API. + self.stream = stream + + self._num_bytes_downloaded = 0 + + def _prepare(self, default_headers: typing.Dict[str, str]) -> None: + for key, value in default_headers.items(): + # Ignore Transfer-Encoding if the Content-Length has been set explicitly. + if key.lower() == "transfer-encoding" and "content-length" in self.headers: + continue + self.headers.setdefault(key, value) + + @property + def elapsed(self) -> datetime.timedelta: + """ + Returns the time taken for the complete request/response + cycle to complete. + """ + if not hasattr(self, "_elapsed"): + raise RuntimeError( + "'.elapsed' may only be accessed after the response " + "has been read or closed." + ) + return self._elapsed + + @elapsed.setter + def elapsed(self, elapsed: datetime.timedelta) -> None: + self._elapsed = elapsed + + @property + def request(self) -> Request: + """ + Returns the request instance associated to the current response. + """ + if self._request is None: + raise RuntimeError( + "The request instance has not been set on this response." + ) + return self._request + + @request.setter + def request(self, value: Request) -> None: + self._request = value + + @property + def http_version(self) -> str: + try: + http_version: bytes = self.extensions["http_version"] + except KeyError: + return "HTTP/1.1" + else: + return http_version.decode("ascii", errors="ignore") + + @property + def reason_phrase(self) -> str: + try: + reason_phrase: bytes = self.extensions["reason_phrase"] + except KeyError: + return codes.get_reason_phrase(self.status_code) + else: + return reason_phrase.decode("ascii", errors="ignore") + + @property + def url(self) -> URL: + """ + Returns the URL for which the request was made. + """ + return self.request.url + + @property + def content(self) -> bytes: + if not hasattr(self, "_content"): + raise ResponseNotRead() + return self._content + + @property + def text(self) -> str: + if not hasattr(self, "_text"): + content = self.content + if not content: + self._text = "" + else: + decoder = TextDecoder(encoding=self.encoding or "utf-8") + self._text = "".join([decoder.decode(self.content), decoder.flush()]) + return self._text + + @property + def encoding(self) -> typing.Optional[str]: + """ + Return an encoding to use for decoding the byte content into text. + The priority for determining this is given by... + + * `.encoding = <>` has been set explicitly. + * The encoding as specified by the charset parameter in the Content-Type header. + * The encoding as determined by `default_encoding`, which may either be + a string like "utf-8" indicating the encoding to use, or may be a callable + which enables charset autodetection. + """ + if not hasattr(self, "_encoding"): + encoding = self.charset_encoding + if encoding is None or not is_known_encoding(encoding): + if isinstance(self.default_encoding, str): + encoding = self.default_encoding + elif hasattr(self, "_content"): + encoding = self.default_encoding(self._content) + self._encoding = encoding or "utf-8" + return self._encoding + + @encoding.setter + def encoding(self, value: str) -> None: + self._encoding = value + + @property + def charset_encoding(self) -> typing.Optional[str]: + """ + Return the encoding, as specified by the Content-Type header. + """ + content_type = self.headers.get("Content-Type") + if content_type is None: + return None + + return parse_content_type_charset(content_type) + + def _get_content_decoder(self) -> ContentDecoder: + """ + Returns a decoder instance which can be used to decode the raw byte + content, depending on the Content-Encoding used in the response. + """ + if not hasattr(self, "_decoder"): + decoders: typing.List[ContentDecoder] = [] + values = self.headers.get_list("content-encoding", split_commas=True) + for value in values: + value = value.strip().lower() + try: + decoder_cls = SUPPORTED_DECODERS[value] + decoders.append(decoder_cls()) + except KeyError: + continue + + if len(decoders) == 1: + self._decoder = decoders[0] + elif len(decoders) > 1: + self._decoder = MultiDecoder(children=decoders) + else: + self._decoder = IdentityDecoder() + + return self._decoder + + @property + def is_informational(self) -> bool: + """ + A property which is `True` for 1xx status codes, `False` otherwise. + """ + return codes.is_informational(self.status_code) + + @property + def is_success(self) -> bool: + """ + A property which is `True` for 2xx status codes, `False` otherwise. + """ + return codes.is_success(self.status_code) + + @property + def is_redirect(self) -> bool: + """ + A property which is `True` for 3xx status codes, `False` otherwise. + + Note that not all responses with a 3xx status code indicate a URL redirect. + + Use `response.has_redirect_location` to determine responses with a properly + formed URL redirection. + """ + return codes.is_redirect(self.status_code) + + @property + def is_client_error(self) -> bool: + """ + A property which is `True` for 4xx status codes, `False` otherwise. + """ + return codes.is_client_error(self.status_code) + + @property + def is_server_error(self) -> bool: + """ + A property which is `True` for 5xx status codes, `False` otherwise. + """ + return codes.is_server_error(self.status_code) + + @property + def is_error(self) -> bool: + """ + A property which is `True` for 4xx and 5xx status codes, `False` otherwise. + """ + return codes.is_error(self.status_code) + + @property + def has_redirect_location(self) -> bool: + """ + Returns True for 3xx responses with a properly formed URL redirection, + `False` otherwise. + """ + return ( + self.status_code + in ( + # 301 (Cacheable redirect. Method may change to GET.) + codes.MOVED_PERMANENTLY, + # 302 (Uncacheable redirect. Method may change to GET.) + codes.FOUND, + # 303 (Client should make a GET or HEAD request.) + codes.SEE_OTHER, + # 307 (Equiv. 302, but retain method) + codes.TEMPORARY_REDIRECT, + # 308 (Equiv. 301, but retain method) + codes.PERMANENT_REDIRECT, + ) + and "Location" in self.headers + ) + + def raise_for_status(self) -> "Response": + """ + Raise the `HTTPStatusError` if one occurred. + """ + request = self._request + if request is None: + raise RuntimeError( + "Cannot call `raise_for_status` as the request " + "instance has not been set on this response." + ) + + if self.is_success: + return self + + if self.has_redirect_location: + message = ( + "{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n" + "Redirect location: '{0.headers[location]}'\n" + "For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}" + ) + else: + message = ( + "{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n" + "For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}" + ) + + status_class = self.status_code // 100 + error_types = { + 1: "Informational response", + 3: "Redirect response", + 4: "Client error", + 5: "Server error", + } + error_type = error_types.get(status_class, "Invalid status code") + message = message.format(self, error_type=error_type) + raise HTTPStatusError(message, request=request, response=self) + + def json(self, **kwargs: typing.Any) -> typing.Any: + if self.charset_encoding is None and self.content and len(self.content) > 3: + encoding = guess_json_utf(self.content) + if encoding is not None: + return jsonlib.loads(self.content.decode(encoding), **kwargs) + return jsonlib.loads(self.text, **kwargs) + + @property + def cookies(self) -> "Cookies": + if not hasattr(self, "_cookies"): + self._cookies = Cookies() + self._cookies.extract_cookies(self) + return self._cookies + + @property + def links(self) -> typing.Dict[typing.Optional[str], typing.Dict[str, str]]: + """ + Returns the parsed header links of the response, if any + """ + header = self.headers.get("link") + ldict = {} + if header: + links = parse_header_links(header) + for link in links: + key = link.get("rel") or link.get("url") + ldict[key] = link + return ldict + + @property + def num_bytes_downloaded(self) -> int: + return self._num_bytes_downloaded + + def __repr__(self) -> str: + return f"<Response [{self.status_code} {self.reason_phrase}]>" + + def __getstate__(self) -> typing.Dict[str, typing.Any]: + return { + name: value + for name, value in self.__dict__.items() + if name not in ["extensions", "stream", "is_closed", "_decoder"] + } + + def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None: + for name, value in state.items(): + setattr(self, name, value) + self.is_closed = True + self.extensions = {} + self.stream = UnattachedStream() + + def read(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join(self.iter_bytes()) + return self._content + + def iter_bytes( + self, chunk_size: typing.Optional[int] = None + ) -> typing.Iterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "_content"): + chunk_size = len(self._content) if chunk_size is None else chunk_size + for i in range(0, len(self._content), max(chunk_size, 1)): + yield self._content[i : i + chunk_size] + else: + decoder = self._get_content_decoder() + chunker = ByteChunker(chunk_size=chunk_size) + with request_context(request=self._request): + for raw_bytes in self.iter_raw(): + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() + for chunk in chunker.decode(decoded): + yield chunk # pragma: no cover + for chunk in chunker.flush(): + yield chunk + + def iter_text( + self, chunk_size: typing.Optional[int] = None + ) -> typing.Iterator[str]: + """ + A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + decoder = TextDecoder(encoding=self.encoding or "utf-8") + chunker = TextChunker(chunk_size=chunk_size) + with request_context(request=self._request): + for byte_content in self.iter_bytes(): + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk + text_content = decoder.flush() + for chunk in chunker.decode(text_content): + yield chunk + for chunk in chunker.flush(): + yield chunk + + def iter_lines(self) -> typing.Iterator[str]: + decoder = LineDecoder() + with request_context(request=self._request): + for text in self.iter_text(): + for line in decoder.decode(text): + yield line + for line in decoder.flush(): + yield line + + def iter_raw( + self, chunk_size: typing.Optional[int] = None + ) -> typing.Iterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise StreamClosed() + if not isinstance(self.stream, SyncByteStream): + raise RuntimeError("Attempted to call a sync iterator on an async stream.") + + self.is_stream_consumed = True + self._num_bytes_downloaded = 0 + chunker = ByteChunker(chunk_size=chunk_size) + + with request_context(request=self._request): + for raw_stream_bytes in self.stream: + self._num_bytes_downloaded += len(raw_stream_bytes) + for chunk in chunker.decode(raw_stream_bytes): + yield chunk + + for chunk in chunker.flush(): + yield chunk + + self.close() + + def close(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not isinstance(self.stream, SyncByteStream): + raise RuntimeError("Attempted to call an sync close on an async stream.") + + if not self.is_closed: + self.is_closed = True + with request_context(request=self._request): + self.stream.close() + + async def aread(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join([part async for part in self.aiter_bytes()]) + return self._content + + async def aiter_bytes( + self, chunk_size: typing.Optional[int] = None + ) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "_content"): + chunk_size = len(self._content) if chunk_size is None else chunk_size + for i in range(0, len(self._content), max(chunk_size, 1)): + yield self._content[i : i + chunk_size] + else: + decoder = self._get_content_decoder() + chunker = ByteChunker(chunk_size=chunk_size) + with request_context(request=self._request): + async for raw_bytes in self.aiter_raw(): + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() + for chunk in chunker.decode(decoded): + yield chunk # pragma: no cover + for chunk in chunker.flush(): + yield chunk + + async def aiter_text( + self, chunk_size: typing.Optional[int] = None + ) -> typing.AsyncIterator[str]: + """ + A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + decoder = TextDecoder(encoding=self.encoding or "utf-8") + chunker = TextChunker(chunk_size=chunk_size) + with request_context(request=self._request): + async for byte_content in self.aiter_bytes(): + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk + text_content = decoder.flush() + for chunk in chunker.decode(text_content): + yield chunk + for chunk in chunker.flush(): + yield chunk + + async def aiter_lines(self) -> typing.AsyncIterator[str]: + decoder = LineDecoder() + with request_context(request=self._request): + async for text in self.aiter_text(): + for line in decoder.decode(text): + yield line + for line in decoder.flush(): + yield line + + async def aiter_raw( + self, chunk_size: typing.Optional[int] = None + ) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise StreamClosed() + if not isinstance(self.stream, AsyncByteStream): + raise RuntimeError("Attempted to call an async iterator on an sync stream.") + + self.is_stream_consumed = True + self._num_bytes_downloaded = 0 + chunker = ByteChunker(chunk_size=chunk_size) + + with request_context(request=self._request): + async for raw_stream_bytes in self.stream: + self._num_bytes_downloaded += len(raw_stream_bytes) + for chunk in chunker.decode(raw_stream_bytes): + yield chunk + + for chunk in chunker.flush(): + yield chunk + + await self.aclose() + + async def aclose(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not isinstance(self.stream, AsyncByteStream): + raise RuntimeError("Attempted to call an async close on an sync stream.") + + if not self.is_closed: + self.is_closed = True + with request_context(request=self._request): + await self.stream.aclose() + + +class Cookies(typing.MutableMapping[str, str]): + """ + HTTP Cookies, as a mutable mapping. + """ + + def __init__(self, cookies: typing.Optional[CookieTypes] = None) -> None: + if cookies is None or isinstance(cookies, dict): + self.jar = CookieJar() + if isinstance(cookies, dict): + for key, value in cookies.items(): + self.set(key, value) + elif isinstance(cookies, list): + self.jar = CookieJar() + for key, value in cookies: + self.set(key, value) + elif isinstance(cookies, Cookies): + self.jar = CookieJar() + for cookie in cookies.jar: + self.jar.set_cookie(cookie) + else: + self.jar = cookies + + def extract_cookies(self, response: Response) -> None: + """ + Loads any cookies based on the response `Set-Cookie` headers. + """ + urllib_response = self._CookieCompatResponse(response) + urllib_request = self._CookieCompatRequest(response.request) + + self.jar.extract_cookies(urllib_response, urllib_request) # type: ignore + + def set_cookie_header(self, request: Request) -> None: + """ + Sets an appropriate 'Cookie:' HTTP header on the `Request`. + """ + urllib_request = self._CookieCompatRequest(request) + self.jar.add_cookie_header(urllib_request) + + def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None: + """ + Set a cookie value by name. May optionally include domain and path. + """ + kwargs = { + "version": 0, + "name": name, + "value": value, + "port": None, + "port_specified": False, + "domain": domain, + "domain_specified": bool(domain), + "domain_initial_dot": domain.startswith("."), + "path": path, + "path_specified": bool(path), + "secure": False, + "expires": None, + "discard": True, + "comment": None, + "comment_url": None, + "rest": {"HttpOnly": None}, + "rfc2109": False, + } + cookie = Cookie(**kwargs) # type: ignore + self.jar.set_cookie(cookie) + + def get( # type: ignore + self, + name: str, + default: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + path: typing.Optional[str] = None, + ) -> typing.Optional[str]: + """ + Get a cookie by name. May optionally include domain and path + in order to specify exactly which cookie to retrieve. + """ + value = None + for cookie in self.jar: + if cookie.name == name: + if domain is None or cookie.domain == domain: + if path is None or cookie.path == path: + if value is not None: + message = f"Multiple cookies exist with name={name}" + raise CookieConflict(message) + value = cookie.value + + if value is None: + return default + return value + + def delete( + self, + name: str, + domain: typing.Optional[str] = None, + path: typing.Optional[str] = None, + ) -> None: + """ + Delete a cookie by name. May optionally include domain and path + in order to specify exactly which cookie to delete. + """ + if domain is not None and path is not None: + return self.jar.clear(domain, path, name) + + remove = [ + cookie + for cookie in self.jar + if cookie.name == name + and (domain is None or cookie.domain == domain) + and (path is None or cookie.path == path) + ] + + for cookie in remove: + self.jar.clear(cookie.domain, cookie.path, cookie.name) + + def clear( + self, domain: typing.Optional[str] = None, path: typing.Optional[str] = None + ) -> None: + """ + Delete all cookies. Optionally include a domain and path in + order to only delete a subset of all the cookies. + """ + args = [] + if domain is not None: + args.append(domain) + if path is not None: + assert domain is not None + args.append(path) + self.jar.clear(*args) + + def update(self, cookies: typing.Optional[CookieTypes] = None) -> None: # type: ignore + cookies = Cookies(cookies) + for cookie in cookies.jar: + self.jar.set_cookie(cookie) + + def __setitem__(self, name: str, value: str) -> None: + return self.set(name, value) + + def __getitem__(self, name: str) -> str: + value = self.get(name) + if value is None: + raise KeyError(name) + return value + + def __delitem__(self, name: str) -> None: + return self.delete(name) + + def __len__(self) -> int: + return len(self.jar) + + def __iter__(self) -> typing.Iterator[str]: + return (cookie.name for cookie in self.jar) + + def __bool__(self) -> bool: + for _ in self.jar: + return True + return False + + def __repr__(self) -> str: + cookies_repr = ", ".join( + [ + f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />" + for cookie in self.jar + ] + ) + + return f"<Cookies[{cookies_repr}]>" + + class _CookieCompatRequest(urllib.request.Request): + """ + Wraps a `Request` instance up in a compatibility interface suitable + for use with `CookieJar` operations. + """ + + def __init__(self, request: Request) -> None: + super().__init__( + url=str(request.url), + headers=dict(request.headers), + method=request.method, + ) + self.request = request + + def add_unredirected_header(self, key: str, value: str) -> None: + super().add_unredirected_header(key, value) + self.request.headers[key] = value + + class _CookieCompatResponse: + """ + Wraps a `Request` instance up in a compatibility interface suitable + for use with `CookieJar` operations. + """ + + def __init__(self, response: Response): + self.response = response + + def info(self) -> email.message.Message: + info = email.message.Message() + for key, value in self.response.headers.multi_items(): + # Note that setting `info[key]` here is an "append" operation, + # not a "replace" operation. + # https://docs.python.org/3/library/email.compat32-message.html#email.message.Message.__setitem__ + info[key] = value + return info diff --git a/contrib/python/httpx/httpx/_multipart.py b/contrib/python/httpx/httpx/_multipart.py new file mode 100644 index 0000000000..446f4ad2df --- /dev/null +++ b/contrib/python/httpx/httpx/_multipart.py @@ -0,0 +1,267 @@ +import binascii +import io +import os +import typing +from pathlib import Path + +from ._types import ( + AsyncByteStream, + FileContent, + FileTypes, + RequestData, + RequestFiles, + SyncByteStream, +) +from ._utils import ( + format_form_param, + guess_content_type, + peek_filelike_length, + primitive_value_to_str, + to_bytes, +) + + +def get_multipart_boundary_from_content_type( + content_type: typing.Optional[bytes], +) -> typing.Optional[bytes]: + if not content_type or not content_type.startswith(b"multipart/form-data"): + return None + # parse boundary according to + # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1 + if b";" in content_type: + for section in content_type.split(b";"): + if section.strip().lower().startswith(b"boundary="): + return section.strip()[len(b"boundary=") :].strip(b'"') + return None + + +class DataField: + """ + A single form field item, within a multipart form field. + """ + + def __init__( + self, name: str, value: typing.Union[str, bytes, int, float, None] + ) -> None: + if not isinstance(name, str): + raise TypeError( + f"Invalid type for name. Expected str, got {type(name)}: {name!r}" + ) + if value is not None and not isinstance(value, (str, bytes, int, float)): + raise TypeError( + f"Invalid type for value. Expected primitive type, got {type(value)}: {value!r}" + ) + self.name = name + self.value: typing.Union[str, bytes] = ( + value if isinstance(value, bytes) else primitive_value_to_str(value) + ) + + def render_headers(self) -> bytes: + if not hasattr(self, "_headers"): + name = format_form_param("name", self.name) + self._headers = b"".join( + [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"] + ) + + return self._headers + + def render_data(self) -> bytes: + if not hasattr(self, "_data"): + self._data = to_bytes(self.value) + + return self._data + + def get_length(self) -> int: + headers = self.render_headers() + data = self.render_data() + return len(headers) + len(data) + + def render(self) -> typing.Iterator[bytes]: + yield self.render_headers() + yield self.render_data() + + +class FileField: + """ + A single file field item, within a multipart form field. + """ + + CHUNK_SIZE = 64 * 1024 + + def __init__(self, name: str, value: FileTypes) -> None: + self.name = name + + fileobj: FileContent + + headers: typing.Dict[str, str] = {} + content_type: typing.Optional[str] = None + + # This large tuple based API largely mirror's requests' API + # It would be good to think of better APIs for this that we could include in httpx 2.0 + # since variable length tuples (especially of 4 elements) are quite unwieldly + if isinstance(value, tuple): + if len(value) == 2: + # neither the 3rd parameter (content_type) nor the 4th (headers) was included + filename, fileobj = value # type: ignore + elif len(value) == 3: + filename, fileobj, content_type = value # type: ignore + else: + # all 4 parameters included + filename, fileobj, content_type, headers = value # type: ignore + else: + filename = Path(str(getattr(value, "name", "upload"))).name + fileobj = value + + if content_type is None: + content_type = guess_content_type(filename) + + has_content_type_header = any("content-type" in key.lower() for key in headers) + if content_type is not None and not has_content_type_header: + # note that unlike requests, we ignore the content_type + # provided in the 3rd tuple element if it is also included in the headers + # requests does the opposite (it overwrites the header with the 3rd tuple element) + headers["Content-Type"] = content_type + + if isinstance(fileobj, io.StringIO): + raise TypeError( + "Multipart file uploads require 'io.BytesIO', not 'io.StringIO'." + ) + if isinstance(fileobj, io.TextIOBase): + raise TypeError( + "Multipart file uploads must be opened in binary mode, not text mode." + ) + + self.filename = filename + self.file = fileobj + self.headers = headers + + def get_length(self) -> typing.Optional[int]: + headers = self.render_headers() + + if isinstance(self.file, (str, bytes)): + return len(headers) + len(to_bytes(self.file)) + + file_length = peek_filelike_length(self.file) + + # If we can't determine the filesize without reading it into memory, + # then return `None` here, to indicate an unknown file length. + if file_length is None: + return None + + return len(headers) + file_length + + def render_headers(self) -> bytes: + if not hasattr(self, "_headers"): + parts = [ + b"Content-Disposition: form-data; ", + format_form_param("name", self.name), + ] + if self.filename: + filename = format_form_param("filename", self.filename) + parts.extend([b"; ", filename]) + for header_name, header_value in self.headers.items(): + key, val = f"\r\n{header_name}: ".encode(), header_value.encode() + parts.extend([key, val]) + parts.append(b"\r\n\r\n") + self._headers = b"".join(parts) + + return self._headers + + def render_data(self) -> typing.Iterator[bytes]: + if isinstance(self.file, (str, bytes)): + yield to_bytes(self.file) + return + + if hasattr(self.file, "seek"): + try: + self.file.seek(0) + except io.UnsupportedOperation: + pass + + chunk = self.file.read(self.CHUNK_SIZE) + while chunk: + yield to_bytes(chunk) + chunk = self.file.read(self.CHUNK_SIZE) + + def render(self) -> typing.Iterator[bytes]: + yield self.render_headers() + yield from self.render_data() + + +class MultipartStream(SyncByteStream, AsyncByteStream): + """ + Request content as streaming multipart encoded form data. + """ + + def __init__( + self, + data: RequestData, + files: RequestFiles, + boundary: typing.Optional[bytes] = None, + ) -> None: + if boundary is None: + boundary = binascii.hexlify(os.urandom(16)) + + self.boundary = boundary + self.content_type = "multipart/form-data; boundary=%s" % boundary.decode( + "ascii" + ) + self.fields = list(self._iter_fields(data, files)) + + def _iter_fields( + self, data: RequestData, files: RequestFiles + ) -> typing.Iterator[typing.Union[FileField, DataField]]: + for name, value in data.items(): + if isinstance(value, (tuple, list)): + for item in value: + yield DataField(name=name, value=item) + else: + yield DataField(name=name, value=value) + + file_items = files.items() if isinstance(files, typing.Mapping) else files + for name, value in file_items: + yield FileField(name=name, value=value) + + def iter_chunks(self) -> typing.Iterator[bytes]: + for field in self.fields: + yield b"--%s\r\n" % self.boundary + yield from field.render() + yield b"\r\n" + yield b"--%s--\r\n" % self.boundary + + def get_content_length(self) -> typing.Optional[int]: + """ + Return the length of the multipart encoded content, or `None` if + any of the files have a length that cannot be determined upfront. + """ + boundary_length = len(self.boundary) + length = 0 + + for field in self.fields: + field_length = field.get_length() + if field_length is None: + return None + + length += 2 + boundary_length + 2 # b"--{boundary}\r\n" + length += field_length + length += 2 # b"\r\n" + + length += 2 + boundary_length + 4 # b"--{boundary}--\r\n" + return length + + # Content stream interface. + + def get_headers(self) -> typing.Dict[str, str]: + content_length = self.get_content_length() + content_type = self.content_type + if content_length is None: + return {"Transfer-Encoding": "chunked", "Content-Type": content_type} + return {"Content-Length": str(content_length), "Content-Type": content_type} + + def __iter__(self) -> typing.Iterator[bytes]: + for chunk in self.iter_chunks(): + yield chunk + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + for chunk in self.iter_chunks(): + yield chunk diff --git a/contrib/python/httpx/httpx/_status_codes.py b/contrib/python/httpx/httpx/_status_codes.py new file mode 100644 index 0000000000..671c30e1b8 --- /dev/null +++ b/contrib/python/httpx/httpx/_status_codes.py @@ -0,0 +1,158 @@ +from enum import IntEnum + + +class codes(IntEnum): + """HTTP status codes and reason phrases + + Status codes from the following RFCs are all observed: + + * RFC 7231: Hypertext Transfer Protocol (HTTP/1.1), obsoletes 2616 + * RFC 6585: Additional HTTP Status Codes + * RFC 3229: Delta encoding in HTTP + * RFC 4918: HTTP Extensions for WebDAV, obsoletes 2518 + * RFC 5842: Binding Extensions to WebDAV + * RFC 7238: Permanent Redirect + * RFC 2295: Transparent Content Negotiation in HTTP + * RFC 2774: An HTTP Extension Framework + * RFC 7540: Hypertext Transfer Protocol Version 2 (HTTP/2) + * RFC 2324: Hyper Text Coffee Pot Control Protocol (HTCPCP/1.0) + * RFC 7725: An HTTP Status Code to Report Legal Obstacles + * RFC 8297: An HTTP Status Code for Indicating Hints + * RFC 8470: Using Early Data in HTTP + """ + + def __new__(cls, value: int, phrase: str = "") -> "codes": + obj = int.__new__(cls, value) + obj._value_ = value + + obj.phrase = phrase # type: ignore[attr-defined] + return obj + + def __str__(self) -> str: + return str(self.value) + + @classmethod + def get_reason_phrase(cls, value: int) -> str: + try: + return codes(value).phrase # type: ignore + except ValueError: + return "" + + @classmethod + def is_informational(cls, value: int) -> bool: + """ + Returns `True` for 1xx status codes, `False` otherwise. + """ + return 100 <= value <= 199 + + @classmethod + def is_success(cls, value: int) -> bool: + """ + Returns `True` for 2xx status codes, `False` otherwise. + """ + return 200 <= value <= 299 + + @classmethod + def is_redirect(cls, value: int) -> bool: + """ + Returns `True` for 3xx status codes, `False` otherwise. + """ + return 300 <= value <= 399 + + @classmethod + def is_client_error(cls, value: int) -> bool: + """ + Returns `True` for 4xx status codes, `False` otherwise. + """ + return 400 <= value <= 499 + + @classmethod + def is_server_error(cls, value: int) -> bool: + """ + Returns `True` for 5xx status codes, `False` otherwise. + """ + return 500 <= value <= 599 + + @classmethod + def is_error(cls, value: int) -> bool: + """ + Returns `True` for 4xx or 5xx status codes, `False` otherwise. + """ + return 400 <= value <= 599 + + # informational + CONTINUE = 100, "Continue" + SWITCHING_PROTOCOLS = 101, "Switching Protocols" + PROCESSING = 102, "Processing" + EARLY_HINTS = 103, "Early Hints" + + # success + OK = 200, "OK" + CREATED = 201, "Created" + ACCEPTED = 202, "Accepted" + NON_AUTHORITATIVE_INFORMATION = 203, "Non-Authoritative Information" + NO_CONTENT = 204, "No Content" + RESET_CONTENT = 205, "Reset Content" + PARTIAL_CONTENT = 206, "Partial Content" + MULTI_STATUS = 207, "Multi-Status" + ALREADY_REPORTED = 208, "Already Reported" + IM_USED = 226, "IM Used" + + # redirection + MULTIPLE_CHOICES = 300, "Multiple Choices" + MOVED_PERMANENTLY = 301, "Moved Permanently" + FOUND = 302, "Found" + SEE_OTHER = 303, "See Other" + NOT_MODIFIED = 304, "Not Modified" + USE_PROXY = 305, "Use Proxy" + TEMPORARY_REDIRECT = 307, "Temporary Redirect" + PERMANENT_REDIRECT = 308, "Permanent Redirect" + + # client error + BAD_REQUEST = 400, "Bad Request" + UNAUTHORIZED = 401, "Unauthorized" + PAYMENT_REQUIRED = 402, "Payment Required" + FORBIDDEN = 403, "Forbidden" + NOT_FOUND = 404, "Not Found" + METHOD_NOT_ALLOWED = 405, "Method Not Allowed" + NOT_ACCEPTABLE = 406, "Not Acceptable" + PROXY_AUTHENTICATION_REQUIRED = 407, "Proxy Authentication Required" + REQUEST_TIMEOUT = 408, "Request Timeout" + CONFLICT = 409, "Conflict" + GONE = 410, "Gone" + LENGTH_REQUIRED = 411, "Length Required" + PRECONDITION_FAILED = 412, "Precondition Failed" + REQUEST_ENTITY_TOO_LARGE = 413, "Request Entity Too Large" + REQUEST_URI_TOO_LONG = 414, "Request-URI Too Long" + UNSUPPORTED_MEDIA_TYPE = 415, "Unsupported Media Type" + REQUESTED_RANGE_NOT_SATISFIABLE = 416, "Requested Range Not Satisfiable" + EXPECTATION_FAILED = 417, "Expectation Failed" + IM_A_TEAPOT = 418, "I'm a teapot" + MISDIRECTED_REQUEST = 421, "Misdirected Request" + UNPROCESSABLE_ENTITY = 422, "Unprocessable Entity" + LOCKED = 423, "Locked" + FAILED_DEPENDENCY = 424, "Failed Dependency" + TOO_EARLY = 425, "Too Early" + UPGRADE_REQUIRED = 426, "Upgrade Required" + PRECONDITION_REQUIRED = 428, "Precondition Required" + TOO_MANY_REQUESTS = 429, "Too Many Requests" + REQUEST_HEADER_FIELDS_TOO_LARGE = 431, "Request Header Fields Too Large" + UNAVAILABLE_FOR_LEGAL_REASONS = 451, "Unavailable For Legal Reasons" + + # server errors + INTERNAL_SERVER_ERROR = 500, "Internal Server Error" + NOT_IMPLEMENTED = 501, "Not Implemented" + BAD_GATEWAY = 502, "Bad Gateway" + SERVICE_UNAVAILABLE = 503, "Service Unavailable" + GATEWAY_TIMEOUT = 504, "Gateway Timeout" + HTTP_VERSION_NOT_SUPPORTED = 505, "HTTP Version Not Supported" + VARIANT_ALSO_NEGOTIATES = 506, "Variant Also Negotiates" + INSUFFICIENT_STORAGE = 507, "Insufficient Storage" + LOOP_DETECTED = 508, "Loop Detected" + NOT_EXTENDED = 510, "Not Extended" + NETWORK_AUTHENTICATION_REQUIRED = 511, "Network Authentication Required" + + +# Include lower-case styles for `requests` compatibility. +for code in codes: + setattr(codes, code._name_.lower(), int(code)) diff --git a/contrib/python/httpx/httpx/_transports/__init__.py b/contrib/python/httpx/httpx/_transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/__init__.py diff --git a/contrib/python/httpx/httpx/_transports/asgi.py b/contrib/python/httpx/httpx/_transports/asgi.py new file mode 100644 index 0000000000..f67f0fbd5b --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/asgi.py @@ -0,0 +1,179 @@ +import typing + +import sniffio + +from .._models import Request, Response +from .._types import AsyncByteStream +from .base import AsyncBaseTransport + +if typing.TYPE_CHECKING: # pragma: no cover + import asyncio + + import trio + + Event = typing.Union[asyncio.Event, trio.Event] + + +_Message = typing.Dict[str, typing.Any] +_Receive = typing.Callable[[], typing.Awaitable[_Message]] +_Send = typing.Callable[ + [typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None] +] +_ASGIApp = typing.Callable[ + [typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None] +] + + +def create_event() -> "Event": + if sniffio.current_async_library() == "trio": + import trio + + return trio.Event() + else: + import asyncio + + return asyncio.Event() + + +class ASGIResponseStream(AsyncByteStream): + def __init__(self, body: typing.List[bytes]) -> None: + self._body = body + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + yield b"".join(self._body) + + +class ASGITransport(AsyncBaseTransport): + """ + A custom AsyncTransport that handles sending requests directly to an ASGI app. + The simplest way to use this functionality is to use the `app` argument. + + ``` + client = httpx.AsyncClient(app=app) + ``` + + Alternatively, you can setup the transport instance explicitly. + This allows you to include any additional configuration arguments specific + to the ASGITransport class: + + ``` + transport = httpx.ASGITransport( + app=app, + root_path="/submount", + client=("1.2.3.4", 123) + ) + client = httpx.AsyncClient(transport=transport) + ``` + + Arguments: + + * `app` - The ASGI application. + * `raise_app_exceptions` - Boolean indicating if exceptions in the application + should be raised. Default to `True`. Can be set to `False` for use cases + such as testing the content of a client 500 response. + * `root_path` - The root path on which the ASGI application should be mounted. + * `client` - A two-tuple indicating the client IP and port of incoming requests. + ``` + """ + + def __init__( + self, + app: _ASGIApp, + raise_app_exceptions: bool = True, + root_path: str = "", + client: typing.Tuple[str, int] = ("127.0.0.1", 123), + ) -> None: + self.app = app + self.raise_app_exceptions = raise_app_exceptions + self.root_path = root_path + self.client = client + + async def handle_async_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + # ASGI scope. + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path, + "query_string": request.url.query, + "server": (request.url.host, request.url.port), + "client": self.client, + "root_path": self.root_path, + } + + # Request. + request_body_chunks = request.stream.__aiter__() + request_complete = False + + # Response. + status_code = None + response_headers = None + body_parts = [] + response_started = False + response_complete = create_event() + + # ASGI callables. + + async def receive() -> typing.Dict[str, typing.Any]: + nonlocal request_complete + + if request_complete: + await response_complete.wait() + return {"type": "http.disconnect"} + + try: + body = await request_body_chunks.__anext__() + except StopAsyncIteration: + request_complete = True + return {"type": "http.request", "body": b"", "more_body": False} + return {"type": "http.request", "body": body, "more_body": True} + + async def send(message: typing.Dict[str, typing.Any]) -> None: + nonlocal status_code, response_headers, response_started + + if message["type"] == "http.response.start": + assert not response_started + + status_code = message["status"] + response_headers = message.get("headers", []) + response_started = True + + elif message["type"] == "http.response.body": + assert not response_complete.is_set() + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body and request.method != "HEAD": + body_parts.append(body) + + if not more_body: + response_complete.set() + + try: + await self.app(scope, receive, send) + except Exception: # noqa: PIE-786 + if self.raise_app_exceptions: + raise + + response_complete.set() + if status_code is None: + status_code = 500 + if response_headers is None: + response_headers = {} + + assert response_complete.is_set() + assert status_code is not None + assert response_headers is not None + + stream = ASGIResponseStream(body_parts) + + return Response(status_code, headers=response_headers, stream=stream) diff --git a/contrib/python/httpx/httpx/_transports/base.py b/contrib/python/httpx/httpx/_transports/base.py new file mode 100644 index 0000000000..f6fdfe6943 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/base.py @@ -0,0 +1,82 @@ +import typing +from types import TracebackType + +from .._models import Request, Response + +T = typing.TypeVar("T", bound="BaseTransport") +A = typing.TypeVar("A", bound="AsyncBaseTransport") + + +class BaseTransport: + def __enter__(self: T) -> T: + return self + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + self.close() + + def handle_request(self, request: Request) -> Response: + """ + Send a single HTTP request and return a response. + + Developers shouldn't typically ever need to call into this API directly, + since the Client class provides all the higher level user-facing API + niceties. + + In order to properly release any network resources, the response + stream should *either* be consumed immediately, with a call to + `response.stream.read()`, or else the `handle_request` call should + be followed with a try/finally block to ensuring the stream is + always closed. + + Example usage: + + with httpx.HTTPTransport() as transport: + req = httpx.Request( + method=b"GET", + url=(b"https", b"www.example.com", 443, b"/"), + headers=[(b"Host", b"www.example.com")], + ) + resp = transport.handle_request(req) + body = resp.stream.read() + print(resp.status_code, resp.headers, body) + + + Takes a `Request` instance as the only argument. + + Returns a `Response` instance. + """ + raise NotImplementedError( + "The 'handle_request' method must be implemented." + ) # pragma: no cover + + def close(self) -> None: + pass + + +class AsyncBaseTransport: + async def __aenter__(self: A) -> A: + return self + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + await self.aclose() + + async def handle_async_request( + self, + request: Request, + ) -> Response: + raise NotImplementedError( + "The 'handle_async_request' method must be implemented." + ) # pragma: no cover + + async def aclose(self) -> None: + pass diff --git a/contrib/python/httpx/httpx/_transports/default.py b/contrib/python/httpx/httpx/_transports/default.py new file mode 100644 index 0000000000..7dba5b8208 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/default.py @@ -0,0 +1,378 @@ +""" +Custom transports, with nicely configured defaults. + +The following additional keyword arguments are currently supported by httpcore... + +* uds: str +* local_address: str +* retries: int + +Example usages... + +# Disable HTTP/2 on a single specific domain. +mounts = { + "all://": httpx.HTTPTransport(http2=True), + "all://*example.org": httpx.HTTPTransport() +} + +# Using advanced httpcore configuration, with connection retries. +transport = httpx.HTTPTransport(retries=1) +client = httpx.Client(transport=transport) + +# Using advanced httpcore configuration, with unix domain sockets. +transport = httpx.HTTPTransport(uds="socket.uds") +client = httpx.Client(transport=transport) +""" +import contextlib +import typing +from types import TracebackType + +import httpcore + +from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context +from .._exceptions import ( + ConnectError, + ConnectTimeout, + LocalProtocolError, + NetworkError, + PoolTimeout, + ProtocolError, + ProxyError, + ReadError, + ReadTimeout, + RemoteProtocolError, + TimeoutException, + UnsupportedProtocol, + WriteError, + WriteTimeout, +) +from .._models import Request, Response +from .._types import AsyncByteStream, CertTypes, SyncByteStream, VerifyTypes +from .base import AsyncBaseTransport, BaseTransport + +T = typing.TypeVar("T", bound="HTTPTransport") +A = typing.TypeVar("A", bound="AsyncHTTPTransport") + +SOCKET_OPTION = typing.Union[ + typing.Tuple[int, int, int], + typing.Tuple[int, int, typing.Union[bytes, bytearray]], + typing.Tuple[int, int, None, int], +] + + +@contextlib.contextmanager +def map_httpcore_exceptions() -> typing.Iterator[None]: + try: + yield + except Exception as exc: # noqa: PIE-786 + mapped_exc = None + + for from_exc, to_exc in HTTPCORE_EXC_MAP.items(): + if not isinstance(exc, from_exc): + continue + # We want to map to the most specific exception we can find. + # Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to + # `httpx.ReadTimeout`, not just `httpx.TimeoutException`. + if mapped_exc is None or issubclass(to_exc, mapped_exc): + mapped_exc = to_exc + + if mapped_exc is None: # pragma: no cover + raise + + message = str(exc) + raise mapped_exc(message) from exc + + +HTTPCORE_EXC_MAP = { + httpcore.TimeoutException: TimeoutException, + httpcore.ConnectTimeout: ConnectTimeout, + httpcore.ReadTimeout: ReadTimeout, + httpcore.WriteTimeout: WriteTimeout, + httpcore.PoolTimeout: PoolTimeout, + httpcore.NetworkError: NetworkError, + httpcore.ConnectError: ConnectError, + httpcore.ReadError: ReadError, + httpcore.WriteError: WriteError, + httpcore.ProxyError: ProxyError, + httpcore.UnsupportedProtocol: UnsupportedProtocol, + httpcore.ProtocolError: ProtocolError, + httpcore.LocalProtocolError: LocalProtocolError, + httpcore.RemoteProtocolError: RemoteProtocolError, +} + + +class ResponseStream(SyncByteStream): + def __init__(self, httpcore_stream: typing.Iterable[bytes]): + self._httpcore_stream = httpcore_stream + + def __iter__(self) -> typing.Iterator[bytes]: + with map_httpcore_exceptions(): + for part in self._httpcore_stream: + yield part + + def close(self) -> None: + if hasattr(self._httpcore_stream, "close"): + self._httpcore_stream.close() + + +class HTTPTransport(BaseTransport): + def __init__( + self, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + trust_env: bool = True, + proxy: typing.Optional[Proxy] = None, + uds: typing.Optional[str] = None, + local_address: typing.Optional[str] = None, + retries: int = 0, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> None: + ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + + if proxy is None: + self._pool = httpcore.ConnectionPool( + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + uds=uds, + local_address=local_address, + retries=retries, + socket_options=socket_options, + ) + elif proxy.url.scheme in ("http", "https"): + self._pool = httpcore.HTTPProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + proxy_headers=proxy.headers.raw, + ssl_context=ssl_context, + proxy_ssl_context=proxy.ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + socket_options=socket_options, + ) + elif proxy.url.scheme == "socks5": + try: + import socksio # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using SOCKS proxy, but the 'socksio' package is not installed. " + "Make sure to install httpx using `pip install httpx[socks]`." + ) from None + + self._pool = httpcore.SOCKSProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + ) + else: # pragma: no cover + raise ValueError( + f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}." + ) + + def __enter__(self: T) -> T: # Use generics for subclass support. + self._pool.__enter__() + return self + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + with map_httpcore_exceptions(): + self._pool.__exit__(exc_type, exc_value, traceback) + + def handle_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, SyncByteStream) + + req = httpcore.Request( + method=request.method, + url=httpcore.URL( + scheme=request.url.raw_scheme, + host=request.url.raw_host, + port=request.url.port, + target=request.url.raw_path, + ), + headers=request.headers.raw, + content=request.stream, + extensions=request.extensions, + ) + with map_httpcore_exceptions(): + resp = self._pool.handle_request(req) + + assert isinstance(resp.stream, typing.Iterable) + + return Response( + status_code=resp.status, + headers=resp.headers, + stream=ResponseStream(resp.stream), + extensions=resp.extensions, + ) + + def close(self) -> None: + self._pool.close() + + +class AsyncResponseStream(AsyncByteStream): + def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]): + self._httpcore_stream = httpcore_stream + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + with map_httpcore_exceptions(): + async for part in self._httpcore_stream: + yield part + + async def aclose(self) -> None: + if hasattr(self._httpcore_stream, "aclose"): + await self._httpcore_stream.aclose() + + +class AsyncHTTPTransport(AsyncBaseTransport): + def __init__( + self, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + trust_env: bool = True, + proxy: typing.Optional[Proxy] = None, + uds: typing.Optional[str] = None, + local_address: typing.Optional[str] = None, + retries: int = 0, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> None: + ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + + if proxy is None: + self._pool = httpcore.AsyncConnectionPool( + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + uds=uds, + local_address=local_address, + retries=retries, + socket_options=socket_options, + ) + elif proxy.url.scheme in ("http", "https"): + self._pool = httpcore.AsyncHTTPProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + proxy_headers=proxy.headers.raw, + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + socket_options=socket_options, + ) + elif proxy.url.scheme == "socks5": + try: + import socksio # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using SOCKS proxy, but the 'socksio' package is not installed. " + "Make sure to install httpx using `pip install httpx[socks]`." + ) from None + + self._pool = httpcore.AsyncSOCKSProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + ) + else: # pragma: no cover + raise ValueError( + f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}." + ) + + async def __aenter__(self: A) -> A: # Use generics for subclass support. + await self._pool.__aenter__() + return self + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + with map_httpcore_exceptions(): + await self._pool.__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + req = httpcore.Request( + method=request.method, + url=httpcore.URL( + scheme=request.url.raw_scheme, + host=request.url.raw_host, + port=request.url.port, + target=request.url.raw_path, + ), + headers=request.headers.raw, + content=request.stream, + extensions=request.extensions, + ) + with map_httpcore_exceptions(): + resp = await self._pool.handle_async_request(req) + + assert isinstance(resp.stream, typing.AsyncIterable) + + return Response( + status_code=resp.status, + headers=resp.headers, + stream=AsyncResponseStream(resp.stream), + extensions=resp.extensions, + ) + + async def aclose(self) -> None: + await self._pool.aclose() diff --git a/contrib/python/httpx/httpx/_transports/mock.py b/contrib/python/httpx/httpx/_transports/mock.py new file mode 100644 index 0000000000..82043da2d9 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/mock.py @@ -0,0 +1,38 @@ +import typing + +from .._models import Request, Response +from .base import AsyncBaseTransport, BaseTransport + +SyncHandler = typing.Callable[[Request], Response] +AsyncHandler = typing.Callable[[Request], typing.Coroutine[None, None, Response]] + + +class MockTransport(AsyncBaseTransport, BaseTransport): + def __init__(self, handler: typing.Union[SyncHandler, AsyncHandler]) -> None: + self.handler = handler + + def handle_request( + self, + request: Request, + ) -> Response: + request.read() + response = self.handler(request) + if not isinstance(response, Response): # pragma: no cover + raise TypeError("Cannot use an async handler in a sync Client") + return response + + async def handle_async_request( + self, + request: Request, + ) -> Response: + await request.aread() + response = self.handler(request) + + # Allow handler to *optionally* be an `async` function. + # If it is, then the `response` variable need to be awaited to actually + # return the result. + + if not isinstance(response, Response): + response = await response + + return response diff --git a/contrib/python/httpx/httpx/_transports/wsgi.py b/contrib/python/httpx/httpx/_transports/wsgi.py new file mode 100644 index 0000000000..a23d42c414 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/wsgi.py @@ -0,0 +1,144 @@ +import io +import itertools +import sys +import typing + +from .._models import Request, Response +from .._types import SyncByteStream +from .base import BaseTransport + +if typing.TYPE_CHECKING: + from _typeshed import OptExcInfo # pragma: no cover + from _typeshed.wsgi import WSGIApplication # pragma: no cover + +_T = typing.TypeVar("_T") + + +def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]: + body = iter(body) + for chunk in body: + if chunk: + return itertools.chain([chunk], body) + return [] + + +class WSGIByteStream(SyncByteStream): + def __init__(self, result: typing.Iterable[bytes]) -> None: + self._close = getattr(result, "close", None) + self._result = _skip_leading_empty_chunks(result) + + def __iter__(self) -> typing.Iterator[bytes]: + for part in self._result: + yield part + + def close(self) -> None: + if self._close is not None: + self._close() + + +class WSGITransport(BaseTransport): + """ + A custom transport that handles sending requests directly to an WSGI app. + The simplest way to use this functionality is to use the `app` argument. + + ``` + client = httpx.Client(app=app) + ``` + + Alternatively, you can setup the transport instance explicitly. + This allows you to include any additional configuration arguments specific + to the WSGITransport class: + + ``` + transport = httpx.WSGITransport( + app=app, + script_name="/submount", + remote_addr="1.2.3.4" + ) + client = httpx.Client(transport=transport) + ``` + + Arguments: + + * `app` - The WSGI application. + * `raise_app_exceptions` - Boolean indicating if exceptions in the application + should be raised. Default to `True`. Can be set to `False` for use cases + such as testing the content of a client 500 response. + * `script_name` - The root path on which the WSGI application should be mounted. + * `remote_addr` - A string indicating the client IP of incoming requests. + ``` + """ + + def __init__( + self, + app: "WSGIApplication", + raise_app_exceptions: bool = True, + script_name: str = "", + remote_addr: str = "127.0.0.1", + wsgi_errors: typing.Optional[typing.TextIO] = None, + ) -> None: + self.app = app + self.raise_app_exceptions = raise_app_exceptions + self.script_name = script_name + self.remote_addr = remote_addr + self.wsgi_errors = wsgi_errors + + def handle_request(self, request: Request) -> Response: + request.read() + wsgi_input = io.BytesIO(request.content) + + port = request.url.port or {"http": 80, "https": 443}[request.url.scheme] + environ = { + "wsgi.version": (1, 0), + "wsgi.url_scheme": request.url.scheme, + "wsgi.input": wsgi_input, + "wsgi.errors": self.wsgi_errors or sys.stderr, + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "REQUEST_METHOD": request.method, + "SCRIPT_NAME": self.script_name, + "PATH_INFO": request.url.path, + "QUERY_STRING": request.url.query.decode("ascii"), + "SERVER_NAME": request.url.host, + "SERVER_PORT": str(port), + "SERVER_PROTOCOL": "HTTP/1.1", + "REMOTE_ADDR": self.remote_addr, + } + for header_key, header_value in request.headers.raw: + key = header_key.decode("ascii").upper().replace("-", "_") + if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): + key = "HTTP_" + key + environ[key] = header_value.decode("ascii") + + seen_status = None + seen_response_headers = None + seen_exc_info = None + + def start_response( + status: str, + response_headers: typing.List[typing.Tuple[str, str]], + exc_info: typing.Optional["OptExcInfo"] = None, + ) -> typing.Callable[[bytes], typing.Any]: + nonlocal seen_status, seen_response_headers, seen_exc_info + seen_status = status + seen_response_headers = response_headers + seen_exc_info = exc_info + return lambda _: None + + result = self.app(environ, start_response) + + stream = WSGIByteStream(result) + + assert seen_status is not None + assert seen_response_headers is not None + if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions: + raise seen_exc_info[1] + + status_code = int(seen_status.split()[0]) + headers = [ + (key.encode("ascii"), value.encode("ascii")) + for key, value in seen_response_headers + ] + + return Response(status_code, headers=headers, stream=stream) diff --git a/contrib/python/httpx/httpx/_types.py b/contrib/python/httpx/httpx/_types.py new file mode 100644 index 0000000000..83cf35a32a --- /dev/null +++ b/contrib/python/httpx/httpx/_types.py @@ -0,0 +1,133 @@ +""" +Type definitions for type checking purposes. +""" + +import ssl +from http.cookiejar import CookieJar +from typing import ( + IO, + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) + +if TYPE_CHECKING: # pragma: no cover + from ._auth import Auth # noqa: F401 + from ._config import Proxy, Timeout # noqa: F401 + from ._models import Cookies, Headers, Request # noqa: F401 + from ._urls import URL, QueryParams # noqa: F401 + + +PrimitiveData = Optional[Union[str, int, float, bool]] + +RawURL = NamedTuple( + "RawURL", + [ + ("raw_scheme", bytes), + ("raw_host", bytes), + ("port", Optional[int]), + ("raw_path", bytes), + ], +) + +URLTypes = Union["URL", str] + +QueryParamTypes = Union[ + "QueryParams", + Mapping[str, Union[PrimitiveData, Sequence[PrimitiveData]]], + List[Tuple[str, PrimitiveData]], + Tuple[Tuple[str, PrimitiveData], ...], + str, + bytes, +] + +HeaderTypes = Union[ + "Headers", + Mapping[str, str], + Mapping[bytes, bytes], + Sequence[Tuple[str, str]], + Sequence[Tuple[bytes, bytes]], +] + +CookieTypes = Union["Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]] + +CertTypes = Union[ + # certfile + str, + # (certfile, keyfile) + Tuple[str, Optional[str]], + # (certfile, keyfile, password) + Tuple[str, Optional[str], Optional[str]], +] +VerifyTypes = Union[str, bool, ssl.SSLContext] +TimeoutTypes = Union[ + Optional[float], + Tuple[Optional[float], Optional[float], Optional[float], Optional[float]], + "Timeout", +] +ProxiesTypes = Union[URLTypes, "Proxy", Dict[URLTypes, Union[None, URLTypes, "Proxy"]]] + +AuthTypes = Union[ + Tuple[Union[str, bytes], Union[str, bytes]], + Callable[["Request"], "Request"], + "Auth", +] + +RequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] +ResponseContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] +ResponseExtensions = MutableMapping[str, Any] + +RequestData = Mapping[str, Any] + +FileContent = Union[IO[bytes], bytes, str] +FileTypes = Union[ + # file (or bytes) + FileContent, + # (filename, file (or bytes)) + Tuple[Optional[str], FileContent], + # (filename, file (or bytes), content_type) + Tuple[Optional[str], FileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], +] +RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] + +RequestExtensions = MutableMapping[str, Any] + + +class SyncByteStream: + def __iter__(self) -> Iterator[bytes]: + raise NotImplementedError( + "The '__iter__' method must be implemented." + ) # pragma: no cover + yield b"" # pragma: no cover + + def close(self) -> None: + """ + Subclasses can override this method to release any network resources + after a request/response cycle is complete. + """ + + +class AsyncByteStream: + async def __aiter__(self) -> AsyncIterator[bytes]: + raise NotImplementedError( + "The '__aiter__' method must be implemented." + ) # pragma: no cover + yield b"" # pragma: no cover + + async def aclose(self) -> None: + pass diff --git a/contrib/python/httpx/httpx/_urlparse.py b/contrib/python/httpx/httpx/_urlparse.py new file mode 100644 index 0000000000..e1ba8dcdb7 --- /dev/null +++ b/contrib/python/httpx/httpx/_urlparse.py @@ -0,0 +1,464 @@ +""" +An implementation of `urlparse` that provides URL validation and normalization +as described by RFC3986. + +We rely on this implementation rather than the one in Python's stdlib, because: + +* It provides more complete URL validation. +* It properly differentiates between an empty querystring and an absent querystring, + to distinguish URLs with a trailing '?'. +* It handles scheme, hostname, port, and path normalization. +* It supports IDNA hostnames, normalizing them to their encoded form. +* The API supports passing individual components, as well as the complete URL string. + +Previously we relied on the excellent `rfc3986` package to handle URL parsing and +validation, but this module provides a simpler alternative, with less indirection +required. +""" +import ipaddress +import re +import typing + +import idna + +from ._exceptions import InvalidURL + +MAX_URL_LENGTH = 65536 + +# https://datatracker.ietf.org/doc/html/rfc3986.html#section-2.3 +UNRESERVED_CHARACTERS = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" +) +SUB_DELIMS = "!$&'()*+,;=" + +PERCENT_ENCODED_REGEX = re.compile("%[A-Fa-f0-9]{2}") + + +# {scheme}: (optional) +# //{authority} (optional) +# {path} +# ?{query} (optional) +# #{fragment} (optional) +URL_REGEX = re.compile( + ( + r"(?:(?P<scheme>{scheme}):)?" + r"(?://(?P<authority>{authority}))?" + r"(?P<path>{path})" + r"(?:\?(?P<query>{query}))?" + r"(?:#(?P<fragment>{fragment}))?" + ).format( + scheme="([a-zA-Z][a-zA-Z0-9+.-]*)?", + authority="[^/?#]*", + path="[^?#]*", + query="[^#]*", + fragment=".*", + ) +) + +# {userinfo}@ (optional) +# {host} +# :{port} (optional) +AUTHORITY_REGEX = re.compile( + ( + r"(?:(?P<userinfo>{userinfo})@)?" r"(?P<host>{host})" r":?(?P<port>{port})?" + ).format( + userinfo="[^@]*", # Any character sequence not including '@'. + host="(\\[.*\\]|[^:]*)", # Either any character sequence not including ':', + # or an IPv6 address enclosed within square brackets. + port=".*", # Any character sequence. + ) +) + + +# If we call urlparse with an individual component, then we need to regex +# validate that component individually. +# Note that we're duplicating the same strings as above. Shock! Horror!! +COMPONENT_REGEX = { + "scheme": re.compile("([a-zA-Z][a-zA-Z0-9+.-]*)?"), + "authority": re.compile("[^/?#]*"), + "path": re.compile("[^?#]*"), + "query": re.compile("[^#]*"), + "fragment": re.compile(".*"), + "userinfo": re.compile("[^@]*"), + "host": re.compile("(\\[.*\\]|[^:]*)"), + "port": re.compile(".*"), +} + + +# We use these simple regexs as a first pass before handing off to +# the stdlib 'ipaddress' module for IP address validation. +IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+.[0-9]+.[0-9]+.[0-9]+$") +IPv6_STYLE_HOSTNAME = re.compile(r"^\[.*\]$") + + +class ParseResult(typing.NamedTuple): + scheme: str + userinfo: str + host: str + port: typing.Optional[int] + path: str + query: typing.Optional[str] + fragment: typing.Optional[str] + + @property + def authority(self) -> str: + return "".join( + [ + f"{self.userinfo}@" if self.userinfo else "", + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + @property + def netloc(self) -> str: + return "".join( + [ + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + def copy_with(self, **kwargs: typing.Optional[str]) -> "ParseResult": + if not kwargs: + return self + + defaults = { + "scheme": self.scheme, + "authority": self.authority, + "path": self.path, + "query": self.query, + "fragment": self.fragment, + } + defaults.update(kwargs) + return urlparse("", **defaults) + + def __str__(self) -> str: + authority = self.authority + return "".join( + [ + f"{self.scheme}:" if self.scheme else "", + f"//{authority}" if authority else "", + self.path, + f"?{self.query}" if self.query is not None else "", + f"#{self.fragment}" if self.fragment is not None else "", + ] + ) + + +def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult: + # Initial basic checks on allowable URLs. + # --------------------------------------- + + # Hard limit the maximum allowable URL length. + if len(url) > MAX_URL_LENGTH: + raise InvalidURL("URL too long") + + # If a URL includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in url): + raise InvalidURL("Invalid non-printable ASCII character in URL") + + # Some keyword arguments require special handling. + # ------------------------------------------------ + + # Coerce "port" to a string, if it is provided as an integer. + if "port" in kwargs: + port = kwargs["port"] + kwargs["port"] = str(port) if isinstance(port, int) else port + + # Replace "netloc" with "host and "port". + if "netloc" in kwargs: + netloc = kwargs.pop("netloc") or "" + kwargs["host"], _, kwargs["port"] = netloc.partition(":") + + # Replace "username" and/or "password" with "userinfo". + if "username" in kwargs or "password" in kwargs: + username = quote(kwargs.pop("username", "") or "") + password = quote(kwargs.pop("password", "") or "") + kwargs["userinfo"] = f"{username}:{password}" if password else username + + # Replace "raw_path" with "path" and "query". + if "raw_path" in kwargs: + raw_path = kwargs.pop("raw_path") or "" + kwargs["path"], seperator, kwargs["query"] = raw_path.partition("?") + if not seperator: + kwargs["query"] = None + + # Ensure that IPv6 "host" addresses are always escaped with "[...]". + if "host" in kwargs: + host = kwargs.get("host") or "" + if ":" in host and not (host.startswith("[") and host.endswith("]")): + kwargs["host"] = f"[{host}]" + + # If any keyword arguments are provided, ensure they are valid. + # ------------------------------------------------------------- + + for key, value in kwargs.items(): + if value is not None: + if len(value) > MAX_URL_LENGTH: + raise InvalidURL(f"URL component '{key}' too long") + + # If a component includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in value): + raise InvalidURL( + f"Invalid non-printable ASCII character in URL component '{key}'" + ) + + # Ensure that keyword arguments match as a valid regex. + if not COMPONENT_REGEX[key].fullmatch(value): + raise InvalidURL(f"Invalid URL component '{key}'") + + # The URL_REGEX will always match, but may have empty components. + url_match = URL_REGEX.match(url) + assert url_match is not None + url_dict = url_match.groupdict() + + # * 'scheme', 'authority', and 'path' may be empty strings. + # * 'query' may be 'None', indicating no trailing "?" portion. + # Any string including the empty string, indicates a trailing "?". + # * 'fragment' may be 'None', indicating no trailing "#" portion. + # Any string including the empty string, indicates a trailing "#". + scheme = kwargs.get("scheme", url_dict["scheme"]) or "" + authority = kwargs.get("authority", url_dict["authority"]) or "" + path = kwargs.get("path", url_dict["path"]) or "" + query = kwargs.get("query", url_dict["query"]) + fragment = kwargs.get("fragment", url_dict["fragment"]) + + # The AUTHORITY_REGEX will always match, but may have empty components. + authority_match = AUTHORITY_REGEX.match(authority) + assert authority_match is not None + authority_dict = authority_match.groupdict() + + # * 'userinfo' and 'host' may be empty strings. + # * 'port' may be 'None'. + userinfo = kwargs.get("userinfo", authority_dict["userinfo"]) or "" + host = kwargs.get("host", authority_dict["host"]) or "" + port = kwargs.get("port", authority_dict["port"]) + + # Normalize and validate each component. + # We end up with a parsed representation of the URL, + # with components that are plain ASCII bytestrings. + parsed_scheme: str = scheme.lower() + parsed_userinfo: str = quote(userinfo, safe=SUB_DELIMS + ":") + parsed_host: str = encode_host(host) + parsed_port: typing.Optional[int] = normalize_port(port, scheme) + + has_scheme = parsed_scheme != "" + has_authority = ( + parsed_userinfo != "" or parsed_host != "" or parsed_port is not None + ) + validate_path(path, has_scheme=has_scheme, has_authority=has_authority) + if has_authority: + path = normalize_path(path) + + # The GEN_DELIMS set is... : / ? # [ ] @ + # These do not need to be percent-quoted unless they serve as delimiters for the + # specific component. + + # For 'path' we need to drop ? and # from the GEN_DELIMS set. + parsed_path: str = quote(path, safe=SUB_DELIMS + ":/[]@") + # For 'query' we need to drop '#' from the GEN_DELIMS set. + # We also exclude '/' because it is more robust to replace it with a percent + # encoding despite it not being a requirement of the spec. + parsed_query: typing.Optional[str] = ( + None if query is None else quote(query, safe=SUB_DELIMS + ":?[]@") + ) + # For 'fragment' we can include all of the GEN_DELIMS set. + parsed_fragment: typing.Optional[str] = ( + None if fragment is None else quote(fragment, safe=SUB_DELIMS + ":/?#[]@") + ) + + # The parsed ASCII bytestrings are our canonical form. + # All properties of the URL are derived from these. + return ParseResult( + parsed_scheme, + parsed_userinfo, + parsed_host, + parsed_port, + parsed_path, + parsed_query, + parsed_fragment, + ) + + +def encode_host(host: str) -> str: + if not host: + return "" + + elif IPv4_STYLE_HOSTNAME.match(host): + # Validate IPv4 hostnames like #.#.#.# + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet + try: + ipaddress.IPv4Address(host) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv4 address: {host!r}") + return host + + elif IPv6_STYLE_HOSTNAME.match(host): + # Validate IPv6 hostnames like [...] + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # "A host identified by an Internet Protocol literal address, version 6 + # [RFC3513] or later, is distinguished by enclosing the IP literal + # within square brackets ("[" and "]"). This is the only place where + # square bracket characters are allowed in the URI syntax." + try: + ipaddress.IPv6Address(host[1:-1]) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv6 address: {host!r}") + return host[1:-1] + + elif host.isascii(): + # Regular ASCII hostnames + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # reg-name = *( unreserved / pct-encoded / sub-delims ) + return quote(host.lower(), safe=SUB_DELIMS) + + # IDNA hostnames + try: + return idna.encode(host.lower()).decode("ascii") + except idna.IDNAError: + raise InvalidURL(f"Invalid IDNA hostname: {host!r}") + + +def normalize_port( + port: typing.Optional[typing.Union[str, int]], scheme: str +) -> typing.Optional[int]: + # From https://tools.ietf.org/html/rfc3986#section-3.2.3 + # + # "A scheme may define a default port. For example, the "http" scheme + # defines a default port of "80", corresponding to its reserved TCP + # port number. The type of port designated by the port number (e.g., + # TCP, UDP, SCTP) is defined by the URI scheme. URI producers and + # normalizers should omit the port component and its ":" delimiter if + # port is empty or if its value would be the same as that of the + # scheme's default." + if port is None or port == "": + return None + + try: + port_as_int = int(port) + except ValueError: + raise InvalidURL(f"Invalid port: {port!r}") + + # See https://url.spec.whatwg.org/#url-miscellaneous + default_port = {"ftp": 21, "http": 80, "https": 443, "ws": 80, "wss": 443}.get( + scheme + ) + if port_as_int == default_port: + return None + return port_as_int + + +def validate_path(path: str, has_scheme: bool, has_authority: bool) -> None: + """ + Path validation rules that depend on if the URL contains a scheme or authority component. + + See https://datatracker.ietf.org/doc/html/rfc3986.html#section-3.3 + """ + if has_authority: + # > If a URI contains an authority component, then the path component + # > must either be empty or begin with a slash ("/") character." + if path and not path.startswith("/"): + raise InvalidURL("For absolute URLs, path must be empty or begin with '/'") + else: + # > If a URI does not contain an authority component, then the path cannot begin + # > with two slash characters ("//"). + if path.startswith("//"): + raise InvalidURL( + "URLs with no authority component cannot have a path starting with '//'" + ) + # > In addition, a URI reference (Section 4.1) may be a relative-path reference, in which + # > case the first path segment cannot contain a colon (":") character. + if path.startswith(":") and not has_scheme: + raise InvalidURL( + "URLs with no scheme component cannot have a path starting with ':'" + ) + + +def normalize_path(path: str) -> str: + """ + Drop "." and ".." segments from a URL path. + + For example: + + normalize_path("/path/./to/somewhere/..") == "/path/to" + """ + # https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4 + components = path.split("/") + output: typing.List[str] = [] + for component in components: + if component == ".": + pass + elif component == "..": + if output and output != [""]: + output.pop() + else: + output.append(component) + return "/".join(output) + + +def percent_encode(char: str) -> str: + """ + Replace a single character with the percent-encoded representation. + + Characters outside the ASCII range are represented with their a percent-encoded + representation of their UTF-8 byte sequence. + + For example: + + percent_encode(" ") == "%20" + """ + return "".join([f"%{byte:02x}" for byte in char.encode("utf-8")]).upper() + + +def is_safe(string: str, safe: str = "/") -> bool: + """ + Determine if a given string is already quote-safe. + """ + NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe + "%" + + # All characters must already be non-escaping or '%' + for char in string: + if char not in NON_ESCAPED_CHARS: + return False + + # Any '%' characters must be valid '%xx' escape sequences. + return string.count("%") == len(PERCENT_ENCODED_REGEX.findall(string)) + + +def quote(string: str, safe: str = "/") -> str: + """ + Use percent-encoding to quote a string if required. + """ + if is_safe(string, safe=safe): + return string + + NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe + return "".join( + [char if char in NON_ESCAPED_CHARS else percent_encode(char) for char in string] + ) + + +def urlencode(items: typing.List[typing.Tuple[str, str]]) -> str: + # We can use a much simpler version of the stdlib urlencode here because + # we don't need to handle a bunch of different typing cases, such as bytes vs str. + # + # https://github.com/python/cpython/blob/b2f7b2ef0b5421e01efb8c7bee2ef95d3bab77eb/Lib/urllib/parse.py#L926 + # + # Note that we use '%20' encoding for spaces. and '%2F for '/'. + # This is slightly different than `requests`, but is the behaviour that browsers use. + # + # See + # - https://github.com/encode/httpx/issues/2536 + # - https://github.com/encode/httpx/issues/2721 + # - https://docs.python.org/3/library/urllib.parse.html#urllib.parse.urlencode + return "&".join([quote(k, safe="") + "=" + quote(v, safe="") for k, v in items]) diff --git a/contrib/python/httpx/httpx/_urls.py b/contrib/python/httpx/httpx/_urls.py new file mode 100644 index 0000000000..b023941b62 --- /dev/null +++ b/contrib/python/httpx/httpx/_urls.py @@ -0,0 +1,642 @@ +import typing +from urllib.parse import parse_qs, unquote + +import idna + +from ._types import QueryParamTypes, RawURL, URLTypes +from ._urlparse import urlencode, urlparse +from ._utils import primitive_value_to_str + + +class URL: + """ + url = httpx.URL("HTTPS://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink") + + assert url.scheme == "https" + assert url.username == "jo@email.com" + assert url.password == "a secret" + assert url.userinfo == b"jo%40email.com:a%20secret" + assert url.host == "müller.de" + assert url.raw_host == b"xn--mller-kva.de" + assert url.port == 1234 + assert url.netloc == b"xn--mller-kva.de:1234" + assert url.path == "/pa th" + assert url.query == b"?search=ab" + assert url.raw_path == b"/pa%20th?search=ab" + assert url.fragment == "anchorlink" + + The components of a URL are broken down like this: + + https://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink + [scheme] [ username ] [password] [ host ][port][ path ] [ query ] [fragment] + [ userinfo ] [ netloc ][ raw_path ] + + Note that: + + * `url.scheme` is normalized to always be lowercased. + + * `url.host` is normalized to always be lowercased. Internationalized domain + names are represented in unicode, without IDNA encoding applied. For instance: + + url = httpx.URL("http://ä¸å›½.icom.museum") + assert url.host == "ä¸å›½.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "ä¸å›½.icom.museum" + + * `url.raw_host` is normalized to always be lowercased, and is IDNA encoded. + + url = httpx.URL("http://ä¸å›½.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + * `url.port` is either None or an integer. URLs that include the default port for + "http", "https", "ws", "wss", and "ftp" schemes have their port normalized to `None`. + + assert httpx.URL("http://example.com") == httpx.URL("http://example.com:80") + assert httpx.URL("http://example.com").port is None + assert httpx.URL("http://example.com:80").port is None + + * `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work with + `url.username` and `url.password` instead, which handle the URL escaping. + + * `url.raw_path` is raw bytes of both the path and query, without URL escaping. + This portion is used as the target when constructing HTTP requests. Usually you'll + want to work with `url.path` instead. + + * `url.query` is raw bytes, without URL escaping. A URL query string portion can only + be properly URL escaped when decoding the parameter names and values themselves. + """ + + def __init__( + self, url: typing.Union["URL", str] = "", **kwargs: typing.Any + ) -> None: + if kwargs: + allowed = { + "scheme": str, + "username": str, + "password": str, + "userinfo": bytes, + "host": str, + "port": int, + "netloc": bytes, + "path": str, + "query": bytes, + "raw_path": bytes, + "fragment": str, + "params": object, + } + + # Perform type checking for all supported keyword arguments. + for key, value in kwargs.items(): + if key not in allowed: + message = f"{key!r} is an invalid keyword argument for URL()" + raise TypeError(message) + if value is not None and not isinstance(value, allowed[key]): + expected = allowed[key].__name__ + seen = type(value).__name__ + message = f"Argument {key!r} must be {expected} but got {seen}" + raise TypeError(message) + if isinstance(value, bytes): + kwargs[key] = value.decode("ascii") + + if "params" in kwargs: + # Replace any "params" keyword with the raw "query" instead. + # + # Ensure that empty params use `kwargs["query"] = None` rather + # than `kwargs["query"] = ""`, so that generated URLs do not + # include an empty trailing "?". + params = kwargs.pop("params") + kwargs["query"] = None if not params else str(QueryParams(params)) + + if isinstance(url, str): + self._uri_reference = urlparse(url, **kwargs) + elif isinstance(url, URL): + self._uri_reference = url._uri_reference.copy_with(**kwargs) + else: + raise TypeError( + f"Invalid type for url. Expected str or httpx.URL, got {type(url)}: {url!r}" + ) + + @property + def scheme(self) -> str: + """ + The URL scheme, such as "http", "https". + Always normalised to lowercase. + """ + return self._uri_reference.scheme + + @property + def raw_scheme(self) -> bytes: + """ + The raw bytes representation of the URL scheme, such as b"http", b"https". + Always normalised to lowercase. + """ + return self._uri_reference.scheme.encode("ascii") + + @property + def userinfo(self) -> bytes: + """ + The URL userinfo as a raw bytestring. + For example: b"jo%40email.com:a%20secret". + """ + return self._uri_reference.userinfo.encode("ascii") + + @property + def username(self) -> str: + """ + The URL username as a string, with URL decoding applied. + For example: "jo@email.com" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[0]) + + @property + def password(self) -> str: + """ + The URL password as a string, with URL decoding applied. + For example: "a secret" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[2]) + + @property + def host(self) -> str: + """ + The URL host as a string. + Always normalized to lowercase, with IDNA hosts decoded into unicode. + + Examples: + + url = httpx.URL("http://www.EXAMPLE.org") + assert url.host == "www.example.org" + + url = httpx.URL("http://ä¸å›½.icom.museum") + assert url.host == "ä¸å›½.icom.museum" + + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "ä¸å›½.icom.museum" + + url = httpx.URL("https://[::ffff:192.168.0.1]") + assert url.host == "::ffff:192.168.0.1" + """ + host: str = self._uri_reference.host + + if host.startswith("xn--"): + host = idna.decode(host) + + return host + + @property + def raw_host(self) -> bytes: + """ + The raw bytes representation of the URL host. + Always normalized to lowercase, and IDNA encoded. + + Examples: + + url = httpx.URL("http://www.EXAMPLE.org") + assert url.raw_host == b"www.example.org" + + url = httpx.URL("http://ä¸å›½.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + url = httpx.URL("https://[::ffff:192.168.0.1]") + assert url.raw_host == b"::ffff:192.168.0.1" + """ + return self._uri_reference.host.encode("ascii") + + @property + def port(self) -> typing.Optional[int]: + """ + The URL port as an integer. + + Note that the URL class performs port normalization as per the WHATWG spec. + Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always + treated as `None`. + + For example: + + assert httpx.URL("http://www.example.com") == httpx.URL("http://www.example.com:80") + assert httpx.URL("http://www.example.com:80").port is None + """ + return self._uri_reference.port + + @property + def netloc(self) -> bytes: + """ + Either `<host>` or `<host>:<port>` as bytes. + Always normalized to lowercase, and IDNA encoded. + + This property may be used for generating the value of a request + "Host" header. + """ + return self._uri_reference.netloc.encode("ascii") + + @property + def path(self) -> str: + """ + The URL path as a string. Excluding the query string, and URL decoded. + + For example: + + url = httpx.URL("https://example.com/pa%20th") + assert url.path == "/pa th" + """ + path = self._uri_reference.path or "/" + return unquote(path) + + @property + def query(self) -> bytes: + """ + The URL query string, as raw bytes, excluding the leading b"?". + + This is necessarily a bytewise interface, because we cannot + perform URL decoding of this representation until we've parsed + the keys and values into a QueryParams instance. + + For example: + + url = httpx.URL("https://example.com/?filter=some%20search%20terms") + assert url.query == b"filter=some%20search%20terms" + """ + query = self._uri_reference.query or "" + return query.encode("ascii") + + @property + def params(self) -> "QueryParams": + """ + The URL query parameters, neatly parsed and packaged into an immutable + multidict representation. + """ + return QueryParams(self._uri_reference.query) + + @property + def raw_path(self) -> bytes: + """ + The complete URL path and query string as raw bytes. + Used as the target when constructing HTTP requests. + + For example: + + GET /users?search=some%20text HTTP/1.1 + Host: www.example.org + Connection: close + """ + path = self._uri_reference.path or "/" + if self._uri_reference.query is not None: + path += "?" + self._uri_reference.query + return path.encode("ascii") + + @property + def fragment(self) -> str: + """ + The URL fragments, as used in HTML anchors. + As a string, without the leading '#'. + """ + return unquote(self._uri_reference.fragment or "") + + @property + def raw(self) -> RawURL: + """ + Provides the (scheme, host, port, target) for the outgoing request. + + In older versions of `httpx` this was used in the low-level transport API. + We no longer use `RawURL`, and this property will be deprecated in a future release. + """ + return RawURL( + self.raw_scheme, + self.raw_host, + self.port, + self.raw_path, + ) + + @property + def is_absolute_url(self) -> bool: + """ + Return `True` for absolute URLs such as 'http://example.com/path', + and `False` for relative URLs such as '/path'. + """ + # We don't use `.is_absolute` from `rfc3986` because it treats + # URLs with a fragment portion as not absolute. + # What we actually care about is if the URL provides + # a scheme and hostname to which connections should be made. + return bool(self._uri_reference.scheme and self._uri_reference.host) + + @property + def is_relative_url(self) -> bool: + """ + Return `False` for absolute URLs such as 'http://example.com/path', + and `True` for relative URLs such as '/path'. + """ + return not self.is_absolute_url + + def copy_with(self, **kwargs: typing.Any) -> "URL": + """ + Copy this URL, returning a new URL with some components altered. + Accepts the same set of parameters as the components that are made + available via properties on the `URL` class. + + For example: + + url = httpx.URL("https://www.example.com").copy_with(username="jo@gmail.com", password="a secret") + assert url == "https://jo%40email.com:a%20secret@www.example.com" + """ + return URL(self, **kwargs) + + def copy_set_param(self, key: str, value: typing.Any = None) -> "URL": + return self.copy_with(params=self.params.set(key, value)) + + def copy_add_param(self, key: str, value: typing.Any = None) -> "URL": + return self.copy_with(params=self.params.add(key, value)) + + def copy_remove_param(self, key: str) -> "URL": + return self.copy_with(params=self.params.remove(key)) + + def copy_merge_params(self, params: QueryParamTypes) -> "URL": + return self.copy_with(params=self.params.merge(params)) + + def join(self, url: URLTypes) -> "URL": + """ + Return an absolute URL, using this URL as the base. + + Eg. + + url = httpx.URL("https://www.example.com/test") + url = url.join("/new/path") + assert url == "https://www.example.com/new/path" + """ + from urllib.parse import urljoin + + return URL(urljoin(str(self), str(URL(url)))) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, (URL, str)) and str(self) == str(URL(other)) + + def __str__(self) -> str: + return str(self._uri_reference) + + def __repr__(self) -> str: + scheme, userinfo, host, port, path, query, fragment = self._uri_reference + + if ":" in userinfo: + # Mask any password component. + userinfo = f'{userinfo.split(":")[0]}:[secure]' + + authority = "".join( + [ + f"{userinfo}@" if userinfo else "", + f"[{host}]" if ":" in host else host, + f":{port}" if port is not None else "", + ] + ) + url = "".join( + [ + f"{self.scheme}:" if scheme else "", + f"//{authority}" if authority else "", + path, + f"?{query}" if query is not None else "", + f"#{fragment}" if fragment is not None else "", + ] + ) + + return f"{self.__class__.__name__}({url!r})" + + +class QueryParams(typing.Mapping[str, str]): + """ + URL query parameters, as a multi-dict. + """ + + def __init__( + self, *args: typing.Optional[QueryParamTypes], **kwargs: typing.Any + ) -> None: + assert len(args) < 2, "Too many arguments." + assert not (args and kwargs), "Cannot mix named and unnamed arguments." + + value = args[0] if args else kwargs + + if value is None or isinstance(value, (str, bytes)): + value = value.decode("ascii") if isinstance(value, bytes) else value + self._dict = parse_qs(value, keep_blank_values=True) + elif isinstance(value, QueryParams): + self._dict = {k: list(v) for k, v in value._dict.items()} + else: + dict_value: typing.Dict[typing.Any, typing.List[typing.Any]] = {} + if isinstance(value, (list, tuple)): + # Convert list inputs like: + # [("a", "123"), ("a", "456"), ("b", "789")] + # To a dict representation, like: + # {"a": ["123", "456"], "b": ["789"]} + for item in value: + dict_value.setdefault(item[0], []).append(item[1]) + else: + # Convert dict inputs like: + # {"a": "123", "b": ["456", "789"]} + # To dict inputs where values are always lists, like: + # {"a": ["123"], "b": ["456", "789"]} + dict_value = { + k: list(v) if isinstance(v, (list, tuple)) else [v] + for k, v in value.items() + } + + # Ensure that keys and values are neatly coerced to strings. + # We coerce values `True` and `False` to JSON-like "true" and "false" + # representations, and coerce `None` values to the empty string. + self._dict = { + str(k): [primitive_value_to_str(item) for item in v] + for k, v in dict_value.items() + } + + def keys(self) -> typing.KeysView[str]: + """ + Return all the keys in the query params. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.keys()) == ["a", "b"] + """ + return self._dict.keys() + + def values(self) -> typing.ValuesView[str]: + """ + Return all the values in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.values()) == ["123", "789"] + """ + return {k: v[0] for k, v in self._dict.items()}.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return all items in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.items()) == [("a", "123"), ("b", "789")] + """ + return {k: v[0] for k, v in self._dict.items()}.items() + + def multi_items(self) -> typing.List[typing.Tuple[str, str]]: + """ + Return all items in the query params. Allow duplicate keys to occur. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")] + """ + multi_items: typing.List[typing.Tuple[str, str]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items + + def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + """ + Get a value from the query param for a given key. If the key occurs + more than once, then only the first value is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get("a") == "123" + """ + if key in self._dict: + return self._dict[str(key)][0] + return default + + def get_list(self, key: str) -> typing.List[str]: + """ + Get all values from the query param for a given key. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get_list("a") == ["123", "456"] + """ + return list(self._dict.get(str(key), [])) + + def set(self, key: str, value: typing.Any = None) -> "QueryParams": + """ + Return a new QueryParams instance, setting the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.set("a", "456") + assert q == httpx.QueryParams("a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[str(key)] = [primitive_value_to_str(value)] + return q + + def add(self, key: str, value: typing.Any = None) -> "QueryParams": + """ + Return a new QueryParams instance, setting or appending the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.add("a", "456") + assert q == httpx.QueryParams("a=123&a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)] + return q + + def remove(self, key: str) -> "QueryParams": + """ + Return a new QueryParams instance, removing the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.remove("a") + assert q == httpx.QueryParams("") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict.pop(str(key), None) + return q + + def merge(self, params: typing.Optional[QueryParamTypes] = None) -> "QueryParams": + """ + Return a new QueryParams instance, updated with. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.merge({"b": "456"}) + assert q == httpx.QueryParams("a=123&b=456") + + q = httpx.QueryParams("a=123") + q = q.merge({"a": "456", "b": "789"}) + assert q == httpx.QueryParams("a=456&b=789") + """ + q = QueryParams(params) + q._dict = {**self._dict, **q._dict} + return q + + def __getitem__(self, key: typing.Any) -> str: + return self._dict[key][0] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, self.__class__): + return False + return sorted(self.multi_items()) == sorted(other.multi_items()) + + def __str__(self) -> str: + """ + Note that we use '%20' encoding for spaces, and treat '/' as a safe + character. + + See https://github.com/encode/httpx/issues/2536 and + https://docs.python.org/3/library/urllib.parse.html#urllib.parse.urlencode + """ + return urlencode(self.multi_items()) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + query_string = str(self) + return f"{class_name}({query_string!r})" + + def update(self, params: typing.Optional[QueryParamTypes] = None) -> None: + raise RuntimeError( + "QueryParams are immutable since 0.18.0. " + "Use `q = q.merge(...)` to create an updated copy." + ) + + def __setitem__(self, key: str, value: str) -> None: + raise RuntimeError( + "QueryParams are immutable since 0.18.0. " + "Use `q = q.set(key, value)` to create an updated copy." + ) diff --git a/contrib/python/httpx/httpx/_utils.py b/contrib/python/httpx/httpx/_utils.py new file mode 100644 index 0000000000..1775b1a1ef --- /dev/null +++ b/contrib/python/httpx/httpx/_utils.py @@ -0,0 +1,477 @@ +import codecs +import email.message +import ipaddress +import mimetypes +import os +import re +import time +import typing +from pathlib import Path +from urllib.request import getproxies + +import sniffio + +from ._types import PrimitiveData + +if typing.TYPE_CHECKING: # pragma: no cover + from ._urls import URL + + +_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} +_HTML5_FORM_ENCODING_REPLACEMENTS.update( + {chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B} +) +_HTML5_FORM_ENCODING_RE = re.compile( + r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()]) +) + + +def normalize_header_key( + value: typing.Union[str, bytes], + lower: bool, + encoding: typing.Optional[str] = None, +) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header key. + """ + if isinstance(value, bytes): + bytes_value = value + else: + bytes_value = value.encode(encoding or "ascii") + + return bytes_value.lower() if lower else bytes_value + + +def normalize_header_value( + value: typing.Union[str, bytes], encoding: typing.Optional[str] = None +) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header value. + """ + if isinstance(value, bytes): + return value + return value.encode(encoding or "ascii") + + +def primitive_value_to_str(value: "PrimitiveData") -> str: + """ + Coerce a primitive data type into a string value. + + Note that we prefer JSON-style 'true'/'false' for boolean values here. + """ + if value is True: + return "true" + elif value is False: + return "false" + elif value is None: + return "" + return str(value) + + +def is_known_encoding(encoding: str) -> bool: + """ + Return `True` if `encoding` is a known codec. + """ + try: + codecs.lookup(encoding) + except LookupError: + return False + return True + + +def format_form_param(name: str, value: str) -> bytes: + """ + Encode a name/value pair within a multipart form. + """ + + def replacer(match: typing.Match[str]) -> str: + return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)] + + value = _HTML5_FORM_ENCODING_RE.sub(replacer, value) + return f'{name}="{value}"'.encode() + + +# Null bytes; no need to recreate these on each call to guess_json_utf +_null = b"\x00" +_null2 = _null * 2 +_null3 = _null * 3 + + +def guess_json_utf(data: bytes) -> typing.Optional[str]: + # JSON always starts with two ASCII characters, so detection is as + # easy as counting the nulls and from their location and count + # determine the encoding. Also detect a BOM, if present. + sample = data[:4] + if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE): + return "utf-32" # BOM included + if sample[:3] == codecs.BOM_UTF8: + return "utf-8-sig" # BOM included, MS style (discouraged) + if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE): + return "utf-16" # BOM included + nullcount = sample.count(_null) + if nullcount == 0: + return "utf-8" + if nullcount == 2: + if sample[::2] == _null2: # 1st and 3rd are null + return "utf-16-be" + if sample[1::2] == _null2: # 2nd and 4th are null + return "utf-16-le" + # Did not detect 2 valid UTF-16 ascii-range characters + if nullcount == 3: + if sample[:3] == _null3: + return "utf-32-be" + if sample[1:] == _null3: + return "utf-32-le" + # Did not detect a valid UTF-32 ascii-range character + return None + + +def get_ca_bundle_from_env() -> typing.Optional[str]: + if "SSL_CERT_FILE" in os.environ: + ssl_file = Path(os.environ["SSL_CERT_FILE"]) + if ssl_file.is_file(): + return str(ssl_file) + if "SSL_CERT_DIR" in os.environ: + ssl_path = Path(os.environ["SSL_CERT_DIR"]) + if ssl_path.is_dir(): + return str(ssl_path) + return None + + +def parse_header_links(value: str) -> typing.List[typing.Dict[str, str]]: + """ + Returns a list of parsed link headers, for more info see: + https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link + The generic syntax of those is: + Link: < uri-reference >; param1=value1; param2="value2" + So for instance: + Link; '<http:/.../front.jpeg>; type="image/jpeg",<http://.../back.jpeg>;' + would return + [ + {"url": "http:/.../front.jpeg", "type": "image/jpeg"}, + {"url": "http://.../back.jpeg"}, + ] + :param value: HTTP Link entity-header field + :return: list of parsed link headers + """ + links: typing.List[typing.Dict[str, str]] = [] + replace_chars = " '\"" + value = value.strip(replace_chars) + if not value: + return links + for val in re.split(", *<", value): + try: + url, params = val.split(";", 1) + except ValueError: + url, params = val, "" + link = {"url": url.strip("<> '\"")} + for param in params.split(";"): + try: + key, value = param.split("=") + except ValueError: + break + link[key.strip(replace_chars)] = value.strip(replace_chars) + links.append(link) + return links + + +def parse_content_type_charset(content_type: str) -> typing.Optional[str]: + # We used to use `cgi.parse_header()` here, but `cgi` became a dead battery. + # See: https://peps.python.org/pep-0594/#cgi + msg = email.message.Message() + msg["content-type"] = content_type + return msg.get_content_charset(failobj=None) + + +SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} + + +def obfuscate_sensitive_headers( + items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]] +) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]: + for k, v in items: + if to_str(k.lower()) in SENSITIVE_HEADERS: + v = to_bytes_or_str("[secure]", match_type_of=v) + yield k, v + + +def port_or_default(url: "URL") -> typing.Optional[int]: + if url.port is not None: + return url.port + return {"http": 80, "https": 443}.get(url.scheme) + + +def same_origin(url: "URL", other: "URL") -> bool: + """ + Return 'True' if the given URLs share the same origin. + """ + return ( + url.scheme == other.scheme + and url.host == other.host + and port_or_default(url) == port_or_default(other) + ) + + +def is_https_redirect(url: "URL", location: "URL") -> bool: + """ + Return 'True' if 'location' is a HTTPS upgrade of 'url' + """ + if url.host != location.host: + return False + + return ( + url.scheme == "http" + and port_or_default(url) == 80 + and location.scheme == "https" + and port_or_default(location) == 443 + ) + + +def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]: + """Gets proxy information from the environment""" + + # urllib.request.getproxies() falls back on System + # Registry and Config for proxies on Windows and macOS. + # We don't want to propagate non-HTTP proxies into + # our configuration such as 'TRAVIS_APT_PROXY'. + proxy_info = getproxies() + mounts: typing.Dict[str, typing.Optional[str]] = {} + + for scheme in ("http", "https", "all"): + if proxy_info.get(scheme): + hostname = proxy_info[scheme] + mounts[f"{scheme}://"] = ( + hostname if "://" in hostname else f"http://{hostname}" + ) + + no_proxy_hosts = [host.strip() for host in proxy_info.get("no", "").split(",")] + for hostname in no_proxy_hosts: + # See https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html for details + # on how names in `NO_PROXY` are handled. + if hostname == "*": + # If NO_PROXY=* is used or if "*" occurs as any one of the comma + # separated hostnames, then we should just bypass any information + # from HTTP_PROXY, HTTPS_PROXY, ALL_PROXY, and always ignore + # proxies. + return {} + elif hostname: + # NO_PROXY=.google.com is marked as "all://*.google.com, + # which disables "www.google.com" but not "google.com" + # NO_PROXY=google.com is marked as "all://*google.com, + # which disables "www.google.com" and "google.com". + # (But not "wwwgoogle.com") + # NO_PROXY can include domains, IPv6, IPv4 addresses and "localhost" + # NO_PROXY=example.com,::1,localhost,192.168.0.0/16 + if is_ipv4_hostname(hostname): + mounts[f"all://{hostname}"] = None + elif is_ipv6_hostname(hostname): + mounts[f"all://[{hostname}]"] = None + elif hostname.lower() == "localhost": + mounts[f"all://{hostname}"] = None + else: + mounts[f"all://*{hostname}"] = None + + return mounts + + +def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes: + return value.encode(encoding) if isinstance(value, str) else value + + +def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str: + return value if isinstance(value, str) else value.decode(encoding) + + +def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr: + return value if isinstance(match_type_of, str) else value.encode() + + +def unquote(value: str) -> str: + return value[1:-1] if value[0] == value[-1] == '"' else value + + +def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]: + if filename: + return mimetypes.guess_type(filename)[0] or "application/octet-stream" + return None + + +def peek_filelike_length(stream: typing.Any) -> typing.Optional[int]: + """ + Given a file-like stream object, return its length in number of bytes + without reading it into memory. + """ + try: + # Is it an actual file? + fd = stream.fileno() + # Yup, seems to be an actual file. + length = os.fstat(fd).st_size + except (AttributeError, OSError): + # No... Maybe it's something that supports random access, like `io.BytesIO`? + try: + # Assuming so, go to end of stream to figure out its length, + # then put it back in place. + offset = stream.tell() + length = stream.seek(0, os.SEEK_END) + stream.seek(offset) + except (AttributeError, OSError): + # Not even that? Sorry, we're doomed... + return None + + return length + + +class Timer: + async def _get_time(self) -> float: + library = sniffio.current_async_library() + if library == "trio": + import trio + + return trio.current_time() + elif library == "curio": # pragma: no cover + import curio + + return typing.cast(float, await curio.clock()) + + import asyncio + + return asyncio.get_event_loop().time() + + def sync_start(self) -> None: + self.started = time.perf_counter() + + async def async_start(self) -> None: + self.started = await self._get_time() + + def sync_elapsed(self) -> float: + now = time.perf_counter() + return now - self.started + + async def async_elapsed(self) -> float: + now = await self._get_time() + return now - self.started + + +class URLPattern: + """ + A utility class currently used for making lookups against proxy keys... + + # Wildcard matching... + >>> pattern = URLPattern("all://") + >>> pattern.matches(httpx.URL("http://example.com")) + True + + # Witch scheme matching... + >>> pattern = URLPattern("https://") + >>> pattern.matches(httpx.URL("https://example.com")) + True + >>> pattern.matches(httpx.URL("http://example.com")) + False + + # With domain matching... + >>> pattern = URLPattern("https://example.com") + >>> pattern.matches(httpx.URL("https://example.com")) + True + >>> pattern.matches(httpx.URL("http://example.com")) + False + >>> pattern.matches(httpx.URL("https://other.com")) + False + + # Wildcard scheme, with domain matching... + >>> pattern = URLPattern("all://example.com") + >>> pattern.matches(httpx.URL("https://example.com")) + True + >>> pattern.matches(httpx.URL("http://example.com")) + True + >>> pattern.matches(httpx.URL("https://other.com")) + False + + # With port matching... + >>> pattern = URLPattern("https://example.com:1234") + >>> pattern.matches(httpx.URL("https://example.com:1234")) + True + >>> pattern.matches(httpx.URL("https://example.com")) + False + """ + + def __init__(self, pattern: str) -> None: + from ._urls import URL + + if pattern and ":" not in pattern: + raise ValueError( + f"Proxy keys should use proper URL forms rather " + f"than plain scheme strings. " + f'Instead of "{pattern}", use "{pattern}://"' + ) + + url = URL(pattern) + self.pattern = pattern + self.scheme = "" if url.scheme == "all" else url.scheme + self.host = "" if url.host == "*" else url.host + self.port = url.port + if not url.host or url.host == "*": + self.host_regex: typing.Optional[typing.Pattern[str]] = None + elif url.host.startswith("*."): + # *.example.com should match "www.example.com", but not "example.com" + domain = re.escape(url.host[2:]) + self.host_regex = re.compile(f"^.+\\.{domain}$") + elif url.host.startswith("*"): + # *example.com should match "www.example.com" and "example.com" + domain = re.escape(url.host[1:]) + self.host_regex = re.compile(f"^(.+\\.)?{domain}$") + else: + # example.com should match "example.com" but not "www.example.com" + domain = re.escape(url.host) + self.host_regex = re.compile(f"^{domain}$") + + def matches(self, other: "URL") -> bool: + if self.scheme and self.scheme != other.scheme: + return False + if ( + self.host + and self.host_regex is not None + and not self.host_regex.match(other.host) + ): + return False + if self.port is not None and self.port != other.port: + return False + return True + + @property + def priority(self) -> typing.Tuple[int, int, int]: + """ + The priority allows URLPattern instances to be sortable, so that + we can match from most specific to least specific. + """ + # URLs with a port should take priority over URLs without a port. + port_priority = 0 if self.port is not None else 1 + # Longer hostnames should match first. + host_priority = -len(self.host) + # Longer schemes should match first. + scheme_priority = -len(self.scheme) + return (port_priority, host_priority, scheme_priority) + + def __hash__(self) -> int: + return hash(self.pattern) + + def __lt__(self, other: "URLPattern") -> bool: + return self.priority < other.priority + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, URLPattern) and self.pattern == other.pattern + + +def is_ipv4_hostname(hostname: str) -> bool: + try: + ipaddress.IPv4Address(hostname.split("/")[0]) + except Exception: + return False + return True + + +def is_ipv6_hostname(hostname: str) -> bool: + try: + ipaddress.IPv6Address(hostname.split("/")[0]) + except Exception: + return False + return True diff --git a/contrib/python/httpx/httpx/py.typed b/contrib/python/httpx/httpx/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/httpx/httpx/py.typed diff --git a/contrib/python/httpx/ya.make b/contrib/python/httpx/ya.make new file mode 100644 index 0000000000..850e354ef0 --- /dev/null +++ b/contrib/python/httpx/ya.make @@ -0,0 +1,58 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(0.25.0) + +LICENSE(BSD-3-Clause) + +PEERDIR( + contrib/python/certifi + contrib/python/httpcore + contrib/python/idna + contrib/python/sniffio +) + +NO_LINT() + +NO_CHECK_IMPORTS( + httpx._main +) + +PY_SRCS( + TOP_LEVEL + httpx/__init__.py + httpx/__version__.py + httpx/_api.py + httpx/_auth.py + httpx/_client.py + httpx/_compat.py + httpx/_config.py + httpx/_content.py + httpx/_decoders.py + httpx/_exceptions.py + httpx/_main.py + httpx/_models.py + httpx/_multipart.py + httpx/_status_codes.py + httpx/_transports/__init__.py + httpx/_transports/asgi.py + httpx/_transports/base.py + httpx/_transports/default.py + httpx/_transports/mock.py + httpx/_transports/wsgi.py + httpx/_types.py + httpx/_urlparse.py + httpx/_urls.py + httpx/_utils.py +) + +RESOURCE_FILES( + PREFIX contrib/python/httpx/ + .dist-info/METADATA + .dist-info/entry_points.txt + .dist-info/top_level.txt + httpx/py.typed +) + +END() diff --git a/contrib/python/sniffio/.dist-info/METADATA b/contrib/python/sniffio/.dist-info/METADATA new file mode 100644 index 0000000000..22520c72af --- /dev/null +++ b/contrib/python/sniffio/.dist-info/METADATA @@ -0,0 +1,102 @@ +Metadata-Version: 2.1 +Name: sniffio +Version: 1.3.0 +Summary: Sniff out which async library your code is running under +Home-page: https://github.com/python-trio/sniffio +Author: Nathaniel J. Smith +Author-email: njs@pobox.com +License: MIT OR Apache-2.0 +Keywords: async,trio,asyncio +Classifier: License :: OSI Approved :: MIT License +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Framework :: Trio +Classifier: Framework :: AsyncIO +Classifier: Operating System :: POSIX :: Linux +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: Microsoft :: Windows +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Intended Audience :: Developers +Classifier: Development Status :: 5 - Production/Stable +Requires-Python: >=3.7 +License-File: LICENSE +License-File: LICENSE.APACHE2 +License-File: LICENSE.MIT + +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://sniffio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +.. image:: https://img.shields.io/pypi/v/sniffio.svg + :target: https://pypi.org/project/sniffio + :alt: Latest PyPi version + +.. image:: https://img.shields.io/conda/vn/conda-forge/sniffio.svg + :target: https://anaconda.org/conda-forge/sniffio + :alt: Latest conda-forge version + +.. image:: https://travis-ci.org/python-trio/sniffio.svg?branch=master + :target: https://travis-ci.org/python-trio/sniffio + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-trio/sniffio/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-trio/sniffio + :alt: Test coverage + +================================================================= +sniffio: Sniff out which async library your code is running under +================================================================= + +You're writing a library. You've decided to be ambitious, and support +multiple async I/O packages, like `Trio +<https://trio.readthedocs.io>`__, and `asyncio +<https://docs.python.org/3/library/asyncio.html>`__, and ... You've +written a bunch of clever code to handle all the differences. But... +how do you know *which* piece of clever code to run? + +This is a tiny package whose only purpose is to let you detect which +async library your code is running under. + +* Documentation: https://sniffio.readthedocs.io + +* Bug tracker and source code: https://github.com/python-trio/sniffio + +* License: MIT or Apache License 2.0, your choice + +* Contributor guide: https://trio.readthedocs.io/en/latest/contributing.html + +* Code of conduct: Contributors are requested to follow our `code of + conduct + <https://trio.readthedocs.io/en/latest/code-of-conduct.html>`_ + in all project spaces. + +This library is maintained by the Trio project, as a service to the +async Python community as a whole. + + +Quickstart +---------- + +.. code-block:: python3 + + from sniffio import current_async_library + import trio + import asyncio + + async def print_library(): + library = current_async_library() + print("This is:", library) + + # Prints "This is trio" + trio.run(print_library) + + # Prints "This is asyncio" + asyncio.run(print_library()) + +For more details, including how to add support to new async libraries, +`please peruse our fine manual <https://sniffio.readthedocs.io>`__. diff --git a/contrib/python/sniffio/.dist-info/top_level.txt b/contrib/python/sniffio/.dist-info/top_level.txt new file mode 100644 index 0000000000..01c650244d --- /dev/null +++ b/contrib/python/sniffio/.dist-info/top_level.txt @@ -0,0 +1 @@ +sniffio diff --git a/contrib/python/sniffio/LICENSE b/contrib/python/sniffio/LICENSE new file mode 100644 index 0000000000..51f3442917 --- /dev/null +++ b/contrib/python/sniffio/LICENSE @@ -0,0 +1,3 @@ +This software is made available under the terms of *either* of the +licenses found in LICENSE.APACHE2 or LICENSE.MIT. Contributions to are +made under the terms of *both* these licenses. diff --git a/contrib/python/sniffio/LICENSE.APACHE2 b/contrib/python/sniffio/LICENSE.APACHE2 new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/contrib/python/sniffio/LICENSE.APACHE2 @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrib/python/sniffio/LICENSE.MIT b/contrib/python/sniffio/LICENSE.MIT new file mode 100644 index 0000000000..b8bb971859 --- /dev/null +++ b/contrib/python/sniffio/LICENSE.MIT @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/contrib/python/sniffio/README.rst b/contrib/python/sniffio/README.rst new file mode 100644 index 0000000000..2a62cea5a0 --- /dev/null +++ b/contrib/python/sniffio/README.rst @@ -0,0 +1,76 @@ +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://sniffio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +.. image:: https://img.shields.io/pypi/v/sniffio.svg + :target: https://pypi.org/project/sniffio + :alt: Latest PyPi version + +.. image:: https://img.shields.io/conda/vn/conda-forge/sniffio.svg + :target: https://anaconda.org/conda-forge/sniffio + :alt: Latest conda-forge version + +.. image:: https://travis-ci.org/python-trio/sniffio.svg?branch=master + :target: https://travis-ci.org/python-trio/sniffio + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-trio/sniffio/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-trio/sniffio + :alt: Test coverage + +================================================================= +sniffio: Sniff out which async library your code is running under +================================================================= + +You're writing a library. You've decided to be ambitious, and support +multiple async I/O packages, like `Trio +<https://trio.readthedocs.io>`__, and `asyncio +<https://docs.python.org/3/library/asyncio.html>`__, and ... You've +written a bunch of clever code to handle all the differences. But... +how do you know *which* piece of clever code to run? + +This is a tiny package whose only purpose is to let you detect which +async library your code is running under. + +* Documentation: https://sniffio.readthedocs.io + +* Bug tracker and source code: https://github.com/python-trio/sniffio + +* License: MIT or Apache License 2.0, your choice + +* Contributor guide: https://trio.readthedocs.io/en/latest/contributing.html + +* Code of conduct: Contributors are requested to follow our `code of + conduct + <https://trio.readthedocs.io/en/latest/code-of-conduct.html>`_ + in all project spaces. + +This library is maintained by the Trio project, as a service to the +async Python community as a whole. + + +Quickstart +---------- + +.. code-block:: python3 + + from sniffio import current_async_library + import trio + import asyncio + + async def print_library(): + library = current_async_library() + print("This is:", library) + + # Prints "This is trio" + trio.run(print_library) + + # Prints "This is asyncio" + asyncio.run(print_library()) + +For more details, including how to add support to new async libraries, +`please peruse our fine manual <https://sniffio.readthedocs.io>`__. diff --git a/contrib/python/sniffio/sniffio/__init__.py b/contrib/python/sniffio/sniffio/__init__.py new file mode 100644 index 0000000000..fb3364d7f1 --- /dev/null +++ b/contrib/python/sniffio/sniffio/__init__.py @@ -0,0 +1,15 @@ +"""Top-level package for sniffio.""" + +__all__ = [ + "current_async_library", "AsyncLibraryNotFoundError", + "current_async_library_cvar" +] + +from ._version import __version__ + +from ._impl import ( + current_async_library, + AsyncLibraryNotFoundError, + current_async_library_cvar, + thread_local, +) diff --git a/contrib/python/sniffio/sniffio/_impl.py b/contrib/python/sniffio/sniffio/_impl.py new file mode 100644 index 0000000000..c1a7bbf218 --- /dev/null +++ b/contrib/python/sniffio/sniffio/_impl.py @@ -0,0 +1,95 @@ +from contextvars import ContextVar +from typing import Optional +import sys +import threading + +current_async_library_cvar = ContextVar( + "current_async_library_cvar", default=None +) # type: ContextVar[Optional[str]] + + +class _ThreadLocal(threading.local): + # Since threading.local provides no explicit mechanism is for setting + # a default for a value, a custom class with a class attribute is used + # instead. + name = None # type: Optional[str] + + +thread_local = _ThreadLocal() + + +class AsyncLibraryNotFoundError(RuntimeError): + pass + + +def current_async_library() -> str: + """Detect which async library is currently running. + + The following libraries are currently supported: + + ================ =========== ============================ + Library Requires Magic string + ================ =========== ============================ + **Trio** Trio v0.6+ ``"trio"`` + **Curio** - ``"curio"`` + **asyncio** ``"asyncio"`` + **Trio-asyncio** v0.8.2+ ``"trio"`` or ``"asyncio"``, + depending on current mode + ================ =========== ============================ + + Returns: + A string like ``"trio"``. + + Raises: + AsyncLibraryNotFoundError: if called from synchronous context, + or if the current async library was not recognized. + + Examples: + + .. code-block:: python3 + + from sniffio import current_async_library + + async def generic_sleep(seconds): + library = current_async_library() + if library == "trio": + import trio + await trio.sleep(seconds) + elif library == "asyncio": + import asyncio + await asyncio.sleep(seconds) + # ... and so on ... + else: + raise RuntimeError(f"Unsupported library {library!r}") + + """ + value = thread_local.name + if value is not None: + return value + + value = current_async_library_cvar.get() + if value is not None: + return value + + # Need to sniff for asyncio + if "asyncio" in sys.modules: + import asyncio + try: + current_task = asyncio.current_task # type: ignore[attr-defined] + except AttributeError: + current_task = asyncio.Task.current_task # type: ignore[attr-defined] + try: + if current_task() is not None: + return "asyncio" + except RuntimeError: + pass + + # Sniff for curio (for now) + if 'curio' in sys.modules: + from curio.meta import curio_running + if curio_running(): + return 'curio' + + raise AsyncLibraryNotFoundError( + "unknown async library, or not in async context" + ) diff --git a/contrib/python/sniffio/sniffio/_version.py b/contrib/python/sniffio/sniffio/_version.py new file mode 100644 index 0000000000..5a5f906bbf --- /dev/null +++ b/contrib/python/sniffio/sniffio/_version.py @@ -0,0 +1,3 @@ +# This file is imported from __init__.py and exec'd from setup.py + +__version__ = "1.3.0" diff --git a/contrib/python/sniffio/sniffio/py.typed b/contrib/python/sniffio/sniffio/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/sniffio/sniffio/py.typed diff --git a/contrib/python/sniffio/ya.make b/contrib/python/sniffio/ya.make new file mode 100644 index 0000000000..d0e376d4ca --- /dev/null +++ b/contrib/python/sniffio/ya.make @@ -0,0 +1,25 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(1.3.0) + +LICENSE(Apache-2.0 AND MIT) + +NO_LINT() + +PY_SRCS( + TOP_LEVEL + sniffio/__init__.py + sniffio/_impl.py + sniffio/_version.py +) + +RESOURCE_FILES( + PREFIX contrib/python/sniffio/ + .dist-info/METADATA + .dist-info/top_level.txt + sniffio/py.typed +) + +END() |