diff options
author | shmel1k <shmel1k@ydb.tech> | 2023-11-26 18:16:14 +0300 |
---|---|---|
committer | shmel1k <shmel1k@ydb.tech> | 2023-11-26 18:43:30 +0300 |
commit | b8cf9e88f4c5c64d9406af533d8948deb050d695 (patch) | |
tree | 218eb61fb3c3b96ec08b4d8cdfef383104a87d63 /contrib/python/Twisted/py2/twisted/protocols | |
parent | 523f645a83a0ec97a0332dbc3863bb354c92a328 (diff) | |
download | ydb-b8cf9e88f4c5c64d9406af533d8948deb050d695.tar.gz |
add kikimr_configure
Diffstat (limited to 'contrib/python/Twisted/py2/twisted/protocols')
28 files changed, 13971 insertions, 0 deletions
diff --git a/contrib/python/Twisted/py2/twisted/protocols/__init__.py b/contrib/python/Twisted/py2/twisted/protocols/__init__.py new file mode 100644 index 0000000000..b04f3ec798 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Twisted Protocols: A collection of internet protocol implementations. +""" + +from incremental import Version +from twisted.python.deprecate import deprecatedModuleAttribute + + +deprecatedModuleAttribute( + Version('Twisted', 17, 9, 0), + "There is no replacement for this module.", + "twisted.protocols", "dict") diff --git a/contrib/python/Twisted/py2/twisted/protocols/amp.py b/contrib/python/Twisted/py2/twisted/protocols/amp.py new file mode 100644 index 0000000000..322d633b68 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/amp.py @@ -0,0 +1,2897 @@ +# -*- test-case-name: twisted.test.test_amp -*- +# Copyright (c) 2005 Divmod, Inc. +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +This module implements AMP, the Asynchronous Messaging Protocol. + +AMP is a protocol for sending multiple asynchronous request/response pairs over +the same connection. Requests and responses are both collections of key/value +pairs. + +AMP is a very simple protocol which is not an application. This module is a +"protocol construction kit" of sorts; it attempts to be the simplest wire-level +implementation of Deferreds. AMP provides the following base-level features: + + - Asynchronous request/response handling (hence the name) + + - Requests and responses are both key/value pairs + + - Binary transfer of all data: all data is length-prefixed. Your + application will never need to worry about quoting. + + - Command dispatching (like HTTP Verbs): the protocol is extensible, and + multiple AMP sub-protocols can be grouped together easily. + +The protocol implementation also provides a few additional features which are +not part of the core wire protocol, but are nevertheless very useful: + + - Tight TLS integration, with an included StartTLS command. + + - Handshaking to other protocols: because AMP has well-defined message + boundaries and maintains all incoming and outgoing requests for you, you + can start a connection over AMP and then switch to another protocol. + This makes it ideal for firewall-traversal applications where you may + have only one forwarded port but multiple applications that want to use + it. + +Using AMP with Twisted is simple. Each message is a command, with a response. +You begin by defining a command type. Commands specify their input and output +in terms of the types that they expect to see in the request and response +key-value pairs. Here's an example of a command that adds two integers, 'a' +and 'b':: + + class Sum(amp.Command): + arguments = [('a', amp.Integer()), + ('b', amp.Integer())] + response = [('total', amp.Integer())] + +Once you have specified a command, you need to make it part of a protocol, and +define a responder for it. Here's a 'JustSum' protocol that includes a +responder for our 'Sum' command:: + + class JustSum(amp.AMP): + def sum(self, a, b): + total = a + b + print 'Did a sum: %d + %d = %d' % (a, b, total) + return {'total': total} + Sum.responder(sum) + +Later, when you want to actually do a sum, the following expression will return +a L{Deferred} which will fire with the result:: + + ClientCreator(reactor, amp.AMP).connectTCP(...).addCallback( + lambda p: p.callRemote(Sum, a=13, b=81)).addCallback( + lambda result: result['total']) + +Command responders may also return Deferreds, causing the response to be +sent only once the Deferred fires:: + + class DelayedSum(amp.AMP): + def slowSum(self, a, b): + total = a + b + result = defer.Deferred() + reactor.callLater(3, result.callback, {'total': total}) + return result + Sum.responder(slowSum) + +This is transparent to the caller. + +You can also define the propagation of specific errors in AMP. For example, +for the slightly more complicated case of division, we might have to deal with +division by zero:: + + class Divide(amp.Command): + arguments = [('numerator', amp.Integer()), + ('denominator', amp.Integer())] + response = [('result', amp.Float())] + errors = {ZeroDivisionError: 'ZERO_DIVISION'} + +The 'errors' mapping here tells AMP that if a responder to Divide emits a +L{ZeroDivisionError}, then the other side should be informed that an error of +the type 'ZERO_DIVISION' has occurred. Writing a responder which takes +advantage of this is very simple - just raise your exception normally:: + + class JustDivide(amp.AMP): + def divide(self, numerator, denominator): + result = numerator / denominator + print 'Divided: %d / %d = %d' % (numerator, denominator, total) + return {'result': result} + Divide.responder(divide) + +On the client side, the errors mapping will be used to determine what the +'ZERO_DIVISION' error means, and translated into an asynchronous exception, +which can be handled normally as any L{Deferred} would be:: + + def trapZero(result): + result.trap(ZeroDivisionError) + print "Divided by zero: returning INF" + return 1e1000 + ClientCreator(reactor, amp.AMP).connectTCP(...).addCallback( + lambda p: p.callRemote(Divide, numerator=1234, + denominator=0) + ).addErrback(trapZero) + +For a complete, runnable example of both of these commands, see the files in +the Twisted repository:: + + doc/core/examples/ampserver.py + doc/core/examples/ampclient.py + +On the wire, AMP is a protocol which uses 2-byte lengths to prefix keys and +values, and empty keys to separate messages:: + + <2-byte length><key><2-byte length><value> + <2-byte length><key><2-byte length><value> + ... + <2-byte length><key><2-byte length><value> + <NUL><NUL> # Empty Key == End of Message + +And so on. Because it's tedious to refer to lengths and NULs constantly, the +documentation will refer to packets as if they were newline delimited, like +so:: + + C: _command: sum + C: _ask: ef639e5c892ccb54 + C: a: 13 + C: b: 81 + + S: _answer: ef639e5c892ccb54 + S: total: 94 + +Notes: + +In general, the order of keys is arbitrary. Specific uses of AMP may impose an +ordering requirement, but unless this is specified explicitly, any ordering may +be generated and any ordering must be accepted. This applies to the +command-related keys I{_command} and I{_ask} as well as any other keys. + +Values are limited to the maximum encodable size in a 16-bit length, 65535 +bytes. + +Keys are limited to the maximum encodable size in a 8-bit length, 255 bytes. +Note that we still use 2-byte lengths to encode keys. This small redundancy +has several features: + + - If an implementation becomes confused and starts emitting corrupt data, + or gets keys confused with values, many common errors will be signalled + immediately instead of delivering obviously corrupt packets. + + - A single NUL will separate every key, and a double NUL separates + messages. This provides some redundancy when debugging traffic dumps. + + - NULs will be present at regular intervals along the protocol, providing + some padding for otherwise braindead C implementations of the protocol, + so that <stdio.h> string functions will see the NUL and stop. + + - This makes it possible to run an AMP server on a port also used by a + plain-text protocol, and easily distinguish between non-AMP clients (like + web browsers) which issue non-NUL as the first byte, and AMP clients, + which always issue NUL as the first byte. + +@var MAX_VALUE_LENGTH: The maximum length of a message. +@type MAX_VALUE_LENGTH: L{int} + +@var ASK: Marker for an Ask packet. +@type ASK: L{bytes} + +@var ANSWER: Marker for an Answer packet. +@type ANSWER: L{bytes} + +@var COMMAND: Marker for a Command packet. +@type COMMAND: L{bytes} + +@var ERROR: Marker for an AMP box of error type. +@type ERROR: L{bytes} + +@var ERROR_CODE: Marker for an AMP box containing the code of an error. +@type ERROR_CODE: L{bytes} + +@var ERROR_DESCRIPTION: Marker for an AMP box containing the description of the + error. +@type ERROR_DESCRIPTION: L{bytes} +""" + +from __future__ import absolute_import, division + +__metaclass__ = type + +import types, warnings + +from io import BytesIO +from struct import pack +import decimal, datetime +from functools import partial +from itertools import count + +from zope.interface import Interface, implementer + +from twisted.python.reflect import accumulateClassDict +from twisted.python.failure import Failure +from twisted.python._tzhelper import ( + FixedOffsetTimeZone as _FixedOffsetTZInfo, UTC as utc +) + +from twisted.python import log, filepath + +from twisted.internet.interfaces import IFileDescriptorReceiver +from twisted.internet.main import CONNECTION_LOST +from twisted.internet.error import PeerVerifyError, ConnectionLost +from twisted.internet.error import ConnectionClosed +from twisted.internet.defer import Deferred, maybeDeferred, fail +from twisted.protocols.basic import Int16StringReceiver, StatefulStringProtocol +from twisted.python.compat import ( + iteritems, unicode, nativeString, intToBytes, _PY3, long, +) + +try: + from twisted.internet import ssl +except ImportError: + ssl = None + +if ssl and not ssl.supported: + ssl = None + +if ssl is not None: + from twisted.internet.ssl import (CertificateOptions, Certificate, DN, + KeyPair) + + + +__all__ = [ + 'AMP', + 'ANSWER', + 'ASK', + 'AmpBox', + 'AmpError', + 'AmpList', + 'Argument', + 'BadLocalReturn', + 'BinaryBoxProtocol', + 'Boolean', + 'Box', + 'BoxDispatcher', + 'COMMAND', + 'Command', + 'CommandLocator', + 'Decimal', + 'Descriptor', + 'ERROR', + 'ERROR_CODE', + 'ERROR_DESCRIPTION', + 'Float', + 'IArgumentType', + 'IBoxReceiver', + 'IBoxSender', + 'IResponderLocator', + 'IncompatibleVersions', + 'Integer', + 'InvalidSignature', + 'ListOf', + 'MAX_KEY_LENGTH', + 'MAX_VALUE_LENGTH', + 'MalformedAmpBox', + 'NoEmptyBoxes', + 'OnlyOneTLS', + 'PROTOCOL_ERRORS', + 'PYTHON_KEYWORDS', + 'Path', + 'ProtocolSwitchCommand', + 'ProtocolSwitched', + 'QuitBox', + 'RemoteAmpError', + 'SimpleStringLocator', + 'StartTLS', + 'String', + 'TooLong', + 'UNHANDLED_ERROR_CODE', + 'UNKNOWN_ERROR_CODE', + 'UnhandledCommand', + 'utc', + 'Unicode', + 'UnknownRemoteError', + 'parse', + 'parseString', +] + + + +ASK = b'_ask' +ANSWER = b'_answer' +COMMAND = b'_command' +ERROR = b'_error' +ERROR_CODE = b'_error_code' +ERROR_DESCRIPTION = b'_error_description' +UNKNOWN_ERROR_CODE = b'UNKNOWN' +UNHANDLED_ERROR_CODE = b'UNHANDLED' + +MAX_KEY_LENGTH = 0xff +MAX_VALUE_LENGTH = 0xffff + + + +class IArgumentType(Interface): + """ + An L{IArgumentType} can serialize a Python object into an AMP box and + deserialize information from an AMP box back into a Python object. + + @since: 9.0 + """ + def fromBox(name, strings, objects, proto): + """ + Given an argument name and an AMP box containing serialized values, + extract one or more Python objects and add them to the C{objects} + dictionary. + + @param name: The name associated with this argument. Most commonly + this is the key which can be used to find a serialized value in + C{strings}. + @type name: C{bytes} + + @param strings: The AMP box from which to extract one or more + values. + @type strings: C{dict} + + @param objects: The output dictionary to populate with the value for + this argument. The key used will be derived from C{name}. It may + differ; in Python 3, for example, the key will be a Unicode/native + string. See L{_wireNameToPythonIdentifier}. + @type objects: C{dict} + + @param proto: The protocol instance which received the AMP box being + interpreted. Most likely this is an instance of L{AMP}, but + this is not guaranteed. + + @return: L{None} + """ + + + def toBox(name, strings, objects, proto): + """ + Given an argument name and a dictionary containing structured Python + objects, serialize values into one or more strings and add them to + the C{strings} dictionary. + + @param name: The name associated with this argument. Most commonly + this is the key in C{strings} to associate with a C{bytes} giving + the serialized form of that object. + @type name: C{bytes} + + @param strings: The AMP box into which to insert one or more strings. + @type strings: C{dict} + + @param objects: The input dictionary from which to extract Python + objects to serialize. The key used will be derived from C{name}. + It may differ; in Python 3, for example, the key will be a + Unicode/native string. See L{_wireNameToPythonIdentifier}. + @type objects: C{dict} + + @param proto: The protocol instance which will send the AMP box once + it is fully populated. Most likely this is an instance of + L{AMP}, but this is not guaranteed. + + @return: L{None} + """ + + + +class IBoxSender(Interface): + """ + A transport which can send L{AmpBox} objects. + """ + + def sendBox(box): + """ + Send an L{AmpBox}. + + @raise ProtocolSwitched: if the underlying protocol has been + switched. + + @raise ConnectionLost: if the underlying connection has already been + lost. + """ + + def unhandledError(failure): + """ + An unhandled error occurred in response to a box. Log it + appropriately. + + @param failure: a L{Failure} describing the error that occurred. + """ + + + +class IBoxReceiver(Interface): + """ + An application object which can receive L{AmpBox} objects and dispatch them + appropriately. + """ + + def startReceivingBoxes(boxSender): + """ + The L{IBoxReceiver.ampBoxReceived} method will start being called; + boxes may be responded to by responding to the given L{IBoxSender}. + + @param boxSender: an L{IBoxSender} provider. + """ + + + def ampBoxReceived(box): + """ + A box was received from the transport; dispatch it appropriately. + """ + + + def stopReceivingBoxes(reason): + """ + No further boxes will be received on this connection. + + @type reason: L{Failure} + """ + + + +class IResponderLocator(Interface): + """ + An application object which can look up appropriate responder methods for + AMP commands. + """ + + def locateResponder(name): + """ + Locate a responder method appropriate for the named command. + + @param name: the wire-level name (commandName) of the AMP command to be + responded to. + @type name: C{bytes} + + @return: a 1-argument callable that takes an L{AmpBox} with argument + values for the given command, and returns an L{AmpBox} containing + argument values for the named command, or a L{Deferred} that fires the + same. + """ + + + +class AmpError(Exception): + """ + Base class of all Amp-related exceptions. + """ + + + +class ProtocolSwitched(Exception): + """ + Connections which have been switched to other protocols can no longer + accept traffic at the AMP level. This is raised when you try to send it. + """ + + + +class OnlyOneTLS(AmpError): + """ + This is an implementation limitation; TLS may only be started once per + connection. + """ + + + +class NoEmptyBoxes(AmpError): + """ + You can't have empty boxes on the connection. This is raised when you + receive or attempt to send one. + """ + + + +class InvalidSignature(AmpError): + """ + You didn't pass all the required arguments. + """ + + + +class TooLong(AmpError): + """ + One of the protocol's length limitations was violated. + + @ivar isKey: true if the string being encoded in a key position, false if + it was in a value position. + + @ivar isLocal: Was the string encoded locally, or received too long from + the network? (It's only physically possible to encode "too long" values on + the network for keys.) + + @ivar value: The string that was too long. + + @ivar keyName: If the string being encoded was in a value position, what + key was it being encoded for? + """ + + def __init__(self, isKey, isLocal, value, keyName=None): + AmpError.__init__(self) + self.isKey = isKey + self.isLocal = isLocal + self.value = value + self.keyName = keyName + + + def __repr__(self): + hdr = self.isKey and "key" or "value" + if not self.isKey: + hdr += ' ' + repr(self.keyName) + lcl = self.isLocal and "local" or "remote" + return "%s %s too long: %d" % (lcl, hdr, len(self.value)) + + + +class BadLocalReturn(AmpError): + """ + A bad value was returned from a local command; we were unable to coerce it. + """ + def __init__(self, message, enclosed): + AmpError.__init__(self) + self.message = message + self.enclosed = enclosed + + + def __repr__(self): + return self.message + " " + self.enclosed.getBriefTraceback() + + __str__ = __repr__ + + + +class RemoteAmpError(AmpError): + """ + This error indicates that something went wrong on the remote end of the + connection, and the error was serialized and transmitted to you. + """ + def __init__(self, errorCode, description, fatal=False, local=None): + """Create a remote error with an error code and description. + + @param errorCode: the AMP error code of this error. + @type errorCode: C{bytes} + + @param description: some text to show to the user. + @type description: C{str} + + @param fatal: a boolean, true if this error should terminate the + connection. + + @param local: a local Failure, if one exists. + """ + if local: + localwhat = ' (local)' + othertb = local.getBriefTraceback() + else: + localwhat = '' + othertb = '' + + # Backslash-escape errorCode. Python 3.5 can do this natively + # ("backslashescape") but Python 2.7 and Python 3.4 can't. + if _PY3: + errorCodeForMessage = "".join( + "\\x%2x" % (c,) if c >= 0x80 else chr(c) + for c in errorCode) + else: + errorCodeForMessage = "".join( + "\\x%2x" % (ord(c),) if ord(c) >= 0x80 else c + for c in errorCode) + + if othertb: + message = "Code<%s>%s: %s\n%s" % ( + errorCodeForMessage, localwhat, description, othertb) + else: + message = "Code<%s>%s: %s" % ( + errorCodeForMessage, localwhat, description) + + super(RemoteAmpError, self).__init__(message) + self.local = local + self.errorCode = errorCode + self.description = description + self.fatal = fatal + + + +class UnknownRemoteError(RemoteAmpError): + """ + This means that an error whose type we can't identify was raised from the + other side. + """ + def __init__(self, description): + errorCode = UNKNOWN_ERROR_CODE + RemoteAmpError.__init__(self, errorCode, description) + + + +class MalformedAmpBox(AmpError): + """ + This error indicates that the wire-level protocol was malformed. + """ + + + +class UnhandledCommand(AmpError): + """ + A command received via amp could not be dispatched. + """ + + + +class IncompatibleVersions(AmpError): + """ + It was impossible to negotiate a compatible version of the protocol with + the other end of the connection. + """ + + +PROTOCOL_ERRORS = {UNHANDLED_ERROR_CODE: UnhandledCommand} + +class AmpBox(dict): + """ + I am a packet in the AMP protocol, much like a regular bytes:bytes dictionary. + """ + __slots__ = [] # be like a regular dictionary, don't magically + # acquire a __dict__... + + + def __init__(self, *args, **kw): + """ + Initialize a new L{AmpBox}. + + In Python 3, keyword arguments MUST be Unicode/native strings whereas + in Python 2 they could be either byte strings or Unicode strings. + + However, all keys of an L{AmpBox} MUST be byte strings, or possible to + transparently coerce into byte strings (i.e. Python 2). + + In Python 3, therefore, native string keys are coerced to byte strings + by encoding as ASCII. This can result in C{UnicodeEncodeError} being + raised. + + @param args: See C{dict}, but all keys and values should be C{bytes}. + On Python 3, native strings may be used as keys provided they + contain only ASCII characters. + + @param kw: See C{dict}, but all keys and values should be C{bytes}. + On Python 3, native strings may be used as keys provided they + contain only ASCII characters. + + @raise UnicodeEncodeError: When a native string key cannot be coerced + to an ASCII byte string (Python 3 only). + """ + super(AmpBox, self).__init__(*args, **kw) + if _PY3: + nonByteNames = [n for n in self if not isinstance(n, bytes)] + for nonByteName in nonByteNames: + byteName = nonByteName.encode("ascii") + self[byteName] = self.pop(nonByteName) + + + def copy(self): + """ + Return another AmpBox just like me. + """ + newBox = self.__class__() + newBox.update(self) + return newBox + + + def serialize(self): + """ + Convert me into a wire-encoded string. + + @return: a C{bytes} encoded according to the rules described in the + module docstring. + """ + i = sorted(iteritems(self)) + L = [] + w = L.append + for k, v in i: + if type(k) == unicode: + raise TypeError("Unicode key not allowed: %r" % k) + if type(v) == unicode: + raise TypeError( + "Unicode value for key %r not allowed: %r" % (k, v)) + if len(k) > MAX_KEY_LENGTH: + raise TooLong(True, True, k, None) + if len(v) > MAX_VALUE_LENGTH: + raise TooLong(False, True, v, k) + for kv in k, v: + w(pack("!H", len(kv))) + w(kv) + w(pack("!H", 0)) + return b''.join(L) + + + def _sendTo(self, proto): + """ + Serialize and send this box to an Amp instance. By the time it is being + sent, several keys are required. I must have exactly ONE of:: + + _ask + _answer + _error + + If the '_ask' key is set, then the '_command' key must also be + set. + + @param proto: an AMP instance. + """ + proto.sendBox(self) + + def __repr__(self): + return 'AmpBox(%s)' % (dict.__repr__(self),) + +# amp.Box => AmpBox + +Box = AmpBox + +class QuitBox(AmpBox): + """ + I am an AmpBox that, upon being sent, terminates the connection. + """ + __slots__ = [] + + + def __repr__(self): + return 'QuitBox(**%s)' % (super(QuitBox, self).__repr__(),) + + + def _sendTo(self, proto): + """ + Immediately call loseConnection after sending. + """ + super(QuitBox, self)._sendTo(proto) + proto.transport.loseConnection() + + + +class _SwitchBox(AmpBox): + """ + Implementation detail of ProtocolSwitchCommand: I am an AmpBox which sets + up state for the protocol to switch. + """ + + # DON'T set __slots__ here; we do have an attribute. + + def __init__(self, innerProto, **kw): + """ + Create a _SwitchBox with the protocol to switch to after being sent. + + @param innerProto: the protocol instance to switch to. + @type innerProto: an IProtocol provider. + """ + super(_SwitchBox, self).__init__(**kw) + self.innerProto = innerProto + + + def __repr__(self): + return '_SwitchBox(%r, **%s)' % (self.innerProto, + dict.__repr__(self),) + + + def _sendTo(self, proto): + """ + Send me; I am the last box on the connection. All further traffic will be + over the new protocol. + """ + super(_SwitchBox, self)._sendTo(proto) + proto._lockForSwitch() + proto._switchTo(self.innerProto) + + + +@implementer(IBoxReceiver) +class BoxDispatcher: + """ + A L{BoxDispatcher} dispatches '_ask', '_answer', and '_error' L{AmpBox}es, + both incoming and outgoing, to their appropriate destinations. + + Outgoing commands are converted into L{Deferred}s and outgoing boxes, and + associated tracking state to fire those L{Deferred} when '_answer' boxes + come back. Incoming '_answer' and '_error' boxes are converted into + callbacks and errbacks on those L{Deferred}s, respectively. + + Incoming '_ask' boxes are converted into method calls on a supplied method + locator. + + @ivar _outstandingRequests: a dictionary mapping request IDs to + L{Deferred}s which were returned for those requests. + + @ivar locator: an object with a L{CommandLocator.locateResponder} method + that locates a responder function that takes a Box and returns a result + (either a Box or a Deferred which fires one). + + @ivar boxSender: an object which can send boxes, via the L{_sendBoxCommand} + method, such as an L{AMP} instance. + @type boxSender: L{IBoxSender} + """ + + _failAllReason = None + _outstandingRequests = None + _counter = long(0) + boxSender = None + + def __init__(self, locator): + self._outstandingRequests = {} + self.locator = locator + + + def startReceivingBoxes(self, boxSender): + """ + The given boxSender is going to start calling boxReceived on this + L{BoxDispatcher}. + + @param boxSender: The L{IBoxSender} to send command responses to. + """ + self.boxSender = boxSender + + + def stopReceivingBoxes(self, reason): + """ + No further boxes will be received here. Terminate all currently + outstanding command deferreds with the given reason. + """ + self.failAllOutgoing(reason) + + + def failAllOutgoing(self, reason): + """ + Call the errback on all outstanding requests awaiting responses. + + @param reason: the Failure instance to pass to those errbacks. + """ + self._failAllReason = reason + OR = self._outstandingRequests.items() + self._outstandingRequests = None # we can never send another request + for key, value in OR: + value.errback(reason) + + + def _nextTag(self): + """ + Generate protocol-local serial numbers for _ask keys. + + @return: a string that has not yet been used on this connection. + """ + self._counter += 1 + return (b'%x' % (self._counter,)) + + + def _sendBoxCommand(self, command, box, requiresAnswer=True): + """ + Send a command across the wire with the given C{amp.Box}. + + Mutate the given box to give it any additional keys (_command, _ask) + required for the command and request/response machinery, then send it. + + If requiresAnswer is True, returns a C{Deferred} which fires when a + response is received. The C{Deferred} is fired with an C{amp.Box} on + success, or with an C{amp.RemoteAmpError} if an error is received. + + If the Deferred fails and the error is not handled by the caller of + this method, the failure will be logged and the connection dropped. + + @param command: a C{bytes}, the name of the command to issue. + + @param box: an AmpBox with the arguments for the command. + + @param requiresAnswer: a boolean. Defaults to True. If True, return a + Deferred which will fire when the other side responds to this command. + If False, return None and do not ask the other side for acknowledgement. + + @return: a Deferred which fires the AmpBox that holds the response to + this command, or None, as specified by requiresAnswer. + + @raise ProtocolSwitched: if the protocol has been switched. + """ + if self._failAllReason is not None: + if requiresAnswer: + return fail(self._failAllReason) + else: + return None + box[COMMAND] = command + tag = self._nextTag() + if requiresAnswer: + box[ASK] = tag + box._sendTo(self.boxSender) + if requiresAnswer: + result = self._outstandingRequests[tag] = Deferred() + else: + result = None + return result + + + def callRemoteString(self, command, requiresAnswer=True, **kw): + """ + This is a low-level API, designed only for optimizing simple messages + for which the overhead of parsing is too great. + + @param command: a C{bytes} naming the command. + + @param kw: arguments to the amp box. + + @param requiresAnswer: a boolean. Defaults to True. If True, return a + Deferred which will fire when the other side responds to this command. + If False, return None and do not ask the other side for acknowledgement. + + @return: a Deferred which fires the AmpBox that holds the response to + this command, or None, as specified by requiresAnswer. + """ + box = Box(kw) + return self._sendBoxCommand(command, box, requiresAnswer) + + + def callRemote(self, commandType, *a, **kw): + """ + This is the primary high-level API for sending messages via AMP. Invoke it + with a command and appropriate arguments to send a message to this + connection's peer. + + @param commandType: a subclass of Command. + @type commandType: L{type} + + @param a: Positional (special) parameters taken by the command. + Positional parameters will typically not be sent over the wire. The + only command included with AMP which uses positional parameters is + L{ProtocolSwitchCommand}, which takes the protocol that will be + switched to as its first argument. + + @param kw: Keyword arguments taken by the command. These are the + arguments declared in the command's 'arguments' attribute. They will + be encoded and sent to the peer as arguments for the L{commandType}. + + @return: If L{commandType} has a C{requiresAnswer} attribute set to + L{False}, then return L{None}. Otherwise, return a L{Deferred} which + fires with a dictionary of objects representing the result of this + call. Additionally, this L{Deferred} may fail with an exception + representing a connection failure, with L{UnknownRemoteError} if the + other end of the connection fails for an unknown reason, or with any + error specified as a key in L{commandType}'s C{errors} dictionary. + """ + + # XXX this takes command subclasses and not command objects on purpose. + # There's really no reason to have all this back-and-forth between + # command objects and the protocol, and the extra object being created + # (the Command instance) is pointless. Command is kind of like + # Interface, and should be more like it. + + # In other words, the fact that commandType is instantiated here is an + # implementation detail. Don't rely on it. + + try: + co = commandType(*a, **kw) + except: + return fail() + return co._doCommand(self) + + + def unhandledError(self, failure): + """ + This is a terminal callback called after application code has had a + chance to quash any errors. + """ + return self.boxSender.unhandledError(failure) + + + def _answerReceived(self, box): + """ + An AMP box was received that answered a command previously sent with + L{callRemote}. + + @param box: an AmpBox with a value for its L{ANSWER} key. + """ + question = self._outstandingRequests.pop(box[ANSWER]) + question.addErrback(self.unhandledError) + question.callback(box) + + + def _errorReceived(self, box): + """ + An AMP box was received that answered a command previously sent with + L{callRemote}, with an error. + + @param box: an L{AmpBox} with a value for its L{ERROR}, L{ERROR_CODE}, + and L{ERROR_DESCRIPTION} keys. + """ + question = self._outstandingRequests.pop(box[ERROR]) + question.addErrback(self.unhandledError) + errorCode = box[ERROR_CODE] + description = box[ERROR_DESCRIPTION] + if isinstance(description, bytes): + description = description.decode("utf-8", "replace") + if errorCode in PROTOCOL_ERRORS: + exc = PROTOCOL_ERRORS[errorCode](errorCode, description) + else: + exc = RemoteAmpError(errorCode, description) + question.errback(Failure(exc)) + + + def _commandReceived(self, box): + """ + @param box: an L{AmpBox} with a value for its L{COMMAND} and L{ASK} + keys. + """ + def formatAnswer(answerBox): + answerBox[ANSWER] = box[ASK] + return answerBox + def formatError(error): + if error.check(RemoteAmpError): + code = error.value.errorCode + desc = error.value.description + if isinstance(desc, unicode): + desc = desc.encode("utf-8", "replace") + if error.value.fatal: + errorBox = QuitBox() + else: + errorBox = AmpBox() + else: + errorBox = QuitBox() + log.err(error) # here is where server-side logging happens + # if the error isn't handled + code = UNKNOWN_ERROR_CODE + desc = b"Unknown Error" + errorBox[ERROR] = box[ASK] + errorBox[ERROR_DESCRIPTION] = desc + errorBox[ERROR_CODE] = code + return errorBox + deferred = self.dispatchCommand(box) + if ASK in box: + deferred.addCallbacks(formatAnswer, formatError) + deferred.addCallback(self._safeEmit) + deferred.addErrback(self.unhandledError) + + + def ampBoxReceived(self, box): + """ + An AmpBox was received, representing a command, or an answer to a + previously issued command (either successful or erroneous). Respond to + it according to its contents. + + @param box: an AmpBox + + @raise NoEmptyBoxes: when a box is received that does not contain an + '_answer', '_command' / '_ask', or '_error' key; i.e. one which does not + fit into the command / response protocol defined by AMP. + """ + if ANSWER in box: + self._answerReceived(box) + elif ERROR in box: + self._errorReceived(box) + elif COMMAND in box: + self._commandReceived(box) + else: + raise NoEmptyBoxes(box) + + + def _safeEmit(self, aBox): + """ + Emit a box, ignoring L{ProtocolSwitched} and L{ConnectionLost} errors + which cannot be usefully handled. + """ + try: + aBox._sendTo(self.boxSender) + except (ProtocolSwitched, ConnectionLost): + pass + + + def dispatchCommand(self, box): + """ + A box with a _command key was received. + + Dispatch it to a local handler call it. + + @param proto: an AMP instance. + @param box: an AmpBox to be dispatched. + """ + cmd = box[COMMAND] + responder = self.locator.locateResponder(cmd) + if responder is None: + description = "Unhandled Command: %r" % (cmd,) + return fail(RemoteAmpError( + UNHANDLED_ERROR_CODE, + description, + False, + local=Failure(UnhandledCommand()))) + return maybeDeferred(responder, box) + + + +@implementer(IResponderLocator) +class CommandLocator: + """ + A L{CommandLocator} is a collection of responders to AMP L{Command}s, with + the help of the L{Command.responder} decorator. + """ + + class __metaclass__(type): + """ + This metaclass keeps track of all of the Command.responder-decorated + methods defined since the last CommandLocator subclass was defined. It + assumes (usually correctly, but unfortunately not necessarily so) that + those commands responders were all declared as methods of the class + being defined. Note that this list can be incorrect if users use the + Command.responder decorator outside the context of a CommandLocator + class declaration. + + Command responders defined on subclasses are given precedence over + those inherited from a base class. + + The Command.responder decorator explicitly cooperates with this + metaclass. + """ + + _currentClassCommands = [] + def __new__(cls, name, bases, attrs): + commands = cls._currentClassCommands[:] + cls._currentClassCommands[:] = [] + cd = attrs['_commandDispatch'] = {} + subcls = type.__new__(cls, name, bases, attrs) + ancestors = list(subcls.__mro__[1:]) + ancestors.reverse() + for ancestor in ancestors: + cd.update(getattr(ancestor, '_commandDispatch', {})) + for commandClass, responderFunc in commands: + cd[commandClass.commandName] = (commandClass, responderFunc) + if (bases and ( + subcls.lookupFunction != CommandLocator.lookupFunction)): + def locateResponder(self, name): + warnings.warn( + "Override locateResponder, not lookupFunction.", + category=PendingDeprecationWarning, + stacklevel=2) + return self.lookupFunction(name) + subcls.locateResponder = locateResponder + return subcls + + + def _wrapWithSerialization(self, aCallable, command): + """ + Wrap aCallable with its command's argument de-serialization + and result serialization logic. + + @param aCallable: a callable with a 'command' attribute, designed to be + called with keyword arguments. + + @param command: the command class whose serialization to use. + + @return: a 1-arg callable which, when invoked with an AmpBox, will + deserialize the argument list and invoke appropriate user code for the + callable's command, returning a Deferred which fires with the result or + fails with an error. + """ + def doit(box): + kw = command.parseArguments(box, self) + def checkKnownErrors(error): + key = error.trap(*command.allErrors) + code = command.allErrors[key] + desc = str(error.value) + return Failure(RemoteAmpError( + code, desc, key in command.fatalErrors, local=error)) + def makeResponseFor(objects): + try: + return command.makeResponse(objects, self) + except: + # let's helpfully log this. + originalFailure = Failure() + raise BadLocalReturn( + "%r returned %r and %r could not serialize it" % ( + aCallable, + objects, + command), + originalFailure) + return maybeDeferred(aCallable, **kw).addCallback( + makeResponseFor).addErrback( + checkKnownErrors) + return doit + + + def lookupFunction(self, name): + """ + Deprecated synonym for L{CommandLocator.locateResponder} + """ + if self.__class__.lookupFunction != CommandLocator.lookupFunction: + return CommandLocator.locateResponder(self, name) + else: + warnings.warn("Call locateResponder, not lookupFunction.", + category=PendingDeprecationWarning, + stacklevel=2) + return self.locateResponder(name) + + + def locateResponder(self, name): + """ + Locate a callable to invoke when executing the named command. + + @param name: the normalized name (from the wire) of the command. + @type name: C{bytes} + + @return: a 1-argument function that takes a Box and returns a box or a + Deferred which fires a Box, for handling the command identified by the + given name, or None, if no appropriate responder can be found. + """ + # Try to find a high-level method to invoke, and if we can't find one, + # fall back to a low-level one. + cd = self._commandDispatch + if name in cd: + commandClass, responderFunc = cd[name] + if _PY3: + responderMethod = types.MethodType( + responderFunc, self) + else: + responderMethod = types.MethodType( + responderFunc, self, self.__class__) + return self._wrapWithSerialization(responderMethod, commandClass) + + + +if _PY3: + # Python 3 ignores the __metaclass__ attribute and has instead new syntax + # for setting the metaclass. Unfortunately it's not valid Python 2 syntax + # so we work-around it by recreating CommandLocator using the metaclass + # here. + CommandLocator = CommandLocator.__metaclass__( + "CommandLocator", (CommandLocator, ), {}) + + + +@implementer(IResponderLocator) +class SimpleStringLocator(object): + """ + Implement the L{AMP.locateResponder} method to do simple, string-based + dispatch. + """ + + baseDispatchPrefix = b'amp_' + + def locateResponder(self, name): + """ + Locate a callable to invoke when executing the named command. + + @return: a function with the name C{"amp_" + name} on the same + instance, or None if no such function exists. + This function will then be called with the L{AmpBox} itself as an + argument. + + @param name: the normalized name (from the wire) of the command. + @type name: C{bytes} + """ + fName = nativeString(self.baseDispatchPrefix + name.upper()) + return getattr(self, fName, None) + + + +PYTHON_KEYWORDS = [ + 'and', 'del', 'for', 'is', 'raise', 'assert', 'elif', 'from', 'lambda', + 'return', 'break', 'else', 'global', 'not', 'try', 'class', 'except', + 'if', 'or', 'while', 'continue', 'exec', 'import', 'pass', 'yield', + 'def', 'finally', 'in', 'print'] + + + +def _wireNameToPythonIdentifier(key): + """ + (Private) Normalize an argument name from the wire for use with Python + code. If the return value is going to be a python keyword it will be + capitalized. If it contains any dashes they will be replaced with + underscores. + + The rationale behind this method is that AMP should be an inherently + multi-language protocol, so message keys may contain all manner of bizarre + bytes. This is not a complete solution; there are still forms of arguments + that this implementation will be unable to parse. However, Python + identifiers share a huge raft of properties with identifiers from many + other languages, so this is a 'good enough' effort for now. We deal + explicitly with dashes because that is the most likely departure: Lisps + commonly use dashes to separate method names, so protocols initially + implemented in a lisp amp dialect may use dashes in argument or command + names. + + @param key: a C{bytes}, looking something like 'foo-bar-baz' or 'from' + @type key: C{bytes} + + @return: a native string which is a valid python identifier, looking + something like 'foo_bar_baz' or 'From'. + """ + lkey = nativeString(key.replace(b"-", b"_")) + if lkey in PYTHON_KEYWORDS: + return lkey.title() + return lkey + + + +@implementer(IArgumentType) +class Argument: + """ + Base-class of all objects that take values from Amp packets and convert + them into objects for Python functions. + + This implementation of L{IArgumentType} provides several higher-level + hooks for subclasses to override. See L{toString} and L{fromString} + which will be used to define the behavior of L{IArgumentType.toBox} and + L{IArgumentType.fromBox}, respectively. + """ + + optional = False + + + def __init__(self, optional=False): + """ + Create an Argument. + + @param optional: a boolean indicating whether this argument can be + omitted in the protocol. + """ + self.optional = optional + + + def retrieve(self, d, name, proto): + """ + Retrieve the given key from the given dictionary, removing it if found. + + @param d: a dictionary. + + @param name: a key in I{d}. + + @param proto: an instance of an AMP. + + @raise KeyError: if I am not optional and no value was found. + + @return: d[name]. + """ + if self.optional: + value = d.get(name) + if value is not None: + del d[name] + else: + value = d.pop(name) + return value + + + def fromBox(self, name, strings, objects, proto): + """ + Populate an 'out' dictionary with mapping names to Python values + decoded from an 'in' AmpBox mapping strings to string values. + + @param name: the argument name to retrieve + @type name: C{bytes} + + @param strings: The AmpBox to read string(s) from, a mapping of + argument names to string values. + @type strings: AmpBox + + @param objects: The dictionary to write object(s) to, a mapping of + names to Python objects. Keys will be native strings. + @type objects: dict + + @param proto: an AMP instance. + """ + st = self.retrieve(strings, name, proto) + nk = _wireNameToPythonIdentifier(name) + if self.optional and st is None: + objects[nk] = None + else: + objects[nk] = self.fromStringProto(st, proto) + + + def toBox(self, name, strings, objects, proto): + """ + Populate an 'out' AmpBox with strings encoded from an 'in' dictionary + mapping names to Python values. + + @param name: the argument name to retrieve + @type name: C{bytes} + + @param strings: The AmpBox to write string(s) to, a mapping of + argument names to string values. + @type strings: AmpBox + + @param objects: The dictionary to read object(s) from, a mapping of + names to Python objects. Keys should be native strings. + + @type objects: dict + + @param proto: the protocol we are converting for. + @type proto: AMP + """ + obj = self.retrieve(objects, _wireNameToPythonIdentifier(name), proto) + if self.optional and obj is None: + # strings[name] = None + pass + else: + strings[name] = self.toStringProto(obj, proto) + + + def fromStringProto(self, inString, proto): + """ + Convert a string to a Python value. + + @param inString: the string to convert. + @type inString: C{bytes} + + @param proto: the protocol we are converting for. + @type proto: AMP + + @return: a Python object. + """ + return self.fromString(inString) + + + def toStringProto(self, inObject, proto): + """ + Convert a Python object to a string. + + @param inObject: the object to convert. + + @param proto: the protocol we are converting for. + @type proto: AMP + """ + return self.toString(inObject) + + + def fromString(self, inString): + """ + Convert a string to a Python object. Subclasses must implement this. + + @param inString: the string to convert. + @type inString: C{bytes} + + @return: the decoded value from C{inString} + """ + + + def toString(self, inObject): + """ + Convert a Python object into a string for passing over the network. + + @param inObject: an object of the type that this Argument is intended + to deal with. + + @return: the wire encoding of inObject + @rtype: C{bytes} + """ + + + +class Integer(Argument): + """ + Encode any integer values of any size on the wire as the string + representation. + + Example: C{123} becomes C{"123"} + """ + fromString = int + def toString(self, inObject): + return intToBytes(inObject) + + + +class String(Argument): + """ + Don't do any conversion at all; just pass through 'str'. + """ + def toString(self, inObject): + return inObject + + def fromString(self, inString): + return inString + + + +class Float(Argument): + """ + Encode floating-point values on the wire as their repr. + """ + fromString = float + + def toString(self, inString): + if not isinstance(inString, float): + raise ValueError("Bad float value %r" % (inString,)) + return str(inString).encode('ascii') + + + +class Boolean(Argument): + """ + Encode True or False as "True" or "False" on the wire. + """ + def fromString(self, inString): + if inString == b'True': + return True + elif inString == b'False': + return False + else: + raise TypeError("Bad boolean value: %r" % (inString,)) + + + def toString(self, inObject): + if inObject: + return b'True' + else: + return b'False' + + + +class Unicode(String): + """ + Encode a unicode string on the wire as UTF-8. + """ + + def toString(self, inObject): + return String.toString(self, inObject.encode('utf-8')) + + + def fromString(self, inString): + return String.fromString(self, inString).decode('utf-8') + + + +class Path(Unicode): + """ + Encode and decode L{filepath.FilePath} instances as paths on the wire. + + This is really intended for use with subprocess communication tools: + exchanging pathnames on different machines over a network is not generally + meaningful, but neither is it disallowed; you can use this to communicate + about NFS paths, for example. + """ + def fromString(self, inString): + return filepath.FilePath(Unicode.fromString(self, inString)) + + + def toString(self, inObject): + return Unicode.toString(self, inObject.asTextMode().path) + + + +class ListOf(Argument): + """ + Encode and decode lists of instances of a single other argument type. + + For example, if you want to pass:: + + [3, 7, 9, 15] + + You can create an argument like this:: + + ListOf(Integer()) + + The serialized form of the entire list is subject to the limit imposed by + L{MAX_VALUE_LENGTH}. List elements are represented as 16-bit length + prefixed strings. The argument type passed to the L{ListOf} initializer is + responsible for producing the serialized form of each element. + + @ivar elementType: The L{Argument} instance used to encode and decode list + elements (note, not an arbitrary L{IArgumentType} implementation: + arguments must be implemented using only the C{fromString} and + C{toString} methods, not the C{fromBox} and C{toBox} methods). + + @param optional: a boolean indicating whether this argument can be + omitted in the protocol. + + @since: 10.0 + """ + def __init__(self, elementType, optional=False): + self.elementType = elementType + Argument.__init__(self, optional) + + + def fromString(self, inString): + """ + Convert the serialized form of a list of instances of some type back + into that list. + """ + strings = [] + parser = Int16StringReceiver() + parser.stringReceived = strings.append + parser.dataReceived(inString) + elementFromString = self.elementType.fromString + return [elementFromString(string) for string in strings] + + + def toString(self, inObject): + """ + Serialize the given list of objects to a single string. + """ + strings = [] + for obj in inObject: + serialized = self.elementType.toString(obj) + strings.append(pack('!H', len(serialized))) + strings.append(serialized) + return b''.join(strings) + + + +class AmpList(Argument): + """ + Convert a list of dictionaries into a list of AMP boxes on the wire. + + For example, if you want to pass:: + + [{'a': 7, 'b': u'hello'}, {'a': 9, 'b': u'goodbye'}] + + You might use an AmpList like this in your arguments or response list:: + + AmpList([('a', Integer()), + ('b', Unicode())]) + """ + def __init__(self, subargs, optional=False): + """ + Create an AmpList. + + @param subargs: a list of 2-tuples of ('name', argument) describing the + schema of the dictionaries in the sequence of amp boxes. + @type subargs: A C{list} of (C{bytes}, L{Argument}) tuples. + + @param optional: a boolean indicating whether this argument can be + omitted in the protocol. + """ + assert all(isinstance(name, bytes) for name, _ in subargs), ( + "AmpList should be defined with a list of (name, argument) " + "tuples where `name' is a byte string, got: %r" % (subargs, )) + self.subargs = subargs + Argument.__init__(self, optional) + + + def fromStringProto(self, inString, proto): + boxes = parseString(inString) + values = [_stringsToObjects(box, self.subargs, proto) + for box in boxes] + return values + + + def toStringProto(self, inObject, proto): + return b''.join([_objectsToStrings( + objects, self.subargs, Box(), proto + ).serialize() for objects in inObject]) + + + +class Descriptor(Integer): + """ + Encode and decode file descriptors for exchange over a UNIX domain socket. + + This argument type requires an AMP connection set up over an + L{IUNIXTransport<twisted.internet.interfaces.IUNIXTransport>} provider (for + example, the kind of connection created by + L{IReactorUNIX.connectUNIX<twisted.internet.interfaces.IReactorUNIX.connectUNIX>} + and L{UNIXClientEndpoint<twisted.internet.endpoints.UNIXClientEndpoint>}). + + There is no correspondence between the integer value of the file descriptor + on the sending and receiving sides, therefore an alternate approach is taken + to matching up received descriptors with particular L{Descriptor} + parameters. The argument is encoded to an ordinal (unique per connection) + for inclusion in the AMP command or response box. The descriptor itself is + sent using + L{IUNIXTransport.sendFileDescriptor<twisted.internet.interfaces.IUNIXTransport.sendFileDescriptor>}. + The receiver uses the order in which file descriptors are received and the + ordinal value to come up with the received copy of the descriptor. + """ + def fromStringProto(self, inString, proto): + """ + Take a unique identifier associated with a file descriptor which must + have been received by now and use it to look up that descriptor in a + dictionary where they are kept. + + @param inString: The base representation (as a byte string) of an + ordinal indicating which file descriptor corresponds to this usage + of this argument. + @type inString: C{str} + + @param proto: The protocol used to receive this descriptor. This + protocol must be connected via a transport providing + L{IUNIXTransport<twisted.internet.interfaces.IUNIXTransport>}. + @type proto: L{BinaryBoxProtocol} + + @return: The file descriptor represented by C{inString}. + @rtype: C{int} + """ + return proto._getDescriptor(int(inString)) + + + def toStringProto(self, inObject, proto): + """ + Send C{inObject}, an integer file descriptor, over C{proto}'s connection + and return a unique identifier which will allow the receiver to + associate the file descriptor with this argument. + + @param inObject: A file descriptor to duplicate over an AMP connection + as the value for this argument. + @type inObject: C{int} + + @param proto: The protocol which will be used to send this descriptor. + This protocol must be connected via a transport providing + L{IUNIXTransport<twisted.internet.interfaces.IUNIXTransport>}. + + @return: A byte string which can be used by the receiver to reconstruct + the file descriptor. + @type: C{str} + """ + identifier = proto._sendFileDescriptor(inObject) + outString = Integer.toStringProto(self, identifier, proto) + return outString + + + +class Command: + """ + Subclass me to specify an AMP Command. + + @cvar arguments: A list of 2-tuples of (name, Argument-subclass-instance), + specifying the names and values of the parameters which are required for + this command. + + @cvar response: A list like L{arguments}, but instead used for the return + value. + + @cvar errors: A mapping of subclasses of L{Exception} to wire-protocol tags + for errors represented as L{str}s. Responders which raise keys from + this dictionary will have the error translated to the corresponding tag + on the wire. + Invokers which receive Deferreds from invoking this command with + L{BoxDispatcher.callRemote} will potentially receive Failures with keys + from this mapping as their value. + This mapping is inherited; if you declare a command which handles + C{FooError} as 'FOO_ERROR', then subclass it and specify C{BarError} as + 'BAR_ERROR', responders to the subclass may raise either C{FooError} or + C{BarError}, and invokers must be able to deal with either of those + exceptions. + + @cvar fatalErrors: like 'errors', but errors in this list will always + terminate the connection, despite being of a recognizable error type. + + @cvar commandType: The type of Box used to issue commands; useful only for + protocol-modifying behavior like startTLS or protocol switching. Defaults + to a plain vanilla L{Box}. + + @cvar responseType: The type of Box used to respond to this command; only + useful for protocol-modifying behavior like startTLS or protocol switching. + Defaults to a plain vanilla L{Box}. + + @ivar requiresAnswer: a boolean; defaults to True. Set it to False on your + subclass if you want callRemote to return None. Note: this is a hint only + to the client side of the protocol. The return-type of a command responder + method must always be a dictionary adhering to the contract specified by + L{response}, because clients are always free to request a response if they + want one. + """ + + class __metaclass__(type): + """ + Metaclass hack to establish reverse-mappings for 'errors' and + 'fatalErrors' as class vars. + """ + def __new__(cls, name, bases, attrs): + reverseErrors = attrs['reverseErrors'] = {} + er = attrs['allErrors'] = {} + if 'commandName' not in attrs: + if _PY3: + attrs['commandName'] = name.encode("ascii") + else: + attrs['commandName'] = name + newtype = type.__new__(cls, name, bases, attrs) + + if not isinstance(newtype.commandName, bytes): + raise TypeError( + "Command names must be byte strings, got: %r" + % (newtype.commandName, )) + for name, _ in newtype.arguments: + if not isinstance(name, bytes): + raise TypeError( + "Argument names must be byte strings, got: %r" + % (name, )) + for name, _ in newtype.response: + if not isinstance(name, bytes): + raise TypeError( + "Response names must be byte strings, got: %r" + % (name, )) + + errors = {} + fatalErrors = {} + accumulateClassDict(newtype, 'errors', errors) + accumulateClassDict(newtype, 'fatalErrors', fatalErrors) + + if not isinstance(newtype.errors, dict): + newtype.errors = dict(newtype.errors) + if not isinstance(newtype.fatalErrors, dict): + newtype.fatalErrors = dict(newtype.fatalErrors) + + for v, k in iteritems(errors): + reverseErrors[k] = v + er[v] = k + for v, k in iteritems(fatalErrors): + reverseErrors[k] = v + er[v] = k + + for _, name in iteritems(newtype.errors): + if not isinstance(name, bytes): + raise TypeError( + "Error names must be byte strings, got: %r" + % (name, )) + for _, name in iteritems(newtype.fatalErrors): + if not isinstance(name, bytes): + raise TypeError( + "Fatal error names must be byte strings, got: %r" + % (name, )) + + return newtype + + arguments = [] + response = [] + extra = [] + errors = {} + fatalErrors = {} + + commandType = Box + responseType = Box + + requiresAnswer = True + + + def __init__(self, **kw): + """ + Create an instance of this command with specified values for its + parameters. + + In Python 3, keyword arguments MUST be Unicode/native strings whereas + in Python 2 they could be either byte strings or Unicode strings. + + A L{Command}'s arguments are defined in its schema using C{bytes} + names. The values for those arguments are plucked from the keyword + arguments using the name returned from L{_wireNameToPythonIdentifier}. + In other words, keyword arguments should be named using the + Python-side equivalent of the on-wire (C{bytes}) name. + + @param kw: a dict containing an appropriate value for each name + specified in the L{arguments} attribute of my class. + + @raise InvalidSignature: if you forgot any required arguments. + """ + self.structured = kw + forgotten = [] + for name, arg in self.arguments: + pythonName = _wireNameToPythonIdentifier(name) + if pythonName not in self.structured and not arg.optional: + forgotten.append(pythonName) + if forgotten: + raise InvalidSignature("forgot %s for %s" % ( + ', '.join(forgotten), self.commandName)) + forgotten = [] + + + def makeResponse(cls, objects, proto): + """ + Serialize a mapping of arguments using this L{Command}'s + response schema. + + @param objects: a dict with keys matching the names specified in + self.response, having values of the types that the Argument objects in + self.response can format. + + @param proto: an L{AMP}. + + @return: an L{AmpBox}. + """ + try: + responseType = cls.responseType() + except: + return fail() + return _objectsToStrings(objects, cls.response, responseType, proto) + makeResponse = classmethod(makeResponse) + + + def makeArguments(cls, objects, proto): + """ + Serialize a mapping of arguments using this L{Command}'s + argument schema. + + @param objects: a dict with keys similar to the names specified in + self.arguments, having values of the types that the Argument objects in + self.arguments can parse. + + @param proto: an L{AMP}. + + @return: An instance of this L{Command}'s C{commandType}. + """ + allowedNames = set() + for (argName, ignored) in cls.arguments: + allowedNames.add(_wireNameToPythonIdentifier(argName)) + + for intendedArg in objects: + if intendedArg not in allowedNames: + raise InvalidSignature( + "%s is not a valid argument" % (intendedArg,)) + return _objectsToStrings(objects, cls.arguments, cls.commandType(), + proto) + makeArguments = classmethod(makeArguments) + + + def parseResponse(cls, box, protocol): + """ + Parse a mapping of serialized arguments using this + L{Command}'s response schema. + + @param box: A mapping of response-argument names to the + serialized forms of those arguments. + @param protocol: The L{AMP} protocol. + + @return: A mapping of response-argument names to the parsed + forms. + """ + return _stringsToObjects(box, cls.response, protocol) + parseResponse = classmethod(parseResponse) + + + def parseArguments(cls, box, protocol): + """ + Parse a mapping of serialized arguments using this + L{Command}'s argument schema. + + @param box: A mapping of argument names to the seralized forms + of those arguments. + @param protocol: The L{AMP} protocol. + + @return: A mapping of argument names to the parsed forms. + """ + return _stringsToObjects(box, cls.arguments, protocol) + parseArguments = classmethod(parseArguments) + + + def responder(cls, methodfunc): + """ + Declare a method to be a responder for a particular command. + + This is a decorator. + + Use like so:: + + class MyCommand(Command): + arguments = [('a', ...), ('b', ...)] + + class MyProto(AMP): + def myFunMethod(self, a, b): + ... + MyCommand.responder(myFunMethod) + + Notes: Although decorator syntax is not used within Twisted, this + function returns its argument and is therefore safe to use with + decorator syntax. + + This is not thread safe. Don't declare AMP subclasses in other + threads. Don't declare responders outside the scope of AMP subclasses; + the behavior is undefined. + + @param methodfunc: A function which will later become a method, which + has a keyword signature compatible with this command's L{argument} list + and returns a dictionary with a set of keys compatible with this + command's L{response} list. + + @return: the methodfunc parameter. + """ + CommandLocator._currentClassCommands.append((cls, methodfunc)) + return methodfunc + responder = classmethod(responder) + + + # Our only instance method + def _doCommand(self, proto): + """ + Encode and send this Command to the given protocol. + + @param proto: an AMP, representing the connection to send to. + + @return: a Deferred which will fire or error appropriately when the + other side responds to the command (or error if the connection is lost + before it is responded to). + """ + + def _massageError(error): + error.trap(RemoteAmpError) + rje = error.value + errorType = self.reverseErrors.get(rje.errorCode, + UnknownRemoteError) + return Failure(errorType(rje.description)) + + d = proto._sendBoxCommand(self.commandName, + self.makeArguments(self.structured, proto), + self.requiresAnswer) + + if self.requiresAnswer: + d.addCallback(self.parseResponse, proto) + d.addErrback(_massageError) + + return d + + + +if _PY3: + # Python 3 ignores the __metaclass__ attribute and has instead new syntax + # for setting the metaclass. Unfortunately it's not valid Python 2 syntax + # so we work-around it by recreating Command using the metaclass here. + Command = Command.__metaclass__("Command", (Command, ), {}) + + + +class _NoCertificate: + """ + This is for peers which don't want to use a local certificate. Used by + AMP because AMP's internal language is all about certificates and this + duck-types in the appropriate place; this API isn't really stable though, + so it's not exposed anywhere public. + + For clients, it will use ephemeral DH keys, or whatever the default is for + certificate-less clients in OpenSSL. For servers, it will generate a + temporary self-signed certificate with garbage values in the DN and use + that. + """ + + def __init__(self, client): + """ + Create a _NoCertificate which either is or isn't for the client side of + the connection. + + @param client: True if we are a client and should truly have no + certificate and be anonymous, False if we are a server and actually + have to generate a temporary certificate. + + @type client: bool + """ + self.client = client + + + def options(self, *authorities): + """ + Behaves like L{twisted.internet.ssl.PrivateCertificate.options}(). + """ + if not self.client: + # do some crud with sslverify to generate a temporary self-signed + # certificate. This is SLOOOWWWWW so it is only in the absolute + # worst, most naive case. + + # We have to do this because OpenSSL will not let both the server + # and client be anonymous. + sharedDN = DN(CN='TEMPORARY CERTIFICATE') + key = KeyPair.generate() + cr = key.certificateRequest(sharedDN) + sscrd = key.signCertificateRequest(sharedDN, cr, lambda dn: True, 1) + cert = key.newCertificate(sscrd) + return cert.options(*authorities) + options = dict() + if authorities: + options.update(dict(verify=True, + requireCertificate=True, + caCerts=[auth.original for auth in authorities])) + occo = CertificateOptions(**options) + return occo + + + +class _TLSBox(AmpBox): + """ + I am an AmpBox that, upon being sent, initiates a TLS connection. + """ + __slots__ = [] + + def __init__(self): + if ssl is None: + raise RemoteAmpError(b"TLS_ERROR", "TLS not available") + AmpBox.__init__(self) + + + def _keyprop(k, default): + return property(lambda self: self.get(k, default)) + + + # These properties are described in startTLS + certificate = _keyprop(b'tls_localCertificate', _NoCertificate(False)) + verify = _keyprop(b'tls_verifyAuthorities', None) + + def _sendTo(self, proto): + """ + Send my encoded value to the protocol, then initiate TLS. + """ + ab = AmpBox(self) + for k in [b'tls_localCertificate', + b'tls_verifyAuthorities']: + ab.pop(k, None) + ab._sendTo(proto) + proto._startTLS(self.certificate, self.verify) + + + +class _LocalArgument(String): + """ + Local arguments are never actually relayed across the wire. This is just a + shim so that StartTLS can pretend to have some arguments: if arguments + acquire documentation properties, replace this with something nicer later. + """ + + def fromBox(self, name, strings, objects, proto): + pass + + + +class StartTLS(Command): + """ + Use, or subclass, me to implement a command that starts TLS. + + Callers of StartTLS may pass several special arguments, which affect the + TLS negotiation: + + - tls_localCertificate: This is a + twisted.internet.ssl.PrivateCertificate which will be used to secure + the side of the connection it is returned on. + + - tls_verifyAuthorities: This is a list of + twisted.internet.ssl.Certificate objects that will be used as the + certificate authorities to verify our peer's certificate. + + Each of those special parameters may also be present as a key in the + response dictionary. + """ + + arguments = [(b"tls_localCertificate", _LocalArgument(optional=True)), + (b"tls_verifyAuthorities", _LocalArgument(optional=True))] + + response = [(b"tls_localCertificate", _LocalArgument(optional=True)), + (b"tls_verifyAuthorities", _LocalArgument(optional=True))] + + responseType = _TLSBox + + def __init__(self, **kw): + """ + Create a StartTLS command. (This is private. Use AMP.callRemote.) + + @param tls_localCertificate: the PrivateCertificate object to use to + secure the connection. If it's None, or unspecified, an ephemeral DH + key is used instead. + + @param tls_verifyAuthorities: a list of Certificate objects which + represent root certificates to verify our peer with. + """ + if ssl is None: + raise RuntimeError("TLS not available.") + self.certificate = kw.pop('tls_localCertificate', _NoCertificate(True)) + self.authorities = kw.pop('tls_verifyAuthorities', None) + Command.__init__(self, **kw) + + + def _doCommand(self, proto): + """ + When a StartTLS command is sent, prepare to start TLS, but don't actually + do it; wait for the acknowledgement, then initiate the TLS handshake. + """ + d = Command._doCommand(self, proto) + proto._prepareTLS(self.certificate, self.authorities) + # XXX before we get back to user code we are going to start TLS... + def actuallystart(response): + proto._startTLS(self.certificate, self.authorities) + return response + d.addCallback(actuallystart) + return d + + + +class ProtocolSwitchCommand(Command): + """ + Use this command to switch from something Amp-derived to a different + protocol mid-connection. This can be useful to use amp as the + connection-startup negotiation phase. Since TLS is a different layer + entirely, you can use Amp to negotiate the security parameters of your + connection, then switch to a different protocol, and the connection will + remain secured. + """ + + def __init__(self, _protoToSwitchToFactory, **kw): + """ + Create a ProtocolSwitchCommand. + + @param _protoToSwitchToFactory: a ProtocolFactory which will generate + the Protocol to switch to. + + @param kw: Keyword arguments, encoded and handled normally as + L{Command} would. + """ + + self.protoToSwitchToFactory = _protoToSwitchToFactory + super(ProtocolSwitchCommand, self).__init__(**kw) + + + def makeResponse(cls, innerProto, proto): + return _SwitchBox(innerProto) + makeResponse = classmethod(makeResponse) + + + def _doCommand(self, proto): + """ + When we emit a ProtocolSwitchCommand, lock the protocol, but don't actually + switch to the new protocol unless an acknowledgement is received. If + an error is received, switch back. + """ + d = super(ProtocolSwitchCommand, self)._doCommand(proto) + proto._lockForSwitch() + def switchNow(ign): + innerProto = self.protoToSwitchToFactory.buildProtocol( + proto.transport.getPeer()) + proto._switchTo(innerProto, self.protoToSwitchToFactory) + return ign + def handle(ign): + proto._unlockFromSwitch() + self.protoToSwitchToFactory.clientConnectionFailed( + None, Failure(CONNECTION_LOST)) + return ign + return d.addCallbacks(switchNow, handle) + + + +@implementer(IFileDescriptorReceiver) +class _DescriptorExchanger(object): + """ + L{_DescriptorExchanger} is a mixin for L{BinaryBoxProtocol} which adds + support for receiving file descriptors, a feature offered by + L{IUNIXTransport<twisted.internet.interfaces.IUNIXTransport>}. + + @ivar _descriptors: Temporary storage for all file descriptors received. + Values in this dictionary are the file descriptors (as integers). Keys + in this dictionary are ordinals giving the order in which each + descriptor was received. The ordering information is used to allow + L{Descriptor} to determine which is the correct descriptor for any + particular usage of that argument type. + @type _descriptors: C{dict} + + @ivar _sendingDescriptorCounter: A no-argument callable which returns the + ordinals, starting from 0. This is used to construct values for + C{_sendFileDescriptor}. + + @ivar _receivingDescriptorCounter: A no-argument callable which returns the + ordinals, starting from 0. This is used to construct values for + C{fileDescriptorReceived}. + """ + + def __init__(self): + self._descriptors = {} + self._getDescriptor = self._descriptors.pop + self._sendingDescriptorCounter = partial(next, count()) + self._receivingDescriptorCounter = partial(next, count()) + + + def _sendFileDescriptor(self, descriptor): + """ + Assign and return the next ordinal to the given descriptor after sending + the descriptor over this protocol's transport. + """ + self.transport.sendFileDescriptor(descriptor) + return self._sendingDescriptorCounter() + + + def fileDescriptorReceived(self, descriptor): + """ + Collect received file descriptors to be claimed later by L{Descriptor}. + + @param descriptor: The received file descriptor. + @type descriptor: C{int} + """ + self._descriptors[self._receivingDescriptorCounter()] = descriptor + + + +@implementer(IBoxSender) +class BinaryBoxProtocol(StatefulStringProtocol, Int16StringReceiver, + _DescriptorExchanger): + """ + A protocol for receiving L{AmpBox}es - key/value pairs - via length-prefixed + strings. A box is composed of: + + - any number of key-value pairs, described by: + - a 2-byte network-endian packed key length (of which the first + byte must be null, and the second must be non-null: i.e. the + value of the length must be 1-255) + - a key, comprised of that many bytes + - a 2-byte network-endian unsigned value length (up to the maximum + of 65535) + - a value, comprised of that many bytes + - 2 null bytes + + In other words, an even number of strings prefixed with packed unsigned + 16-bit integers, and then a 0-length string to indicate the end of the box. + + This protocol also implements 2 extra private bits of functionality related + to the byte boundaries between messages; it can start TLS between two given + boxes or switch to an entirely different protocol. However, due to some + tricky elements of the implementation, the public interface to this + functionality is L{ProtocolSwitchCommand} and L{StartTLS}. + + @ivar _keyLengthLimitExceeded: A flag which is only true when the + connection is being closed because a key length prefix which was longer + than allowed by the protocol was received. + + @ivar boxReceiver: an L{IBoxReceiver} provider, whose + L{IBoxReceiver.ampBoxReceived} method will be invoked for each + L{AmpBox} that is received. + """ + + _justStartedTLS = False + _startingTLSBuffer = None + _locked = False + _currentKey = None + _currentBox = None + + _keyLengthLimitExceeded = False + + hostCertificate = None + noPeerCertificate = False # for tests + innerProtocol = None + innerProtocolClientFactory = None + + def __init__(self, boxReceiver): + _DescriptorExchanger.__init__(self) + self.boxReceiver = boxReceiver + + + def _switchTo(self, newProto, clientFactory=None): + """ + Switch this BinaryBoxProtocol's transport to a new protocol. You need + to do this 'simultaneously' on both ends of a connection; the easiest + way to do this is to use a subclass of ProtocolSwitchCommand. + + @param newProto: the new protocol instance to switch to. + + @param clientFactory: the ClientFactory to send the + L{twisted.internet.protocol.ClientFactory.clientConnectionLost} + notification to. + """ + # All the data that Int16Receiver has not yet dealt with belongs to our + # new protocol: luckily it's keeping that in a handy (although + # ostensibly internal) variable for us: + newProtoData = self.recvd + # We're quite possibly in the middle of a 'dataReceived' loop in + # Int16StringReceiver: let's make sure that the next iteration, the + # loop will break and not attempt to look at something that isn't a + # length prefix. + self.recvd = '' + # Finally, do the actual work of setting up the protocol and delivering + # its first chunk of data, if one is available. + self.innerProtocol = newProto + self.innerProtocolClientFactory = clientFactory + newProto.makeConnection(self.transport) + if newProtoData: + newProto.dataReceived(newProtoData) + + + def sendBox(self, box): + """ + Send a amp.Box to my peer. + + Note: transport.write is never called outside of this method. + + @param box: an AmpBox. + + @raise ProtocolSwitched: if the protocol has previously been switched. + + @raise ConnectionLost: if the connection has previously been lost. + """ + if self._locked: + raise ProtocolSwitched( + "This connection has switched: no AMP traffic allowed.") + if self.transport is None: + raise ConnectionLost() + if self._startingTLSBuffer is not None: + self._startingTLSBuffer.append(box) + else: + self.transport.write(box.serialize()) + + + def makeConnection(self, transport): + """ + Notify L{boxReceiver} that it is about to receive boxes from this + protocol by invoking L{IBoxReceiver.startReceivingBoxes}. + """ + self.transport = transport + self.boxReceiver.startReceivingBoxes(self) + self.connectionMade() + + + def dataReceived(self, data): + """ + Either parse incoming data as L{AmpBox}es or relay it to our nested + protocol. + """ + if self._justStartedTLS: + self._justStartedTLS = False + # If we already have an inner protocol, then we don't deliver data to + # the protocol parser any more; we just hand it off. + if self.innerProtocol is not None: + self.innerProtocol.dataReceived(data) + return + return Int16StringReceiver.dataReceived(self, data) + + + def connectionLost(self, reason): + """ + The connection was lost; notify any nested protocol. + """ + if self.innerProtocol is not None: + self.innerProtocol.connectionLost(reason) + if self.innerProtocolClientFactory is not None: + self.innerProtocolClientFactory.clientConnectionLost(None, reason) + if self._keyLengthLimitExceeded: + failReason = Failure(TooLong(True, False, None, None)) + elif reason.check(ConnectionClosed) and self._justStartedTLS: + # We just started TLS and haven't received any data. This means + # the other connection didn't like our cert (although they may not + # have told us why - later Twisted should make 'reason' into a TLS + # error.) + failReason = PeerVerifyError( + "Peer rejected our certificate for an unknown reason.") + else: + failReason = reason + self.boxReceiver.stopReceivingBoxes(failReason) + + + # The longest key allowed + _MAX_KEY_LENGTH = 255 + + # The longest value allowed (this is somewhat redundant, as longer values + # cannot be encoded - ah well). + _MAX_VALUE_LENGTH = 65535 + + # The first thing received is a key. + MAX_LENGTH = _MAX_KEY_LENGTH + + def proto_init(self, string): + """ + String received in the 'init' state. + """ + self._currentBox = AmpBox() + return self.proto_key(string) + + + def proto_key(self, string): + """ + String received in the 'key' state. If the key is empty, a complete + box has been received. + """ + if string: + self._currentKey = string + self.MAX_LENGTH = self._MAX_VALUE_LENGTH + return 'value' + else: + self.boxReceiver.ampBoxReceived(self._currentBox) + self._currentBox = None + return 'init' + + + def proto_value(self, string): + """ + String received in the 'value' state. + """ + self._currentBox[self._currentKey] = string + self._currentKey = None + self.MAX_LENGTH = self._MAX_KEY_LENGTH + return 'key' + + + def lengthLimitExceeded(self, length): + """ + The key length limit was exceeded. Disconnect the transport and make + sure a meaningful exception is reported. + """ + self._keyLengthLimitExceeded = True + self.transport.loseConnection() + + + def _lockForSwitch(self): + """ + Lock this binary protocol so that no further boxes may be sent. This + is used when sending a request to switch underlying protocols. You + probably want to subclass ProtocolSwitchCommand rather than calling + this directly. + """ + self._locked = True + + + def _unlockFromSwitch(self): + """ + Unlock this locked binary protocol so that further boxes may be sent + again. This is used after an attempt to switch protocols has failed + for some reason. + """ + if self.innerProtocol is not None: + raise ProtocolSwitched("Protocol already switched. Cannot unlock.") + self._locked = False + + + def _prepareTLS(self, certificate, verifyAuthorities): + """ + Used by StartTLSCommand to put us into the state where we don't + actually send things that get sent, instead we buffer them. see + L{_sendBoxCommand}. + """ + self._startingTLSBuffer = [] + if self.hostCertificate is not None: + raise OnlyOneTLS( + "Previously authenticated connection between %s and %s " + "is trying to re-establish as %s" % ( + self.hostCertificate, + self.peerCertificate, + (certificate, verifyAuthorities))) + + + def _startTLS(self, certificate, verifyAuthorities): + """ + Used by TLSBox to initiate the SSL handshake. + + @param certificate: a L{twisted.internet.ssl.PrivateCertificate} for + use locally. + + @param verifyAuthorities: L{twisted.internet.ssl.Certificate} instances + representing certificate authorities which will verify our peer. + """ + self.hostCertificate = certificate + self._justStartedTLS = True + if verifyAuthorities is None: + verifyAuthorities = () + self.transport.startTLS(certificate.options(*verifyAuthorities)) + stlsb = self._startingTLSBuffer + if stlsb is not None: + self._startingTLSBuffer = None + for box in stlsb: + self.sendBox(box) + + + def _getPeerCertificate(self): + if self.noPeerCertificate: + return None + return Certificate.peerFromTransport(self.transport) + peerCertificate = property(_getPeerCertificate) + + + def unhandledError(self, failure): + """ + The buck stops here. This error was completely unhandled, time to + terminate the connection. + """ + log.err( + failure, + "Amp server or network failure unhandled by client application. " + "Dropping connection! To avoid, add errbacks to ALL remote " + "commands!") + if self.transport is not None: + self.transport.loseConnection() + + + def _defaultStartTLSResponder(self): + """ + The default TLS responder doesn't specify any certificate or anything. + + From a security perspective, it's little better than a plain-text + connection - but it is still a *bit* better, so it's included for + convenience. + + You probably want to override this by providing your own StartTLS.responder. + """ + return {} + StartTLS.responder(_defaultStartTLSResponder) + + + +class AMP(BinaryBoxProtocol, BoxDispatcher, + CommandLocator, SimpleStringLocator): + """ + This protocol is an AMP connection. See the module docstring for protocol + details. + """ + + _ampInitialized = False + + def __init__(self, boxReceiver=None, locator=None): + # For backwards compatibility. When AMP did not separate parsing logic + # (L{BinaryBoxProtocol}), request-response logic (L{BoxDispatcher}) and + # command routing (L{CommandLocator}), it did not have a constructor. + # Now it does, so old subclasses might have defined their own that did + # not upcall. If this flag isn't set, we'll call the constructor in + # makeConnection before anything actually happens. + self._ampInitialized = True + if boxReceiver is None: + boxReceiver = self + if locator is None: + locator = self + BoxDispatcher.__init__(self, locator) + BinaryBoxProtocol.__init__(self, boxReceiver) + + + def locateResponder(self, name): + """ + Unify the implementations of L{CommandLocator} and + L{SimpleStringLocator} to perform both kinds of dispatch, preferring + L{CommandLocator}. + + @type name: C{bytes} + """ + firstResponder = CommandLocator.locateResponder(self, name) + if firstResponder is not None: + return firstResponder + secondResponder = SimpleStringLocator.locateResponder(self, name) + return secondResponder + + + def __repr__(self): + """ + A verbose string representation which gives us information about this + AMP connection. + """ + if self.innerProtocol is not None: + innerRepr = ' inner %r' % (self.innerProtocol,) + else: + innerRepr = '' + return '<%s%s at 0x%x>' % ( + self.__class__.__name__, innerRepr, id(self)) + + + def makeConnection(self, transport): + """ + Emit a helpful log message when the connection is made. + """ + if not self._ampInitialized: + # See comment in the constructor re: backward compatibility. I + # should probably emit a deprecation warning here. + AMP.__init__(self) + # Save these so we can emit a similar log message in L{connectionLost}. + self._transportPeer = transport.getPeer() + self._transportHost = transport.getHost() + log.msg("%s connection established (HOST:%s PEER:%s)" % ( + self.__class__.__name__, + self._transportHost, + self._transportPeer)) + BinaryBoxProtocol.makeConnection(self, transport) + + + def connectionLost(self, reason): + """ + Emit a helpful log message when the connection is lost. + """ + log.msg("%s connection lost (HOST:%s PEER:%s)" % + (self.__class__.__name__, + self._transportHost, + self._transportPeer)) + BinaryBoxProtocol.connectionLost(self, reason) + self.transport = None + + + +class _ParserHelper: + """ + A box receiver which records all boxes received. + """ + def __init__(self): + self.boxes = [] + + + def getPeer(self): + return 'string' + + + def getHost(self): + return 'string' + + disconnecting = False + + + def startReceivingBoxes(self, sender): + """ + No initialization is required. + """ + + + def ampBoxReceived(self, box): + self.boxes.append(box) + + + # Synchronous helpers + def parse(cls, fileObj): + """ + Parse some amp data stored in a file. + + @param fileObj: a file-like object. + + @return: a list of AmpBoxes encoded in the given file. + """ + parserHelper = cls() + bbp = BinaryBoxProtocol(boxReceiver=parserHelper) + bbp.makeConnection(parserHelper) + bbp.dataReceived(fileObj.read()) + return parserHelper.boxes + parse = classmethod(parse) + + + def parseString(cls, data): + """ + Parse some amp data stored in a string. + + @param data: a str holding some amp-encoded data. + + @return: a list of AmpBoxes encoded in the given string. + """ + return cls.parse(BytesIO(data)) + parseString = classmethod(parseString) + + + +parse = _ParserHelper.parse +parseString = _ParserHelper.parseString + +def _stringsToObjects(strings, arglist, proto): + """ + Convert an AmpBox to a dictionary of python objects, converting through a + given arglist. + + @param strings: an AmpBox (or dict of strings) + + @param arglist: a list of 2-tuples of strings and Argument objects, as + described in L{Command.arguments}. + + @param proto: an L{AMP} instance. + + @return: the converted dictionary mapping names to argument objects. + """ + objects = {} + myStrings = strings.copy() + for argname, argparser in arglist: + argparser.fromBox(argname, myStrings, objects, proto) + return objects + + + +def _objectsToStrings(objects, arglist, strings, proto): + """ + Convert a dictionary of python objects to an AmpBox, converting through a + given arglist. + + @param objects: a dict mapping names to python objects + + @param arglist: a list of 2-tuples of strings and Argument objects, as + described in L{Command.arguments}. + + @param strings: [OUT PARAMETER] An object providing the L{dict} + interface which will be populated with serialized data. + + @param proto: an L{AMP} instance. + + @return: The converted dictionary mapping names to encoded argument + strings (identical to C{strings}). + """ + myObjects = objects.copy() + for argname, argparser in arglist: + argparser.toBox(argname, strings, myObjects, proto) + return strings + + + +class Decimal(Argument): + """ + Encodes C{decimal.Decimal} instances. + + There are several ways in which a decimal value might be encoded. + + Special values are encoded as special strings:: + + - Positive infinity is encoded as C{"Infinity"} + - Negative infinity is encoded as C{"-Infinity"} + - Quiet not-a-number is encoded as either C{"NaN"} or C{"-NaN"} + - Signalling not-a-number is encoded as either C{"sNaN"} or C{"-sNaN"} + + Normal values are encoded using the base ten string representation, using + engineering notation to indicate magnitude without precision, and "normal" + digits to indicate precision. For example:: + + - C{"1"} represents the value I{1} with precision to one place. + - C{"-1"} represents the value I{-1} with precision to one place. + - C{"1.0"} represents the value I{1} with precision to two places. + - C{"10"} represents the value I{10} with precision to two places. + - C{"1E+2"} represents the value I{10} with precision to one place. + - C{"1E-1"} represents the value I{0.1} with precision to one place. + - C{"1.5E+2"} represents the value I{15} with precision to two places. + + U{http://speleotrove.com/decimal/} should be considered the authoritative + specification for the format. + """ + + def fromString(self, inString): + inString = nativeString(inString) + return decimal.Decimal(inString) + + def toString(self, inObject): + """ + Serialize a C{decimal.Decimal} instance to the specified wire format. + """ + if isinstance(inObject, decimal.Decimal): + # Hopefully decimal.Decimal.__str__ actually does what we want. + return str(inObject).encode("ascii") + raise ValueError( + "amp.Decimal can only encode instances of decimal.Decimal") + + + +class DateTime(Argument): + """ + Encodes C{datetime.datetime} instances. + + Wire format: '%04i-%02i-%02iT%02i:%02i:%02i.%06i%s%02i:%02i'. Fields in + order are: year, month, day, hour, minute, second, microsecond, timezone + direction (+ or -), timezone hour, timezone minute. Encoded string is + always exactly 32 characters long. This format is compatible with ISO 8601, + but that does not mean all ISO 8601 dates can be accepted. + + Also, note that the datetime module's notion of a "timezone" can be + complex, but the wire format includes only a fixed offset, so the + conversion is not lossless. A lossless transmission of a C{datetime} instance + is not feasible since the receiving end would require a Python interpreter. + + @ivar _positions: A sequence of slices giving the positions of various + interesting parts of the wire format. + """ + + _positions = [ + slice(0, 4), slice(5, 7), slice(8, 10), # year, month, day + slice(11, 13), slice(14, 16), slice(17, 19), # hour, minute, second + slice(20, 26), # microsecond + # intentionally skip timezone direction, as it is not an integer + slice(27, 29), slice(30, 32) # timezone hour, timezone minute + ] + + def fromString(self, s): + """ + Parse a string containing a date and time in the wire format into a + C{datetime.datetime} instance. + """ + s = nativeString(s) + + if len(s) != 32: + raise ValueError('invalid date format %r' % (s,)) + + values = [int(s[p]) for p in self._positions] + sign = s[26] + timezone = _FixedOffsetTZInfo.fromSignHoursMinutes(sign, *values[7:]) + values[7:] = [timezone] + return datetime.datetime(*values) + + + def toString(self, i): + """ + Serialize a C{datetime.datetime} instance to a string in the specified + wire format. + """ + offset = i.utcoffset() + if offset is None: + raise ValueError( + 'amp.DateTime cannot serialize naive datetime instances. ' + 'You may find amp.utc useful.') + + minutesOffset = (offset.days * 86400 + offset.seconds) // 60 + + if minutesOffset > 0: + sign = '+' + else: + sign = '-' + + # strftime has no way to format the microseconds, or put a ':' in the + # timezone. Surprise! + + # Python 3.4 cannot do % interpolation on byte strings so we pack into + # an explicitly Unicode string then encode as ASCII. + packed = u'%04i-%02i-%02iT%02i:%02i:%02i.%06i%s%02i:%02i' % ( + i.year, + i.month, + i.day, + i.hour, + i.minute, + i.second, + i.microsecond, + sign, + abs(minutesOffset) // 60, + abs(minutesOffset) % 60) + + return packed.encode("ascii") diff --git a/contrib/python/Twisted/py2/twisted/protocols/basic.py b/contrib/python/Twisted/py2/twisted/protocols/basic.py new file mode 100644 index 0000000000..adecfd30ce --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/basic.py @@ -0,0 +1,953 @@ +# -*- test-case-name: twisted.protocols.test.test_basic -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + + +""" +Basic protocols, such as line-oriented, netstring, and int prefixed strings. +""" + +from __future__ import absolute_import, division + +# System imports +import re +from struct import pack, unpack, calcsize +from io import BytesIO +import math + +from zope.interface import implementer + +# Twisted imports +from twisted.python.compat import _PY3 +from twisted.internet import protocol, defer, interfaces +from twisted.python import log + + +# Unfortunately we cannot use regular string formatting on Python 3; see +# http://bugs.python.org/issue3982 for details. +if _PY3: + def _formatNetstring(data): + return b''.join([str(len(data)).encode("ascii"), b':', data, b',']) +else: + def _formatNetstring(data): + return b'%d:%s,' % (len(data), data) +_formatNetstring.__doc__ = """ +Convert some C{bytes} into netstring format. + +@param data: C{bytes} that will be reformatted. +""" + + + +DEBUG = 0 + +class NetstringParseError(ValueError): + """ + The incoming data is not in valid Netstring format. + """ + + + +class IncompleteNetstring(Exception): + """ + Not enough data to complete a netstring. + """ + + +class NetstringReceiver(protocol.Protocol): + """ + A protocol that sends and receives netstrings. + + See U{http://cr.yp.to/proto/netstrings.txt} for the specification of + netstrings. Every netstring starts with digits that specify the length + of the data. This length specification is separated from the data by + a colon. The data is terminated with a comma. + + Override L{stringReceived} to handle received netstrings. This + method is called with the netstring payload as a single argument + whenever a complete netstring is received. + + Security features: + 1. Messages are limited in size, useful if you don't want + someone sending you a 500MB netstring (change C{self.MAX_LENGTH} + to the maximum length you wish to accept). + 2. The connection is lost if an illegal message is received. + + @ivar MAX_LENGTH: Defines the maximum length of netstrings that can be + received. + @type MAX_LENGTH: C{int} + + @ivar _LENGTH: A pattern describing all strings that contain a netstring + length specification. Examples for length specifications are C{b'0:'}, + C{b'12:'}, and C{b'179:'}. C{b'007:'} is not a valid length + specification, since leading zeros are not allowed. + @type _LENGTH: C{re.Match} + + @ivar _LENGTH_PREFIX: A pattern describing all strings that contain + the first part of a netstring length specification (without the + trailing comma). Examples are '0', '12', and '179'. '007' does not + start a netstring length specification, since leading zeros are + not allowed. + @type _LENGTH_PREFIX: C{re.Match} + + @ivar _PARSING_LENGTH: Indicates that the C{NetstringReceiver} is in + the state of parsing the length portion of a netstring. + @type _PARSING_LENGTH: C{int} + + @ivar _PARSING_PAYLOAD: Indicates that the C{NetstringReceiver} is in + the state of parsing the payload portion (data and trailing comma) + of a netstring. + @type _PARSING_PAYLOAD: C{int} + + @ivar brokenPeer: Indicates if the connection is still functional + @type brokenPeer: C{int} + + @ivar _state: Indicates if the protocol is consuming the length portion + (C{PARSING_LENGTH}) or the payload (C{PARSING_PAYLOAD}) of a netstring + @type _state: C{int} + + @ivar _remainingData: Holds the chunk of data that has not yet been consumed + @type _remainingData: C{string} + + @ivar _payload: Holds the payload portion of a netstring including the + trailing comma + @type _payload: C{BytesIO} + + @ivar _expectedPayloadSize: Holds the payload size plus one for the trailing + comma. + @type _expectedPayloadSize: C{int} + """ + MAX_LENGTH = 99999 + _LENGTH = re.compile(br'(0|[1-9]\d*)(:)') + + _LENGTH_PREFIX = re.compile(br'(0|[1-9]\d*)$') + + # Some error information for NetstringParseError instances. + _MISSING_LENGTH = ("The received netstring does not start with a " + "length specification.") + _OVERFLOW = ("The length specification of the received netstring " + "cannot be represented in Python - it causes an " + "OverflowError!") + _TOO_LONG = ("The received netstring is longer than the maximum %s " + "specified by self.MAX_LENGTH") + _MISSING_COMMA = "The received netstring is not terminated by a comma." + + # The following constants are used for determining if the NetstringReceiver + # is parsing the length portion of a netstring, or the payload. + _PARSING_LENGTH, _PARSING_PAYLOAD = range(2) + + def makeConnection(self, transport): + """ + Initializes the protocol. + """ + protocol.Protocol.makeConnection(self, transport) + self._remainingData = b"" + self._currentPayloadSize = 0 + self._payload = BytesIO() + self._state = self._PARSING_LENGTH + self._expectedPayloadSize = 0 + self.brokenPeer = 0 + + + def sendString(self, string): + """ + Sends a netstring. + + Wraps up C{string} by adding length information and a + trailing comma; writes the result to the transport. + + @param string: The string to send. The necessary framing (length + prefix, etc) will be added. + @type string: C{bytes} + """ + self.transport.write(_formatNetstring(string)) + + + def dataReceived(self, data): + """ + Receives some characters of a netstring. + + Whenever a complete netstring is received, this method extracts + its payload and calls L{stringReceived} to process it. + + @param data: A chunk of data representing a (possibly partial) + netstring + @type data: C{bytes} + """ + self._remainingData += data + while self._remainingData: + try: + self._consumeData() + except IncompleteNetstring: + break + except NetstringParseError: + self._handleParseError() + break + + + def stringReceived(self, string): + """ + Override this for notification when each complete string is received. + + @param string: The complete string which was received with all + framing (length prefix, etc) removed. + @type string: C{bytes} + + @raise NotImplementedError: because the method has to be implemented + by the child class. + """ + raise NotImplementedError() + + + def _maxLengthSize(self): + """ + Calculate and return the string size of C{self.MAX_LENGTH}. + + @return: The size of the string representation for C{self.MAX_LENGTH} + @rtype: C{float} + """ + return math.ceil(math.log10(self.MAX_LENGTH)) + 1 + + + def _consumeData(self): + """ + Consumes the content of C{self._remainingData}. + + @raise IncompleteNetstring: if C{self._remainingData} does not + contain enough data to complete the current netstring. + @raise NetstringParseError: if the received data do not + form a valid netstring. + """ + if self._state == self._PARSING_LENGTH: + self._consumeLength() + self._prepareForPayloadConsumption() + if self._state == self._PARSING_PAYLOAD: + self._consumePayload() + + + def _consumeLength(self): + """ + Consumes the length portion of C{self._remainingData}. + + @raise IncompleteNetstring: if C{self._remainingData} contains + a partial length specification (digits without trailing + comma). + @raise NetstringParseError: if the received data do not form a valid + netstring. + """ + lengthMatch = self._LENGTH.match(self._remainingData) + if not lengthMatch: + self._checkPartialLengthSpecification() + raise IncompleteNetstring() + self._processLength(lengthMatch) + + + def _checkPartialLengthSpecification(self): + """ + Makes sure that the received data represents a valid number. + + Checks if C{self._remainingData} represents a number smaller or + equal to C{self.MAX_LENGTH}. + + @raise NetstringParseError: if C{self._remainingData} is no + number or is too big (checked by L{_extractLength}). + """ + partialLengthMatch = self._LENGTH_PREFIX.match(self._remainingData) + if not partialLengthMatch: + raise NetstringParseError(self._MISSING_LENGTH) + lengthSpecification = (partialLengthMatch.group(1)) + self._extractLength(lengthSpecification) + + + def _processLength(self, lengthMatch): + """ + Processes the length definition of a netstring. + + Extracts and stores in C{self._expectedPayloadSize} the number + representing the netstring size. Removes the prefix + representing the length specification from + C{self._remainingData}. + + @raise NetstringParseError: if the received netstring does not + start with a number or the number is bigger than + C{self.MAX_LENGTH}. + @param lengthMatch: A regular expression match object matching + a netstring length specification + @type lengthMatch: C{re.Match} + """ + endOfNumber = lengthMatch.end(1) + startOfData = lengthMatch.end(2) + lengthString = self._remainingData[:endOfNumber] + # Expect payload plus trailing comma: + self._expectedPayloadSize = self._extractLength(lengthString) + 1 + self._remainingData = self._remainingData[startOfData:] + + + def _extractLength(self, lengthAsString): + """ + Attempts to extract the length information of a netstring. + + @raise NetstringParseError: if the number is bigger than + C{self.MAX_LENGTH}. + @param lengthAsString: A chunk of data starting with a length + specification + @type lengthAsString: C{bytes} + @return: The length of the netstring + @rtype: C{int} + """ + self._checkStringSize(lengthAsString) + length = int(lengthAsString) + if length > self.MAX_LENGTH: + raise NetstringParseError(self._TOO_LONG % (self.MAX_LENGTH,)) + return length + + + def _checkStringSize(self, lengthAsString): + """ + Checks the sanity of lengthAsString. + + Checks if the size of the length specification exceeds the + size of the string representing self.MAX_LENGTH. If this is + not the case, the number represented by lengthAsString is + certainly bigger than self.MAX_LENGTH, and a + NetstringParseError can be raised. + + This method should make sure that netstrings with extremely + long length specifications are refused before even attempting + to convert them to an integer (which might trigger a + MemoryError). + """ + if len(lengthAsString) > self._maxLengthSize(): + raise NetstringParseError(self._TOO_LONG % (self.MAX_LENGTH,)) + + + def _prepareForPayloadConsumption(self): + """ + Sets up variables necessary for consuming the payload of a netstring. + """ + self._state = self._PARSING_PAYLOAD + self._currentPayloadSize = 0 + self._payload.seek(0) + self._payload.truncate() + + + def _consumePayload(self): + """ + Consumes the payload portion of C{self._remainingData}. + + If the payload is complete, checks for the trailing comma and + processes the payload. If not, raises an L{IncompleteNetstring} + exception. + + @raise IncompleteNetstring: if the payload received so far + contains fewer characters than expected. + @raise NetstringParseError: if the payload does not end with a + comma. + """ + self._extractPayload() + if self._currentPayloadSize < self._expectedPayloadSize: + raise IncompleteNetstring() + self._checkForTrailingComma() + self._state = self._PARSING_LENGTH + self._processPayload() + + + def _extractPayload(self): + """ + Extracts payload information from C{self._remainingData}. + + Splits C{self._remainingData} at the end of the netstring. The + first part becomes C{self._payload}, the second part is stored + in C{self._remainingData}. + + If the netstring is not yet complete, the whole content of + C{self._remainingData} is moved to C{self._payload}. + """ + if self._payloadComplete(): + remainingPayloadSize = (self._expectedPayloadSize - + self._currentPayloadSize) + self._payload.write(self._remainingData[:remainingPayloadSize]) + self._remainingData = self._remainingData[remainingPayloadSize:] + self._currentPayloadSize = self._expectedPayloadSize + else: + self._payload.write(self._remainingData) + self._currentPayloadSize += len(self._remainingData) + self._remainingData = b"" + + + def _payloadComplete(self): + """ + Checks if enough data have been received to complete the netstring. + + @return: C{True} iff the received data contain at least as many + characters as specified in the length section of the + netstring + @rtype: C{bool} + """ + return (len(self._remainingData) + self._currentPayloadSize >= + self._expectedPayloadSize) + + + def _processPayload(self): + """ + Processes the actual payload with L{stringReceived}. + + Strips C{self._payload} of the trailing comma and calls + L{stringReceived} with the result. + """ + self.stringReceived(self._payload.getvalue()[:-1]) + + + def _checkForTrailingComma(self): + """ + Checks if the netstring has a trailing comma at the expected position. + + @raise NetstringParseError: if the last payload character is + anything but a comma. + """ + if self._payload.getvalue()[-1:] != b",": + raise NetstringParseError(self._MISSING_COMMA) + + + def _handleParseError(self): + """ + Terminates the connection and sets the flag C{self.brokenPeer}. + """ + self.transport.loseConnection() + self.brokenPeer = 1 + + + +class LineOnlyReceiver(protocol.Protocol): + """ + A protocol that receives only lines. + + This is purely a speed optimisation over LineReceiver, for the + cases that raw mode is known to be unnecessary. + + @cvar delimiter: The line-ending delimiter to use. By default this is + C{b'\\r\\n'}. + @cvar MAX_LENGTH: The maximum length of a line to allow (If a + sent line is longer than this, the connection is dropped). + Default is 16384. + """ + _buffer = b'' + delimiter = b'\r\n' + MAX_LENGTH = 16384 + + def dataReceived(self, data): + """ + Translates bytes into lines, and calls lineReceived. + """ + lines = (self._buffer+data).split(self.delimiter) + self._buffer = lines.pop(-1) + for line in lines: + if self.transport.disconnecting: + # this is necessary because the transport may be told to lose + # the connection by a line within a larger packet, and it is + # important to disregard all the lines in that packet following + # the one that told it to close. + return + if len(line) > self.MAX_LENGTH: + return self.lineLengthExceeded(line) + else: + self.lineReceived(line) + if len(self._buffer) > self.MAX_LENGTH: + return self.lineLengthExceeded(self._buffer) + + + def lineReceived(self, line): + """ + Override this for when each line is received. + + @param line: The line which was received with the delimiter removed. + @type line: C{bytes} + """ + raise NotImplementedError + + + def sendLine(self, line): + """ + Sends a line to the other end of the connection. + + @param line: The line to send, not including the delimiter. + @type line: C{bytes} + """ + return self.transport.writeSequence((line, self.delimiter)) + + + def lineLengthExceeded(self, line): + """ + Called when the maximum line length has been reached. + Override if it needs to be dealt with in some special way. + """ + return self.transport.loseConnection() + + + +class _PauseableMixin: + paused = False + + def pauseProducing(self): + self.paused = True + self.transport.pauseProducing() + + + def resumeProducing(self): + self.paused = False + self.transport.resumeProducing() + self.dataReceived(b'') + + + def stopProducing(self): + self.paused = True + self.transport.stopProducing() + + + +class LineReceiver(protocol.Protocol, _PauseableMixin): + """ + A protocol that receives lines and/or raw data, depending on mode. + + In line mode, each line that's received becomes a callback to + L{lineReceived}. In raw data mode, each chunk of raw data becomes a + callback to L{LineReceiver.rawDataReceived}. + The L{setLineMode} and L{setRawMode} methods switch between the two modes. + + This is useful for line-oriented protocols such as IRC, HTTP, POP, etc. + + @cvar delimiter: The line-ending delimiter to use. By default this is + C{b'\\r\\n'}. + @cvar MAX_LENGTH: The maximum length of a line to allow (If a + sent line is longer than this, the connection is dropped). + Default is 16384. + """ + line_mode = 1 + _buffer = b'' + _busyReceiving = False + delimiter = b'\r\n' + MAX_LENGTH = 16384 + + def clearLineBuffer(self): + """ + Clear buffered data. + + @return: All of the cleared buffered data. + @rtype: C{bytes} + """ + b, self._buffer = self._buffer, b"" + return b + + + def dataReceived(self, data): + """ + Protocol.dataReceived. + Translates bytes into lines, and calls lineReceived (or + rawDataReceived, depending on mode.) + """ + if self._busyReceiving: + self._buffer += data + return + + try: + self._busyReceiving = True + self._buffer += data + while self._buffer and not self.paused: + if self.line_mode: + try: + line, self._buffer = self._buffer.split( + self.delimiter, 1) + except ValueError: + if len(self._buffer) >= (self.MAX_LENGTH + + len(self.delimiter)): + line, self._buffer = self._buffer, b'' + return self.lineLengthExceeded(line) + return + else: + lineLength = len(line) + if lineLength > self.MAX_LENGTH: + exceeded = line + self.delimiter + self._buffer + self._buffer = b'' + return self.lineLengthExceeded(exceeded) + why = self.lineReceived(line) + if (why or self.transport and + self.transport.disconnecting): + return why + else: + data = self._buffer + self._buffer = b'' + why = self.rawDataReceived(data) + if why: + return why + finally: + self._busyReceiving = False + + + def setLineMode(self, extra=b''): + """ + Sets the line-mode of this receiver. + + If you are calling this from a rawDataReceived callback, + you can pass in extra unhandled data, and that data will + be parsed for lines. Further data received will be sent + to lineReceived rather than rawDataReceived. + + Do not pass extra data if calling this function from + within a lineReceived callback. + """ + self.line_mode = 1 + if extra: + return self.dataReceived(extra) + + + def setRawMode(self): + """ + Sets the raw mode of this receiver. + Further data received will be sent to rawDataReceived rather + than lineReceived. + """ + self.line_mode = 0 + + + def rawDataReceived(self, data): + """ + Override this for when raw data is received. + """ + raise NotImplementedError + + + def lineReceived(self, line): + """ + Override this for when each line is received. + + @param line: The line which was received with the delimiter removed. + @type line: C{bytes} + """ + raise NotImplementedError + + + def sendLine(self, line): + """ + Sends a line to the other end of the connection. + + @param line: The line to send, not including the delimiter. + @type line: C{bytes} + """ + return self.transport.write(line + self.delimiter) + + + def lineLengthExceeded(self, line): + """ + Called when the maximum line length has been reached. + Override if it needs to be dealt with in some special way. + + The argument 'line' contains the remainder of the buffer, starting + with (at least some part) of the line which is too long. This may + be more than one line, or may be only the initial portion of the + line. + """ + return self.transport.loseConnection() + + + +class StringTooLongError(AssertionError): + """ + Raised when trying to send a string too long for a length prefixed + protocol. + """ + + + +class _RecvdCompatHack(object): + """ + Emulates the to-be-deprecated C{IntNStringReceiver.recvd} attribute. + + The C{recvd} attribute was where the working buffer for buffering and + parsing netstrings was kept. It was updated each time new data arrived and + each time some of that data was parsed and delivered to application code. + The piecemeal updates to its string value were expensive and have been + removed from C{IntNStringReceiver} in the normal case. However, for + applications directly reading this attribute, this descriptor restores that + behavior. It only copies the working buffer when necessary (ie, when + accessed). This avoids the cost for applications not using the data. + + This is a custom descriptor rather than a property, because we still need + the default __set__ behavior in both new-style and old-style subclasses. + """ + def __get__(self, oself, type=None): + return oself._unprocessed[oself._compatibilityOffset:] + + + +class IntNStringReceiver(protocol.Protocol, _PauseableMixin): + """ + Generic class for length prefixed protocols. + + @ivar _unprocessed: bytes received, but not yet broken up into messages / + sent to stringReceived. _compatibilityOffset must be updated when this + value is updated so that the C{recvd} attribute can be generated + correctly. + @type _unprocessed: C{bytes} + + @ivar structFormat: format used for struct packing/unpacking. Define it in + subclass. + @type structFormat: C{str} + + @ivar prefixLength: length of the prefix, in bytes. Define it in subclass, + using C{struct.calcsize(structFormat)} + @type prefixLength: C{int} + + @ivar _compatibilityOffset: the offset within C{_unprocessed} to the next + message to be parsed. (used to generate the recvd attribute) + @type _compatibilityOffset: C{int} + """ + + MAX_LENGTH = 99999 + _unprocessed = b"" + _compatibilityOffset = 0 + + # Backwards compatibility support for applications which directly touch the + # "internal" parse buffer. + recvd = _RecvdCompatHack() + + def stringReceived(self, string): + """ + Override this for notification when each complete string is received. + + @param string: The complete string which was received with all + framing (length prefix, etc) removed. + @type string: C{bytes} + """ + raise NotImplementedError + + + def lengthLimitExceeded(self, length): + """ + Callback invoked when a length prefix greater than C{MAX_LENGTH} is + received. The default implementation disconnects the transport. + Override this. + + @param length: The length prefix which was received. + @type length: C{int} + """ + self.transport.loseConnection() + + + def dataReceived(self, data): + """ + Convert int prefixed strings into calls to stringReceived. + """ + # Try to minimize string copying (via slices) by keeping one buffer + # containing all the data we have so far and a separate offset into that + # buffer. + alldata = self._unprocessed + data + currentOffset = 0 + prefixLength = self.prefixLength + fmt = self.structFormat + self._unprocessed = alldata + + while len(alldata) >= (currentOffset + prefixLength) and not self.paused: + messageStart = currentOffset + prefixLength + length, = unpack(fmt, alldata[currentOffset:messageStart]) + if length > self.MAX_LENGTH: + self._unprocessed = alldata + self._compatibilityOffset = currentOffset + self.lengthLimitExceeded(length) + return + messageEnd = messageStart + length + if len(alldata) < messageEnd: + break + + # Here we have to slice the working buffer so we can send just the + # netstring into the stringReceived callback. + packet = alldata[messageStart:messageEnd] + currentOffset = messageEnd + self._compatibilityOffset = currentOffset + self.stringReceived(packet) + + # Check to see if the backwards compat "recvd" attribute got written + # to by application code. If so, drop the current data buffer and + # switch to the new buffer given by that attribute's value. + if 'recvd' in self.__dict__: + alldata = self.__dict__.pop('recvd') + self._unprocessed = alldata + self._compatibilityOffset = currentOffset = 0 + if alldata: + continue + return + + # Slice off all the data that has been processed, avoiding holding onto + # memory to store it, and update the compatibility attributes to reflect + # that change. + self._unprocessed = alldata[currentOffset:] + self._compatibilityOffset = 0 + + + def sendString(self, string): + """ + Send a prefixed string to the other end of the connection. + + @param string: The string to send. The necessary framing (length + prefix, etc) will be added. + @type string: C{bytes} + """ + if len(string) >= 2 ** (8 * self.prefixLength): + raise StringTooLongError( + "Try to send %s bytes whereas maximum is %s" % ( + len(string), 2 ** (8 * self.prefixLength))) + self.transport.write( + pack(self.structFormat, len(string)) + string) + + + +class Int32StringReceiver(IntNStringReceiver): + """ + A receiver for int32-prefixed strings. + + An int32 string is a string prefixed by 4 bytes, the 32-bit length of + the string encoded in network byte order. + + This class publishes the same interface as NetstringReceiver. + """ + structFormat = "!I" + prefixLength = calcsize(structFormat) + + + +class Int16StringReceiver(IntNStringReceiver): + """ + A receiver for int16-prefixed strings. + + An int16 string is a string prefixed by 2 bytes, the 16-bit length of + the string encoded in network byte order. + + This class publishes the same interface as NetstringReceiver. + """ + structFormat = "!H" + prefixLength = calcsize(structFormat) + + + +class Int8StringReceiver(IntNStringReceiver): + """ + A receiver for int8-prefixed strings. + + An int8 string is a string prefixed by 1 byte, the 8-bit length of + the string. + + This class publishes the same interface as NetstringReceiver. + """ + structFormat = "!B" + prefixLength = calcsize(structFormat) + + + +class StatefulStringProtocol: + """ + A stateful string protocol. + + This is a mixin for string protocols (L{Int32StringReceiver}, + L{NetstringReceiver}) which translates L{stringReceived} into a callback + (prefixed with C{'proto_'}) depending on state. + + The state C{'done'} is special; if a C{proto_*} method returns it, the + connection will be closed immediately. + + @ivar state: Current state of the protocol. Defaults to C{'init'}. + @type state: C{str} + """ + + state = 'init' + + def stringReceived(self, string): + """ + Choose a protocol phase function and call it. + + Call back to the appropriate protocol phase; this begins with + the function C{proto_init} and moves on to C{proto_*} depending on + what each C{proto_*} function returns. (For example, if + C{self.proto_init} returns 'foo', then C{self.proto_foo} will be the + next function called when a protocol message is received. + """ + try: + pto = 'proto_' + self.state + statehandler = getattr(self, pto) + except AttributeError: + log.msg('callback', self.state, 'not found') + else: + self.state = statehandler(string) + if self.state == 'done': + self.transport.loseConnection() + + + +@implementer(interfaces.IProducer) +class FileSender: + """ + A producer that sends the contents of a file to a consumer. + + This is a helper for protocols that, at some point, will take a + file-like object, read its contents, and write them out to the network, + optionally performing some transformation on the bytes in between. + """ + + CHUNK_SIZE = 2 ** 14 + + lastSent = '' + deferred = None + + def beginFileTransfer(self, file, consumer, transform=None): + """ + Begin transferring a file + + @type file: Any file-like object + @param file: The file object to read data from + + @type consumer: Any implementor of IConsumer + @param consumer: The object to write data to + + @param transform: A callable taking one string argument and returning + the same. All bytes read from the file are passed through this before + being written to the consumer. + + @rtype: C{Deferred} + @return: A deferred whose callback will be invoked when the file has + been completely written to the consumer. The last byte written to the + consumer is passed to the callback. + """ + self.file = file + self.consumer = consumer + self.transform = transform + + self.deferred = deferred = defer.Deferred() + self.consumer.registerProducer(self, False) + return deferred + + + def resumeProducing(self): + chunk = '' + if self.file: + chunk = self.file.read(self.CHUNK_SIZE) + if not chunk: + self.file = None + self.consumer.unregisterProducer() + if self.deferred: + self.deferred.callback(self.lastSent) + self.deferred = None + return + + if self.transform: + chunk = self.transform(chunk) + self.consumer.write(chunk) + self.lastSent = chunk[-1:] + + + def pauseProducing(self): + pass + + + def stopProducing(self): + if self.deferred: + self.deferred.errback( + Exception("Consumer asked us to stop producing")) + self.deferred = None diff --git a/contrib/python/Twisted/py2/twisted/protocols/dict.py b/contrib/python/Twisted/py2/twisted/protocols/dict.py new file mode 100644 index 0000000000..d7976411bc --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/dict.py @@ -0,0 +1,415 @@ +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + + +""" +Dict client protocol implementation. + +@author: Pavel Pergamenshchik +""" + +from twisted.protocols import basic +from twisted.internet import defer, protocol +from twisted.python import log +from io import BytesIO + +def parseParam(line): + """Chew one dqstring or atom from beginning of line and return (param, remaningline)""" + if line == b'': + return (None, b'') + elif line[0:1] != b'"': # atom + mode = 1 + else: # dqstring + mode = 2 + res = b"" + io = BytesIO(line) + if mode == 2: # skip the opening quote + io.read(1) + while 1: + a = io.read(1) + if a == b'"': + if mode == 2: + io.read(1) # skip the separating space + return (res, io.read()) + elif a == b'\\': + a = io.read(1) + if a == b'': + return (None, line) # unexpected end of string + elif a == b'': + if mode == 1: + return (res, io.read()) + else: + return (None, line) # unexpected end of string + elif a == b' ': + if mode == 1: + return (res, io.read()) + res += a + + + +def makeAtom(line): + """Munch a string into an 'atom'""" + # FIXME: proper quoting + return filter(lambda x: not (x in map(chr, range(33)+[34, 39, 92])), line) + + + +def makeWord(s): + mustquote = range(33)+[34, 39, 92] + result = [] + for c in s: + if ord(c) in mustquote: + result.append(b"\\") + result.append(c) + s = b"".join(result) + return s + + + +def parseText(line): + if len(line) == 1 and line == b'.': + return None + else: + if len(line) > 1 and line[0:2] == b'..': + line = line[1:] + return line + + + +class Definition: + """A word definition""" + def __init__(self, name, db, dbdesc, text): + self.name = name + self.db = db + self.dbdesc = dbdesc + self.text = text # list of strings not terminated by newline + + + +class DictClient(basic.LineReceiver): + """dict (RFC2229) client""" + + data = None # multiline data + MAX_LENGTH = 1024 + state = None + mode = None + result = None + factory = None + + def __init__(self): + self.data = None + self.result = None + + + def connectionMade(self): + self.state = "conn" + self.mode = "command" + + + def sendLine(self, line): + """Throw up if the line is longer than 1022 characters""" + if len(line) > self.MAX_LENGTH - 2: + raise ValueError("DictClient tried to send a too long line") + basic.LineReceiver.sendLine(self, line) + + + def lineReceived(self, line): + try: + line = line.decode("utf-8") + except UnicodeError: # garbage received, skip + return + if self.mode == "text": # we are receiving textual data + code = "text" + else: + if len(line) < 4: + log.msg("DictClient got invalid line from server -- %s" % line) + self.protocolError("Invalid line from server") + self.transport.LoseConnection() + return + code = int(line[:3]) + line = line[4:] + method = getattr(self, 'dictCode_%s_%s' % (code, self.state), self.dictCode_default) + method(line) + + + def dictCode_default(self, line): + """Unknown message""" + log.msg("DictClient got unexpected message from server -- %s" % line) + self.protocolError("Unexpected server message") + self.transport.loseConnection() + + + def dictCode_221_ready(self, line): + """We are about to get kicked off, do nothing""" + pass + + + def dictCode_220_conn(self, line): + """Greeting message""" + self.state = "ready" + self.dictConnected() + + + def dictCode_530_conn(self): + self.protocolError("Access denied") + self.transport.loseConnection() + + + def dictCode_420_conn(self): + self.protocolError("Server temporarily unavailable") + self.transport.loseConnection() + + + def dictCode_421_conn(self): + self.protocolError("Server shutting down at operator request") + self.transport.loseConnection() + + + def sendDefine(self, database, word): + """Send a dict DEFINE command""" + assert self.state == "ready", "DictClient.sendDefine called when not in ready state" + self.result = None # these two are just in case. In "ready" state, result and data + self.data = None # should be None + self.state = "define" + command = "DEFINE %s %s" % (makeAtom(database.encode("UTF-8")), makeWord(word.encode("UTF-8"))) + self.sendLine(command) + + + def sendMatch(self, database, strategy, word): + """Send a dict MATCH command""" + assert self.state == "ready", "DictClient.sendMatch called when not in ready state" + self.result = None + self.data = None + self.state = "match" + command = "MATCH %s %s %s" % (makeAtom(database), makeAtom(strategy), makeAtom(word)) + self.sendLine(command.encode("UTF-8")) + + def dictCode_550_define(self, line): + """Invalid database""" + self.mode = "ready" + self.defineFailed("Invalid database") + + + def dictCode_550_match(self, line): + """Invalid database""" + self.mode = "ready" + self.matchFailed("Invalid database") + + + def dictCode_551_match(self, line): + """Invalid strategy""" + self.mode = "ready" + self.matchFailed("Invalid strategy") + + + def dictCode_552_define(self, line): + """No match""" + self.mode = "ready" + self.defineFailed("No match") + + + def dictCode_552_match(self, line): + """No match""" + self.mode = "ready" + self.matchFailed("No match") + + + def dictCode_150_define(self, line): + """n definitions retrieved""" + self.result = [] + + + def dictCode_151_define(self, line): + """Definition text follows""" + self.mode = "text" + (word, line) = parseParam(line) + (db, line) = parseParam(line) + (dbdesc, line) = parseParam(line) + if not (word and db and dbdesc): + self.protocolError("Invalid server response") + self.transport.loseConnection() + else: + self.result.append(Definition(word, db, dbdesc, [])) + self.data = [] + + + def dictCode_152_match(self, line): + """n matches found, text follows""" + self.mode = "text" + self.result = [] + self.data = [] + + + def dictCode_text_define(self, line): + """A line of definition text received""" + res = parseText(line) + if res == None: + self.mode = "command" + self.result[-1].text = self.data + self.data = None + else: + self.data.append(line) + + + def dictCode_text_match(self, line): + """One line of match text received""" + def l(s): + p1, t = parseParam(s) + p2, t = parseParam(t) + return (p1, p2) + res = parseText(line) + if res == None: + self.mode = "command" + self.result = map(l, self.data) + self.data = None + else: + self.data.append(line) + + + def dictCode_250_define(self, line): + """ok""" + t = self.result + self.result = None + self.state = "ready" + self.defineDone(t) + + + def dictCode_250_match(self, line): + """ok""" + t = self.result + self.result = None + self.state = "ready" + self.matchDone(t) + + + def protocolError(self, reason): + """override to catch unexpected dict protocol conditions""" + pass + + + def dictConnected(self): + """override to be notified when the server is ready to accept commands""" + pass + + + def defineFailed(self, reason): + """override to catch reasonable failure responses to DEFINE""" + pass + + + def defineDone(self, result): + """override to catch successful DEFINE""" + pass + + + def matchFailed(self, reason): + """override to catch resonable failure responses to MATCH""" + pass + + + def matchDone(self, result): + """override to catch successful MATCH""" + pass + + + +class InvalidResponse(Exception): + pass + + + +class DictLookup(DictClient): + """Utility class for a single dict transaction. To be used with DictLookupFactory""" + + def protocolError(self, reason): + if not self.factory.done: + self.factory.d.errback(InvalidResponse(reason)) + self.factory.clientDone() + + + def dictConnected(self): + if self.factory.queryType == "define": + self.sendDefine(*self.factory.param) + elif self.factory.queryType == "match": + self.sendMatch(*self.factory.param) + + + def defineFailed(self, reason): + self.factory.d.callback([]) + self.factory.clientDone() + self.transport.loseConnection() + + + def defineDone(self, result): + self.factory.d.callback(result) + self.factory.clientDone() + self.transport.loseConnection() + + + def matchFailed(self, reason): + self.factory.d.callback([]) + self.factory.clientDone() + self.transport.loseConnection() + + + def matchDone(self, result): + self.factory.d.callback(result) + self.factory.clientDone() + self.transport.loseConnection() + + + +class DictLookupFactory(protocol.ClientFactory): + """Utility factory for a single dict transaction""" + protocol = DictLookup + done = None + + def __init__(self, queryType, param, d): + self.queryType = queryType + self.param = param + self.d = d + self.done = 0 + + + def clientDone(self): + """Called by client when done.""" + self.done = 1 + del self.d + + + def clientConnectionFailed(self, connector, error): + self.d.errback(error) + + + def clientConnectionLost(self, connector, error): + if not self.done: + self.d.errback(error) + + + def buildProtocol(self, addr): + p = self.protocol() + p.factory = self + return p + + + +def define(host, port, database, word): + """Look up a word using a dict server""" + d = defer.Deferred() + factory = DictLookupFactory("define", (database, word), d) + + from twisted.internet import reactor + reactor.connectTCP(host, port, factory) + return d + + + +def match(host, port, database, strategy, word): + """Match a word using a dict server""" + d = defer.Deferred() + factory = DictLookupFactory("match", (database, strategy, word), d) + + from twisted.internet import reactor + reactor.connectTCP(host, port, factory) + return d + diff --git a/contrib/python/Twisted/py2/twisted/protocols/finger.py b/contrib/python/Twisted/py2/twisted/protocols/finger.py new file mode 100644 index 0000000000..101f29b4f0 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/finger.py @@ -0,0 +1,42 @@ +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + + +"""The Finger User Information Protocol (RFC 1288)""" + +from twisted.protocols import basic + +class Finger(basic.LineReceiver): + + def lineReceived(self, line): + parts = line.split() + if not parts: + parts = [b''] + if len(parts) == 1: + slash_w = 0 + else: + slash_w = 1 + user = parts[-1] + if b'@' in user: + hostPlace = user.rfind(b'@') + user = user[:hostPlace] + host = user[hostPlace+1:] + return self.forwardQuery(slash_w, user, host) + if user: + return self.getUser(slash_w, user) + else: + return self.getDomain(slash_w) + + def _refuseMessage(self, message): + self.transport.write(message + b"\n") + self.transport.loseConnection() + + def forwardQuery(self, slash_w, user, host): + self._refuseMessage(b'Finger forwarding service denied') + + def getDomain(self, slash_w): + self._refuseMessage(b'Finger online list denied') + + def getUser(self, slash_w, user): + self.transport.write(b'Login: ' + user + b'\n') + self._refuseMessage(b'No such user') diff --git a/contrib/python/Twisted/py2/twisted/protocols/ftp.py b/contrib/python/Twisted/py2/twisted/protocols/ftp.py new file mode 100644 index 0000000000..0c7171a070 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/ftp.py @@ -0,0 +1,3374 @@ +# -*- test-case-name: twisted.test.test_ftp -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +An FTP protocol implementation +""" + +# System Imports +import os +import time +import re +import stat +import errno +import fnmatch + +try: + import pwd + import grp +except ImportError: + pwd = grp = None + +from zope.interface import Interface, implementer + +# Twisted Imports +from twisted import copyright +from twisted.internet import reactor, interfaces, protocol, error, defer +from twisted.protocols import basic, policies + +from twisted.python import log, failure, filepath +from twisted.python.compat import range, unicode +from twisted.cred import error as cred_error, portal, credentials, checkers + +# constants +# response codes + +RESTART_MARKER_REPLY = "100" +SERVICE_READY_IN_N_MINUTES = "120" +DATA_CNX_ALREADY_OPEN_START_XFR = "125" +FILE_STATUS_OK_OPEN_DATA_CNX = "150" + +CMD_OK = "200.1" +TYPE_SET_OK = "200.2" +ENTERING_PORT_MODE = "200.3" +CMD_NOT_IMPLMNTD_SUPERFLUOUS = "202" +SYS_STATUS_OR_HELP_REPLY = "211.1" +FEAT_OK = '211.2' +DIR_STATUS = "212" +FILE_STATUS = "213" +HELP_MSG = "214" +NAME_SYS_TYPE = "215" +SVC_READY_FOR_NEW_USER = "220.1" +WELCOME_MSG = "220.2" +SVC_CLOSING_CTRL_CNX = "221.1" +GOODBYE_MSG = "221.2" +DATA_CNX_OPEN_NO_XFR_IN_PROGRESS = "225" +CLOSING_DATA_CNX = "226.1" +TXFR_COMPLETE_OK = "226.2" +ENTERING_PASV_MODE = "227" +ENTERING_EPSV_MODE = "229" +USR_LOGGED_IN_PROCEED = "230.1" # v1 of code 230 +GUEST_LOGGED_IN_PROCEED = "230.2" # v2 of code 230 +REQ_FILE_ACTN_COMPLETED_OK = "250" +PWD_REPLY = "257.1" +MKD_REPLY = "257.2" + +USR_NAME_OK_NEED_PASS = "331.1" # v1 of Code 331 +GUEST_NAME_OK_NEED_EMAIL = "331.2" # v2 of code 331 +NEED_ACCT_FOR_LOGIN = "332" +REQ_FILE_ACTN_PENDING_FURTHER_INFO = "350" + +SVC_NOT_AVAIL_CLOSING_CTRL_CNX = "421.1" +TOO_MANY_CONNECTIONS = "421.2" +CANT_OPEN_DATA_CNX = "425" +CNX_CLOSED_TXFR_ABORTED = "426" +REQ_ACTN_ABRTD_FILE_UNAVAIL = "450" +REQ_ACTN_ABRTD_LOCAL_ERR = "451" +REQ_ACTN_ABRTD_INSUFF_STORAGE = "452" + +SYNTAX_ERR = "500" +SYNTAX_ERR_IN_ARGS = "501" +CMD_NOT_IMPLMNTD = "502.1" +OPTS_NOT_IMPLEMENTED = '502.2' +BAD_CMD_SEQ = "503" +CMD_NOT_IMPLMNTD_FOR_PARAM = "504" +NOT_LOGGED_IN = "530.1" # v1 of code 530 - please log in +AUTH_FAILURE = "530.2" # v2 of code 530 - authorization failure +NEED_ACCT_FOR_STOR = "532" +FILE_NOT_FOUND = "550.1" # no such file or directory +PERMISSION_DENIED = "550.2" # permission denied +ANON_USER_DENIED = "550.3" # anonymous users can't alter filesystem +IS_NOT_A_DIR = "550.4" # rmd called on a path that is not a directory +REQ_ACTN_NOT_TAKEN = "550.5" +FILE_EXISTS = "550.6" +IS_A_DIR = "550.7" +PAGE_TYPE_UNK = "551" +EXCEEDED_STORAGE_ALLOC = "552" +FILENAME_NOT_ALLOWED = "553" + + +RESPONSE = { + # -- 100's -- + # TODO: this must be fixed + RESTART_MARKER_REPLY: '110 MARK yyyy-mmmm', + SERVICE_READY_IN_N_MINUTES: '120 service ready in %s minutes', + DATA_CNX_ALREADY_OPEN_START_XFR: '125 Data connection already open, ' + 'starting transfer', + FILE_STATUS_OK_OPEN_DATA_CNX: '150 File status okay; about to open ' + 'data connection.', + + # -- 200's -- + CMD_OK: '200 Command OK', + TYPE_SET_OK: '200 Type set to %s.', + ENTERING_PORT_MODE: '200 PORT OK', + CMD_NOT_IMPLMNTD_SUPERFLUOUS: '202 Command not implemented, ' + 'superfluous at this site', + SYS_STATUS_OR_HELP_REPLY: '211 System status reply', + FEAT_OK: ['211-Features:', '211 End'], + DIR_STATUS: '212 %s', + FILE_STATUS: '213 %s', + HELP_MSG: '214 help: %s', + NAME_SYS_TYPE: '215 UNIX Type: L8', + WELCOME_MSG: "220 %s", + SVC_READY_FOR_NEW_USER: '220 Service ready', + SVC_CLOSING_CTRL_CNX: '221 Service closing control ' + 'connection', + GOODBYE_MSG: '221 Goodbye.', + DATA_CNX_OPEN_NO_XFR_IN_PROGRESS: '225 data connection open, no ' + 'transfer in progress', + CLOSING_DATA_CNX: '226 Abort successful', + TXFR_COMPLETE_OK: '226 Transfer Complete.', + ENTERING_PASV_MODE: '227 Entering Passive Mode (%s).', + # Where is EPSV defined in the RFCs? + ENTERING_EPSV_MODE: '229 Entering Extended Passive Mode ' + '(|||%s|).', + USR_LOGGED_IN_PROCEED: '230 User logged in, proceed', + GUEST_LOGGED_IN_PROCEED: '230 Anonymous login ok, access ' + 'restrictions apply.', + # i.e. CWD completed OK + REQ_FILE_ACTN_COMPLETED_OK: '250 Requested File Action Completed ' + 'OK', + PWD_REPLY: '257 "%s"', + MKD_REPLY: '257 "%s" created', + + # -- 300's -- + USR_NAME_OK_NEED_PASS: '331 Password required for %s.', + GUEST_NAME_OK_NEED_EMAIL: '331 Guest login ok, type your email ' + 'address as password.', + NEED_ACCT_FOR_LOGIN: '332 Need account for login.', + + REQ_FILE_ACTN_PENDING_FURTHER_INFO: '350 Requested file action pending ' + 'further information.', + + # -- 400's -- + SVC_NOT_AVAIL_CLOSING_CTRL_CNX: '421 Service not available, closing ' + 'control connection.', + TOO_MANY_CONNECTIONS: '421 Too many users right now, try ' + 'again in a few minutes.', + CANT_OPEN_DATA_CNX: "425 Can't open data connection.", + CNX_CLOSED_TXFR_ABORTED: '426 Transfer aborted. Data ' + 'connection closed.', + + REQ_ACTN_ABRTD_FILE_UNAVAIL: '450 Requested action aborted. ' + 'File unavailable.', + REQ_ACTN_ABRTD_LOCAL_ERR: '451 Requested action aborted. ' + 'Local error in processing.', + REQ_ACTN_ABRTD_INSUFF_STORAGE: '452 Requested action aborted. ' + 'Insufficient storage.', + + # -- 500's -- + SYNTAX_ERR: "500 Syntax error: %s", + SYNTAX_ERR_IN_ARGS: '501 syntax error in argument(s) %s.', + CMD_NOT_IMPLMNTD: "502 Command '%s' not implemented", + OPTS_NOT_IMPLEMENTED: "502 Option '%s' not implemented.", + BAD_CMD_SEQ: '503 Incorrect sequence of commands: ' + '%s', + CMD_NOT_IMPLMNTD_FOR_PARAM: "504 Not implemented for parameter " + "'%s'.", + NOT_LOGGED_IN: '530 Please login with USER and PASS.', + AUTH_FAILURE: '530 Sorry, Authentication failed.', + NEED_ACCT_FOR_STOR: '532 Need an account for storing ' + 'files', + FILE_NOT_FOUND: '550 %s: No such file or directory.', + PERMISSION_DENIED: '550 %s: Permission denied.', + ANON_USER_DENIED: '550 Anonymous users are forbidden to ' + 'change the filesystem', + IS_NOT_A_DIR: '550 Cannot rmd, %s is not a ' + 'directory', + FILE_EXISTS: '550 %s: File exists', + IS_A_DIR: '550 %s: is a directory', + REQ_ACTN_NOT_TAKEN: '550 Requested action not taken: %s', + PAGE_TYPE_UNK: '551 Page type unknown', + EXCEEDED_STORAGE_ALLOC: '552 Requested file action aborted, ' + 'exceeded file storage allocation', + FILENAME_NOT_ALLOWED: '553 Requested action not taken, file ' + 'name not allowed' +} + + + +class InvalidPath(Exception): + """ + Internal exception used to signify an error during parsing a path. + """ + + + +def toSegments(cwd, path): + """ + Normalize a path, as represented by a list of strings each + representing one segment of the path. + """ + if path.startswith('/'): + segs = [] + else: + segs = cwd[:] + + for s in path.split('/'): + if s == '.' or s == '': + continue + elif s == '..': + if segs: + segs.pop() + else: + raise InvalidPath(cwd, path) + elif '\0' in s or '/' in s: + raise InvalidPath(cwd, path) + else: + segs.append(s) + return segs + + + +def errnoToFailure(e, path): + """ + Map C{OSError} and C{IOError} to standard FTP errors. + """ + if e == errno.ENOENT: + return defer.fail(FileNotFoundError(path)) + elif e == errno.EACCES or e == errno.EPERM: + return defer.fail(PermissionDeniedError(path)) + elif e == errno.ENOTDIR: + return defer.fail(IsNotADirectoryError(path)) + elif e == errno.EEXIST: + return defer.fail(FileExistsError(path)) + elif e == errno.EISDIR: + return defer.fail(IsADirectoryError(path)) + else: + return defer.fail() + + + +_testTranslation = fnmatch.translate('TEST') + + + +def _isGlobbingExpression(segments=None): + """ + Helper for checking if a FTPShell `segments` contains a wildcard Unix + expression. + + Only filename globbing is supported. + This means that wildcards can only be presents in the last element of + `segments`. + + @type segments: C{list} + @param segments: List of path elements as used by the FTP server protocol. + + @rtype: Boolean + @return: True if `segments` contains a globbing expression. + """ + if not segments: + return False + + # To check that something is a glob expression, we convert it to + # Regular Expression. + # We compare it to the translation of a known non-glob expression. + # If the result is the same as the original expression then it contains no + # globbing expression. + globCandidate = segments[-1] + globTranslations = fnmatch.translate(globCandidate) + nonGlobTranslations = _testTranslation.replace('TEST', globCandidate, 1) + + if nonGlobTranslations == globTranslations: + return False + else: + return True + + + +class FTPCmdError(Exception): + """ + Generic exception for FTP commands. + """ + def __init__(self, *msg): + Exception.__init__(self, *msg) + self.errorMessage = msg + + + def response(self): + """ + Generate a FTP response message for this error. + """ + return RESPONSE[self.errorCode] % self.errorMessage + + + +class FileNotFoundError(FTPCmdError): + """ + Raised when trying to access a non existent file or directory. + """ + errorCode = FILE_NOT_FOUND + + + +class AnonUserDeniedError(FTPCmdError): + """ + Raised when an anonymous user issues a command that will alter the + filesystem + """ + + errorCode = ANON_USER_DENIED + + + +class PermissionDeniedError(FTPCmdError): + """ + Raised when access is attempted to a resource to which access is + not allowed. + """ + errorCode = PERMISSION_DENIED + + + +class IsNotADirectoryError(FTPCmdError): + """ + Raised when RMD is called on a path that isn't a directory. + """ + errorCode = IS_NOT_A_DIR + + + +class FileExistsError(FTPCmdError): + """ + Raised when attempted to override an existing resource. + """ + errorCode = FILE_EXISTS + + + +class IsADirectoryError(FTPCmdError): + """ + Raised when DELE is called on a path that is a directory. + """ + errorCode = IS_A_DIR + + + +class CmdSyntaxError(FTPCmdError): + """ + Raised when a command syntax is wrong. + """ + errorCode = SYNTAX_ERR + + + +class CmdArgSyntaxError(FTPCmdError): + """ + Raised when a command is called with wrong value or a wrong number of + arguments. + """ + errorCode = SYNTAX_ERR_IN_ARGS + + + +class CmdNotImplementedError(FTPCmdError): + """ + Raised when an unimplemented command is given to the server. + """ + errorCode = CMD_NOT_IMPLMNTD + + + +class CmdNotImplementedForArgError(FTPCmdError): + """ + Raised when the handling of a parameter for a command is not implemented by + the server. + """ + errorCode = CMD_NOT_IMPLMNTD_FOR_PARAM + + + +class FTPError(Exception): + pass + + + +class PortConnectionError(Exception): + pass + + + +class BadCmdSequenceError(FTPCmdError): + """ + Raised when a client sends a series of commands in an illogical sequence. + """ + errorCode = BAD_CMD_SEQ + + + +class AuthorizationError(FTPCmdError): + """ + Raised when client authentication fails. + """ + errorCode = AUTH_FAILURE + + + +def debugDeferred(self, *_): + log.msg('debugDeferred(): %s' % str(_), debug=True) + + + +# -- DTP Protocol -- + + +_months = [ + None, + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + + + +@implementer(interfaces.IConsumer) +class DTP(protocol.Protocol, object): + isConnected = False + + _cons = None + _onConnLost = None + _buffer = None + _encoding = 'latin-1' + + def connectionMade(self): + self.isConnected = True + self.factory.deferred.callback(None) + self._buffer = [] + + def connectionLost(self, reason): + self.isConnected = False + if self._onConnLost is not None: + self._onConnLost.callback(None) + + def sendLine(self, line): + """ + Send a line to data channel. + + @param line: The line to be sent. + @type line: L{bytes} + """ + self.transport.write(line + b'\r\n') + + + def _formatOneListResponse(self, name, size, directory, permissions, + hardlinks, modified, owner, group): + """ + Helper method to format one entry's info into a text entry like: + 'drwxrwxrwx 0 user group 0 Jan 01 1970 filename.txt' + + @param name: C{bytes} name of the entry (file or directory or link) + @param size: C{int} size of the entry + @param directory: evals to C{bool} - whether the entry is a directory + @param permissions: L{twisted.python.filepath.Permissions} object + representing that entry's permissions + @param hardlinks: C{int} number of hardlinks + @param modified: C{float} - entry's last modified time in seconds + since the epoch + @param owner: C{str} username of the owner + @param group: C{str} group name of the owner + + @return: C{str} in the requisite format + """ + def formatDate(mtime): + now = time.gmtime() + info = { + 'month': _months[mtime.tm_mon], + 'day': mtime.tm_mday, + 'year': mtime.tm_year, + 'hour': mtime.tm_hour, + 'minute': mtime.tm_min + } + if now.tm_year != mtime.tm_year: + return '%(month)s %(day)02d %(year)5d' % info + else: + return '%(month)s %(day)02d %(hour)02d:%(minute)02d' % info + + format = ('%(directory)s%(permissions)s%(hardlinks)4d ' + '%(owner)-9s %(group)-9s %(size)15d %(date)12s ' + ) + + msg = (format % { + 'directory': directory and 'd' or '-', + 'permissions': permissions.shorthand(), + 'hardlinks': hardlinks, + 'owner': owner[:8], + 'group': group[:8], + 'size': size, + 'date': formatDate(time.gmtime(modified)), + }).encode(self._encoding) + return msg + name + + + def sendListResponse(self, name, response): + self.sendLine(self._formatOneListResponse(name, *response)) + + # Proxy IConsumer to our transport + def registerProducer(self, producer, streaming): + return self.transport.registerProducer(producer, streaming) + + def unregisterProducer(self): + self.transport.unregisterProducer() + self.transport.loseConnection() + + def write(self, data): + if self.isConnected: + return self.transport.write(data) + raise Exception("Crap damn crap damn crap damn") + + + # Pretend to be a producer, too. + def _conswrite(self, bytes): + try: + self._cons.write(bytes) + except: + self._onConnLost.errback() + + def dataReceived(self, bytes): + if self._cons is not None: + self._conswrite(bytes) + else: + self._buffer.append(bytes) + + def _unregConsumer(self, ignored): + self._cons.unregisterProducer() + self._cons = None + del self._onConnLost + return ignored + + def registerConsumer(self, cons): + assert self._cons is None + self._cons = cons + self._cons.registerProducer(self, True) + for chunk in self._buffer: + self._conswrite(chunk) + self._buffer = None + if self.isConnected: + self._onConnLost = d = defer.Deferred() + d.addBoth(self._unregConsumer) + return d + else: + self._cons.unregisterProducer() + self._cons = None + return defer.succeed(None) + + def resumeProducing(self): + self.transport.resumeProducing() + + def pauseProducing(self): + self.transport.pauseProducing() + + def stopProducing(self): + self.transport.stopProducing() + + + +class DTPFactory(protocol.ClientFactory): + """ + Client factory for I{data transfer process} protocols. + + @ivar peerCheck: perform checks to make sure the ftp-pi's peer is the same + as the dtp's + @ivar pi: a reference to this factory's protocol interpreter + + @ivar _state: Indicates the current state of the DTPFactory. Initially, + this is L{_IN_PROGRESS}. If the connection fails or times out, it is + L{_FAILED}. If the connection succeeds before the timeout, it is + L{_FINISHED}. + + @cvar _IN_PROGRESS: Token to signal that connection is active. + @type _IN_PROGRESS: L{object}. + + @cvar _FAILED: Token to signal that connection has failed. + @type _FAILED: L{object}. + + @cvar _FINISHED: Token to signal that connection was successfully closed. + @type _FINISHED: L{object}. + """ + + _IN_PROGRESS = object() + _FAILED = object() + _FINISHED = object() + + _state = _IN_PROGRESS + + # -- configuration variables -- + peerCheck = False + + # -- class variables -- + def __init__(self, pi, peerHost=None, reactor=None): + """ + Constructor + + @param pi: this factory's protocol interpreter + @param peerHost: if peerCheck is True, this is the tuple that the + generated instance will use to perform security checks + """ + self.pi = pi + self.peerHost = peerHost # from FTP.transport.peerHost() + # deferred will fire when instance is connected + self.deferred = defer.Deferred() + self.delayedCall = None + if reactor is None: + from twisted.internet import reactor + self._reactor = reactor + + + def buildProtocol(self, addr): + log.msg('DTPFactory.buildProtocol', debug=True) + + if self._state is not self._IN_PROGRESS: + return None + self._state = self._FINISHED + + self.cancelTimeout() + p = DTP() + p.factory = self + p.pi = self.pi + self.pi.dtpInstance = p + return p + + + def stopFactory(self): + log.msg('dtpFactory.stopFactory', debug=True) + self.cancelTimeout() + + + def timeoutFactory(self): + log.msg('timed out waiting for DTP connection') + if self._state is not self._IN_PROGRESS: + return + self._state = self._FAILED + + d = self.deferred + self.deferred = None + d.errback( + PortConnectionError(defer.TimeoutError("DTPFactory timeout"))) + + + def cancelTimeout(self): + if self.delayedCall is not None and self.delayedCall.active(): + log.msg('cancelling DTP timeout', debug=True) + self.delayedCall.cancel() + + + def setTimeout(self, seconds): + log.msg('DTPFactory.setTimeout set to %s seconds' % seconds) + self.delayedCall = self._reactor.callLater( + seconds, self.timeoutFactory) + + + def clientConnectionFailed(self, connector, reason): + if self._state is not self._IN_PROGRESS: + return + self._state = self._FAILED + d = self.deferred + self.deferred = None + d.errback(PortConnectionError(reason)) + + + +# -- FTP-PI (Protocol Interpreter) -- + +class ASCIIConsumerWrapper(object): + def __init__(self, cons): + self.cons = cons + self.registerProducer = cons.registerProducer + self.unregisterProducer = cons.unregisterProducer + + assert os.linesep == "\r\n" or len(os.linesep) == 1, ( + "Unsupported platform (yea right like this even exists)") + + if os.linesep == "\r\n": + self.write = cons.write + + def write(self, bytes): + return self.cons.write(bytes.replace(os.linesep, "\r\n")) + + + +@implementer(interfaces.IConsumer) +class FileConsumer(object): + """ + A consumer for FTP input that writes data to a file. + + @ivar fObj: a file object opened for writing, used to write data received. + @type fObj: C{file} + """ + def __init__(self, fObj): + self.fObj = fObj + + + def registerProducer(self, producer, streaming): + self.producer = producer + assert streaming + + + def unregisterProducer(self): + self.producer = None + self.fObj.close() + + + def write(self, bytes): + self.fObj.write(bytes) + + + +class FTPOverflowProtocol(basic.LineReceiver): + """FTP mini-protocol for when there are too many connections.""" + _encoding = 'latin-1' + + def connectionMade(self): + self.sendLine(RESPONSE[TOO_MANY_CONNECTIONS].encode(self._encoding)) + self.transport.loseConnection() + + + +class FTP(basic.LineReceiver, policies.TimeoutMixin, object): + """ + Protocol Interpreter for the File Transfer Protocol + + @ivar state: The current server state. One of L{UNAUTH}, + L{INAUTH}, L{AUTHED}, L{RENAMING}. + + @ivar shell: The connected avatar + @ivar binary: The transfer mode. If false, ASCII. + @ivar dtpFactory: Generates a single DTP for this session + @ivar dtpPort: Port returned from listenTCP + @ivar listenFactory: A callable with the signature of + L{twisted.internet.interfaces.IReactorTCP.listenTCP} which will be used + to create Ports for passive connections (mainly for testing). + + @ivar passivePortRange: iterator used as source of passive port numbers. + @type passivePortRange: C{iterator} + + @cvar UNAUTH: Command channel is not yet authenticated. + @type UNAUTH: L{int} + + @cvar INAUTH: Command channel is in the process of being authenticated. + @type INAUTH: L{int} + + @cvar AUTHED: Command channel was successfully authenticated. + @type AUTHED: L{int} + + @cvar RENAMING: Command channel is between the renaming command sequence. + @type RENAMING: L{int} + """ + + disconnected = False + + # States an FTP can be in + UNAUTH, INAUTH, AUTHED, RENAMING = range(4) + + # how long the DTP waits for a connection + dtpTimeout = 10 + + portal = None + shell = None + dtpFactory = None + dtpPort = None + dtpInstance = None + binary = True + PUBLIC_COMMANDS = ['FEAT', 'QUIT'] + FEATURES = ['FEAT', 'MDTM', 'PASV', 'SIZE', 'TYPE A;I'] + + passivePortRange = range(0, 1) + + listenFactory = reactor.listenTCP + _encoding = 'latin-1' + + def reply(self, key, *args): + msg = RESPONSE[key] % args + self.sendLine(msg) + + + def sendLine(self, line): + """ + (Private) Encodes and sends a line + + @param line: L{bytes} or L{unicode} + """ + if isinstance(line, unicode): + line = line.encode(self._encoding) + super(FTP, self).sendLine(line) + + + def connectionMade(self): + self.state = self.UNAUTH + self.setTimeout(self.timeOut) + self.reply(WELCOME_MSG, self.factory.welcomeMessage) + + def connectionLost(self, reason): + # if we have a DTP protocol instance running and + # we lose connection to the client's PI, kill the + # DTP connection and close the port + if self.dtpFactory: + self.cleanupDTP() + self.setTimeout(None) + if hasattr(self.shell, 'logout') and self.shell.logout is not None: + self.shell.logout() + self.shell = None + self.transport = None + + def timeoutConnection(self): + self.transport.loseConnection() + + def lineReceived(self, line): + self.resetTimeout() + self.pauseProducing() + if bytes != str: + line = line.decode(self._encoding) + + def processFailed(err): + if err.check(FTPCmdError): + self.sendLine(err.value.response()) + elif (err.check(TypeError) and any(( + msg in err.value.args[0] for msg in ( + 'takes exactly', 'required positional argument')))): + self.reply(SYNTAX_ERR, "%s requires an argument." % (cmd,)) + else: + log.msg("Unexpected FTP error") + log.err(err) + self.reply(REQ_ACTN_NOT_TAKEN, "internal server error") + + def processSucceeded(result): + if isinstance(result, tuple): + self.reply(*result) + elif result is not None: + self.reply(result) + + def allDone(ignored): + if not self.disconnected: + self.resumeProducing() + + spaceIndex = line.find(' ') + if spaceIndex != -1: + cmd = line[:spaceIndex] + args = (line[spaceIndex + 1:],) + else: + cmd = line + args = () + d = defer.maybeDeferred(self.processCommand, cmd, *args) + d.addCallbacks(processSucceeded, processFailed) + d.addErrback(log.err) + + # XXX It burnsss + # LineReceiver doesn't let you resumeProducing inside + # lineReceived atm + from twisted.internet import reactor + reactor.callLater(0, d.addBoth, allDone) + + + def processCommand(self, cmd, *params): + + def call_ftp_command(command): + method = getattr(self, "ftp_" + command, None) + if method is not None: + return method(*params) + return defer.fail(CmdNotImplementedError(command)) + + cmd = cmd.upper() + + if cmd in self.PUBLIC_COMMANDS: + return call_ftp_command(cmd) + + elif self.state == self.UNAUTH: + if cmd == 'USER': + return self.ftp_USER(*params) + elif cmd == 'PASS': + return BAD_CMD_SEQ, "USER required before PASS" + else: + return NOT_LOGGED_IN + + elif self.state == self.INAUTH: + if cmd == 'PASS': + return self.ftp_PASS(*params) + else: + return BAD_CMD_SEQ, "PASS required after USER" + + elif self.state == self.AUTHED: + return call_ftp_command(cmd) + + elif self.state == self.RENAMING: + if cmd == 'RNTO': + return self.ftp_RNTO(*params) + else: + return BAD_CMD_SEQ, "RNTO required after RNFR" + + + def getDTPPort(self, factory): + """ + Return a port for passive access, using C{self.passivePortRange} + attribute. + """ + for portn in self.passivePortRange: + try: + dtpPort = self.listenFactory(portn, factory) + except error.CannotListenError: + continue + else: + return dtpPort + raise error.CannotListenError( + '', portn, + "No port available in range %s" % (self.passivePortRange,)) + + + def ftp_USER(self, username): + """ + First part of login. Get the username the peer wants to + authenticate as. + """ + if not username: + return defer.fail(CmdSyntaxError('USER requires an argument')) + + self._user = username + self.state = self.INAUTH + if (self.factory.allowAnonymous and + self._user == self.factory.userAnonymous): + return GUEST_NAME_OK_NEED_EMAIL + else: + return (USR_NAME_OK_NEED_PASS, username) + + # TODO: add max auth try before timeout from ip... + # TODO: need to implement minimal ABOR command + + def ftp_PASS(self, password): + """ + Second part of login. Get the password the peer wants to + authenticate with. + """ + if (self.factory.allowAnonymous and + self._user == self.factory.userAnonymous): + # anonymous login + creds = credentials.Anonymous() + reply = GUEST_LOGGED_IN_PROCEED + else: + # user login + creds = credentials.UsernamePassword(self._user, password) + reply = USR_LOGGED_IN_PROCEED + del self._user + + def _cbLogin(result): + (interface, avatar, logout) = result + assert interface is IFTPShell, "The realm is busted, jerk." + self.shell = avatar + self.logout = logout + self.workingDirectory = [] + self.state = self.AUTHED + return reply + + def _ebLogin(failure): + failure.trap( + cred_error.UnauthorizedLogin, cred_error.UnhandledCredentials) + self.state = self.UNAUTH + raise AuthorizationError + + d = self.portal.login(creds, None, IFTPShell) + d.addCallbacks(_cbLogin, _ebLogin) + return d + + + def ftp_PASV(self): + """ + Request for a passive connection + + from the rfc:: + + This command requests the server-DTP to \"listen\" on a data port + (which is not its default data port) and to wait for a connection + rather than initiate one upon receipt of a transfer command. The + response to this command includes the host and port address this + server is listening on. + """ + # if we have a DTP port set up, lose it. + if self.dtpFactory is not None: + # cleanupDTP sets dtpFactory to none. Later we'll do + # cleanup here or something. + self.cleanupDTP() + self.dtpFactory = DTPFactory(pi=self) + self.dtpFactory.setTimeout(self.dtpTimeout) + self.dtpPort = self.getDTPPort(self.dtpFactory) + + host = self.transport.getHost().host + port = self.dtpPort.getHost().port + self.reply(ENTERING_PASV_MODE, encodeHostPort(host, port)) + return self.dtpFactory.deferred.addCallback(lambda ign: None) + + + def ftp_PORT(self, address): + addr = tuple(map(int, address.split(','))) + ip = '%d.%d.%d.%d' % tuple(addr[:4]) + port = addr[4] << 8 | addr[5] + + # if we have a DTP port set up, lose it. + if self.dtpFactory is not None: + self.cleanupDTP() + + self.dtpFactory = DTPFactory( + pi=self, peerHost=self.transport.getPeer().host) + self.dtpFactory.setTimeout(self.dtpTimeout) + self.dtpPort = reactor.connectTCP(ip, port, self.dtpFactory) + + def connected(ignored): + return ENTERING_PORT_MODE + + def connFailed(err): + err.trap(PortConnectionError) + return CANT_OPEN_DATA_CNX + + return self.dtpFactory.deferred.addCallbacks(connected, connFailed) + + + def _encodeName(self, name): + """ + Encode C{name} to be sent over the wire. + + This encodes L{unicode} objects as UTF-8 and leaves L{bytes} as-is. + + As described by U{RFC 3659 section + 2.2<https://tools.ietf.org/html/rfc3659#section-2.2>}:: + + Various FTP commands take pathnames as arguments, or return + pathnames in responses. When the MLST command is supported, as + indicated in the response to the FEAT command, pathnames are to be + transferred in one of the following two formats. + + pathname = utf-8-name / raw + utf-8-name = <a UTF-8 encoded Unicode string> + raw = <any string that is not a valid UTF-8 encoding> + + Which format is used is at the option of the user-PI or server-PI + sending the pathname. + + @param name: Name to be encoded. + @type name: L{bytes} or L{unicode} + + @return: Wire format of C{name}. + @rtype: L{bytes} + """ + if isinstance(name, unicode): + return name.encode('utf-8') + return name + + + def ftp_LIST(self, path=''): + """ This command causes a list to be sent from the server to the + passive DTP. If the pathname specifies a directory or other + group of files, the server should transfer a list of files + in the specified directory. If the pathname specifies a + file then the server should send current information on the + file. A null argument implies the user's current working or + default directory. + """ + # XXX: why is this check different from ftp_RETR/ftp_STOR? See #4180 + if self.dtpInstance is None or not self.dtpInstance.isConnected: + return defer.fail( + BadCmdSequenceError('must send PORT or PASV before RETR')) + + # Various clients send flags like -L or -al etc. We just ignore them. + if path.lower() in ['-a', '-l', '-la', '-al']: + path = '' + + def gotListing(results): + self.reply(DATA_CNX_ALREADY_OPEN_START_XFR) + for (name, attrs) in results: + name = self._encodeName(name) + self.dtpInstance.sendListResponse(name, attrs) + self.dtpInstance.transport.loseConnection() + return (TXFR_COMPLETE_OK,) + + try: + segments = toSegments(self.workingDirectory, path) + except InvalidPath: + return defer.fail(FileNotFoundError(path)) + + d = self.shell.list( + segments, + ('size', 'directory', 'permissions', 'hardlinks', + 'modified', 'owner', 'group')) + d.addCallback(gotListing) + return d + + + def ftp_NLST(self, path): + """ + This command causes a directory listing to be sent from the server to + the client. The pathname should specify a directory or other + system-specific file group descriptor. An empty path implies the + current working directory. If the path is non-existent, send nothing. + If the path is to a file, send only the file name. + + @type path: C{str} + @param path: The path for which a directory listing should be returned. + + @rtype: L{Deferred} + @return: a L{Deferred} which will be fired when the listing request + is finished. + """ + # XXX: why is this check different from ftp_RETR/ftp_STOR? See #4180 + if self.dtpInstance is None or not self.dtpInstance.isConnected: + return defer.fail( + BadCmdSequenceError('must send PORT or PASV before RETR')) + + try: + segments = toSegments(self.workingDirectory, path) + except InvalidPath: + return defer.fail(FileNotFoundError(path)) + + def cbList(results, glob): + """ + Send, line by line, each matching file in the directory listing, + and then close the connection. + + @type results: A C{list} of C{tuple}. The first element of each + C{tuple} is a C{str} and the second element is a C{list}. + @param results: The names of the files in the directory. + + @param glob: A shell-style glob through which to filter results + (see U{http://docs.python.org/2/library/fnmatch.html}), or + L{None} for no filtering. + @type glob: L{str} or L{None} + + @return: A C{tuple} containing the status code for a successful + transfer. + @rtype: C{tuple} + """ + self.reply(DATA_CNX_ALREADY_OPEN_START_XFR) + for (name, ignored) in results: + if not glob or (glob and fnmatch.fnmatch(name, glob)): + name = self._encodeName(name) + self.dtpInstance.sendLine(name) + self.dtpInstance.transport.loseConnection() + return (TXFR_COMPLETE_OK,) + + def listErr(results): + """ + RFC 959 specifies that an NLST request may only return directory + listings. Thus, send nothing and just close the connection. + + @type results: L{Failure} + @param results: The L{Failure} wrapping a L{FileNotFoundError} that + occurred while trying to list the contents of a nonexistent + directory. + + @returns: A C{tuple} containing the status code for a successful + transfer. + @rtype: C{tuple} + """ + self.dtpInstance.transport.loseConnection() + return (TXFR_COMPLETE_OK,) + + if _isGlobbingExpression(segments): + # Remove globbing expression from path + # and keep to be used for filtering. + glob = segments.pop() + else: + glob = None + + d = self.shell.list(segments) + d.addCallback(cbList, glob) + # self.shell.list will generate an error if the path is invalid + d.addErrback(listErr) + return d + + + def ftp_CWD(self, path): + try: + segments = toSegments(self.workingDirectory, path) + except InvalidPath: + # XXX Eh, what to fail with here? + return defer.fail(FileNotFoundError(path)) + + def accessGranted(result): + self.workingDirectory = segments + return (REQ_FILE_ACTN_COMPLETED_OK,) + + return self.shell.access(segments).addCallback(accessGranted) + + + def ftp_CDUP(self): + return self.ftp_CWD('..') + + + def ftp_PWD(self): + return (PWD_REPLY, '/' + '/'.join(self.workingDirectory)) + + + def ftp_RETR(self, path): + """ + This command causes the content of a file to be sent over the data + transfer channel. If the path is to a folder, an error will be raised. + + @type path: C{str} + @param path: The path to the file which should be transferred over the + data transfer channel. + + @rtype: L{Deferred} + @return: a L{Deferred} which will be fired when the transfer is done. + """ + if self.dtpInstance is None: + raise BadCmdSequenceError('PORT or PASV required before RETR') + + try: + newsegs = toSegments(self.workingDirectory, path) + except InvalidPath: + return defer.fail(FileNotFoundError(path)) + + # XXX For now, just disable the timeout. Later we'll want to + # leave it active and have the DTP connection reset it + # periodically. + self.setTimeout(None) + + # Put it back later + def enableTimeout(result): + self.setTimeout(self.factory.timeOut) + return result + + # And away she goes + if not self.binary: + cons = ASCIIConsumerWrapper(self.dtpInstance) + else: + cons = self.dtpInstance + + def cbSent(result): + return (TXFR_COMPLETE_OK,) + + def ebSent(err): + log.msg("Unexpected error attempting to transmit file to client:") + log.err(err) + if err.check(FTPCmdError): + return err + return (CNX_CLOSED_TXFR_ABORTED,) + + def cbOpened(file): + # Tell them what to doooo + if self.dtpInstance.isConnected: + self.reply(DATA_CNX_ALREADY_OPEN_START_XFR) + else: + self.reply(FILE_STATUS_OK_OPEN_DATA_CNX) + + d = file.send(cons) + d.addCallbacks(cbSent, ebSent) + return d + + def ebOpened(err): + if not err.check( + PermissionDeniedError, FileNotFoundError, + IsADirectoryError): + log.msg( + "Unexpected error attempting to open file for " + "transmission:") + log.err(err) + if err.check(FTPCmdError): + return (err.value.errorCode, '/'.join(newsegs)) + return (FILE_NOT_FOUND, '/'.join(newsegs)) + + d = self.shell.openForReading(newsegs) + d.addCallbacks(cbOpened, ebOpened) + d.addBoth(enableTimeout) + + # Pass back Deferred that fires when the transfer is done + return d + + + def ftp_STOR(self, path): + """ + STORE (STOR) + + This command causes the server-DTP to accept the data + transferred via the data connection and to store the data as + a file at the server site. If the file specified in the + pathname exists at the server site, then its contents shall + be replaced by the data being transferred. A new file is + created at the server site if the file specified in the + pathname does not already exist. + """ + if self.dtpInstance is None: + raise BadCmdSequenceError('PORT or PASV required before STOR') + + try: + newsegs = toSegments(self.workingDirectory, path) + except InvalidPath: + return defer.fail(FileNotFoundError(path)) + + # XXX For now, just disable the timeout. Later we'll want to + # leave it active and have the DTP connection reset it + # periodically. + self.setTimeout(None) + + # Put it back later + def enableTimeout(result): + self.setTimeout(self.factory.timeOut) + return result + + def cbOpened(file): + """ + File was open for reading. Launch the data transfer channel via + the file consumer. + """ + d = file.receive() + d.addCallback(cbConsumer) + d.addCallback(lambda ignored: file.close()) + d.addCallbacks(cbSent, ebSent) + return d + + def ebOpened(err): + """ + Called when failed to open the file for reading. + + For known errors, return the FTP error code. + For all other, return a file not found error. + """ + if isinstance(err.value, FTPCmdError): + return (err.value.errorCode, '/'.join(newsegs)) + log.err(err, "Unexpected error received while opening file:") + return (FILE_NOT_FOUND, '/'.join(newsegs)) + + def cbConsumer(cons): + """ + Called after the file was opended for reading. + + Prepare the data transfer channel and send the response + to the command channel. + """ + if not self.binary: + cons = ASCIIConsumerWrapper(cons) + + d = self.dtpInstance.registerConsumer(cons) + + # Tell them what to doooo + if self.dtpInstance.isConnected: + self.reply(DATA_CNX_ALREADY_OPEN_START_XFR) + else: + self.reply(FILE_STATUS_OK_OPEN_DATA_CNX) + + return d + + def cbSent(result): + """ + Called from data transport when tranfer is done. + """ + return (TXFR_COMPLETE_OK,) + + def ebSent(err): + """ + Called from data transport when there are errors during the + transfer. + """ + log.err(err, "Unexpected error received during transfer:") + if err.check(FTPCmdError): + return err + return (CNX_CLOSED_TXFR_ABORTED,) + + d = self.shell.openForWriting(newsegs) + d.addCallbacks(cbOpened, ebOpened) + d.addBoth(enableTimeout) + + # Pass back Deferred that fires when the transfer is done + return d + + + def ftp_SIZE(self, path): + """ + File SIZE + + The FTP command, SIZE OF FILE (SIZE), is used to obtain the transfer + size of a file from the server-FTP process. This is the exact number + of octets (8 bit bytes) that would be transmitted over the data + connection should that file be transmitted. This value will change + depending on the current STRUcture, MODE, and TYPE of the data + connection or of a data connection that would be created were one + created now. Thus, the result of the SIZE command is dependent on + the currently established STRU, MODE, and TYPE parameters. + + The SIZE command returns how many octets would be transferred if the + file were to be transferred using the current transfer structure, + mode, and type. This command is normally used in conjunction with + the RESTART (REST) command when STORing a file to a remote server in + STREAM mode, to determine the restart point. The server-PI might + need to read the partially transferred file, do any appropriate + conversion, and count the number of octets that would be generated + when sending the file in order to correctly respond to this command. + Estimates of the file transfer size MUST NOT be returned; only + precise information is acceptable. + + http://tools.ietf.org/html/rfc3659 + """ + try: + newsegs = toSegments(self.workingDirectory, path) + except InvalidPath: + return defer.fail(FileNotFoundError(path)) + + def cbStat(result): + (size,) = result + return (FILE_STATUS, str(size)) + + return self.shell.stat(newsegs, ('size',)).addCallback(cbStat) + + + def ftp_MDTM(self, path): + """ + File Modification Time (MDTM) + + The FTP command, MODIFICATION TIME (MDTM), can be used to determine + when a file in the server NVFS was last modified. This command has + existed in many FTP servers for many years, as an adjunct to the REST + command for STREAM mode, thus is widely available. However, where + supported, the "modify" fact that can be provided in the result from + the new MLST command is recommended as a superior alternative. + + http://tools.ietf.org/html/rfc3659 + """ + try: + newsegs = toSegments(self.workingDirectory, path) + except InvalidPath: + return defer.fail(FileNotFoundError(path)) + + def cbStat(result): + (modified,) = result + return ( + FILE_STATUS, + time.strftime('%Y%m%d%H%M%S', time.gmtime(modified))) + + return self.shell.stat(newsegs, ('modified',)).addCallback(cbStat) + + + def ftp_TYPE(self, type): + """ + REPRESENTATION TYPE (TYPE) + + The argument specifies the representation type as described + in the Section on Data Representation and Storage. Several + types take a second parameter. The first parameter is + denoted by a single Telnet character, as is the second + Format parameter for ASCII and EBCDIC; the second parameter + for local byte is a decimal integer to indicate Bytesize. + The parameters are separated by a <SP> (Space, ASCII code + 32). + """ + p = type.upper() + if p: + f = getattr(self, 'type_' + p[0], None) + if f is not None: + return f(p[1:]) + return self.type_UNKNOWN(p) + return (SYNTAX_ERR,) + + def type_A(self, code): + if code == '' or code == 'N': + self.binary = False + return (TYPE_SET_OK, 'A' + code) + else: + return defer.fail(CmdArgSyntaxError(code)) + + def type_I(self, code): + if code == '': + self.binary = True + return (TYPE_SET_OK, 'I') + else: + return defer.fail(CmdArgSyntaxError(code)) + + def type_UNKNOWN(self, code): + return defer.fail(CmdNotImplementedForArgError(code)) + + + def ftp_SYST(self): + return NAME_SYS_TYPE + + + def ftp_STRU(self, structure): + p = structure.upper() + if p == 'F': + return (CMD_OK,) + return defer.fail(CmdNotImplementedForArgError(structure)) + + + def ftp_MODE(self, mode): + p = mode.upper() + if p == 'S': + return (CMD_OK,) + return defer.fail(CmdNotImplementedForArgError(mode)) + + + def ftp_MKD(self, path): + try: + newsegs = toSegments(self.workingDirectory, path) + except InvalidPath: + return defer.fail(FileNotFoundError(path)) + return self.shell.makeDirectory(newsegs).addCallback( + lambda ign: (MKD_REPLY, path)) + + + def ftp_RMD(self, path): + try: + newsegs = toSegments(self.workingDirectory, path) + except InvalidPath: + return defer.fail(FileNotFoundError(path)) + return self.shell.removeDirectory(newsegs).addCallback( + lambda ign: (REQ_FILE_ACTN_COMPLETED_OK,)) + + + def ftp_DELE(self, path): + try: + newsegs = toSegments(self.workingDirectory, path) + except InvalidPath: + return defer.fail(FileNotFoundError(path)) + return self.shell.removeFile(newsegs).addCallback( + lambda ign: (REQ_FILE_ACTN_COMPLETED_OK,)) + + + def ftp_NOOP(self): + return (CMD_OK,) + + + def ftp_RNFR(self, fromName): + self._fromName = fromName + self.state = self.RENAMING + return (REQ_FILE_ACTN_PENDING_FURTHER_INFO,) + + + def ftp_RNTO(self, toName): + fromName = self._fromName + del self._fromName + self.state = self.AUTHED + + try: + fromsegs = toSegments(self.workingDirectory, fromName) + tosegs = toSegments(self.workingDirectory, toName) + except InvalidPath: + return defer.fail(FileNotFoundError(fromName)) + return self.shell.rename(fromsegs, tosegs).addCallback( + lambda ign: (REQ_FILE_ACTN_COMPLETED_OK,)) + + + def ftp_FEAT(self): + """ + Advertise the features supported by the server. + + http://tools.ietf.org/html/rfc2389 + """ + self.sendLine(RESPONSE[FEAT_OK][0]) + for feature in self.FEATURES: + self.sendLine(' ' + feature) + self.sendLine(RESPONSE[FEAT_OK][1]) + + def ftp_OPTS(self, option): + """ + Handle OPTS command. + + http://tools.ietf.org/html/draft-ietf-ftpext-utf-8-option-00 + """ + return self.reply(OPTS_NOT_IMPLEMENTED, option) + + def ftp_QUIT(self): + self.reply(GOODBYE_MSG) + self.transport.loseConnection() + self.disconnected = True + + def cleanupDTP(self): + """ + Call when DTP connection exits + """ + log.msg('cleanupDTP', debug=True) + + log.msg(self.dtpPort) + dtpPort, self.dtpPort = self.dtpPort, None + if interfaces.IListeningPort.providedBy(dtpPort): + dtpPort.stopListening() + elif interfaces.IConnector.providedBy(dtpPort): + dtpPort.disconnect() + else: + assert False, ( + "dtpPort should be an IListeningPort or IConnector, " + "instead is %r" % (dtpPort,)) + + self.dtpFactory.stopFactory() + self.dtpFactory = None + + if self.dtpInstance is not None: + self.dtpInstance = None + + + +class FTPFactory(policies.LimitTotalConnectionsFactory): + """ + A factory for producing ftp protocol instances + + @ivar timeOut: the protocol interpreter's idle timeout time in seconds, + default is 600 seconds. + + @ivar passivePortRange: value forwarded to C{protocol.passivePortRange}. + @type passivePortRange: C{iterator} + """ + protocol = FTP + overflowProtocol = FTPOverflowProtocol + allowAnonymous = True + userAnonymous = 'anonymous' + timeOut = 600 + + welcomeMessage = "Twisted %s FTP Server" % (copyright.version,) + + passivePortRange = range(0, 1) + + def __init__(self, portal=None, userAnonymous='anonymous'): + self.portal = portal + self.userAnonymous = userAnonymous + self.instances = [] + + def buildProtocol(self, addr): + p = policies.LimitTotalConnectionsFactory.buildProtocol(self, addr) + if p is not None: + p.wrappedProtocol.portal = self.portal + p.wrappedProtocol.timeOut = self.timeOut + p.wrappedProtocol.passivePortRange = self.passivePortRange + return p + + def stopFactory(self): + # make sure ftp instance's timeouts are set to None + # to avoid reactor complaints + [p.setTimeout(None) for p in self.instances if p.timeOut is not None] + policies.LimitTotalConnectionsFactory.stopFactory(self) + + + +# -- Cred Objects -- + +class IFTPShell(Interface): + """ + An abstraction of the shell commands used by the FTP protocol for + a given user account. + + All path names must be absolute. + """ + + def makeDirectory(path): + """ + Create a directory. + + @param path: The path, as a list of segments, to create + @type path: C{list} of C{unicode} + + @return: A Deferred which fires when the directory has been + created, or which fails if the directory cannot be created. + """ + + + def removeDirectory(path): + """ + Remove a directory. + + @param path: The path, as a list of segments, to remove + @type path: C{list} of C{unicode} + + @return: A Deferred which fires when the directory has been + removed, or which fails if the directory cannot be removed. + """ + + + def removeFile(path): + """ + Remove a file. + + @param path: The path, as a list of segments, to remove + @type path: C{list} of C{unicode} + + @return: A Deferred which fires when the file has been + removed, or which fails if the file cannot be removed. + """ + + + def rename(fromPath, toPath): + """ + Rename a file or directory. + + @param fromPath: The current name of the path. + @type fromPath: C{list} of C{unicode} + + @param toPath: The desired new name of the path. + @type toPath: C{list} of C{unicode} + + @return: A Deferred which fires when the path has been + renamed, or which fails if the path cannot be renamed. + """ + + + def access(path): + """ + Determine whether access to the given path is allowed. + + @param path: The path, as a list of segments + + @return: A Deferred which fires with None if access is allowed + or which fails with a specific exception type if access is + denied. + """ + + + def stat(path, keys=()): + """ + Retrieve information about the given path. + + This is like list, except it will never return results about + child paths. + """ + + + def list(path, keys=()): + """ + Retrieve information about the given path. + + If the path represents a non-directory, the result list should + have only one entry with information about that non-directory. + Otherwise, the result list should have an element for each + child of the directory. + + @param path: The path, as a list of segments, to list + @type path: C{list} of C{unicode} or C{bytes} + + @param keys: A tuple of keys desired in the resulting + dictionaries. + + @return: A Deferred which fires with a list of (name, list), + where the name is the name of the entry as a unicode string or + bytes and each list contains values corresponding to the requested + keys. The following are possible elements of keys, and the + values which should be returned for them: + + - C{'size'}: size in bytes, as an integer (this is kinda required) + + - C{'directory'}: boolean indicating the type of this entry + + - C{'permissions'}: a bitvector (see os.stat(foo).st_mode) + + - C{'hardlinks'}: Number of hard links to this entry + + - C{'modified'}: number of seconds since the epoch since entry was + modified + + - C{'owner'}: string indicating the user owner of this entry + + - C{'group'}: string indicating the group owner of this entry + """ + + + def openForReading(path): + """ + @param path: The path, as a list of segments, to open + @type path: C{list} of C{unicode} + + @rtype: C{Deferred} which will fire with L{IReadFile} + """ + + + def openForWriting(path): + """ + @param path: The path, as a list of segments, to open + @type path: C{list} of C{unicode} + + @rtype: C{Deferred} which will fire with L{IWriteFile} + """ + + + +class IReadFile(Interface): + """ + A file out of which bytes may be read. + """ + + def send(consumer): + """ + Produce the contents of the given path to the given consumer. This + method may only be invoked once on each provider. + + @type consumer: C{IConsumer} + + @return: A Deferred which fires when the file has been + consumed completely. + """ + + + +class IWriteFile(Interface): + """ + A file into which bytes may be written. + """ + + def receive(): + """ + Create a consumer which will write to this file. This method may + only be invoked once on each provider. + + @rtype: C{Deferred} of C{IConsumer} + """ + + def close(): + """ + Perform any post-write work that needs to be done. This method may + only be invoked once on each provider, and will always be invoked + after receive(). + + @rtype: C{Deferred} of anything: the value is ignored. The FTP client + will not see their upload request complete until this Deferred has + been fired. + """ + + + +def _getgroups(uid): + """ + Return the primary and supplementary groups for the given UID. + + @type uid: C{int} + """ + result = [] + pwent = pwd.getpwuid(uid) + + result.append(pwent.pw_gid) + + for grent in grp.getgrall(): + if pwent.pw_name in grent.gr_mem: + result.append(grent.gr_gid) + + return result + + + +def _testPermissions(uid, gid, spath, mode='r'): + """ + checks to see if uid has proper permissions to access path with mode + + @type uid: C{int} + @param uid: numeric user id + + @type gid: C{int} + @param gid: numeric group id + + @type spath: C{str} + @param spath: the path on the server to test + + @type mode: C{str} + @param mode: 'r' or 'w' (read or write) + + @rtype: C{bool} + @return: True if the given credentials have the specified form of + access to the given path + """ + if mode == 'r': + usr = stat.S_IRUSR + grp = stat.S_IRGRP + oth = stat.S_IROTH + amode = os.R_OK + elif mode == 'w': + usr = stat.S_IWUSR + grp = stat.S_IWGRP + oth = stat.S_IWOTH + amode = os.W_OK + else: + raise ValueError("Invalid mode %r: must specify 'r' or 'w'" % (mode,)) + + access = False + if os.path.exists(spath): + if uid == 0: + access = True + else: + s = os.stat(spath) + if usr & s.st_mode and uid == s.st_uid: + access = True + elif grp & s.st_mode and gid in _getgroups(uid): + access = True + elif oth & s.st_mode: + access = True + + if access: + if not os.access(spath, amode): + access = False + log.msg( + "Filesystem grants permission to UID %d but it is " + "inaccessible to me running as UID %d" % ( + uid, os.getuid())) + return access + + + +@implementer(IFTPShell) +class FTPAnonymousShell(object): + """ + An anonymous implementation of IFTPShell + + @type filesystemRoot: L{twisted.python.filepath.FilePath} + @ivar filesystemRoot: The path which is considered the root of + this shell. + """ + def __init__(self, filesystemRoot): + self.filesystemRoot = filesystemRoot + + + def _path(self, path): + return self.filesystemRoot.descendant(path) + + + def makeDirectory(self, path): + return defer.fail(AnonUserDeniedError()) + + + def removeDirectory(self, path): + return defer.fail(AnonUserDeniedError()) + + + def removeFile(self, path): + return defer.fail(AnonUserDeniedError()) + + + def rename(self, fromPath, toPath): + return defer.fail(AnonUserDeniedError()) + + + def receive(self, path): + path = self._path(path) + return defer.fail(AnonUserDeniedError()) + + + def openForReading(self, path): + """ + Open C{path} for reading. + + @param path: The path, as a list of segments, to open. + @type path: C{list} of C{unicode} + @return: A L{Deferred} is returned that will fire with an object + implementing L{IReadFile} if the file is successfully opened. If + C{path} is a directory, or if an exception is raised while trying + to open the file, the L{Deferred} will fire with an error. + """ + p = self._path(path) + if p.isdir(): + # Normally, we would only check for EISDIR in open, but win32 + # returns EACCES in this case, so we check before + return defer.fail(IsADirectoryError(path)) + try: + f = p.open('r') + except (IOError, OSError) as e: + return errnoToFailure(e.errno, path) + except: + return defer.fail() + else: + return defer.succeed(_FileReader(f)) + + + def openForWriting(self, path): + """ + Reject write attempts by anonymous users with + L{PermissionDeniedError}. + """ + return defer.fail(PermissionDeniedError("STOR not allowed")) + + + def access(self, path): + p = self._path(path) + if not p.exists(): + # Again, win32 doesn't report a sane error after, so let's fail + # early if we can + return defer.fail(FileNotFoundError(path)) + # For now, just see if we can os.listdir() it + try: + p.listdir() + except (IOError, OSError) as e: + return errnoToFailure(e.errno, path) + except: + return defer.fail() + else: + return defer.succeed(None) + + + def stat(self, path, keys=()): + p = self._path(path) + if p.isdir(): + try: + statResult = self._statNode(p, keys) + except (IOError, OSError) as e: + return errnoToFailure(e.errno, path) + except: + return defer.fail() + else: + return defer.succeed(statResult) + else: + return self.list(path, keys).addCallback(lambda res: res[0][1]) + + + def list(self, path, keys=()): + """ + Return the list of files at given C{path}, adding C{keys} stat + informations if specified. + + @param path: the directory or file to check. + @type path: C{str} + + @param keys: the list of desired metadata + @type keys: C{list} of C{str} + """ + filePath = self._path(path) + if filePath.isdir(): + entries = filePath.listdir() + fileEntries = [filePath.child(p) for p in entries] + elif filePath.isfile(): + entries = [ + os.path.join(*filePath.segmentsFrom(self.filesystemRoot))] + fileEntries = [filePath] + else: + return defer.fail(FileNotFoundError(path)) + + results = [] + for fileName, filePath in zip(entries, fileEntries): + ent = [] + results.append((fileName, ent)) + if keys: + try: + ent.extend(self._statNode(filePath, keys)) + except (IOError, OSError) as e: + return errnoToFailure(e.errno, fileName) + except: + return defer.fail() + + return defer.succeed(results) + + + def _statNode(self, filePath, keys): + """ + Shortcut method to get stat info on a node. + + @param filePath: the node to stat. + @type filePath: C{filepath.FilePath} + + @param keys: the stat keys to get. + @type keys: C{iterable} + """ + filePath.restat() + return [getattr(self, '_stat_' + k)(filePath) for k in keys] + + + def _stat_size(self, fp): + """ + Get the filepath's size as an int + + @param fp: L{twisted.python.filepath.FilePath} + @return: C{int} representing the size + """ + return fp.getsize() + + + def _stat_permissions(self, fp): + """ + Get the filepath's permissions object + + @param fp: L{twisted.python.filepath.FilePath} + @return: L{twisted.python.filepath.Permissions} of C{fp} + """ + return fp.getPermissions() + + + def _stat_hardlinks(self, fp): + """ + Get the number of hardlinks for the filepath - if the number of + hardlinks is not yet implemented (say in Windows), just return 0 since + stat-ing a file in Windows seems to return C{st_nlink=0}. + + (Reference: + U{http://stackoverflow.com/questions/5275731/os-stat-on-windows}) + + @param fp: L{twisted.python.filepath.FilePath} + @return: C{int} representing the number of hardlinks + """ + try: + return fp.getNumberOfHardLinks() + except NotImplementedError: + return 0 + + + def _stat_modified(self, fp): + """ + Get the filepath's last modified date + + @param fp: L{twisted.python.filepath.FilePath} + @return: C{int} as seconds since the epoch + """ + return fp.getModificationTime() + + + def _stat_owner(self, fp): + """ + Get the filepath's owner's username. If this is not implemented + (say in Windows) return the string "0" since stat-ing a file in + Windows seems to return C{st_uid=0}. + + (Reference: + U{http://stackoverflow.com/questions/5275731/os-stat-on-windows}) + + @param fp: L{twisted.python.filepath.FilePath} + @return: C{str} representing the owner's username + """ + try: + userID = fp.getUserID() + except NotImplementedError: + return "0" + else: + if pwd is not None: + try: + return pwd.getpwuid(userID)[0] + except KeyError: + pass + return str(userID) + + + def _stat_group(self, fp): + """ + Get the filepath's owner's group. If this is not implemented + (say in Windows) return the string "0" since stat-ing a file in + Windows seems to return C{st_gid=0}. + + (Reference: + U{http://stackoverflow.com/questions/5275731/os-stat-on-windows}) + + @param fp: L{twisted.python.filepath.FilePath} + @return: C{str} representing the owner's group + """ + try: + groupID = fp.getGroupID() + except NotImplementedError: + return "0" + else: + if grp is not None: + try: + return grp.getgrgid(groupID)[0] + except KeyError: + pass + return str(groupID) + + + def _stat_directory(self, fp): + """ + Get whether the filepath is a directory + + @param fp: L{twisted.python.filepath.FilePath} + @return: C{bool} + """ + return fp.isdir() + + + +@implementer(IReadFile) +class _FileReader(object): + def __init__(self, fObj): + self.fObj = fObj + self._send = False + + def _close(self, passthrough): + self._send = True + self.fObj.close() + return passthrough + + def send(self, consumer): + assert not self._send, ( + "Can only call IReadFile.send *once* per instance") + self._send = True + d = basic.FileSender().beginFileTransfer(self.fObj, consumer) + d.addBoth(self._close) + return d + + + +class FTPShell(FTPAnonymousShell): + """ + An authenticated implementation of L{IFTPShell}. + """ + + def makeDirectory(self, path): + p = self._path(path) + try: + p.makedirs() + except (IOError, OSError) as e: + return errnoToFailure(e.errno, path) + except: + return defer.fail() + else: + return defer.succeed(None) + + + def removeDirectory(self, path): + p = self._path(path) + if p.isfile(): + # Win32 returns the wrong errno when rmdir is called on a file + # instead of a directory, so as we have the info here, let's fail + # early with a pertinent error + return defer.fail(IsNotADirectoryError(path)) + try: + os.rmdir(p.path) + except (IOError, OSError) as e: + return errnoToFailure(e.errno, path) + except: + return defer.fail() + else: + return defer.succeed(None) + + + def removeFile(self, path): + p = self._path(path) + if p.isdir(): + # Win32 returns the wrong errno when remove is called on a + # directory instead of a file, so as we have the info here, + # let's fail early with a pertinent error + return defer.fail(IsADirectoryError(path)) + try: + p.remove() + except (IOError, OSError) as e: + return errnoToFailure(e.errno, path) + except: + return defer.fail() + else: + return defer.succeed(None) + + + def rename(self, fromPath, toPath): + fp = self._path(fromPath) + tp = self._path(toPath) + try: + os.rename(fp.path, tp.path) + except (IOError, OSError) as e: + return errnoToFailure(e.errno, fromPath) + except: + return defer.fail() + else: + return defer.succeed(None) + + + def openForWriting(self, path): + """ + Open C{path} for writing. + + @param path: The path, as a list of segments, to open. + @type path: C{list} of C{unicode} + @return: A L{Deferred} is returned that will fire with an object + implementing L{IWriteFile} if the file is successfully opened. If + C{path} is a directory, or if an exception is raised while trying + to open the file, the L{Deferred} will fire with an error. + """ + p = self._path(path) + if p.isdir(): + # Normally, we would only check for EISDIR in open, but win32 + # returns EACCES in this case, so we check before + return defer.fail(IsADirectoryError(path)) + try: + fObj = p.open('w') + except (IOError, OSError) as e: + return errnoToFailure(e.errno, path) + except: + return defer.fail() + return defer.succeed(_FileWriter(fObj)) + + + +@implementer(IWriteFile) +class _FileWriter(object): + def __init__(self, fObj): + self.fObj = fObj + self._receive = False + + def receive(self): + assert not self._receive, ( + "Can only call IWriteFile.receive *once* per instance") + self._receive = True + # FileConsumer will close the file object + return defer.succeed(FileConsumer(self.fObj)) + + def close(self): + return defer.succeed(None) + + + +@implementer(portal.IRealm) +class BaseFTPRealm: + """ + Base class for simple FTP realms which provides an easy hook for specifying + the home directory for each user. + """ + def __init__(self, anonymousRoot): + self.anonymousRoot = filepath.FilePath(anonymousRoot) + + + def getHomeDirectory(self, avatarId): + """ + Return a L{FilePath} representing the home directory of the given + avatar. Override this in a subclass. + + @param avatarId: A user identifier returned from a credentials checker. + @type avatarId: C{str} + + @rtype: L{FilePath} + """ + raise NotImplementedError( + "%r did not override getHomeDirectory" % (self.__class__,)) + + + def requestAvatar(self, avatarId, mind, *interfaces): + for iface in interfaces: + if iface is IFTPShell: + if avatarId is checkers.ANONYMOUS: + avatar = FTPAnonymousShell(self.anonymousRoot) + else: + avatar = FTPShell(self.getHomeDirectory(avatarId)) + return (IFTPShell, avatar, + getattr(avatar, 'logout', lambda: None)) + raise NotImplementedError( + "Only IFTPShell interface is supported by this realm") + + + +class FTPRealm(BaseFTPRealm): + """ + @type anonymousRoot: L{twisted.python.filepath.FilePath} + @ivar anonymousRoot: Root of the filesystem to which anonymous + users will be granted access. + + @type userHome: L{filepath.FilePath} + @ivar userHome: Root of the filesystem containing user home directories. + """ + def __init__(self, anonymousRoot, userHome='/home'): + BaseFTPRealm.__init__(self, anonymousRoot) + self.userHome = filepath.FilePath(userHome) + + + def getHomeDirectory(self, avatarId): + """ + Use C{avatarId} as a single path segment to construct a child of + C{self.userHome} and return that child. + """ + return self.userHome.child(avatarId) + + + +class SystemFTPRealm(BaseFTPRealm): + """ + L{SystemFTPRealm} uses system user account information to decide what the + home directory for a particular avatarId is. + + This works on POSIX but probably is not reliable on Windows. + """ + def getHomeDirectory(self, avatarId): + """ + Return the system-defined home directory of the system user account + with the name C{avatarId}. + """ + path = os.path.expanduser('~' + avatarId) + if path.startswith('~'): + raise cred_error.UnauthorizedLogin() + return filepath.FilePath(path) + + + +# --- FTP CLIENT ------------------------------------------------------------- + +#### +# And now for the client... + +# Notes: +# * Reference: http://cr.yp.to/ftp.html +# * FIXME: Does not support pipelining (which is not supported by all +# servers anyway). This isn't a functionality limitation, just a +# small performance issue. +# * Only has a rudimentary understanding of FTP response codes (although +# the full response is passed to the caller if they so choose). +# * Assumes that USER and PASS should always be sent +# * Always sets TYPE I (binary mode) +# * Doesn't understand any of the weird, obscure TELNET stuff (\377...) +# * FIXME: Doesn't share any code with the FTPServer + +class ConnectionLost(FTPError): + pass + + + +class CommandFailed(FTPError): + pass + + + +class BadResponse(FTPError): + pass + + + +class UnexpectedResponse(FTPError): + pass + + + +class UnexpectedData(FTPError): + pass + + + +class FTPCommand: + def __init__(self, text=None, public=0): + self.text = text + self.deferred = defer.Deferred() + self.ready = 1 + self.public = public + self.transferDeferred = None + + def fail(self, failure): + if self.public: + self.deferred.errback(failure) + + + +class ProtocolWrapper(protocol.Protocol): + def __init__(self, original, deferred): + self.original = original + self.deferred = deferred + + def makeConnection(self, transport): + self.original.makeConnection(transport) + + def dataReceived(self, data): + self.original.dataReceived(data) + + def connectionLost(self, reason): + self.original.connectionLost(reason) + # Signal that transfer has completed + self.deferred.callback(None) + + + +class IFinishableConsumer(interfaces.IConsumer): + """ + A Consumer for producers that finish. + + @since: 11.0 + """ + + def finish(): + """ + The producer has finished producing. + """ + + + +@implementer(IFinishableConsumer) +class SenderProtocol(protocol.Protocol): + def __init__(self): + # Fired upon connection + self.connectedDeferred = defer.Deferred() + + # Fired upon disconnection + self.deferred = defer.Deferred() + + # Protocol stuff + def dataReceived(self, data): + raise UnexpectedData( + "Received data from the server on a " + "send-only data-connection" + ) + + def makeConnection(self, transport): + protocol.Protocol.makeConnection(self, transport) + self.connectedDeferred.callback(self) + + def connectionLost(self, reason): + if reason.check(error.ConnectionDone): + self.deferred.callback('connection done') + else: + self.deferred.errback(reason) + + # IFinishableConsumer stuff + def write(self, data): + self.transport.write(data) + + def registerProducer(self, producer, streaming): + """ + Register the given producer with our transport. + """ + self.transport.registerProducer(producer, streaming) + + def unregisterProducer(self): + """ + Unregister the previously registered producer. + """ + self.transport.unregisterProducer() + + def finish(self): + self.transport.loseConnection() + + + +def decodeHostPort(line): + """ + Decode an FTP response specifying a host and port. + + @return: a 2-tuple of (host, port). + """ + abcdef = re.sub('[^0-9, ]', '', line) + parsed = [int(p.strip()) for p in abcdef.split(',')] + for x in parsed: + if x < 0 or x > 255: + raise ValueError("Out of range", line, x) + a, b, c, d, e, f = parsed + host = "%s.%s.%s.%s" % (a, b, c, d) + port = (int(e) << 8) + int(f) + return host, port + + + +def encodeHostPort(host, port): + numbers = host.split('.') + [str(port >> 8), str(port % 256)] + return ','.join(numbers) + + + +def _unwrapFirstError(failure): + failure.trap(defer.FirstError) + return failure.value.subFailure + + + +class FTPDataPortFactory(protocol.ServerFactory): + """ + Factory for data connections that use the PORT command + + (i.e. "active" transfers) + """ + noisy = 0 + + def buildProtocol(self, addr): + # This is a bit hackish -- we already have a Protocol instance, + # so just return it instead of making a new one + # FIXME: Reject connections from the wrong address/port + # (potential security problem) + self.protocol.factory = self + self.port.loseConnection() + return self.protocol + + + +class FTPClientBasic(basic.LineReceiver): + """ + Foundations of an FTP client. + """ + debug = False + _encoding = 'latin-1' + + def __init__(self): + self.actionQueue = [] + self.greeting = None + self.nextDeferred = defer.Deferred().addCallback(self._cb_greeting) + self.nextDeferred.addErrback(self.fail) + self.response = [] + self._failed = 0 + + def fail(self, error): + """ + Give an error to any queued deferreds. + """ + self._fail(error) + + def _fail(self, error): + """ + Errback all queued deferreds. + """ + if self._failed: + # We're recursing; bail out here for simplicity + return error + self._failed = 1 + if self.nextDeferred: + try: + self.nextDeferred.errback(failure.Failure( + ConnectionLost('FTP connection lost', error))) + except defer.AlreadyCalledError: + pass + for ftpCommand in self.actionQueue: + ftpCommand.fail(failure.Failure( + ConnectionLost('FTP connection lost', error))) + return error + + def _cb_greeting(self, greeting): + self.greeting = greeting + + + def sendLine(self, line): + """ + Sends a line, unless line is None. + + @param line: Line to send + @type line: L{bytes} or L{unicode} + """ + if line is None: + return + elif isinstance(line, unicode): + line = line.encode(self._encoding) + basic.LineReceiver.sendLine(self, line) + + + def sendNextCommand(self): + """ + (Private) Processes the next command in the queue. + """ + ftpCommand = self.popCommandQueue() + if ftpCommand is None: + self.nextDeferred = None + return + if not ftpCommand.ready: + self.actionQueue.insert(0, ftpCommand) + reactor.callLater(1.0, self.sendNextCommand) + self.nextDeferred = None + return + + # FIXME: this if block doesn't belong in FTPClientBasic, it belongs in + # FTPClient. + if ftpCommand.text == 'PORT': + self.generatePortCommand(ftpCommand) + + if self.debug: + log.msg('<-- %s' % ftpCommand.text) + self.nextDeferred = ftpCommand.deferred + self.sendLine(ftpCommand.text) + + def queueCommand(self, ftpCommand): + """ + Add an FTPCommand object to the queue. + + If it's the only thing in the queue, and we are connected and we aren't + waiting for a response of an earlier command, the command will be sent + immediately. + + @param ftpCommand: an L{FTPCommand} + """ + self.actionQueue.append(ftpCommand) + if (len(self.actionQueue) == 1 and self.transport is not None and + self.nextDeferred is None): + self.sendNextCommand() + + def queueStringCommand(self, command, public=1): + """ + Queues a string to be issued as an FTP command + + @param command: string of an FTP command to queue + @param public: a flag intended for internal use by FTPClient. Don't + change it unless you know what you're doing. + + @return: a L{Deferred} that will be called when the response to the + command has been received. + """ + ftpCommand = FTPCommand(command, public) + self.queueCommand(ftpCommand) + return ftpCommand.deferred + + def popCommandQueue(self): + """ + Return the front element of the command queue, or None if empty. + """ + if self.actionQueue: + return self.actionQueue.pop(0) + else: + return None + + def queueLogin(self, username, password): + """ + Login: send the username, send the password. + + If the password is L{None}, the PASS command won't be sent. Also, if + the response to the USER command has a response code of 230 (User + logged in), then PASS won't be sent either. + """ + # Prepare the USER command + deferreds = [] + userDeferred = self.queueStringCommand('USER ' + username, public=0) + deferreds.append(userDeferred) + + # Prepare the PASS command (if a password is given) + if password is not None: + passwordCmd = FTPCommand('PASS ' + password, public=0) + self.queueCommand(passwordCmd) + deferreds.append(passwordCmd.deferred) + + # Avoid sending PASS if the response to USER is 230. + # (ref: http://cr.yp.to/ftp/user.html#user) + def cancelPasswordIfNotNeeded(response): + if response[0].startswith('230'): + # No password needed! + self.actionQueue.remove(passwordCmd) + return response + userDeferred.addCallback(cancelPasswordIfNotNeeded) + + # Error handling. + for deferred in deferreds: + # If something goes wrong, call fail + deferred.addErrback(self.fail) + # But also swallow the error, so we don't cause spurious errors + deferred.addErrback(lambda x: None) + + def lineReceived(self, line): + """ + (Private) Parses the response messages from the FTP server. + """ + # Add this line to the current response + if bytes != str: + line = line.decode(self._encoding) + + if self.debug: + log.msg('--> %s' % line) + self.response.append(line) + + # Bail out if this isn't the last line of a response + # The last line of response starts with 3 digits followed by a space + codeIsValid = re.match(r'\d{3} ', line) + if not codeIsValid: + return + + code = line[0:3] + + # Ignore marks + if code[0] == '1': + return + + # Check that we were expecting a response + if self.nextDeferred is None: + self.fail(UnexpectedResponse(self.response)) + return + + # Reset the response + response = self.response + self.response = [] + + # Look for a success or error code, and call the appropriate callback + if code[0] in ('2', '3'): + # Success + self.nextDeferred.callback(response) + elif code[0] in ('4', '5'): + # Failure + self.nextDeferred.errback(failure.Failure(CommandFailed(response))) + else: + # This shouldn't happen unless something screwed up. + log.msg('Server sent invalid response code %s' % (code,)) + self.nextDeferred.errback(failure.Failure(BadResponse(response))) + + # Run the next command + self.sendNextCommand() + + def connectionLost(self, reason): + self._fail(reason) + + + +class _PassiveConnectionFactory(protocol.ClientFactory): + noisy = False + + def __init__(self, protoInstance): + self.protoInstance = protoInstance + + def buildProtocol(self, ignored): + self.protoInstance.factory = self + return self.protoInstance + + def clientConnectionFailed(self, connector, reason): + e = FTPError('Connection Failed', reason) + self.protoInstance.deferred.errback(e) + + + +class FTPClient(FTPClientBasic): + """ + L{FTPClient} is a client implementation of the FTP protocol which + exposes FTP commands as methods which return L{Deferred}s. + + Each command method returns a L{Deferred} which is called back when a + successful response code (2xx or 3xx) is received from the server or + which is error backed if an error response code (4xx or 5xx) is received + from the server or if a protocol violation occurs. If an error response + code is received, the L{Deferred} fires with a L{Failure} wrapping a + L{CommandFailed} instance. The L{CommandFailed} instance is created + with a list of the response lines received from the server. + + See U{RFC 959<http://www.ietf.org/rfc/rfc959.txt>} for error code + definitions. + + Both active and passive transfers are supported. + + @ivar passive: See description in __init__. + """ + connectFactory = reactor.connectTCP + + def __init__(self, username='anonymous', + password='twisted@twistedmatrix.com', + passive=1): + """ + Constructor. + + I will login as soon as I receive the welcome message from the server. + + @param username: FTP username + @param password: FTP password + @param passive: flag that controls if I use active or passive data + connections. You can also change this after construction by + assigning to C{self.passive}. + """ + FTPClientBasic.__init__(self) + self.queueLogin(username, password) + + self.passive = passive + + def fail(self, error): + """ + Disconnect, and also give an error to any queued deferreds. + """ + self.transport.loseConnection() + self._fail(error) + + def receiveFromConnection(self, commands, protocol): + """ + Retrieves a file or listing generated by the given command, + feeding it to the given protocol. + + @param commands: list of strings of FTP commands to execute then + receive the results of (e.g. C{LIST}, C{RETR}) + @param protocol: A L{Protocol} B{instance} e.g. an + L{FTPFileListProtocol}, or something that can be adapted to one. + Typically this will be an L{IConsumer} implementation. + + @return: L{Deferred}. + """ + protocol = interfaces.IProtocol(protocol) + wrapper = ProtocolWrapper(protocol, defer.Deferred()) + return self._openDataConnection(commands, wrapper) + + def queueLogin(self, username, password): + """ + Login: send the username, send the password, and + set retrieval mode to binary + """ + FTPClientBasic.queueLogin(self, username, password) + d = self.queueStringCommand('TYPE I', public=0) + # If something goes wrong, call fail + d.addErrback(self.fail) + # But also swallow the error, so we don't cause spurious errors + d.addErrback(lambda x: None) + + def sendToConnection(self, commands): + """ + XXX + + @return: A tuple of two L{Deferred}s: + - L{Deferred} L{IFinishableConsumer}. You must call + the C{finish} method on the IFinishableConsumer when the + file is completely transferred. + - L{Deferred} list of control-connection responses. + """ + s = SenderProtocol() + r = self._openDataConnection(commands, s) + return (s.connectedDeferred, r) + + def _openDataConnection(self, commands, protocol): + """ + This method returns a DeferredList. + """ + cmds = [FTPCommand(command, public=1) for command in commands] + cmdsDeferred = defer.DeferredList( + [cmd.deferred for cmd in cmds], + fireOnOneErrback=True, consumeErrors=True) + cmdsDeferred.addErrback(_unwrapFirstError) + + if self.passive: + # Hack: use a mutable object to sneak a variable out of the + # scope of doPassive + _mutable = [None] + + def doPassive(response): + """Connect to the port specified in the response to PASV""" + host, port = decodeHostPort(response[-1][4:]) + + f = _PassiveConnectionFactory(protocol) + _mutable[0] = self.connectFactory(host, port, f) + + pasvCmd = FTPCommand('PASV') + self.queueCommand(pasvCmd) + pasvCmd.deferred.addCallback(doPassive).addErrback(self.fail) + + results = [cmdsDeferred, pasvCmd.deferred, protocol.deferred] + d = defer.DeferredList( + results, fireOnOneErrback=True, consumeErrors=True) + d.addErrback(_unwrapFirstError) + + # Ensure the connection is always closed + def close(x, m=_mutable): + m[0] and m[0].disconnect() + return x + d.addBoth(close) + + else: + # We just place a marker command in the queue, and will fill in + # the host and port numbers later (see generatePortCommand) + portCmd = FTPCommand('PORT') + + # Ok, now we jump through a few hoops here. + # This is the problem: a transfer is not to be trusted as complete + # until we get both the "226 Transfer complete" message on the + # control connection, and the data socket is closed. Thus, we use + # a DeferredList to make sure we only fire the callback at the + # right time. + + portCmd.transferDeferred = protocol.deferred + portCmd.protocol = protocol + portCmd.deferred.addErrback(portCmd.transferDeferred.errback) + self.queueCommand(portCmd) + + # Create dummy functions for the next callback to call. + # These will also be replaced with real functions in + # generatePortCommand. + portCmd.loseConnection = lambda result: result + portCmd.fail = lambda error: error + + # Ensure that the connection always gets closed + cmdsDeferred.addErrback(lambda e, pc=portCmd: pc.fail(e) or e) + + results = [ + cmdsDeferred, portCmd.deferred, portCmd.transferDeferred] + d = defer.DeferredList( + results, fireOnOneErrback=True, consumeErrors=True) + d.addErrback(_unwrapFirstError) + + for cmd in cmds: + self.queueCommand(cmd) + return d + + def generatePortCommand(self, portCmd): + """ + (Private) Generates the text of a given PORT command. + """ + + # The problem is that we don't create the listening port until we need + # it for various reasons, and so we have to muck about to figure out + # what interface and port it's listening on, and then finally we can + # create the text of the PORT command to send to the FTP server. + + # FIXME: This method is far too ugly. + + # FIXME: The best solution is probably to only create the data port + # once per FTPClient, and just recycle it for each new download. + # This should be ok, because we don't pipeline commands. + + # Start listening on a port + factory = FTPDataPortFactory() + factory.protocol = portCmd.protocol + listener = reactor.listenTCP(0, factory) + factory.port = listener + + # Ensure we close the listening port if something goes wrong + def listenerFail(error, listener=listener): + if listener.connected: + listener.loseConnection() + return error + portCmd.fail = listenerFail + + # Construct crufty FTP magic numbers that represent host & port + host = self.transport.getHost().host + port = listener.getHost().port + portCmd.text = 'PORT ' + encodeHostPort(host, port) + + def escapePath(self, path): + """ + Returns a FTP escaped path (replace newlines with nulls). + """ + # Escape newline characters + return path.replace('\n', '\0') + + def retrieveFile(self, path, protocol, offset=0): + """ + Retrieve a file from the given path + + This method issues the 'RETR' FTP command. + + The file is fed into the given Protocol instance. The data connection + will be passive if self.passive is set. + + @param path: path to file that you wish to receive. + @param protocol: a L{Protocol} instance. + @param offset: offset to start downloading from + + @return: L{Deferred} + """ + cmds = ['RETR ' + self.escapePath(path)] + if offset: + cmds.insert(0, ('REST ' + str(offset))) + return self.receiveFromConnection(cmds, protocol) + + retr = retrieveFile + + def storeFile(self, path, offset=0): + """ + Store a file at the given path. + + This method issues the 'STOR' FTP command. + + @return: A tuple of two L{Deferred}s: + - L{Deferred} L{IFinishableConsumer}. You must call + the C{finish} method on the IFinishableConsumer when the + file is completely transferred. + - L{Deferred} list of control-connection responses. + """ + cmds = ['STOR ' + self.escapePath(path)] + if offset: + cmds.insert(0, ('REST ' + str(offset))) + return self.sendToConnection(cmds) + + stor = storeFile + + + def rename(self, pathFrom, pathTo): + """ + Rename a file. + + This method issues the I{RNFR}/I{RNTO} command sequence to rename + C{pathFrom} to C{pathTo}. + + @param: pathFrom: the absolute path to the file to be renamed + @type pathFrom: C{str} + + @param: pathTo: the absolute path to rename the file to. + @type pathTo: C{str} + + @return: A L{Deferred} which fires when the rename operation has + succeeded or failed. If it succeeds, the L{Deferred} is called + back with a two-tuple of lists. The first list contains the + responses to the I{RNFR} command. The second list contains the + responses to the I{RNTO} command. If either I{RNFR} or I{RNTO} + fails, the L{Deferred} is errbacked with L{CommandFailed} or + L{BadResponse}. + @rtype: L{Deferred} + + @since: 8.2 + """ + renameFrom = self.queueStringCommand( + 'RNFR ' + self.escapePath(pathFrom)) + renameTo = self.queueStringCommand('RNTO ' + self.escapePath(pathTo)) + + fromResponse = [] + + # Use a separate Deferred for the ultimate result so that Deferred + # chaining can't interfere with its result. + result = defer.Deferred() + # Bundle up all the responses + result.addCallback(lambda toResponse: (fromResponse, toResponse)) + + def ebFrom(failure): + # Make sure the RNTO doesn't run if the RNFR failed. + self.popCommandQueue() + result.errback(failure) + + # Save the RNFR response to pass to the result Deferred later + renameFrom.addCallbacks(fromResponse.extend, ebFrom) + + # Hook up the RNTO to the result Deferred as well + renameTo.chainDeferred(result) + + return result + + + def list(self, path, protocol): + """ + Retrieve a file listing into the given protocol instance. + + This method issues the 'LIST' FTP command. + + @param path: path to get a file listing for. + @param protocol: a L{Protocol} instance, probably a + L{FTPFileListProtocol} instance. It can cope with most common file + listing formats. + + @return: L{Deferred} + """ + if path is None: + path = '' + return self.receiveFromConnection( + ['LIST ' + self.escapePath(path)], protocol) + + + def nlst(self, path, protocol): + """ + Retrieve a short file listing into the given protocol instance. + + This method issues the 'NLST' FTP command. + + NLST (should) return a list of filenames, one per line. + + @param path: path to get short file listing for. + @param protocol: a L{Protocol} instance. + """ + if path is None: + path = '' + return self.receiveFromConnection( + ['NLST ' + self.escapePath(path)], protocol) + + + def cwd(self, path): + """ + Issues the CWD (Change Working Directory) command. + + @return: a L{Deferred} that will be called when done. + """ + return self.queueStringCommand('CWD ' + self.escapePath(path)) + + + def makeDirectory(self, path): + """ + Make a directory + + This method issues the MKD command. + + @param path: The path to the directory to create. + @type path: C{str} + + @return: A L{Deferred} which fires when the server responds. If the + directory is created, the L{Deferred} is called back with the + server response. If the server response indicates the directory + was not created, the L{Deferred} is errbacked with a L{Failure} + wrapping L{CommandFailed} or L{BadResponse}. + @rtype: L{Deferred} + + @since: 8.2 + """ + return self.queueStringCommand('MKD ' + self.escapePath(path)) + + + def removeFile(self, path): + """ + Delete a file on the server. + + L{removeFile} issues a I{DELE} command to the server to remove the + indicated file. Note that this command cannot remove a directory. + + @param path: The path to the file to delete. May be relative to the + current dir. + @type path: C{str} + + @return: A L{Deferred} which fires when the server responds. On error, + it is errbacked with either L{CommandFailed} or L{BadResponse}. On + success, it is called back with a list of response lines. + @rtype: L{Deferred} + + @since: 8.2 + """ + return self.queueStringCommand('DELE ' + self.escapePath(path)) + + + def removeDirectory(self, path): + """ + Delete a directory on the server. + + L{removeDirectory} issues a I{RMD} command to the server to remove the + indicated directory. Described in RFC959. + + @param path: The path to the directory to delete. May be relative to + the current working directory. + @type path: C{str} + + @return: A L{Deferred} which fires when the server responds. On error, + it is errbacked with either L{CommandFailed} or L{BadResponse}. On + success, it is called back with a list of response lines. + @rtype: L{Deferred} + + @since: 11.1 + """ + return self.queueStringCommand('RMD ' + self.escapePath(path)) + + + def cdup(self): + """ + Issues the CDUP (Change Directory UP) command. + + @return: a L{Deferred} that will be called when done. + """ + return self.queueStringCommand('CDUP') + + + def pwd(self): + """ + Issues the PWD (Print Working Directory) command. + + The L{getDirectory} does the same job but automatically parses the + result. + + @return: a L{Deferred} that will be called when done. It is up to the + caller to interpret the response, but the L{parsePWDResponse} + method in this module should work. + """ + return self.queueStringCommand('PWD') + + + def getDirectory(self): + """ + Returns the current remote directory. + + @return: a L{Deferred} that will be called back with a C{str} giving + the remote directory or which will errback with L{CommandFailed} + if an error response is returned. + """ + def cbParse(result): + try: + # The only valid code is 257 + if int(result[0].split(' ', 1)[0]) != 257: + raise ValueError + except (IndexError, ValueError): + return failure.Failure(CommandFailed(result)) + path = parsePWDResponse(result[0]) + if path is None: + return failure.Failure(CommandFailed(result)) + return path + return self.pwd().addCallback(cbParse) + + + def quit(self): + """ + Issues the I{QUIT} command. + + @return: A L{Deferred} that fires when the server acknowledges the + I{QUIT} command. The transport should not be disconnected until + this L{Deferred} fires. + """ + return self.queueStringCommand('QUIT') + + + +class FTPFileListProtocol(basic.LineReceiver): + """ + Parser for standard FTP file listings + + This is the evil required to match:: + + -rw-r--r-- 1 root other 531 Jan 29 03:26 README + + If you need different evil for a wacky FTP server, you can + override either C{fileLinePattern} or C{parseDirectoryLine()}. + + It populates the instance attribute self.files, which is a list containing + dicts with the following keys (examples from the above line): + - filetype: e.g. 'd' for directories, or '-' for an ordinary file + - perms: e.g. 'rw-r--r--' + - nlinks: e.g. 1 + - owner: e.g. 'root' + - group: e.g. 'other' + - size: e.g. 531 + - date: e.g. 'Jan 29 03:26' + - filename: e.g. 'README' + - linktarget: e.g. 'some/file' + + Note that the 'date' value will be formatted differently depending on the + date. Check U{http://cr.yp.to/ftp.html} if you really want to try to parse + it. + + It also matches the following:: + -rw-r--r-- 1 root other 531 Jan 29 03:26 I HAVE\\ SPACE + - filename: e.g. 'I HAVE SPACE' + + -rw-r--r-- 1 root other 531 Jan 29 03:26 LINK -> TARGET + - filename: e.g. 'LINK' + - linktarget: e.g. 'TARGET' + + -rw-r--r-- 1 root other 531 Jan 29 03:26 N S -> L S + - filename: e.g. 'N S' + - linktarget: e.g. 'L S' + + @ivar files: list of dicts describing the files in this listing + """ + fileLinePattern = re.compile( + r'^(?P<filetype>.)(?P<perms>.{9})\s+(?P<nlinks>\d*)\s*' + r'(?P<owner>\S+)\s+(?P<group>\S+)\s+(?P<size>\d+)\s+' + r'(?P<date>...\s+\d+\s+[\d:]+)\s+(?P<filename>.{1,}?)' + r'( -> (?P<linktarget>[^\r]*))?\r?$' + ) + delimiter = b'\n' + _encoding = 'latin-1' + + def __init__(self): + self.files = [] + + def lineReceived(self, line): + if bytes != str: + line = line.decode(self._encoding) + d = self.parseDirectoryLine(line) + if d is None: + self.unknownLine(line) + else: + self.addFile(d) + + def parseDirectoryLine(self, line): + """ + Return a dictionary of fields, or None if line cannot be parsed. + + @param line: line of text expected to contain a directory entry + @type line: str + + @return: dict + """ + match = self.fileLinePattern.match(line) + if match is None: + return None + else: + d = match.groupdict() + d['filename'] = d['filename'].replace(r'\ ', ' ') + d['nlinks'] = int(d['nlinks']) + d['size'] = int(d['size']) + if d['linktarget']: + d['linktarget'] = d['linktarget'].replace(r'\ ', ' ') + return d + + def addFile(self, info): + """ + Append file information dictionary to the list of known files. + + Subclasses can override or extend this method to handle file + information differently without affecting the parsing of data + from the server. + + @param info: dictionary containing the parsed representation + of the file information + @type info: dict + """ + self.files.append(info) + + def unknownLine(self, line): + """ + Deal with received lines which could not be parsed as file + information. + + Subclasses can override this to perform any special processing + needed. + + @param line: unparsable line as received + @type line: str + """ + pass + + + +def parsePWDResponse(response): + """ + Returns the path from a response to a PWD command. + + Responses typically look like:: + + 257 "/home/andrew" is current directory. + + For this example, I will return C{'/home/andrew'}. + + If I can't find the path, I return L{None}. + """ + match = re.search('"(.*)"', response) + if match: + return match.groups()[0] + else: + return None diff --git a/contrib/python/Twisted/py2/twisted/protocols/haproxy/__init__.py b/contrib/python/Twisted/py2/twisted/protocols/haproxy/__init__.py new file mode 100644 index 0000000000..c238b7ac1a --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/haproxy/__init__.py @@ -0,0 +1,13 @@ +# -*- test-case-name: twisted.protocols.haproxy.test -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +HAProxy PROXY protocol implementations. +""" + +from ._wrapper import proxyEndpoint + +__all__ = [ + 'proxyEndpoint', +] diff --git a/contrib/python/Twisted/py2/twisted/protocols/haproxy/_exceptions.py b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_exceptions.py new file mode 100644 index 0000000000..8633ca4eec --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_exceptions.py @@ -0,0 +1,52 @@ +# -*- test-case-name: twisted.protocols.haproxy.test -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +HAProxy specific exceptions. +""" + +import contextlib +import sys + +from twisted.python import compat + + +class InvalidProxyHeader(Exception): + """ + The provided PROXY protocol header is invalid. + """ + + + +class InvalidNetworkProtocol(InvalidProxyHeader): + """ + The network protocol was not one of TCP4 TCP6 or UNKNOWN. + """ + + + +class MissingAddressData(InvalidProxyHeader): + """ + The address data is missing or incomplete. + """ + + + +@contextlib.contextmanager +def convertError(sourceType, targetType): + """ + Convert an error into a different error type. + + @param sourceType: The type of exception that should be caught and + converted. + @type sourceType: L{Exception} + + @param targetType: The type of exception to which the original should be + converted. + @type targetType: L{Exception} + """ + try: + yield None + except sourceType: + compat.reraise(targetType(), sys.exc_info()[-1]) diff --git a/contrib/python/Twisted/py2/twisted/protocols/haproxy/_info.py b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_info.py new file mode 100644 index 0000000000..489d7b2cee --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_info.py @@ -0,0 +1,36 @@ +# -*- test-case-name: twisted.protocols.haproxy.test -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +IProxyInfo implementation. +""" + +from zope.interface import implementer + +from ._interfaces import IProxyInfo + + +@implementer(IProxyInfo) +class ProxyInfo(object): + """ + A data container for parsed PROXY protocol information. + + @ivar header: The raw header bytes extracted from the connection. + @type header: bytes + @ivar source: The connection source address. + @type source: L{twisted.internet.interfaces.IAddress} + @ivar destination: The connection destination address. + @type destination: L{twisted.internet.interfaces.IAddress} + """ + + __slots__ = ( + 'header', + 'source', + 'destination', + ) + + def __init__(self, header, source, destination): + self.header = header + self.source = source + self.destination = destination diff --git a/contrib/python/Twisted/py2/twisted/protocols/haproxy/_interfaces.py b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_interfaces.py new file mode 100644 index 0000000000..3453ecf37b --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_interfaces.py @@ -0,0 +1,64 @@ +# -*- test-case-name: twisted.protocols.haproxy.test -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Interfaces used by the PROXY protocol modules. +""" + +import zope.interface + + +class IProxyInfo(zope.interface.Interface): + """ + Data container for PROXY protocol header data. + """ + + header = zope.interface.Attribute( + "The raw byestring that represents the PROXY protocol header.", + ) + source = zope.interface.Attribute( + "An L{twisted.internet.interfaces.IAddress} representing the " + "connection source." + ) + destination = zope.interface.Attribute( + "An L{twisted.internet.interfaces.IAddress} representing the " + "connection destination." + ) + + + +class IProxyParser(zope.interface.Interface): + """ + Streaming parser that handles PROXY protocol headers. + """ + + def feed(self, data): + """ + Consume a chunk of data and attempt to parse it. + + @param data: A bytestring. + @type data: bytes + + @return: A two-tuple containing, in order, an L{IProxyInfo} and any + bytes fed to the parser that followed the end of the header. Both + of these values are None until a complete header is parsed. + + @raises InvalidProxyHeader: If the bytes fed to the parser create an + invalid PROXY header. + """ + + + def parse(self, line): + """ + Parse a bytestring as a full PROXY protocol header line. + + @param line: A bytestring that represents a valid HAProxy PROXY + protocol header line. + @type line: bytes + + @return: An L{IProxyInfo} containing the parsed data. + + @raises InvalidProxyHeader: If the bytestring does not represent a + valid PROXY header. + """ diff --git a/contrib/python/Twisted/py2/twisted/protocols/haproxy/_parser.py b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_parser.py new file mode 100644 index 0000000000..35ef29fbb7 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_parser.py @@ -0,0 +1,71 @@ +# -*- test-case-name: twisted.protocols.haproxy.test.test_parser -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Parser for 'haproxy:' string endpoint. +""" + +from zope.interface import implementer +from twisted.plugin import IPlugin + +from twisted.internet.endpoints import ( + quoteStringArgument, serverFromString, IStreamServerEndpointStringParser +) +from twisted.python.compat import iteritems + +from . import proxyEndpoint + + +def unparseEndpoint(args, kwargs): + """ + Un-parse the already-parsed args and kwargs back into endpoint syntax. + + @param args: C{:}-separated arguments + @type args: L{tuple} of native L{str} + + @param kwargs: C{:} and then C{=}-separated keyword arguments + + @type arguments: L{tuple} of native L{str} + + @return: a string equivalent to the original format which this was parsed + as. + @rtype: native L{str} + """ + + description = ':'.join( + [quoteStringArgument(str(arg)) for arg in args] + + sorted(['%s=%s' % (quoteStringArgument(str(key)), + quoteStringArgument(str(value))) + for key, value in iteritems(kwargs) + ])) + return description + + + +@implementer(IPlugin, IStreamServerEndpointStringParser) +class HAProxyServerParser(object): + """ + Stream server endpoint string parser for the HAProxyServerEndpoint type. + + @ivar prefix: See L{IStreamServerEndpointStringParser.prefix}. + """ + prefix = "haproxy" + + def parseStreamServer(self, reactor, *args, **kwargs): + """ + Parse a stream server endpoint from a reactor and string-only arguments + and keyword arguments. + + @param reactor: The reactor. + + @param args: The parsed string arguments. + + @param kwargs: The parsed keyword arguments. + + @return: a stream server endpoint + @rtype: L{IStreamServerEndpoint} + """ + subdescription = unparseEndpoint(args, kwargs) + wrappedEndpoint = serverFromString(reactor, subdescription) + return proxyEndpoint(wrappedEndpoint) diff --git a/contrib/python/Twisted/py2/twisted/protocols/haproxy/_v1parser.py b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_v1parser.py new file mode 100644 index 0000000000..b17099f3cc --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_v1parser.py @@ -0,0 +1,143 @@ +# -*- test-case-name: twisted.protocols.haproxy.test.test_v1parser -*- + +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +IProxyParser implementation for version one of the PROXY protocol. +""" + +from zope.interface import implementer +from twisted.internet import address + +from ._exceptions import ( + convertError, InvalidProxyHeader, InvalidNetworkProtocol, + MissingAddressData +) +from . import _info +from . import _interfaces + + + +@implementer(_interfaces.IProxyParser) +class V1Parser(object): + """ + PROXY protocol version one header parser. + + Version one of the PROXY protocol is a human readable format represented + by a single, newline delimited binary string that contains all of the + relevant source and destination data. + """ + + PROXYSTR = b'PROXY' + UNKNOWN_PROTO = b'UNKNOWN' + TCP4_PROTO = b'TCP4' + TCP6_PROTO = b'TCP6' + ALLOWED_NET_PROTOS = ( + TCP4_PROTO, + TCP6_PROTO, + UNKNOWN_PROTO, + ) + NEWLINE = b'\r\n' + + def __init__(self): + self.buffer = b'' + + + def feed(self, data): + """ + Consume a chunk of data and attempt to parse it. + + @param data: A bytestring. + @type data: L{bytes} + + @return: A two-tuple containing, in order, a + L{_interfaces.IProxyInfo} and any bytes fed to the + parser that followed the end of the header. Both of these values + are None until a complete header is parsed. + + @raises InvalidProxyHeader: If the bytes fed to the parser create an + invalid PROXY header. + """ + self.buffer += data + if len(self.buffer) > 107 and self.NEWLINE not in self.buffer: + raise InvalidProxyHeader() + lines = (self.buffer).split(self.NEWLINE, 1) + if not len(lines) > 1: + return (None, None) + self.buffer = b'' + remaining = lines.pop() + header = lines.pop() + info = self.parse(header) + return (info, remaining) + + + @classmethod + def parse(cls, line): + """ + Parse a bytestring as a full PROXY protocol header line. + + @param line: A bytestring that represents a valid HAProxy PROXY + protocol header line. + @type line: bytes + + @return: A L{_interfaces.IProxyInfo} containing the parsed data. + + @raises InvalidProxyHeader: If the bytestring does not represent a + valid PROXY header. + + @raises InvalidNetworkProtocol: When no protocol can be parsed or is + not one of the allowed values. + + @raises MissingAddressData: When the protocol is TCP* but the header + does not contain a complete set of addresses and ports. + """ + originalLine = line + proxyStr = None + networkProtocol = None + sourceAddr = None + sourcePort = None + destAddr = None + destPort = None + + with convertError(ValueError, InvalidProxyHeader): + proxyStr, line = line.split(b' ', 1) + + if proxyStr != cls.PROXYSTR: + raise InvalidProxyHeader() + + with convertError(ValueError, InvalidNetworkProtocol): + networkProtocol, line = line.split(b' ', 1) + + if networkProtocol not in cls.ALLOWED_NET_PROTOS: + raise InvalidNetworkProtocol() + + if networkProtocol == cls.UNKNOWN_PROTO: + + return _info.ProxyInfo(originalLine, None, None) + + with convertError(ValueError, MissingAddressData): + sourceAddr, line = line.split(b' ', 1) + + with convertError(ValueError, MissingAddressData): + destAddr, line = line.split(b' ', 1) + + with convertError(ValueError, MissingAddressData): + sourcePort, line = line.split(b' ', 1) + + with convertError(ValueError, MissingAddressData): + destPort = line.split(b' ')[0] + + if networkProtocol == cls.TCP4_PROTO: + + return _info.ProxyInfo( + originalLine, + address.IPv4Address('TCP', sourceAddr, int(sourcePort)), + address.IPv4Address('TCP', destAddr, int(destPort)), + ) + + return _info.ProxyInfo( + originalLine, + address.IPv6Address('TCP', sourceAddr, int(sourcePort)), + address.IPv6Address('TCP', destAddr, int(destPort)), + ) diff --git a/contrib/python/Twisted/py2/twisted/protocols/haproxy/_v2parser.py b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_v2parser.py new file mode 100644 index 0000000000..94c495ffe2 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_v2parser.py @@ -0,0 +1,215 @@ +# -*- test-case-name: twisted.protocols.haproxy.test.test_v2parser -*- + +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +IProxyParser implementation for version two of the PROXY protocol. +""" + +import binascii +import struct + +from constantly import Values, ValueConstant + +from zope.interface import implementer +from twisted.internet import address +from twisted.python import compat + +from ._exceptions import ( + convertError, InvalidProxyHeader, InvalidNetworkProtocol, + MissingAddressData +) +from . import _info +from . import _interfaces + +class NetFamily(Values): + """ + Values for the 'family' field. + """ + UNSPEC = ValueConstant(0x00) + INET = ValueConstant(0x10) + INET6 = ValueConstant(0x20) + UNIX = ValueConstant(0x30) + + + +class NetProtocol(Values): + """ + Values for 'protocol' field. + """ + UNSPEC = ValueConstant(0) + STREAM = ValueConstant(1) + DGRAM = ValueConstant(2) + + +_HIGH = 0b11110000 +_LOW = 0b00001111 +_LOCALCOMMAND = 'LOCAL' +_PROXYCOMMAND = 'PROXY' + +@implementer(_interfaces.IProxyParser) +class V2Parser(object): + """ + PROXY protocol version two header parser. + + Version two of the PROXY protocol is a binary format. + """ + + PREFIX = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A' + VERSIONS = [32] + COMMANDS = {0: _LOCALCOMMAND, 1: _PROXYCOMMAND} + ADDRESSFORMATS = { + # TCP4 + 17: '!4s4s2H', + 18: '!4s4s2H', + # TCP6 + 33: '!16s16s2H', + 34: '!16s16s2H', + # UNIX + 49: '!108s108s', + 50: '!108s108s', + } + + def __init__(self): + self.buffer = b'' + + + def feed(self, data): + """ + Consume a chunk of data and attempt to parse it. + + @param data: A bytestring. + @type data: bytes + + @return: A two-tuple containing, in order, a L{_interfaces.IProxyInfo} + and any bytes fed to the parser that followed the end of the + header. Both of these values are None until a complete header is + parsed. + + @raises InvalidProxyHeader: If the bytes fed to the parser create an + invalid PROXY header. + """ + self.buffer += data + if len(self.buffer) < 16: + raise InvalidProxyHeader() + + size = struct.unpack('!H', self.buffer[14:16])[0] + 16 + if len(self.buffer) < size: + return (None, None) + + header, remaining = self.buffer[:size], self.buffer[size:] + self.buffer = b'' + info = self.parse(header) + return (info, remaining) + + + @staticmethod + def _bytesToIPv4(bytestring): + """ + Convert packed 32-bit IPv4 address bytes into a dotted-quad ASCII bytes + representation of that address. + + @param bytestring: 4 octets representing an IPv4 address. + @type bytestring: L{bytes} + + @return: a dotted-quad notation IPv4 address. + @rtype: L{bytes} + """ + return b'.'.join( + ('%i' % (ord(b),)).encode('ascii') + for b in compat.iterbytes(bytestring) + ) + + + @staticmethod + def _bytesToIPv6(bytestring): + """ + Convert packed 128-bit IPv6 address bytes into a colon-separated ASCII + bytes representation of that address. + + @param bytestring: 16 octets representing an IPv6 address. + @type bytestring: L{bytes} + + @return: a dotted-quad notation IPv6 address. + @rtype: L{bytes} + """ + hexString = binascii.b2a_hex(bytestring) + return b':'.join( + ('%x' % (int(hexString[b:b+4], 16),)).encode('ascii') + for b in range(0, 32, 4) + ) + + + @classmethod + def parse(cls, line): + """ + Parse a bytestring as a full PROXY protocol header. + + @param line: A bytestring that represents a valid HAProxy PROXY + protocol version 2 header. + @type line: bytes + + @return: A L{_interfaces.IProxyInfo} containing the + parsed data. + + @raises InvalidProxyHeader: If the bytestring does not represent a + valid PROXY header. + """ + prefix = line[:12] + addrInfo = None + with convertError(IndexError, InvalidProxyHeader): + # Use single value slices to ensure bytestring values are returned + # instead of int in PY3. + versionCommand = ord(line[12:13]) + familyProto = ord(line[13:14]) + + if prefix != cls.PREFIX: + raise InvalidProxyHeader() + + version, command = versionCommand & _HIGH, versionCommand & _LOW + if version not in cls.VERSIONS or command not in cls.COMMANDS: + raise InvalidProxyHeader() + + if cls.COMMANDS[command] == _LOCALCOMMAND: + return _info.ProxyInfo(line, None, None) + + family, netproto = familyProto & _HIGH, familyProto & _LOW + with convertError(ValueError, InvalidNetworkProtocol): + family = NetFamily.lookupByValue(family) + netproto = NetProtocol.lookupByValue(netproto) + if ( + family is NetFamily.UNSPEC or + netproto is NetProtocol.UNSPEC + ): + return _info.ProxyInfo(line, None, None) + + addressFormat = cls.ADDRESSFORMATS[familyProto] + addrInfo = line[16:16+struct.calcsize(addressFormat)] + if family is NetFamily.UNIX: + with convertError(struct.error, MissingAddressData): + source, dest = struct.unpack(addressFormat, addrInfo) + return _info.ProxyInfo( + line, + address.UNIXAddress(source.rstrip(b'\x00')), + address.UNIXAddress(dest.rstrip(b'\x00')), + ) + + addrType = 'TCP' + if netproto is NetProtocol.DGRAM: + addrType = 'UDP' + addrCls = address.IPv4Address + addrParser = cls._bytesToIPv4 + if family is NetFamily.INET6: + addrCls = address.IPv6Address + addrParser = cls._bytesToIPv6 + + with convertError(struct.error, MissingAddressData): + info = struct.unpack(addressFormat, addrInfo) + source, dest, sPort, dPort = info + + return _info.ProxyInfo( + line, + addrCls(addrType, addrParser(source), sPort), + addrCls(addrType, addrParser(dest), dPort), + ) diff --git a/contrib/python/Twisted/py2/twisted/protocols/haproxy/_wrapper.py b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_wrapper.py new file mode 100644 index 0000000000..a6e98892f3 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/haproxy/_wrapper.py @@ -0,0 +1,106 @@ +# -*- test-case-name: twisted.protocols.haproxy.test.test_wrapper -*- + +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Protocol wrapper that provides HAProxy PROXY protocol support. +""" + +from twisted.protocols import policies +from twisted.internet import interfaces +from twisted.internet.endpoints import _WrapperServerEndpoint + +from ._exceptions import InvalidProxyHeader +from ._v1parser import V1Parser +from ._v2parser import V2Parser + + + +class HAProxyProtocolWrapper(policies.ProtocolWrapper, object): + """ + A Protocol wrapper that provides HAProxy support. + + This protocol reads the PROXY stream header, v1 or v2, parses the provided + connection data, and modifies the behavior of getPeer and getHost to return + the data provided by the PROXY header. + """ + + def __init__(self, factory, wrappedProtocol): + policies.ProtocolWrapper.__init__(self, factory, wrappedProtocol) + self._proxyInfo = None + self._parser = None + + + def dataReceived(self, data): + if self._proxyInfo is not None: + return self.wrappedProtocol.dataReceived(data) + + if self._parser is None: + if ( + len(data) >= 16 and + data[:12] == V2Parser.PREFIX and + ord(data[12:13]) & 0b11110000 == 0x20 + ): + self._parser = V2Parser() + elif len(data) >= 8 and data[:5] == V1Parser.PROXYSTR: + self._parser = V1Parser() + else: + self.loseConnection() + return None + + try: + self._proxyInfo, remaining = self._parser.feed(data) + if remaining: + self.wrappedProtocol.dataReceived(remaining) + except InvalidProxyHeader: + self.loseConnection() + + + def getPeer(self): + if self._proxyInfo and self._proxyInfo.source: + return self._proxyInfo.source + return self.transport.getPeer() + + + def getHost(self): + if self._proxyInfo and self._proxyInfo.destination: + return self._proxyInfo.destination + return self.transport.getHost() + + + +class HAProxyWrappingFactory(policies.WrappingFactory): + """ + A Factory wrapper that adds PROXY protocol support to connections. + """ + protocol = HAProxyProtocolWrapper + + def logPrefix(self): + """ + Annotate the wrapped factory's log prefix with some text indicating + the PROXY protocol is in use. + + @rtype: C{str} + """ + if interfaces.ILoggingContext.providedBy(self.wrappedFactory): + logPrefix = self.wrappedFactory.logPrefix() + else: + logPrefix = self.wrappedFactory.__class__.__name__ + return "%s (PROXY)" % (logPrefix,) + + + +def proxyEndpoint(wrappedEndpoint): + """ + Wrap an endpoint with PROXY protocol support, so that the transport's + C{getHost} and C{getPeer} methods reflect the attributes of the proxied + connection rather than the underlying connection. + + @param wrappedEndpoint: The underlying listening endpoint. + @type wrappedEndpoint: L{IStreamServerEndpoint} + + @return: a new listening endpoint that speaks the PROXY protocol. + @rtype: L{IStreamServerEndpoint} + """ + return _WrapperServerEndpoint(wrappedEndpoint, HAProxyWrappingFactory) diff --git a/contrib/python/Twisted/py2/twisted/protocols/htb.py b/contrib/python/Twisted/py2/twisted/protocols/htb.py new file mode 100644 index 0000000000..22a9299bc6 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/htb.py @@ -0,0 +1,295 @@ +# -*- test-case-name: twisted.test.test_htb -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + + +""" +Hierarchical Token Bucket traffic shaping. + +Patterned after U{Martin Devera's Hierarchical Token Bucket traffic +shaper for the Linux kernel<http://luxik.cdi.cz/~devik/qos/htb/>}. + +@seealso: U{HTB Linux queuing discipline manual - user guide + <http://luxik.cdi.cz/~devik/qos/htb/manual/userg.htm>} +@seealso: U{Token Bucket Filter in Linux Advanced Routing & Traffic Control + HOWTO<http://lartc.org/howto/lartc.qdisc.classless.html#AEN682>} +""" + + +# TODO: Investigate whether we should be using os.times()[-1] instead of +# time.time. time.time, it has been pointed out, can go backwards. Is +# the same true of os.times? +from time import time +from zope.interface import implementer, Interface + +from twisted.protocols import pcp + + +class Bucket: + """ + Implementation of a Token bucket. + + A bucket can hold a certain number of tokens and it drains over time. + + @cvar maxburst: The maximum number of tokens that the bucket can + hold at any given time. If this is L{None}, the bucket has + an infinite size. + @type maxburst: C{int} + @cvar rate: The rate at which the bucket drains, in number + of tokens per second. If the rate is L{None}, the bucket + drains instantaneously. + @type rate: C{int} + """ + + maxburst = None + rate = None + + _refcount = 0 + + def __init__(self, parentBucket=None): + """ + Create a L{Bucket} that may have a parent L{Bucket}. + + @param parentBucket: If a parent Bucket is specified, + all L{add} and L{drip} operations on this L{Bucket} + will be applied on the parent L{Bucket} as well. + @type parentBucket: L{Bucket} + """ + self.content = 0 + self.parentBucket = parentBucket + self.lastDrip = time() + + + def add(self, amount): + """ + Adds tokens to the L{Bucket} and its C{parentBucket}. + + This will add as many of the C{amount} tokens as will fit into both + this L{Bucket} and its C{parentBucket}. + + @param amount: The number of tokens to try to add. + @type amount: C{int} + + @returns: The number of tokens that actually fit. + @returntype: C{int} + """ + self.drip() + if self.maxburst is None: + allowable = amount + else: + allowable = min(amount, self.maxburst - self.content) + + if self.parentBucket is not None: + allowable = self.parentBucket.add(allowable) + self.content += allowable + return allowable + + + def drip(self): + """ + Let some of the bucket drain. + + The L{Bucket} drains at the rate specified by the class + variable C{rate}. + + @returns: C{True} if the bucket is empty after this drip. + @returntype: C{bool} + """ + if self.parentBucket is not None: + self.parentBucket.drip() + + if self.rate is None: + self.content = 0 + else: + now = time() + deltaTime = now - self.lastDrip + deltaTokens = deltaTime * self.rate + self.content = max(0, self.content - deltaTokens) + self.lastDrip = now + return self.content == 0 + + +class IBucketFilter(Interface): + def getBucketFor(*somethings, **some_kw): + """ + Return a L{Bucket} corresponding to the provided parameters. + + @returntype: L{Bucket} + """ + +@implementer(IBucketFilter) +class HierarchicalBucketFilter: + """ + Filter things into buckets that can be nested. + + @cvar bucketFactory: Class of buckets to make. + @type bucketFactory: L{Bucket} + @cvar sweepInterval: Seconds between sweeping out the bucket cache. + @type sweepInterval: C{int} + """ + bucketFactory = Bucket + sweepInterval = None + + def __init__(self, parentFilter=None): + self.buckets = {} + self.parentFilter = parentFilter + self.lastSweep = time() + + def getBucketFor(self, *a, **kw): + """ + Find or create a L{Bucket} corresponding to the provided parameters. + + Any parameters are passed on to L{getBucketKey}, from them it + decides which bucket you get. + + @returntype: L{Bucket} + """ + if ((self.sweepInterval is not None) + and ((time() - self.lastSweep) > self.sweepInterval)): + self.sweep() + + if self.parentFilter: + parentBucket = self.parentFilter.getBucketFor(self, *a, **kw) + else: + parentBucket = None + + key = self.getBucketKey(*a, **kw) + bucket = self.buckets.get(key) + if bucket is None: + bucket = self.bucketFactory(parentBucket) + self.buckets[key] = bucket + return bucket + + def getBucketKey(self, *a, **kw): + """ + Construct a key based on the input parameters to choose a L{Bucket}. + + The default implementation returns the same key for all + arguments. Override this method to provide L{Bucket} selection. + + @returns: Something to be used as a key in the bucket cache. + """ + return None + + def sweep(self): + """ + Remove empty buckets. + """ + for key, bucket in self.buckets.items(): + bucket_is_empty = bucket.drip() + if (bucket._refcount == 0) and bucket_is_empty: + del self.buckets[key] + + self.lastSweep = time() + + +class FilterByHost(HierarchicalBucketFilter): + """ + A Hierarchical Bucket filter with a L{Bucket} for each host. + """ + sweepInterval = 60 * 20 + + def getBucketKey(self, transport): + return transport.getPeer()[1] + + +class FilterByServer(HierarchicalBucketFilter): + """ + A Hierarchical Bucket filter with a L{Bucket} for each service. + """ + sweepInterval = None + + def getBucketKey(self, transport): + return transport.getHost()[2] + + +class ShapedConsumer(pcp.ProducerConsumerProxy): + """ + Wraps a C{Consumer} and shapes the rate at which it receives data. + """ + # Providing a Pull interface means I don't have to try to schedule + # traffic with callLaters. + iAmStreaming = False + + def __init__(self, consumer, bucket): + pcp.ProducerConsumerProxy.__init__(self, consumer) + self.bucket = bucket + self.bucket._refcount += 1 + + def _writeSomeData(self, data): + # In practice, this actually results in obscene amounts of + # overhead, as a result of generating lots and lots of packets + # with twelve-byte payloads. We may need to do a version of + # this with scheduled writes after all. + amount = self.bucket.add(len(data)) + return pcp.ProducerConsumerProxy._writeSomeData(self, data[:amount]) + + def stopProducing(self): + pcp.ProducerConsumerProxy.stopProducing(self) + self.bucket._refcount -= 1 + + +class ShapedTransport(ShapedConsumer): + """ + Wraps a C{Transport} and shapes the rate at which it receives data. + + This is a L{ShapedConsumer} with a little bit of magic to provide for + the case where the consumer it wraps is also a C{Transport} and people + will be attempting to access attributes this does not proxy as a + C{Consumer} (e.g. C{loseConnection}). + """ + # Ugh. We only wanted to filter IConsumer, not ITransport. + + iAmStreaming = False + def __getattr__(self, name): + # Because people will be doing things like .getPeer and + # .loseConnection on me. + return getattr(self.consumer, name) + + +class ShapedProtocolFactory: + """ + Dispense C{Protocols} with traffic shaping on their transports. + + Usage:: + + myserver = SomeFactory() + myserver.protocol = ShapedProtocolFactory(myserver.protocol, + bucketFilter) + + Where C{SomeServerFactory} is a L{twisted.internet.protocol.Factory}, and + C{bucketFilter} is an instance of L{HierarchicalBucketFilter}. + """ + def __init__(self, protoClass, bucketFilter): + """ + Tell me what to wrap and where to get buckets. + + @param protoClass: The class of C{Protocol} this will generate + wrapped instances of. + @type protoClass: L{Protocol<twisted.internet.interfaces.IProtocol>} + class + @param bucketFilter: The filter which will determine how + traffic is shaped. + @type bucketFilter: L{HierarchicalBucketFilter}. + """ + # More precisely, protoClass can be any callable that will return + # instances of something that implements IProtocol. + self.protocol = protoClass + self.bucketFilter = bucketFilter + + def __call__(self, *a, **kw): + """ + Make a C{Protocol} instance with a shaped transport. + + Any parameters will be passed on to the protocol's initializer. + + @returns: A C{Protocol} instance with a L{ShapedTransport}. + """ + proto = self.protocol(*a, **kw) + origMakeConnection = proto.makeConnection + def makeConnection(transport): + bucket = self.bucketFilter.getBucketFor(transport) + shapedTransport = ShapedTransport(transport, bucket) + return origMakeConnection(shapedTransport) + proto.makeConnection = makeConnection + return proto diff --git a/contrib/python/Twisted/py2/twisted/protocols/ident.py b/contrib/python/Twisted/py2/twisted/protocols/ident.py new file mode 100644 index 0000000000..69128b4326 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/ident.py @@ -0,0 +1,255 @@ +# -*- test-case-name: twisted.test.test_ident -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Ident protocol implementation. +""" + +import struct + +from twisted.internet import defer +from twisted.protocols import basic +from twisted.python import log, failure + +_MIN_PORT = 1 +_MAX_PORT = 2 ** 16 - 1 + +class IdentError(Exception): + """ + Can't determine connection owner; reason unknown. + """ + + identDescription = 'UNKNOWN-ERROR' + + def __str__(self): + return self.identDescription + + + +class NoUser(IdentError): + """ + The connection specified by the port pair is not currently in use or + currently not owned by an identifiable entity. + """ + identDescription = 'NO-USER' + + + +class InvalidPort(IdentError): + """ + Either the local or foreign port was improperly specified. This should + be returned if either or both of the port ids were out of range (TCP + port numbers are from 1-65535), negative integers, reals or in any + fashion not recognized as a non-negative integer. + """ + identDescription = 'INVALID-PORT' + + + +class HiddenUser(IdentError): + """ + The server was able to identify the user of this port, but the + information was not returned at the request of the user. + """ + identDescription = 'HIDDEN-USER' + + + +class IdentServer(basic.LineOnlyReceiver): + """ + The Identification Protocol (a.k.a., "ident", a.k.a., "the Ident + Protocol") provides a means to determine the identity of a user of a + particular TCP connection. Given a TCP port number pair, it returns a + character string which identifies the owner of that connection on the + server's system. + + Server authors should subclass this class and override the lookup method. + The default implementation returns an UNKNOWN-ERROR response for every + query. + """ + + def lineReceived(self, line): + parts = line.split(',') + if len(parts) != 2: + self.invalidQuery() + else: + try: + portOnServer, portOnClient = map(int, parts) + except ValueError: + self.invalidQuery() + else: + if _MIN_PORT <= portOnServer <= _MAX_PORT and _MIN_PORT <= portOnClient <= _MAX_PORT: + self.validQuery(portOnServer, portOnClient) + else: + self._ebLookup(failure.Failure(InvalidPort()), portOnServer, portOnClient) + + + def invalidQuery(self): + self.transport.loseConnection() + + + def validQuery(self, portOnServer, portOnClient): + """ + Called when a valid query is received to look up and deliver the + response. + + @param portOnServer: The server port from the query. + @param portOnClient: The client port from the query. + """ + serverAddr = self.transport.getHost().host, portOnServer + clientAddr = self.transport.getPeer().host, portOnClient + defer.maybeDeferred(self.lookup, serverAddr, clientAddr + ).addCallback(self._cbLookup, portOnServer, portOnClient + ).addErrback(self._ebLookup, portOnServer, portOnClient + ) + + + def _cbLookup(self, result, sport, cport): + (sysName, userId) = result + self.sendLine('%d, %d : USERID : %s : %s' % (sport, cport, sysName, userId)) + + + def _ebLookup(self, failure, sport, cport): + if failure.check(IdentError): + self.sendLine('%d, %d : ERROR : %s' % (sport, cport, failure.value)) + else: + log.err(failure) + self.sendLine('%d, %d : ERROR : %s' % (sport, cport, IdentError(failure.value))) + + + def lookup(self, serverAddress, clientAddress): + """ + Lookup user information about the specified address pair. + + Return value should be a two-tuple of system name and username. + Acceptable values for the system name may be found online at:: + + U{http://www.iana.org/assignments/operating-system-names} + + This method may also raise any IdentError subclass (or IdentError + itself) to indicate user information will not be provided for the + given query. + + A Deferred may also be returned. + + @param serverAddress: A two-tuple representing the server endpoint + of the address being queried. The first element is a string holding + a dotted-quad IP address. The second element is an integer + representing the port. + + @param clientAddress: Like I{serverAddress}, but represents the + client endpoint of the address being queried. + """ + raise IdentError() + + + +class ProcServerMixin: + """Implements lookup() to grab entries for responses from /proc/net/tcp + """ + + SYSTEM_NAME = 'LINUX' + + try: + from pwd import getpwuid + def getUsername(self, uid, getpwuid=getpwuid): + return getpwuid(uid)[0] + del getpwuid + except ImportError: + def getUsername(self, uid): + raise IdentError() + + + def entries(self): + with open('/proc/net/tcp') as f: + f.readline() + for L in f: + yield L.strip() + + + def dottedQuadFromHexString(self, hexstr): + return '.'.join(map(str, struct.unpack('4B', struct.pack('=L', int(hexstr, 16))))) + + + def unpackAddress(self, packed): + addr, port = packed.split(':') + addr = self.dottedQuadFromHexString(addr) + port = int(port, 16) + return addr, port + + + def parseLine(self, line): + parts = line.strip().split() + localAddr, localPort = self.unpackAddress(parts[1]) + remoteAddr, remotePort = self.unpackAddress(parts[2]) + uid = int(parts[7]) + return (localAddr, localPort), (remoteAddr, remotePort), uid + + + def lookup(self, serverAddress, clientAddress): + for ent in self.entries(): + localAddr, remoteAddr, uid = self.parseLine(ent) + if remoteAddr == clientAddress and localAddr[1] == serverAddress[1]: + return (self.SYSTEM_NAME, self.getUsername(uid)) + + raise NoUser() + + + +class IdentClient(basic.LineOnlyReceiver): + + errorTypes = (IdentError, NoUser, InvalidPort, HiddenUser) + + def __init__(self): + self.queries = [] + + + def lookup(self, portOnServer, portOnClient): + """ + Lookup user information about the specified address pair. + """ + self.queries.append((defer.Deferred(), portOnServer, portOnClient)) + if len(self.queries) > 1: + return self.queries[-1][0] + + self.sendLine('%d, %d' % (portOnServer, portOnClient)) + return self.queries[-1][0] + + + def lineReceived(self, line): + if not self.queries: + log.msg("Unexpected server response: %r" % (line,)) + else: + d, _, _ = self.queries.pop(0) + self.parseResponse(d, line) + if self.queries: + self.sendLine('%d, %d' % (self.queries[0][1], self.queries[0][2])) + + + def connectionLost(self, reason): + for q in self.queries: + q[0].errback(IdentError(reason)) + self.queries = [] + + + def parseResponse(self, deferred, line): + parts = line.split(':', 2) + if len(parts) != 3: + deferred.errback(IdentError(line)) + else: + ports, type, addInfo = map(str.strip, parts) + if type == 'ERROR': + for et in self.errorTypes: + if et.identDescription == addInfo: + deferred.errback(et(line)) + return + deferred.errback(IdentError(line)) + else: + deferred.callback((type, addInfo)) + + + +__all__ = ['IdentError', 'NoUser', 'InvalidPort', 'HiddenUser', + 'IdentServer', 'IdentClient', + 'ProcServerMixin'] diff --git a/contrib/python/Twisted/py2/twisted/protocols/loopback.py b/contrib/python/Twisted/py2/twisted/protocols/loopback.py new file mode 100644 index 0000000000..9d7beb0ce4 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/loopback.py @@ -0,0 +1,385 @@ +# -*- test-case-name: twisted.test.test_loopback -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Testing support for protocols -- loopback between client and server. +""" + +from __future__ import division, absolute_import + +# system imports +import tempfile + +from zope.interface import implementer + +# Twisted Imports +from twisted.protocols import policies +from twisted.internet import interfaces, protocol, main, defer +from twisted.internet.task import deferLater +from twisted.python import failure +from twisted.internet.interfaces import IAddress + + +class _LoopbackQueue(object): + """ + Trivial wrapper around a list to give it an interface like a queue, which + the addition of also sending notifications by way of a Deferred whenever + the list has something added to it. + """ + + _notificationDeferred = None + disconnect = False + + def __init__(self): + self._queue = [] + + + def put(self, v): + self._queue.append(v) + if self._notificationDeferred is not None: + d, self._notificationDeferred = self._notificationDeferred, None + d.callback(None) + + + def __nonzero__(self): + return bool(self._queue) + __bool__ = __nonzero__ + + + def get(self): + return self._queue.pop(0) + + + +@implementer(IAddress) +class _LoopbackAddress(object): + pass + + + +@implementer(interfaces.ITransport, interfaces.IConsumer) +class _LoopbackTransport(object): + disconnecting = False + producer = None + + # ITransport + def __init__(self, q): + self.q = q + + def write(self, data): + if not isinstance(data, bytes): + raise TypeError("Can only write bytes to ITransport") + self.q.put(data) + + def writeSequence(self, iovec): + self.q.put(b''.join(iovec)) + + def loseConnection(self): + self.q.disconnect = True + self.q.put(None) + + + def abortConnection(self): + """ + Abort the connection. Same as L{loseConnection}. + """ + self.loseConnection() + + + def getPeer(self): + return _LoopbackAddress() + + def getHost(self): + return _LoopbackAddress() + + # IConsumer + def registerProducer(self, producer, streaming): + assert self.producer is None + self.producer = producer + self.streamingProducer = streaming + self._pollProducer() + + def unregisterProducer(self): + assert self.producer is not None + self.producer = None + + def _pollProducer(self): + if self.producer is not None and not self.streamingProducer: + self.producer.resumeProducing() + + + +def identityPumpPolicy(queue, target): + """ + L{identityPumpPolicy} is a policy which delivers each chunk of data written + to the given queue as-is to the target. + + This isn't a particularly realistic policy. + + @see: L{loopbackAsync} + """ + while queue: + bytes = queue.get() + if bytes is None: + break + target.dataReceived(bytes) + + + +def collapsingPumpPolicy(queue, target): + """ + L{collapsingPumpPolicy} is a policy which collapses all outstanding chunks + into a single string and delivers it to the target. + + @see: L{loopbackAsync} + """ + bytes = [] + while queue: + chunk = queue.get() + if chunk is None: + break + bytes.append(chunk) + if bytes: + target.dataReceived(b''.join(bytes)) + + + +def loopbackAsync(server, client, pumpPolicy=identityPumpPolicy): + """ + Establish a connection between C{server} and C{client} then transfer data + between them until the connection is closed. This is often useful for + testing a protocol. + + @param server: The protocol instance representing the server-side of this + connection. + + @param client: The protocol instance representing the client-side of this + connection. + + @param pumpPolicy: When either C{server} or C{client} writes to its + transport, the string passed in is added to a queue of data for the + other protocol. Eventually, C{pumpPolicy} will be called with one such + queue and the corresponding protocol object. The pump policy callable + is responsible for emptying the queue and passing the strings it + contains to the given protocol's C{dataReceived} method. The signature + of C{pumpPolicy} is C{(queue, protocol)}. C{queue} is an object with a + C{get} method which will return the next string written to the + transport, or L{None} if the transport has been disconnected, and which + evaluates to C{True} if and only if there are more items to be + retrieved via C{get}. + + @return: A L{Deferred} which fires when the connection has been closed and + both sides have received notification of this. + """ + serverToClient = _LoopbackQueue() + clientToServer = _LoopbackQueue() + + server.makeConnection(_LoopbackTransport(serverToClient)) + client.makeConnection(_LoopbackTransport(clientToServer)) + + return _loopbackAsyncBody( + server, serverToClient, client, clientToServer, pumpPolicy) + + + +def _loopbackAsyncBody(server, serverToClient, client, clientToServer, + pumpPolicy): + """ + Transfer bytes from the output queue of each protocol to the input of the other. + + @param server: The protocol instance representing the server-side of this + connection. + + @param serverToClient: The L{_LoopbackQueue} holding the server's output. + + @param client: The protocol instance representing the client-side of this + connection. + + @param clientToServer: The L{_LoopbackQueue} holding the client's output. + + @param pumpPolicy: See L{loopbackAsync}. + + @return: A L{Deferred} which fires when the connection has been closed and + both sides have received notification of this. + """ + def pump(source, q, target): + sent = False + if q: + pumpPolicy(q, target) + sent = True + if sent and not q: + # A write buffer has now been emptied. Give any producer on that + # side an opportunity to produce more data. + source.transport._pollProducer() + + return sent + + while 1: + disconnect = clientSent = serverSent = False + + # Deliver the data which has been written. + serverSent = pump(server, serverToClient, client) + clientSent = pump(client, clientToServer, server) + + if not clientSent and not serverSent: + # Neither side wrote any data. Wait for some new data to be added + # before trying to do anything further. + d = defer.Deferred() + clientToServer._notificationDeferred = d + serverToClient._notificationDeferred = d + d.addCallback( + _loopbackAsyncContinue, + server, serverToClient, client, clientToServer, pumpPolicy) + return d + if serverToClient.disconnect: + # The server wants to drop the connection. Flush any remaining + # data it has. + disconnect = True + pump(server, serverToClient, client) + elif clientToServer.disconnect: + # The client wants to drop the connection. Flush any remaining + # data it has. + disconnect = True + pump(client, clientToServer, server) + if disconnect: + # Someone wanted to disconnect, so okay, the connection is gone. + server.connectionLost(failure.Failure(main.CONNECTION_DONE)) + client.connectionLost(failure.Failure(main.CONNECTION_DONE)) + return defer.succeed(None) + + + +def _loopbackAsyncContinue(ignored, server, serverToClient, client, + clientToServer, pumpPolicy): + # Clear the Deferred from each message queue, since it has already fired + # and cannot be used again. + clientToServer._notificationDeferred = None + serverToClient._notificationDeferred = None + + # Schedule some more byte-pushing to happen. This isn't done + # synchronously because no actual transport can re-enter dataReceived as + # a result of calling write, and doing this synchronously could result + # in that. + from twisted.internet import reactor + return deferLater( + reactor, 0, + _loopbackAsyncBody, + server, serverToClient, client, clientToServer, pumpPolicy) + + + +@implementer(interfaces.ITransport, interfaces.IConsumer) +class LoopbackRelay: + buffer = b'' + shouldLose = 0 + disconnecting = 0 + producer = None + + def __init__(self, target, logFile=None): + self.target = target + self.logFile = logFile + + def write(self, data): + self.buffer = self.buffer + data + if self.logFile: + self.logFile.write("loopback writing %s\n" % repr(data)) + + def writeSequence(self, iovec): + self.write(b"".join(iovec)) + + def clearBuffer(self): + if self.shouldLose == -1: + return + + if self.producer: + self.producer.resumeProducing() + if self.buffer: + if self.logFile: + self.logFile.write("loopback receiving %s\n" % repr(self.buffer)) + buffer = self.buffer + self.buffer = b'' + self.target.dataReceived(buffer) + if self.shouldLose == 1: + self.shouldLose = -1 + self.target.connectionLost(failure.Failure(main.CONNECTION_DONE)) + + def loseConnection(self): + if self.shouldLose != -1: + self.shouldLose = 1 + + def getHost(self): + return 'loopback' + + def getPeer(self): + return 'loopback' + + def registerProducer(self, producer, streaming): + self.producer = producer + + def unregisterProducer(self): + self.producer = None + + def logPrefix(self): + return 'Loopback(%r)' % (self.target.__class__.__name__,) + + + +class LoopbackClientFactory(protocol.ClientFactory): + + def __init__(self, protocol): + self.disconnected = 0 + self.deferred = defer.Deferred() + self.protocol = protocol + + def buildProtocol(self, addr): + return self.protocol + + def clientConnectionLost(self, connector, reason): + self.disconnected = 1 + self.deferred.callback(None) + + +class _FireOnClose(policies.ProtocolWrapper): + def __init__(self, protocol, factory): + policies.ProtocolWrapper.__init__(self, protocol, factory) + self.deferred = defer.Deferred() + + def connectionLost(self, reason): + policies.ProtocolWrapper.connectionLost(self, reason) + self.deferred.callback(None) + + +def loopbackTCP(server, client, port=0, noisy=True): + """Run session between server and client protocol instances over TCP.""" + from twisted.internet import reactor + f = policies.WrappingFactory(protocol.Factory()) + serverWrapper = _FireOnClose(f, server) + f.noisy = noisy + f.buildProtocol = lambda addr: serverWrapper + serverPort = reactor.listenTCP(port, f, interface='127.0.0.1') + clientF = LoopbackClientFactory(client) + clientF.noisy = noisy + reactor.connectTCP('127.0.0.1', serverPort.getHost().port, clientF) + d = clientF.deferred + d.addCallback(lambda x: serverWrapper.deferred) + d.addCallback(lambda x: serverPort.stopListening()) + return d + + +def loopbackUNIX(server, client, noisy=True): + """Run session between server and client protocol instances over UNIX socket.""" + path = tempfile.mktemp() + from twisted.internet import reactor + f = policies.WrappingFactory(protocol.Factory()) + serverWrapper = _FireOnClose(f, server) + f.noisy = noisy + f.buildProtocol = lambda addr: serverWrapper + serverPort = reactor.listenUNIX(path, f) + clientF = LoopbackClientFactory(client) + clientF.noisy = noisy + reactor.connectUNIX(path, clientF) + d = clientF.deferred + d.addCallback(lambda x: serverWrapper.deferred) + d.addCallback(lambda x: serverPort.stopListening()) + return d diff --git a/contrib/python/Twisted/py2/twisted/protocols/memcache.py b/contrib/python/Twisted/py2/twisted/protocols/memcache.py new file mode 100644 index 0000000000..6fd666bfd0 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/memcache.py @@ -0,0 +1,766 @@ +# -*- test-case-name: twisted.test.test_memcache -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Memcache client protocol. Memcached is a caching server, storing data in the +form of pairs key/value, and memcache is the protocol to talk with it. + +To connect to a server, create a factory for L{MemCacheProtocol}:: + + from twisted.internet import reactor, protocol + from twisted.protocols.memcache import MemCacheProtocol, DEFAULT_PORT + d = protocol.ClientCreator(reactor, MemCacheProtocol + ).connectTCP("localhost", DEFAULT_PORT) + def doSomething(proto): + # Here you call the memcache operations + return proto.set("mykey", "a lot of data") + d.addCallback(doSomething) + reactor.run() + +All the operations of the memcache protocol are present, but +L{MemCacheProtocol.set} and L{MemCacheProtocol.get} are the more important. + +See U{http://code.sixapart.com/svn/memcached/trunk/server/doc/protocol.txt} for +more information about the protocol. +""" + +from __future__ import absolute_import, division + +from collections import deque + +from twisted.protocols.basic import LineReceiver +from twisted.protocols.policies import TimeoutMixin +from twisted.internet.defer import Deferred, fail, TimeoutError +from twisted.python import log +from twisted.python.compat import ( + intToBytes, iteritems, nativeString, networkString) + + + +DEFAULT_PORT = 11211 + + + +class NoSuchCommand(Exception): + """ + Exception raised when a non existent command is called. + """ + + + +class ClientError(Exception): + """ + Error caused by an invalid client call. + """ + + + +class ServerError(Exception): + """ + Problem happening on the server. + """ + + + +class Command(object): + """ + Wrap a client action into an object, that holds the values used in the + protocol. + + @ivar _deferred: the L{Deferred} object that will be fired when the result + arrives. + @type _deferred: L{Deferred} + + @ivar command: name of the command sent to the server. + @type command: L{bytes} + """ + + def __init__(self, command, **kwargs): + """ + Create a command. + + @param command: the name of the command. + @type command: L{bytes} + + @param kwargs: this values will be stored as attributes of the object + for future use + """ + self.command = command + self._deferred = Deferred() + for k, v in kwargs.items(): + setattr(self, k, v) + + + def success(self, value): + """ + Shortcut method to fire the underlying deferred. + """ + self._deferred.callback(value) + + + def fail(self, error): + """ + Make the underlying deferred fails. + """ + self._deferred.errback(error) + + + +class MemCacheProtocol(LineReceiver, TimeoutMixin): + """ + MemCache protocol: connect to a memcached server to store/retrieve values. + + @ivar persistentTimeOut: the timeout period used to wait for a response. + @type persistentTimeOut: L{int} + + @ivar _current: current list of requests waiting for an answer from the + server. + @type _current: L{deque} of L{Command} + + @ivar _lenExpected: amount of data expected in raw mode, when reading for + a value. + @type _lenExpected: L{int} + + @ivar _getBuffer: current buffer of data, used to store temporary data + when reading in raw mode. + @type _getBuffer: L{list} + + @ivar _bufferLength: the total amount of bytes in C{_getBuffer}. + @type _bufferLength: L{int} + + @ivar _disconnected: indicate if the connectionLost has been called or not. + @type _disconnected: L{bool} + """ + MAX_KEY_LENGTH = 250 + _disconnected = False + + def __init__(self, timeOut=60): + """ + Create the protocol. + + @param timeOut: the timeout to wait before detecting that the + connection is dead and close it. It's expressed in seconds. + @type timeOut: L{int} + """ + self._current = deque() + self._lenExpected = None + self._getBuffer = None + self._bufferLength = None + self.persistentTimeOut = self.timeOut = timeOut + + + def _cancelCommands(self, reason): + """ + Cancel all the outstanding commands, making them fail with C{reason}. + """ + while self._current: + cmd = self._current.popleft() + cmd.fail(reason) + + + def timeoutConnection(self): + """ + Close the connection in case of timeout. + """ + self._cancelCommands(TimeoutError("Connection timeout")) + self.transport.loseConnection() + + + def connectionLost(self, reason): + """ + Cause any outstanding commands to fail. + """ + self._disconnected = True + self._cancelCommands(reason) + LineReceiver.connectionLost(self, reason) + + + def sendLine(self, line): + """ + Override sendLine to add a timeout to response. + """ + if not self._current: + self.setTimeout(self.persistentTimeOut) + LineReceiver.sendLine(self, line) + + + def rawDataReceived(self, data): + """ + Collect data for a get. + """ + self.resetTimeout() + self._getBuffer.append(data) + self._bufferLength += len(data) + if self._bufferLength >= self._lenExpected + 2: + data = b"".join(self._getBuffer) + buf = data[:self._lenExpected] + rem = data[self._lenExpected + 2:] + val = buf + self._lenExpected = None + self._getBuffer = None + self._bufferLength = None + cmd = self._current[0] + if cmd.multiple: + flags, cas = cmd.values[cmd.currentKey] + cmd.values[cmd.currentKey] = (flags, cas, val) + else: + cmd.value = val + self.setLineMode(rem) + + + def cmd_STORED(self): + """ + Manage a success response to a set operation. + """ + self._current.popleft().success(True) + + + def cmd_NOT_STORED(self): + """ + Manage a specific 'not stored' response to a set operation: this is not + an error, but some condition wasn't met. + """ + self._current.popleft().success(False) + + + def cmd_END(self): + """ + This the end token to a get or a stat operation. + """ + cmd = self._current.popleft() + if cmd.command == b"get": + if cmd.multiple: + values = {key: val[::2] for key, val in iteritems(cmd.values)} + cmd.success(values) + else: + cmd.success((cmd.flags, cmd.value)) + elif cmd.command == b"gets": + if cmd.multiple: + cmd.success(cmd.values) + else: + cmd.success((cmd.flags, cmd.cas, cmd.value)) + elif cmd.command == b"stats": + cmd.success(cmd.values) + else: + raise RuntimeError( + "Unexpected END response to %s command" % + (nativeString(cmd.command),)) + + + def cmd_NOT_FOUND(self): + """ + Manage error response for incr/decr/delete. + """ + self._current.popleft().success(False) + + + def cmd_VALUE(self, line): + """ + Prepare the reading a value after a get. + """ + cmd = self._current[0] + if cmd.command == b"get": + key, flags, length = line.split() + cas = b"" + else: + key, flags, length, cas = line.split() + self._lenExpected = int(length) + self._getBuffer = [] + self._bufferLength = 0 + if cmd.multiple: + if key not in cmd.keys: + raise RuntimeError("Unexpected commands answer.") + cmd.currentKey = key + cmd.values[key] = [int(flags), cas] + else: + if cmd.key != key: + raise RuntimeError("Unexpected commands answer.") + cmd.flags = int(flags) + cmd.cas = cas + self.setRawMode() + + + def cmd_STAT(self, line): + """ + Reception of one stat line. + """ + cmd = self._current[0] + key, val = line.split(b" ", 1) + cmd.values[key] = val + + + def cmd_VERSION(self, versionData): + """ + Read version token. + """ + self._current.popleft().success(versionData) + + + def cmd_ERROR(self): + """ + A non-existent command has been sent. + """ + log.err("Non-existent command sent.") + cmd = self._current.popleft() + cmd.fail(NoSuchCommand()) + + + def cmd_CLIENT_ERROR(self, errText): + """ + An invalid input as been sent. + """ + errText = repr(errText) + log.err("Invalid input: " + errText) + cmd = self._current.popleft() + cmd.fail(ClientError(errText)) + + + def cmd_SERVER_ERROR(self, errText): + """ + An error has happened server-side. + """ + errText = repr(errText) + log.err("Server error: " + errText) + cmd = self._current.popleft() + cmd.fail(ServerError(errText)) + + + def cmd_DELETED(self): + """ + A delete command has completed successfully. + """ + self._current.popleft().success(True) + + + def cmd_OK(self): + """ + The last command has been completed. + """ + self._current.popleft().success(True) + + + def cmd_EXISTS(self): + """ + A C{checkAndSet} update has failed. + """ + self._current.popleft().success(False) + + + def lineReceived(self, line): + """ + Receive line commands from the server. + """ + self.resetTimeout() + token = line.split(b" ", 1)[0] + # First manage standard commands without space + cmd = getattr(self, "cmd_" + nativeString(token), None) + if cmd is not None: + args = line.split(b" ", 1)[1:] + if args: + cmd(args[0]) + else: + cmd() + else: + # Then manage commands with space in it + line = line.replace(b" ", b"_") + cmd = getattr(self, "cmd_" + nativeString(line), None) + if cmd is not None: + cmd() + else: + # Increment/Decrement response + cmd = self._current.popleft() + val = int(line) + cmd.success(val) + if not self._current: + # No pending request, remove timeout + self.setTimeout(None) + + + def increment(self, key, val=1): + """ + Increment the value of C{key} by given value (default to 1). + C{key} must be consistent with an int. Return the new value. + + @param key: the key to modify. + @type key: L{bytes} + + @param val: the value to increment. + @type val: L{int} + + @return: a deferred with will be called back with the new value + associated with the key (after the increment). + @rtype: L{Deferred} + """ + return self._incrdecr(b"incr", key, val) + + + def decrement(self, key, val=1): + """ + Decrement the value of C{key} by given value (default to 1). + C{key} must be consistent with an int. Return the new value, coerced to + 0 if negative. + + @param key: the key to modify. + @type key: L{bytes} + + @param val: the value to decrement. + @type val: L{int} + + @return: a deferred with will be called back with the new value + associated with the key (after the decrement). + @rtype: L{Deferred} + """ + return self._incrdecr(b"decr", key, val) + + + def _incrdecr(self, cmd, key, val): + """ + Internal wrapper for incr/decr. + """ + if self._disconnected: + return fail(RuntimeError("not connected")) + if not isinstance(key, bytes): + return fail(ClientError( + "Invalid type for key: %s, expecting bytes" % (type(key),))) + if len(key) > self.MAX_KEY_LENGTH: + return fail(ClientError("Key too long")) + fullcmd = b" ".join([cmd, key, intToBytes(int(val))]) + self.sendLine(fullcmd) + cmdObj = Command(cmd, key=key) + self._current.append(cmdObj) + return cmdObj._deferred + + + def replace(self, key, val, flags=0, expireTime=0): + """ + Replace the given C{key}. It must already exist in the server. + + @param key: the key to replace. + @type key: L{bytes} + + @param val: the new value associated with the key. + @type val: L{bytes} + + @param flags: the flags to store with the key. + @type flags: L{int} + + @param expireTime: if different from 0, the relative time in seconds + when the key will be deleted from the store. + @type expireTime: L{int} + + @return: a deferred that will fire with C{True} if the operation has + succeeded, and C{False} with the key didn't previously exist. + @rtype: L{Deferred} + """ + return self._set(b"replace", key, val, flags, expireTime, b"") + + + def add(self, key, val, flags=0, expireTime=0): + """ + Add the given C{key}. It must not exist in the server. + + @param key: the key to add. + @type key: L{bytes} + + @param val: the value associated with the key. + @type val: L{bytes} + + @param flags: the flags to store with the key. + @type flags: L{int} + + @param expireTime: if different from 0, the relative time in seconds + when the key will be deleted from the store. + @type expireTime: L{int} + + @return: a deferred that will fire with C{True} if the operation has + succeeded, and C{False} with the key already exists. + @rtype: L{Deferred} + """ + return self._set(b"add", key, val, flags, expireTime, b"") + + + def set(self, key, val, flags=0, expireTime=0): + """ + Set the given C{key}. + + @param key: the key to set. + @type key: L{bytes} + + @param val: the value associated with the key. + @type val: L{bytes} + + @param flags: the flags to store with the key. + @type flags: L{int} + + @param expireTime: if different from 0, the relative time in seconds + when the key will be deleted from the store. + @type expireTime: L{int} + + @return: a deferred that will fire with C{True} if the operation has + succeeded. + @rtype: L{Deferred} + """ + return self._set(b"set", key, val, flags, expireTime, b"") + + + def checkAndSet(self, key, val, cas, flags=0, expireTime=0): + """ + Change the content of C{key} only if the C{cas} value matches the + current one associated with the key. Use this to store a value which + hasn't been modified since last time you fetched it. + + @param key: The key to set. + @type key: L{bytes} + + @param val: The value associated with the key. + @type val: L{bytes} + + @param cas: Unique 64-bit value returned by previous call of C{get}. + @type cas: L{bytes} + + @param flags: The flags to store with the key. + @type flags: L{int} + + @param expireTime: If different from 0, the relative time in seconds + when the key will be deleted from the store. + @type expireTime: L{int} + + @return: A deferred that will fire with C{True} if the operation has + succeeded, C{False} otherwise. + @rtype: L{Deferred} + """ + return self._set(b"cas", key, val, flags, expireTime, cas) + + + def _set(self, cmd, key, val, flags, expireTime, cas): + """ + Internal wrapper for setting values. + """ + if self._disconnected: + return fail(RuntimeError("not connected")) + if not isinstance(key, bytes): + return fail(ClientError( + "Invalid type for key: %s, expecting bytes" % (type(key),))) + if len(key) > self.MAX_KEY_LENGTH: + return fail(ClientError("Key too long")) + if not isinstance(val, bytes): + return fail(ClientError( + "Invalid type for value: %s, expecting bytes" % + (type(val),))) + if cas: + cas = b" " + cas + length = len(val) + fullcmd = b" ".join([ + cmd, key, + networkString("%d %d %d" % (flags, expireTime, length))]) + cas + self.sendLine(fullcmd) + self.sendLine(val) + cmdObj = Command(cmd, key=key, flags=flags, length=length) + self._current.append(cmdObj) + return cmdObj._deferred + + + def append(self, key, val): + """ + Append given data to the value of an existing key. + + @param key: The key to modify. + @type key: L{bytes} + + @param val: The value to append to the current value associated with + the key. + @type val: L{bytes} + + @return: A deferred that will fire with C{True} if the operation has + succeeded, C{False} otherwise. + @rtype: L{Deferred} + """ + # Even if flags and expTime values are ignored, we have to pass them + return self._set(b"append", key, val, 0, 0, b"") + + + def prepend(self, key, val): + """ + Prepend given data to the value of an existing key. + + @param key: The key to modify. + @type key: L{bytes} + + @param val: The value to prepend to the current value associated with + the key. + @type val: L{bytes} + + @return: A deferred that will fire with C{True} if the operation has + succeeded, C{False} otherwise. + @rtype: L{Deferred} + """ + # Even if flags and expTime values are ignored, we have to pass them + return self._set(b"prepend", key, val, 0, 0, b"") + + + def get(self, key, withIdentifier=False): + """ + Get the given C{key}. It doesn't support multiple keys. If + C{withIdentifier} is set to C{True}, the command issued is a C{gets}, + that will return the current identifier associated with the value. This + identifier has to be used when issuing C{checkAndSet} update later, + using the corresponding method. + + @param key: The key to retrieve. + @type key: L{bytes} + + @param withIdentifier: If set to C{True}, retrieve the current + identifier along with the value and the flags. + @type withIdentifier: L{bool} + + @return: A deferred that will fire with the tuple (flags, value) if + C{withIdentifier} is C{False}, or (flags, cas identifier, value) + if C{True}. If the server indicates there is no value + associated with C{key}, the returned value will be L{None} and + the returned flags will be C{0}. + @rtype: L{Deferred} + """ + return self._get([key], withIdentifier, False) + + + def getMultiple(self, keys, withIdentifier=False): + """ + Get the given list of C{keys}. If C{withIdentifier} is set to C{True}, + the command issued is a C{gets}, that will return the identifiers + associated with each values. This identifier has to be used when + issuing C{checkAndSet} update later, using the corresponding method. + + @param keys: The keys to retrieve. + @type keys: L{list} of L{bytes} + + @param withIdentifier: If set to C{True}, retrieve the identifiers + along with the values and the flags. + @type withIdentifier: L{bool} + + @return: A deferred that will fire with a dictionary with the elements + of C{keys} as keys and the tuples (flags, value) as values if + C{withIdentifier} is C{False}, or (flags, cas identifier, value) if + C{True}. If the server indicates there is no value associated with + C{key}, the returned values will be L{None} and the returned flags + will be C{0}. + @rtype: L{Deferred} + + @since: 9.0 + """ + return self._get(keys, withIdentifier, True) + + + def _get(self, keys, withIdentifier, multiple): + """ + Helper method for C{get} and C{getMultiple}. + """ + keys = list(keys) + if self._disconnected: + return fail(RuntimeError("not connected")) + for key in keys: + if not isinstance(key, bytes): + return fail(ClientError( + "Invalid type for key: %s, expecting bytes" % + (type(key),))) + if len(key) > self.MAX_KEY_LENGTH: + return fail(ClientError("Key too long")) + if withIdentifier: + cmd = b"gets" + else: + cmd = b"get" + fullcmd = b" ".join([cmd] + keys) + self.sendLine(fullcmd) + if multiple: + values = dict([(key, (0, b"", None)) for key in keys]) + cmdObj = Command(cmd, keys=keys, values=values, multiple=True) + else: + cmdObj = Command(cmd, key=keys[0], value=None, flags=0, cas=b"", + multiple=False) + self._current.append(cmdObj) + return cmdObj._deferred + + + def stats(self, arg=None): + """ + Get some stats from the server. It will be available as a dict. + + @param arg: An optional additional string which will be sent along + with the I{stats} command. The interpretation of this value by + the server is left undefined by the memcache protocol + specification. + @type arg: L{None} or L{bytes} + + @return: a deferred that will fire with a L{dict} of the available + statistics. + @rtype: L{Deferred} + """ + if arg: + cmd = b"stats " + arg + else: + cmd = b"stats" + if self._disconnected: + return fail(RuntimeError("not connected")) + self.sendLine(cmd) + cmdObj = Command(b"stats", values={}) + self._current.append(cmdObj) + return cmdObj._deferred + + + def version(self): + """ + Get the version of the server. + + @return: a deferred that will fire with the string value of the + version. + @rtype: L{Deferred} + """ + if self._disconnected: + return fail(RuntimeError("not connected")) + self.sendLine(b"version") + cmdObj = Command(b"version") + self._current.append(cmdObj) + return cmdObj._deferred + + + def delete(self, key): + """ + Delete an existing C{key}. + + @param key: the key to delete. + @type key: L{bytes} + + @return: a deferred that will be called back with C{True} if the key + was successfully deleted, or C{False} if not. + @rtype: L{Deferred} + """ + if self._disconnected: + return fail(RuntimeError("not connected")) + if not isinstance(key, bytes): + return fail(ClientError( + "Invalid type for key: %s, expecting bytes" % (type(key),))) + self.sendLine(b"delete " + key) + cmdObj = Command(b"delete", key=key) + self._current.append(cmdObj) + return cmdObj._deferred + + + def flushAll(self): + """ + Flush all cached values. + + @return: a deferred that will be called back with C{True} when the + operation has succeeded. + @rtype: L{Deferred} + """ + if self._disconnected: + return fail(RuntimeError("not connected")) + self.sendLine(b"flush_all") + cmdObj = Command(b"flush_all") + self._current.append(cmdObj) + return cmdObj._deferred + + + +__all__ = ["MemCacheProtocol", "DEFAULT_PORT", "NoSuchCommand", "ClientError", + "ServerError"] diff --git a/contrib/python/Twisted/py2/twisted/protocols/pcp.py b/contrib/python/Twisted/py2/twisted/protocols/pcp.py new file mode 100644 index 0000000000..43f20ec410 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/pcp.py @@ -0,0 +1,203 @@ +# -*- test-case-name: twisted.test.test_pcp -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Producer-Consumer Proxy. +""" + +from zope.interface import implementer + +from twisted.internet import interfaces + + +@implementer(interfaces.IProducer, interfaces.IConsumer) +class BasicProducerConsumerProxy: + """ + I can act as a man in the middle between any Producer and Consumer. + + @ivar producer: the Producer I subscribe to. + @type producer: L{IProducer<interfaces.IProducer>} + @ivar consumer: the Consumer I publish to. + @type consumer: L{IConsumer<interfaces.IConsumer>} + @ivar paused: As a Producer, am I paused? + @type paused: bool + """ + consumer = None + producer = None + producerIsStreaming = None + iAmStreaming = True + outstandingPull = False + paused = False + stopped = False + + def __init__(self, consumer): + self._buffer = [] + if consumer is not None: + self.consumer = consumer + consumer.registerProducer(self, self.iAmStreaming) + + # Producer methods: + + def pauseProducing(self): + self.paused = True + if self.producer: + self.producer.pauseProducing() + + def resumeProducing(self): + self.paused = False + if self._buffer: + # TODO: Check to see if consumer supports writeSeq. + self.consumer.write(''.join(self._buffer)) + self._buffer[:] = [] + else: + if not self.iAmStreaming: + self.outstandingPull = True + + if self.producer is not None: + self.producer.resumeProducing() + + def stopProducing(self): + if self.producer is not None: + self.producer.stopProducing() + if self.consumer is not None: + del self.consumer + + # Consumer methods: + + def write(self, data): + if self.paused or (not self.iAmStreaming and not self.outstandingPull): + # We could use that fifo queue here. + self._buffer.append(data) + + elif self.consumer is not None: + self.consumer.write(data) + self.outstandingPull = False + + def finish(self): + if self.consumer is not None: + self.consumer.finish() + self.unregisterProducer() + + def registerProducer(self, producer, streaming): + self.producer = producer + self.producerIsStreaming = streaming + + def unregisterProducer(self): + if self.producer is not None: + del self.producer + del self.producerIsStreaming + if self.consumer: + self.consumer.unregisterProducer() + + def __repr__(self): + return '<%s@%x around %s>' % (self.__class__, id(self), self.consumer) + + +class ProducerConsumerProxy(BasicProducerConsumerProxy): + """ProducerConsumerProxy with a finite buffer. + + When my buffer fills up, I have my parent Producer pause until my buffer + has room in it again. + """ + # Copies much from abstract.FileDescriptor + bufferSize = 2**2**2**2 + + producerPaused = False + unregistered = False + + def pauseProducing(self): + # Does *not* call up to ProducerConsumerProxy to relay the pause + # message through to my parent Producer. + self.paused = True + + def resumeProducing(self): + self.paused = False + if self._buffer: + data = ''.join(self._buffer) + bytesSent = self._writeSomeData(data) + if bytesSent < len(data): + unsent = data[bytesSent:] + assert not self.iAmStreaming, ( + "Streaming producer did not write all its data.") + self._buffer[:] = [unsent] + else: + self._buffer[:] = [] + else: + bytesSent = 0 + + if (self.unregistered and bytesSent and not self._buffer and + self.consumer is not None): + self.consumer.unregisterProducer() + + if not self.iAmStreaming: + self.outstandingPull = not bytesSent + + if self.producer is not None: + bytesBuffered = sum([len(s) for s in self._buffer]) + # TODO: You can see here the potential for high and low + # watermarks, where bufferSize would be the high mark when we + # ask the upstream producer to pause, and we wouldn't have + # it resume again until it hit the low mark. Or if producer + # is Pull, maybe we'd like to pull from it as much as necessary + # to keep our buffer full to the low mark, so we're never caught + # without something to send. + if self.producerPaused and (bytesBuffered < self.bufferSize): + # Now that our buffer is empty, + self.producerPaused = False + self.producer.resumeProducing() + elif self.outstandingPull: + # I did not have any data to write in response to a pull, + # so I'd better pull some myself. + self.producer.resumeProducing() + + def write(self, data): + if self.paused or (not self.iAmStreaming and not self.outstandingPull): + # We could use that fifo queue here. + self._buffer.append(data) + + elif self.consumer is not None: + assert not self._buffer, ( + "Writing fresh data to consumer before my buffer is empty!") + # I'm going to use _writeSomeData here so that there is only one + # path to self.consumer.write. But it doesn't actually make sense, + # if I am streaming, for some data to not be all data. But maybe I + # am not streaming, but I am writing here anyway, because there was + # an earlier request for data which was not answered. + bytesSent = self._writeSomeData(data) + self.outstandingPull = False + if not bytesSent == len(data): + assert not self.iAmStreaming, ( + "Streaming producer did not write all its data.") + self._buffer.append(data[bytesSent:]) + + if (self.producer is not None) and self.producerIsStreaming: + bytesBuffered = sum([len(s) for s in self._buffer]) + if bytesBuffered >= self.bufferSize: + + self.producer.pauseProducing() + self.producerPaused = True + + def registerProducer(self, producer, streaming): + self.unregistered = False + BasicProducerConsumerProxy.registerProducer(self, producer, streaming) + if not streaming: + producer.resumeProducing() + + def unregisterProducer(self): + if self.producer is not None: + del self.producer + del self.producerIsStreaming + self.unregistered = True + if self.consumer and not self._buffer: + self.consumer.unregisterProducer() + + def _writeSomeData(self, data): + """Write as much of this data as possible. + + @returns: The number of bytes written. + """ + if self.consumer is None: + return 0 + self.consumer.write(data) + return len(data) diff --git a/contrib/python/Twisted/py2/twisted/protocols/policies.py b/contrib/python/Twisted/py2/twisted/protocols/policies.py new file mode 100644 index 0000000000..5b8830aa86 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/policies.py @@ -0,0 +1,751 @@ +# -*- test-case-name: twisted.test.test_policies -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Resource limiting policies. + +@seealso: See also L{twisted.protocols.htb} for rate limiting. +""" + +from __future__ import division, absolute_import + +# system imports +import sys + +from zope.interface import directlyProvides, providedBy + +# twisted imports +from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory +from twisted.internet import error +from twisted.internet.interfaces import ILoggingContext +from twisted.python import log + + +def _wrappedLogPrefix(wrapper, wrapped): + """ + Compute a log prefix for a wrapper and the object it wraps. + + @rtype: C{str} + """ + if ILoggingContext.providedBy(wrapped): + logPrefix = wrapped.logPrefix() + else: + logPrefix = wrapped.__class__.__name__ + return "%s (%s)" % (logPrefix, wrapper.__class__.__name__) + + + +class ProtocolWrapper(Protocol): + """ + Wraps protocol instances and acts as their transport as well. + + @ivar wrappedProtocol: An L{IProtocol<twisted.internet.interfaces.IProtocol>} + provider to which L{IProtocol<twisted.internet.interfaces.IProtocol>} + method calls onto this L{ProtocolWrapper} will be proxied. + + @ivar factory: The L{WrappingFactory} which created this + L{ProtocolWrapper}. + """ + + disconnecting = 0 + + def __init__(self, factory, wrappedProtocol): + self.wrappedProtocol = wrappedProtocol + self.factory = factory + + + def logPrefix(self): + """ + Use a customized log prefix mentioning both the wrapped protocol and + the current one. + """ + return _wrappedLogPrefix(self, self.wrappedProtocol) + + + def makeConnection(self, transport): + """ + When a connection is made, register this wrapper with its factory, + save the real transport, and connect the wrapped protocol to this + L{ProtocolWrapper} to intercept any transport calls it makes. + """ + directlyProvides(self, providedBy(transport)) + Protocol.makeConnection(self, transport) + self.factory.registerProtocol(self) + self.wrappedProtocol.makeConnection(self) + + + # Transport relaying + + def write(self, data): + self.transport.write(data) + + + def writeSequence(self, data): + self.transport.writeSequence(data) + + + def loseConnection(self): + self.disconnecting = 1 + self.transport.loseConnection() + + + def getPeer(self): + return self.transport.getPeer() + + + def getHost(self): + return self.transport.getHost() + + + def registerProducer(self, producer, streaming): + self.transport.registerProducer(producer, streaming) + + + def unregisterProducer(self): + self.transport.unregisterProducer() + + + def stopConsuming(self): + self.transport.stopConsuming() + + + def __getattr__(self, name): + return getattr(self.transport, name) + + + # Protocol relaying + + def dataReceived(self, data): + self.wrappedProtocol.dataReceived(data) + + + def connectionLost(self, reason): + self.factory.unregisterProtocol(self) + self.wrappedProtocol.connectionLost(reason) + + # Breaking reference cycle between self and wrappedProtocol. + self.wrappedProtocol = None + + +class WrappingFactory(ClientFactory): + """ + Wraps a factory and its protocols, and keeps track of them. + """ + + protocol = ProtocolWrapper + + def __init__(self, wrappedFactory): + self.wrappedFactory = wrappedFactory + self.protocols = {} + + + def logPrefix(self): + """ + Generate a log prefix mentioning both the wrapped factory and this one. + """ + return _wrappedLogPrefix(self, self.wrappedFactory) + + + def doStart(self): + self.wrappedFactory.doStart() + ClientFactory.doStart(self) + + + def doStop(self): + self.wrappedFactory.doStop() + ClientFactory.doStop(self) + + + def startedConnecting(self, connector): + self.wrappedFactory.startedConnecting(connector) + + + def clientConnectionFailed(self, connector, reason): + self.wrappedFactory.clientConnectionFailed(connector, reason) + + + def clientConnectionLost(self, connector, reason): + self.wrappedFactory.clientConnectionLost(connector, reason) + + + def buildProtocol(self, addr): + return self.protocol(self, self.wrappedFactory.buildProtocol(addr)) + + + def registerProtocol(self, p): + """ + Called by protocol to register itself. + """ + self.protocols[p] = 1 + + + def unregisterProtocol(self, p): + """ + Called by protocols when they go away. + """ + del self.protocols[p] + + + +class ThrottlingProtocol(ProtocolWrapper): + """ + Protocol for L{ThrottlingFactory}. + """ + + # wrap API for tracking bandwidth + + def write(self, data): + self.factory.registerWritten(len(data)) + ProtocolWrapper.write(self, data) + + + def writeSequence(self, seq): + self.factory.registerWritten(sum(map(len, seq))) + ProtocolWrapper.writeSequence(self, seq) + + + def dataReceived(self, data): + self.factory.registerRead(len(data)) + ProtocolWrapper.dataReceived(self, data) + + + def registerProducer(self, producer, streaming): + self.producer = producer + ProtocolWrapper.registerProducer(self, producer, streaming) + + + def unregisterProducer(self): + del self.producer + ProtocolWrapper.unregisterProducer(self) + + + def throttleReads(self): + self.transport.pauseProducing() + + + def unthrottleReads(self): + self.transport.resumeProducing() + + + def throttleWrites(self): + if hasattr(self, "producer"): + self.producer.pauseProducing() + + + def unthrottleWrites(self): + if hasattr(self, "producer"): + self.producer.resumeProducing() + + + +class ThrottlingFactory(WrappingFactory): + """ + Throttles bandwidth and number of connections. + + Write bandwidth will only be throttled if there is a producer + registered. + """ + + protocol = ThrottlingProtocol + + def __init__(self, wrappedFactory, maxConnectionCount=sys.maxsize, + readLimit=None, writeLimit=None): + WrappingFactory.__init__(self, wrappedFactory) + self.connectionCount = 0 + self.maxConnectionCount = maxConnectionCount + self.readLimit = readLimit # max bytes we should read per second + self.writeLimit = writeLimit # max bytes we should write per second + self.readThisSecond = 0 + self.writtenThisSecond = 0 + self.unthrottleReadsID = None + self.checkReadBandwidthID = None + self.unthrottleWritesID = None + self.checkWriteBandwidthID = None + + + def callLater(self, period, func): + """ + Wrapper around + L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>} + for test purpose. + """ + from twisted.internet import reactor + return reactor.callLater(period, func) + + + def registerWritten(self, length): + """ + Called by protocol to tell us more bytes were written. + """ + self.writtenThisSecond += length + + + def registerRead(self, length): + """ + Called by protocol to tell us more bytes were read. + """ + self.readThisSecond += length + + + def checkReadBandwidth(self): + """ + Checks if we've passed bandwidth limits. + """ + if self.readThisSecond > self.readLimit: + self.throttleReads() + throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0 + self.unthrottleReadsID = self.callLater(throttleTime, + self.unthrottleReads) + self.readThisSecond = 0 + self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth) + + + def checkWriteBandwidth(self): + if self.writtenThisSecond > self.writeLimit: + self.throttleWrites() + throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0 + self.unthrottleWritesID = self.callLater(throttleTime, + self.unthrottleWrites) + # reset for next round + self.writtenThisSecond = 0 + self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth) + + + def throttleReads(self): + """ + Throttle reads on all protocols. + """ + log.msg("Throttling reads on %s" % self) + for p in self.protocols.keys(): + p.throttleReads() + + + def unthrottleReads(self): + """ + Stop throttling reads on all protocols. + """ + self.unthrottleReadsID = None + log.msg("Stopped throttling reads on %s" % self) + for p in self.protocols.keys(): + p.unthrottleReads() + + + def throttleWrites(self): + """ + Throttle writes on all protocols. + """ + log.msg("Throttling writes on %s" % self) + for p in self.protocols.keys(): + p.throttleWrites() + + + def unthrottleWrites(self): + """ + Stop throttling writes on all protocols. + """ + self.unthrottleWritesID = None + log.msg("Stopped throttling writes on %s" % self) + for p in self.protocols.keys(): + p.unthrottleWrites() + + + def buildProtocol(self, addr): + if self.connectionCount == 0: + if self.readLimit is not None: + self.checkReadBandwidth() + if self.writeLimit is not None: + self.checkWriteBandwidth() + + if self.connectionCount < self.maxConnectionCount: + self.connectionCount += 1 + return WrappingFactory.buildProtocol(self, addr) + else: + log.msg("Max connection count reached!") + return None + + + def unregisterProtocol(self, p): + WrappingFactory.unregisterProtocol(self, p) + self.connectionCount -= 1 + if self.connectionCount == 0: + if self.unthrottleReadsID is not None: + self.unthrottleReadsID.cancel() + if self.checkReadBandwidthID is not None: + self.checkReadBandwidthID.cancel() + if self.unthrottleWritesID is not None: + self.unthrottleWritesID.cancel() + if self.checkWriteBandwidthID is not None: + self.checkWriteBandwidthID.cancel() + + + +class SpewingProtocol(ProtocolWrapper): + def dataReceived(self, data): + log.msg("Received: %r" % data) + ProtocolWrapper.dataReceived(self,data) + + def write(self, data): + log.msg("Sending: %r" % data) + ProtocolWrapper.write(self,data) + + + +class SpewingFactory(WrappingFactory): + protocol = SpewingProtocol + + + +class LimitConnectionsByPeer(WrappingFactory): + + maxConnectionsPerPeer = 5 + + def startFactory(self): + self.peerConnections = {} + + def buildProtocol(self, addr): + peerHost = addr[0] + connectionCount = self.peerConnections.get(peerHost, 0) + if connectionCount >= self.maxConnectionsPerPeer: + return None + self.peerConnections[peerHost] = connectionCount + 1 + return WrappingFactory.buildProtocol(self, addr) + + def unregisterProtocol(self, p): + peerHost = p.getPeer()[1] + self.peerConnections[peerHost] -= 1 + if self.peerConnections[peerHost] == 0: + del self.peerConnections[peerHost] + + +class LimitTotalConnectionsFactory(ServerFactory): + """ + Factory that limits the number of simultaneous connections. + + @type connectionCount: C{int} + @ivar connectionCount: number of current connections. + @type connectionLimit: C{int} or L{None} + @cvar connectionLimit: maximum number of connections. + @type overflowProtocol: L{Protocol} or L{None} + @cvar overflowProtocol: Protocol to use for new connections when + connectionLimit is exceeded. If L{None} (the default value), excess + connections will be closed immediately. + """ + connectionCount = 0 + connectionLimit = None + overflowProtocol = None + + def buildProtocol(self, addr): + if (self.connectionLimit is None or + self.connectionCount < self.connectionLimit): + # Build the normal protocol + wrappedProtocol = self.protocol() + elif self.overflowProtocol is None: + # Just drop the connection + return None + else: + # Too many connections, so build the overflow protocol + wrappedProtocol = self.overflowProtocol() + + wrappedProtocol.factory = self + protocol = ProtocolWrapper(self, wrappedProtocol) + self.connectionCount += 1 + return protocol + + def registerProtocol(self, p): + pass + + def unregisterProtocol(self, p): + self.connectionCount -= 1 + + + +class TimeoutProtocol(ProtocolWrapper): + """ + Protocol that automatically disconnects when the connection is idle. + """ + + def __init__(self, factory, wrappedProtocol, timeoutPeriod): + """ + Constructor. + + @param factory: An L{TimeoutFactory}. + @param wrappedProtocol: A L{Protocol} to wrapp. + @param timeoutPeriod: Number of seconds to wait for activity before + timing out. + """ + ProtocolWrapper.__init__(self, factory, wrappedProtocol) + self.timeoutCall = None + self.timeoutPeriod = None + self.setTimeout(timeoutPeriod) + + + def setTimeout(self, timeoutPeriod=None): + """ + Set a timeout. + + This will cancel any existing timeouts. + + @param timeoutPeriod: If not L{None}, change the timeout period. + Otherwise, use the existing value. + """ + self.cancelTimeout() + self.timeoutPeriod = timeoutPeriod + if timeoutPeriod is not None: + self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc) + + + def cancelTimeout(self): + """ + Cancel the timeout. + + If the timeout was already cancelled, this does nothing. + """ + self.timeoutPeriod = None + if self.timeoutCall: + try: + self.timeoutCall.cancel() + except (error.AlreadyCalled, error.AlreadyCancelled): + pass + self.timeoutCall = None + + + def resetTimeout(self): + """ + Reset the timeout, usually because some activity just happened. + """ + if self.timeoutCall: + self.timeoutCall.reset(self.timeoutPeriod) + + + def write(self, data): + self.resetTimeout() + ProtocolWrapper.write(self, data) + + + def writeSequence(self, seq): + self.resetTimeout() + ProtocolWrapper.writeSequence(self, seq) + + + def dataReceived(self, data): + self.resetTimeout() + ProtocolWrapper.dataReceived(self, data) + + + def connectionLost(self, reason): + self.cancelTimeout() + ProtocolWrapper.connectionLost(self, reason) + + + def timeoutFunc(self): + """ + This method is called when the timeout is triggered. + + By default it calls I{loseConnection}. Override this if you want + something else to happen. + """ + self.loseConnection() + + + +class TimeoutFactory(WrappingFactory): + """ + Factory for TimeoutWrapper. + """ + protocol = TimeoutProtocol + + + def __init__(self, wrappedFactory, timeoutPeriod=30*60): + self.timeoutPeriod = timeoutPeriod + WrappingFactory.__init__(self, wrappedFactory) + + + def buildProtocol(self, addr): + return self.protocol(self, self.wrappedFactory.buildProtocol(addr), + timeoutPeriod=self.timeoutPeriod) + + + def callLater(self, period, func): + """ + Wrapper around + L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>} + for test purpose. + """ + from twisted.internet import reactor + return reactor.callLater(period, func) + + + +class TrafficLoggingProtocol(ProtocolWrapper): + + def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None, + number=0): + """ + @param factory: factory which created this protocol. + @type factory: L{protocol.Factory}. + @param wrappedProtocol: the underlying protocol. + @type wrappedProtocol: C{protocol.Protocol}. + @param logfile: file opened for writing used to write log messages. + @type logfile: C{file} + @param lengthLimit: maximum size of the datareceived logged. + @type lengthLimit: C{int} + @param number: identifier of the connection. + @type number: C{int}. + """ + ProtocolWrapper.__init__(self, factory, wrappedProtocol) + self.logfile = logfile + self.lengthLimit = lengthLimit + self._number = number + + + def _log(self, line): + self.logfile.write(line + '\n') + self.logfile.flush() + + + def _mungeData(self, data): + if self.lengthLimit and len(data) > self.lengthLimit: + data = data[:self.lengthLimit - 12] + '<... elided>' + return data + + + # IProtocol + def connectionMade(self): + self._log('*') + return ProtocolWrapper.connectionMade(self) + + + def dataReceived(self, data): + self._log('C %d: %r' % (self._number, self._mungeData(data))) + return ProtocolWrapper.dataReceived(self, data) + + + def connectionLost(self, reason): + self._log('C %d: %r' % (self._number, reason)) + return ProtocolWrapper.connectionLost(self, reason) + + + # ITransport + def write(self, data): + self._log('S %d: %r' % (self._number, self._mungeData(data))) + return ProtocolWrapper.write(self, data) + + + def writeSequence(self, iovec): + self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec])) + return ProtocolWrapper.writeSequence(self, iovec) + + + def loseConnection(self): + self._log('S %d: *' % (self._number,)) + return ProtocolWrapper.loseConnection(self) + + + +class TrafficLoggingFactory(WrappingFactory): + protocol = TrafficLoggingProtocol + + _counter = 0 + + def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None): + self.logfilePrefix = logfilePrefix + self.lengthLimit = lengthLimit + WrappingFactory.__init__(self, wrappedFactory) + + + def open(self, name): + return open(name, 'w') + + + def buildProtocol(self, addr): + self._counter += 1 + logfile = self.open(self.logfilePrefix + '-' + str(self._counter)) + return self.protocol(self, self.wrappedFactory.buildProtocol(addr), + logfile, self.lengthLimit, self._counter) + + + def resetCounter(self): + """ + Reset the value of the counter used to identify connections. + """ + self._counter = 0 + + + +class TimeoutMixin: + """ + Mixin for protocols which wish to timeout connections. + + Protocols that mix this in have a single timeout, set using L{setTimeout}. + When the timeout is hit, L{timeoutConnection} is called, which, by + default, closes the connection. + + @cvar timeOut: The number of seconds after which to timeout the connection. + """ + timeOut = None + + __timeoutCall = None + + def callLater(self, period, func): + """ + Wrapper around + L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>} + for test purpose. + """ + from twisted.internet import reactor + return reactor.callLater(period, func) + + + def resetTimeout(self): + """ + Reset the timeout count down. + + If the connection has already timed out, then do nothing. If the + timeout has been cancelled (probably using C{setTimeout(None)}), also + do nothing. + + It's often a good idea to call this when the protocol has received + some meaningful input from the other end of the connection. "I've got + some data, they're still there, reset the timeout". + """ + if self.__timeoutCall is not None and self.timeOut is not None: + self.__timeoutCall.reset(self.timeOut) + + def setTimeout(self, period): + """ + Change the timeout period + + @type period: C{int} or L{None} + @param period: The period, in seconds, to change the timeout to, or + L{None} to disable the timeout. + """ + prev = self.timeOut + self.timeOut = period + + if self.__timeoutCall is not None: + if period is None: + try: + self.__timeoutCall.cancel() + except (error.AlreadyCancelled, error.AlreadyCalled): + # Do nothing if the call was already consumed. + pass + self.__timeoutCall = None + else: + self.__timeoutCall.reset(period) + elif period is not None: + self.__timeoutCall = self.callLater(period, self.__timedOut) + + return prev + + def __timedOut(self): + self.__timeoutCall = None + self.timeoutConnection() + + def timeoutConnection(self): + """ + Called when the connection times out. + + Override to define behavior other than dropping the connection. + """ + self.transport.loseConnection() diff --git a/contrib/python/Twisted/py2/twisted/protocols/portforward.py b/contrib/python/Twisted/py2/twisted/protocols/portforward.py new file mode 100644 index 0000000000..a3c39549ae --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/portforward.py @@ -0,0 +1,99 @@ + +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +A simple port forwarder. +""" + +# Twisted imports +from twisted.internet import protocol +from twisted.python import log + +class Proxy(protocol.Protocol): + noisy = True + + peer = None + + def setPeer(self, peer): + self.peer = peer + + + def connectionLost(self, reason): + if self.peer is not None: + self.peer.transport.loseConnection() + self.peer = None + elif self.noisy: + log.msg("Unable to connect to peer: %s" % (reason,)) + + + def dataReceived(self, data): + self.peer.transport.write(data) + + + +class ProxyClient(Proxy): + def connectionMade(self): + self.peer.setPeer(self) + + # Wire this and the peer transport together to enable + # flow control (this stops connections from filling + # this proxy memory when one side produces data at a + # higher rate than the other can consume). + self.transport.registerProducer(self.peer.transport, True) + self.peer.transport.registerProducer(self.transport, True) + + # We're connected, everybody can read to their hearts content. + self.peer.transport.resumeProducing() + + + +class ProxyClientFactory(protocol.ClientFactory): + + protocol = ProxyClient + + def setServer(self, server): + self.server = server + + + def buildProtocol(self, *args, **kw): + prot = protocol.ClientFactory.buildProtocol(self, *args, **kw) + prot.setPeer(self.server) + return prot + + + def clientConnectionFailed(self, connector, reason): + self.server.transport.loseConnection() + + + +class ProxyServer(Proxy): + + clientProtocolFactory = ProxyClientFactory + reactor = None + + def connectionMade(self): + # Don't read anything from the connecting client until we have + # somewhere to send it to. + self.transport.pauseProducing() + + client = self.clientProtocolFactory() + client.setServer(self) + + if self.reactor is None: + from twisted.internet import reactor + self.reactor = reactor + self.reactor.connectTCP(self.factory.host, self.factory.port, client) + + + +class ProxyFactory(protocol.Factory): + """ + Factory for port forwarder. + """ + + protocol = ProxyServer + + def __init__(self, host, port): + self.host = host + self.port = port diff --git a/contrib/python/Twisted/py2/twisted/protocols/postfix.py b/contrib/python/Twisted/py2/twisted/protocols/postfix.py new file mode 100644 index 0000000000..445b88cb05 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/postfix.py @@ -0,0 +1,158 @@ +# -*- test-case-name: twisted.test.test_postfix -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Postfix mail transport agent related protocols. +""" + +import sys +try: + # Python 2 + from UserDict import UserDict +except ImportError: + # Python 3 + from collections import UserDict + +try: + # Python 2 + from urllib import quote as _quote, unquote as _unquote +except ImportError: + # Python 3 + from urllib.parse import quote as _quote, unquote as _unquote + +from twisted.protocols import basic +from twisted.protocols import policies +from twisted.internet import protocol, defer +from twisted.python import log +from twisted.python.compat import unicode + +# urllib's quote functions just happen to match +# the postfix semantics. +def quote(s): + quoted = _quote(s) + if isinstance(quoted, unicode): + quoted = quoted.encode("ascii") + return quoted + + + +def unquote(s): + if isinstance(s, bytes): + s = s.decode("ascii") + quoted = _unquote(s) + return quoted.encode("ascii") + + + +class PostfixTCPMapServer(basic.LineReceiver, policies.TimeoutMixin): + """ + Postfix mail transport agent TCP map protocol implementation. + + Receive requests for data matching given key via lineReceived, + asks it's factory for the data with self.factory.get(key), and + returns the data to the requester. None means no entry found. + + You can use postfix's postmap to test the map service:: + + /usr/sbin/postmap -q KEY tcp:localhost:4242 + + """ + + timeout = 600 + delimiter = b'\n' + + def connectionMade(self): + self.setTimeout(self.timeout) + + + + def sendCode(self, code, message=b''): + """ + Send an SMTP-like code with a message. + """ + self.sendLine(str(code).encode("ascii") + b' ' + message) + + + + def lineReceived(self, line): + self.resetTimeout() + try: + request, params = line.split(None, 1) + except ValueError: + request = line + params = None + try: + f = getattr(self, u'do_' + request.decode("ascii")) + except AttributeError: + self.sendCode(400, b'unknown command') + else: + try: + f(params) + except: + excInfo = str(sys.exc_info()[1]).encode("ascii") + self.sendCode(400, b'Command ' + request + b' failed: ' + + excInfo) + + + + def do_get(self, key): + if key is None: + self.sendCode(400, b"Command 'get' takes 1 parameters.") + else: + d = defer.maybeDeferred(self.factory.get, key) + d.addCallbacks(self._cbGot, self._cbNot) + d.addErrback(log.err) + + + + def _cbNot(self, fail): + msg = fail.getErrorMessage().encode("ascii") + self.sendCode(400, msg) + + + + def _cbGot(self, value): + if value is None: + self.sendCode(500) + else: + self.sendCode(200, quote(value)) + + + + def do_put(self, keyAndValue): + if keyAndValue is None: + self.sendCode(400, b"Command 'put' takes 2 parameters.") + else: + try: + key, value = keyAndValue.split(None, 1) + except ValueError: + self.sendCode(400, b"Command 'put' takes 2 parameters.") + else: + self.sendCode(500, b'put is not implemented yet.') + + + +class PostfixTCPMapDictServerFactory(UserDict, protocol.ServerFactory): + """ + An in-memory dictionary factory for PostfixTCPMapServer. + """ + + protocol = PostfixTCPMapServer + + + +class PostfixTCPMapDeferringDictServerFactory(protocol.ServerFactory): + """ + An in-memory dictionary factory for PostfixTCPMapServer. + """ + + protocol = PostfixTCPMapServer + + def __init__(self, data=None): + self.data = {} + if data is not None: + self.data.update(data) + + def get(self, key): + return defer.succeed(self.data.get(key)) diff --git a/contrib/python/Twisted/py2/twisted/protocols/shoutcast.py b/contrib/python/Twisted/py2/twisted/protocols/shoutcast.py new file mode 100644 index 0000000000..e2be938995 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/shoutcast.py @@ -0,0 +1,111 @@ +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Chop up shoutcast stream into MP3s and metadata, if available. +""" + +from twisted.web import http +from twisted import copyright + + +class ShoutcastClient(http.HTTPClient): + """ + Shoutcast HTTP stream. + + Modes can be 'length', 'meta' and 'mp3'. + + See U{http://www.smackfu.com/stuff/programming/shoutcast.html} + for details on the protocol. + """ + + userAgent = "Twisted Shoutcast client " + copyright.version + + def __init__(self, path="/"): + self.path = path + self.got_metadata = False + self.metaint = None + self.metamode = "mp3" + self.databuffer = "" + + def connectionMade(self): + self.sendCommand("GET", self.path) + self.sendHeader("User-Agent", self.userAgent) + self.sendHeader("Icy-MetaData", "1") + self.endHeaders() + + def lineReceived(self, line): + # fix shoutcast crappiness + if not self.firstLine and line: + if len(line.split(": ", 1)) == 1: + line = line.replace(":", ": ", 1) + http.HTTPClient.lineReceived(self, line) + + def handleHeader(self, key, value): + if key.lower() == 'icy-metaint': + self.metaint = int(value) + self.got_metadata = True + + def handleEndHeaders(self): + # Lets check if we got metadata, and set the + # appropriate handleResponsePart method. + if self.got_metadata: + # if we have metadata, then it has to be parsed out of the data stream + self.handleResponsePart = self.handleResponsePart_with_metadata + else: + # otherwise, all the data is MP3 data + self.handleResponsePart = self.gotMP3Data + + def handleResponsePart_with_metadata(self, data): + self.databuffer += data + while self.databuffer: + stop = getattr(self, "handle_%s" % self.metamode)() + if stop: + return + + def handle_length(self): + self.remaining = ord(self.databuffer[0]) * 16 + self.databuffer = self.databuffer[1:] + self.metamode = "meta" + + def handle_mp3(self): + if len(self.databuffer) > self.metaint: + self.gotMP3Data(self.databuffer[:self.metaint]) + self.databuffer = self.databuffer[self.metaint:] + self.metamode = "length" + else: + return 1 + + def handle_meta(self): + if len(self.databuffer) >= self.remaining: + if self.remaining: + data = self.databuffer[:self.remaining] + self.gotMetaData(self.parseMetadata(data)) + self.databuffer = self.databuffer[self.remaining:] + self.metamode = "mp3" + else: + return 1 + + def parseMetadata(self, data): + meta = [] + for chunk in data.split(';'): + chunk = chunk.strip().replace("\x00", "") + if not chunk: + continue + key, value = chunk.split('=', 1) + if value.startswith("'") and value.endswith("'"): + value = value[1:-1] + meta.append((key, value)) + return meta + + def gotMetaData(self, metadata): + """Called with a list of (key, value) pairs of metadata, + if metadata is available on the server. + + Will only be called on non-empty metadata. + """ + raise NotImplementedError("implement in subclass") + + def gotMP3Data(self, data): + """Called with chunk of MP3 data.""" + raise NotImplementedError("implement in subclass") diff --git a/contrib/python/Twisted/py2/twisted/protocols/sip.py b/contrib/python/Twisted/py2/twisted/protocols/sip.py new file mode 100644 index 0000000000..6a8429ed14 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/sip.py @@ -0,0 +1,1294 @@ +# -*- test-case-name: twisted.test.test_sip -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Session Initialization Protocol. + +Documented in RFC 2543. +[Superseded by 3261] +""" + +import socket +import time +import warnings + +from zope.interface import implementer, Interface +from collections import OrderedDict + +from twisted import cred +from twisted.internet import protocol, defer, reactor +from twisted.protocols import basic +from twisted.python import log +from twisted.python.compat import _PY3, iteritems, unicode + +PORT = 5060 + +# SIP headers have short forms +shortHeaders = {"call-id": "i", + "contact": "m", + "content-encoding": "e", + "content-length": "l", + "content-type": "c", + "from": "f", + "subject": "s", + "to": "t", + "via": "v", + } + +longHeaders = {} +for k, v in shortHeaders.items(): + longHeaders[v] = k +del k, v + +statusCodes = { + 100: "Trying", + 180: "Ringing", + 181: "Call Is Being Forwarded", + 182: "Queued", + 183: "Session Progress", + + 200: "OK", + + 300: "Multiple Choices", + 301: "Moved Permanently", + 302: "Moved Temporarily", + 303: "See Other", + 305: "Use Proxy", + 380: "Alternative Service", + + 400: "Bad Request", + 401: "Unauthorized", + 402: "Payment Required", + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 406: "Not Acceptable", + 407: "Proxy Authentication Required", + 408: "Request Timeout", + 409: "Conflict", # Not in RFC3261 + 410: "Gone", + 411: "Length Required", # Not in RFC3261 + 413: "Request Entity Too Large", + 414: "Request-URI Too Large", + 415: "Unsupported Media Type", + 416: "Unsupported URI Scheme", + 420: "Bad Extension", + 421: "Extension Required", + 423: "Interval Too Brief", + 480: "Temporarily Unavailable", + 481: "Call/Transaction Does Not Exist", + 482: "Loop Detected", + 483: "Too Many Hops", + 484: "Address Incomplete", + 485: "Ambiguous", + 486: "Busy Here", + 487: "Request Terminated", + 488: "Not Acceptable Here", + 491: "Request Pending", + 493: "Undecipherable", + + 500: "Internal Server Error", + 501: "Not Implemented", + 502: "Bad Gateway", # No donut + 503: "Service Unavailable", + 504: "Server Time-out", + 505: "SIP Version not supported", + 513: "Message Too Large", + + 600: "Busy Everywhere", + 603: "Decline", + 604: "Does not exist anywhere", + 606: "Not Acceptable", +} + +specialCases = { + 'cseq': 'CSeq', + 'call-id': 'Call-ID', + 'www-authenticate': 'WWW-Authenticate', +} + + +def dashCapitalize(s): + """ + Capitalize a string, making sure to treat '-' as a word separator + """ + return '-'.join([ x.capitalize() for x in s.split('-')]) + + + +def unq(s): + if s[0] == s[-1] == '"': + return s[1:-1] + return s + + + +_absent = object() + +class Via(object): + """ + A L{Via} is a SIP Via header, representing a segment of the path taken by + the request. + + See RFC 3261, sections 8.1.1.7, 18.2.2, and 20.42. + + @ivar transport: Network protocol used for this leg. (Probably either "TCP" + or "UDP".) + @type transport: C{str} + @ivar branch: Unique identifier for this request. + @type branch: C{str} + @ivar host: Hostname or IP for this leg. + @type host: C{str} + @ivar port: Port used for this leg. + @type port C{int}, or None. + @ivar rportRequested: Whether to request RFC 3581 client processing or not. + @type rportRequested: C{bool} + @ivar rportValue: Servers wishing to honor requests for RFC 3581 processing + should set this parameter to the source port the request was received + from. + @type rportValue: C{int}, or None. + + @ivar ttl: Time-to-live for requests on multicast paths. + @type ttl: C{int}, or None. + @ivar maddr: The destination multicast address, if any. + @type maddr: C{str}, or None. + @ivar hidden: Obsolete in SIP 2.0. + @type hidden: C{bool} + @ivar otherParams: Any other parameters in the header. + @type otherParams: C{dict} + """ + + def __init__(self, host, port=PORT, transport="UDP", ttl=None, + hidden=False, received=None, rport=_absent, branch=None, + maddr=None, **kw): + """ + Set parameters of this Via header. All arguments correspond to + attributes of the same name. + + To maintain compatibility with old SIP + code, the 'rport' argument is used to determine the values of + C{rportRequested} and C{rportValue}. If None, C{rportRequested} is set + to True. (The deprecated method for doing this is to pass True.) If an + integer, C{rportValue} is set to the given value. + + Any arguments not explicitly named here are collected into the + C{otherParams} dict. + """ + self.transport = transport + self.host = host + self.port = port + self.ttl = ttl + self.hidden = hidden + self.received = received + if rport is True: + warnings.warn( + "rport=True is deprecated since Twisted 9.0.", + DeprecationWarning, + stacklevel=2) + self.rportValue = None + self.rportRequested = True + elif rport is None: + self.rportValue = None + self.rportRequested = True + elif rport is _absent: + self.rportValue = None + self.rportRequested = False + else: + self.rportValue = rport + self.rportRequested = False + + self.branch = branch + self.maddr = maddr + self.otherParams = kw + + + def _getrport(self): + """ + Returns the rport value expected by the old SIP code. + """ + if self.rportRequested == True: + return True + elif self.rportValue is not None: + return self.rportValue + else: + return None + + + def _setrport(self, newRPort): + """ + L{Base._fixupNAT} sets C{rport} directly, so this method sets + C{rportValue} based on that. + + @param newRPort: The new rport value. + @type newRPort: C{int} + """ + self.rportValue = newRPort + self.rportRequested = False + + rport = property(_getrport, _setrport) + + def toString(self): + """ + Serialize this header for use in a request or response. + """ + s = "SIP/2.0/%s %s:%s" % (self.transport, self.host, self.port) + if self.hidden: + s += ";hidden" + for n in "ttl", "branch", "maddr", "received": + value = getattr(self, n) + if value is not None: + s += ";%s=%s" % (n, value) + if self.rportRequested: + s += ";rport" + elif self.rportValue is not None: + s += ";rport=%s" % (self.rport,) + + etc = sorted(self.otherParams.items()) + for k, v in etc: + if v is None: + s += ";" + k + else: + s += ";%s=%s" % (k, v) + return s + + + +def parseViaHeader(value): + """ + Parse a Via header. + + @return: The parsed version of this header. + @rtype: L{Via} + """ + parts = value.split(";") + sent, params = parts[0], parts[1:] + protocolinfo, by = sent.split(" ", 1) + by = by.strip() + result = {} + pname, pversion, transport = protocolinfo.split("/") + if pname != "SIP" or pversion != "2.0": + raise ValueError("wrong protocol or version: %r" % (value,)) + result["transport"] = transport + if ":" in by: + host, port = by.split(":") + result["port"] = int(port) + result["host"] = host + else: + result["host"] = by + for p in params: + # It's the comment-striping dance! + p = p.strip().split(" ", 1) + if len(p) == 1: + p, comment = p[0], "" + else: + p, comment = p + if p == "hidden": + result["hidden"] = True + continue + parts = p.split("=", 1) + if len(parts) == 1: + name, value = parts[0], None + else: + name, value = parts + if name in ("rport", "ttl"): + value = int(value) + result[name] = value + return Via(**result) + + + +class URL: + """ + A SIP URL. + """ + + def __init__(self, host, username=None, password=None, port=None, + transport=None, usertype=None, method=None, + ttl=None, maddr=None, tag=None, other=None, headers=None): + self.username = username + self.host = host + self.password = password + self.port = port + self.transport = transport + self.usertype = usertype + self.method = method + self.tag = tag + self.ttl = ttl + self.maddr = maddr + if other == None: + self.other = [] + else: + self.other = other + if headers == None: + self.headers = {} + else: + self.headers = headers + + + def toString(self): + l = []; w = l.append + w("sip:") + if self.username != None: + w(self.username) + if self.password != None: + w(":%s" % self.password) + w("@") + w(self.host) + if self.port != None: + w(":%d" % self.port) + if self.usertype != None: + w(";user=%s" % self.usertype) + for n in ("transport", "ttl", "maddr", "method", "tag"): + v = getattr(self, n) + if v != None: + w(";%s=%s" % (n, v)) + for v in self.other: + w(";%s" % v) + if self.headers: + w("?") + w("&".join([("%s=%s" % (specialCases.get(h) or dashCapitalize(h), v)) for (h, v) in self.headers.items()])) + return "".join(l) + + + def __str__(self): + return self.toString() + + + def __repr__(self): + return '<URL %s:%s@%s:%r/%s>' % (self.username, self.password, self.host, self.port, self.transport) + + + +def parseURL(url, host=None, port=None): + """ + Return string into URL object. + + URIs are of form 'sip:user@example.com'. + """ + d = {} + if not url.startswith("sip:"): + raise ValueError("unsupported scheme: " + url[:4]) + parts = url[4:].split(";") + userdomain, params = parts[0], parts[1:] + udparts = userdomain.split("@", 1) + if len(udparts) == 2: + userpass, hostport = udparts + upparts = userpass.split(":", 1) + if len(upparts) == 1: + d["username"] = upparts[0] + else: + d["username"] = upparts[0] + d["password"] = upparts[1] + else: + hostport = udparts[0] + hpparts = hostport.split(":", 1) + if len(hpparts) == 1: + d["host"] = hpparts[0] + else: + d["host"] = hpparts[0] + d["port"] = int(hpparts[1]) + if host != None: + d["host"] = host + if port != None: + d["port"] = port + for p in params: + if p == params[-1] and "?" in p: + d["headers"] = h = {} + p, headers = p.split("?", 1) + for header in headers.split("&"): + k, v = header.split("=") + h[k] = v + nv = p.split("=", 1) + if len(nv) == 1: + d.setdefault("other", []).append(p) + continue + name, value = nv + if name == "user": + d["usertype"] = value + elif name in ("transport", "ttl", "maddr", "method", "tag"): + if name == "ttl": + value = int(value) + d[name] = value + else: + d.setdefault("other", []).append(p) + return URL(**d) + + + +def cleanRequestURL(url): + """ + Clean a URL from a Request line. + """ + url.transport = None + url.maddr = None + url.ttl = None + url.headers = {} + + + +def parseAddress(address, host=None, port=None, clean=0): + """ + Return (name, uri, params) for From/To/Contact header. + + @param clean: remove unnecessary info, usually for From and To headers. + """ + address = address.strip() + # Simple 'sip:foo' case + if address.startswith("sip:"): + return "", parseURL(address, host=host, port=port), {} + params = {} + name, url = address.split("<", 1) + name = name.strip() + if name.startswith('"'): + name = name[1:] + if name.endswith('"'): + name = name[:-1] + url, paramstring = url.split(">", 1) + url = parseURL(url, host=host, port=port) + paramstring = paramstring.strip() + if paramstring: + for l in paramstring.split(";"): + if not l: + continue + k, v = l.split("=") + params[k] = v + if clean: + # RFC 2543 6.21 + url.ttl = None + url.headers = {} + url.transport = None + url.maddr = None + return name, url, params + + + +class SIPError(Exception): + def __init__(self, code, phrase=None): + if phrase is None: + phrase = statusCodes[code] + Exception.__init__(self, "SIP error (%d): %s" % (code, phrase)) + self.code = code + self.phrase = phrase + + + +class RegistrationError(SIPError): + """ + Registration was not possible. + """ + + + +class Message: + """ + A SIP message. + """ + + length = None + + def __init__(self): + self.headers = OrderedDict() # Map name to list of values + self.body = "" + self.finished = 0 + + + def addHeader(self, name, value): + name = name.lower() + name = longHeaders.get(name, name) + if name == "content-length": + self.length = int(value) + self.headers.setdefault(name,[]).append(value) + + + def bodyDataReceived(self, data): + self.body += data + + + def creationFinished(self): + if (self.length != None) and (self.length != len(self.body)): + raise ValueError("wrong body length") + self.finished = 1 + + + def toString(self): + s = "%s\r\n" % self._getHeaderLine() + for n, vs in self.headers.items(): + for v in vs: + s += "%s: %s\r\n" % (specialCases.get(n) or dashCapitalize(n), v) + s += "\r\n" + s += self.body + return s + + + def _getHeaderLine(self): + raise NotImplementedError + + + +class Request(Message): + """ + A Request for a URI + """ + + def __init__(self, method, uri, version="SIP/2.0"): + Message.__init__(self) + self.method = method + if isinstance(uri, URL): + self.uri = uri + else: + self.uri = parseURL(uri) + cleanRequestURL(self.uri) + + + def __repr__(self): + return "<SIP Request %d:%s %s>" % (id(self), self.method, self.uri.toString()) + + + def _getHeaderLine(self): + return "%s %s SIP/2.0" % (self.method, self.uri.toString()) + + + +class Response(Message): + """ + A Response to a URI Request + """ + + def __init__(self, code, phrase=None, version="SIP/2.0"): + Message.__init__(self) + self.code = code + if phrase == None: + phrase = statusCodes[code] + self.phrase = phrase + + + def __repr__(self): + return "<SIP Response %d:%s>" % (id(self), self.code) + + + def _getHeaderLine(self): + return "SIP/2.0 %s %s" % (self.code, self.phrase) + + + +class MessagesParser(basic.LineReceiver): + """ + A SIP messages parser. + + Expects dataReceived, dataDone repeatedly, + in that order. Shouldn't be connected to actual transport. + """ + + version = "SIP/2.0" + acceptResponses = 1 + acceptRequests = 1 + state = "firstline" # Or "headers", "body" or "invalid" + + debug = 0 + + def __init__(self, messageReceivedCallback): + self.messageReceived = messageReceivedCallback + self.reset() + + + def reset(self, remainingData=""): + self.state = "firstline" + self.length = None # Body length + self.bodyReceived = 0 # How much of the body we received + self.message = None + self.header = None + self.setLineMode(remainingData) + + + def invalidMessage(self): + self.state = "invalid" + self.setRawMode() + + + def dataDone(self): + """ + Clear out any buffered data that may be hanging around. + """ + self.clearLineBuffer() + if self.state == "firstline": + return + if self.state != "body": + self.reset() + return + if self.length == None: + # No content-length header, so end of data signals message done + self.messageDone() + elif self.length < self.bodyReceived: + # Aborted in the middle + self.reset() + else: + # We have enough data and message wasn't finished? something is wrong + raise RuntimeError("this should never happen") + + + def dataReceived(self, data): + try: + if isinstance(data, unicode): + data = data.encode("utf-8") + basic.LineReceiver.dataReceived(self, data) + except: + log.err() + self.invalidMessage() + + + def handleFirstLine(self, line): + """ + Expected to create self.message. + """ + raise NotImplementedError + + + def lineLengthExceeded(self, line): + self.invalidMessage() + + + def lineReceived(self, line): + if _PY3 and isinstance(line, bytes): + line = line.decode("utf-8") + + if self.state == "firstline": + while line.startswith("\n") or line.startswith("\r"): + line = line[1:] + if not line: + return + try: + a, b, c = line.split(" ", 2) + except ValueError: + self.invalidMessage() + return + if a == "SIP/2.0" and self.acceptResponses: + # Response + try: + code = int(b) + except ValueError: + self.invalidMessage() + return + self.message = Response(code, c) + elif c == "SIP/2.0" and self.acceptRequests: + self.message = Request(a, b) + else: + self.invalidMessage() + return + self.state = "headers" + return + else: + assert self.state == "headers" + if line: + # Multiline header + if line.startswith(" ") or line.startswith("\t"): + name, value = self.header + self.header = name, (value + line.lstrip()) + else: + # New header + if self.header: + self.message.addHeader(*self.header) + self.header = None + try: + name, value = line.split(":", 1) + except ValueError: + self.invalidMessage() + return + self.header = name, value.lstrip() + # XXX we assume content-length won't be multiline + if name.lower() == "content-length": + try: + self.length = int(value.lstrip()) + except ValueError: + self.invalidMessage() + return + else: + # CRLF, we now have message body until self.length bytes, + # or if no length was given, until there is no more data + # from the connection sending us data. + self.state = "body" + if self.header: + self.message.addHeader(*self.header) + self.header = None + if self.length == 0: + self.messageDone() + return + self.setRawMode() + + + def messageDone(self, remainingData=""): + assert self.state == "body" + self.message.creationFinished() + self.messageReceived(self.message) + self.reset(remainingData) + + + def rawDataReceived(self, data): + assert self.state in ("body", "invalid") + if _PY3 and isinstance(data, bytes): + data = data.decode("utf-8") + if self.state == "invalid": + return + if self.length == None: + self.message.bodyDataReceived(data) + else: + dataLen = len(data) + expectedLen = self.length - self.bodyReceived + if dataLen > expectedLen: + self.message.bodyDataReceived(data[:expectedLen]) + self.messageDone(data[expectedLen:]) + return + else: + self.bodyReceived += dataLen + self.message.bodyDataReceived(data) + if self.bodyReceived == self.length: + self.messageDone() + + + +class Base(protocol.DatagramProtocol): + """ + Base class for SIP clients and servers. + """ + + PORT = PORT + debug = False + + def __init__(self): + self.messages = [] + self.parser = MessagesParser(self.addMessage) + + + def addMessage(self, msg): + self.messages.append(msg) + + + def datagramReceived(self, data, addr): + self.parser.dataReceived(data) + self.parser.dataDone() + for m in self.messages: + self._fixupNAT(m, addr) + if self.debug: + log.msg("Received %r from %r" % (m.toString(), addr)) + if isinstance(m, Request): + self.handle_request(m, addr) + else: + self.handle_response(m, addr) + self.messages[:] = [] + + + def _fixupNAT(self, message, sourcePeer): + # RFC 2543 6.40.2, + (srcHost, srcPort) = sourcePeer + senderVia = parseViaHeader(message.headers["via"][0]) + if senderVia.host != srcHost: + senderVia.received = srcHost + if senderVia.port != srcPort: + senderVia.rport = srcPort + message.headers["via"][0] = senderVia.toString() + elif senderVia.rport == True: + senderVia.received = srcHost + senderVia.rport = srcPort + message.headers["via"][0] = senderVia.toString() + + + def deliverResponse(self, responseMessage): + """ + Deliver response. + + Destination is based on topmost Via header. + """ + destVia = parseViaHeader(responseMessage.headers["via"][0]) + # XXX we don't do multicast yet + host = destVia.received or destVia.host + port = destVia.rport or destVia.port or self.PORT + destAddr = URL(host=host, port=port) + self.sendMessage(destAddr, responseMessage) + + + def responseFromRequest(self, code, request): + """ + Create a response to a request message. + """ + response = Response(code) + for name in ("via", "to", "from", "call-id", "cseq"): + response.headers[name] = request.headers.get(name, [])[:] + + return response + + + def sendMessage(self, destURL, message): + """ + Send a message. + + @param destURL: C{URL}. This should be a *physical* URL, not a logical one. + @param message: The message to send. + """ + if destURL.transport not in ("udp", None): + raise RuntimeError("only UDP currently supported") + if self.debug: + log.msg("Sending %r to %r" % (message.toString(), destURL)) + data = message.toString() + if isinstance(data, unicode): + data = data.encode("utf-8") + self.transport.write(data, (destURL.host, destURL.port or self.PORT)) + + + def handle_request(self, message, addr): + """ + Override to define behavior for requests received + + @type message: C{Message} + @type addr: C{tuple} + """ + raise NotImplementedError + + + def handle_response(self, message, addr): + """ + Override to define behavior for responses received. + + @type message: C{Message} + @type addr: C{tuple} + """ + raise NotImplementedError + + + +class IContact(Interface): + """ + A user of a registrar or proxy + """ + + + +class Registration: + def __init__(self, secondsToExpiry, contactURL): + self.secondsToExpiry = secondsToExpiry + self.contactURL = contactURL + + + +class IRegistry(Interface): + """ + Allows registration of logical->physical URL mapping. + """ + + def registerAddress(domainURL, logicalURL, physicalURL): + """ + Register the physical address of a logical URL. + + @return: Deferred of C{Registration} or failure with RegistrationError. + """ + + + def unregisterAddress(domainURL, logicalURL, physicalURL): + """ + Unregister the physical address of a logical URL. + + @return: Deferred of C{Registration} or failure with RegistrationError. + """ + + + def getRegistrationInfo(logicalURL): + """ + Get registration info for logical URL. + + @return: Deferred of C{Registration} object or failure of LookupError. + """ + + + +class ILocator(Interface): + """ + Allow looking up physical address for logical URL. + """ + + def getAddress(logicalURL): + """ + Return physical URL of server for logical URL of user. + + @param logicalURL: a logical C{URL}. + @return: Deferred which becomes URL or fails with LookupError. + """ + + + +class Proxy(Base): + """ + SIP proxy. + """ + + PORT = PORT + + locator = None # Object implementing ILocator + + def __init__(self, host=None, port=PORT): + """ + Create new instance. + + @param host: our hostname/IP as set in Via headers. + @param port: our port as set in Via headers. + """ + self.host = host or socket.getfqdn() + self.port = port + Base.__init__(self) + + + def getVia(self): + """ + Return value of Via header for this proxy. + """ + return Via(host=self.host, port=self.port) + + + def handle_request(self, message, addr): + # Send immediate 100/trying message before processing + #self.deliverResponse(self.responseFromRequest(100, message)) + f = getattr(self, "handle_%s_request" % message.method, None) + if f is None: + f = self.handle_request_default + try: + d = f(message, addr) + except SIPError as e: + self.deliverResponse(self.responseFromRequest(e.code, message)) + except: + log.err() + self.deliverResponse(self.responseFromRequest(500, message)) + else: + if d is not None: + d.addErrback(lambda e: + self.deliverResponse(self.responseFromRequest(e.code, message)) + ) + + + def handle_request_default(self, message, sourcePeer): + """ + Default request handler. + + Default behaviour for OPTIONS and unknown methods for proxies + is to forward message on to the client. + + Since at the moment we are stateless proxy, that's basically + everything. + """ + (srcHost, srcPort) = sourcePeer + def _mungContactHeader(uri, message): + message.headers['contact'][0] = uri.toString() + return self.sendMessage(uri, message) + + viaHeader = self.getVia() + if viaHeader.toString() in message.headers["via"]: + # Must be a loop, so drop message + log.msg("Dropping looped message.") + return + + message.headers["via"].insert(0, viaHeader.toString()) + name, uri, tags = parseAddress(message.headers["to"][0], clean=1) + + # This is broken and needs refactoring to use cred + d = self.locator.getAddress(uri) + d.addCallback(self.sendMessage, message) + d.addErrback(self._cantForwardRequest, message) + + + def _cantForwardRequest(self, error, message): + error.trap(LookupError) + del message.headers["via"][0] # This'll be us + self.deliverResponse(self.responseFromRequest(404, message)) + + + def deliverResponse(self, responseMessage): + """ + Deliver response. + + Destination is based on topmost Via header. + """ + destVia = parseViaHeader(responseMessage.headers["via"][0]) + # XXX we don't do multicast yet + host = destVia.received or destVia.host + port = destVia.rport or destVia.port or self.PORT + + destAddr = URL(host=host, port=port) + self.sendMessage(destAddr, responseMessage) + + + def responseFromRequest(self, code, request): + """ + Create a response to a request message. + """ + response = Response(code) + for name in ("via", "to", "from", "call-id", "cseq"): + response.headers[name] = request.headers.get(name, [])[:] + return response + + + def handle_response(self, message, addr): + """ + Default response handler. + """ + v = parseViaHeader(message.headers["via"][0]) + if (v.host, v.port) != (self.host, self.port): + # We got a message not intended for us? + # XXX note this check breaks if we have multiple external IPs + # yay for suck protocols + log.msg("Dropping incorrectly addressed message") + return + del message.headers["via"][0] + if not message.headers["via"]: + # This message is addressed to us + self.gotResponse(message, addr) + return + self.deliverResponse(message) + + + def gotResponse(self, message, addr): + """ + Called with responses that are addressed at this server. + """ + pass + + + +class IAuthorizer(Interface): + def getChallenge(peer): + """ + Generate a challenge the client may respond to. + + @type peer: C{tuple} + @param peer: The client's address + + @rtype: C{str} + @return: The challenge string + """ + + + def decode(response): + """ + Create a credentials object from the given response. + + @type response: C{str} + """ + + + +class RegisterProxy(Proxy): + """ + A proxy that allows registration for a specific domain. + + Unregistered users won't be handled. + """ + + portal = None + + registry = None # Should implement IRegistry + + authorizers = {} + + def __init__(self, *args, **kw): + Proxy.__init__(self, *args, **kw) + self.liveChallenges = {} + + + def handle_ACK_request(self, message, host_port): + # XXX + # ACKs are a client's way of indicating they got the last message + # Responding to them is not a good idea. + # However, we should keep track of terminal messages and re-transmit + # if no ACK is received. + (host, port) = host_port + pass + + + def handle_REGISTER_request(self, message, host_port): + """ + Handle a registration request. + + Currently registration is not proxied. + """ + (host, port) = host_port + if self.portal is None: + # There is no portal. Let anyone in. + self.register(message, host, port) + else: + # There is a portal. Check for credentials. + if "authorization" not in message.headers: + return self.unauthorized(message, host, port) + else: + return self.login(message, host, port) + + + def unauthorized(self, message, host, port): + m = self.responseFromRequest(401, message) + for (scheme, auth) in iteritems(self.authorizers): + chal = auth.getChallenge((host, port)) + if chal is None: + value = '%s realm="%s"' % (scheme.title(), self.host) + else: + value = '%s %s,realm="%s"' % (scheme.title(), chal, self.host) + m.headers.setdefault('www-authenticate', []).append(value) + self.deliverResponse(m) + + + def login(self, message, host, port): + parts = message.headers['authorization'][0].split(None, 1) + a = self.authorizers.get(parts[0].lower()) + if a: + try: + c = a.decode(parts[1]) + except SIPError: + raise + except: + log.err() + self.deliverResponse(self.responseFromRequest(500, message)) + else: + c.username += '@' + self.host + self.portal.login(c, None, IContact + ).addCallback(self._cbLogin, message, host, port + ).addErrback(self._ebLogin, message, host, port + ).addErrback(log.err + ) + else: + self.deliverResponse(self.responseFromRequest(501, message)) + + + def _cbLogin(self, i_a_l, message, host, port): + # It's stateless, matey. What a joke. + (i, a, l) = i_a_l + self.register(message, host, port) + + + def _ebLogin(self, failure, message, host, port): + failure.trap(cred.error.UnauthorizedLogin) + self.unauthorized(message, host, port) + + + def register(self, message, host, port): + """ + Allow all users to register + """ + name, toURL, params = parseAddress(message.headers["to"][0], clean=1) + contact = None + if "contact" in message.headers: + contact = message.headers["contact"][0] + + if message.headers.get("expires", [None])[0] == "0": + self.unregister(message, toURL, contact) + else: + # XXX Check expires on appropriate URL, and pass it to registry + # instead of having registry hardcode it. + if contact is not None: + name, contactURL, params = parseAddress(contact, host=host, port=port) + d = self.registry.registerAddress(message.uri, toURL, contactURL) + else: + d = self.registry.getRegistrationInfo(toURL) + d.addCallbacks(self._cbRegister, self._ebRegister, + callbackArgs=(message,), + errbackArgs=(message,) + ) + + + def _cbRegister(self, registration, message): + response = self.responseFromRequest(200, message) + if registration.contactURL != None: + response.addHeader("contact", registration.contactURL.toString()) + response.addHeader("expires", "%d" % registration.secondsToExpiry) + response.addHeader("content-length", "0") + self.deliverResponse(response) + + + def _ebRegister(self, error, message): + error.trap(RegistrationError, LookupError) + # XXX return error message, and alter tests to deal with + # this, currently tests assume no message sent on failure + + + def unregister(self, message, toURL, contact): + try: + expires = int(message.headers["expires"][0]) + except ValueError: + self.deliverResponse(self.responseFromRequest(400, message)) + else: + if expires == 0: + if contact == "*": + contactURL = "*" + else: + name, contactURL, params = parseAddress(contact) + d = self.registry.unregisterAddress(message.uri, toURL, contactURL) + d.addCallback(self._cbUnregister, message + ).addErrback(self._ebUnregister, message + ) + + + def _cbUnregister(self, registration, message): + msg = self.responseFromRequest(200, message) + msg.headers.setdefault('contact', []).append(registration.contactURL.toString()) + msg.addHeader("expires", "0") + self.deliverResponse(msg) + + + def _ebUnregister(self, registration, message): + pass + + + +@implementer(IRegistry, ILocator) +class InMemoryRegistry: + """ + A simplistic registry for a specific domain. + """ + def __init__(self, domain): + self.domain = domain # The domain we handle registration for + self.users = {} # Map username to (IDelayedCall for expiry, address URI) + + + def getAddress(self, userURI): + if userURI.host != self.domain: + return defer.fail(LookupError("unknown domain")) + if userURI.username in self.users: + dc, url = self.users[userURI.username] + return defer.succeed(url) + else: + return defer.fail(LookupError("no such user")) + + + def getRegistrationInfo(self, userURI): + if userURI.host != self.domain: + return defer.fail(LookupError("unknown domain")) + if userURI.username in self.users: + dc, url = self.users[userURI.username] + return defer.succeed(Registration(int(dc.getTime() - time.time()), url)) + else: + return defer.fail(LookupError("no such user")) + + + def _expireRegistration(self, username): + try: + dc, url = self.users[username] + except KeyError: + return defer.fail(LookupError("no such user")) + else: + dc.cancel() + del self.users[username] + return defer.succeed(Registration(0, url)) + + + def registerAddress(self, domainURL, logicalURL, physicalURL): + if domainURL.host != self.domain: + log.msg("Registration for domain we don't handle.") + return defer.fail(RegistrationError(404)) + if logicalURL.host != self.domain: + log.msg("Registration for domain we don't handle.") + return defer.fail(RegistrationError(404)) + if logicalURL.username in self.users: + dc, old = self.users[logicalURL.username] + dc.reset(3600) + else: + dc = reactor.callLater(3600, self._expireRegistration, logicalURL.username) + log.msg("Registered %s at %s" % (logicalURL.toString(), physicalURL.toString())) + self.users[logicalURL.username] = (dc, physicalURL) + return defer.succeed(Registration(int(dc.getTime() - time.time()), physicalURL)) + + + def unregisterAddress(self, domainURL, logicalURL, physicalURL): + return self._expireRegistration(logicalURL.username) diff --git a/contrib/python/Twisted/py2/twisted/protocols/socks.py b/contrib/python/Twisted/py2/twisted/protocols/socks.py new file mode 100644 index 0000000000..a52c09b669 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/socks.py @@ -0,0 +1,255 @@ +# -*- test-case-name: twisted.test.test_socks -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Implementation of the SOCKSv4 protocol. +""" + +# python imports +import struct +import string +import socket +import time + +# twisted imports +from twisted.internet import reactor, protocol, defer +from twisted.python import log + + +class SOCKSv4Outgoing(protocol.Protocol): + def __init__(self, socks): + self.socks=socks + + + def connectionMade(self): + peer = self.transport.getPeer() + self.socks.makeReply(90, 0, port=peer.port, ip=peer.host) + self.socks.otherConn=self + + + def connectionLost(self, reason): + self.socks.transport.loseConnection() + + + def dataReceived(self, data): + self.socks.write(data) + + + def write(self,data): + self.socks.log(self,data) + self.transport.write(data) + + + +class SOCKSv4Incoming(protocol.Protocol): + def __init__(self,socks): + self.socks=socks + self.socks.otherConn=self + + + def connectionLost(self, reason): + self.socks.transport.loseConnection() + + + def dataReceived(self,data): + self.socks.write(data) + + + def write(self, data): + self.socks.log(self,data) + self.transport.write(data) + + + +class SOCKSv4(protocol.Protocol): + """ + An implementation of the SOCKSv4 protocol. + + @type logging: L{str} or L{None} + @ivar logging: If not L{None}, the name of the logfile to which connection + information will be written. + + @type reactor: object providing L{twisted.internet.interfaces.IReactorTCP} + @ivar reactor: The reactor used to create connections. + + @type buf: L{str} + @ivar buf: Part of a SOCKSv4 connection request. + + @type otherConn: C{SOCKSv4Incoming}, C{SOCKSv4Outgoing} or L{None} + @ivar otherConn: Until the connection has been established, C{otherConn} is + L{None}. After that, it is the proxy-to-destination protocol instance + along which the client's connection is being forwarded. + """ + def __init__(self, logging=None, reactor=reactor): + self.logging = logging + self.reactor = reactor + + + def connectionMade(self): + self.buf = b"" + self.otherConn = None + + + def dataReceived(self, data): + """ + Called whenever data is received. + + @type data: L{bytes} + @param data: Part or all of a SOCKSv4 packet. + """ + if self.otherConn: + self.otherConn.write(data) + return + self.buf = self.buf + data + completeBuffer = self.buf + if b"\000" in self.buf[8:]: + head, self.buf = self.buf[:8], self.buf[8:] + version, code, port = struct.unpack("!BBH", head[:4]) + user, self.buf = self.buf.split(b"\000", 1) + if head[4:7] == b"\000\000\000" and head[7:8] != b"\000": + # An IP address of the form 0.0.0.X, where X is non-zero, + # signifies that this is a SOCKSv4a packet. + # If the complete packet hasn't been received, restore the + # buffer and wait for it. + if b"\000" not in self.buf: + self.buf = completeBuffer + return + server, self.buf = self.buf.split(b"\000", 1) + d = self.reactor.resolve(server) + d.addCallback(self._dataReceived2, user, + version, code, port) + d.addErrback(lambda result, self = self: self.makeReply(91)) + return + else: + server = socket.inet_ntoa(head[4:8]) + + self._dataReceived2(server, user, version, code, port) + + + def _dataReceived2(self, server, user, version, code, port): + """ + The second half of the SOCKS connection setup. For a SOCKSv4 packet this + is after the server address has been extracted from the header. For a + SOCKSv4a packet this is after the host name has been resolved. + + @type server: L{str} + @param server: The IP address of the destination, represented as a + dotted quad. + + @type user: L{str} + @param user: The username associated with the connection. + + @type version: L{int} + @param version: The SOCKS protocol version number. + + @type code: L{int} + @param code: The comand code. 1 means establish a TCP/IP stream + connection, and 2 means establish a TCP/IP port binding. + + @type port: L{int} + @param port: The port number associated with the connection. + """ + assert version == 4, "Bad version code: %s" % version + if not self.authorize(code, server, port, user): + self.makeReply(91) + return + if code == 1: # CONNECT + d = self.connectClass(server, port, SOCKSv4Outgoing, self) + d.addErrback(lambda result, self = self: self.makeReply(91)) + elif code == 2: # BIND + d = self.listenClass(0, SOCKSv4IncomingFactory, self, server) + d.addCallback(lambda x, + self = self: self.makeReply(90, 0, x[1], x[0])) + else: + raise RuntimeError("Bad Connect Code: %s" % (code,)) + assert self.buf == b"", "hmm, still stuff in buffer... %s" % repr( + self.buf) + + + def connectionLost(self, reason): + if self.otherConn: + self.otherConn.transport.loseConnection() + + + def authorize(self,code,server,port,user): + log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user)) + return 1 + + + def connectClass(self, host, port, klass, *args): + return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port) + + + def listenClass(self, port, klass, *args): + serv = reactor.listenTCP(port, klass(*args)) + return defer.succeed(serv.getHost()[1:]) + + + def makeReply(self,reply,version=0,port=0,ip="0.0.0.0"): + self.transport.write(struct.pack("!BBH",version,reply,port)+socket.inet_aton(ip)) + if reply!=90: self.transport.loseConnection() + + + def write(self,data): + self.log(self,data) + self.transport.write(data) + + + def log(self,proto,data): + if not self.logging: return + peer = self.transport.getPeer() + their_peer = self.otherConn.transport.getPeer() + f=open(self.logging,"a") + f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(), + peer.host,peer.port, + ((proto==self and '<') or '>'), + their_peer.host,their_peer.port)) + while data: + p,data=data[:16],data[16:] + f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ') + f.write((16-len(p))*3*' ') + for c in p: + if len(repr(c))>3: f.write('.') + else: f.write(c) + f.write('\n') + f.write('\n') + f.close() + + + +class SOCKSv4Factory(protocol.Factory): + """ + A factory for a SOCKSv4 proxy. + + Constructor accepts one argument, a log file name. + """ + def __init__(self, log): + self.logging = log + + + def buildProtocol(self, addr): + return SOCKSv4(self.logging, reactor) + + + +class SOCKSv4IncomingFactory(protocol.Factory): + """ + A utility class for building protocols for incoming connections. + """ + def __init__(self, socks, ip): + self.socks = socks + self.ip = ip + + + def buildProtocol(self, addr): + if addr[0] == self.ip: + self.ip = "" + self.socks.makeReply(90, 0) + return SOCKSv4Incoming(self.socks) + elif self.ip == "": + return None + else: + self.socks.makeReply(91, 0) + self.ip = "" + return None diff --git a/contrib/python/Twisted/py2/twisted/protocols/stateful.py b/contrib/python/Twisted/py2/twisted/protocols/stateful.py new file mode 100644 index 0000000000..cd2b7cfe70 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/stateful.py @@ -0,0 +1,49 @@ +# -*- test-case-name: twisted.test.test_stateful -*- + +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + + +from twisted.internet import protocol + +from io import BytesIO + +class StatefulProtocol(protocol.Protocol): + """A Protocol that stores state for you. + + state is a pair (function, num_bytes). When num_bytes bytes of data arrives + from the network, function is called. It is expected to return the next + state or None to keep same state. Initial state is returned by + getInitialState (override it). + """ + _sful_data = None, None, 0 + + def makeConnection(self, transport): + protocol.Protocol.makeConnection(self, transport) + self._sful_data = self.getInitialState(), BytesIO(), 0 + + def getInitialState(self): + raise NotImplementedError + + def dataReceived(self, data): + state, buffer, offset = self._sful_data + buffer.seek(0, 2) + buffer.write(data) + blen = buffer.tell() # how many bytes total is in the buffer + buffer.seek(offset) + while blen - offset >= state[1]: + d = buffer.read(state[1]) + offset += state[1] + next = state[0](d) + if self.transport.disconnecting: # XXX: argh stupid hack borrowed right from LineReceiver + return # dataReceived won't be called again, so who cares about consistent state + if next: + state = next + if offset != 0: + b = buffer.read() + buffer.seek(0) + buffer.truncate() + buffer.write(b) + offset = 0 + self._sful_data = state, buffer, offset + diff --git a/contrib/python/Twisted/py2/twisted/protocols/tls.py b/contrib/python/Twisted/py2/twisted/protocols/tls.py new file mode 100644 index 0000000000..52cd498aa9 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/tls.py @@ -0,0 +1,830 @@ +# -*- test-case-name: twisted.protocols.test.test_tls,twisted.internet.test.test_tls,twisted.test.test_sslverify -*- +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +""" +Implementation of a TLS transport (L{ISSLTransport}) as an +L{IProtocol<twisted.internet.interfaces.IProtocol>} layered on top of any +L{ITransport<twisted.internet.interfaces.ITransport>} implementation, based on +U{OpenSSL<http://www.openssl.org>}'s memory BIO features. + +L{TLSMemoryBIOFactory} is a L{WrappingFactory} which wraps protocols created by +the factory it wraps with L{TLSMemoryBIOProtocol}. L{TLSMemoryBIOProtocol} +intercedes between the underlying transport and the wrapped protocol to +implement SSL and TLS. Typical usage of this module looks like this:: + + from twisted.protocols.tls import TLSMemoryBIOFactory + from twisted.internet.protocol import ServerFactory + from twisted.internet.ssl import PrivateCertificate + from twisted.internet import reactor + + from someapplication import ApplicationProtocol + + serverFactory = ServerFactory() + serverFactory.protocol = ApplicationProtocol + certificate = PrivateCertificate.loadPEM(certPEMData) + contextFactory = certificate.options() + tlsFactory = TLSMemoryBIOFactory(contextFactory, False, serverFactory) + reactor.listenTCP(12345, tlsFactory) + reactor.run() + +This API offers somewhat more flexibility than +L{twisted.internet.interfaces.IReactorSSL}; for example, a +L{TLSMemoryBIOProtocol} instance can use another instance of +L{TLSMemoryBIOProtocol} as its transport, yielding TLS over TLS - useful to +implement onion routing. It can also be used to run TLS over unusual +transports, such as UNIX sockets and stdio. +""" + +from __future__ import division, absolute_import + +from OpenSSL.SSL import Error, ZeroReturnError, WantReadError +from OpenSSL.SSL import TLSv1_METHOD, Context, Connection + +try: + Connection(Context(TLSv1_METHOD), None) +except TypeError as e: + if str(e) != "argument must be an int, or have a fileno() method.": + raise + raise ImportError("twisted.protocols.tls requires pyOpenSSL 0.10 or newer.") + +from zope.interface import implementer, providedBy, directlyProvides + +from twisted.python.compat import unicode +from twisted.python.failure import Failure +from twisted.internet.interfaces import ( + ISystemHandle, INegotiated, IPushProducer, ILoggingContext, + IOpenSSLServerConnectionCreator, IOpenSSLClientConnectionCreator, + IProtocolNegotiationFactory, IHandshakeListener +) +from twisted.internet.main import CONNECTION_LOST +from twisted.internet._producer_helpers import _PullToPush +from twisted.internet.protocol import Protocol +from twisted.internet._sslverify import _setAcceptableProtocols +from twisted.protocols.policies import ProtocolWrapper, WrappingFactory + + +@implementer(IPushProducer) +class _ProducerMembrane(object): + """ + Stand-in for producer registered with a L{TLSMemoryBIOProtocol} transport. + + Ensures that producer pause/resume events from the undelying transport are + coordinated with pause/resume events from the TLS layer. + + @ivar _producer: The application-layer producer. + """ + + _producerPaused = False + + def __init__(self, producer): + self._producer = producer + + + def pauseProducing(self): + """ + C{pauseProducing} the underlying producer, if it's not paused. + """ + if self._producerPaused: + return + self._producerPaused = True + self._producer.pauseProducing() + + + def resumeProducing(self): + """ + C{resumeProducing} the underlying producer, if it's paused. + """ + if not self._producerPaused: + return + self._producerPaused = False + self._producer.resumeProducing() + + + def stopProducing(self): + """ + C{stopProducing} the underlying producer. + + There is only a single source for this event, so it's simply passed + on. + """ + self._producer.stopProducing() + + +@implementer(ISystemHandle, INegotiated) +class TLSMemoryBIOProtocol(ProtocolWrapper): + """ + L{TLSMemoryBIOProtocol} is a protocol wrapper which uses OpenSSL via a + memory BIO to encrypt bytes written to it before sending them on to the + underlying transport and decrypts bytes received from the underlying + transport before delivering them to the wrapped protocol. + + In addition to producer events from the underlying transport, the need to + wait for reads before a write can proceed means the L{TLSMemoryBIOProtocol} + may also want to pause a producer. Pause/resume events are therefore + merged using the L{_ProducerMembrane} wrapper. Non-streaming (pull) + producers are supported by wrapping them with L{_PullToPush}. + + @ivar _tlsConnection: The L{OpenSSL.SSL.Connection} instance which is + encrypted and decrypting this connection. + + @ivar _lostTLSConnection: A flag indicating whether connection loss has + already been dealt with (C{True}) or not (C{False}). TLS disconnection + is distinct from the underlying connection being lost. + + @ivar _appSendBuffer: application-level (cleartext) data that is waiting to + be transferred to the TLS buffer, but can't be because the TLS + connection is handshaking. + @type _appSendBuffer: L{list} of L{bytes} + + @ivar _connectWrapped: A flag indicating whether or not to call + C{makeConnection} on the wrapped protocol. This is for the reactor's + L{twisted.internet.interfaces.ITLSTransport.startTLS} implementation, + since it has a protocol which it has already called C{makeConnection} + on, and which has no interest in a new transport. See #3821. + + @ivar _handshakeDone: A flag indicating whether or not the handshake is + known to have completed successfully (C{True}) or not (C{False}). This + is used to control error reporting behavior. If the handshake has not + completed, the underlying L{OpenSSL.SSL.Error} will be passed to the + application's C{connectionLost} method. If it has completed, any + unexpected L{OpenSSL.SSL.Error} will be turned into a + L{ConnectionLost}. This is weird; however, it is simply an attempt at + a faithful re-implementation of the behavior provided by + L{twisted.internet.ssl}. + + @ivar _reason: If an unexpected L{OpenSSL.SSL.Error} occurs which causes + the connection to be lost, it is saved here. If appropriate, this may + be used as the reason passed to the application protocol's + C{connectionLost} method. + + @ivar _producer: The current producer registered via C{registerProducer}, + or L{None} if no producer has been registered or a previous one was + unregistered. + + @ivar _aborted: C{abortConnection} has been called. No further data will + be received to the wrapped protocol's C{dataReceived}. + @type _aborted: L{bool} + """ + + _reason = None + _handshakeDone = False + _lostTLSConnection = False + _producer = None + _aborted = False + + def __init__(self, factory, wrappedProtocol, _connectWrapped=True): + ProtocolWrapper.__init__(self, factory, wrappedProtocol) + self._connectWrapped = _connectWrapped + + + def getHandle(self): + """ + Return the L{OpenSSL.SSL.Connection} object being used to encrypt and + decrypt this connection. + + This is done for the benefit of L{twisted.internet.ssl.Certificate}'s + C{peerFromTransport} and C{hostFromTransport} methods only. A + different system handle may be returned by future versions of this + method. + """ + return self._tlsConnection + + + def makeConnection(self, transport): + """ + Connect this wrapper to the given transport and initialize the + necessary L{OpenSSL.SSL.Connection} with a memory BIO. + """ + self._tlsConnection = self.factory._createConnection(self) + self._appSendBuffer = [] + + # Add interfaces provided by the transport we are wrapping: + for interface in providedBy(transport): + directlyProvides(self, interface) + + # Intentionally skip ProtocolWrapper.makeConnection - it might call + # wrappedProtocol.makeConnection, which we want to make conditional. + Protocol.makeConnection(self, transport) + self.factory.registerProtocol(self) + if self._connectWrapped: + # Now that the TLS layer is initialized, notify the application of + # the connection. + ProtocolWrapper.makeConnection(self, transport) + + # Now that we ourselves have a transport (initialized by the + # ProtocolWrapper.makeConnection call above), kick off the TLS + # handshake. + self._checkHandshakeStatus() + + + def _checkHandshakeStatus(self): + """ + Ask OpenSSL to proceed with a handshake in progress. + + Initially, this just sends the ClientHello; after some bytes have been + stuffed in to the C{Connection} object by C{dataReceived}, it will then + respond to any C{Certificate} or C{KeyExchange} messages. + """ + # The connection might already be aborted (eg. by a callback during + # connection setup), so don't even bother trying to handshake in that + # case. + if self._aborted: + return + try: + self._tlsConnection.do_handshake() + except WantReadError: + self._flushSendBIO() + except Error: + self._tlsShutdownFinished(Failure()) + else: + self._handshakeDone = True + if IHandshakeListener.providedBy(self.wrappedProtocol): + self.wrappedProtocol.handshakeCompleted() + + + def _flushSendBIO(self): + """ + Read any bytes out of the send BIO and write them to the underlying + transport. + """ + try: + bytes = self._tlsConnection.bio_read(2 ** 15) + except WantReadError: + # There may be nothing in the send BIO right now. + pass + else: + self.transport.write(bytes) + + + def _flushReceiveBIO(self): + """ + Try to receive any application-level bytes which are now available + because of a previous write into the receive BIO. This will take + care of delivering any application-level bytes which are received to + the protocol, as well as handling of the various exceptions which + can come from trying to get such bytes. + """ + # Keep trying this until an error indicates we should stop or we + # close the connection. Looping is necessary to make sure we + # process all of the data which was put into the receive BIO, as + # there is no guarantee that a single recv call will do it all. + while not self._lostTLSConnection: + try: + bytes = self._tlsConnection.recv(2 ** 15) + except WantReadError: + # The newly received bytes might not have been enough to produce + # any application data. + break + except ZeroReturnError: + # TLS has shut down and no more TLS data will be received over + # this connection. + self._shutdownTLS() + # Passing in None means the user protocol's connnectionLost + # will get called with reason from underlying transport: + self._tlsShutdownFinished(None) + except Error: + # Something went pretty wrong. For example, this might be a + # handshake failure during renegotiation (because there were no + # shared ciphers, because a certificate failed to verify, etc). + # TLS can no longer proceed. + failure = Failure() + self._tlsShutdownFinished(failure) + else: + if not self._aborted: + ProtocolWrapper.dataReceived(self, bytes) + + # The received bytes might have generated a response which needs to be + # sent now. For example, the handshake involves several round-trip + # exchanges without ever producing application-bytes. + self._flushSendBIO() + + + def dataReceived(self, bytes): + """ + Deliver any received bytes to the receive BIO and then read and deliver + to the application any application-level data which becomes available + as a result of this. + """ + # Let OpenSSL know some bytes were just received. + self._tlsConnection.bio_write(bytes) + + # If we are still waiting for the handshake to complete, try to + # complete the handshake with the bytes we just received. + if not self._handshakeDone: + self._checkHandshakeStatus() + + # If the handshake still isn't finished, then we've nothing left to + # do. + if not self._handshakeDone: + return + + # If we've any pending writes, this read may have un-blocked them, so + # attempt to unbuffer them into the OpenSSL layer. + if self._appSendBuffer: + self._unbufferPendingWrites() + + # Since the handshake is complete, the wire-level bytes we just + # processed might turn into some application-level bytes; try to pull + # those out. + self._flushReceiveBIO() + + + def _shutdownTLS(self): + """ + Initiate, or reply to, the shutdown handshake of the TLS layer. + """ + try: + shutdownSuccess = self._tlsConnection.shutdown() + except Error: + # Mid-handshake, a call to shutdown() can result in a + # WantWantReadError, or rather an SSL_ERR_WANT_READ; but pyOpenSSL + # doesn't allow us to get at the error. See: + # https://github.com/pyca/pyopenssl/issues/91 + shutdownSuccess = False + self._flushSendBIO() + if shutdownSuccess: + # Both sides have shutdown, so we can start closing lower-level + # transport. This will also happen if we haven't started + # negotiation at all yet, in which case shutdown succeeds + # immediately. + self.transport.loseConnection() + + + def _tlsShutdownFinished(self, reason): + """ + Called when TLS connection has gone away; tell underlying transport to + disconnect. + + @param reason: a L{Failure} whose value is an L{Exception} if we want to + report that failure through to the wrapped protocol's + C{connectionLost}, or L{None} if the C{reason} that + C{connectionLost} should receive should be coming from the + underlying transport. + @type reason: L{Failure} or L{None} + """ + if reason is not None: + # Squash an EOF in violation of the TLS protocol into + # ConnectionLost, so that applications which might run over + # multiple protocols can recognize its type. + if tuple(reason.value.args[:2]) == (-1, 'Unexpected EOF'): + reason = Failure(CONNECTION_LOST) + if self._reason is None: + self._reason = reason + self._lostTLSConnection = True + # We may need to send a TLS alert regarding the nature of the shutdown + # here (for example, why a handshake failed), so always flush our send + # buffer before telling our lower-level transport to go away. + self._flushSendBIO() + # Using loseConnection causes the application protocol's + # connectionLost method to be invoked non-reentrantly, which is always + # a nice feature. However, for error cases (reason != None) we might + # want to use abortConnection when it becomes available. The + # loseConnection call is basically tested by test_handshakeFailure. + # At least one side will need to do it or the test never finishes. + self.transport.loseConnection() + + + def connectionLost(self, reason): + """ + Handle the possible repetition of calls to this method (due to either + the underlying transport going away or due to an error at the TLS + layer) and make sure the base implementation only gets invoked once. + """ + if not self._lostTLSConnection: + # Tell the TLS connection that it's not going to get any more data + # and give it a chance to finish reading. + self._tlsConnection.bio_shutdown() + self._flushReceiveBIO() + self._lostTLSConnection = True + reason = self._reason or reason + self._reason = None + self.connected = False + ProtocolWrapper.connectionLost(self, reason) + + # Breaking reference cycle between self._tlsConnection and self. + self._tlsConnection = None + + + def loseConnection(self): + """ + Send a TLS close alert and close the underlying connection. + """ + if self.disconnecting or not self.connected: + return + # If connection setup has not finished, OpenSSL 1.0.2f+ will not shut + # down the connection until we write some data to the connection which + # allows the handshake to complete. However, since no data should be + # written after loseConnection, this means we'll be stuck forever + # waiting for shutdown to complete. Instead, we simply abort the + # connection without trying to shut down cleanly: + if not self._handshakeDone and not self._appSendBuffer: + self.abortConnection() + self.disconnecting = True + if not self._appSendBuffer and self._producer is None: + self._shutdownTLS() + + + def abortConnection(self): + """ + Tear down TLS state so that if the connection is aborted mid-handshake + we don't deliver any further data from the application. + """ + self._aborted = True + self.disconnecting = True + self._shutdownTLS() + self.transport.abortConnection() + + + def failVerification(self, reason): + """ + Abort the connection during connection setup, giving a reason that + certificate verification failed. + + @param reason: The reason that the verification failed; reported to the + application protocol's C{connectionLost} method. + @type reason: L{Failure} + """ + self._reason = reason + self.abortConnection() + + + def write(self, bytes): + """ + Process the given application bytes and send any resulting TLS traffic + which arrives in the send BIO. + + If C{loseConnection} was called, subsequent calls to C{write} will + drop the bytes on the floor. + """ + if isinstance(bytes, unicode): + raise TypeError("Must write bytes to a TLS transport, not unicode.") + # Writes after loseConnection are not supported, unless a producer has + # been registered, in which case writes can happen until the producer + # is unregistered: + if self.disconnecting and self._producer is None: + return + self._write(bytes) + + + def _bufferedWrite(self, octets): + """ + Put the given octets into L{TLSMemoryBIOProtocol._appSendBuffer}, and + tell any listening producer that it should pause because we are now + buffering. + """ + self._appSendBuffer.append(octets) + if self._producer is not None: + self._producer.pauseProducing() + + + def _unbufferPendingWrites(self): + """ + Un-buffer all waiting writes in L{TLSMemoryBIOProtocol._appSendBuffer}. + """ + pendingWrites, self._appSendBuffer = self._appSendBuffer, [] + for eachWrite in pendingWrites: + self._write(eachWrite) + + if self._appSendBuffer: + # If OpenSSL ran out of buffer space in the Connection on our way + # through the loop earlier and re-buffered any of our outgoing + # writes, then we're done; don't consider any future work. + return + + if self._producer is not None: + # If we have a registered producer, let it know that we have some + # more buffer space. + self._producer.resumeProducing() + return + + if self.disconnecting: + # Finally, if we have no further buffered data, no producer wants + # to send us more data in the future, and the application told us + # to end the stream, initiate a TLS shutdown. + self._shutdownTLS() + + + def _write(self, bytes): + """ + Process the given application bytes and send any resulting TLS traffic + which arrives in the send BIO. + + This may be called by C{dataReceived} with bytes that were buffered + before C{loseConnection} was called, which is why this function + doesn't check for disconnection but accepts the bytes regardless. + """ + if self._lostTLSConnection: + return + + # A TLS payload is 16kB max + bufferSize = 2 ** 14 + + # How far into the input we've gotten so far + alreadySent = 0 + + while alreadySent < len(bytes): + toSend = bytes[alreadySent:alreadySent + bufferSize] + try: + sent = self._tlsConnection.send(toSend) + except WantReadError: + self._bufferedWrite(bytes[alreadySent:]) + break + except Error: + # Pretend TLS connection disconnected, which will trigger + # disconnect of underlying transport. The error will be passed + # to the application protocol's connectionLost method. The + # other SSL implementation doesn't, but losing helpful + # debugging information is a bad idea. + self._tlsShutdownFinished(Failure()) + break + else: + # We've successfully handed off the bytes to the OpenSSL + # Connection object. + alreadySent += sent + # See if OpenSSL wants to hand any bytes off to the underlying + # transport as a result. + self._flushSendBIO() + + + def writeSequence(self, iovec): + """ + Write a sequence of application bytes by joining them into one string + and passing them to L{write}. + """ + self.write(b"".join(iovec)) + + + def getPeerCertificate(self): + return self._tlsConnection.get_peer_certificate() + + + @property + def negotiatedProtocol(self): + """ + @see: L{INegotiated.negotiatedProtocol} + """ + protocolName = None + + try: + # If ALPN is not implemented that's ok, NPN might be. + protocolName = self._tlsConnection.get_alpn_proto_negotiated() + except (NotImplementedError, AttributeError): + pass + + if protocolName not in (b'', None): + # A protocol was selected using ALPN. + return protocolName + + try: + protocolName = self._tlsConnection.get_next_proto_negotiated() + except (NotImplementedError, AttributeError): + pass + + if protocolName != b'': + return protocolName + + return None + + + def registerProducer(self, producer, streaming): + # If we've already disconnected, nothing to do here: + if self._lostTLSConnection: + producer.stopProducing() + return + + # If we received a non-streaming producer, wrap it so it becomes a + # streaming producer: + if not streaming: + producer = streamingProducer = _PullToPush(producer, self) + producer = _ProducerMembrane(producer) + # This will raise an exception if a producer is already registered: + self.transport.registerProducer(producer, True) + self._producer = producer + # If we received a non-streaming producer, we need to start the + # streaming wrapper: + if not streaming: + streamingProducer.startStreaming() + + + def unregisterProducer(self): + # If we have no producer, we don't need to do anything here. + if self._producer is None: + return + + # If we received a non-streaming producer, we need to stop the + # streaming wrapper: + if isinstance(self._producer._producer, _PullToPush): + self._producer._producer.stopStreaming() + self._producer = None + self._producerPaused = False + self.transport.unregisterProducer() + if self.disconnecting and not self._appSendBuffer: + self._shutdownTLS() + + + +@implementer(IOpenSSLClientConnectionCreator, IOpenSSLServerConnectionCreator) +class _ContextFactoryToConnectionFactory(object): + """ + Adapter wrapping a L{twisted.internet.interfaces.IOpenSSLContextFactory} + into a L{IOpenSSLClientConnectionCreator} or + L{IOpenSSLServerConnectionCreator}. + + See U{https://twistedmatrix.com/trac/ticket/7215} for work that should make + this unnecessary. + """ + + def __init__(self, oldStyleContextFactory): + """ + Construct a L{_ContextFactoryToConnectionFactory} with a + L{twisted.internet.interfaces.IOpenSSLContextFactory}. + + Immediately call C{getContext} on C{oldStyleContextFactory} in order to + force advance parameter checking, since old-style context factories + don't actually check that their arguments to L{OpenSSL} are correct. + + @param oldStyleContextFactory: A factory that can produce contexts. + @type oldStyleContextFactory: + L{twisted.internet.interfaces.IOpenSSLContextFactory} + """ + oldStyleContextFactory.getContext() + self._oldStyleContextFactory = oldStyleContextFactory + + + def _connectionForTLS(self, protocol): + """ + Create an L{OpenSSL.SSL.Connection} object. + + @param protocol: The protocol initiating a TLS connection. + @type protocol: L{TLSMemoryBIOProtocol} + + @return: a connection + @rtype: L{OpenSSL.SSL.Connection} + """ + context = self._oldStyleContextFactory.getContext() + return Connection(context, None) + + + def serverConnectionForTLS(self, protocol): + """ + Construct an OpenSSL server connection from the wrapped old-style + context factory. + + @note: Since old-style context factories don't distinguish between + clients and servers, this is exactly the same as + L{_ContextFactoryToConnectionFactory.clientConnectionForTLS}. + + @param protocol: The protocol initiating a TLS connection. + @type protocol: L{TLSMemoryBIOProtocol} + + @return: a connection + @rtype: L{OpenSSL.SSL.Connection} + """ + return self._connectionForTLS(protocol) + + + def clientConnectionForTLS(self, protocol): + """ + Construct an OpenSSL server connection from the wrapped old-style + context factory. + + @note: Since old-style context factories don't distinguish between + clients and servers, this is exactly the same as + L{_ContextFactoryToConnectionFactory.serverConnectionForTLS}. + + @param protocol: The protocol initiating a TLS connection. + @type protocol: L{TLSMemoryBIOProtocol} + + @return: a connection + @rtype: L{OpenSSL.SSL.Connection} + """ + return self._connectionForTLS(protocol) + + + +class TLSMemoryBIOFactory(WrappingFactory): + """ + L{TLSMemoryBIOFactory} adds TLS to connections. + + @ivar _creatorInterface: the interface which L{_connectionCreator} is + expected to implement. + @type _creatorInterface: L{zope.interface.interfaces.IInterface} + + @ivar _connectionCreator: a callable which creates an OpenSSL Connection + object. + @type _connectionCreator: 1-argument callable taking + L{TLSMemoryBIOProtocol} and returning L{OpenSSL.SSL.Connection}. + """ + protocol = TLSMemoryBIOProtocol + + noisy = False # disable unnecessary logging. + + def __init__(self, contextFactory, isClient, wrappedFactory): + """ + Create a L{TLSMemoryBIOFactory}. + + @param contextFactory: Configuration parameters used to create an + OpenSSL connection. In order of preference, what you should pass + here should be: + + 1. L{twisted.internet.ssl.CertificateOptions} (if you're + writing a server) or the result of + L{twisted.internet.ssl.optionsForClientTLS} (if you're + writing a client). If you want security you should really + use one of these. + + 2. If you really want to implement something yourself, supply a + provider of L{IOpenSSLClientConnectionCreator} or + L{IOpenSSLServerConnectionCreator}. + + 3. If you really have to, supply a + L{twisted.internet.ssl.ContextFactory}. This will likely be + deprecated at some point so please upgrade to the new + interfaces. + + @type contextFactory: L{IOpenSSLClientConnectionCreator} or + L{IOpenSSLServerConnectionCreator}, or, for compatibility with + older code, anything implementing + L{twisted.internet.interfaces.IOpenSSLContextFactory}. See + U{https://twistedmatrix.com/trac/ticket/7215} for information on + the upcoming deprecation of passing a + L{twisted.internet.ssl.ContextFactory} here. + + @param isClient: Is this a factory for TLS client connections; in other + words, those that will send a C{ClientHello} greeting? L{True} if + so, L{False} otherwise. This flag determines what interface is + expected of C{contextFactory}. If L{True}, C{contextFactory} + should provide L{IOpenSSLClientConnectionCreator}; otherwise it + should provide L{IOpenSSLServerConnectionCreator}. + @type isClient: L{bool} + + @param wrappedFactory: A factory which will create the + application-level protocol. + @type wrappedFactory: L{twisted.internet.interfaces.IProtocolFactory} + """ + WrappingFactory.__init__(self, wrappedFactory) + if isClient: + creatorInterface = IOpenSSLClientConnectionCreator + else: + creatorInterface = IOpenSSLServerConnectionCreator + self._creatorInterface = creatorInterface + if not creatorInterface.providedBy(contextFactory): + contextFactory = _ContextFactoryToConnectionFactory(contextFactory) + self._connectionCreator = contextFactory + + + def logPrefix(self): + """ + Annotate the wrapped factory's log prefix with some text indicating TLS + is in use. + + @rtype: C{str} + """ + if ILoggingContext.providedBy(self.wrappedFactory): + logPrefix = self.wrappedFactory.logPrefix() + else: + logPrefix = self.wrappedFactory.__class__.__name__ + return "%s (TLS)" % (logPrefix,) + + + def _applyProtocolNegotiation(self, connection): + """ + Applies ALPN/NPN protocol neogitation to the connection, if the factory + supports it. + + @param connection: The OpenSSL connection object to have ALPN/NPN added + to it. + @type connection: L{OpenSSL.SSL.Connection} + + @return: Nothing + @rtype: L{None} + """ + if IProtocolNegotiationFactory.providedBy(self.wrappedFactory): + protocols = self.wrappedFactory.acceptableProtocols() + context = connection.get_context() + _setAcceptableProtocols(context, protocols) + + return + + + def _createConnection(self, tlsProtocol): + """ + Create an OpenSSL connection and set it up good. + + @param tlsProtocol: The protocol which is establishing the connection. + @type tlsProtocol: L{TLSMemoryBIOProtocol} + + @return: an OpenSSL connection object for C{tlsProtocol} to use + @rtype: L{OpenSSL.SSL.Connection} + """ + connectionCreator = self._connectionCreator + if self._creatorInterface is IOpenSSLClientConnectionCreator: + connection = connectionCreator.clientConnectionForTLS(tlsProtocol) + self._applyProtocolNegotiation(connection) + connection.set_connect_state() + else: + connection = connectionCreator.serverConnectionForTLS(tlsProtocol) + self._applyProtocolNegotiation(connection) + connection.set_accept_state() + return connection diff --git a/contrib/python/Twisted/py2/twisted/protocols/wire.py b/contrib/python/Twisted/py2/twisted/protocols/wire.py new file mode 100644 index 0000000000..0e647b3235 --- /dev/null +++ b/contrib/python/Twisted/py2/twisted/protocols/wire.py @@ -0,0 +1,124 @@ +# Copyright (c) Twisted Matrix Laboratories. +# See LICENSE for details. + +"""Implement standard (and unused) TCP protocols. + +These protocols are either provided by inetd, or are not provided at all. +""" + +from __future__ import absolute_import, division + +import time +import struct + +from zope.interface import implementer + +from twisted.internet import protocol, interfaces + + + +class Echo(protocol.Protocol): + """ + As soon as any data is received, write it back (RFC 862). + """ + + def dataReceived(self, data): + self.transport.write(data) + + + +class Discard(protocol.Protocol): + """ + Discard any received data (RFC 863). + """ + + def dataReceived(self, data): + # I'm ignoring you, nyah-nyah + pass + + + +@implementer(interfaces.IProducer) +class Chargen(protocol.Protocol): + """ + Generate repeating noise (RFC 864). + """ + noise = b'@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~ !"#$%&?' + + def connectionMade(self): + self.transport.registerProducer(self, 0) + + + def resumeProducing(self): + self.transport.write(self.noise) + + + def pauseProducing(self): + pass + + + def stopProducing(self): + pass + + + +class QOTD(protocol.Protocol): + """ + Return a quote of the day (RFC 865). + """ + + def connectionMade(self): + self.transport.write(self.getQuote()) + self.transport.loseConnection() + + + def getQuote(self): + """ + Return a quote. May be overrriden in subclasses. + """ + return b"An apple a day keeps the doctor away.\r\n" + + + +class Who(protocol.Protocol): + """ + Return list of active users (RFC 866) + """ + + def connectionMade(self): + self.transport.write(self.getUsers()) + self.transport.loseConnection() + + + def getUsers(self): + """ + Return active users. Override in subclasses. + """ + return b"root\r\n" + + + +class Daytime(protocol.Protocol): + """ + Send back the daytime in ASCII form (RFC 867). + """ + + def connectionMade(self): + self.transport.write(time.asctime(time.gmtime(time.time())) + b'\r\n') + self.transport.loseConnection() + + + +class Time(protocol.Protocol): + """ + Send back the time in machine readable form (RFC 868). + """ + + def connectionMade(self): + # is this correct only for 32-bit machines? + result = struct.pack("!i", int(time.time())) + self.transport.write(result) + self.transport.loseConnection() + + +__all__ = ["Echo", "Discard", "Chargen", "QOTD", "Who", "Daytime", "Time"] |