diff options
| author | robot-piglet <[email protected]> | 2025-10-16 23:00:05 +0300 |
|---|---|---|
| committer | robot-piglet <[email protected]> | 2025-10-16 23:13:00 +0300 |
| commit | 11655c9ebab3639829f448726c53598a309836d2 (patch) | |
| tree | 0fb621e064ba793816eee485f5b5b2b0d770c234 /contrib/python/aiohttp | |
| parent | ab033a87e63230a4a2116281fc1dc2d0072fc894 (diff) | |
Intermediate changes
commit_hash:736b945c094519c9026454678b074141602f4584
Diffstat (limited to 'contrib/python/aiohttp')
48 files changed, 2512 insertions, 1443 deletions
diff --git a/contrib/python/aiohttp/.dist-info/METADATA b/contrib/python/aiohttp/.dist-info/METADATA index cd312649136..7ad8cef0fd3 100644 --- a/contrib/python/aiohttp/.dist-info/METADATA +++ b/contrib/python/aiohttp/.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: aiohttp -Version: 3.9.5 +Version: 3.10.6 Summary: Async http client/server framework (asyncio) Home-page: https://github.com/aio-libs/aiohttp Maintainer: aiohttp team <[email protected]> @@ -28,20 +28,22 @@ Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 Classifier: Topic :: Internet :: WWW/HTTP Requires-Python: >=3.8 Description-Content-Type: text/x-rst License-File: LICENSE.txt +Requires-Dist: aiohappyeyeballs >=2.3.0 Requires-Dist: aiosignal >=1.1.2 Requires-Dist: attrs >=17.3.0 Requires-Dist: frozenlist >=1.1.1 Requires-Dist: multidict <7.0,>=4.5 -Requires-Dist: yarl <2.0,>=1.0 +Requires-Dist: yarl <2.0,>=1.12.0 Requires-Dist: async-timeout <5.0,>=4.0 ; python_version < "3.11" Provides-Extra: speedups Requires-Dist: brotlicffi ; (platform_python_implementation != "CPython") and extra == 'speedups' Requires-Dist: Brotli ; (platform_python_implementation == "CPython") and extra == 'speedups' -Requires-Dist: aiodns ; (sys_platform == "linux" or sys_platform == "darwin") and extra == 'speedups' +Requires-Dist: aiodns >=3.2.0 ; (sys_platform == "linux" or sys_platform == "darwin") and extra == 'speedups' ================================== Async http client/server framework @@ -193,7 +195,7 @@ Communication channels *aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions -*gitter chat* https://gitter.im/aio-libs/Lobby +*Matrix*: `#aio-libs:matrix.org <https://matrix.to/#/#aio-libs:matrix.org>`_ We support `Stack Overflow <https://stackoverflow.com/questions/tagged/aiohttp>`_. @@ -202,7 +204,6 @@ Please add *aiohttp* tag to your question there. Requirements ============ -- async-timeout_ - attrs_ - multidict_ - yarl_ diff --git a/contrib/python/aiohttp/README.rst b/contrib/python/aiohttp/README.rst index 90b7f713577..470ced9b29c 100644 --- a/contrib/python/aiohttp/README.rst +++ b/contrib/python/aiohttp/README.rst @@ -148,7 +148,7 @@ Communication channels *aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions -*gitter chat* https://gitter.im/aio-libs/Lobby +*Matrix*: `#aio-libs:matrix.org <https://matrix.to/#/#aio-libs:matrix.org>`_ We support `Stack Overflow <https://stackoverflow.com/questions/tagged/aiohttp>`_. @@ -157,7 +157,6 @@ Please add *aiohttp* tag to your question there. Requirements ============ -- async-timeout_ - attrs_ - multidict_ - yarl_ diff --git a/contrib/python/aiohttp/aiohttp/__init__.py b/contrib/python/aiohttp/aiohttp/__init__.py index e82e790b46a..8830d340940 100644 --- a/contrib/python/aiohttp/aiohttp/__init__.py +++ b/contrib/python/aiohttp/aiohttp/__init__.py @@ -1,40 +1,48 @@ -__version__ = "3.9.5" +__version__ = "3.10.6" from typing import TYPE_CHECKING, Tuple from . import hdrs as hdrs from .client import ( - BaseConnector as BaseConnector, - ClientConnectionError as ClientConnectionError, - ClientConnectorCertificateError as ClientConnectorCertificateError, - ClientConnectorError as ClientConnectorError, - ClientConnectorSSLError as ClientConnectorSSLError, - ClientError as ClientError, - ClientHttpProxyError as ClientHttpProxyError, - ClientOSError as ClientOSError, - ClientPayloadError as ClientPayloadError, - ClientProxyConnectionError as ClientProxyConnectionError, - ClientRequest as ClientRequest, - ClientResponse as ClientResponse, - ClientResponseError as ClientResponseError, - ClientSession as ClientSession, - ClientSSLError as ClientSSLError, - ClientTimeout as ClientTimeout, - ClientWebSocketResponse as ClientWebSocketResponse, - ContentTypeError as ContentTypeError, - Fingerprint as Fingerprint, - InvalidURL as InvalidURL, - NamedPipeConnector as NamedPipeConnector, - RequestInfo as RequestInfo, - ServerConnectionError as ServerConnectionError, - ServerDisconnectedError as ServerDisconnectedError, - ServerFingerprintMismatch as ServerFingerprintMismatch, - ServerTimeoutError as ServerTimeoutError, - TCPConnector as TCPConnector, - TooManyRedirects as TooManyRedirects, - UnixConnector as UnixConnector, - WSServerHandshakeError as WSServerHandshakeError, - request as request, + BaseConnector, + ClientConnectionError, + ClientConnectionResetError, + ClientConnectorCertificateError, + ClientConnectorError, + ClientConnectorSSLError, + ClientError, + ClientHttpProxyError, + ClientOSError, + ClientPayloadError, + ClientProxyConnectionError, + ClientRequest, + ClientResponse, + ClientResponseError, + ClientSession, + ClientSSLError, + ClientTimeout, + ClientWebSocketResponse, + ConnectionTimeoutError, + ContentTypeError, + Fingerprint, + InvalidURL, + InvalidUrlClientError, + InvalidUrlRedirectClientError, + NamedPipeConnector, + NonHttpUrlClientError, + NonHttpUrlRedirectClientError, + RedirectClientError, + RequestInfo, + ServerConnectionError, + ServerDisconnectedError, + ServerFingerprintMismatch, + ServerTimeoutError, + SocketTimeoutError, + TCPConnector, + TooManyRedirects, + UnixConnector, + WSServerHandshakeError, + request, ) from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar from .formdata import FormData as FormData @@ -99,6 +107,7 @@ from .tracing import ( TraceRequestChunkSentParams as TraceRequestChunkSentParams, TraceRequestEndParams as TraceRequestEndParams, TraceRequestExceptionParams as TraceRequestExceptionParams, + TraceRequestHeadersSentParams as TraceRequestHeadersSentParams, TraceRequestRedirectParams as TraceRequestRedirectParams, TraceRequestStartParams as TraceRequestStartParams, TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams, @@ -116,6 +125,7 @@ __all__: Tuple[str, ...] = ( # client "BaseConnector", "ClientConnectionError", + "ClientConnectionResetError", "ClientConnectorCertificateError", "ClientConnectorError", "ClientConnectorSSLError", @@ -131,14 +141,21 @@ __all__: Tuple[str, ...] = ( "ClientSession", "ClientTimeout", "ClientWebSocketResponse", + "ConnectionTimeoutError", "ContentTypeError", "Fingerprint", "InvalidURL", + "InvalidUrlClientError", + "InvalidUrlRedirectClientError", + "NonHttpUrlClientError", + "NonHttpUrlRedirectClientError", + "RedirectClientError", "RequestInfo", "ServerConnectionError", "ServerDisconnectedError", "ServerFingerprintMismatch", "ServerTimeoutError", + "SocketTimeoutError", "TCPConnector", "TooManyRedirects", "UnixConnector", @@ -210,6 +227,7 @@ __all__: Tuple[str, ...] = ( "TraceRequestChunkSentParams", "TraceRequestEndParams", "TraceRequestExceptionParams", + "TraceRequestHeadersSentParams", "TraceRequestRedirectParams", "TraceRequestStartParams", "TraceResponseChunkReceivedParams", diff --git a/contrib/python/aiohttp/aiohttp/_helpers.pyx b/contrib/python/aiohttp/aiohttp/_helpers.pyx index 665f367c5de..5f089225dc8 100644 --- a/contrib/python/aiohttp/aiohttp/_helpers.pyx +++ b/contrib/python/aiohttp/aiohttp/_helpers.pyx @@ -1,3 +1,6 @@ + +cdef _sentinel = object() + cdef class reify: """Use as a class method decorator. It operates almost exactly like the Python `@property` decorator, but it puts the result of the @@ -19,17 +22,14 @@ cdef class reify: return self.wrapped.__doc__ def __get__(self, inst, owner): - try: - try: - return inst._cache[self.name] - except KeyError: - val = self.wrapped(inst) - inst._cache[self.name] = val - return val - except AttributeError: - if inst is None: - return self - raise + if inst is None: + return self + cdef dict cache = inst._cache + val = cache.get(self.name, _sentinel) + if val is _sentinel: + val = self.wrapped(inst) + cache[self.name] = val + return val def __set__(self, inst, value): raise AttributeError("reified property is read-only") diff --git a/contrib/python/aiohttp/aiohttp/_http_parser.pyx b/contrib/python/aiohttp/aiohttp/_http_parser.pyx index ec6edb2dfec..8e82c0fd77f 100644 --- a/contrib/python/aiohttp/aiohttp/_http_parser.pyx +++ b/contrib/python/aiohttp/aiohttp/_http_parser.pyx @@ -47,6 +47,7 @@ include "_headers.pxi" from aiohttp cimport _find_header +ALLOWED_UPGRADES = frozenset({"websocket"}) DEF DEFAULT_FREELIST_SIZE = 250 cdef extern from "Python.h": @@ -417,7 +418,6 @@ cdef class HttpParser: cdef _on_headers_complete(self): self._process_header() - method = http_method_str(self._cparser.method) should_close = not cparser.llhttp_should_keep_alive(self._cparser) upgrade = self._cparser.upgrade chunked = self._cparser.flags & cparser.F_CHUNKED @@ -425,8 +425,13 @@ cdef class HttpParser: raw_headers = tuple(self._raw_headers) headers = CIMultiDictProxy(self._headers) - if upgrade or self._cparser.method == cparser.HTTP_CONNECT: - self._upgraded = True + if self._cparser.type == cparser.HTTP_REQUEST: + allowed = upgrade and headers.get("upgrade", "").lower() in ALLOWED_UPGRADES + if allowed or self._cparser.method == cparser.HTTP_CONNECT: + self._upgraded = True + else: + if upgrade and self._cparser.status_code == 101: + self._upgraded = True # do not support old websocket spec if SEC_WEBSOCKET_KEY1 in headers: @@ -441,6 +446,7 @@ cdef class HttpParser: encoding = enc if self._cparser.type == cparser.HTTP_REQUEST: + method = http_method_str(self._cparser.method) msg = _new_request_message( method, self._path, self.http_version(), headers, raw_headers, @@ -565,7 +571,7 @@ cdef class HttpParser: if self._upgraded: return messages, True, data[nb:] else: - return messages, False, b'' + return messages, False, b"" def set_upgraded(self, val): self._upgraded = val @@ -748,10 +754,7 @@ cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1: pyparser._last_error = exc return -1 else: - if ( - pyparser._cparser.upgrade or - pyparser._cparser.method == cparser.HTTP_CONNECT - ): + if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT: return 2 else: return 0 diff --git a/contrib/python/aiohttp/aiohttp/abc.py b/contrib/python/aiohttp/aiohttp/abc.py index ee838998997..b309bcffe01 100644 --- a/contrib/python/aiohttp/aiohttp/abc.py +++ b/contrib/python/aiohttp/aiohttp/abc.py @@ -1,5 +1,6 @@ import asyncio import logging +import socket from abc import ABC, abstractmethod from collections.abc import Sized from http.cookies import BaseCookie, Morsel @@ -14,12 +15,12 @@ from typing import ( List, Optional, Tuple, + TypedDict, ) from multidict import CIMultiDict from yarl import URL -from .helpers import get_running_loop from .typedefs import LooseCookies if TYPE_CHECKING: @@ -119,11 +120,35 @@ class AbstractView(ABC): """Execute the view handler.""" +class ResolveResult(TypedDict): + """Resolve result. + + This is the result returned from an AbstractResolver's + resolve method. + + :param hostname: The hostname that was provided. + :param host: The IP address that was resolved. + :param port: The port that was resolved. + :param family: The address family that was resolved. + :param proto: The protocol that was resolved. + :param flags: The flags that were resolved. + """ + + hostname: str + host: str + port: int + family: int + proto: int + flags: int + + class AbstractResolver(ABC): """Abstract DNS resolver.""" @abstractmethod - async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]: + async def resolve( + self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET + ) -> List[ResolveResult]: """Return IP address for given hostname""" @abstractmethod @@ -144,7 +169,7 @@ class AbstractCookieJar(Sized, IterableBase): """Abstract Cookie Jar.""" def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: - self._loop = get_running_loop(loop) + self._loop = loop or asyncio.get_event_loop() @abstractmethod def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: diff --git a/contrib/python/aiohttp/aiohttp/base_protocol.py b/contrib/python/aiohttp/aiohttp/base_protocol.py index dc1f24f99cd..2fc2fa65885 100644 --- a/contrib/python/aiohttp/aiohttp/base_protocol.py +++ b/contrib/python/aiohttp/aiohttp/base_protocol.py @@ -1,6 +1,7 @@ import asyncio from typing import Optional, cast +from .client_exceptions import ClientConnectionResetError from .helpers import set_exception from .tcp_helpers import tcp_nodelay @@ -85,7 +86,7 @@ class BaseProtocol(asyncio.Protocol): async def _drain_helper(self) -> None: if not self.connected: - raise ConnectionResetError("Connection lost") + raise ClientConnectionResetError("Connection lost") if not self._paused: return waiter = self._drain_waiter diff --git a/contrib/python/aiohttp/aiohttp/client.py b/contrib/python/aiohttp/aiohttp/client.py index 32d2c3b7119..186105bee9f 100644 --- a/contrib/python/aiohttp/aiohttp/client.py +++ b/contrib/python/aiohttp/aiohttp/client.py @@ -9,7 +9,7 @@ import sys import traceback import warnings from contextlib import suppress -from types import SimpleNamespace, TracebackType +from types import TracebackType from typing import ( TYPE_CHECKING, Any, @@ -27,6 +27,7 @@ from typing import ( Set, Tuple, Type, + TypedDict, TypeVar, Union, ) @@ -38,25 +39,33 @@ from yarl import URL from . import hdrs, http, payload from .abc import AbstractCookieJar from .client_exceptions import ( - ClientConnectionError as ClientConnectionError, - ClientConnectorCertificateError as ClientConnectorCertificateError, - ClientConnectorError as ClientConnectorError, - ClientConnectorSSLError as ClientConnectorSSLError, - ClientError as ClientError, - ClientHttpProxyError as ClientHttpProxyError, - ClientOSError as ClientOSError, - ClientPayloadError as ClientPayloadError, - ClientProxyConnectionError as ClientProxyConnectionError, - ClientResponseError as ClientResponseError, - ClientSSLError as ClientSSLError, - ContentTypeError as ContentTypeError, - InvalidURL as InvalidURL, - ServerConnectionError as ServerConnectionError, - ServerDisconnectedError as ServerDisconnectedError, - ServerFingerprintMismatch as ServerFingerprintMismatch, - ServerTimeoutError as ServerTimeoutError, - TooManyRedirects as TooManyRedirects, - WSServerHandshakeError as WSServerHandshakeError, + ClientConnectionError, + ClientConnectionResetError, + ClientConnectorCertificateError, + ClientConnectorError, + ClientConnectorSSLError, + ClientError, + ClientHttpProxyError, + ClientOSError, + ClientPayloadError, + ClientProxyConnectionError, + ClientResponseError, + ClientSSLError, + ConnectionTimeoutError, + ContentTypeError, + InvalidURL, + InvalidUrlClientError, + InvalidUrlRedirectClientError, + NonHttpUrlClientError, + NonHttpUrlRedirectClientError, + RedirectClientError, + ServerConnectionError, + ServerDisconnectedError, + ServerFingerprintMismatch, + ServerTimeoutError, + SocketTimeoutError, + TooManyRedirects, + WSServerHandshakeError, ) from .client_reqrep import ( ClientRequest as ClientRequest, @@ -67,6 +76,7 @@ from .client_reqrep import ( ) from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse from .connector import ( + HTTP_AND_EMPTY_SCHEMA_SET, BaseConnector as BaseConnector, NamedPipeConnector as NamedPipeConnector, TCPConnector as TCPConnector, @@ -80,7 +90,6 @@ from .helpers import ( TimeoutHandle, ceil_timeout, get_env_proxy_for_url, - get_running_loop, method_must_be_empty_body, sentinel, strip_auth_from_url, @@ -89,11 +98,12 @@ from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse from .streams import FlowControlDataQueue from .tracing import Trace, TraceConfig -from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, StrOrURL +from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, Query, StrOrURL __all__ = ( # client_exceptions "ClientConnectionError", + "ClientConnectionResetError", "ClientConnectorCertificateError", "ClientConnectorError", "ClientConnectorSSLError", @@ -104,12 +114,19 @@ __all__ = ( "ClientProxyConnectionError", "ClientResponseError", "ClientSSLError", + "ConnectionTimeoutError", "ContentTypeError", "InvalidURL", + "InvalidUrlClientError", + "RedirectClientError", + "NonHttpUrlClientError", + "InvalidUrlRedirectClientError", + "NonHttpUrlRedirectClientError", "ServerConnectionError", "ServerDisconnectedError", "ServerFingerprintMismatch", "ServerTimeoutError", + "SocketTimeoutError", "TooManyRedirects", "WSServerHandshakeError", # client_reqrep @@ -136,6 +153,37 @@ if TYPE_CHECKING: else: SSLContext = None +if sys.version_info >= (3, 11) and TYPE_CHECKING: + from typing import Unpack + + +class _RequestOptions(TypedDict, total=False): + params: Query + data: Any + json: Any + cookies: Union[LooseCookies, None] + headers: Union[LooseHeaders, None] + skip_auto_headers: Union[Iterable[str], None] + auth: Union[BasicAuth, None] + allow_redirects: bool + max_redirects: int + compress: Union[str, bool, None] + chunked: Union[bool, None] + expect100: bool + raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]] + read_until_eof: bool + proxy: Union[StrOrURL, None] + proxy_auth: Union[BasicAuth, None] + timeout: "Union[ClientTimeout, _SENTINEL, int, float, None]" + ssl: Union[SSLContext, bool, Fingerprint] + server_hostname: Union[str, None] + proxy_headers: Union[LooseHeaders, None] + trace_request_ctx: Any #Union[Mapping[str, str], None] + read_bufsize: Union[int, None] + auto_decompress: Union[bool, None] + max_line_size: Union[int, None] + max_field_size: Union[int, None] + @attr.s(auto_attribs=True, frozen=True, slots=True) class ClientTimeout: @@ -162,7 +210,10 @@ class ClientTimeout: # 5 Minute default read timeout DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60) -_RetType = TypeVar("_RetType") +# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 +IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) + +_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse) _CharsetResolver = Callable[[ClientResponse, bytes], str] @@ -237,6 +288,21 @@ class ClientSession: # We initialise _connector to None immediately, as it's referenced in __del__() # and could cause issues if an exception occurs during initialisation. self._connector: Optional[BaseConnector] = None + + if loop is None: + if connector is not None: + loop = connector._loop + + loop = loop or asyncio.get_event_loop() + + if base_url is None or isinstance(base_url, URL): + self._base_url: Optional[URL] = base_url + else: + self._base_url = URL(base_url) + assert ( + self._base_url.origin() == self._base_url + ), "Only absolute URLs without path part are supported" + if timeout is sentinel or timeout is None: self._timeout = DEFAULT_TIMEOUT if read_timeout is not sentinel: @@ -272,19 +338,6 @@ class ClientSession: "conflict, please setup " "timeout.connect" ) - if loop is None: - if connector is not None: - loop = connector._loop - - loop = get_running_loop(loop) - - if base_url is None or isinstance(base_url, URL): - self._base_url: Optional[URL] = base_url - else: - self._base_url = URL(base_url) - assert ( - self._base_url.origin() == self._base_url - ), "Only absolute URLs without path part are supported" if connector is None: connector = TCPConnector(loop=loop) @@ -369,11 +422,22 @@ class ClientSession: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - def request( - self, method: str, url: StrOrURL, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP request.""" - return _RequestContextManager(self._request(method, url, **kwargs)) + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + def request( + self, + method: str, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + else: + + def request( + self, method: str, url: StrOrURL, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP request.""" + return _RequestContextManager(self._request(method, url, **kwargs)) def _build_url(self, str_or_url: StrOrURL) -> URL: url = URL(str_or_url) @@ -388,7 +452,7 @@ class ClientSession: method: str, str_or_url: StrOrURL, *, - params: Optional[Mapping[str, str]] = None, + params: Query = None, data: Any = None, json: Any = None, cookies: Optional[LooseCookies] = None, @@ -397,7 +461,7 @@ class ClientSession: auth: Optional[BasicAuth] = None, allow_redirects: bool = True, max_redirects: int = 10, - compress: Optional[str] = None, + compress: Union[str, bool, None] = None, chunked: Optional[bool] = None, expect100: bool = False, raise_for_status: Union[ @@ -413,7 +477,7 @@ class ClientSession: ssl: Union[SSLContext, bool, Fingerprint] = True, server_hostname: Optional[str] = None, proxy_headers: Optional[LooseHeaders] = None, - trace_request_ctx: Optional[SimpleNamespace] = None, + trace_request_ctx: Optional[Mapping[str, str]] = None, read_bufsize: Optional[int] = None, auto_decompress: Optional[bool] = None, max_line_size: Optional[int] = None, @@ -451,7 +515,11 @@ class ClientSession: try: url = self._build_url(str_or_url) except ValueError as e: - raise InvalidURL(str_or_url) from e + raise InvalidUrlClientError(str_or_url) from e + + assert self._connector is not None + if url.scheme not in self._connector.allowed_protocol_schema_set: + raise NonHttpUrlClientError(url) skip_headers = set(self._skip_auto_headers) if skip_auto_headers is not None: @@ -505,8 +573,19 @@ class ClientSession: timer = tm.timer() try: with timer: + # https://www.rfc-editor.org/rfc/rfc9112.html#name-retrying-requests + retry_persistent_connection = False #method in IDEMPOTENT_METHODS while True: url, auth_from_url = strip_auth_from_url(url) + if not url.raw_host: + # NOTE: Bail early, otherwise, causes `InvalidURL` through + # NOTE: `self._request_class()` below. + err_exc_cls = ( + InvalidUrlRedirectClientError + if redirects + else InvalidUrlClientError + ) + raise err_exc_cls(url) if auth and auth_from_url: raise ValueError( "Cannot combine AUTH argument with " @@ -550,7 +629,7 @@ class ClientSession: url, params=params, headers=headers, - skip_auto_headers=skip_headers, + skip_auto_headers=skip_headers if skip_headers else None, data=data, cookies=all_cookies, auth=auth, @@ -577,13 +656,12 @@ class ClientSession: real_timeout.connect, ceil_threshold=real_timeout.ceil_threshold, ): - assert self._connector is not None conn = await self._connector.connect( req, traces=traces, timeout=real_timeout ) except asyncio.TimeoutError as exc: - raise ServerTimeoutError( - "Connection timeout " "to host {}".format(url) + raise ConnectionTimeoutError( + f"Connection timeout to host {url}" ) from exc assert conn.transport is not None @@ -612,6 +690,11 @@ class ClientSession: except BaseException: conn.close() raise + except (ClientOSError, ServerDisconnectedError): + if retry_persistent_connection: + retry_persistent_connection = False + continue + raise except ClientError: raise except OSError as exc: @@ -659,25 +742,35 @@ class ClientSession: resp.release() try: - parsed_url = URL( + parsed_redirect_url = URL( r_url, encoded=not self._requote_redirect_url ) - except ValueError as e: - raise InvalidURL(r_url) from e + raise InvalidUrlRedirectClientError( + r_url, + "Server attempted redirecting to a location that does not look like a URL", + ) from e - scheme = parsed_url.scheme - if scheme not in ("http", "https", ""): + scheme = parsed_redirect_url.scheme + if scheme not in HTTP_AND_EMPTY_SCHEMA_SET: resp.close() - raise ValueError("Can redirect only to http or https") + raise NonHttpUrlRedirectClientError(r_url) elif not scheme: - parsed_url = url.join(parsed_url) + parsed_redirect_url = url.join(parsed_redirect_url) + + try: + redirect_origin = parsed_redirect_url.origin() + except ValueError as origin_val_err: + raise InvalidUrlRedirectClientError( + parsed_redirect_url, + "Invalid redirect URL origin", + ) from origin_val_err - if url.origin() != parsed_url.origin(): + if url.origin() != redirect_origin: auth = None headers.pop(hdrs.AUTHORIZATION, None) - url = parsed_url + url = parsed_redirect_url params = {} resp.release() continue @@ -736,11 +829,11 @@ class ClientSession: heartbeat: Optional[float] = None, auth: Optional[BasicAuth] = None, origin: Optional[str] = None, - params: Optional[Mapping[str, str]] = None, + params: Query = None, headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, bool, None, Fingerprint] = True, + ssl: Union[SSLContext, bool, Fingerprint] = True, verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, @@ -788,11 +881,11 @@ class ClientSession: heartbeat: Optional[float] = None, auth: Optional[BasicAuth] = None, origin: Optional[str] = None, - params: Optional[Mapping[str, str]] = None, + params: Query = None, headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Optional[Union[SSLContext, bool, Fingerprint]] = True, + ssl: Union[SSLContext, bool, Fingerprint] = True, verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, @@ -828,6 +921,11 @@ class ClientSession: # For the sake of backward compatibility, if user passes in None, convert it to True if ssl is None: + warnings.warn( + "ssl=None is deprecated, please use ssl=True", + DeprecationWarning, + stacklevel=2, + ) ssl = True ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) @@ -922,6 +1020,16 @@ class ClientSession: assert conn is not None conn_proto = conn.protocol assert conn_proto is not None + + # For WS connection the read_timeout must be either receive_timeout or greater + # None == no timeout, i.e. infinite timeout, so None is the max timeout possible + if receive_timeout is None: + # Reset regardless + conn_proto.read_timeout = receive_timeout + elif conn_proto.read_timeout is not None: + # If read_timeout was set check which wins + conn_proto.read_timeout = max(receive_timeout, conn_proto.read_timeout) + transport = conn.transport assert transport is not None reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue( @@ -970,61 +1078,111 @@ class ClientSession: added_names.add(key) return result - def get( - self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP GET request.""" - return _RequestContextManager( - self._request(hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs) - ) + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + def get( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def options( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def head( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def post( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... - def options( - self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP OPTIONS request.""" - return _RequestContextManager( - self._request( - hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs + def put( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def patch( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def delete( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + else: + + def get( + self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP GET request.""" + return _RequestContextManager( + self._request( + hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs + ) ) - ) - def head( - self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP HEAD request.""" - return _RequestContextManager( - self._request( - hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs + def options( + self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP OPTIONS request.""" + return _RequestContextManager( + self._request( + hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs + ) ) - ) - def post( - self, url: StrOrURL, *, data: Any = None, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP POST request.""" - return _RequestContextManager( - self._request(hdrs.METH_POST, url, data=data, **kwargs) - ) + def head( + self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP HEAD request.""" + return _RequestContextManager( + self._request( + hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs + ) + ) - def put( - self, url: StrOrURL, *, data: Any = None, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP PUT request.""" - return _RequestContextManager( - self._request(hdrs.METH_PUT, url, data=data, **kwargs) - ) + def post( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP POST request.""" + return _RequestContextManager( + self._request(hdrs.METH_POST, url, data=data, **kwargs) + ) - def patch( - self, url: StrOrURL, *, data: Any = None, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP PATCH request.""" - return _RequestContextManager( - self._request(hdrs.METH_PATCH, url, data=data, **kwargs) - ) + def put( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP PUT request.""" + return _RequestContextManager( + self._request(hdrs.METH_PUT, url, data=data, **kwargs) + ) - def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager": - """Perform HTTP DELETE request.""" - return _RequestContextManager(self._request(hdrs.METH_DELETE, url, **kwargs)) + def patch( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP PATCH request.""" + return _RequestContextManager( + self._request(hdrs.METH_PATCH, url, data=data, **kwargs) + ) + + def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager": + """Perform HTTP DELETE request.""" + return _RequestContextManager( + self._request(hdrs.METH_DELETE, url, **kwargs) + ) async def close(self) -> None: """Close underlying connector. @@ -1175,7 +1333,7 @@ class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType __slots__ = ("_coro", "_resp") def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None: - self._coro = coro + self._coro: Coroutine["asyncio.Future[Any]", None, _RetType] = coro def send(self, arg: None) -> "asyncio.Future[Any]": return self._coro.send(arg) @@ -1194,12 +1352,8 @@ class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType return self.__await__() async def __aenter__(self) -> _RetType: - self._resp = await self._coro - return self._resp - - -class _RequestContextManager(_BaseRequestContextManager[ClientResponse]): - __slots__ = () + self._resp: _RetType = await self._coro + return await self._resp.__aenter__() async def __aexit__( self, @@ -1207,25 +1361,11 @@ class _RequestContextManager(_BaseRequestContextManager[ClientResponse]): exc: Optional[BaseException], tb: Optional[TracebackType], ) -> None: - # We're basing behavior on the exception as it can be caused by - # user code unrelated to the status of the connection. If you - # would like to close a connection you must do that - # explicitly. Otherwise connection error handling should kick in - # and close/recycle the connection as required. - self._resp.release() - await self._resp.wait_for_close() - + await self._resp.__aexit__(exc_type, exc, tb) -class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]): - __slots__ = () - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], - ) -> None: - await self._resp.close() +_RequestContextManager = _BaseRequestContextManager[ClientResponse] +_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse] class _SessionRequestContextManager: @@ -1265,7 +1405,7 @@ def request( method: str, url: StrOrURL, *, - params: Optional[Mapping[str, str]] = None, + params: Query = None, data: Any = None, json: Any = None, headers: Optional[LooseHeaders] = None, diff --git a/contrib/python/aiohttp/aiohttp/client_exceptions.py b/contrib/python/aiohttp/aiohttp/client_exceptions.py index 9b6e44203c8..94991c42477 100644 --- a/contrib/python/aiohttp/aiohttp/client_exceptions.py +++ b/contrib/python/aiohttp/aiohttp/client_exceptions.py @@ -2,10 +2,11 @@ import asyncio import warnings -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union -from .http_parser import RawResponseMessage -from .typedefs import LooseHeaders +from multidict import MultiMapping + +from .typedefs import StrOrURL try: import ssl @@ -17,18 +18,22 @@ except ImportError: # pragma: no cover if TYPE_CHECKING: from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo + from .http_parser import RawResponseMessage else: - RequestInfo = ClientResponse = ConnectionKey = None + RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None __all__ = ( "ClientError", "ClientConnectionError", + "ClientConnectionResetError", "ClientOSError", "ClientConnectorError", "ClientProxyConnectionError", "ClientSSLError", "ClientConnectorSSLError", "ClientConnectorCertificateError", + "ConnectionTimeoutError", + "SocketTimeoutError", "ServerConnectionError", "ServerTimeoutError", "ServerDisconnectedError", @@ -39,6 +44,11 @@ __all__ = ( "ContentTypeError", "ClientPayloadError", "InvalidURL", + "InvalidUrlClientError", + "RedirectClientError", + "NonHttpUrlClientError", + "InvalidUrlRedirectClientError", + "NonHttpUrlRedirectClientError", ) @@ -64,7 +74,7 @@ class ClientResponseError(ClientError): code: Optional[int] = None, status: Optional[int] = None, message: str = "", - headers: Optional[LooseHeaders] = None, + headers: Optional[MultiMapping[str]] = None, ) -> None: self.request_info = request_info if code is not None: @@ -93,7 +103,7 @@ class ClientResponseError(ClientError): return "{}, message={!r}, url={!r}".format( self.status, self.message, - self.request_info.real_url, + str(self.request_info.real_url), ) def __repr__(self) -> str: @@ -150,6 +160,10 @@ class ClientConnectionError(ClientError): """Base class for client socket errors.""" +class ClientConnectionResetError(ClientConnectionError, ConnectionResetError): + """ConnectionResetError""" + + class ClientOSError(ClientConnectionError, OSError): """OSError error.""" @@ -242,6 +256,14 @@ class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError): """Server timeout error.""" +class ConnectionTimeoutError(ServerTimeoutError): + """Connection timeout error.""" + + +class SocketTimeoutError(ServerTimeoutError): + """Socket timeout error.""" + + class ServerFingerprintMismatch(ServerConnectionError): """SSL certificate does not match expected fingerprint.""" @@ -271,17 +293,52 @@ class InvalidURL(ClientError, ValueError): # Derive from ValueError for backward compatibility - def __init__(self, url: Any) -> None: + def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None: # The type of url is not yarl.URL because the exception can be raised # on URL(url) call - super().__init__(url) + self._url = url + self._description = description + + if description: + super().__init__(url, description) + else: + super().__init__(url) @property - def url(self) -> Any: - return self.args[0] + def url(self) -> StrOrURL: + return self._url + + @property + def description(self) -> "str | None": + return self._description def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.url}>" + return f"<{self.__class__.__name__} {self}>" + + def __str__(self) -> str: + if self._description: + return f"{self._url} - {self._description}" + return str(self._url) + + +class InvalidUrlClientError(InvalidURL): + """Invalid URL client error.""" + + +class RedirectClientError(ClientError): + """Client redirect error.""" + + +class NonHttpUrlClientError(ClientError): + """Non http URL client error.""" + + +class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError): + """Invalid URL redirect client error.""" + + +class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError): + """Non http URL redirect client error.""" class ClientSSLError(ClientConnectorError): diff --git a/contrib/python/aiohttp/aiohttp/client_proto.py b/contrib/python/aiohttp/aiohttp/client_proto.py index 723f5aae5f4..8055811e40d 100644 --- a/contrib/python/aiohttp/aiohttp/client_proto.py +++ b/contrib/python/aiohttp/aiohttp/client_proto.py @@ -7,7 +7,7 @@ from .client_exceptions import ( ClientOSError, ClientPayloadError, ServerDisconnectedError, - ServerTimeoutError, + SocketTimeoutError, ) from .helpers import ( _EXC_SENTINEL, @@ -50,15 +50,13 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe @property def should_close(self) -> bool: - if self._payload is not None and not self._payload.is_eof() or self._upgraded: - return True - return ( self._should_close + or (self._payload is not None and not self._payload.is_eof()) or self._upgraded - or self.exception() is not None + or self._exception is not None or self._payload_parser is not None - or len(self) > 0 + or bool(self._buffer) or bool(self._tail) ) @@ -224,8 +222,16 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe def start_timeout(self) -> None: self._reschedule_timeout() + @property + def read_timeout(self) -> Optional[float]: + return self._read_timeout + + @read_timeout.setter + def read_timeout(self, read_timeout: Optional[float]) -> None: + self._read_timeout = read_timeout + def _on_read_timeout(self) -> None: - exc = ServerTimeoutError("Timeout on reading data from socket") + exc = SocketTimeoutError("Timeout on reading data from socket") self.set_exception(exc) if self._payload is not None: set_exception(self._payload, exc) @@ -261,7 +267,15 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe # closed in this case self.transport.close() # should_close is True after the call - self.set_exception(HttpProcessingError(), underlying_exc) + if isinstance(underlying_exc, HttpProcessingError): + exc = HttpProcessingError( + code=underlying_exc.code, + message=underlying_exc.message, + headers=underlying_exc.headers, + ) + else: + exc = HttpProcessingError() + self.set_exception(exc, underlying_exc) return self._upgraded = upgraded diff --git a/contrib/python/aiohttp/aiohttp/client_reqrep.py b/contrib/python/aiohttp/aiohttp/client_reqrep.py index afe719da16e..aa8f54e67b8 100644 --- a/contrib/python/aiohttp/aiohttp/client_reqrep.py +++ b/contrib/python/aiohttp/aiohttp/client_reqrep.py @@ -27,7 +27,7 @@ from typing import ( import attr from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy -from yarl import URL +from yarl import URL, __version__ as yarl_version from . import hdrs, helpers, http, multipart, payload from .abc import AbstractStreamWriter @@ -67,6 +67,7 @@ from .typedefs import ( JSONDecoder, LooseCookies, LooseHeaders, + Query, RawHeaders, ) @@ -88,6 +89,7 @@ if TYPE_CHECKING: _CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]") +_YARL_SUPPORTS_EXTEND_QUERY = tuple(map(int, yarl_version.split(".")[:2])) >= (1, 11) json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json") @@ -209,7 +211,7 @@ def _merge_ssl_params( return ssl [email protected](auto_attribs=True, slots=True, frozen=True) [email protected](auto_attribs=True, slots=True, frozen=True, cache_hash=True) class ConnectionKey: # the key should contain an information about used proxy / TLS # to prevent reusing wrong connections from a pool @@ -245,7 +247,8 @@ class ClientRequest: hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), } - body = b"" + # Type of body depends on PAYLOAD_REGISTRY, which is dynamic. + body: Any = b"" auth = None response = None @@ -262,14 +265,14 @@ class ClientRequest: method: str, url: URL, *, - params: Optional[Mapping[str, str]] = None, + params: Query = None, headers: Optional[LooseHeaders] = None, - skip_auto_headers: Iterable[str] = frozenset(), + skip_auto_headers: Optional[Iterable[str]] = None, data: Any = None, cookies: Optional[LooseCookies] = None, auth: Optional[BasicAuth] = None, version: http.HttpVersion = http.HttpVersion11, - compress: Optional[str] = None, + compress: Union[str, bool, None] = None, chunked: Optional[bool] = None, expect100: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, @@ -300,10 +303,13 @@ class ClientRequest: # assert session is not None self._session = cast("ClientSession", session) if params: - q = MultiDict(url.query) - url2 = url.with_query(params) - q.extend(url2.query) - url = url.with_query(q) + if _YARL_SUPPORTS_EXTEND_QUERY: + url = url.extend_query(params) + else: + q = MultiDict(url.query) + url2 = url.with_query(params) + q.extend(url2.query) + url = url.with_query(q) self.original_url = url self.url = url.with_fragment(None) self.method = method.upper() @@ -352,7 +358,12 @@ class ClientRequest: if self.__writer is not None: self.__writer.remove_done_callback(self.__reset_writer) self.__writer = writer - if writer is not None: + if writer is None: + return + if writer.done(): + # The writer is already done, so we can reset it immediately. + self.__reset_writer() + else: writer.add_done_callback(self.__reset_writer) def is_ssl(self) -> bool: @@ -402,8 +413,8 @@ class ClientRequest: # basic auth info username, password = url.user, url.password - if username: - self.auth = helpers.BasicAuth(username, password or "") + if username or password: + self.auth = helpers.BasicAuth(username or "", password or "") def update_version(self, version: Union[http.HttpVersion, str]) -> None: """Convert request version to two elements tuple. @@ -436,7 +447,7 @@ class ClientRequest: if headers: if isinstance(headers, (dict, MultiDictProxy, MultiDict)): - headers = headers.items() # type: ignore[assignment] + headers = headers.items() for key, value in headers: # type: ignore[misc] # A special case for Host header @@ -445,12 +456,18 @@ class ClientRequest: else: self.headers.add(key, value) - def update_auto_headers(self, skip_auto_headers: Iterable[str]) -> None: - self.skip_auto_headers = CIMultiDict( - (hdr, None) for hdr in sorted(skip_auto_headers) - ) - used_headers = self.headers.copy() - used_headers.extend(self.skip_auto_headers) # type: ignore[arg-type] + def update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> None: + if skip_auto_headers is not None: + self.skip_auto_headers = CIMultiDict( + (hdr, None) for hdr in sorted(skip_auto_headers) + ) + used_headers = self.headers.copy() + used_headers.extend(self.skip_auto_headers) # type: ignore[arg-type] + else: + # Fast path when there are no headers to skip + # which is the most common case. + self.skip_auto_headers = CIMultiDict() + used_headers = self.headers for hdr, val in self.DEFAULT_HEADERS.items(): if hdr not in used_headers: @@ -486,11 +503,12 @@ class ClientRequest: def update_content_encoding(self, data: Any) -> None: """Set request content encoding.""" - if data is None: + if not data: + # Don't compress an empty body. + self.compress = None return - enc = self.headers.get(hdrs.CONTENT_ENCODING, "").lower() - if enc: + if self.headers.get(hdrs.CONTENT_ENCODING): if self.compress: raise ValueError( "compress can not be set " "if Content-Encoding header is set" @@ -566,10 +584,8 @@ class ClientRequest: # copy payload headers assert body.headers - for (key, value) in body.headers.items(): - if key in self.headers: - continue - if key in self.skip_auto_headers: + for key, value in body.headers.items(): + if key in self.headers or key in self.skip_auto_headers: continue self.headers[key] = value @@ -592,6 +608,10 @@ class ClientRequest: raise ValueError("proxy_auth must be None or BasicAuth() tuple") self.proxy = proxy self.proxy_auth = proxy_auth + if proxy_headers is not None and not isinstance( + proxy_headers, (MultiDict, MultiDictProxy) + ): + proxy_headers = CIMultiDict(proxy_headers) self.proxy_headers = proxy_headers def keep_alive(self) -> bool: @@ -614,11 +634,8 @@ class ClientRequest: """Support coroutines that yields bytes objects.""" # 100 response if self._continue is not None: - try: - await writer.drain() - await self._continue - except asyncio.CancelledError: - return + await writer.drain() + await self._continue protocol = conn.protocol assert protocol is not None @@ -627,10 +644,10 @@ class ClientRequest: await self.body.write(writer) else: if isinstance(self.body, (bytes, bytearray)): - self.body = (self.body,) # type: ignore[assignment] + self.body = (self.body,) for chunk in self.body: - await writer.write(chunk) # type: ignore[arg-type] + await writer.write(chunk) except OSError as underlying_exc: reraised_exc = underlying_exc @@ -645,7 +662,9 @@ class ClientRequest: set_exception(protocol, reraised_exc, underlying_exc) except asyncio.CancelledError: - await writer.write_eof() + # Body hasn't been fully sent, so connection can't be reused. + conn.close() + raise except Exception as underlying_exc: set_exception( protocol, @@ -681,16 +700,20 @@ class ClientRequest: writer = StreamWriter( protocol, self.loop, - on_chunk_sent=functools.partial( - self._on_chunk_request_sent, self.method, self.url + on_chunk_sent=( + functools.partial(self._on_chunk_request_sent, self.method, self.url) + if self._traces + else None ), - on_headers_sent=functools.partial( - self._on_headers_request_sent, self.method, self.url + on_headers_sent=( + functools.partial(self._on_headers_request_sent, self.method, self.url) + if self._traces + else None ), ) if self.compress: - writer.enable_compression(self.compress) + writer.enable_compression(self.compress) # type: ignore[arg-type] if self.chunked is not None: writer.enable_chunking() @@ -717,13 +740,20 @@ class ClientRequest: self.headers[hdrs.CONNECTION] = connection # status + headers - status_line = "{0} {1} HTTP/{v.major}.{v.minor}".format( - self.method, path, v=self.version - ) + v = self.version + status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}" await writer.write_headers(status_line, self.headers) + coro = self.write_bytes(writer, conn) - self._writer = self.loop.create_task(self.write_bytes(writer, conn)) + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to write + # bytes immediately to avoid having to schedule + # the task on the event loop. + task = asyncio.Task(coro, loop=self.loop, eager_start=True) + else: + task = self.loop.create_task(coro) + self._writer = task response_class = self.response_class assert response_class is not None self.response = response_class( @@ -741,8 +771,15 @@ class ClientRequest: async def close(self) -> None: if self._writer is not None: - with contextlib.suppress(asyncio.CancelledError): + try: await self._writer + except asyncio.CancelledError: + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise def terminate(self) -> None: if self._writer is not None: @@ -782,6 +819,7 @@ class ClientResponse(HeadersMixin): # post-init stage allows to not change ctor signature _closed = True # to allow __del__ for non-initialized properly response _released = False + _in_context = False __writer = None def __init__( @@ -820,9 +858,9 @@ class ClientResponse(HeadersMixin): # work after the response has finished reading the body. if session is None: # TODO: Fix session=None in tests (see ClientRequest.__init__). - self._resolve_charset: Callable[ - ["ClientResponse", bytes], str - ] = lambda *_: "utf-8" + self._resolve_charset: Callable[["ClientResponse", bytes], str] = ( + lambda *_: "utf-8" + ) else: self._resolve_charset = session._resolve_charset if loop.get_debug(): @@ -840,7 +878,12 @@ class ClientResponse(HeadersMixin): if self.__writer is not None: self.__writer.remove_done_callback(self.__reset_writer) self.__writer = writer - if writer is not None: + if writer is None: + return + if writer.done(): + # The writer is already done, so we can reset it immediately. + self.__reset_writer() + else: writer.add_done_callback(self.__reset_writer) @reify @@ -1066,7 +1109,12 @@ class ClientResponse(HeadersMixin): if not self.ok: # reason should always be not None for a started response assert self.reason is not None - self.release() + + # If we're in a context we can rely on __aexit__() to release as the + # exception propagates. + if not self._in_context: + self.release() + raise ClientResponseError( self.request_info, self.history, @@ -1085,7 +1133,15 @@ class ClientResponse(HeadersMixin): async def _wait_released(self) -> None: if self._writer is not None: - await self._writer + try: + await self._writer + except asyncio.CancelledError: + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise self._release_connection() def _cleanup_writer(self) -> None: @@ -1101,7 +1157,15 @@ class ClientResponse(HeadersMixin): async def wait_for_close(self) -> None: if self._writer is not None: - await self._writer + try: + await self._writer + except asyncio.CancelledError: + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise self.release() async def read(self) -> bytes: @@ -1130,7 +1194,7 @@ class ClientResponse(HeadersMixin): encoding = mimetype.parameters.get("charset") if encoding: - with contextlib.suppress(LookupError): + with contextlib.suppress(LookupError, ValueError): return codecs.lookup(encoding).name if mimetype.type == "application" and ( @@ -1176,6 +1240,7 @@ class ClientResponse(HeadersMixin): raise ContentTypeError( self.request_info, self.history, + status=self.status, message=( "Attempt to decode JSON with " "unexpected mimetype: %s" % ctype ), @@ -1192,6 +1257,7 @@ class ClientResponse(HeadersMixin): return loads(stripped.decode(encoding)) async def __aenter__(self) -> "ClientResponse": + self._in_context = True return self async def __aexit__( @@ -1200,6 +1266,7 @@ class ClientResponse(HeadersMixin): exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: + self._in_context = False # similar to _RequestContextManager, we do not need to check # for exceptions, response object can close connection # if state is broken diff --git a/contrib/python/aiohttp/aiohttp/client_ws.py b/contrib/python/aiohttp/aiohttp/client_ws.py index d9c74a30f52..c6b5da5103b 100644 --- a/contrib/python/aiohttp/aiohttp/client_ws.py +++ b/contrib/python/aiohttp/aiohttp/client_ws.py @@ -2,11 +2,12 @@ import asyncio import sys -from typing import Any, Optional, cast +from types import TracebackType +from typing import Any, Optional, Type, cast -from .client_exceptions import ClientError +from .client_exceptions import ClientError, ServerTimeoutError from .client_reqrep import ClientResponse -from .helpers import call_later, set_result +from .helpers import calculate_timeout_when, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -62,63 +63,123 @@ class ClientWebSocketResponse: self._autoping = autoping self._heartbeat = heartbeat self._heartbeat_cb: Optional[asyncio.TimerHandle] = None + self._heartbeat_when: float = 0.0 if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb: Optional[asyncio.TimerHandle] = None self._loop = loop - self._waiting: Optional[asyncio.Future[bool]] = None + self._waiting: bool = False + self._close_wait: Optional[asyncio.Future[None]] = None self._exception: Optional[BaseException] = None self._compress = compress self._client_notakeover = client_notakeover + self._ping_task: Optional[asyncio.Task[None]] = None self._reset_heartbeat() def _cancel_heartbeat(self) -> None: - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = None - + self._cancel_pong_response_cb() if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None + if self._ping_task is not None: + self._ping_task.cancel() + self._ping_task = None + + def _cancel_pong_response_cb(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None def _reset_heartbeat(self) -> None: - self._cancel_heartbeat() + if self._heartbeat is None: + return + self._cancel_pong_response_cb() + loop = self._loop + assert loop is not None + conn = self._conn + timeout_ceil_threshold = ( + conn._connector._timeout_ceil_threshold if conn is not None else 5 + ) + now = loop.time() + when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) + self._heartbeat_when = when + if self._heartbeat_cb is None: + # We do not cancel the previous heartbeat_cb here because + # it generates a significant amount of TimerHandle churn + # which causes asyncio to rebuild the heap frequently. + # Instead _send_heartbeat() will reschedule the next + # heartbeat if it fires too early. + self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) - if self._heartbeat is not None: - self._heartbeat_cb = call_later( - self._send_heartbeat, - self._heartbeat, - self._loop, - timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold - if self._conn is not None - else 5, + def _send_heartbeat(self) -> None: + self._heartbeat_cb = None + loop = self._loop + now = loop.time() + if now < self._heartbeat_when: + # Heartbeat fired too early, reschedule + self._heartbeat_cb = loop.call_at( + self._heartbeat_when, self._send_heartbeat ) + return - def _send_heartbeat(self) -> None: - if self._heartbeat is not None and not self._closed: - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - self._loop.create_task(self._writer.ping()) + conn = self._conn + timeout_ceil_threshold = ( + conn._connector._timeout_ceil_threshold if conn is not None else 5 + ) + when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) + self._cancel_pong_response_cb() + self._pong_response_cb = loop.call_at(when, self._pong_not_received) - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = call_later( - self._pong_not_received, - self._pong_heartbeat, - self._loop, - timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold - if self._conn is not None - else 5, - ) + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to send the ping + # immediately to avoid having to schedule + # the task on the event loop. + ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True) + else: + ping_task = loop.create_task(self._writer.ping()) + + if not ping_task.done(): + self._ping_task = ping_task + ping_task.add_done_callback(self._ping_task_done) + else: + self._ping_task_done(ping_task) + + def _ping_task_done(self, task: "asyncio.Task[None]") -> None: + """Callback for when the ping task completes.""" + if not task.cancelled() and (exc := task.exception()): + self._handle_ping_pong_exception(exc) + self._ping_task = None def _pong_not_received(self) -> None: - if not self._closed: - self._closed = True - self._close_code = WSCloseCode.ABNORMAL_CLOSURE - self._exception = asyncio.TimeoutError() - self._response.close() + self._handle_ping_pong_exception(ServerTimeoutError()) + + def _handle_ping_pong_exception(self, exc: BaseException) -> None: + """Handle exceptions raised during ping/pong processing.""" + if self._closed: + return + self._set_closed() + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + self._exception = exc + self._response.close() + if self._waiting and not self._closing: + self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None)) + + def _set_closed(self) -> None: + """Set the connection to closed. + + Cancel any heartbeat timers and set the closed flag. + """ + self._closed = True + self._cancel_heartbeat() + + def _set_closing(self) -> None: + """Set the connection to closing. + + Cancel any heartbeat timers and set the closing flag. + """ + self._closing = True + self._cancel_heartbeat() @property def closed(self) -> bool: @@ -181,14 +242,15 @@ class ClientWebSocketResponse: async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: # we need to break `receive()` cycle first, # `close()` may be called from different task - if self._waiting is not None and not self._closing: - self._closing = True + if self._waiting and not self._closing: + assert self._loop is not None + self._close_wait = self._loop.create_future() + self._set_closing() self._reader.feed_data(WS_CLOSING_MESSAGE, 0) - await self._waiting + await self._close_wait if not self._closed: - self._cancel_heartbeat() - self._closed = True + self._set_closed() try: await self._writer.close(code, message) except asyncio.CancelledError: @@ -219,7 +281,7 @@ class ClientWebSocketResponse: self._response.close() return True - if msg.type == WSMsgType.CLOSE: + if msg.type is WSMsgType.CLOSE: self._close_code = msg.data self._response.close() return True @@ -227,8 +289,10 @@ class ClientWebSocketResponse: return False async def receive(self, timeout: Optional[float] = None) -> WSMessage: + receive_timeout = timeout or self._receive_timeout + while True: - if self._waiting is not None: + if self._waiting: raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: @@ -238,15 +302,22 @@ class ClientWebSocketResponse: return WS_CLOSED_MESSAGE try: - self._waiting = self._loop.create_future() + self._waiting = True try: - async with async_timeout.timeout(timeout or self._receive_timeout): + if receive_timeout: + # Entering the context manager and creating + # Timeout() object can take almost 50% of the + # run time in this loop so we avoid it if + # there is no read timeout. + async with async_timeout.timeout(receive_timeout): + msg = await self._reader.read() + else: msg = await self._reader.read() self._reset_heartbeat() finally: - waiter = self._waiting - self._waiting = None - set_result(waiter, True) + self._waiting = False + if self._close_wait: + set_result(self._close_wait, None) except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = WSCloseCode.ABNORMAL_CLOSURE raise @@ -255,7 +326,8 @@ class ClientWebSocketResponse: await self.close() return WSMessage(WSMsgType.CLOSED, None, None) except ClientError: - self._closed = True + # Likely ServerDisconnectedError when connection is lost + self._set_closed() self._close_code = WSCloseCode.ABNORMAL_CLOSURE return WS_CLOSED_MESSAGE except WebSocketError as exc: @@ -264,35 +336,35 @@ class ClientWebSocketResponse: return WSMessage(WSMsgType.ERROR, exc, None) except Exception as exc: self._exception = exc - self._closing = True + self._set_closing() self._close_code = WSCloseCode.ABNORMAL_CLOSURE await self.close() return WSMessage(WSMsgType.ERROR, exc, None) - if msg.type == WSMsgType.CLOSE: - self._closing = True + if msg.type is WSMsgType.CLOSE: + self._set_closing() self._close_code = msg.data if not self._closed and self._autoclose: await self.close() - elif msg.type == WSMsgType.CLOSING: - self._closing = True - elif msg.type == WSMsgType.PING and self._autoping: + elif msg.type is WSMsgType.CLOSING: + self._set_closing() + elif msg.type is WSMsgType.PING and self._autoping: await self.pong(msg.data) continue - elif msg.type == WSMsgType.PONG and self._autoping: + elif msg.type is WSMsgType.PONG and self._autoping: continue return msg async def receive_str(self, *, timeout: Optional[float] = None) -> str: msg = await self.receive(timeout) - if msg.type != WSMsgType.TEXT: + if msg.type is not WSMsgType.TEXT: raise TypeError(f"Received message {msg.type}:{msg.data!r} is not str") return cast(str, msg.data) async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: msg = await self.receive(timeout) - if msg.type != WSMsgType.BINARY: + if msg.type is not WSMsgType.BINARY: raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") return cast(bytes, msg.data) @@ -313,3 +385,14 @@ class ClientWebSocketResponse: if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): raise StopAsyncIteration return msg + + async def __aenter__(self) -> "ClientWebSocketResponse": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() diff --git a/contrib/python/aiohttp/aiohttp/compression_utils.py b/contrib/python/aiohttp/aiohttp/compression_utils.py index 9631d377e9a..ab4a2f1cc84 100644 --- a/contrib/python/aiohttp/aiohttp/compression_utils.py +++ b/contrib/python/aiohttp/aiohttp/compression_utils.py @@ -50,9 +50,11 @@ class ZLibCompressor(ZlibBaseHandler): max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, ): super().__init__( - mode=encoding_to_mode(encoding, suppress_deflate_header) - if wbits is None - else wbits, + mode=( + encoding_to_mode(encoding, suppress_deflate_header) + if wbits is None + else wbits + ), executor=executor, max_sync_chunk_size=max_sync_chunk_size, ) diff --git a/contrib/python/aiohttp/aiohttp/connector.py b/contrib/python/aiohttp/aiohttp/connector.py index f95ebe84c66..8288a0115b7 100644 --- a/contrib/python/aiohttp/aiohttp/connector.py +++ b/contrib/python/aiohttp/aiohttp/connector.py @@ -1,6 +1,7 @@ import asyncio import functools import random +import socket import sys import traceback import warnings @@ -22,6 +23,7 @@ from typing import ( List, Literal, Optional, + Sequence, Set, Tuple, Type, @@ -29,10 +31,11 @@ from typing import ( cast, ) +import aiohappyeyeballs import attr from . import hdrs, helpers -from .abc import AbstractResolver +from .abc import AbstractResolver, ResolveResult from .client_exceptions import ( ClientConnectionError, ClientConnectorCertificateError, @@ -47,7 +50,7 @@ from .client_exceptions import ( ) from .client_proto import ResponseHandler from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params -from .helpers import ceil_timeout, get_running_loop, is_ip_address, noop, sentinel +from .helpers import ceil_timeout, is_ip_address, noop, sentinel from .locks import EventResultOrError from .resolver import DefaultResolver @@ -60,6 +63,14 @@ except ImportError: # pragma: no cover SSLContext = object # type: ignore[misc,assignment] +EMPTY_SCHEMA_SET = frozenset({""}) +HTTP_SCHEMA_SET = frozenset({"http", "https"}) +WS_SCHEMA_SET = frozenset({"ws", "wss"}) + +HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET +HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET + + __all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector") @@ -208,6 +219,8 @@ class BaseConnector: # abort transport after 2 seconds (cleanup broken connections) _cleanup_closed_period = 2.0 + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET + def __init__( self, *, @@ -229,7 +242,7 @@ class BaseConnector: if keepalive_timeout is sentinel: keepalive_timeout = 15.0 - loop = get_running_loop(loop) + loop = loop or asyncio.get_event_loop() self._timeout_ceil_threshold = timeout_ceil_threshold self._closed = False @@ -240,9 +253,9 @@ class BaseConnector: self._limit = limit self._limit_per_host = limit_per_host self._acquired: Set[ResponseHandler] = set() - self._acquired_per_host: DefaultDict[ - ConnectionKey, Set[ResponseHandler] - ] = defaultdict(set) + self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = ( + defaultdict(set) + ) self._keepalive_timeout = cast(float, keepalive_timeout) self._force_close = force_close @@ -691,14 +704,14 @@ class BaseConnector: class _DNSCacheTable: def __init__(self, ttl: Optional[float] = None) -> None: - self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]] = {} + self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {} self._timestamps: Dict[Tuple[str, int], float] = {} self._ttl = ttl def __contains__(self, host: object) -> bool: return host in self._addrs_rr - def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None: + def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None: self._addrs_rr[key] = (cycle(addrs), len(addrs)) if self._ttl is not None: @@ -714,7 +727,7 @@ class _DNSCacheTable: self._addrs_rr.clear() self._timestamps.clear() - def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]: + def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]: loop, length = self._addrs_rr[key] addrs = list(islice(loop, length)) # Consume one more element to shift internal state of `cycle` @@ -728,6 +741,35 @@ class _DNSCacheTable: return self._timestamps[key] + self._ttl < monotonic() +def _make_ssl_context(verified: bool) -> SSLContext: + """Create SSL context. + + This method is not async-friendly and should be called from a thread + because it will load certificates from disk and do other blocking I/O. + """ + if ssl is None: + # No ssl support + return None + if verified: + return ssl.create_default_context() + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.check_hostname = False + sslcontext.verify_mode = ssl.CERT_NONE + sslcontext.options |= ssl.OP_NO_COMPRESSION + sslcontext.set_default_verify_paths() + return sslcontext + + +# The default SSLContext objects are created at import time +# since they do blocking I/O to load certificates from disk, +# and imports should always be done before the event loop starts +# or in a thread. +_SSL_CONTEXT_VERIFIED = _make_ssl_context(True) +_SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False) + + class TCPConnector(BaseConnector): """TCP connector. @@ -735,7 +777,7 @@ class TCPConnector(BaseConnector): fingerprint - Pass the binary sha256 digest of the expected certificate in DER format to verify that the certificate the server presents matches. See also - https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning + https://en.wikipedia.org/wiki/HTTP_Public_Key_Pinning resolver - Enable DNS lookups and use this resolver use_dns_cache - Use memory cache for DNS lookups. @@ -750,9 +792,15 @@ class TCPConnector(BaseConnector): limit_per_host - Number of simultaneous connections to one host. enable_cleanup_closed - Enables clean-up closed ssl transports. Disabled by default. + happy_eyeballs_delay - This is the “Connection Attempt Delay” + as defined in RFC 8305. To disable + the happy eyeballs algorithm, set to None. + interleave - “First Address Family Count” as defined in RFC 8305 loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) + def __init__( self, *, @@ -760,7 +808,7 @@ class TCPConnector(BaseConnector): fingerprint: Optional[bytes] = None, use_dns_cache: bool = True, ttl_dns_cache: Optional[int] = 10, - family: int = 0, + family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC, ssl_context: Optional[SSLContext] = None, ssl: Union[bool, Fingerprint, SSLContext] = True, local_addr: Optional[Tuple[str, int]] = None, @@ -772,6 +820,8 @@ class TCPConnector(BaseConnector): enable_cleanup_closed: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, timeout_ceil_threshold: float = 5, + happy_eyeballs_delay: Optional[float] = 0.25, + interleave: Optional[int] = None, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -792,13 +842,19 @@ class TCPConnector(BaseConnector): self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {} self._family = family - self._local_addr = local_addr + self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) + self._happy_eyeballs_delay = happy_eyeballs_delay + self._interleave = interleave + self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() def close(self) -> Awaitable[None]: """Close all ongoing DNS calls.""" for ev in self._throttle_dns_events.values(): ev.cancel() + for t in self._resolve_host_tasks: + t.cancel() + return super().close() @property @@ -823,8 +879,8 @@ class TCPConnector(BaseConnector): self._cached_hosts.clear() async def _resolve_host( - self, host: str, port: int, traces: Optional[List["Trace"]] = None - ) -> List[Dict[str, Any]]: + self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None + ) -> List[ResolveResult]: """Resolve host and return list of addresses.""" if is_ip_address(host): return [ @@ -876,11 +932,13 @@ class TCPConnector(BaseConnector): resolved_host_task = asyncio.create_task( self._resolve_host_with_throttle(key, host, port, traces) ) + self._resolve_host_tasks.add(resolved_host_task) + resolved_host_task.add_done_callback(self._resolve_host_tasks.discard) try: return await asyncio.shield(resolved_host_task) except asyncio.CancelledError: - def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: + def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None: with suppress(Exception, asyncio.CancelledError): fut.result() @@ -892,8 +950,8 @@ class TCPConnector(BaseConnector): key: Tuple[str, int], host: str, port: int, - traces: Optional[List["Trace"]], - ) -> List[Dict[str, Any]]: + traces: Optional[Sequence["Trace"]], + ) -> List[ResolveResult]: """Resolve host with a dns events throttle.""" if key in self._throttle_dns_events: # get event early, before any await (#4014) @@ -945,29 +1003,6 @@ class TCPConnector(BaseConnector): return proto - @staticmethod - @functools.lru_cache(None) - def _make_ssl_context(verified: bool) -> SSLContext: - if verified: - return ssl.create_default_context() - else: - sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.check_hostname = False - sslcontext.verify_mode = ssl.CERT_NONE - try: - sslcontext.options |= ssl.OP_NO_COMPRESSION - except AttributeError as attr_err: - warnings.warn( - "{!s}: The Python interpreter is compiled " - "against OpenSSL < 1.0.0. Ref: " - "https://docs.python.org/3/library/ssl.html" - "#ssl.OP_NO_COMPRESSION".format(attr_err), - ) - sslcontext.set_default_verify_paths() - return sslcontext - def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: """Logic to get the correct SSL context @@ -982,25 +1017,25 @@ class TCPConnector(BaseConnector): 3. if verify_ssl is False in req, generate a SSL context that won't verify """ - if req.is_ssl(): - if ssl is None: # pragma: no cover - raise RuntimeError("SSL is not supported.") - sslcontext = req.ssl - if isinstance(sslcontext, ssl.SSLContext): - return sslcontext - if sslcontext is not True: - # not verified or fingerprinted - return self._make_ssl_context(False) - sslcontext = self._ssl - if isinstance(sslcontext, ssl.SSLContext): - return sslcontext - if sslcontext is not True: - # not verified or fingerprinted - return self._make_ssl_context(False) - return self._make_ssl_context(True) - else: + if not req.is_ssl(): return None + if ssl is None: # pragma: no cover + raise RuntimeError("SSL is not supported.") + sslcontext = req.ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not True: + # not verified or fingerprinted + return _SSL_CONTEXT_UNVERIFIED + sslcontext = self._ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not True: + # not verified or fingerprinted + return _SSL_CONTEXT_UNVERIFIED + return _SSL_CONTEXT_VERIFIED + def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: ret = req.ssl if isinstance(ret, Fingerprint): @@ -1013,6 +1048,36 @@ class TCPConnector(BaseConnector): async def _wrap_create_connection( self, *args: Any, + addr_infos: List[aiohappyeyeballs.AddrInfoType], + req: ClientRequest, + timeout: "ClientTimeout", + client_error: Type[Exception] = ClientConnectorError, + **kwargs: Any, + ) -> Tuple[asyncio.Transport, ResponseHandler]: + try: + async with ceil_timeout( + timeout.sock_connect, ceil_threshold=timeout.ceil_threshold + ): + sock = await aiohappyeyeballs.start_connection( + addr_infos=addr_infos, + local_addr_infos=self._local_addr_infos, + happy_eyeballs_delay=self._happy_eyeballs_delay, + interleave=self._interleave, + loop=self._loop, + ) + return await self._loop.create_connection(*args, **kwargs, sock=sock) + except cert_errors as exc: + raise ClientConnectorCertificateError(req.connection_key, exc) from exc + except ssl_errors as exc: + raise ClientConnectorSSLError(req.connection_key, exc) from exc + except OSError as exc: + if exc.errno is None and isinstance(exc, asyncio.TimeoutError): + raise + raise client_error(req.connection_key, exc) from exc + + async def _wrap_existing_connection( + self, + *args: Any, req: ClientRequest, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, @@ -1121,13 +1186,11 @@ class TCPConnector(BaseConnector): ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: """Wrap the raw TCP transport with TLS.""" tls_proto = self._factory() # Create a brand new proto for TLS - - # Safety of the `cast()` call here is based on the fact that - # internally `_get_ssl_context()` only returns `None` when - # `req.is_ssl()` evaluates to `False` which is never gonna happen - # in this code path. Of course, it's rather fragile - # maintainability-wise but this is to be solved separately. - sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req)) + sslcontext = self._get_ssl_context(req) + if TYPE_CHECKING: + # _start_tls_connection is unreachable in the current code path + # if sslcontext is None. + assert sslcontext is not None try: async with ceil_timeout( @@ -1176,6 +1239,27 @@ class TCPConnector(BaseConnector): return tls_transport, tls_proto + def _convert_hosts_to_addr_infos( + self, hosts: List[ResolveResult] + ) -> List[aiohappyeyeballs.AddrInfoType]: + """Converts the list of hosts to a list of addr_infos. + + The list of hosts is the result of a DNS lookup. The list of + addr_infos is the result of a call to `socket.getaddrinfo()`. + """ + addr_infos: List[aiohappyeyeballs.AddrInfoType] = [] + for hinfo in hosts: + host = hinfo["host"] + is_ipv6 = ":" in host + family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + if self._family and self._family != family: + continue + addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"]) + addr_infos.append( + (family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr) + ) + return addr_infos + async def _create_direct_connection( self, req: ClientRequest, @@ -1209,36 +1293,27 @@ class TCPConnector(BaseConnector): raise ClientConnectorError(req.connection_key, exc) from exc last_exc: Optional[Exception] = None - - for hinfo in hosts: - host = hinfo["host"] - port = hinfo["port"] - + addr_infos = self._convert_hosts_to_addr_infos(hosts) + while addr_infos: # Strip trailing dots, certificates contain FQDN without dots. # See https://github.com/aio-libs/aiohttp/issues/3636 server_hostname = ( - (req.server_hostname or hinfo["hostname"]).rstrip(".") - if sslcontext - else None + (req.server_hostname or host).rstrip(".") if sslcontext else None ) try: transp, proto = await self._wrap_create_connection( self._factory, - host, - port, timeout=timeout, ssl=sslcontext, - family=hinfo["family"], - proto=hinfo["proto"], - flags=hinfo["flags"], + addr_infos=addr_infos, server_hostname=server_hostname, - local_addr=self._local_addr, req=req, client_error=client_error, ) except ClientConnectorError as exc: last_exc = exc + aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave) continue if req.is_ssl() and fingerprint: @@ -1249,6 +1324,10 @@ class TCPConnector(BaseConnector): if not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transp) last_exc = exc + # Remove the bad peer from the list of addr_infos + sock: socket.socket = transp.get_extra_info("socket") + bad_peer = sock.getpeername() + aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer) continue return transp, proto @@ -1367,7 +1446,7 @@ class TCPConnector(BaseConnector): if not runtime_has_start_tls: # HTTP proxy with support for upgrade to HTTPS sslcontext = self._get_ssl_context(req) - return await self._wrap_create_connection( + return await self._wrap_existing_connection( self._factory, timeout=timeout, ssl=sslcontext, @@ -1401,6 +1480,8 @@ class UnixConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"}) + def __init__( self, path: str, @@ -1457,6 +1538,8 @@ class NamedPipeConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"}) + def __init__( self, path: str, diff --git a/contrib/python/aiohttp/aiohttp/cookiejar.py b/contrib/python/aiohttp/aiohttp/cookiejar.py index a348f112cb5..c78d5fa7e72 100644 --- a/contrib/python/aiohttp/aiohttp/cookiejar.py +++ b/contrib/python/aiohttp/aiohttp/cookiejar.py @@ -2,6 +2,8 @@ import asyncio import calendar import contextlib import datetime +import heapq +import itertools import os # noqa import pathlib import pickle @@ -9,8 +11,7 @@ import re import time from collections import defaultdict from http.cookies import BaseCookie, Morsel, SimpleCookie -from math import ceil -from typing import ( # noqa +from typing import ( DefaultDict, Dict, Iterable, @@ -35,6 +36,15 @@ __all__ = ("CookieJar", "DummyCookieJar") CookieItem = Union[str, "Morsel[str]"] +# We cache these string methods here as their use is in performance critical code. +_FORMAT_PATH = "{}/{}".format +_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format + +# The minimum number of scheduled cookie expirations before we start cleaning up +# the expiration heap. This is a performance optimization to avoid cleaning up the +# heap too often when there are only a few scheduled expirations. +_MIN_SCHEDULED_COOKIE_EXPIRATION = 100 + class CookieJar(AbstractCookieJar): """Implements cookie storage adhering to RFC 6265.""" @@ -85,6 +95,9 @@ class CookieJar(AbstractCookieJar): self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict( SimpleCookie ) + self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = ( + defaultdict(dict) + ) self._host_only_cookies: Set[Tuple[str, str]] = set() self._unsafe = unsafe self._quote_cookie = quote_cookie @@ -100,7 +113,7 @@ class CookieJar(AbstractCookieJar): for url in treat_as_secure_origin ] self._treat_as_secure_origin = treat_as_secure_origin - self._next_expiration: float = ceil(time.time()) + self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = [] self._expirations: Dict[Tuple[str, str, str], float] = {} def save(self, file_path: PathLike) -> None: @@ -115,34 +128,26 @@ class CookieJar(AbstractCookieJar): def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: if predicate is None: - self._next_expiration = ceil(time.time()) + self._expire_heap.clear() self._cookies.clear() + self._morsel_cache.clear() self._host_only_cookies.clear() self._expirations.clear() return - to_del = [] now = time.time() - for (domain, path), cookie in self._cookies.items(): - for name, morsel in cookie.items(): - key = (domain, path, name) - if ( - key in self._expirations and self._expirations[key] <= now - ) or predicate(morsel): - to_del.append(key) - - for domain, path, name in to_del: - self._host_only_cookies.discard((domain, name)) - key = (domain, path, name) - if key in self._expirations: - del self._expirations[(domain, path, name)] - self._cookies[(domain, path)].pop(name, None) - - self._next_expiration = ( - min(*self._expirations.values(), self.SUB_MAX_TIME) + 1 - if self._expirations - else self.MAX_TIME - ) + to_del = [ + key + for (domain, path), cookie in self._cookies.items() + for name, morsel in cookie.items() + if ( + (key := (domain, path, name)) in self._expirations + and self._expirations[key] <= now + ) + or predicate(morsel) + ] + if to_del: + self._delete_cookies(to_del) def clear_domain(self, domain: str) -> None: self.clear(lambda x: self._is_domain_match(domain, x["domain"])) @@ -153,14 +158,70 @@ class CookieJar(AbstractCookieJar): yield from val.values() def __len__(self) -> int: - return sum(1 for i in self) + """Return number of cookies. + + This function does not iterate self to avoid unnecessary expiration + checks. + """ + return sum(len(cookie.values()) for cookie in self._cookies.values()) def _do_expiration(self) -> None: - self.clear(lambda x: False) + """Remove expired cookies.""" + if not (expire_heap_len := len(self._expire_heap)): + return + + # If the expiration heap grows larger than the number expirations + # times two, we clean it up to avoid keeping expired entries in + # the heap and consuming memory. We guard this with a minimum + # threshold to avoid cleaning up the heap too often when there are + # only a few scheduled expirations. + if ( + expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION + and expire_heap_len > len(self._expirations) * 2 + ): + # Remove any expired entries from the expiration heap + # that do not match the expiration time in the expirations + # as it means the cookie has been re-added to the heap + # with a different expiration time. + self._expire_heap = [ + entry + for entry in self._expire_heap + if self._expirations.get(entry[1]) == entry[0] + ] + heapq.heapify(self._expire_heap) + + now = time.time() + to_del: List[Tuple[str, str, str]] = [] + # Find any expired cookies and add them to the to-delete list + while self._expire_heap: + when, cookie_key = self._expire_heap[0] + if when > now: + break + heapq.heappop(self._expire_heap) + # Check if the cookie hasn't been re-added to the heap + # with a different expiration time as it will be removed + # later when it reaches the top of the heap and its + # expiration time is met. + if self._expirations.get(cookie_key) == when: + to_del.append(cookie_key) + + if to_del: + self._delete_cookies(to_del) + + def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None: + for domain, path, name in to_del: + self._host_only_cookies.discard((domain, name)) + self._cookies[(domain, path)].pop(name, None) + self._morsel_cache[(domain, path)].pop(name, None) + self._expirations.pop((domain, path, name), None) def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None: - self._next_expiration = min(self._next_expiration, when) - self._expirations[(domain, path, name)] = when + cookie_key = (domain, path, name) + if self._expirations.get(cookie_key) == when: + # Avoid adding duplicates to the heap + return + heapq.heappush(self._expire_heap, (when, cookie_key)) + self._expirations[cookie_key] = when def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: """Update cookies.""" @@ -182,7 +243,7 @@ class CookieJar(AbstractCookieJar): domain = cookie["domain"] # ignore domains with trailing dots - if domain.endswith("."): + if domain and domain[-1] == ".": domain = "" del cookie["domain"] @@ -192,7 +253,7 @@ class CookieJar(AbstractCookieJar): self._host_only_cookies.add((hostname, name)) domain = cookie["domain"] = hostname - if domain.startswith("."): + if domain and domain[0] == ".": # Remove leading dot domain = domain[1:] cookie["domain"] = domain @@ -202,7 +263,7 @@ class CookieJar(AbstractCookieJar): continue path = cookie["path"] - if not path or not path.startswith("/"): + if not path or path[0] != "/": # Set the cookie's path to the response path path = response_url.path if not path.startswith("/"): @@ -211,9 +272,9 @@ class CookieJar(AbstractCookieJar): # Cut everything from the last slash to the end path = "/" + path[1 : path.rfind("/")] cookie["path"] = path + path = path.rstrip("/") - max_age = cookie["max-age"] - if max_age: + if max_age := cookie["max-age"]: try: delta_seconds = int(max_age) max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME) @@ -221,16 +282,18 @@ class CookieJar(AbstractCookieJar): except ValueError: cookie["max-age"] = "" - else: - expires = cookie["expires"] - if expires: - expire_time = self._parse_date(expires) - if expire_time: - self._expire_cookie(expire_time, domain, path, name) - else: - cookie["expires"] = "" + elif expires := cookie["expires"]: + if expire_time := self._parse_date(expires): + self._expire_cookie(expire_time, domain, path, name) + else: + cookie["expires"] = "" - self._cookies[(domain, path)][name] = cookie + key = (domain, path) + if self._cookies[key].get(name) != cookie: + # Don't blow away the cache if the same + # cookie gets set again + self._cookies[key][name] = cookie + self._morsel_cache[key].pop(name, None) self._do_expiration() @@ -256,36 +319,52 @@ class CookieJar(AbstractCookieJar): request_origin = request_url.origin() is_not_secure = request_origin not in self._treat_as_secure_origin - # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 - for cookie in sorted(self, key=lambda c: len(c["path"])): - name = cookie.key - domain = cookie["domain"] + # Send shared cookie + for c in self._cookies[("", "")].values(): + filtered[c.key] = c.value - # Send shared cookies - if not domain: - filtered[name] = cookie.value - continue + if is_ip_address(hostname): + if not self._unsafe: + return filtered + domains: Iterable[str] = (hostname,) + else: + # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com") + domains = itertools.accumulate( + reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED + ) - if not self._unsafe and is_ip_address(hostname): - continue + # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar") + paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH) + # Create every combination of (domain, path) pairs. + pairs = itertools.product(domains, paths) + + path_len = len(request_url.path) + # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 + for p in pairs: + for name, cookie in self._cookies[p].items(): + domain = cookie["domain"] - if (domain, name) in self._host_only_cookies: - if domain != hostname: + if (domain, name) in self._host_only_cookies and domain != hostname: continue - elif not self._is_domain_match(domain, hostname): - continue - if not self._is_path_match(request_url.path, cookie["path"]): - continue + # Skip edge case when the cookie has a trailing slash but request doesn't. + if len(cookie["path"]) > path_len: + continue - if is_not_secure and cookie["secure"]: - continue + if is_not_secure and cookie["secure"]: + continue - # It's critical we use the Morsel so the coded_value - # (based on cookie version) is preserved - mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) - mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) - filtered[name] = mrsl_val + # We already built the Morsel so reuse it here + if name in self._morsel_cache[p]: + filtered[name] = self._morsel_cache[p][name] + continue + + # It's critical we use the Morsel so the coded_value + # (based on cookie version) is preserved + mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) + mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) + self._morsel_cache[p][name] = mrsl_val + filtered[name] = mrsl_val return filtered @@ -305,25 +384,6 @@ class CookieJar(AbstractCookieJar): return not is_ip_address(hostname) - @staticmethod - def _is_path_match(req_path: str, cookie_path: str) -> bool: - """Implements path matching adhering to RFC 6265.""" - if not req_path.startswith("/"): - req_path = "/" - - if req_path == cookie_path: - return True - - if not req_path.startswith(cookie_path): - return False - - if cookie_path.endswith("/"): - return True - - non_matching = req_path[len(cookie_path) :] - - return non_matching.startswith("/") - @classmethod def _parse_date(cls, date_str: str) -> Optional[int]: """Implements date string parsing adhering to RFC 6265.""" diff --git a/contrib/python/aiohttp/aiohttp/helpers.py b/contrib/python/aiohttp/aiohttp/helpers.py index 284033b7a04..40705b16d71 100644 --- a/contrib/python/aiohttp/aiohttp/helpers.py +++ b/contrib/python/aiohttp/aiohttp/helpers.py @@ -14,7 +14,6 @@ import platform import re import sys import time -import warnings import weakref from collections import namedtuple from contextlib import suppress @@ -35,7 +34,6 @@ from typing import ( List, Mapping, Optional, - Pattern, Protocol, Tuple, Type, @@ -52,7 +50,7 @@ from multidict import MultiDict, MultiDictProxy, MultiMapping from yarl import URL from . import hdrs -from .log import client_logger, internal_logger +from .log import client_logger if sys.version_info >= (3, 11): import asyncio as async_timeout @@ -165,9 +163,9 @@ class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): """Create BasicAuth from url.""" if not isinstance(url, URL): raise TypeError("url should be yarl.URL instance") - if url.user is None: + if url.user is None and url.password is None: return None - return cls(url.user, url.password or "", encoding=encoding) + return cls(url.user or "", url.password or "", encoding=encoding) def encode(self) -> str: """Encode credentials.""" @@ -287,38 +285,6 @@ def proxies_from_env() -> Dict[str, ProxyInfo]: return ret -def current_task( - loop: Optional[asyncio.AbstractEventLoop] = None, -) -> "Optional[asyncio.Task[Any]]": - return asyncio.current_task(loop=loop) - - -def get_running_loop( - loop: Optional[asyncio.AbstractEventLoop] = None, -) -> asyncio.AbstractEventLoop: - if loop is None: - loop = asyncio.get_event_loop() - if not loop.is_running(): - warnings.warn( - "The object should be created within an async function", - DeprecationWarning, - stacklevel=3, - ) - if loop.get_debug(): - internal_logger.warning( - "The object should be created within an async function", stack_info=True - ) - return loop - - -def isasyncgenfunction(obj: Any) -> bool: - func = getattr(inspect, "isasyncgenfunction", None) - if func is not None: - return func(obj) # type: ignore[no-any-return] - else: - return False - - def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: """Get a permitted proxy for the given URL from the env.""" if url.host is not None and proxy_bypass(url.host): @@ -504,44 +470,51 @@ try: except ImportError: pass -_ipv4_pattern = ( - r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}" - r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" -) -_ipv6_pattern = ( - r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}" - r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)" - r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})" - r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}" - r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}" - r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)" - r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}" - r":|:(:[A-F0-9]{1,4}){7})$" -) -_ipv4_regex = re.compile(_ipv4_pattern) -_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE) -_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii")) -_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE) +def is_ipv4_address(host: Optional[Union[str, bytes]]) -> bool: + """Check if host looks like an IPv4 address. -def _is_ip_address( - regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]] -) -> bool: - if host is None: + This function does not validate that the format is correct, only that + the host is a str or bytes, and its all numeric. + + This check is only meant as a heuristic to ensure that + a host is not a domain name. + """ + if not host: return False + # For a host to be an ipv4 address, it must be all numeric. if isinstance(host, str): - return bool(regex.match(host)) - elif isinstance(host, (bytes, bytearray, memoryview)): - return bool(regexb.match(host)) - else: - raise TypeError(f"{host} [{type(host)}] is not a str or bytes") + return host.replace(".", "").isdigit() + if isinstance(host, (bytes, bytearray, memoryview)): + return host.decode("ascii").replace(".", "").isdigit() + raise TypeError(f"{host} [{type(host)}] is not a str or bytes") + +def is_ipv6_address(host: Optional[Union[str, bytes]]) -> bool: + """Check if host looks like an IPv6 address. -is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb) -is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb) + This function does not validate that the format is correct, only that + the host contains a colon and that it is a str or bytes. + + This check is only meant as a heuristic to ensure that + a host is not a domain name. + """ + if not host: + return False + # The host must contain a colon to be an IPv6 address. + if isinstance(host, str): + return ":" in host + if isinstance(host, (bytes, bytearray, memoryview)): + return b":" in host + raise TypeError(f"{host} [{type(host)}] is not a str or bytes") def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool: + """Check if host looks like an IP Address. + + This check is only meant as a heuristic to ensure that + a host is not a domain name. + """ return is_ipv4_address(host) or is_ipv6_address(host) @@ -619,12 +592,23 @@ def call_later( loop: asyncio.AbstractEventLoop, timeout_ceil_threshold: float = 5, ) -> Optional[asyncio.TimerHandle]: - if timeout is not None and timeout > 0: - when = loop.time() + timeout - if timeout > timeout_ceil_threshold: - when = ceil(when) - return loop.call_at(when, cb) - return None + if timeout is None or timeout <= 0: + return None + now = loop.time() + when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) + return loop.call_at(when, cb) + + +def calculate_timeout_when( + loop_time: float, + timeout: float, + timeout_ceiling_threshold: float, +) -> float: + """Calculate when to execute a timeout.""" + when = loop_time + timeout + if timeout > timeout_ceiling_threshold: + return ceil(when) + return when class TimeoutHandle: @@ -651,7 +635,7 @@ class TimeoutHandle: def close(self) -> None: self._callbacks.clear() - def start(self) -> Optional[asyncio.Handle]: + def start(self) -> Optional[asyncio.TimerHandle]: timeout = self._timeout if timeout is not None and timeout > 0: when = self._loop.time() + timeout @@ -709,7 +693,7 @@ class TimerContext(BaseTimerContext): raise asyncio.TimeoutError from None def __enter__(self) -> BaseTimerContext: - task = current_task(loop=self._loop) + task = asyncio.current_task(loop=self._loop) if task is None: raise RuntimeError( @@ -749,7 +733,7 @@ def ceil_timeout( if delay is None or delay <= 0: return async_timeout.timeout(None) - loop = get_running_loop() + loop = asyncio.get_running_loop() now = loop.time() when = now + delay if delay > ceil_threshold: @@ -784,7 +768,8 @@ class HeadersMixin: raw = self._headers.get(hdrs.CONTENT_TYPE) if self._stored_content_type != raw: self._parse_content_type(raw) - return self._content_type # type: ignore[return-value] + assert self._content_type is not None + return self._content_type @property def charset(self) -> Optional[str]: @@ -792,7 +777,8 @@ class HeadersMixin: raw = self._headers.get(hdrs.CONTENT_TYPE) if self._stored_content_type != raw: self._parse_content_type(raw) - return self._content_dict.get("charset") # type: ignore[union-attr] + assert self._content_dict is not None + return self._content_dict.get("charset") @property def content_length(self) -> Optional[int]: @@ -818,8 +804,7 @@ class ErrorableProtocol(Protocol): self, exc: BaseException, exc_cause: BaseException = ..., - ) -> None: - ... # pragma: no cover + ) -> None: ... # pragma: no cover def set_exception( @@ -905,12 +890,10 @@ class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]): ) @overload # type: ignore[override] - def __getitem__(self, key: AppKey[_T]) -> _T: - ... + def __getitem__(self, key: AppKey[_T]) -> _T: ... @overload - def __getitem__(self, key: str) -> Any: - ... + def __getitem__(self, key: str) -> Any: ... def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: for mapping in self._maps: @@ -921,16 +904,13 @@ class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]): raise KeyError(key) @overload # type: ignore[override] - def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: - ... + def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ... @overload - def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: - ... + def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... @overload - def get(self, key: str, default: Any = ...) -> Any: - ... + def get(self, key: str, default: Any = ...) -> Any: ... def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: try: @@ -993,6 +973,7 @@ def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]: return None [email protected]_cache def must_be_empty_body(method: str, code: int) -> bool: """Check if a request must return an empty body.""" return ( diff --git a/contrib/python/aiohttp/aiohttp/http_exceptions.py b/contrib/python/aiohttp/aiohttp/http_exceptions.py index 72eac3a3cac..c43ee0d9659 100644 --- a/contrib/python/aiohttp/aiohttp/http_exceptions.py +++ b/contrib/python/aiohttp/aiohttp/http_exceptions.py @@ -1,6 +1,5 @@ """Low-level http related exceptions.""" - from textwrap import indent from typing import Optional, Union diff --git a/contrib/python/aiohttp/aiohttp/http_parser.py b/contrib/python/aiohttp/aiohttp/http_parser.py index 013511917e8..686a2d02e28 100644 --- a/contrib/python/aiohttp/aiohttp/http_parser.py +++ b/contrib/python/aiohttp/aiohttp/http_parser.py @@ -47,7 +47,6 @@ from .http_exceptions import ( TransferEncodingError, ) from .http_writer import HttpVersion, HttpVersion10 -from .log import internal_logger from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import RawHeaders @@ -249,7 +248,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]): timer: Optional[BaseTimerContext] = None, code: Optional[int] = None, method: Optional[str] = None, - readall: bool = False, payload_exception: Optional[Type[BaseException]] = None, response_with_body: bool = True, read_until_eof: bool = False, @@ -263,7 +261,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]): self.timer = timer self.code = code self.method = method - self.readall = readall self.payload_exception = payload_exception self.response_with_body = response_with_body self.read_until_eof = read_until_eof @@ -280,8 +277,10 @@ class HttpParser(abc.ABC, Generic[_MsgT]): ) @abc.abstractmethod - def parse_message(self, lines: List[bytes]) -> _MsgT: - pass + def parse_message(self, lines: List[bytes]) -> _MsgT: ... + + @abc.abstractmethod + def _is_chunked_te(self, te: str) -> bool: ... def feed_eof(self) -> Optional[_MsgT]: if self._payload_parser is not None: @@ -318,6 +317,7 @@ class HttpParser(abc.ABC, Generic[_MsgT]): start_pos = 0 loop = self.loop + should_close = False while start_pos < data_len: # read HTTP message (request/response line + headers), \r\n\r\n @@ -330,6 +330,9 @@ class HttpParser(abc.ABC, Generic[_MsgT]): continue if pos >= start_pos: + if should_close: + raise BadHttpMessage("Data after `Connection: close`") + # line found line = data[start_pos:pos] if SEP == b"\n": # For lax response parsing @@ -393,7 +396,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]): method=method, compression=msg.compression, code=self.code, - readall=self.readall, response_with_body=self.response_with_body, auto_decompress=self._auto_decompress, lax=self.lax, @@ -413,7 +415,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]): payload, method=msg.method, compression=msg.compression, - readall=True, auto_decompress=self._auto_decompress, lax=self.lax, ) @@ -431,7 +432,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]): method=method, compression=msg.compression, code=self.code, - readall=True, response_with_body=self.response_with_body, auto_decompress=self._auto_decompress, lax=self.lax, @@ -442,6 +442,7 @@ class HttpParser(abc.ABC, Generic[_MsgT]): payload = EMPTY_PAYLOAD messages.append((msg, payload)) + should_close = msg.should_close else: self._tail = data[start_pos:] data = EMPTY @@ -543,10 +544,8 @@ class HttpParser(abc.ABC, Generic[_MsgT]): # chunking te = headers.get(hdrs.TRANSFER_ENCODING) if te is not None: - if "chunked" == te.lower(): + if self._is_chunked_te(te): chunked = True - else: - raise BadHttpMessage("Request has invalid `Transfer-Encoding`") if hdrs.CONTENT_LENGTH in headers: raise BadHttpMessage( @@ -656,6 +655,12 @@ class HttpRequestParser(HttpParser[RawRequestMessage]): url, ) + def _is_chunked_te(self, te: str) -> bool: + if te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked": + return True + # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3 + raise BadHttpMessage("Request has invalid `Transfer-Encoding`") + class HttpResponseParser(HttpParser[RawResponseMessage]): """Read response status line and headers. @@ -741,6 +746,10 @@ class HttpResponseParser(HttpParser[RawResponseMessage]): chunked, ) + def _is_chunked_te(self, te: str) -> bool: + # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.2 + return te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked" + class HttpPayloadParser: def __init__( @@ -751,13 +760,12 @@ class HttpPayloadParser: compression: Optional[str] = None, code: Optional[int] = None, method: Optional[str] = None, - readall: bool = False, response_with_body: bool = True, auto_decompress: bool = True, lax: bool = False, ) -> None: self._length = 0 - self._type = ParseState.PARSE_NONE + self._type = ParseState.PARSE_UNTIL_EOF self._chunk = ChunkState.PARSE_CHUNKED_SIZE self._chunk_size = 0 self._chunk_tail = b"" @@ -779,7 +787,6 @@ class HttpPayloadParser: self._type = ParseState.PARSE_NONE real_payload.feed_eof() self.done = True - elif chunked: self._type = ParseState.PARSE_CHUNKED elif length is not None: @@ -788,16 +795,6 @@ class HttpPayloadParser: if self._length == 0: real_payload.feed_eof() self.done = True - else: - if readall and code != 204: - self._type = ParseState.PARSE_UNTIL_EOF - elif method in ("PUT", "POST"): - internal_logger.warning( # pragma: no cover - "Content-Length or Transfer-Encoding header is required" - ) - self._type = ParseState.PARSE_NONE - real_payload.feed_eof() - self.done = True self.payload = real_payload @@ -888,13 +885,13 @@ class HttpPayloadParser: self._chunk_size = 0 self.payload.feed_data(chunk[:required], required) chunk = chunk[required:] - if self._lax and chunk.startswith(b"\r"): - chunk = chunk[1:] self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF self.payload.end_http_chunk_receiving() # toss the CRLF at the end of the chunk if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF: + if self._lax and chunk.startswith(b"\r"): + chunk = chunk[1:] if chunk[: len(SEP)] == SEP: chunk = chunk[len(SEP) :] self._chunk = ChunkState.PARSE_CHUNKED_SIZE diff --git a/contrib/python/aiohttp/aiohttp/http_websocket.py b/contrib/python/aiohttp/aiohttp/http_websocket.py index 39f2e4a5c15..fb00ebc7d35 100644 --- a/contrib/python/aiohttp/aiohttp/http_websocket.py +++ b/contrib/python/aiohttp/aiohttp/http_websocket.py @@ -8,6 +8,7 @@ import re import sys import zlib from enum import IntEnum +from functools import partial from struct import Struct from typing import ( Any, @@ -24,6 +25,7 @@ from typing import ( ) from .base_protocol import BaseProtocol +from .client_exceptions import ClientConnectionResetError from .compression_utils import ZLibCompressor, ZLibDecompressor from .helpers import NO_EXTENSIONS, set_exception from .streams import DataQueue @@ -93,6 +95,14 @@ class WSMsgType(IntEnum): error = ERROR +MESSAGE_TYPES_WITH_CONTENT: Final = frozenset( + { + WSMsgType.BINARY, + WSMsgType.TEXT, + WSMsgType.CONTINUATION, + } +) + WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" @@ -103,8 +113,10 @@ PACK_LEN1 = Struct("!BB").pack PACK_LEN2 = Struct("!BBH").pack PACK_LEN3 = Struct("!BBQ").pack PACK_CLOSE_CODE = Struct("!H").pack +PACK_RANDBITS = Struct("!L").pack MSG_SIZE: Final[int] = 2**14 DEFAULT_LIMIT: Final[int] = 2**16 +MASK_LEN: Final[int] = 4 class WSMessage(NamedTuple): @@ -294,7 +306,7 @@ class WebSocketReader: self._frame_opcode: Optional[int] = None self._frame_payload = bytearray() - self._tail = b"" + self._tail: bytes = b"" self._has_mask = False self._frame_mask: Optional[bytes] = None self._payload_length = 0 @@ -311,17 +323,101 @@ class WebSocketReader: return True, data try: - return self._feed_data(data) + self._feed_data(data) except Exception as exc: self._exc = exc set_exception(self.queue, exc) return True, b"" - def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: + return False, b"" + + def _feed_data(self, data: bytes) -> None: for fin, opcode, payload, compressed in self.parse_frame(data): - if compressed and not self._decompressobj: - self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) - if opcode == WSMsgType.CLOSE: + if opcode in MESSAGE_TYPES_WITH_CONTENT: + # load text/binary + is_continuation = opcode == WSMsgType.CONTINUATION + if not fin: + # got partial frame payload + if not is_continuation: + self._opcode = opcode + self._partial += payload + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size + ), + ) + continue + + has_partial = bool(self._partial) + if is_continuation: + if self._opcode is None: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Continuation frame for non started message", + ) + opcode = self._opcode + self._opcode = None + # previous frame was non finished + # we should get continuation opcode + elif has_partial: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "The opcode in non-fin frame is expected " + "to be zero, got {!r}".format(opcode), + ) + + if has_partial: + assembled_payload = self._partial + payload + self._partial.clear() + else: + assembled_payload = payload + + if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(assembled_payload), self._max_msg_size + ), + ) + + # Decompress process must to be done after all packets + # received. + if compressed: + if not self._decompressobj: + self._decompressobj = ZLibDecompressor( + suppress_deflate_header=True + ) + payload_merged = self._decompressobj.decompress_sync( + assembled_payload + _WS_DEFLATE_TRAILING, self._max_msg_size + ) + if self._decompressobj.unconsumed_tail: + left = len(self._decompressobj.unconsumed_tail) + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Decompressed message size {} exceeds limit {}".format( + self._max_msg_size + left, self._max_msg_size + ), + ) + else: + payload_merged = bytes(assembled_payload) + + if opcode == WSMsgType.TEXT: + try: + text = payload_merged.decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + + self.queue.feed_data(WSMessage(WSMsgType.TEXT, text, ""), len(text)) + continue + + self.queue.feed_data( + WSMessage(WSMsgType.BINARY, payload_merged, ""), len(payload_merged) + ) + elif opcode == WSMsgType.CLOSE: if len(payload) >= 2: close_code = UNPACK_CLOSE_CODE(payload[:2])[0] if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: @@ -356,241 +452,145 @@ class WebSocketReader: WSMessage(WSMsgType.PONG, payload, ""), len(payload) ) - elif ( - opcode not in (WSMsgType.TEXT, WSMsgType.BINARY) - and self._opcode is None - ): + else: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" ) - else: - # load text/binary - if not fin: - # got partial frame payload - if opcode != WSMsgType.CONTINUATION: - self._opcode = opcode - self._partial.extend(payload) - if self._max_msg_size and len(self._partial) >= self._max_msg_size: - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(self._partial), self._max_msg_size - ), - ) - else: - # previous frame was non finished - # we should get continuation opcode - if self._partial: - if opcode != WSMsgType.CONTINUATION: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "The opcode in non-fin frame is expected " - "to be zero, got {!r}".format(opcode), - ) - - if opcode == WSMsgType.CONTINUATION: - assert self._opcode is not None - opcode = self._opcode - self._opcode = None - - self._partial.extend(payload) - if self._max_msg_size and len(self._partial) >= self._max_msg_size: - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(self._partial), self._max_msg_size - ), - ) - - # Decompress process must to be done after all packets - # received. - if compressed: - assert self._decompressobj is not None - self._partial.extend(_WS_DEFLATE_TRAILING) - payload_merged = self._decompressobj.decompress_sync( - self._partial, self._max_msg_size - ) - if self._decompressobj.unconsumed_tail: - left = len(self._decompressobj.unconsumed_tail) - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - "Decompressed message size {} exceeds limit {}".format( - self._max_msg_size + left, self._max_msg_size - ), - ) - else: - payload_merged = bytes(self._partial) - - self._partial.clear() - - if opcode == WSMsgType.TEXT: - try: - text = payload_merged.decode("utf-8") - self.queue.feed_data( - WSMessage(WSMsgType.TEXT, text, ""), len(text) - ) - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" - ) from exc - else: - self.queue.feed_data( - WSMessage(WSMsgType.BINARY, payload_merged, ""), - len(payload_merged), - ) - - return False, b"" def parse_frame( self, buf: bytes ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]: """Return the next frame from the socket.""" - frames = [] + frames: List[Tuple[bool, Optional[int], bytearray, Optional[bool]]] = [] if self._tail: buf, self._tail = self._tail + buf, b"" - start_pos = 0 + start_pos: int = 0 buf_length = len(buf) while True: # read header - if self._state == WSParserState.READ_HEADER: - if buf_length - start_pos >= 2: - data = buf[start_pos : start_pos + 2] - start_pos += 2 - first_byte, second_byte = data + if self._state is WSParserState.READ_HEADER: + if buf_length - start_pos < 2: + break + data = buf[start_pos : start_pos + 2] + start_pos += 2 + first_byte, second_byte = data - fin = (first_byte >> 7) & 1 - rsv1 = (first_byte >> 6) & 1 - rsv2 = (first_byte >> 5) & 1 - rsv3 = (first_byte >> 4) & 1 - opcode = first_byte & 0xF + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xF - # frame-fin = %x0 ; more frames of this message follow - # / %x1 ; final frame of this message - # frame-rsv1 = %x0 ; - # 1 bit, MUST be 0 unless negotiated otherwise - # frame-rsv2 = %x0 ; - # 1 bit, MUST be 0 unless negotiated otherwise - # frame-rsv3 = %x0 ; - # 1 bit, MUST be 0 unless negotiated otherwise - # - # Remove rsv1 from this test for deflate development - if rsv2 or rsv3 or (rsv1 and not self._compress): - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Received frame with non-zero reserved bits", - ) + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # + # Remove rsv1 from this test for deflate development + if rsv2 or rsv3 or (rsv1 and not self._compress): + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) - if opcode > 0x7 and fin == 0: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Received fragmented control frame", - ) + if opcode > 0x7 and fin == 0: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received fragmented control frame", + ) - has_mask = (second_byte >> 7) & 1 - length = second_byte & 0x7F + has_mask = (second_byte >> 7) & 1 + length = second_byte & 0x7F - # Control frames MUST have a payload - # length of 125 bytes or less - if opcode > 0x7 and length > 125: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Control frame payload cannot be " "larger than 125 bytes", - ) + # Control frames MUST have a payload + # length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Control frame payload cannot be " "larger than 125 bytes", + ) - # Set compress status if last package is FIN - # OR set compress status if this is first fragment - # Raise error if not first fragment with rsv1 = 0x1 - if self._frame_fin or self._compressed is None: - self._compressed = True if rsv1 else False - elif rsv1: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Received frame with non-zero reserved bits", - ) + # Set compress status if last package is FIN + # OR set compress status if this is first fragment + # Raise error if not first fragment with rsv1 = 0x1 + if self._frame_fin or self._compressed is None: + self._compressed = True if rsv1 else False + elif rsv1: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) - self._frame_fin = bool(fin) - self._frame_opcode = opcode - self._has_mask = bool(has_mask) - self._payload_length_flag = length - self._state = WSParserState.READ_PAYLOAD_LENGTH - else: - break + self._frame_fin = bool(fin) + self._frame_opcode = opcode + self._has_mask = bool(has_mask) + self._payload_length_flag = length + self._state = WSParserState.READ_PAYLOAD_LENGTH # read payload length - if self._state == WSParserState.READ_PAYLOAD_LENGTH: - length = self._payload_length_flag - if length == 126: - if buf_length - start_pos >= 2: - data = buf[start_pos : start_pos + 2] - start_pos += 2 - length = UNPACK_LEN2(data)[0] - self._payload_length = length - self._state = ( - WSParserState.READ_PAYLOAD_MASK - if self._has_mask - else WSParserState.READ_PAYLOAD - ) - else: + if self._state is WSParserState.READ_PAYLOAD_LENGTH: + length_flag = self._payload_length_flag + if length_flag == 126: + if buf_length - start_pos < 2: break - elif length > 126: - if buf_length - start_pos >= 8: - data = buf[start_pos : start_pos + 8] - start_pos += 8 - length = UNPACK_LEN3(data)[0] - self._payload_length = length - self._state = ( - WSParserState.READ_PAYLOAD_MASK - if self._has_mask - else WSParserState.READ_PAYLOAD - ) - else: + data = buf[start_pos : start_pos + 2] + start_pos += 2 + self._payload_length = UNPACK_LEN2(data)[0] + elif length_flag > 126: + if buf_length - start_pos < 8: break + data = buf[start_pos : start_pos + 8] + start_pos += 8 + self._payload_length = UNPACK_LEN3(data)[0] else: - self._payload_length = length - self._state = ( - WSParserState.READ_PAYLOAD_MASK - if self._has_mask - else WSParserState.READ_PAYLOAD - ) + self._payload_length = length_flag + + self._state = ( + WSParserState.READ_PAYLOAD_MASK + if self._has_mask + else WSParserState.READ_PAYLOAD + ) # read payload mask - if self._state == WSParserState.READ_PAYLOAD_MASK: - if buf_length - start_pos >= 4: - self._frame_mask = buf[start_pos : start_pos + 4] - start_pos += 4 - self._state = WSParserState.READ_PAYLOAD - else: + if self._state is WSParserState.READ_PAYLOAD_MASK: + if buf_length - start_pos < 4: break + self._frame_mask = buf[start_pos : start_pos + 4] + start_pos += 4 + self._state = WSParserState.READ_PAYLOAD - if self._state == WSParserState.READ_PAYLOAD: + if self._state is WSParserState.READ_PAYLOAD: length = self._payload_length payload = self._frame_payload chunk_len = buf_length - start_pos if length >= chunk_len: self._payload_length = length - chunk_len - payload.extend(buf[start_pos:]) + payload += buf[start_pos:] start_pos = buf_length else: self._payload_length = 0 - payload.extend(buf[start_pos : start_pos + length]) + payload += buf[start_pos : start_pos + length] start_pos = start_pos + length - if self._payload_length == 0: - if self._has_mask: - assert self._frame_mask is not None - _websocket_mask(self._frame_mask, payload) + if self._payload_length != 0: + break - frames.append( - (self._frame_fin, self._frame_opcode, payload, self._compressed) - ) + if self._has_mask: + assert self._frame_mask is not None + _websocket_mask(self._frame_mask, payload) - self._frame_payload = bytearray() - self._state = WSParserState.READ_HEADER - else: - break + frames.append( + (self._frame_fin, self._frame_opcode, payload, self._compressed) + ) + self._frame_payload = bytearray() + self._state = WSParserState.READ_HEADER self._tail = buf[start_pos:] @@ -612,7 +612,7 @@ class WebSocketWriter: self.protocol = protocol self.transport = transport self.use_mask = use_mask - self.randrange = random.randrange + self.get_random_bits = partial(random.getrandbits, 32) self.compress = compress self.notakeover = notakeover self._closing = False @@ -625,14 +625,20 @@ class WebSocketWriter: ) -> None: """Send a frame over the websocket with message as its payload.""" if self._closing and not (opcode & WSMsgType.CLOSE): - raise ConnectionResetError("Cannot write to closing transport") + raise ClientConnectionResetError("Cannot write to closing transport") + # RSV are the reserved bits in the frame header. They are used to + # indicate that the frame is using an extension. + # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 rsv = 0 - # Only compress larger packets (disabled) # Does small packet needs to be compressed? # if self.compress and opcode < 8 and len(message) > 124: if (compress or self.compress) and opcode < 8: + # RSV1 (rsv = 0x40) is set for compressed frames + # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 + rsv = 0x40 + if compress: # Do not set self._compress if compressing is for this frame compressobj = self._make_compress_obj(compress) @@ -651,29 +657,39 @@ class WebSocketWriter: ) if message.endswith(_WS_DEFLATE_TRAILING): message = message[:-4] - rsv = rsv | 0x40 msg_length = len(message) use_mask = self.use_mask - if use_mask: - mask_bit = 0x80 - else: - mask_bit = 0 + mask_bit = 0x80 if use_mask else 0 + # Depending on the message length, the header is assembled differently. + # The first byte is reserved for the opcode and the RSV bits. + first_byte = 0x80 | rsv | opcode if msg_length < 126: - header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit) + header = PACK_LEN1(first_byte, msg_length | mask_bit) + header_len = 2 elif msg_length < (1 << 16): - header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length) + header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length) + header_len = 4 else: - header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length) + header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length) + header_len = 10 + + # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3 + # If we are using a mask, we need to generate it randomly + # and apply it to the message before sending it. A mask is + # a 32-bit value that is applied to the message using a + # bitwise XOR operation. It is used to prevent certain types + # of attacks on the websocket protocol. The mask is only used + # when aiohttp is acting as a client. Servers do not use a mask. if use_mask: - mask_int = self.randrange(0, 0xFFFFFFFF) - mask = mask_int.to_bytes(4, "big") + mask = PACK_RANDBITS(self.get_random_bits()) message = bytearray(message) _websocket_mask(mask, message) self._write(header + mask + message) - self._output_size += len(header) + len(mask) + msg_length + self._output_size += header_len + MASK_LEN + msg_length + else: if msg_length > MSG_SIZE: self._write(header) @@ -681,11 +697,16 @@ class WebSocketWriter: else: self._write(header + message) - self._output_size += len(header) + msg_length + self._output_size += header_len + msg_length # It is safe to return control to the event loop when using compression # after this point as we have already sent or buffered all the data. + # Once we have written output_size up to the limit, we call the + # drain helper which waits for the transport to be ready to accept + # more data. This is a flow control mechanism to prevent the buffer + # from growing too large. The drain helper will return right away + # if the writer is not paused. if self._output_size > self._limit: self._output_size = 0 await self.protocol._drain_helper() @@ -699,7 +720,7 @@ class WebSocketWriter: def _write(self, data: bytes) -> None: if self.transport is None or self.transport.is_closing(): - raise ConnectionResetError("Cannot write to closing transport") + raise ClientConnectionResetError("Cannot write to closing transport") self.transport.write(data) async def pong(self, message: Union[bytes, str] = b"") -> None: diff --git a/contrib/python/aiohttp/aiohttp/http_writer.py b/contrib/python/aiohttp/aiohttp/http_writer.py index d6b02e6f566..dc07a358c70 100644 --- a/contrib/python/aiohttp/aiohttp/http_writer.py +++ b/contrib/python/aiohttp/aiohttp/http_writer.py @@ -8,6 +8,7 @@ from multidict import CIMultiDict from .abc import AbstractStreamWriter from .base_protocol import BaseProtocol +from .client_exceptions import ClientConnectionResetError from .compression_utils import ZLibCompressor from .helpers import NO_EXTENSIONS @@ -70,9 +71,9 @@ class StreamWriter(AbstractStreamWriter): size = len(chunk) self.buffer_size += size self.output_size += size - transport = self.transport - if not self._protocol.connected or transport is None or transport.is_closing(): - raise ConnectionResetError("Cannot write to closing transport") + transport = self._protocol.transport + if transport is None or transport.is_closing(): + raise ClientConnectionResetError("Cannot write to closing transport") transport.write(chunk) async def write( diff --git a/contrib/python/aiohttp/aiohttp/multipart.py b/contrib/python/aiohttp/aiohttp/multipart.py index 71fc2654a1c..49c05c5af25 100644 --- a/contrib/python/aiohttp/aiohttp/multipart.py +++ b/contrib/python/aiohttp/aiohttp/multipart.py @@ -2,6 +2,7 @@ import base64 import binascii import json import re +import sys import uuid import warnings import zlib @@ -10,7 +11,6 @@ from types import TracebackType from typing import ( TYPE_CHECKING, Any, - AsyncIterator, Deque, Dict, Iterator, @@ -48,6 +48,13 @@ from .payload import ( ) from .streams import StreamReader +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing import TypeVar + + Self = TypeVar("Self", bound="BodyPartReader") + __all__ = ( "MultipartReader", "MultipartWriter", @@ -266,6 +273,7 @@ class BodyPartReader: ) -> None: self.headers = headers self._boundary = boundary + self._boundary_len = len(boundary) + 2 # Boundary + \r\n self._content = content self._default_charset = default_charset self._at_eof = False @@ -279,8 +287,8 @@ class BodyPartReader: self._content_eof = 0 self._cache: Dict[str, Any] = {} - def __aiter__(self) -> AsyncIterator["BodyPartReader"]: - return self # type: ignore[return-value] + def __aiter__(self: Self): + return self async def __anext__(self) -> bytes: part = await self.next() @@ -322,6 +330,31 @@ class BodyPartReader: else: chunk = await self._read_chunk_from_stream(size) + # For the case of base64 data, we must read a fragment of size with a + # remainder of 0 by dividing by 4 for string without symbols \n or \r + encoding = self.headers.get(CONTENT_TRANSFER_ENCODING) + if encoding and encoding.lower() == "base64": + stripped_chunk = b"".join(chunk.split()) + remainder = len(stripped_chunk) % 4 + + while remainder != 0 and not self.at_eof(): + over_chunk_size = 4 - remainder + over_chunk = b"" + + if self._prev_chunk: + over_chunk = self._prev_chunk[:over_chunk_size] + self._prev_chunk = self._prev_chunk[len(over_chunk) :] + + if len(over_chunk) != over_chunk_size: + over_chunk += await self._content.read(4 - len(over_chunk)) + + if not over_chunk: + self._at_eof = True + + stripped_chunk += b"".join(over_chunk.split()) + chunk += over_chunk + remainder = len(stripped_chunk) % 4 + self._read_bytes += len(chunk) if self._read_bytes == self._length: self._at_eof = True @@ -346,15 +379,25 @@ class BodyPartReader: # Reads content chunk of body part with unknown length. # The Content-Length header for body part is not necessary. assert ( - size >= len(self._boundary) + 2 + size >= self._boundary_len ), "Chunk size must be greater or equal than boundary length + 2" first_chunk = self._prev_chunk is None if first_chunk: self._prev_chunk = await self._content.read(size) - chunk = await self._content.read(size) - self._content_eof += int(self._content.at_eof()) - assert self._content_eof < 3, "Reading after EOF" + chunk = b"" + # content.read() may return less than size, so we need to loop to ensure + # we have enough data to detect the boundary. + while len(chunk) < self._boundary_len: + chunk += await self._content.read(size) + self._content_eof += int(self._content.at_eof()) + assert self._content_eof < 3, "Reading after EOF" + if self._content_eof: + break + if len(chunk) > size: + self._content.unread_data(chunk[size:]) + chunk = chunk[:size] + assert self._prev_chunk is not None window = self._prev_chunk + chunk sub = b"\r\n" + self._boundary @@ -518,6 +561,8 @@ class BodyPartReader: @payload_type(BodyPartReader, order=Order.try_first) class BodyPartReaderPayload(Payload): + _value: BodyPartReader + def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None: super().__init__(value, *args, **kwargs) @@ -530,6 +575,9 @@ class BodyPartReaderPayload(Payload): if params: self.set_content_disposition("attachment", True, **params) + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + raise TypeError("Unable to decode.") + async def write(self, writer: Any) -> None: field = self._value chunk = await field.read_chunk(size=2**16) @@ -566,10 +614,8 @@ class MultipartReader: self._at_bof = True self._unread: List[bytes] = [] - def __aiter__( - self, - ) -> AsyncIterator["BodyPartReader"]: - return self # type: ignore[return-value] + def __aiter__(self: Self): + return self async def __anext__( self, @@ -749,6 +795,8 @@ _Part = Tuple[Payload, str, str] class MultipartWriter(Payload): """Multipart body writer.""" + _value: None + def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None: boundary = boundary if boundary is not None else uuid.uuid4().hex # The underlying Payload API demands a str (utf-8), not bytes, @@ -929,6 +977,16 @@ class MultipartWriter(Payload): total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n' return total + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + return "".join( + "--" + + self.boundary + + "\n" + + part._binary_headers.decode(encoding, errors) + + part.decode() + for part, _e, _te in self._parts + ) + async def write(self, writer: Any, close_boundary: bool = True) -> None: """Write body.""" for part, encoding, te_encoding in self._parts: diff --git a/contrib/python/aiohttp/aiohttp/payload.py b/contrib/python/aiohttp/aiohttp/payload.py index 6593b05c6f7..27636977774 100644 --- a/contrib/python/aiohttp/aiohttp/payload.py +++ b/contrib/python/aiohttp/aiohttp/payload.py @@ -11,7 +11,6 @@ from typing import ( IO, TYPE_CHECKING, Any, - ByteString, Dict, Final, Iterable, @@ -209,6 +208,13 @@ class Payload(ABC): ) @abstractmethod + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + """Return string representation of the value. + + This is named decode() to allow compatibility with bytes objects. + """ + + @abstractmethod async def write(self, writer: AbstractStreamWriter) -> None: """Write payload. @@ -217,7 +223,11 @@ class Payload(ABC): class BytesPayload(Payload): - def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None: + _value: bytes + + def __init__( + self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any + ) -> None: if not isinstance(value, (bytes, bytearray, memoryview)): raise TypeError(f"value argument must be byte-ish, not {type(value)!r}") @@ -241,6 +251,9 @@ class BytesPayload(Payload): **kwargs, ) + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + return self._value.decode(encoding, errors) + async def write(self, writer: AbstractStreamWriter) -> None: await writer.write(self._value) @@ -282,7 +295,7 @@ class StringIOPayload(StringPayload): class IOBasePayload(Payload): - _value: IO[Any] + _value: io.IOBase def __init__( self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any @@ -306,9 +319,12 @@ class IOBasePayload(Payload): finally: await loop.run_in_executor(None, self._value.close) + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + return "".join(r.decode(encoding, errors) for r in self._value.readlines()) + class TextIOPayload(IOBasePayload): - _value: TextIO + _value: io.TextIOBase def __init__( self, @@ -345,6 +361,9 @@ class TextIOPayload(IOBasePayload): except OSError: return None + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + return self._value.read() + async def write(self, writer: AbstractStreamWriter) -> None: loop = asyncio.get_event_loop() try: @@ -362,6 +381,8 @@ class TextIOPayload(IOBasePayload): class BytesIOPayload(IOBasePayload): + _value: io.BytesIO + @property def size(self) -> int: position = self._value.tell() @@ -369,17 +390,27 @@ class BytesIOPayload(IOBasePayload): self._value.seek(position) return end - position + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + return self._value.read().decode(encoding, errors) + class BufferedReaderPayload(IOBasePayload): + _value: io.BufferedIOBase + @property def size(self) -> Optional[int]: try: return os.fstat(self._value.fileno()).st_size - self._value.tell() - except OSError: + except (OSError, AttributeError): # data.fileno() is not supported, e.g. # io.BufferedReader(io.BytesIO(b'data')) + # For some file-like objects (e.g. tarfile), the fileno() attribute may + # not exist at all, and will instead raise an AttributeError. return None + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + return self._value.read().decode(encoding, errors) + class JsonPayload(BytesPayload): def __init__( @@ -416,6 +447,7 @@ else: class AsyncIterablePayload(Payload): _iter: Optional[_AsyncIterator] = None + _value: _AsyncIterable def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: if not isinstance(value, AsyncIterable): @@ -443,6 +475,9 @@ class AsyncIterablePayload(Payload): except StopAsyncIteration: self._iter = None + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + raise TypeError("Unable to decode.") + class StreamReaderPayload(AsyncIterablePayload): def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None: diff --git a/contrib/python/aiohttp/aiohttp/payload_streamer.py b/contrib/python/aiohttp/aiohttp/payload_streamer.py index 364f763ae74..831fdc0a77f 100644 --- a/contrib/python/aiohttp/aiohttp/payload_streamer.py +++ b/contrib/python/aiohttp/aiohttp/payload_streamer.py @@ -65,6 +65,9 @@ class StreamWrapperPayload(Payload): async def write(self, writer: AbstractStreamWriter) -> None: await self._value(writer) + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + raise TypeError("Unable to decode.") + @payload_type(streamer) class StreamPayload(StreamWrapperPayload): diff --git a/contrib/python/aiohttp/aiohttp/pytest_plugin.py b/contrib/python/aiohttp/aiohttp/pytest_plugin.py index 5754747bf48..55964ead041 100644 --- a/contrib/python/aiohttp/aiohttp/pytest_plugin.py +++ b/contrib/python/aiohttp/aiohttp/pytest_plugin.py @@ -1,13 +1,21 @@ import asyncio import contextlib +import inspect import warnings -from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Type, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterator, + Optional, + Protocol, + Type, + Union, +) import pytest -from aiohttp.helpers import isasyncgenfunction -from aiohttp.web import Application - from .test_utils import ( BaseTestServer, RawTestServer, @@ -18,15 +26,35 @@ from .test_utils import ( teardown_test_loop, unused_port as _unused_port, ) +from .web import Application +from .web_protocol import _RequestHandler try: import uvloop except ImportError: # pragma: no cover uvloop = None # type: ignore[assignment] -AiohttpClient = Callable[[Union[Application, BaseTestServer]], Awaitable[TestClient]] -AiohttpRawServer = Callable[[Application], Awaitable[RawTestServer]] -AiohttpServer = Callable[[Application], Awaitable[TestServer]] + +class AiohttpClient(Protocol): + def __call__( + self, + __param: Union[Application, BaseTestServer], + *, + server_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any + ) -> Awaitable[TestClient]: ... + + +class AiohttpServer(Protocol): + def __call__( + self, app: Application, *, port: Optional[int] = None, **kwargs: Any + ) -> Awaitable[TestServer]: ... + + +class AiohttpRawServer(Protocol): + def __call__( + self, handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any + ) -> Awaitable[RawTestServer]: ... def pytest_addoption(parser): # type: ignore[no-untyped-def] @@ -57,7 +85,7 @@ def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def] """ func = fixturedef.func - if isasyncgenfunction(func): + if inspect.isasyncgenfunction(func): # async generator fixture is_async_gen = True elif asyncio.iscoroutinefunction(func): @@ -262,7 +290,9 @@ def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]: """ servers = [] - async def go(app, *, port=None, **kwargs): # type: ignore[no-untyped-def] + async def go( + app: Application, *, port: Optional[int] = None, **kwargs: Any + ) -> TestServer: server = TestServer(app, port=port) await server.start_server(loop=loop, **kwargs) servers.append(server) @@ -295,7 +325,9 @@ def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawSe """ servers = [] - async def go(handler, *, port=None, **kwargs): # type: ignore[no-untyped-def] + async def go( + handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any + ) -> RawTestServer: server = RawTestServer(handler, port=port) await server.start_server(loop=loop, **kwargs) servers.append(server) diff --git a/contrib/python/aiohttp/aiohttp/resolver.py b/contrib/python/aiohttp/aiohttp/resolver.py index c03230c744e..06855fa13fd 100644 --- a/contrib/python/aiohttp/aiohttp/resolver.py +++ b/contrib/python/aiohttp/aiohttp/resolver.py @@ -1,20 +1,25 @@ import asyncio import socket -from typing import Any, Dict, List, Optional, Type, Union +import sys +from typing import Any, Dict, List, Optional, Tuple, Type, Union -from .abc import AbstractResolver -from .helpers import get_running_loop +from .abc import AbstractResolver, ResolveResult __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver") + try: import aiodns - # aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname') + aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo") except ImportError: # pragma: no cover - aiodns = None + aiodns = None # type: ignore[assignment] + aiodns_default = False + -aiodns_default = False +_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV +_NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV +_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0) class ThreadedResolver(AbstractResolver): @@ -25,48 +30,48 @@ class ThreadedResolver(AbstractResolver): """ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: - self._loop = get_running_loop(loop) + self._loop = loop or asyncio.get_event_loop() async def resolve( - self, hostname: str, port: int = 0, family: int = socket.AF_INET - ) -> List[Dict[str, Any]]: + self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET + ) -> List[ResolveResult]: infos = await self._loop.getaddrinfo( - hostname, + host, port, type=socket.SOCK_STREAM, family=family, # flags=socket.AI_ADDRCONFIG, ) - hosts = [] + hosts: List[ResolveResult] = [] for family, _, proto, _, address in infos: if family == socket.AF_INET6: if len(address) < 3: # IPv6 is not supported by Python build, # or IPv6 is not enabled in the host continue - if address[3]: + if address[3] and _SUPPORTS_SCOPE_ID: # This is essential for link-local IPv6 addresses. # LL IPv6 is a VERY rare case. Strictly speaking, we should use # getnameinfo() unconditionally, but performance makes sense. - host, _port = socket.getnameinfo( - address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV + resolved_host, _port = await self._loop.getnameinfo( + address, _NAME_SOCKET_FLAGS ) port = int(_port) else: - host, port = address[:2] + resolved_host, port = address[:2] else: # IPv4 assert family == socket.AF_INET - host, port = address # type: ignore[misc] + resolved_host, port = address # type: ignore[misc] hosts.append( - { - "hostname": hostname, - "host": host, - "port": port, - "family": family, - "proto": proto, - "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV, - } + ResolveResult( + hostname=host, + host=resolved_host, + port=port, + family=family, + proto=proto, + flags=_NUMERIC_SOCKET_FLAGS, + ) ) return hosts @@ -87,32 +92,56 @@ class AsyncResolver(AbstractResolver): if aiodns is None: raise RuntimeError("Resolver requires aiodns library") - self._loop = get_running_loop(loop) - self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs) + self._resolver = aiodns.DNSResolver(*args, **kwargs) if not hasattr(self._resolver, "gethostbyname"): # aiodns 1.1 is not available, fallback to DNSResolver.query self.resolve = self._resolve_with_query # type: ignore async def resolve( - self, host: str, port: int = 0, family: int = socket.AF_INET - ) -> List[Dict[str, Any]]: + self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET + ) -> List[ResolveResult]: try: - resp = await self._resolver.gethostbyname(host, family) + resp = await self._resolver.getaddrinfo( + host, + port=port, + type=socket.SOCK_STREAM, + family=family, + flags=socket.AI_ADDRCONFIG, + ) except aiodns.error.DNSError as exc: msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" raise OSError(msg) from exc - hosts = [] - for address in resp.addresses: + hosts: List[ResolveResult] = [] + for node in resp.nodes: + address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr + family = node.family + if family == socket.AF_INET6: + if len(address) > 3 and address[3] and _SUPPORTS_SCOPE_ID: + # This is essential for link-local IPv6 addresses. + # LL IPv6 is a VERY rare case. Strictly speaking, we should use + # getnameinfo() unconditionally, but performance makes sense. + result = await self._resolver.getnameinfo( + (address[0].decode("ascii"), *address[1:]), + _NAME_SOCKET_FLAGS, + ) + resolved_host = result.node + else: + resolved_host = address[0].decode("ascii") + port = address[1] + else: # IPv4 + assert family == socket.AF_INET + resolved_host = address[0].decode("ascii") + port = address[1] hosts.append( - { - "hostname": host, - "host": address, - "port": port, - "family": family, - "proto": 0, - "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV, - } + ResolveResult( + hostname=host, + host=resolved_host, + port=port, + family=family, + proto=0, + flags=_NUMERIC_SOCKET_FLAGS, + ) ) if not hosts: diff --git a/contrib/python/aiohttp/aiohttp/streams.py b/contrib/python/aiohttp/aiohttp/streams.py index b9b9c3fd96f..c927cfbb1b3 100644 --- a/contrib/python/aiohttp/aiohttp/streams.py +++ b/contrib/python/aiohttp/aiohttp/streams.py @@ -296,6 +296,9 @@ class StreamReader(AsyncStreamReaderMixin): set_result(waiter, None) async def _wait(self, func_name: str) -> None: + if not self._protocol.connected: + raise RuntimeError("Connection closed.") + # StreamReader uses a future to link the protocol feed_data() method # to a read coroutine. Running two read coroutines at the same time # would have an unexpected behaviour. It would not possible to know diff --git a/contrib/python/aiohttp/aiohttp/test_utils.py b/contrib/python/aiohttp/aiohttp/test_utils.py index a36e8599689..01496b6711a 100644 --- a/contrib/python/aiohttp/aiohttp/test_utils.py +++ b/contrib/python/aiohttp/aiohttp/test_utils.py @@ -11,17 +11,7 @@ import sys import warnings from abc import ABC, abstractmethod from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterator, - List, - Optional, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Type, cast from unittest import IsolatedAsyncioTestCase, mock from aiosignal import Signal @@ -29,7 +19,11 @@ from multidict import CIMultiDict, CIMultiDictProxy from yarl import URL import aiohttp -from aiohttp.client import _RequestContextManager, _WSRequestContextManager +from aiohttp.client import ( + _RequestContextManager, + _RequestOptions, + _WSRequestContextManager, +) from . import ClientSession, hdrs from .abc import AbstractCookieJar @@ -37,6 +31,7 @@ from .client_reqrep import ClientResponse from .client_ws import ClientWebSocketResponse from .helpers import sentinel from .http import HttpVersion, RawRequestMessage +from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import StrOrURL from .web import ( Application, @@ -55,6 +50,9 @@ if TYPE_CHECKING: else: SSLContext = None +if sys.version_info >= (3, 11) and TYPE_CHECKING: + from typing import Unpack + REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" @@ -90,7 +88,7 @@ class BaseTestServer(ABC): def __init__( self, *, - scheme: Union[str, object] = sentinel, + scheme: str = "", loop: Optional[asyncio.AbstractEventLoop] = None, host: str = "127.0.0.1", port: Optional[int] = None, @@ -121,10 +119,13 @@ class BaseTestServer(ABC): await self.runner.setup() if not self.port: self.port = 0 + absolute_host = self.host try: version = ipaddress.ip_address(self.host).version except ValueError: version = 4 + if version == 6: + absolute_host = f"[{self.host}]" family = socket.AF_INET6 if version == 6 else socket.AF_INET _sock = self.socket_factory(self.host, self.port, family) self.host, self.port = _sock.getsockname()[:2] @@ -135,13 +136,9 @@ class BaseTestServer(ABC): sockets = server.sockets # type: ignore[attr-defined] assert sockets is not None self.port = sockets[0].getsockname()[1] - if self.scheme is sentinel: - if self._ssl: - scheme = "https" - else: - scheme = "http" - self.scheme = scheme - self._root = URL(f"{self.scheme}://{self.host}:{self.port}") + if not self.scheme: + self.scheme = "https" if self._ssl else "http" + self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}") @abstractmethod # pragma: no cover async def _make_runner(self, **kwargs: Any) -> BaseRunner: @@ -222,7 +219,7 @@ class TestServer(BaseTestServer): self, app: Application, *, - scheme: Union[str, object] = sentinel, + scheme: str = "", host: str = "127.0.0.1", port: Optional[int] = None, **kwargs: Any, @@ -239,7 +236,7 @@ class RawTestServer(BaseTestServer): self, handler: _RequestHandler, *, - scheme: Union[str, object] = sentinel, + scheme: str = "", host: str = "127.0.0.1", port: Optional[int] = None, **kwargs: Any, @@ -324,45 +321,101 @@ class TestClient: self._responses.append(resp) return resp - def request( - self, method: str, path: StrOrURL, **kwargs: Any - ) -> _RequestContextManager: - """Routes a request to tested http server. + if sys.version_info >= (3, 11) and TYPE_CHECKING: - The interface is identical to aiohttp.ClientSession.request, - except the loop kwarg is overridden by the instance used by the - test server. + def request( + self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions] + ) -> _RequestContextManager: ... - """ - return _RequestContextManager(self._request(method, path, **kwargs)) + def get( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def options( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def head( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def post( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def put( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def patch( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... - def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP GET request.""" - return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) + def delete( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... - def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP POST request.""" - return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) + else: + + def request( + self, method: str, path: StrOrURL, **kwargs: Any + ) -> _RequestContextManager: + """Routes a request to tested http server. + + The interface is identical to aiohttp.ClientSession.request, + except the loop kwarg is overridden by the instance used by the + test server. + + """ + return _RequestContextManager(self._request(method, path, **kwargs)) + + def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP GET request.""" + return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) - def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP OPTIONS request.""" - return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs)) + def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP POST request.""" + return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) - def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP HEAD request.""" - return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) + def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP OPTIONS request.""" + return _RequestContextManager( + self._request(hdrs.METH_OPTIONS, path, **kwargs) + ) - def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP PUT request.""" - return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) + def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP HEAD request.""" + return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) - def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP PATCH request.""" - return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs)) + def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PUT request.""" + return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) - def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP PATCH request.""" - return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs)) + def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PATCH request.""" + return _RequestContextManager( + self._request(hdrs.METH_PATCH, path, **kwargs) + ) + + def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PATCH request.""" + return _RequestContextManager( + self._request(hdrs.METH_DELETE, path, **kwargs) + ) def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: """Initiate websocket connection. @@ -582,7 +635,7 @@ def make_mocked_request( writer: Any = sentinel, protocol: Any = sentinel, transport: Any = sentinel, - payload: Any = sentinel, + payload: StreamReader = EMPTY_PAYLOAD, sslcontext: Optional[SSLContext] = None, client_max_size: int = 1024**2, loop: Any = ..., @@ -651,9 +704,6 @@ def make_mocked_request( protocol.transport = transport protocol.writer = writer - if payload is sentinel: - payload = mock.Mock() - req = Request( message, payload, protocol, writer, task, loop, client_max_size=client_max_size ) diff --git a/contrib/python/aiohttp/aiohttp/tracing.py b/contrib/python/aiohttp/aiohttp/tracing.py index fe3eda9abb7..067a132464e 100644 --- a/contrib/python/aiohttp/aiohttp/tracing.py +++ b/contrib/python/aiohttp/aiohttp/tracing.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from typing import TYPE_CHECKING, Awaitable, Optional, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Awaitable, Mapping, Optional, Protocol, Type, TypeVar import attr from aiosignal import Signal @@ -42,59 +42,29 @@ class TraceConfig: def __init__( self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace ) -> None: - self._on_request_start: _TracingSignal[ - TraceRequestStartParams - ] = Signal(self) - self._on_request_chunk_sent: _TracingSignal[ - TraceRequestChunkSentParams - ] = Signal(self) - self._on_response_chunk_received: _TracingSignal[ - TraceResponseChunkReceivedParams - ] = Signal(self) - self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal( - self + self._on_request_start: _TracingSignal[TraceRequestStartParams] = ( + Signal(self) ) - self._on_request_exception: _TracingSignal[ - TraceRequestExceptionParams - ] = Signal(self) - self._on_request_redirect: _TracingSignal[ - TraceRequestRedirectParams - ] = Signal(self) - self._on_connection_queued_start: _TracingSignal[ - TraceConnectionQueuedStartParams - ] = Signal(self) - self._on_connection_queued_end: _TracingSignal[ - TraceConnectionQueuedEndParams - ] = Signal(self) - self._on_connection_create_start: _TracingSignal[ - TraceConnectionCreateStartParams - ] = Signal(self) - self._on_connection_create_end: _TracingSignal[ - TraceConnectionCreateEndParams - ] = Signal(self) - self._on_connection_reuseconn: _TracingSignal[ - TraceConnectionReuseconnParams - ] = Signal(self) - self._on_dns_resolvehost_start: _TracingSignal[ - TraceDnsResolveHostStartParams - ] = Signal(self) - self._on_dns_resolvehost_end: _TracingSignal[ - TraceDnsResolveHostEndParams - ] = Signal(self) - self._on_dns_cache_hit: _TracingSignal[ - TraceDnsCacheHitParams - ] = Signal(self) - self._on_dns_cache_miss: _TracingSignal[ - TraceDnsCacheMissParams - ] = Signal(self) - self._on_request_headers_sent: _TracingSignal[ - TraceRequestHeadersSentParams - ] = Signal(self) + self._on_request_chunk_sent: _TracingSignal[TraceRequestChunkSentParams] = Signal(self) + self._on_response_chunk_received: _TracingSignal[TraceResponseChunkReceivedParams] = Signal(self) + self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal(self) + self._on_request_exception: _TracingSignal[TraceRequestExceptionParams] = Signal(self) + self._on_request_redirect: _TracingSignal[TraceRequestRedirectParams] = Signal(self) + self._on_connection_queued_start: _TracingSignal[TraceConnectionQueuedStartParams] = Signal(self) + self._on_connection_queued_end: _TracingSignal[TraceConnectionQueuedEndParams] = Signal(self) + self._on_connection_create_start: _TracingSignal[TraceConnectionCreateStartParams] = Signal(self) + self._on_connection_create_end: _TracingSignal[TraceConnectionCreateEndParams] = Signal(self) + self._on_connection_reuseconn: _TracingSignal[TraceConnectionReuseconnParams] = Signal(self) + self._on_dns_resolvehost_start: _TracingSignal[TraceDnsResolveHostStartParams] = Signal(self) + self._on_dns_resolvehost_end: _TracingSignal[TraceDnsResolveHostEndParams] = Signal(self) + self._on_dns_cache_hit: _TracingSignal[TraceDnsCacheHitParams] = (Signal(self)) + self._on_dns_cache_miss: _TracingSignal[TraceDnsCacheMissParams] = (Signal(self)) + self._on_request_headers_sent: _TracingSignal[TraceRequestHeadersSentParams] = Signal(self) self._trace_config_ctx_factory = trace_config_ctx_factory def trace_config_ctx( - self, trace_request_ctx: Optional[SimpleNamespace] = None + self, trace_request_ctx: Optional[Mapping[str, str]] = None ) -> SimpleNamespace: """Return a new trace_config_ctx instance""" return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx) @@ -122,7 +92,9 @@ class TraceConfig: return self._on_request_start @property - def on_request_chunk_sent(self) -> "_TracingSignal[TraceRequestChunkSentParams]": + def on_request_chunk_sent( + self, + ) -> "_TracingSignal[TraceRequestChunkSentParams]": return self._on_request_chunk_sent @property diff --git a/contrib/python/aiohttp/aiohttp/typedefs.py b/contrib/python/aiohttp/aiohttp/typedefs.py index 5e963e1a10e..668d4fc344f 100644 --- a/contrib/python/aiohttp/aiohttp/typedefs.py +++ b/contrib/python/aiohttp/aiohttp/typedefs.py @@ -7,6 +7,8 @@ from typing import ( Callable, Iterable, Mapping, + Protocol, + Sequence, Tuple, Union, ) @@ -14,6 +16,18 @@ from typing import ( from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy, istr from yarl import URL +try: + # Available in yarl>=1.10.0 + from yarl import Query as _Query +except ImportError: # pragma: no cover + SimpleQuery = Union[str, int, float] # pragma: no cover + QueryVariable = Union[SimpleQuery, "Sequence[SimpleQuery]"] # pragma: no cover + _Query = Union[ # type: ignore[misc] # pragma: no cover + None, str, "Mapping[str, QueryVariable]", "Sequence[Tuple[str, QueryVariable]]" + ] + +Query = _Query + DEFAULT_JSON_ENCODER = json.dumps DEFAULT_JSON_DECODER = json.loads @@ -34,7 +48,13 @@ else: Byteish = Union[bytes, bytearray, memoryview] JSONEncoder = Callable[[Any], str] JSONDecoder = Callable[[str], Any] -LooseHeaders = Union[Mapping[Union[str, istr], str], _CIMultiDict, _CIMultiDictProxy] +LooseHeaders = Union[ + Mapping[str, str], + Mapping[istr, str], + _CIMultiDict, + _CIMultiDictProxy, + Iterable[Tuple[Union[str, istr], str]], +] RawHeaders = Tuple[Tuple[bytes, bytes], ...] StrOrURL = Union[str, URL] @@ -51,4 +71,5 @@ LooseCookies = Union[ Handler = Callable[["Request"], Awaitable["StreamResponse"]] Middleware = Callable[["Request", Handler], Awaitable["StreamResponse"]] + PathLike = Union[str, "os.PathLike[str]"] diff --git a/contrib/python/aiohttp/aiohttp/web.py b/contrib/python/aiohttp/aiohttp/web.py index e9116507f4e..88bf14bf828 100644 --- a/contrib/python/aiohttp/aiohttp/web.py +++ b/contrib/python/aiohttp/aiohttp/web.py @@ -7,7 +7,6 @@ import warnings from argparse import ArgumentParser from collections.abc import Iterable from contextlib import suppress -from functools import partial from importlib import import_module from typing import ( Any, @@ -21,7 +20,6 @@ from typing import ( Union, cast, ) -from weakref import WeakSet from .abc import AbstractAccessLogger from .helpers import AppKey as AppKey @@ -320,23 +318,6 @@ async def _run_app( reuse_port: Optional[bool] = None, handler_cancellation: bool = False, ) -> None: - async def wait( - starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float - ) -> None: - # Wait for pending tasks for a given time limit. - t = asyncio.current_task() - assert t is not None - starting_tasks.add(t) - with suppress(asyncio.TimeoutError): - await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout) - - async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: - t = asyncio.current_task() - assert t is not None - exclude.add(t) - while tasks := asyncio.all_tasks().difference(exclude): - await asyncio.wait(tasks) - # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): app = await app @@ -355,12 +336,6 @@ async def _run_app( ) await runner.setup() - # On shutdown we want to avoid waiting on tasks which run forever. - # It's very likely that all tasks which run forever will have been created by - # the time we have completed the application startup (in runner.setup()), - # so we just record all running tasks here and exclude them later. - starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks()) - runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout) sites: List[BaseSite] = [] @@ -545,10 +520,14 @@ def run_app( except (GracefulExit, KeyboardInterrupt): # pragma: no cover pass finally: - _cancel_tasks({main_task}, loop) - _cancel_tasks(asyncio.all_tasks(loop), loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - loop.close() + try: + main_task.cancel() + with suppress(asyncio.CancelledError): + loop.run_until_complete(main_task) + finally: + _cancel_tasks(asyncio.all_tasks(loop), loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() def main(argv: List[str]) -> None: diff --git a/contrib/python/aiohttp/aiohttp/web_app.py b/contrib/python/aiohttp/aiohttp/web_app.py index 8e0e91bcfe3..ab11981a8a5 100644 --- a/contrib/python/aiohttp/aiohttp/web_app.py +++ b/contrib/python/aiohttp/aiohttp/web_app.py @@ -1,7 +1,7 @@ import asyncio import logging import warnings -from functools import partial, update_wrapper +from functools import lru_cache, partial, update_wrapper from typing import ( TYPE_CHECKING, Any, @@ -38,7 +38,7 @@ from .helpers import DEBUG, AppKey from .http_parser import RawRequestMessage from .log import web_logger from .streams import StreamReader -from .typedefs import Middleware +from .typedefs import Handler, Middleware from .web_exceptions import NotAppKeyWarning from .web_log import AccessLogger from .web_middlewares import _fix_request_current_app @@ -76,6 +76,18 @@ else: _T = TypeVar("_T") _U = TypeVar("_U") +_Resource = TypeVar("_Resource", bound=AbstractResource) + + +@lru_cache(None) +def _build_middlewares( + handler: Handler, apps: Tuple["Application", ...] +) -> Callable[[Request], Awaitable[StreamResponse]]: + """Apply middlewares to handler.""" + for app in apps[::-1]: + for m, _ in app._middlewares_handlers: # type: ignore[union-attr] + handler = update_wrapper(partial(m, handler=handler), handler) # type: ignore[misc] + return handler class Application(MutableMapping[Union[str, AppKey[Any]], Any]): @@ -88,6 +100,7 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): "_handler_args", "_middlewares", "_middlewares_handlers", + "_has_legacy_middlewares", "_run_middlewares", "_state", "_frozen", @@ -142,6 +155,7 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): self._middlewares_handlers: _MiddlewaresHandlers = None # initialized on freezing self._run_middlewares: Optional[bool] = None + self._has_legacy_middlewares: bool = True self._state: Dict[Union[AppKey[Any], str], object] = {} self._frozen = False @@ -183,12 +197,10 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): return self is other @overload # type: ignore[override] - def __getitem__(self, key: AppKey[_T]) -> _T: - ... + def __getitem__(self, key: AppKey[_T]) -> _T: ... @overload - def __getitem__(self, key: str) -> Any: - ... + def __getitem__(self, key: str) -> Any: ... def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: return self._state[key] @@ -202,12 +214,10 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): ) @overload # type: ignore[override] - def __setitem__(self, key: AppKey[_T], value: _T) -> None: - ... + def __setitem__(self, key: AppKey[_T], value: _T) -> None: ... @overload - def __setitem__(self, key: str, value: Any) -> None: - ... + def __setitem__(self, key: str, value: Any) -> None: ... def __setitem__(self, key: Union[str, AppKey[_T]], value: Any) -> None: self._check_frozen() @@ -231,17 +241,17 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: return iter(self._state) + def __hash__(self) -> int: + return id(self) + @overload # type: ignore[override] - def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: - ... + def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... @overload - def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: - ... + def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: ... @overload - def get(self, key: str, default: Any = ...) -> Any: - ... + def get(self, key: str, default: Any = ...) -> Any: ... def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: return self._state.get(key, default) @@ -290,6 +300,9 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): self._on_shutdown.freeze() self._on_cleanup.freeze() self._middlewares_handlers = tuple(self._prepare_middleware()) + self._has_legacy_middlewares = any( + not new_style for _, new_style in self._middlewares_handlers + ) # If current app and any subapp do not have middlewares avoid run all # of the code footprint that it implies, which have a middleware @@ -334,7 +347,7 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): reg_handler("on_shutdown") reg_handler("on_cleanup") - def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource: + def add_subapp(self, prefix: str, subapp: "Application") -> PrefixedSubAppResource: if not isinstance(prefix, str): raise TypeError("Prefix must be str") prefix = prefix.rstrip("/") @@ -344,8 +357,8 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): return self._add_subapp(factory, subapp) def _add_subapp( - self, resource_factory: Callable[[], AbstractResource], subapp: "Application" - ) -> AbstractResource: + self, resource_factory: Callable[[], _Resource], subapp: "Application" + ) -> _Resource: if self.frozen: raise RuntimeError("Cannot add sub application to frozen application") if subapp.frozen: @@ -359,7 +372,7 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): subapp._set_loop(self._loop) return resource - def add_domain(self, domain: str, subapp: "Application") -> AbstractResource: + def add_domain(self, domain: str, subapp: "Application") -> MatchedSubAppResource: if not isinstance(domain, str): raise TypeError("Domain must be str") elif "*" in domain: @@ -520,29 +533,30 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): match_info.freeze() - resp = None request._match_info = match_info - expect = request.headers.get(hdrs.EXPECT) - if expect: + + if request.headers.get(hdrs.EXPECT): resp = await match_info.expect_handler(request) await request.writer.drain() + if resp is not None: + return resp - if resp is None: - handler = match_info.handler + handler = match_info.handler - if self._run_middlewares: + if self._run_middlewares: + if not self._has_legacy_middlewares: + handler = _build_middlewares(handler, match_info.apps) + else: for app in match_info.apps[::-1]: for m, new_style in app._middlewares_handlers: # type: ignore[union-attr] if new_style: handler = update_wrapper( - partial(m, handler=handler), handler + partial(m, handler=handler), handler # type: ignore[misc] ) else: handler = await m(app, handler) # type: ignore[arg-type,assignment] - resp = await handler(request) - - return resp + return await handler(request) def __call__(self) -> "Application": """gunicorn compatibility""" @@ -585,7 +599,7 @@ class CleanupContext(_CleanupContextBase): await it.__anext__() except StopAsyncIteration: pass - except Exception as exc: + except (Exception, asyncio.CancelledError) as exc: errors.append(exc) else: errors.append(RuntimeError(f"{it!r} has more than one 'yield'")) diff --git a/contrib/python/aiohttp/aiohttp/web_fileresponse.py b/contrib/python/aiohttp/aiohttp/web_fileresponse.py index 7dbe50f0a5a..f0de75e9f1b 100644 --- a/contrib/python/aiohttp/aiohttp/web_fileresponse.py +++ b/contrib/python/aiohttp/aiohttp/web_fileresponse.py @@ -1,7 +1,11 @@ import asyncio -import mimetypes import os import pathlib +import sys +from contextlib import suppress +from mimetypes import MimeTypes +from stat import S_ISREG +from types import MappingProxyType from typing import ( # noqa IO, TYPE_CHECKING, @@ -22,6 +26,8 @@ from .abc import AbstractStreamWriter from .helpers import ETAG_ANY, ETag, must_be_empty_body from .typedefs import LooseHeaders, PathLike from .web_exceptions import ( + HTTPForbidden, + HTTPNotFound, HTTPNotModified, HTTPPartialContent, HTTPPreconditionFailed, @@ -40,6 +46,35 @@ _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE")) +CONTENT_TYPES: Final[MimeTypes] = MimeTypes() + +if sys.version_info < (3, 9): + CONTENT_TYPES.encodings_map[".br"] = "br" + +# File extension to IANA encodings map that will be checked in the order defined. +ENCODING_EXTENSIONS = MappingProxyType( + {ext: CONTENT_TYPES.encodings_map[ext] for ext in (".br", ".gz")} +) + +FALLBACK_CONTENT_TYPE = "application/octet-stream" + +# Provide additional MIME type/extension pairs to be recognized. +# https://en.wikipedia.org/wiki/List_of_archive_formats#Compression_only +ADDITIONAL_CONTENT_TYPES = MappingProxyType( + { + "application/gzip": ".gz", + "application/x-brotli": ".br", + "application/x-bzip2": ".bz2", + "application/x-compress": ".Z", + "application/x-xz": ".xz", + } +) + +# Add custom pairs and clear the encodings map so guess_type ignores them. +CONTENT_TYPES.encodings_map.clear() +for content_type, extension in ADDITIONAL_CONTENT_TYPES.items(): + CONTENT_TYPES.add_type(content_type, extension) # type: ignore[attr-defined] + class FileResponse(StreamResponse): """A response object can be used to send files.""" @@ -101,10 +136,12 @@ class FileResponse(StreamResponse): return writer @staticmethod - def _strong_etag_match(etag_value: str, etags: Tuple[ETag, ...]) -> bool: + def _etag_match(etag_value: str, etags: Tuple[ETag, ...], *, weak: bool) -> bool: if len(etags) == 1 and etags[0].value == ETAG_ANY: return True - return any(etag.value == etag_value for etag in etags if not etag.is_weak) + return any( + etag.value == etag_value for etag in etags if weak or not etag.is_weak + ) async def _not_modified( self, request: "BaseRequest", etag_value: str, last_modified: float @@ -124,42 +161,60 @@ class FileResponse(StreamResponse): self.content_length = 0 return await super().prepare(request) - def _get_file_path_stat_and_gzip( - self, check_for_gzipped_file: bool - ) -> Tuple[pathlib.Path, os.stat_result, bool]: - """Return the file path, stat result, and gzip status. + def _get_file_path_stat_encoding( + self, accept_encoding: str + ) -> Tuple[pathlib.Path, os.stat_result, Optional[str]]: + """Return the file path, stat result, and encoding. + + If an uncompressed file is returned, the encoding is set to + :py:data:`None`. This method should be called from a thread executor since it calls os.stat which may block. """ - filepath = self._path - if check_for_gzipped_file: - gzip_path = filepath.with_name(filepath.name + ".gz") - try: - return gzip_path, gzip_path.stat(), True - except OSError: - # Fall through and try the non-gzipped file - pass + file_path = self._path + for file_extension, file_encoding in ENCODING_EXTENSIONS.items(): + if file_encoding not in accept_encoding: + continue + + compressed_path = file_path.with_suffix(file_path.suffix + file_extension) + with suppress(OSError): + # Do not follow symlinks and ignore any non-regular files. + st = compressed_path.lstat() + if S_ISREG(st.st_mode): + return compressed_path, st, file_encoding - return filepath, filepath.stat(), False + # Fallback to the uncompressed file + return file_path, file_path.stat(), None async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() # Encoding comparisons should be case-insensitive # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 - check_for_gzipped_file = ( - "gzip" in request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() - ) - filepath, st, gzip = await loop.run_in_executor( - None, self._get_file_path_stat_and_gzip, check_for_gzipped_file - ) + accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() + try: + file_path, st, file_encoding = await loop.run_in_executor( + None, self._get_file_path_stat_encoding, accept_encoding + ) + except OSError: + # Most likely to be FileNotFoundError or OSError for circular + # symlinks in python >= 3.13, so respond with 404. + self.set_status(HTTPNotFound.status_code) + return await super().prepare(request) + + # Forbid special files like sockets, pipes, devices, etc. + if not S_ISREG(st.st_mode): + self.set_status(HTTPForbidden.status_code) + return await super().prepare(request) etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" last_modified = st.st_mtime - # https://tools.ietf.org/html/rfc7232#section-6 + # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2 ifmatch = request.if_match - if ifmatch is not None and not self._strong_etag_match(etag_value, ifmatch): + if ifmatch is not None and not self._etag_match( + etag_value, ifmatch, weak=False + ): return await self._precondition_failed(request) unmodsince = request.if_unmodified_since @@ -170,8 +225,11 @@ class FileResponse(StreamResponse): ): return await self._precondition_failed(request) + # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2 ifnonematch = request.if_none_match - if ifnonematch is not None and self._strong_etag_match(etag_value, ifnonematch): + if ifnonematch is not None and self._etag_match( + etag_value, ifnonematch, weak=True + ): return await self._not_modified(request, etag_value, last_modified) modsince = request.if_modified_since @@ -182,15 +240,6 @@ class FileResponse(StreamResponse): ): return await self._not_modified(request, etag_value, last_modified) - if hdrs.CONTENT_TYPE not in self.headers: - ct, encoding = mimetypes.guess_type(str(filepath)) - if not ct: - ct = "application/octet-stream" - should_set_ct = True - else: - encoding = "gzip" if gzip else None - should_set_ct = False - status = self._status file_size = st.st_size count = file_size @@ -265,11 +314,16 @@ class FileResponse(StreamResponse): # return a HTTP 206 for a Range request. self.set_status(status) - if should_set_ct: - self.content_type = ct # type: ignore[assignment] - if encoding: - self.headers[hdrs.CONTENT_ENCODING] = encoding - if gzip: + # If the Content-Type header is not already set, guess it based on the + # extension of the request path. The encoding returned by guess_type + # can be ignored since the map was cleared above. + if hdrs.CONTENT_TYPE not in self.headers: + self.content_type = ( + CONTENT_TYPES.guess_type(self._path)[0] or FALLBACK_CONTENT_TYPE + ) + + if file_encoding: + self.headers[hdrs.CONTENT_ENCODING] = file_encoding self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING # Disable compression if we are already sending # a compressed file since we don't want to double @@ -293,7 +347,12 @@ class FileResponse(StreamResponse): if count == 0 or must_be_empty_body(request.method, self.status): return await super().prepare(request) - fobj = await loop.run_in_executor(None, filepath.open, "rb") + try: + fobj = await loop.run_in_executor(None, file_path.open, "rb") + except PermissionError: + self.set_status(HTTPForbidden.status_code) + return await super().prepare(request) + if start: # be aware that start could be None or int=0 here. offset = start else: diff --git a/contrib/python/aiohttp/aiohttp/web_middlewares.py b/contrib/python/aiohttp/aiohttp/web_middlewares.py index 5da1533c0df..2f1f5f58e6e 100644 --- a/contrib/python/aiohttp/aiohttp/web_middlewares.py +++ b/contrib/python/aiohttp/aiohttp/web_middlewares.py @@ -110,7 +110,12 @@ def normalize_path_middleware( def _fix_request_current_app(app: "Application") -> Middleware: @middleware async def impl(request: Request, handler: Handler) -> StreamResponse: - with request.match_info.set_current_app(app): + match_info = request.match_info + prev = match_info.current_app + match_info.current_app = app + try: return await handler(request) + finally: + match_info.current_app = prev return impl diff --git a/contrib/python/aiohttp/aiohttp/web_protocol.py b/contrib/python/aiohttp/aiohttp/web_protocol.py index f083b13eb0f..85eb70d5a0b 100644 --- a/contrib/python/aiohttp/aiohttp/web_protocol.py +++ b/contrib/python/aiohttp/aiohttp/web_protocol.py @@ -1,5 +1,6 @@ import asyncio import asyncio.streams +import sys import traceback import warnings from collections import deque @@ -26,7 +27,7 @@ import yarl from .abc import AbstractAccessLogger, AbstractStreamWriter from .base_protocol import BaseProtocol -from .helpers import ceil_timeout, set_exception +from .helpers import ceil_timeout from .http import ( HttpProcessingError, HttpRequestParser, @@ -37,7 +38,7 @@ from .http import ( from .log import access_logger, server_logger from .streams import EMPTY_PAYLOAD, StreamReader from .tcp_helpers import tcp_keepalive -from .web_exceptions import HTTPException +from .web_exceptions import HTTPException, HTTPInternalServerError from .web_log import AccessLogger from .web_request import BaseRequest from .web_response import Response, StreamResponse @@ -83,6 +84,9 @@ class PayloadAccessError(Exception): """Payload was accessed after response was sent.""" +_PAYLOAD_ACCESS_ERROR = PayloadAccessError() + + @attr.s(auto_attribs=True, frozen=True, slots=True) class _ErrInfo: status: int @@ -133,8 +137,6 @@ class RequestHandler(BaseProtocol): """ - KEEPALIVE_RESCHEDULE_DELAY = 1 - __slots__ = ( "_request_count", "_keepalive", @@ -142,12 +144,13 @@ class RequestHandler(BaseProtocol): "_request_handler", "_request_factory", "_tcp_keepalive", - "_keepalive_time", + "_next_keepalive_close_time", "_keepalive_handle", "_keepalive_timeout", "_lingering_time", "_messages", "_message_tail", + "_handler_waiter", "_waiter", "_task_handler", "_upgrade", @@ -162,6 +165,7 @@ class RequestHandler(BaseProtocol): "_force_close", "_current_request", "_timeout_ceil_threshold", + "_request_in_progress", ) def __init__( @@ -195,7 +199,7 @@ class RequestHandler(BaseProtocol): self._tcp_keepalive = tcp_keepalive # placeholder to be replaced on keepalive timeout setup - self._keepalive_time = 0.0 + self._next_keepalive_close_time = 0.0 self._keepalive_handle: Optional[asyncio.Handle] = None self._keepalive_timeout = keepalive_timeout self._lingering_time = float(lingering_time) @@ -204,6 +208,7 @@ class RequestHandler(BaseProtocol): self._message_tail = b"" self._waiter: Optional[asyncio.Future[None]] = None + self._handler_waiter: Optional[asyncio.Future[None]] = None self._task_handler: Optional[asyncio.Task[None]] = None self._upgrade = False @@ -237,6 +242,7 @@ class RequestHandler(BaseProtocol): self._close = False self._force_close = False + self._request_in_progress = False def __repr__(self) -> str: return "<{} {}>".format( @@ -259,25 +265,44 @@ class RequestHandler(BaseProtocol): if self._keepalive_handle is not None: self._keepalive_handle.cancel() - if self._waiter: - self._waiter.cancel() - - # wait for handlers - with suppress(asyncio.CancelledError, asyncio.TimeoutError): + # Wait for graceful handler completion + if self._request_in_progress: + # The future is only created when we are shutting + # down while the handler is still processing a request + # to avoid creating a future for every request. + self._handler_waiter = self._loop.create_future() + try: + async with ceil_timeout(timeout): + await self._handler_waiter + except (asyncio.CancelledError, asyncio.TimeoutError): + self._handler_waiter = None + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise + # Then cancel handler and wait + try: async with ceil_timeout(timeout): if self._current_request is not None: self._current_request._cancel(asyncio.CancelledError()) if self._task_handler is not None and not self._task_handler.done(): - await self._task_handler + await asyncio.shield(self._task_handler) + except (asyncio.CancelledError, asyncio.TimeoutError): + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise # force-close non-idle handler if self._task_handler is not None: self._task_handler.cancel() - if self.transport is not None: - self.transport.close() - self.transport = None + self.force_close() def connection_made(self, transport: asyncio.BaseTransport) -> None: super().connection_made(transport) @@ -286,22 +311,27 @@ class RequestHandler(BaseProtocol): if self._tcp_keepalive: tcp_keepalive(real_transport) - self._task_handler = self._loop.create_task(self.start()) assert self._manager is not None self._manager.connection_made(self, real_transport) + loop = self._loop + if sys.version_info >= (3, 12): + task = asyncio.Task(self.start(), loop=loop, eager_start=True) + else: + task = loop.create_task(self.start()) + self._task_handler = task + def connection_lost(self, exc: Optional[BaseException]) -> None: if self._manager is None: return self._manager.connection_lost(self, exc) - super().connection_lost(exc) - # Grab value before setting _manager to None. handler_cancellation = self._manager.handler_cancellation + self.force_close() + super().connection_lost(exc) self._manager = None - self._force_close = True self._request_factory = None self._request_handler = None self._request_parser = None @@ -314,9 +344,6 @@ class RequestHandler(BaseProtocol): exc = ConnectionResetError("Connection lost") self._current_request._cancel(exc) - if self._waiter is not None: - self._waiter.cancel() - if handler_cancellation and self._task_handler is not None: self._task_handler.cancel() @@ -421,23 +448,21 @@ class RequestHandler(BaseProtocol): self.logger.exception(*args, **kw) def _process_keepalive(self) -> None: + self._keepalive_handle = None if self._force_close or not self._keepalive: return - next = self._keepalive_time + self._keepalive_timeout + loop = self._loop + now = loop.time() + close_time = self._next_keepalive_close_time + if now <= close_time: + # Keep alive close check fired too early, reschedule + self._keepalive_handle = loop.call_at(close_time, self._process_keepalive) + return # handler in idle state - if self._waiter: - if self._loop.time() > next: - self.force_close() - return - - # not all request handlers are done, - # reschedule itself to next second - self._keepalive_handle = self._loop.call_later( - self.KEEPALIVE_RESCHEDULE_DELAY, - self._process_keepalive, - ) + if self._waiter and not self._waiter.done(): + self.force_close() async def _handle_request( self, @@ -445,7 +470,7 @@ class RequestHandler(BaseProtocol): start_time: float, request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]], ) -> Tuple[StreamResponse, bool]: - assert self._request_handler is not None + self._request_in_progress = True try: try: self._current_request = request @@ -454,16 +479,16 @@ class RequestHandler(BaseProtocol): self._current_request = None except HTTPException as exc: resp = exc - reset = await self.finish_response(request, resp, start_time) + resp, reset = await self.finish_response(request, resp, start_time) except asyncio.CancelledError: raise except asyncio.TimeoutError as exc: self.log_debug("Request handler timed out.", exc_info=exc) resp = self.handle_error(request, 504) - reset = await self.finish_response(request, resp, start_time) + resp, reset = await self.finish_response(request, resp, start_time) except Exception as exc: resp = self.handle_error(request, 500, exc) - reset = await self.finish_response(request, resp, start_time) + resp, reset = await self.finish_response(request, resp, start_time) else: # Deprecation warning (See #2415) if getattr(resp, "__http_exception__", False): @@ -474,7 +499,11 @@ class RequestHandler(BaseProtocol): DeprecationWarning, ) - reset = await self.finish_response(request, resp, start_time) + resp, reset = await self.finish_response(request, resp, start_time) + finally: + self._request_in_progress = False + if self._handler_waiter is not None: + self._handler_waiter.set_result(None) return resp, reset @@ -488,7 +517,7 @@ class RequestHandler(BaseProtocol): keep_alive(True) specified. """ loop = self._loop - handler = self._task_handler + handler = asyncio.current_task(loop) assert handler is not None manager = self._manager assert manager is not None @@ -503,8 +532,6 @@ class RequestHandler(BaseProtocol): # wait for next request self._waiter = loop.create_future() await self._waiter - except asyncio.CancelledError: - break finally: self._waiter = None @@ -524,12 +551,14 @@ class RequestHandler(BaseProtocol): request = self._request_factory(message, payload, self, writer, handler) try: # a new task is used for copy context vars (#3406) - task = self._loop.create_task( - self._handle_request(request, start, request_handler) - ) + coro = self._handle_request(request, start, request_handler) + if sys.version_info >= (3, 12): + task = asyncio.Task(coro, loop=loop, eager_start=True) + else: + task = loop.create_task(coro) try: resp, reset = await task - except (asyncio.CancelledError, ConnectionError): + except ConnectionError: self.log_debug("Ignored premature client disconnection") break @@ -553,27 +582,30 @@ class RequestHandler(BaseProtocol): now = loop.time() end_t = now + lingering_time - with suppress(asyncio.TimeoutError, asyncio.CancelledError): + try: while not payload.is_eof() and now < end_t: async with ceil_timeout(end_t - now): # read and ignore await payload.readany() now = loop.time() + except (asyncio.CancelledError, asyncio.TimeoutError): + if ( + sys.version_info >= (3, 11) + and (t := asyncio.current_task()) + and t.cancelling() + ): + raise # if payload still uncompleted if not payload.is_eof() and not self._force_close: self.log_debug("Uncompleted request.") self.close() - set_exception(payload, PayloadAccessError()) + payload.set_exception(_PAYLOAD_ACCESS_ERROR) except asyncio.CancelledError: - self.log_debug("Ignored premature client disconnection ") - break - except RuntimeError as exc: - if self.debug: - self.log_exception("Unhandled runtime exception", exc_info=exc) - self.force_close() + self.log_debug("Ignored premature client disconnection") + raise except Exception as exc: self.log_exception("Unhandled exception", exc_info=exc) self.force_close() @@ -584,11 +616,12 @@ class RequestHandler(BaseProtocol): if self._keepalive and not self._close: # start keep-alive timer if keepalive_timeout is not None: - now = self._loop.time() - self._keepalive_time = now + now = loop.time() + close_time = now + keepalive_timeout + self._next_keepalive_close_time = close_time if self._keepalive_handle is None: self._keepalive_handle = loop.call_at( - now + keepalive_timeout, self._process_keepalive + close_time, self._process_keepalive ) else: break @@ -601,7 +634,7 @@ class RequestHandler(BaseProtocol): async def finish_response( self, request: BaseRequest, resp: StreamResponse, start_time: float - ) -> bool: + ) -> Tuple[StreamResponse, bool]: """Prepare the response and write_eof, then log access. This has to @@ -609,6 +642,7 @@ class RequestHandler(BaseProtocol): can get exception information. Returns True if the client disconnects prematurely. """ + request._finish() if self._request_parser is not None: self._request_parser.set_upgraded(False) self._upgrade = False @@ -619,22 +653,26 @@ class RequestHandler(BaseProtocol): prepare_meth = resp.prepare except AttributeError: if resp is None: - raise RuntimeError("Missing return " "statement on request handler") + self.log_exception("Missing return statement on request handler") else: - raise RuntimeError( - "Web-handler should return " - "a response instance, " + self.log_exception( + "Web-handler should return a response instance, " "got {!r}".format(resp) ) + exc = HTTPInternalServerError() + resp = Response( + status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers + ) + prepare_meth = resp.prepare try: await prepare_meth(request) await resp.write_eof() except ConnectionError: self.log_access(request, resp, start_time) - return True - else: - self.log_access(request, resp, start_time) - return False + return resp, True + + self.log_access(request, resp, start_time) + return resp, False def handle_error( self, diff --git a/contrib/python/aiohttp/aiohttp/web_request.py b/contrib/python/aiohttp/aiohttp/web_request.py index 4bc670a798c..eca71e4413a 100644 --- a/contrib/python/aiohttp/aiohttp/web_request.py +++ b/contrib/python/aiohttp/aiohttp/web_request.py @@ -79,7 +79,7 @@ class FileField: filename: str file: io.BufferedReader content_type: str - headers: "CIMultiDictProxy[str]" + headers: CIMultiDictProxy[str] _TCHAR: Final[str] = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-" @@ -99,10 +99,10 @@ _QUOTED_STRING: Final[str] = r'"(?:{quoted_pair}|{qdtext})*"'.format( qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR ) -_FORWARDED_PAIR: Final[ - str -] = r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format( - token=_TOKEN, quoted_string=_QUOTED_STRING +_FORWARDED_PAIR: Final[str] = ( + r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format( + token=_TOKEN, quoted_string=_QUOTED_STRING + ) ) _QUOTED_PAIR_REPLACE_RE: Final[Pattern[str]] = re.compile(r"\\([\t !-~])") @@ -169,12 +169,16 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): self._payload_writer = payload_writer self._payload = payload - self._headers = message.headers + self._headers: CIMultiDictProxy[str] = message.headers self._method = message.method self._version = message.version self._cache: Dict[str, Any] = {} url = message.url if url.is_absolute(): + if scheme is not None: + url = url.with_scheme(scheme) + if host is not None: + url = url.with_host(host) # absolute URL is given, # override auto-calculating url, host, and scheme # all other properties should be good @@ -184,6 +188,10 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): self._rel_url = url.relative() else: self._rel_url = message.url + if scheme is not None: + self._cache["scheme"] = scheme + if host is not None: + self._cache["host"] = host self._post: Optional[MultiDictProxy[Union[str, bytes, FileField]]] = None self._read_bytes: Optional[bytes] = None @@ -197,10 +205,6 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): self._transport_sslcontext = transport.get_extra_info("sslcontext") self._transport_peername = transport.get_extra_info("peername") - if scheme is not None: - self._cache["scheme"] = scheme - if host is not None: - self._cache["host"] = host if remote is not None: self._cache["remote"] = remote @@ -235,7 +239,8 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): # a copy semantic dct["headers"] = CIMultiDictProxy(CIMultiDict(headers)) dct["raw_headers"] = tuple( - (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() + (k.encode("utf-8"), v.encode("utf-8")) + for k, v in dct["headers"].items() ) message = self._message._replace(**dct) @@ -481,7 +486,7 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): @reify def query(self) -> "MultiMapping[str]": """A multidict with all the variables in the query string.""" - return MultiDictProxy(self._rel_url.query) + return self._rel_url.query @reify def query_string(self) -> str: @@ -492,7 +497,7 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): return self._rel_url.query_string @reify - def headers(self) -> "MultiMapping[str]": + def headers(self) -> CIMultiDictProxy[str]: """A case-insensitive multidict proxy with all headers.""" return self._headers @@ -819,6 +824,18 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): def _cancel(self, exc: BaseException) -> None: set_exception(self._payload, exc) + def _finish(self) -> None: + if self._post is None or self.content_type != "multipart/form-data": + return + + # NOTE: Release file descriptors for the + # NOTE: `tempfile.Temporaryfile`-created `_io.BufferedRandom` + # NOTE: instances of files sent within multipart request body + # NOTE: via HTTP POST request. + for file_name, file_field_object in self._post.items(): + if isinstance(file_field_object, FileField): + file_field_object.file.close() + class Request(BaseRequest): @@ -898,4 +915,5 @@ class Request(BaseRequest): if match_info is None: return for app in match_info._apps: - await app.on_response_prepare.send(self, response) + if on_response_prepare := app.on_response_prepare: + await on_response_prepare.send(self, response) diff --git a/contrib/python/aiohttp/aiohttp/web_response.py b/contrib/python/aiohttp/aiohttp/web_response.py index 40d6f01ecaa..4307b2a98c8 100644 --- a/contrib/python/aiohttp/aiohttp/web_response.py +++ b/contrib/python/aiohttp/aiohttp/web_response.py @@ -41,6 +41,8 @@ from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11 from .payload import Payload from .typedefs import JSONEncoder, LooseHeaders +REASON_PHRASES = {http_status.value: http_status.phrase for http_status in HTTPStatus} + __all__ = ("ContentCoding", "StreamResponse", "Response", "json_response") @@ -52,6 +54,7 @@ else: BaseClass = collections.abc.MutableMapping +# TODO(py311): Convert to StrEnum for wider use class ContentCoding(enum.Enum): # The content codings that we have support for. # @@ -62,6 +65,8 @@ class ContentCoding(enum.Enum): identity = "identity" +CONTENT_CODINGS = {coding.value: coding for coding in ContentCoding} + ############################################################ # HTTP Response classes ############################################################ @@ -71,6 +76,8 @@ class StreamResponse(BaseClass, HeadersMixin): _length_check = True + _body: Union[None, bytes, bytearray, Payload] + def __init__( self, *, @@ -97,11 +104,11 @@ class StreamResponse(BaseClass, HeadersMixin): else: self._headers = CIMultiDict() - self.set_status(status, reason) + self._set_status(status, reason) @property def prepared(self) -> bool: - return self._payload_writer is not None + return self._eof_sent or self._payload_writer is not None @property def task(self) -> "Optional[asyncio.Task[None]]": @@ -131,15 +138,15 @@ class StreamResponse(BaseClass, HeadersMixin): status: int, reason: Optional[str] = None, ) -> None: - assert not self.prepared, ( - "Cannot change the response status code after " "the headers have been sent" - ) + assert ( + not self.prepared + ), "Cannot change the response status code after the headers have been sent" + self._set_status(status, reason) + + def _set_status(self, status: int, reason: Optional[str]) -> None: self._status = int(status) if reason is None: - try: - reason = HTTPStatus(self._status).phrase - except ValueError: - reason = "" + reason = REASON_PHRASES.get(self._status, "") self._reason = reason @property @@ -175,7 +182,7 @@ class StreamResponse(BaseClass, HeadersMixin): ) -> None: """Enables response compression encoding.""" # Backwards compatibility for when force was a bool <0.17. - if type(force) == bool: + if isinstance(force, bool): force = ContentCoding.deflate if force else ContentCoding.identity warnings.warn( "Using boolean for force is deprecated #3318", DeprecationWarning @@ -403,8 +410,8 @@ class StreamResponse(BaseClass, HeadersMixin): # Encoding comparisons should be case-insensitive # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() - for coding in ContentCoding: - if coding.value in accept_encoding: + for value, coding in CONTENT_CODINGS.items(): + if value in accept_encoding: await self._do_start_compression(coding) return @@ -499,9 +506,7 @@ class StreamResponse(BaseClass, HeadersMixin): assert writer is not None # status line version = request.version - status_line = "HTTP/{}.{} {} {}".format( - version[0], version[1], self._status, self._reason - ) + status_line = f"HTTP/{version[0]}.{version[1]} {self._status} {self._reason}" await writer.write_headers(status_line, self._headers) async def write(self, data: bytes) -> None: @@ -650,21 +655,17 @@ class Response(StreamResponse): return self._body @body.setter - def body(self, body: bytes) -> None: + def body(self, body: Any) -> None: if body is None: - self._body: Optional[bytes] = None - self._body_payload: bool = False + self._body = None elif isinstance(body, (bytes, bytearray)): self._body = body - self._body_payload = False else: try: self._body = body = payload.PAYLOAD_REGISTRY.get(body) except payload.LookupError: raise ValueError("Unsupported body type %r" % type(body)) - self._body_payload = True - headers = self._headers # set content-type @@ -673,7 +674,7 @@ class Response(StreamResponse): # copy payload headers if body.headers: - for (key, value) in body.headers.items(): + for key, value in body.headers.items(): if key not in headers: headers[key] = value @@ -697,7 +698,6 @@ class Response(StreamResponse): self.charset = "utf-8" self._body = text.encode(self.charset) - self._body_payload = False self._compressed_body = None @property @@ -711,7 +711,7 @@ class Response(StreamResponse): if self._compressed_body is not None: # Return length of the compressed body return len(self._compressed_body) - elif self._body_payload: + elif isinstance(self._body, Payload): # A payload without content length, or a compressed payload return None elif self._body is not None: @@ -736,9 +736,8 @@ class Response(StreamResponse): if body is not None: if self._must_be_empty_body: await super().write_eof() - elif self._body_payload: - payload = cast(Payload, body) - await payload.write(self._payload_writer) + elif isinstance(self._body, Payload): + await self._body.write(self._payload_writer) await super().write_eof() else: await super().write_eof(cast(bytes, body)) @@ -746,14 +745,13 @@ class Response(StreamResponse): await super().write_eof() async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: - if should_remove_content_length(request.method, self.status): - if hdrs.CONTENT_LENGTH in self._headers: + if hdrs.CONTENT_LENGTH in self._headers: + if should_remove_content_length(request.method, self.status): del self._headers[hdrs.CONTENT_LENGTH] - elif not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: - if self._body_payload: - size = cast(Payload, self._body).size - if size is not None: - self._headers[hdrs.CONTENT_LENGTH] = str(size) + elif not self._chunked: + if isinstance(self._body, Payload): + if self._body.size is not None: + self._headers[hdrs.CONTENT_LENGTH] = str(self._body.size) else: body_len = len(self._body) if self._body else "0" # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-7 @@ -765,7 +763,7 @@ class Response(StreamResponse): return await super()._start(request) async def _do_start_compression(self, coding: ContentCoding) -> None: - if self._body_payload or self._chunked: + if self._chunked or isinstance(self._body, Payload): return await super()._do_start_compression(coding) if coding != ContentCoding.identity: diff --git a/contrib/python/aiohttp/aiohttp/web_routedef.py b/contrib/python/aiohttp/aiohttp/web_routedef.py index d79cd32a14a..93802141c56 100644 --- a/contrib/python/aiohttp/aiohttp/web_routedef.py +++ b/contrib/python/aiohttp/aiohttp/web_routedef.py @@ -162,12 +162,10 @@ class RouteTableDef(Sequence[AbstractRouteDef]): return f"<RouteTableDef count={len(self._items)}>" @overload - def __getitem__(self, index: int) -> AbstractRouteDef: - ... + def __getitem__(self, index: int) -> AbstractRouteDef: ... @overload - def __getitem__(self, index: slice) -> List[AbstractRouteDef]: - ... + def __getitem__(self, index: slice) -> List[AbstractRouteDef]: ... def __getitem__(self, index): # type: ignore[no-untyped-def] return self._items[index] diff --git a/contrib/python/aiohttp/aiohttp/web_runner.py b/contrib/python/aiohttp/aiohttp/web_runner.py index 19a4441658f..0a237ede2c5 100644 --- a/contrib/python/aiohttp/aiohttp/web_runner.py +++ b/contrib/python/aiohttp/aiohttp/web_runner.py @@ -3,7 +3,7 @@ import signal import socket import warnings from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, List, Optional, Set +from typing import Any, List, Optional, Set from yarl import URL @@ -108,7 +108,7 @@ class TCPSite(BaseSite): @property def name(self) -> str: scheme = "https" if self._ssl_context else "http" - host = "0.0.0.0" if self._host is None else self._host + host = "0.0.0.0" if not self._host else self._host return str(URL.build(scheme=scheme, host=host, port=self._port)) async def start(self) -> None: @@ -238,14 +238,7 @@ class SockSite(BaseSite): class BaseRunner(ABC): - __slots__ = ( - "shutdown_callback", - "_handle_signals", - "_kwargs", - "_server", - "_sites", - "_shutdown_timeout", - ) + __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout") def __init__( self, @@ -254,7 +247,6 @@ class BaseRunner(ABC): shutdown_timeout: float = 60.0, **kwargs: Any, ) -> None: - self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None self._handle_signals = handle_signals self._kwargs = kwargs self._server: Optional[Server] = None @@ -312,10 +304,6 @@ class BaseRunner(ABC): await asyncio.sleep(0) self._server.pre_shutdown() await self.shutdown() - - if self.shutdown_callback: - await self.shutdown_callback() - await self._server.shutdown(self._shutdown_timeout) await self._cleanup_server() diff --git a/contrib/python/aiohttp/aiohttp/web_server.py b/contrib/python/aiohttp/aiohttp/web_server.py index 52faacb164a..973e7c15440 100644 --- a/contrib/python/aiohttp/aiohttp/web_server.py +++ b/contrib/python/aiohttp/aiohttp/web_server.py @@ -1,9 +1,9 @@ """Low level HTTP server.""" + import asyncio from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa from .abc import AbstractStreamWriter -from .helpers import get_running_loop from .http_parser import RawRequestMessage from .streams import StreamReader from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler @@ -22,7 +22,7 @@ class Server: loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any ) -> None: - self._loop = get_running_loop(loop) + self._loop = loop or asyncio.get_event_loop() self._connections: Dict[RequestHandler, asyncio.Transport] = {} self._kwargs = kwargs self.requests_count = 0 @@ -43,7 +43,12 @@ class Server: self, handler: RequestHandler, exc: Optional[BaseException] = None ) -> None: if handler in self._connections: - del self._connections[handler] + if handler._task_handler: + handler._task_handler.add_done_callback( + lambda f: self._connections.pop(handler, None) + ) + else: + del self._connections[handler] def _make_request( self, diff --git a/contrib/python/aiohttp/aiohttp/web_urldispatcher.py b/contrib/python/aiohttp/aiohttp/web_urldispatcher.py index 954291f6449..89abdc43fa6 100644 --- a/contrib/python/aiohttp/aiohttp/web_urldispatcher.py +++ b/contrib/python/aiohttp/aiohttp/web_urldispatcher.py @@ -8,8 +8,8 @@ import inspect import keyword import os import re +import sys import warnings -from contextlib import contextmanager from functools import wraps from pathlib import Path from types import MappingProxyType @@ -38,7 +38,7 @@ from typing import ( cast, ) -from yarl import URL, __version__ as yarl_version # type: ignore[attr-defined] +from yarl import URL, __version__ as yarl_version from . import hdrs from .abc import AbstractMatchInfo, AbstractRouter, AbstractView @@ -78,6 +78,12 @@ if TYPE_CHECKING: else: BaseDict = dict +CIRCULAR_SYMLINK_ERROR = ( + (OSError,) + if sys.version_info < (3, 10) and sys.platform.startswith("win32") + else (RuntimeError,) if sys.version_info < (3, 13) else () +) + YARL_VERSION: Final[Tuple[int, ...]] = tuple(map(int, yarl_version.split(".")[:2])) HTTP_METHOD_RE: Final[Pattern[str]] = re.compile( @@ -199,7 +205,7 @@ class AbstractRoute(abc.ABC): @wraps(handler) async def handler_wrapper(request: Request) -> StreamResponse: - result = old_handler(request) + result = old_handler(request) # type: ignore[call-arg] if asyncio.iscoroutine(result): result = await result assert isinstance(result, StreamResponse) @@ -286,8 +292,8 @@ class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo): assert app is not None return app - @contextmanager - def set_current_app(self, app: "Application") -> Generator[None, None, None]: + @current_app.setter + def current_app(self, app: "Application") -> None: if DEBUG: # pragma: no cover if app not in self._apps: raise RuntimeError( @@ -295,12 +301,7 @@ class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo): self._apps, app ) ) - prev = self._current_app self._current_app = app - try: - yield - finally: - self._current_app = prev def freeze(self) -> None: self._frozen = True @@ -334,6 +335,8 @@ async def _default_expect_handler(request: Request) -> None: if request.version == HttpVersion11: if expect.lower() == "100-continue": await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") + # Reset output_size as we haven't started the main body yet. + request.writer.output_size = 0 else: raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect) @@ -372,7 +375,7 @@ class Resource(AbstractResource): async def resolve(self, request: Request) -> _Resolve: allowed_methods: Set[str] = set() - match_dict = self._match(request.rel_url.raw_path) + match_dict = self._match(request.rel_url.path_safe) if match_dict is None: return None, allowed_methods @@ -422,8 +425,7 @@ class PlainResource(Resource): # string comparison is about 10 times faster than regexp matching if self._path == path: return {} - else: - return None + return None def raw_match(self, path: str) -> bool: return self._path == path @@ -447,6 +449,7 @@ class DynamicResource(Resource): def __init__(self, path: str, *, name: Optional[str] = None) -> None: super().__init__(name=name) + self._orig_path = path pattern = "" formatter = "" for part in ROUTE_RE.split(path): @@ -493,13 +496,12 @@ class DynamicResource(Resource): match = self._pattern.fullmatch(path) if match is None: return None - else: - return { - key: _unquote_path(value) for key, value in match.groupdict().items() - } + return { + key: _unquote_path_safe(value) for key, value in match.groupdict().items() + } def raw_match(self, path: str) -> bool: - return self._formatter == path + return self._orig_path == path def get_info(self) -> _InfoDict: return {"formatter": self._formatter, "pattern": self._pattern} @@ -557,14 +559,11 @@ class StaticResource(PrefixResource): ) -> None: super().__init__(prefix, name=name) try: - directory = Path(directory) - if str(directory).startswith("~"): - directory = Path(os.path.expanduser(str(directory))) - directory = directory.resolve() - if not directory.is_dir(): - raise ValueError("Not a directory") - except (FileNotFoundError, ValueError) as error: - raise ValueError(f"No directory exists at '{directory}'") from error + directory = Path(directory).expanduser().resolve(strict=True) + except FileNotFoundError as error: + raise ValueError(f"'{directory}' does not exist") from error + if not directory.is_dir(): + raise ValueError(f"'{directory}' is not a directory") self._directory = directory self._show_index = show_index self._chunk_size = chunk_size @@ -644,7 +643,7 @@ class StaticResource(PrefixResource): ) async def resolve(self, request: Request) -> _Resolve: - path = request.rel_url.raw_path + path = request.rel_url.path_safe method = request.method allowed_methods = set(self._routes) if not path.startswith(self._prefix2) and path != self._prefix: @@ -653,7 +652,7 @@ class StaticResource(PrefixResource): if method not in allowed_methods: return None, allowed_methods - match_dict = {"filename": _unquote_path(path[len(self._prefix) + 1 :])} + match_dict = {"filename": _unquote_path_safe(path[len(self._prefix) + 1 :])} return (UrlMappingMatchInfo(match_dict, self._routes[method]), allowed_methods) def __len__(self) -> int: @@ -664,59 +663,64 @@ class StaticResource(PrefixResource): async def _handle(self, request: Request) -> StreamResponse: rel_url = request.match_info["filename"] + filename = Path(rel_url) + if filename.anchor: + # rel_url is an absolute name like + # /static/\\machine_name\c$ or /static/D:\path + # where the static dir is totally different + raise HTTPForbidden() + + unresolved_path = self._directory.joinpath(filename) + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, self._resolve_path_to_response, unresolved_path + ) + + def _resolve_path_to_response(self, unresolved_path: Path) -> StreamResponse: + """Take the unresolved path and query the file system to form a response.""" + # Check for access outside the root directory. For follow symlinks, URI + # cannot traverse out, but symlinks can. Otherwise, no access outside + # root is permitted. try: - filename = Path(rel_url) - if filename.anchor: - # rel_url is an absolute name like - # /static/\\machine_name\c$ or /static/D:\path - # where the static dir is totally different - raise HTTPForbidden() - unresolved_path = self._directory.joinpath(filename) if self._follow_symlinks: normalized_path = Path(os.path.normpath(unresolved_path)) normalized_path.relative_to(self._directory) - filepath = normalized_path.resolve() + file_path = normalized_path.resolve() else: - filepath = unresolved_path.resolve() - filepath.relative_to(self._directory) - except (ValueError, FileNotFoundError) as error: - # relatively safe - raise HTTPNotFound() from error - except HTTPForbidden: - raise - except Exception as error: - # perm error or other kind! - request.app.logger.exception(error) + file_path = unresolved_path.resolve() + file_path.relative_to(self._directory) + except (ValueError, *CIRCULAR_SYMLINK_ERROR) as error: + # ValueError is raised for the relative check. Circular symlinks + # raise here on resolving for python < 3.13. raise HTTPNotFound() from error - # on opening a dir, load its contents if allowed - if filepath.is_dir(): - if self._show_index: - try: + # if path is a directory, return the contents if permitted. Note the + # directory check will raise if a segment is not readable. + try: + if file_path.is_dir(): + if self._show_index: return Response( - text=self._directory_as_html(filepath), content_type="text/html" + text=self._directory_as_html(file_path), + content_type="text/html", ) - except PermissionError: + else: raise HTTPForbidden() - else: - raise HTTPForbidden() - elif filepath.is_file(): - return FileResponse(filepath, chunk_size=self._chunk_size) - else: - raise HTTPNotFound + except PermissionError as error: + raise HTTPForbidden() from error - def _directory_as_html(self, filepath: Path) -> str: - # returns directory's index as html + # Return the file response, which handles all other checks. + return FileResponse(file_path, chunk_size=self._chunk_size) - # sanity check - assert filepath.is_dir() + def _directory_as_html(self, dir_path: Path) -> str: + """returns directory's index as html.""" + assert dir_path.is_dir() - relative_path_to_dir = filepath.relative_to(self._directory).as_posix() + relative_path_to_dir = dir_path.relative_to(self._directory).as_posix() index_of = f"Index of /{html_escape(relative_path_to_dir)}" h1 = f"<h1>{index_of}</h1>" index_list = [] - dir_index = filepath.iterdir() + dir_index = dir_path.iterdir() for _file in sorted(dir_index): # show file url as relative to static path rel_path = _file.relative_to(self._directory).as_posix() @@ -750,13 +754,20 @@ class PrefixedSubAppResource(PrefixResource): def __init__(self, prefix: str, app: "Application") -> None: super().__init__(prefix) self._app = app - for resource in app.router.resources(): - resource.add_prefix(prefix) + self._add_prefix_to_resources(prefix) def add_prefix(self, prefix: str) -> None: super().add_prefix(prefix) - for resource in self._app.router.resources(): + self._add_prefix_to_resources(prefix) + + def _add_prefix_to_resources(self, prefix: str) -> None: + router = self._app.router + for resource in router.resources(): + # Since the canonical path of a resource is about + # to change, we need to unindex it and then reindex + router.unindex_resource(resource) resource.add_prefix(prefix) + router.index_resource(resource) def url_for(self, *args: str, **kwargs: str) -> URL: raise RuntimeError(".url_for() is not supported " "by sub-application root") @@ -765,11 +776,6 @@ class PrefixedSubAppResource(PrefixResource): return {"app": self._app, "prefix": self._prefix} async def resolve(self, request: Request) -> _Resolve: - if ( - not request.url.raw_path.startswith(self._prefix2) - and request.url.raw_path != self._prefix - ): - return None, set() match_info = await self._app.router.resolve(request) match_info.add_app(self._app) if isinstance(match_info.http_exception, HTTPMethodNotAllowed): @@ -1015,12 +1021,39 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]): super().__init__() self._resources: List[AbstractResource] = [] self._named_resources: Dict[str, AbstractResource] = {} + self._resource_index: dict[str, list[AbstractResource]] = {} + self._matched_sub_app_resources: List[MatchedSubAppResource] = [] async def resolve(self, request: Request) -> UrlMappingMatchInfo: - method = request.method + resource_index = self._resource_index allowed_methods: Set[str] = set() - for resource in self._resources: + # Walk the url parts looking for candidates. We walk the url backwards + # to ensure the most explicit match is found first. If there are multiple + # candidates for a given url part because there are multiple resources + # registered for the same canonical path, we resolve them in a linear + # fashion to ensure registration order is respected. + url_part = request.rel_url.path_safe + while url_part: + for candidate in resource_index.get(url_part, ()): + match_dict, allowed = await candidate.resolve(request) + if match_dict is not None: + return match_dict + else: + allowed_methods |= allowed + if url_part == "/": + break + url_part = url_part.rpartition("/")[0] or "/" + + # + # We didn't find any candidates, so we'll try the matched sub-app + # resources which we have to walk in a linear fashion because they + # have regex/wildcard match rules and we cannot index them. + # + # For most cases we do not expect there to be many of these since + # currently they are only added by `add_domain` + # + for resource in self._matched_sub_app_resources: match_dict, allowed = await resource.resolve(request) if match_dict is not None: return match_dict @@ -1028,9 +1061,9 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]): allowed_methods |= allowed if allowed_methods: - return MatchInfoError(HTTPMethodNotAllowed(method, allowed_methods)) - else: - return MatchInfoError(HTTPNotFound()) + return MatchInfoError(HTTPMethodNotAllowed(request.method, allowed_methods)) + + return MatchInfoError(HTTPNotFound()) def __iter__(self) -> Iterator[str]: return iter(self._named_resources) @@ -1086,6 +1119,36 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]): self._named_resources[name] = resource self._resources.append(resource) + if isinstance(resource, MatchedSubAppResource): + # We cannot index match sub-app resources because they have match rules + self._matched_sub_app_resources.append(resource) + else: + self.index_resource(resource) + + def _get_resource_index_key(self, resource: AbstractResource) -> str: + """Return a key to index the resource in the resource index.""" + if "{" in (index_key := resource.canonical): + # strip at the first { to allow for variables, and than + # rpartition at / to allow for variable parts in the path + # For example if the canonical path is `/core/locations{tail:.*}` + # the index key will be `/core` since index is based on the + # url parts split by `/` + index_key = index_key.partition("{")[0].rpartition("/")[0] + return index_key.rstrip("/") or "/" + + def index_resource(self, resource: AbstractResource) -> None: + """Add a resource to the resource index.""" + resource_key = self._get_resource_index_key(resource) + # There may be multiple resources for a canonical path + # so we keep them in a list to ensure that registration + # order is respected. + self._resource_index.setdefault(resource_key, []).append(resource) + + def unindex_resource(self, resource: AbstractResource) -> None: + """Remove a resource from the resource index.""" + resource_key = self._get_resource_index_key(resource) + self._resource_index[resource_key].remove(resource) + def add_resource(self, path: str, *, name: Optional[str] = None) -> Resource: if path and not path.startswith("/"): raise ValueError("path should be started with / or be empty") @@ -1095,7 +1158,7 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]): if resource.name == name and resource.raw_match(path): return cast(Resource, resource) if not ("{" in path or "}" in path or ROUTE_RE.search(path)): - resource = PlainResource(_requote_path(path), name=name) + resource = PlainResource(path, name=name) self.register_resource(resource) return resource resource = DynamicResource(path, name=name) @@ -1221,8 +1284,10 @@ def _quote_path(value: str) -> str: return URL.build(path=value, encoded=False).raw_path -def _unquote_path(value: str) -> str: - return URL.build(path=value, encoded=True).path +def _unquote_path_safe(value: str) -> str: + if "%" not in value: + return value + return value.replace("%2F", "/").replace("%25", "%") def _requote_path(value: str) -> str: diff --git a/contrib/python/aiohttp/aiohttp/web_ws.py b/contrib/python/aiohttp/aiohttp/web_ws.py index 9fe66527539..382223097ea 100644 --- a/contrib/python/aiohttp/aiohttp/web_ws.py +++ b/contrib/python/aiohttp/aiohttp/web_ws.py @@ -11,7 +11,7 @@ from multidict import CIMultiDict from . import hdrs from .abc import AbstractStreamWriter -from .helpers import call_later, set_exception, set_result +from .helpers import calculate_timeout_when, set_exception, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -81,67 +81,119 @@ class WebSocketResponse(StreamResponse): self._conn_lost = 0 self._close_code: Optional[int] = None self._loop: Optional[asyncio.AbstractEventLoop] = None - self._waiting: Optional[asyncio.Future[bool]] = None + self._waiting: bool = False + self._close_wait: Optional[asyncio.Future[None]] = None self._exception: Optional[BaseException] = None self._timeout = timeout self._receive_timeout = receive_timeout self._autoclose = autoclose self._autoping = autoping self._heartbeat = heartbeat + self._heartbeat_when = 0.0 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb: Optional[asyncio.TimerHandle] = None self._compress = compress self._max_msg_size = max_msg_size + self._ping_task: Optional[asyncio.Task[None]] = None def _cancel_heartbeat(self) -> None: - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = None - + self._cancel_pong_response_cb() if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None + if self._ping_task is not None: + self._ping_task.cancel() + self._ping_task = None + + def _cancel_pong_response_cb(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None def _reset_heartbeat(self) -> None: - self._cancel_heartbeat() + if self._heartbeat is None: + return + self._cancel_pong_response_cb() + req = self._req + timeout_ceil_threshold = ( + req._protocol._timeout_ceil_threshold if req is not None else 5 + ) + loop = self._loop + assert loop is not None + now = loop.time() + when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) + self._heartbeat_when = when + if self._heartbeat_cb is None: + # We do not cancel the previous heartbeat_cb here because + # it generates a significant amount of TimerHandle churn + # which causes asyncio to rebuild the heap frequently. + # Instead _send_heartbeat() will reschedule the next + # heartbeat if it fires too early. + self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) - if self._heartbeat is not None: - assert self._loop is not None - self._heartbeat_cb = call_later( - self._send_heartbeat, - self._heartbeat, - self._loop, - timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold - if self._req is not None - else 5, + def _send_heartbeat(self) -> None: + self._heartbeat_cb = None + loop = self._loop + assert loop is not None and self._writer is not None + now = loop.time() + if now < self._heartbeat_when: + # Heartbeat fired too early, reschedule + self._heartbeat_cb = loop.call_at( + self._heartbeat_when, self._send_heartbeat ) + return - def _send_heartbeat(self) -> None: - if self._heartbeat is not None and not self._closed: - assert self._loop is not None - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - self._loop.create_task(self._writer.ping()) # type: ignore[union-attr] + req = self._req + timeout_ceil_threshold = ( + req._protocol._timeout_ceil_threshold if req is not None else 5 + ) + when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) + self._cancel_pong_response_cb() + self._pong_response_cb = loop.call_at(when, self._pong_not_received) - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = call_later( - self._pong_not_received, - self._pong_heartbeat, - self._loop, - timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold - if self._req is not None - else 5, - ) + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to send the ping + # immediately to avoid having to schedule + # the task on the event loop. + ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True) + else: + ping_task = loop.create_task(self._writer.ping()) + + if not ping_task.done(): + self._ping_task = ping_task + ping_task.add_done_callback(self._ping_task_done) + else: + self._ping_task_done(ping_task) + + def _ping_task_done(self, task: "asyncio.Task[None]") -> None: + """Callback for when the ping task completes.""" + if not task.cancelled() and (exc := task.exception()): + self._handle_ping_pong_exception(exc) + self._ping_task = None def _pong_not_received(self) -> None: if self._req is not None and self._req.transport is not None: - self._closed = True - self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) - self._exception = asyncio.TimeoutError() + self._handle_ping_pong_exception(asyncio.TimeoutError()) + + def _handle_ping_pong_exception(self, exc: BaseException) -> None: + """Handle exceptions raised during ping/pong processing.""" + if self._closed: + return + self._set_closed() + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + self._exception = exc + if self._waiting and not self._closing and self._reader is not None: + self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None)) + + def _set_closed(self) -> None: + """Set the connection to closed. + + Cancel any heartbeat timers and set the closed flag. + """ + self._closed = True + self._cancel_heartbeat() async def prepare(self, request: BaseRequest) -> AbstractStreamWriter: # make pre-check to don't hide it by do_handshake() exceptions @@ -366,20 +418,10 @@ class WebSocketResponse(StreamResponse): if self._writer is None: raise RuntimeError("Call .prepare() first") - self._cancel_heartbeat() - reader = self._reader - assert reader is not None - - # we need to break `receive()` cycle first, - # `close()` may be called from different task - if self._waiting is not None and not self._closed: - reader.feed_data(WS_CLOSING_MESSAGE, 0) - await self._waiting - if self._closed: return False + self._set_closed() - self._closed = True try: await self._writer.close(code, message) writer = self._payload_writer @@ -394,12 +436,21 @@ class WebSocketResponse(StreamResponse): self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) return True + reader = self._reader + assert reader is not None + # we need to break `receive()` cycle before we can call + # `reader.read()` as `close()` may be called from different task + if self._waiting: + assert self._loop is not None + assert self._close_wait is None + self._close_wait = self._loop.create_future() + reader.feed_data(WS_CLOSING_MESSAGE) + await self._close_wait + if self._closing: self._close_transport() return True - reader = self._reader - assert reader is not None try: async with async_timeout.timeout(self._timeout): msg = await reader.read() @@ -411,7 +462,7 @@ class WebSocketResponse(StreamResponse): self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) return True - if msg.type == WSMsgType.CLOSE: + if msg.type is WSMsgType.CLOSE: self._set_code_close_transport(msg.data) return True @@ -423,6 +474,7 @@ class WebSocketResponse(StreamResponse): """Set the close code and mark the connection as closing.""" self._closing = True self._close_code = code + self._cancel_heartbeat() def _set_code_close_transport(self, code: WSCloseCode) -> None: """Set the close code and close the transport.""" @@ -440,8 +492,9 @@ class WebSocketResponse(StreamResponse): loop = self._loop assert loop is not None + receive_timeout = timeout or self._receive_timeout while True: - if self._waiting is not None: + if self._waiting: raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: @@ -453,15 +506,22 @@ class WebSocketResponse(StreamResponse): return WS_CLOSING_MESSAGE try: - self._waiting = loop.create_future() + self._waiting = True try: - async with async_timeout.timeout(timeout or self._receive_timeout): + if receive_timeout: + # Entering the context manager and creating + # Timeout() object can take almost 50% of the + # run time in this loop so we avoid it if + # there is no read timeout. + async with async_timeout.timeout(receive_timeout): + msg = await self._reader.read() + else: msg = await self._reader.read() self._reset_heartbeat() finally: - waiter = self._waiting - set_result(waiter, True) - self._waiting = None + self._waiting = False + if self._close_wait: + set_result(self._close_wait, None) except asyncio.TimeoutError: raise except EofStream: @@ -478,7 +538,7 @@ class WebSocketResponse(StreamResponse): await self.close() return WSMessage(WSMsgType.ERROR, exc, None) - if msg.type == WSMsgType.CLOSE: + if msg.type is WSMsgType.CLOSE: self._set_closing(msg.data) # Could be closed while awaiting reader. if not self._closed and self._autoclose: @@ -487,19 +547,19 @@ class WebSocketResponse(StreamResponse): # want to drain any pending writes as it will # likely result writing to a broken pipe. await self.close(drain=False) - elif msg.type == WSMsgType.CLOSING: + elif msg.type is WSMsgType.CLOSING: self._set_closing(WSCloseCode.OK) - elif msg.type == WSMsgType.PING and self._autoping: + elif msg.type is WSMsgType.PING and self._autoping: await self.pong(msg.data) continue - elif msg.type == WSMsgType.PONG and self._autoping: + elif msg.type is WSMsgType.PONG and self._autoping: continue return msg async def receive_str(self, *, timeout: Optional[float] = None) -> str: msg = await self.receive(timeout) - if msg.type != WSMsgType.TEXT: + if msg.type is not WSMsgType.TEXT: raise TypeError( "Received message {}:{!r} is not WSMsgType.TEXT".format( msg.type, msg.data @@ -509,7 +569,7 @@ class WebSocketResponse(StreamResponse): async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: msg = await self.receive(timeout) - if msg.type != WSMsgType.BINARY: + if msg.type is not WSMsgType.BINARY: raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") return cast(bytes, msg.data) @@ -535,5 +595,6 @@ class WebSocketResponse(StreamResponse): # web_protocol calls this from connection_lost # or when the server is shutting down. self._closing = True + self._cancel_heartbeat() if self._reader is not None: set_exception(self._reader, exc) diff --git a/contrib/python/aiohttp/patches/04-force-content-type.patch b/contrib/python/aiohttp/patches/04-force-content-type.patch new file mode 100644 index 00000000000..44569413307 --- /dev/null +++ b/contrib/python/aiohttp/patches/04-force-content-type.patch @@ -0,0 +1,12 @@ +--- contrib/python/aiohttp/aiohttp/web_response.py (ddcb92de87597ba3c0a8961e7fdf04a184c227ce) ++++ contrib/python/aiohttp/aiohttp/web_response.py (0978c4fe84e8994e041f045b1447dd8058efa52c) +@@ -487,8 +487,7 @@ class StreamResponse(BaseClass, HeadersMixin): + # https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-13 + if hdrs.TRANSFER_ENCODING in headers: + del headers[hdrs.TRANSFER_ENCODING] +- elif self.content_length != 0: +- # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5 ++ else: + headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream") + headers.setdefault(hdrs.DATE, rfc822_formatted_time()) + headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE) diff --git a/contrib/python/aiohttp/patches/04-pr11265-support-aiosignal-1.4.patch b/contrib/python/aiohttp/patches/04-pr11265-support-aiosignal-1.4.patch index 0bfc1cc3a87..3b478e3aca5 100644 --- a/contrib/python/aiohttp/patches/04-pr11265-support-aiosignal-1.4.patch +++ b/contrib/python/aiohttp/patches/04-pr11265-support-aiosignal-1.4.patch @@ -1,6 +1,6 @@ --- contrib/python/aiohttp/aiohttp/tracing.py (index) +++ contrib/python/aiohttp/aiohttp/tracing.py (working tree) -@@ -12,15 +12,7 @@ if TYPE_CHECKING: +@@ -12,14 +12,7 @@ if TYPE_CHECKING: from .client import ClientSession _ParamT_contra = TypeVar("_ParamT_contra", contravariant=True) @@ -11,98 +11,83 @@ - __client_session: ClientSession, - __trace_config_ctx: SimpleNamespace, - __params: _ParamT_contra, -- ) -> Awaitable[None]: -- ... +- ) -> Awaitable[None]: ... + _TracingSignal = Signal[ClientSession, SimpleNamespace, _ParamT_contra] __all__ = ( -@@ -50,53 +42,53 @@ class TraceConfig: +@@ -49,54 +42,24 @@ class TraceConfig: def __init__( self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace ) -> None: -- self._on_request_start: Signal[ -- _SignalCallback[TraceRequestStartParams] -+ self._on_request_start: _TracingSignal[ -+ TraceRequestStartParams - ] = Signal(self) +- self._on_request_start: Signal[_SignalCallback[TraceRequestStartParams]] = ( ++ self._on_request_start: _TracingSignal[TraceRequestStartParams] = ( + Signal(self) + ) - self._on_request_chunk_sent: Signal[ - _SignalCallback[TraceRequestChunkSentParams] -+ self._on_request_chunk_sent: _TracingSignal[ -+ TraceRequestChunkSentParams - ] = Signal(self) +- ] = Signal(self) - self._on_response_chunk_received: Signal[ - _SignalCallback[TraceResponseChunkReceivedParams] -+ self._on_response_chunk_received: _TracingSignal[ -+ TraceResponseChunkReceivedParams - ] = Signal(self) +- ] = Signal(self) - self._on_request_end: Signal[_SignalCallback[TraceRequestEndParams]] = Signal( -+ self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal( - self - ) +- self +- ) - self._on_request_exception: Signal[ - _SignalCallback[TraceRequestExceptionParams] -+ self._on_request_exception: _TracingSignal[ -+ TraceRequestExceptionParams - ] = Signal(self) +- ] = Signal(self) - self._on_request_redirect: Signal[ - _SignalCallback[TraceRequestRedirectParams] -+ self._on_request_redirect: _TracingSignal[ -+ TraceRequestRedirectParams - ] = Signal(self) +- ] = Signal(self) - self._on_connection_queued_start: Signal[ - _SignalCallback[TraceConnectionQueuedStartParams] -+ self._on_connection_queued_start: _TracingSignal[ -+ TraceConnectionQueuedStartParams - ] = Signal(self) +- ] = Signal(self) - self._on_connection_queued_end: Signal[ - _SignalCallback[TraceConnectionQueuedEndParams] -+ self._on_connection_queued_end: _TracingSignal[ -+ TraceConnectionQueuedEndParams - ] = Signal(self) +- ] = Signal(self) - self._on_connection_create_start: Signal[ - _SignalCallback[TraceConnectionCreateStartParams] -+ self._on_connection_create_start: _TracingSignal[ -+ TraceConnectionCreateStartParams - ] = Signal(self) +- ] = Signal(self) - self._on_connection_create_end: Signal[ - _SignalCallback[TraceConnectionCreateEndParams] -+ self._on_connection_create_end: _TracingSignal[ -+ TraceConnectionCreateEndParams - ] = Signal(self) +- ] = Signal(self) - self._on_connection_reuseconn: Signal[ - _SignalCallback[TraceConnectionReuseconnParams] -+ self._on_connection_reuseconn: _TracingSignal[ -+ TraceConnectionReuseconnParams - ] = Signal(self) +- ] = Signal(self) - self._on_dns_resolvehost_start: Signal[ - _SignalCallback[TraceDnsResolveHostStartParams] -+ self._on_dns_resolvehost_start: _TracingSignal[ -+ TraceDnsResolveHostStartParams - ] = Signal(self) +- ] = Signal(self) - self._on_dns_resolvehost_end: Signal[ - _SignalCallback[TraceDnsResolveHostEndParams] -+ self._on_dns_resolvehost_end: _TracingSignal[ -+ TraceDnsResolveHostEndParams - ] = Signal(self) -- self._on_dns_cache_hit: Signal[ -- _SignalCallback[TraceDnsCacheHitParams] -+ self._on_dns_cache_hit: _TracingSignal[ -+ TraceDnsCacheHitParams - ] = Signal(self) -- self._on_dns_cache_miss: Signal[ -- _SignalCallback[TraceDnsCacheMissParams] -+ self._on_dns_cache_miss: _TracingSignal[ -+ TraceDnsCacheMissParams - ] = Signal(self) +- ] = Signal(self) +- self._on_dns_cache_hit: Signal[_SignalCallback[TraceDnsCacheHitParams]] = ( +- Signal(self) +- ) +- self._on_dns_cache_miss: Signal[_SignalCallback[TraceDnsCacheMissParams]] = ( +- Signal(self) +- ) - self._on_request_headers_sent: Signal[ - _SignalCallback[TraceRequestHeadersSentParams] -+ self._on_request_headers_sent: _TracingSignal[ -+ TraceRequestHeadersSentParams - ] = Signal(self) +- ] = Signal(self) ++ self._on_request_chunk_sent: _TracingSignal[TraceRequestChunkSentParams] = Signal(self) ++ self._on_response_chunk_received: _TracingSignal[TraceResponseChunkReceivedParams] = Signal(self) ++ self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal(self) ++ self._on_request_exception: _TracingSignal[TraceRequestExceptionParams] = Signal(self) ++ self._on_request_redirect: _TracingSignal[TraceRequestRedirectParams] = Signal(self) ++ self._on_connection_queued_start: _TracingSignal[TraceConnectionQueuedStartParams] = Signal(self) ++ self._on_connection_queued_end: _TracingSignal[TraceConnectionQueuedEndParams] = Signal(self) ++ self._on_connection_create_start: _TracingSignal[TraceConnectionCreateStartParams] = Signal(self) ++ self._on_connection_create_end: _TracingSignal[TraceConnectionCreateEndParams] = Signal(self) ++ self._on_connection_reuseconn: _TracingSignal[TraceConnectionReuseconnParams] = Signal(self) ++ self._on_dns_resolvehost_start: _TracingSignal[TraceDnsResolveHostStartParams] = Signal(self) ++ self._on_dns_resolvehost_end: _TracingSignal[TraceDnsResolveHostEndParams] = Signal(self) ++ self._on_dns_cache_hit: _TracingSignal[TraceDnsCacheHitParams] = (Signal(self)) ++ self._on_dns_cache_miss: _TracingSignal[TraceDnsCacheMissParams] = (Signal(self)) ++ self._on_request_headers_sent: _TracingSignal[TraceRequestHeadersSentParams] = Signal(self) self._trace_config_ctx_factory = trace_config_ctx_factory -@@ -126,91 +118,89 @@ class TraceConfig: + +@@ -125,91 +88,91 @@ class TraceConfig: self._on_request_headers_sent.freeze() @property @@ -111,10 +96,10 @@ return self._on_request_start @property -- def on_request_chunk_sent( -- self, + def on_request_chunk_sent( + self, - ) -> "Signal[_SignalCallback[TraceRequestChunkSentParams]]": -+ def on_request_chunk_sent(self) -> "_TracingSignal[TraceRequestChunkSentParams]": ++ ) -> "_TracingSignal[TraceRequestChunkSentParams]": return self._on_request_chunk_sent @property diff --git a/contrib/python/aiohttp/patches/05-disable-retries-on-oidemponent-methods.patch b/contrib/python/aiohttp/patches/05-disable-retries-on-oidemponent-methods.patch new file mode 100644 index 00000000000..8bd072eeabe --- /dev/null +++ b/contrib/python/aiohttp/patches/05-disable-retries-on-oidemponent-methods.patch @@ -0,0 +1,11 @@ +--- contrib/python/aiohttp/aiohttp/client.py (index) ++++ contrib/python/aiohttp/aiohttp/client.py (working tree) +@@ -574,7 +574,7 @@ class ClientSession: + try: + with timer: + # https://www.rfc-editor.org/rfc/rfc9112.html#name-retrying-requests +- retry_persistent_connection = method in IDEMPOTENT_METHODS ++ retry_persistent_connection = False #method in IDEMPOTENT_METHODS + while True: + url, auth_from_url = strip_auth_from_url(url) + if not url.raw_host: diff --git a/contrib/python/aiohttp/patches/06-mypy-silence.patch b/contrib/python/aiohttp/patches/06-mypy-silence.patch new file mode 100644 index 00000000000..d18dca98aa4 --- /dev/null +++ b/contrib/python/aiohttp/patches/06-mypy-silence.patch @@ -0,0 +1,61 @@ +--- contrib/python/aiohttp/aiohttp/client.py ++++ contrib/python/aiohttp/aiohttp/client.py +@@ -174,11 +174,11 @@ class _RequestOptions(TypedDict, total=False): + read_until_eof: bool + proxy: Union[StrOrURL, None] + proxy_auth: Union[BasicAuth, None] +- timeout: "Union[ClientTimeout, _SENTINEL, None]" ++ timeout: "Union[ClientTimeout, _SENTINEL, int, float, None]" + ssl: Union[SSLContext, bool, Fingerprint] + server_hostname: Union[str, None] + proxy_headers: Union[LooseHeaders, None] +- trace_request_ctx: Union[Mapping[str, str], None] ++ trace_request_ctx: Any #Union[Mapping[str, str], None] + read_bufsize: Union[int, None] + auto_decompress: Union[bool, None] + max_line_size: Union[int, None] +--- contrib/python/aiohttp/aiohttp/typedefs.py ++++ contrib/python/aiohttp/aiohttp/typedefs.py +@@ -69,12 +69,7 @@ LooseCookies = Union[ + ] + + Handler = Callable[["Request"], Awaitable["StreamResponse"]] +- +- +-class Middleware(Protocol): +- def __call__( +- self, request: "Request", handler: Handler +- ) -> Awaitable["StreamResponse"]: ... ++Middleware = Callable[["Request", Handler], Awaitable["StreamResponse"]] + + + PathLike = Union[str, "os.PathLike[str]"] +--- contrib/python/aiohttp/aiohttp/multipart.py ++++ contrib/python/aiohttp/aiohttp/multipart.py +@@ -287,7 +287,7 @@ class BodyPartReader: + self._content_eof = 0 + self._cache: Dict[str, Any] = {} + +- def __aiter__(self: Self) -> Self: ++ def __aiter__(self: Self): + return self + + async def __anext__(self) -> bytes: +@@ -593,7 +593,7 @@ class MultipartReader: + response_wrapper_cls = MultipartResponseWrapper + #: Multipart reader class, used to handle multipart/* body parts. + #: None points to type(self) +- multipart_reader_cls: Optional[Type["MultipartReader"]] = None ++ multipart_reader_cls = None + #: Body part reader class for non multipart/* content types. + part_reader_cls = BodyPartReader + +@@ -614,7 +614,7 @@ class MultipartReader: + self._at_bof = True + self._unread: List[bytes] = [] + +- def __aiter__(self: Self) -> Self: ++ def __aiter__(self: Self): + return self + + async def __anext__( diff --git a/contrib/python/aiohttp/patches/07-dont-throw-at-newline.patch b/contrib/python/aiohttp/patches/07-dont-throw-at-newline.patch new file mode 100644 index 00000000000..815dd062c17 --- /dev/null +++ b/contrib/python/aiohttp/patches/07-dont-throw-at-newline.patch @@ -0,0 +1,12 @@ +# This patch is revert commit dd5bb073107caa1c764158b87fb8482124aad6c1 +--- contrib/python/aiohttp/aiohttp/web_response.py (index) ++++ contrib/python/aiohttp/aiohttp/web_response.py (working tree) +@@ -147,8 +147,6 @@ class StreamResponse(BaseClass, HeadersMixin): + self._status = int(status) + if reason is None: + reason = REASON_PHRASES.get(self._status, "") +- elif "\n" in reason: +- raise ValueError("Reason cannot contain \\n") + self._reason = reason + + @property diff --git a/contrib/python/aiohttp/patches/99-rep-get-running-loop.sh b/contrib/python/aiohttp/patches/99-rep-get-running-loop.sh new file mode 100644 index 00000000000..b97011e3c67 --- /dev/null +++ b/contrib/python/aiohttp/patches/99-rep-get-running-loop.sh @@ -0,0 +1,4 @@ +# This patch may be dropped after python 3.13 upver + +find aiohttp -type f -exec sed --in-place 's|loop or asyncio.get_running_loop|loop or asyncio.get_event_loop|g' '{}' ';' + diff --git a/contrib/python/aiohttp/ya.make b/contrib/python/aiohttp/ya.make index 40b3b6faabe..e714d0fd423 100644 --- a/contrib/python/aiohttp/ya.make +++ b/contrib/python/aiohttp/ya.make @@ -2,11 +2,12 @@ PY3_LIBRARY() -VERSION(3.9.5) +VERSION(3.10.6) LICENSE(Apache-2.0) PEERDIR( + contrib/python/aiohappyeyeballs contrib/python/aiosignal contrib/python/attrs contrib/python/frozenlist |
