aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/Twisted/py3/twisted/trial/_dist/stream.py
blob: a53fd4ab214b944a7424a0a372744f4d4a4ccc14 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Buffer byte streams.
"""

from itertools import count
from typing import Dict, Iterator, List, TypeVar

from attrs import Factory, define

from twisted.protocols.amp import AMP, Command, Integer, String as Bytes

T = TypeVar("T")


class StreamOpen(Command):
    """
    Open a new stream.
    """

    response = [(b"streamId", Integer())]


class StreamWrite(Command):
    """
    Write a chunk of data to a stream.
    """

    arguments = [
        (b"streamId", Integer()),
        (b"data", Bytes()),
    ]


@define
class StreamReceiver:
    """
    Buffering de-multiplexing byte stream receiver.
    """

    _counter: Iterator[int] = count()
    _streams: Dict[int, List[bytes]] = Factory(dict)

    def open(self) -> int:
        """
        Open a new stream and return its unique identifier.
        """
        newId = next(self._counter)
        self._streams[newId] = []
        return newId

    def write(self, streamId: int, chunk: bytes) -> None:
        """
        Write to an open stream using its unique identifier.

        @raise KeyError: If there is no such open stream.
        """
        self._streams[streamId].append(chunk)

    def finish(self, streamId: int) -> List[bytes]:
        """
        Indicate an open stream may receive no further data and return all of
        its current contents.

        @raise KeyError: If there is no such open stream.
        """
        return self._streams.pop(streamId)


def chunk(data: bytes, chunkSize: int) -> Iterator[bytes]:
    """
    Break a byte string into pieces of no more than ``chunkSize`` length.

    @param data: The byte string.

    @param chunkSize: The maximum length of the resulting pieces.  All pieces
        except possibly the last will be this length.

    @return: The pieces.
    """
    pos = 0
    while pos < len(data):
        yield data[pos : pos + chunkSize]
        pos += chunkSize


async def stream(amp: AMP, chunks: Iterator[bytes]) -> int:
    """
    Send the given stream chunks, one by one, over the given connection.

    The chunks are sent using L{StreamWrite} over a stream opened using
    L{StreamOpen}.

    @return: The identifier of the stream over which the chunks were sent.
    """
    streamId = (await amp.callRemote(StreamOpen))["streamId"]
    assert isinstance(streamId, int)

    for oneChunk in chunks:
        await amp.callRemote(StreamWrite, streamId=streamId, data=oneChunk)
    return streamId