1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
|
import logging
import re
from typing import Optional, Sequence, List, Dict
from clickhouse_connect.datatypes.registry import get_from_name
from clickhouse_connect.driver.common import unescape_identifier
from clickhouse_connect.driver.exceptions import ProgrammingError
from clickhouse_connect.driver import Client
from clickhouse_connect.driver.parser import parse_callable
from clickhouse_connect.driver.query import remove_sql_comments
logger = logging.getLogger(__name__)
insert_re = re.compile(r'^\s*INSERT\s+INTO\s+(.*$)', re.IGNORECASE)
str_type = get_from_name('String')
int_type = get_from_name('Int32')
# pylint: disable=too-many-instance-attributes
class Cursor:
"""
See :ref:`https://peps.python.org/pep-0249/`
"""
def __init__(self, client: Client):
self.client = client
self.arraysize = 1
self.data: Optional[Sequence] = None
self.names = []
self.types = []
self._rowcount = 0
self._summary: List[Dict[str, str]] = []
self._ix = 0
def check_valid(self):
if self.data is None:
raise ProgrammingError('Cursor is not valid')
@property
def description(self):
return [(n, t, None, None, None, None, True) for n, t in zip(self.names, self.types)]
@property
def rowcount(self):
return self._rowcount
@property
def summary(self) -> List[Dict[str, str]]:
return self._summary
def close(self):
self.data = None
def execute(self, operation: str, parameters=None):
query_result = self.client.query(operation, parameters)
self.data = query_result.result_set
self._rowcount = len(self.data)
self._summary.append(query_result.summary)
if query_result.column_names:
self.names = query_result.column_names
self.types = [x.name for x in query_result.column_types]
elif self.data:
self.names = [f'col_{x}' for x in range(len(self.data[0]))]
self.types = [x.__class__ for x in self.data[0]]
def _try_bulk_insert(self, operation: str, data):
match = insert_re.match(remove_sql_comments(operation))
if not match:
return False
temp = match.group(1)
table_end = min(temp.find(' '), temp.find('('))
table = temp[:table_end].strip()
temp = temp[table_end:].strip()
if temp[0] == '(':
_, op_columns, temp = parse_callable(temp)
else:
op_columns = None
if 'VALUES' not in temp.upper():
return False
col_names = list(data[0].keys())
if op_columns and {unescape_identifier(x) for x in op_columns} != set(col_names):
return False # Data sent in doesn't match the columns in the insert statement
data_values = [list(row.values()) for row in data]
self.client.insert(table, data_values, col_names)
self.data = []
return True
def executemany(self, operation, parameters):
if not parameters or self._try_bulk_insert(operation, parameters):
return
self.data = []
try:
for param_row in parameters:
query_result = self.client.query(operation, param_row)
self.data.extend(query_result.result_set)
if self.names or self.types:
if query_result.column_names != self.names:
logger.warning('Inconsistent column names %s : %s for operation %s in cursor executemany',
self.names, query_result.column_names, operation)
else:
self.names = query_result.column_names
self.types = query_result.column_types
self._summary.append(query_result.summary)
except TypeError as ex:
raise ProgrammingError(f'Invalid parameters {parameters} passed to cursor executemany') from ex
self._rowcount = len(self.data)
def fetchall(self):
self.check_valid()
ret = self.data
self._ix = self._rowcount
return ret
def fetchone(self):
self.check_valid()
if self._ix >= self._rowcount:
return None
val = self.data[self._ix]
self._ix += 1
return val
def fetchmany(self, size: int = -1):
self.check_valid()
end = self._ix + max(size, self._rowcount - self._ix)
ret = self.data[self._ix: end]
self._ix = end
return ret
def nextset(self):
raise NotImplementedError
def callproc(self, *args, **kwargs):
raise NotImplementedError
|