diff options
| author | robot-piglet <[email protected]> | 2025-08-01 00:01:09 +0300 |
|---|---|---|
| committer | robot-piglet <[email protected]> | 2025-08-01 00:11:46 +0300 |
| commit | 75fd1fc757cc04e434a65784ae4ba6e28350878d (patch) | |
| tree | def4a4c6e8a93c0f37b563a6bb86bc7936fc3912 /contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py | |
| parent | f5d4ccd1e8d8054636ee31f953767a529801fcbf (diff) | |
Intermediate changes
commit_hash:11a36b37f1d393ab351897e8a0b5bf4de5871fe0
Diffstat (limited to 'contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py')
| -rw-r--r-- | contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py | 56 |
1 files changed, 39 insertions, 17 deletions
diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py index 793ca3f953d..c055c639675 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py @@ -54,6 +54,7 @@ class HttpClient(Client): username: str, password: str, database: str, + access_token: Optional[str] = None, compress: Union[bool, str] = True, query_limit: int = 0, query_retries: int = 2, @@ -73,12 +74,16 @@ class HttpClient(Client): apply_server_timezone: Optional[Union[str, bool]] = None, show_clickhouse_errors: Optional[bool] = None, autogenerate_session_id: Optional[bool] = None, - tls_mode: Optional[str] = None): + tls_mode: Optional[str] = None, + proxy_path: str = ''): """ Create an HTTP ClickHouse Connect client See clickhouse_connect.get_client for parameters """ - self.url = f'{interface}://{host}:{port}' + proxy_path = proxy_path.lstrip('/') + if proxy_path: + proxy_path = '/' + proxy_path + self.url = f'{interface}://{host}:{port}{proxy_path}' self.headers = {} self.params = dict_copy(HttpClient.params) ch_settings = dict_copy(settings, self.params) @@ -115,8 +120,11 @@ class HttpClient(Client): else: self.http = default_pool_manager() - if (not client_cert or tls_mode in ('strict', 'proxy')) and username: + if access_token: + self.headers['Authorization'] = f'Bearer {access_token}' + elif (not client_cert or tls_mode in ('strict', 'proxy')) and username: self.headers['Authorization'] = 'Basic ' + b64encode(f'{username}:{password}'.encode()).decode() + self.headers['User-Agent'] = common.build_client_name(client_name) self._read_format = self._write_format = 'Native' self._transform = NativeTransform() @@ -180,6 +188,12 @@ class HttpClient(Client): def get_client_setting(self, key) -> Optional[str]: return self.params.get(key) + def set_access_token(self, access_token: str): + auth_header = self.headers.get('Authorization') + if auth_header and not auth_header.startswith('Bearer'): + raise ProgrammingError('Cannot set access token when a different auth type is used') + self.headers['Authorization'] = f'Bearer {access_token}' + def _prep_query(self, context: QueryContext): final_query = super()._prep_query(context) if context.is_insert: @@ -228,7 +242,7 @@ class HttpClient(Client): headers['Content-Type'] = 'text/plain; charset=utf-8' response = self._raw_request(body, params, - headers, + dict_copy(headers, context.transport_settings), stream=True, retries=self.query_retries, fields=fields, @@ -266,7 +280,7 @@ class HttpClient(Client): if self.database: params['database'] = self.database params.update(self._validate_settings(context.settings)) - + headers = dict_copy(headers, context.transport_settings) response = self._raw_request(block_gen, params, headers, error_handler=error_handler, server_wait=False) logger.debug('Context insert response code: %d, content: %s', response.status, response.data) context.data = None @@ -277,7 +291,8 @@ class HttpClient(Client): insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None, settings: Optional[Dict] = None, fmt: Optional[str] = None, - compression: Optional[str] = None) -> QuerySummary: + compression: Optional[str] = None, + transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: """ See BaseClient doc_string for this method """ @@ -297,6 +312,7 @@ class HttpClient(Client): if self.database: params['database'] = self.database params.update(self._validate_settings(settings or {})) + headers = dict_copy(headers, transport_settings) response = self._raw_request(insert_block, params, headers, server_wait=False) logger.debug('Raw insert response code: %d, content: %s', response.status, response.data) return QuerySummary(self._summary(response)) @@ -318,7 +334,8 @@ class HttpClient(Client): data: Union[str, bytes] = None, settings: Optional[Dict] = None, use_database: int = True, - external_data: Optional[ExternalData] = None) -> Union[str, int, Sequence[str], QuerySummary]: + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None) -> Union[str, int, Sequence[str], QuerySummary]: """ See BaseClient doc_string for this method """ @@ -346,7 +363,7 @@ class HttpClient(Client): if use_database and self.database: params['database'] = self.database params.update(self._validate_settings(settings or {})) - + headers = dict_copy(headers, transport_settings) method = 'POST' if payload or fields else 'GET' response = self._raw_request(payload, params, headers, method, fields=fields, server_wait=False) if response.data: @@ -398,16 +415,18 @@ class HttpClient(Client): data = data.encode() headers = dict_copy(self.headers, headers) attempts = 0 + final_params = {} if server_wait: - params['wait_end_of_query'] = '1' + final_params['wait_end_of_query'] = '1' # We can't actually read the progress headers, but we enable them so ClickHouse sends something # to keep the connection alive when waiting for long-running queries and (2) to get summary information # if not streaming if self._send_progress: - params['send_progress_in_http_headers'] = '1' + final_params['send_progress_in_http_headers'] = '1' if self._progress_interval: - params['http_headers_progress_interval_ms'] = self._progress_interval - final_params = dict_copy(self.params, params) + final_params['http_headers_progress_interval_ms'] = self._progress_interval + final_params = dict_copy(self.params, final_params) + final_params = dict_copy(final_params, params) url = f'{self.url}?{urlencode(final_params)}' kwargs = { 'headers': headers, @@ -466,24 +485,27 @@ class HttpClient(Client): settings: Optional[Dict[str, Any]] = None, fmt: str = None, use_database: bool = True, - external_data: Optional[ExternalData] = None) -> bytes: + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None) -> bytes: """ See BaseClient doc_string for this method """ body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) - return self._raw_request(body, params, fields=fields).data + return self._raw_request(body, params, fields=fields, headers=transport_settings).data def raw_stream(self, query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, settings: Optional[Dict[str, Any]] = None, fmt: str = None, use_database: bool = True, - external_data: Optional[ExternalData] = None) -> io.IOBase: + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase: """ See BaseClient doc_string for this method """ body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) - return self._raw_request(body, params, fields=fields, stream=True, server_wait=False) + return self._raw_request(body, params, fields=fields, stream=True, server_wait=False, + headers=transport_settings) def _prep_raw_query(self, query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]], @@ -515,7 +537,7 @@ class HttpClient(Client): See BaseClient doc_string for this method """ try: - response = self.http.request('GET', f'{self.url}/ping', timeout=3) + response = self.http.request('GET', f'{self.url}/ping', timeout=3, preload_content=True) return 200 <= response.status < 300 except HTTPError: logger.debug('ping failed', exc_info=True) |
