aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/Twisted/py3/twisted/conch/client/knownhosts.py
blob: 44118512bd209f194ef2eecf2fe31762a88f4efa (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
# -*- test-case-name: twisted.conch.test.test_knownhosts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
An implementation of the OpenSSH known_hosts database.

@since: 8.2
"""

from __future__ import annotations

import hmac
import sys
from binascii import Error as DecodeError, a2b_base64, b2a_base64
from contextlib import closing
from hashlib import sha1
from typing import IO, Callable, Literal

from zope.interface import implementer

from twisted.conch.error import HostKeyChanged, InvalidEntry, UserRejectedKey
from twisted.conch.interfaces import IKnownHostEntry
from twisted.conch.ssh.keys import BadKeyError, FingerprintFormats, Key
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.logger import Logger
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
from twisted.python.randbytes import secureRandom
from twisted.python.util import FancyEqMixin

log = Logger()


def _b64encode(s):
    """
    Encode a binary string as base64 with no trailing newline.

    @param s: The string to encode.
    @type s: L{bytes}

    @return: The base64-encoded string.
    @rtype: L{bytes}
    """
    return b2a_base64(s).strip()


def _extractCommon(string):
    """
    Extract common elements of base64 keys from an entry in a hosts file.

    @param string: A known hosts file entry (a single line).
    @type string: L{bytes}

    @return: a 4-tuple of hostname data (L{bytes}), ssh key type (L{bytes}), key
        (L{Key}), and comment (L{bytes} or L{None}).  The hostname data is
        simply the beginning of the line up to the first occurrence of
        whitespace.
    @rtype: L{tuple}
    """
    elements = string.split(None, 2)
    if len(elements) != 3:
        raise InvalidEntry()
    hostnames, keyType, keyAndComment = elements
    splitkey = keyAndComment.split(None, 1)
    if len(splitkey) == 2:
        keyString, comment = splitkey
        comment = comment.rstrip(b"\n")
    else:
        keyString = splitkey[0]
        comment = None
    key = Key.fromString(a2b_base64(keyString))
    return hostnames, keyType, key, comment


class _BaseEntry:
    """
    Abstract base of both hashed and non-hashed entry objects, since they
    represent keys and key types the same way.

    @ivar keyType: The type of the key; either ssh-dss or ssh-rsa.
    @type keyType: L{bytes}

    @ivar publicKey: The server public key indicated by this line.
    @type publicKey: L{twisted.conch.ssh.keys.Key}

    @ivar comment: Trailing garbage after the key line.
    @type comment: L{bytes}
    """

    def __init__(self, keyType, publicKey, comment):
        self.keyType = keyType
        self.publicKey = publicKey
        self.comment = comment

    def matchesKey(self, keyObject):
        """
        Check to see if this entry matches a given key object.

        @param keyObject: A public key object to check.
        @type keyObject: L{Key}

        @return: C{True} if this entry's key matches C{keyObject}, C{False}
            otherwise.
        @rtype: L{bool}
        """
        return self.publicKey == keyObject


@implementer(IKnownHostEntry)
class PlainEntry(_BaseEntry):
    """
    A L{PlainEntry} is a representation of a plain-text entry in a known_hosts
    file.

    @ivar _hostnames: the list of all host-names associated with this entry.
    """

    def __init__(
        self, hostnames: list[bytes], keyType: bytes, publicKey: Key, comment: bytes
    ):
        self._hostnames: list[bytes] = hostnames
        super().__init__(keyType, publicKey, comment)

    @classmethod
    def fromString(cls, string: bytes) -> PlainEntry:
        """
        Parse a plain-text entry in a known_hosts file, and return a
        corresponding L{PlainEntry}.

        @param string: a space-separated string formatted like "hostname
            key-type base64-key-data comment".

        @raise DecodeError: if the key is not valid encoded as valid base64.

        @raise InvalidEntry: if the entry does not have the right number of
            elements and is therefore invalid.

        @raise BadKeyError: if the key, once decoded from base64, is not
            actually an SSH key.

        @return: an IKnownHostEntry representing the hostname and key in the
            input line.

        @rtype: L{PlainEntry}
        """
        hostnames, keyType, key, comment = _extractCommon(string)
        self = cls(hostnames.split(b","), keyType, key, comment)
        return self

    def matchesHost(self, hostname: bytes | str) -> bool:
        """
        Check to see if this entry matches a given hostname.

        @param hostname: A hostname or IP address literal to check against this
            entry.

        @return: C{True} if this entry is for the given hostname or IP address,
            C{False} otherwise.
        """
        if isinstance(hostname, str):
            hostname = hostname.encode("utf-8")
        return hostname in self._hostnames

    def toString(self) -> bytes:
        """
        Implement L{IKnownHostEntry.toString} by recording the comma-separated
        hostnames, key type, and base-64 encoded key.

        @return: The string representation of this entry, with unhashed hostname
            information.
        """
        fields = [
            b",".join(self._hostnames),
            self.keyType,
            _b64encode(self.publicKey.blob()),
        ]
        if self.comment is not None:
            fields.append(self.comment)
        return b" ".join(fields)


@implementer(IKnownHostEntry)
class UnparsedEntry:
    """
    L{UnparsedEntry} is an entry in a L{KnownHostsFile} which can't actually be
    parsed; therefore it matches no keys and no hosts.
    """

    def __init__(self, string):
        """
        Create an unparsed entry from a line in a known_hosts file which cannot
        otherwise be parsed.
        """
        self._string = string

    def matchesHost(self, hostname):
        """
        Always returns False.
        """
        return False

    def matchesKey(self, key):
        """
        Always returns False.
        """
        return False

    def toString(self):
        """
        Returns the input line, without its newline if one was given.

        @return: The string representation of this entry, almost exactly as was
            used to initialize this entry but without a trailing newline.
        @rtype: L{bytes}
        """
        return self._string.rstrip(b"\n")


def _hmacedString(key, string):
    """
    Return the SHA-1 HMAC hash of the given key and string.

    @param key: The HMAC key.
    @type key: L{bytes}

    @param string: The string to be hashed.
    @type string: L{bytes}

    @return: The keyed hash value.
    @rtype: L{bytes}
    """
    hash = hmac.HMAC(key, digestmod=sha1)
    if isinstance(string, str):
        string = string.encode("utf-8")
    hash.update(string)
    return hash.digest()


@implementer(IKnownHostEntry)
class HashedEntry(_BaseEntry, FancyEqMixin):
    """
    A L{HashedEntry} is a representation of an entry in a known_hosts file
    where the hostname has been hashed and salted.

    @ivar _hostSalt: the salt to combine with a hostname for hashing.

    @ivar _hostHash: the hashed representation of the hostname.

    @cvar MAGIC: the 'hash magic' string used to identify a hashed line in a
    known_hosts file as opposed to a plaintext one.
    """

    MAGIC = b"|1|"

    compareAttributes = ("_hostSalt", "_hostHash", "keyType", "publicKey", "comment")

    def __init__(
        self,
        hostSalt: bytes,
        hostHash: bytes,
        keyType: bytes,
        publicKey: Key,
        comment: bytes | None,
    ) -> None:
        self._hostSalt = hostSalt
        self._hostHash = hostHash
        super().__init__(keyType, publicKey, comment)

    @classmethod
    def fromString(cls, string: bytes) -> HashedEntry:
        """
        Load a hashed entry from a string representing a line in a known_hosts
        file.

        @param string: A complete single line from a I{known_hosts} file,
            formatted as defined by OpenSSH.

        @raise DecodeError: if the key, the hostname, or the is not valid
            encoded as valid base64

        @raise InvalidEntry: if the entry does not have the right number of
            elements and is therefore invalid, or the host/hash portion
            contains more items than just the host and hash.

        @raise BadKeyError: if the key, once decoded from base64, is not
            actually an SSH key.

        @return: The newly created L{HashedEntry} instance, initialized with
            the information from C{string}.
        """
        stuff, keyType, key, comment = _extractCommon(string)
        saltAndHash = stuff[len(cls.MAGIC) :].split(b"|")
        if len(saltAndHash) != 2:
            raise InvalidEntry()
        hostSalt, hostHash = saltAndHash
        self = cls(a2b_base64(hostSalt), a2b_base64(hostHash), keyType, key, comment)
        return self

    def matchesHost(self, hostname):
        """
        Implement L{IKnownHostEntry.matchesHost} to compare the hash of the
        input to the stored hash.

        @param hostname: A hostname or IP address literal to check against this
            entry.
        @type hostname: L{bytes}

        @return: C{True} if this entry is for the given hostname or IP address,
            C{False} otherwise.
        @rtype: L{bool}
        """
        return hmac.compare_digest(
            _hmacedString(self._hostSalt, hostname), self._hostHash
        )

    def toString(self):
        """
        Implement L{IKnownHostEntry.toString} by base64-encoding the salt, host
        hash, and key.

        @return: The string representation of this entry, with the hostname part
            hashed.
        @rtype: L{bytes}
        """
        fields = [
            self.MAGIC
            + b"|".join([_b64encode(self._hostSalt), _b64encode(self._hostHash)]),
            self.keyType,
            _b64encode(self.publicKey.blob()),
        ]
        if self.comment is not None:
            fields.append(self.comment)
        return b" ".join(fields)


class KnownHostsFile:
    """
    A structured representation of an OpenSSH-format ~/.ssh/known_hosts file.

    @ivar _added: A list of L{IKnownHostEntry} providers which have been added
        to this instance in memory but not yet saved.

    @ivar _clobber: A flag indicating whether the current contents of the save
        path will be disregarded and potentially overwritten or not.  If
        C{True}, this will be done.  If C{False}, entries in the save path will
        be read and new entries will be saved by appending rather than
        overwriting.
    @type _clobber: L{bool}

    @ivar _savePath: See C{savePath} parameter of L{__init__}.
    """

    def __init__(self, savePath: FilePath[str]) -> None:
        """
        Create a new, empty KnownHostsFile.

        Unless you want to erase the current contents of C{savePath}, you want
        to use L{KnownHostsFile.fromPath} instead.

        @param savePath: The L{FilePath} to which to save new entries.
        @type savePath: L{FilePath}
        """
        self._added: list[IKnownHostEntry] = []
        self._savePath = savePath
        self._clobber = True

    @property
    def savePath(self) -> FilePath[str]:
        """
        @see: C{savePath} parameter of L{__init__}
        """
        return self._savePath

    def iterentries(self):
        """
        Iterate over the host entries in this file.

        @return: An iterable the elements of which provide L{IKnownHostEntry}.
            There is an element for each entry in the file as well as an element
            for each added but not yet saved entry.
        @rtype: iterable of L{IKnownHostEntry} providers
        """
        for entry in self._added:
            yield entry

        if self._clobber:
            return

        try:
            fp = self._savePath.open()
        except OSError:
            return

        with fp:
            for line in fp:
                try:
                    if line.startswith(HashedEntry.MAGIC):
                        entry = HashedEntry.fromString(line)
                    else:
                        entry = PlainEntry.fromString(line)
                except (DecodeError, InvalidEntry, BadKeyError):
                    entry = UnparsedEntry(line)
                yield entry

    def hasHostKey(self, hostname, key):
        """
        Check for an entry with matching hostname and key.

        @param hostname: A hostname or IP address literal to check for.
        @type hostname: L{bytes}

        @param key: The public key to check for.
        @type key: L{Key}

        @return: C{True} if the given hostname and key are present in this file,
            C{False} if they are not.
        @rtype: L{bool}

        @raise HostKeyChanged: if the host key found for the given hostname
            does not match the given key.
        """
        for lineidx, entry in enumerate(self.iterentries(), -len(self._added)):
            if entry.matchesHost(hostname) and entry.keyType == key.sshType():
                if entry.matchesKey(key):
                    return True
                else:
                    # Notice that lineidx is 0-based but HostKeyChanged.lineno
                    # is 1-based.
                    if lineidx < 0:
                        line = None
                        path = None
                    else:
                        line = lineidx + 1
                        path = self._savePath
                    raise HostKeyChanged(entry, path, line)
        return False

    def verifyHostKey(
        self, ui: ConsoleUI, hostname: bytes, ip: bytes, key: Key
    ) -> Deferred[bool]:
        """
        Verify the given host key for the given IP and host, asking for
        confirmation from, and notifying, the given UI about changes to this
        file.

        @param ui: The user interface to request an IP address from.

        @param hostname: The hostname that the user requested to connect to.

        @param ip: The string representation of the IP address that is actually
        being connected to.

        @param key: The public key of the server.

        @return: a L{Deferred} that fires with True when the key has been
            verified, or fires with an errback when the key either cannot be
            verified or has changed.
        @rtype: L{Deferred}
        """
        hhk = defer.execute(self.hasHostKey, hostname, key)

        def gotHasKey(result: bool) -> bool | Deferred[bool]:
            if result:
                if not self.hasHostKey(ip, key):
                    addMessage = (
                        f"Warning: Permanently added the {key.type()} host key"
                        f" for IP address '{ip.decode()}' to the list of known"
                        " hosts.\n"
                    )
                    ui.warn(addMessage.encode("utf-8"))
                    self.addHostKey(ip, key)
                    self.save()
                return result
            else:

                def promptResponse(response: bool) -> bool:
                    if response:
                        self.addHostKey(hostname, key)
                        self.addHostKey(ip, key)
                        self.save()
                        return response
                    else:
                        raise UserRejectedKey()

                keytype: str = key.type()

                if keytype == "EC":
                    keytype = "ECDSA"

                prompt = (
                    "The authenticity of host '%s (%s)' "
                    "can't be established.\n"
                    "%s key fingerprint is SHA256:%s.\n"
                    "Are you sure you want to continue connecting (yes/no)? "
                    % (
                        nativeString(hostname),
                        nativeString(ip),
                        keytype,
                        key.fingerprint(format=FingerprintFormats.SHA256_BASE64),
                    )
                )
                proceed = ui.prompt(prompt.encode(sys.getdefaultencoding()))
                return proceed.addCallback(promptResponse)

        return hhk.addCallback(gotHasKey)

    def addHostKey(self, hostname: bytes, key: Key) -> HashedEntry:
        """
        Add a new L{HashedEntry} to the key database.

        Note that you still need to call L{KnownHostsFile.save} if you wish
        these changes to be persisted.

        @param hostname: A hostname or IP address literal to associate with the
            new entry.
        @type hostname: L{bytes}

        @param key: The public key to associate with the new entry.
        @type key: L{Key}

        @return: The L{HashedEntry} that was added.
        @rtype: L{HashedEntry}
        """
        salt = secureRandom(20)
        keyType = key.sshType()
        entry = HashedEntry(salt, _hmacedString(salt, hostname), keyType, key, None)
        self._added.append(entry)
        return entry

    def save(self) -> None:
        """
        Save this L{KnownHostsFile} to the path it was loaded from.
        """
        p = self._savePath.parent()
        if not p.isdir():
            p.makedirs()

        mode: Literal["a", "w"] = "w" if self._clobber else "a"
        with self._savePath.open(mode) as hostsFileObj:
            if self._added:
                hostsFileObj.write(
                    b"\n".join([entry.toString() for entry in self._added]) + b"\n"
                )
                self._added = []
        self._clobber = False

    @classmethod
    def fromPath(cls, path: FilePath[str]) -> KnownHostsFile:
        """
        Create a new L{KnownHostsFile}, potentially reading existing known
        hosts information from the given file.

        @param path: A path object to use for both reading contents from and
            later saving to.  If no file exists at this path, it is not an
            error; a L{KnownHostsFile} with no entries is returned.

        @return: A L{KnownHostsFile} initialized with entries from C{path}.
        """
        knownHosts = cls(path)
        knownHosts._clobber = False
        return knownHosts


class ConsoleUI:
    """
    A UI object that can ask true/false questions and post notifications on the
    console, to be used during key verification.
    """

    def __init__(self, opener: Callable[[], IO[bytes]]):
        """
        @param opener: A no-argument callable which should open a console
            binary-mode file-like object to be used for reading and writing.
            This initializes the C{opener} attribute.
        @type opener: callable taking no arguments and returning a read/write
            file-like object
        """
        self.opener = opener

    def prompt(self, text: bytes) -> Deferred[bool]:
        """
        Write the given text as a prompt to the console output, then read a
        result from the console input.

        @param text: Something to present to a user to solicit a yes or no
            response.
        @type text: L{bytes}

        @return: a L{Deferred} which fires with L{True} when the user answers
            'yes' and L{False} when the user answers 'no'.  It may errback if
            there were any I/O errors.
        """
        d = defer.succeed(None)

        def body(ignored):
            with closing(self.opener()) as f:
                f.write(text)
                while True:
                    answer = f.readline().strip().lower()
                    if answer == b"yes":
                        return True
                    elif answer in {b"no", b""}:
                        return False
                    else:
                        f.write(b"Please type 'yes' or 'no': ")

        return d.addCallback(body)

    def warn(self, text: bytes) -> None:
        """
        Notify the user (non-interactively) of the provided text, by writing it
        to the console.

        @param text: Some information the user is to be made aware of.
        """
        try:
            with closing(self.opener()) as f:
                f.write(text)
        except Exception:
            log.failure("Failed to write to console")