summaryrefslogtreecommitdiffstats
path: root/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py
diff options
context:
space:
mode:
authorrobot-piglet <[email protected]>2025-08-01 00:01:09 +0300
committerrobot-piglet <[email protected]>2025-08-01 00:11:46 +0300
commit75fd1fc757cc04e434a65784ae4ba6e28350878d (patch)
treedef4a4c6e8a93c0f37b563a6bb86bc7936fc3912 /contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py
parentf5d4ccd1e8d8054636ee31f953767a529801fcbf (diff)
Intermediate changes
commit_hash:11a36b37f1d393ab351897e8a0b5bf4de5871fe0
Diffstat (limited to 'contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py')
-rw-r--r--contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py56
1 files changed, 39 insertions, 17 deletions
diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py
index 793ca3f953d..c055c639675 100644
--- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py
+++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py
@@ -54,6 +54,7 @@ class HttpClient(Client):
username: str,
password: str,
database: str,
+ access_token: Optional[str] = None,
compress: Union[bool, str] = True,
query_limit: int = 0,
query_retries: int = 2,
@@ -73,12 +74,16 @@ class HttpClient(Client):
apply_server_timezone: Optional[Union[str, bool]] = None,
show_clickhouse_errors: Optional[bool] = None,
autogenerate_session_id: Optional[bool] = None,
- tls_mode: Optional[str] = None):
+ tls_mode: Optional[str] = None,
+ proxy_path: str = ''):
"""
Create an HTTP ClickHouse Connect client
See clickhouse_connect.get_client for parameters
"""
- self.url = f'{interface}://{host}:{port}'
+ proxy_path = proxy_path.lstrip('/')
+ if proxy_path:
+ proxy_path = '/' + proxy_path
+ self.url = f'{interface}://{host}:{port}{proxy_path}'
self.headers = {}
self.params = dict_copy(HttpClient.params)
ch_settings = dict_copy(settings, self.params)
@@ -115,8 +120,11 @@ class HttpClient(Client):
else:
self.http = default_pool_manager()
- if (not client_cert or tls_mode in ('strict', 'proxy')) and username:
+ if access_token:
+ self.headers['Authorization'] = f'Bearer {access_token}'
+ elif (not client_cert or tls_mode in ('strict', 'proxy')) and username:
self.headers['Authorization'] = 'Basic ' + b64encode(f'{username}:{password}'.encode()).decode()
+
self.headers['User-Agent'] = common.build_client_name(client_name)
self._read_format = self._write_format = 'Native'
self._transform = NativeTransform()
@@ -180,6 +188,12 @@ class HttpClient(Client):
def get_client_setting(self, key) -> Optional[str]:
return self.params.get(key)
+ def set_access_token(self, access_token: str):
+ auth_header = self.headers.get('Authorization')
+ if auth_header and not auth_header.startswith('Bearer'):
+ raise ProgrammingError('Cannot set access token when a different auth type is used')
+ self.headers['Authorization'] = f'Bearer {access_token}'
+
def _prep_query(self, context: QueryContext):
final_query = super()._prep_query(context)
if context.is_insert:
@@ -228,7 +242,7 @@ class HttpClient(Client):
headers['Content-Type'] = 'text/plain; charset=utf-8'
response = self._raw_request(body,
params,
- headers,
+ dict_copy(headers, context.transport_settings),
stream=True,
retries=self.query_retries,
fields=fields,
@@ -266,7 +280,7 @@ class HttpClient(Client):
if self.database:
params['database'] = self.database
params.update(self._validate_settings(context.settings))
-
+ headers = dict_copy(headers, context.transport_settings)
response = self._raw_request(block_gen, params, headers, error_handler=error_handler, server_wait=False)
logger.debug('Context insert response code: %d, content: %s', response.status, response.data)
context.data = None
@@ -277,7 +291,8 @@ class HttpClient(Client):
insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None,
settings: Optional[Dict] = None,
fmt: Optional[str] = None,
- compression: Optional[str] = None) -> QuerySummary:
+ compression: Optional[str] = None,
+ transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary:
"""
See BaseClient doc_string for this method
"""
@@ -297,6 +312,7 @@ class HttpClient(Client):
if self.database:
params['database'] = self.database
params.update(self._validate_settings(settings or {}))
+ headers = dict_copy(headers, transport_settings)
response = self._raw_request(insert_block, params, headers, server_wait=False)
logger.debug('Raw insert response code: %d, content: %s', response.status, response.data)
return QuerySummary(self._summary(response))
@@ -318,7 +334,8 @@ class HttpClient(Client):
data: Union[str, bytes] = None,
settings: Optional[Dict] = None,
use_database: int = True,
- external_data: Optional[ExternalData] = None) -> Union[str, int, Sequence[str], QuerySummary]:
+ external_data: Optional[ExternalData] = None,
+ transport_settings: Optional[Dict[str, str]] = None) -> Union[str, int, Sequence[str], QuerySummary]:
"""
See BaseClient doc_string for this method
"""
@@ -346,7 +363,7 @@ class HttpClient(Client):
if use_database and self.database:
params['database'] = self.database
params.update(self._validate_settings(settings or {}))
-
+ headers = dict_copy(headers, transport_settings)
method = 'POST' if payload or fields else 'GET'
response = self._raw_request(payload, params, headers, method, fields=fields, server_wait=False)
if response.data:
@@ -398,16 +415,18 @@ class HttpClient(Client):
data = data.encode()
headers = dict_copy(self.headers, headers)
attempts = 0
+ final_params = {}
if server_wait:
- params['wait_end_of_query'] = '1'
+ final_params['wait_end_of_query'] = '1'
# We can't actually read the progress headers, but we enable them so ClickHouse sends something
# to keep the connection alive when waiting for long-running queries and (2) to get summary information
# if not streaming
if self._send_progress:
- params['send_progress_in_http_headers'] = '1'
+ final_params['send_progress_in_http_headers'] = '1'
if self._progress_interval:
- params['http_headers_progress_interval_ms'] = self._progress_interval
- final_params = dict_copy(self.params, params)
+ final_params['http_headers_progress_interval_ms'] = self._progress_interval
+ final_params = dict_copy(self.params, final_params)
+ final_params = dict_copy(final_params, params)
url = f'{self.url}?{urlencode(final_params)}'
kwargs = {
'headers': headers,
@@ -466,24 +485,27 @@ class HttpClient(Client):
settings: Optional[Dict[str, Any]] = None,
fmt: str = None,
use_database: bool = True,
- external_data: Optional[ExternalData] = None) -> bytes:
+ external_data: Optional[ExternalData] = None,
+ transport_settings: Optional[Dict[str, str]] = None) -> bytes:
"""
See BaseClient doc_string for this method
"""
body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
- return self._raw_request(body, params, fields=fields).data
+ return self._raw_request(body, params, fields=fields, headers=transport_settings).data
def raw_stream(self, query: str,
parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
settings: Optional[Dict[str, Any]] = None,
fmt: str = None,
use_database: bool = True,
- external_data: Optional[ExternalData] = None) -> io.IOBase:
+ external_data: Optional[ExternalData] = None,
+ transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase:
"""
See BaseClient doc_string for this method
"""
body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
- return self._raw_request(body, params, fields=fields, stream=True, server_wait=False)
+ return self._raw_request(body, params, fields=fields, stream=True, server_wait=False,
+ headers=transport_settings)
def _prep_raw_query(self, query: str,
parameters: Optional[Union[Sequence, Dict[str, Any]]],
@@ -515,7 +537,7 @@ class HttpClient(Client):
See BaseClient doc_string for this method
"""
try:
- response = self.http.request('GET', f'{self.url}/ping', timeout=3)
+ response = self.http.request('GET', f'{self.url}/ping', timeout=3, preload_content=True)
return 200 <= response.status < 300
except HTTPError:
logger.debug('ping failed', exc_info=True)