aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/cursor.py
blob: b8f23452ac6922713dd45c86201787bf5fd735e6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import logging
import re

from typing import Optional, Sequence

from clickhouse_connect.datatypes.registry import get_from_name
from clickhouse_connect.driver.common import unescape_identifier
from clickhouse_connect.driver.exceptions import ProgrammingError
from clickhouse_connect.driver import Client
from clickhouse_connect.driver.parser import parse_callable
from clickhouse_connect.driver.query import remove_sql_comments

logger = logging.getLogger(__name__)

insert_re = re.compile(r'^\s*INSERT\s+INTO\s+(.*$)', re.IGNORECASE)
str_type = get_from_name('String')
int_type = get_from_name('Int32')


class Cursor:
    """
    See :ref:`https://peps.python.org/pep-0249/`
    """

    def __init__(self, client: Client):
        self.client = client
        self.arraysize = 1
        self.data: Optional[Sequence] = None
        self.names = []
        self.types = []
        self._rowcount = 0
        self._ix = 0

    def check_valid(self):
        if self.data is None:
            raise ProgrammingError('Cursor is not valid')

    @property
    def description(self):
        return [(n, t, None, None, None, None, True) for n, t in zip(self.names, self.types)]

    @property
    def rowcount(self):
        return self._rowcount

    def close(self):
        self.data = None

    def execute(self, operation: str, parameters=None):
        query_result = self.client.query(operation, parameters)
        self.data = query_result.result_set
        self._rowcount = len(self.data)
        if query_result.column_names:
            self.names = query_result.column_names
            self.types = [x.name for x in query_result.column_types]
        elif self.data:
            self.names = [f'col_{x}' for x in range(len(self.data[0]))]
            self.types = [x.__class__ for x in self.data[0]]

    def _try_bulk_insert(self, operation: str, data):
        match = insert_re.match(remove_sql_comments(operation))
        if not match:
            return False
        temp = match.group(1)
        table_end = min(temp.find(' '), temp.find('('))
        table = temp[:table_end].strip()
        temp = temp[table_end:].strip()
        if temp[0] == '(':
            _, op_columns, temp = parse_callable(temp)
        else:
            op_columns = None
        if 'VALUES' not in temp.upper():
            return False
        col_names = list(data[0].keys())
        if op_columns and {unescape_identifier(x) for x in op_columns} != set(col_names):
            return False  # Data sent in doesn't match the columns in the insert statement
        data_values = [list(row.values()) for row in data]
        self.client.insert(table, data_values, col_names)
        self.data = []
        return True

    def executemany(self, operation, parameters):
        if not parameters or self._try_bulk_insert(operation, parameters):
            return
        self.data = []
        try:
            for param_row in parameters:
                query_result = self.client.query(operation, param_row)
                self.data.extend(query_result.result_set)
                if self.names or self.types:
                    if query_result.column_names != self.names:
                        logger.warning('Inconsistent column names %s : %s for operation %s in cursor executemany',
                                       self.names, query_result.column_names, operation)
                else:
                    self.names = query_result.column_names
                    self.types = query_result.column_types
        except TypeError as ex:
            raise ProgrammingError(f'Invalid parameters {parameters} passed to cursor executemany') from ex
        self._rowcount = len(self.data)

    def fetchall(self):
        self.check_valid()
        ret = self.data
        self._ix = self._rowcount
        return ret

    def fetchone(self):
        self.check_valid()
        if self._ix >= self._rowcount:
            return None
        val = self.data[self._ix]
        self._ix += 1
        return val

    def fetchmany(self, size: int = -1):
        self.check_valid()
        end = self._ix + max(size, self._rowcount - self._ix)
        ret = self.data[self._ix: end]
        self._ix = end
        return ret

    def nextset(self):
        raise NotImplementedError

    def callproc(self, *args, **kwargs):
        raise NotImplementedError