aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/Twisted/py2/twisted/names/srvconnect.py
blob: 5346808d01b11b917fbb02538f0ef4786da1eb2e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# -*- test-case-name: twisted.names.test.test_srvconnect -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

from __future__ import absolute_import, division

import random

from zope.interface import implementer

from twisted.internet import error, interfaces
from twisted.names import client, dns
from twisted.names.error import DNSNameError
from twisted.python.compat import nativeString



class _SRVConnector_ClientFactoryWrapper:
    def __init__(self, connector, wrappedFactory):
        self.__connector = connector
        self.__wrappedFactory = wrappedFactory


    def startedConnecting(self, connector):
        self.__wrappedFactory.startedConnecting(self.__connector)


    def clientConnectionFailed(self, connector, reason):
        self.__connector.connectionFailed(reason)


    def clientConnectionLost(self, connector, reason):
        self.__connector.connectionLost(reason)


    def __getattr__(self, key):
        return getattr(self.__wrappedFactory, key)



@implementer(interfaces.IConnector)
class SRVConnector:
    """
    A connector that looks up DNS SRV records.

    RFC 2782 details how SRV records should be interpreted and selected
    for subsequent connection attempts. The algorithm for using the records'
    priority and weight is implemented in L{pickServer}.

    @ivar servers: List of candidate server records for future connection
        attempts.
    @type servers: L{list} of L{dns.Record_SRV}

    @ivar orderedServers: List of server records that have already been tried
        in this round of connection attempts.
    @type orderedServers: L{list} of L{dns.Record_SRV}
    """

    stopAfterDNS = 0

    def __init__(self, reactor, service, domain, factory,
                 protocol='tcp', connectFuncName='connectTCP',
                 connectFuncArgs=(),
                 connectFuncKwArgs={},
                 defaultPort=None,
                 ):
        """
        @param domain: The domain to connect to.  If passed as a text
            string, it will be encoded using C{idna} encoding.
        @type domain: L{bytes} or L{str}

        @param defaultPort: Optional default port number to be used when SRV
            lookup fails and the service name is unknown. This should be the
            port number associated with the service name as defined by the IANA
            registry.
        @type defaultPort: L{int}
        """
        self.reactor = reactor
        self.service = service
        self.domain = None if domain is None else dns.domainString(domain)
        self.factory = factory

        self.protocol = protocol
        self.connectFuncName = connectFuncName
        self.connectFuncArgs = connectFuncArgs
        self.connectFuncKwArgs = connectFuncKwArgs
        self._defaultPort = defaultPort

        self.connector = None
        self.servers = None
        # list of servers already used in this round:
        self.orderedServers = None


    def connect(self):
        """Start connection to remote server."""
        self.factory.doStart()
        self.factory.startedConnecting(self)

        if not self.servers:
            if self.domain is None:
                self.connectionFailed(
                    error.DNSLookupError("Domain is not defined."),
                )
                return
            d = client.lookupService('_%s._%s.%s' % (
                nativeString(self.service),
                nativeString(self.protocol),
                nativeString(self.domain)),
            )
            d.addCallbacks(self._cbGotServers, self._ebGotServers)
            d.addCallback(lambda x, self=self: self._reallyConnect())
            if self._defaultPort:
                d.addErrback(self._ebServiceUnknown)
            d.addErrback(self.connectionFailed)
        elif self.connector is None:
            self._reallyConnect()
        else:
            self.connector.connect()


    def _ebGotServers(self, failure):
        failure.trap(DNSNameError)

        # Some DNS servers reply with NXDOMAIN when in fact there are
        # just no SRV records for that domain. Act as if we just got an
        # empty response and use fallback.

        self.servers = []
        self.orderedServers = []


    def _cbGotServers(self, result):
        answers, auth, add = result
        if len(answers) == 1 and answers[0].type == dns.SRV \
                             and answers[0].payload \
                             and answers[0].payload.target == dns.Name(b'.'):
            # decidedly not available
            raise error.DNSLookupError("Service %s not available for domain %s."
                                       % (repr(self.service), repr(self.domain)))

        self.servers = []
        self.orderedServers = []
        for a in answers:
            if a.type != dns.SRV or not a.payload:
                continue

            self.orderedServers.append(a.payload)


    def _ebServiceUnknown(self, failure):
        """
        Connect to the default port when the service name is unknown.

        If no SRV records were found, the service name will be passed as the
        port. If resolving the name fails with
        L{error.ServiceNameUnknownError}, a final attempt is done using the
        default port.
        """
        failure.trap(error.ServiceNameUnknownError)
        self.servers = [dns.Record_SRV(0, 0, self._defaultPort, self.domain)]
        self.orderedServers = []
        self.connect()


    def pickServer(self):
        """
        Pick the next server.

        This selects the next server from the list of SRV records according
        to their priority and weight values, as set out by the default
        algorithm specified in RFC 2782.

        At the beginning of a round, L{servers} is populated with
        L{orderedServers}, and the latter is made empty. L{servers}
        is the list of candidates, and L{orderedServers} is the list of servers
        that have already been tried.

        First, all records are ordered by priority and weight in ascending
        order. Then for each priority level, a running sum is calculated
        over the sorted list of records for that priority. Then a random value
        between 0 and the final sum is compared to each record in order. The
        first record that is greater than or equal to that random value is
        chosen and removed from the list of candidates for this round.

        @return: A tuple of target hostname and port from the chosen DNS SRV
            record.
        @rtype: L{tuple} of native L{str} and L{int}
        """
        assert self.servers is not None
        assert self.orderedServers is not None

        if not self.servers and not self.orderedServers:
            # no SRV record, fall back..
            return nativeString(self.domain), self.service

        if not self.servers and self.orderedServers:
            # start new round
            self.servers = self.orderedServers
            self.orderedServers = []

        assert self.servers

        self.servers.sort(key=lambda record: (record.priority, record.weight))
        minPriority = self.servers[0].priority

        index = 0
        weightSum = 0
        weightIndex = []
        for x in self.servers:
            if x.priority == minPriority:
                weightSum += x.weight
                weightIndex.append((index, weightSum))
                index += 1

        rand = random.randint(0, weightSum)
        for index, weight in weightIndex:
            if weight >= rand:
                chosen = self.servers[index]
                del self.servers[index]
                self.orderedServers.append(chosen)

                return str(chosen.target), chosen.port

        raise RuntimeError(
            'Impossible %s pickServer result.' % (self.__class__.__name__,))


    def _reallyConnect(self):
        if self.stopAfterDNS:
            self.stopAfterDNS = 0
            return

        self.host, self.port = self.pickServer()
        assert self.host is not None, 'Must have a host to connect to.'
        assert self.port is not None, 'Must have a port to connect to.'

        connectFunc = getattr(self.reactor, self.connectFuncName)
        self.connector = connectFunc(
            self.host, self.port,
            _SRVConnector_ClientFactoryWrapper(self, self.factory),
            *self.connectFuncArgs, **self.connectFuncKwArgs)


    def stopConnecting(self):
        """Stop attempting to connect."""
        if self.connector:
            self.connector.stopConnecting()
        else:
            self.stopAfterDNS = 1


    def disconnect(self):
        """Disconnect whatever our are state is."""
        if self.connector is not None:
            self.connector.disconnect()
        else:
            self.stopConnecting()


    def getDestination(self):
        assert self.connector
        return self.connector.getDestination()


    def connectionFailed(self, reason):
        self.factory.clientConnectionFailed(self, reason)
        self.factory.doStop()


    def connectionLost(self, reason):
        self.factory.clientConnectionLost(self, reason)
        self.factory.doStop()