aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/clickhouse-connect/clickhouse_connect/driver/transform.py
blob: 3181b9e5e46267ab2652cfc06a0e52b018a49a8f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import logging
from typing import Union

from clickhouse_connect.datatypes import registry
from clickhouse_connect.driver.common import write_leb128
from clickhouse_connect.driver.exceptions import StreamCompleteException, StreamFailureError
from clickhouse_connect.driver.insert import InsertContext
from clickhouse_connect.driver.npquery import NumpyResult
from clickhouse_connect.driver.query import QueryResult, QueryContext
from clickhouse_connect.driver.types import ByteSource
from clickhouse_connect.driver.compression import get_compressor

_EMPTY_CTX = QueryContext()

logger = logging.getLogger(__name__)


class NativeTransform:
    # pylint: disable=too-many-locals
    @staticmethod
    def parse_response(source: ByteSource, context: QueryContext = _EMPTY_CTX) -> Union[NumpyResult, QueryResult]:
        names = []
        col_types = []
        block_num = 0

        def get_block():
            nonlocal block_num
            result_block = []
            try:
                try:
                    if context.block_info:
                        source.read_bytes(8)
                    num_cols = source.read_leb128()
                except StreamCompleteException:
                    return None
                num_rows = source.read_leb128()
                for col_num in range(num_cols):
                    name = source.read_leb128_str()
                    type_name = source.read_leb128_str()
                    if block_num == 0:
                        names.append(name)
                        col_type = registry.get_from_name(type_name)
                        col_types.append(col_type)
                    else:
                        col_type = col_types[col_num]
                    if num_rows == 0:
                        result_block.append(tuple())
                    else:
                        context.start_column(name)
                        column = col_type.read_column(source, num_rows, context)
                        result_block.append(column)
            except Exception as ex:
                source.close()
                if isinstance(ex, StreamCompleteException):
                    # We ran out of data before it was expected, this could be ClickHouse reporting an error
                    # in the response
                    message = source.last_message
                    if len(message) > 1024:
                        message = message[-1024:]
                    error_start = message.find('Code: ')
                    if error_start != -1:
                        message = message[error_start:]
                    raise StreamFailureError(message) from None
                raise
            block_num += 1
            return result_block

        first_block = get_block()
        if first_block is None:
            return NumpyResult() if context.use_numpy else QueryResult([])

        def gen():
            yield first_block
            while True:
                next_block = get_block()
                if next_block is None:
                    return
                yield next_block

        if context.use_numpy:
            res_types = [col.dtype if hasattr(col, 'dtype') else 'O' for col in first_block]
            return NumpyResult(gen(), tuple(names), tuple(col_types), res_types, source)
        return QueryResult(None, gen(), tuple(names), tuple(col_types), context.column_oriented, source)

    @staticmethod
    def build_insert(context: InsertContext):
        compressor = get_compressor(context.compression)

        def chunk_gen():
            for block in context.next_block():
                output = bytearray()
                output += block.prefix
                write_leb128(block.column_count, output)
                write_leb128(block.row_count, output)
                for col_name, col_type, data in zip(block.column_names, block.column_types, block.column_data):
                    write_leb128(len(col_name), output)
                    output += col_name.encode()
                    write_leb128(len(col_type.name), output)
                    output += col_type.name.encode()
                    context.start_column(col_name)
                    try:
                        col_type.write_column(data, output, context)
                    except Exception as ex:  # pylint: disable=broad-except
                        # This is hideous, but some low level serializations can fail while streaming
                        # the insert if the user has included bad data in the column.  We need to ensure that the
                        # insert fails (using garbage data) to avoid a partial insert, and use the context to
                        # propagate the correct exception to the user
                        logger.error('Error serializing column `%s` into data type `%s`',
                                     col_name, col_type.name, exc_info=True)
                        context.insert_exception = ex
                        yield 'INTERNAL EXCEPTION WHILE SERIALIZING'.encode()
                        return
                yield compressor.compress_block(output)
            footer = compressor.flush()
            if footer:
                yield footer

        return chunk_gen()