aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/clickhouse-connect/clickhouse_connect/driver/dataconv.py
blob: 29c96a9a66f327c0d3e5260ef7ab2079f31b19cb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import array
from datetime import datetime, date, tzinfo
from ipaddress import IPv4Address
from typing import Sequence, Optional, Any
from uuid import UUID, SafeUUID

from clickhouse_connect.driver.common import int_size
from clickhouse_connect.driver.types import ByteSource
from clickhouse_connect.driver.options import np


MONTH_DAYS = (0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365)
MONTH_DAYS_LEAP = (0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366)


def read_ipv4_col(source: ByteSource, num_rows: int):
    column = source.read_array('I', num_rows)
    fast_ip_v4 = IPv4Address.__new__
    new_col = []
    app = new_col.append
    for x in column:
        ipv4 = fast_ip_v4(IPv4Address)
        ipv4._ip = x  # pylint: disable=protected-access
        app(ipv4)
    return new_col


def read_datetime_col(source: ByteSource, num_rows: int, tz_info: Optional[tzinfo]):
    src_array = source.read_array('I', num_rows)
    if tz_info is None:
        fts = datetime.utcfromtimestamp
        return [fts(ts) for ts in src_array]
    fts = datetime.fromtimestamp
    return [fts(ts, tz_info) for ts in src_array]


def epoch_days_to_date(days: int) -> date:
    cycles400, rem = divmod(days + 134774, 146097)
    cycles100, rem = divmod(rem, 36524)
    cycles, rem = divmod(rem, 1461)
    years, rem = divmod(rem, 365)
    year = (cycles << 2) + cycles400 * 400 + cycles100 * 100 + years + 1601
    if years == 4 or cycles100 == 4:
        return date(year - 1, 12, 31)
    m_list = MONTH_DAYS_LEAP if years == 3 and (year == 2000 or year % 100 != 0) else MONTH_DAYS
    month = (rem + 24) >> 5
    while rem < m_list[month]:
        month -= 1
    return date(year, month + 1, rem + 1 - m_list[month])


def read_date_col(source: ByteSource, num_rows: int):
    column = source.read_array('H', num_rows)
    return [epoch_days_to_date(x) for x in column]


def read_date32_col(source: ByteSource, num_rows: int):
    column = source.read_array('l' if int_size == 2 else 'i', num_rows)
    return [epoch_days_to_date(x) for x in column]


def read_uuid_col(source: ByteSource, num_rows: int):
    v = source.read_array('Q', num_rows * 2)
    empty_uuid = UUID(int=0)
    new_uuid = UUID.__new__
    unsafe = SafeUUID.unsafe
    oset = object.__setattr__
    column = []
    app = column.append
    for i in range(num_rows):
        ix = i << 1
        int_value = v[ix] << 64 | v[ix + 1]
        if int_value == 0:
            app(empty_uuid)
        else:
            fast_uuid = new_uuid(UUID)
            oset(fast_uuid, 'int', int_value)
            oset(fast_uuid, 'is_safe', unsafe)
            app(fast_uuid)
    return column


def read_nullable_array(source: ByteSource, array_type: str, num_rows: int, null_obj: Any):
    null_map = source.read_bytes(num_rows)
    column = source.read_array(array_type, num_rows)
    return [null_obj if null_map[ix] else column[ix] for ix in range(num_rows)]


def build_nullable_column(source: Sequence, null_map: bytes, null_obj: Any):
    return [source[ix] if null_map[ix] == 0 else null_obj for ix in range(len(source))]


def build_lc_nullable_column(index: Sequence, keys: array.array, null_obj: Any):
    column = []
    for key in keys:
        if key == 0:
            column.append(null_obj)
        else:
            column.append(index[key])
    return column


def to_numpy_array(column: Sequence):
    arr = np.empty((len(column),), dtype=np.object)
    arr[:] = column
    return arr


def pivot(data: Sequence[Sequence], start_row: int, end_row: int) -> Sequence[Sequence]:
    return tuple(zip(*data[start_row: end_row]))


def write_str_col(column: Sequence, encoding: Optional[str], dest: bytearray):
    app = dest.append
    for x in column:
        if not x:
            app(0)
        else:
            if encoding:
                x = x.encode(encoding)
            sz = len(x)
            while True:
                b = sz & 0x7f
                sz >>= 7
                if sz == 0:
                    app(b)
                    break
                app(0x80 | b)
            dest += x