summaryrefslogtreecommitdiffstats
path: root/contrib/python/aiohttp
diff options
context:
space:
mode:
authorrobot-piglet <[email protected]>2025-10-16 23:00:05 +0300
committerrobot-piglet <[email protected]>2025-10-16 23:13:00 +0300
commit11655c9ebab3639829f448726c53598a309836d2 (patch)
tree0fb621e064ba793816eee485f5b5b2b0d770c234 /contrib/python/aiohttp
parentab033a87e63230a4a2116281fc1dc2d0072fc894 (diff)
Intermediate changes
commit_hash:736b945c094519c9026454678b074141602f4584
Diffstat (limited to 'contrib/python/aiohttp')
-rw-r--r--contrib/python/aiohttp/.dist-info/METADATA11
-rw-r--r--contrib/python/aiohttp/README.rst3
-rw-r--r--contrib/python/aiohttp/aiohttp/__init__.py82
-rw-r--r--contrib/python/aiohttp/aiohttp/_helpers.pyx22
-rw-r--r--contrib/python/aiohttp/aiohttp/_http_parser.pyx19
-rw-r--r--contrib/python/aiohttp/aiohttp/abc.py31
-rw-r--r--contrib/python/aiohttp/aiohttp/base_protocol.py3
-rw-r--r--contrib/python/aiohttp/aiohttp/client.py408
-rw-r--r--contrib/python/aiohttp/aiohttp/client_exceptions.py79
-rw-r--r--contrib/python/aiohttp/aiohttp/client_proto.py30
-rw-r--r--contrib/python/aiohttp/aiohttp/client_reqrep.py173
-rw-r--r--contrib/python/aiohttp/aiohttp/client_ws.py203
-rw-r--r--contrib/python/aiohttp/aiohttp/compression_utils.py8
-rw-r--r--contrib/python/aiohttp/aiohttp/connector.py241
-rw-r--r--contrib/python/aiohttp/aiohttp/cookiejar.py230
-rw-r--r--contrib/python/aiohttp/aiohttp/helpers.py159
-rw-r--r--contrib/python/aiohttp/aiohttp/http_exceptions.py1
-rw-r--r--contrib/python/aiohttp/aiohttp/http_parser.py49
-rw-r--r--contrib/python/aiohttp/aiohttp/http_websocket.py447
-rw-r--r--contrib/python/aiohttp/aiohttp/http_writer.py7
-rw-r--r--contrib/python/aiohttp/aiohttp/multipart.py80
-rw-r--r--contrib/python/aiohttp/aiohttp/payload.py45
-rw-r--r--contrib/python/aiohttp/aiohttp/payload_streamer.py3
-rw-r--r--contrib/python/aiohttp/aiohttp/pytest_plugin.py52
-rw-r--r--contrib/python/aiohttp/aiohttp/resolver.py107
-rw-r--r--contrib/python/aiohttp/aiohttp/streams.py3
-rw-r--r--contrib/python/aiohttp/aiohttp/test_utils.py162
-rw-r--r--contrib/python/aiohttp/aiohttp/tracing.py72
-rw-r--r--contrib/python/aiohttp/aiohttp/typedefs.py23
-rw-r--r--contrib/python/aiohttp/aiohttp/web.py37
-rw-r--r--contrib/python/aiohttp/aiohttp/web_app.py76
-rw-r--r--contrib/python/aiohttp/aiohttp/web_fileresponse.py141
-rw-r--r--contrib/python/aiohttp/aiohttp/web_middlewares.py7
-rw-r--r--contrib/python/aiohttp/aiohttp/web_protocol.py170
-rw-r--r--contrib/python/aiohttp/aiohttp/web_request.py46
-rw-r--r--contrib/python/aiohttp/aiohttp/web_response.py68
-rw-r--r--contrib/python/aiohttp/aiohttp/web_routedef.py6
-rw-r--r--contrib/python/aiohttp/aiohttp/web_runner.py18
-rw-r--r--contrib/python/aiohttp/aiohttp/web_server.py11
-rw-r--r--contrib/python/aiohttp/aiohttp/web_urldispatcher.py225
-rw-r--r--contrib/python/aiohttp/aiohttp/web_ws.py185
-rw-r--r--contrib/python/aiohttp/patches/04-force-content-type.patch12
-rw-r--r--contrib/python/aiohttp/patches/04-pr11265-support-aiosignal-1.4.patch109
-rw-r--r--contrib/python/aiohttp/patches/05-disable-retries-on-oidemponent-methods.patch11
-rw-r--r--contrib/python/aiohttp/patches/06-mypy-silence.patch61
-rw-r--r--contrib/python/aiohttp/patches/07-dont-throw-at-newline.patch12
-rw-r--r--contrib/python/aiohttp/patches/99-rep-get-running-loop.sh4
-rw-r--r--contrib/python/aiohttp/ya.make3
48 files changed, 2512 insertions, 1443 deletions
diff --git a/contrib/python/aiohttp/.dist-info/METADATA b/contrib/python/aiohttp/.dist-info/METADATA
index cd312649136..7ad8cef0fd3 100644
--- a/contrib/python/aiohttp/.dist-info/METADATA
+++ b/contrib/python/aiohttp/.dist-info/METADATA
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: aiohttp
-Version: 3.9.5
+Version: 3.10.6
Summary: Async http client/server framework (asyncio)
Home-page: https://github.com/aio-libs/aiohttp
Maintainer: aiohttp team <[email protected]>
@@ -28,20 +28,22 @@ Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
+Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Internet :: WWW/HTTP
Requires-Python: >=3.8
Description-Content-Type: text/x-rst
License-File: LICENSE.txt
+Requires-Dist: aiohappyeyeballs >=2.3.0
Requires-Dist: aiosignal >=1.1.2
Requires-Dist: attrs >=17.3.0
Requires-Dist: frozenlist >=1.1.1
Requires-Dist: multidict <7.0,>=4.5
-Requires-Dist: yarl <2.0,>=1.0
+Requires-Dist: yarl <2.0,>=1.12.0
Requires-Dist: async-timeout <5.0,>=4.0 ; python_version < "3.11"
Provides-Extra: speedups
Requires-Dist: brotlicffi ; (platform_python_implementation != "CPython") and extra == 'speedups'
Requires-Dist: Brotli ; (platform_python_implementation == "CPython") and extra == 'speedups'
-Requires-Dist: aiodns ; (sys_platform == "linux" or sys_platform == "darwin") and extra == 'speedups'
+Requires-Dist: aiodns >=3.2.0 ; (sys_platform == "linux" or sys_platform == "darwin") and extra == 'speedups'
==================================
Async http client/server framework
@@ -193,7 +195,7 @@ Communication channels
*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions
-*gitter chat* https://gitter.im/aio-libs/Lobby
+*Matrix*: `#aio-libs:matrix.org <https://matrix.to/#/#aio-libs:matrix.org>`_
We support `Stack Overflow
<https://stackoverflow.com/questions/tagged/aiohttp>`_.
@@ -202,7 +204,6 @@ Please add *aiohttp* tag to your question there.
Requirements
============
-- async-timeout_
- attrs_
- multidict_
- yarl_
diff --git a/contrib/python/aiohttp/README.rst b/contrib/python/aiohttp/README.rst
index 90b7f713577..470ced9b29c 100644
--- a/contrib/python/aiohttp/README.rst
+++ b/contrib/python/aiohttp/README.rst
@@ -148,7 +148,7 @@ Communication channels
*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions
-*gitter chat* https://gitter.im/aio-libs/Lobby
+*Matrix*: `#aio-libs:matrix.org <https://matrix.to/#/#aio-libs:matrix.org>`_
We support `Stack Overflow
<https://stackoverflow.com/questions/tagged/aiohttp>`_.
@@ -157,7 +157,6 @@ Please add *aiohttp* tag to your question there.
Requirements
============
-- async-timeout_
- attrs_
- multidict_
- yarl_
diff --git a/contrib/python/aiohttp/aiohttp/__init__.py b/contrib/python/aiohttp/aiohttp/__init__.py
index e82e790b46a..8830d340940 100644
--- a/contrib/python/aiohttp/aiohttp/__init__.py
+++ b/contrib/python/aiohttp/aiohttp/__init__.py
@@ -1,40 +1,48 @@
-__version__ = "3.9.5"
+__version__ = "3.10.6"
from typing import TYPE_CHECKING, Tuple
from . import hdrs as hdrs
from .client import (
- BaseConnector as BaseConnector,
- ClientConnectionError as ClientConnectionError,
- ClientConnectorCertificateError as ClientConnectorCertificateError,
- ClientConnectorError as ClientConnectorError,
- ClientConnectorSSLError as ClientConnectorSSLError,
- ClientError as ClientError,
- ClientHttpProxyError as ClientHttpProxyError,
- ClientOSError as ClientOSError,
- ClientPayloadError as ClientPayloadError,
- ClientProxyConnectionError as ClientProxyConnectionError,
- ClientRequest as ClientRequest,
- ClientResponse as ClientResponse,
- ClientResponseError as ClientResponseError,
- ClientSession as ClientSession,
- ClientSSLError as ClientSSLError,
- ClientTimeout as ClientTimeout,
- ClientWebSocketResponse as ClientWebSocketResponse,
- ContentTypeError as ContentTypeError,
- Fingerprint as Fingerprint,
- InvalidURL as InvalidURL,
- NamedPipeConnector as NamedPipeConnector,
- RequestInfo as RequestInfo,
- ServerConnectionError as ServerConnectionError,
- ServerDisconnectedError as ServerDisconnectedError,
- ServerFingerprintMismatch as ServerFingerprintMismatch,
- ServerTimeoutError as ServerTimeoutError,
- TCPConnector as TCPConnector,
- TooManyRedirects as TooManyRedirects,
- UnixConnector as UnixConnector,
- WSServerHandshakeError as WSServerHandshakeError,
- request as request,
+ BaseConnector,
+ ClientConnectionError,
+ ClientConnectionResetError,
+ ClientConnectorCertificateError,
+ ClientConnectorError,
+ ClientConnectorSSLError,
+ ClientError,
+ ClientHttpProxyError,
+ ClientOSError,
+ ClientPayloadError,
+ ClientProxyConnectionError,
+ ClientRequest,
+ ClientResponse,
+ ClientResponseError,
+ ClientSession,
+ ClientSSLError,
+ ClientTimeout,
+ ClientWebSocketResponse,
+ ConnectionTimeoutError,
+ ContentTypeError,
+ Fingerprint,
+ InvalidURL,
+ InvalidUrlClientError,
+ InvalidUrlRedirectClientError,
+ NamedPipeConnector,
+ NonHttpUrlClientError,
+ NonHttpUrlRedirectClientError,
+ RedirectClientError,
+ RequestInfo,
+ ServerConnectionError,
+ ServerDisconnectedError,
+ ServerFingerprintMismatch,
+ ServerTimeoutError,
+ SocketTimeoutError,
+ TCPConnector,
+ TooManyRedirects,
+ UnixConnector,
+ WSServerHandshakeError,
+ request,
)
from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar
from .formdata import FormData as FormData
@@ -99,6 +107,7 @@ from .tracing import (
TraceRequestChunkSentParams as TraceRequestChunkSentParams,
TraceRequestEndParams as TraceRequestEndParams,
TraceRequestExceptionParams as TraceRequestExceptionParams,
+ TraceRequestHeadersSentParams as TraceRequestHeadersSentParams,
TraceRequestRedirectParams as TraceRequestRedirectParams,
TraceRequestStartParams as TraceRequestStartParams,
TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams,
@@ -116,6 +125,7 @@ __all__: Tuple[str, ...] = (
# client
"BaseConnector",
"ClientConnectionError",
+ "ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorError",
"ClientConnectorSSLError",
@@ -131,14 +141,21 @@ __all__: Tuple[str, ...] = (
"ClientSession",
"ClientTimeout",
"ClientWebSocketResponse",
+ "ConnectionTimeoutError",
"ContentTypeError",
"Fingerprint",
"InvalidURL",
+ "InvalidUrlClientError",
+ "InvalidUrlRedirectClientError",
+ "NonHttpUrlClientError",
+ "NonHttpUrlRedirectClientError",
+ "RedirectClientError",
"RequestInfo",
"ServerConnectionError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
+ "SocketTimeoutError",
"TCPConnector",
"TooManyRedirects",
"UnixConnector",
@@ -210,6 +227,7 @@ __all__: Tuple[str, ...] = (
"TraceRequestChunkSentParams",
"TraceRequestEndParams",
"TraceRequestExceptionParams",
+ "TraceRequestHeadersSentParams",
"TraceRequestRedirectParams",
"TraceRequestStartParams",
"TraceResponseChunkReceivedParams",
diff --git a/contrib/python/aiohttp/aiohttp/_helpers.pyx b/contrib/python/aiohttp/aiohttp/_helpers.pyx
index 665f367c5de..5f089225dc8 100644
--- a/contrib/python/aiohttp/aiohttp/_helpers.pyx
+++ b/contrib/python/aiohttp/aiohttp/_helpers.pyx
@@ -1,3 +1,6 @@
+
+cdef _sentinel = object()
+
cdef class reify:
"""Use as a class method decorator. It operates almost exactly like
the Python `@property` decorator, but it puts the result of the
@@ -19,17 +22,14 @@ cdef class reify:
return self.wrapped.__doc__
def __get__(self, inst, owner):
- try:
- try:
- return inst._cache[self.name]
- except KeyError:
- val = self.wrapped(inst)
- inst._cache[self.name] = val
- return val
- except AttributeError:
- if inst is None:
- return self
- raise
+ if inst is None:
+ return self
+ cdef dict cache = inst._cache
+ val = cache.get(self.name, _sentinel)
+ if val is _sentinel:
+ val = self.wrapped(inst)
+ cache[self.name] = val
+ return val
def __set__(self, inst, value):
raise AttributeError("reified property is read-only")
diff --git a/contrib/python/aiohttp/aiohttp/_http_parser.pyx b/contrib/python/aiohttp/aiohttp/_http_parser.pyx
index ec6edb2dfec..8e82c0fd77f 100644
--- a/contrib/python/aiohttp/aiohttp/_http_parser.pyx
+++ b/contrib/python/aiohttp/aiohttp/_http_parser.pyx
@@ -47,6 +47,7 @@ include "_headers.pxi"
from aiohttp cimport _find_header
+ALLOWED_UPGRADES = frozenset({"websocket"})
DEF DEFAULT_FREELIST_SIZE = 250
cdef extern from "Python.h":
@@ -417,7 +418,6 @@ cdef class HttpParser:
cdef _on_headers_complete(self):
self._process_header()
- method = http_method_str(self._cparser.method)
should_close = not cparser.llhttp_should_keep_alive(self._cparser)
upgrade = self._cparser.upgrade
chunked = self._cparser.flags & cparser.F_CHUNKED
@@ -425,8 +425,13 @@ cdef class HttpParser:
raw_headers = tuple(self._raw_headers)
headers = CIMultiDictProxy(self._headers)
- if upgrade or self._cparser.method == cparser.HTTP_CONNECT:
- self._upgraded = True
+ if self._cparser.type == cparser.HTTP_REQUEST:
+ allowed = upgrade and headers.get("upgrade", "").lower() in ALLOWED_UPGRADES
+ if allowed or self._cparser.method == cparser.HTTP_CONNECT:
+ self._upgraded = True
+ else:
+ if upgrade and self._cparser.status_code == 101:
+ self._upgraded = True
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in headers:
@@ -441,6 +446,7 @@ cdef class HttpParser:
encoding = enc
if self._cparser.type == cparser.HTTP_REQUEST:
+ method = http_method_str(self._cparser.method)
msg = _new_request_message(
method, self._path,
self.http_version(), headers, raw_headers,
@@ -565,7 +571,7 @@ cdef class HttpParser:
if self._upgraded:
return messages, True, data[nb:]
else:
- return messages, False, b''
+ return messages, False, b""
def set_upgraded(self, val):
self._upgraded = val
@@ -748,10 +754,7 @@ cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1:
pyparser._last_error = exc
return -1
else:
- if (
- pyparser._cparser.upgrade or
- pyparser._cparser.method == cparser.HTTP_CONNECT
- ):
+ if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT:
return 2
else:
return 0
diff --git a/contrib/python/aiohttp/aiohttp/abc.py b/contrib/python/aiohttp/aiohttp/abc.py
index ee838998997..b309bcffe01 100644
--- a/contrib/python/aiohttp/aiohttp/abc.py
+++ b/contrib/python/aiohttp/aiohttp/abc.py
@@ -1,5 +1,6 @@
import asyncio
import logging
+import socket
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel
@@ -14,12 +15,12 @@ from typing import (
List,
Optional,
Tuple,
+ TypedDict,
)
from multidict import CIMultiDict
from yarl import URL
-from .helpers import get_running_loop
from .typedefs import LooseCookies
if TYPE_CHECKING:
@@ -119,11 +120,35 @@ class AbstractView(ABC):
"""Execute the view handler."""
+class ResolveResult(TypedDict):
+ """Resolve result.
+
+ This is the result returned from an AbstractResolver's
+ resolve method.
+
+ :param hostname: The hostname that was provided.
+ :param host: The IP address that was resolved.
+ :param port: The port that was resolved.
+ :param family: The address family that was resolved.
+ :param proto: The protocol that was resolved.
+ :param flags: The flags that were resolved.
+ """
+
+ hostname: str
+ host: str
+ port: int
+ family: int
+ proto: int
+ flags: int
+
+
class AbstractResolver(ABC):
"""Abstract DNS resolver."""
@abstractmethod
- async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]:
+ async def resolve(
+ self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
+ ) -> List[ResolveResult]:
"""Return IP address for given hostname"""
@abstractmethod
@@ -144,7 +169,7 @@ class AbstractCookieJar(Sized, IterableBase):
"""Abstract Cookie Jar."""
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
- self._loop = get_running_loop(loop)
+ self._loop = loop or asyncio.get_event_loop()
@abstractmethod
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
diff --git a/contrib/python/aiohttp/aiohttp/base_protocol.py b/contrib/python/aiohttp/aiohttp/base_protocol.py
index dc1f24f99cd..2fc2fa65885 100644
--- a/contrib/python/aiohttp/aiohttp/base_protocol.py
+++ b/contrib/python/aiohttp/aiohttp/base_protocol.py
@@ -1,6 +1,7 @@
import asyncio
from typing import Optional, cast
+from .client_exceptions import ClientConnectionResetError
from .helpers import set_exception
from .tcp_helpers import tcp_nodelay
@@ -85,7 +86,7 @@ class BaseProtocol(asyncio.Protocol):
async def _drain_helper(self) -> None:
if not self.connected:
- raise ConnectionResetError("Connection lost")
+ raise ClientConnectionResetError("Connection lost")
if not self._paused:
return
waiter = self._drain_waiter
diff --git a/contrib/python/aiohttp/aiohttp/client.py b/contrib/python/aiohttp/aiohttp/client.py
index 32d2c3b7119..186105bee9f 100644
--- a/contrib/python/aiohttp/aiohttp/client.py
+++ b/contrib/python/aiohttp/aiohttp/client.py
@@ -9,7 +9,7 @@ import sys
import traceback
import warnings
from contextlib import suppress
-from types import SimpleNamespace, TracebackType
+from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
@@ -27,6 +27,7 @@ from typing import (
Set,
Tuple,
Type,
+ TypedDict,
TypeVar,
Union,
)
@@ -38,25 +39,33 @@ from yarl import URL
from . import hdrs, http, payload
from .abc import AbstractCookieJar
from .client_exceptions import (
- ClientConnectionError as ClientConnectionError,
- ClientConnectorCertificateError as ClientConnectorCertificateError,
- ClientConnectorError as ClientConnectorError,
- ClientConnectorSSLError as ClientConnectorSSLError,
- ClientError as ClientError,
- ClientHttpProxyError as ClientHttpProxyError,
- ClientOSError as ClientOSError,
- ClientPayloadError as ClientPayloadError,
- ClientProxyConnectionError as ClientProxyConnectionError,
- ClientResponseError as ClientResponseError,
- ClientSSLError as ClientSSLError,
- ContentTypeError as ContentTypeError,
- InvalidURL as InvalidURL,
- ServerConnectionError as ServerConnectionError,
- ServerDisconnectedError as ServerDisconnectedError,
- ServerFingerprintMismatch as ServerFingerprintMismatch,
- ServerTimeoutError as ServerTimeoutError,
- TooManyRedirects as TooManyRedirects,
- WSServerHandshakeError as WSServerHandshakeError,
+ ClientConnectionError,
+ ClientConnectionResetError,
+ ClientConnectorCertificateError,
+ ClientConnectorError,
+ ClientConnectorSSLError,
+ ClientError,
+ ClientHttpProxyError,
+ ClientOSError,
+ ClientPayloadError,
+ ClientProxyConnectionError,
+ ClientResponseError,
+ ClientSSLError,
+ ConnectionTimeoutError,
+ ContentTypeError,
+ InvalidURL,
+ InvalidUrlClientError,
+ InvalidUrlRedirectClientError,
+ NonHttpUrlClientError,
+ NonHttpUrlRedirectClientError,
+ RedirectClientError,
+ ServerConnectionError,
+ ServerDisconnectedError,
+ ServerFingerprintMismatch,
+ ServerTimeoutError,
+ SocketTimeoutError,
+ TooManyRedirects,
+ WSServerHandshakeError,
)
from .client_reqrep import (
ClientRequest as ClientRequest,
@@ -67,6 +76,7 @@ from .client_reqrep import (
)
from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse
from .connector import (
+ HTTP_AND_EMPTY_SCHEMA_SET,
BaseConnector as BaseConnector,
NamedPipeConnector as NamedPipeConnector,
TCPConnector as TCPConnector,
@@ -80,7 +90,6 @@ from .helpers import (
TimeoutHandle,
ceil_timeout,
get_env_proxy_for_url,
- get_running_loop,
method_must_be_empty_body,
sentinel,
strip_auth_from_url,
@@ -89,11 +98,12 @@ from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter
from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse
from .streams import FlowControlDataQueue
from .tracing import Trace, TraceConfig
-from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, StrOrURL
+from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, Query, StrOrURL
__all__ = (
# client_exceptions
"ClientConnectionError",
+ "ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorError",
"ClientConnectorSSLError",
@@ -104,12 +114,19 @@ __all__ = (
"ClientProxyConnectionError",
"ClientResponseError",
"ClientSSLError",
+ "ConnectionTimeoutError",
"ContentTypeError",
"InvalidURL",
+ "InvalidUrlClientError",
+ "RedirectClientError",
+ "NonHttpUrlClientError",
+ "InvalidUrlRedirectClientError",
+ "NonHttpUrlRedirectClientError",
"ServerConnectionError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
+ "SocketTimeoutError",
"TooManyRedirects",
"WSServerHandshakeError",
# client_reqrep
@@ -136,6 +153,37 @@ if TYPE_CHECKING:
else:
SSLContext = None
+if sys.version_info >= (3, 11) and TYPE_CHECKING:
+ from typing import Unpack
+
+
+class _RequestOptions(TypedDict, total=False):
+ params: Query
+ data: Any
+ json: Any
+ cookies: Union[LooseCookies, None]
+ headers: Union[LooseHeaders, None]
+ skip_auto_headers: Union[Iterable[str], None]
+ auth: Union[BasicAuth, None]
+ allow_redirects: bool
+ max_redirects: int
+ compress: Union[str, bool, None]
+ chunked: Union[bool, None]
+ expect100: bool
+ raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]]
+ read_until_eof: bool
+ proxy: Union[StrOrURL, None]
+ proxy_auth: Union[BasicAuth, None]
+ timeout: "Union[ClientTimeout, _SENTINEL, int, float, None]"
+ ssl: Union[SSLContext, bool, Fingerprint]
+ server_hostname: Union[str, None]
+ proxy_headers: Union[LooseHeaders, None]
+ trace_request_ctx: Any #Union[Mapping[str, str], None]
+ read_bufsize: Union[int, None]
+ auto_decompress: Union[bool, None]
+ max_line_size: Union[int, None]
+ max_field_size: Union[int, None]
+
@attr.s(auto_attribs=True, frozen=True, slots=True)
class ClientTimeout:
@@ -162,7 +210,10 @@ class ClientTimeout:
# 5 Minute default read timeout
DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60)
-_RetType = TypeVar("_RetType")
+# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
+IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
+
+_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse)
_CharsetResolver = Callable[[ClientResponse, bytes], str]
@@ -237,6 +288,21 @@ class ClientSession:
# We initialise _connector to None immediately, as it's referenced in __del__()
# and could cause issues if an exception occurs during initialisation.
self._connector: Optional[BaseConnector] = None
+
+ if loop is None:
+ if connector is not None:
+ loop = connector._loop
+
+ loop = loop or asyncio.get_event_loop()
+
+ if base_url is None or isinstance(base_url, URL):
+ self._base_url: Optional[URL] = base_url
+ else:
+ self._base_url = URL(base_url)
+ assert (
+ self._base_url.origin() == self._base_url
+ ), "Only absolute URLs without path part are supported"
+
if timeout is sentinel or timeout is None:
self._timeout = DEFAULT_TIMEOUT
if read_timeout is not sentinel:
@@ -272,19 +338,6 @@ class ClientSession:
"conflict, please setup "
"timeout.connect"
)
- if loop is None:
- if connector is not None:
- loop = connector._loop
-
- loop = get_running_loop(loop)
-
- if base_url is None or isinstance(base_url, URL):
- self._base_url: Optional[URL] = base_url
- else:
- self._base_url = URL(base_url)
- assert (
- self._base_url.origin() == self._base_url
- ), "Only absolute URLs without path part are supported"
if connector is None:
connector = TCPConnector(loop=loop)
@@ -369,11 +422,22 @@ class ClientSession:
context["source_traceback"] = self._source_traceback
self._loop.call_exception_handler(context)
- def request(
- self, method: str, url: StrOrURL, **kwargs: Any
- ) -> "_RequestContextManager":
- """Perform HTTP request."""
- return _RequestContextManager(self._request(method, url, **kwargs))
+ if sys.version_info >= (3, 11) and TYPE_CHECKING:
+
+ def request(
+ self,
+ method: str,
+ url: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> "_RequestContextManager": ...
+
+ else:
+
+ def request(
+ self, method: str, url: StrOrURL, **kwargs: Any
+ ) -> "_RequestContextManager":
+ """Perform HTTP request."""
+ return _RequestContextManager(self._request(method, url, **kwargs))
def _build_url(self, str_or_url: StrOrURL) -> URL:
url = URL(str_or_url)
@@ -388,7 +452,7 @@ class ClientSession:
method: str,
str_or_url: StrOrURL,
*,
- params: Optional[Mapping[str, str]] = None,
+ params: Query = None,
data: Any = None,
json: Any = None,
cookies: Optional[LooseCookies] = None,
@@ -397,7 +461,7 @@ class ClientSession:
auth: Optional[BasicAuth] = None,
allow_redirects: bool = True,
max_redirects: int = 10,
- compress: Optional[str] = None,
+ compress: Union[str, bool, None] = None,
chunked: Optional[bool] = None,
expect100: bool = False,
raise_for_status: Union[
@@ -413,7 +477,7 @@ class ClientSession:
ssl: Union[SSLContext, bool, Fingerprint] = True,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
- trace_request_ctx: Optional[SimpleNamespace] = None,
+ trace_request_ctx: Optional[Mapping[str, str]] = None,
read_bufsize: Optional[int] = None,
auto_decompress: Optional[bool] = None,
max_line_size: Optional[int] = None,
@@ -451,7 +515,11 @@ class ClientSession:
try:
url = self._build_url(str_or_url)
except ValueError as e:
- raise InvalidURL(str_or_url) from e
+ raise InvalidUrlClientError(str_or_url) from e
+
+ assert self._connector is not None
+ if url.scheme not in self._connector.allowed_protocol_schema_set:
+ raise NonHttpUrlClientError(url)
skip_headers = set(self._skip_auto_headers)
if skip_auto_headers is not None:
@@ -505,8 +573,19 @@ class ClientSession:
timer = tm.timer()
try:
with timer:
+ # https://www.rfc-editor.org/rfc/rfc9112.html#name-retrying-requests
+ retry_persistent_connection = False #method in IDEMPOTENT_METHODS
while True:
url, auth_from_url = strip_auth_from_url(url)
+ if not url.raw_host:
+ # NOTE: Bail early, otherwise, causes `InvalidURL` through
+ # NOTE: `self._request_class()` below.
+ err_exc_cls = (
+ InvalidUrlRedirectClientError
+ if redirects
+ else InvalidUrlClientError
+ )
+ raise err_exc_cls(url)
if auth and auth_from_url:
raise ValueError(
"Cannot combine AUTH argument with "
@@ -550,7 +629,7 @@ class ClientSession:
url,
params=params,
headers=headers,
- skip_auto_headers=skip_headers,
+ skip_auto_headers=skip_headers if skip_headers else None,
data=data,
cookies=all_cookies,
auth=auth,
@@ -577,13 +656,12 @@ class ClientSession:
real_timeout.connect,
ceil_threshold=real_timeout.ceil_threshold,
):
- assert self._connector is not None
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
)
except asyncio.TimeoutError as exc:
- raise ServerTimeoutError(
- "Connection timeout " "to host {}".format(url)
+ raise ConnectionTimeoutError(
+ f"Connection timeout to host {url}"
) from exc
assert conn.transport is not None
@@ -612,6 +690,11 @@ class ClientSession:
except BaseException:
conn.close()
raise
+ except (ClientOSError, ServerDisconnectedError):
+ if retry_persistent_connection:
+ retry_persistent_connection = False
+ continue
+ raise
except ClientError:
raise
except OSError as exc:
@@ -659,25 +742,35 @@ class ClientSession:
resp.release()
try:
- parsed_url = URL(
+ parsed_redirect_url = URL(
r_url, encoded=not self._requote_redirect_url
)
-
except ValueError as e:
- raise InvalidURL(r_url) from e
+ raise InvalidUrlRedirectClientError(
+ r_url,
+ "Server attempted redirecting to a location that does not look like a URL",
+ ) from e
- scheme = parsed_url.scheme
- if scheme not in ("http", "https", ""):
+ scheme = parsed_redirect_url.scheme
+ if scheme not in HTTP_AND_EMPTY_SCHEMA_SET:
resp.close()
- raise ValueError("Can redirect only to http or https")
+ raise NonHttpUrlRedirectClientError(r_url)
elif not scheme:
- parsed_url = url.join(parsed_url)
+ parsed_redirect_url = url.join(parsed_redirect_url)
+
+ try:
+ redirect_origin = parsed_redirect_url.origin()
+ except ValueError as origin_val_err:
+ raise InvalidUrlRedirectClientError(
+ parsed_redirect_url,
+ "Invalid redirect URL origin",
+ ) from origin_val_err
- if url.origin() != parsed_url.origin():
+ if url.origin() != redirect_origin:
auth = None
headers.pop(hdrs.AUTHORIZATION, None)
- url = parsed_url
+ url = parsed_redirect_url
params = {}
resp.release()
continue
@@ -736,11 +829,11 @@ class ClientSession:
heartbeat: Optional[float] = None,
auth: Optional[BasicAuth] = None,
origin: Optional[str] = None,
- params: Optional[Mapping[str, str]] = None,
+ params: Query = None,
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
- ssl: Union[SSLContext, bool, None, Fingerprint] = True,
+ ssl: Union[SSLContext, bool, Fingerprint] = True,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
@@ -788,11 +881,11 @@ class ClientSession:
heartbeat: Optional[float] = None,
auth: Optional[BasicAuth] = None,
origin: Optional[str] = None,
- params: Optional[Mapping[str, str]] = None,
+ params: Query = None,
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
- ssl: Optional[Union[SSLContext, bool, Fingerprint]] = True,
+ ssl: Union[SSLContext, bool, Fingerprint] = True,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
@@ -828,6 +921,11 @@ class ClientSession:
# For the sake of backward compatibility, if user passes in None, convert it to True
if ssl is None:
+ warnings.warn(
+ "ssl=None is deprecated, please use ssl=True",
+ DeprecationWarning,
+ stacklevel=2,
+ )
ssl = True
ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
@@ -922,6 +1020,16 @@ class ClientSession:
assert conn is not None
conn_proto = conn.protocol
assert conn_proto is not None
+
+ # For WS connection the read_timeout must be either receive_timeout or greater
+ # None == no timeout, i.e. infinite timeout, so None is the max timeout possible
+ if receive_timeout is None:
+ # Reset regardless
+ conn_proto.read_timeout = receive_timeout
+ elif conn_proto.read_timeout is not None:
+ # If read_timeout was set check which wins
+ conn_proto.read_timeout = max(receive_timeout, conn_proto.read_timeout)
+
transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
@@ -970,61 +1078,111 @@ class ClientSession:
added_names.add(key)
return result
- def get(
- self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
- ) -> "_RequestContextManager":
- """Perform HTTP GET request."""
- return _RequestContextManager(
- self._request(hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs)
- )
+ if sys.version_info >= (3, 11) and TYPE_CHECKING:
+
+ def get(
+ self,
+ url: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> "_RequestContextManager": ...
+
+ def options(
+ self,
+ url: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> "_RequestContextManager": ...
+
+ def head(
+ self,
+ url: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> "_RequestContextManager": ...
+
+ def post(
+ self,
+ url: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> "_RequestContextManager": ...
- def options(
- self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
- ) -> "_RequestContextManager":
- """Perform HTTP OPTIONS request."""
- return _RequestContextManager(
- self._request(
- hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs
+ def put(
+ self,
+ url: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> "_RequestContextManager": ...
+
+ def patch(
+ self,
+ url: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> "_RequestContextManager": ...
+
+ def delete(
+ self,
+ url: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> "_RequestContextManager": ...
+
+ else:
+
+ def get(
+ self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
+ ) -> "_RequestContextManager":
+ """Perform HTTP GET request."""
+ return _RequestContextManager(
+ self._request(
+ hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs
+ )
)
- )
- def head(
- self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any
- ) -> "_RequestContextManager":
- """Perform HTTP HEAD request."""
- return _RequestContextManager(
- self._request(
- hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs
+ def options(
+ self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
+ ) -> "_RequestContextManager":
+ """Perform HTTP OPTIONS request."""
+ return _RequestContextManager(
+ self._request(
+ hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs
+ )
)
- )
- def post(
- self, url: StrOrURL, *, data: Any = None, **kwargs: Any
- ) -> "_RequestContextManager":
- """Perform HTTP POST request."""
- return _RequestContextManager(
- self._request(hdrs.METH_POST, url, data=data, **kwargs)
- )
+ def head(
+ self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any
+ ) -> "_RequestContextManager":
+ """Perform HTTP HEAD request."""
+ return _RequestContextManager(
+ self._request(
+ hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs
+ )
+ )
- def put(
- self, url: StrOrURL, *, data: Any = None, **kwargs: Any
- ) -> "_RequestContextManager":
- """Perform HTTP PUT request."""
- return _RequestContextManager(
- self._request(hdrs.METH_PUT, url, data=data, **kwargs)
- )
+ def post(
+ self, url: StrOrURL, *, data: Any = None, **kwargs: Any
+ ) -> "_RequestContextManager":
+ """Perform HTTP POST request."""
+ return _RequestContextManager(
+ self._request(hdrs.METH_POST, url, data=data, **kwargs)
+ )
- def patch(
- self, url: StrOrURL, *, data: Any = None, **kwargs: Any
- ) -> "_RequestContextManager":
- """Perform HTTP PATCH request."""
- return _RequestContextManager(
- self._request(hdrs.METH_PATCH, url, data=data, **kwargs)
- )
+ def put(
+ self, url: StrOrURL, *, data: Any = None, **kwargs: Any
+ ) -> "_RequestContextManager":
+ """Perform HTTP PUT request."""
+ return _RequestContextManager(
+ self._request(hdrs.METH_PUT, url, data=data, **kwargs)
+ )
- def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager":
- """Perform HTTP DELETE request."""
- return _RequestContextManager(self._request(hdrs.METH_DELETE, url, **kwargs))
+ def patch(
+ self, url: StrOrURL, *, data: Any = None, **kwargs: Any
+ ) -> "_RequestContextManager":
+ """Perform HTTP PATCH request."""
+ return _RequestContextManager(
+ self._request(hdrs.METH_PATCH, url, data=data, **kwargs)
+ )
+
+ def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager":
+ """Perform HTTP DELETE request."""
+ return _RequestContextManager(
+ self._request(hdrs.METH_DELETE, url, **kwargs)
+ )
async def close(self) -> None:
"""Close underlying connector.
@@ -1175,7 +1333,7 @@ class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType
__slots__ = ("_coro", "_resp")
def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None:
- self._coro = coro
+ self._coro: Coroutine["asyncio.Future[Any]", None, _RetType] = coro
def send(self, arg: None) -> "asyncio.Future[Any]":
return self._coro.send(arg)
@@ -1194,12 +1352,8 @@ class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType
return self.__await__()
async def __aenter__(self) -> _RetType:
- self._resp = await self._coro
- return self._resp
-
-
-class _RequestContextManager(_BaseRequestContextManager[ClientResponse]):
- __slots__ = ()
+ self._resp: _RetType = await self._coro
+ return await self._resp.__aenter__()
async def __aexit__(
self,
@@ -1207,25 +1361,11 @@ class _RequestContextManager(_BaseRequestContextManager[ClientResponse]):
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
- # We're basing behavior on the exception as it can be caused by
- # user code unrelated to the status of the connection. If you
- # would like to close a connection you must do that
- # explicitly. Otherwise connection error handling should kick in
- # and close/recycle the connection as required.
- self._resp.release()
- await self._resp.wait_for_close()
-
+ await self._resp.__aexit__(exc_type, exc, tb)
-class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]):
- __slots__ = ()
- async def __aexit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc: Optional[BaseException],
- tb: Optional[TracebackType],
- ) -> None:
- await self._resp.close()
+_RequestContextManager = _BaseRequestContextManager[ClientResponse]
+_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse]
class _SessionRequestContextManager:
@@ -1265,7 +1405,7 @@ def request(
method: str,
url: StrOrURL,
*,
- params: Optional[Mapping[str, str]] = None,
+ params: Query = None,
data: Any = None,
json: Any = None,
headers: Optional[LooseHeaders] = None,
diff --git a/contrib/python/aiohttp/aiohttp/client_exceptions.py b/contrib/python/aiohttp/aiohttp/client_exceptions.py
index 9b6e44203c8..94991c42477 100644
--- a/contrib/python/aiohttp/aiohttp/client_exceptions.py
+++ b/contrib/python/aiohttp/aiohttp/client_exceptions.py
@@ -2,10 +2,11 @@
import asyncio
import warnings
-from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Optional, Tuple, Union
-from .http_parser import RawResponseMessage
-from .typedefs import LooseHeaders
+from multidict import MultiMapping
+
+from .typedefs import StrOrURL
try:
import ssl
@@ -17,18 +18,22 @@ except ImportError: # pragma: no cover
if TYPE_CHECKING:
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
+ from .http_parser import RawResponseMessage
else:
- RequestInfo = ClientResponse = ConnectionKey = None
+ RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None
__all__ = (
"ClientError",
"ClientConnectionError",
+ "ClientConnectionResetError",
"ClientOSError",
"ClientConnectorError",
"ClientProxyConnectionError",
"ClientSSLError",
"ClientConnectorSSLError",
"ClientConnectorCertificateError",
+ "ConnectionTimeoutError",
+ "SocketTimeoutError",
"ServerConnectionError",
"ServerTimeoutError",
"ServerDisconnectedError",
@@ -39,6 +44,11 @@ __all__ = (
"ContentTypeError",
"ClientPayloadError",
"InvalidURL",
+ "InvalidUrlClientError",
+ "RedirectClientError",
+ "NonHttpUrlClientError",
+ "InvalidUrlRedirectClientError",
+ "NonHttpUrlRedirectClientError",
)
@@ -64,7 +74,7 @@ class ClientResponseError(ClientError):
code: Optional[int] = None,
status: Optional[int] = None,
message: str = "",
- headers: Optional[LooseHeaders] = None,
+ headers: Optional[MultiMapping[str]] = None,
) -> None:
self.request_info = request_info
if code is not None:
@@ -93,7 +103,7 @@ class ClientResponseError(ClientError):
return "{}, message={!r}, url={!r}".format(
self.status,
self.message,
- self.request_info.real_url,
+ str(self.request_info.real_url),
)
def __repr__(self) -> str:
@@ -150,6 +160,10 @@ class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
+class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
+ """ConnectionResetError"""
+
+
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
@@ -242,6 +256,14 @@ class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError):
"""Server timeout error."""
+class ConnectionTimeoutError(ServerTimeoutError):
+ """Connection timeout error."""
+
+
+class SocketTimeoutError(ServerTimeoutError):
+ """Socket timeout error."""
+
+
class ServerFingerprintMismatch(ServerConnectionError):
"""SSL certificate does not match expected fingerprint."""
@@ -271,17 +293,52 @@ class InvalidURL(ClientError, ValueError):
# Derive from ValueError for backward compatibility
- def __init__(self, url: Any) -> None:
+ def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None:
# The type of url is not yarl.URL because the exception can be raised
# on URL(url) call
- super().__init__(url)
+ self._url = url
+ self._description = description
+
+ if description:
+ super().__init__(url, description)
+ else:
+ super().__init__(url)
@property
- def url(self) -> Any:
- return self.args[0]
+ def url(self) -> StrOrURL:
+ return self._url
+
+ @property
+ def description(self) -> "str | None":
+ return self._description
def __repr__(self) -> str:
- return f"<{self.__class__.__name__} {self.url}>"
+ return f"<{self.__class__.__name__} {self}>"
+
+ def __str__(self) -> str:
+ if self._description:
+ return f"{self._url} - {self._description}"
+ return str(self._url)
+
+
+class InvalidUrlClientError(InvalidURL):
+ """Invalid URL client error."""
+
+
+class RedirectClientError(ClientError):
+ """Client redirect error."""
+
+
+class NonHttpUrlClientError(ClientError):
+ """Non http URL client error."""
+
+
+class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError):
+ """Invalid URL redirect client error."""
+
+
+class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError):
+ """Non http URL redirect client error."""
class ClientSSLError(ClientConnectorError):
diff --git a/contrib/python/aiohttp/aiohttp/client_proto.py b/contrib/python/aiohttp/aiohttp/client_proto.py
index 723f5aae5f4..8055811e40d 100644
--- a/contrib/python/aiohttp/aiohttp/client_proto.py
+++ b/contrib/python/aiohttp/aiohttp/client_proto.py
@@ -7,7 +7,7 @@ from .client_exceptions import (
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
- ServerTimeoutError,
+ SocketTimeoutError,
)
from .helpers import (
_EXC_SENTINEL,
@@ -50,15 +50,13 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
@property
def should_close(self) -> bool:
- if self._payload is not None and not self._payload.is_eof() or self._upgraded:
- return True
-
return (
self._should_close
+ or (self._payload is not None and not self._payload.is_eof())
or self._upgraded
- or self.exception() is not None
+ or self._exception is not None
or self._payload_parser is not None
- or len(self) > 0
+ or bool(self._buffer)
or bool(self._tail)
)
@@ -224,8 +222,16 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
def start_timeout(self) -> None:
self._reschedule_timeout()
+ @property
+ def read_timeout(self) -> Optional[float]:
+ return self._read_timeout
+
+ @read_timeout.setter
+ def read_timeout(self, read_timeout: Optional[float]) -> None:
+ self._read_timeout = read_timeout
+
def _on_read_timeout(self) -> None:
- exc = ServerTimeoutError("Timeout on reading data from socket")
+ exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
set_exception(self._payload, exc)
@@ -261,7 +267,15 @@ class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamRe
# closed in this case
self.transport.close()
# should_close is True after the call
- self.set_exception(HttpProcessingError(), underlying_exc)
+ if isinstance(underlying_exc, HttpProcessingError):
+ exc = HttpProcessingError(
+ code=underlying_exc.code,
+ message=underlying_exc.message,
+ headers=underlying_exc.headers,
+ )
+ else:
+ exc = HttpProcessingError()
+ self.set_exception(exc, underlying_exc)
return
self._upgraded = upgraded
diff --git a/contrib/python/aiohttp/aiohttp/client_reqrep.py b/contrib/python/aiohttp/aiohttp/client_reqrep.py
index afe719da16e..aa8f54e67b8 100644
--- a/contrib/python/aiohttp/aiohttp/client_reqrep.py
+++ b/contrib/python/aiohttp/aiohttp/client_reqrep.py
@@ -27,7 +27,7 @@ from typing import (
import attr
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
-from yarl import URL
+from yarl import URL, __version__ as yarl_version
from . import hdrs, helpers, http, multipart, payload
from .abc import AbstractStreamWriter
@@ -67,6 +67,7 @@ from .typedefs import (
JSONDecoder,
LooseCookies,
LooseHeaders,
+ Query,
RawHeaders,
)
@@ -88,6 +89,7 @@ if TYPE_CHECKING:
_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
+_YARL_SUPPORTS_EXTEND_QUERY = tuple(map(int, yarl_version.split(".")[:2])) >= (1, 11)
json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json")
@@ -209,7 +211,7 @@ def _merge_ssl_params(
return ssl
[email protected](auto_attribs=True, slots=True, frozen=True)
[email protected](auto_attribs=True, slots=True, frozen=True, cache_hash=True)
class ConnectionKey:
# the key should contain an information about used proxy / TLS
# to prevent reusing wrong connections from a pool
@@ -245,7 +247,8 @@ class ClientRequest:
hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(),
}
- body = b""
+ # Type of body depends on PAYLOAD_REGISTRY, which is dynamic.
+ body: Any = b""
auth = None
response = None
@@ -262,14 +265,14 @@ class ClientRequest:
method: str,
url: URL,
*,
- params: Optional[Mapping[str, str]] = None,
+ params: Query = None,
headers: Optional[LooseHeaders] = None,
- skip_auto_headers: Iterable[str] = frozenset(),
+ skip_auto_headers: Optional[Iterable[str]] = None,
data: Any = None,
cookies: Optional[LooseCookies] = None,
auth: Optional[BasicAuth] = None,
version: http.HttpVersion = http.HttpVersion11,
- compress: Optional[str] = None,
+ compress: Union[str, bool, None] = None,
chunked: Optional[bool] = None,
expect100: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
@@ -300,10 +303,13 @@ class ClientRequest:
# assert session is not None
self._session = cast("ClientSession", session)
if params:
- q = MultiDict(url.query)
- url2 = url.with_query(params)
- q.extend(url2.query)
- url = url.with_query(q)
+ if _YARL_SUPPORTS_EXTEND_QUERY:
+ url = url.extend_query(params)
+ else:
+ q = MultiDict(url.query)
+ url2 = url.with_query(params)
+ q.extend(url2.query)
+ url = url.with_query(q)
self.original_url = url
self.url = url.with_fragment(None)
self.method = method.upper()
@@ -352,7 +358,12 @@ class ClientRequest:
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
- if writer is not None:
+ if writer is None:
+ return
+ if writer.done():
+ # The writer is already done, so we can reset it immediately.
+ self.__reset_writer()
+ else:
writer.add_done_callback(self.__reset_writer)
def is_ssl(self) -> bool:
@@ -402,8 +413,8 @@ class ClientRequest:
# basic auth info
username, password = url.user, url.password
- if username:
- self.auth = helpers.BasicAuth(username, password or "")
+ if username or password:
+ self.auth = helpers.BasicAuth(username or "", password or "")
def update_version(self, version: Union[http.HttpVersion, str]) -> None:
"""Convert request version to two elements tuple.
@@ -436,7 +447,7 @@ class ClientRequest:
if headers:
if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
- headers = headers.items() # type: ignore[assignment]
+ headers = headers.items()
for key, value in headers: # type: ignore[misc]
# A special case for Host header
@@ -445,12 +456,18 @@ class ClientRequest:
else:
self.headers.add(key, value)
- def update_auto_headers(self, skip_auto_headers: Iterable[str]) -> None:
- self.skip_auto_headers = CIMultiDict(
- (hdr, None) for hdr in sorted(skip_auto_headers)
- )
- used_headers = self.headers.copy()
- used_headers.extend(self.skip_auto_headers) # type: ignore[arg-type]
+ def update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> None:
+ if skip_auto_headers is not None:
+ self.skip_auto_headers = CIMultiDict(
+ (hdr, None) for hdr in sorted(skip_auto_headers)
+ )
+ used_headers = self.headers.copy()
+ used_headers.extend(self.skip_auto_headers) # type: ignore[arg-type]
+ else:
+ # Fast path when there are no headers to skip
+ # which is the most common case.
+ self.skip_auto_headers = CIMultiDict()
+ used_headers = self.headers
for hdr, val in self.DEFAULT_HEADERS.items():
if hdr not in used_headers:
@@ -486,11 +503,12 @@ class ClientRequest:
def update_content_encoding(self, data: Any) -> None:
"""Set request content encoding."""
- if data is None:
+ if not data:
+ # Don't compress an empty body.
+ self.compress = None
return
- enc = self.headers.get(hdrs.CONTENT_ENCODING, "").lower()
- if enc:
+ if self.headers.get(hdrs.CONTENT_ENCODING):
if self.compress:
raise ValueError(
"compress can not be set " "if Content-Encoding header is set"
@@ -566,10 +584,8 @@ class ClientRequest:
# copy payload headers
assert body.headers
- for (key, value) in body.headers.items():
- if key in self.headers:
- continue
- if key in self.skip_auto_headers:
+ for key, value in body.headers.items():
+ if key in self.headers or key in self.skip_auto_headers:
continue
self.headers[key] = value
@@ -592,6 +608,10 @@ class ClientRequest:
raise ValueError("proxy_auth must be None or BasicAuth() tuple")
self.proxy = proxy
self.proxy_auth = proxy_auth
+ if proxy_headers is not None and not isinstance(
+ proxy_headers, (MultiDict, MultiDictProxy)
+ ):
+ proxy_headers = CIMultiDict(proxy_headers)
self.proxy_headers = proxy_headers
def keep_alive(self) -> bool:
@@ -614,11 +634,8 @@ class ClientRequest:
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
- try:
- await writer.drain()
- await self._continue
- except asyncio.CancelledError:
- return
+ await writer.drain()
+ await self._continue
protocol = conn.protocol
assert protocol is not None
@@ -627,10 +644,10 @@ class ClientRequest:
await self.body.write(writer)
else:
if isinstance(self.body, (bytes, bytearray)):
- self.body = (self.body,) # type: ignore[assignment]
+ self.body = (self.body,)
for chunk in self.body:
- await writer.write(chunk) # type: ignore[arg-type]
+ await writer.write(chunk)
except OSError as underlying_exc:
reraised_exc = underlying_exc
@@ -645,7 +662,9 @@ class ClientRequest:
set_exception(protocol, reraised_exc, underlying_exc)
except asyncio.CancelledError:
- await writer.write_eof()
+ # Body hasn't been fully sent, so connection can't be reused.
+ conn.close()
+ raise
except Exception as underlying_exc:
set_exception(
protocol,
@@ -681,16 +700,20 @@ class ClientRequest:
writer = StreamWriter(
protocol,
self.loop,
- on_chunk_sent=functools.partial(
- self._on_chunk_request_sent, self.method, self.url
+ on_chunk_sent=(
+ functools.partial(self._on_chunk_request_sent, self.method, self.url)
+ if self._traces
+ else None
),
- on_headers_sent=functools.partial(
- self._on_headers_request_sent, self.method, self.url
+ on_headers_sent=(
+ functools.partial(self._on_headers_request_sent, self.method, self.url)
+ if self._traces
+ else None
),
)
if self.compress:
- writer.enable_compression(self.compress)
+ writer.enable_compression(self.compress) # type: ignore[arg-type]
if self.chunked is not None:
writer.enable_chunking()
@@ -717,13 +740,20 @@ class ClientRequest:
self.headers[hdrs.CONNECTION] = connection
# status + headers
- status_line = "{0} {1} HTTP/{v.major}.{v.minor}".format(
- self.method, path, v=self.version
- )
+ v = self.version
+ status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
await writer.write_headers(status_line, self.headers)
+ coro = self.write_bytes(writer, conn)
- self._writer = self.loop.create_task(self.write_bytes(writer, conn))
+ if sys.version_info >= (3, 12):
+ # Optimization for Python 3.12, try to write
+ # bytes immediately to avoid having to schedule
+ # the task on the event loop.
+ task = asyncio.Task(coro, loop=self.loop, eager_start=True)
+ else:
+ task = self.loop.create_task(coro)
+ self._writer = task
response_class = self.response_class
assert response_class is not None
self.response = response_class(
@@ -741,8 +771,15 @@ class ClientRequest:
async def close(self) -> None:
if self._writer is not None:
- with contextlib.suppress(asyncio.CancelledError):
+ try:
await self._writer
+ except asyncio.CancelledError:
+ if (
+ sys.version_info >= (3, 11)
+ and (task := asyncio.current_task())
+ and task.cancelling()
+ ):
+ raise
def terminate(self) -> None:
if self._writer is not None:
@@ -782,6 +819,7 @@ class ClientResponse(HeadersMixin):
# post-init stage allows to not change ctor signature
_closed = True # to allow __del__ for non-initialized properly response
_released = False
+ _in_context = False
__writer = None
def __init__(
@@ -820,9 +858,9 @@ class ClientResponse(HeadersMixin):
# work after the response has finished reading the body.
if session is None:
# TODO: Fix session=None in tests (see ClientRequest.__init__).
- self._resolve_charset: Callable[
- ["ClientResponse", bytes], str
- ] = lambda *_: "utf-8"
+ self._resolve_charset: Callable[["ClientResponse", bytes], str] = (
+ lambda *_: "utf-8"
+ )
else:
self._resolve_charset = session._resolve_charset
if loop.get_debug():
@@ -840,7 +878,12 @@ class ClientResponse(HeadersMixin):
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
- if writer is not None:
+ if writer is None:
+ return
+ if writer.done():
+ # The writer is already done, so we can reset it immediately.
+ self.__reset_writer()
+ else:
writer.add_done_callback(self.__reset_writer)
@reify
@@ -1066,7 +1109,12 @@ class ClientResponse(HeadersMixin):
if not self.ok:
# reason should always be not None for a started response
assert self.reason is not None
- self.release()
+
+ # If we're in a context we can rely on __aexit__() to release as the
+ # exception propagates.
+ if not self._in_context:
+ self.release()
+
raise ClientResponseError(
self.request_info,
self.history,
@@ -1085,7 +1133,15 @@ class ClientResponse(HeadersMixin):
async def _wait_released(self) -> None:
if self._writer is not None:
- await self._writer
+ try:
+ await self._writer
+ except asyncio.CancelledError:
+ if (
+ sys.version_info >= (3, 11)
+ and (task := asyncio.current_task())
+ and task.cancelling()
+ ):
+ raise
self._release_connection()
def _cleanup_writer(self) -> None:
@@ -1101,7 +1157,15 @@ class ClientResponse(HeadersMixin):
async def wait_for_close(self) -> None:
if self._writer is not None:
- await self._writer
+ try:
+ await self._writer
+ except asyncio.CancelledError:
+ if (
+ sys.version_info >= (3, 11)
+ and (task := asyncio.current_task())
+ and task.cancelling()
+ ):
+ raise
self.release()
async def read(self) -> bytes:
@@ -1130,7 +1194,7 @@ class ClientResponse(HeadersMixin):
encoding = mimetype.parameters.get("charset")
if encoding:
- with contextlib.suppress(LookupError):
+ with contextlib.suppress(LookupError, ValueError):
return codecs.lookup(encoding).name
if mimetype.type == "application" and (
@@ -1176,6 +1240,7 @@ class ClientResponse(HeadersMixin):
raise ContentTypeError(
self.request_info,
self.history,
+ status=self.status,
message=(
"Attempt to decode JSON with " "unexpected mimetype: %s" % ctype
),
@@ -1192,6 +1257,7 @@ class ClientResponse(HeadersMixin):
return loads(stripped.decode(encoding))
async def __aenter__(self) -> "ClientResponse":
+ self._in_context = True
return self
async def __aexit__(
@@ -1200,6 +1266,7 @@ class ClientResponse(HeadersMixin):
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
+ self._in_context = False
# similar to _RequestContextManager, we do not need to check
# for exceptions, response object can close connection
# if state is broken
diff --git a/contrib/python/aiohttp/aiohttp/client_ws.py b/contrib/python/aiohttp/aiohttp/client_ws.py
index d9c74a30f52..c6b5da5103b 100644
--- a/contrib/python/aiohttp/aiohttp/client_ws.py
+++ b/contrib/python/aiohttp/aiohttp/client_ws.py
@@ -2,11 +2,12 @@
import asyncio
import sys
-from typing import Any, Optional, cast
+from types import TracebackType
+from typing import Any, Optional, Type, cast
-from .client_exceptions import ClientError
+from .client_exceptions import ClientError, ServerTimeoutError
from .client_reqrep import ClientResponse
-from .helpers import call_later, set_result
+from .helpers import calculate_timeout_when, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
@@ -62,63 +63,123 @@ class ClientWebSocketResponse:
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
+ self._heartbeat_when: float = 0.0
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._loop = loop
- self._waiting: Optional[asyncio.Future[bool]] = None
+ self._waiting: bool = False
+ self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
+ self._ping_task: Optional[asyncio.Task[None]] = None
self._reset_heartbeat()
def _cancel_heartbeat(self) -> None:
- if self._pong_response_cb is not None:
- self._pong_response_cb.cancel()
- self._pong_response_cb = None
-
+ self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
+ if self._ping_task is not None:
+ self._ping_task.cancel()
+ self._ping_task = None
+
+ def _cancel_pong_response_cb(self) -> None:
+ if self._pong_response_cb is not None:
+ self._pong_response_cb.cancel()
+ self._pong_response_cb = None
def _reset_heartbeat(self) -> None:
- self._cancel_heartbeat()
+ if self._heartbeat is None:
+ return
+ self._cancel_pong_response_cb()
+ loop = self._loop
+ assert loop is not None
+ conn = self._conn
+ timeout_ceil_threshold = (
+ conn._connector._timeout_ceil_threshold if conn is not None else 5
+ )
+ now = loop.time()
+ when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
+ self._heartbeat_when = when
+ if self._heartbeat_cb is None:
+ # We do not cancel the previous heartbeat_cb here because
+ # it generates a significant amount of TimerHandle churn
+ # which causes asyncio to rebuild the heap frequently.
+ # Instead _send_heartbeat() will reschedule the next
+ # heartbeat if it fires too early.
+ self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
- if self._heartbeat is not None:
- self._heartbeat_cb = call_later(
- self._send_heartbeat,
- self._heartbeat,
- self._loop,
- timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold
- if self._conn is not None
- else 5,
+ def _send_heartbeat(self) -> None:
+ self._heartbeat_cb = None
+ loop = self._loop
+ now = loop.time()
+ if now < self._heartbeat_when:
+ # Heartbeat fired too early, reschedule
+ self._heartbeat_cb = loop.call_at(
+ self._heartbeat_when, self._send_heartbeat
)
+ return
- def _send_heartbeat(self) -> None:
- if self._heartbeat is not None and not self._closed:
- # fire-and-forget a task is not perfect but maybe ok for
- # sending ping. Otherwise we need a long-living heartbeat
- # task in the class.
- self._loop.create_task(self._writer.ping())
+ conn = self._conn
+ timeout_ceil_threshold = (
+ conn._connector._timeout_ceil_threshold if conn is not None else 5
+ )
+ when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
+ self._cancel_pong_response_cb()
+ self._pong_response_cb = loop.call_at(when, self._pong_not_received)
- if self._pong_response_cb is not None:
- self._pong_response_cb.cancel()
- self._pong_response_cb = call_later(
- self._pong_not_received,
- self._pong_heartbeat,
- self._loop,
- timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold
- if self._conn is not None
- else 5,
- )
+ if sys.version_info >= (3, 12):
+ # Optimization for Python 3.12, try to send the ping
+ # immediately to avoid having to schedule
+ # the task on the event loop.
+ ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
+ else:
+ ping_task = loop.create_task(self._writer.ping())
+
+ if not ping_task.done():
+ self._ping_task = ping_task
+ ping_task.add_done_callback(self._ping_task_done)
+ else:
+ self._ping_task_done(ping_task)
+
+ def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
+ """Callback for when the ping task completes."""
+ if not task.cancelled() and (exc := task.exception()):
+ self._handle_ping_pong_exception(exc)
+ self._ping_task = None
def _pong_not_received(self) -> None:
- if not self._closed:
- self._closed = True
- self._close_code = WSCloseCode.ABNORMAL_CLOSURE
- self._exception = asyncio.TimeoutError()
- self._response.close()
+ self._handle_ping_pong_exception(ServerTimeoutError())
+
+ def _handle_ping_pong_exception(self, exc: BaseException) -> None:
+ """Handle exceptions raised during ping/pong processing."""
+ if self._closed:
+ return
+ self._set_closed()
+ self._close_code = WSCloseCode.ABNORMAL_CLOSURE
+ self._exception = exc
+ self._response.close()
+ if self._waiting and not self._closing:
+ self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))
+
+ def _set_closed(self) -> None:
+ """Set the connection to closed.
+
+ Cancel any heartbeat timers and set the closed flag.
+ """
+ self._closed = True
+ self._cancel_heartbeat()
+
+ def _set_closing(self) -> None:
+ """Set the connection to closing.
+
+ Cancel any heartbeat timers and set the closing flag.
+ """
+ self._closing = True
+ self._cancel_heartbeat()
@property
def closed(self) -> bool:
@@ -181,14 +242,15 @@ class ClientWebSocketResponse:
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
- if self._waiting is not None and not self._closing:
- self._closing = True
+ if self._waiting and not self._closing:
+ assert self._loop is not None
+ self._close_wait = self._loop.create_future()
+ self._set_closing()
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
- await self._waiting
+ await self._close_wait
if not self._closed:
- self._cancel_heartbeat()
- self._closed = True
+ self._set_closed()
try:
await self._writer.close(code, message)
except asyncio.CancelledError:
@@ -219,7 +281,7 @@ class ClientWebSocketResponse:
self._response.close()
return True
- if msg.type == WSMsgType.CLOSE:
+ if msg.type is WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
@@ -227,8 +289,10 @@ class ClientWebSocketResponse:
return False
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
+ receive_timeout = timeout or self._receive_timeout
+
while True:
- if self._waiting is not None:
+ if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")
if self._closed:
@@ -238,15 +302,22 @@ class ClientWebSocketResponse:
return WS_CLOSED_MESSAGE
try:
- self._waiting = self._loop.create_future()
+ self._waiting = True
try:
- async with async_timeout.timeout(timeout or self._receive_timeout):
+ if receive_timeout:
+ # Entering the context manager and creating
+ # Timeout() object can take almost 50% of the
+ # run time in this loop so we avoid it if
+ # there is no read timeout.
+ async with async_timeout.timeout(receive_timeout):
+ msg = await self._reader.read()
+ else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
- waiter = self._waiting
- self._waiting = None
- set_result(waiter, True)
+ self._waiting = False
+ if self._close_wait:
+ set_result(self._close_wait, None)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
@@ -255,7 +326,8 @@ class ClientWebSocketResponse:
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except ClientError:
- self._closed = True
+ # Likely ServerDisconnectedError when connection is lost
+ self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
@@ -264,35 +336,35 @@ class ClientWebSocketResponse:
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
- self._closing = True
+ self._set_closing()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
- if msg.type == WSMsgType.CLOSE:
- self._closing = True
+ if msg.type is WSMsgType.CLOSE:
+ self._set_closing()
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
- elif msg.type == WSMsgType.CLOSING:
- self._closing = True
- elif msg.type == WSMsgType.PING and self._autoping:
+ elif msg.type is WSMsgType.CLOSING:
+ self._set_closing()
+ elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
- elif msg.type == WSMsgType.PONG and self._autoping:
+ elif msg.type is WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
- if msg.type != WSMsgType.TEXT:
+ if msg.type is not WSMsgType.TEXT:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not str")
return cast(str, msg.data)
async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
- if msg.type != WSMsgType.BINARY:
+ if msg.type is not WSMsgType.BINARY:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
return cast(bytes, msg.data)
@@ -313,3 +385,14 @@ class ClientWebSocketResponse:
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
raise StopAsyncIteration
return msg
+
+ async def __aenter__(self) -> "ClientWebSocketResponse":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
diff --git a/contrib/python/aiohttp/aiohttp/compression_utils.py b/contrib/python/aiohttp/aiohttp/compression_utils.py
index 9631d377e9a..ab4a2f1cc84 100644
--- a/contrib/python/aiohttp/aiohttp/compression_utils.py
+++ b/contrib/python/aiohttp/aiohttp/compression_utils.py
@@ -50,9 +50,11 @@ class ZLibCompressor(ZlibBaseHandler):
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
- mode=encoding_to_mode(encoding, suppress_deflate_header)
- if wbits is None
- else wbits,
+ mode=(
+ encoding_to_mode(encoding, suppress_deflate_header)
+ if wbits is None
+ else wbits
+ ),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
diff --git a/contrib/python/aiohttp/aiohttp/connector.py b/contrib/python/aiohttp/aiohttp/connector.py
index f95ebe84c66..8288a0115b7 100644
--- a/contrib/python/aiohttp/aiohttp/connector.py
+++ b/contrib/python/aiohttp/aiohttp/connector.py
@@ -1,6 +1,7 @@
import asyncio
import functools
import random
+import socket
import sys
import traceback
import warnings
@@ -22,6 +23,7 @@ from typing import (
List,
Literal,
Optional,
+ Sequence,
Set,
Tuple,
Type,
@@ -29,10 +31,11 @@ from typing import (
cast,
)
+import aiohappyeyeballs
import attr
from . import hdrs, helpers
-from .abc import AbstractResolver
+from .abc import AbstractResolver, ResolveResult
from .client_exceptions import (
ClientConnectionError,
ClientConnectorCertificateError,
@@ -47,7 +50,7 @@ from .client_exceptions import (
)
from .client_proto import ResponseHandler
from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
-from .helpers import ceil_timeout, get_running_loop, is_ip_address, noop, sentinel
+from .helpers import ceil_timeout, is_ip_address, noop, sentinel
from .locks import EventResultOrError
from .resolver import DefaultResolver
@@ -60,6 +63,14 @@ except ImportError: # pragma: no cover
SSLContext = object # type: ignore[misc,assignment]
+EMPTY_SCHEMA_SET = frozenset({""})
+HTTP_SCHEMA_SET = frozenset({"http", "https"})
+WS_SCHEMA_SET = frozenset({"ws", "wss"})
+
+HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET
+HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET
+
+
__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
@@ -208,6 +219,8 @@ class BaseConnector:
# abort transport after 2 seconds (cleanup broken connections)
_cleanup_closed_period = 2.0
+ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET
+
def __init__(
self,
*,
@@ -229,7 +242,7 @@ class BaseConnector:
if keepalive_timeout is sentinel:
keepalive_timeout = 15.0
- loop = get_running_loop(loop)
+ loop = loop or asyncio.get_event_loop()
self._timeout_ceil_threshold = timeout_ceil_threshold
self._closed = False
@@ -240,9 +253,9 @@ class BaseConnector:
self._limit = limit
self._limit_per_host = limit_per_host
self._acquired: Set[ResponseHandler] = set()
- self._acquired_per_host: DefaultDict[
- ConnectionKey, Set[ResponseHandler]
- ] = defaultdict(set)
+ self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = (
+ defaultdict(set)
+ )
self._keepalive_timeout = cast(float, keepalive_timeout)
self._force_close = force_close
@@ -691,14 +704,14 @@ class BaseConnector:
class _DNSCacheTable:
def __init__(self, ttl: Optional[float] = None) -> None:
- self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]] = {}
+ self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {}
self._timestamps: Dict[Tuple[str, int], float] = {}
self._ttl = ttl
def __contains__(self, host: object) -> bool:
return host in self._addrs_rr
- def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None:
+ def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None:
self._addrs_rr[key] = (cycle(addrs), len(addrs))
if self._ttl is not None:
@@ -714,7 +727,7 @@ class _DNSCacheTable:
self._addrs_rr.clear()
self._timestamps.clear()
- def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]:
+ def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]:
loop, length = self._addrs_rr[key]
addrs = list(islice(loop, length))
# Consume one more element to shift internal state of `cycle`
@@ -728,6 +741,35 @@ class _DNSCacheTable:
return self._timestamps[key] + self._ttl < monotonic()
+def _make_ssl_context(verified: bool) -> SSLContext:
+ """Create SSL context.
+
+ This method is not async-friendly and should be called from a thread
+ because it will load certificates from disk and do other blocking I/O.
+ """
+ if ssl is None:
+ # No ssl support
+ return None
+ if verified:
+ return ssl.create_default_context()
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ sslcontext.options |= ssl.OP_NO_SSLv2
+ sslcontext.options |= ssl.OP_NO_SSLv3
+ sslcontext.check_hostname = False
+ sslcontext.verify_mode = ssl.CERT_NONE
+ sslcontext.options |= ssl.OP_NO_COMPRESSION
+ sslcontext.set_default_verify_paths()
+ return sslcontext
+
+
+# The default SSLContext objects are created at import time
+# since they do blocking I/O to load certificates from disk,
+# and imports should always be done before the event loop starts
+# or in a thread.
+_SSL_CONTEXT_VERIFIED = _make_ssl_context(True)
+_SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False)
+
+
class TCPConnector(BaseConnector):
"""TCP connector.
@@ -735,7 +777,7 @@ class TCPConnector(BaseConnector):
fingerprint - Pass the binary sha256
digest of the expected certificate in DER format to verify
that the certificate the server presents matches. See also
- https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning
+ https://en.wikipedia.org/wiki/HTTP_Public_Key_Pinning
resolver - Enable DNS lookups and use this
resolver
use_dns_cache - Use memory cache for DNS lookups.
@@ -750,9 +792,15 @@ class TCPConnector(BaseConnector):
limit_per_host - Number of simultaneous connections to one host.
enable_cleanup_closed - Enables clean-up closed ssl transports.
Disabled by default.
+ happy_eyeballs_delay - This is the “Connection Attempt Delay”
+ as defined in RFC 8305. To disable
+ the happy eyeballs algorithm, set to None.
+ interleave - “First Address Family Count” as defined in RFC 8305
loop - Optional event loop.
"""
+ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
+
def __init__(
self,
*,
@@ -760,7 +808,7 @@ class TCPConnector(BaseConnector):
fingerprint: Optional[bytes] = None,
use_dns_cache: bool = True,
ttl_dns_cache: Optional[int] = 10,
- family: int = 0,
+ family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC,
ssl_context: Optional[SSLContext] = None,
ssl: Union[bool, Fingerprint, SSLContext] = True,
local_addr: Optional[Tuple[str, int]] = None,
@@ -772,6 +820,8 @@ class TCPConnector(BaseConnector):
enable_cleanup_closed: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
timeout_ceil_threshold: float = 5,
+ happy_eyeballs_delay: Optional[float] = 0.25,
+ interleave: Optional[int] = None,
):
super().__init__(
keepalive_timeout=keepalive_timeout,
@@ -792,13 +842,19 @@ class TCPConnector(BaseConnector):
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {}
self._family = family
- self._local_addr = local_addr
+ self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr)
+ self._happy_eyeballs_delay = happy_eyeballs_delay
+ self._interleave = interleave
+ self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
def close(self) -> Awaitable[None]:
"""Close all ongoing DNS calls."""
for ev in self._throttle_dns_events.values():
ev.cancel()
+ for t in self._resolve_host_tasks:
+ t.cancel()
+
return super().close()
@property
@@ -823,8 +879,8 @@ class TCPConnector(BaseConnector):
self._cached_hosts.clear()
async def _resolve_host(
- self, host: str, port: int, traces: Optional[List["Trace"]] = None
- ) -> List[Dict[str, Any]]:
+ self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None
+ ) -> List[ResolveResult]:
"""Resolve host and return list of addresses."""
if is_ip_address(host):
return [
@@ -876,11 +932,13 @@ class TCPConnector(BaseConnector):
resolved_host_task = asyncio.create_task(
self._resolve_host_with_throttle(key, host, port, traces)
)
+ self._resolve_host_tasks.add(resolved_host_task)
+ resolved_host_task.add_done_callback(self._resolve_host_tasks.discard)
try:
return await asyncio.shield(resolved_host_task)
except asyncio.CancelledError:
- def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
+ def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()
@@ -892,8 +950,8 @@ class TCPConnector(BaseConnector):
key: Tuple[str, int],
host: str,
port: int,
- traces: Optional[List["Trace"]],
- ) -> List[Dict[str, Any]]:
+ traces: Optional[Sequence["Trace"]],
+ ) -> List[ResolveResult]:
"""Resolve host with a dns events throttle."""
if key in self._throttle_dns_events:
# get event early, before any await (#4014)
@@ -945,29 +1003,6 @@ class TCPConnector(BaseConnector):
return proto
- @staticmethod
- @functools.lru_cache(None)
- def _make_ssl_context(verified: bool) -> SSLContext:
- if verified:
- return ssl.create_default_context()
- else:
- sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- sslcontext.options |= ssl.OP_NO_SSLv2
- sslcontext.options |= ssl.OP_NO_SSLv3
- sslcontext.check_hostname = False
- sslcontext.verify_mode = ssl.CERT_NONE
- try:
- sslcontext.options |= ssl.OP_NO_COMPRESSION
- except AttributeError as attr_err:
- warnings.warn(
- "{!s}: The Python interpreter is compiled "
- "against OpenSSL < 1.0.0. Ref: "
- "https://docs.python.org/3/library/ssl.html"
- "#ssl.OP_NO_COMPRESSION".format(attr_err),
- )
- sslcontext.set_default_verify_paths()
- return sslcontext
-
def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
"""Logic to get the correct SSL context
@@ -982,25 +1017,25 @@ class TCPConnector(BaseConnector):
3. if verify_ssl is False in req, generate a SSL context that
won't verify
"""
- if req.is_ssl():
- if ssl is None: # pragma: no cover
- raise RuntimeError("SSL is not supported.")
- sslcontext = req.ssl
- if isinstance(sslcontext, ssl.SSLContext):
- return sslcontext
- if sslcontext is not True:
- # not verified or fingerprinted
- return self._make_ssl_context(False)
- sslcontext = self._ssl
- if isinstance(sslcontext, ssl.SSLContext):
- return sslcontext
- if sslcontext is not True:
- # not verified or fingerprinted
- return self._make_ssl_context(False)
- return self._make_ssl_context(True)
- else:
+ if not req.is_ssl():
return None
+ if ssl is None: # pragma: no cover
+ raise RuntimeError("SSL is not supported.")
+ sslcontext = req.ssl
+ if isinstance(sslcontext, ssl.SSLContext):
+ return sslcontext
+ if sslcontext is not True:
+ # not verified or fingerprinted
+ return _SSL_CONTEXT_UNVERIFIED
+ sslcontext = self._ssl
+ if isinstance(sslcontext, ssl.SSLContext):
+ return sslcontext
+ if sslcontext is not True:
+ # not verified or fingerprinted
+ return _SSL_CONTEXT_UNVERIFIED
+ return _SSL_CONTEXT_VERIFIED
+
def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
ret = req.ssl
if isinstance(ret, Fingerprint):
@@ -1013,6 +1048,36 @@ class TCPConnector(BaseConnector):
async def _wrap_create_connection(
self,
*args: Any,
+ addr_infos: List[aiohappyeyeballs.AddrInfoType],
+ req: ClientRequest,
+ timeout: "ClientTimeout",
+ client_error: Type[Exception] = ClientConnectorError,
+ **kwargs: Any,
+ ) -> Tuple[asyncio.Transport, ResponseHandler]:
+ try:
+ async with ceil_timeout(
+ timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
+ ):
+ sock = await aiohappyeyeballs.start_connection(
+ addr_infos=addr_infos,
+ local_addr_infos=self._local_addr_infos,
+ happy_eyeballs_delay=self._happy_eyeballs_delay,
+ interleave=self._interleave,
+ loop=self._loop,
+ )
+ return await self._loop.create_connection(*args, **kwargs, sock=sock)
+ except cert_errors as exc:
+ raise ClientConnectorCertificateError(req.connection_key, exc) from exc
+ except ssl_errors as exc:
+ raise ClientConnectorSSLError(req.connection_key, exc) from exc
+ except OSError as exc:
+ if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
+ raise
+ raise client_error(req.connection_key, exc) from exc
+
+ async def _wrap_existing_connection(
+ self,
+ *args: Any,
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
@@ -1121,13 +1186,11 @@ class TCPConnector(BaseConnector):
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
"""Wrap the raw TCP transport with TLS."""
tls_proto = self._factory() # Create a brand new proto for TLS
-
- # Safety of the `cast()` call here is based on the fact that
- # internally `_get_ssl_context()` only returns `None` when
- # `req.is_ssl()` evaluates to `False` which is never gonna happen
- # in this code path. Of course, it's rather fragile
- # maintainability-wise but this is to be solved separately.
- sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req))
+ sslcontext = self._get_ssl_context(req)
+ if TYPE_CHECKING:
+ # _start_tls_connection is unreachable in the current code path
+ # if sslcontext is None.
+ assert sslcontext is not None
try:
async with ceil_timeout(
@@ -1176,6 +1239,27 @@ class TCPConnector(BaseConnector):
return tls_transport, tls_proto
+ def _convert_hosts_to_addr_infos(
+ self, hosts: List[ResolveResult]
+ ) -> List[aiohappyeyeballs.AddrInfoType]:
+ """Converts the list of hosts to a list of addr_infos.
+
+ The list of hosts is the result of a DNS lookup. The list of
+ addr_infos is the result of a call to `socket.getaddrinfo()`.
+ """
+ addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
+ for hinfo in hosts:
+ host = hinfo["host"]
+ is_ipv6 = ":" in host
+ family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
+ if self._family and self._family != family:
+ continue
+ addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"])
+ addr_infos.append(
+ (family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)
+ )
+ return addr_infos
+
async def _create_direct_connection(
self,
req: ClientRequest,
@@ -1209,36 +1293,27 @@ class TCPConnector(BaseConnector):
raise ClientConnectorError(req.connection_key, exc) from exc
last_exc: Optional[Exception] = None
-
- for hinfo in hosts:
- host = hinfo["host"]
- port = hinfo["port"]
-
+ addr_infos = self._convert_hosts_to_addr_infos(hosts)
+ while addr_infos:
# Strip trailing dots, certificates contain FQDN without dots.
# See https://github.com/aio-libs/aiohttp/issues/3636
server_hostname = (
- (req.server_hostname or hinfo["hostname"]).rstrip(".")
- if sslcontext
- else None
+ (req.server_hostname or host).rstrip(".") if sslcontext else None
)
try:
transp, proto = await self._wrap_create_connection(
self._factory,
- host,
- port,
timeout=timeout,
ssl=sslcontext,
- family=hinfo["family"],
- proto=hinfo["proto"],
- flags=hinfo["flags"],
+ addr_infos=addr_infos,
server_hostname=server_hostname,
- local_addr=self._local_addr,
req=req,
client_error=client_error,
)
except ClientConnectorError as exc:
last_exc = exc
+ aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave)
continue
if req.is_ssl() and fingerprint:
@@ -1249,6 +1324,10 @@ class TCPConnector(BaseConnector):
if not self._cleanup_closed_disabled:
self._cleanup_closed_transports.append(transp)
last_exc = exc
+ # Remove the bad peer from the list of addr_infos
+ sock: socket.socket = transp.get_extra_info("socket")
+ bad_peer = sock.getpeername()
+ aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer)
continue
return transp, proto
@@ -1367,7 +1446,7 @@ class TCPConnector(BaseConnector):
if not runtime_has_start_tls:
# HTTP proxy with support for upgrade to HTTPS
sslcontext = self._get_ssl_context(req)
- return await self._wrap_create_connection(
+ return await self._wrap_existing_connection(
self._factory,
timeout=timeout,
ssl=sslcontext,
@@ -1401,6 +1480,8 @@ class UnixConnector(BaseConnector):
loop - Optional event loop.
"""
+ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"})
+
def __init__(
self,
path: str,
@@ -1457,6 +1538,8 @@ class NamedPipeConnector(BaseConnector):
loop - Optional event loop.
"""
+ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"})
+
def __init__(
self,
path: str,
diff --git a/contrib/python/aiohttp/aiohttp/cookiejar.py b/contrib/python/aiohttp/aiohttp/cookiejar.py
index a348f112cb5..c78d5fa7e72 100644
--- a/contrib/python/aiohttp/aiohttp/cookiejar.py
+++ b/contrib/python/aiohttp/aiohttp/cookiejar.py
@@ -2,6 +2,8 @@ import asyncio
import calendar
import contextlib
import datetime
+import heapq
+import itertools
import os # noqa
import pathlib
import pickle
@@ -9,8 +11,7 @@ import re
import time
from collections import defaultdict
from http.cookies import BaseCookie, Morsel, SimpleCookie
-from math import ceil
-from typing import ( # noqa
+from typing import (
DefaultDict,
Dict,
Iterable,
@@ -35,6 +36,15 @@ __all__ = ("CookieJar", "DummyCookieJar")
CookieItem = Union[str, "Morsel[str]"]
+# We cache these string methods here as their use is in performance critical code.
+_FORMAT_PATH = "{}/{}".format
+_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
+
+# The minimum number of scheduled cookie expirations before we start cleaning up
+# the expiration heap. This is a performance optimization to avoid cleaning up the
+# heap too often when there are only a few scheduled expirations.
+_MIN_SCHEDULED_COOKIE_EXPIRATION = 100
+
class CookieJar(AbstractCookieJar):
"""Implements cookie storage adhering to RFC 6265."""
@@ -85,6 +95,9 @@ class CookieJar(AbstractCookieJar):
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
SimpleCookie
)
+ self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
+ defaultdict(dict)
+ )
self._host_only_cookies: Set[Tuple[str, str]] = set()
self._unsafe = unsafe
self._quote_cookie = quote_cookie
@@ -100,7 +113,7 @@ class CookieJar(AbstractCookieJar):
for url in treat_as_secure_origin
]
self._treat_as_secure_origin = treat_as_secure_origin
- self._next_expiration: float = ceil(time.time())
+ self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
self._expirations: Dict[Tuple[str, str, str], float] = {}
def save(self, file_path: PathLike) -> None:
@@ -115,34 +128,26 @@ class CookieJar(AbstractCookieJar):
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
if predicate is None:
- self._next_expiration = ceil(time.time())
+ self._expire_heap.clear()
self._cookies.clear()
+ self._morsel_cache.clear()
self._host_only_cookies.clear()
self._expirations.clear()
return
- to_del = []
now = time.time()
- for (domain, path), cookie in self._cookies.items():
- for name, morsel in cookie.items():
- key = (domain, path, name)
- if (
- key in self._expirations and self._expirations[key] <= now
- ) or predicate(morsel):
- to_del.append(key)
-
- for domain, path, name in to_del:
- self._host_only_cookies.discard((domain, name))
- key = (domain, path, name)
- if key in self._expirations:
- del self._expirations[(domain, path, name)]
- self._cookies[(domain, path)].pop(name, None)
-
- self._next_expiration = (
- min(*self._expirations.values(), self.SUB_MAX_TIME) + 1
- if self._expirations
- else self.MAX_TIME
- )
+ to_del = [
+ key
+ for (domain, path), cookie in self._cookies.items()
+ for name, morsel in cookie.items()
+ if (
+ (key := (domain, path, name)) in self._expirations
+ and self._expirations[key] <= now
+ )
+ or predicate(morsel)
+ ]
+ if to_del:
+ self._delete_cookies(to_del)
def clear_domain(self, domain: str) -> None:
self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
@@ -153,14 +158,70 @@ class CookieJar(AbstractCookieJar):
yield from val.values()
def __len__(self) -> int:
- return sum(1 for i in self)
+ """Return number of cookies.
+
+ This function does not iterate self to avoid unnecessary expiration
+ checks.
+ """
+ return sum(len(cookie.values()) for cookie in self._cookies.values())
def _do_expiration(self) -> None:
- self.clear(lambda x: False)
+ """Remove expired cookies."""
+ if not (expire_heap_len := len(self._expire_heap)):
+ return
+
+ # If the expiration heap grows larger than the number expirations
+ # times two, we clean it up to avoid keeping expired entries in
+ # the heap and consuming memory. We guard this with a minimum
+ # threshold to avoid cleaning up the heap too often when there are
+ # only a few scheduled expirations.
+ if (
+ expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
+ and expire_heap_len > len(self._expirations) * 2
+ ):
+ # Remove any expired entries from the expiration heap
+ # that do not match the expiration time in the expirations
+ # as it means the cookie has been re-added to the heap
+ # with a different expiration time.
+ self._expire_heap = [
+ entry
+ for entry in self._expire_heap
+ if self._expirations.get(entry[1]) == entry[0]
+ ]
+ heapq.heapify(self._expire_heap)
+
+ now = time.time()
+ to_del: List[Tuple[str, str, str]] = []
+ # Find any expired cookies and add them to the to-delete list
+ while self._expire_heap:
+ when, cookie_key = self._expire_heap[0]
+ if when > now:
+ break
+ heapq.heappop(self._expire_heap)
+ # Check if the cookie hasn't been re-added to the heap
+ # with a different expiration time as it will be removed
+ # later when it reaches the top of the heap and its
+ # expiration time is met.
+ if self._expirations.get(cookie_key) == when:
+ to_del.append(cookie_key)
+
+ if to_del:
+ self._delete_cookies(to_del)
+
+ def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
+ for domain, path, name in to_del:
+ self._host_only_cookies.discard((domain, name))
+ self._cookies[(domain, path)].pop(name, None)
+ self._morsel_cache[(domain, path)].pop(name, None)
+ self._expirations.pop((domain, path, name), None)
def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
- self._next_expiration = min(self._next_expiration, when)
- self._expirations[(domain, path, name)] = when
+ cookie_key = (domain, path, name)
+ if self._expirations.get(cookie_key) == when:
+ # Avoid adding duplicates to the heap
+ return
+ heapq.heappush(self._expire_heap, (when, cookie_key))
+ self._expirations[cookie_key] = when
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
@@ -182,7 +243,7 @@ class CookieJar(AbstractCookieJar):
domain = cookie["domain"]
# ignore domains with trailing dots
- if domain.endswith("."):
+ if domain and domain[-1] == ".":
domain = ""
del cookie["domain"]
@@ -192,7 +253,7 @@ class CookieJar(AbstractCookieJar):
self._host_only_cookies.add((hostname, name))
domain = cookie["domain"] = hostname
- if domain.startswith("."):
+ if domain and domain[0] == ".":
# Remove leading dot
domain = domain[1:]
cookie["domain"] = domain
@@ -202,7 +263,7 @@ class CookieJar(AbstractCookieJar):
continue
path = cookie["path"]
- if not path or not path.startswith("/"):
+ if not path or path[0] != "/":
# Set the cookie's path to the response path
path = response_url.path
if not path.startswith("/"):
@@ -211,9 +272,9 @@ class CookieJar(AbstractCookieJar):
# Cut everything from the last slash to the end
path = "/" + path[1 : path.rfind("/")]
cookie["path"] = path
+ path = path.rstrip("/")
- max_age = cookie["max-age"]
- if max_age:
+ if max_age := cookie["max-age"]:
try:
delta_seconds = int(max_age)
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
@@ -221,16 +282,18 @@ class CookieJar(AbstractCookieJar):
except ValueError:
cookie["max-age"] = ""
- else:
- expires = cookie["expires"]
- if expires:
- expire_time = self._parse_date(expires)
- if expire_time:
- self._expire_cookie(expire_time, domain, path, name)
- else:
- cookie["expires"] = ""
+ elif expires := cookie["expires"]:
+ if expire_time := self._parse_date(expires):
+ self._expire_cookie(expire_time, domain, path, name)
+ else:
+ cookie["expires"] = ""
- self._cookies[(domain, path)][name] = cookie
+ key = (domain, path)
+ if self._cookies[key].get(name) != cookie:
+ # Don't blow away the cache if the same
+ # cookie gets set again
+ self._cookies[key][name] = cookie
+ self._morsel_cache[key].pop(name, None)
self._do_expiration()
@@ -256,36 +319,52 @@ class CookieJar(AbstractCookieJar):
request_origin = request_url.origin()
is_not_secure = request_origin not in self._treat_as_secure_origin
- # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
- for cookie in sorted(self, key=lambda c: len(c["path"])):
- name = cookie.key
- domain = cookie["domain"]
+ # Send shared cookie
+ for c in self._cookies[("", "")].values():
+ filtered[c.key] = c.value
- # Send shared cookies
- if not domain:
- filtered[name] = cookie.value
- continue
+ if is_ip_address(hostname):
+ if not self._unsafe:
+ return filtered
+ domains: Iterable[str] = (hostname,)
+ else:
+ # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
+ domains = itertools.accumulate(
+ reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
+ )
- if not self._unsafe and is_ip_address(hostname):
- continue
+ # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
+ paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
+ # Create every combination of (domain, path) pairs.
+ pairs = itertools.product(domains, paths)
+
+ path_len = len(request_url.path)
+ # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
+ for p in pairs:
+ for name, cookie in self._cookies[p].items():
+ domain = cookie["domain"]
- if (domain, name) in self._host_only_cookies:
- if domain != hostname:
+ if (domain, name) in self._host_only_cookies and domain != hostname:
continue
- elif not self._is_domain_match(domain, hostname):
- continue
- if not self._is_path_match(request_url.path, cookie["path"]):
- continue
+ # Skip edge case when the cookie has a trailing slash but request doesn't.
+ if len(cookie["path"]) > path_len:
+ continue
- if is_not_secure and cookie["secure"]:
- continue
+ if is_not_secure and cookie["secure"]:
+ continue
- # It's critical we use the Morsel so the coded_value
- # (based on cookie version) is preserved
- mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
- mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
- filtered[name] = mrsl_val
+ # We already built the Morsel so reuse it here
+ if name in self._morsel_cache[p]:
+ filtered[name] = self._morsel_cache[p][name]
+ continue
+
+ # It's critical we use the Morsel so the coded_value
+ # (based on cookie version) is preserved
+ mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
+ mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
+ self._morsel_cache[p][name] = mrsl_val
+ filtered[name] = mrsl_val
return filtered
@@ -305,25 +384,6 @@ class CookieJar(AbstractCookieJar):
return not is_ip_address(hostname)
- @staticmethod
- def _is_path_match(req_path: str, cookie_path: str) -> bool:
- """Implements path matching adhering to RFC 6265."""
- if not req_path.startswith("/"):
- req_path = "/"
-
- if req_path == cookie_path:
- return True
-
- if not req_path.startswith(cookie_path):
- return False
-
- if cookie_path.endswith("/"):
- return True
-
- non_matching = req_path[len(cookie_path) :]
-
- return non_matching.startswith("/")
-
@classmethod
def _parse_date(cls, date_str: str) -> Optional[int]:
"""Implements date string parsing adhering to RFC 6265."""
diff --git a/contrib/python/aiohttp/aiohttp/helpers.py b/contrib/python/aiohttp/aiohttp/helpers.py
index 284033b7a04..40705b16d71 100644
--- a/contrib/python/aiohttp/aiohttp/helpers.py
+++ b/contrib/python/aiohttp/aiohttp/helpers.py
@@ -14,7 +14,6 @@ import platform
import re
import sys
import time
-import warnings
import weakref
from collections import namedtuple
from contextlib import suppress
@@ -35,7 +34,6 @@ from typing import (
List,
Mapping,
Optional,
- Pattern,
Protocol,
Tuple,
Type,
@@ -52,7 +50,7 @@ from multidict import MultiDict, MultiDictProxy, MultiMapping
from yarl import URL
from . import hdrs
-from .log import client_logger, internal_logger
+from .log import client_logger
if sys.version_info >= (3, 11):
import asyncio as async_timeout
@@ -165,9 +163,9 @@ class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
"""Create BasicAuth from url."""
if not isinstance(url, URL):
raise TypeError("url should be yarl.URL instance")
- if url.user is None:
+ if url.user is None and url.password is None:
return None
- return cls(url.user, url.password or "", encoding=encoding)
+ return cls(url.user or "", url.password or "", encoding=encoding)
def encode(self) -> str:
"""Encode credentials."""
@@ -287,38 +285,6 @@ def proxies_from_env() -> Dict[str, ProxyInfo]:
return ret
-def current_task(
- loop: Optional[asyncio.AbstractEventLoop] = None,
-) -> "Optional[asyncio.Task[Any]]":
- return asyncio.current_task(loop=loop)
-
-
-def get_running_loop(
- loop: Optional[asyncio.AbstractEventLoop] = None,
-) -> asyncio.AbstractEventLoop:
- if loop is None:
- loop = asyncio.get_event_loop()
- if not loop.is_running():
- warnings.warn(
- "The object should be created within an async function",
- DeprecationWarning,
- stacklevel=3,
- )
- if loop.get_debug():
- internal_logger.warning(
- "The object should be created within an async function", stack_info=True
- )
- return loop
-
-
-def isasyncgenfunction(obj: Any) -> bool:
- func = getattr(inspect, "isasyncgenfunction", None)
- if func is not None:
- return func(obj) # type: ignore[no-any-return]
- else:
- return False
-
-
def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
"""Get a permitted proxy for the given URL from the env."""
if url.host is not None and proxy_bypass(url.host):
@@ -504,44 +470,51 @@ try:
except ImportError:
pass
-_ipv4_pattern = (
- r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}"
- r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
-)
-_ipv6_pattern = (
- r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}"
- r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)"
- r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})"
- r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}"
- r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}"
- r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)"
- r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}"
- r":|:(:[A-F0-9]{1,4}){7})$"
-)
-_ipv4_regex = re.compile(_ipv4_pattern)
-_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
-_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii"))
-_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE)
+def is_ipv4_address(host: Optional[Union[str, bytes]]) -> bool:
+ """Check if host looks like an IPv4 address.
-def _is_ip_address(
- regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]]
-) -> bool:
- if host is None:
+ This function does not validate that the format is correct, only that
+ the host is a str or bytes, and its all numeric.
+
+ This check is only meant as a heuristic to ensure that
+ a host is not a domain name.
+ """
+ if not host:
return False
+ # For a host to be an ipv4 address, it must be all numeric.
if isinstance(host, str):
- return bool(regex.match(host))
- elif isinstance(host, (bytes, bytearray, memoryview)):
- return bool(regexb.match(host))
- else:
- raise TypeError(f"{host} [{type(host)}] is not a str or bytes")
+ return host.replace(".", "").isdigit()
+ if isinstance(host, (bytes, bytearray, memoryview)):
+ return host.decode("ascii").replace(".", "").isdigit()
+ raise TypeError(f"{host} [{type(host)}] is not a str or bytes")
+
+def is_ipv6_address(host: Optional[Union[str, bytes]]) -> bool:
+ """Check if host looks like an IPv6 address.
-is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
-is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)
+ This function does not validate that the format is correct, only that
+ the host contains a colon and that it is a str or bytes.
+
+ This check is only meant as a heuristic to ensure that
+ a host is not a domain name.
+ """
+ if not host:
+ return False
+ # The host must contain a colon to be an IPv6 address.
+ if isinstance(host, str):
+ return ":" in host
+ if isinstance(host, (bytes, bytearray, memoryview)):
+ return b":" in host
+ raise TypeError(f"{host} [{type(host)}] is not a str or bytes")
def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
+ """Check if host looks like an IP Address.
+
+ This check is only meant as a heuristic to ensure that
+ a host is not a domain name.
+ """
return is_ipv4_address(host) or is_ipv6_address(host)
@@ -619,12 +592,23 @@ def call_later(
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
- if timeout is not None and timeout > 0:
- when = loop.time() + timeout
- if timeout > timeout_ceil_threshold:
- when = ceil(when)
- return loop.call_at(when, cb)
- return None
+ if timeout is None or timeout <= 0:
+ return None
+ now = loop.time()
+ when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
+ return loop.call_at(when, cb)
+
+
+def calculate_timeout_when(
+ loop_time: float,
+ timeout: float,
+ timeout_ceiling_threshold: float,
+) -> float:
+ """Calculate when to execute a timeout."""
+ when = loop_time + timeout
+ if timeout > timeout_ceiling_threshold:
+ return ceil(when)
+ return when
class TimeoutHandle:
@@ -651,7 +635,7 @@ class TimeoutHandle:
def close(self) -> None:
self._callbacks.clear()
- def start(self) -> Optional[asyncio.Handle]:
+ def start(self) -> Optional[asyncio.TimerHandle]:
timeout = self._timeout
if timeout is not None and timeout > 0:
when = self._loop.time() + timeout
@@ -709,7 +693,7 @@ class TimerContext(BaseTimerContext):
raise asyncio.TimeoutError from None
def __enter__(self) -> BaseTimerContext:
- task = current_task(loop=self._loop)
+ task = asyncio.current_task(loop=self._loop)
if task is None:
raise RuntimeError(
@@ -749,7 +733,7 @@ def ceil_timeout(
if delay is None or delay <= 0:
return async_timeout.timeout(None)
- loop = get_running_loop()
+ loop = asyncio.get_running_loop()
now = loop.time()
when = now + delay
if delay > ceil_threshold:
@@ -784,7 +768,8 @@ class HeadersMixin:
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
- return self._content_type # type: ignore[return-value]
+ assert self._content_type is not None
+ return self._content_type
@property
def charset(self) -> Optional[str]:
@@ -792,7 +777,8 @@ class HeadersMixin:
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
- return self._content_dict.get("charset") # type: ignore[union-attr]
+ assert self._content_dict is not None
+ return self._content_dict.get("charset")
@property
def content_length(self) -> Optional[int]:
@@ -818,8 +804,7 @@ class ErrorableProtocol(Protocol):
self,
exc: BaseException,
exc_cause: BaseException = ...,
- ) -> None:
- ... # pragma: no cover
+ ) -> None: ... # pragma: no cover
def set_exception(
@@ -905,12 +890,10 @@ class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
)
@overload # type: ignore[override]
- def __getitem__(self, key: AppKey[_T]) -> _T:
- ...
+ def __getitem__(self, key: AppKey[_T]) -> _T: ...
@overload
- def __getitem__(self, key: str) -> Any:
- ...
+ def __getitem__(self, key: str) -> Any: ...
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
for mapping in self._maps:
@@ -921,16 +904,13 @@ class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
raise KeyError(key)
@overload # type: ignore[override]
- def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]:
- ...
+ def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ...
@overload
- def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]:
- ...
+ def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
@overload
- def get(self, key: str, default: Any = ...) -> Any:
- ...
+ def get(self, key: str, default: Any = ...) -> Any: ...
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
try:
@@ -993,6 +973,7 @@ def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
return None
def must_be_empty_body(method: str, code: int) -> bool:
"""Check if a request must return an empty body."""
return (
diff --git a/contrib/python/aiohttp/aiohttp/http_exceptions.py b/contrib/python/aiohttp/aiohttp/http_exceptions.py
index 72eac3a3cac..c43ee0d9659 100644
--- a/contrib/python/aiohttp/aiohttp/http_exceptions.py
+++ b/contrib/python/aiohttp/aiohttp/http_exceptions.py
@@ -1,6 +1,5 @@
"""Low-level http related exceptions."""
-
from textwrap import indent
from typing import Optional, Union
diff --git a/contrib/python/aiohttp/aiohttp/http_parser.py b/contrib/python/aiohttp/aiohttp/http_parser.py
index 013511917e8..686a2d02e28 100644
--- a/contrib/python/aiohttp/aiohttp/http_parser.py
+++ b/contrib/python/aiohttp/aiohttp/http_parser.py
@@ -47,7 +47,6 @@ from .http_exceptions import (
TransferEncodingError,
)
from .http_writer import HttpVersion, HttpVersion10
-from .log import internal_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import RawHeaders
@@ -249,7 +248,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
timer: Optional[BaseTimerContext] = None,
code: Optional[int] = None,
method: Optional[str] = None,
- readall: bool = False,
payload_exception: Optional[Type[BaseException]] = None,
response_with_body: bool = True,
read_until_eof: bool = False,
@@ -263,7 +261,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
self.timer = timer
self.code = code
self.method = method
- self.readall = readall
self.payload_exception = payload_exception
self.response_with_body = response_with_body
self.read_until_eof = read_until_eof
@@ -280,8 +277,10 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
)
@abc.abstractmethod
- def parse_message(self, lines: List[bytes]) -> _MsgT:
- pass
+ def parse_message(self, lines: List[bytes]) -> _MsgT: ...
+
+ @abc.abstractmethod
+ def _is_chunked_te(self, te: str) -> bool: ...
def feed_eof(self) -> Optional[_MsgT]:
if self._payload_parser is not None:
@@ -318,6 +317,7 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
start_pos = 0
loop = self.loop
+ should_close = False
while start_pos < data_len:
# read HTTP message (request/response line + headers), \r\n\r\n
@@ -330,6 +330,9 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
continue
if pos >= start_pos:
+ if should_close:
+ raise BadHttpMessage("Data after `Connection: close`")
+
# line found
line = data[start_pos:pos]
if SEP == b"\n": # For lax response parsing
@@ -393,7 +396,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
method=method,
compression=msg.compression,
code=self.code,
- readall=self.readall,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
@@ -413,7 +415,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
payload,
method=msg.method,
compression=msg.compression,
- readall=True,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
@@ -431,7 +432,6 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
method=method,
compression=msg.compression,
code=self.code,
- readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
@@ -442,6 +442,7 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
payload = EMPTY_PAYLOAD
messages.append((msg, payload))
+ should_close = msg.should_close
else:
self._tail = data[start_pos:]
data = EMPTY
@@ -543,10 +544,8 @@ class HttpParser(abc.ABC, Generic[_MsgT]):
# chunking
te = headers.get(hdrs.TRANSFER_ENCODING)
if te is not None:
- if "chunked" == te.lower():
+ if self._is_chunked_te(te):
chunked = True
- else:
- raise BadHttpMessage("Request has invalid `Transfer-Encoding`")
if hdrs.CONTENT_LENGTH in headers:
raise BadHttpMessage(
@@ -656,6 +655,12 @@ class HttpRequestParser(HttpParser[RawRequestMessage]):
url,
)
+ def _is_chunked_te(self, te: str) -> bool:
+ if te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked":
+ return True
+ # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3
+ raise BadHttpMessage("Request has invalid `Transfer-Encoding`")
+
class HttpResponseParser(HttpParser[RawResponseMessage]):
"""Read response status line and headers.
@@ -741,6 +746,10 @@ class HttpResponseParser(HttpParser[RawResponseMessage]):
chunked,
)
+ def _is_chunked_te(self, te: str) -> bool:
+ # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.2
+ return te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked"
+
class HttpPayloadParser:
def __init__(
@@ -751,13 +760,12 @@ class HttpPayloadParser:
compression: Optional[str] = None,
code: Optional[int] = None,
method: Optional[str] = None,
- readall: bool = False,
response_with_body: bool = True,
auto_decompress: bool = True,
lax: bool = False,
) -> None:
self._length = 0
- self._type = ParseState.PARSE_NONE
+ self._type = ParseState.PARSE_UNTIL_EOF
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
self._chunk_size = 0
self._chunk_tail = b""
@@ -779,7 +787,6 @@ class HttpPayloadParser:
self._type = ParseState.PARSE_NONE
real_payload.feed_eof()
self.done = True
-
elif chunked:
self._type = ParseState.PARSE_CHUNKED
elif length is not None:
@@ -788,16 +795,6 @@ class HttpPayloadParser:
if self._length == 0:
real_payload.feed_eof()
self.done = True
- else:
- if readall and code != 204:
- self._type = ParseState.PARSE_UNTIL_EOF
- elif method in ("PUT", "POST"):
- internal_logger.warning( # pragma: no cover
- "Content-Length or Transfer-Encoding header is required"
- )
- self._type = ParseState.PARSE_NONE
- real_payload.feed_eof()
- self.done = True
self.payload = real_payload
@@ -888,13 +885,13 @@ class HttpPayloadParser:
self._chunk_size = 0
self.payload.feed_data(chunk[:required], required)
chunk = chunk[required:]
- if self._lax and chunk.startswith(b"\r"):
- chunk = chunk[1:]
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF
self.payload.end_http_chunk_receiving()
# toss the CRLF at the end of the chunk
if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF:
+ if self._lax and chunk.startswith(b"\r"):
+ chunk = chunk[1:]
if chunk[: len(SEP)] == SEP:
chunk = chunk[len(SEP) :]
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
diff --git a/contrib/python/aiohttp/aiohttp/http_websocket.py b/contrib/python/aiohttp/aiohttp/http_websocket.py
index 39f2e4a5c15..fb00ebc7d35 100644
--- a/contrib/python/aiohttp/aiohttp/http_websocket.py
+++ b/contrib/python/aiohttp/aiohttp/http_websocket.py
@@ -8,6 +8,7 @@ import re
import sys
import zlib
from enum import IntEnum
+from functools import partial
from struct import Struct
from typing import (
Any,
@@ -24,6 +25,7 @@ from typing import (
)
from .base_protocol import BaseProtocol
+from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS, set_exception
from .streams import DataQueue
@@ -93,6 +95,14 @@ class WSMsgType(IntEnum):
error = ERROR
+MESSAGE_TYPES_WITH_CONTENT: Final = frozenset(
+ {
+ WSMsgType.BINARY,
+ WSMsgType.TEXT,
+ WSMsgType.CONTINUATION,
+ }
+)
+
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
@@ -103,8 +113,10 @@ PACK_LEN1 = Struct("!BB").pack
PACK_LEN2 = Struct("!BBH").pack
PACK_LEN3 = Struct("!BBQ").pack
PACK_CLOSE_CODE = Struct("!H").pack
+PACK_RANDBITS = Struct("!L").pack
MSG_SIZE: Final[int] = 2**14
DEFAULT_LIMIT: Final[int] = 2**16
+MASK_LEN: Final[int] = 4
class WSMessage(NamedTuple):
@@ -294,7 +306,7 @@ class WebSocketReader:
self._frame_opcode: Optional[int] = None
self._frame_payload = bytearray()
- self._tail = b""
+ self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_length = 0
@@ -311,17 +323,101 @@ class WebSocketReader:
return True, data
try:
- return self._feed_data(data)
+ self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return True, b""
- def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
+ return False, b""
+
+ def _feed_data(self, data: bytes) -> None:
for fin, opcode, payload, compressed in self.parse_frame(data):
- if compressed and not self._decompressobj:
- self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
- if opcode == WSMsgType.CLOSE:
+ if opcode in MESSAGE_TYPES_WITH_CONTENT:
+ # load text/binary
+ is_continuation = opcode == WSMsgType.CONTINUATION
+ if not fin:
+ # got partial frame payload
+ if not is_continuation:
+ self._opcode = opcode
+ self._partial += payload
+ if self._max_msg_size and len(self._partial) >= self._max_msg_size:
+ raise WebSocketError(
+ WSCloseCode.MESSAGE_TOO_BIG,
+ "Message size {} exceeds limit {}".format(
+ len(self._partial), self._max_msg_size
+ ),
+ )
+ continue
+
+ has_partial = bool(self._partial)
+ if is_continuation:
+ if self._opcode is None:
+ raise WebSocketError(
+ WSCloseCode.PROTOCOL_ERROR,
+ "Continuation frame for non started message",
+ )
+ opcode = self._opcode
+ self._opcode = None
+ # previous frame was non finished
+ # we should get continuation opcode
+ elif has_partial:
+ raise WebSocketError(
+ WSCloseCode.PROTOCOL_ERROR,
+ "The opcode in non-fin frame is expected "
+ "to be zero, got {!r}".format(opcode),
+ )
+
+ if has_partial:
+ assembled_payload = self._partial + payload
+ self._partial.clear()
+ else:
+ assembled_payload = payload
+
+ if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
+ raise WebSocketError(
+ WSCloseCode.MESSAGE_TOO_BIG,
+ "Message size {} exceeds limit {}".format(
+ len(assembled_payload), self._max_msg_size
+ ),
+ )
+
+ # Decompress process must to be done after all packets
+ # received.
+ if compressed:
+ if not self._decompressobj:
+ self._decompressobj = ZLibDecompressor(
+ suppress_deflate_header=True
+ )
+ payload_merged = self._decompressobj.decompress_sync(
+ assembled_payload + _WS_DEFLATE_TRAILING, self._max_msg_size
+ )
+ if self._decompressobj.unconsumed_tail:
+ left = len(self._decompressobj.unconsumed_tail)
+ raise WebSocketError(
+ WSCloseCode.MESSAGE_TOO_BIG,
+ "Decompressed message size {} exceeds limit {}".format(
+ self._max_msg_size + left, self._max_msg_size
+ ),
+ )
+ else:
+ payload_merged = bytes(assembled_payload)
+
+ if opcode == WSMsgType.TEXT:
+ try:
+ text = payload_merged.decode("utf-8")
+ except UnicodeDecodeError as exc:
+ raise WebSocketError(
+ WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
+ ) from exc
+
+ self.queue.feed_data(WSMessage(WSMsgType.TEXT, text, ""), len(text))
+ continue
+
+ self.queue.feed_data(
+ WSMessage(WSMsgType.BINARY, payload_merged, ""), len(payload_merged)
+ )
+ elif opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
@@ -356,241 +452,145 @@ class WebSocketReader:
WSMessage(WSMsgType.PONG, payload, ""), len(payload)
)
- elif (
- opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
- and self._opcode is None
- ):
+ else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
- else:
- # load text/binary
- if not fin:
- # got partial frame payload
- if opcode != WSMsgType.CONTINUATION:
- self._opcode = opcode
- self._partial.extend(payload)
- if self._max_msg_size and len(self._partial) >= self._max_msg_size:
- raise WebSocketError(
- WSCloseCode.MESSAGE_TOO_BIG,
- "Message size {} exceeds limit {}".format(
- len(self._partial), self._max_msg_size
- ),
- )
- else:
- # previous frame was non finished
- # we should get continuation opcode
- if self._partial:
- if opcode != WSMsgType.CONTINUATION:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "The opcode in non-fin frame is expected "
- "to be zero, got {!r}".format(opcode),
- )
-
- if opcode == WSMsgType.CONTINUATION:
- assert self._opcode is not None
- opcode = self._opcode
- self._opcode = None
-
- self._partial.extend(payload)
- if self._max_msg_size and len(self._partial) >= self._max_msg_size:
- raise WebSocketError(
- WSCloseCode.MESSAGE_TOO_BIG,
- "Message size {} exceeds limit {}".format(
- len(self._partial), self._max_msg_size
- ),
- )
-
- # Decompress process must to be done after all packets
- # received.
- if compressed:
- assert self._decompressobj is not None
- self._partial.extend(_WS_DEFLATE_TRAILING)
- payload_merged = self._decompressobj.decompress_sync(
- self._partial, self._max_msg_size
- )
- if self._decompressobj.unconsumed_tail:
- left = len(self._decompressobj.unconsumed_tail)
- raise WebSocketError(
- WSCloseCode.MESSAGE_TOO_BIG,
- "Decompressed message size {} exceeds limit {}".format(
- self._max_msg_size + left, self._max_msg_size
- ),
- )
- else:
- payload_merged = bytes(self._partial)
-
- self._partial.clear()
-
- if opcode == WSMsgType.TEXT:
- try:
- text = payload_merged.decode("utf-8")
- self.queue.feed_data(
- WSMessage(WSMsgType.TEXT, text, ""), len(text)
- )
- except UnicodeDecodeError as exc:
- raise WebSocketError(
- WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
- ) from exc
- else:
- self.queue.feed_data(
- WSMessage(WSMsgType.BINARY, payload_merged, ""),
- len(payload_merged),
- )
-
- return False, b""
def parse_frame(
self, buf: bytes
) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
"""Return the next frame from the socket."""
- frames = []
+ frames: List[Tuple[bool, Optional[int], bytearray, Optional[bool]]] = []
if self._tail:
buf, self._tail = self._tail + buf, b""
- start_pos = 0
+ start_pos: int = 0
buf_length = len(buf)
while True:
# read header
- if self._state == WSParserState.READ_HEADER:
- if buf_length - start_pos >= 2:
- data = buf[start_pos : start_pos + 2]
- start_pos += 2
- first_byte, second_byte = data
+ if self._state is WSParserState.READ_HEADER:
+ if buf_length - start_pos < 2:
+ break
+ data = buf[start_pos : start_pos + 2]
+ start_pos += 2
+ first_byte, second_byte = data
- fin = (first_byte >> 7) & 1
- rsv1 = (first_byte >> 6) & 1
- rsv2 = (first_byte >> 5) & 1
- rsv3 = (first_byte >> 4) & 1
- opcode = first_byte & 0xF
+ fin = (first_byte >> 7) & 1
+ rsv1 = (first_byte >> 6) & 1
+ rsv2 = (first_byte >> 5) & 1
+ rsv3 = (first_byte >> 4) & 1
+ opcode = first_byte & 0xF
- # frame-fin = %x0 ; more frames of this message follow
- # / %x1 ; final frame of this message
- # frame-rsv1 = %x0 ;
- # 1 bit, MUST be 0 unless negotiated otherwise
- # frame-rsv2 = %x0 ;
- # 1 bit, MUST be 0 unless negotiated otherwise
- # frame-rsv3 = %x0 ;
- # 1 bit, MUST be 0 unless negotiated otherwise
- #
- # Remove rsv1 from this test for deflate development
- if rsv2 or rsv3 or (rsv1 and not self._compress):
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Received frame with non-zero reserved bits",
- )
+ # frame-fin = %x0 ; more frames of this message follow
+ # / %x1 ; final frame of this message
+ # frame-rsv1 = %x0 ;
+ # 1 bit, MUST be 0 unless negotiated otherwise
+ # frame-rsv2 = %x0 ;
+ # 1 bit, MUST be 0 unless negotiated otherwise
+ # frame-rsv3 = %x0 ;
+ # 1 bit, MUST be 0 unless negotiated otherwise
+ #
+ # Remove rsv1 from this test for deflate development
+ if rsv2 or rsv3 or (rsv1 and not self._compress):
+ raise WebSocketError(
+ WSCloseCode.PROTOCOL_ERROR,
+ "Received frame with non-zero reserved bits",
+ )
- if opcode > 0x7 and fin == 0:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Received fragmented control frame",
- )
+ if opcode > 0x7 and fin == 0:
+ raise WebSocketError(
+ WSCloseCode.PROTOCOL_ERROR,
+ "Received fragmented control frame",
+ )
- has_mask = (second_byte >> 7) & 1
- length = second_byte & 0x7F
+ has_mask = (second_byte >> 7) & 1
+ length = second_byte & 0x7F
- # Control frames MUST have a payload
- # length of 125 bytes or less
- if opcode > 0x7 and length > 125:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Control frame payload cannot be " "larger than 125 bytes",
- )
+ # Control frames MUST have a payload
+ # length of 125 bytes or less
+ if opcode > 0x7 and length > 125:
+ raise WebSocketError(
+ WSCloseCode.PROTOCOL_ERROR,
+ "Control frame payload cannot be " "larger than 125 bytes",
+ )
- # Set compress status if last package is FIN
- # OR set compress status if this is first fragment
- # Raise error if not first fragment with rsv1 = 0x1
- if self._frame_fin or self._compressed is None:
- self._compressed = True if rsv1 else False
- elif rsv1:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Received frame with non-zero reserved bits",
- )
+ # Set compress status if last package is FIN
+ # OR set compress status if this is first fragment
+ # Raise error if not first fragment with rsv1 = 0x1
+ if self._frame_fin or self._compressed is None:
+ self._compressed = True if rsv1 else False
+ elif rsv1:
+ raise WebSocketError(
+ WSCloseCode.PROTOCOL_ERROR,
+ "Received frame with non-zero reserved bits",
+ )
- self._frame_fin = bool(fin)
- self._frame_opcode = opcode
- self._has_mask = bool(has_mask)
- self._payload_length_flag = length
- self._state = WSParserState.READ_PAYLOAD_LENGTH
- else:
- break
+ self._frame_fin = bool(fin)
+ self._frame_opcode = opcode
+ self._has_mask = bool(has_mask)
+ self._payload_length_flag = length
+ self._state = WSParserState.READ_PAYLOAD_LENGTH
# read payload length
- if self._state == WSParserState.READ_PAYLOAD_LENGTH:
- length = self._payload_length_flag
- if length == 126:
- if buf_length - start_pos >= 2:
- data = buf[start_pos : start_pos + 2]
- start_pos += 2
- length = UNPACK_LEN2(data)[0]
- self._payload_length = length
- self._state = (
- WSParserState.READ_PAYLOAD_MASK
- if self._has_mask
- else WSParserState.READ_PAYLOAD
- )
- else:
+ if self._state is WSParserState.READ_PAYLOAD_LENGTH:
+ length_flag = self._payload_length_flag
+ if length_flag == 126:
+ if buf_length - start_pos < 2:
break
- elif length > 126:
- if buf_length - start_pos >= 8:
- data = buf[start_pos : start_pos + 8]
- start_pos += 8
- length = UNPACK_LEN3(data)[0]
- self._payload_length = length
- self._state = (
- WSParserState.READ_PAYLOAD_MASK
- if self._has_mask
- else WSParserState.READ_PAYLOAD
- )
- else:
+ data = buf[start_pos : start_pos + 2]
+ start_pos += 2
+ self._payload_length = UNPACK_LEN2(data)[0]
+ elif length_flag > 126:
+ if buf_length - start_pos < 8:
break
+ data = buf[start_pos : start_pos + 8]
+ start_pos += 8
+ self._payload_length = UNPACK_LEN3(data)[0]
else:
- self._payload_length = length
- self._state = (
- WSParserState.READ_PAYLOAD_MASK
- if self._has_mask
- else WSParserState.READ_PAYLOAD
- )
+ self._payload_length = length_flag
+
+ self._state = (
+ WSParserState.READ_PAYLOAD_MASK
+ if self._has_mask
+ else WSParserState.READ_PAYLOAD
+ )
# read payload mask
- if self._state == WSParserState.READ_PAYLOAD_MASK:
- if buf_length - start_pos >= 4:
- self._frame_mask = buf[start_pos : start_pos + 4]
- start_pos += 4
- self._state = WSParserState.READ_PAYLOAD
- else:
+ if self._state is WSParserState.READ_PAYLOAD_MASK:
+ if buf_length - start_pos < 4:
break
+ self._frame_mask = buf[start_pos : start_pos + 4]
+ start_pos += 4
+ self._state = WSParserState.READ_PAYLOAD
- if self._state == WSParserState.READ_PAYLOAD:
+ if self._state is WSParserState.READ_PAYLOAD:
length = self._payload_length
payload = self._frame_payload
chunk_len = buf_length - start_pos
if length >= chunk_len:
self._payload_length = length - chunk_len
- payload.extend(buf[start_pos:])
+ payload += buf[start_pos:]
start_pos = buf_length
else:
self._payload_length = 0
- payload.extend(buf[start_pos : start_pos + length])
+ payload += buf[start_pos : start_pos + length]
start_pos = start_pos + length
- if self._payload_length == 0:
- if self._has_mask:
- assert self._frame_mask is not None
- _websocket_mask(self._frame_mask, payload)
+ if self._payload_length != 0:
+ break
- frames.append(
- (self._frame_fin, self._frame_opcode, payload, self._compressed)
- )
+ if self._has_mask:
+ assert self._frame_mask is not None
+ _websocket_mask(self._frame_mask, payload)
- self._frame_payload = bytearray()
- self._state = WSParserState.READ_HEADER
- else:
- break
+ frames.append(
+ (self._frame_fin, self._frame_opcode, payload, self._compressed)
+ )
+ self._frame_payload = bytearray()
+ self._state = WSParserState.READ_HEADER
self._tail = buf[start_pos:]
@@ -612,7 +612,7 @@ class WebSocketWriter:
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
- self.randrange = random.randrange
+ self.get_random_bits = partial(random.getrandbits, 32)
self.compress = compress
self.notakeover = notakeover
self._closing = False
@@ -625,14 +625,20 @@ class WebSocketWriter:
) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing and not (opcode & WSMsgType.CLOSE):
- raise ConnectionResetError("Cannot write to closing transport")
+ raise ClientConnectionResetError("Cannot write to closing transport")
+ # RSV are the reserved bits in the frame header. They are used to
+ # indicate that the frame is using an extension.
+ # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
rsv = 0
-
# Only compress larger packets (disabled)
# Does small packet needs to be compressed?
# if self.compress and opcode < 8 and len(message) > 124:
if (compress or self.compress) and opcode < 8:
+ # RSV1 (rsv = 0x40) is set for compressed frames
+ # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
+ rsv = 0x40
+
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = self._make_compress_obj(compress)
@@ -651,29 +657,39 @@ class WebSocketWriter:
)
if message.endswith(_WS_DEFLATE_TRAILING):
message = message[:-4]
- rsv = rsv | 0x40
msg_length = len(message)
use_mask = self.use_mask
- if use_mask:
- mask_bit = 0x80
- else:
- mask_bit = 0
+ mask_bit = 0x80 if use_mask else 0
+ # Depending on the message length, the header is assembled differently.
+ # The first byte is reserved for the opcode and the RSV bits.
+ first_byte = 0x80 | rsv | opcode
if msg_length < 126:
- header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
+ header = PACK_LEN1(first_byte, msg_length | mask_bit)
+ header_len = 2
elif msg_length < (1 << 16):
- header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
+ header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
+ header_len = 4
else:
- header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
+ header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
+ header_len = 10
+
+ # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
+ # If we are using a mask, we need to generate it randomly
+ # and apply it to the message before sending it. A mask is
+ # a 32-bit value that is applied to the message using a
+ # bitwise XOR operation. It is used to prevent certain types
+ # of attacks on the websocket protocol. The mask is only used
+ # when aiohttp is acting as a client. Servers do not use a mask.
if use_mask:
- mask_int = self.randrange(0, 0xFFFFFFFF)
- mask = mask_int.to_bytes(4, "big")
+ mask = PACK_RANDBITS(self.get_random_bits())
message = bytearray(message)
_websocket_mask(mask, message)
self._write(header + mask + message)
- self._output_size += len(header) + len(mask) + msg_length
+ self._output_size += header_len + MASK_LEN + msg_length
+
else:
if msg_length > MSG_SIZE:
self._write(header)
@@ -681,11 +697,16 @@ class WebSocketWriter:
else:
self._write(header + message)
- self._output_size += len(header) + msg_length
+ self._output_size += header_len + msg_length
# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.
+ # Once we have written output_size up to the limit, we call the
+ # drain helper which waits for the transport to be ready to accept
+ # more data. This is a flow control mechanism to prevent the buffer
+ # from growing too large. The drain helper will return right away
+ # if the writer is not paused.
if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()
@@ -699,7 +720,7 @@ class WebSocketWriter:
def _write(self, data: bytes) -> None:
if self.transport is None or self.transport.is_closing():
- raise ConnectionResetError("Cannot write to closing transport")
+ raise ClientConnectionResetError("Cannot write to closing transport")
self.transport.write(data)
async def pong(self, message: Union[bytes, str] = b"") -> None:
diff --git a/contrib/python/aiohttp/aiohttp/http_writer.py b/contrib/python/aiohttp/aiohttp/http_writer.py
index d6b02e6f566..dc07a358c70 100644
--- a/contrib/python/aiohttp/aiohttp/http_writer.py
+++ b/contrib/python/aiohttp/aiohttp/http_writer.py
@@ -8,6 +8,7 @@ from multidict import CIMultiDict
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
+from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS
@@ -70,9 +71,9 @@ class StreamWriter(AbstractStreamWriter):
size = len(chunk)
self.buffer_size += size
self.output_size += size
- transport = self.transport
- if not self._protocol.connected or transport is None or transport.is_closing():
- raise ConnectionResetError("Cannot write to closing transport")
+ transport = self._protocol.transport
+ if transport is None or transport.is_closing():
+ raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)
async def write(
diff --git a/contrib/python/aiohttp/aiohttp/multipart.py b/contrib/python/aiohttp/aiohttp/multipart.py
index 71fc2654a1c..49c05c5af25 100644
--- a/contrib/python/aiohttp/aiohttp/multipart.py
+++ b/contrib/python/aiohttp/aiohttp/multipart.py
@@ -2,6 +2,7 @@ import base64
import binascii
import json
import re
+import sys
import uuid
import warnings
import zlib
@@ -10,7 +11,6 @@ from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
- AsyncIterator,
Deque,
Dict,
Iterator,
@@ -48,6 +48,13 @@ from .payload import (
)
from .streams import StreamReader
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing import TypeVar
+
+ Self = TypeVar("Self", bound="BodyPartReader")
+
__all__ = (
"MultipartReader",
"MultipartWriter",
@@ -266,6 +273,7 @@ class BodyPartReader:
) -> None:
self.headers = headers
self._boundary = boundary
+ self._boundary_len = len(boundary) + 2 # Boundary + \r\n
self._content = content
self._default_charset = default_charset
self._at_eof = False
@@ -279,8 +287,8 @@ class BodyPartReader:
self._content_eof = 0
self._cache: Dict[str, Any] = {}
- def __aiter__(self) -> AsyncIterator["BodyPartReader"]:
- return self # type: ignore[return-value]
+ def __aiter__(self: Self):
+ return self
async def __anext__(self) -> bytes:
part = await self.next()
@@ -322,6 +330,31 @@ class BodyPartReader:
else:
chunk = await self._read_chunk_from_stream(size)
+ # For the case of base64 data, we must read a fragment of size with a
+ # remainder of 0 by dividing by 4 for string without symbols \n or \r
+ encoding = self.headers.get(CONTENT_TRANSFER_ENCODING)
+ if encoding and encoding.lower() == "base64":
+ stripped_chunk = b"".join(chunk.split())
+ remainder = len(stripped_chunk) % 4
+
+ while remainder != 0 and not self.at_eof():
+ over_chunk_size = 4 - remainder
+ over_chunk = b""
+
+ if self._prev_chunk:
+ over_chunk = self._prev_chunk[:over_chunk_size]
+ self._prev_chunk = self._prev_chunk[len(over_chunk) :]
+
+ if len(over_chunk) != over_chunk_size:
+ over_chunk += await self._content.read(4 - len(over_chunk))
+
+ if not over_chunk:
+ self._at_eof = True
+
+ stripped_chunk += b"".join(over_chunk.split())
+ chunk += over_chunk
+ remainder = len(stripped_chunk) % 4
+
self._read_bytes += len(chunk)
if self._read_bytes == self._length:
self._at_eof = True
@@ -346,15 +379,25 @@ class BodyPartReader:
# Reads content chunk of body part with unknown length.
# The Content-Length header for body part is not necessary.
assert (
- size >= len(self._boundary) + 2
+ size >= self._boundary_len
), "Chunk size must be greater or equal than boundary length + 2"
first_chunk = self._prev_chunk is None
if first_chunk:
self._prev_chunk = await self._content.read(size)
- chunk = await self._content.read(size)
- self._content_eof += int(self._content.at_eof())
- assert self._content_eof < 3, "Reading after EOF"
+ chunk = b""
+ # content.read() may return less than size, so we need to loop to ensure
+ # we have enough data to detect the boundary.
+ while len(chunk) < self._boundary_len:
+ chunk += await self._content.read(size)
+ self._content_eof += int(self._content.at_eof())
+ assert self._content_eof < 3, "Reading after EOF"
+ if self._content_eof:
+ break
+ if len(chunk) > size:
+ self._content.unread_data(chunk[size:])
+ chunk = chunk[:size]
+
assert self._prev_chunk is not None
window = self._prev_chunk + chunk
sub = b"\r\n" + self._boundary
@@ -518,6 +561,8 @@ class BodyPartReader:
@payload_type(BodyPartReader, order=Order.try_first)
class BodyPartReaderPayload(Payload):
+ _value: BodyPartReader
+
def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
super().__init__(value, *args, **kwargs)
@@ -530,6 +575,9 @@ class BodyPartReaderPayload(Payload):
if params:
self.set_content_disposition("attachment", True, **params)
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ raise TypeError("Unable to decode.")
+
async def write(self, writer: Any) -> None:
field = self._value
chunk = await field.read_chunk(size=2**16)
@@ -566,10 +614,8 @@ class MultipartReader:
self._at_bof = True
self._unread: List[bytes] = []
- def __aiter__(
- self,
- ) -> AsyncIterator["BodyPartReader"]:
- return self # type: ignore[return-value]
+ def __aiter__(self: Self):
+ return self
async def __anext__(
self,
@@ -749,6 +795,8 @@ _Part = Tuple[Payload, str, str]
class MultipartWriter(Payload):
"""Multipart body writer."""
+ _value: None
+
def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
boundary = boundary if boundary is not None else uuid.uuid4().hex
# The underlying Payload API demands a str (utf-8), not bytes,
@@ -929,6 +977,16 @@ class MultipartWriter(Payload):
total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
return total
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return "".join(
+ "--"
+ + self.boundary
+ + "\n"
+ + part._binary_headers.decode(encoding, errors)
+ + part.decode()
+ for part, _e, _te in self._parts
+ )
+
async def write(self, writer: Any, close_boundary: bool = True) -> None:
"""Write body."""
for part, encoding, te_encoding in self._parts:
diff --git a/contrib/python/aiohttp/aiohttp/payload.py b/contrib/python/aiohttp/aiohttp/payload.py
index 6593b05c6f7..27636977774 100644
--- a/contrib/python/aiohttp/aiohttp/payload.py
+++ b/contrib/python/aiohttp/aiohttp/payload.py
@@ -11,7 +11,6 @@ from typing import (
IO,
TYPE_CHECKING,
Any,
- ByteString,
Dict,
Final,
Iterable,
@@ -209,6 +208,13 @@ class Payload(ABC):
)
@abstractmethod
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ """Return string representation of the value.
+
+ This is named decode() to allow compatibility with bytes objects.
+ """
+
+ @abstractmethod
async def write(self, writer: AbstractStreamWriter) -> None:
"""Write payload.
@@ -217,7 +223,11 @@ class Payload(ABC):
class BytesPayload(Payload):
- def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None:
+ _value: bytes
+
+ def __init__(
+ self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any
+ ) -> None:
if not isinstance(value, (bytes, bytearray, memoryview)):
raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
@@ -241,6 +251,9 @@ class BytesPayload(Payload):
**kwargs,
)
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return self._value.decode(encoding, errors)
+
async def write(self, writer: AbstractStreamWriter) -> None:
await writer.write(self._value)
@@ -282,7 +295,7 @@ class StringIOPayload(StringPayload):
class IOBasePayload(Payload):
- _value: IO[Any]
+ _value: io.IOBase
def __init__(
self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
@@ -306,9 +319,12 @@ class IOBasePayload(Payload):
finally:
await loop.run_in_executor(None, self._value.close)
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return "".join(r.decode(encoding, errors) for r in self._value.readlines())
+
class TextIOPayload(IOBasePayload):
- _value: TextIO
+ _value: io.TextIOBase
def __init__(
self,
@@ -345,6 +361,9 @@ class TextIOPayload(IOBasePayload):
except OSError:
return None
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return self._value.read()
+
async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
@@ -362,6 +381,8 @@ class TextIOPayload(IOBasePayload):
class BytesIOPayload(IOBasePayload):
+ _value: io.BytesIO
+
@property
def size(self) -> int:
position = self._value.tell()
@@ -369,17 +390,27 @@ class BytesIOPayload(IOBasePayload):
self._value.seek(position)
return end - position
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return self._value.read().decode(encoding, errors)
+
class BufferedReaderPayload(IOBasePayload):
+ _value: io.BufferedIOBase
+
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
- except OSError:
+ except (OSError, AttributeError):
# data.fileno() is not supported, e.g.
# io.BufferedReader(io.BytesIO(b'data'))
+ # For some file-like objects (e.g. tarfile), the fileno() attribute may
+ # not exist at all, and will instead raise an AttributeError.
return None
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return self._value.read().decode(encoding, errors)
+
class JsonPayload(BytesPayload):
def __init__(
@@ -416,6 +447,7 @@ else:
class AsyncIterablePayload(Payload):
_iter: Optional[_AsyncIterator] = None
+ _value: _AsyncIterable
def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
if not isinstance(value, AsyncIterable):
@@ -443,6 +475,9 @@ class AsyncIterablePayload(Payload):
except StopAsyncIteration:
self._iter = None
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ raise TypeError("Unable to decode.")
+
class StreamReaderPayload(AsyncIterablePayload):
def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
diff --git a/contrib/python/aiohttp/aiohttp/payload_streamer.py b/contrib/python/aiohttp/aiohttp/payload_streamer.py
index 364f763ae74..831fdc0a77f 100644
--- a/contrib/python/aiohttp/aiohttp/payload_streamer.py
+++ b/contrib/python/aiohttp/aiohttp/payload_streamer.py
@@ -65,6 +65,9 @@ class StreamWrapperPayload(Payload):
async def write(self, writer: AbstractStreamWriter) -> None:
await self._value(writer)
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ raise TypeError("Unable to decode.")
+
@payload_type(streamer)
class StreamPayload(StreamWrapperPayload):
diff --git a/contrib/python/aiohttp/aiohttp/pytest_plugin.py b/contrib/python/aiohttp/aiohttp/pytest_plugin.py
index 5754747bf48..55964ead041 100644
--- a/contrib/python/aiohttp/aiohttp/pytest_plugin.py
+++ b/contrib/python/aiohttp/aiohttp/pytest_plugin.py
@@ -1,13 +1,21 @@
import asyncio
import contextlib
+import inspect
import warnings
-from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Type, Union
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterator,
+ Optional,
+ Protocol,
+ Type,
+ Union,
+)
import pytest
-from aiohttp.helpers import isasyncgenfunction
-from aiohttp.web import Application
-
from .test_utils import (
BaseTestServer,
RawTestServer,
@@ -18,15 +26,35 @@ from .test_utils import (
teardown_test_loop,
unused_port as _unused_port,
)
+from .web import Application
+from .web_protocol import _RequestHandler
try:
import uvloop
except ImportError: # pragma: no cover
uvloop = None # type: ignore[assignment]
-AiohttpClient = Callable[[Union[Application, BaseTestServer]], Awaitable[TestClient]]
-AiohttpRawServer = Callable[[Application], Awaitable[RawTestServer]]
-AiohttpServer = Callable[[Application], Awaitable[TestServer]]
+
+class AiohttpClient(Protocol):
+ def __call__(
+ self,
+ __param: Union[Application, BaseTestServer],
+ *,
+ server_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs: Any
+ ) -> Awaitable[TestClient]: ...
+
+
+class AiohttpServer(Protocol):
+ def __call__(
+ self, app: Application, *, port: Optional[int] = None, **kwargs: Any
+ ) -> Awaitable[TestServer]: ...
+
+
+class AiohttpRawServer(Protocol):
+ def __call__(
+ self, handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
+ ) -> Awaitable[RawTestServer]: ...
def pytest_addoption(parser): # type: ignore[no-untyped-def]
@@ -57,7 +85,7 @@ def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def]
"""
func = fixturedef.func
- if isasyncgenfunction(func):
+ if inspect.isasyncgenfunction(func):
# async generator fixture
is_async_gen = True
elif asyncio.iscoroutinefunction(func):
@@ -262,7 +290,9 @@ def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]:
"""
servers = []
- async def go(app, *, port=None, **kwargs): # type: ignore[no-untyped-def]
+ async def go(
+ app: Application, *, port: Optional[int] = None, **kwargs: Any
+ ) -> TestServer:
server = TestServer(app, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
@@ -295,7 +325,9 @@ def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawSe
"""
servers = []
- async def go(handler, *, port=None, **kwargs): # type: ignore[no-untyped-def]
+ async def go(
+ handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
+ ) -> RawTestServer:
server = RawTestServer(handler, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
diff --git a/contrib/python/aiohttp/aiohttp/resolver.py b/contrib/python/aiohttp/aiohttp/resolver.py
index c03230c744e..06855fa13fd 100644
--- a/contrib/python/aiohttp/aiohttp/resolver.py
+++ b/contrib/python/aiohttp/aiohttp/resolver.py
@@ -1,20 +1,25 @@
import asyncio
import socket
-from typing import Any, Dict, List, Optional, Type, Union
+import sys
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
-from .abc import AbstractResolver
-from .helpers import get_running_loop
+from .abc import AbstractResolver, ResolveResult
__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
+
try:
import aiodns
- # aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
+ aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
except ImportError: # pragma: no cover
- aiodns = None
+ aiodns = None # type: ignore[assignment]
+ aiodns_default = False
+
-aiodns_default = False
+_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
+_NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
+_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)
class ThreadedResolver(AbstractResolver):
@@ -25,48 +30,48 @@ class ThreadedResolver(AbstractResolver):
"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
- self._loop = get_running_loop(loop)
+ self._loop = loop or asyncio.get_event_loop()
async def resolve(
- self, hostname: str, port: int = 0, family: int = socket.AF_INET
- ) -> List[Dict[str, Any]]:
+ self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
+ ) -> List[ResolveResult]:
infos = await self._loop.getaddrinfo(
- hostname,
+ host,
port,
type=socket.SOCK_STREAM,
family=family,
# flags=socket.AI_ADDRCONFIG,
)
- hosts = []
+ hosts: List[ResolveResult] = []
for family, _, proto, _, address in infos:
if family == socket.AF_INET6:
if len(address) < 3:
# IPv6 is not supported by Python build,
# or IPv6 is not enabled in the host
continue
- if address[3]:
+ if address[3] and _SUPPORTS_SCOPE_ID:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
- host, _port = socket.getnameinfo(
- address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
+ resolved_host, _port = await self._loop.getnameinfo(
+ address, _NAME_SOCKET_FLAGS
)
port = int(_port)
else:
- host, port = address[:2]
+ resolved_host, port = address[:2]
else: # IPv4
assert family == socket.AF_INET
- host, port = address # type: ignore[misc]
+ resolved_host, port = address # type: ignore[misc]
hosts.append(
- {
- "hostname": hostname,
- "host": host,
- "port": port,
- "family": family,
- "proto": proto,
- "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
- }
+ ResolveResult(
+ hostname=host,
+ host=resolved_host,
+ port=port,
+ family=family,
+ proto=proto,
+ flags=_NUMERIC_SOCKET_FLAGS,
+ )
)
return hosts
@@ -87,32 +92,56 @@ class AsyncResolver(AbstractResolver):
if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")
- self._loop = get_running_loop(loop)
- self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs)
+ self._resolver = aiodns.DNSResolver(*args, **kwargs)
if not hasattr(self._resolver, "gethostbyname"):
# aiodns 1.1 is not available, fallback to DNSResolver.query
self.resolve = self._resolve_with_query # type: ignore
async def resolve(
- self, host: str, port: int = 0, family: int = socket.AF_INET
- ) -> List[Dict[str, Any]]:
+ self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
+ ) -> List[ResolveResult]:
try:
- resp = await self._resolver.gethostbyname(host, family)
+ resp = await self._resolver.getaddrinfo(
+ host,
+ port=port,
+ type=socket.SOCK_STREAM,
+ family=family,
+ flags=socket.AI_ADDRCONFIG,
+ )
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
- hosts = []
- for address in resp.addresses:
+ hosts: List[ResolveResult] = []
+ for node in resp.nodes:
+ address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
+ family = node.family
+ if family == socket.AF_INET6:
+ if len(address) > 3 and address[3] and _SUPPORTS_SCOPE_ID:
+ # This is essential for link-local IPv6 addresses.
+ # LL IPv6 is a VERY rare case. Strictly speaking, we should use
+ # getnameinfo() unconditionally, but performance makes sense.
+ result = await self._resolver.getnameinfo(
+ (address[0].decode("ascii"), *address[1:]),
+ _NAME_SOCKET_FLAGS,
+ )
+ resolved_host = result.node
+ else:
+ resolved_host = address[0].decode("ascii")
+ port = address[1]
+ else: # IPv4
+ assert family == socket.AF_INET
+ resolved_host = address[0].decode("ascii")
+ port = address[1]
hosts.append(
- {
- "hostname": host,
- "host": address,
- "port": port,
- "family": family,
- "proto": 0,
- "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
- }
+ ResolveResult(
+ hostname=host,
+ host=resolved_host,
+ port=port,
+ family=family,
+ proto=0,
+ flags=_NUMERIC_SOCKET_FLAGS,
+ )
)
if not hosts:
diff --git a/contrib/python/aiohttp/aiohttp/streams.py b/contrib/python/aiohttp/aiohttp/streams.py
index b9b9c3fd96f..c927cfbb1b3 100644
--- a/contrib/python/aiohttp/aiohttp/streams.py
+++ b/contrib/python/aiohttp/aiohttp/streams.py
@@ -296,6 +296,9 @@ class StreamReader(AsyncStreamReaderMixin):
set_result(waiter, None)
async def _wait(self, func_name: str) -> None:
+ if not self._protocol.connected:
+ raise RuntimeError("Connection closed.")
+
# StreamReader uses a future to link the protocol feed_data() method
# to a read coroutine. Running two read coroutines at the same time
# would have an unexpected behaviour. It would not possible to know
diff --git a/contrib/python/aiohttp/aiohttp/test_utils.py b/contrib/python/aiohttp/aiohttp/test_utils.py
index a36e8599689..01496b6711a 100644
--- a/contrib/python/aiohttp/aiohttp/test_utils.py
+++ b/contrib/python/aiohttp/aiohttp/test_utils.py
@@ -11,17 +11,7 @@ import sys
import warnings
from abc import ABC, abstractmethod
from types import TracebackType
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Iterator,
- List,
- Optional,
- Type,
- Union,
- cast,
-)
+from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Type, cast
from unittest import IsolatedAsyncioTestCase, mock
from aiosignal import Signal
@@ -29,7 +19,11 @@ from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL
import aiohttp
-from aiohttp.client import _RequestContextManager, _WSRequestContextManager
+from aiohttp.client import (
+ _RequestContextManager,
+ _RequestOptions,
+ _WSRequestContextManager,
+)
from . import ClientSession, hdrs
from .abc import AbstractCookieJar
@@ -37,6 +31,7 @@ from .client_reqrep import ClientResponse
from .client_ws import ClientWebSocketResponse
from .helpers import sentinel
from .http import HttpVersion, RawRequestMessage
+from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import StrOrURL
from .web import (
Application,
@@ -55,6 +50,9 @@ if TYPE_CHECKING:
else:
SSLContext = None
+if sys.version_info >= (3, 11) and TYPE_CHECKING:
+ from typing import Unpack
+
REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
@@ -90,7 +88,7 @@ class BaseTestServer(ABC):
def __init__(
self,
*,
- scheme: Union[str, object] = sentinel,
+ scheme: str = "",
loop: Optional[asyncio.AbstractEventLoop] = None,
host: str = "127.0.0.1",
port: Optional[int] = None,
@@ -121,10 +119,13 @@ class BaseTestServer(ABC):
await self.runner.setup()
if not self.port:
self.port = 0
+ absolute_host = self.host
try:
version = ipaddress.ip_address(self.host).version
except ValueError:
version = 4
+ if version == 6:
+ absolute_host = f"[{self.host}]"
family = socket.AF_INET6 if version == 6 else socket.AF_INET
_sock = self.socket_factory(self.host, self.port, family)
self.host, self.port = _sock.getsockname()[:2]
@@ -135,13 +136,9 @@ class BaseTestServer(ABC):
sockets = server.sockets # type: ignore[attr-defined]
assert sockets is not None
self.port = sockets[0].getsockname()[1]
- if self.scheme is sentinel:
- if self._ssl:
- scheme = "https"
- else:
- scheme = "http"
- self.scheme = scheme
- self._root = URL(f"{self.scheme}://{self.host}:{self.port}")
+ if not self.scheme:
+ self.scheme = "https" if self._ssl else "http"
+ self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
@abstractmethod # pragma: no cover
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
@@ -222,7 +219,7 @@ class TestServer(BaseTestServer):
self,
app: Application,
*,
- scheme: Union[str, object] = sentinel,
+ scheme: str = "",
host: str = "127.0.0.1",
port: Optional[int] = None,
**kwargs: Any,
@@ -239,7 +236,7 @@ class RawTestServer(BaseTestServer):
self,
handler: _RequestHandler,
*,
- scheme: Union[str, object] = sentinel,
+ scheme: str = "",
host: str = "127.0.0.1",
port: Optional[int] = None,
**kwargs: Any,
@@ -324,45 +321,101 @@ class TestClient:
self._responses.append(resp)
return resp
- def request(
- self, method: str, path: StrOrURL, **kwargs: Any
- ) -> _RequestContextManager:
- """Routes a request to tested http server.
+ if sys.version_info >= (3, 11) and TYPE_CHECKING:
- The interface is identical to aiohttp.ClientSession.request,
- except the loop kwarg is overridden by the instance used by the
- test server.
+ def request(
+ self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions]
+ ) -> _RequestContextManager: ...
- """
- return _RequestContextManager(self._request(method, path, **kwargs))
+ def get(
+ self,
+ path: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> _RequestContextManager: ...
+
+ def options(
+ self,
+ path: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> _RequestContextManager: ...
+
+ def head(
+ self,
+ path: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> _RequestContextManager: ...
+
+ def post(
+ self,
+ path: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> _RequestContextManager: ...
+
+ def put(
+ self,
+ path: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> _RequestContextManager: ...
+
+ def patch(
+ self,
+ path: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> _RequestContextManager: ...
- def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
- """Perform an HTTP GET request."""
- return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
+ def delete(
+ self,
+ path: StrOrURL,
+ **kwargs: Unpack[_RequestOptions],
+ ) -> _RequestContextManager: ...
- def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
- """Perform an HTTP POST request."""
- return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
+ else:
+
+ def request(
+ self, method: str, path: StrOrURL, **kwargs: Any
+ ) -> _RequestContextManager:
+ """Routes a request to tested http server.
+
+ The interface is identical to aiohttp.ClientSession.request,
+ except the loop kwarg is overridden by the instance used by the
+ test server.
+
+ """
+ return _RequestContextManager(self._request(method, path, **kwargs))
+
+ def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+ """Perform an HTTP GET request."""
+ return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
- def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
- """Perform an HTTP OPTIONS request."""
- return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs))
+ def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+ """Perform an HTTP POST request."""
+ return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
- def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
- """Perform an HTTP HEAD request."""
- return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
+ def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+ """Perform an HTTP OPTIONS request."""
+ return _RequestContextManager(
+ self._request(hdrs.METH_OPTIONS, path, **kwargs)
+ )
- def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
- """Perform an HTTP PUT request."""
- return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
+ def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+ """Perform an HTTP HEAD request."""
+ return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
- def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
- """Perform an HTTP PATCH request."""
- return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs))
+ def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+ """Perform an HTTP PUT request."""
+ return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
- def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
- """Perform an HTTP PATCH request."""
- return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs))
+ def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+ """Perform an HTTP PATCH request."""
+ return _RequestContextManager(
+ self._request(hdrs.METH_PATCH, path, **kwargs)
+ )
+
+ def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
+ """Perform an HTTP PATCH request."""
+ return _RequestContextManager(
+ self._request(hdrs.METH_DELETE, path, **kwargs)
+ )
def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
"""Initiate websocket connection.
@@ -582,7 +635,7 @@ def make_mocked_request(
writer: Any = sentinel,
protocol: Any = sentinel,
transport: Any = sentinel,
- payload: Any = sentinel,
+ payload: StreamReader = EMPTY_PAYLOAD,
sslcontext: Optional[SSLContext] = None,
client_max_size: int = 1024**2,
loop: Any = ...,
@@ -651,9 +704,6 @@ def make_mocked_request(
protocol.transport = transport
protocol.writer = writer
- if payload is sentinel:
- payload = mock.Mock()
-
req = Request(
message, payload, protocol, writer, task, loop, client_max_size=client_max_size
)
diff --git a/contrib/python/aiohttp/aiohttp/tracing.py b/contrib/python/aiohttp/aiohttp/tracing.py
index fe3eda9abb7..067a132464e 100644
--- a/contrib/python/aiohttp/aiohttp/tracing.py
+++ b/contrib/python/aiohttp/aiohttp/tracing.py
@@ -1,5 +1,5 @@
from types import SimpleNamespace
-from typing import TYPE_CHECKING, Awaitable, Optional, Protocol, Type, TypeVar
+from typing import TYPE_CHECKING, Awaitable, Mapping, Optional, Protocol, Type, TypeVar
import attr
from aiosignal import Signal
@@ -42,59 +42,29 @@ class TraceConfig:
def __init__(
self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace
) -> None:
- self._on_request_start: _TracingSignal[
- TraceRequestStartParams
- ] = Signal(self)
- self._on_request_chunk_sent: _TracingSignal[
- TraceRequestChunkSentParams
- ] = Signal(self)
- self._on_response_chunk_received: _TracingSignal[
- TraceResponseChunkReceivedParams
- ] = Signal(self)
- self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal(
- self
+ self._on_request_start: _TracingSignal[TraceRequestStartParams] = (
+ Signal(self)
)
- self._on_request_exception: _TracingSignal[
- TraceRequestExceptionParams
- ] = Signal(self)
- self._on_request_redirect: _TracingSignal[
- TraceRequestRedirectParams
- ] = Signal(self)
- self._on_connection_queued_start: _TracingSignal[
- TraceConnectionQueuedStartParams
- ] = Signal(self)
- self._on_connection_queued_end: _TracingSignal[
- TraceConnectionQueuedEndParams
- ] = Signal(self)
- self._on_connection_create_start: _TracingSignal[
- TraceConnectionCreateStartParams
- ] = Signal(self)
- self._on_connection_create_end: _TracingSignal[
- TraceConnectionCreateEndParams
- ] = Signal(self)
- self._on_connection_reuseconn: _TracingSignal[
- TraceConnectionReuseconnParams
- ] = Signal(self)
- self._on_dns_resolvehost_start: _TracingSignal[
- TraceDnsResolveHostStartParams
- ] = Signal(self)
- self._on_dns_resolvehost_end: _TracingSignal[
- TraceDnsResolveHostEndParams
- ] = Signal(self)
- self._on_dns_cache_hit: _TracingSignal[
- TraceDnsCacheHitParams
- ] = Signal(self)
- self._on_dns_cache_miss: _TracingSignal[
- TraceDnsCacheMissParams
- ] = Signal(self)
- self._on_request_headers_sent: _TracingSignal[
- TraceRequestHeadersSentParams
- ] = Signal(self)
+ self._on_request_chunk_sent: _TracingSignal[TraceRequestChunkSentParams] = Signal(self)
+ self._on_response_chunk_received: _TracingSignal[TraceResponseChunkReceivedParams] = Signal(self)
+ self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal(self)
+ self._on_request_exception: _TracingSignal[TraceRequestExceptionParams] = Signal(self)
+ self._on_request_redirect: _TracingSignal[TraceRequestRedirectParams] = Signal(self)
+ self._on_connection_queued_start: _TracingSignal[TraceConnectionQueuedStartParams] = Signal(self)
+ self._on_connection_queued_end: _TracingSignal[TraceConnectionQueuedEndParams] = Signal(self)
+ self._on_connection_create_start: _TracingSignal[TraceConnectionCreateStartParams] = Signal(self)
+ self._on_connection_create_end: _TracingSignal[TraceConnectionCreateEndParams] = Signal(self)
+ self._on_connection_reuseconn: _TracingSignal[TraceConnectionReuseconnParams] = Signal(self)
+ self._on_dns_resolvehost_start: _TracingSignal[TraceDnsResolveHostStartParams] = Signal(self)
+ self._on_dns_resolvehost_end: _TracingSignal[TraceDnsResolveHostEndParams] = Signal(self)
+ self._on_dns_cache_hit: _TracingSignal[TraceDnsCacheHitParams] = (Signal(self))
+ self._on_dns_cache_miss: _TracingSignal[TraceDnsCacheMissParams] = (Signal(self))
+ self._on_request_headers_sent: _TracingSignal[TraceRequestHeadersSentParams] = Signal(self)
self._trace_config_ctx_factory = trace_config_ctx_factory
def trace_config_ctx(
- self, trace_request_ctx: Optional[SimpleNamespace] = None
+ self, trace_request_ctx: Optional[Mapping[str, str]] = None
) -> SimpleNamespace:
"""Return a new trace_config_ctx instance"""
return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx)
@@ -122,7 +92,9 @@ class TraceConfig:
return self._on_request_start
@property
- def on_request_chunk_sent(self) -> "_TracingSignal[TraceRequestChunkSentParams]":
+ def on_request_chunk_sent(
+ self,
+ ) -> "_TracingSignal[TraceRequestChunkSentParams]":
return self._on_request_chunk_sent
@property
diff --git a/contrib/python/aiohttp/aiohttp/typedefs.py b/contrib/python/aiohttp/aiohttp/typedefs.py
index 5e963e1a10e..668d4fc344f 100644
--- a/contrib/python/aiohttp/aiohttp/typedefs.py
+++ b/contrib/python/aiohttp/aiohttp/typedefs.py
@@ -7,6 +7,8 @@ from typing import (
Callable,
Iterable,
Mapping,
+ Protocol,
+ Sequence,
Tuple,
Union,
)
@@ -14,6 +16,18 @@ from typing import (
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy, istr
from yarl import URL
+try:
+ # Available in yarl>=1.10.0
+ from yarl import Query as _Query
+except ImportError: # pragma: no cover
+ SimpleQuery = Union[str, int, float] # pragma: no cover
+ QueryVariable = Union[SimpleQuery, "Sequence[SimpleQuery]"] # pragma: no cover
+ _Query = Union[ # type: ignore[misc] # pragma: no cover
+ None, str, "Mapping[str, QueryVariable]", "Sequence[Tuple[str, QueryVariable]]"
+ ]
+
+Query = _Query
+
DEFAULT_JSON_ENCODER = json.dumps
DEFAULT_JSON_DECODER = json.loads
@@ -34,7 +48,13 @@ else:
Byteish = Union[bytes, bytearray, memoryview]
JSONEncoder = Callable[[Any], str]
JSONDecoder = Callable[[str], Any]
-LooseHeaders = Union[Mapping[Union[str, istr], str], _CIMultiDict, _CIMultiDictProxy]
+LooseHeaders = Union[
+ Mapping[str, str],
+ Mapping[istr, str],
+ _CIMultiDict,
+ _CIMultiDictProxy,
+ Iterable[Tuple[Union[str, istr], str]],
+]
RawHeaders = Tuple[Tuple[bytes, bytes], ...]
StrOrURL = Union[str, URL]
@@ -51,4 +71,5 @@ LooseCookies = Union[
Handler = Callable[["Request"], Awaitable["StreamResponse"]]
Middleware = Callable[["Request", Handler], Awaitable["StreamResponse"]]
+
PathLike = Union[str, "os.PathLike[str]"]
diff --git a/contrib/python/aiohttp/aiohttp/web.py b/contrib/python/aiohttp/aiohttp/web.py
index e9116507f4e..88bf14bf828 100644
--- a/contrib/python/aiohttp/aiohttp/web.py
+++ b/contrib/python/aiohttp/aiohttp/web.py
@@ -7,7 +7,6 @@ import warnings
from argparse import ArgumentParser
from collections.abc import Iterable
from contextlib import suppress
-from functools import partial
from importlib import import_module
from typing import (
Any,
@@ -21,7 +20,6 @@ from typing import (
Union,
cast,
)
-from weakref import WeakSet
from .abc import AbstractAccessLogger
from .helpers import AppKey as AppKey
@@ -320,23 +318,6 @@ async def _run_app(
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
- async def wait(
- starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float
- ) -> None:
- # Wait for pending tasks for a given time limit.
- t = asyncio.current_task()
- assert t is not None
- starting_tasks.add(t)
- with suppress(asyncio.TimeoutError):
- await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout)
-
- async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None:
- t = asyncio.current_task()
- assert t is not None
- exclude.add(t)
- while tasks := asyncio.all_tasks().difference(exclude):
- await asyncio.wait(tasks)
-
# An internal function to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = await app
@@ -355,12 +336,6 @@ async def _run_app(
)
await runner.setup()
- # On shutdown we want to avoid waiting on tasks which run forever.
- # It's very likely that all tasks which run forever will have been created by
- # the time we have completed the application startup (in runner.setup()),
- # so we just record all running tasks here and exclude them later.
- starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks())
- runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout)
sites: List[BaseSite] = []
@@ -545,10 +520,14 @@ def run_app(
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
finally:
- _cancel_tasks({main_task}, loop)
- _cancel_tasks(asyncio.all_tasks(loop), loop)
- loop.run_until_complete(loop.shutdown_asyncgens())
- loop.close()
+ try:
+ main_task.cancel()
+ with suppress(asyncio.CancelledError):
+ loop.run_until_complete(main_task)
+ finally:
+ _cancel_tasks(asyncio.all_tasks(loop), loop)
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ loop.close()
def main(argv: List[str]) -> None:
diff --git a/contrib/python/aiohttp/aiohttp/web_app.py b/contrib/python/aiohttp/aiohttp/web_app.py
index 8e0e91bcfe3..ab11981a8a5 100644
--- a/contrib/python/aiohttp/aiohttp/web_app.py
+++ b/contrib/python/aiohttp/aiohttp/web_app.py
@@ -1,7 +1,7 @@
import asyncio
import logging
import warnings
-from functools import partial, update_wrapper
+from functools import lru_cache, partial, update_wrapper
from typing import (
TYPE_CHECKING,
Any,
@@ -38,7 +38,7 @@ from .helpers import DEBUG, AppKey
from .http_parser import RawRequestMessage
from .log import web_logger
from .streams import StreamReader
-from .typedefs import Middleware
+from .typedefs import Handler, Middleware
from .web_exceptions import NotAppKeyWarning
from .web_log import AccessLogger
from .web_middlewares import _fix_request_current_app
@@ -76,6 +76,18 @@ else:
_T = TypeVar("_T")
_U = TypeVar("_U")
+_Resource = TypeVar("_Resource", bound=AbstractResource)
+
+
+@lru_cache(None)
+def _build_middlewares(
+ handler: Handler, apps: Tuple["Application", ...]
+) -> Callable[[Request], Awaitable[StreamResponse]]:
+ """Apply middlewares to handler."""
+ for app in apps[::-1]:
+ for m, _ in app._middlewares_handlers: # type: ignore[union-attr]
+ handler = update_wrapper(partial(m, handler=handler), handler) # type: ignore[misc]
+ return handler
class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
@@ -88,6 +100,7 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
"_handler_args",
"_middlewares",
"_middlewares_handlers",
+ "_has_legacy_middlewares",
"_run_middlewares",
"_state",
"_frozen",
@@ -142,6 +155,7 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
self._middlewares_handlers: _MiddlewaresHandlers = None
# initialized on freezing
self._run_middlewares: Optional[bool] = None
+ self._has_legacy_middlewares: bool = True
self._state: Dict[Union[AppKey[Any], str], object] = {}
self._frozen = False
@@ -183,12 +197,10 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
return self is other
@overload # type: ignore[override]
- def __getitem__(self, key: AppKey[_T]) -> _T:
- ...
+ def __getitem__(self, key: AppKey[_T]) -> _T: ...
@overload
- def __getitem__(self, key: str) -> Any:
- ...
+ def __getitem__(self, key: str) -> Any: ...
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
return self._state[key]
@@ -202,12 +214,10 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
)
@overload # type: ignore[override]
- def __setitem__(self, key: AppKey[_T], value: _T) -> None:
- ...
+ def __setitem__(self, key: AppKey[_T], value: _T) -> None: ...
@overload
- def __setitem__(self, key: str, value: Any) -> None:
- ...
+ def __setitem__(self, key: str, value: Any) -> None: ...
def __setitem__(self, key: Union[str, AppKey[_T]], value: Any) -> None:
self._check_frozen()
@@ -231,17 +241,17 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
return iter(self._state)
+ def __hash__(self) -> int:
+ return id(self)
+
@overload # type: ignore[override]
- def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]:
- ...
+ def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
@overload
- def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]:
- ...
+ def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: ...
@overload
- def get(self, key: str, default: Any = ...) -> Any:
- ...
+ def get(self, key: str, default: Any = ...) -> Any: ...
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
return self._state.get(key, default)
@@ -290,6 +300,9 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
self._on_shutdown.freeze()
self._on_cleanup.freeze()
self._middlewares_handlers = tuple(self._prepare_middleware())
+ self._has_legacy_middlewares = any(
+ not new_style for _, new_style in self._middlewares_handlers
+ )
# If current app and any subapp do not have middlewares avoid run all
# of the code footprint that it implies, which have a middleware
@@ -334,7 +347,7 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
reg_handler("on_shutdown")
reg_handler("on_cleanup")
- def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource:
+ def add_subapp(self, prefix: str, subapp: "Application") -> PrefixedSubAppResource:
if not isinstance(prefix, str):
raise TypeError("Prefix must be str")
prefix = prefix.rstrip("/")
@@ -344,8 +357,8 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
return self._add_subapp(factory, subapp)
def _add_subapp(
- self, resource_factory: Callable[[], AbstractResource], subapp: "Application"
- ) -> AbstractResource:
+ self, resource_factory: Callable[[], _Resource], subapp: "Application"
+ ) -> _Resource:
if self.frozen:
raise RuntimeError("Cannot add sub application to frozen application")
if subapp.frozen:
@@ -359,7 +372,7 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
subapp._set_loop(self._loop)
return resource
- def add_domain(self, domain: str, subapp: "Application") -> AbstractResource:
+ def add_domain(self, domain: str, subapp: "Application") -> MatchedSubAppResource:
if not isinstance(domain, str):
raise TypeError("Domain must be str")
elif "*" in domain:
@@ -520,29 +533,30 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
match_info.freeze()
- resp = None
request._match_info = match_info
- expect = request.headers.get(hdrs.EXPECT)
- if expect:
+
+ if request.headers.get(hdrs.EXPECT):
resp = await match_info.expect_handler(request)
await request.writer.drain()
+ if resp is not None:
+ return resp
- if resp is None:
- handler = match_info.handler
+ handler = match_info.handler
- if self._run_middlewares:
+ if self._run_middlewares:
+ if not self._has_legacy_middlewares:
+ handler = _build_middlewares(handler, match_info.apps)
+ else:
for app in match_info.apps[::-1]:
for m, new_style in app._middlewares_handlers: # type: ignore[union-attr]
if new_style:
handler = update_wrapper(
- partial(m, handler=handler), handler
+ partial(m, handler=handler), handler # type: ignore[misc]
)
else:
handler = await m(app, handler) # type: ignore[arg-type,assignment]
- resp = await handler(request)
-
- return resp
+ return await handler(request)
def __call__(self) -> "Application":
"""gunicorn compatibility"""
@@ -585,7 +599,7 @@ class CleanupContext(_CleanupContextBase):
await it.__anext__()
except StopAsyncIteration:
pass
- except Exception as exc:
+ except (Exception, asyncio.CancelledError) as exc:
errors.append(exc)
else:
errors.append(RuntimeError(f"{it!r} has more than one 'yield'"))
diff --git a/contrib/python/aiohttp/aiohttp/web_fileresponse.py b/contrib/python/aiohttp/aiohttp/web_fileresponse.py
index 7dbe50f0a5a..f0de75e9f1b 100644
--- a/contrib/python/aiohttp/aiohttp/web_fileresponse.py
+++ b/contrib/python/aiohttp/aiohttp/web_fileresponse.py
@@ -1,7 +1,11 @@
import asyncio
-import mimetypes
import os
import pathlib
+import sys
+from contextlib import suppress
+from mimetypes import MimeTypes
+from stat import S_ISREG
+from types import MappingProxyType
from typing import ( # noqa
IO,
TYPE_CHECKING,
@@ -22,6 +26,8 @@ from .abc import AbstractStreamWriter
from .helpers import ETAG_ANY, ETag, must_be_empty_body
from .typedefs import LooseHeaders, PathLike
from .web_exceptions import (
+ HTTPForbidden,
+ HTTPNotFound,
HTTPNotModified,
HTTPPartialContent,
HTTPPreconditionFailed,
@@ -40,6 +46,35 @@ _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE"))
+CONTENT_TYPES: Final[MimeTypes] = MimeTypes()
+
+if sys.version_info < (3, 9):
+ CONTENT_TYPES.encodings_map[".br"] = "br"
+
+# File extension to IANA encodings map that will be checked in the order defined.
+ENCODING_EXTENSIONS = MappingProxyType(
+ {ext: CONTENT_TYPES.encodings_map[ext] for ext in (".br", ".gz")}
+)
+
+FALLBACK_CONTENT_TYPE = "application/octet-stream"
+
+# Provide additional MIME type/extension pairs to be recognized.
+# https://en.wikipedia.org/wiki/List_of_archive_formats#Compression_only
+ADDITIONAL_CONTENT_TYPES = MappingProxyType(
+ {
+ "application/gzip": ".gz",
+ "application/x-brotli": ".br",
+ "application/x-bzip2": ".bz2",
+ "application/x-compress": ".Z",
+ "application/x-xz": ".xz",
+ }
+)
+
+# Add custom pairs and clear the encodings map so guess_type ignores them.
+CONTENT_TYPES.encodings_map.clear()
+for content_type, extension in ADDITIONAL_CONTENT_TYPES.items():
+ CONTENT_TYPES.add_type(content_type, extension) # type: ignore[attr-defined]
+
class FileResponse(StreamResponse):
"""A response object can be used to send files."""
@@ -101,10 +136,12 @@ class FileResponse(StreamResponse):
return writer
@staticmethod
- def _strong_etag_match(etag_value: str, etags: Tuple[ETag, ...]) -> bool:
+ def _etag_match(etag_value: str, etags: Tuple[ETag, ...], *, weak: bool) -> bool:
if len(etags) == 1 and etags[0].value == ETAG_ANY:
return True
- return any(etag.value == etag_value for etag in etags if not etag.is_weak)
+ return any(
+ etag.value == etag_value for etag in etags if weak or not etag.is_weak
+ )
async def _not_modified(
self, request: "BaseRequest", etag_value: str, last_modified: float
@@ -124,42 +161,60 @@ class FileResponse(StreamResponse):
self.content_length = 0
return await super().prepare(request)
- def _get_file_path_stat_and_gzip(
- self, check_for_gzipped_file: bool
- ) -> Tuple[pathlib.Path, os.stat_result, bool]:
- """Return the file path, stat result, and gzip status.
+ def _get_file_path_stat_encoding(
+ self, accept_encoding: str
+ ) -> Tuple[pathlib.Path, os.stat_result, Optional[str]]:
+ """Return the file path, stat result, and encoding.
+
+ If an uncompressed file is returned, the encoding is set to
+ :py:data:`None`.
This method should be called from a thread executor
since it calls os.stat which may block.
"""
- filepath = self._path
- if check_for_gzipped_file:
- gzip_path = filepath.with_name(filepath.name + ".gz")
- try:
- return gzip_path, gzip_path.stat(), True
- except OSError:
- # Fall through and try the non-gzipped file
- pass
+ file_path = self._path
+ for file_extension, file_encoding in ENCODING_EXTENSIONS.items():
+ if file_encoding not in accept_encoding:
+ continue
+
+ compressed_path = file_path.with_suffix(file_path.suffix + file_extension)
+ with suppress(OSError):
+ # Do not follow symlinks and ignore any non-regular files.
+ st = compressed_path.lstat()
+ if S_ISREG(st.st_mode):
+ return compressed_path, st, file_encoding
- return filepath, filepath.stat(), False
+ # Fallback to the uncompressed file
+ return file_path, file_path.stat(), None
async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
- loop = asyncio.get_event_loop()
+ loop = asyncio.get_running_loop()
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
- check_for_gzipped_file = (
- "gzip" in request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
- )
- filepath, st, gzip = await loop.run_in_executor(
- None, self._get_file_path_stat_and_gzip, check_for_gzipped_file
- )
+ accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
+ try:
+ file_path, st, file_encoding = await loop.run_in_executor(
+ None, self._get_file_path_stat_encoding, accept_encoding
+ )
+ except OSError:
+ # Most likely to be FileNotFoundError or OSError for circular
+ # symlinks in python >= 3.13, so respond with 404.
+ self.set_status(HTTPNotFound.status_code)
+ return await super().prepare(request)
+
+ # Forbid special files like sockets, pipes, devices, etc.
+ if not S_ISREG(st.st_mode):
+ self.set_status(HTTPForbidden.status_code)
+ return await super().prepare(request)
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime
- # https://tools.ietf.org/html/rfc7232#section-6
+ # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
ifmatch = request.if_match
- if ifmatch is not None and not self._strong_etag_match(etag_value, ifmatch):
+ if ifmatch is not None and not self._etag_match(
+ etag_value, ifmatch, weak=False
+ ):
return await self._precondition_failed(request)
unmodsince = request.if_unmodified_since
@@ -170,8 +225,11 @@ class FileResponse(StreamResponse):
):
return await self._precondition_failed(request)
+ # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
ifnonematch = request.if_none_match
- if ifnonematch is not None and self._strong_etag_match(etag_value, ifnonematch):
+ if ifnonematch is not None and self._etag_match(
+ etag_value, ifnonematch, weak=True
+ ):
return await self._not_modified(request, etag_value, last_modified)
modsince = request.if_modified_since
@@ -182,15 +240,6 @@ class FileResponse(StreamResponse):
):
return await self._not_modified(request, etag_value, last_modified)
- if hdrs.CONTENT_TYPE not in self.headers:
- ct, encoding = mimetypes.guess_type(str(filepath))
- if not ct:
- ct = "application/octet-stream"
- should_set_ct = True
- else:
- encoding = "gzip" if gzip else None
- should_set_ct = False
-
status = self._status
file_size = st.st_size
count = file_size
@@ -265,11 +314,16 @@ class FileResponse(StreamResponse):
# return a HTTP 206 for a Range request.
self.set_status(status)
- if should_set_ct:
- self.content_type = ct # type: ignore[assignment]
- if encoding:
- self.headers[hdrs.CONTENT_ENCODING] = encoding
- if gzip:
+ # If the Content-Type header is not already set, guess it based on the
+ # extension of the request path. The encoding returned by guess_type
+ # can be ignored since the map was cleared above.
+ if hdrs.CONTENT_TYPE not in self.headers:
+ self.content_type = (
+ CONTENT_TYPES.guess_type(self._path)[0] or FALLBACK_CONTENT_TYPE
+ )
+
+ if file_encoding:
+ self.headers[hdrs.CONTENT_ENCODING] = file_encoding
self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
# Disable compression if we are already sending
# a compressed file since we don't want to double
@@ -293,7 +347,12 @@ class FileResponse(StreamResponse):
if count == 0 or must_be_empty_body(request.method, self.status):
return await super().prepare(request)
- fobj = await loop.run_in_executor(None, filepath.open, "rb")
+ try:
+ fobj = await loop.run_in_executor(None, file_path.open, "rb")
+ except PermissionError:
+ self.set_status(HTTPForbidden.status_code)
+ return await super().prepare(request)
+
if start: # be aware that start could be None or int=0 here.
offset = start
else:
diff --git a/contrib/python/aiohttp/aiohttp/web_middlewares.py b/contrib/python/aiohttp/aiohttp/web_middlewares.py
index 5da1533c0df..2f1f5f58e6e 100644
--- a/contrib/python/aiohttp/aiohttp/web_middlewares.py
+++ b/contrib/python/aiohttp/aiohttp/web_middlewares.py
@@ -110,7 +110,12 @@ def normalize_path_middleware(
def _fix_request_current_app(app: "Application") -> Middleware:
@middleware
async def impl(request: Request, handler: Handler) -> StreamResponse:
- with request.match_info.set_current_app(app):
+ match_info = request.match_info
+ prev = match_info.current_app
+ match_info.current_app = app
+ try:
return await handler(request)
+ finally:
+ match_info.current_app = prev
return impl
diff --git a/contrib/python/aiohttp/aiohttp/web_protocol.py b/contrib/python/aiohttp/aiohttp/web_protocol.py
index f083b13eb0f..85eb70d5a0b 100644
--- a/contrib/python/aiohttp/aiohttp/web_protocol.py
+++ b/contrib/python/aiohttp/aiohttp/web_protocol.py
@@ -1,5 +1,6 @@
import asyncio
import asyncio.streams
+import sys
import traceback
import warnings
from collections import deque
@@ -26,7 +27,7 @@ import yarl
from .abc import AbstractAccessLogger, AbstractStreamWriter
from .base_protocol import BaseProtocol
-from .helpers import ceil_timeout, set_exception
+from .helpers import ceil_timeout
from .http import (
HttpProcessingError,
HttpRequestParser,
@@ -37,7 +38,7 @@ from .http import (
from .log import access_logger, server_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .tcp_helpers import tcp_keepalive
-from .web_exceptions import HTTPException
+from .web_exceptions import HTTPException, HTTPInternalServerError
from .web_log import AccessLogger
from .web_request import BaseRequest
from .web_response import Response, StreamResponse
@@ -83,6 +84,9 @@ class PayloadAccessError(Exception):
"""Payload was accessed after response was sent."""
+_PAYLOAD_ACCESS_ERROR = PayloadAccessError()
+
+
@attr.s(auto_attribs=True, frozen=True, slots=True)
class _ErrInfo:
status: int
@@ -133,8 +137,6 @@ class RequestHandler(BaseProtocol):
"""
- KEEPALIVE_RESCHEDULE_DELAY = 1
-
__slots__ = (
"_request_count",
"_keepalive",
@@ -142,12 +144,13 @@ class RequestHandler(BaseProtocol):
"_request_handler",
"_request_factory",
"_tcp_keepalive",
- "_keepalive_time",
+ "_next_keepalive_close_time",
"_keepalive_handle",
"_keepalive_timeout",
"_lingering_time",
"_messages",
"_message_tail",
+ "_handler_waiter",
"_waiter",
"_task_handler",
"_upgrade",
@@ -162,6 +165,7 @@ class RequestHandler(BaseProtocol):
"_force_close",
"_current_request",
"_timeout_ceil_threshold",
+ "_request_in_progress",
)
def __init__(
@@ -195,7 +199,7 @@ class RequestHandler(BaseProtocol):
self._tcp_keepalive = tcp_keepalive
# placeholder to be replaced on keepalive timeout setup
- self._keepalive_time = 0.0
+ self._next_keepalive_close_time = 0.0
self._keepalive_handle: Optional[asyncio.Handle] = None
self._keepalive_timeout = keepalive_timeout
self._lingering_time = float(lingering_time)
@@ -204,6 +208,7 @@ class RequestHandler(BaseProtocol):
self._message_tail = b""
self._waiter: Optional[asyncio.Future[None]] = None
+ self._handler_waiter: Optional[asyncio.Future[None]] = None
self._task_handler: Optional[asyncio.Task[None]] = None
self._upgrade = False
@@ -237,6 +242,7 @@ class RequestHandler(BaseProtocol):
self._close = False
self._force_close = False
+ self._request_in_progress = False
def __repr__(self) -> str:
return "<{} {}>".format(
@@ -259,25 +265,44 @@ class RequestHandler(BaseProtocol):
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()
- if self._waiter:
- self._waiter.cancel()
-
- # wait for handlers
- with suppress(asyncio.CancelledError, asyncio.TimeoutError):
+ # Wait for graceful handler completion
+ if self._request_in_progress:
+ # The future is only created when we are shutting
+ # down while the handler is still processing a request
+ # to avoid creating a future for every request.
+ self._handler_waiter = self._loop.create_future()
+ try:
+ async with ceil_timeout(timeout):
+ await self._handler_waiter
+ except (asyncio.CancelledError, asyncio.TimeoutError):
+ self._handler_waiter = None
+ if (
+ sys.version_info >= (3, 11)
+ and (task := asyncio.current_task())
+ and task.cancelling()
+ ):
+ raise
+ # Then cancel handler and wait
+ try:
async with ceil_timeout(timeout):
if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())
if self._task_handler is not None and not self._task_handler.done():
- await self._task_handler
+ await asyncio.shield(self._task_handler)
+ except (asyncio.CancelledError, asyncio.TimeoutError):
+ if (
+ sys.version_info >= (3, 11)
+ and (task := asyncio.current_task())
+ and task.cancelling()
+ ):
+ raise
# force-close non-idle handler
if self._task_handler is not None:
self._task_handler.cancel()
- if self.transport is not None:
- self.transport.close()
- self.transport = None
+ self.force_close()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
super().connection_made(transport)
@@ -286,22 +311,27 @@ class RequestHandler(BaseProtocol):
if self._tcp_keepalive:
tcp_keepalive(real_transport)
- self._task_handler = self._loop.create_task(self.start())
assert self._manager is not None
self._manager.connection_made(self, real_transport)
+ loop = self._loop
+ if sys.version_info >= (3, 12):
+ task = asyncio.Task(self.start(), loop=loop, eager_start=True)
+ else:
+ task = loop.create_task(self.start())
+ self._task_handler = task
+
def connection_lost(self, exc: Optional[BaseException]) -> None:
if self._manager is None:
return
self._manager.connection_lost(self, exc)
- super().connection_lost(exc)
-
# Grab value before setting _manager to None.
handler_cancellation = self._manager.handler_cancellation
+ self.force_close()
+ super().connection_lost(exc)
self._manager = None
- self._force_close = True
self._request_factory = None
self._request_handler = None
self._request_parser = None
@@ -314,9 +344,6 @@ class RequestHandler(BaseProtocol):
exc = ConnectionResetError("Connection lost")
self._current_request._cancel(exc)
- if self._waiter is not None:
- self._waiter.cancel()
-
if handler_cancellation and self._task_handler is not None:
self._task_handler.cancel()
@@ -421,23 +448,21 @@ class RequestHandler(BaseProtocol):
self.logger.exception(*args, **kw)
def _process_keepalive(self) -> None:
+ self._keepalive_handle = None
if self._force_close or not self._keepalive:
return
- next = self._keepalive_time + self._keepalive_timeout
+ loop = self._loop
+ now = loop.time()
+ close_time = self._next_keepalive_close_time
+ if now <= close_time:
+ # Keep alive close check fired too early, reschedule
+ self._keepalive_handle = loop.call_at(close_time, self._process_keepalive)
+ return
# handler in idle state
- if self._waiter:
- if self._loop.time() > next:
- self.force_close()
- return
-
- # not all request handlers are done,
- # reschedule itself to next second
- self._keepalive_handle = self._loop.call_later(
- self.KEEPALIVE_RESCHEDULE_DELAY,
- self._process_keepalive,
- )
+ if self._waiter and not self._waiter.done():
+ self.force_close()
async def _handle_request(
self,
@@ -445,7 +470,7 @@ class RequestHandler(BaseProtocol):
start_time: float,
request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
) -> Tuple[StreamResponse, bool]:
- assert self._request_handler is not None
+ self._request_in_progress = True
try:
try:
self._current_request = request
@@ -454,16 +479,16 @@ class RequestHandler(BaseProtocol):
self._current_request = None
except HTTPException as exc:
resp = exc
- reset = await self.finish_response(request, resp, start_time)
+ resp, reset = await self.finish_response(request, resp, start_time)
except asyncio.CancelledError:
raise
except asyncio.TimeoutError as exc:
self.log_debug("Request handler timed out.", exc_info=exc)
resp = self.handle_error(request, 504)
- reset = await self.finish_response(request, resp, start_time)
+ resp, reset = await self.finish_response(request, resp, start_time)
except Exception as exc:
resp = self.handle_error(request, 500, exc)
- reset = await self.finish_response(request, resp, start_time)
+ resp, reset = await self.finish_response(request, resp, start_time)
else:
# Deprecation warning (See #2415)
if getattr(resp, "__http_exception__", False):
@@ -474,7 +499,11 @@ class RequestHandler(BaseProtocol):
DeprecationWarning,
)
- reset = await self.finish_response(request, resp, start_time)
+ resp, reset = await self.finish_response(request, resp, start_time)
+ finally:
+ self._request_in_progress = False
+ if self._handler_waiter is not None:
+ self._handler_waiter.set_result(None)
return resp, reset
@@ -488,7 +517,7 @@ class RequestHandler(BaseProtocol):
keep_alive(True) specified.
"""
loop = self._loop
- handler = self._task_handler
+ handler = asyncio.current_task(loop)
assert handler is not None
manager = self._manager
assert manager is not None
@@ -503,8 +532,6 @@ class RequestHandler(BaseProtocol):
# wait for next request
self._waiter = loop.create_future()
await self._waiter
- except asyncio.CancelledError:
- break
finally:
self._waiter = None
@@ -524,12 +551,14 @@ class RequestHandler(BaseProtocol):
request = self._request_factory(message, payload, self, writer, handler)
try:
# a new task is used for copy context vars (#3406)
- task = self._loop.create_task(
- self._handle_request(request, start, request_handler)
- )
+ coro = self._handle_request(request, start, request_handler)
+ if sys.version_info >= (3, 12):
+ task = asyncio.Task(coro, loop=loop, eager_start=True)
+ else:
+ task = loop.create_task(coro)
try:
resp, reset = await task
- except (asyncio.CancelledError, ConnectionError):
+ except ConnectionError:
self.log_debug("Ignored premature client disconnection")
break
@@ -553,27 +582,30 @@ class RequestHandler(BaseProtocol):
now = loop.time()
end_t = now + lingering_time
- with suppress(asyncio.TimeoutError, asyncio.CancelledError):
+ try:
while not payload.is_eof() and now < end_t:
async with ceil_timeout(end_t - now):
# read and ignore
await payload.readany()
now = loop.time()
+ except (asyncio.CancelledError, asyncio.TimeoutError):
+ if (
+ sys.version_info >= (3, 11)
+ and (t := asyncio.current_task())
+ and t.cancelling()
+ ):
+ raise
# if payload still uncompleted
if not payload.is_eof() and not self._force_close:
self.log_debug("Uncompleted request.")
self.close()
- set_exception(payload, PayloadAccessError())
+ payload.set_exception(_PAYLOAD_ACCESS_ERROR)
except asyncio.CancelledError:
- self.log_debug("Ignored premature client disconnection ")
- break
- except RuntimeError as exc:
- if self.debug:
- self.log_exception("Unhandled runtime exception", exc_info=exc)
- self.force_close()
+ self.log_debug("Ignored premature client disconnection")
+ raise
except Exception as exc:
self.log_exception("Unhandled exception", exc_info=exc)
self.force_close()
@@ -584,11 +616,12 @@ class RequestHandler(BaseProtocol):
if self._keepalive and not self._close:
# start keep-alive timer
if keepalive_timeout is not None:
- now = self._loop.time()
- self._keepalive_time = now
+ now = loop.time()
+ close_time = now + keepalive_timeout
+ self._next_keepalive_close_time = close_time
if self._keepalive_handle is None:
self._keepalive_handle = loop.call_at(
- now + keepalive_timeout, self._process_keepalive
+ close_time, self._process_keepalive
)
else:
break
@@ -601,7 +634,7 @@ class RequestHandler(BaseProtocol):
async def finish_response(
self, request: BaseRequest, resp: StreamResponse, start_time: float
- ) -> bool:
+ ) -> Tuple[StreamResponse, bool]:
"""Prepare the response and write_eof, then log access.
This has to
@@ -609,6 +642,7 @@ class RequestHandler(BaseProtocol):
can get exception information. Returns True if the client disconnects
prematurely.
"""
+ request._finish()
if self._request_parser is not None:
self._request_parser.set_upgraded(False)
self._upgrade = False
@@ -619,22 +653,26 @@ class RequestHandler(BaseProtocol):
prepare_meth = resp.prepare
except AttributeError:
if resp is None:
- raise RuntimeError("Missing return " "statement on request handler")
+ self.log_exception("Missing return statement on request handler")
else:
- raise RuntimeError(
- "Web-handler should return "
- "a response instance, "
+ self.log_exception(
+ "Web-handler should return a response instance, "
"got {!r}".format(resp)
)
+ exc = HTTPInternalServerError()
+ resp = Response(
+ status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
+ )
+ prepare_meth = resp.prepare
try:
await prepare_meth(request)
await resp.write_eof()
except ConnectionError:
self.log_access(request, resp, start_time)
- return True
- else:
- self.log_access(request, resp, start_time)
- return False
+ return resp, True
+
+ self.log_access(request, resp, start_time)
+ return resp, False
def handle_error(
self,
diff --git a/contrib/python/aiohttp/aiohttp/web_request.py b/contrib/python/aiohttp/aiohttp/web_request.py
index 4bc670a798c..eca71e4413a 100644
--- a/contrib/python/aiohttp/aiohttp/web_request.py
+++ b/contrib/python/aiohttp/aiohttp/web_request.py
@@ -79,7 +79,7 @@ class FileField:
filename: str
file: io.BufferedReader
content_type: str
- headers: "CIMultiDictProxy[str]"
+ headers: CIMultiDictProxy[str]
_TCHAR: Final[str] = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-"
@@ -99,10 +99,10 @@ _QUOTED_STRING: Final[str] = r'"(?:{quoted_pair}|{qdtext})*"'.format(
qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR
)
-_FORWARDED_PAIR: Final[
- str
-] = r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format(
- token=_TOKEN, quoted_string=_QUOTED_STRING
+_FORWARDED_PAIR: Final[str] = (
+ r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format(
+ token=_TOKEN, quoted_string=_QUOTED_STRING
+ )
)
_QUOTED_PAIR_REPLACE_RE: Final[Pattern[str]] = re.compile(r"\\([\t !-~])")
@@ -169,12 +169,16 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
self._payload_writer = payload_writer
self._payload = payload
- self._headers = message.headers
+ self._headers: CIMultiDictProxy[str] = message.headers
self._method = message.method
self._version = message.version
self._cache: Dict[str, Any] = {}
url = message.url
if url.is_absolute():
+ if scheme is not None:
+ url = url.with_scheme(scheme)
+ if host is not None:
+ url = url.with_host(host)
# absolute URL is given,
# override auto-calculating url, host, and scheme
# all other properties should be good
@@ -184,6 +188,10 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
self._rel_url = url.relative()
else:
self._rel_url = message.url
+ if scheme is not None:
+ self._cache["scheme"] = scheme
+ if host is not None:
+ self._cache["host"] = host
self._post: Optional[MultiDictProxy[Union[str, bytes, FileField]]] = None
self._read_bytes: Optional[bytes] = None
@@ -197,10 +205,6 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
self._transport_sslcontext = transport.get_extra_info("sslcontext")
self._transport_peername = transport.get_extra_info("peername")
- if scheme is not None:
- self._cache["scheme"] = scheme
- if host is not None:
- self._cache["host"] = host
if remote is not None:
self._cache["remote"] = remote
@@ -235,7 +239,8 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
# a copy semantic
dct["headers"] = CIMultiDictProxy(CIMultiDict(headers))
dct["raw_headers"] = tuple(
- (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
+ (k.encode("utf-8"), v.encode("utf-8"))
+ for k, v in dct["headers"].items()
)
message = self._message._replace(**dct)
@@ -481,7 +486,7 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
@reify
def query(self) -> "MultiMapping[str]":
"""A multidict with all the variables in the query string."""
- return MultiDictProxy(self._rel_url.query)
+ return self._rel_url.query
@reify
def query_string(self) -> str:
@@ -492,7 +497,7 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
return self._rel_url.query_string
@reify
- def headers(self) -> "MultiMapping[str]":
+ def headers(self) -> CIMultiDictProxy[str]:
"""A case-insensitive multidict proxy with all headers."""
return self._headers
@@ -819,6 +824,18 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
def _cancel(self, exc: BaseException) -> None:
set_exception(self._payload, exc)
+ def _finish(self) -> None:
+ if self._post is None or self.content_type != "multipart/form-data":
+ return
+
+ # NOTE: Release file descriptors for the
+ # NOTE: `tempfile.Temporaryfile`-created `_io.BufferedRandom`
+ # NOTE: instances of files sent within multipart request body
+ # NOTE: via HTTP POST request.
+ for file_name, file_field_object in self._post.items():
+ if isinstance(file_field_object, FileField):
+ file_field_object.file.close()
+
class Request(BaseRequest):
@@ -898,4 +915,5 @@ class Request(BaseRequest):
if match_info is None:
return
for app in match_info._apps:
- await app.on_response_prepare.send(self, response)
+ if on_response_prepare := app.on_response_prepare:
+ await on_response_prepare.send(self, response)
diff --git a/contrib/python/aiohttp/aiohttp/web_response.py b/contrib/python/aiohttp/aiohttp/web_response.py
index 40d6f01ecaa..4307b2a98c8 100644
--- a/contrib/python/aiohttp/aiohttp/web_response.py
+++ b/contrib/python/aiohttp/aiohttp/web_response.py
@@ -41,6 +41,8 @@ from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11
from .payload import Payload
from .typedefs import JSONEncoder, LooseHeaders
+REASON_PHRASES = {http_status.value: http_status.phrase for http_status in HTTPStatus}
+
__all__ = ("ContentCoding", "StreamResponse", "Response", "json_response")
@@ -52,6 +54,7 @@ else:
BaseClass = collections.abc.MutableMapping
+# TODO(py311): Convert to StrEnum for wider use
class ContentCoding(enum.Enum):
# The content codings that we have support for.
#
@@ -62,6 +65,8 @@ class ContentCoding(enum.Enum):
identity = "identity"
+CONTENT_CODINGS = {coding.value: coding for coding in ContentCoding}
+
############################################################
# HTTP Response classes
############################################################
@@ -71,6 +76,8 @@ class StreamResponse(BaseClass, HeadersMixin):
_length_check = True
+ _body: Union[None, bytes, bytearray, Payload]
+
def __init__(
self,
*,
@@ -97,11 +104,11 @@ class StreamResponse(BaseClass, HeadersMixin):
else:
self._headers = CIMultiDict()
- self.set_status(status, reason)
+ self._set_status(status, reason)
@property
def prepared(self) -> bool:
- return self._payload_writer is not None
+ return self._eof_sent or self._payload_writer is not None
@property
def task(self) -> "Optional[asyncio.Task[None]]":
@@ -131,15 +138,15 @@ class StreamResponse(BaseClass, HeadersMixin):
status: int,
reason: Optional[str] = None,
) -> None:
- assert not self.prepared, (
- "Cannot change the response status code after " "the headers have been sent"
- )
+ assert (
+ not self.prepared
+ ), "Cannot change the response status code after the headers have been sent"
+ self._set_status(status, reason)
+
+ def _set_status(self, status: int, reason: Optional[str]) -> None:
self._status = int(status)
if reason is None:
- try:
- reason = HTTPStatus(self._status).phrase
- except ValueError:
- reason = ""
+ reason = REASON_PHRASES.get(self._status, "")
self._reason = reason
@property
@@ -175,7 +182,7 @@ class StreamResponse(BaseClass, HeadersMixin):
) -> None:
"""Enables response compression encoding."""
# Backwards compatibility for when force was a bool <0.17.
- if type(force) == bool:
+ if isinstance(force, bool):
force = ContentCoding.deflate if force else ContentCoding.identity
warnings.warn(
"Using boolean for force is deprecated #3318", DeprecationWarning
@@ -403,8 +410,8 @@ class StreamResponse(BaseClass, HeadersMixin):
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
- for coding in ContentCoding:
- if coding.value in accept_encoding:
+ for value, coding in CONTENT_CODINGS.items():
+ if value in accept_encoding:
await self._do_start_compression(coding)
return
@@ -499,9 +506,7 @@ class StreamResponse(BaseClass, HeadersMixin):
assert writer is not None
# status line
version = request.version
- status_line = "HTTP/{}.{} {} {}".format(
- version[0], version[1], self._status, self._reason
- )
+ status_line = f"HTTP/{version[0]}.{version[1]} {self._status} {self._reason}"
await writer.write_headers(status_line, self._headers)
async def write(self, data: bytes) -> None:
@@ -650,21 +655,17 @@ class Response(StreamResponse):
return self._body
@body.setter
- def body(self, body: bytes) -> None:
+ def body(self, body: Any) -> None:
if body is None:
- self._body: Optional[bytes] = None
- self._body_payload: bool = False
+ self._body = None
elif isinstance(body, (bytes, bytearray)):
self._body = body
- self._body_payload = False
else:
try:
self._body = body = payload.PAYLOAD_REGISTRY.get(body)
except payload.LookupError:
raise ValueError("Unsupported body type %r" % type(body))
- self._body_payload = True
-
headers = self._headers
# set content-type
@@ -673,7 +674,7 @@ class Response(StreamResponse):
# copy payload headers
if body.headers:
- for (key, value) in body.headers.items():
+ for key, value in body.headers.items():
if key not in headers:
headers[key] = value
@@ -697,7 +698,6 @@ class Response(StreamResponse):
self.charset = "utf-8"
self._body = text.encode(self.charset)
- self._body_payload = False
self._compressed_body = None
@property
@@ -711,7 +711,7 @@ class Response(StreamResponse):
if self._compressed_body is not None:
# Return length of the compressed body
return len(self._compressed_body)
- elif self._body_payload:
+ elif isinstance(self._body, Payload):
# A payload without content length, or a compressed payload
return None
elif self._body is not None:
@@ -736,9 +736,8 @@ class Response(StreamResponse):
if body is not None:
if self._must_be_empty_body:
await super().write_eof()
- elif self._body_payload:
- payload = cast(Payload, body)
- await payload.write(self._payload_writer)
+ elif isinstance(self._body, Payload):
+ await self._body.write(self._payload_writer)
await super().write_eof()
else:
await super().write_eof(cast(bytes, body))
@@ -746,14 +745,13 @@ class Response(StreamResponse):
await super().write_eof()
async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
- if should_remove_content_length(request.method, self.status):
- if hdrs.CONTENT_LENGTH in self._headers:
+ if hdrs.CONTENT_LENGTH in self._headers:
+ if should_remove_content_length(request.method, self.status):
del self._headers[hdrs.CONTENT_LENGTH]
- elif not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
- if self._body_payload:
- size = cast(Payload, self._body).size
- if size is not None:
- self._headers[hdrs.CONTENT_LENGTH] = str(size)
+ elif not self._chunked:
+ if isinstance(self._body, Payload):
+ if self._body.size is not None:
+ self._headers[hdrs.CONTENT_LENGTH] = str(self._body.size)
else:
body_len = len(self._body) if self._body else "0"
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-7
@@ -765,7 +763,7 @@ class Response(StreamResponse):
return await super()._start(request)
async def _do_start_compression(self, coding: ContentCoding) -> None:
- if self._body_payload or self._chunked:
+ if self._chunked or isinstance(self._body, Payload):
return await super()._do_start_compression(coding)
if coding != ContentCoding.identity:
diff --git a/contrib/python/aiohttp/aiohttp/web_routedef.py b/contrib/python/aiohttp/aiohttp/web_routedef.py
index d79cd32a14a..93802141c56 100644
--- a/contrib/python/aiohttp/aiohttp/web_routedef.py
+++ b/contrib/python/aiohttp/aiohttp/web_routedef.py
@@ -162,12 +162,10 @@ class RouteTableDef(Sequence[AbstractRouteDef]):
return f"<RouteTableDef count={len(self._items)}>"
@overload
- def __getitem__(self, index: int) -> AbstractRouteDef:
- ...
+ def __getitem__(self, index: int) -> AbstractRouteDef: ...
@overload
- def __getitem__(self, index: slice) -> List[AbstractRouteDef]:
- ...
+ def __getitem__(self, index: slice) -> List[AbstractRouteDef]: ...
def __getitem__(self, index): # type: ignore[no-untyped-def]
return self._items[index]
diff --git a/contrib/python/aiohttp/aiohttp/web_runner.py b/contrib/python/aiohttp/aiohttp/web_runner.py
index 19a4441658f..0a237ede2c5 100644
--- a/contrib/python/aiohttp/aiohttp/web_runner.py
+++ b/contrib/python/aiohttp/aiohttp/web_runner.py
@@ -3,7 +3,7 @@ import signal
import socket
import warnings
from abc import ABC, abstractmethod
-from typing import Any, Awaitable, Callable, List, Optional, Set
+from typing import Any, List, Optional, Set
from yarl import URL
@@ -108,7 +108,7 @@ class TCPSite(BaseSite):
@property
def name(self) -> str:
scheme = "https" if self._ssl_context else "http"
- host = "0.0.0.0" if self._host is None else self._host
+ host = "0.0.0.0" if not self._host else self._host
return str(URL.build(scheme=scheme, host=host, port=self._port))
async def start(self) -> None:
@@ -238,14 +238,7 @@ class SockSite(BaseSite):
class BaseRunner(ABC):
- __slots__ = (
- "shutdown_callback",
- "_handle_signals",
- "_kwargs",
- "_server",
- "_sites",
- "_shutdown_timeout",
- )
+ __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout")
def __init__(
self,
@@ -254,7 +247,6 @@ class BaseRunner(ABC):
shutdown_timeout: float = 60.0,
**kwargs: Any,
) -> None:
- self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None
self._handle_signals = handle_signals
self._kwargs = kwargs
self._server: Optional[Server] = None
@@ -312,10 +304,6 @@ class BaseRunner(ABC):
await asyncio.sleep(0)
self._server.pre_shutdown()
await self.shutdown()
-
- if self.shutdown_callback:
- await self.shutdown_callback()
-
await self._server.shutdown(self._shutdown_timeout)
await self._cleanup_server()
diff --git a/contrib/python/aiohttp/aiohttp/web_server.py b/contrib/python/aiohttp/aiohttp/web_server.py
index 52faacb164a..973e7c15440 100644
--- a/contrib/python/aiohttp/aiohttp/web_server.py
+++ b/contrib/python/aiohttp/aiohttp/web_server.py
@@ -1,9 +1,9 @@
"""Low level HTTP server."""
+
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa
from .abc import AbstractStreamWriter
-from .helpers import get_running_loop
from .http_parser import RawRequestMessage
from .streams import StreamReader
from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler
@@ -22,7 +22,7 @@ class Server:
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs: Any
) -> None:
- self._loop = get_running_loop(loop)
+ self._loop = loop or asyncio.get_event_loop()
self._connections: Dict[RequestHandler, asyncio.Transport] = {}
self._kwargs = kwargs
self.requests_count = 0
@@ -43,7 +43,12 @@ class Server:
self, handler: RequestHandler, exc: Optional[BaseException] = None
) -> None:
if handler in self._connections:
- del self._connections[handler]
+ if handler._task_handler:
+ handler._task_handler.add_done_callback(
+ lambda f: self._connections.pop(handler, None)
+ )
+ else:
+ del self._connections[handler]
def _make_request(
self,
diff --git a/contrib/python/aiohttp/aiohttp/web_urldispatcher.py b/contrib/python/aiohttp/aiohttp/web_urldispatcher.py
index 954291f6449..89abdc43fa6 100644
--- a/contrib/python/aiohttp/aiohttp/web_urldispatcher.py
+++ b/contrib/python/aiohttp/aiohttp/web_urldispatcher.py
@@ -8,8 +8,8 @@ import inspect
import keyword
import os
import re
+import sys
import warnings
-from contextlib import contextmanager
from functools import wraps
from pathlib import Path
from types import MappingProxyType
@@ -38,7 +38,7 @@ from typing import (
cast,
)
-from yarl import URL, __version__ as yarl_version # type: ignore[attr-defined]
+from yarl import URL, __version__ as yarl_version
from . import hdrs
from .abc import AbstractMatchInfo, AbstractRouter, AbstractView
@@ -78,6 +78,12 @@ if TYPE_CHECKING:
else:
BaseDict = dict
+CIRCULAR_SYMLINK_ERROR = (
+ (OSError,)
+ if sys.version_info < (3, 10) and sys.platform.startswith("win32")
+ else (RuntimeError,) if sys.version_info < (3, 13) else ()
+)
+
YARL_VERSION: Final[Tuple[int, ...]] = tuple(map(int, yarl_version.split(".")[:2]))
HTTP_METHOD_RE: Final[Pattern[str]] = re.compile(
@@ -199,7 +205,7 @@ class AbstractRoute(abc.ABC):
@wraps(handler)
async def handler_wrapper(request: Request) -> StreamResponse:
- result = old_handler(request)
+ result = old_handler(request) # type: ignore[call-arg]
if asyncio.iscoroutine(result):
result = await result
assert isinstance(result, StreamResponse)
@@ -286,8 +292,8 @@ class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo):
assert app is not None
return app
- @contextmanager
- def set_current_app(self, app: "Application") -> Generator[None, None, None]:
+ @current_app.setter
+ def current_app(self, app: "Application") -> None:
if DEBUG: # pragma: no cover
if app not in self._apps:
raise RuntimeError(
@@ -295,12 +301,7 @@ class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo):
self._apps, app
)
)
- prev = self._current_app
self._current_app = app
- try:
- yield
- finally:
- self._current_app = prev
def freeze(self) -> None:
self._frozen = True
@@ -334,6 +335,8 @@ async def _default_expect_handler(request: Request) -> None:
if request.version == HttpVersion11:
if expect.lower() == "100-continue":
await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")
+ # Reset output_size as we haven't started the main body yet.
+ request.writer.output_size = 0
else:
raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect)
@@ -372,7 +375,7 @@ class Resource(AbstractResource):
async def resolve(self, request: Request) -> _Resolve:
allowed_methods: Set[str] = set()
- match_dict = self._match(request.rel_url.raw_path)
+ match_dict = self._match(request.rel_url.path_safe)
if match_dict is None:
return None, allowed_methods
@@ -422,8 +425,7 @@ class PlainResource(Resource):
# string comparison is about 10 times faster than regexp matching
if self._path == path:
return {}
- else:
- return None
+ return None
def raw_match(self, path: str) -> bool:
return self._path == path
@@ -447,6 +449,7 @@ class DynamicResource(Resource):
def __init__(self, path: str, *, name: Optional[str] = None) -> None:
super().__init__(name=name)
+ self._orig_path = path
pattern = ""
formatter = ""
for part in ROUTE_RE.split(path):
@@ -493,13 +496,12 @@ class DynamicResource(Resource):
match = self._pattern.fullmatch(path)
if match is None:
return None
- else:
- return {
- key: _unquote_path(value) for key, value in match.groupdict().items()
- }
+ return {
+ key: _unquote_path_safe(value) for key, value in match.groupdict().items()
+ }
def raw_match(self, path: str) -> bool:
- return self._formatter == path
+ return self._orig_path == path
def get_info(self) -> _InfoDict:
return {"formatter": self._formatter, "pattern": self._pattern}
@@ -557,14 +559,11 @@ class StaticResource(PrefixResource):
) -> None:
super().__init__(prefix, name=name)
try:
- directory = Path(directory)
- if str(directory).startswith("~"):
- directory = Path(os.path.expanduser(str(directory)))
- directory = directory.resolve()
- if not directory.is_dir():
- raise ValueError("Not a directory")
- except (FileNotFoundError, ValueError) as error:
- raise ValueError(f"No directory exists at '{directory}'") from error
+ directory = Path(directory).expanduser().resolve(strict=True)
+ except FileNotFoundError as error:
+ raise ValueError(f"'{directory}' does not exist") from error
+ if not directory.is_dir():
+ raise ValueError(f"'{directory}' is not a directory")
self._directory = directory
self._show_index = show_index
self._chunk_size = chunk_size
@@ -644,7 +643,7 @@ class StaticResource(PrefixResource):
)
async def resolve(self, request: Request) -> _Resolve:
- path = request.rel_url.raw_path
+ path = request.rel_url.path_safe
method = request.method
allowed_methods = set(self._routes)
if not path.startswith(self._prefix2) and path != self._prefix:
@@ -653,7 +652,7 @@ class StaticResource(PrefixResource):
if method not in allowed_methods:
return None, allowed_methods
- match_dict = {"filename": _unquote_path(path[len(self._prefix) + 1 :])}
+ match_dict = {"filename": _unquote_path_safe(path[len(self._prefix) + 1 :])}
return (UrlMappingMatchInfo(match_dict, self._routes[method]), allowed_methods)
def __len__(self) -> int:
@@ -664,59 +663,64 @@ class StaticResource(PrefixResource):
async def _handle(self, request: Request) -> StreamResponse:
rel_url = request.match_info["filename"]
+ filename = Path(rel_url)
+ if filename.anchor:
+ # rel_url is an absolute name like
+ # /static/\\machine_name\c$ or /static/D:\path
+ # where the static dir is totally different
+ raise HTTPForbidden()
+
+ unresolved_path = self._directory.joinpath(filename)
+ loop = asyncio.get_running_loop()
+ return await loop.run_in_executor(
+ None, self._resolve_path_to_response, unresolved_path
+ )
+
+ def _resolve_path_to_response(self, unresolved_path: Path) -> StreamResponse:
+ """Take the unresolved path and query the file system to form a response."""
+ # Check for access outside the root directory. For follow symlinks, URI
+ # cannot traverse out, but symlinks can. Otherwise, no access outside
+ # root is permitted.
try:
- filename = Path(rel_url)
- if filename.anchor:
- # rel_url is an absolute name like
- # /static/\\machine_name\c$ or /static/D:\path
- # where the static dir is totally different
- raise HTTPForbidden()
- unresolved_path = self._directory.joinpath(filename)
if self._follow_symlinks:
normalized_path = Path(os.path.normpath(unresolved_path))
normalized_path.relative_to(self._directory)
- filepath = normalized_path.resolve()
+ file_path = normalized_path.resolve()
else:
- filepath = unresolved_path.resolve()
- filepath.relative_to(self._directory)
- except (ValueError, FileNotFoundError) as error:
- # relatively safe
- raise HTTPNotFound() from error
- except HTTPForbidden:
- raise
- except Exception as error:
- # perm error or other kind!
- request.app.logger.exception(error)
+ file_path = unresolved_path.resolve()
+ file_path.relative_to(self._directory)
+ except (ValueError, *CIRCULAR_SYMLINK_ERROR) as error:
+ # ValueError is raised for the relative check. Circular symlinks
+ # raise here on resolving for python < 3.13.
raise HTTPNotFound() from error
- # on opening a dir, load its contents if allowed
- if filepath.is_dir():
- if self._show_index:
- try:
+ # if path is a directory, return the contents if permitted. Note the
+ # directory check will raise if a segment is not readable.
+ try:
+ if file_path.is_dir():
+ if self._show_index:
return Response(
- text=self._directory_as_html(filepath), content_type="text/html"
+ text=self._directory_as_html(file_path),
+ content_type="text/html",
)
- except PermissionError:
+ else:
raise HTTPForbidden()
- else:
- raise HTTPForbidden()
- elif filepath.is_file():
- return FileResponse(filepath, chunk_size=self._chunk_size)
- else:
- raise HTTPNotFound
+ except PermissionError as error:
+ raise HTTPForbidden() from error
- def _directory_as_html(self, filepath: Path) -> str:
- # returns directory's index as html
+ # Return the file response, which handles all other checks.
+ return FileResponse(file_path, chunk_size=self._chunk_size)
- # sanity check
- assert filepath.is_dir()
+ def _directory_as_html(self, dir_path: Path) -> str:
+ """returns directory's index as html."""
+ assert dir_path.is_dir()
- relative_path_to_dir = filepath.relative_to(self._directory).as_posix()
+ relative_path_to_dir = dir_path.relative_to(self._directory).as_posix()
index_of = f"Index of /{html_escape(relative_path_to_dir)}"
h1 = f"<h1>{index_of}</h1>"
index_list = []
- dir_index = filepath.iterdir()
+ dir_index = dir_path.iterdir()
for _file in sorted(dir_index):
# show file url as relative to static path
rel_path = _file.relative_to(self._directory).as_posix()
@@ -750,13 +754,20 @@ class PrefixedSubAppResource(PrefixResource):
def __init__(self, prefix: str, app: "Application") -> None:
super().__init__(prefix)
self._app = app
- for resource in app.router.resources():
- resource.add_prefix(prefix)
+ self._add_prefix_to_resources(prefix)
def add_prefix(self, prefix: str) -> None:
super().add_prefix(prefix)
- for resource in self._app.router.resources():
+ self._add_prefix_to_resources(prefix)
+
+ def _add_prefix_to_resources(self, prefix: str) -> None:
+ router = self._app.router
+ for resource in router.resources():
+ # Since the canonical path of a resource is about
+ # to change, we need to unindex it and then reindex
+ router.unindex_resource(resource)
resource.add_prefix(prefix)
+ router.index_resource(resource)
def url_for(self, *args: str, **kwargs: str) -> URL:
raise RuntimeError(".url_for() is not supported " "by sub-application root")
@@ -765,11 +776,6 @@ class PrefixedSubAppResource(PrefixResource):
return {"app": self._app, "prefix": self._prefix}
async def resolve(self, request: Request) -> _Resolve:
- if (
- not request.url.raw_path.startswith(self._prefix2)
- and request.url.raw_path != self._prefix
- ):
- return None, set()
match_info = await self._app.router.resolve(request)
match_info.add_app(self._app)
if isinstance(match_info.http_exception, HTTPMethodNotAllowed):
@@ -1015,12 +1021,39 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]):
super().__init__()
self._resources: List[AbstractResource] = []
self._named_resources: Dict[str, AbstractResource] = {}
+ self._resource_index: dict[str, list[AbstractResource]] = {}
+ self._matched_sub_app_resources: List[MatchedSubAppResource] = []
async def resolve(self, request: Request) -> UrlMappingMatchInfo:
- method = request.method
+ resource_index = self._resource_index
allowed_methods: Set[str] = set()
- for resource in self._resources:
+ # Walk the url parts looking for candidates. We walk the url backwards
+ # to ensure the most explicit match is found first. If there are multiple
+ # candidates for a given url part because there are multiple resources
+ # registered for the same canonical path, we resolve them in a linear
+ # fashion to ensure registration order is respected.
+ url_part = request.rel_url.path_safe
+ while url_part:
+ for candidate in resource_index.get(url_part, ()):
+ match_dict, allowed = await candidate.resolve(request)
+ if match_dict is not None:
+ return match_dict
+ else:
+ allowed_methods |= allowed
+ if url_part == "/":
+ break
+ url_part = url_part.rpartition("/")[0] or "/"
+
+ #
+ # We didn't find any candidates, so we'll try the matched sub-app
+ # resources which we have to walk in a linear fashion because they
+ # have regex/wildcard match rules and we cannot index them.
+ #
+ # For most cases we do not expect there to be many of these since
+ # currently they are only added by `add_domain`
+ #
+ for resource in self._matched_sub_app_resources:
match_dict, allowed = await resource.resolve(request)
if match_dict is not None:
return match_dict
@@ -1028,9 +1061,9 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]):
allowed_methods |= allowed
if allowed_methods:
- return MatchInfoError(HTTPMethodNotAllowed(method, allowed_methods))
- else:
- return MatchInfoError(HTTPNotFound())
+ return MatchInfoError(HTTPMethodNotAllowed(request.method, allowed_methods))
+
+ return MatchInfoError(HTTPNotFound())
def __iter__(self) -> Iterator[str]:
return iter(self._named_resources)
@@ -1086,6 +1119,36 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]):
self._named_resources[name] = resource
self._resources.append(resource)
+ if isinstance(resource, MatchedSubAppResource):
+ # We cannot index match sub-app resources because they have match rules
+ self._matched_sub_app_resources.append(resource)
+ else:
+ self.index_resource(resource)
+
+ def _get_resource_index_key(self, resource: AbstractResource) -> str:
+ """Return a key to index the resource in the resource index."""
+ if "{" in (index_key := resource.canonical):
+ # strip at the first { to allow for variables, and than
+ # rpartition at / to allow for variable parts in the path
+ # For example if the canonical path is `/core/locations{tail:.*}`
+ # the index key will be `/core` since index is based on the
+ # url parts split by `/`
+ index_key = index_key.partition("{")[0].rpartition("/")[0]
+ return index_key.rstrip("/") or "/"
+
+ def index_resource(self, resource: AbstractResource) -> None:
+ """Add a resource to the resource index."""
+ resource_key = self._get_resource_index_key(resource)
+ # There may be multiple resources for a canonical path
+ # so we keep them in a list to ensure that registration
+ # order is respected.
+ self._resource_index.setdefault(resource_key, []).append(resource)
+
+ def unindex_resource(self, resource: AbstractResource) -> None:
+ """Remove a resource from the resource index."""
+ resource_key = self._get_resource_index_key(resource)
+ self._resource_index[resource_key].remove(resource)
+
def add_resource(self, path: str, *, name: Optional[str] = None) -> Resource:
if path and not path.startswith("/"):
raise ValueError("path should be started with / or be empty")
@@ -1095,7 +1158,7 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]):
if resource.name == name and resource.raw_match(path):
return cast(Resource, resource)
if not ("{" in path or "}" in path or ROUTE_RE.search(path)):
- resource = PlainResource(_requote_path(path), name=name)
+ resource = PlainResource(path, name=name)
self.register_resource(resource)
return resource
resource = DynamicResource(path, name=name)
@@ -1221,8 +1284,10 @@ def _quote_path(value: str) -> str:
return URL.build(path=value, encoded=False).raw_path
-def _unquote_path(value: str) -> str:
- return URL.build(path=value, encoded=True).path
+def _unquote_path_safe(value: str) -> str:
+ if "%" not in value:
+ return value
+ return value.replace("%2F", "/").replace("%25", "%")
def _requote_path(value: str) -> str:
diff --git a/contrib/python/aiohttp/aiohttp/web_ws.py b/contrib/python/aiohttp/aiohttp/web_ws.py
index 9fe66527539..382223097ea 100644
--- a/contrib/python/aiohttp/aiohttp/web_ws.py
+++ b/contrib/python/aiohttp/aiohttp/web_ws.py
@@ -11,7 +11,7 @@ from multidict import CIMultiDict
from . import hdrs
from .abc import AbstractStreamWriter
-from .helpers import call_later, set_exception, set_result
+from .helpers import calculate_timeout_when, set_exception, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
@@ -81,67 +81,119 @@ class WebSocketResponse(StreamResponse):
self._conn_lost = 0
self._close_code: Optional[int] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
- self._waiting: Optional[asyncio.Future[bool]] = None
+ self._waiting: bool = False
+ self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._timeout = timeout
self._receive_timeout = receive_timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
+ self._heartbeat_when = 0.0
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._compress = compress
self._max_msg_size = max_msg_size
+ self._ping_task: Optional[asyncio.Task[None]] = None
def _cancel_heartbeat(self) -> None:
- if self._pong_response_cb is not None:
- self._pong_response_cb.cancel()
- self._pong_response_cb = None
-
+ self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
+ if self._ping_task is not None:
+ self._ping_task.cancel()
+ self._ping_task = None
+
+ def _cancel_pong_response_cb(self) -> None:
+ if self._pong_response_cb is not None:
+ self._pong_response_cb.cancel()
+ self._pong_response_cb = None
def _reset_heartbeat(self) -> None:
- self._cancel_heartbeat()
+ if self._heartbeat is None:
+ return
+ self._cancel_pong_response_cb()
+ req = self._req
+ timeout_ceil_threshold = (
+ req._protocol._timeout_ceil_threshold if req is not None else 5
+ )
+ loop = self._loop
+ assert loop is not None
+ now = loop.time()
+ when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
+ self._heartbeat_when = when
+ if self._heartbeat_cb is None:
+ # We do not cancel the previous heartbeat_cb here because
+ # it generates a significant amount of TimerHandle churn
+ # which causes asyncio to rebuild the heap frequently.
+ # Instead _send_heartbeat() will reschedule the next
+ # heartbeat if it fires too early.
+ self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
- if self._heartbeat is not None:
- assert self._loop is not None
- self._heartbeat_cb = call_later(
- self._send_heartbeat,
- self._heartbeat,
- self._loop,
- timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold
- if self._req is not None
- else 5,
+ def _send_heartbeat(self) -> None:
+ self._heartbeat_cb = None
+ loop = self._loop
+ assert loop is not None and self._writer is not None
+ now = loop.time()
+ if now < self._heartbeat_when:
+ # Heartbeat fired too early, reschedule
+ self._heartbeat_cb = loop.call_at(
+ self._heartbeat_when, self._send_heartbeat
)
+ return
- def _send_heartbeat(self) -> None:
- if self._heartbeat is not None and not self._closed:
- assert self._loop is not None
- # fire-and-forget a task is not perfect but maybe ok for
- # sending ping. Otherwise we need a long-living heartbeat
- # task in the class.
- self._loop.create_task(self._writer.ping()) # type: ignore[union-attr]
+ req = self._req
+ timeout_ceil_threshold = (
+ req._protocol._timeout_ceil_threshold if req is not None else 5
+ )
+ when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
+ self._cancel_pong_response_cb()
+ self._pong_response_cb = loop.call_at(when, self._pong_not_received)
- if self._pong_response_cb is not None:
- self._pong_response_cb.cancel()
- self._pong_response_cb = call_later(
- self._pong_not_received,
- self._pong_heartbeat,
- self._loop,
- timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold
- if self._req is not None
- else 5,
- )
+ if sys.version_info >= (3, 12):
+ # Optimization for Python 3.12, try to send the ping
+ # immediately to avoid having to schedule
+ # the task on the event loop.
+ ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
+ else:
+ ping_task = loop.create_task(self._writer.ping())
+
+ if not ping_task.done():
+ self._ping_task = ping_task
+ ping_task.add_done_callback(self._ping_task_done)
+ else:
+ self._ping_task_done(ping_task)
+
+ def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
+ """Callback for when the ping task completes."""
+ if not task.cancelled() and (exc := task.exception()):
+ self._handle_ping_pong_exception(exc)
+ self._ping_task = None
def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
- self._closed = True
- self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
- self._exception = asyncio.TimeoutError()
+ self._handle_ping_pong_exception(asyncio.TimeoutError())
+
+ def _handle_ping_pong_exception(self, exc: BaseException) -> None:
+ """Handle exceptions raised during ping/pong processing."""
+ if self._closed:
+ return
+ self._set_closed()
+ self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
+ self._exception = exc
+ if self._waiting and not self._closing and self._reader is not None:
+ self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))
+
+ def _set_closed(self) -> None:
+ """Set the connection to closed.
+
+ Cancel any heartbeat timers and set the closed flag.
+ """
+ self._closed = True
+ self._cancel_heartbeat()
async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
# make pre-check to don't hide it by do_handshake() exceptions
@@ -366,20 +418,10 @@ class WebSocketResponse(StreamResponse):
if self._writer is None:
raise RuntimeError("Call .prepare() first")
- self._cancel_heartbeat()
- reader = self._reader
- assert reader is not None
-
- # we need to break `receive()` cycle first,
- # `close()` may be called from different task
- if self._waiting is not None and not self._closed:
- reader.feed_data(WS_CLOSING_MESSAGE, 0)
- await self._waiting
-
if self._closed:
return False
+ self._set_closed()
- self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
@@ -394,12 +436,21 @@ class WebSocketResponse(StreamResponse):
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True
+ reader = self._reader
+ assert reader is not None
+ # we need to break `receive()` cycle before we can call
+ # `reader.read()` as `close()` may be called from different task
+ if self._waiting:
+ assert self._loop is not None
+ assert self._close_wait is None
+ self._close_wait = self._loop.create_future()
+ reader.feed_data(WS_CLOSING_MESSAGE)
+ await self._close_wait
+
if self._closing:
self._close_transport()
return True
- reader = self._reader
- assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
@@ -411,7 +462,7 @@ class WebSocketResponse(StreamResponse):
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True
- if msg.type == WSMsgType.CLOSE:
+ if msg.type is WSMsgType.CLOSE:
self._set_code_close_transport(msg.data)
return True
@@ -423,6 +474,7 @@ class WebSocketResponse(StreamResponse):
"""Set the close code and mark the connection as closing."""
self._closing = True
self._close_code = code
+ self._cancel_heartbeat()
def _set_code_close_transport(self, code: WSCloseCode) -> None:
"""Set the close code and close the transport."""
@@ -440,8 +492,9 @@ class WebSocketResponse(StreamResponse):
loop = self._loop
assert loop is not None
+ receive_timeout = timeout or self._receive_timeout
while True:
- if self._waiting is not None:
+ if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")
if self._closed:
@@ -453,15 +506,22 @@ class WebSocketResponse(StreamResponse):
return WS_CLOSING_MESSAGE
try:
- self._waiting = loop.create_future()
+ self._waiting = True
try:
- async with async_timeout.timeout(timeout or self._receive_timeout):
+ if receive_timeout:
+ # Entering the context manager and creating
+ # Timeout() object can take almost 50% of the
+ # run time in this loop so we avoid it if
+ # there is no read timeout.
+ async with async_timeout.timeout(receive_timeout):
+ msg = await self._reader.read()
+ else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
- waiter = self._waiting
- set_result(waiter, True)
- self._waiting = None
+ self._waiting = False
+ if self._close_wait:
+ set_result(self._close_wait, None)
except asyncio.TimeoutError:
raise
except EofStream:
@@ -478,7 +538,7 @@ class WebSocketResponse(StreamResponse):
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
- if msg.type == WSMsgType.CLOSE:
+ if msg.type is WSMsgType.CLOSE:
self._set_closing(msg.data)
# Could be closed while awaiting reader.
if not self._closed and self._autoclose:
@@ -487,19 +547,19 @@ class WebSocketResponse(StreamResponse):
# want to drain any pending writes as it will
# likely result writing to a broken pipe.
await self.close(drain=False)
- elif msg.type == WSMsgType.CLOSING:
+ elif msg.type is WSMsgType.CLOSING:
self._set_closing(WSCloseCode.OK)
- elif msg.type == WSMsgType.PING and self._autoping:
+ elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
- elif msg.type == WSMsgType.PONG and self._autoping:
+ elif msg.type is WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
- if msg.type != WSMsgType.TEXT:
+ if msg.type is not WSMsgType.TEXT:
raise TypeError(
"Received message {}:{!r} is not WSMsgType.TEXT".format(
msg.type, msg.data
@@ -509,7 +569,7 @@ class WebSocketResponse(StreamResponse):
async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
- if msg.type != WSMsgType.BINARY:
+ if msg.type is not WSMsgType.BINARY:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
return cast(bytes, msg.data)
@@ -535,5 +595,6 @@ class WebSocketResponse(StreamResponse):
# web_protocol calls this from connection_lost
# or when the server is shutting down.
self._closing = True
+ self._cancel_heartbeat()
if self._reader is not None:
set_exception(self._reader, exc)
diff --git a/contrib/python/aiohttp/patches/04-force-content-type.patch b/contrib/python/aiohttp/patches/04-force-content-type.patch
new file mode 100644
index 00000000000..44569413307
--- /dev/null
+++ b/contrib/python/aiohttp/patches/04-force-content-type.patch
@@ -0,0 +1,12 @@
+--- contrib/python/aiohttp/aiohttp/web_response.py (ddcb92de87597ba3c0a8961e7fdf04a184c227ce)
++++ contrib/python/aiohttp/aiohttp/web_response.py (0978c4fe84e8994e041f045b1447dd8058efa52c)
+@@ -487,8 +487,7 @@ class StreamResponse(BaseClass, HeadersMixin):
+ # https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-13
+ if hdrs.TRANSFER_ENCODING in headers:
+ del headers[hdrs.TRANSFER_ENCODING]
+- elif self.content_length != 0:
+- # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
++ else:
+ headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream")
+ headers.setdefault(hdrs.DATE, rfc822_formatted_time())
+ headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE)
diff --git a/contrib/python/aiohttp/patches/04-pr11265-support-aiosignal-1.4.patch b/contrib/python/aiohttp/patches/04-pr11265-support-aiosignal-1.4.patch
index 0bfc1cc3a87..3b478e3aca5 100644
--- a/contrib/python/aiohttp/patches/04-pr11265-support-aiosignal-1.4.patch
+++ b/contrib/python/aiohttp/patches/04-pr11265-support-aiosignal-1.4.patch
@@ -1,6 +1,6 @@
--- contrib/python/aiohttp/aiohttp/tracing.py (index)
+++ contrib/python/aiohttp/aiohttp/tracing.py (working tree)
-@@ -12,15 +12,7 @@ if TYPE_CHECKING:
+@@ -12,14 +12,7 @@ if TYPE_CHECKING:
from .client import ClientSession
_ParamT_contra = TypeVar("_ParamT_contra", contravariant=True)
@@ -11,98 +11,83 @@
- __client_session: ClientSession,
- __trace_config_ctx: SimpleNamespace,
- __params: _ParamT_contra,
-- ) -> Awaitable[None]:
-- ...
+- ) -> Awaitable[None]: ...
+ _TracingSignal = Signal[ClientSession, SimpleNamespace, _ParamT_contra]
__all__ = (
-@@ -50,53 +42,53 @@ class TraceConfig:
+@@ -49,54 +42,24 @@ class TraceConfig:
def __init__(
self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace
) -> None:
-- self._on_request_start: Signal[
-- _SignalCallback[TraceRequestStartParams]
-+ self._on_request_start: _TracingSignal[
-+ TraceRequestStartParams
- ] = Signal(self)
+- self._on_request_start: Signal[_SignalCallback[TraceRequestStartParams]] = (
++ self._on_request_start: _TracingSignal[TraceRequestStartParams] = (
+ Signal(self)
+ )
- self._on_request_chunk_sent: Signal[
- _SignalCallback[TraceRequestChunkSentParams]
-+ self._on_request_chunk_sent: _TracingSignal[
-+ TraceRequestChunkSentParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_response_chunk_received: Signal[
- _SignalCallback[TraceResponseChunkReceivedParams]
-+ self._on_response_chunk_received: _TracingSignal[
-+ TraceResponseChunkReceivedParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_request_end: Signal[_SignalCallback[TraceRequestEndParams]] = Signal(
-+ self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal(
- self
- )
+- self
+- )
- self._on_request_exception: Signal[
- _SignalCallback[TraceRequestExceptionParams]
-+ self._on_request_exception: _TracingSignal[
-+ TraceRequestExceptionParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_request_redirect: Signal[
- _SignalCallback[TraceRequestRedirectParams]
-+ self._on_request_redirect: _TracingSignal[
-+ TraceRequestRedirectParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_connection_queued_start: Signal[
- _SignalCallback[TraceConnectionQueuedStartParams]
-+ self._on_connection_queued_start: _TracingSignal[
-+ TraceConnectionQueuedStartParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_connection_queued_end: Signal[
- _SignalCallback[TraceConnectionQueuedEndParams]
-+ self._on_connection_queued_end: _TracingSignal[
-+ TraceConnectionQueuedEndParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_connection_create_start: Signal[
- _SignalCallback[TraceConnectionCreateStartParams]
-+ self._on_connection_create_start: _TracingSignal[
-+ TraceConnectionCreateStartParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_connection_create_end: Signal[
- _SignalCallback[TraceConnectionCreateEndParams]
-+ self._on_connection_create_end: _TracingSignal[
-+ TraceConnectionCreateEndParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_connection_reuseconn: Signal[
- _SignalCallback[TraceConnectionReuseconnParams]
-+ self._on_connection_reuseconn: _TracingSignal[
-+ TraceConnectionReuseconnParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_dns_resolvehost_start: Signal[
- _SignalCallback[TraceDnsResolveHostStartParams]
-+ self._on_dns_resolvehost_start: _TracingSignal[
-+ TraceDnsResolveHostStartParams
- ] = Signal(self)
+- ] = Signal(self)
- self._on_dns_resolvehost_end: Signal[
- _SignalCallback[TraceDnsResolveHostEndParams]
-+ self._on_dns_resolvehost_end: _TracingSignal[
-+ TraceDnsResolveHostEndParams
- ] = Signal(self)
-- self._on_dns_cache_hit: Signal[
-- _SignalCallback[TraceDnsCacheHitParams]
-+ self._on_dns_cache_hit: _TracingSignal[
-+ TraceDnsCacheHitParams
- ] = Signal(self)
-- self._on_dns_cache_miss: Signal[
-- _SignalCallback[TraceDnsCacheMissParams]
-+ self._on_dns_cache_miss: _TracingSignal[
-+ TraceDnsCacheMissParams
- ] = Signal(self)
+- ] = Signal(self)
+- self._on_dns_cache_hit: Signal[_SignalCallback[TraceDnsCacheHitParams]] = (
+- Signal(self)
+- )
+- self._on_dns_cache_miss: Signal[_SignalCallback[TraceDnsCacheMissParams]] = (
+- Signal(self)
+- )
- self._on_request_headers_sent: Signal[
- _SignalCallback[TraceRequestHeadersSentParams]
-+ self._on_request_headers_sent: _TracingSignal[
-+ TraceRequestHeadersSentParams
- ] = Signal(self)
+- ] = Signal(self)
++ self._on_request_chunk_sent: _TracingSignal[TraceRequestChunkSentParams] = Signal(self)
++ self._on_response_chunk_received: _TracingSignal[TraceResponseChunkReceivedParams] = Signal(self)
++ self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal(self)
++ self._on_request_exception: _TracingSignal[TraceRequestExceptionParams] = Signal(self)
++ self._on_request_redirect: _TracingSignal[TraceRequestRedirectParams] = Signal(self)
++ self._on_connection_queued_start: _TracingSignal[TraceConnectionQueuedStartParams] = Signal(self)
++ self._on_connection_queued_end: _TracingSignal[TraceConnectionQueuedEndParams] = Signal(self)
++ self._on_connection_create_start: _TracingSignal[TraceConnectionCreateStartParams] = Signal(self)
++ self._on_connection_create_end: _TracingSignal[TraceConnectionCreateEndParams] = Signal(self)
++ self._on_connection_reuseconn: _TracingSignal[TraceConnectionReuseconnParams] = Signal(self)
++ self._on_dns_resolvehost_start: _TracingSignal[TraceDnsResolveHostStartParams] = Signal(self)
++ self._on_dns_resolvehost_end: _TracingSignal[TraceDnsResolveHostEndParams] = Signal(self)
++ self._on_dns_cache_hit: _TracingSignal[TraceDnsCacheHitParams] = (Signal(self))
++ self._on_dns_cache_miss: _TracingSignal[TraceDnsCacheMissParams] = (Signal(self))
++ self._on_request_headers_sent: _TracingSignal[TraceRequestHeadersSentParams] = Signal(self)
self._trace_config_ctx_factory = trace_config_ctx_factory
-@@ -126,91 +118,89 @@ class TraceConfig:
+
+@@ -125,91 +88,91 @@ class TraceConfig:
self._on_request_headers_sent.freeze()
@property
@@ -111,10 +96,10 @@
return self._on_request_start
@property
-- def on_request_chunk_sent(
-- self,
+ def on_request_chunk_sent(
+ self,
- ) -> "Signal[_SignalCallback[TraceRequestChunkSentParams]]":
-+ def on_request_chunk_sent(self) -> "_TracingSignal[TraceRequestChunkSentParams]":
++ ) -> "_TracingSignal[TraceRequestChunkSentParams]":
return self._on_request_chunk_sent
@property
diff --git a/contrib/python/aiohttp/patches/05-disable-retries-on-oidemponent-methods.patch b/contrib/python/aiohttp/patches/05-disable-retries-on-oidemponent-methods.patch
new file mode 100644
index 00000000000..8bd072eeabe
--- /dev/null
+++ b/contrib/python/aiohttp/patches/05-disable-retries-on-oidemponent-methods.patch
@@ -0,0 +1,11 @@
+--- contrib/python/aiohttp/aiohttp/client.py (index)
++++ contrib/python/aiohttp/aiohttp/client.py (working tree)
+@@ -574,7 +574,7 @@ class ClientSession:
+ try:
+ with timer:
+ # https://www.rfc-editor.org/rfc/rfc9112.html#name-retrying-requests
+- retry_persistent_connection = method in IDEMPOTENT_METHODS
++ retry_persistent_connection = False #method in IDEMPOTENT_METHODS
+ while True:
+ url, auth_from_url = strip_auth_from_url(url)
+ if not url.raw_host:
diff --git a/contrib/python/aiohttp/patches/06-mypy-silence.patch b/contrib/python/aiohttp/patches/06-mypy-silence.patch
new file mode 100644
index 00000000000..d18dca98aa4
--- /dev/null
+++ b/contrib/python/aiohttp/patches/06-mypy-silence.patch
@@ -0,0 +1,61 @@
+--- contrib/python/aiohttp/aiohttp/client.py
++++ contrib/python/aiohttp/aiohttp/client.py
+@@ -174,11 +174,11 @@ class _RequestOptions(TypedDict, total=False):
+ read_until_eof: bool
+ proxy: Union[StrOrURL, None]
+ proxy_auth: Union[BasicAuth, None]
+- timeout: "Union[ClientTimeout, _SENTINEL, None]"
++ timeout: "Union[ClientTimeout, _SENTINEL, int, float, None]"
+ ssl: Union[SSLContext, bool, Fingerprint]
+ server_hostname: Union[str, None]
+ proxy_headers: Union[LooseHeaders, None]
+- trace_request_ctx: Union[Mapping[str, str], None]
++ trace_request_ctx: Any #Union[Mapping[str, str], None]
+ read_bufsize: Union[int, None]
+ auto_decompress: Union[bool, None]
+ max_line_size: Union[int, None]
+--- contrib/python/aiohttp/aiohttp/typedefs.py
++++ contrib/python/aiohttp/aiohttp/typedefs.py
+@@ -69,12 +69,7 @@ LooseCookies = Union[
+ ]
+
+ Handler = Callable[["Request"], Awaitable["StreamResponse"]]
+-
+-
+-class Middleware(Protocol):
+- def __call__(
+- self, request: "Request", handler: Handler
+- ) -> Awaitable["StreamResponse"]: ...
++Middleware = Callable[["Request", Handler], Awaitable["StreamResponse"]]
+
+
+ PathLike = Union[str, "os.PathLike[str]"]
+--- contrib/python/aiohttp/aiohttp/multipart.py
++++ contrib/python/aiohttp/aiohttp/multipart.py
+@@ -287,7 +287,7 @@ class BodyPartReader:
+ self._content_eof = 0
+ self._cache: Dict[str, Any] = {}
+
+- def __aiter__(self: Self) -> Self:
++ def __aiter__(self: Self):
+ return self
+
+ async def __anext__(self) -> bytes:
+@@ -593,7 +593,7 @@ class MultipartReader:
+ response_wrapper_cls = MultipartResponseWrapper
+ #: Multipart reader class, used to handle multipart/* body parts.
+ #: None points to type(self)
+- multipart_reader_cls: Optional[Type["MultipartReader"]] = None
++ multipart_reader_cls = None
+ #: Body part reader class for non multipart/* content types.
+ part_reader_cls = BodyPartReader
+
+@@ -614,7 +614,7 @@ class MultipartReader:
+ self._at_bof = True
+ self._unread: List[bytes] = []
+
+- def __aiter__(self: Self) -> Self:
++ def __aiter__(self: Self):
+ return self
+
+ async def __anext__(
diff --git a/contrib/python/aiohttp/patches/07-dont-throw-at-newline.patch b/contrib/python/aiohttp/patches/07-dont-throw-at-newline.patch
new file mode 100644
index 00000000000..815dd062c17
--- /dev/null
+++ b/contrib/python/aiohttp/patches/07-dont-throw-at-newline.patch
@@ -0,0 +1,12 @@
+# This patch is revert commit dd5bb073107caa1c764158b87fb8482124aad6c1
+--- contrib/python/aiohttp/aiohttp/web_response.py (index)
++++ contrib/python/aiohttp/aiohttp/web_response.py (working tree)
+@@ -147,8 +147,6 @@ class StreamResponse(BaseClass, HeadersMixin):
+ self._status = int(status)
+ if reason is None:
+ reason = REASON_PHRASES.get(self._status, "")
+- elif "\n" in reason:
+- raise ValueError("Reason cannot contain \\n")
+ self._reason = reason
+
+ @property
diff --git a/contrib/python/aiohttp/patches/99-rep-get-running-loop.sh b/contrib/python/aiohttp/patches/99-rep-get-running-loop.sh
new file mode 100644
index 00000000000..b97011e3c67
--- /dev/null
+++ b/contrib/python/aiohttp/patches/99-rep-get-running-loop.sh
@@ -0,0 +1,4 @@
+# This patch may be dropped after python 3.13 upver
+
+find aiohttp -type f -exec sed --in-place 's|loop or asyncio.get_running_loop|loop or asyncio.get_event_loop|g' '{}' ';'
+
diff --git a/contrib/python/aiohttp/ya.make b/contrib/python/aiohttp/ya.make
index 40b3b6faabe..e714d0fd423 100644
--- a/contrib/python/aiohttp/ya.make
+++ b/contrib/python/aiohttp/ya.make
@@ -2,11 +2,12 @@
PY3_LIBRARY()
-VERSION(3.9.5)
+VERSION(3.10.6)
LICENSE(Apache-2.0)
PEERDIR(
+ contrib/python/aiohappyeyeballs
contrib/python/aiosignal
contrib/python/attrs
contrib/python/frozenlist