aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/Twisted/py3/twisted/protocols/haproxy/_wrapper.py
blob: 935dbfa9e20ebe879e82f4f593f92853633cfb05 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- test-case-name: twisted.protocols.haproxy.test.test_wrapper -*-

# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
Protocol wrapper that provides HAProxy PROXY protocol support.
"""
from typing import Optional, Union

from twisted.internet import interfaces
from twisted.internet.endpoints import _WrapperServerEndpoint
from twisted.protocols import policies
from . import _info
from ._exceptions import InvalidProxyHeader
from ._v1parser import V1Parser
from ._v2parser import V2Parser


class HAProxyProtocolWrapper(policies.ProtocolWrapper):
    """
    A Protocol wrapper that provides HAProxy support.

    This protocol reads the PROXY stream header, v1 or v2, parses the provided
    connection data, and modifies the behavior of getPeer and getHost to return
    the data provided by the PROXY header.
    """

    def __init__(
        self, factory: policies.WrappingFactory, wrappedProtocol: interfaces.IProtocol
    ):
        super().__init__(factory, wrappedProtocol)
        self._proxyInfo: Optional[_info.ProxyInfo] = None
        self._parser: Union[V2Parser, V1Parser, None] = None

    def dataReceived(self, data: bytes) -> None:
        if self._proxyInfo is not None:
            return self.wrappedProtocol.dataReceived(data)

        parser = self._parser
        if parser is None:
            if (
                len(data) >= 16
                and data[:12] == V2Parser.PREFIX
                and ord(data[12:13]) & 0b11110000 == 0x20
            ):
                self._parser = parser = V2Parser()
            elif len(data) >= 8 and data[:5] == V1Parser.PROXYSTR:
                self._parser = parser = V1Parser()
            else:
                self.loseConnection()
                return None

        try:
            self._proxyInfo, remaining = parser.feed(data)
            if remaining:
                self.wrappedProtocol.dataReceived(remaining)
        except InvalidProxyHeader:
            self.loseConnection()

    def getPeer(self) -> interfaces.IAddress:
        if self._proxyInfo and self._proxyInfo.source:
            return self._proxyInfo.source
        assert self.transport
        return self.transport.getPeer()

    def getHost(self) -> interfaces.IAddress:
        if self._proxyInfo and self._proxyInfo.destination:
            return self._proxyInfo.destination
        assert self.transport
        return self.transport.getHost()


class HAProxyWrappingFactory(policies.WrappingFactory):
    """
    A Factory wrapper that adds PROXY protocol support to connections.
    """

    protocol = HAProxyProtocolWrapper

    def logPrefix(self) -> str:
        """
        Annotate the wrapped factory's log prefix with some text indicating
        the PROXY protocol is in use.

        @rtype: C{str}
        """
        if interfaces.ILoggingContext.providedBy(self.wrappedFactory):
            logPrefix = self.wrappedFactory.logPrefix()
        else:
            logPrefix = self.wrappedFactory.__class__.__name__
        return f"{logPrefix} (PROXY)"


def proxyEndpoint(
    wrappedEndpoint: interfaces.IStreamServerEndpoint,
) -> _WrapperServerEndpoint:
    """
    Wrap an endpoint with PROXY protocol support, so that the transport's
    C{getHost} and C{getPeer} methods reflect the attributes of the proxied
    connection rather than the underlying connection.

    @param wrappedEndpoint: The underlying listening endpoint.
    @type wrappedEndpoint: L{IStreamServerEndpoint}

    @return: a new listening endpoint that speaks the PROXY protocol.
    @rtype: L{IStreamServerEndpoint}
    """
    return _WrapperServerEndpoint(wrappedEndpoint, HAProxyWrappingFactory)