aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/Twisted/py3/twisted/internet/endpoints.py
diff options
context:
space:
mode:
authorrobot-piglet <robot-piglet@yandex-team.com>2024-12-17 12:07:28 +0300
committerrobot-piglet <robot-piglet@yandex-team.com>2024-12-17 12:18:43 +0300
commit48bd5f88777f4dc94fd41a7dd22808ed639b985d (patch)
tree6a899d7cc8bd632073408198260a93d76f99ef32 /contrib/python/Twisted/py3/twisted/internet/endpoints.py
parent3e05dc5f5c47aa8d220db7b5508cfbd4a0d8919f (diff)
downloadydb-48bd5f88777f4dc94fd41a7dd22808ed639b985d.tar.gz
Intermediate changes
commit_hash:3786c4fc65af12274eea45a3ef9de6050e262ac0
Diffstat (limited to 'contrib/python/Twisted/py3/twisted/internet/endpoints.py')
-rw-r--r--contrib/python/Twisted/py3/twisted/internet/endpoints.py255
1 files changed, 146 insertions, 109 deletions
diff --git a/contrib/python/Twisted/py3/twisted/internet/endpoints.py b/contrib/python/Twisted/py3/twisted/internet/endpoints.py
index a98fd2ba43..dfa0cc43ce 100644
--- a/contrib/python/Twisted/py3/twisted/internet/endpoints.py
+++ b/contrib/python/Twisted/py3/twisted/internet/endpoints.py
@@ -12,21 +12,22 @@ parsed by the L{clientFromString} and L{serverFromString} functions.
@since: 10.1
"""
+from __future__ import annotations
import os
import re
import socket
import warnings
-from typing import Optional, Sequence, Type
+from typing import Any, Iterable, Optional, Sequence, Type
from unicodedata import normalize
-from zope.interface import directlyProvides, implementer, provider
+from zope.interface import directlyProvides, implementer
from constantly import NamedConstant, Names
from incremental import Version
from twisted.internet import defer, error, fdesc, interfaces, threads
-from twisted.internet.abstract import isIPAddress, isIPv6Address
+from twisted.internet.abstract import isIPv6Address
from twisted.internet.address import (
HostnameAddress,
IPv4Address,
@@ -37,9 +38,13 @@ from twisted.internet.interfaces import (
IAddress,
IHostnameResolver,
IHostResolution,
+ IOpenSSLClientConnectionCreator,
+ IProtocol,
+ IProtocolFactory,
IReactorPluggableNameResolver,
IReactorSocket,
IResolutionReceiver,
+ IStreamClientEndpoint,
IStreamClientEndpointStringParserWithReactor,
IStreamServerEndpointStringParser,
)
@@ -201,14 +206,16 @@ class _WrappingFactory(ClientFactory):
# Type is wrong. See https://twistedmatrix.com/trac/ticket/10005#ticket
protocol = _WrappingProtocol # type: ignore[assignment]
- def __init__(self, wrappedFactory):
+ def __init__(self, wrappedFactory: IProtocolFactory) -> None:
"""
@param wrappedFactory: A provider of I{IProtocolFactory} whose
buildProtocol method will be called and whose resulting protocol
will be wrapped.
"""
self._wrappedFactory = wrappedFactory
- self._onConnection = defer.Deferred(canceller=self._canceller)
+ self._onConnection: defer.Deferred[IProtocol] = defer.Deferred(
+ canceller=self._canceller
+ )
def startedConnecting(self, connector):
"""
@@ -567,7 +574,14 @@ class TCP4ClientEndpoint:
TCP client endpoint with an IPv4 configuration.
"""
- def __init__(self, reactor, host, port, timeout=30, bindAddress=None):
+ def __init__(
+ self,
+ reactor: Any,
+ host: str,
+ port: int,
+ timeout: float = 30,
+ bindAddress: str | tuple[bytes | str, int] | None = None,
+ ) -> None:
"""
@param reactor: An L{IReactorTCP} provider
@@ -591,7 +605,7 @@ class TCP4ClientEndpoint:
self._timeout = timeout
self._bindAddress = bindAddress
- def connect(self, protocolFactory):
+ def connect(self, protocolFactory: IProtocolFactory) -> Deferred[IProtocol]:
"""
Implement L{IStreamClientEndpoint.connect} to connect via TCP.
"""
@@ -665,7 +679,9 @@ class TCP6ClientEndpoint:
"""
return self._deferToThread(self._getaddrinfo, host, 0, socket.AF_INET6)
- def _resolvedHostConnect(self, resolvedHost, protocolFactory):
+ def _resolvedHostConnect(
+ self, resolvedHost: str, protocolFactory: IProtocolFactory
+ ) -> Deferred[IProtocol]:
"""
Connect to the server using the resolved hostname.
"""
@@ -774,10 +790,6 @@ class HostnameEndpoint:
associated with this endpoint.
@type _hostBytes: L{bytes}
- @ivar _hostStr: the native-string representation of the hostname passed to
- the constructor, used for exception construction
- @type _hostStr: native L{str}
-
@ivar _badHostname: a flag - hopefully false! - indicating that an invalid
hostname was passed to the constructor. This might be a textual
hostname that isn't valid IDNA, or non-ASCII bytes.
@@ -789,8 +801,14 @@ class HostnameEndpoint:
_DEFAULT_ATTEMPT_DELAY = 0.3
def __init__(
- self, reactor, host, port, timeout=30, bindAddress=None, attemptDelay=None
- ):
+ self,
+ reactor: Any,
+ host: str | bytes,
+ port: int,
+ timeout: float = 30,
+ bindAddress: bytes | str | tuple[bytes | str, int] | None = None,
+ attemptDelay: float | None = None,
+ ) -> None:
"""
Create a L{HostnameEndpoint}.
@@ -799,7 +817,7 @@ class HostnameEndpoint:
L{IReactorPluggableNameResolver} or L{IReactorPluggableResolver}.
@param host: A hostname to connect to.
- @type host: L{bytes} or L{unicode}
+ @type host: L{bytes} or L{str}
@param port: The port number to connect to.
@type port: L{int}
@@ -833,9 +851,9 @@ class HostnameEndpoint:
[self._badHostname, self._hostBytes, self._hostText] = self._hostAsBytesAndText(
host
)
- self._hostStr = self._hostBytes if bytes is str else self._hostText
self._port = port
self._timeout = timeout
+
if bindAddress is not None:
if isinstance(bindAddress, (bytes, str)):
bindAddress = (bindAddress, 0)
@@ -852,21 +870,25 @@ class HostnameEndpoint:
@return: A L{str}
"""
- if self._badHostname:
- # Use the backslash-encoded version of the string passed to the
- # constructor, which is already a native string.
- host = self._hostStr
- elif isIPv6Address(self._hostStr):
- host = f"[{self._hostStr}]"
- else:
- # Convert the bytes representation to a native string to ensure
- # that we display the punycoded version of the hostname, which is
- # more useful than any IDN version as it can be easily copy-pasted
- # into debugging tools.
- host = nativeString(self._hostBytes)
- return "".join(["<HostnameEndpoint ", host, ":", str(self._port), ">"])
+ host = (
+ # It the hostname is bad, use the backslash-encoded version of the
+ # string passed to the constructor, which is already a string.
+ self._hostText
+ if self._badHostname
+ else (
+ # Add some square brackets if it's an IPv6 address.
+ f"[{self._hostText}]"
+ if isIPv6Address(self._hostText)
+ # Convert the bytes representation to a native string to ensure
+ # that we display the punycoded version of the hostname, which is
+ # more useful than any IDN version as it can be easily copy-pasted
+ # into debugging tools.
+ else self._hostBytes.decode("ascii")
+ )
+ )
+ return f"<HostnameEndpoint {host}:{self._port}>"
- def _getNameResolverAndMaybeWarn(self, reactor):
+ def _getNameResolverAndMaybeWarn(self, reactor: object) -> IHostnameResolver:
"""
Retrieve a C{nameResolver} callable and warn the caller's
caller that using a reactor which doesn't provide
@@ -894,7 +916,7 @@ class HostnameEndpoint:
return reactor.nameResolver
@staticmethod
- def _hostAsBytesAndText(host):
+ def _hostAsBytesAndText(host: bytes | str) -> tuple[bool, bytes, str]:
"""
For various reasons (documented in the C{@ivar}'s in the class
docstring) we need both a textual and a binary representation of the
@@ -906,39 +928,36 @@ class HostnameEndpoint:
this up in the future and just operate in terms of text internally.
@param host: A hostname to convert.
- @type host: L{bytes} or C{str}
@return: a 3-tuple of C{(invalid, bytes, text)} where C{invalid} is a
boolean indicating the validity of the hostname, C{bytes} is a
binary representation of C{host}, and C{text} is a textual
representation of C{host}.
"""
+ invalid = False
if isinstance(host, bytes):
- if isIPAddress(host) or isIPv6Address(host):
- return False, host, host.decode("ascii")
- else:
- try:
- return False, host, _idnaText(host)
- except UnicodeError:
- # Convert the host to _some_ kind of text, to handle below.
- host = host.decode("charmap")
+ hostBytes = host
+ try:
+ hostText = _idnaText(hostBytes)
+ except UnicodeError:
+ hostText = hostBytes.decode("charmap")
+ if not isIPv6Address(hostText):
+ invalid = True
else:
- host = normalize("NFC", host)
- if isIPAddress(host) or isIPv6Address(host):
- return False, host.encode("ascii"), host
+ hostText = normalize("NFC", host)
+ if isIPv6Address(hostText):
+ hostBytes = hostText.encode("ascii")
else:
try:
- return False, _idnaBytes(host), host
+ hostBytes = _idnaBytes(hostText)
except UnicodeError:
- pass
- # `host` has been converted to text by this point either way; it's
- # invalid as a hostname, and so may contain unprintable characters and
- # such. escape it with backslashes so the user can get _some_ guess as
- # to what went wrong.
- asciibytes = host.encode("ascii", "backslashreplace")
- return True, asciibytes, asciibytes.decode("ascii")
+ invalid = True
+ if invalid:
+ hostBytes = hostText.encode("ascii", "backslashreplace")
+ hostText = hostBytes.decode("ascii")
+ return invalid, hostBytes, hostText
- def connect(self, protocolFactory):
+ def connect(self, protocolFactory: IProtocolFactory) -> Deferred[IProtocol]:
"""
Attempts a connection to each resolved address, and returns a
connection which is established first.
@@ -952,37 +971,38 @@ class HostnameEndpoint:
or fails a connection-related error.
"""
if self._badHostname:
- return defer.fail(ValueError(f"invalid hostname: {self._hostStr}"))
+ return defer.fail(ValueError(f"invalid hostname: {self._hostText}"))
- d = Deferred()
- addresses = []
+ resolved: Deferred[list[IAddress]] = Deferred()
+ addresses: list[IAddress] = []
- @provider(IResolutionReceiver)
+ @implementer(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
- def resolutionBegan(resolutionInProgress):
+ def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
@staticmethod
- def addressResolved(address):
+ def addressResolved(address: IAddress) -> None:
addresses.append(address)
@staticmethod
- def resolutionComplete():
- d.callback(addresses)
+ def resolutionComplete() -> None:
+ resolved.callback(addresses)
self._nameResolver.resolveHostName(
- EndpointReceiver, self._hostText, portNumber=self._port
+ EndpointReceiver(), self._hostText, portNumber=self._port
)
- d.addErrback(
+ resolved.addErrback(
lambda ignored: defer.fail(
- error.DNSLookupError(f"Couldn't find the hostname '{self._hostStr}'")
+ error.DNSLookupError(f"Couldn't find the hostname '{self._hostText}'")
)
)
- @d.addCallback
- def resolvedAddressesToEndpoints(addresses):
+ def resolvedAddressesToEndpoints(
+ addresses: Iterable[IAddress],
+ ) -> Iterable[TCP6ClientEndpoint | TCP4ClientEndpoint]:
# Yield an endpoint for every address resolved from the name.
for eachAddress in addresses:
if isinstance(eachAddress, IPv6Address):
@@ -1002,22 +1022,24 @@ class HostnameEndpoint:
self._bindAddress,
)
- d.addCallback(list)
+ iterd = resolved.addCallback(resolvedAddressesToEndpoints)
+ listd = iterd.addCallback(list)
- def _canceller(d):
+ def _canceller(cancelled: Deferred[IProtocol]) -> None:
# This canceller must remain defined outside of
# `startConnectionAttempts`, because Deferred should not
# participate in cycles with their cancellers; that would create a
# potentially problematic circular reference and possibly
# gc.garbage.
- d.errback(
+ cancelled.errback(
error.ConnectingCancelledError(
HostnameAddress(self._hostBytes, self._port)
)
)
- @d.addCallback
- def startConnectionAttempts(endpoints):
+ def startConnectionAttempts(
+ endpoints: list[TCP6ClientEndpoint | TCP4ClientEndpoint],
+ ) -> Deferred[IProtocol]:
"""
Given a sequence of endpoints obtained via name resolution, start
connecting to a new one every C{self._attemptDelay} seconds until
@@ -1037,62 +1059,68 @@ class HostnameEndpoint:
"""
if not endpoints:
raise error.DNSLookupError(
- f"no results for hostname lookup: {self._hostStr}"
+ f"no results for hostname lookup: {self._hostText}"
)
iterEndpoints = iter(endpoints)
- pending = []
- failures = []
- winner = defer.Deferred(canceller=_canceller)
+ pending: list[defer.Deferred[IProtocol]] = []
+ failures: list[Failure] = []
+ winner: defer.Deferred[IProtocol] = defer.Deferred(canceller=_canceller)
- def checkDone():
- if pending or checkDone.completed or checkDone.endpointsLeft:
+ checkDoneCompleted = False
+ checkDoneEndpointsLeft = True
+
+ def checkDone() -> None:
+ if pending or checkDoneCompleted or checkDoneEndpointsLeft:
return
winner.errback(failures.pop())
- checkDone.completed = False
- checkDone.endpointsLeft = True
-
@LoopingCall
- def iterateEndpoint():
+ def iterateEndpoint() -> None:
+ nonlocal checkDoneEndpointsLeft
endpoint = next(iterEndpoints, None)
if endpoint is None:
# The list of endpoints ends.
- checkDone.endpointsLeft = False
+ checkDoneEndpointsLeft = False
checkDone()
return
eachAttempt = endpoint.connect(protocolFactory)
pending.append(eachAttempt)
- @eachAttempt.addBoth
- def noLongerPending(result):
+ def noLongerPending(result: IProtocol | Failure) -> IProtocol | Failure:
pending.remove(eachAttempt)
return result
- @eachAttempt.addCallback
- def succeeded(result):
+ successState = eachAttempt.addBoth(noLongerPending)
+
+ def succeeded(result: IProtocol) -> None:
winner.callback(result)
- @eachAttempt.addErrback
+ successState.addCallback(succeeded)
+
def failed(reason):
failures.append(reason)
checkDone()
+ successState.addErrback(failed)
+
iterateEndpoint.clock = self._reactor
iterateEndpoint.start(self._attemptDelay)
- @winner.addBoth
- def cancelRemainingPending(result):
- checkDone.completed = True
+ def cancelRemainingPending(
+ result: IProtocol | Failure,
+ ) -> IProtocol | Failure:
+ nonlocal checkDoneCompleted
+ checkDoneCompleted = True
for remaining in pending[:]:
remaining.cancel()
if iterateEndpoint.running:
iterateEndpoint.stop()
return result
- return winner
+ return winner.addBoth(cancelRemainingPending)
- return d
+ return listd.addCallback(startConnectionAttempts)
def _fallbackNameResolution(self, host, port):
"""
@@ -2218,7 +2246,10 @@ class _WrapperServerEndpoint:
return self._wrappedEndpoint.listen(self._wrapperFactory(protocolFactory))
-def wrapClientTLS(connectionCreator, wrappedEndpoint):
+def wrapClientTLS(
+ connectionCreator: IOpenSSLClientConnectionCreator,
+ wrappedEndpoint: IStreamClientEndpoint,
+) -> _WrapperEndpoint:
"""
Wrap an endpoint which upgrades to TLS as soon as the connection is
established.
@@ -2250,17 +2281,17 @@ def wrapClientTLS(connectionCreator, wrappedEndpoint):
def _parseClientTLS(
- reactor,
- host,
- port,
- timeout=b"30",
- bindAddress=None,
- certificate=None,
- privateKey=None,
- trustRoots=None,
- endpoint=None,
- **kwargs,
-):
+ reactor: Any,
+ host: bytes | str,
+ port: bytes | str,
+ timeout: bytes | str = b"30",
+ bindAddress: bytes | str | None = None,
+ certificate: bytes | str | None = None,
+ privateKey: bytes | str | None = None,
+ trustRoots: bytes | str | None = None,
+ endpoint: bytes | str | None = None,
+ **kwargs: object,
+) -> IStreamClientEndpoint:
"""
Internal method to construct an endpoint from string parameters.
@@ -2303,18 +2334,24 @@ def _parseClientTLS(
if isinstance(bindAddress, str) or bindAddress is None
else bindAddress.decode("utf-8")
)
- port = int(port)
- timeout = int(timeout)
+ portint = int(port)
+ timeoutint = int(timeout)
return wrapClientTLS(
optionsForClientTLS(
host,
trustRoot=_parseTrustRootPath(trustRoots),
clientCertificate=_privateCertFromPaths(certificate, privateKey),
),
- clientFromString(reactor, endpoint)
- if endpoint is not None
- else HostnameEndpoint(
- reactor, _idnaBytes(host), port, timeout, (bindAddress, 0)
+ (
+ clientFromString(reactor, endpoint)
+ if endpoint is not None
+ else HostnameEndpoint(
+ reactor,
+ _idnaBytes(host),
+ portint,
+ timeoutint,
+ None if bindAddress is None else (bindAddress, 0),
+ )
),
)