diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2024-12-17 12:07:28 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2024-12-17 12:18:43 +0300 |
commit | 48bd5f88777f4dc94fd41a7dd22808ed639b985d (patch) | |
tree | 6a899d7cc8bd632073408198260a93d76f99ef32 /contrib/python/Twisted/py3/twisted/internet/endpoints.py | |
parent | 3e05dc5f5c47aa8d220db7b5508cfbd4a0d8919f (diff) | |
download | ydb-48bd5f88777f4dc94fd41a7dd22808ed639b985d.tar.gz |
Intermediate changes
commit_hash:3786c4fc65af12274eea45a3ef9de6050e262ac0
Diffstat (limited to 'contrib/python/Twisted/py3/twisted/internet/endpoints.py')
-rw-r--r-- | contrib/python/Twisted/py3/twisted/internet/endpoints.py | 255 |
1 files changed, 146 insertions, 109 deletions
diff --git a/contrib/python/Twisted/py3/twisted/internet/endpoints.py b/contrib/python/Twisted/py3/twisted/internet/endpoints.py index a98fd2ba43..dfa0cc43ce 100644 --- a/contrib/python/Twisted/py3/twisted/internet/endpoints.py +++ b/contrib/python/Twisted/py3/twisted/internet/endpoints.py @@ -12,21 +12,22 @@ parsed by the L{clientFromString} and L{serverFromString} functions. @since: 10.1 """ +from __future__ import annotations import os import re import socket import warnings -from typing import Optional, Sequence, Type +from typing import Any, Iterable, Optional, Sequence, Type from unicodedata import normalize -from zope.interface import directlyProvides, implementer, provider +from zope.interface import directlyProvides, implementer from constantly import NamedConstant, Names from incremental import Version from twisted.internet import defer, error, fdesc, interfaces, threads -from twisted.internet.abstract import isIPAddress, isIPv6Address +from twisted.internet.abstract import isIPv6Address from twisted.internet.address import ( HostnameAddress, IPv4Address, @@ -37,9 +38,13 @@ from twisted.internet.interfaces import ( IAddress, IHostnameResolver, IHostResolution, + IOpenSSLClientConnectionCreator, + IProtocol, + IProtocolFactory, IReactorPluggableNameResolver, IReactorSocket, IResolutionReceiver, + IStreamClientEndpoint, IStreamClientEndpointStringParserWithReactor, IStreamServerEndpointStringParser, ) @@ -201,14 +206,16 @@ class _WrappingFactory(ClientFactory): # Type is wrong. See https://twistedmatrix.com/trac/ticket/10005#ticket protocol = _WrappingProtocol # type: ignore[assignment] - def __init__(self, wrappedFactory): + def __init__(self, wrappedFactory: IProtocolFactory) -> None: """ @param wrappedFactory: A provider of I{IProtocolFactory} whose buildProtocol method will be called and whose resulting protocol will be wrapped. """ self._wrappedFactory = wrappedFactory - self._onConnection = defer.Deferred(canceller=self._canceller) + self._onConnection: defer.Deferred[IProtocol] = defer.Deferred( + canceller=self._canceller + ) def startedConnecting(self, connector): """ @@ -567,7 +574,14 @@ class TCP4ClientEndpoint: TCP client endpoint with an IPv4 configuration. """ - def __init__(self, reactor, host, port, timeout=30, bindAddress=None): + def __init__( + self, + reactor: Any, + host: str, + port: int, + timeout: float = 30, + bindAddress: str | tuple[bytes | str, int] | None = None, + ) -> None: """ @param reactor: An L{IReactorTCP} provider @@ -591,7 +605,7 @@ class TCP4ClientEndpoint: self._timeout = timeout self._bindAddress = bindAddress - def connect(self, protocolFactory): + def connect(self, protocolFactory: IProtocolFactory) -> Deferred[IProtocol]: """ Implement L{IStreamClientEndpoint.connect} to connect via TCP. """ @@ -665,7 +679,9 @@ class TCP6ClientEndpoint: """ return self._deferToThread(self._getaddrinfo, host, 0, socket.AF_INET6) - def _resolvedHostConnect(self, resolvedHost, protocolFactory): + def _resolvedHostConnect( + self, resolvedHost: str, protocolFactory: IProtocolFactory + ) -> Deferred[IProtocol]: """ Connect to the server using the resolved hostname. """ @@ -774,10 +790,6 @@ class HostnameEndpoint: associated with this endpoint. @type _hostBytes: L{bytes} - @ivar _hostStr: the native-string representation of the hostname passed to - the constructor, used for exception construction - @type _hostStr: native L{str} - @ivar _badHostname: a flag - hopefully false! - indicating that an invalid hostname was passed to the constructor. This might be a textual hostname that isn't valid IDNA, or non-ASCII bytes. @@ -789,8 +801,14 @@ class HostnameEndpoint: _DEFAULT_ATTEMPT_DELAY = 0.3 def __init__( - self, reactor, host, port, timeout=30, bindAddress=None, attemptDelay=None - ): + self, + reactor: Any, + host: str | bytes, + port: int, + timeout: float = 30, + bindAddress: bytes | str | tuple[bytes | str, int] | None = None, + attemptDelay: float | None = None, + ) -> None: """ Create a L{HostnameEndpoint}. @@ -799,7 +817,7 @@ class HostnameEndpoint: L{IReactorPluggableNameResolver} or L{IReactorPluggableResolver}. @param host: A hostname to connect to. - @type host: L{bytes} or L{unicode} + @type host: L{bytes} or L{str} @param port: The port number to connect to. @type port: L{int} @@ -833,9 +851,9 @@ class HostnameEndpoint: [self._badHostname, self._hostBytes, self._hostText] = self._hostAsBytesAndText( host ) - self._hostStr = self._hostBytes if bytes is str else self._hostText self._port = port self._timeout = timeout + if bindAddress is not None: if isinstance(bindAddress, (bytes, str)): bindAddress = (bindAddress, 0) @@ -852,21 +870,25 @@ class HostnameEndpoint: @return: A L{str} """ - if self._badHostname: - # Use the backslash-encoded version of the string passed to the - # constructor, which is already a native string. - host = self._hostStr - elif isIPv6Address(self._hostStr): - host = f"[{self._hostStr}]" - else: - # Convert the bytes representation to a native string to ensure - # that we display the punycoded version of the hostname, which is - # more useful than any IDN version as it can be easily copy-pasted - # into debugging tools. - host = nativeString(self._hostBytes) - return "".join(["<HostnameEndpoint ", host, ":", str(self._port), ">"]) + host = ( + # It the hostname is bad, use the backslash-encoded version of the + # string passed to the constructor, which is already a string. + self._hostText + if self._badHostname + else ( + # Add some square brackets if it's an IPv6 address. + f"[{self._hostText}]" + if isIPv6Address(self._hostText) + # Convert the bytes representation to a native string to ensure + # that we display the punycoded version of the hostname, which is + # more useful than any IDN version as it can be easily copy-pasted + # into debugging tools. + else self._hostBytes.decode("ascii") + ) + ) + return f"<HostnameEndpoint {host}:{self._port}>" - def _getNameResolverAndMaybeWarn(self, reactor): + def _getNameResolverAndMaybeWarn(self, reactor: object) -> IHostnameResolver: """ Retrieve a C{nameResolver} callable and warn the caller's caller that using a reactor which doesn't provide @@ -894,7 +916,7 @@ class HostnameEndpoint: return reactor.nameResolver @staticmethod - def _hostAsBytesAndText(host): + def _hostAsBytesAndText(host: bytes | str) -> tuple[bool, bytes, str]: """ For various reasons (documented in the C{@ivar}'s in the class docstring) we need both a textual and a binary representation of the @@ -906,39 +928,36 @@ class HostnameEndpoint: this up in the future and just operate in terms of text internally. @param host: A hostname to convert. - @type host: L{bytes} or C{str} @return: a 3-tuple of C{(invalid, bytes, text)} where C{invalid} is a boolean indicating the validity of the hostname, C{bytes} is a binary representation of C{host}, and C{text} is a textual representation of C{host}. """ + invalid = False if isinstance(host, bytes): - if isIPAddress(host) or isIPv6Address(host): - return False, host, host.decode("ascii") - else: - try: - return False, host, _idnaText(host) - except UnicodeError: - # Convert the host to _some_ kind of text, to handle below. - host = host.decode("charmap") + hostBytes = host + try: + hostText = _idnaText(hostBytes) + except UnicodeError: + hostText = hostBytes.decode("charmap") + if not isIPv6Address(hostText): + invalid = True else: - host = normalize("NFC", host) - if isIPAddress(host) or isIPv6Address(host): - return False, host.encode("ascii"), host + hostText = normalize("NFC", host) + if isIPv6Address(hostText): + hostBytes = hostText.encode("ascii") else: try: - return False, _idnaBytes(host), host + hostBytes = _idnaBytes(hostText) except UnicodeError: - pass - # `host` has been converted to text by this point either way; it's - # invalid as a hostname, and so may contain unprintable characters and - # such. escape it with backslashes so the user can get _some_ guess as - # to what went wrong. - asciibytes = host.encode("ascii", "backslashreplace") - return True, asciibytes, asciibytes.decode("ascii") + invalid = True + if invalid: + hostBytes = hostText.encode("ascii", "backslashreplace") + hostText = hostBytes.decode("ascii") + return invalid, hostBytes, hostText - def connect(self, protocolFactory): + def connect(self, protocolFactory: IProtocolFactory) -> Deferred[IProtocol]: """ Attempts a connection to each resolved address, and returns a connection which is established first. @@ -952,37 +971,38 @@ class HostnameEndpoint: or fails a connection-related error. """ if self._badHostname: - return defer.fail(ValueError(f"invalid hostname: {self._hostStr}")) + return defer.fail(ValueError(f"invalid hostname: {self._hostText}")) - d = Deferred() - addresses = [] + resolved: Deferred[list[IAddress]] = Deferred() + addresses: list[IAddress] = [] - @provider(IResolutionReceiver) + @implementer(IResolutionReceiver) class EndpointReceiver: @staticmethod - def resolutionBegan(resolutionInProgress): + def resolutionBegan(resolutionInProgress: IHostResolution) -> None: pass @staticmethod - def addressResolved(address): + def addressResolved(address: IAddress) -> None: addresses.append(address) @staticmethod - def resolutionComplete(): - d.callback(addresses) + def resolutionComplete() -> None: + resolved.callback(addresses) self._nameResolver.resolveHostName( - EndpointReceiver, self._hostText, portNumber=self._port + EndpointReceiver(), self._hostText, portNumber=self._port ) - d.addErrback( + resolved.addErrback( lambda ignored: defer.fail( - error.DNSLookupError(f"Couldn't find the hostname '{self._hostStr}'") + error.DNSLookupError(f"Couldn't find the hostname '{self._hostText}'") ) ) - @d.addCallback - def resolvedAddressesToEndpoints(addresses): + def resolvedAddressesToEndpoints( + addresses: Iterable[IAddress], + ) -> Iterable[TCP6ClientEndpoint | TCP4ClientEndpoint]: # Yield an endpoint for every address resolved from the name. for eachAddress in addresses: if isinstance(eachAddress, IPv6Address): @@ -1002,22 +1022,24 @@ class HostnameEndpoint: self._bindAddress, ) - d.addCallback(list) + iterd = resolved.addCallback(resolvedAddressesToEndpoints) + listd = iterd.addCallback(list) - def _canceller(d): + def _canceller(cancelled: Deferred[IProtocol]) -> None: # This canceller must remain defined outside of # `startConnectionAttempts`, because Deferred should not # participate in cycles with their cancellers; that would create a # potentially problematic circular reference and possibly # gc.garbage. - d.errback( + cancelled.errback( error.ConnectingCancelledError( HostnameAddress(self._hostBytes, self._port) ) ) - @d.addCallback - def startConnectionAttempts(endpoints): + def startConnectionAttempts( + endpoints: list[TCP6ClientEndpoint | TCP4ClientEndpoint], + ) -> Deferred[IProtocol]: """ Given a sequence of endpoints obtained via name resolution, start connecting to a new one every C{self._attemptDelay} seconds until @@ -1037,62 +1059,68 @@ class HostnameEndpoint: """ if not endpoints: raise error.DNSLookupError( - f"no results for hostname lookup: {self._hostStr}" + f"no results for hostname lookup: {self._hostText}" ) iterEndpoints = iter(endpoints) - pending = [] - failures = [] - winner = defer.Deferred(canceller=_canceller) + pending: list[defer.Deferred[IProtocol]] = [] + failures: list[Failure] = [] + winner: defer.Deferred[IProtocol] = defer.Deferred(canceller=_canceller) - def checkDone(): - if pending or checkDone.completed or checkDone.endpointsLeft: + checkDoneCompleted = False + checkDoneEndpointsLeft = True + + def checkDone() -> None: + if pending or checkDoneCompleted or checkDoneEndpointsLeft: return winner.errback(failures.pop()) - checkDone.completed = False - checkDone.endpointsLeft = True - @LoopingCall - def iterateEndpoint(): + def iterateEndpoint() -> None: + nonlocal checkDoneEndpointsLeft endpoint = next(iterEndpoints, None) if endpoint is None: # The list of endpoints ends. - checkDone.endpointsLeft = False + checkDoneEndpointsLeft = False checkDone() return eachAttempt = endpoint.connect(protocolFactory) pending.append(eachAttempt) - @eachAttempt.addBoth - def noLongerPending(result): + def noLongerPending(result: IProtocol | Failure) -> IProtocol | Failure: pending.remove(eachAttempt) return result - @eachAttempt.addCallback - def succeeded(result): + successState = eachAttempt.addBoth(noLongerPending) + + def succeeded(result: IProtocol) -> None: winner.callback(result) - @eachAttempt.addErrback + successState.addCallback(succeeded) + def failed(reason): failures.append(reason) checkDone() + successState.addErrback(failed) + iterateEndpoint.clock = self._reactor iterateEndpoint.start(self._attemptDelay) - @winner.addBoth - def cancelRemainingPending(result): - checkDone.completed = True + def cancelRemainingPending( + result: IProtocol | Failure, + ) -> IProtocol | Failure: + nonlocal checkDoneCompleted + checkDoneCompleted = True for remaining in pending[:]: remaining.cancel() if iterateEndpoint.running: iterateEndpoint.stop() return result - return winner + return winner.addBoth(cancelRemainingPending) - return d + return listd.addCallback(startConnectionAttempts) def _fallbackNameResolution(self, host, port): """ @@ -2218,7 +2246,10 @@ class _WrapperServerEndpoint: return self._wrappedEndpoint.listen(self._wrapperFactory(protocolFactory)) -def wrapClientTLS(connectionCreator, wrappedEndpoint): +def wrapClientTLS( + connectionCreator: IOpenSSLClientConnectionCreator, + wrappedEndpoint: IStreamClientEndpoint, +) -> _WrapperEndpoint: """ Wrap an endpoint which upgrades to TLS as soon as the connection is established. @@ -2250,17 +2281,17 @@ def wrapClientTLS(connectionCreator, wrappedEndpoint): def _parseClientTLS( - reactor, - host, - port, - timeout=b"30", - bindAddress=None, - certificate=None, - privateKey=None, - trustRoots=None, - endpoint=None, - **kwargs, -): + reactor: Any, + host: bytes | str, + port: bytes | str, + timeout: bytes | str = b"30", + bindAddress: bytes | str | None = None, + certificate: bytes | str | None = None, + privateKey: bytes | str | None = None, + trustRoots: bytes | str | None = None, + endpoint: bytes | str | None = None, + **kwargs: object, +) -> IStreamClientEndpoint: """ Internal method to construct an endpoint from string parameters. @@ -2303,18 +2334,24 @@ def _parseClientTLS( if isinstance(bindAddress, str) or bindAddress is None else bindAddress.decode("utf-8") ) - port = int(port) - timeout = int(timeout) + portint = int(port) + timeoutint = int(timeout) return wrapClientTLS( optionsForClientTLS( host, trustRoot=_parseTrustRootPath(trustRoots), clientCertificate=_privateCertFromPaths(certificate, privateKey), ), - clientFromString(reactor, endpoint) - if endpoint is not None - else HostnameEndpoint( - reactor, _idnaBytes(host), port, timeout, (bindAddress, 0) + ( + clientFromString(reactor, endpoint) + if endpoint is not None + else HostnameEndpoint( + reactor, + _idnaBytes(host), + portint, + timeoutint, + None if bindAddress is None else (bindAddress, 0), + ) ), ) |