diff options
| author | robot-piglet <[email protected]> | 2025-06-22 18:50:56 +0300 |
|---|---|---|
| committer | robot-piglet <[email protected]> | 2025-06-22 19:04:42 +0300 |
| commit | c7cbc6d480c5488ff6e921c709680fd2c1340a10 (patch) | |
| tree | 10843f44b67c0fb5717ad555556064095f701d8c /contrib/python/Twisted/py3/twisted/conch | |
| parent | 26d391cdb94d2ce5efc8d0cc5cea7607dc363c0b (diff) | |
Intermediate changes
commit_hash:28750b74281710ec1ab5bdc2403c8ab24bdd164b
Diffstat (limited to 'contrib/python/Twisted/py3/twisted/conch')
14 files changed, 444 insertions, 335 deletions
diff --git a/contrib/python/Twisted/py3/twisted/conch/client/connect.py b/contrib/python/Twisted/py3/twisted/conch/client/connect.py index f21f16768bb..1683e7f0704 100644 --- a/contrib/python/Twisted/py3/twisted/conch/client/connect.py +++ b/contrib/python/Twisted/py3/twisted/conch/client/connect.py @@ -1,24 +1,46 @@ # Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. -# +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure + +if TYPE_CHECKING: + from twisted.conch.client.options import ConchOptions + from twisted.conch.ssh.userauth import SSHUserAuthClient + from twisted.conch.client import direct -connectTypes = {"direct": direct.connect} +connectTypes: dict[ + str, + Callable[[str, int, ConchOptions, direct._VHK, SSHUserAuthClient], Deferred[None]], +] = { + "direct": direct.connect, +} -def connect(host, port, options, verifyHostKey, userAuthObject): +def connect( + host: str, + port: int, + options: ConchOptions, + verifyHostKey: direct._VHK, + userAuthObject: SSHUserAuthClient, +) -> Deferred[None]: useConnects = ["direct"] - return _ebConnect( - None, useConnects, host, port, options, verifyHostKey, userAuthObject - ) - - -def _ebConnect(f, useConnects, host, port, options, vhk, uao): - if not useConnects: - return f - connectType = useConnects.pop(0) - f = connectTypes[connectType] - d = f(host, port, options, vhk, uao) - d.addErrback(_ebConnect, useConnects, host, port, options, vhk, uao) - return d + + def _ebConnect(interimResult: Failure | None, /) -> Deferred[None] | None | Failure: + if not useConnects: + return interimResult + connectType = useConnects.pop(0) + f = connectTypes[connectType] + d = f(host, port, options, verifyHostKey, userAuthObject) + d.addErrback(_ebConnect) + return d + + start: Deferred[None] = Deferred() + start.callback(None) + start.addCallback(_ebConnect) + return start diff --git a/contrib/python/Twisted/py3/twisted/conch/client/default.py b/contrib/python/Twisted/py3/twisted/conch/client/default.py index daf4cf33719..7038f8c0107 100644 --- a/contrib/python/Twisted/py3/twisted/conch/client/default.py +++ b/contrib/python/Twisted/py3/twisted/conch/client/default.py @@ -17,12 +17,15 @@ import io import os import sys from base64 import decodebytes +from typing import TYPE_CHECKING from twisted.conch.client import agent from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile from twisted.conch.error import ConchError from twisted.conch.ssh import common, keys, userauth +from twisted.conch.ssh.transport import SSHClientTransport from twisted.internet import defer, protocol, reactor +from twisted.internet.defer import Deferred from twisted.python.compat import nativeString from twisted.python.filepath import FilePath @@ -36,7 +39,9 @@ _open = open _input = input -def verifyHostKey(transport, host, pubKey, fingerprint): +def verifyHostKey( + transport: SSHClientTransport, host: bytes, pubKey: bytes, fingerprint: str +) -> Deferred[bool]: """ Verify a host's key. @@ -56,26 +61,29 @@ def verifyHostKey(transport, host, pubKey, fingerprint): equivalent that could be used. @param host: Due to a bug in L{SSHClientTransport.verifyHostKey}, this is - always the dotted-quad IP address of the host being connected to. - @type host: L{str} + always the dotted-quad IP address of the host being connected to. @param transport: the client transport which is attempting to connect to - the given host. - @type transport: L{SSHClientTransport} + the given host. @param fingerprint: the fingerprint of the given public key, in - xx:xx:xx:... format. This is ignored in favor of getting the fingerprint - from the key itself. - @type fingerprint: L{str} + xx:xx:xx:... format. This is ignored in favor of getting the + fingerprint from the key itself. @param pubKey: The public key of the server being connected to. - @type pubKey: L{str} - @return: a L{Deferred} which fires with C{1} if the key was successfully - verified, or fails if the key could not be successfully verified. Failure - types may include L{HostKeyChanged}, L{UserRejectedKey}, L{IOError} or - L{KeyboardInterrupt}. + @return: a L{Deferred} which fires with C{True} if the key was successfully + verified, or fails if the key could not be successfully verified. + Failure types may include L{HostKeyChanged}, L{UserRejectedKey}, + L{IOError} or L{KeyboardInterrupt}. """ + if TYPE_CHECKING: + # this is just a structured assumption that we are making about the + # transport's factory; behind a TYPE_CHECKING flag because we use some + # test fakes and don't want to nail down the type that much. + from twisted.conch.client.direct import SSHClientFactory + + assert isinstance(transport.factory, SSHClientFactory) actualHost = transport.factory.options["host"] actualKey = keys.Key.fromString(pubKey) kh = KnownHostsFile.fromPath( diff --git a/contrib/python/Twisted/py3/twisted/conch/client/direct.py b/contrib/python/Twisted/py3/twisted/conch/client/direct.py index d9f4828ec5f..33fd1d2df46 100644 --- a/contrib/python/Twisted/py3/twisted/conch/client/direct.py +++ b/contrib/python/Twisted/py3/twisted/conch/client/direct.py @@ -1,51 +1,83 @@ # Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable from twisted.conch import error from twisted.conch.ssh import transport from twisted.internet import defer, protocol, reactor +from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.defer import Deferred, maybeDeferred +from twisted.internet.interfaces import ( + IAddress, + IConnector, + IListeningPort, + IReactorTCP, +) +from twisted.python.failure import Failure + +if TYPE_CHECKING: + from twisted.conch.client.options import ConchOptions + from twisted.conch.ssh.userauth import SSHUserAuthClient class SSHClientFactory(protocol.ClientFactory): - def __init__(self, d, options, verifyHostKey, userAuthObject): - self.d = d + def __init__( + self, + d: Deferred[None], + options: ConchOptions, + verifyHostKey: _VHK, + userAuthObject: SSHUserAuthClient, + ) -> None: + self.d: Deferred[None] | None = d self.options = options self.verifyHostKey = verifyHostKey self.userAuthObject = userAuthObject - def clientConnectionLost(self, connector, reason): + def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: if self.options["reconnect"]: connector.connect() - def clientConnectionFailed(self, connector, reason): + def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: if self.d is None: return d, self.d = self.d, None d.errback(reason) - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> SSHClientTransport: trans = SSHClientTransport(self) if self.options["ciphers"]: trans.supportedCiphers = self.options["ciphers"] if self.options["macs"]: trans.supportedMACs = self.options["macs"] if self.options["compress"]: - trans.supportedCompressions[0:1] = ["zlib"] + trans.supportedCompressions[0:1] = [b"zlib"] if self.options["host-key-algorithms"]: trans.supportedPublicKeys = self.options["host-key-algorithms"] return trans class SSHClientTransport(transport.SSHClientTransport): - def __init__(self, factory): + # pre-mypy LSP violation + factory: SSHClientFactory # type:ignore[assignment] + + def __init__(self, factory: SSHClientFactory) -> None: self.factory = factory - self.unixServer = None + self.unixServer: None | IListeningPort = None - def connectionLost(self, reason): + def connectionLost(self, reason: Failure | None = None) -> None: if self.unixServer: - d = self.unixServer.stopListening() - self.unixServer = None + # The C{unixServer} attribute is untested, and it's not entirely + # clear that it does anything at all. It appears to be a vestigial + # attempt to support something like OpenSSH's ControlMaster client + # option; at some point we should either document and test it, or + # remove it. + + # https://github.com/twisted/twisted/issues/12418 + d = maybeDeferred(self.unixServer.stopListening) # pragma: no cover + self.unixServer = None # pragma: no cover else: d = defer.succeed(None) d.addCallback( @@ -75,9 +107,15 @@ class SSHClientTransport(transport.SSHClientTransport): if alwaysDisplay: # XXX what should happen here? print(message) - def verifyHostKey(self, pubKey, fingerprint): + def verifyHostKey(self, pubKey: bytes, fingerprint: str) -> Deferred[bool]: + transport = self.transport + assert transport is not None + peer = transport.getPeer() + assert isinstance( + peer, (IPv4Address, IPv6Address) + ), "Address must have a host to verify against." return self.factory.verifyHostKey( - self, self.transport.getPeer().host, pubKey, fingerprint + self, peer.host.encode("utf-8"), pubKey, fingerprint ) def setService(self, service): @@ -91,8 +129,17 @@ class SSHClientTransport(transport.SSHClientTransport): self.requestService(self.factory.userAuthObject) -def connect(host, port, options, verifyHostKey, userAuthObject): - d = defer.Deferred() +_VHK = Callable[[SSHClientTransport, bytes, bytes, str], Deferred[bool]] + + +def connect( + host: str, + port: int, + options: ConchOptions, + verifyHostKey: _VHK, + userAuthObject: SSHUserAuthClient, +) -> Deferred[None]: + d: Deferred[None] = defer.Deferred() factory = SSHClientFactory(d, options, verifyHostKey, userAuthObject) - reactor.connectTCP(host, port, factory) + IReactorTCP(reactor).connectTCP(host, port, factory) return d diff --git a/contrib/python/Twisted/py3/twisted/conch/client/knownhosts.py b/contrib/python/Twisted/py3/twisted/conch/client/knownhosts.py index 44118512bd2..1aa4b477c27 100644 --- a/contrib/python/Twisted/py3/twisted/conch/client/knownhosts.py +++ b/contrib/python/Twisted/py3/twisted/conch/client/knownhosts.py @@ -15,7 +15,7 @@ import sys from binascii import Error as DecodeError, a2b_base64, b2a_base64 from contextlib import closing from hashlib import sha1 -from typing import IO, Callable, Literal +from typing import IO, Callable, Iterable, Literal from zope.interface import implementer @@ -33,31 +33,27 @@ from twisted.python.util import FancyEqMixin log = Logger() -def _b64encode(s): +def _b64encode(s: bytes) -> bytes: """ Encode a binary string as base64 with no trailing newline. @param s: The string to encode. - @type s: L{bytes} @return: The base64-encoded string. - @rtype: L{bytes} """ return b2a_base64(s).strip() -def _extractCommon(string): +def _extractCommon(string: bytes) -> tuple[bytes, bytes, Key, bytes | None]: """ Extract common elements of base64 keys from an entry in a hosts file. @param string: A known hosts file entry (a single line). - @type string: L{bytes} @return: a 4-tuple of hostname data (L{bytes}), ssh key type (L{bytes}), key (L{Key}), and comment (L{bytes} or L{None}). The hostname data is simply the beginning of the line up to the first occurrence of whitespace. - @rtype: L{tuple} """ elements = string.split(None, 2) if len(elements) != 3: @@ -86,26 +82,25 @@ class _BaseEntry: @type publicKey: L{twisted.conch.ssh.keys.Key} @ivar comment: Trailing garbage after the key line. - @type comment: L{bytes} + @type comment: L{bytes} or C{None} """ - def __init__(self, keyType, publicKey, comment): + def __init__(self, keyType: bytes, publicKey: Key, comment: bytes | None) -> None: self.keyType = keyType self.publicKey = publicKey self.comment = comment - def matchesKey(self, keyObject): + def matchesKey(self, keyObject: Key) -> bool: """ Check to see if this entry matches a given key object. @param keyObject: A public key object to check. - @type keyObject: L{Key} @return: C{True} if this entry's key matches C{keyObject}, C{False} otherwise. - @rtype: L{bool} """ - return self.publicKey == keyObject + result = self.publicKey == keyObject + return result @implementer(IKnownHostEntry) @@ -118,7 +113,11 @@ class PlainEntry(_BaseEntry): """ def __init__( - self, hostnames: list[bytes], keyType: bytes, publicKey: Key, comment: bytes + self, + hostnames: list[bytes], + keyType: bytes, + publicKey: Key, + comment: bytes | None, ): self._hostnames: list[bytes] = hostnames super().__init__(keyType, publicKey, comment) @@ -188,26 +187,28 @@ class UnparsedEntry: parsed; therefore it matches no keys and no hosts. """ - def __init__(self, string): + keyType: None = None + + def __init__(self, string: bytes) -> None: """ Create an unparsed entry from a line in a known_hosts file which cannot otherwise be parsed. """ self._string = string - def matchesHost(self, hostname): + def matchesHost(self, hostname: bytes) -> bool: """ Always returns False. """ return False - def matchesKey(self, key): + def matchesKey(self, key: Key) -> bool: """ Always returns False. """ return False - def toString(self): + def toString(self) -> bytes: """ Returns the input line, without its newline if one was given. @@ -218,18 +219,15 @@ class UnparsedEntry: return self._string.rstrip(b"\n") -def _hmacedString(key, string): +def _hmacedString(key: bytes, string: bytes | str) -> bytes: """ Return the SHA-1 HMAC hash of the given key and string. @param key: The HMAC key. - @type key: L{bytes} @param string: The string to be hashed. - @type string: L{bytes} @return: The keyed hash value. - @rtype: L{bytes} """ hash = hmac.HMAC(key, digestmod=sha1) if isinstance(string, str): @@ -298,7 +296,7 @@ class HashedEntry(_BaseEntry, FancyEqMixin): self = cls(a2b_base64(hostSalt), a2b_base64(hostHash), keyType, key, comment) return self - def matchesHost(self, hostname): + def matchesHost(self, hostname: bytes) -> bool: """ Implement L{IKnownHostEntry.matchesHost} to compare the hash of the input to the stored hash. @@ -315,7 +313,7 @@ class HashedEntry(_BaseEntry, FancyEqMixin): _hmacedString(self._hostSalt, hostname), self._hostHash ) - def toString(self): + def toString(self) -> bytes: """ Implement L{IKnownHostEntry.toString} by base64-encoding the salt, host hash, and key. @@ -373,7 +371,7 @@ class KnownHostsFile: """ return self._savePath - def iterentries(self): + def iterentries(self) -> Iterable[IKnownHostEntry]: """ Iterate over the host entries in this file. @@ -404,25 +402,22 @@ class KnownHostsFile: entry = UnparsedEntry(line) yield entry - def hasHostKey(self, hostname, key): + def hasHostKey(self, hostname: bytes, key: Key) -> bool: """ Check for an entry with matching hostname and key. @param hostname: A hostname or IP address literal to check for. - @type hostname: L{bytes} @param key: The public key to check for. - @type key: L{Key} - @return: C{True} if the given hostname and key are present in this file, - C{False} if they are not. - @rtype: L{bool} + @return: C{True} if the given hostname and key are present in this + file, C{False} if they are not. @raise HostKeyChanged: if the host key found for the given hostname does not match the given key. """ for lineidx, entry in enumerate(self.iterentries(), -len(self._added)): - if entry.matchesHost(hostname) and entry.keyType == key.sshType(): + if entry.keyType == key.sshType() and entry.matchesHost(hostname): if entry.matchesKey(key): return True else: @@ -569,7 +564,7 @@ class ConsoleUI: console, to be used during key verification. """ - def __init__(self, opener: Callable[[], IO[bytes]]): + def __init__(self, opener: Callable[[], IO[bytes]]) -> None: """ @param opener: A no-argument callable which should open a console binary-mode file-like object to be used for reading and writing. diff --git a/contrib/python/Twisted/py3/twisted/conch/endpoints.py b/contrib/python/Twisted/py3/twisted/conch/endpoints.py index 3269532acd1..966669edec7 100644 --- a/contrib/python/Twisted/py3/twisted/conch/endpoints.py +++ b/contrib/python/Twisted/py3/twisted/conch/endpoints.py @@ -31,13 +31,18 @@ from twisted.conch.ssh.connection import SSHConnection from twisted.conch.ssh.keys import Key from twisted.conch.ssh.transport import SSHClientTransport from twisted.conch.ssh.userauth import SSHUserAuthClient +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import CancelledError, Deferred, succeed from twisted.internet.endpoints import TCP4ClientEndpoint, connectProtocol from twisted.internet.error import ConnectionDone, ProcessTerminated -from twisted.internet.interfaces import IStreamClientEndpoint +from twisted.internet.interfaces import ( + IReactorTCP, + IStreamClientEndpoint, + ITCPTransport, +) from twisted.internet.protocol import Factory from twisted.logger import Logger -from twisted.python.compat import nativeString, networkString +from twisted.python.compat import nativeString from twisted.python.failure import Failure from twisted.python.filepath import FilePath @@ -297,8 +302,8 @@ class _UserAuth(SSHUserAuthClient): authentication, and delegating authentication to an agent. """ - password = None - keys = None + password: bytes | None = None + keys: list[Key] | None = None agent = None def getPublicKey(self): @@ -414,16 +419,23 @@ class _CommandTransport(SSHClientTransport): _hostKeyFailure = None - _userauth = None + _userauth: _UserAuth | None = None - def __init__(self, creator): + def __init__(self, creator: _NewConnectionHelper) -> None: """ @param creator: The L{_NewConnectionHelper} that created this connection. - - @type creator: L{_NewConnectionHelper}. """ - self.connectionReady = Deferred(lambda d: self.transport.abortConnection()) + + def cancelReady(d: Deferred[None]) -> None: + transport = ITCPTransport(self.transport, None) + # adaptation is papering over an annoying type-punning issue here, + # we more or less have to run over an abortable transport, so not + # testing the negative branch. + if transport is not None: # pragma: no branch + transport.abortConnection() + + self.connectionReady: Deferred[None] = Deferred(cancelReady) # Clear the reference to that deferred to help the garbage collector # and to signal to other parts of this implementation (in particular # connectionLost) that it has already been fired and does not need to @@ -436,7 +448,7 @@ class _CommandTransport(SSHClientTransport): self.connectionReady.addBoth(readyFired) self.creator = creator - def verifyHostKey(self, hostKey, fingerprint): + def verifyHostKey(self, hostKey: Key, fingerprint: bytes) -> Deferred[bool]: """ Ask the L{KnownHostsFile} provider available on the factory which created this protocol this protocol to verify the given host key. @@ -445,30 +457,29 @@ class _CommandTransport(SSHClientTransport): L{KnownHostsFile.verifyHostKey}. """ hostname = self.creator.hostname - ip = networkString(self.transport.getPeer().host) - + transport = self.transport + assert transport is not None + peer = transport.getPeer() + assert isinstance(peer, (IPv4Address, IPv6Address)) + ip = peer.host.encode("ascii") self._state = b"SECURING" - d = self.creator.knownHosts.verifyHostKey( + return self.creator.knownHosts.verifyHostKey( self.creator.ui, hostname, ip, Key.fromString(hostKey) - ) - d.addErrback(self._saveHostKeyFailure) - return d + ).addErrback(self._saveHostKeyFailure) - def _saveHostKeyFailure(self, reason): + def _saveHostKeyFailure(self, reason: Failure) -> Failure: """ When host key verification fails, record the reason for the failure in order to fire a L{Deferred} with it later. @param reason: The cause of the host key verification failure. - @type reason: L{Failure} @return: C{reason} - @rtype: L{Failure} """ self._hostKeyFailure = reason return reason - def connectionSecure(self): + def connectionSecure(self) -> None: """ When the connection is secure, start the authentication process. """ @@ -481,17 +492,18 @@ class _CommandTransport(SSHClientTransport): if self.creator.keys: self._userauth.keys = list(self.creator.keys) - if self.creator.agentEndpoint is not None: - d = self._userauth.connectToAgent(self.creator.agentEndpoint) - else: - d = succeed(None) + d = ( + succeed(None) + if self.creator.agentEndpoint is None + else self._userauth.connectToAgent(self.creator.agentEndpoint) + ) def maybeGotAgent(ignored): self.requestService(self._userauth) d.addBoth(maybeGotAgent) - def connectionLost(self, reason): + def connectionLost(self, reason: Failure | None = None) -> None: """ When the underlying connection to the SSH server is lost, if there were any connection setup errors, propagate them. Also, clean up the @@ -529,7 +541,7 @@ class SSHCommandClientEndpoint: command invocations over a single SSH connection. """ - def __init__(self, creator, command): + def __init__(self, creator: _ISSHConnectionCreator, command: bytes) -> None: """ @param creator: An L{_ISSHConnectionCreator} provider which will be used to set up the SSH connection which will be used to run a @@ -550,17 +562,17 @@ class SSHCommandClientEndpoint: @classmethod def newConnection( cls, - reactor, - command, - username, - hostname, - port=None, - keys=None, - password=None, - agentEndpoint=None, - knownHosts=None, - ui=None, - ): + reactor: IReactorTCP, + command: bytes, + username: bytes, + hostname: bytes, + port: int | None = None, + keys: list[Key] | None = None, + password: bytes | None = None, + agentEndpoint: IStreamClientEndpoint | None = None, + knownHosts: str | None = None, + ui: ConsoleUI | None = None, + ) -> SSHCommandClientEndpoint: """ Create and return a new endpoint which will try to create a new connection to an SSH server and run a command over it. It will also @@ -569,44 +581,36 @@ class SSHCommandClientEndpoint: L{Deferred} is cancelled. @param reactor: The reactor to use to establish the connection. - @type reactor: L{IReactorTCP} provider @param command: See L{__init__}'s C{command} argument. @param username: The username with which to authenticate to the SSH server. - @type username: L{bytes} @param hostname: The hostname of the SSH server. - @type hostname: L{bytes} @param port: The port number of the SSH server. By default, the standard SSH port number is used. - @type port: L{int} @param keys: Private keys with which to authenticate to the SSH server, if key authentication is to be attempted (otherwise L{None}). - @type keys: L{list} of L{Key} @param password: The password with which to authenticate to the SSH server, if password authentication is to be attempted (otherwise L{None}). - @type password: L{bytes} or L{None} @param agentEndpoint: An L{IStreamClientEndpoint} provider which may be used to connect to an SSH agent, if one is to be used to help with authentication. - @type agentEndpoint: L{IStreamClientEndpoint} provider - @param knownHosts: The currently known host keys, used to check the - host key presented by the server we actually connect to. - @type knownHosts: L{KnownHostsFile} + @param knownHosts: The path to the currently known host keys file, used + to check the host key presented by the server we actually connect + to. @param ui: An object for interacting with users to make decisions about whether to accept the server host keys. If L{None}, a L{ConsoleUI} connected to /dev/tty will be used; if /dev/tty is unavailable, an object which answers C{b"no"} to all prompts will be used. - @type ui: L{None} or L{ConsoleUI} @return: A new instance of C{cls} (probably L{SSHCommandClientEndpoint}). @@ -707,17 +711,19 @@ class _NewConnectionHelper: _KNOWN_HOSTS = _KNOWN_HOSTS port = 22 + knownHosts: KnownHostsFile + def __init__( self, reactor: Any, - hostname: str, - port: int, - command: str, - username: str, - keys: str, - password: str, - agentEndpoint: str, - knownHosts: str | None, + hostname: bytes, + port: int | None, + command: bytes, + username: bytes, + keys: list[Key] | None, + password: bytes | None, + agentEndpoint: IStreamClientEndpoint | None, + knownHosts: str | None | KnownHostsFile, ui: ConsoleUI | None, tty: FilePath[bytes] | FilePath[str] = FilePath(b"/dev/tty"), ): @@ -733,12 +739,13 @@ class _NewConnectionHelper: self.port = port self.command = command self.username = username - self.keys = keys + self.keys = [] if keys is None else keys self.password = password self.agentEndpoint = agentEndpoint - if knownHosts is None: - knownHosts = self._knownHosts() - self.knownHosts = knownHosts + if isinstance(knownHosts, KnownHostsFile): + self.knownHosts = knownHosts + else: + self.knownHosts = self._knownHosts(knownHosts) if ui is None: ui = ConsoleUI(self._opener) @@ -760,14 +767,19 @@ class _NewConnectionHelper: return BytesIO(b"no") @classmethod - def _knownHosts(cls): + def _knownHosts(cls, path: str | None = None) -> KnownHostsFile: """ - @return: A L{KnownHostsFile} instance pointed at the user's personal I{known hosts} file. @rtype: L{KnownHostsFile} """ - return KnownHostsFile.fromPath(FilePath(expanduser(cls._KNOWN_HOSTS))) + if path is None: # pragma: no branch + # negative branch untested because this fallback path requires user + # configuration that tests shouldn't be messing with + # directly. (This should be factored out for better testability in + # terms of coverage.) + path = expanduser(cls._KNOWN_HOSTS) + return KnownHostsFile.fromPath(FilePath(path)) def secureConnection(self): """ diff --git a/contrib/python/Twisted/py3/twisted/conch/error.py b/contrib/python/Twisted/py3/twisted/conch/error.py index a923b9a4c4a..fcf2c1d2f43 100644 --- a/contrib/python/Twisted/py3/twisted/conch/error.py +++ b/contrib/python/Twisted/py3/twisted/conch/error.py @@ -90,7 +90,7 @@ class HostKeyChanged(Exception): """ def __init__(self, offendingEntry, path, lineno): - Exception.__init__(self) + Exception.__init__(self, offendingEntry, path, lineno) self.offendingEntry = offendingEntry self.path = path self.lineno = lineno diff --git a/contrib/python/Twisted/py3/twisted/conch/interfaces.py b/contrib/python/Twisted/py3/twisted/conch/interfaces.py index 965519b0ea5..59fe7145f05 100644 --- a/contrib/python/Twisted/py3/twisted/conch/interfaces.py +++ b/contrib/python/Twisted/py3/twisted/conch/interfaces.py @@ -370,6 +370,12 @@ class IKnownHostEntry(Interface): @since: 8.2 """ + keyType: bytes | None = Attribute( + """ + The SSH key type identifier for this key. + """ + ) + def matchesKey(key: Key) -> bool: """ Return True if this entry matches the given Key object, False diff --git a/contrib/python/Twisted/py3/twisted/conch/scripts/conch.py b/contrib/python/Twisted/py3/twisted/conch/scripts/conch.py index f3e5479bd91..74e3a9f9fe0 100644 --- a/contrib/python/Twisted/py3/twisted/conch/scripts/conch.py +++ b/contrib/python/Twisted/py3/twisted/conch/scripts/conch.py @@ -16,7 +16,7 @@ import signal import struct import sys import tty -from typing import List, Tuple +from typing import Any, List, Tuple from twisted.conch.client import connect, default from twisted.conch.client.options import ConchOptions @@ -113,7 +113,7 @@ class ClientOptions(ConchOptions): # Rest of code in "run" -options = None +options: Any = None conn = None exitStatus = 0 old = None @@ -198,20 +198,21 @@ def _stopReactor(): pass -def doConnect(): +def doConnect() -> None: if "@" in options["host"]: options["user"], options["host"] = options["host"].split("@", 1) if not options.identitys: options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"] - host = options["host"] + if not options["user"]: options["user"] = getpass.getuser() if not options["port"]: options["port"] = 22 else: options["port"] = int(options["port"]) - host = options["host"] - port = options["port"] + + host: str = options["host"] + port: int = options["port"] vhk = default.verifyHostKey if not options["host-key-algorithms"]: options["host-key-algorithms"] = default.getHostKeyAlgorithms(host, options) diff --git a/contrib/python/Twisted/py3/twisted/conch/ssh/_kex.py b/contrib/python/Twisted/py3/twisted/conch/ssh/_kex.py index c23acec219c..0b04f1f5f6b 100644 --- a/contrib/python/Twisted/py3/twisted/conch/ssh/_kex.py +++ b/contrib/python/Twisted/py3/twisted/conch/ssh/_kex.py @@ -6,26 +6,37 @@ SSH key exchange handling. """ +from __future__ import annotations from hashlib import sha1, sha256, sha384, sha512 +from typing import TYPE_CHECKING, Protocol from zope.interface import Attribute, Interface, implementer from twisted.conch import error +if TYPE_CHECKING: + # NB: Not a real attribute at runtime. + from hashlib import _Hash + + +class _HashFactory(Protocol): + def __call__(self, data: bytes = ...) -> _Hash: + ... + class _IKexAlgorithm(Interface): """ An L{_IKexAlgorithm} describes a key exchange algorithm. """ - preference = Attribute( + preference: int = Attribute( "An L{int} giving the preference of the algorithm when negotiating " "key exchange. Algorithms with lower precedence values are more " "preferred." ) - hashProcessor = Attribute( + hashProcessor: _HashFactory = Attribute( "A callable hash algorithm constructor (e.g. C{hashlib.sha256}) " "suitable for use with this key exchange algorithm." ) @@ -175,7 +186,7 @@ class _DHGroup14SHA1: # Which ECDH hash function to use is dependent on the size. -_kexAlgorithms = { +_kexAlgorithms: dict[bytes, _IKexAlgorithm] = { b"curve25519-sha256": _Curve25519SHA256(), b"[email protected]": _Curve25519SHA256LibSSH(), b"diffie-hellman-group-exchange-sha256": _DHGroupExchangeSHA256(), @@ -187,7 +198,7 @@ _kexAlgorithms = { } -def getKex(kexAlgorithm): +def getKex(kexAlgorithm: bytes) -> _IKexAlgorithm: """ Get a description of a named key exchange algorithm. @@ -201,53 +212,47 @@ def getKex(kexAlgorithm): @raises ConchError: if the key exchange algorithm is not found. """ if kexAlgorithm not in _kexAlgorithms: - raise error.ConchError(f"Unsupported key exchange algorithm: {kexAlgorithm}") + raise error.ConchError(f"Unsupported key exchange algorithm: {kexAlgorithm!r}") return _kexAlgorithms[kexAlgorithm] -def isEllipticCurve(kexAlgorithm): +def isEllipticCurve(kexAlgorithm: bytes) -> bool: """ Returns C{True} if C{kexAlgorithm} is an elliptic curve. @param kexAlgorithm: The key exchange algorithm name. - @type kexAlgorithm: C{str} - @return: C{True} if C{kexAlgorithm} is an elliptic curve, - otherwise C{False}. - @rtype: C{bool} + @return: C{True} if C{kexAlgorithm} is an elliptic curve, otherwise + C{False}. """ return _IEllipticCurveExchangeKexAlgorithm.providedBy(getKex(kexAlgorithm)) -def isFixedGroup(kexAlgorithm): +def isFixedGroup(kexAlgorithm: bytes) -> bool: """ Returns C{True} if C{kexAlgorithm} has a fixed prime / generator group. @param kexAlgorithm: The key exchange algorithm name. - @type kexAlgorithm: L{bytes} @return: C{True} if C{kexAlgorithm} has a fixed prime / generator group, otherwise C{False}. - @rtype: L{bool} """ return _IFixedGroupKexAlgorithm.providedBy(getKex(kexAlgorithm)) -def getHashProcessor(kexAlgorithm): +def getHashProcessor(kexAlgorithm: bytes) -> _HashFactory: """ Get the hash algorithm callable to use in key exchange. @param kexAlgorithm: The key exchange algorithm name. - @type kexAlgorithm: L{bytes} @return: A callable hash algorithm constructor (e.g. C{hashlib.sha256}). - @rtype: C{callable} """ kex = getKex(kexAlgorithm) return kex.hashProcessor -def getDHGeneratorAndPrime(kexAlgorithm): +def getDHGeneratorAndPrime(kexAlgorithm: bytes) -> tuple[int, int]: """ Get the generator and the prime to use in key exchange. @@ -257,17 +262,16 @@ def getDHGeneratorAndPrime(kexAlgorithm): @return: A L{tuple} containing L{int} generator and L{int} prime. @rtype: L{tuple} """ - kex = getKex(kexAlgorithm) + kex = _IFixedGroupKexAlgorithm(getKex(kexAlgorithm)) return kex.generator, kex.prime -def getSupportedKeyExchanges(): +def getSupportedKeyExchanges() -> list[bytes]: """ Get a list of supported key exchange algorithm names in order of preference. @return: A C{list} of supported key exchange algorithm names. - @rtype: C{list} of L{bytes} """ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec diff --git a/contrib/python/Twisted/py3/twisted/conch/ssh/common.py b/contrib/python/Twisted/py3/twisted/conch/ssh/common.py index 8bb6a286c3b..8d01ab14e50 100644 --- a/contrib/python/Twisted/py3/twisted/conch/ssh/common.py +++ b/contrib/python/Twisted/py3/twisted/conch/ssh/common.py @@ -4,12 +4,11 @@ """ Common functions for the SSH classes. - -Maintainer: Paul Swartz """ - +from __future__ import annotations import struct +from typing import Sequence, overload from cryptography.utils import int_to_bytes @@ -19,7 +18,7 @@ from twisted.python.versions import Version __all__ = ["NS", "getNS", "MP", "getMP", "ffs"] -def NS(t): +def NS(t: bytes | str) -> bytes: """ net string """ @@ -28,7 +27,7 @@ def NS(t): return struct.pack("!L", len(t)) + t -def getNS(s, count=1): +def getNS(s: bytes, count: int = 1) -> Sequence[bytes]: """ get net string """ @@ -41,7 +40,7 @@ def getNS(s, count=1): return tuple(ns) + (s[c:],) -def MP(number): +def MP(number: int) -> bytes: if number == 0: return b"\000" * 4 assert number > 0 @@ -51,7 +50,17 @@ def MP(number): return struct.pack(">L", len(bn)) + bn -def getMP(data, count=1): +@overload +def getMP(data: bytes) -> tuple[int, bytes]: + ... + + +@overload +def getMP(data: bytes, count: int) -> Sequence[int | bytes]: + ... + + +def getMP(data: bytes, count: int = 1) -> Sequence[int | bytes]: """ Get multiple precision integer out of the string. A multiple precision integer is stored as a 4-byte length followed by length bytes of the diff --git a/contrib/python/Twisted/py3/twisted/conch/ssh/keys.py b/contrib/python/Twisted/py3/twisted/conch/ssh/keys.py index e0e4a4b2c54..e52608df9aa 100644 --- a/contrib/python/Twisted/py3/twisted/conch/ssh/keys.py +++ b/contrib/python/Twisted/py3/twisted/conch/ssh/keys.py @@ -1675,20 +1675,17 @@ class Key: values = (data["p"], data["q"], data["g"], data["y"], data["x"]) return common.NS(self.sshType()) + b"".join(map(common.MP, values)) - def sign(self, data, signatureType=None): + def sign(self, data: bytes, signatureType: bytes | None = None) -> bytes: """ Sign some data with this key. SECSH-TRANS RFC 4253 Section 6.6. - @type data: L{bytes} @param data: The data to sign. - @type signatureType: L{bytes} @param signatureType: The SSH public key algorithm name to sign this - data with, or L{None} to use a reasonable default for the key. + data with, or L{None} to use a reasonable default for the key. - @rtype: L{bytes} @return: A signature for the given data. """ keyType = self.type() @@ -1702,7 +1699,7 @@ class Key: hashAlgorithm = self._getHashAlgorithm(signatureType) if hashAlgorithm is None: raise BadSignatureAlgorithmError( - f"public key signature algorithm {signatureType} is not " + f"public key signature algorithm {signatureType!r} is not " f"defined for {keyType} keys" ) @@ -1726,21 +1723,13 @@ class Key: rb = int_to_bytes(r) sb = int_to_bytes(s) - # Int_to_bytes returns rb[0] as a str in python2 - # and an as int in python3 - if type(rb[0]) is str: - rcomp = ord(rb[0]) - else: - rcomp = rb[0] + rcomp = rb[0] # If the MSB is set, prepend a null byte for correct formatting. if rcomp & 0x80: rb = b"\x00" + rb - if type(sb[0]) is str: - scomp = ord(sb[0]) - else: - scomp = sb[0] + scomp = sb[0] if scomp & 0x80: sb = b"\x00" + sb diff --git a/contrib/python/Twisted/py3/twisted/conch/ssh/service.py b/contrib/python/Twisted/py3/twisted/conch/ssh/service.py index 7d0d41c4aed..acfd40ee6a7 100644 --- a/contrib/python/Twisted/py3/twisted/conch/ssh/service.py +++ b/contrib/python/Twisted/py3/twisted/conch/ssh/service.py @@ -7,18 +7,22 @@ are ssh-userauth and ssh-connection. Maintainer: Paul Swartz """ +from __future__ import annotations -from typing import Dict +from typing import TYPE_CHECKING, Dict from twisted.logger import Logger +if TYPE_CHECKING: + from twisted.conch.ssh.transport import SSHTransportBase + class SSHService: # this is the ssh name for the service: name: bytes = None # type:ignore[assignment] protocolMessages: Dict[int, str] = {} # map #'s -> protocol names - transport = None # gets set later + transport: SSHTransportBase | None = None # gets set later _log = Logger() diff --git a/contrib/python/Twisted/py3/twisted/conch/ssh/transport.py b/contrib/python/Twisted/py3/twisted/conch/ssh/transport.py index 545c010f76e..323236f4720 100644 --- a/contrib/python/Twisted/py3/twisted/conch/ssh/transport.py +++ b/contrib/python/Twisted/py3/twisted/conch/ssh/transport.py @@ -17,7 +17,7 @@ import struct import types import zlib from hashlib import md5, sha1, sha256, sha384, sha512 -from typing import Any, Callable, Dict, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple, Union from cryptography.exceptions import UnsupportedAlgorithm from cryptography.hazmat.backends import default_backend @@ -29,14 +29,19 @@ from typing_extensions import Literal from twisted import __version__ as twisted_version from twisted.conch.ssh import _kex, address, keys from twisted.conch.ssh.common import MP, NS, ffs, getMP, getNS +from twisted.conch.ssh.service import SSHService from twisted.internet import defer, protocol from twisted.logger import Logger from twisted.python import randbytes from twisted.python.compat import iterbytes, networkString +from twisted.python.failure import Failure # This import is needed if SHA256 hashing is used. # from twisted.python.compat import nativeString +if TYPE_CHECKING: + from twisted.conch.ssh.factory import SSHFactory + def _mpFromBytes(data): """Make an SSH multiple-precision integer from big-endian L{bytes}. @@ -311,8 +316,8 @@ def _getSupportedCiphers(): class SSHTransportBase(protocol.Protocol): """ - Protocol supporting basic SSH functionality: sending/receiving packets - and message dispatch. To connect to or run a server, you must use + Protocol supporting basic SSH functionality: sending/receiving packets and + message dispatch. To connect to or run a server, you must use SSHClientTransport or SSHServerTransport. @ivar protocolVersion: A string representing the version of the SSH @@ -321,23 +326,22 @@ class SSHTransportBase(protocol.Protocol): @ivar version: A string representing the version of the server or client. Currently defaults to 'Twisted'. - @ivar comment: An optional string giving more information about the - server or client. + @ivar comment: An optional string giving more information about the server + or client. @ivar supportedCiphers: A list of strings representing the encryption algorithms supported, in order from most-preferred to least. @ivar supportedMACs: A list of strings representing the message authentication codes (hashes) supported, in order from most-preferred - to least. Both this and supportedCiphers can include 'none' to use - no encryption or authentication, but that must be done manually, + to least. Both this and supportedCiphers can include 'none' to use no + encryption or authentication, but that must be done manually, - @ivar supportedKeyExchanges: A list of strings representing the - key exchanges supported, in order from most-preferred to least. + @ivar supportedKeyExchanges: A list of strings representing the key + exchanges supported, in order from most-preferred to least. - @ivar supportedPublicKeys: A list of strings representing the - public key algorithms supported, in order from most-preferred to - least. + @ivar supportedPublicKeys: A list of strings representing the public key + algorithms supported, in order from most-preferred to least. @ivar supportedCompressions: A list of strings representing compression types supported, from most-preferred to least. @@ -350,16 +354,16 @@ class SSHTransportBase(protocol.Protocol): @ivar isClient: A boolean indicating whether this is a client or server. - @ivar gotVersion: A boolean indicating whether we have received the - version string from the other side. + @ivar gotVersion: A boolean indicating whether we have received the version + string from the other side. @ivar buf: Data we've received but hasn't been parsed into a packet. @ivar outgoingPacketSequence: the sequence number of the next packet we will send. - @ivar incomingPacketSequence: the sequence number of the next packet we - are expecting from the other side. + @ivar incomingPacketSequence: the sequence number of the next packet we are + expecting from the other side. @ivar outgoingCompression: an object supporting the .compress(str) and .flush() methods, or None if there is no outgoing compression. Used to @@ -391,8 +395,8 @@ class SSHTransportBase(protocol.Protocol): part of the key exchange, sessionID is used to generate the various encryption and authentication keys. - @ivar service: an SSHService instance, or None. If it's set to an object, - it's the currently running service. + @ivar service: an L{SSHService} instance, or None. If it's set to an + object, it's the currently running service. @ivar kexAlg: the agreed-upon key exchange algorithm. @@ -476,8 +480,8 @@ class SSHTransportBase(protocol.Protocol): incomingPacketSequence = 0 outgoingCompression = None incomingCompression = None - sessionID = None - service = None + sessionID: bytes | None = None + service: SSHService | None = None # There is no key exchange activity in progress. _KEY_EXCHANGE_NONE = "_KEY_EXCHANGE_NONE" @@ -507,7 +511,13 @@ class SSHTransportBase(protocol.Protocol): _peerSupportsExtensions = False peerExtensions: Dict[bytes, bytes] = {} - def connectionLost(self, reason): + factory: SSHFactory + + # Set by twisted.conch.ssh.userauth.SSHUserAuthServer._cbFinishedAuth + avatar: object + logoutFunction: Callable[[], None] + + def connectionLost(self, reason: Failure | None = None) -> None: """ When the underlying connection is closed, stop the running service (if any), and log out the avatar (if any). @@ -1171,41 +1181,35 @@ class SSHTransportBase(protocol.Protocol): prefix = struct.pack(">L", len(secret)) return prefix + secret - def _getKey(self, c, sharedSecret, exchangeHash): + def _getKey(self, c: bytes, sharedSecret: bytes, exchangeHash: bytes) -> bytes: """ Get one of the keys for authentication/encryption. - @type c: L{bytes} @param c: The letter identifying which key this is. - @type sharedSecret: L{bytes} @param sharedSecret: The shared secret K. - @type exchangeHash: L{bytes} @param exchangeHash: The hash H from key exchange. - @rtype: L{bytes} @return: The derived key. """ hashProcessor = _kex.getHashProcessor(self.kexAlg) - k1 = hashProcessor(sharedSecret + exchangeHash + c + self.sessionID) - k1 = k1.digest() + assert self.sessionID is not None, "session ID must already have been assigned" + k1 = hashProcessor(sharedSecret + exchangeHash + c + self.sessionID).digest() k2 = hashProcessor(sharedSecret + exchangeHash + k1).digest() k3 = hashProcessor(sharedSecret + exchangeHash + k1 + k2).digest() k4 = hashProcessor(sharedSecret + exchangeHash + k1 + k2 + k3).digest() return k1 + k2 + k3 + k4 - def _keySetup(self, sharedSecret, exchangeHash): + def _keySetup(self, sharedSecret: bytes, exchangeHash: bytes) -> None: """ - Set up the keys for the connection and sends MSG_NEWKEYS when - finished, + Set up the keys for the connection and sends MSG_NEWKEYS when finished. @param sharedSecret: a secret string agreed upon using a Diffie- - Hellman exchange, so it is only shared between - the server and the client. - @type sharedSecret: L{str} + Hellman exchange, so it is only shared between the server and the + client. + @param exchangeHash: A hash of various data known by both sides. - @type exchangeHash: L{str} """ if not self.sessionID: self.sessionID = exchangeHash @@ -1442,7 +1446,7 @@ class SSHServerTransport(SSHTransportBase): isClient = False ignoreNextPacket = 0 - def _getHostKeys(self, keyAlg): + def _getHostKeys(self, keyAlg: bytes) -> tuple[keys.Key, keys.Key]: """ Get the public and private host keys corresponding to the given public key signature algorithm. @@ -1467,7 +1471,7 @@ class SSHServerTransport(SSHTransportBase): keyFormat = keyAlg return self.factory.publicKeys[keyFormat], self.factory.privateKeys[keyFormat] - def ssh_KEXINIT(self, packet): + def ssh_KEXINIT(self, packet: bytes) -> None: """ Called when we receive a MSG_KEXINIT message. For a description of the packet, see SSHTransportBase.ssh_KEXINIT(). Additionally, @@ -1487,29 +1491,26 @@ class SSHServerTransport(SSHTransportBase): ): self.ignoreNextPacket = True # Guess was wrong - def _ssh_KEX_ECDH_INIT(self, packet): + def _ssh_KEX_ECDH_INIT(self, packet: bytes) -> None: """ - Called from L{ssh_KEX_DH_GEX_REQUEST_OLD} to handle - elliptic curve key exchanges. + Called from L{ssh_KEX_DH_GEX_REQUEST_OLD} to handle elliptic curve key + exchanges. Payload:: string client Elliptic Curve Diffie-Hellman public key Just like L{_ssh_KEXDH_INIT} this message type is also not dispatched - directly. Extra check to determine if this is really KEX_ECDH_INIT - is required. + directly. Extra check to determine if this is really KEX_ECDH_INIT is + required. - First we load the host's public/private keys. - Then we generate the ECDH public/private keypair for the given curve. - With that we generate the shared secret key. - Then we compute the hash to sign and send back to the client - Along with the server's public key and the ECDH public key. + First we load the host's public/private keys. Then we generate the + ECDH public/private keypair for the given curve. With that we generate + the shared secret key. Then we compute the hash to sign and send back + to the client Along with the server's public key and the ECDH public + key. - @type packet: L{bytes} @param packet: The message data. - - @return: None. """ # Get the raw client public key. pktPub, packet = getNS(packet) @@ -1547,7 +1548,7 @@ class SSHServerTransport(SSHTransportBase): ) self._keySetup(sharedSecret, exchangeHash) - def _ssh_KEXDH_INIT(self, packet): + def _ssh_KEXDH_INIT(self, packet: bytes) -> None: """ Called to handle the beginning of a non-group key exchange. @@ -1588,7 +1589,7 @@ class SSHServerTransport(SSHTransportBase): ) self._keySetup(sharedSecret, exchangeHash) - def ssh_KEX_DH_GEX_REQUEST_OLD(self, packet): + def ssh_KEX_DH_GEX_REQUEST_OLD(self, packet: bytes) -> None: """ This represents different key exchange methods that share the same integer value. If the message is determined to be a KEXDH_INIT, @@ -1625,7 +1626,7 @@ class SSHServerTransport(SSHTransportBase): self._startEphemeralDH() self.sendPacket(MSG_KEX_DH_GEX_GROUP, MP(self.p) + MP(self.g)) - def ssh_KEX_DH_GEX_REQUEST(self, packet): + def ssh_KEX_DH_GEX_REQUEST(self, packet: bytes) -> None: """ Called when we receive a MSG_KEX_DH_GEX_REQUEST message. Payload:: integer minimum @@ -1651,7 +1652,7 @@ class SSHServerTransport(SSHTransportBase): self._startEphemeralDH() self.sendPacket(MSG_KEX_DH_GEX_GROUP, MP(self.p) + MP(self.g)) - def ssh_KEX_DH_GEX_INIT(self, packet): + def ssh_KEX_DH_GEX_INIT(self, packet: bytes) -> None: """ Called when we get a MSG_KEX_DH_GEX_INIT message. Payload:: integer e (client DH public key) @@ -1693,7 +1694,7 @@ class SSHServerTransport(SSHTransportBase): ) self._keySetup(sharedSecret, exchangeHash) - def _keySetup(self, sharedSecret, exchangeHash): + def _keySetup(self, sharedSecret: bytes, exchangeHash: bytes) -> None: """ See SSHTransportBase._keySetup(). """ @@ -1709,13 +1710,12 @@ class SSHServerTransport(SSHTransportBase): [(b"server-sig-algs", b",".join(self.supportedPublicKeys))] ) - def ssh_NEWKEYS(self, packet): + def ssh_NEWKEYS(self, packet: bytes) -> None: """ Called when we get a MSG_NEWKEYS message. No payload. When we get this, the keys have been set on both sides, and we start using them to encrypt and authenticate the connection. - @type packet: L{bytes} @param packet: The message data. """ if packet != b"": @@ -1723,7 +1723,7 @@ class SSHServerTransport(SSHTransportBase): return self._newKeys() - def ssh_SERVICE_REQUEST(self, packet): + def ssh_SERVICE_REQUEST(self, packet: bytes) -> None: """ Called when we get a MSG_SERVICE_REQUEST message. Payload:: string serviceName @@ -1906,7 +1906,8 @@ class SSHClientTransport(SSHTransportBase): d.addCallback(_continue_KEX_ECDH_REPLY, hostKey, pubKey, signature) d.addErrback( lambda unused: self.sendDisconnect( - DISCONNECT_HOST_KEY_NOT_VERIFIABLE, b"bad host key" + DISCONNECT_HOST_KEY_NOT_VERIFIABLE, + f"bad host key [ecdh] {unused}".encode("utf-8"), ) ) return d @@ -2122,7 +2123,7 @@ class SSHClientTransport(SSHTransportBase): ) self.setService(self.instance) - def requestService(self, instance): + def requestService(self, instance: SSHService) -> None: """ Request that a service be run over this transport. diff --git a/contrib/python/Twisted/py3/twisted/conch/ssh/userauth.py b/contrib/python/Twisted/py3/twisted/conch/ssh/userauth.py index 310f5f09f2e..0d24df00f92 100644 --- a/contrib/python/Twisted/py3/twisted/conch/ssh/userauth.py +++ b/contrib/python/Twisted/py3/twisted/conch/ssh/userauth.py @@ -8,20 +8,28 @@ Currently implemented authentication types are public-key and password. Maintainer: Paul Swartz """ - +from __future__ import annotations import struct +from typing import Callable, Tuple, Type from twisted.conch import error, interfaces from twisted.conch.ssh import keys, service, transport from twisted.conch.ssh.common import NS, getNS +from twisted.conch.ssh.keys import Key from twisted.cred import credentials from twisted.cred.error import UnauthorizedLogin from twisted.internet import defer, reactor +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IReactorTime from twisted.logger import Logger from twisted.python import failure from twisted.python.compat import nativeString +_ConchPortalTuple = Tuple[ + Type[interfaces.IConchUser], interfaces.IConchUser, Callable[[], None] +] + class SSHUserAuthServer(service.SSHService): """ @@ -72,7 +80,7 @@ class SSHUserAuthServer(service.SSHService): attemptsBeforeDisconnect = 20 # 20 login attempts before a disconnect passwordDelay = 1 # number of seconds to delay on a failed password - clock = reactor + clock: IReactorTime = IReactorTime(reactor) interfaceToMethod = { credentials.ISSHPrivateKey: b"publickey", credentials.IUsernamePassword: b"password", @@ -124,37 +132,40 @@ class SSHUserAuthServer(service.SSHService): transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, b"you took too long" ) - def tryAuth(self, kind, user, data): + def tryAuth( + self, kind: bytes, user: bytes, data: bytes + ) -> Deferred[_ConchPortalTuple]: """ Try to authenticate the user with the given method. Dispatches to a auth_* method. @param kind: the authentication method to try. - @type kind: L{bytes} + @param user: the username the client is authenticating with. - @type user: L{bytes} + @param data: authentication specific data sent by the client. - @type data: L{bytes} + @return: A Deferred called back if the method succeeded, or erred back if it failed. - @rtype: C{defer.Deferred} """ self._log.debug("{user!r} trying auth {kind!r}", user=user, kind=kind) if kind not in self.supportedAuthentications: return defer.fail(error.ConchError("unsupported authentication, failing")) - kind = nativeString(kind.replace(b"-", b"_")) - f = getattr(self, f"auth_{kind}", None) - if f: + strkind = kind.replace(b"-", b"_").decode("ascii") + f: Callable[[bytes], Deferred[_ConchPortalTuple] | None] | None = getattr( + self, f"auth_{strkind}", None + ) + if f is not None: ret = f(data) - if not ret: + if ret is None: return defer.fail( - error.ConchError(f"{kind} return None instead of a Deferred") + error.ConchError(f"{strkind} return None instead of a Deferred") ) else: return ret - return defer.fail(error.ConchError(f"bad auth type: {kind}")) + return defer.fail(error.ConchError(f"bad auth type: {strkind}")) - def ssh_USERAUTH_REQUEST(self, packet): + def ssh_USERAUTH_REQUEST(self, packet: bytes) -> Deferred[_ConchPortalTuple] | None: """ The client has requested authentication. Payload:: string user @@ -173,19 +184,21 @@ class SSHUserAuthServer(service.SSHService): d = self.tryAuth(method, user, rest) if not d: self._ebBadAuth(failure.Failure(error.ConchError("auth returned none"))) - return - d.addCallback(self._cbFinishedAuth) - d.addErrback(self._ebMaybeBadAuth) - d.addErrback(self._ebBadAuth) - return d + return None + return ( + d.addCallback(self._cbFinishedAuth) + .addErrback(self._ebMaybeBadAuth) + .addErrback(self._ebBadAuth) + ) - def _cbFinishedAuth(self, result): + def _cbFinishedAuth(self, result: _ConchPortalTuple) -> None: """ The callback when user has successfully been authenticated. For a description of the arguments, see L{twisted.cred.portal.Portal.login}. We start the service requested by the user. """ (interface, avatar, logout) = result + assert self.transport is not None self.transport.avatar = avatar self.transport.logoutFunction = logout service = self.transport.factory.getService(self.transport, self.nextService) @@ -249,7 +262,7 @@ class SSHUserAuthServer(service.SSHService): MSG_USERAUTH_FAILURE, NS(b",".join(self.supportedAuthentications)) + b"\x00" ) - def auth_publickey(self, packet): + def auth_publickey(self, packet: bytes) -> Deferred[_ConchPortalTuple]: """ Public key authentication. Payload:: byte has signature @@ -262,6 +275,7 @@ class SSHUserAuthServer(service.SSHService): hasSig = ord(packet[0:1]) algName, blob, rest = getNS(packet[1:], 2) + result: Deferred[_ConchPortalTuple] try: keys.Key.fromString(blob) except keys.BadKeyError: @@ -271,6 +285,8 @@ class SSHUserAuthServer(service.SSHService): signature = hasSig and getNS(rest)[0] or None if hasSig: + assert self.transport is not None, "must have transport for auth" + assert self.transport.sessionID is not None, "must have session for auth" b = ( NS(self.transport.sessionID) + bytes((MSG_USERAUTH_REQUEST,)) @@ -282,19 +298,21 @@ class SSHUserAuthServer(service.SSHService): + NS(blob) ) c = credentials.SSHPrivateKey(self.user, algName, blob, b, signature) - return self.portal.login(c, None, interfaces.IConchUser) + result = self.portal.login(c, None, interfaces.IConchUser) else: c = credentials.SSHPrivateKey(self.user, algName, blob, None, None) - return self.portal.login(c, None, interfaces.IConchUser).addErrback( + result = self.portal.login(c, None, interfaces.IConchUser).addErrback( self._ebCheckKey, packet[1:] ) + return result - def _ebCheckKey(self, reason, packet): + def _ebCheckKey(self, reason: failure.Failure, packet: bytes) -> failure.Failure: """ Called back if the user did not sent a signature. If reason is error.ValidPublicKey then this key is valid for the user to authenticate with. Send MSG_USERAUTH_PK_OK. """ + assert self.transport is not None reason.trap(error.ValidPublicKey) # if we make it here, it means that the publickey is valid self.transport.sendPacket(MSG_USERAUTH_PK_OK, packet) @@ -331,64 +349,67 @@ class SSHUserAuthClient(service.SSHService): making callbacks for more information when necessary. @ivar name: the name of this service: 'ssh-userauth' - @type name: L{str} + @ivar preferredOrder: a list of authentication methods that should be used first, in order of preference, if supported by the server - @type preferredOrder: L{list} + @ivar user: the name of the user to authenticate as - @type user: L{bytes} + @ivar instance: the service to start after authentication has finished - @type instance: L{service.SSHService} - @ivar authenticatedWith: a list of strings of authentication methods we've tried - @type authenticatedWith: L{list} of L{bytes} + + @ivar authenticatedWith: a list of strings of authentication methods we've + tried + @ivar triedPublicKeys: a list of public key objects that we've tried to authenticate with - @type triedPublicKeys: L{list} of L{Key} + @ivar lastPublicKey: the last public key object we've tried to authenticate with - @type lastPublicKey: L{Key} """ - name = b"ssh-userauth" - preferredOrder = [b"publickey", b"password", b"keyboard-interactive"] + name: bytes = b"ssh-userauth" + preferredOrder: list[bytes] = [ + b"publickey", + b"password", + b"keyboard-interactive", + ] - def __init__(self, user, instance): + def __init__(self, user: bytes, instance: service.SSHService): self.user = user self.instance = instance - def serviceStarted(self): - self.authenticatedWith = [] - self.triedPublicKeys = [] - self.lastPublicKey = None + def serviceStarted(self) -> None: + self.authenticatedWith: list[bytes] = [] + self.triedPublicKeys: list[Key] = [] + self.lastPublicKey: Key | None = None self.askForAuth(b"none", b"") - def askForAuth(self, kind, extraData): + def askForAuth(self, kind: bytes, extraData: bytes) -> None: """ - Send a MSG_USERAUTH_REQUEST. + Send a C{MSG_USERAUTH_REQUEST}. @param kind: the authentication method to try. - @type kind: L{bytes} + @param extraData: method-specific data to go in the packet - @type extraData: L{bytes} """ + assert self.transport is not None self.lastAuth = kind self.transport.sendPacket( MSG_USERAUTH_REQUEST, NS(self.user) + NS(self.instance.name) + NS(kind) + extraData, ) - def tryAuth(self, kind): + def tryAuth(self, kind: bytes) -> None | Deferred[bool]: """ Dispatch to an authentication method. @param kind: the authentication method @type kind: L{bytes} """ - kind = nativeString(kind.replace(b"-", b"_")) + strkind = kind.replace(b"-", b"_").decode("ascii") self._log.debug("trying to auth with {kind}", kind=kind) - f = getattr(self, "auth_" + kind, None) - if f: - return f() + f: Callable[[], Deferred[bool]] | None = getattr(self, "auth_" + strkind, None) + return f() if f is not None else None def _ebAuth(self, ignored, *args): """ @@ -597,19 +618,15 @@ class SSHUserAuthClient(service.SSHService): data += NS(r.encode("UTF8")) self.transport.sendPacket(MSG_USERAUTH_INFO_RESPONSE, data) - def auth_publickey(self): + def auth_publickey(self) -> Deferred[bool]: """ Try to authenticate with a public key. Ask the user for a public key; if the user has one, send the request to the server and return True. Otherwise, return False. - - @rtype: L{bool} """ - d = defer.maybeDeferred(self.getPublicKey) - d.addBoth(self._cbGetPublicKey) - return d + return defer.maybeDeferred(self.getPublicKey).addBoth(self._cbGetPublicKey) - def _cbGetPublicKey(self, publicKey): + def _cbGetPublicKey(self, publicKey: Key | failure.Failure | None) -> bool: if not isinstance(publicKey, keys.Key): # failure or None publicKey = None if publicKey is not None: @@ -623,13 +640,15 @@ class SSHUserAuthClient(service.SSHService): else: return False - def auth_password(self): + # Section defining C{auth_}-prefixed methods begins here: they must each be + # defined with the signature (() -> bool), as described by + # L{SSHUserAuthClient.tryAuth}. + + def auth_password(self) -> bool: """ Try to authenticate with a password. Ask the user for a password. If the user will return a password, return True. Otherwise, return False. - - @rtype: L{bool} """ d = self.getPassword() if d: @@ -638,83 +657,75 @@ class SSHUserAuthClient(service.SSHService): else: # returned None, don't do password auth return False - def auth_keyboard_interactive(self): + def auth_keyboard_interactive(self) -> bool: """ Try to authenticate with keyboard-interactive authentication. Send the request to the server and return True. - - @rtype: L{bool} """ self._log.debug("authing with keyboard-interactive") self.askForAuth(b"keyboard-interactive", NS(b"") + NS(b"")) return True - def _cbPassword(self, password): + # Section defining C{auth_}-prefixed methods ends here. + + def _cbPassword(self, password: bytes) -> None: """ Called back when the user gives a password. Send the request to the server. @param password: the password the user entered - @type password: L{bytes} """ self.askForAuth(b"password", b"\x00" + NS(password)) - def signData(self, publicKey, signData): + def signData(self, publicKey: keys.Key, signData: bytes) -> Deferred[bytes] | None: """ Sign the given data with the given public key. - By default, this will call getPrivateKey to get the private key, - then sign the data using Key.sign(). + By default, this will call getPrivateKey to get the private key, then + sign the data using Key.sign(). This method is factored out so that it can be overridden to use alternate methods, such as a key agent. @param publicKey: The public key object returned from L{getPublicKey} - @type publicKey: L{keys.Key} @param signData: the data to be signed by the private key. - @type signData: L{bytes} + @return: a Deferred that's called back with the signature - @rtype: L{defer.Deferred} """ key = self.getPrivateKey() if not key: - return + return None return key.addCallback(self._cbSignData, signData) - def _cbSignData(self, privateKey, signData): + def _cbSignData(self, privateKey: keys.Key, signData: bytes) -> bytes: """ - Called back when the private key is returned. Sign the data and - return the signature. + Called back when the private key is returned. Sign the data and return + the signature. @param privateKey: the private key object - @type privateKey: L{keys.Key} + @param signData: the data to be signed by the private key. - @type signData: L{bytes} + @return: the signature - @rtype: L{bytes} """ return privateKey.sign(signData) - def getPublicKey(self): + def getPublicKey(self) -> Key | None: """ Return a public key for the user. If no more public keys are available, return L{None}. This implementation always returns L{None}. Override it in a subclass to actually find and return a public key object. - - @rtype: L{Key} or L{None} """ return None - def getPrivateKey(self): + def getPrivateKey(self) -> Deferred[Key]: """ Return a L{Deferred} that will be called back with the private key object corresponding to the last public key from getPublicKey(). If the private key is not available, errback on the Deferred. - - @rtype: L{Deferred} called back with L{Key} """ return defer.fail(NotImplementedError()) |
