aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/clickhouse-connect/clickhouse_connect/driver/transform.py
blob: b5ae795c9ad1ab8cec4199e6ecd4762aa9935ae5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
from typing import Union

from clickhouse_connect.datatypes import registry
from clickhouse_connect.driver.common import write_leb128
from clickhouse_connect.driver.exceptions import StreamCompleteException, StreamFailureError
from clickhouse_connect.driver.insert import InsertContext
from clickhouse_connect.driver.npquery import NumpyResult
from clickhouse_connect.driver.query import QueryResult, QueryContext
from clickhouse_connect.driver.types import ByteSource
from clickhouse_connect.driver.compression import get_compressor

_EMPTY_CTX = QueryContext()

logger = logging.getLogger(__name__)


class NativeTransform:
    # pylint: disable=too-many-locals
    @staticmethod
    def parse_response(source: ByteSource, context: QueryContext = _EMPTY_CTX) -> Union[NumpyResult, QueryResult]:
        names = []
        col_types = []
        block_num = 0

        def get_block():
            nonlocal block_num
            result_block = []
            try:
                try:
                    if context.block_info:
                        source.read_bytes(8)
                    num_cols = source.read_leb128()
                except StreamCompleteException:
                    return None
                num_rows = source.read_leb128()
                for col_num in range(num_cols):
                    name = source.read_leb128_str()
                    type_name = source.read_leb128_str()
                    if block_num == 0:
                        names.append(name)
                        col_type = registry.get_from_name(type_name)
                        col_types.append(col_type)
                    else:
                        col_type = col_types[col_num]
                    if num_rows == 0:
                        result_block.append(tuple())
                    else:
                        context.start_column(name)
                        column = col_type.read_column(source, num_rows, context)
                        result_block.append(column)
            except Exception as ex:
                source.close()
                if isinstance(ex, StreamCompleteException):
                    # We ran out of data before it was expected, this could be ClickHouse reporting an error
                    # in the response
                    if source.last_message:
                        raise StreamFailureError(extract_error_message(source.last_message)) from None
                raise
            block_num += 1
            return result_block

        first_block = get_block()
        if first_block is None:
            return NumpyResult() if context.use_numpy else QueryResult([])

        def gen():
            yield first_block
            while True:
                next_block = get_block()
                if next_block is None:
                    return
                yield next_block

        if context.use_numpy:
            res_types = [col.dtype if hasattr(col, 'dtype') else 'O' for col in first_block]
            return NumpyResult(gen(), tuple(names), tuple(col_types), res_types, source)
        return QueryResult(None, gen(), tuple(names), tuple(col_types), context.column_oriented, source)

    @staticmethod
    def build_insert(context: InsertContext):
        compressor = get_compressor(context.compression)

        def chunk_gen():
            for block in context.next_block():
                output = bytearray()
                output += block.prefix
                write_leb128(block.column_count, output)
                write_leb128(block.row_count, output)
                for col_name, col_type, data in zip(block.column_names, block.column_types, block.column_data):
                    col_enc = col_name.encode()
                    write_leb128(len(col_enc), output)
                    output += col_enc
                    col_enc = col_type.insert_name.encode()
                    write_leb128(len(col_enc), output)
                    output += col_enc
                    context.start_column(col_name)
                    try:
                        col_type.write_column(data, output, context)
                    except Exception as ex:  # pylint: disable=broad-except
                        # This is hideous, but some low level serializations can fail while streaming
                        # the insert if the user has included bad data in the column.  We need to ensure that the
                        # insert fails (using garbage data) to avoid a partial insert, and use the context to
                        # propagate the correct exception to the user
                        logger.error('Error serializing column `%s` into data type `%s`',
                                     col_name, col_type.name, exc_info=True)
                        context.insert_exception = ex
                        yield 'INTERNAL EXCEPTION WHILE SERIALIZING'.encode()
                        return
                yield compressor.compress_block(output)
            footer = compressor.flush()
            if footer:
                yield footer

        return chunk_gen()


def extract_error_message(message: bytes) -> str:
    if len(message) > 1024:
        message = message[-1024:]
    error_start = message.find('Code: '.encode())
    if error_start != -1:
        message = message[error_start:]
    try:
        message_str = message.decode()
    except UnicodeError:
        message_str = f'unrecognized data found in stream: `{message.hex()[128:]}`'
    return message_str