aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/Twisted/py3/twisted/application/_client_service.py
blob: 32b6c4c52349bc87223d902667e311e46265b1ea (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
# -*- test-case-name: twisted.application.test.test_internet,twisted.test.test_application,twisted.test.test_cooperator -*-

"""
Implementation of L{twisted.application.internet.ClientService}, particularly
its U{automat <https://automat.readthedocs.org/>} state machine.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from random import random as _goodEnoughRandom
from typing import Callable, Optional, Protocol as TypingProtocol, TypeVar, Union

from zope.interface import implementer

from automat import TypeMachineBuilder, pep614

from twisted.application.service import Service
from twisted.internet.defer import (
    CancelledError,
    Deferred,
    fail,
    maybeDeferred,
    succeed,
)
from twisted.internet.interfaces import (
    IAddress,
    IDelayedCall,
    IProtocol,
    IProtocolFactory,
    IReactorTime,
    IStreamClientEndpoint,
    ITransport,
)
from twisted.logger import Logger
from twisted.python.failure import Failure

T = TypeVar("T")


def _maybeGlobalReactor(maybeReactor: Optional[T]) -> T:
    """
    @return: the argument, or the global reactor if the argument is L{None}.
    """
    if maybeReactor is None:
        from twisted.internet import reactor

        return reactor  # type:ignore[return-value]
    else:
        return maybeReactor


class _Client(TypingProtocol):
    def start(self) -> None:
        """
        Start this L{ClientService}, initiating the connection retry loop.
        """

    def stop(self) -> Deferred[None]:
        """
        Stop trying to connect and disconnect any current connection.

        @return: a L{Deferred} that fires when all outstanding connections are
            closed and all in-progress connection attempts halted.
        """

    def _connectionMade(self, protocol: _ReconnectingProtocolProxy) -> None:
        """
        A connection has been made.

        @param protocol: The protocol of the connection.
        """

    def _connectionFailed(self, failure: Failure) -> None:
        """
        Deliver connection failures to any L{ClientService.whenConnected}
        L{Deferred}s that have met their failAfterFailures threshold.

        @param failure: the Failure to fire the L{Deferred}s with.
        """

    def _reconnect(self, failure: Optional[Failure] = None) -> None:
        """
        The wait between connection attempts is done.
        """

    def _clientDisconnected(self, failure: Optional[Failure] = None) -> None:
        """
        The current connection has been disconnected.
        """

    def whenConnected(
        self, /, failAfterFailures: Optional[int] = None
    ) -> Deferred[IProtocol]:
        """
        Retrieve the currently-connected L{Protocol}, or the next one to
        connect.

        @param failAfterFailures: number of connection failures after which the
            Deferred will deliver a Failure (None means the Deferred will only
            fail if/when the service is stopped).  Set this to 1 to make the
            very first connection failure signal an error.  Use 2 to allow one
            failure but signal an error if the subsequent retry then fails.

        @return: a Deferred that fires with a protocol produced by the factory
            passed to C{__init__}.  It may:

                - fire with L{IProtocol}

                - fail with L{CancelledError} when the service is stopped

                - fail with e.g.
                  L{DNSLookupError<twisted.internet.error.DNSLookupError>} or
                  L{ConnectionRefusedError<twisted.internet.error.ConnectionRefusedError>}
                  when the number of consecutive failed connection attempts
                  equals the value of "failAfterFailures"
        """


@implementer(IProtocol)
class _ReconnectingProtocolProxy:
    """
    A proxy for a Protocol to provide connectionLost notification to a client
    connection service, in support of reconnecting when connections are lost.
    """

    def __init__(
        self, protocol: IProtocol, lostNotification: Callable[[Failure], None]
    ) -> None:
        """
        Create a L{_ReconnectingProtocolProxy}.

        @param protocol: the application-provided L{interfaces.IProtocol}
            provider.
        @type protocol: provider of L{interfaces.IProtocol} which may
            additionally provide L{interfaces.IHalfCloseableProtocol} and
            L{interfaces.IFileDescriptorReceiver}.

        @param lostNotification: a 1-argument callable to invoke with the
            C{reason} when the connection is lost.
        """
        self._protocol = protocol
        self._lostNotification = lostNotification

    def makeConnection(self, transport: ITransport) -> None:
        self._transport = transport
        self._protocol.makeConnection(transport)

    def connectionLost(self, reason: Failure) -> None:
        """
        The connection was lost.  Relay this information.

        @param reason: The reason the connection was lost.

        @return: the underlying protocol's result
        """
        try:
            return self._protocol.connectionLost(reason)
        finally:
            self._lostNotification(reason)

    def __getattr__(self, item: str) -> object:
        return getattr(self._protocol, item)

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__} wrapping {self._protocol!r}>"


@implementer(IProtocolFactory)
class _DisconnectFactory:
    """
    A L{_DisconnectFactory} is a proxy for L{IProtocolFactory} that catches
    C{connectionLost} notifications and relays them.
    """

    def __init__(
        self,
        protocolFactory: IProtocolFactory,
        protocolDisconnected: Callable[[Failure], None],
    ) -> None:
        self._protocolFactory = protocolFactory
        self._protocolDisconnected = protocolDisconnected

    def buildProtocol(self, addr: IAddress) -> Optional[IProtocol]:
        """
        Create a L{_ReconnectingProtocolProxy} with the disconnect-notification
        callback we were called with.

        @param addr: The address the connection is coming from.

        @return: a L{_ReconnectingProtocolProxy} for a protocol produced by
            C{self._protocolFactory}
        """
        built = self._protocolFactory.buildProtocol(addr)
        if built is None:
            return None
        return _ReconnectingProtocolProxy(built, self._protocolDisconnected)

    def __getattr__(self, item: str) -> object:
        return getattr(self._protocolFactory, item)

    def __repr__(self) -> str:
        return "<{} wrapping {!r}>".format(
            self.__class__.__name__, self._protocolFactory
        )


def _deinterface(o: object) -> None:
    """
    Remove the special runtime attributes set by L{implementer} so that a class
    can proxy through those attributes with C{__getattr__} and thereby forward
    optionally-provided interfaces by the delegated class.
    """
    for zopeSpecial in ["__providedBy__", "__provides__", "__implemented__"]:
        delattr(o, zopeSpecial)


_deinterface(_DisconnectFactory)
_deinterface(_ReconnectingProtocolProxy)


@dataclass
class _Core:
    """
    Shared core for ClientService state machine.
    """

    # required parameters
    endpoint: IStreamClientEndpoint
    factory: IProtocolFactory
    timeoutForAttempt: Callable[[int], float]
    clock: IReactorTime
    prepareConnection: Optional[Callable[[IProtocol], object]]

    # internal state
    stopWaiters: list[Deferred[None]] = field(default_factory=list)
    awaitingConnected: list[tuple[Deferred[IProtocol], Optional[int]]] = field(
        default_factory=list
    )
    failedAttempts: int = 0
    log: Logger = Logger()

    def waitForStop(self) -> Deferred[None]:
        self.stopWaiters.append(Deferred())
        return self.stopWaiters[-1]

    def unawait(self, value: Union[IProtocol, Failure]) -> None:
        self.awaitingConnected, waiting = [], self.awaitingConnected
        for w, remaining in waiting:
            w.callback(value)

    def cancelConnectWaiters(self) -> None:
        self.unawait(Failure(CancelledError()))

    def finishStopping(self) -> None:
        self.stopWaiters, waiting = [], self.stopWaiters
        for w in waiting:
            w.callback(None)


def makeMachine() -> Callable[[_Core], _Client]:
    machine = TypeMachineBuilder(_Client, _Core)

    def waitForRetry(
        c: _Client, s: _Core, failure: Optional[Failure] = None
    ) -> IDelayedCall:
        s.failedAttempts += 1
        delay = s.timeoutForAttempt(s.failedAttempts)
        s.log.info(
            "Scheduling retry {attempt} to connect {endpoint} in {delay} seconds.",
            attempt=s.failedAttempts,
            endpoint=s.endpoint,
            delay=delay,
        )
        return s.clock.callLater(delay, c._reconnect)

    def rememberConnection(
        c: _Client, s: _Core, protocol: _ReconnectingProtocolProxy
    ) -> _ReconnectingProtocolProxy:
        s.failedAttempts = 0
        s.unawait(protocol._protocol)
        return protocol

    def attemptConnection(
        c: _Client, s: _Core, failure: Optional[Failure] = None
    ) -> Deferred[_ReconnectingProtocolProxy]:
        factoryProxy = _DisconnectFactory(s.factory, c._clientDisconnected)
        connecting: Deferred[IProtocol] = s.endpoint.connect(factoryProxy)

        def prepare(
            protocol: _ReconnectingProtocolProxy,
        ) -> Deferred[_ReconnectingProtocolProxy]:
            if s.prepareConnection is not None:
                return maybeDeferred(s.prepareConnection, protocol).addCallback(
                    lambda _: protocol
                )
            return succeed(protocol)

        # endpoint.connect() is actually generic on the type of the protocol,
        # but this is not expressible via zope.interface, so we have to cast
        # https://github.com/Shoobx/mypy-zope/issues/95
        connectingProxy: Deferred[_ReconnectingProtocolProxy]
        connectingProxy = connecting  # type:ignore[assignment]
        (
            connectingProxy.addCallback(prepare)
            .addCallback(c._connectionMade)
            .addErrback(c._connectionFailed)
        )
        return connectingProxy

    # States:
    Init = machine.state("Init")
    Connecting = machine.state("Connecting", attemptConnection)
    Stopped = machine.state("Stopped")
    Waiting = machine.state("Waiting", waitForRetry)
    Connected = machine.state("Connected", rememberConnection)
    Disconnecting = machine.state("Disconnecting")
    Restarting = machine.state("Restarting")
    Stopped = machine.state("Stopped")

    # Behavior-less state transitions:
    Init.upon(_Client.start).to(Connecting).returns(None)
    Connecting.upon(_Client.start).loop().returns(None)
    Connecting.upon(_Client._connectionMade).to(Connected).returns(None)
    Waiting.upon(_Client.start).loop().returns(None)
    Waiting.upon(_Client._reconnect).to(Connecting).returns(None)
    Connected.upon(_Client._connectionFailed).to(Waiting).returns(None)
    Connected.upon(_Client.start).loop().returns(None)
    Connected.upon(_Client._clientDisconnected).to(Waiting).returns(None)
    Disconnecting.upon(_Client.start).to(Restarting).returns(None)
    Restarting.upon(_Client.start).to(Restarting).returns(None)
    Stopped.upon(_Client.start).to(Connecting).returns(None)

    # Behavior-full state transitions:
    @pep614(Init.upon(_Client.stop).to(Stopped))
    @pep614(Stopped.upon(_Client.stop).to(Stopped))
    def immediateStop(c: _Client, s: _Core) -> Deferred[None]:
        return succeed(None)

    @pep614(Connecting.upon(_Client.stop).to(Disconnecting))
    def connectingStop(
        c: _Client, s: _Core, attempt: Deferred[_ReconnectingProtocolProxy]
    ) -> Deferred[None]:
        waited = s.waitForStop()
        attempt.cancel()
        return waited

    @pep614(Connecting.upon(_Client._connectionFailed, nodata=True).to(Waiting))
    def failedWhenConnecting(c: _Client, s: _Core, failure: Failure) -> None:
        ready = []
        notReady: list[tuple[Deferred[IProtocol], Optional[int]]] = []
        for w, remaining in s.awaitingConnected:
            if remaining is None:
                notReady.append((w, remaining))
            elif remaining <= 1:
                ready.append(w)
            else:
                notReady.append((w, remaining - 1))
        s.awaitingConnected = notReady
        for w in ready:
            w.callback(failure)

    @pep614(Waiting.upon(_Client.stop).to(Stopped))
    def stop(c: _Client, s: _Core, futureRetry: IDelayedCall) -> Deferred[None]:
        waited = s.waitForStop()
        s.cancelConnectWaiters()
        futureRetry.cancel()
        s.finishStopping()
        return waited

    @pep614(Connected.upon(_Client.stop).to(Disconnecting))
    def stopWhileConnected(
        c: _Client, s: _Core, protocol: _ReconnectingProtocolProxy
    ) -> Deferred[None]:
        waited = s.waitForStop()
        protocol._transport.loseConnection()
        return waited

    @pep614(Connected.upon(_Client.whenConnected).loop())
    def whenConnectedWhenConnected(
        c: _Client,
        s: _Core,
        protocol: _ReconnectingProtocolProxy,
        failAfterFailures: Optional[int] = None,
    ) -> Deferred[IProtocol]:
        return succeed(protocol._protocol)

    @pep614(Disconnecting.upon(_Client.stop).loop())
    @pep614(Restarting.upon(_Client.stop).to(Disconnecting))
    def discoStop(c: _Client, s: _Core) -> Deferred[None]:
        return s.waitForStop()

    @pep614(Disconnecting.upon(_Client._connectionFailed).to(Stopped))
    @pep614(Disconnecting.upon(_Client._clientDisconnected).to(Stopped))
    def disconnectingFinished(
        c: _Client, s: _Core, failure: Optional[Failure] = None
    ) -> None:
        s.cancelConnectWaiters()
        s.finishStopping()

    @pep614(Connecting.upon(_Client.whenConnected, nodata=True).loop())
    @pep614(Waiting.upon(_Client.whenConnected, nodata=True).loop())
    @pep614(Init.upon(_Client.whenConnected).to(Init))
    @pep614(Restarting.upon(_Client.whenConnected).to(Restarting))
    @pep614(Disconnecting.upon(_Client.whenConnected).to(Disconnecting))
    def awaitingConnection(
        c: _Client, s: _Core, failAfterFailures: Optional[int] = None
    ) -> Deferred[IProtocol]:
        result: Deferred[IProtocol] = Deferred()
        s.awaitingConnected.append((result, failAfterFailures))
        return result

    @pep614(Restarting.upon(_Client._clientDisconnected).to(Connecting))
    def restartDone(c: _Client, s: _Core, failure: Optional[Failure] = None) -> None:
        s.finishStopping()

    @pep614(Stopped.upon(_Client.whenConnected).to(Stopped))
    def notGoingToConnect(
        c: _Client, s: _Core, failAfterFailures: Optional[int] = None
    ) -> Deferred[IProtocol]:
        return fail(CancelledError())

    return machine.build()


def backoffPolicy(
    initialDelay: float = 1.0,
    maxDelay: float = 60.0,
    factor: float = 1.5,
    jitter: Callable[[], float] = _goodEnoughRandom,
) -> Callable[[int], float]:
    """
    A timeout policy for L{ClientService} which computes an exponential backoff
    interval with configurable parameters.

    @since: 16.1.0

    @param initialDelay: Delay for the first reconnection attempt (default
        1.0s).
    @type initialDelay: L{float}

    @param maxDelay: Maximum number of seconds between connection attempts
        (default 60 seconds, or one minute).  Note that this value is before
        jitter is applied, so the actual maximum possible delay is this value
        plus the maximum possible result of C{jitter()}.
    @type maxDelay: L{float}

    @param factor: A multiplicative factor by which the delay grows on each
        failed reattempt.  Default: 1.5.
    @type factor: L{float}

    @param jitter: A 0-argument callable that introduces noise into the delay.
        By default, C{random.random}, i.e. a pseudorandom floating-point value
        between zero and one.
    @type jitter: 0-argument callable returning L{float}

    @return: a 1-argument callable that, given an attempt count, returns a
        floating point number; the number of seconds to delay.
    @rtype: see L{ClientService.__init__}'s C{retryPolicy} argument.
    """

    def policy(attempt: int) -> float:
        try:
            delay = min(initialDelay * (factor ** min(100, attempt)), maxDelay)
        except OverflowError:
            delay = maxDelay
        return delay + jitter()

    return policy


_defaultPolicy = backoffPolicy()
ClientMachine = makeMachine()


class ClientService(Service):
    """
    A L{ClientService} maintains a single outgoing connection to a client
    endpoint, reconnecting after a configurable timeout when a connection
    fails, either before or after connecting.

    @since: 16.1.0
    """

    _log = Logger()

    def __init__(
        self,
        endpoint: IStreamClientEndpoint,
        factory: IProtocolFactory,
        retryPolicy: Optional[Callable[[int], float]] = None,
        clock: Optional[IReactorTime] = None,
        prepareConnection: Optional[Callable[[IProtocol], object]] = None,
    ):
        """
        @param endpoint: A L{stream client endpoint
            <interfaces.IStreamClientEndpoint>} provider which will be used to
            connect when the service starts.

        @param factory: A L{protocol factory <interfaces.IProtocolFactory>}
            which will be used to create clients for the endpoint.

        @param retryPolicy: A policy configuring how long L{ClientService} will
            wait between attempts to connect to C{endpoint}; a callable taking
            (the number of failed connection attempts made in a row (L{int}))
            and returning the number of seconds to wait before making another
            attempt.

        @param clock: The clock used to schedule reconnection.  It's mainly
            useful to be parametrized in tests.  If the factory is serialized,
            this attribute will not be serialized, and the default value (the
            reactor) will be restored when deserialized.

        @param prepareConnection: A single argument L{callable} that may return
            a L{Deferred}.  It will be called once with the L{protocol
            <interfaces.IProtocol>} each time a new connection is made.  It may
            call methods on the protocol to prepare it for use (e.g.
            authenticate) or validate it (check its health).

            The C{prepareConnection} callable may raise an exception or return
            a L{Deferred} which fails to reject the connection.  A rejected
            connection is not used to fire an L{Deferred} returned by
            L{whenConnected}.  Instead, L{ClientService} handles the failure
            and continues as if the connection attempt were a failure
            (incrementing the counter passed to C{retryPolicy}).

            L{Deferred}s returned by L{whenConnected} will not fire until any
            L{Deferred} returned by the C{prepareConnection} callable fire.
            Otherwise its successful return value is consumed, but ignored.

            Present Since Twisted 18.7.0
        """
        clock = _maybeGlobalReactor(clock)
        retryPolicy = _defaultPolicy if retryPolicy is None else retryPolicy

        self._machine: _Client = ClientMachine(
            _Core(
                endpoint,
                factory,
                retryPolicy,
                clock,
                prepareConnection=prepareConnection,
                log=self._log,
            )
        )

    def whenConnected(
        self, failAfterFailures: Optional[int] = None
    ) -> Deferred[IProtocol]:
        """
        Retrieve the currently-connected L{Protocol}, or the next one to
        connect.

        @param failAfterFailures: number of connection failures after which
            the Deferred will deliver a Failure (None means the Deferred will
            only fail if/when the service is stopped).  Set this to 1 to make
            the very first connection failure signal an error.  Use 2 to
            allow one failure but signal an error if the subsequent retry
            then fails.
        @type failAfterFailures: L{int} or None

        @return: a Deferred that fires with a protocol produced by the
            factory passed to C{__init__}
        @rtype: L{Deferred} that may:

            - fire with L{IProtocol}

            - fail with L{CancelledError} when the service is stopped

            - fail with e.g.
              L{DNSLookupError<twisted.internet.error.DNSLookupError>} or
              L{ConnectionRefusedError<twisted.internet.error.ConnectionRefusedError>}
              when the number of consecutive failed connection attempts
              equals the value of "failAfterFailures"
        """
        return self._machine.whenConnected(failAfterFailures)

    def startService(self) -> None:
        """
        Start this L{ClientService}, initiating the connection retry loop.
        """
        if self.running:
            self._log.warn("Duplicate ClientService.startService {log_source}")
            return
        super().startService()
        self._machine.start()

    def stopService(self) -> Deferred[None]:
        """
        Stop attempting to reconnect and close any existing connections.

        @return: a L{Deferred} that fires when all outstanding connections are
            closed and all in-progress connection attempts halted.
        """
        super().stopService()
        return self._machine.stop()