aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/clickhouse-connect/clickhouse_connect/driver/buffer.py
blob: b50b9bb678226947a5dbc57b648bb7e99858c2a1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import sys
import array
from typing import Any, Iterable

from clickhouse_connect.driver.exceptions import StreamCompleteException
from clickhouse_connect.driver.types import ByteSource

must_swap = sys.byteorder == 'big'


class ResponseBuffer(ByteSource):
    slots = 'slice_sz', 'buf_loc', 'end', 'gen', 'buffer', 'slice'

    def __init__(self, source):
        self.slice_sz = 4096
        self.buf_loc = 0
        self.buf_sz = 0
        self.source = source
        self.gen = source.gen
        self.buffer = bytes()

    def read_bytes(self, sz: int):
        if self.buf_loc + sz <= self.buf_sz:
            self.buf_loc += sz
            return self.buffer[self.buf_loc - sz: self.buf_loc]
        # Create a temporary buffer that bridges two or more source chunks
        bridge = bytearray(self.buffer[self.buf_loc: self.buf_sz])
        self.buf_loc = 0
        self.buf_sz = 0
        while len(bridge) < sz:
            chunk = next(self.gen, None)
            if not chunk:
                raise StreamCompleteException
            x = len(chunk)
            if len(bridge) + x <= sz:
                bridge.extend(chunk)
            else:
                tail = sz - len(bridge)
                bridge.extend(chunk[:tail])
                self.buffer = chunk
                self.buf_sz = x
                self.buf_loc = tail
        return bridge

    def read_byte(self) -> int:
        if self.buf_loc < self.buf_sz:
            self.buf_loc += 1
            return self.buffer[self.buf_loc - 1]
        self.buf_sz = 0
        self.buf_loc = 0
        chunk = next(self.gen, None)
        if not chunk:
            raise StreamCompleteException
        x = len(chunk)
        if x > 1:
            self.buffer = chunk
            self.buf_loc = 1
            self.buf_sz = x
        return chunk[0]

    def read_leb128(self) -> int:
        sz = 0
        shift = 0
        while True:
            b = self.read_byte()
            sz += ((b & 0x7f) << shift)
            if (b & 0x80) == 0:
                return sz
            shift += 7

    def read_leb128_str(self) -> str:
        sz = self.read_leb128()
        return self.read_bytes(sz).decode()

    def read_uint64(self) -> int:
        return int.from_bytes(self.read_bytes(8), 'little', signed=False)

    def read_str_col(self,
                     num_rows: int,
                     encoding: str,
                     nullable: bool = False,
                     null_obj: Any = None) -> Iterable[str]:
        column = []
        app = column.append
        null_map = self.read_bytes(num_rows) if nullable else None
        for ix in range(num_rows):
            sz = 0
            shift = 0
            while True:
                b = self.read_byte()
                sz += ((b & 0x7f) << shift)
                if (b & 0x80) == 0:
                    break
                shift += 7
            x = self.read_bytes(sz)
            if null_map and null_map[ix]:
                app(null_obj)
            elif encoding:
                try:
                    app(x.decode(encoding))
                except UnicodeDecodeError:
                    app(x.hex())
            else:
                app(x)
        return column

    def read_bytes_col(self, sz: int, num_rows: int) -> Iterable[bytes]:
        source = self.read_bytes(sz * num_rows)
        return [bytes(source[x:x+sz]) for x in range(0, sz * num_rows, sz)]

    def read_fixed_str_col(self, sz: int, num_rows: int, encoding: str) -> Iterable[str]:
        source = self.read_bytes(sz * num_rows)
        column = []
        app = column.append
        for ix in range(0, sz * num_rows, sz):
            try:
                app(str(source[ix: ix + sz], encoding).rstrip('\x00'))
            except UnicodeDecodeError:
                app(source[ix: ix + sz].hex())
        return column

    def read_array(self, array_type: str, num_rows: int) -> Iterable[Any]:
        column = array.array(array_type)
        sz = column.itemsize * num_rows
        b = self.read_bytes(sz)
        column.frombytes(b)
        if must_swap:
            column.byteswap()
        return column

    @property
    def last_message(self):
        if len(self.buffer) == 0:
            return None
        return self.buffer.decode()

    def close(self):
        if self.source:
            self.source.close()
            self.source = None