diff options
| author | robot-piglet <[email protected]> | 2026-05-29 19:54:22 +0300 |
|---|---|---|
| committer | robot-piglet <[email protected]> | 2026-05-29 20:17:28 +0300 |
| commit | 87d2566e1f5bf842627474dcc2bfd79af3e76aad (patch) | |
| tree | e42e45c942f5e4c2399afa771cec8c018d4339fe /contrib/python/clickhouse-connect | |
| parent | 1af78d28ba237044096f076bfc6f35c68c52753f (diff) | |
Intermediate changes
commit_hash:52f35c8ce278d258055acfd71e7acb09b857e4ea
Diffstat (limited to 'contrib/python/clickhouse-connect')
74 files changed, 5670 insertions, 3413 deletions
diff --git a/contrib/python/clickhouse-connect/.dist-info/METADATA b/contrib/python/clickhouse-connect/.dist-info/METADATA index c70f24f991e..4a4168b9536 100644 --- a/contrib/python/clickhouse-connect/.dist-info/METADATA +++ b/contrib/python/clickhouse-connect/.dist-info/METADATA @@ -1,27 +1,26 @@ Metadata-Version: 2.4 Name: clickhouse-connect -Version: 0.15.1 +Version: 1.0.0 Summary: ClickHouse Database Core Driver for Python, Pandas, and Superset Home-page: https://github.com/ClickHouse/clickhouse-connect Author: ClickHouse Inc. Author-email: [email protected] -License: Apache License 2.0 +License: Apache-2.0 Keywords: clickhouse,superset,sqlalchemy,http,driver -Classifier: Development Status :: 4 - Beta +Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License -Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: 3.14 -Requires-Python: >=3.9,<3.15 +Requires-Python: >=3.10,<3.15 Description-Content-Type: text/markdown License-File: LICENSE Requires-Dist: certifi Requires-Dist: urllib3>=1.26 -Requires-Dist: pytz +Requires-Dist: tzdata; sys_platform == "win32" Requires-Dist: zstandard; python_version < "3.14" Requires-Dist: zstandard>=0.25.0; python_version >= "3.14" Requires-Dist: lz4; python_version < "3.14" @@ -31,7 +30,7 @@ Requires-Dist: sqlalchemy<3.0,>=1.4.40; extra == "sqlalchemy" Provides-Extra: numpy Requires-Dist: numpy; extra == "numpy" Provides-Extra: pandas -Requires-Dist: pandas<3; extra == "pandas" +Requires-Dist: pandas<4,>=2; extra == "pandas" Provides-Extra: polars Requires-Dist: polars>=1.0; extra == "polars" Provides-Extra: arrow @@ -41,6 +40,10 @@ Provides-Extra: orjson Requires-Dist: orjson; extra == "orjson" Provides-Extra: tzlocal Requires-Dist: tzlocal>=4.0; extra == "tzlocal" +Provides-Extra: tzdata +Requires-Dist: tzdata; extra == "tzdata" +Provides-Extra: async +Requires-Dist: aiohttp>=3.8.0; extra == "async" Dynamic: author Dynamic: author-email Dynamic: classifier @@ -59,7 +62,7 @@ Dynamic: summary A high performance core database driver for connecting ClickHouse to Python, Pandas, and Superset -* Pandas DataFrames (numpy and arrow-backed). Pandas 2.x and above only, 1.x is deprecated and will be dropped in 1.0. +* Pandas DataFrames (numpy and arrow-backed). Requires pandas 2.0 or later. * Numpy Arrays * PyArrow Tables * Polars DataFrames @@ -74,8 +77,11 @@ ClickHouse Connect currently uses the ClickHouse HTTP interface for maximum comp pip install clickhouse-connect ``` -ClickHouse Connect requires Python 3.9 or higher. We officially test against Python 3.10 through 3.14. -Python 3.9 is deprecated and support will be removed entirely in 1.0. +ClickHouse Connect requires Python 3.10 or higher. + +#### Upgrading from 0.x + +The 1.0 release includes breaking changes. If you are upgrading from a 0.15.x or earlier release, see [MIGRATION.md](MIGRATION.md) for a guide to the changes and their replacements. ### Superset Connectivity @@ -110,13 +116,15 @@ are not implemented. The dialect is best suited for SQLAlchemy Core usage and Su ### Asyncio Support -ClickHouse Connect provides an `AsyncClient` for use in `asyncio` environments. -See the [run_async example](./examples/run_async.py) for more details. +ClickHouse Connect provides native async support using aiohttp. To use the async client, +install the optional async dependency: + +``` +pip install clickhouse-connect[async] +``` -The current `AsyncClient` is a thread-pool executor wrapper around the synchronous client and is deprecated. -In 1.0.0 it will be replaced by a fully native async implementation. The API surface is the same, -with one difference: you will no longer be able to create a sync client first and pass it to the -`AsyncClient` constructor. Instead, use `clickhouse_connect.get_async_client()` directly. +Then create a client with `clickhouse_connect.get_async_client()`. See the +[run_async example](./examples/run_async.py) for more details. ### Complete Documentation diff --git a/contrib/python/clickhouse-connect/README.md b/contrib/python/clickhouse-connect/README.md index 7d0a5977159..243dac6c5c4 100644 --- a/contrib/python/clickhouse-connect/README.md +++ b/contrib/python/clickhouse-connect/README.md @@ -2,7 +2,7 @@ A high performance core database driver for connecting ClickHouse to Python, Pandas, and Superset -* Pandas DataFrames (numpy and arrow-backed). Pandas 2.x and above only, 1.x is deprecated and will be dropped in 1.0. +* Pandas DataFrames (numpy and arrow-backed). Requires pandas 2.0 or later. * Numpy Arrays * PyArrow Tables * Polars DataFrames @@ -17,8 +17,11 @@ ClickHouse Connect currently uses the ClickHouse HTTP interface for maximum comp pip install clickhouse-connect ``` -ClickHouse Connect requires Python 3.9 or higher. We officially test against Python 3.10 through 3.14. -Python 3.9 is deprecated and support will be removed entirely in 1.0. +ClickHouse Connect requires Python 3.10 or higher. + +#### Upgrading from 0.x + +The 1.0 release includes breaking changes. If you are upgrading from a 0.15.x or earlier release, see [MIGRATION.md](MIGRATION.md) for a guide to the changes and their replacements. ### Superset Connectivity @@ -53,13 +56,15 @@ are not implemented. The dialect is best suited for SQLAlchemy Core usage and Su ### Asyncio Support -ClickHouse Connect provides an `AsyncClient` for use in `asyncio` environments. -See the [run_async example](./examples/run_async.py) for more details. +ClickHouse Connect provides native async support using aiohttp. To use the async client, +install the optional async dependency: + +``` +pip install clickhouse-connect[async] +``` -The current `AsyncClient` is a thread-pool executor wrapper around the synchronous client and is deprecated. -In 1.0.0 it will be replaced by a fully native async implementation. The API surface is the same, -with one difference: you will no longer be able to create a sync client first and pass it to the -`AsyncClient` constructor. Instead, use `clickhouse_connect.get_async_client()` directly. +Then create a client with `clickhouse_connect.get_async_client()`. See the +[run_async example](./examples/run_async.py) for more details. ### Complete Documentation diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/__init__.py b/contrib/python/clickhouse-connect/clickhouse_connect/__init__.py index 64029f2132f..b61ccbf6b92 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/__init__.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/__init__.py @@ -1,19 +1,14 @@ import sys -import warnings -from clickhouse_connect.driver import create_client, create_async_client +if sys.version_info < (3, 10): # noqa: UP036 + raise RuntimeError("clickhouse-connect 1.0+ requires Python 3.10 or later. Python 3.9 users should pin to clickhouse-connect<1.0.") +from clickhouse_connect._version import version as __version__ +from clickhouse_connect.driver import create_async_client, create_client -if sys.version_info < (3, 10): - warnings.warn( - "Python 3.9 support is deprecated and will be removed in a future release. " - "This version of clickhouse-connect may stop working with Python 3.9 unexpectedly.", - DeprecationWarning, - stacklevel=2 - ) +__all__ = ["__version__", "driver_name", "get_client", "get_async_client"] - -driver_name = 'clickhousedb' +driver_name = "clickhousedb" get_client = create_client get_async_client = create_async_client diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/__version__.py b/contrib/python/clickhouse-connect/clickhouse_connect/__version__.py deleted file mode 100644 index 77e69eaa222..00000000000 --- a/contrib/python/clickhouse-connect/clickhouse_connect/__version__.py +++ /dev/null @@ -1 +0,0 @@ -version = "0.15.1" diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/_version.py b/contrib/python/clickhouse-connect/clickhouse_connect/_version.py new file mode 100644 index 00000000000..11a716ec1fe --- /dev/null +++ b/contrib/python/clickhouse-connect/clickhouse_connect/_version.py @@ -0,0 +1 @@ +version = "1.0.0" diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/__init__.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/__init__.py index 5d505cf7de2..9762c64a969 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/__init__.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/__init__.py @@ -1,10 +1,9 @@ from clickhouse_connect import driver_name from clickhouse_connect.cc_sqlalchemy.datatypes.base import schema_types from clickhouse_connect.cc_sqlalchemy.sql import final, sample -from clickhouse_connect.cc_sqlalchemy.sql.clauses import array_join, ArrayJoin, ch_join, ClickHouseJoin +from clickhouse_connect.cc_sqlalchemy.sql.clauses import ArrayJoin, ClickHouseJoin, array_join, ch_join -# pylint: disable=invalid-name dialect_name = driver_name ischema_names = schema_types -__all__ = ['dialect_name', 'ischema_names', 'array_join', 'ArrayJoin', 'ch_join', 'ClickHouseJoin', 'final', 'sample'] +__all__ = ["dialect_name", "ischema_names", "array_join", "ArrayJoin", "ch_join", "ClickHouseJoin", "final", "sample"] diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/__init__.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/__init__.py index f364badd886..a83eefd4c7f 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/__init__.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/__init__.py @@ -1 +1 @@ -import clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes +import clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes # noqa: F401 diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/base.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/base.py index d79b4e52a3c..a4c1d4a94f2 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/base.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/base.py @@ -1,9 +1,8 @@ import logging -from typing import Dict, Type, Optional from sqlalchemy.exc import CompileError -from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef, EMPTY_TYPE_DEF +from clickhouse_connect.datatypes.base import EMPTY_TYPE_DEF, ClickHouseType, TypeDef from clickhouse_connect.datatypes.registry import parse_name, type_map from clickhouse_connect.driver.binding import str_query_value @@ -15,11 +14,12 @@ class ChSqlaType: A SQLAlchemy TypeEngine that wraps a ClickHouseType. We don't extend TypeEngine directly, instead all concrete subclasses will inherit from TypeEngine. """ + ch_type: ClickHouseType = None generic_type: None _ch_type_cls = None _instance = None - _instance_cache: Dict[TypeDef, 'ChSqlaType'] = None + _instance_cache: dict[TypeDef, "ChSqlaType"] = None def __init_subclass__(cls): """ @@ -31,7 +31,7 @@ class ChSqlaType: try: cls._ch_type_cls = type_map[base] except KeyError: - logger.warning('Attempted to register SQLAlchemy type without corresponding ClickHouse Type') + logger.warning("Attempted to register SQLAlchemy type without corresponding ClickHouse Type") return schema_types.append(base) sqla_type_map[base] = cls @@ -106,8 +106,7 @@ class ChSqlaType: """ return self.name - # pylint: disable=unused-argument - def _with_collation(self, collation: Optional[str]) -> "ChSqlaType": + def _with_collation(self, collation: str | None) -> "ChSqlaType": """ SQLAlchemy 2.x compatibility: TypeEngine declares this abstract to support text types that can carry a collation. ClickHouse types in this dialect @@ -124,7 +123,7 @@ class CaseInsensitiveDict(dict): return super().__getitem__(item.lower()) -sqla_type_map: Dict[str, Type[ChSqlaType]] = CaseInsensitiveDict() +sqla_type_map: dict[str, type[ChSqlaType]] = CaseInsensitiveDict() schema_types = [] @@ -138,7 +137,7 @@ def sqla_type_from_name(name: str) -> ChSqlaType: try: type_cls = sqla_type_map[base] except KeyError: - err_str = f'Unrecognized ClickHouse type base: {base} name: {name}' + err_str = f"Unrecognized ClickHouse type base: {base} name: {name}" logger.error(err_str) raise CompileError(err_str) from KeyError return type_cls.build(type_def) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/sqltypes.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/sqltypes.py index 4f30b02c7a3..a11e4c1ee04 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/sqltypes.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/datatypes/sqltypes.py @@ -1,14 +1,32 @@ -import pytz +from collections.abc import Sequence from enum import Enum as PyEnum -from typing import Type, Union, Sequence -from sqlalchemy.types import Integer, Float, Numeric, Boolean as SqlaBoolean, \ - UserDefinedType, String as SqlaString, DateTime as SqlaDateTime, Date as SqlaDate, Interval from sqlalchemy.exc import ArgumentError +from sqlalchemy.types import ( + Boolean as SqlaBoolean, +) +from sqlalchemy.types import ( + Date as SqlaDate, +) +from sqlalchemy.types import ( + DateTime as SqlaDateTime, +) +from sqlalchemy.types import ( + Float, + Integer, + Interval, + Numeric, + UserDefinedType, +) +from sqlalchemy.types import ( + String as SqlaString, +) from clickhouse_connect.cc_sqlalchemy.datatypes.base import ChSqlaType -from clickhouse_connect.datatypes.base import TypeDef, NULLABLE_TYPE_DEF, LC_TYPE_DEF, EMPTY_TYPE_DEF -from clickhouse_connect.datatypes.numeric import Enum8 as ChEnum8, Enum16 as ChEnum16 +from clickhouse_connect.datatypes.base import EMPTY_TYPE_DEF, LC_TYPE_DEF, NULLABLE_TYPE_DEF, TypeDef +from clickhouse_connect.datatypes.numeric import Enum8 as ChEnum8 +from clickhouse_connect.datatypes.numeric import Enum16 as ChEnum16 +from clickhouse_connect.driver import tzutil from clickhouse_connect.driver.common import decimal_prec @@ -73,12 +91,11 @@ class Float64(ChSqlaType, Float): class Bool(ChSqlaType, SqlaBoolean): - def __init__(self, type_def: TypeDef = EMPTY_TYPE_DEF): + def __init__(self, type_def: TypeDef = EMPTY_TYPE_DEF, **kwargs): ChSqlaType.__init__(self, type_def) - SqlaBoolean.__init__(self) + SqlaBoolean.__init__(self, **kwargs) -# pylint: disable=too-many-ancestors class Boolean(Bool): pass @@ -100,14 +117,13 @@ class Decimal(ChSqlaType, Numeric): else: precision, scale = type_def.values elif not precision or scale < 0 or scale > precision: - raise ArgumentError('Invalid precision or scale for ClickHouse Decimal type') + raise ArgumentError("Invalid precision or scale for ClickHouse Decimal type") else: type_def = TypeDef(values=(precision, scale)) ChSqlaType.__init__(self, type_def) Numeric.__init__(self, precision, scale) -# pylint: disable=duplicate-code class Decimal32(Decimal): dec_size = 32 @@ -128,8 +144,13 @@ class Enum(ChSqlaType, UserDefinedType): _size = 16 python_type = str - def __init__(self, enum: Type[PyEnum] = None, keys: Sequence[str] = None, values: Sequence[int] = None, - type_def: TypeDef = None): + def __init__( + self, + enum: type[PyEnum] = None, + keys: Sequence[str] = None, + values: Sequence[int] = None, + type_def: TypeDef = None, + ): """ Construct a ClickHouse enum either from a Python Enum or parallel lists of keys and value. Note that Python enums do not support empty strings as keys, so the alternate keys/values must be used in that case @@ -143,7 +164,7 @@ class Enum(ChSqlaType, UserDefinedType): keys = [e.name for e in enum] values = [e.value for e in enum] self._validate(keys, values) - if self.__class__.__name__ == 'Enum': + if self.__class__.__name__ == "Enum": if max(values) <= 127 and min(values) >= -128: self._ch_type_cls = ChEnum8 else: @@ -155,15 +176,15 @@ class Enum(ChSqlaType, UserDefinedType): def _validate(cls, keys: Sequence, values: Sequence): bad_key = next((x for x in keys if not isinstance(x, str)), None) if bad_key: - raise ArgumentError(f'ClickHouse enum key {bad_key} is not a string') + raise ArgumentError(f"ClickHouse enum key {bad_key} is not a string") bad_value = next((x for x in values if not isinstance(x, int)), None) if bad_value: - raise ArgumentError(f'ClickHouse enum value {bad_value} is not an integer') + raise ArgumentError(f"ClickHouse enum value {bad_value} is not an integer") value_min = -(2 ** (cls._size - 1)) value_max = 2 ** (cls._size - 1) - 1 bad_value = next((x for x in values if x < value_min or x > value_max), None) if bad_value: - raise ArgumentError(f'Clickhouse enum value {bad_value} is out of range') + raise ArgumentError(f"Clickhouse enum value {bad_value} is out of range") class Enum8(Enum): @@ -239,12 +260,14 @@ class DateTime(ChSqlaType, SqlaDateTime): def __init__(self, tz: str = None, type_def: TypeDef = None): """ Date time constructor with optional ClickHouse timezone parameter if not constructed with TypeDef - :param tz: Timezone string as defined in pytz + :param tz: IANA timezone key (e.g. "UTC", "America/New_York"). Resolved via the standard + library zoneinfo module. On platforms without system zoneinfo data (notably + Windows), install the tzdata package. :param type_def: TypeDef from parse_name function """ if not type_def: if tz: - pytz.timezone(tz) + tzutil.resolve_zone(tz) type_def = TypeDef(values=(f"'{tz}'",)) else: type_def = EMPTY_TYPE_DEF @@ -257,23 +280,24 @@ class DateTime64(ChSqlaType, SqlaDateTime): """ Date time constructor with precision and timezone parameters if not constructed with TypeDef :param precision: Usually 3/6/9 for mill/micro/nanosecond precision on ClickHouse side - :param tz: Timezone string as defined in pytz + :param tz: IANA timezone key (e.g. "UTC", "America/New_York"). Resolved via the standard + library zoneinfo module. On platforms without system zoneinfo data (notably + Windows), install the tzdata package. :param type_def: TypeDef from parse_name function """ if not type_def: if tz: - pytz.timezone(tz) + tzutil.resolve_zone(tz) type_def = TypeDef(values=(precision, f"'{tz}'")) else: type_def = TypeDef(values=(precision,)) prec = type_def.values[0] if len(type_def.values) else None if not isinstance(prec, int) or prec < 0 or prec > 9: - raise ArgumentError(f'Invalid precision value {prec} for ClickHouse DateTime64') + raise ArgumentError(f"Invalid precision value {prec} for ClickHouse DateTime64") ChSqlaType.__init__(self, type_def) SqlaDateTime.__init__(self) -# pylint: disable=too-many-ancestors class Time(ChSqlaType, Interval): """ Represents the ClickHouse Time type, which corresponds to a timedelta. @@ -296,7 +320,6 @@ class Time(ChSqlaType, Interval): return None -# pylint: disable=too-many-ancestors class Time64(ChSqlaType, Interval): """ Represents the ClickHouse Time64 type with configurable precision. @@ -317,9 +340,7 @@ class Time64(ChSqlaType, Interval): precision = 3 if precision not in (3, 6, 9): - raise ArgumentError( - f"Invalid precision value {precision} for ClickHouse Time64. Must be 3, 6, or 9." - ) + raise ArgumentError(f"Invalid precision value {precision} for ClickHouse Time64. Must be 3, 6, or 9.") type_def = TypeDef(values=(precision,)) else: precision = type_def.values[0] if len(type_def.values) > 0 else 3 @@ -344,7 +365,7 @@ class Nullable: type with a Nullable wrapper """ - def __new__(cls, element: Union[ChSqlaType, Type[ChSqlaType]]): + def __new__(cls, element: ChSqlaType | type[ChSqlaType]): """ Actually returns an instance of the enclosed type with a Nullable wrapper. If element is an instance, constructs a new instance with a copied TypeDef plus the Nullable wrapper. If element is just a type, @@ -354,9 +375,9 @@ class Nullable: if callable(element): return element(type_def=NULLABLE_TYPE_DEF) if element.low_card: - raise ArgumentError('Low Cardinality type cannot be Nullable') + raise ArgumentError("Low Cardinality type cannot be Nullable") orig = element.type_def - wrappers = orig if 'Nullable' in orig.wrappers else orig.wrappers + ('Nullable',) + wrappers = orig if "Nullable" in orig.wrappers else orig.wrappers + ("Nullable",) return element.__class__(type_def=TypeDef(wrappers, orig.keys, orig.values)) @@ -366,24 +387,24 @@ class LowCardinality: type with a LowCardinality wrapper """ - def __new__(cls, element: Union[ChSqlaType, Type[ChSqlaType]]): + def __new__(cls, element: ChSqlaType | type[ChSqlaType]): + """ + Actually returns an instance of the enclosed type with a LowCardinality wrapper. If element is an instance, + constructs a new instance with a copied TypeDef plus the LowCardinality wrapper. If element is just a type, + constructs a new element of that type with only the LowCardinality wrapper. + :param element: ChSqlaType instance or class to wrap """ - Actually returns an instance of the enclosed type with a LowCardinality wrapper. If element is an instance, - constructs a new instance with a copied TypeDef plus the LowCardinality wrapper. If element is just a type, - constructs a new element of that type with only the LowCardinality wrapper. - :param element: ChSqlaType instance or class to wrap - """ if callable(element): return element(type_def=LC_TYPE_DEF) orig = element.type_def - wrappers = orig if 'LowCardinality' in orig.wrappers else ('LowCardinality',) + orig.wrappers + wrappers = orig if "LowCardinality" in orig.wrappers else ("LowCardinality",) + orig.wrappers return element.__class__(type_def=TypeDef(wrappers, orig.keys, orig.values)) class Array(ChSqlaType, UserDefinedType): python_type = list - def __init__(self, element: Union[ChSqlaType, Type[ChSqlaType]] = None, type_def: TypeDef = None): + def __init__(self, element: ChSqlaType | type[ChSqlaType] = None, type_def: TypeDef = None): """ Array constructor that can take a wrapped Array type if not constructed from a TypeDef :param element: ChSqlaType instance or class to wrap @@ -399,8 +420,12 @@ class Array(ChSqlaType, UserDefinedType): class Map(ChSqlaType, UserDefinedType): python_type = dict - def __init__(self, key_type: Union[ChSqlaType, Type[ChSqlaType]] = None, - value_type: Union[ChSqlaType, Type[ChSqlaType]] = None, type_def: TypeDef = None): + def __init__( + self, + key_type: ChSqlaType | type[ChSqlaType] = None, + value_type: ChSqlaType | type[ChSqlaType] = None, + type_def: TypeDef = None, + ): """ Map constructor that can take a wrapped key/values types if not constructed from a TypeDef :param key_type: ChSqlaType instance or class to use as keys for the Map @@ -419,12 +444,16 @@ class Map(ChSqlaType, UserDefinedType): class Tuple(ChSqlaType, UserDefinedType): python_type = tuple - def __init__(self, elements: Sequence[Union[ChSqlaType, Type[ChSqlaType]]] = None, type_def: TypeDef = None): + def __init__( + self, + elements: Sequence[ChSqlaType | type[ChSqlaType]] = None, + type_def: TypeDef = None, + ): + """ + Tuple constructor that can take a list of element types if not constructed from a TypeDef + :param elements: sequence of ChSqlaType instance or class to use as tuple element types + :param type_def: TypeDef from parse_name function """ - Tuple constructor that can take a list of element types if not constructed from a TypeDef - :param elements: sequence of ChSqlaType instance or class to use as tuple element types - :param type_def: TypeDef from parse_name function - """ if not type_def: values = [et() if callable(et) else et for et in elements] type_def = TypeDef(values=tuple(v.name for v in values)) @@ -435,6 +464,7 @@ class JSON(ChSqlaType, UserDefinedType): """ Note this isn't currently supported for insert/select, only table definitions """ + python_type = None @@ -442,14 +472,19 @@ class Nested(ChSqlaType, UserDefinedType): """ Note this isn't currently supported for insert/select, only table definitions """ - python_type = None + python_type = None class SimpleAggregateFunction(ChSqlaType, UserDefinedType): python_type = None - def __init__(self, name: str = None, element: Union[ChSqlaType, Type[ChSqlaType]] = None, type_def: TypeDef = None): + def __init__( + self, + name: str = None, + element: ChSqlaType | type[ChSqlaType] = None, + type_def: TypeDef = None, + ): """ Constructor that can take the SimpleAggregateFunction name and wrapped type if not constructed from a TypeDef :param name: Aggregate function name @@ -459,7 +494,12 @@ class SimpleAggregateFunction(ChSqlaType, UserDefinedType): if not type_def: if callable(element): element = element() - type_def = TypeDef(values=(name, element.name,)) + type_def = TypeDef( + values=( + name, + element.name, + ) + ) super().__init__(type_def) @@ -467,6 +507,7 @@ class AggregateFunction(ChSqlaType, UserDefinedType): """ Note this isn't currently supported for insert/select, only table definitions """ + python_type = None def __init__(self, *params, type_def: TypeDef = None): diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/ddl/custom.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/ddl/custom.py index 3b45b85641c..b29e8e6b26a 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/ddl/custom.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/ddl/custom.py @@ -1,17 +1,23 @@ -from sqlalchemy.sql.ddl import DDL from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.ddl import DDL -from clickhouse_connect.driver.binding import quote_identifier +from clickhouse_connect.driver.binding import format_str, quote_identifier -# pylint: disable=too-many-ancestors,abstract-method class CreateDatabase(DDL): """ SqlAlchemy DDL statement that is essentially an alternative to the built in CreateSchema DDL class """ - # pylint: disable-msg=too-many-arguments - def __init__(self, name: str, engine: str = None, zoo_path: str = None, shard_name: str = '{shard}', - replica_name: str = '{replica}', exists_ok: bool = False): + + def __init__( + self, + name: str, + engine: str = None, + zoo_path: str = None, + shard_name: str = "{shard}", + replica_name: str = "{replica}", + exists_ok: bool = False, + ): """ :param name: Database name :param engine: Database ClickHouse engine type @@ -19,22 +25,22 @@ class CreateDatabase(DDL): :param shard_name: Clickhouse shard name for Replicated database engine :param replica_name: Replica name for Replicated database engine """ - if engine and engine not in ('Ordinary', 'Atomic', 'Lazy', 'Replicated'): - raise ArgumentError(f'Unrecognized engine type {engine}') + if engine and engine not in ("Ordinary", "Atomic", "Lazy", "Replicated"): + raise ArgumentError(f"Unrecognized engine type {engine}") stmt = f"CREATE DATABASE {'IF NOT EXISTS ' if exists_ok else ''}{quote_identifier(name)}" if engine: - stmt += f' Engine {engine}' - if engine == 'Replicated': + stmt += f" Engine {engine}" + if engine == "Replicated": if not zoo_path: - raise ArgumentError('zoo_path is required for Replicated Database Engine') - stmt += f" ('{zoo_path}', '{shard_name}', '{replica_name}'" + raise ArgumentError("zoo_path is required for Replicated Database Engine") + stmt += f" ({format_str(zoo_path)}, {format_str(shard_name)}, {format_str(replica_name)})" super().__init__(stmt) -# pylint: disable=too-many-ancestors,abstract-method class DropDatabase(DDL): """ Alternative DDL statement for built in SqlAlchemy DropSchema DDL class """ + def __init__(self, name: str, missing_ok: bool = False): super().__init__(f"DROP DATABASE {'IF EXISTS ' if missing_ok else ''}{quote_identifier(name)}") diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/ddl/tableengine.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/ddl/tableengine.py index f035270a57c..9e5aaeba3a5 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/ddl/tableengine.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/ddl/tableengine.py @@ -1,5 +1,5 @@ import logging -from typing import Type, Sequence, Optional, Dict +from collections.abc import Sequence from sqlalchemy.exc import ArgumentError, SQLAlchemyError from sqlalchemy.sql.base import SchemaEventTarget @@ -7,7 +7,7 @@ from sqlalchemy.sql.visitors import Visitable logger = logging.getLogger(__name__) -engine_map: Dict[str, Type['TableEngine']] = {} +engine_map: dict[str, type["TableEngine"]] = {} def tuple_expr(expr_name, value): @@ -18,11 +18,11 @@ def tuple_expr(expr_name, value): :return: formatted parameter string """ if value is None: - return '' - v = f'{expr_name.strip()}' + return "" + v = f"{expr_name.strip()}" if isinstance(value, (tuple, list)): return f" {v} ({','.join(value)})" - return f'{v} {value}' + return f"{v} {value}" class TableEngine(SchemaEventTarget, Visitable): @@ -30,6 +30,7 @@ class TableEngine(SchemaEventTarget, Visitable): SqlAlchemy Schema element to support ClickHouse table engines. At the moment provides no real functionality other than the CREATE TABLE argument string """ + arg_names = () quoted_args = set() optional_args = set() @@ -39,40 +40,40 @@ class TableEngine(SchemaEventTarget, Visitable): engine_map[cls.__name__] = cls def __init__(self, kwargs): - # pylint: disable=no-value-for-parameter + Visitable.__init__(self) self.name = self.__class__.__name__ - te_name = f'{self.name} Table Engine' + te_name = f"{self.name} Table Engine" engine_args = [] for arg_name in self.arg_names: v = kwargs.pop(arg_name, None) if v is None: if arg_name in self.optional_args: continue - raise ValueError(f'Required engine parameter {arg_name} not provided for {te_name}') + raise ValueError(f"Required engine parameter {arg_name} not provided for {te_name}") if arg_name in self.quoted_args: engine_args.append(f"'{v}'") else: engine_args.append(v) if engine_args: - self.arg_str = f'({", ".join(engine_args)})' + self.arg_str = f"({', '.join(engine_args)})" params = [] for param_name in self.eng_params: v = kwargs.pop(param_name, None) if v is not None: - params.append(tuple_expr(param_name.upper().replace('_', ' '), v)) + params.append(tuple_expr(param_name.upper().replace("_", " "), v)) - self.full_engine = 'Engine ' + self.name + self.full_engine = "Engine " + self.name if engine_args: - self.full_engine += f'({", ".join(engine_args)})' + self.full_engine += f"({', '.join(engine_args)})" if params: - self.full_engine += ' ' + ' '.join(params) + self.full_engine += " " + " ".join(params) def compile(self): return self.full_engine def check_primary_keys(self, primary_keys: Sequence): - raise SQLAlchemyError(f'Table Engine {self.name} does not support primary keys') + raise SQLAlchemyError(f"Table Engine {self.name} does not support primary keys") def _set_parent(self, parent, **_kwargs): parent.engine = self @@ -103,47 +104,53 @@ class Set(TableEngine): class Dictionary(TableEngine): - arg_names = ['dictionary'] + arg_names = ["dictionary"] - # pylint: disable=unused-argument def __init__(self, dictionary: str = None): super().__init__(locals()) class Merge(TableEngine): - arg_names = ['db_name, tables_regexp'] + arg_names = ["db_name, tables_regexp"] - # pylint: disable=unused-argument def __init__(self, db_name: str = None, tables_regexp: str = None): super().__init__(locals()) class File(TableEngine): - arg_names = ['fmt'] + arg_names = ["fmt"] - # pylint: disable=unused-argument def __init__(self, fmt: str = None): super().__init__(locals()) class Distributed(TableEngine): - arg_names = ['cluster', 'database', 'table', 'sharding_key', 'policy_name'] - optional_args = {'sharding_key', 'policy_name'} + arg_names = ["cluster", "database", "table", "sharding_key", "policy_name"] + optional_args = {"sharding_key", "policy_name"} - # pylint: disable=unused-argument - def __init__(self, cluster: str = None, database: str = None, table=None, - sharding_key: str = None, policy_name: str = None): + def __init__( + self, + cluster: str = None, + database: str = None, + table=None, + sharding_key: str = None, + policy_name: str = None, + ): super().__init__(locals()) class MergeTree(TableEngine): eng_params = ["order_by", "partition_by", "primary_key", "sample_by"] - # pylint: disable=unused-argument - def __init__(self, order_by: str = None, primary_key: str = None, - partition_by: str = None, sample_by: str = None): + def __init__( + self, + order_by: str = None, + primary_key: str = None, + partition_by: str = None, + sample_by: str = None, + ): if not order_by and not primary_key: - raise ArgumentError(None, 'Either PRIMARY KEY or ORDER BY must be specified') + raise ArgumentError(None, "Either PRIMARY KEY or ORDER BY must be specified") super().__init__(locals()) @@ -160,66 +167,94 @@ class AggregatingMergeTree(MergeTree): class ReplacingMergeTree(TableEngine): - arg_names = ['ver'] + arg_names = ["ver"] optional_args = set(arg_names) eng_params = MergeTree.eng_params - # pylint: disable=unused-argument - def __init__(self, ver: str = None, order_by: str = None, primary_key: str = None, - partition_by: str = None, sample_by: str = None): + def __init__( + self, + ver: str = None, + order_by: str = None, + primary_key: str = None, + partition_by: str = None, + sample_by: str = None, + ): if not order_by and not primary_key: - raise ArgumentError(None, 'Either PRIMARY KEY or ORDER BY must be specified') + raise ArgumentError(None, "Either PRIMARY KEY or ORDER BY must be specified") super().__init__(locals()) class CollapsingMergeTree(TableEngine): - arg_names = ['sign'] + arg_names = ["sign"] eng_params = MergeTree.eng_params - # pylint: disable=unused-argument - def __init__(self, sign: str = None, order_by: str = None, primary_key: str = None, - partition_by: str = None, sample_by: str = None): + def __init__( + self, + sign: str = None, + order_by: str = None, + primary_key: str = None, + partition_by: str = None, + sample_by: str = None, + ): if not order_by and not primary_key: - raise ArgumentError(None, 'Either PRIMARY KEY or ORDER BY must be specified') + raise ArgumentError(None, "Either PRIMARY KEY or ORDER BY must be specified") super().__init__(locals()) class VersionedCollapsingMergeTree(TableEngine): - arg_names = ['sign', 'version'] + arg_names = ["sign", "version"] eng_params = MergeTree.eng_params - # pylint: disable=unused-argument - def __init__(self, sign: str = None, version: str = None, order_by: str = None, primary_key: str = None, - partition_by: str = None, sample_by: str = None): + def __init__( + self, + sign: str = None, + version: str = None, + order_by: str = None, + primary_key: str = None, + partition_by: str = None, + sample_by: str = None, + ): if not order_by and not primary_key: - raise ArgumentError(None, 'Either PRIMARY KEY or ORDER BY must be specified') + raise ArgumentError(None, "Either PRIMARY KEY or ORDER BY must be specified") super().__init__(locals()) class GraphiteMergeTree(TableEngine): - arg_names = ['config_section'] + arg_names = ["config_section"] quoted_args = set(arg_names) eng_params = MergeTree.eng_params - # pylint: disable=unused-argument - def __init__(self, config_section: str = None, version: str = None, order_by: str = None, primary_key: str = None, - partition_by: str = None, sample_by: str = None): + def __init__( + self, + config_section: str = None, + version: str = None, + order_by: str = None, + primary_key: str = None, + partition_by: str = None, + sample_by: str = None, + ): if not order_by and not primary_key: - raise ArgumentError(None, 'Either PRIMARY KEY or ORDER BY must be specified') + raise ArgumentError(None, "Either PRIMARY KEY or ORDER BY must be specified") super().__init__(locals()) class ReplicatedMergeTree(TableEngine): - arg_names = ['zk_path', 'replica'] + arg_names = ["zk_path", "replica"] quoted_args = set(arg_names) optional_args = quoted_args eng_params = MergeTree.eng_params - # pylint: disable=unused-argument - def __init__(self, order_by: str = None, primary_key: str = None, partition_by: str = None, sample_by: str = None, - zk_path: str = None, replica: str = None): + def __init__( + self, + order_by: str = None, + primary_key: str = None, + partition_by: str = None, + sample_by: str = None, + zk_path: str = None, + replica: str = None, + ): if not order_by and not primary_key: - raise ArgumentError(None, 'Either PRIMARY KEY or ORDER BY must be specified') + raise ArgumentError(None, "Either PRIMARY KEY or ORDER BY must be specified") super().__init__(locals()) @@ -237,7 +272,6 @@ class ReplicatedReplacingMergeTree(TableEngine): optional_args = {"zk_path", "replica", "ver"} eng_params = MergeTree.eng_params - # pylint: disable=unused-argument def __init__( self, ver: str = None, @@ -259,7 +293,6 @@ class ReplicatedCollapsingMergeTree(TableEngine): optional_args = {"zk_path", "replica"} eng_params = MergeTree.eng_params - # pylint: disable=unused-argument def __init__( self, sign: str = None, @@ -281,7 +314,6 @@ class ReplicatedVersionedCollapsingMergeTree(TableEngine): optional_args = {"zk_path", "replica"} eng_params = MergeTree.eng_params - # pylint: disable=unused-argument def __init__( self, sign: str = None, @@ -304,7 +336,6 @@ class ReplicatedGraphiteMergeTree(TableEngine): optional_args = {"zk_path", "replica"} eng_params = MergeTree.eng_params - # pylint: disable=unused-argument def __init__( self, config_section: str = None, @@ -340,7 +371,7 @@ class SharedGraphiteMergeTree(GraphiteMergeTree): pass -def build_engine(full_engine: str) -> Optional[TableEngine]: +def build_engine(full_engine: str) -> TableEngine | None: """ Factory function to create TableEngine class from ClickHouse full_engine expression :param full_engine @@ -348,12 +379,12 @@ def build_engine(full_engine: str) -> Optional[TableEngine]: """ if not full_engine: return None - name = full_engine.split(' ')[0].split('(')[0] + name = full_engine.split(" ")[0].split("(")[0] try: engine_cls = engine_map[name] except KeyError: - if not name.startswith('System'): - logger.warning('Engine %s not found', name) + if not name.startswith("System"): + logger.warning("Engine %s not found", name) return None engine = engine_cls.__new__(engine_cls) engine.name = name diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/dialect.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/dialect.py index 4a415914714..e6c8b9ddf49 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/dialect.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/dialect.py @@ -2,25 +2,24 @@ from sqlalchemy import text from sqlalchemy.engine.default import DefaultDialect from clickhouse_connect import dbapi - +from clickhouse_connect.cc_sqlalchemy import dialect_name, ischema_names from clickhouse_connect.cc_sqlalchemy.inspector import ChInspector from clickhouse_connect.cc_sqlalchemy.sql import full_table -from clickhouse_connect.cc_sqlalchemy.sql.ddlcompiler import ChDDLCompiler from clickhouse_connect.cc_sqlalchemy.sql.compiler import ChStatementCompiler -from clickhouse_connect.cc_sqlalchemy import ischema_names, dialect_name +from clickhouse_connect.cc_sqlalchemy.sql.ddlcompiler import ChDDLCompiler from clickhouse_connect.cc_sqlalchemy.sql.preparer import ChIdentifierPreparer -from clickhouse_connect.driver.binding import quote_identifier, format_str +from clickhouse_connect.driver.binding import format_str, quote_identifier -# pylint: disable=too-many-public-methods,no-self-use,unused-argument class ClickHouseDialect(DefaultDialect): """ See :py:class:`sqlalchemy.engine.interfaces` """ + name = dialect_name - driver = 'connect' + driver = "connect" - default_schema_name = 'default' + default_schema_name = "default" supports_native_decimal = True supports_native_boolean = True supports_statement_cache = False @@ -35,13 +34,13 @@ class ClickHouseDialect(DefaultDialect): inspector = ChInspector # SQA 1 compatibility - # pylint: disable=method-hidden + @classmethod def dbapi(cls): return dbapi # SQA 2 compatibility - # pylint: disable=method-hidden + @classmethod def import_dbapi(cls): return dbapi @@ -50,23 +49,21 @@ class ClickHouseDialect(DefaultDialect): pass def get_schema_names(self, connection, **_): - return [row.name for row in connection.execute('SHOW DATABASES')] + return [row.name for row in connection.execute(text("SHOW DATABASES"))] @staticmethod def has_database(connection, db_name): - return (connection.execute(text('SELECT name FROM system.databases ' + - f'WHERE name = {format_str(db_name)}'))).rowcount > 0 + return (connection.execute(text(f"SELECT name FROM system.databases WHERE name = {format_str(db_name)}"))).rowcount > 0 def get_table_names(self, connection, schema=None, **kw): - cmd = 'SHOW TABLES' + cmd = "SHOW TABLES" if schema: - cmd += ' FROM ' + quote_identifier(schema) - return [row.name for row in connection.execute(cmd)] + cmd += " FROM " + quote_identifier(schema) + return [row.name for row in connection.execute(text(cmd))] def get_primary_keys(self, connection, table_name, schema=None, **kw): return [] - # pylint: disable=arguments-renamed def get_pk_constraint(self, connection, table_name, schema=None, **kw): return [] @@ -95,14 +92,13 @@ class ClickHouseDialect(DefaultDialect): return [] def has_table(self, connection, table_name, schema=None, **_kw): - result = connection.execute(text(f'EXISTS TABLE {full_table(table_name, schema)}')) + result = connection.execute(text(f"EXISTS TABLE {full_table(table_name, schema)}")) row = result.fetchone() return row[0] == 1 def has_sequence(self, connection, sequence_name, schema=None, **_kw): return False - # pylint: disable=duplicate-code def do_begin_twophase(self, connection, xid): raise NotImplementedError diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/inspector.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/inspector.py index 53616d00f2a..1f0040e47dd 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/inspector.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/inspector.py @@ -1,55 +1,57 @@ import sqlalchemy.schema as sa_schema - +from sqlalchemy import text from sqlalchemy.engine.reflection import Inspector from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy import text +from clickhouse_connect.cc_sqlalchemy import dialect_name as dn from clickhouse_connect.cc_sqlalchemy.datatypes.base import sqla_type_from_name from clickhouse_connect.cc_sqlalchemy.ddl.tableengine import build_engine from clickhouse_connect.cc_sqlalchemy.sql import full_table -from clickhouse_connect.cc_sqlalchemy import dialect_name as dn -ch_col_args = ('default_type', 'codec_expression', 'ttl_expression') +ch_col_args = ("default_type", "codec_expression", "ttl_expression") def get_engine(connection, table_name, schema=None): - result_set = connection.execute(text( - f"SELECT engine_full FROM system.tables WHERE database = '{schema}' and name = '{table_name}'")) + result_set = connection.execute( + text("SELECT engine_full FROM system.tables WHERE database = :schema AND name = :table_name"), + {"schema": schema, "table_name": table_name}, + ) row = next(result_set, None) if not row: - raise NoResultFound(f'Table {schema}.{table_name} does not exist') + raise NoResultFound(f"Table {schema}.{table_name} does not exist") return build_engine(row.engine_full) class ChInspector(Inspector): - def reflect_table(self, table, include_columns, exclude_columns, *_args, **_kwargs): schema = table.schema for col in self.get_columns(table.name, schema): - name = col.pop('name') + name = col.pop("name") if (include_columns and name not in include_columns) or (exclude_columns and name in exclude_columns): continue - col_type = col.pop('type') - col_args = {f'{dn}_{key}' if key in ch_col_args else key: value for key, value in col.items() if value} + col_type = col.pop("type") + col_args = {f"{dn}_{key}" if key in ch_col_args else key: value for key, value in col.items() if value} table.append_column(sa_schema.Column(name, col_type, **col_args)) table.engine = get_engine(self.bind, table.name, schema) def get_columns(self, table_name, schema=None, **_kwargs): table_id = full_table(table_name, schema) - result_set = self.bind.execute(text(f'DESCRIBE TABLE {table_id}')) + result_set = self.bind.execute(text(f"DESCRIBE TABLE {table_id}")) if not result_set: - raise NoResultFound(f'Table {full_table} does not exist') + raise NoResultFound(f"Table {full_table} does not exist") columns = [] for row in result_set: - sqla_type = sqla_type_from_name(row.type.replace('\n', '')) - col = {'name': row.name, - 'type': sqla_type, - 'nullable': sqla_type.nullable, - 'autoincrement': False, - 'default': row.default_expression, - 'default_type': row.default_type, - 'comment': row.comment, - 'codec_expression': row.codec_expression, - 'ttl_expression': row.ttl_expression} + sqla_type = sqla_type_from_name(row.type.replace("\n", "")) + col = { + "name": row.name, + "type": sqla_type, + "nullable": sqla_type.nullable, + "autoincrement": False, + "default": row.default_expression, + "default_type": row.default_type, + "comment": row.comment, + "codec_expression": row.codec_expression, + "ttl_expression": row.ttl_expression, + } columns.append(col) return columns diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/__init__.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/__init__.py index f115d34f574..88db89e3aa3 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/__init__.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/__init__.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - from sqlalchemy import Table from sqlalchemy.sql.selectable import FromClause, Select @@ -10,17 +8,17 @@ from clickhouse_connect.driver.binding import quote_identifier _CH_MODIFIER_DIALECT = "_ch_modifier" -def full_table(table_name: str, schema: Optional[str] = None) -> str: - if table_name.startswith('(') or '.' in table_name or not schema: +def full_table(table_name: str, schema: str | None = None) -> str: + if table_name.startswith("(") or "." in table_name or not schema: return quote_identifier(table_name) - return f'{quote_identifier(schema)}.{quote_identifier(table_name)}' + return f"{quote_identifier(schema)}.{quote_identifier(table_name)}" def format_table(table: Table): return full_table(table.name, table.schema) -def _resolve_target(select_stmt: Select, table: Optional[FromClause], method_name: str) -> FromClause: +def _resolve_target(select_stmt: Select, table: FromClause | None, method_name: str) -> FromClause: """Resolve the target FROM clause for ClickHouse modifiers (FINAL/SAMPLE).""" if not isinstance(select_stmt, Select): raise TypeError(f"{method_name}() expects a SQLAlchemy Select instance") @@ -31,10 +29,7 @@ def _resolve_target(select_stmt: Select, table: Optional[FromClause], method_nam if not froms: raise ValueError(f"{method_name}() requires a table to apply the {method_name.upper()} modifier.") if len(froms) > 1: - raise ValueError( - f"{method_name}() is ambiguous for statements with multiple FROM clauses. " - "Specify the table explicitly." - ) + raise ValueError(f"{method_name}() is ambiguous for statements with multiple FROM clauses. Specify the table explicitly.") target = froms[0] if not isinstance(target, FromClause): @@ -50,8 +45,7 @@ def _target_cache_key(target: FromClause) -> str: return target.name -# pylint: disable=protected-access -def final(select_stmt: Select, table: Optional[FromClause] = None) -> Select: +def final(select_stmt: Select, table: FromClause | None = None) -> Select: """Apply the ClickHouse FINAL modifier to a select statement. FINAL forces ClickHouse to merge data parts before returning results, @@ -77,14 +71,14 @@ def final(select_stmt: Select, table: Optional[FromClause] = None) -> Select: return new_stmt -def _select_final(self: Select, table: Optional[FromClause] = None) -> Select: +def _select_final(self: Select, table: FromClause | None = None) -> Select: """ Select.final() convenience wrapper around the module-level final() helper. """ return final(self, table=table) -def sample(select_stmt: Select, sample_value: Union[str, int, float], table: Optional[FromClause] = None) -> Select: +def sample(select_stmt: Select, sample_value: str | int | float, table: FromClause | None = None) -> Select: """Apply the ClickHouse SAMPLE modifier to a select statement. Args: @@ -99,16 +93,14 @@ def sample(select_stmt: Select, sample_value: Union[str, int, float], table: Opt target = _resolve_target(select_stmt, table, "sample") hint_key = _target_cache_key(target) - new_stmt = select_stmt.with_statement_hint( - f"SAMPLE:{hint_key}:{sample_value}", dialect_name=_CH_MODIFIER_DIALECT - ) + new_stmt = select_stmt.with_statement_hint(f"SAMPLE:{hint_key}:{sample_value}", dialect_name=_CH_MODIFIER_DIALECT) ch_sample = dict(getattr(select_stmt, "_ch_sample", {})) ch_sample[target] = sample_value new_stmt._ch_sample = ch_sample return new_stmt -def _select_sample(self: Select, sample_value: Union[str, int, float], table: Optional[FromClause] = None) -> Select: +def _select_sample(self: Select, sample_value: str | int | float, table: FromClause | None = None) -> Select: """ Select.sample() convenience wrapper around the module-level sample() helper. """ diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/clauses.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/clauses.py index e7c164074ad..d6e5578b7a7 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/clauses.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/clauses.py @@ -1,5 +1,3 @@ -from typing import Optional - from sqlalchemy import and_, true from sqlalchemy.sql.base import Immutable from sqlalchemy.sql.selectable import FromClause, Join @@ -17,7 +15,7 @@ def _normalize_array_columns(array_column, alias): elif isinstance(alias, (list, tuple)): aliases = list(alias) if len(aliases) != len(columns): - raise ValueError(f"Length of alias list ({len(aliases)}) must match " f"length of array_column list ({len(columns)})") + raise ValueError(f"Length of alias list ({len(aliases)}) must match length of array_column list ({len(columns)})") else: raise ValueError("alias must be a list when array_column is a list") else: @@ -29,7 +27,6 @@ def _normalize_array_columns(array_column, alias): return list(zip(columns, aliases)) -# pylint: disable=protected-access,too-many-ancestors,abstract-method,unused-argument class ArrayJoin(Immutable, FromClause): """Represents ClickHouse ARRAY JOIN clause. @@ -91,6 +88,7 @@ class ArrayJoin(Immutable, FromClause): This ensures that when queries are cloned (e.g., for subqueries, unions, or CTEs), the left FromClause and array column references are properly deep-cloned. """ + def _default_clone(elem, **kwargs): return elem @@ -98,10 +96,7 @@ class ArrayJoin(Immutable, FromClause): clone = _default_clone self.left = clone(self.left, **kw) - self.array_columns = [ - (clone(col, **kw), alias) - for col, alias in self.array_columns - ] + self.array_columns = [(clone(col, **kw), alias) for col, alias in self.array_columns] def array_join(left, array_column, alias=None, is_left=False): @@ -195,7 +190,6 @@ def _build_using_onclause(left, right, using): return and_(*conditions) if len(conditions) > 1 else conditions[0] -# pylint: disable=too-many-ancestors,abstract-method class ClickHouseJoin(Join): """A SQLAlchemy Join subclass that supports ClickHouse-specific join features. @@ -231,8 +225,18 @@ class ClickHouseJoin(Join): ("using_columns", InternalTraversal.dp_string_list), ] - def __init__(self, left, right, onclause=None, isouter=False, full=False, - strictness=None, distribution=None, _is_cross=False, using=None): + def __init__( + self, + left, + right, + onclause=None, + isouter=False, + full=False, + strictness=None, + distribution=None, + _is_cross=False, + using=None, + ): if strictness is not None: strictness = strictness.upper() if distribution is not None: @@ -257,8 +261,8 @@ def ch_join( full=False, cross=False, using=None, - strictness: Optional[str] = None, - distribution: Optional[str] = None, + strictness: str | None = None, + distribution: str | None = None, ): """Create a ClickHouse JOIN with optional strictness, distribution, and USING support. @@ -286,5 +290,14 @@ def ch_join( if using is not None: raise ValueError("cross=True conflicts with using") onclause = true() - return ClickHouseJoin(left, right, onclause, isouter, full, - strictness, distribution, _is_cross=cross, using=using) + return ClickHouseJoin( + left, + right, + onclause, + isouter, + full, + strictness, + distribution, + _is_cross=cross, + using=using, + ) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/compiler.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/compiler.py index 65d1c2c2b2f..9443cee553e 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/compiler.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/compiler.py @@ -7,7 +7,6 @@ from clickhouse_connect.cc_sqlalchemy.datatypes.base import ChSqlaType from clickhouse_connect.cc_sqlalchemy.sql import format_table -# pylint: disable=too-many-return-statements def _resolve_ch_type_name(sqla_type): """Resolve a SQLAlchemy type instance to a ClickHouse type name string. @@ -40,10 +39,7 @@ def _resolve_ch_type_name(sqla_type): return "String" -# pylint: disable=arguments-differ class ChStatementCompiler(SQLCompiler): - - # pylint: disable=attribute-defined-outside-init,unused-argument def visit_delete(self, delete_stmt, visiting_cte=None, **kw): table = delete_stmt.table text = f"DELETE FROM {format_table(table)}" @@ -59,7 +55,6 @@ class ChStatementCompiler(SQLCompiler): return text - # pylint: disable=protected-access def visit_values(self, element, asfrom=False, from_linter=None, visiting_cte=None, **kw): """Compile a VALUES clause using ClickHouse's VALUES table function syntax. @@ -74,15 +69,12 @@ class ChStatementCompiler(SQLCompiler): if getattr(element, "_independent_ctes", None): self._dispatch_independent_ctes(element, kw) - structure = ", ".join( - f"{col.name} {_resolve_ch_type_name(col.type)}" - for col in element.columns - ) + structure = ", ".join(f"{col.name} {_resolve_ch_type_name(col.type)}" for col in element.columns) kw.setdefault("literal_binds", element.literal_binds) tuples = ", ".join( self.process( - elements.Tuple(types=element._column_types, *elem).self_group(), + elements.Tuple(types=element._column_types, *elem).self_group(), # noqa: B026 **kw, ) for chunk in element._data @@ -107,15 +99,11 @@ class ChStatementCompiler(SQLCompiler): if from_linter: # SA 2.x has _de_clone(); SA 1.4 doesn't key = element._de_clone() if hasattr(element, "_de_clone") else element - from_linter.froms[key] = ( - name if name is not None else "(unnamed VALUES element)" - ) + from_linter.froms[key] = name if name is not None else "(unnamed VALUES element)" if visiting_cte is not None and visiting_cte.element is element: if element._is_lateral: - raise CompileError( - "Can't use a LATERAL VALUES expression inside of a CTE" - ) + raise CompileError("Can't use a LATERAL VALUES expression inside of a CTE") v = f"SELECT * FROM {v}" elif name: kw["include_table"] = False @@ -201,7 +189,6 @@ class ChStatementCompiler(SQLCompiler): def update_from_clause(self, update_stmt, from_table, extra_froms, from_hints, **kw): raise NotImplementedError("ClickHouse doesn't support UPDATE with FROM clause") - # pylint: disable=unused-argument def visit_empty_set_expr(self, element_types, **kw): return "SELECT 1 WHERE 1=0" @@ -213,7 +200,6 @@ class ChStatementCompiler(SQLCompiler): kw["_ch_group_by"] = True return super().group_by_clause(select, **kw) - # pylint: disable=protected-access def visit_label( self, label, @@ -235,7 +221,6 @@ class ChStatementCompiler(SQLCompiler): **kw, ) - # pylint: disable=protected-access def _compose_select_body(self, text, select, compile_state, inner_columns, froms, byfrom, toplevel, kwargs): ch_final = getattr(select, "_ch_final", set()) ch_sample = getattr(select, "_ch_sample", {}) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/ddlcompiler.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/ddlcompiler.py index 8a2180c4662..e514b3ee95e 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/ddlcompiler.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/ddlcompiler.py @@ -1,24 +1,23 @@ from sqlalchemy import Column from sqlalchemy.sql.compiler import DDLCompiler -from clickhouse_connect.cc_sqlalchemy.sql import format_table +from clickhouse_connect.cc_sqlalchemy.sql import format_table from clickhouse_connect.driver.binding import quote_identifier class ChDDLCompiler(DDLCompiler): - def visit_create_schema(self, create, **_): - return f'CREATE DATABASE {quote_identifier(create.element)}' + return f"CREATE DATABASE {quote_identifier(create.element)}" def visit_drop_schema(self, drop, **_): - return f'DROP DATABASE {quote_identifier(drop.element)}' + return f"DROP DATABASE {quote_identifier(drop.element)}" def visit_create_table(self, create, **_): table = create.element - text = f'CREATE TABLE {format_table(table)} (' - text += ', '.join([self.get_column_specification(c.element) for c in create.columns]) - return text + ') ' + table.engine.compile() + text = f"CREATE TABLE {format_table(table)} (" + text += ", ".join([self.get_column_specification(c.element) for c in create.columns]) + return text + ") " + table.engine.compile() def get_column_specification(self, column: Column, **_): - text = f'{quote_identifier(column.name)} {column.type.compile()}' + text = f"{quote_identifier(column.name)} {column.type.compile()}" return text diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/preparer.py b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/preparer.py index f53a2bde371..b22c2b32c61 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/preparer.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/sql/preparer.py @@ -4,7 +4,6 @@ from clickhouse_connect.driver.binding import quote_identifier class ChIdentifierPreparer(IdentifierPreparer): - quote_identifier = staticmethod(quote_identifier) def _requires_quotes(self, _value): diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/common.py b/contrib/python/clickhouse-connect/clickhouse_connect/common.py index 04f43524105..8ee96e15ef4 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/common.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/common.py @@ -1,18 +1,19 @@ import getpass import sys +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Sequence, Optional, Dict +from typing import Any -from clickhouse_connect import __version__ +from clickhouse_connect._version import version as _version_string from clickhouse_connect.driver.exceptions import ProgrammingError def version() -> str: - return __version__.version + return _version_string def format_error(msg: str) -> str: - max_size = _common_settings['max_error_size'].value + max_size = _common_settings["max_error_size"].value if max_size: return msg[:max_size] return msg @@ -23,41 +24,40 @@ class CommonSetting: name: str options: Sequence[Any] default: Any - value: Optional[Any] = None + value: Any | None = None -_common_settings: Dict[str, CommonSetting] = {} +_common_settings: dict[str, CommonSetting] = {} -def build_client_name(client_name: str) -> str: - product_name = get_setting('product_name') - product_name = product_name.strip() + ' ' if product_name else '' - client_name = client_name.strip() + ' ' if client_name else '' - py_version = sys.version.split(' ', maxsplit=1)[0] - os_user = '' - if get_setting('send_os_user'): +def build_client_name(client_name: str | None) -> str: + product_name = get_setting("product_name") + product_name = product_name.strip() + " " if product_name else "" + client_name = client_name.strip() + " " if client_name else "" + py_version = sys.version.split(" ", maxsplit=1)[0] + os_user = "" + if get_setting("send_os_user"): try: - os_user = f'; os_user:{getpass.getuser()}' - except Exception: # pylint: disable=broad-except + os_user = f"; os_user:{getpass.getuser()}" + except Exception: pass - full_name = (f'{client_name}{product_name}clickhouse-connect/{version()}' + - f' (lv:py/{py_version}; mode:sync; os:{sys.platform}{os_user})') - return full_name.encode('ascii', 'ignore').decode() + full_name = f"{client_name}{product_name}clickhouse-connect/{version()} (lv:py/{py_version}; mode:sync; os:{sys.platform}{os_user})" + return full_name.encode("ascii", "ignore").decode() def get_setting(name: str) -> Any: setting = _common_settings.get(name) if setting is None: - raise ProgrammingError(f'Unrecognized common setting {name}') + raise ProgrammingError(f"Unrecognized common setting {name}") return setting.value if setting.value is not None else setting.default def set_setting(name: str, value: Any) -> None: setting = _common_settings.get(name) if setting is None: - raise ProgrammingError(f'Unrecognized common setting {name}') + raise ProgrammingError(f"Unrecognized common setting {name}") if setting.options and value not in setting.options: - raise ProgrammingError(f'Unrecognized option {value} for setting {name})') + raise ProgrammingError(f"Unrecognized option {value} for setting {name})") if value == setting.default: setting.value = None else: @@ -68,30 +68,25 @@ def _init_common(name: str, options: Sequence[Any], default: Any) -> None: _common_settings[name] = CommonSetting(name, options, default) -_init_common('autogenerate_session_id', (True, False), True) -_init_common('autogenerate_query_id', (True, False), True) -_init_common('dict_parameter_format', ('json', 'map'), 'json') -_init_common('invalid_setting_action', ('send', 'drop', 'error'), 'error') -_init_common('max_connection_age', (), 10 * 60) # Max time in seconds to keep reusing a database TCP connection -_init_common('product_name', (), '') # Product name used as part of client identification for ClickHouse query_log -_init_common('readonly', (0, 1), 0) # Implied "read_only" ClickHouse settings for versions prior to 19.17 -_init_common('send_os_user', (True, False), True) +_init_common("autogenerate_session_id", (True, False), True) +_init_common("autogenerate_query_id", (True, False), True) +_init_common("dict_parameter_format", ("json", "map"), "json") +_init_common("invalid_setting_action", ("send", "drop", "error"), "error") +_init_common("max_connection_age", (), 10 * 60) # Max time in seconds to keep reusing a database TCP connection +_init_common("product_name", (), "") # Product name used as part of client identification for ClickHouse query_log +_init_common("readonly", (0, 1), 0) # Implied "read_only" ClickHouse settings for versions prior to 19.17 +_init_common("send_os_user", (True, False), True) # Include integration tags (library name/version) in the User-Agent, e.g.: # pandas/2.2.5; polars/0.20.x; sqlalchemy/2.0.x. These tags are only included # when using relevant API methods. -_init_common('send_integration_tags', (True, False), True) +_init_common("send_integration_tags", (True, False), True) # Use the client protocol version This is needed for DateTime timezone columns but breaks with current version of # chproxy -_init_common('use_protocol_version', (True, False), True) +_init_common("use_protocol_version", (True, False), True) -_init_common('max_error_size', (), 1024) +_init_common("max_error_size", (), 1024) # HTTP raw data buffer for streaming queries. This should not be reduced below 64KB to ensure compatibility with LZ4 compression -_init_common('http_buffer_size', (), 10 * 1024 * 1024) - -# If True and using pandas 2.x, preserves the datetime64/timedelta64 -# dtype resolution (e.g., 's', 'ms', 'us', 'ns'). If False (or on -# pandas <2.x), coerces to nanosecond ('ns') resolution for compatibility. -_init_common('preserve_pandas_datetime_resolution', (True, False), False) +_init_common("http_buffer_size", (), 10 * 1024 * 1024) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/__init__.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/__init__.py index 09d640e2c05..855869e7c86 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/__init__.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/__init__.py @@ -1,11 +1,11 @@ import clickhouse_connect.datatypes.container +import clickhouse_connect.datatypes.dynamic +import clickhouse_connect.datatypes.geometric import clickhouse_connect.datatypes.network import clickhouse_connect.datatypes.numeric +import clickhouse_connect.datatypes.postinit +import clickhouse_connect.datatypes.registry import clickhouse_connect.datatypes.special import clickhouse_connect.datatypes.string import clickhouse_connect.datatypes.temporal -import clickhouse_connect.datatypes.geometric -import clickhouse_connect.datatypes.vector -import clickhouse_connect.datatypes.dynamic -import clickhouse_connect.datatypes.registry -import clickhouse_connect.datatypes.postinit +import clickhouse_connect.datatypes.vector # noqa: F401 diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/base.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/base.py index 521ff2d1fbc..220a2dda579 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/base.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/base.py @@ -1,19 +1,19 @@ import array import logging - from abc import ABC +from collections.abc import Collection, MutableSequence, Sequence from math import log -from typing import NamedTuple, Dict, Type, Any, Sequence, MutableSequence, Union, Collection +from typing import Any, NamedTuple -from clickhouse_connect.driver.common import array_type, int_size, write_array, write_uint64, low_card_version -from clickhouse_connect.driver.context import BaseQueryContext from clickhouse_connect.driver import ctypes as driver_ctypes +from clickhouse_connect.driver import options +from clickhouse_connect.driver.common import array_type, int_size, low_card_version, write_array, write_uint64 +from clickhouse_connect.driver.context import BaseQueryContext from clickhouse_connect.driver.ctypes import data_conv from clickhouse_connect.driver.exceptions import NotSupportedError from clickhouse_connect.driver.insert import InsertContext from clickhouse_connect.driver.query import QueryContext from clickhouse_connect.driver.types import ByteSource -from clickhouse_connect.driver import options logger = logging.getLogger(__name__) ch_read_formats = {} @@ -24,46 +24,56 @@ class TypeDef(NamedTuple): """ Immutable tuple that contains all additional information needed to construct a particular ClickHouseType """ + wrappers: tuple = () keys: tuple = () values: tuple = () @property def arg_str(self): - return f"({', '.join(str(v) for v in self.values)})" if self.values else '' + return f"({', '.join(str(v) for v in self.values)})" if self.values else "" -class ClickHouseType(ABC): +class ClickHouseType(ABC): # noqa: B024 """ Base class for all ClickHouseType objects. """ - __slots__ = 'nullable', 'low_card', 'wrappers', 'type_def', '__dict__' - _name_suffix = '' - encoding = 'utf8' - np_type = 'O' # Default to Numpy Object type + + __slots__ = "nullable", "low_card", "wrappers", "type_def", "__dict__" + _name_suffix = "" + encoding = "utf8" + np_type = "O" # Default to Numpy Object type nano_divisor = 0 # Only relevant for date like objects - pd_datetime_res = "ns" # Default date-like resolution for pd byte_size = 0 - valid_formats = 'native' + valid_formats = "native" python_type = None base_type = None + @property + def _null_time_unit(self): + """Extract the time unit from np_type, e.g. 'datetime64[s]' -> 's'.""" + start = self.np_type.find("[") + end = self.np_type.find("]") + if start != -1 and end != -1: + return self.np_type[start + 1 : end] + return "ns" + def __init_subclass__(cls, registered: bool = True): if registered: cls.base_type = cls.__name__ type_map[cls.base_type] = cls @classmethod - def build(cls: Type['ClickHouseType'], type_def: TypeDef): + def build(cls: type["ClickHouseType"], type_def: TypeDef): return cls(type_def) @classmethod - def _active_format(cls, fmt_map: Dict[Type['ClickHouseType'], str], ctx: BaseQueryContext): + def _active_format(cls, fmt_map: dict[type["ClickHouseType"], str], ctx: BaseQueryContext): ctx_fmt = ctx.active_fmt(cls.base_type) if ctx_fmt: return ctx_fmt - return fmt_map.get(cls, 'native') + return fmt_map.get(cls, "native") @classmethod def read_format(cls, ctx: BaseQueryContext): @@ -80,8 +90,8 @@ class ClickHouseType(ABC): """ self.type_def = type_def self.wrappers = type_def.wrappers - self.low_card = 'LowCardinality' in self.wrappers - self.nullable = 'Nullable' in self.wrappers + self.low_card = "LowCardinality" in self.wrappers + self.nullable = "Nullable" in self.wrappers def __eq__(self, other): return other.__class__ == self.__class__ and self.type_def == other.type_def @@ -91,9 +101,9 @@ class ClickHouseType(ABC): @property def name(self): - name = f'{self.base_type}{self._name_suffix}' + name = f"{self.base_type}{self._name_suffix}" for wrapper in reversed(self.wrappers): - name = f'{wrapper}({name})' + name = f"{wrapper}({name})" return name @property @@ -138,7 +148,7 @@ class ClickHouseType(ABC): if self.low_card: v = source.read_uint64() if v != low_card_version: - logger.warning('Unexpected low cardinality version %d reading type %s', v, self.name) + logger.warning("Unexpected low cardinality version %d reading type %s", v, self.name) return v return None @@ -180,11 +190,13 @@ class ClickHouseType(ABC): # The binary methods are really abstract, but they aren't implemented for container classes which # delegate binary operations to their elements - # pylint: disable=no-self-use - def _read_column_binary(self, - _source: ByteSource, - _num_rows: int, _ctx: QueryContext, - _read_state: Any) -> Union[Sequence, MutableSequence]: + def _read_column_binary( + self, + _source: ByteSource, + _num_rows: int, + _ctx: QueryContext, + _read_state: Any, + ) -> Sequence | MutableSequence: """ Lowest level read method for ClickHouseType native data columns :param _source: Native protocol binary read buffer @@ -196,7 +208,7 @@ class ClickHouseType(ABC): def _finalize_column(self, column: Sequence, _ctx: QueryContext) -> Sequence: return column - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): # noqa: B027 """ Lowest level write method for ClickHouseType data columns :param column: Python data column @@ -230,12 +242,11 @@ class ClickHouseType(ABC): dest += bytes([1 if x is None else 0 for x in column]) self._write_column_binary(column, dest, ctx) - # pylint: disable=no-member def _read_low_card_column(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any): if num_rows == 0: return [] key_data = source.read_uint64() - key_sz = 2 ** (key_data & 0xff) + key_sz = 2 ** (key_data & 0xFF) index_cnt = source.read_uint64() index = self._read_column_binary(source, index_cnt, ctx, read_state) key_cnt = source.read_uint64() @@ -295,9 +306,9 @@ class ClickHouseType(ABC): EMPTY_TYPE_DEF = TypeDef() -NULLABLE_TYPE_DEF = TypeDef(wrappers=('Nullable',)) -LC_TYPE_DEF = TypeDef(wrappers=('LowCardinality',)) -type_map: Dict[str, Type[ClickHouseType]] = {} +NULLABLE_TYPE_DEF = TypeDef(wrappers=("Nullable",)) +LC_TYPE_DEF = TypeDef(wrappers=("LowCardinality",)) +type_map: dict[str, type[ClickHouseType]] = {} class ArrayType(ClickHouseType, ABC, registered=False): @@ -306,18 +317,19 @@ class ArrayType(ClickHouseType, ABC, registered=False): arrays can only be used for ClickHouse types that can be translated into UInt64 (and smaller) integers or Float32/64 """ + _signed = True _array_type = None _struct_type = None - valid_formats = 'string', 'native' + valid_formats = "string", "native" python_type = int def __init_subclass__(cls, registered: bool = True): super().__init_subclass__(registered) - if cls._array_type in ('i', 'I') and int_size == 2: - cls._array_type = 'L' if cls._array_type.isupper() else 'l' + if cls._array_type in ("i", "I") and int_size == 2: + cls._array_type = "L" if cls._array_type.isupper() else "l" if isinstance(cls._array_type, str) and cls._array_type: - cls._struct_type = '<' + cls._array_type + cls._struct_type = "<" + cls._array_type cls.byte_size = array.array(cls._array_type).itemsize def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): @@ -334,7 +346,7 @@ class ArrayType(ClickHouseType, ABC, registered=False): return super()._build_lc_column(index, keys, ctx) def _finalize_column(self, column: Sequence, ctx: QueryContext) -> Sequence: - if self.read_format(ctx) == 'string': + if self.read_format(ctx) == "string": return [str(x) for x in column] if ctx.use_extended_dtypes and self.nullable: return options.pd.array(column, dtype=self.base_type) @@ -342,7 +354,7 @@ class ArrayType(ClickHouseType, ABC, registered=False): return options.np.array(column, dtype=self.np_type) return column - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): if len(column) and self.nullable: column = [0 if x is None else x for x in column] write_array(self._array_type, column, dest, ctx.column_name) @@ -360,12 +372,13 @@ class UnsupportedType(ClickHouseType, ABC, registered=False): Base class for ClickHouse types that can't be serialized/deserialized into Python types. Mostly useful just for DDL statements """ + def __init__(self, type_def: TypeDef): super().__init__(type_def) self._name_suffix = type_def.arg_str def _read_column_binary(self, source: Sequence, num_rows: int, ctx: QueryContext, read_state: Any): - raise NotSupportedError(f'{self.name} deserialization not supported') + raise NotSupportedError(f"{self.name} deserialization not supported") - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): - raise NotSupportedError(f'{self.name} serialization not supported') + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): + raise NotSupportedError(f"{self.name} serialization not supported") diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/container.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/container.py index 5908c75765b..88e87540695 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/container.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/container.py @@ -1,21 +1,23 @@ import array import logging -from typing import Sequence, Collection, Any +from collections.abc import Collection, Sequence +from typing import Any +from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef +from clickhouse_connect.datatypes.registry import get_from_name +from clickhouse_connect.driver.binding import quote_identifier +from clickhouse_connect.driver.common import first_value, must_swap +from clickhouse_connect.driver.ctypes import data_conv from clickhouse_connect.driver.insert import InsertContext from clickhouse_connect.driver.query import QueryContext -from clickhouse_connect.driver.binding import quote_identifier from clickhouse_connect.driver.types import ByteSource from clickhouse_connect.json_impl import any_to_json -from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef -from clickhouse_connect.driver.common import must_swap, first_value -from clickhouse_connect.datatypes.registry import get_from_name logger = logging.getLogger(__name__) class Array(ClickHouseType): - __slots__ = ('element_type', '_insert_name') + __slots__ = ("element_type", "_insert_name") python_type = list @property @@ -25,8 +27,8 @@ class Array(ClickHouseType): def __init__(self, type_def: TypeDef): super().__init__(type_def) self.element_type = get_from_name(type_def.values[0]) - self._name_suffix = f'({self.element_type.name})' - self._insert_name = f'Array({self.element_type.insert_name})' + self._name_suffix = f"({self.element_type.name})" + self._insert_name = f"Array({self.element_type.insert_name})" def read_column_prefix(self, source: ByteSource, ctx: QueryContext): return self.element_type.read_column_prefix(source, ctx) @@ -39,7 +41,6 @@ class Array(ClickHouseType): total += self.element_type.data_size(x) return total // len(sample) + 8 - # pylint: disable=too-many-locals def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any): final_type = self.element_type depth = 1 @@ -49,7 +50,7 @@ class Array(ClickHouseType): level_size = num_rows offset_sizes = [] for _ in range(depth): - level_offsets = source.read_array('Q', level_size) + level_offsets = source.read_array("Q", level_size) offset_sizes.append(level_offsets) level_size = level_offsets[-1] if level_offsets else 0 if level_size: @@ -61,7 +62,7 @@ class Array(ClickHouseType): data = [] last = 0 for x in offset_range: - data.append(column[last: x]) + data.append(column[last:x]) last = x column = data return column @@ -78,7 +79,7 @@ class Array(ClickHouseType): for _ in range(depth): total = 0 data = [] - offsets = array.array('Q') + offsets = array.array("Q") for x in column: total += len(x) offsets.append(total) @@ -91,9 +92,9 @@ class Array(ClickHouseType): class Tuple(ClickHouseType): - _slots = 'element_names', 'element_types', '_insert_name' + _slots = "element_names", "element_types", "_insert_name" python_type = tuple - valid_formats = 'tuple', 'dict', 'json', 'native' # native is 'tuple' for unnamed tuples, and dict for named tuples + valid_formats = "tuple", "dict", "json", "native" # native is 'tuple' for unnamed tuples, and dict for named tuples @property def insert_name(self): @@ -108,8 +109,9 @@ class Tuple(ClickHouseType): else: self._name_suffix = type_def.arg_str if self.element_names: - self._insert_name = \ + self._insert_name = ( f"Tuple({', '.join(quote_identifier(k) + ' ' + v.insert_name for k, v in zip(type_def.keys, self.element_types))})" + ) else: self._insert_name = f"Tuple({', '.join(v.insert_name for v in self.element_types)})" @@ -136,12 +138,12 @@ class Tuple(ClickHouseType): for ix, e_type in enumerate(self.element_types): column = e_type.read_column_data(source, num_rows, ctx, read_state[ix]) columns.append(column) - if e_names and self.read_format(ctx) != 'tuple': + if e_names and self.read_format(ctx) != "tuple": dicts = [{} for _ in range(num_rows)] for ix, x in enumerate(dicts): for y, key in enumerate(e_names): x[key] = columns[y][ix] - if self.read_format(ctx) == 'json': + if self.read_format(ctx) == "json": to_json = any_to_json return [to_json(x) for x in dicts] return dicts @@ -169,7 +171,7 @@ class Tuple(ClickHouseType): class Map(ClickHouseType): - _slots = 'key_type', 'value_type', '_insert_name' + _slots = "key_type", "value_type", "_insert_name" python_type = dict @property @@ -181,7 +183,7 @@ class Map(ClickHouseType): self.key_type = get_from_name(type_def.values[0]) self.value_type = get_from_name(type_def.values[1]) self._name_suffix = type_def.arg_str - self._insert_name = f'Map({self.key_type.insert_name}, {self.value_type.insert_name})' + self._insert_name = f"Map({self.key_type.insert_name}, {self.value_type.insert_name})" def _data_size(self, sample: Collection) -> int: total = 0 @@ -197,19 +199,16 @@ class Map(ClickHouseType): value_state = self.value_type.read_column_prefix(source, ctx) return key_state, value_state - # pylint: disable=too-many-locals def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any): - offsets = source.read_array('Q', num_rows) + offsets = source.read_array("Q", num_rows) total_rows = 0 if len(offsets) == 0 else offsets[-1] keys = self.key_type.read_column_data(source, total_rows, ctx, read_state[0]) values = self.value_type.read_column_data(source, total_rows, ctx, read_state[1]) - all_pairs = tuple(zip(keys, values)) column = [] - app = column.append - last = 0 + prev = 0 for offset in offsets: - app(dict(all_pairs[last: offset])) - last = offset + column.append(dict(zip(keys[prev:offset], values[prev:offset]))) + prev = offset return column def write_column_prefix(self, dest: bytearray): @@ -217,24 +216,13 @@ class Map(ClickHouseType): self.value_type.write_column_prefix(dest) def write_column_data(self, column: Sequence, dest: bytearray, ctx: InsertContext): - offsets = array.array('Q') - keys = [] - values = [] - total = 0 - for v in column: - total += len(v) - offsets.append(total) - keys.extend(v.keys()) - values.extend(v.values()) - if must_swap: - offsets.byteswap() - dest += offsets.tobytes() + keys, values = data_conv.build_map_columns(column, dest) self.key_type.write_column_data(keys, dest, ctx) self.value_type.write_column_data(values, dest, ctx) class Nested(ClickHouseType): - __slots__ = 'tuple_array', 'element_names', 'element_types' + __slots__ = "tuple_array", "element_names", "element_types" python_type = Sequence[dict] def __init__(self, type_def): @@ -242,7 +230,7 @@ class Nested(ClickHouseType): self.element_names = type_def.keys self.tuple_array = get_from_name(f"Array(Tuple({','.join(type_def.values)}))") self.element_types = self.tuple_array.element_type.element_types - cols = [f'{x[0]} {x[1].name}' for x in zip(type_def.keys, self.element_types)] + cols = [f"{x[0]} {x[1].name}" for x in zip(type_def.keys, self.element_types)] self._name_suffix = f"({', '.join(cols)})" def _data_size(self, sample: Collection) -> int: diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/dynamic.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/dynamic.py index 1b6668d096b..150fa1ced1f 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/dynamic.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/dynamic.py @@ -1,13 +1,14 @@ import logging from collections import namedtuple -from typing import List, Tuple, Sequence, Collection, Any +from collections.abc import Collection, Sequence +from typing import Any from urllib.parse import unquote from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef from clickhouse_connect.datatypes.registry import get_from_name from clickhouse_connect.datatypes.string import String from clickhouse_connect.driver.bytesource import ByteArraySource -from clickhouse_connect.driver.common import unescape_identifier, first_value, write_uint64 +from clickhouse_connect.driver.common import first_value, unescape_identifier, write_uint64 from clickhouse_connect.driver.ctypes import data_conv from clickhouse_connect.driver.errors import handle_error from clickhouse_connect.driver.exceptions import DataError, InternalError @@ -19,19 +20,19 @@ from clickhouse_connect.json_impl import any_to_json SHARED_DATA_TYPE: ClickHouseType STRING_DATA_TYPE: ClickHouseType SHARED_VARIANT_TYPE: ClickHouseType -_JSON_NULL = b'null' -_JSON_NULL_STR = 'null' +_JSON_NULL = b"null" +_JSON_NULL_STR = "null" logger = logging.getLogger(__name__) json_serialization_format = 0x1 -VariantState = namedtuple('VariantState', 'discriminator_mode element_states') +VariantState = namedtuple("VariantState", "discriminator_mode element_states") -def _json_path_segments(path: str) -> List[str]: - segments = path.split('.') - if '%' in path: +def _json_path_segments(path: str) -> list[str]: + segments = path.split(".") + if "%" in path: return [unquote(segment) for segment in segments] return segments @@ -54,7 +55,7 @@ class SharedDataString(String): return source.read_str_col(num_rows, None) -TypedVariant = namedtuple('TypedVariant', 'value type_name') +TypedVariant = namedtuple("TypedVariant", "value type_name") def typed_variant(value: Any, type_name: str) -> TypedVariant: @@ -78,7 +79,7 @@ def typed_variant(value: Any, type_name: str) -> TypedVariant: client.insert('my_table', data, column_names=['variant_col']) """ if value is None: - raise DataError('Use None directly instead of typed_variant for null Variant values') + raise DataError("Use None directly instead of typed_variant for null Variant values") try: return TypedVariant(value, get_from_name(type_name).name) except InternalError: @@ -86,12 +87,12 @@ def typed_variant(value: Any, type_name: str) -> TypedVariant: class Variant(ClickHouseType): - __slots__ = ('element_types', '_python_map', '_name_index') + __slots__ = ("element_types", "_python_map", "_name_index") python_type = object def __init__(self, type_def: TypeDef): super().__init__(type_def) - self.element_types: List[ClickHouseType] = [get_from_name(name) for name in type_def.values] + self.element_types: list[ClickHouseType] = [get_from_name(name) for name in type_def.values] self._name_suffix = f"({', '.join(ch_type.name for ch_type in self.element_types)})" self._build_dispatch() @@ -109,7 +110,7 @@ class Variant(ClickHouseType): self._python_map = {pt: idx for pt, idx in seen.items() if pt not in collisions} self._name_index = {etype.name: i for i, etype in enumerate(self.element_types)} - def _resolve_disc(self, v: Any) -> Tuple[int, Any]: + def _resolve_disc(self, v: Any) -> tuple[int, Any]: if isinstance(v, TypedVariant): idx = self._name_index.get(v.type_name) if idx is None: @@ -126,8 +127,7 @@ class Variant(ClickHouseType): element_states = [e_type.read_column_prefix(source, ctx) for e_type in self.element_types] return VariantState(discriminator_mode, element_states) - def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, - read_state: VariantState) -> Sequence: + def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: VariantState) -> Sequence: return read_variant_column(source, num_rows, ctx, self.element_types, read_state.element_states) def write_column_prefix(self, dest: bytearray): @@ -136,7 +136,7 @@ class Variant(ClickHouseType): e_type.write_column_prefix(dest) def write_column_data(self, column: Sequence, dest: bytearray, ctx: InsertContext): - sub_columns: List[list] = [[] for _ in range(len(self.element_types))] + sub_columns: list[list] = [[] for _ in range(len(self.element_types))] discriminators = bytearray() for v in column: if v is None: @@ -175,19 +175,21 @@ class Variant(ClickHouseType): return (total_data_size // len(sample)) + 1 -def read_variant_column(source: ByteSource, - num_rows: int, - ctx: QueryContext, - variant_types: List[ClickHouseType], - element_states: List[Any]) -> Sequence: +def read_variant_column( + source: ByteSource, + num_rows: int, + ctx: QueryContext, + variant_types: list[ClickHouseType], + element_states: list[Any], +) -> Sequence: v_count = len(variant_types) - discriminators = source.read_array('B', num_rows) + discriminators = source.read_array("B", num_rows) # We have to count up how many of each discriminator there are in the block to read the sub columns correctly disc_rows = [0] * v_count for disc in discriminators: if disc != 255: disc_rows[disc] += 1 - sub_columns: List[Sequence] = [[]] * v_count + sub_columns: list[Sequence] = [[]] * v_count # Read all the sub-columns for ix in range(v_count): if disc_rows[ix] > 0: @@ -206,7 +208,7 @@ def read_variant_column(source: ByteSource, return col -DynamicState = namedtuple('DynamicState', 'struct_version variant_types variant_states') +DynamicState = namedtuple("DynamicState", "struct_version variant_types variant_states") def read_dynamic_prefix(_, source: ByteSource, ctx: QueryContext) -> DynamicState: @@ -214,12 +216,14 @@ def read_dynamic_prefix(_, source: ByteSource, ctx: QueryContext) -> DynamicStat if struct_version == 1: source.read_leb128() # max dynamic types, we ignore this value elif struct_version != 2: - raise DataError('Unrecognized dynamic structure version') + raise DataError("Unrecognized dynamic structure version") num_variants = source.read_leb128() variant_types = [get_from_name(source.read_leb128_str()) for _ in range(num_variants)] - variant_types.append(SHARED_VARIANT_TYPE) + variant_types.append(SHARED_VARIANT_TYPE) # noqa: F821 (undefined-name) + # replicate the sort after appending SharedVariant + variant_types.sort(key=lambda t: t.name) if source.read_uint64() != 0: # discriminator format, currently only 0 is recognized - raise DataError('Unexpected discriminator format in Variant column prefix') + raise DataError("Unexpected discriminator format in Variant column prefix") variant_states = [e_type.read_column_prefix(source, ctx) for e_type in variant_types] return DynamicState(struct_version, variant_types, variant_states) @@ -230,15 +234,20 @@ class Dynamic(ClickHouseType): @property def insert_name(self): - return 'String' + return "String" def __init__(self, type_def: TypeDef): super().__init__(type_def) - if type_def.keys and type_def.keys[0] == 'max_types': - self._name_suffix = f'(max_types={type_def.values[0]})' + if type_def.keys and type_def.keys[0] == "max_types": + self._name_suffix = f"(max_types={type_def.values[0]})" - def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, - read_state: DynamicState) -> Sequence: + def _read_column_binary( + self, + source: ByteSource, + num_rows: int, + ctx: QueryContext, + read_state: DynamicState, + ) -> Sequence: return read_variant_column(source, num_rows, ctx, read_state.variant_types, read_state.variant_states) def write_column_data(self, column: Sequence, dest: bytearray, ctx: InsertContext): @@ -264,7 +273,7 @@ def write_json(ch_type: ClickHouseType, column: Sequence, dest: bytearray, ctx: first = first_value(column, ch_type.nullable) write_col = column encoding = ctx.encoding or ch_type.encoding - if not isinstance(first, str) and ch_type.write_format(ctx) != 'string': + if not isinstance(first, str) and ch_type.write_format(ctx) != "string": to_json = any_to_json if ch_type.nullable: write_col = [_JSON_NULL if v is None else to_json(v) for v in column] @@ -279,10 +288,10 @@ def write_json(ch_type: ClickHouseType, column: Sequence, dest: bytearray, ctx: def write_str_values(ch_type: ClickHouseType, column: Sequence, dest: bytearray, ctx: InsertContext): encoding = ctx.encoding or ch_type.encoding - col = [''] * len(column) + col = [""] * len(column) for ix, v in enumerate(column): if v is None: - col[ix] = 'NULL' + col[ix] = "NULL" else: col[ix] = str(v) handle_error(data_conv.write_str_col(col, False, encoding, dest), ctx) @@ -315,36 +324,36 @@ STANDARD_DISCRIMINATOR_TYPES = { # Known fixed payload sizes for BinaryTypeIndex values outside STANDARD_DISCRIMINATOR_TYPES. # Used to validate variant-encoded data in the printable ASCII overlap range (0x20+). _EXTENDED_PAYLOAD_SIZE = { - 0x0F: 2, # Date (UInt16) - 0x10: 4, # Date32 (Int32) - 0x11: 4, # DateTimeUTC (UInt32) - 0x13: 8, # DateTime64UTC (Int64) + 0x0F: 2, # Date (UInt16) + 0x10: 4, # Date32 (Int32) + 0x11: 4, # DateTimeUTC (UInt32) + 0x13: 8, # DateTime64UTC (Int64) 0x1D: 16, # UUID - 0x28: 4, # IPv4 + 0x28: 4, # IPv4 0x29: 16, # IPv6 - 0x31: 2, # BFloat16 + 0x31: 2, # BFloat16 } # Expected payload sizes for fixed-size discriminator types. # Used to validate that binary data is actually variant-encoded vs a plain string # whose first byte happens to collide with a discriminator value. _DISCRIMINATOR_PAYLOAD_SIZE = { - 0x00: 0, # Nothing - 0x01: 1, # UInt8 - 0x02: 2, # UInt16 - 0x03: 4, # UInt32 - 0x04: 8, # UInt64 + 0x00: 0, # Nothing + 0x01: 1, # UInt8 + 0x02: 2, # UInt16 + 0x03: 4, # UInt32 + 0x04: 8, # UInt64 0x05: 16, # UInt128 0x06: 32, # UInt256 - 0x07: 1, # Int8 - 0x08: 2, # Int16 - 0x09: 4, # Int32 - 0x0A: 8, # Int64 + 0x07: 1, # Int8 + 0x08: 2, # Int16 + 0x09: 4, # Int32 + 0x0A: 8, # Int64 0x0B: 16, # Int128 0x0C: 32, # Int256 - 0x0D: 4, # Float32 - 0x0E: 8, # Float64 - 0x2D: 1, # Bool + 0x0D: 4, # Float32 + 0x0E: 8, # Float64 + 0x2D: 1, # Bool # String (0x15) is variable-length and validated separately } @@ -396,7 +405,6 @@ def _decode_variant(binary_data: bytes, ctx: QueryContext, validate_length: bool result = value_type.read_column_data(byte_source, 1, ctx, read_state) return result[0] if result else None - # pylint: disable=broad-exception-caught except Exception as e: logger.debug("Variant decode failed: %s", e) return binary_data @@ -416,7 +424,6 @@ def decode_shared_data_value(binary_data: bytes, ctx: QueryContext): return _decode_variant(binary_data, ctx) -# pylint: disable=too-many-return-statements, too-many-branches def decode_shared_variant_value(binary_data: bytes, ctx: QueryContext): """Decode a value from a Dynamic column's shared variant. @@ -482,7 +489,7 @@ class SharedVariant(String): class JSON(ClickHouseType): __slots__ = "typed_paths", "typed_types", "skips" python_type = dict - valid_formats = 'string', 'native' + valid_formats = "string", "native" _data_size = json_sample_size write_column_data = write_json shared_data_type: ClickHouseType @@ -499,45 +506,45 @@ class JSON(ClickHouseType): skips = [] parts = [] for key, value in zip(type_def.keys, type_def.values): - if key == 'max_dynamic_paths': + if key == "max_dynamic_paths": try: self.max_dynamic_paths = int(value) - parts.append(f'{key} = {value}') + parts.append(f"{key} = {value}") continue except ValueError: pass - if key == 'max_dynamic_types': + if key == "max_dynamic_types": try: self.max_dynamic_types = int(value) - parts.append(f'{key} = {value}') + parts.append(f"{key} = {value}") continue except ValueError: pass - if key == 'SKIP': - if value.startswith('REGEXP'): - value = 'REGEXP ' + value[6:] + if key == "SKIP": + if value.startswith("REGEXP"): + value = "REGEXP " + value[6:] else: if not value.startswith("`"): - value = f'`{value}`' + value = f"`{value}`" skips.append(value) else: key = unescape_identifier(key) typed_paths.append(key) typed_types.append(get_from_name(value)) - key = f'`{key}`' - parts.append(f'{key} {value}') + key = f"`{key}`" + parts.append(f"{key} {value}") if typed_paths: self.typed_paths = typed_paths self.typed_types = typed_types if skips: self.skips = skips if parts: - self._name_suffix = f'({", ".join(parts)})' + self._name_suffix = f"({', '.join(parts)})" @property def insert_name(self): if json_serialization_format == 0: - return 'String' + return "String" return super().insert_name def write_column_prefix(self, dest: bytearray): @@ -549,22 +556,24 @@ class JSON(ClickHouseType): if serialize_version == 0: source.read_leb128() # max dynamic types, we ignore this value elif serialize_version != 2: - raise DataError(f'Unrecognized json structure version: {serialize_version} column: `{ctx.column_name}`') + raise DataError(f"Unrecognized json structure version: {serialize_version} column: `{ctx.column_name}`") dynamic_path_cnt = source.read_leb128() dynamic_paths = [source.read_leb128_str() for _ in range(dynamic_path_cnt)] typed_states = [typed.read_column_prefix(source, ctx) for typed in self.typed_types] dynamic_states = [read_dynamic_prefix(self, source, ctx) for _ in range(dynamic_path_cnt)] - shared_state = SHARED_DATA_TYPE.read_column_prefix(source, ctx) + shared_state = SHARED_DATA_TYPE.read_column_prefix(source, ctx) # noqa: F821 (undefined-name) return JSONState(serialize_version, dynamic_paths, typed_states, dynamic_states, shared_state) - # pylint: disable=too-many-locals def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: JSONState): - typed_columns = [ch_type.read_column_data(source, num_rows, ctx, read_state) - for ch_type, read_state in zip(self.typed_types, read_state.typed_states)] + typed_columns = [ + ch_type.read_column_data(source, num_rows, ctx, read_state) + for ch_type, read_state in zip(self.typed_types, read_state.typed_states) + ] dynamic_columns = [ read_variant_column(source, num_rows, ctx, dynamic_state.variant_types, dynamic_state.variant_states) - for dynamic_state in read_state.dynamic_states] - shared_columns = SHARED_DATA_TYPE.read_column_data(source, num_rows, ctx, read_state.shared_state) + for dynamic_state in read_state.dynamic_states + ] + shared_columns = SHARED_DATA_TYPE.read_column_data(source, num_rows, ctx, read_state.shared_state) # noqa: F821 (undefined-name) col = [] for row_num in range(num_rows): top = {} @@ -582,6 +591,6 @@ class JSON(ClickHouseType): if value is not None: _nest_value(top, key, value) col.append(top) - if self.read_format(ctx) == 'string': + if self.read_format(ctx) == "string": return [any_to_json(v) for v in col] return col diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/format.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/format.py index 8c73934d284..fa09e2c9fdd 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/format.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/format.py @@ -1,10 +1,10 @@ import re +from collections.abc import Sequence -from typing import Dict, Type, Sequence, Optional - -from clickhouse_connect.datatypes.base import ClickHouseType, type_map, ch_read_formats, ch_write_formats +from clickhouse_connect.datatypes.base import ClickHouseType, ch_read_formats, ch_write_formats, type_map from clickhouse_connect.driver.exceptions import ProgrammingError + def set_default_formats(*args, **kwargs): fmt_map = format_map(_convert_arguments(*args, **kwargs)) ch_read_formats.update(fmt_map) @@ -42,7 +42,7 @@ def clear_read_format(pattern: str): ch_read_formats.pop(ch_type, None) -def format_map(fmt_map: Optional[Dict[str, str]]) -> Dict[Type[ClickHouseType], str]: +def format_map(fmt_map: dict[str, str] | None) -> dict[type[ClickHouseType], str]: if not fmt_map: return {} final_map = {} @@ -52,22 +52,22 @@ def format_map(fmt_map: Optional[Dict[str, str]]) -> Dict[Type[ClickHouseType], return final_map -def _convert_arguments(*args, **kwargs) -> Dict[str, str]: +def _convert_arguments(*args, **kwargs) -> dict[str, str]: fmt_map = {} try: for x in range(0, len(args), 2): fmt_map[args[x]] = args[x + 1] except (IndexError, TypeError, ValueError) as ex: - raise ProgrammingError('Invalid type/format arguments for format method') from ex + raise ProgrammingError("Invalid type/format arguments for format method") from ex fmt_map.update(kwargs) return fmt_map -def _matching_types(pattern: str, fmt: str = None) -> Sequence[Type[ClickHouseType]]: - re_pattern = re.compile(pattern.replace('*', '.*'), re.IGNORECASE) +def _matching_types(pattern: str, fmt: str = None) -> Sequence[type[ClickHouseType]]: + re_pattern = re.compile(pattern.replace("*", ".*"), re.IGNORECASE) matches = [ch_type for type_name, ch_type in type_map.items() if re_pattern.match(type_name)] if not matches: - raise ProgrammingError(f'Unrecognized ClickHouse type {pattern} when setting formats') + raise ProgrammingError(f"Unrecognized ClickHouse type {pattern} when setting formats") if fmt: invalid = [ch_type.__name__ for ch_type in matches if fmt not in ch_type.valid_formats] if invalid: diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/geometric.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/geometric.py index 04c94eba561..f7773dd34b3 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/geometric.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/geometric.py @@ -1,4 +1,5 @@ -from typing import Sequence, Any +from collections.abc import Sequence +from typing import Any from clickhouse_connect.datatypes.base import ClickHouseType from clickhouse_connect.driver.insert import InsertContext @@ -10,6 +11,8 @@ RING_DATA_TYPE: ClickHouseType POLYGON_DATA_TYPE: ClickHouseType MULTI_POLYGON_DATA_TYPE: ClickHouseType +# ruff: noqa: F821 (Undefine name) + class Point(ClickHouseType): def write_column(self, column: Sequence, dest: bytearray, ctx: InsertContext): @@ -40,7 +43,7 @@ class Polygon(ClickHouseType): def read_column_prefix(self, source: ByteSource, ctx: QueryContext): return POLYGON_DATA_TYPE.read_column_prefix(source, ctx) - def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state:Any) -> Sequence: + def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any) -> Sequence: return POLYGON_DATA_TYPE.read_column_data(source, num_rows, ctx, read_state) @@ -51,7 +54,7 @@ class MultiPolygon(ClickHouseType): def read_column_prefix(self, source: ByteSource, ctx: QueryContext): return MULTI_POLYGON_DATA_TYPE.read_column_prefix(source, ctx) - def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state:Any) -> Sequence: + def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any) -> Sequence: return MULTI_POLYGON_DATA_TYPE.read_column_data(source, num_rows, ctx, read_state) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/network.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/network.py index 7f13820f5b2..e0e751cf390 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/network.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/network.py @@ -1,39 +1,39 @@ import socket -from ipaddress import ip_address, IPv4Address, IPv6Address -from typing import Union, MutableSequence, Sequence, Any +from collections.abc import MutableSequence, Sequence +from ipaddress import IPv4Address, IPv6Address, ip_address +from typing import Any from clickhouse_connect.datatypes.base import ClickHouseType -from clickhouse_connect.driver.common import write_array, int_size, first_value +from clickhouse_connect.driver.common import first_value, int_size, write_array +from clickhouse_connect.driver.ctypes import data_conv from clickhouse_connect.driver.insert import InsertContext from clickhouse_connect.driver.query import QueryContext from clickhouse_connect.driver.types import ByteSource -from clickhouse_connect.driver.ctypes import data_conv -IPV4_V6_MASK = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff' -V6_NULL = bytes(b'\x00' * 16) +IPV4_V6_MASK = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" +V6_NULL = bytes(b"\x00" * 16) -# pylint: disable=protected-access class IPv4(ClickHouseType): - _array_type = 'L' if int_size == 2 else 'I' - valid_formats = 'string', 'native', 'int' + _array_type = "L" if int_size == 2 else "I" + valid_formats = "string", "native", "int" python_type = IPv4Address byte_size = 4 def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): - if self.read_format(ctx) == 'int': + if self.read_format(ctx) == "int": return source.read_array(self._array_type, num_rows) - if self.read_format(ctx) == 'string': + if self.read_format(ctx) == "string": column = source.read_array(self._array_type, num_rows) - return [socket.inet_ntoa(x.to_bytes(4, 'big')) for x in column] + return [socket.inet_ntoa(x.to_bytes(4, "big")) for x in column] return data_conv.read_ipv4_col(source, num_rows) - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): first = first_value(column, self.nullable) if isinstance(first, str): fixed = 24, 16, 8, 0 - # pylint: disable=consider-using-generator - column = [(sum([int(b) << fixed[ix] for ix, b in enumerate(x.split('.'))])) if x else 0 for x in column] + + column = [(sum([int(b) << fixed[ix] for ix, b in enumerate(x.split("."))])) if x else 0 for x in column] else: if self.nullable: column = [x._ip if x else 0 for x in column] @@ -45,21 +45,20 @@ class IPv4(ClickHouseType): fmt = self.read_format(ctx) if ctx.use_none: return None - if fmt == 'string': - return '0.0.0.0' - if fmt == 'int': + if fmt == "string": + return "0.0.0.0" + if fmt == "int": return 0 return None -# pylint: disable=protected-access class IPv6(ClickHouseType): - valid_formats = 'string', 'native' + valid_formats = "string", "native" python_type = IPv6Address byte_size = 16 def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): - if self.read_format(ctx) == 'string': + if self.read_format(ctx) == "string": return self._read_binary_str(source, num_rows) return self._read_binary_ip(source, num_rows) @@ -67,12 +66,12 @@ class IPv6(ClickHouseType): def _read_binary_ip(source: ByteSource, num_rows: int) -> list[IPv6Address]: """Read IPv6 addresses in native format, always returning IPv6Address objects.""" fast_ip_v6 = IPv6Address.__new__ - with_scope_id = '_scope_id' in IPv6Address.__slots__ + with_scope_id = "_scope_id" in IPv6Address.__slots__ new_col = [] app = new_col.append ifb = int.from_bytes for _ in range(num_rows): - int_value = ifb(source.read_bytes(16), 'big') + int_value = ifb(source.read_bytes(16), "big") ipv6 = fast_ip_v6(IPv6Address) ipv6._ip = int_value if with_scope_id: @@ -95,7 +94,7 @@ class IPv6(ClickHouseType): def _write_column_binary( self, - column: Union[Sequence, MutableSequence], + column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext, ): @@ -108,9 +107,7 @@ class IPv6(ClickHouseType): try: addr = ip_address(value) except ValueError as e: - raise ValueError( - f"Failed to parse '{value}' as a valid IP address for column '{ctx.column_name}'" - ) from e + raise ValueError(f"Failed to parse '{value}' as a valid IP address for column '{ctx.column_name}'") from e # Now handle parsed object if isinstance(addr, IPv6Address): @@ -122,4 +119,4 @@ class IPv6(ClickHouseType): def _active_null(self, ctx): if ctx.use_none: return None - return '::' if self.read_format(ctx) == 'string' else V6_NULL + return "::" if self.read_format(ctx) == "string" else V6_NULL diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/numeric.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/numeric.py index 8cea7a4bbbd..37b0d83ed4c 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/numeric.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/numeric.py @@ -1,24 +1,28 @@ +import array import decimal -from typing import Union, Type, Sequence, MutableSequence, Any import struct -import array - -from math import nan, isnan, isinf +from collections.abc import MutableSequence, Sequence +from math import isinf, isnan, nan +from typing import Any -from clickhouse_connect.datatypes.base import TypeDef, ArrayType, ClickHouseType -from clickhouse_connect.driver.common import array_type, write_array, decimal_size, decimal_prec, first_value +from clickhouse_connect.datatypes.base import ArrayType, ClickHouseType, TypeDef from clickhouse_connect.driver import ctypes as driver_ctypes +from clickhouse_connect.driver import options +from clickhouse_connect.driver.common import array_type, decimal_prec, decimal_size, first_value, write_array from clickhouse_connect.driver.ctypes import data_conv from clickhouse_connect.driver.insert import InsertContext -from clickhouse_connect.driver import options from clickhouse_connect.driver.query import QueryContext from clickhouse_connect.driver.types import ByteSource class IntBase(ArrayType, registered=False): - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): if len(column) == 0: return + np = options.np + if np is not None and isinstance(column, np.ndarray) and column.dtype.kind in ("i", "u"): + data_conv.write_native_col(self._array_type, column, dest, ctx.column_name) + return if self.nullable: first = next((x for x in column if x is not None), None) if isinstance(first, int): @@ -31,76 +35,75 @@ class IntBase(ArrayType, registered=False): column = [0 if x is None or isnan(x) or isinf(x) else int(x) for x in column] elif not isinstance(column[0], int): column = [int(x) for x in column] - write_array(self._array_type, column, dest) + data_conv.write_native_col(self._array_type, column, dest, ctx.column_name) class Int8(IntBase): - _array_type = 'b' - np_type = 'b' + _array_type = "b" + np_type = "b" class UInt8(IntBase): - _array_type = 'B' - np_type = 'B' + _array_type = "B" + np_type = "B" class Int16(IntBase): - _array_type = 'h' - np_type = '<i2' + _array_type = "h" + np_type = "<i2" class UInt16(IntBase): - _array_type = 'H' - np_type = '<u2' + _array_type = "H" + np_type = "<u2" class Int32(IntBase): - _array_type = 'i' - np_type = '<i4' + _array_type = "i" + np_type = "<i4" class UInt32(IntBase): - _array_type = 'I' - np_type = '<u4' + _array_type = "I" + np_type = "<u4" class Int64(IntBase): - _array_type = 'q' - np_type = '<i8' + _array_type = "q" + np_type = "<i8" class UInt64(IntBase): - valid_formats = 'signed', 'native' - _array_type = 'Q' - np_type = '<u8' + valid_formats = "signed", "native" + _array_type = "Q" + np_type = "<u8" python_type = int def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): fmt = self.read_format(ctx) if ctx.use_numpy: - np_type = '<q' if fmt == 'signed' else '<u8' + np_type = "<q" if fmt == "signed" else "<u8" return driver_ctypes.numpy_conv.read_numpy_array(source, np_type, num_rows) - arr_type = 'q' if fmt == 'signed' else 'Q' + arr_type = "q" if fmt == "signed" else "Q" return source.read_array(arr_type, num_rows) def _read_nullable_column(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any) -> Sequence: - return data_conv.read_nullable_array(source, 'q' if self.read_format(ctx) == 'signed' else 'Q', - num_rows, self._active_null(ctx)) + return data_conv.read_nullable_array(source, "q" if self.read_format(ctx) == "signed" else "Q", num_rows, self._active_null(ctx)) def _finalize_column(self, column: Sequence, ctx: QueryContext) -> Sequence: fmt = self.read_format(ctx) - if fmt == 'string': + if fmt == "string": return [str(x) for x in column] if ctx.use_extended_dtypes and self.nullable: - return options.pd.array(column, dtype='Int64' if fmt == 'signed' else 'UInt64') + return options.pd.array(column, dtype="Int64" if fmt == "signed" else "UInt64") if ctx.use_numpy and self.nullable and (not ctx.use_none): - return options.np.array(column, dtype='<q' if fmt == 'signed' else '<u8') + return options.np.array(column, dtype="<q" if fmt == "signed" else "<u8") return column class BigInt(ClickHouseType, registered=False): _signed = True - valid_formats = 'string', 'native' + valid_formats = "string", "native" python_type = int def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): @@ -109,43 +112,42 @@ class BigInt(ClickHouseType, registered=False): column = [] app = column.append ifb = int.from_bytes - if self.read_format(ctx) == 'string': + if self.read_format(ctx) == "string": for _ in range(num_rows): - app(str(ifb(source.read_bytes(sz), 'little', signed=signed))) + app(str(ifb(source.read_bytes(sz), "little", signed=signed))) else: for _ in range(num_rows): - app(ifb(source.read_bytes(sz), 'little', signed=signed)) + app(ifb(source.read_bytes(sz), "little", signed=signed)) return column - # pylint: disable=too-many-branches - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): if len(column) == 0: return first = first_value(column, self.nullable) sz = self.byte_size signed = self._signed - empty = bytes(b'\x00' * sz) + empty = bytes(b"\x00" * sz) ext = dest.extend - if isinstance(first, str) or self.write_format(ctx) == 'string': + if isinstance(first, str) or self.write_format(ctx) == "string": if self.nullable: for x in column: if x: - ext(int(x).to_bytes(sz, 'little', signed=signed)) + ext(int(x).to_bytes(sz, "little", signed=signed)) else: ext(empty) else: for x in column: - ext(int(x).to_bytes(sz, 'little', signed=signed)) + ext(int(x).to_bytes(sz, "little", signed=signed)) else: if self.nullable: for x in column: if x: - ext(x.to_bytes(sz, 'little', signed=signed)) + ext(x.to_bytes(sz, "little", signed=signed)) else: ext(empty) else: for x in column: - ext(x.to_bytes(sz, 'little', signed=signed)) + ext(x.to_bytes(sz, "little", signed=signed)) class Int128(BigInt): @@ -169,11 +171,11 @@ class UInt256(BigInt): class Float(ArrayType, registered=False): - _array_type = 'f' + _array_type = "f" python_type = float def _finalize_column(self, column: Sequence, ctx: QueryContext) -> Sequence: - if self.read_format(ctx) == 'string': + if self.read_format(ctx) == "string": return [str(x) for x in column] if ctx.use_numpy and self.nullable and (not ctx.use_none): return options.np.array(column, dtype=self.np_type) @@ -188,9 +190,13 @@ class Float(ArrayType, registered=False): return nan return 0.0 - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): if len(column) == 0: return + np = options.np + if np is not None and isinstance(column, np.ndarray) and column.dtype.kind == "f": + data_conv.write_native_col(self._array_type, column, dest, ctx.column_name) + return if self.nullable: first = next((x for x in column if x is not None), None) if not isinstance(first, float): @@ -199,16 +205,16 @@ class Float(ArrayType, registered=False): column = [0 if x is None else x for x in column] elif not isinstance(column[0], float): column = [float(x) for x in column] - write_array(self._array_type, column, dest) + data_conv.write_native_col(self._array_type, column, dest, ctx.column_name) class Float32(Float): - np_type = '<f4' + np_type = "<f4" class Float64(Float): - _array_type = 'd' - np_type = '<f8' + _array_type = "d" + np_type = "<f8" class BFloat16(ArrayType): @@ -242,9 +248,7 @@ class BFloat16(ArrayType): write_array(self._array_type, vals, dest, ctx.column_name) - def _read_column_binary( - self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any - ): + def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): if ctx.use_numpy: arr16 = driver_ctypes.numpy_conv.read_numpy_array(source, "<u2", num_rows) return (arr16.astype(options.np.uint32) << options.np.uint32(16)).view(options.np.float32) @@ -252,17 +256,13 @@ class BFloat16(ArrayType): raw = source.read_array(self._array_type, num_rows) return [struct.unpack("<f", struct.pack("<I", v << 16))[0] for v in raw] - def _read_nullable_column( - self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any - ): + def _read_nullable_column(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): null_map = source.read_bytes(num_rows) if ctx.use_numpy: arr16 = driver_ctypes.numpy_conv.read_numpy_array(source, "<u2", num_rows) floats = (arr16.astype(options.np.uint32) << options.np.uint32(16)).view(options.np.float32) - return data_conv.build_nullable_column( - floats, null_map, self._active_null(ctx) - ) + return data_conv.build_nullable_column(floats, null_map, self._active_null(ctx)) raw = source.read_array(self._array_type, num_rows) floats = [struct.unpack("<f", struct.pack("<I", v << 16))[0] for v in raw] @@ -286,7 +286,7 @@ class BFloat16(ArrayType): class Bool(ClickHouseType): - np_type = '?' + np_type = "?" python_type = bool byte_size = 1 @@ -300,7 +300,7 @@ class Bool(ClickHouseType): return column def _write_column_binary(self, column, dest, ctx): - write_array('B', [1 if x else 0 for x in column], dest, ctx.column_name) + write_array("B", [1 if x else 0 for x in column], dest, ctx.column_name) class Boolean(Bool): @@ -308,9 +308,9 @@ class Boolean(Bool): class Enum(ClickHouseType): - __slots__ = '_name_map', '_int_map' - _array_type = 'b' - valid_formats = 'native', 'int' + __slots__ = "_name_map", "_int_map" + _array_type = "b" + valid_formats = "native", "int" python_type = str def __init__(self, type_def: TypeDef): @@ -318,17 +318,17 @@ class Enum(ClickHouseType): escaped_keys = [key.replace("'", "\\'") for key in type_def.keys] self._name_map = dict(zip(type_def.keys, type_def.values)) self._int_map = dict(zip(type_def.values, type_def.keys)) - val_str = ', '.join(f"'{key}' = {value}" for key, value in zip(escaped_keys, type_def.values)) - self._name_suffix = f'({val_str})' + val_str = ", ".join(f"'{key}' = {value}" for key, value in zip(escaped_keys, type_def.values)) + self._name_suffix = f"({val_str})" def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): column = source.read_array(self._array_type, num_rows) - if self.read_format(ctx) == 'int': + if self.read_format(ctx) == "int": return column lookup = self._int_map.get return [lookup(x, None) for x in column] - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx:InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): first = first_value(column, self.nullable) if first is None or not isinstance(first, str): if self.nullable: @@ -340,22 +340,22 @@ class Enum(ClickHouseType): class Enum8(Enum): - _array_type = 'b' + _array_type = "b" byte_size = 1 class Enum16(Enum): - _array_type = 'h' + _array_type = "h" byte_size = 2 class Decimal(ClickHouseType): - __slots__ = 'prec', 'scale', '_mult', '_zeros', 'byte_size', '_array_type' + __slots__ = "prec", "scale", "_mult", "_zeros", "byte_size", "_array_type" python_type = decimal.Decimal dec_size = 0 @classmethod - def build(cls: Type['Decimal'], type_def: TypeDef): + def build(cls: type["Decimal"], type_def: TypeDef): size = cls.dec_size if size == 0: prec = type_def.values[0] @@ -371,31 +371,21 @@ class Decimal(ClickHouseType): super().__init__(type_def) self.prec = prec self.scale = scale - self._mult = 10 ** scale + self._mult = 10**scale self.byte_size = size // 8 self._zeros = bytes([0] * self.byte_size) - self._name_suffix = f'({prec}, {scale})' + self._name_suffix = f"({prec}, {scale})" self._array_type = array_type(self.byte_size, True) def _read_column_binary(self, source: ByteSource, num_rows: int, _ctx: QueryContext, _read_state: Any): column = source.read_array(self._array_type, num_rows) dec = decimal.Decimal scale = self.scale - prec = self.prec if scale == 0: - return [dec(str(x)) for x in column] - new_col = [] - app = new_col.append - for x in column: - if x >= 0: - digits = str(x).rjust(prec, '0') - app(dec(f'{digits[:-scale]}.{digits[-scale:]}')) - else: - digits = str(-x).rjust(prec, '0') - app(dec(f'-{digits[:-scale]}.{digits[-scale:]}')) - return new_col + return [dec(x) for x in column] + return [dec(x).scaleb(-scale) for x in column] - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx:InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): with decimal.localcontext() as dec_ctx: dec_ctx.prec = self.prec dec = decimal.Decimal @@ -408,35 +398,33 @@ class Decimal(ClickHouseType): def _active_null(self, ctx: QueryContext): if ctx.use_none: return None - digits = str('0').rjust(self.prec, '0') + digits = "0".rjust(self.prec, "0") scale = self.scale - return decimal.Decimal(f'{digits[:-scale]}.{digits[-scale:]}') + return decimal.Decimal(f"{digits[:-scale]}.{digits[-scale:]}") class BigDecimal(Decimal, registered=False): def _read_column_binary(self, source: ByteSource, num_rows: int, _ctx: QueryContext, _read_state: Any): dec = decimal.Decimal scale = self.scale - prec = self.prec column = [] app = column.append sz = self.byte_size ifb = int.from_bytes if scale == 0: for _ in range(num_rows): - app(dec(str(ifb(source.read_bytes(sz), 'little', signed=True)))) + app(dec(ifb(source.read_bytes(sz), "little", signed=True))) return column - for _ in range(num_rows): - x = ifb(source.read_bytes(sz), 'little', signed=True) - if x >= 0: - digits = str(x).rjust(prec, '0') - app(dec(f'{digits[:-scale]}.{digits[-scale:]}')) - else: - digits = str(-x).rjust(prec, '0') - app(dec(f'-{digits[:-scale]}.{digits[-scale:]}')) + # localcontext with ctx.prec = self.prec is required because scaleb() + # rounds to context precision. Default prec is 28 which would silently + # truncate Decimal128 (prec up to 38) and Decimal256 (prec up to 76) values. + with decimal.localcontext() as ctx: + ctx.prec = self.prec + for _ in range(num_rows): + app(dec(ifb(source.read_bytes(sz), "little", signed=True)).scaleb(-scale)) return column - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, _ctx): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, _ctx): with decimal.localcontext() as ctx: ctx.prec = self.prec mult = decimal.Decimal(f"{self._mult}.{'0' * self.scale}") @@ -445,10 +433,10 @@ class BigDecimal(Decimal, registered=False): if self.nullable: v = self._zeros for x in column: - dest += v if not x else itb(int(decimal.Decimal(str(x)) * mult), sz, 'little', signed=True) + dest += v if not x else itb(int(decimal.Decimal(str(x)) * mult), sz, "little", signed=True) else: for x in column: - dest += itb(int(decimal.Decimal(str(x)) * mult), sz, 'little', signed=True) + dest += itb(int(decimal.Decimal(str(x)) * mult), sz, "little", signed=True) class Decimal32(Decimal): diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/postinit.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/postinit.py index dbb29e56138..0f41d678816 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/postinit.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/postinit.py @@ -1,23 +1,23 @@ -from clickhouse_connect.datatypes import registry, dynamic, geometric +from clickhouse_connect.datatypes import dynamic, geometric, registry from clickhouse_connect.datatypes.base import TypeDef from clickhouse_connect.datatypes.container import Map -dynamic.STRING_DATA_TYPE = registry.get_from_name('String') +dynamic.STRING_DATA_TYPE = registry.get_from_name("String") # Build a private Map(String, String) for JSON shared data decoding. # We must NOT reuse the cached registry instance because we replace # value_type with SharedDataString (reads raw bytes, encoding=None). # Mutating the cached instance would break all normal Map(String, String) columns. -_shared_map = Map(TypeDef((), (), ('String', 'String'))) +_shared_map = Map(TypeDef((), (), ("String", "String"))) _shared_map.value_type = dynamic.SharedDataString(dynamic.STRING_DATA_TYPE.type_def) dynamic.SHARED_DATA_TYPE = _shared_map dynamic.SHARED_VARIANT_TYPE = dynamic.SharedVariant(dynamic.STRING_DATA_TYPE.type_def) -point = 'Tuple(Float64, Float64)' -ring = f'Array({point})' -polygon = f'Array({ring})' -multi_polygon = f'Array({polygon})' +point = "Tuple(Float64, Float64)" +ring = f"Array({point})" +polygon = f"Array({ring})" +multi_polygon = f"Array({polygon})" geometric.POINT_DATA_TYPE = registry.get_from_name(point) geometric.RING_DATA_TYPE = registry.get_from_name(ring) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/registry.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/registry.py index 6fb20d8af6c..3f5e45866c8 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/registry.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/registry.py @@ -1,15 +1,14 @@ import logging -from typing import Tuple, Dict -from clickhouse_connect.datatypes.base import TypeDef, ClickHouseType, type_map +from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef, type_map from clickhouse_connect.driver.exceptions import InternalError -from clickhouse_connect.driver.parser import parse_enum, parse_callable, parse_columns +from clickhouse_connect.driver.parser import parse_callable, parse_columns, parse_enum logger = logging.getLogger(__name__) -type_cache: Dict[str, ClickHouseType] = {} +type_cache: dict[str, ClickHouseType] = {} -def parse_name(name: str) -> Tuple[str, str, TypeDef]: +def parse_name(name: str) -> tuple[str, str, TypeDef]: """ Converts a ClickHouse type name into the base class and the definition (TypeDef) needed for any additional instantiation @@ -20,34 +19,34 @@ def parse_name(name: str) -> Tuple[str, str, TypeDef]: base = name wrappers = [] keys = tuple() - if base.startswith('LowCardinality'): - wrappers.append('LowCardinality') + if base.startswith("LowCardinality"): + wrappers.append("LowCardinality") base = base[15:-1] - if base.startswith('Nullable'): - wrappers.append('Nullable') + if base.startswith("Nullable"): + wrappers.append("Nullable") base = base[9:-1] - if base.startswith('Enum'): + if base.startswith("Enum"): keys, values = parse_enum(base) - base = base[:base.find('(')] - elif base.startswith('Nested'): + base = base[: base.find("(")] + elif base.startswith("Nested"): keys, values = parse_columns(base[6:]) - base = 'Nested' - elif base.startswith('Tuple'): + base = "Nested" + elif base.startswith("Tuple"): keys, values = parse_columns(base[5:]) - base = 'Tuple' - elif base.startswith('Variant'): + base = "Tuple" + elif base.startswith("Variant"): keys, values = parse_columns(base[7:]) - base = 'Variant' - elif base.startswith('JSON') and len(base) > 4 and base[4] == '(': + base = "Variant" + elif base.startswith("JSON") and len(base) > 4 and base[4] == "(": keys, values = parse_columns(base[4:]) - base = 'JSON' - elif base == 'Point': - values = ('Float64', 'Float64') + base = "JSON" + elif base == "Point": + values = ("Float64", "Float64") else: try: base, values, _ = parse_callable(base) except IndexError: - raise InternalError(f'Can not parse ClickHouse data type: {name}') from None + raise InternalError(f"Can not parse ClickHouse data type: {name}") from None return base, name, TypeDef(tuple(wrappers), keys, values) @@ -63,7 +62,7 @@ def get_from_name(name: str) -> ClickHouseType: try: ch_type = type_map[base].build(type_def) except KeyError: - err_str = f'Unrecognized ClickHouse type base: {base} name: {name}' + err_str = f"Unrecognized ClickHouse type base: {base} name: {name}" logger.error(err_str) raise InternalError(err_str) from None type_cache[name] = ch_type diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/special.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/special.py index 2c2843527d0..7b030fb3662 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/special.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/special.py @@ -1,7 +1,8 @@ -from typing import Union, Sequence, MutableSequence, Any +from collections.abc import MutableSequence, Sequence +from typing import Any from uuid import UUID as PYUUID -from clickhouse_connect.datatypes.base import TypeDef, ClickHouseType, ArrayType, UnsupportedType +from clickhouse_connect.datatypes.base import ArrayType, ClickHouseType, TypeDef, UnsupportedType from clickhouse_connect.datatypes.registry import get_from_name from clickhouse_connect.driver.common import first_value from clickhouse_connect.driver.ctypes import data_conv @@ -9,56 +10,55 @@ from clickhouse_connect.driver.insert import InsertContext from clickhouse_connect.driver.query import QueryContext from clickhouse_connect.driver.types import ByteSource -empty_uuid_b = bytes(b'\x00' * 16) +empty_uuid_b = bytes(b"\x00" * 16) class UUID(ClickHouseType): python_type = PYUUID - valid_formats = 'string', 'native' - np_type = 'U36' + valid_formats = "string", "native" + np_type = "U36" byte_size = 16 def python_null(self, ctx): - return '' if self.read_format(ctx) == 'string' else PYUUID(int=0) + return "" if self.read_format(ctx) == "string" else PYUUID(int=0) def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): - if self.read_format(ctx) == 'string': + if self.read_format(ctx) == "string": return self._read_binary_str(source, num_rows) return data_conv.read_uuid_col(source, num_rows) @staticmethod def _read_binary_str(source: ByteSource, num_rows: int): - v = source.read_array('Q', num_rows * 2) + v = source.read_array("Q", num_rows * 2) column = [] app = column.append for i in range(num_rows): ix = i << 1 - x = f'{(v[ix] << 64 | v[ix + 1]):032x}' - app(f'{x[:8]}-{x[8:12]}-{x[12:16]}-{x[16:20]}-{x[20:]}') + x = f"{(v[ix] << 64 | v[ix + 1]):032x}" + app(f"{x[:8]}-{x[8:12]}-{x[12:16]}-{x[16:20]}-{x[20:]}") return column - # pylint: disable=too-many-branches - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): first = first_value(column, self.nullable) empty = empty_uuid_b - if isinstance(first, str) or self.write_format(ctx) == 'string': + if isinstance(first, str) or self.write_format(ctx) == "string": for v in column: if v: - x = int(v.replace('-', ''), 16) - dest += (x >> 64).to_bytes(8, 'little') + (x & 0xffffffffffffffff).to_bytes(8, 'little') + x = int(v.replace("-", ""), 16) + dest += (x >> 64).to_bytes(8, "little") + (x & 0xFFFFFFFFFFFFFFFF).to_bytes(8, "little") else: dest += empty elif isinstance(first, int): for x in column: if x: - dest += (x >> 64).to_bytes(8, 'little') + (x & 0xffffffffffffffff).to_bytes(8, 'little') + dest += (x >> 64).to_bytes(8, "little") + (x & 0xFFFFFFFFFFFFFFFF).to_bytes(8, "little") else: dest += empty elif isinstance(first, PYUUID): for v in column: if v: x = v.int - dest += (x >> 64).to_bytes(8, 'little') + (x & 0xffffffffffffffff).to_bytes(8, 'little') + dest += (x >> 64).to_bytes(8, "little") + (x & 0xFFFFFFFFFFFFFFFF).to_bytes(8, "little") else: dest += empty elif isinstance(first, (bytes, bytearray, memoryview)): @@ -72,18 +72,18 @@ class UUID(ClickHouseType): class Nothing(ArrayType): - _array_type = 'b' + _array_type = "b" def __init__(self, type_def: TypeDef): super().__init__(type_def) self.nullable = True - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, _ctx): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, _ctx): dest += bytes(0x30 for _ in range(len(column))) class SimpleAggregateFunction(ClickHouseType): - _slots = ('element_type',) + _slots = ("element_type",) def __init__(self, type_def: TypeDef): super().__init__(type_def) @@ -106,7 +106,7 @@ class SimpleAggregateFunction(ClickHouseType): def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any): return self.element_type.read_column_data(source, num_rows, ctx, read_state) - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): self.element_type.write_column_data(column, dest, ctx) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/string.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/string.py index 7e71e6a6c7f..b06aa0f1b93 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/string.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/string.py @@ -1,22 +1,22 @@ -from typing import Sequence, MutableSequence, Union, Collection, Any +from collections.abc import Collection, MutableSequence, Sequence +from typing import Any +from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef +from clickhouse_connect.driver import options from clickhouse_connect.driver.common import first_value from clickhouse_connect.driver.ctypes import data_conv - -from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef from clickhouse_connect.driver.errors import handle_error from clickhouse_connect.driver.insert import InsertContext from clickhouse_connect.driver.query import QueryContext from clickhouse_connect.driver.types import ByteSource -from clickhouse_connect.driver import options class String(ClickHouseType): python_type = str - valid_formats = 'bytes', 'native' + valid_formats = "bytes", "native" def _active_encoding(self, ctx): - if self.read_format(ctx) == 'bytes': + if self.read_format(ctx) == "bytes": return None if ctx.encoding: return ctx.encoding @@ -27,7 +27,7 @@ class String(ClickHouseType): return 0 total = 0 for x in sample: - if x: + if isinstance(x, (str, bytes)): total += len(x) return total // len(sample) + 1 @@ -38,13 +38,13 @@ class String(ClickHouseType): return source.read_str_col(num_rows, self._active_encoding(ctx), True, self._active_null(ctx)) def _finalize_column(self, column: Sequence, ctx: QueryContext) -> Sequence: - if ctx.use_extended_dtypes and self.read_format(ctx) == 'native': + if ctx.use_extended_dtypes and self.read_format(ctx) == "native": return options.pd.array(column, dtype=options.pd.StringDtype()) if ctx.use_numpy and ctx.max_str_len: - return options.np.array(column, dtype=f'<U{ctx.max_str_len}') + return options.np.array(column, dtype=f"<U{ctx.max_str_len}") return column - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): encoding = None if not isinstance(first_value(column, self.nullable), bytes): encoding = ctx.encoding or self.encoding @@ -53,49 +53,48 @@ class String(ClickHouseType): def _active_null(self, ctx): if ctx.use_none: return None - if self.read_format(ctx) == 'bytes': - return bytes() - return '' + if self.read_format(ctx) == "bytes": + return b"" + return "" class FixedString(ClickHouseType): python_type = str - valid_formats = 'string', 'native' + valid_formats = "string", "native" def __init__(self, type_def: TypeDef): super().__init__(type_def) self.byte_size = type_def.values[0] self._name_suffix = type_def.arg_str - self._empty_bytes = bytes(b'\x00' * self.byte_size) + self._empty_bytes = bytes(b"\x00" * self.byte_size) def _active_null(self, ctx: QueryContext): if ctx.use_none: return None - return self._empty_bytes if self.read_format(ctx) == 'native' else '' + return self._empty_bytes if self.read_format(ctx) == "native" else "" @property def np_type(self): - return f'<U{self.byte_size}' + return f"<U{self.byte_size}" def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): - if self.read_format(ctx) == 'string': - return source.read_fixed_str_col(self.byte_size, num_rows, ctx.encoding or self.encoding ) + if self.read_format(ctx) == "string": + return source.read_fixed_str_col(self.byte_size, num_rows, ctx.encoding or self.encoding) return source.read_bytes_col(self.byte_size, num_rows) def _finalize_column(self, column: Sequence, ctx: QueryContext) -> Sequence: - if ctx.use_extended_dtypes and self.read_format(ctx) == 'string': + if ctx.use_extended_dtypes and self.read_format(ctx) == "string": return options.pd.array(column, dtype=options.pd.StringDtype()) return column - # pylint: disable=too-many-branches,duplicate-code - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): ext = dest.extend sz = self.byte_size empty = bytes((0,) * sz) str_enc = str.encode enc = ctx.encoding or self.encoding first = first_value(column, self.nullable) - if isinstance(first, str) or self.write_format(ctx) == 'string': + if isinstance(first, str) or self.write_format(ctx) == "string": if self.nullable: for x in column: if x is None: @@ -106,9 +105,9 @@ class FixedString(ClickHouseType): except UnicodeEncodeError: b = empty if len(b) > sz: - raise ctx.data_error(f'UTF-8 encoded FixedString value {b.hex(" ")} exceeds column size {sz}') + raise ctx.data_error(f"UTF-8 encoded FixedString value {b.hex(' ')} exceeds column size {sz}") ext(b) - ext(empty[:sz - len(b)]) + ext(empty[: sz - len(b)]) else: for x in column: try: @@ -116,19 +115,19 @@ class FixedString(ClickHouseType): except UnicodeEncodeError: b = empty if len(b) > sz: - raise ctx.data_error(f'UTF-8 encoded FixedString value {b.hex(" ")} exceeds column size {sz}') + raise ctx.data_error(f"UTF-8 encoded FixedString value {b.hex(' ')} exceeds column size {sz}") ext(b) - ext(empty[:sz - len(b)]) + ext(empty[: sz - len(b)]) elif self.nullable: for b in column: if not b: ext(empty) elif len(b) != sz: - raise ctx.data_error(f'Fixed String binary value {b.hex(" ")} does not match column size {sz}') + raise ctx.data_error(f"Fixed String binary value {b.hex(' ')} does not match column size {sz}") else: ext(b) else: for b in column: if len(b) != sz: - raise ctx.data_error(f'Fixed String binary value {b.hex(" ")} does not match column size {sz}') + raise ctx.data_error(f"Fixed String binary value {b.hex(' ')} does not match column size {sz}") ext(b) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/temporal.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/temporal.py index 7ad9b1757c0..0f8ed6ad43f 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/temporal.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/temporal.py @@ -1,52 +1,52 @@ -import pytz +from __future__ import annotations import array -from datetime import date, datetime, tzinfo, timedelta, time - -from typing import Union, Sequence, MutableSequence, Any, NamedTuple, Optional -from abc import abstractmethod import re +import zoneinfo +from abc import abstractmethod +from collections.abc import MutableSequence, Sequence +from datetime import date, datetime, time, timedelta, tzinfo +from typing import TYPE_CHECKING, Any, NamedTuple -from clickhouse_connect.datatypes.base import TypeDef, ClickHouseType -from clickhouse_connect.common import get_setting -from clickhouse_connect.driver import tzutil -from clickhouse_connect.driver.common import write_array, np_date_types, int_size, first_value -from clickhouse_connect.driver.exceptions import ProgrammingError +if TYPE_CHECKING: + import numpy + +from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef from clickhouse_connect.driver import ctypes as driver_ctypes +from clickhouse_connect.driver import options, tzutil +from clickhouse_connect.driver.common import first_value, int_size, np_date_types, write_array from clickhouse_connect.driver.ctypes import data_conv +from clickhouse_connect.driver.exceptions import ProgrammingError from clickhouse_connect.driver.insert import InsertContext from clickhouse_connect.driver.query import QueryContext from clickhouse_connect.driver.types import ByteSource -from clickhouse_connect.driver import options epoch_start_date = date(1970, 1, 1) epoch_start_datetime = datetime(1970, 1, 1) class Date(ClickHouseType): - _array_type = 'H' - np_type = 'datetime64[D]' + _array_type = "H" + np_type = "datetime64[D]" nano_divisor = 86400 * 1000000000 - valid_formats = 'native', 'int' + valid_formats = "native", "int" python_type = date byte_size = 2 @property def pandas_dtype(self): - if options.IS_PANDAS_2 and get_setting("preserve_pandas_datetime_resolution"): - return "datetime64[s]" - return f"datetime64[{self.pd_datetime_res}]" + return "datetime64[s]" - def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state:Any): - if self.read_format(ctx) == 'int': + def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): + if self.read_format(ctx) == "int": return source.read_array(self._array_type, num_rows) if ctx.use_numpy: - return driver_ctypes.numpy_conv.read_numpy_array(source, '<u2', num_rows).astype(self.np_type) + return driver_ctypes.numpy_conv.read_numpy_array(source, "<u2", num_rows).astype(self.np_type) return data_conv.read_date_col(source, num_rows) - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): first = first_value(column, self.nullable) - if isinstance(first, int) or self.write_format(ctx) == 'int': + if isinstance(first, int) or self.write_format(ctx) == "int": if self.nullable: column = [x if x else 0 for x in column] else: @@ -63,27 +63,24 @@ class Date(ClickHouseType): def _active_null(self, ctx: QueryContext): fmt = self.read_format(ctx) if ctx.use_extended_dtypes: - return options.pd.NA if fmt == 'int' else options.pd.NaT + return options.pd.NA if fmt == "int" else options.pd.NaT if ctx.use_none: return None - if fmt == 'int': + if fmt == "int": return 0 if ctx.use_numpy: - return options.np.datetime64(0, self.pd_datetime_res) + return options.np.datetime64(0, self._null_time_unit) return epoch_start_date - # pylint: disable=too-many-return-statements def _finalize_column(self, column: Sequence, ctx: QueryContext) -> Sequence: - if self.read_format(ctx) == 'int': + if self.read_format(ctx) == "int": return column if ctx.use_numpy and self.nullable and not ctx.use_none: return options.np.array(column, dtype=self.np_type) if ctx.use_extended_dtypes: - if isinstance(column, options.np.ndarray) and options.np.issubdtype( - column.dtype, options.np.datetime64 - ): + if isinstance(column, options.np.ndarray) and options.np.issubdtype(column.dtype, options.np.datetime64): return column.astype(self.pandas_dtype) if isinstance(column, options.pd.DatetimeIndex): @@ -94,64 +91,54 @@ class Date(ClickHouseType): return naive.tz_localize("UTC").tz_convert(column.tz) if self.nullable and isinstance(column, list): - return options.np.array([None if options.pd.isna(s) else s for s in column]).astype( - self.pandas_dtype - ) + return options.np.array([None if options.pd.isna(s) else s for s in column]).astype(self.pandas_dtype) - return options.pd.to_datetime(column, errors="coerce").to_numpy( - dtype=self.pandas_dtype, copy=False - ) + return options.pd.to_datetime(column, errors="coerce").to_numpy(dtype=self.pandas_dtype, copy=False) return column class Date32(Date): byte_size = 4 - _array_type = 'l' if int_size == 2 else 'i' + _array_type = "l" if int_size == 2 else "i" def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any): if ctx.use_numpy: - return driver_ctypes.numpy_conv.read_numpy_array(source, '<i4', num_rows).astype(self.np_type) - if self.read_format(ctx) == 'int': + return driver_ctypes.numpy_conv.read_numpy_array(source, "<i4", num_rows).astype(self.np_type) + if self.read_format(ctx) == "int": return source.read_array(self._array_type, num_rows) return data_conv.read_date32_col(source, num_rows) class DateTimeBase(ClickHouseType, registered=False): - __slots__ = ('tzinfo',) - valid_formats = 'native', 'int' + __slots__ = ("tzinfo",) + valid_formats = "native", "int" python_type = datetime @property def pandas_dtype(self): """Sets dtype for pandas datetime objects""" - if options.IS_PANDAS_2 and get_setting("preserve_pandas_datetime_resolution"): - return "datetime64[s]" - return f"datetime64[{self.pd_datetime_res}]" + return "datetime64[s]" def _active_null(self, ctx: QueryContext): fmt = self.read_format(ctx) if ctx.use_extended_dtypes: - return options.pd.NA if fmt == 'int' else options.pd.NaT + return options.pd.NA if fmt == "int" else options.pd.NaT if ctx.use_none: return None - if self.read_format(ctx) == 'int': + if self.read_format(ctx) == "int": return 0 if ctx.use_numpy: - return options.np.datetime64(0, self.pd_datetime_res) + return options.np.datetime64(0, self._null_time_unit) return epoch_start_datetime def _finalize_column(self, column: Sequence, ctx: QueryContext) -> Sequence: - """Ensure every datetime-like column is at nanosecond resolution, preserving any tz.""" if ctx.use_extended_dtypes: - if isinstance(column, options.np.ndarray) and options.np.issubdtype( - column.dtype, options.np.datetime64 - ): + if isinstance(column, options.np.ndarray) and options.np.issubdtype(column.dtype, options.np.datetime64): return column.astype(self.pandas_dtype) if isinstance(column, options.pd.DatetimeIndex) or ( - isinstance(column, list) - and hasattr(next((s for s in column if not options.pd.isna(s)), None), "tz") + isinstance(column, list) and hasattr(next((s for s in column if not options.pd.isna(s)), None), "tz") ): if isinstance(column, list): column = options.pd.DatetimeIndex(column) @@ -162,20 +149,16 @@ class DateTimeBase(ClickHouseType, registered=False): naive_ns = column.tz_convert("UTC").tz_localize(None).astype(self.pandas_dtype) tz_aware_result = naive_ns.tz_localize("UTC").tz_convert(column.tz) - return ( - options.pd.array(tz_aware_result) if self.nullable else tz_aware_result - ) + return options.pd.array(tz_aware_result) if self.nullable else tz_aware_result if self.nullable: - return options.pd.array( - [None if options.pd.isna(s) else s for s in column], dtype=self.pandas_dtype - ) + return options.pd.array([None if options.pd.isna(s) else s for s in column], dtype=self.pandas_dtype) return column class DateTime(DateTimeBase): - _array_type = 'L' if int_size == 2 else 'I' - np_type = 'datetime64[s]' + _array_type = "L" if int_size == 2 else "I" + np_type = "datetime64[s]" nano_divisor = 1000000000 byte_size = 4 @@ -183,24 +166,28 @@ class DateTime(DateTimeBase): super().__init__(type_def) self._name_suffix = type_def.arg_str if len(type_def.values) > 0: - self.tzinfo = pytz.timezone(type_def.values[0][1:-1]) + tz_name = type_def.values[0][1:-1] + try: + self.tzinfo = tzutil.resolve_zone(tz_name) + except zoneinfo.ZoneInfoNotFoundError as ex: + raise ProgrammingError(f"Column timezone {tz_name} is not recognized; {tzutil.TZDATA_HINT}") from ex else: self.tzinfo = None def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any) -> Sequence: - if self.read_format(ctx) == 'int': + if self.read_format(ctx) == "int": return source.read_array(self._array_type, num_rows) active_tz = ctx.active_tz(self.tzinfo) if ctx.use_numpy: - np_array = driver_ctypes.numpy_conv.read_numpy_array(source, '<u4', num_rows).astype(self.np_type) + np_array = driver_ctypes.numpy_conv.read_numpy_array(source, "<u4", num_rows).astype(self.np_type) if ctx.as_pandas and active_tz: - return options.pd.DatetimeIndex(np_array, tz='UTC').tz_convert(active_tz) + return options.pd.DatetimeIndex(np_array, tz="UTC").tz_convert(active_tz) return np_array return data_conv.read_datetime_col(source, num_rows, active_tz) - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): first = first_value(column, self.nullable) - if isinstance(first, int) or self.write_format(ctx) == 'int': + if isinstance(first, int) or self.write_format(ctx) == "int": if self.nullable: column = [x if x else 0 for x in column] else: @@ -212,77 +199,67 @@ class DateTime(DateTimeBase): class DateTime64(DateTimeBase): - __slots__ = 'scale', 'prec', 'unit' + __slots__ = "scale", "prec", "unit" byte_size = 8 def __init__(self, type_def: TypeDef): super().__init__(type_def) self._name_suffix = type_def.arg_str self.scale = type_def.values[0] - self.prec = 10 ** self.scale + self.prec = 10**self.scale self.unit = np_date_types.get(self.scale) if len(type_def.values) > 1: - self.tzinfo = pytz.timezone(type_def.values[1][1:-1]) + tz_name = type_def.values[1][1:-1] + try: + self.tzinfo = tzutil.resolve_zone(tz_name) + except zoneinfo.ZoneInfoNotFoundError as ex: + raise ProgrammingError(f"Column timezone {tz_name} is not recognized; {tzutil.TZDATA_HINT}") from ex else: self.tzinfo = None @property def pandas_dtype(self): """Sets dtype for pandas datetime objects""" - if options.IS_PANDAS_2 and get_setting("preserve_pandas_datetime_resolution"): - return f"datetime64{self.unit}" - return f"datetime64[{self.pd_datetime_res}]" + return f"datetime64{self.unit}" @property def np_type(self): if self.unit: - return f'datetime64{self.unit}' - raise ProgrammingError(f'Cannot use {self.name} as a numpy or Pandas datatype. Only milliseconds(3), ' + - 'microseconds(6), or nanoseconds(9) are supported for numpy based queries.') + return f"datetime64{self.unit}" + raise ProgrammingError( + f"Cannot use {self.name} as a numpy or Pandas datatype. Only milliseconds(3), " + + "microseconds(6), or nanoseconds(9) are supported for numpy based queries." + ) @property def nano_divisor(self): return 1000000000 // self.prec def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any) -> Sequence: - if self.read_format(ctx) == 'int': - return source.read_array('q', num_rows) + if self.read_format(ctx) == "int": + return source.read_array("q", num_rows) active_tz = ctx.active_tz(self.tzinfo) if ctx.use_numpy: np_array = driver_ctypes.numpy_conv.read_numpy_array(source, self.np_type, num_rows) if ctx.as_pandas and active_tz: - return options.pd.DatetimeIndex(np_array, tz='UTC').tz_convert(active_tz) + return options.pd.DatetimeIndex(np_array, tz="UTC").tz_convert(active_tz) return np_array - column = source.read_array('q', num_rows) + column = source.read_array("q", num_rows) if active_tz: return self._read_binary_tz(column, active_tz) return self._read_binary_naive(column) def _read_binary_tz(self, column: Sequence, tz_info: tzinfo): - new_col = [] - app = new_col.append - dt_from = datetime.fromtimestamp - prec = self.prec - for ticks in column: - seconds = ticks // prec - dt_sec = dt_from(seconds, tz_info) - app(dt_sec.replace(microsecond=((ticks - seconds * prec) * 1000000) // prec)) - return new_col + if tzutil.is_utc_timezone(tz_info): + return data_conv.read_datetime64_naive_col(column, self.prec, tz_info) + return data_conv.read_datetime64_tz_col(column, self.prec, tz_info) def _read_binary_naive(self, column: Sequence): - new_col = [] - app = new_col.append - dt_from = tzutil.utcfromtimestamp - prec = self.prec - for ticks in column: - seconds = ticks // prec - dt_sec = dt_from(seconds) - app(dt_sec.replace(microsecond=((ticks - seconds * prec) * 1000000) // prec)) - return new_col + return data_conv.read_datetime64_naive_col(column, self.prec) - def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext): + def _write_column_binary(self, column: Sequence | MutableSequence, dest: bytearray, ctx: InsertContext): first = first_value(column, self.nullable) - if isinstance(first, int) or self.write_format(ctx) == 'int': + if isinstance(first, int) or self.write_format(ctx) == "int": if self.nullable: column = [x if x else 0 for x in column] elif isinstance(first, str): @@ -300,11 +277,10 @@ class DateTime64(DateTimeBase): else: prec = self.prec if self.nullable: - column = [((int(x.timestamp()) * 1000000 + x.microsecond) * prec) // 1000000 if x else 0 - for x in column] + column = [((int(x.timestamp()) * 1000000 + x.microsecond) * prec) // 1000000 if x else 0 for x in column] else: column = [((int(x.timestamp()) * 1000000 + x.microsecond) * prec) // 1000000 for x in column] - write_array('q', column, dest, ctx.column_name) + write_array("q", column, dest, ctx.column_name) class _HMSParts(NamedTuple): @@ -313,7 +289,7 @@ class _HMSParts(NamedTuple): hours: int minutes: int seconds: int - frac: Optional[str] + frac: str | None is_negative: bool @@ -369,9 +345,7 @@ class TimeBase(ClickHouseType, registered=False): fmt = self.read_format(ctx) if ctx.use_numpy: - return options.np.array( - [self._ticks_to_np_timedelta(t) for t in ticks], dtype=self.np_type - ) + return options.np.array([self._ticks_to_np_timedelta(t) for t in ticks], dtype=self.np_type) if fmt == "int": return ticks @@ -405,17 +379,11 @@ class TimeBase(ClickHouseType, registered=False): seconds = int(match["seconds"]) if hours > 999: - raise ValueError( - f"Hours out of range; cannot exceed 999: got {hours} in '{time_str}'" - ) + raise ValueError(f"Hours out of range; cannot exceed 999: got {hours} in '{time_str}'") if not 0 <= minutes < 60: - raise ValueError( - f"Minutes out of range; must be 0-59: got {minutes} in '{time_str}'" - ) + raise ValueError(f"Minutes out of range; must be 0-59: got {minutes} in '{time_str}'") if not 0 <= seconds < 60: - raise ValueError( - f"Seconds out of range; must be 0-59: got {seconds} in '{time_str}'" - ) + raise ValueError(f"Seconds out of range; must be 0-59: got {seconds} in '{time_str}'") return _HMSParts( hours=hours, @@ -466,11 +434,9 @@ class TimeBase(ClickHouseType, registered=False): def _validate_time_obj_range(self, ticks: int) -> None: """Ensure ticks can form a valid datetime.time object.""" if not self.min_time_ticks <= ticks <= self.max_time_ticks: - raise ValueError( - f"Ticks value {ticks} is outside valid range for datetime.time object." - ) + raise ValueError(f"Ticks value {ticks} is outside valid range for datetime.time object.") - def _numerical_to_ticks(self, value: Union[int, float, "np.int64"]) -> int: + def _numerical_to_ticks(self, value: int | float | numpy.int64) -> int: """Convert numerical value to ticks, with range validation.""" value = int(value) self._validate_standard_range(value, value) @@ -494,26 +460,20 @@ class TimeBase(ClickHouseType, registered=False): @property def pandas_dtype(self): - """Sets dtype for pandas datetime objects""" - if options.IS_PANDAS_2 and get_setting("preserve_pandas_datetime_resolution"): - return "timedelta64[s]" - return f"timedelta64[{self.pd_datetime_res}]" + """Sets dtype for pandas timedelta objects""" + return "timedelta64[s]" def _finalize_column(self, column: Sequence, ctx: QueryContext) -> Sequence: """Finalize column data based on context requirements.""" if ctx.use_extended_dtypes: - if isinstance(column, options.np.ndarray) and options.np.issubdtype( - column.dtype, options.np.timedelta64 - ): + if isinstance(column, options.np.ndarray) and options.np.issubdtype(column.dtype, options.np.timedelta64): return column.astype(self.pandas_dtype) if isinstance(column, options.pd.TimedeltaIndex): return column.astype(self.pandas_dtype) if self.nullable: - return options.np.array([None if options.pd.isna(s) else s for s in column]).astype( - self.pandas_dtype - ) + return options.np.array([None if options.pd.isna(s) else s for s in column]).astype(self.pandas_dtype) return column def _build_lc_column(self, index: Sequence, keys: array.array, ctx: QueryContext): @@ -529,7 +489,7 @@ class TimeBase(ClickHouseType, registered=False): raise NotImplementedError @abstractmethod - def _timedelta_to_ticks(self, td: Union[timedelta, "np.timedelta64"]) -> int: + def _timedelta_to_ticks(self, td: timedelta | numpy.timedelta64) -> int: """Convert a timedelta into integer ticks.""" raise NotImplementedError @@ -621,7 +581,7 @@ class Time(TimeBase): return f"{sign}{h:03d}:{m:02d}:{s:02d}" - def _timedelta_to_ticks(self, td: Union[timedelta, "np.timedelta64"]) -> int: + def _timedelta_to_ticks(self, td: timedelta | numpy.timedelta64) -> int: """Convert timedelta to ticks (seconds), flooring fractional seconds.""" if isinstance(td, timedelta): total = int(td.total_seconds()) @@ -664,19 +624,14 @@ class Time64(TimeBase): self._name_suffix = type_def.arg_str self.scale = type_def.values[0] if self.scale not in (3, 6, 9): - raise ProgrammingError( - f"Unsupported Time64 scale {self.scale}; " - "only 3, 6, or 9 are allowed for NumPy." - ) + raise ProgrammingError(f"Unsupported Time64 scale {self.scale}; only 3, 6, or 9 are allowed for NumPy.") self.precision = 10**self.scale self.unit = np_date_types.get(self.scale) @property def pandas_dtype(self): - """Sets dtype for pandas datetime objects""" - if options.IS_PANDAS_2 and get_setting("preserve_pandas_datetime_resolution"): - return f"timedelta64{self.unit}" - return f"timedelta64[{self.pd_datetime_res}]" + """Sets dtype for pandas timedelta objects""" + return f"timedelta64{self.unit}" @property def max_time_ticks(self) -> int: @@ -698,9 +653,7 @@ class Time64(TimeBase): """Parse string format 'HHH:MM:SS[.fff]' to ticks with sub-second precision.""" parts = self._parse_core(time_str) frac_ticks = int((parts.frac or "").ljust(self.scale, "0")[: self.scale]) - ticks = ( - parts.hours * 3600 + parts.minutes * 60 + parts.seconds - ) * self.precision + frac_ticks + ticks = (parts.hours * 3600 + parts.minutes * 60 + parts.seconds) * self.precision + frac_ticks if parts.is_negative: ticks = -ticks self._validate_standard_range(ticks, time_str) @@ -718,12 +671,10 @@ class Time64(TimeBase): return f"{sign}{h:03d}:{m:02d}:{s:02d}{frac_str}" - def _timedelta_to_ticks(self, td: Union[timedelta, "np.timedelta64"]) -> int: + def _timedelta_to_ticks(self, td: timedelta | numpy.timedelta64) -> int: """Convert timedelta to ticks with sub-second precision.""" if isinstance(td, timedelta): - total_us = ( - int(td.total_seconds()) * self._MICROS_PER_SECOND + td.microseconds - ) + total_us = int(td.total_seconds()) * self._MICROS_PER_SECOND + td.microseconds ticks = (total_us * self.precision) // self._MICROS_PER_SECOND else: ticks = td.astype("timedelta64[s]").astype(int) @@ -742,7 +693,7 @@ class Time64(TimeBase): return -td if neg else td - def _ticks_to_np_timedelta(self, ticks: int) -> "np.timedelta64": + def _ticks_to_np_timedelta(self, ticks: int) -> numpy.timedelta64: """Convert ticks to numpy timedelta64 with nanosecond precision.""" res_map = {3: "ms", 6: "us", 9: "ns"} @@ -750,9 +701,7 @@ class Time64(TimeBase): def _time_to_ticks(self, t: time) -> int: """Convert time to ticks with sub-second precision.""" - total_us = ( - t.hour * 3600 + t.minute * 60 + t.second - ) * self._MICROS_PER_SECOND + t.microsecond + total_us = (t.hour * 3600 + t.minute * 60 + t.second) * self._MICROS_PER_SECOND + t.microsecond ticks = (total_us * self.precision) // self._MICROS_PER_SECOND self._validate_time_obj_range(ticks) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/vector.py b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/vector.py index a495985c86b..0d666941f74 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/vector.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/datatypes/vector.py @@ -1,13 +1,19 @@ +from __future__ import annotations + import logging +from collections.abc import Sequence from math import ceil, nan from struct import pack, unpack -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import numpy from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef from clickhouse_connect.datatypes.registry import get_from_name +from clickhouse_connect.driver import options from clickhouse_connect.driver.ctypes import data_conv from clickhouse_connect.driver.insert import InsertContext -from clickhouse_connect.driver import options from clickhouse_connect.driver.query import QueryContext from clickhouse_connect.driver.types import ByteSource @@ -178,11 +184,11 @@ class QBit(ClickHouseType): planes_uint8 = planes_uint8.reshape(self._bits_per_element, -1) # 2. Unpack bits to get the boolean/integer matrix - bits_matrix: "np.ndarray" = options.np.unpackbits(planes_uint8, axis=1, bitorder="little") + bits_matrix = options.np.unpackbits(planes_uint8, axis=1, bitorder="little") # 3. Trim padding if necessary - if bits_matrix.shape[1] != self.dimension: # pylint: disable=no-member - bits_matrix = bits_matrix[:, : self.dimension] # pylint: disable=invalid-sequence-index + if bits_matrix.shape[1] != self.dimension: + bits_matrix = bits_matrix[:, : self.dimension] # 4. Reconstruct the integer words if self.element_type == "Float64": @@ -246,7 +252,7 @@ class QBit(ClickHouseType): return tuple(bit_planes) - def _transpose_row_numpy(self, vector: "np.ndarray") -> tuple: + def _transpose_row_numpy(self, vector: numpy.ndarray) -> tuple: """Fast path for numpy arrays using vectorized operations.""" # Cast to int view if self.element_type == "BFloat16": diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/__init__.py b/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/__init__.py index ea792b49683..84a9a8c3fcc 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/__init__.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/__init__.py @@ -1,28 +1,29 @@ -from typing import Optional - from clickhouse_connect.dbapi.connection import Connection - -apilevel = '2.0' # PEP 249 DB API level -threadsafety = 2 # PEP 249 Threads may share the module and connections. -paramstyle = 'pyformat' # PEP 249 Python extended format codes, e.g. ...WHERE name=%(name)s +apilevel = "2.0" # PEP 249 DB API level +threadsafety = 2 # PEP 249 Threads may share the module and connections. +paramstyle = "pyformat" # PEP 249 Python extended format codes, e.g. ...WHERE name=%(name)s class Error(Exception): pass -def connect(host: Optional[str] = None, - database: Optional[str] = None, - username: Optional[str] = '', - password: Optional[str] = '', - port: Optional[int] = None, - **kwargs): - secure = kwargs.pop('secure', False) - return Connection(host=host, - database=database, - username=username, - password=password, - port=port, - secure=secure, - **kwargs) +def connect( + host: str | None = None, + database: str | None = None, + username: str | None = "", + password: str | None = "", + port: int | None = None, + **kwargs, +): + secure = kwargs.pop("secure", False) + return Connection( + host=host, + database=database, + username=username, + password=password, + port=port, + secure=secure, + **kwargs, + ) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/connection.py b/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/connection.py index 2c5bff6741e..3dba3c54aef 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/connection.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/connection.py @@ -1,5 +1,3 @@ -from typing import Union - from clickhouse_connect.dbapi.cursor import Cursor from clickhouse_connect.driver import create_client from clickhouse_connect.driver.query import QueryResult @@ -9,26 +7,30 @@ class Connection: """ See :ref:`https://peps.python.org/pep-0249/` """ - # pylint: disable=too-many-arguments - def __init__(self, - dsn: str = None, - username: str = '', - password: str = '', - host: str = None, - database: str = None, - interface: str = None, - port: int = 0, - secure: Union[bool, str] = False, - **kwargs): - self.client = create_client(host=host, - username=username, - password=password, - database=database, - interface=interface, - port=port, - secure=secure, - dsn=dsn, - generic_args=kwargs) + + def __init__( + self, + dsn: str = None, + username: str = "", + password: str = "", + host: str = None, + database: str = None, + interface: str = None, + port: int = 0, + secure: bool | str = False, + **kwargs, + ): + self.client = create_client( + host=host, + username=username, + password=password, + database=database, + interface=interface, + port=port, + secure=secure, + dsn=dsn, + generic_args=kwargs, + ) self.client._add_integration_tag("sqlalchemy") self.timezone = self.client.server_tz diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/cursor.py b/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/cursor.py index 279c67002eb..9ad0431d6ed 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/cursor.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/dbapi/cursor.py @@ -1,23 +1,21 @@ import logging import re - -from typing import Optional, Sequence, List, Dict +from collections.abc import Sequence from clickhouse_connect.datatypes.registry import get_from_name +from clickhouse_connect.driver import Client from clickhouse_connect.driver.common import unescape_identifier from clickhouse_connect.driver.exceptions import ProgrammingError -from clickhouse_connect.driver import Client from clickhouse_connect.driver.parser import parse_callable from clickhouse_connect.driver.query import remove_sql_comments logger = logging.getLogger(__name__) -insert_re = re.compile(r'^\s*INSERT\s+INTO\s+(.*$)', re.IGNORECASE) -str_type = get_from_name('String') -int_type = get_from_name('Int32') +insert_re = re.compile(r"^\s*INSERT\s+INTO\s+(.*$)", re.IGNORECASE) +str_type = get_from_name("String") +int_type = get_from_name("Int32") -# pylint: disable=too-many-instance-attributes class Cursor: """ See :ref:`https://peps.python.org/pep-0249/` @@ -26,16 +24,16 @@ class Cursor: def __init__(self, client: Client): self.client = client self.arraysize = 1 - self.data: Optional[Sequence] = None + self.data: Sequence | None = None self.names = [] self.types = [] self._rowcount = 0 - self._summary: List[Dict[str, str]] = [] + self._summary: list[dict[str, str]] = [] self._ix = 0 def check_valid(self): if self.data is None: - raise ProgrammingError('Cursor is not valid') + raise ProgrammingError("Cursor is not valid") @property def description(self): @@ -46,7 +44,7 @@ class Cursor: return self._rowcount @property - def summary(self) -> List[Dict[str, str]]: + def summary(self) -> list[dict[str, str]]: return self._summary def close(self): @@ -59,7 +57,7 @@ class Cursor: # parameters, Python's % operator in finalize_query handles the # unescaping automatically. When there are no parameters, # finalize_query short-circuits, so we must unescape here. - operation = operation.replace('%%', '%') + operation = operation.replace("%%", "%") query_result = self.client.query(operation, parameters) self.data = query_result.result_set self._rowcount = len(self.data) @@ -72,7 +70,7 @@ class Cursor: self.names = query_result.column_names self.types = [x.name for x in query_result.column_types] elif self.data: - self.names = [f'col_{x}' for x in range(len(self.data[0]))] + self.names = [f"col_{x}" for x in range(len(self.data[0]))] self.types = [x.__class__ for x in self.data[0]] else: stripped = operation.strip().rstrip(";").strip() @@ -87,14 +85,14 @@ class Cursor: if not match: return False temp = match.group(1) - table_end = min(temp.find(' '), temp.find('(')) + table_end = min(temp.find(" "), temp.find("(")) table = temp[:table_end].strip() temp = temp[table_end:].strip() - if temp[0] == '(': + if temp[0] == "(": _, op_columns, temp = parse_callable(temp) else: op_columns = None - if 'VALUES' not in temp.upper(): + if "VALUES" not in temp.upper(): return False col_names = list(data[0].keys()) if op_columns and {unescape_identifier(x) for x in op_columns} != set(col_names): @@ -114,14 +112,18 @@ class Cursor: self.data.extend(query_result.result_set) if self.names or self.types: if query_result.column_names != self.names: - logger.warning('Inconsistent column names %s : %s for operation %s in cursor executemany', - self.names, query_result.column_names, operation) + logger.warning( + "Inconsistent column names %s : %s for operation %s in cursor executemany", + self.names, + query_result.column_names, + operation, + ) else: self.names = query_result.column_names self.types = query_result.column_types self._summary.append(query_result.summary) except TypeError as ex: - raise ProgrammingError(f'Invalid parameters {parameters} passed to cursor executemany') from ex + raise ProgrammingError(f"Invalid parameters {parameters} passed to cursor executemany") from ex self._rowcount = len(self.data) # Need to reset cursor _ix after performing an execute @@ -129,7 +131,7 @@ class Cursor: def fetchall(self): self.check_valid() - ret = self.data[self._ix:] + ret = self.data[self._ix :] self._ix = self._rowcount return ret @@ -152,7 +154,7 @@ class Cursor: return [] end = min(self._ix + size, self._rowcount) - ret = self.data[self._ix: end] + ret = self.data[self._ix : end] self._ix = end return ret diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/__init__.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/__init__.py index aac2f85e2b8..d0fe9f1d8f6 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/__init__.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/__init__.py @@ -1,34 +1,101 @@ -import asyncio -import warnings -from concurrent.futures import ThreadPoolExecutor +from __future__ import annotations + from inspect import signature -from typing import Optional, Union, Dict, Any -from urllib.parse import urlparse, parse_qs +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qs, urlparse -import clickhouse_connect.driver.ctypes +import clickhouse_connect.driver.ctypes # noqa: F401 -- side-effect import from clickhouse_connect.driver.client import Client -from clickhouse_connect.driver.common import dict_copy from clickhouse_connect.driver.exceptions import ProgrammingError from clickhouse_connect.driver.httpclient import HttpClient -from clickhouse_connect.driver.asyncclient import AsyncClient, DefaultThreadPoolExecutor, NEW_THREAD_POOL_EXECUTOR -__all__ = ['Client', 'AsyncClient', 'create_client', 'create_async_client'] +if TYPE_CHECKING: + from clickhouse_connect.driver.asyncclient import AsyncClient + +__all__ = ["Client", "AsyncClient", "create_client", "create_async_client"] + + +def __getattr__(name): + if name == "AsyncClient": + try: + from clickhouse_connect.driver.asyncclient import AsyncClient + except ModuleNotFoundError as ex: + if ex.name == "aiohttp" or (ex.name and ex.name.startswith("aiohttp.")): + raise ImportError("Async support requires aiohttp. Install with: pip install clickhouse-connect[async]") from ex + raise + return AsyncClient + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def default_port(interface: str, secure: bool) -> int: + """Get default port for the given interface.""" + if interface.startswith("http"): + return 8443 if secure else 8123 + raise ValueError("Unrecognized ClickHouse interface") -# pylint: disable=too-many-arguments,too-many-locals,too-many-branches -def create_client(*, - host: Optional[str] = None, - username: Optional[str] = None, - password: str = '', - access_token: Optional[str] = None, - database: str = '__default__', - interface: Optional[str] = None, - port: int = 0, - secure: Union[bool, str] = False, - dsn: Optional[str] = None, - settings: Optional[Dict[str, Any]] = None, - generic_args: Optional[Dict[str, Any]] = None, - **kwargs) -> Client: +def _parse_connection_params( + host: str | None, + username: str | None, + password: str, + port: int, + database: str, + interface: str | None, + secure: bool | str, + dsn: str | None, + kwargs: dict[str, Any], +) -> tuple[str, str | None, str, int, str, str]: + """Parse and normalize connection parameters including DSN parsing.""" + if dsn: + parsed = urlparse(dsn) + username = username or parsed.username + password = password or parsed.password + host = host or parsed.hostname + port = port or parsed.port + if parsed.path and (not database or database == "__default__"): + database = parsed.path[1:].split("/")[0] + database = database or parsed.path + for k, v in parse_qs(parsed.query).items(): + kwargs[k] = v[0] + use_tls = str(secure).lower() == "true" or interface == "https" or (not interface and str(port) in ("443", "8443")) + if not host: + host = "localhost" + if not interface: + interface = "https" if use_tls else "http" + port = port or default_port(interface, use_tls) + if username is None and "user" in kwargs: + username = kwargs.pop("user") + if username is None and "user_name" in kwargs: + username = kwargs.pop("user_name") + if password and username is None: + username = "default" + if "compression" in kwargs and "compress" not in kwargs: + kwargs["compress"] = kwargs.pop("compression") + + return host, username, password, port, database, interface + + +def _validate_access_token(access_token: str | None, username: str | None, password: str) -> None: + """Validate that access_token and username/password are not both provided.""" + if access_token and (username or password != ""): + raise ProgrammingError("Cannot use both access_token and username/password") + + +def create_client( + *, + host: str | None = None, + username: str | None = None, + password: str = "", + access_token: str | None = None, + database: str = "__default__", + interface: str | None = None, + port: int = 0, + secure: bool | str = False, + dsn: str | None = None, + settings: dict[str, Any] | None = None, + generic_args: dict[str, Any] | None = None, + **kwargs, +) -> Client: """ The preferred method to get a ClickHouse Connect Client instance @@ -82,86 +149,61 @@ def create_client(*, :param tz_source Controls how the client determines the fallback timezone for DateTime columns without an explicit timezone. "auto" (default) auto-detects based on DST safety of server timezone. "server" always uses the server timezone. "local" always uses the local timezone. - :param apply_server_timezone Deprecated. Use tz_source instead. :param tz_mode Controls timezone-aware behavior for UTC DateTime columns. "naive_utc" (default) returns naive UTC timestamps. "aware" forces timezone-aware UTC datetimes. "schema" returns datetimes that match the server's column definition which means timezone-aware when the column defines a timezone and naive for bare DateTime columns. - :param utc_tz_aware Deprecated. Use tz_mode instead. :param autogenerate_session_id If set, this will override the 'autogenerate_session_id' common setting. :param form_encode_query_params If True, query parameters will be sent as form-encoded data in the request body instead of as URL parameters. This is useful for queries with large parameter sets that might exceed URL length limits. Only available for query operations (not inserts). Default: False :return: ClickHouse Connect Client instance """ - if dsn: - parsed = urlparse(dsn) - username = username or parsed.username - password = password or parsed.password - host = host or parsed.hostname - port = port or parsed.port - if parsed.path and (not database or database == '__default__'): - database = parsed.path[1:].split('/')[0] - database = database or parsed.path - for k, v in parse_qs(parsed.query).items(): - kwargs[k] = v[0] - use_tls = str(secure).lower() == 'true' or interface == 'https' or (not interface and str(port) in ('443', '8443')) - if not host: - host = 'localhost' - if not interface: - interface = 'https' if use_tls else 'http' - port = port or default_port(interface, use_tls) - if access_token and (username or password != ''): - raise ProgrammingError('Cannot use both access_token and username/password') - if username is None and 'user' in kwargs: - username = kwargs.pop('user') - if username is None and 'user_name' in kwargs: - username = kwargs.pop('user_name') - if password and username is None: - username = 'default' - if 'compression' in kwargs and 'compress' not in kwargs: - kwargs['compress'] = kwargs.pop('compression') + host, username, password, port, database, interface = _parse_connection_params( + host, username, password, port, database, interface, secure, dsn, kwargs + ) + _validate_access_token(access_token, username, password) + settings = settings or {} - if interface.startswith('http'): + if interface.startswith("http"): if generic_args: client_params = signature(HttpClient).parameters for name, value in generic_args.items(): if name in client_params: kwargs[name] = value - elif name == 'compression': - if 'compress' not in kwargs: - kwargs['compress'] = value + elif name == "compression": + if "compress" not in kwargs: + kwargs["compress"] = value else: - if name.startswith('ch_'): + if name.startswith("ch_"): name = name[3:] settings[name] = value - return HttpClient(interface, host, port, username, password, database, access_token, - settings=settings, **kwargs) - raise ProgrammingError(f'Unrecognized client type {interface}') + return HttpClient(interface, host, port, username, password, database, access_token, settings=settings, **kwargs) + raise ProgrammingError(f"Unrecognized client type {interface}") -def default_port(interface: str, secure: bool): - if interface.startswith('http'): - return 8443 if secure else 8123 - raise ValueError('Unrecognized ClickHouse interface') - - -async def create_async_client(*, - host: Optional[str] = None, - username: Optional[str] = None, - password: str = '', - database: str = '__default__', - interface: Optional[str] = None, - port: int = 0, - secure: Union[bool, str] = False, - dsn: Optional[str] = None, - settings: Optional[Dict[str, Any]] = None, - generic_args: Optional[Dict[str, Any]] = None, - executor_threads: int = 0, - executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR, - **kwargs) -> AsyncClient: +async def create_async_client( + *, + host: str | None = None, + username: str | None = None, + password: str = "", + access_token: str | None = None, + database: str = "__default__", + interface: str | None = None, + port: int = 0, + secure: bool | str = False, + dsn: str | None = None, + settings: dict[str, Any] | None = None, + generic_args: dict[str, Any] | None = None, + connector_limit: int = 100, + connector_limit_per_host: int = 20, + keepalive_timeout: float = 30.0, + **kwargs, +) -> AsyncClient: """ The preferred method to get an async ClickHouse Connect Client instance. + Requires the async extra: pip install clickhouse-connect[async] + For sync version, see create_client. Unlike sync version, the 'autogenerate_session_id' setting by default is False. @@ -169,6 +211,7 @@ async def create_async_client(*, :param host: The hostname or IP address of the ClickHouse server. If not set, localhost will be used. :param username: The ClickHouse username. If not set, the default ClickHouse user will be used. :param password: The password for username. + :param access_token: JWT access token. :param database: The default database for the connection. If not set, ClickHouse Connect will use the default database for username. :param interface: Must be http or https. Defaults to http, or to https if port is set to 8443 or 443 @@ -180,12 +223,10 @@ async def create_async_client(*, :param settings: ClickHouse server settings to be used with the session/every request :param generic_args: Used internally to parse DBAPI connection strings into keyword arguments and ClickHouse settings. It is not recommended to use this parameter externally - :param executor_threads: 'max_worker' threads used by the client ThreadPoolExecutor. If not set, the default - of 4 + detected CPU cores will be used - :param executor: Optional `ThreadPoolExecutor` to use for async operations. If not set, a new `ThreadPoolExecutor` - will be created with the number of threads specified by `executor_threads`. If set to `None` it will use the - default executor of the event loop. - :param kwargs -- Recognized keyword arguments (used by the HTTP client), see below + :param connector_limit: Maximum number of allowable connections to the server + :param connector_limit_per_host: Maximum number of connections per host + :param keepalive_timeout: Time limit on idle keepalive connections + :param kwargs -- Recognized keyword arguments (used by the async HTTP client), see below :param compress: Enable compression for ClickHouse HTTP inserts and query results. True will select the preferred compression method (lz4). A str of 'lz4', 'zstd', 'brotli', or 'gzip' can be used to use a specific compression type @@ -194,7 +235,6 @@ async def create_async_client(*, :param send_receive_timeout: Read timeout in seconds for http connection :param client_name: client_name prepended to the HTTP User Agent header. Set this to track client queries in the ClickHouse system.query_log. - :param send_progress: Deprecated, has no effect. Previous functionality is now automatically determined :param verify: Verify the server certificate in secure/https mode :param ca_cert: If verify is True, the file path to Certificate Authority root to validate ClickHouse server certificate, in .pem format. Ignored if verify is False. This is not necessary if the ClickHouse server @@ -206,8 +246,6 @@ async def create_async_client(*, is not included the Client Certificate key file :param session_id ClickHouse session id. If not specified and the common setting 'autogenerate_session_id' is True, the client will generate a UUID1 session id - :param pool_mgr Optional urllib3 PoolManager for this client. Useful for creating separate connection - pools for multiple client endpoints for applications with many clients :param http_proxy http proxy address. Equivalent to setting the HTTP_PROXY environment variable :param https_proxy https proxy address. Equivalent to setting the HTTPS_PROXY environment variable :param server_host_name This is the server host name that will be checked against a TLS certificate for @@ -216,37 +254,64 @@ async def create_async_client(*, :param tz_source Controls how the client determines the fallback timezone for DateTime columns without an explicit timezone. "auto" (default) auto-detects based on DST safety of server timezone. "server" always uses the server timezone. "local" always uses the local timezone. - :param apply_server_timezone Deprecated. Use tz_source instead. :param tz_mode Controls timezone-aware behavior for UTC DateTime columns. "naive_utc" (default) returns naive UTC timestamps. "aware" forces timezone-aware UTC datetimes. "schema" returns datetimes that match the server's column definition which means timezone-aware when the column defines a timezone and naive for bare DateTime columns. - :param utc_tz_aware Deprecated. Use tz_mode instead. :param autogenerate_session_id If set, this will override the 'autogenerate_session_id' common setting. :param form_encode_query_params If True, query parameters will be sent as form-encoded data in the request body instead of as URL parameters. This is useful for queries with large parameter sets that might exceed URL length limits. Only available for query operations (not inserts). Default: False - :return: ClickHouse Connect Client instance + :return: ClickHouse Connect AsyncClient instance """ + try: + from clickhouse_connect.driver.asyncclient import AsyncClient as _AsyncClient + except ModuleNotFoundError as ex: + if ex.name == "aiohttp" or (ex.name and ex.name.startswith("aiohttp.")): + raise ImportError("Async support requires aiohttp. Install with: pip install clickhouse-connect[async]") from ex + raise + + if "pool_mgr" in kwargs: + raise ProgrammingError( + "pool_mgr is not supported by the async client. " + "Use connector_limit and connector_limit_per_host to configure connection pooling." + ) - warnings.warn( - "The current async client is a thread-pool wrapper around the sync client. " - "A fully native async client is available for testing as a prerelease: " - "pip install 'clickhouse-connect[async]==0.12.0rc1'. " - "This prerelease branch is based on 0.11.0 and is gathering feedback ahead of 1.0.0, " - "where it will become the default async implementation. It is a drop-in replacement " - "with the same API surface. The main line includes additional updates that the native " - "client will receive when merged into 1.0.0.", - FutureWarning, - stacklevel=2, + host, username, password, port, database, interface = _parse_connection_params( + host, username, password, port, database, interface, secure, dsn, kwargs ) + _validate_access_token(access_token, username, password) - def _create_client(): - if 'autogenerate_session_id' not in kwargs: - kwargs['autogenerate_session_id'] = False - return create_client(host=host, username=username, password=password, database=database, interface=interface, - port=port, secure=secure, dsn=dsn, settings=settings, generic_args=generic_args, **kwargs) + settings = settings or {} + if generic_args: + client_params = signature(_AsyncClient).parameters + for name, value in generic_args.items(): + if name in client_params: + kwargs[name] = value + elif name == "compression": + if "compress" not in kwargs: + kwargs["compress"] = value + else: + if name.startswith("ch_"): + name = name[3:] + settings[name] = value + + if "autogenerate_session_id" not in kwargs: + kwargs["autogenerate_session_id"] = False - loop = asyncio.get_running_loop() - _client = await loop.run_in_executor(None, _create_client) - return AsyncClient(client=_client, executor_threads=executor_threads, executor=executor) + client = _AsyncClient( + interface=interface, + host=host, + port=port, + username=username, + password=password, + database=database, + access_token=access_token, + settings=settings, + connector_limit=connector_limit, + connector_limit_per_host=connector_limit_per_host, + keepalive_timeout=keepalive_timeout, + **kwargs, + ) + await client._initialize() + return client diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/asyncclient.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/asyncclient.py index 3f046042c96..7e8c80187e7 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/asyncclient.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/asyncclient.py @@ -1,404 +1,1230 @@ +from __future__ import annotations + import asyncio +import gzip import io +import json import logging -import os -from concurrent.futures.thread import ThreadPoolExecutor -from datetime import tzinfo -from typing import Literal, Optional, Union, Dict, Any, Sequence, Iterable, Generator, BinaryIO, TYPE_CHECKING +import re +import ssl +import sys +import time +import uuid +import zlib +import zoneinfo +from base64 import b64encode +from collections.abc import Awaitable, Callable, Generator, Iterable, Sequence +from datetime import timezone, tzinfo +from importlib import import_module +from importlib.metadata import version as dist_version +from typing import TYPE_CHECKING, Any, BinaryIO -from clickhouse_connect.driver.client import Client -from clickhouse_connect.driver.query import TzMode -from clickhouse_connect.driver.common import StreamContext -from clickhouse_connect.driver.httpclient import HttpClient -from clickhouse_connect.driver.external import ExternalData -from clickhouse_connect.driver.query import QueryContext, QueryResult -from clickhouse_connect.driver.summary import QuerySummary -from clickhouse_connect.datatypes.base import ClickHouseType -from clickhouse_connect.driver.insert import InsertContext +import aiohttp +import lz4.frame +import zstandard if TYPE_CHECKING: - import numpy import pandas + import polars import pyarrow +from clickhouse_connect import common +from clickhouse_connect.datatypes import dynamic as dynamic_module +from clickhouse_connect.datatypes.base import ClickHouseType +from clickhouse_connect.datatypes.registry import get_from_name +from clickhouse_connect.driver import httputil, options, tzutil +from clickhouse_connect.driver.asyncqueue import EOF_SENTINEL, AsyncSyncQueue +from clickhouse_connect.driver.binding import bind_query, quote_identifier +from clickhouse_connect.driver.client import Client, _apply_arrow_tz_policy +from clickhouse_connect.driver.common import StreamContext, coerce_bool, dict_copy +from clickhouse_connect.driver.compression import available_compression +from clickhouse_connect.driver.constants import CH_VERSION_WITH_PROTOCOL, PROTOCOL_VERSION_WITH_LOW_CARD +from clickhouse_connect.driver.ctypes import RespBuffCls +from clickhouse_connect.driver.exceptions import DatabaseError, DataError, OperationalError, ProgrammingError +from clickhouse_connect.driver.external import ExternalData +from clickhouse_connect.driver.insert import InsertContext +from clickhouse_connect.driver.models import ColumnDef, SettingDef +from clickhouse_connect.driver.options import check_arrow, check_numpy, check_pandas, check_polars +from clickhouse_connect.driver.query import QueryContext, QueryResult, TzMode, TzSource, arrow_buffer +from clickhouse_connect.driver.streaming import StreamingFileAdapter, StreamingInsertSource, StreamingResponseSource +from clickhouse_connect.driver.summary import QuerySummary +from clickhouse_connect.driver.transform import NativeTransform + logger = logging.getLogger(__name__) +columns_only_re = re.compile(r"LIMIT 0\s*$", re.IGNORECASE) +ex_header = "X-ClickHouse-Exception-Code" +ex_tag_header = "X-ClickHouse-Exception-Tag" +if "br" in available_compression: + import brotli +else: + brotli = None -class DefaultThreadPoolExecutor: - pass +def decompress_response(data: bytes, encoding: str | None) -> bytes: + """Decompress response data based on Content-Encoding header.""" -# Sentinel value to preserve default behavior and also allow passing `None` -NEW_THREAD_POOL_EXECUTOR = DefaultThreadPoolExecutor() + if not encoding or encoding == "identity": + return data + if encoding == "lz4": + lz4_decom = lz4.frame.LZ4FrameDecompressor() + return lz4_decom.decompress(data, len(data)) + if encoding == "zstd": + zstd_decom = zstandard.ZstdDecompressor() + return zstd_decom.stream_reader(io.BytesIO(data)).read() + if encoding == "br": + if brotli is not None: + return brotli.decompress(data) + raise OperationalError("Brotli compression requested but not installed.") + if encoding == "gzip": + return gzip.decompress(data) + if encoding == "deflate": + return zlib.decompress(data) + raise OperationalError(f"Unsupported compression type: '{encoding}'. Supported compression: {', '.join(available_compression)}") -# pylint: disable=too-many-public-methods,too-many-instance-attributes,too-many-arguments,too-many-positional-arguments,too-many-locals -class AsyncClient: - """ - AsyncClient is a wrapper around the ClickHouse Client object that allows for async calls to the ClickHouse server. - Internally, each of the methods that uses IO is wrapped in a call to EventLoop.run_in_executor. - """ - def __init__(self, - *, - client: Client, - executor_threads: int = 0, - executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR): - if isinstance(client, HttpClient): - client.headers['User-Agent'] = client.headers['User-Agent'].replace('mode:sync;', 'mode:async;') - self.client = client - if executor_threads == 0: - executor_threads = min(32, (os.cpu_count() or 1) + 4) # Mimic the default behavior - if executor is NEW_THREAD_POOL_EXECUTOR: - self.new_executor = True - self.executor = ThreadPoolExecutor(max_workers=executor_threads) - else: - if executor_threads != 0: - logger.warning('executor_threads parameter is ignored when passing an executor object') +class BytesSource: + """Wrapper to make bytes compatible with ResponseBuffer expectations.""" + + def __init__(self, data: bytes): + self.data = data + self.gen = self._make_generator() - self.new_executor = False - self.executor = executor + def _make_generator(self): + yield self.data - def set_client_setting(self, key: str, value: Any) -> None: + def close(self): + """No-op close method for compatibility.""" + + +class AsyncClient(Client): + valid_transport_settings = { + "database", + "buffer_size", + "session_id", + "compress", + "decompress", + "session_timeout", + "session_check", + "query_id", + "quota_key", + "wait_end_of_query", + "client_protocol_version", + "role", + } + optional_transport_settings = { + "send_progress_in_http_headers", + "http_headers_progress_interval_ms", + "enable_http_compression", + } + + def __init__( + self, + interface: str, + host: str, + port: int, + username: str | None = None, + password: str | None = None, + database: str | None = None, + access_token: str | None = None, + compress: bool | str = True, + connect_timeout: int = 10, + send_receive_timeout: int = 300, + client_name: str | None = None, + verify: bool | str = True, + ca_cert: str | None = None, + client_cert: str | None = None, + client_cert_key: str | None = None, + http_proxy: str | None = None, + https_proxy: str | None = None, + server_host_name: str | None = None, + tls_mode: str | None = None, + proxy_path: str = "", + connector_limit: int = 100, + connector_limit_per_host: int = 20, + keepalive_timeout: float = 30.0, + session_id: str | None = None, + settings: dict[str, Any] | None = None, + query_limit: int = 0, + query_retries: int = 2, + tz_source: TzSource | None = None, + tz_mode: TzMode | None = None, + show_clickhouse_errors: bool | None = None, + autogenerate_session_id: bool | None = None, + autogenerate_query_id: bool | None = None, + form_encode_query_params: bool = False, + rename_response_column: str | None = None, + ): """ - Set a clickhouse setting for the client after initialization. If a setting is not recognized by ClickHouse, - or the setting is identified as "read_only", this call will either throw a Programming exception or attempt - to send the setting anyway based on the common setting 'invalid_setting_action'. - :param key: ClickHouse setting name - :param value: ClickHouse setting value + Async HTTP Client using aiohttp. Initialization is handled via _initialize(). """ - self.client.set_client_setting(key=key, value=value) + proxy_path = proxy_path.lstrip("/") + if proxy_path: + proxy_path = "/" + proxy_path + self.uri = f"{interface}://{host}:{port}{proxy_path}" + self.url = self.uri + self.form_encode_query_params = form_encode_query_params + self._rename_response_column = rename_response_column + self._initial_settings = settings + self.headers = {} + + if interface == "https": + if isinstance(verify, str) and verify.lower() == "proxy": + verify = True + tls_mode = tls_mode or "proxy" + + # Priority: access_token > mutual TLS > basic auth + if client_cert and (tls_mode is None or tls_mode == "mutual"): + if not username: + raise ProgrammingError("username parameter is required for Mutual TLS authentication") + self.headers["X-ClickHouse-User"] = username + self.headers["X-ClickHouse-SSL-Certificate-Auth"] = "on" + elif access_token: + self.headers["Authorization"] = f"Bearer {access_token}" + elif username and (not client_cert or tls_mode in ("strict", "proxy")): + credentials = b64encode(f"{username}:{password}".encode()).decode() + self.headers["Authorization"] = f"Basic {credentials}" + + self.headers["User-Agent"] = common.build_client_name(client_name) + # Prevent aiohttp from automatically requesting compressed responses + # We'll manually set Accept-Encoding when compression is desired + self.headers["Accept-Encoding"] = "identity" + self._send_receive_timeout = send_receive_timeout + + connect_timeout_val = float(connect_timeout) if connect_timeout is not None else None + send_receive_timeout_val = float(send_receive_timeout) if send_receive_timeout is not None else None + + self._timeout = aiohttp.ClientTimeout( + total=None, + connect=connect_timeout_val, + sock_connect=connect_timeout_val, + sock_read=send_receive_timeout_val, + ) + connector_limit_per_host = min(connector_limit_per_host, connector_limit) + + proxy_url = None + if http_proxy: + if not http_proxy.startswith("http://") and not http_proxy.startswith("https://"): + proxy_url = f"http://{http_proxy}" + else: + proxy_url = http_proxy + elif https_proxy: + if not https_proxy.startswith("http://") and not https_proxy.startswith("https://"): + proxy_url = f"http://{https_proxy}" + else: + proxy_url = https_proxy + else: + scheme = "https" if self.url.startswith("https://") else "http" + env_proxy = httputil.check_env_proxy(scheme, host, port) + if env_proxy: + if not env_proxy.startswith("http://") and not env_proxy.startswith("https://"): + proxy_url = f"http://{env_proxy}" + else: + proxy_url = env_proxy + + ssl_context = None + if interface == "https": + ssl_context = ssl.create_default_context() + ssl_verify = verify if isinstance(verify, bool) else coerce_bool(verify) + if not ssl_verify: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif ca_cert: + ssl_context.load_verify_locations(ca_cert) + if client_cert: + ssl_context.load_cert_chain(client_cert, client_cert_key) - def get_client_setting(self, key: str) -> Optional[str]: + self._ssl_context = ssl_context + self._proxy_url = proxy_url + self._connector_kwargs = { + "limit": connector_limit, + "limit_per_host": connector_limit_per_host, + "keepalive_timeout": keepalive_timeout, + "force_close": False, + "ssl": ssl_context, + } + # enable_cleanup_closed is only needed for Python < 3.12.7 or == 3.13.0 + # The underlying SSL connection leak was fixed in 3.12.7 and 3.13.1+ + # https://github.com/python/cpython/pull/118960 + if sys.version_info < (3, 12, 7) or sys.version_info[:3] == (3, 13, 0): + self._connector_kwargs["enable_cleanup_closed"] = True + + self._session = None + self._read_format = "Native" + self._write_format = "Native" + self._transform = NativeTransform() + self._client_settings = {} + self._initialized = False + self._reported_libs = set() + self._last_pool_reset = None + self.headers["User-Agent"] = self.headers["User-Agent"].replace("mode:sync;", "mode:async;") + + # Store aiohttp-specific params for deferred initialization + self._compress_param = compress + self._session_id_param = session_id + self._autogenerate_session_id_param = autogenerate_session_id + self._autogenerate_query_id = ( + common.get_setting("autogenerate_query_id") if autogenerate_query_id is None else autogenerate_query_id + ) + self._active_session = None + self._send_progress = None + self._progress_interval = None + + # Call parent init with autoconnect=False to set up config without blocking I/O + super().__init__( + database=database, + query_limit=query_limit, + uri=self.uri, + query_retries=query_retries, + server_host_name=server_host_name, + tz_source=tz_source, + tz_mode=tz_mode, + show_clickhouse_errors=show_clickhouse_errors, + autoconnect=False, + ) + + async def _initialize(self): """ - :param key: The setting key - :return: The string value of the setting, if it exists, or None + Async equivalent of Client._init_common_settings. + Fetches server version, timezone, and settings. """ - return self.client.get_client_setting(key=key) + if not self._session: + connector = aiohttp.TCPConnector(**self._connector_kwargs) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=self._timeout, + headers=self.headers, + trust_env=False, + auto_decompress=False, + skip_auto_headers={"Accept-Encoding"}, + ) + + if self._initialized: + return + + try: + tz_source = self._deferred_tz_source + + self.server_tz, self._dst_safe = timezone.utc, True + row = await self.command("SELECT version(), timezone()", use_database=False) + self.server_version, server_tz_str = tuple(row) + try: + server_tz = tzutil.resolve_zone(server_tz_str) + server_tz, self._dst_safe = tzutil.normalize_timezone(server_tz) + self.server_tz = server_tz + except zoneinfo.ZoneInfoNotFoundError: + logger.warning( + "Server timezone %s could not be resolved, falling back to UTC; %s", + server_tz_str, + tzutil.TZDATA_HINT, + ) + if tz_source == "auto": + self._apply_server_tz = self._dst_safe + else: + self._apply_server_tz = tz_source == "server" + + if not self._apply_server_tz and not tzutil.local_tz_dst_safe: + logger.warning("local timezone %s may return unexpected times due to Daylight Savings Time", tzutil.local_tz.tzname(None)) + + readonly = "readonly" + if not self.min_version("19.17"): + readonly = common.get_setting("readonly") + + server_settings = await self.query(f"SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000") + self.server_settings = {row["name"]: SettingDef(**row) for row in server_settings.named_results()} + + if self.min_version(CH_VERSION_WITH_PROTOCOL) and common.get_setting("use_protocol_version"): + try: + test_data = await self.raw_query( + "SELECT 1 AS check", fmt="Native", settings={"client_protocol_version": PROTOCOL_VERSION_WITH_LOW_CARD} + ) + if test_data[8:16] == b"\x01\x01\x05check": + self.protocol_version = PROTOCOL_VERSION_WITH_LOW_CARD + except Exception: + pass - def set_access_token(self, access_token: str) -> None: + cancel_setting = self._setting_status("cancel_http_readonly_queries_on_client_close") + if ( + cancel_setting.is_writable + and not cancel_setting.is_set + and "cancel_http_readonly_queries_on_client_close" not in (self._initial_settings or {}) + ): + self._client_settings["cancel_http_readonly_queries_on_client_close"] = "1" + + if self._initial_settings: + for key, value in self._initial_settings.items(): + self.set_client_setting(key, value) + + compress = self._compress_param + if coerce_bool(compress): + compression = ",".join(available_compression) + self.write_compression = available_compression[0] + elif compress and compress not in ("False", "false", "0"): + if compress not in available_compression: + raise ProgrammingError(f"Unsupported compression method {compress}") + compression = compress + self.write_compression = compress + else: + compression = None + + comp_setting = self._setting_status("enable_http_compression") + self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable + if comp_setting.is_set or comp_setting.is_writable: + self.compression = compression + + session_id = self._session_id_param + autogenerate_session_id = self._autogenerate_session_id_param + + if autogenerate_session_id is None: + autogenerate_session_id = common.get_setting("autogenerate_session_id") + + if session_id: + self.set_client_setting("session_id", session_id) + elif self.get_client_setting("session_id"): + pass + elif autogenerate_session_id: + self.set_client_setting("session_id", str(uuid.uuid4())) + + send_setting = self._setting_status("send_progress_in_http_headers") + self._send_progress = not send_setting.is_set and send_setting.is_writable + if (send_setting.is_set or send_setting.is_writable) and self._setting_status("http_headers_progress_interval_ms").is_writable: + self._progress_interval = str(min(120000, max(10000, (self._send_receive_timeout - 5) * 1000))) + + if self._setting_status("date_time_input_format").is_writable: + self.set_client_setting("date_time_input_format", "best_effort") + if ( + self._setting_status("allow_experimental_json_type").is_set + and self._setting_status("cast_string_to_dynamic_use_inference").is_writable + ): + self.set_client_setting("cast_string_to_dynamic_use_inference", "1") + if self.min_version("24.8") and not self.min_version("24.10"): + dynamic_module.json_serialization_format = 0 + + self._initialized = True + except Exception: + if self._session and not self._session.closed: + await self._session.close() + self._session = None + raise + + async def __aenter__(self): + """Async context manager entry.""" + if not self._initialized: + await self._initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + return False + + async def close(self): # type: ignore[override] + if self._session: + await self._session.close() + + async def close_connections(self): # type: ignore[override] + """Close all pooled connections and recreate session""" + if self._session: + await self._session.close() + connector = aiohttp.TCPConnector(**self._connector_kwargs) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=self._timeout, + headers=self.headers, + trust_env=False, + auto_decompress=False, + skip_auto_headers={"Accept-Encoding"}, + ) + + def set_client_setting(self, key, value): + str_value = self._validate_setting(key, value, common.get_setting("invalid_setting_action")) + if str_value is not None: + self._client_settings[key] = str_value + + def get_client_setting(self, key) -> str | None: + return self._client_settings.get(key) + + def set_access_token(self, access_token: str): + auth_header = self.headers.get("Authorization") + if auth_header and not auth_header.startswith("Bearer"): + raise ProgrammingError("Cannot set access token when a different auth type is used") + self.headers["Authorization"] = f"Bearer {access_token}" + if self._session: + self._session.headers["Authorization"] = f"Bearer {access_token}" + + def _prep_query(self, context: QueryContext): + final_query = super()._prep_query(context) + if context.is_insert: + return final_query + fmt = f"\n FORMAT {self._read_format}" + if isinstance(final_query, bytes): + return final_query + fmt.encode() + return final_query + fmt + + async def _query_with_context(self, context: QueryContext) -> QueryResult: # type: ignore[override] + headers = {} + params = {} + if self.database: + params["database"] = self.database + if self.protocol_version: + params["client_protocol_version"] = self.protocol_version + context.block_info = True + params.update(self._validate_settings(context.settings)) + context.rename_response_column = self._rename_response_column + + if not context.is_insert and columns_only_re.search(context.uncommented_query): + fmt_json_query = f"{context.final_query}\n FORMAT JSON" + fields = {"query": fmt_json_query} + fields.update(context.bind_params) + + if self.form_encode_query_params: + files = {} + if context.external_data: + params.update(context.external_data.query_params) + files.update(context.external_data.form_data) + + for k, v in fields.items(): + files[k] = (None, str(v)) + response = await self._raw_request(None, params, headers, files=files, retries=self.query_retries) + elif context.external_data: + params.update(context.bind_params) + params.update(context.external_data.query_params) + params["query"] = fmt_json_query + response = await self._raw_request(None, params, headers, files=context.external_data.form_data, retries=self.query_retries) + else: + params.update(context.bind_params) + response = await self._raw_request(fmt_json_query, params, headers, retries=self.query_retries) + + body = await response.read() + encoding = response.headers.get("Content-Encoding") + loop = asyncio.get_running_loop() + + def decompress_and_parse_json(): + if encoding: + decompressed_body = decompress_response(body, encoding) + else: + decompressed_body = body + return json.loads(decompressed_body) + + # Offload to executor + json_result = await loop.run_in_executor(None, decompress_and_parse_json) + + names: list[str] = [] + types: list[ClickHouseType] = [] + renamer = context.column_renamer + for col in json_result["meta"]: + name = col["name"] + if renamer is not None: + try: + name = renamer(name) + except Exception as e: + logger.debug("Failed to rename col '%s'. Skipping rename. Error: %s", name, e) + names.append(name) + types.append(get_from_name(col["type"])) + return QueryResult([], None, tuple(names), tuple(types)) + + if self.compression: + headers["Accept-Encoding"] = self.compression + if self._send_comp_setting: + params["enable_http_compression"] = "1" + + final_query = self._prep_query(context) + + files = None + data = None + + if self.form_encode_query_params: + fields = {"query": final_query} + fields.update(context.bind_params) + + files = {} + if context.external_data: + params.update(context.external_data.query_params) + files.update(context.external_data.form_data) + + for k, v in fields.items(): + files[k] = (None, str(v)) + elif context.external_data: + params.update(context.bind_params) + params.update(context.external_data.query_params) + params["query"] = final_query + files = context.external_data.form_data + else: + params.update(context.bind_params) + data = final_query + headers["Content-Type"] = "text/plain; charset=utf-8" + + headers = dict_copy(headers, context.transport_settings) + + response = await self._raw_request( + data, + params, + headers, + files=files, + server_wait=not context.streaming, + stream=True, + retries=self.query_retries, + ) + encoding = response.headers.get("Content-Encoding") + tz_header = response.headers.get("X-ClickHouse-Timezone") + exception_tag = response.headers.get(ex_tag_header) + + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) + + def parse_streaming(): + """Parse response from streaming queue (runs in executor).""" + # Wrap streaming source with ResponseBuffer. The streaming source provides a + # .gen property that yields decompressed chunks. + byte_source = RespBuffCls(streaming_source) + context.set_response_tz(self._check_tz_change(tz_header)) + result = self._transform.parse_response(byte_source, context) + + # For Pandas/Numpy, we must materialize in the executor because the resulting objects + # (DataFrame, Array) are fully in-memory structures. + # For standard queries, we return a lazy QueryResult. Accessing .result_set on the event loop + # will raise a ProgrammingError (deadlock check), encouraging usage of .rows_stream. + if not context.streaming: + if context.as_pandas and hasattr(result, "df_result"): + _ = result.df_result + elif context.use_numpy and hasattr(result, "np_result"): + _ = result.np_result + elif isinstance(result, QueryResult): + _ = result.result_set + + return result + + # Run parser in executor (pulls from queue, decompresses & parses) + try: + query_result = await loop.run_in_executor(None, parse_streaming) + except Exception: + await streaming_source.aclose() + raise + query_result.summary = self._summary(response) + + # Attach streaming_source to query_result.source to ensure it gets closed + # when the query result is closed (e.g. by StreamContext.__exit__) + query_result.source = streaming_source + + return query_result + + async def query( # type: ignore[override] + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + column_oriented: bool | None = None, + use_numpy: bool | None = None, + max_str_len: int | None = None, + context: QueryContext | None = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> QueryResult: """ - Set the ClickHouse access token for the client - :param access_token: Access token string + Main query method for SELECT, DESCRIBE and other SQL statements that return a result matrix. For + parameters, see the create_query_context method + :return: QueryResult -- data and metadata from response """ - self.client.set_access_token(access_token) + if query and query.lower().strip().startswith("select __connect_version__"): + return QueryResult( + [[f"ClickHouse Connect v.{common.version()} ⓒ ClickHouse Inc."]], None, ("connect_version",), (get_from_name("String"),) + ) + if not context: + context = self.create_query_context( + query=query, + parameters=parameters, + settings=settings, + query_formats=query_formats, + column_formats=column_formats, + encoding=encoding, + use_none=use_none, + column_oriented=column_oriented, + use_numpy=use_numpy, + max_str_len=max_str_len, + query_tz=query_tz, + column_tzs=column_tzs, + external_data=external_data, + transport_settings=transport_settings, + tz_mode=tz_mode, + ) + + if context.is_command: + response = await self.command( + query, + parameters=context.parameters, + settings=context.settings, + external_data=context.external_data, + transport_settings=context.transport_settings, + ) + if isinstance(response, QuerySummary): + return response.as_query_result() + return QueryResult([response] if isinstance(response, list) else [[response]]) + + return await self._query_with_context(context) - def min_version(self, version_str: str) -> bool: + async def query_column_block_stream( # type: ignore[override] + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + context: QueryContext | None = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> StreamContext: """ - Determine whether the connected server is at least the submitted version - For Altinity Stable versions like 22.8.15.25.altinitystable - the last condition in the first list comprehension expression is added - :param version_str: A version string consisting of up to 4 integers delimited by dots - :return: True if version_str is greater than the server_version, False if less than + Async version of query_column_block_stream. + Returns a StreamContext that yields column-oriented blocks. """ - return self.client.min_version(version_str) + return (await self._context_query(locals(), use_numpy=False, streaming=True)).column_block_stream - async def close(self) -> None: + async def query_row_block_stream( # type: ignore[override] + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + context: QueryContext | None = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> StreamContext: """ - Subclass implementation to close the connection to the server/deallocate the client + Async version of query_row_block_stream. + Returns a StreamContext that yields row-oriented blocks. """ - self.client.close() - - if self.new_executor: - await asyncio.to_thread(self.executor.shutdown, True) + return (await self._context_query(locals(), use_numpy=False, streaming=True)).row_block_stream - async def query(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> QueryResult: + async def query_rows_stream( # type: ignore[override] + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + context: QueryContext | None = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> StreamContext: """ - Main query method for SELECT, DESCRIBE and other SQL statements that return a result matrix. - For parameters, see the create_query_context method. - :return: QueryResult -- data and metadata from response + Async version of query_rows_stream. + Returns a StreamContext that yields individual rows. """ + return (await self._context_query(locals(), use_numpy=False, streaming=True)).rows_stream - def _query(): - return self.client.query(query=query, parameters=parameters, settings=settings, query_formats=query_formats, - column_formats=column_formats, encoding=encoding, use_none=use_none, - column_oriented=column_oriented, use_numpy=use_numpy, max_str_len=max_str_len, - context=context, query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) + async def query_np( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + max_str_len: int | None = None, + context: QueryContext | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ): + check_numpy() + self._add_integration_tag("numpy") + return (await self._context_query(locals(), use_numpy=True)).np_result - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query) - return result + async def query_np_stream( # type: ignore[override] + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + max_str_len: int | None = None, + context: QueryContext | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> StreamContext: + check_numpy() + self._add_integration_tag("numpy") + return (await self._context_query(locals(), use_numpy=True, streaming=True)).np_stream - async def query_column_block_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None, - ) -> StreamContext: + async def query_df( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + max_str_len: int | None = None, + use_na_values: bool | None = None, + query_tz: str | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + context: QueryContext | None = None, + external_data: ExternalData | None = None, + use_extended_dtypes: bool | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ): + check_pandas() + self._add_integration_tag("pandas") + return (await self._context_query(locals(), use_numpy=True, as_pandas=True)).df_result + + async def query_df_stream( # type: ignore[override] + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + max_str_len: int | None = None, + use_na_values: bool | None = None, + query_tz: str | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + context: QueryContext | None = None, + external_data: ExternalData | None = None, + use_extended_dtypes: bool | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> StreamContext: + check_pandas() + self._add_integration_tag("pandas") + return (await self._context_query(locals(), use_numpy=True, as_pandas=True, streaming=True)).df_stream + + async def _context_query(self, lcls: dict, **overrides): # type: ignore[override] + """ + Helper method to create query context and execute query. + Matches sync client pattern for consistency. """ - Variation of main query method that returns a stream of column oriented blocks. - For parameters, see the create_query_context method. - :return: StreamContext -- Iterable stream context that returns column oriented blocks + kwargs = lcls.copy() + kwargs.pop("self") + kwargs.update(overrides) + return await self._query_with_context(self.create_query_context(**kwargs)) + + async def command( # type: ignore[override] + self, + cmd, + parameters: Sequence | dict[str, Any] | None = None, + data: str | bytes | None = None, + settings: dict | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> str | int | Sequence[str] | QuerySummary: + """ + See BaseClient doc_string for this method """ + cmd, bind_params = bind_query(cmd, parameters, self.server_tz) + params = bind_params.copy() + headers = {} + payload = None + files = None - def _query_column_block_stream(): - return self.client.query_column_block_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) + if external_data: + if data: + raise ProgrammingError("Cannot combine command data with external data") from None + files = external_data.form_data + params.update(external_data.query_params) + elif isinstance(data, str): + headers["Content-Type"] = "text/plain; charset=utf-8" + payload = data.encode() + elif isinstance(data, bytes): + headers["Content-Type"] = "application/octet-stream" + payload = data + + if payload is None and not cmd: + raise ProgrammingError("Command sent without query or recognized data") from None + + if payload or files: + params["query"] = cmd + else: + payload = cmd + + if use_database and self.database: + params["database"] = self.database + params.update(self._validate_settings(settings or {})) + headers = dict_copy(headers, transport_settings) + method = "POST" if payload or files else "GET" + response = await self._raw_request(payload, params, headers, files=files, method=method, server_wait=False) + body = await response.read() + encoding = response.headers.get("Content-Encoding") + summary = self._summary(response) + + if not body: + return QuerySummary(summary) loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_column_block_stream) - return result - async def query_row_block_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: + def decompress_and_decode(): + if encoding: + decompressed_body = decompress_response(body, encoding) + else: + decompressed_body = body + try: + result = decompressed_body.decode()[:-1].split("\t") + if len(result) == 1: + try: + return int(result[0]) + except ValueError: + return result[0] + return result + except UnicodeDecodeError: + return str(decompressed_body) + + return await loop.run_in_executor(None, decompress_and_decode) + + async def ping(self) -> bool: # type: ignore[override] + try: + url = f"{self.url}/ping" + timeout = aiohttp.ClientTimeout(total=3.0) + async with self._session.get(url, timeout=timeout) as response: + return 200 <= response.status < 300 + except (aiohttp.ClientError, asyncio.TimeoutError): + logger.debug("ping failed", exc_info=True) + return False + + async def raw_query( # type: ignore[override] + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> bytes: """ - Variation of main query method that returns a stream of row oriented blocks. - For parameters, see the create_query_context method. - :return: StreamContext -- Iterable stream context that returns blocks of rows + See BaseClient doc_string for this method """ + body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) + if transport_settings: + headers = dict_copy(headers, transport_settings) - def _query_row_block_stream(): - return self.client.query_row_block_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) + response = await self._raw_request(body, params, headers=headers, files=files, retries=self.query_retries) + response_data = await response.read() + encoding = response.headers.get("Content-Encoding") - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_row_block_stream) - return result + if encoding: + loop = asyncio.get_running_loop() + response_data = await loop.run_in_executor(None, decompress_response, response_data, encoding) + + return response_data - async def query_rows_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: + async def raw_stream( # type: ignore[override] + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> StreamContext: + + body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, stream=True, server_wait=False, retries=self.query_retries + ) + + async def byte_iterator(): + async for chunk in response.content.iter_any(): + yield chunk + + return StreamContext(response, byte_iterator()) + + def _prep_raw_query(self, query, parameters, settings, fmt, use_database, external_data): """ - Variation of main query method that returns a stream of row oriented blocks. - For parameters, see the create_query_context method. - :return: StreamContext -- Iterable stream context that returns blocks of rows + Prepare raw query for execution. + + Note: Unlike sync client which returns (body, params, fields), this async version + returns (body, params, headers, files) because aiohttp requires headers to be + configured before the request() call, while urllib3 can add them during request. """ + if fmt: + query += f"\n FORMAT {fmt}" - def _query_rows_stream(): - return self.client.query_rows_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) + final_query, bind_params = bind_query(query, parameters, self.server_tz) + params = self._validate_settings(settings or {}) + if use_database and self.database: + params["database"] = self.database - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_rows_stream) - return result + headers = {} + files = None + body = None + + if external_data and not self.form_encode_query_params and isinstance(final_query, bytes): + raise ProgrammingError("Binary query cannot be placed in URL when using External Data; enable form encoding.") + + if self.form_encode_query_params: + files = {} + files["query"] = (None, final_query if isinstance(final_query, str) else final_query.decode()) + for k, v in bind_params.items(): + files[k] = (None, str(v)) - async def raw_query(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: str = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> bytes: + if external_data: + params.update(external_data.query_params) + files.update(external_data.form_data) + + body = None + elif external_data: + params.update(bind_params) + params["query"] = final_query + params.update(external_data.query_params) + files = external_data.form_data + body = None + else: + params.update(bind_params) + body = final_query.encode() if isinstance(final_query, str) else final_query + + return body, params, headers, files + + async def insert( # type: ignore[override] + self, + table: str | None = None, + data: Sequence[Sequence[Any]] | None = None, + column_names: str | Iterable[str] = "*", + database: str | None = None, + column_types: Sequence[ClickHouseType] | None = None, + column_type_names: Sequence[str] | None = None, + column_oriented: bool = False, + settings: dict[str, Any] | None = None, + context: InsertContext | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: + """ + Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments + other than data are ignored + :param table: Target table + :param data: Sequence of sequences of Python data + :param column_names: Ordered list of column names or '*' if column types should be retrieved from the + ClickHouse table definition + :param database: Target database -- will use client default database if not specified. + :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from + the server + :param column_type_names: ClickHouse column type names. If set then column data does not need to be + retrieved from the server + :param column_oriented: If true the data is already "pivoted" in column form + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param context: Optional reusable insert context to allow repeated inserts into the same table with + different data batches + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: QuerySummary with summary information, throws exception if insert fails + """ + if (context is None or context.empty) and data is None: + raise ProgrammingError("No data specified for insert") from None + if context is None: + context = await self.create_insert_context( + table, + column_names, + database, + column_types, + column_type_names, + column_oriented, + settings, + transport_settings=transport_settings, + ) + if data is not None: + if not context.empty: + raise ProgrammingError("Attempting to insert new data with non-empty insert context") from None + context.data = data + return await self.data_insert(context) + + async def query_arrow( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + use_strings: bool | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ): """ - Query method that simply returns the raw ClickHouse format bytes. + Query method using the ClickHouse Arrow format to return a PyArrow table :param query: Query statement/format string :param parameters: Optional dictionary used to format the query :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param fmt: ClickHouse output format - :param use_database Send the database parameter to ClickHouse so the command will be executed in the client - database context - :param external_data External data to send with the query + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: bytes representing raw ClickHouse return value based on format + :return: PyArrow.Table """ + check_arrow() + self._add_integration_tag("arrow") + settings = self._update_arrow_settings(settings, use_strings) - def _raw_query(): - return self.client.raw_query(query=query, parameters=parameters, settings=settings, fmt=fmt, - use_database=use_database, external_data=external_data, - transport_settings=transport_settings) + body, params, headers, files = self._prep_raw_query( + query, + parameters, + settings, + fmt="ArrowStream", + use_database=True, + external_data=external_data, + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, + params, + headers=headers, + files=files, + stream=True, + server_wait=False, + retries=self.query_retries, + ) + encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _raw_query) - return result + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) + + def parse_arrow_stream(): + file_adapter = StreamingFileAdapter(streaming_source) + reader = options.arrow.ipc.open_stream(file_adapter) + table = reader.read_all() + return _apply_arrow_tz_policy(table, self.tz_mode) + + try: + return await loop.run_in_executor(None, parse_arrow_stream) + finally: + await streaming_source.aclose() - async def raw_stream(self, query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: str = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase: + async def query_arrow_stream( # type: ignore[override] + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + use_strings: bool | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> StreamContext: """ - Query method that returns the result as an io.IOBase iterator. + Query method that returns the results as a stream of Arrow record batches. + :param query: Query statement/format string :param parameters: Optional dictionary used to format the query :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param fmt: ClickHouse output format - :param use_database Send the database parameter to ClickHouse so the command will be executed in the client - database context - :param external_data External data to send with the query + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: io.IOBase stream/iterator for the result + :return: StreamContext that yields PyArrow RecordBatch objects asynchronously """ + check_arrow() + self._add_integration_tag("arrow") + settings = self._update_arrow_settings(settings, use_strings) - def _raw_stream(): - return self.client.raw_stream(query=query, parameters=parameters, settings=settings, fmt=fmt, - use_database=use_database, external_data=external_data, transport_settings=transport_settings) + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _raw_stream) - return result + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) - async def query_np(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> 'numpy.ndarray': - """ - Query method that returns the results as a numpy array. - For parameter values, see the create_query_context method. - :return: Numpy array representing the result set - """ + queue = AsyncSyncQueue(maxsize=10) - def _query_np(): - return self.client.query_np(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, encoding=encoding, - use_none=use_none, max_str_len=max_str_len, context=context, - external_data=external_data, transport_settings=transport_settings) + class _ArrowStreamSource: + def __init__(self, source, q): + self._source = source + self._queue = q - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_np) - return result + async def aclose(self): + self._queue.shutdown() + await self._source.aclose() - async def query_np_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: - """ - Query method that returns the results as a stream of numpy arrays. - For parameter values, see the create_query_context method. - :return: Generator that yield a numpy array per block representing the result set - """ + def close(self): + self._queue.shutdown() + self._source.close() - def _query_np_stream(): - return self.client.query_np_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, max_str_len=max_str_len, - context=context, external_data=external_data, transport_settings=transport_settings) + def parse_arrow_streaming(): + """Parse Arrow stream incrementally in executor (off event loop).""" + try: + file_adapter = StreamingFileAdapter(streaming_source) + reader = options.arrow.ipc.open_stream(file_adapter) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_np_stream) - return result + for batch in reader: + try: + batch = _apply_arrow_tz_policy(batch, self.tz_mode) + queue.sync_q.put(batch) + except RuntimeError: + return - async def query_df(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> 'pandas.DataFrame': - """ - Query method that results the results as a pandas dataframe. - For parameter values, see the create_query_context method. - :return: Pandas dataframe representing the result set - """ + try: + queue.sync_q.put(EOF_SENTINEL) + except RuntimeError: + return + except Exception as e: + try: + queue.sync_q.put(e) + except Exception: + pass + finally: + queue.shutdown() - def _query_df(): - return self.client.query_df(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, encoding=encoding, - use_none=use_none, max_str_len=max_str_len, use_na_values=use_na_values, - query_tz=query_tz, column_tzs=column_tzs, tz_mode=tz_mode, - utc_tz_aware=utc_tz_aware, context=context, - external_data=external_data, use_extended_dtypes=use_extended_dtypes, - transport_settings=transport_settings) + loop.run_in_executor(None, parse_arrow_streaming) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df) - return result + async def arrow_batch_generator(): + """Async generator that yields record batches without blocking event loop.""" + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item + + return StreamContext(_ArrowStreamSource(streaming_source, queue), arrow_batch_generator()) async def query_df_arrow( self, query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + use_strings: bool | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, dataframe_library: str = "pandas", - ) -> Union["pd.DataFrame", "pl.DataFrame"]: + ) -> pandas.DataFrame | polars.DataFrame: """ Query method using the ClickHouse Arrow format to return a DataFrame with PyArrow dtype backend. This provides better performance and memory efficiency @@ -413,73 +1239,51 @@ class AsyncClient: :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") :return: DataFrame (pandas or polars based on dataframe_library parameter) """ + check_arrow() - def _query_df_arrow(): - return self.client.query_df_arrow( - query=query, - parameters=parameters, - settings=settings, - use_strings=use_strings, - external_data=external_data, - transport_settings=transport_settings, - dataframe_library=dataframe_library - ) + if dataframe_library == "pandas": + check_pandas() + self._add_integration_tag("pandas") - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df_arrow) - return result + def converter(table: pyarrow.Table) -> pandas.DataFrame: + table = _apply_arrow_tz_policy(table, self.tz_mode) + return table.to_pandas(types_mapper=options.pd.ArrowDtype, safe=False) - async def query_df_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: - """ - Query method that returns the results as a StreamContext. - For parameter values, see the create_query_context method. - :return: Generator that yields a Pandas dataframe per block representing the result set - """ + elif dataframe_library == "polars": + check_polars() + self._add_integration_tag("polars") - def _query_df_stream(): - return self.client.query_df_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, - use_none=use_none, max_str_len=max_str_len, use_na_values=use_na_values, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, context=context, - external_data=external_data, use_extended_dtypes=use_extended_dtypes, - transport_settings=transport_settings) + def converter(table: pyarrow.Table) -> polars.DataFrame: + table = _apply_arrow_tz_policy(table, self.tz_mode) + return options.pl.from_arrow(table) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df_stream) - return result + else: + raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") + + arrow_table = await self.query_arrow( + query=query, + parameters=parameters, + settings=settings, + use_strings=use_strings, + external_data=external_data, + transport_settings=transport_settings, + ) - async def query_df_arrow_stream( + return converter(arrow_table) + + async def query_df_arrow_stream( # type: ignore[override] self, query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas" + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + use_strings: bool | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + dataframe_library: str = "pandas", ) -> StreamContext: """ Query method that returns the results as a stream of DataFrames with PyArrow dtype backend. - Each DataFrame represents a block from the ClickHouse response. + Each DataFrame represents a record batch from the ClickHouse response. :param query: Query statement/format string :param parameters: Optional dictionary used to format the query @@ -488,273 +1292,108 @@ class AsyncClient: :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") - :return: StreamContext that yields DataFrames (pandas or polars based on dataframe_library parameter) + :return: StreamContext that yields DataFrames asynchronously (pandas or polars based on dataframe_library parameter) """ + check_arrow() + if dataframe_library == "pandas": + check_pandas() + self._add_integration_tag("pandas") - def _query_df_arrow_stream(): - return self.client.query_df_arrow_stream( - query=query, - parameters=parameters, - settings=settings, - use_strings=use_strings, - external_data=external_data, - transport_settings=transport_settings, - dataframe_library=dataframe_library - ) + def converter(table: pyarrow.Table) -> pandas.DataFrame: + table = _apply_arrow_tz_policy(table, self.tz_mode) + return table.to_pandas(types_mapper=options.pd.ArrowDtype, safe=False) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df_arrow_stream) - return result + elif dataframe_library == "polars": + check_polars() + self._add_integration_tag("polars") - def create_query_context(self, - query: Optional[Union[str, bytes]] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = False, - max_str_len: Optional[int] = 0, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - tz_mode: Optional[TzMode] = None, - use_na_values: Optional[bool] = None, - streaming: bool = False, - as_pandas: bool = False, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None) -> QueryContext: - """ - Creates or updates a reusable QueryContext object - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param query_formats: See QueryContext __init__ docstring - :param column_formats: See QueryContext __init__ docstring - :param encoding: See QueryContext __init__ docstring - :param use_none: Use None for ClickHouse NULL instead of default values. Note that using None in Numpy - arrays will force the numpy array dtype to 'object', which is often inefficient. This effect also - will impact the performance of Pandas dataframes. - :param column_oriented: Deprecated. Controls orientation of the QueryResult result_set property - :param use_numpy: Return QueryResult columns as one-dimensional numpy arrays - :param max_str_len: Limit returned ClickHouse String values to this length, which allows a Numpy - structured array even with ClickHouse variable length String columns. If 0, Numpy arrays for - String columns will always be object arrays - :param context: An existing QueryContext to be updated with any provided parameter values - :param query_tz Either a string or a pytz tzinfo object. (Strings will be converted to tzinfo objects). - Values for any DateTime or DateTime64 column in the query will be converted to Python datetime.datetime - objects with the selected timezone - :param column_tzs A dictionary of column names to tzinfo objects (or strings that will be converted to - tzinfo objects). The timezone will be applied to datetime objects returned in the query - :param use_na_values: Deprecated alias for use_advanced_dtypes - :param as_pandas Return the result columns as pandas.Series objects - :param streaming Marker used to correctly configure streaming queries - :param external_data ClickHouse "external data" to send with query - :param use_extended_dtypes: Only relevant to Pandas Dataframe queries. Use Pandas "missing types", such as - pandas.NA and pandas.NaT for ClickHouse NULL values, as well as extended Pandas dtypes such as IntegerArray - and StringArray. Defaulted to True for query_df methods - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Reusable QueryContext - """ + def converter(table: pyarrow.Table) -> polars.DataFrame: + table = _apply_arrow_tz_policy(table, self.tz_mode) + return options.pl.from_arrow(table) - return self.client.create_query_context(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, - column_oriented=column_oriented, - use_numpy=use_numpy, max_str_len=max_str_len, context=context, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - use_na_values=use_na_values, - streaming=streaming, as_pandas=as_pandas, - external_data=external_data, - use_extended_dtypes=use_extended_dtypes, - transport_settings=transport_settings) + else: + raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") + settings = self._update_arrow_settings(settings, use_strings) - async def query_arrow(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> 'pyarrow.Table': - """ - Query method using the ClickHouse Arrow format to return a PyArrow table - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data ClickHouse "external data" to send with query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: PyArrow.Table - """ + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) - def _query_arrow(): - return self.client.query_arrow(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings) + response = await self._raw_request( + body, params, headers=headers, files=files, stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_arrow) - return result + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) - async def query_arrow_stream(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: - """ - Query method that returns the results as a stream of Arrow tables - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data ClickHouse "external data" to send with query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Generator that yields a PyArrow.Table for per block representing the result set - """ - - def _query_arrow_stream(): - return self.client.query_arrow_stream(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings) + queue = AsyncSyncQueue(maxsize=10) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_arrow_stream) - return result + class _ArrowDFStreamSource: + def __init__(self, source, q): + self._source = source + self._queue = q - async def command(self, - cmd: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - data: Union[str, bytes] = None, - settings: Optional[Dict[str, Any]] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> Union[str, int, Sequence[str], QuerySummary]: - """ - Client method that returns a single value instead of a result set - :param cmd: ClickHouse query/command as a python format string - :param parameters: Optional dictionary of key/values pairs to be formatted - :param data: Optional 'data' for the command (for INSERT INTO in particular) - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_database: Send the database parameter to ClickHouse so the command will be executed in the client - database context. Otherwise, no database will be specified with the command. This is useful for determining - the default user database - :param external_data ClickHouse "external data" to send with command/query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Decoded response from ClickHouse as either a string, int, or sequence of strings, or QuerySummary - if no data returned - """ + async def aclose(self): + self._queue.shutdown() + await self._source.aclose() - def _command(): - return self.client.command(cmd=cmd, parameters=parameters, data=data, settings=settings, - use_database=use_database, external_data=external_data, - transport_settings=transport_settings) + def close(self): + self._queue.shutdown() + self._source.close() - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _command) - return result + def parse_and_convert_streaming(): + """Parse Arrow stream and convert to DataFrames in executor (off event loop).""" + try: + file_adapter = StreamingFileAdapter(streaming_source) - async def ping(self) -> bool: - """ - Validate the connection, does not throw an Exception (see debug logs) - :return: ClickHouse server is up and reachable - """ + # PyArrow reads incrementally from adapter (which pulls from queue) + reader = options.arrow.ipc.open_stream(file_adapter) - def _ping(): - return self.client.ping() + for batch in reader: + try: + queue.sync_q.put(converter(batch)) + except RuntimeError: + return - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _ping) - return result + try: + queue.sync_q.put(EOF_SENTINEL) + except RuntimeError: + return + except Exception as e: + try: + queue.sync_q.put(e) + except Exception: + pass + finally: + queue.shutdown() - async def insert(self, - table: Optional[str] = None, - data: Sequence[Sequence[Any]] = None, - column_names: Union[str, Iterable[str]] = '*', - database: Optional[str] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - context: InsertContext = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: - """ - Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments - other than data are ignored - :param table: Target table - :param data: Sequence of sequences of Python data - :param column_names: Ordered list of column names or '*' if column types should be retrieved from the - ClickHouse table definition - :param database: Target database -- will use client default database if not specified. - :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from - the server - :param column_type_names: ClickHouse column type names. If set then column data does not need to be - retrieved from the server - :param column_oriented: If true the data is already "pivoted" in column form - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param context: Optional reusable insert context to allow repeated inserts into the same table with - different data batches - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails - """ + loop.run_in_executor(None, parse_and_convert_streaming) - def _insert(): - return self.client.insert(table=table, data=data, column_names=column_names, database=database, - column_types=column_types, column_type_names=column_type_names, - column_oriented=column_oriented, settings=settings, context=context, - transport_settings=transport_settings) + async def df_generator(): + """Async generator that yields DataFrames without blocking event loop.""" + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert) - return result + return StreamContext(_ArrowDFStreamSource(streaming_source, queue), df_generator()) - async def insert_df(self, table: str = None, - df=None, - database: Optional[str] = None, - settings: Optional[Dict] = None, - column_names: Optional[Sequence[str]] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - context: InsertContext = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: - """ - Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored - :param table: ClickHouse table - :param df: two-dimensional pandas dataframe - :param database: Optional ClickHouse database - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param column_names: An optional list of ClickHouse column names. If not set, the DataFrame column names - will be used - :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from - the server - :param column_type_names: ClickHouse column type names. If set then column data does not need to be - retrieved from the server - :param context: Optional reusable insert context to allow repeated inserts into the same table with - different data batches - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails - """ - - def _insert_df(): - return self.client.insert_df(table=table, df=df, database=database, settings=settings, - column_names=column_names, - column_types=column_types, column_type_names=column_type_names, - context=context, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert_df) - return result - - async def insert_arrow(self, table: str, - arrow_table, database: str = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + async def insert_arrow( # type: ignore[override] + self, + table: str, + arrow_table, + database: str | None = None, + settings: dict | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: """ Insert a PyArrow table DataFrame into ClickHouse using raw Arrow format :param table: ClickHouse table @@ -762,24 +1401,23 @@ class AsyncClient: :param database: Optional ClickHouse database :param settings: Optional dictionary of ClickHouse settings (key/string values) :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails """ + check_arrow() + self._add_integration_tag("arrow") + full_table = table if "." in table or not database else f"{database}.{table}" + compression = self.write_compression if self.write_compression in ("zstd", "lz4") else None + column_names, insert_block = arrow_buffer(arrow_table, compression) + if hasattr(insert_block, "to_pybytes"): + insert_block = insert_block.to_pybytes() + return await self.raw_insert(full_table, column_names, insert_block, settings, "Arrow", transport_settings) - def _insert_arrow(): - return self.client.insert_arrow(table=table, arrow_table=arrow_table, database=database, - settings=settings, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert_arrow) - return result - - async def insert_df_arrow( + async def insert_df_arrow( # type: ignore[override] self, table: str, - df: Union["pd.DataFrame", "pl.DataFrame"], - database: Optional[str] = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None, + df: pandas.DataFrame | polars.DataFrame, + database: str | None = None, + settings: dict | None = None, + transport_settings: dict[str, str] | None = None, ) -> QuerySummary: """ Insert a pandas DataFrame with PyArrow backend or a polars DataFrame into ClickHouse using Arrow format. @@ -796,30 +1434,54 @@ class AsyncClient: :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: QuerySummary with summary information, throws exception if insert fails """ + check_arrow() - def _insert_df_arrow(): - return self.client.insert_df_arrow( - table=table, - df=df, - database=database, - settings=settings, - transport_settings=transport_settings, - ) + if options.pd is not None and isinstance(df, options.pd.DataFrame): + df_lib = "pandas" + elif options.pl is not None and isinstance(df, options.pl.DataFrame): + df_lib = "polars" + else: + if options.pd is None and options.pl is None: + raise ImportError("A DataFrame library (pandas or polars) must be installed to use insert_df_arrow.") + raise TypeError(f"df must be either a pandas DataFrame or polars DataFrame, got {type(df).__name__}") - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert_df_arrow) - return result + if df_lib == "pandas": + non_arrow_cols = [col for col, dtype in df.dtypes.items() if not isinstance(dtype, options.pd.ArrowDtype)] + if non_arrow_cols: + raise ProgrammingError( + f"insert_df_arrow requires all columns to use PyArrow dtypes. Non-Arrow columns found: [{', '.join(non_arrow_cols)}]. " + ) + try: + arrow_table = options.arrow.Table.from_pandas(df, preserve_index=False) + except Exception as e: + raise DataError(f"Failed to convert pandas DataFrame to Arrow table: {e}") from e + else: + try: + arrow_table = df.to_arrow() + except Exception as e: + raise DataError(f"Failed to convert polars DataFrame to Arrow table: {e}") from e + + self._add_integration_tag(df_lib) + return await self.insert_arrow( + table=table, + arrow_table=arrow_table, + database=database, + settings=settings, + transport_settings=transport_settings, + ) - async def create_insert_context(self, - table: str, - column_names: Optional[Union[str, Sequence[str]]] = None, - database: Optional[str] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - data: Optional[Sequence[Sequence[Any]]] = None, - transport_settings: Optional[Dict[str, str]] = None) -> InsertContext: + async def create_insert_context( # type: ignore[override] + self, + table: str, + column_names: str | Sequence[str] | None = None, + database: str | None = None, + column_types: Sequence[ClickHouseType] | None = None, + column_type_names: Sequence[str] | None = None, + column_oriented: bool = False, + settings: dict[str, Any] | None = None, + data: Sequence[Sequence[Any]] | None = None, + transport_settings: dict[str, str] | None = None, + ) -> InsertContext: """ Builds a reusable insert context to hold state for a duration of an insert :param table: Target table @@ -836,60 +1498,406 @@ class AsyncClient: :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: Reusable insert context """ + full_table = table + if "." not in table: + if database: + full_table = f"{quote_identifier(database)}.{quote_identifier(table)}" + else: + full_table = quote_identifier(table) + column_defs = [] + if column_types is None and column_type_names is None: + describe_result = await self.query(f"DESCRIBE TABLE {full_table}", settings=settings) + column_defs = [ + ColumnDef(**row) for row in describe_result.named_results() if row["default_type"] not in ("ALIAS", "MATERIALIZED") + ] + if column_names is None or isinstance(column_names, str) and column_names == "*": + column_names = [cd.name for cd in column_defs] + column_types = [cd.ch_type for cd in column_defs] + elif isinstance(column_names, str): + column_names = [column_names] + if len(column_names) == 0: + raise ValueError("Column names must be specified for insert") + if not column_types: + if column_type_names: + column_types = [get_from_name(name) for name in column_type_names] + else: + column_map = {d.name: d for d in column_defs} + try: + column_types = [column_map[name].ch_type for name in column_names] + except KeyError as ex: + raise ProgrammingError(f"Unrecognized column {ex} in table {table}") from None + if len(column_names) != len(column_types): + raise ProgrammingError("Column names do not match column types") from None + return InsertContext( + full_table, + column_names, + column_types, + column_oriented=column_oriented, + settings=settings, + transport_settings=transport_settings, + data=data, + ) + + async def data_insert(self, context: InsertContext) -> QuerySummary: # type: ignore[override] + """ + See BaseClient doc_string for this method. + + Uses true streaming via reverse bridge pattern: + - Sync producer (serializer) runs in executor, puts blocks in queue + - Async consumer (network) pulls from queue and yields to aiohttp + - Bounded queue provides backpressure to prevent memory bloat + """ + if context.empty: + logger.debug("No data included in insert, skipping") + return QuerySummary() - def _create_insert_context(): - return self.client.create_insert_context(table=table, column_names=column_names, database=database, - column_types=column_types, column_type_names=column_type_names, - column_oriented=column_oriented, settings=settings, data=data, - transport_settings=transport_settings) + if context.compression is None: + context.compression = self.write_compression loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _create_insert_context) - return result - async def data_insert(self, context: InsertContext) -> QuerySummary: + active_source = StreamingInsertSource(transform=self._transform, context=context, loop=loop, maxsize=10) + active_source.start_producer() + + async def rebuild_body(): + nonlocal active_source + await active_source.close(timeout=None) + context.current_row = 0 + context.current_block = 0 + active_source = StreamingInsertSource(transform=self._transform, context=context, loop=loop, maxsize=10) + active_source.start_producer() + return active_source.async_generator() + + headers = {"Content-Type": "application/octet-stream"} + if context.compression: + headers["Content-Encoding"] = context.compression + + params = {} + if self.database: + params["database"] = self.database + params.update(self._validate_settings(context.settings)) + headers = dict_copy(headers, context.transport_settings) + + try: + response = await self._raw_request( + active_source.async_generator(), + params, + headers=headers, + server_wait=False, + retry_body=rebuild_body, + ) + logger.debug("Context insert response code: %d", response.status) + except Exception: + await active_source.close() + + if context.insert_exception: + ex = context.insert_exception + context.insert_exception = None + raise ex from None + raise + finally: + await active_source.close() + context.data = None + + return QuerySummary(self._summary(response)) + + async def insert_df( # type: ignore[override] + self, + table: str | None = None, + df=None, + database: str | None = None, + settings: dict | None = None, + column_names: Sequence[str] | None = None, + column_types: Sequence[ClickHouseType] | None = None, + column_type_names: Sequence[str] | None = None, + context: InsertContext | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: + """ + Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored + :param table: ClickHouse table + :param df: two-dimensional pandas dataframe + :param database: Optional ClickHouse database + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param column_names: An optional list of ClickHouse column names. If not set, the DataFrame column names + will be used + :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from + the server + :param column_type_names: ClickHouse column type names. If set then column data does not need to be + retrieved from the server + :param context: Optional reusable insert context to allow repeated inserts into the same table with + different data batches + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: QuerySummary with summary information, throws exception if insert fails + """ + check_pandas() + self._add_integration_tag("pandas") + if context is None: + if column_names is None: + column_names = df.columns + elif len(column_names) != len(df.columns): + raise ProgrammingError("DataFrame column count does not match insert_columns") from None + return await self.insert( + table, + df, + column_names, + database, + column_types=column_types, + column_type_names=column_type_names, + settings=settings, + transport_settings=transport_settings, + context=context, + ) + + async def raw_insert( # type: ignore[override] + self, + table: str | None = None, + column_names: Sequence[str] | None = None, + insert_block: str | bytes | Generator[bytes, None, None] | BinaryIO | None = None, + settings: dict | None = None, + fmt: str | None = None, + compression: str | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: """ - Subclass implementation of the data insert - :context: InsertContext parameter object - :return: No return, throws an exception if the insert fails + See BaseClient doc_string for this method """ + params = {} + headers = {"Content-Type": "application/octet-stream"} + if compression: + headers["Content-Encoding"] = compression - def _data_insert(): - return self.client.data_insert(context=context) + if table: + cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else "" + fmt_str = fmt if fmt else self._write_format + query = f"INSERT INTO {table}{cols} FORMAT {fmt_str}" + if not compression and isinstance(insert_block, str): + insert_block = query + "\n" + insert_block + elif not compression and isinstance(insert_block, (bytes, bytearray, BinaryIO)): + insert_block = (query + "\n").encode() + insert_block + else: + params["query"] = query - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _data_insert) - return result + if self.database: + params["database"] = self.database + params.update(self._validate_settings(settings or {})) + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request(insert_block, params, headers, server_wait=False) + logger.debug("Raw insert response code: %d", response.status) + return QuerySummary(self._summary(response)) - async def raw_insert(self, table: str, - column_names: Optional[Sequence[str]] = None, - insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None, - settings: Optional[Dict] = None, - fmt: Optional[str] = None, - compression: Optional[str] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + def _add_integration_tag(self, name: str): """ - Insert data already formatted in a bytes object - :param table: Table name (whether qualified with the database name or not) - :param column_names: Sequence of column names - :param insert_block: Binary or string data already in a recognized ClickHouse format - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param compression: Recognized ClickHouse `Accept-Encoding` header compression value - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :param fmt: Valid clickhouse format + Dynamically adds a product (like pandas or sqlalchemy) to the User-Agent string details section. """ + if not common.get_setting("send_integration_tags") or name in self._reported_libs: + return - def _raw_insert(): - return self.client.raw_insert(table=table, column_names=column_names, insert_block=insert_block, - settings=settings, fmt=fmt, compression=compression, - transport_settings=transport_settings) + try: + ver = "unknown" + try: + ver = dist_version(name) + except Exception: + try: + mod = import_module(name) + ver = getattr(mod, "__version__", "unknown") + except Exception: + pass - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _raw_insert) - return result + product_info = f"{name}/{ver}" - async def __aenter__(self) -> "AsyncClient": - return self + ua = self.headers.get("User-Agent", "") + start = ua.find("(") + if start == -1: + return + end = ua.find(")", start + 1) + if end == -1: + return - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - await self.close() + details = ua[start + 1 : end].strip() + + if product_info in details: + self._reported_libs.add(name) + return + + new_details = f"{product_info}; {details}" if details else product_info + new_ua = f"{ua[: start + 1]}{new_details}{ua[end:]}" + self.headers["User-Agent"] = new_ua.strip() + if self._session: + self._session.headers["User-Agent"] = new_ua.strip() + + self._reported_libs.add(name) + logger.debug("Added '%s' to User-Agent", product_info) + + except Exception as e: + logger.debug("Problem adding '%s' to User-Agent: %s", name, e) + + async def _error_handler(self, response: aiohttp.ClientResponse, retried: bool = False): + """ + Handles HTTP errors. Tries to be robust and provide maximum context. + """ + try: + body = "" + try: + raw_body = await response.read() + encoding = response.headers.get("Content-Encoding") + + if encoding: + loop = asyncio.get_running_loop() + + def decompress_and_decode(): + decompressed = decompress_response(raw_body, encoding) + return common.format_error(decompressed.decode(errors="backslashreplace")).strip() + + body = await loop.run_in_executor(None, decompress_and_decode) + else: + loop = asyncio.get_running_loop() + body = await loop.run_in_executor(None, lambda: common.format_error(raw_body.decode(errors="backslashreplace")).strip()) + except Exception: + logger.warning("Failed to read error response body", exc_info=True) + + if self.show_clickhouse_errors: + err_code = response.headers.get(ex_header) + if err_code: + err_str = f"Received ClickHouse exception, code: {err_code}" + else: + err_str = f"HTTP driver received HTTP status {response.status}" + + if body: + err_str = f"{err_str}, server response: {body}" + else: + err_str = "The ClickHouse server returned an error" + + err_str = f"{err_str} (for url {self.url})" + + finally: + response.close() + + raise OperationalError(err_str) if retried else DatabaseError(err_str) from None + + async def _raw_request( + self, + data, + params, + headers=None, + files=None, + method="POST", + stream=False, + server_wait=True, + retries: int = 0, + retry_body: Callable[[], Awaitable[Any]] | None = None, + ) -> aiohttp.ClientResponse: + if self._session is None: + raise ProgrammingError( + "Session not initialized. Use 'async with get_async_client(...)' or call 'await client._initialize()' first." + ) + + reset_seconds = common.get_setting("max_connection_age") + if reset_seconds: + now = time.time() + if self._last_pool_reset is None: + self._last_pool_reset = now + elif self._last_pool_reset < now - reset_seconds: + logger.debug("connection expiration - resetting connection pool") + await self.close_connections() + self._last_pool_reset = now + + final_params = dict_copy(self._client_settings, params) + if server_wait: + final_params.setdefault("wait_end_of_query", "1") + if self._send_progress: + final_params.setdefault("send_progress_in_http_headers", "1") + if self._progress_interval: + final_params.setdefault("http_headers_progress_interval_ms", self._progress_interval) + if self._autogenerate_query_id and "query_id" not in final_params: + final_params["query_id"] = str(uuid.uuid4()) + + req_headers = dict_copy(self.headers, headers) + if self.server_host_name: + req_headers["Host"] = self.server_host_name + query_session = final_params.get("session_id") + attempts = 0 + + while True: + attempts += 1 + + if query_session: + if query_session == self._active_session: + raise ProgrammingError( + "Attempt to execute concurrent queries within the same session. " + "Please use a separate client instance per concurrent query." + ) + self._active_session = query_session + + try: + # Construct full URL (aiohttp doesn't have base_url) + url = f"{self.url}/" + request_kwargs = {"method": method, "url": url, "params": final_params, "headers": req_headers} + if hasattr(self, "_proxy_url") and self._proxy_url: + request_kwargs["proxy"] = self._proxy_url + if files: + # IMPORTANT: Must set content_type on text fields to force multipart/form-data encoding + # Without content_type, aiohttp uses application/x-www-form-urlencoded + form = aiohttp.FormData() + for field_name, field_value in files.items(): + if isinstance(field_value, tuple): + if field_value[0] is None: + form.add_field(field_name, str(field_value[1]), content_type="text/plain") + else: + filename = field_value[0] + file_data = field_value[1] + content_type = field_value[2] if len(field_value) > 2 else None + form.add_field(field_name, file_data, filename=filename, content_type=content_type) + else: + form.add_field(field_name, field_value, content_type="text/plain") + request_kwargs["data"] = form + elif isinstance(data, dict): + request_kwargs["data"] = data + else: + request_kwargs["data"] = data + + response = await self._session.request(**request_kwargs) + if 200 <= response.status < 300 and not response.headers.get(ex_header): + return response + + if response.status in (429, 503, 504): + if attempts > retries: + await self._error_handler(response, retried=True) + else: + logger.debug("Retrying request with status code %s (attempt %s/%s)", response.status, attempts, retries + 1) + await asyncio.sleep(0.1 * attempts) + response.close() + continue + await self._error_handler(response) + + except aiohttp.ServerConnectionError as e: + msg = str(e) + if "Connection reset" in msg or "Remote end closed" in msg or "Cannot connect" in msg or "Server disconnected" in msg: + if attempts == 1: + if retry_body is not None: + data = await retry_body() + logger.debug("Retrying after connection error with rebuilt body") + continue + if data is None or isinstance(data, (bytes, bytearray, str, dict)): + logger.debug("Retrying after connection error from remote host") + continue + raise OperationalError(f"Network Error: {msg}") from e + + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + raise OperationalError(f"Network Error: {str(e)}") from e + + finally: + if query_session: + self._active_session = None + + @staticmethod + def _summary(response: aiohttp.ClientResponse): + summary = {} + if "X-ClickHouse-Summary" in response.headers: + try: + summary = json.loads(response.headers["X-ClickHouse-Summary"]) + except json.JSONDecodeError: + pass + summary["query_id"] = response.headers.get("X-ClickHouse-Query-Id", "") + return summary diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/asyncqueue.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/asyncqueue.py new file mode 100644 index 00000000000..a318f59cea1 --- /dev/null +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/asyncqueue.py @@ -0,0 +1,205 @@ +import asyncio +import threading +from collections import deque +from typing import Generic, TypeVar + +from clickhouse_connect.driver.exceptions import ProgrammingError + +__all__ = ["AsyncSyncQueue", "Empty", "Full", "EOF_SENTINEL"] + +T = TypeVar("T") + +EOF_SENTINEL = object() + + +class AsyncSyncQueue(Generic[T]): + """High-performance bridge between AsyncIO and Threading.""" + + def __init__(self, maxsize: int = 100): + self._maxsize = maxsize + self._queue: deque[T] = deque() + self._shutdown = False + self._loop: asyncio.AbstractEventLoop | None = None + + self._lock = threading.Lock() + + self._sync_not_empty = threading.Condition(self._lock) + self._sync_not_full = threading.Condition(self._lock) + + self._async_getters: deque[asyncio.Future] = deque() + self._async_putters: deque[asyncio.Future] = deque() + + self.sync_q = _SyncQueueInterface(self) + self.async_q = _AsyncQueueInterface(self) + + def _bind_loop(self): + """Lazy-bind to the running loop on first async access.""" + if self._loop is None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + pass + + def _check_deadlock(self): + """Check if blocking would cause a deadlock on the event loop.""" + if self._loop is None: + return + + try: + current_loop = asyncio.get_running_loop() + if current_loop is self._loop: + raise ProgrammingError( + "Deadlock detected: Synchronous blocking operation called on event loop thread. " + "This usually happens when iterating a stream synchronously (e.g., 'for row in result') " + "instead of asynchronously ('async for row in result') inside an async function." + ) + except RuntimeError: + pass + + @staticmethod + def _safe_set_result(fut: asyncio.Future): + """Set result on a future only if it hasn't been cancelled or resolved. + + This runs on the event loop thread after being scheduled via + call_soon_threadsafe. Between scheduling and execution the future + may have been cancelled (e.g. by Task.cancel()), so the done() + check must happen here, not at schedule time. + """ + if not fut.done(): + fut.set_result(None) + + def _wakeup_async_waiter(self, waiter_queue: deque[asyncio.Future]): + """Helper: Wake up the next async waiter in the queue safely.""" + while waiter_queue: + fut = waiter_queue.popleft() + if not fut.done(): + self._loop.call_soon_threadsafe(self._safe_set_result, fut) + break + + def shutdown(self): + """Terminates the queue. All readers will receive EOF_SENTINEL.""" + with self._lock: + self._shutdown = True + + self._sync_not_empty.notify_all() + self._sync_not_full.notify_all() + + if self._loop and not self._loop.is_closed(): + for fut in list(self._async_getters): + if not fut.done(): + self._loop.call_soon_threadsafe(self._safe_set_result, fut) + for fut in list(self._async_putters): + if not fut.done(): + self._loop.call_soon_threadsafe(self._safe_set_result, fut) + self._async_getters.clear() + self._async_putters.clear() + + @property + def qsize(self) -> int: + with self._lock: + return len(self._queue) + + +class _SyncQueueInterface(Generic[T]): + def __init__(self, parent: AsyncSyncQueue[T]): + self._p = parent + + def get(self, block: bool = True, timeout: float | None = None) -> T: + with self._p._lock: + while not self._p._queue and not self._p._shutdown: + if not block: + raise Empty() + + self._p._check_deadlock() + if not self._p._sync_not_empty.wait(timeout): + raise Empty() + + if not self._p._queue and self._p._shutdown: + return EOF_SENTINEL + + item = self._p._queue.popleft() + self._p._sync_not_full.notify() + self._p._wakeup_async_waiter(self._p._async_putters) + + return item + + def put(self, item: T, block: bool = True, timeout: float | None = None) -> None: + with self._p._lock: + if self._p._shutdown: + raise RuntimeError("Queue is shutdown") + + while self._p._maxsize > 0 and len(self._p._queue) >= self._p._maxsize: + if not block: + raise Full() + + self._p._check_deadlock() + if not self._p._sync_not_full.wait(timeout): + raise Full() + if self._p._shutdown: + raise RuntimeError("Queue is shutdown") + + self._p._queue.append(item) + + self._p._sync_not_empty.notify() + self._p._wakeup_async_waiter(self._p._async_getters) + + +class _AsyncQueueInterface(Generic[T]): + def __init__(self, parent: AsyncSyncQueue[T]): + self._p = parent + + async def get(self) -> T: + self._p._bind_loop() + while True: + with self._p._lock: + if self._p._queue: + item = self._p._queue.popleft() + self._p._sync_not_full.notify() + self._p._wakeup_async_waiter(self._p._async_putters) + return item + + if self._p._shutdown: + return EOF_SENTINEL + + fut = self._p._loop.create_future() + self._p._async_getters.append(fut) + + try: + await fut + except asyncio.CancelledError: + with self._p._lock: + if fut in self._p._async_getters: + self._p._async_getters.remove(fut) + raise + + async def put(self, item: T) -> None: + self._p._bind_loop() + while True: + with self._p._lock: + if self._p._shutdown: + raise RuntimeError("Queue is shutdown") + + if self._p._maxsize <= 0 or len(self._p._queue) < self._p._maxsize: + self._p._queue.append(item) + self._p._sync_not_empty.notify() + self._p._wakeup_async_waiter(self._p._async_getters) + return + + fut = self._p._loop.create_future() + self._p._async_putters.append(fut) + + try: + await fut + except asyncio.CancelledError: + with self._p._lock: + if fut in self._p._async_putters: + self._p._async_putters.remove(fut) + raise + + +class Empty(Exception): # noqa: N818 + pass + + +class Full(Exception): # noqa: N818 + pass diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/binding.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/binding.py index 97cbb41da19..9c63713e6dc 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/binding.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/binding.py @@ -1,11 +1,11 @@ import ipaddress import re import uuid -from datetime import tzinfo, datetime, date +import zoneinfo +from collections.abc import Sequence +from datetime import date, datetime, timezone, tzinfo from enum import Enum -from typing import Optional, Union, Sequence, Dict, Any, Tuple - -import pytz +from typing import Any from clickhouse_connect import common from clickhouse_connect.driver import tzutil @@ -13,20 +13,20 @@ from clickhouse_connect.driver.common import dict_copy from clickhouse_connect.driver.parser import parse_callable from clickhouse_connect.json_impl import any_to_json -BS = '\\' -must_escape = (BS, '\'', '`', '\t', '\n') -external_bind_re = re.compile(r'\{(\w+):([^}]+)\}') +BS = "\\" +must_escape = (BS, "'", "`", "\t", "\n") +external_bind_re = re.compile(r"\{(\w+):([^}]+)\}") class DT64Param: def __init__(self, value: datetime): self.value = value - def format(self, tz: tzinfo, top_level:bool) -> str: + def format(self, tz: tzinfo, top_level: bool) -> str: value = self.value if tz: value = value.astimezone(tz) - s = value.strftime('%Y-%m-%d %H:%M:%S.%f') + s = value.strftime("%Y-%m-%d %H:%M:%S.%f") if top_level: return s return f"'{s}'" @@ -34,23 +34,22 @@ class DT64Param: def quote_identifier(identifier: str): first_char = identifier[0] - if first_char in ('`', '"') and identifier[-1] == first_char: + if first_char in ("`", '"') and identifier[-1] == first_char: # Identifier is already quoted, assume that it's valid return identifier - return f'`{escape_str(identifier)}`' + return f"`{escape_str(identifier)}`" -def finalize_query(query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]], - server_tz: Optional[tzinfo] = None) -> str: +def finalize_query(query: str, parameters: Sequence | dict[str, Any] | None, server_tz: tzinfo | None = None) -> str: query = query.rstrip(";") if not parameters: return query - if hasattr(parameters, 'items'): + if hasattr(parameters, "items"): return query % {k: format_query_value(v, server_tz) for k, v in parameters.items()} return query % tuple(format_query_value(v, server_tz) for v in parameters) -def _extract_tz_from_type(type_str: str) -> Optional[tzinfo]: +def _extract_tz_from_type(type_str: str) -> tzinfo | None: """Extract timezone from a ClickHouse type hint like DateTime64(6, 'UTC'). Handles LowCardinality/Nullable wrappers and container types @@ -69,8 +68,8 @@ def _extract_tz_from_type(type_str: str) -> Optional[tzinfo]: for v in values: if isinstance(v, str) and v.startswith("'") and v.endswith("'"): try: - return pytz.timezone(v[1:-1]) - except pytz.UnknownTimeZoneError: + return tzutil.resolve_zone(v[1:-1]) + except zoneinfo.ZoneInfoNotFoundError: return None return None @@ -82,13 +81,15 @@ def _extract_tz_from_type(type_str: str) -> Optional[tzinfo]: return tz return None - except Exception: # pylint: disable=broad-exception-caught + except Exception: return None -# pylint: disable=too-many-locals,too-many-branches -def bind_query(query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]], - server_tz: Optional[tzinfo] = None) -> Tuple[str, Dict[str, str]]: +def bind_query( + query: str, + parameters: Sequence | dict[str, Any] | None, + server_tz: tzinfo | None = None, +) -> tuple[str, dict[str, str]]: query = query.rstrip(";") if not parameters: return query, {} @@ -97,13 +98,13 @@ def bind_query(query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] if isinstance(parameters, dict): params_copy = dict_copy(parameters) - binary_binds = {k: v for k, v in params_copy.items() if k.startswith('$') and k.endswith('$') and len(k) > 1} + binary_binds = {k: v for k, v in params_copy.items() if k.startswith("$") and k.endswith("$") and len(k) > 1} for key in binary_binds.keys(): del params_copy[key] final_params = {} for k, v in params_copy.items(): - if k.endswith('_64'): + if k.endswith("_64"): if isinstance(v, datetime): k = k[:-3] v = DT64Param(v) @@ -142,7 +143,7 @@ def bind_query(query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] break binary_indexes[item_index + len(key)] = key, v item_index += len(key) - query = b'' + query = b"" start = 0 for loc in sorted(binary_indexes.keys()): key, value = binary_indexes[loc] @@ -157,11 +158,10 @@ def format_str(value: str): def escape_str(value: str): - return ''.join(f'{BS}{c}' if c in must_escape else c for c in value) + return "".join(f"{BS}{c}" if c in must_escape else c for c in value) -# pylint: disable=too-many-return-statements -def format_query_value(value: Any, server_tz: tzinfo = pytz.UTC): +def format_query_value(value: Any, server_tz: tzinfo = timezone.utc): """ Format Python values in a ClickHouse query :param value: Python object @@ -169,7 +169,7 @@ def format_query_value(value: Any, server_tz: tzinfo = pytz.UTC): :return: Literal string for python value """ if value is None: - return 'NULL' + return "NULL" if isinstance(value, str): return format_str(value) if isinstance(value, DT64Param): @@ -185,10 +185,9 @@ def format_query_value(value: Any, server_tz: tzinfo = pytz.UTC): if isinstance(value, tuple): return f"({', '.join(str_query_value(x, server_tz) for x in value)})" if isinstance(value, dict): - if common.get_setting('dict_parameter_format') == 'json': + if common.get_setting("dict_parameter_format") == "json": return format_str(any_to_json(value).decode()) - pairs = [str_query_value(k, server_tz) + ':' + str_query_value(v, server_tz) - for k, v in value.items()] + pairs = [str_query_value(k, server_tz) + ":" + str_query_value(v, server_tz) for k, v in value.items()] return f"{{{', '.join(pairs)}}}" if isinstance(value, Enum): return format_query_value(value.value, server_tz) @@ -197,12 +196,11 @@ def format_query_value(value: Any, server_tz: tzinfo = pytz.UTC): return value -def str_query_value(value: Any, server_tz: tzinfo = pytz.UTC): +def str_query_value(value: Any, server_tz: tzinfo = timezone.utc): return str(format_query_value(value, server_tz)) -# pylint: disable=too-many-branches -def format_bind_value(value: Any, server_tz: tzinfo = pytz.UTC, top_level: bool = True): +def format_bind_value(value: Any, server_tz: tzinfo = timezone.utc, top_level: bool = True): """ Format Python values in a ClickHouse query :param value: Python object @@ -215,7 +213,7 @@ def format_bind_value(value: Any, server_tz: tzinfo = pytz.UTC, top_level: bool return format_bind_value(x, server_tz, False) if value is None: - return '\\N' + return "\\N" if isinstance(value, str): if top_level: # At the top levels, strings must not be surrounded by quotes @@ -225,7 +223,7 @@ def format_bind_value(value: Any, server_tz: tzinfo = pytz.UTC, top_level: bool return value.format(server_tz, top_level) if isinstance(value, datetime): value = value.astimezone(server_tz) - val = value.strftime('%Y-%m-%d %H:%M:%S') + val = value.strftime("%Y-%m-%d %H:%M:%S") if top_level: return val return f"'{val}'" @@ -238,10 +236,9 @@ def format_bind_value(value: Any, server_tz: tzinfo = pytz.UTC, top_level: bool if isinstance(value, tuple): return f"({', '.join(recurse(x) for x in value)})" if isinstance(value, dict): - if common.get_setting('dict_parameter_format') == 'json': + if common.get_setting("dict_parameter_format") == "json": return any_to_json(value).decode() - pairs = [recurse(k) + ':' + recurse(v) - for k, v in value.items()] + pairs = [recurse(k) + ":" + recurse(v) for k, v in value.items()] return f"{{{', '.join(pairs)}}}" if isinstance(value, Enum): return recurse(value.value) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/buffer.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/buffer.py index 4bfca8e3384..97c01670b09 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/buffer.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/buffer.py @@ -1,16 +1,16 @@ -import sys import array -from typing import Any, Iterable +import sys +from collections.abc import Iterable +from typing import Any from clickhouse_connect.driver.exceptions import StreamCompleteException from clickhouse_connect.driver.types import ByteSource -must_swap = sys.byteorder == 'big' +must_swap = sys.byteorder == "big" -# pylint: disable=too-many-instance-attributes class ResponseBuffer(ByteSource): - slots = 'slice_sz', 'buf_loc', 'end', 'gen', 'buffer', 'slice' + slots = "slice_sz", "buf_loc", "end", "gen", "buffer", "slice" def __init__(self, source): self.slice_sz = 4096 @@ -18,7 +18,7 @@ class ResponseBuffer(ByteSource): self.buf_sz = 0 self.source = source self.gen = source.gen - self.buffer = bytes() + self.buffer = b"" self.exception_tag = getattr(source, "exception_tag", None) if self.exception_tag: tag_bytes = self.exception_tag.encode() @@ -56,9 +56,9 @@ class ResponseBuffer(ByteSource): def read_bytes(self, sz: int): if self.buf_loc + sz <= self.buf_sz: self.buf_loc += sz - return self.buffer[self.buf_loc - sz: self.buf_loc] + return self.buffer[self.buf_loc - sz : self.buf_loc] # Create a temporary buffer that bridges two or more source chunks - bridge = bytearray(self.buffer[self.buf_loc: self.buf_sz]) + bridge = bytearray(self.buffer[self.buf_loc : self.buf_sz]) self.buf_loc = 0 self.buf_sz = 0 while len(bridge) < sz: @@ -99,7 +99,7 @@ class ResponseBuffer(ByteSource): shift = 0 while True: b = self.read_byte() - sz += ((b & 0x7f) << shift) + sz += (b & 0x7F) << shift if (b & 0x80) == 0: return sz shift += 7 @@ -109,13 +109,15 @@ class ResponseBuffer(ByteSource): return self.read_bytes(sz).decode() def read_uint64(self) -> int: - return int.from_bytes(self.read_bytes(8), 'little', signed=False) + return int.from_bytes(self.read_bytes(8), "little", signed=False) - def read_str_col(self, - num_rows: int, - encoding: str, - nullable: bool = False, - null_obj: Any = None) -> Iterable[str]: + def read_str_col( + self, + num_rows: int, + encoding: str, + nullable: bool = False, + null_obj: Any = None, + ) -> Iterable[str]: column = [] app = column.append null_map = self.read_bytes(num_rows) if nullable else None @@ -124,7 +126,7 @@ class ResponseBuffer(ByteSource): shift = 0 while True: b = self.read_byte() - sz += ((b & 0x7f) << shift) + sz += (b & 0x7F) << shift if (b & 0x80) == 0: break shift += 7 @@ -142,7 +144,7 @@ class ResponseBuffer(ByteSource): def read_bytes_col(self, sz: int, num_rows: int) -> Iterable[bytes]: source = self.read_bytes(sz * num_rows) - return [bytes(source[x:x+sz]) for x in range(0, sz * num_rows, sz)] + return [bytes(source[x : x + sz]) for x in range(0, sz * num_rows, sz)] def read_fixed_str_col(self, sz: int, num_rows: int, encoding: str) -> Iterable[str]: source = self.read_bytes(sz * num_rows) @@ -150,9 +152,9 @@ class ResponseBuffer(ByteSource): app = column.append for ix in range(0, sz * num_rows, sz): try: - app(str(source[ix: ix + sz], encoding).rstrip('\x00')) + app(str(source[ix : ix + sz], encoding).rstrip("\x00")) except UnicodeDecodeError: - app(source[ix: ix + sz].hex()) + app(source[ix : ix + sz].hex()) return column def read_array(self, array_type: str, num_rows: int) -> Iterable[Any]: diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/bytesource.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/bytesource.py index 3af433711f4..3d98b6dd394 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/bytesource.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/bytesource.py @@ -73,7 +73,6 @@ class ByteArraySource(ByteSource): def read_float64(self) -> float: return struct.unpack("<d", self.read_bytes(8))[0] - # pylint: disable=too-many-return-statements def read_array(self, array_type: str, num_rows: int): # type: ignore if array_type == "B": return [self.read_byte() for _ in range(num_rows)] diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/client.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/client.py index 234f0e46724..c88af65c345 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/client.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/client.py @@ -1,23 +1,16 @@ +from __future__ import annotations + import io import logging -import warnings from abc import ABC, abstractmethod -from datetime import tzinfo +from collections.abc import Generator, Iterable, Sequence +from datetime import timezone, tzinfo from typing import ( TYPE_CHECKING, Any, BinaryIO, - Dict, - Generator, - Iterable, - Literal, - Optional, - Sequence, - Union, ) - -import pytz -from pytz.exceptions import UnknownTimeZoneError +from zoneinfo import ZoneInfoNotFoundError from clickhouse_connect import common from clickhouse_connect.common import version @@ -51,15 +44,12 @@ from clickhouse_connect.driver.options import ( check_polars, ) from clickhouse_connect.driver.query import ( - _APPLY_SERVER_TZ_TO_TZ_SOURCE, - _TZ_MODE_TO_UTC_TZ_AWARE, + _VALID_TZ_MODES, _VALID_TZ_SOURCES, QueryContext, QueryResult, TzMode, TzSource, - _resolve_tz_mode, - _resolve_tz_source, arrow_buffer, to_arrow, to_arrow_batches, @@ -69,22 +59,23 @@ from clickhouse_connect.driver.summary import QuerySummary if TYPE_CHECKING: import numpy import pandas + import polars import pyarrow io.DEFAULT_BUFFER_SIZE = 1024 * 256 logger = logging.getLogger(__name__) -arrow_str_setting = 'output_format_arrow_string_as_string' +arrow_str_setting = "output_format_arrow_string_as_string" -def _strip_utc_timezone_from_arrow(table: "arrow.Table") -> "arrow.Table": +def _strip_utc_timezone_from_arrow(table: pyarrow.Table) -> pyarrow.Table: """Strip UTC timezone from timestamp columns in Arrow table. This ensures naive datetimes are returned when the server timezone is UTC - and utc_tz_aware is False (the default). + and tz_mode is 'naive_utc' (the default). Only UTC-equivalent timezones (UTC, Etc/UTC, GMT, etc.) are stripped. Non-UTC timezones carry important offset information and are always - preserved regardless of utc_tz_aware setting. + preserved regardless of tz_mode setting. """ new_fields = [] needs_cast = False @@ -99,7 +90,7 @@ def _strip_utc_timezone_from_arrow(table: "arrow.Table") -> "arrow.Table": return table -def _apply_arrow_tz_policy(table: "options.arrow.Table", tz_mode: str) -> "options.arrow.Table": +def _apply_arrow_tz_policy(table: pyarrow.Table, tz_mode: str) -> pyarrow.Table: """Apply the tz_mode policy to an Arrow table before conversion. Handles UTC stripping when tz_mode is "naive_utc" and warns when @@ -118,14 +109,13 @@ def _apply_arrow_tz_policy(table: "options.arrow.Table", tz_mode: str) -> "optio return table -# pylint: disable=too-many-lines -# pylint: disable=too-many-public-methods,too-many-arguments,too-many-positional-arguments,too-many-instance-attributes class Client(ABC): """ Base ClickHouse Connect client """ - compression: str = None - write_compression: str = None + + compression: str | None = None + write_compression: str | None = None protocol_version = 0 valid_transport_settings = set() optional_transport_settings = set() @@ -143,63 +133,25 @@ class Client(ABC): @tz_source.setter def tz_source(self, value: TzSource): if value not in _VALID_TZ_SOURCES: - raise ProgrammingError( - f'tz_source must be "auto", "server", or "local", got "{value}"' - ) + raise ProgrammingError(f'tz_source must be "auto", "server", or "local", got "{value}"') self._tz_source = value if value == "auto": self._apply_server_tz = self._dst_safe else: self._apply_server_tz = value == "server" - @property - def apply_server_timezone(self) -> bool: - """Deprecated: use tz_source instead.""" - warnings.warn( - "apply_server_timezone is deprecated and will be removed in 1.0. " - "Use tz_source instead.", - DeprecationWarning, - stacklevel=2, - ) - return self._apply_server_tz - - @apply_server_timezone.setter - def apply_server_timezone(self, value: Union[bool, str]): - """Deprecated: use tz_source instead.""" - warnings.warn( - "apply_server_timezone is deprecated and will be removed in 1.0. " - "Use tz_source instead.", - DeprecationWarning, - stacklevel=2, - ) - if value not in _APPLY_SERVER_TZ_TO_TZ_SOURCE: - raise ProgrammingError( - f"apply_server_timezone must be True, False, or 'always', got \"{value}\"" - ) - self.tz_source = _APPLY_SERVER_TZ_TO_TZ_SOURCE[value] - - @property - def utc_tz_aware(self) -> Union[bool, Literal["schema"]]: - """Deprecated: use tz_mode instead.""" - warnings.warn( - "utc_tz_aware is deprecated and will be removed in 1.0. " - "Use tz_mode instead.", - DeprecationWarning, - stacklevel=2, - ) - return _TZ_MODE_TO_UTC_TZ_AWARE[self.tz_mode] - - def __init__(self, - database: str, - query_limit: int, - uri: str, - query_retries: int, - server_host_name: Optional[str], - tz_source: Optional[TzSource] = None, - tz_mode: Optional[TzMode] = None, - show_clickhouse_errors: Optional[bool] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - apply_server_timezone: Optional[Union[str, bool]] = None): + def __init__( + self, + database: str | None, + query_limit: int, + uri: str, + query_retries: int, + server_host_name: str | None, + tz_source: TzSource | None = None, + tz_mode: TzMode | None = None, + show_clickhouse_errors: bool | None = None, + autoconnect: bool = True, + ): """ Shared initialization of ClickHouse Connect client :param database: database name @@ -212,82 +164,103 @@ class Client(ABC): naive UTC timestamps. "aware" forces timezone-aware UTC datetimes. "schema" returns datetimes that match the server's column definition which means timezone-aware when the column defines a timezone and naive for bare DateTime columns. - :param utc_tz_aware: Deprecated. Use tz_mode instead. - :param apply_server_timezone: Deprecated. Use tz_source instead. + :param autoconnect: If True, immediately connect to server and fetch settings. If False, + defer connection to _connect() method. Used by async clients to avoid blocking I/O in __init__. """ self.query_limit = coerce_int(query_limit) self.query_retries = coerce_int(query_retries) - if database and not database == '__default__': + if database and not database == "__default__": self.database = database if show_clickhouse_errors is not None: self.show_clickhouse_errors = coerce_bool(show_clickhouse_errors) self.server_host_name = server_host_name self.uri = uri - self.tz_mode = _resolve_tz_mode(tz_mode, utc_tz_aware) - resolved_tz_source = _resolve_tz_source(tz_source, apply_server_timezone) + self.tz_mode = tz_mode if tz_mode is not None else "naive_utc" + if self.tz_mode not in _VALID_TZ_MODES: + raise ProgrammingError(f'tz_mode must be "naive_utc", "aware", or "schema", got "{self.tz_mode}"') + resolved_tz_source = tz_source if tz_source is not None else "auto" + if resolved_tz_source not in _VALID_TZ_SOURCES: + raise ProgrammingError(f'tz_source must be "auto", "server", or "local", got "{resolved_tz_source}"') self._tz_source = resolved_tz_source - self._init_common_settings(resolved_tz_source) + + # Initialize attributes that will be set during connection + self.server_version = None + self.server_tz = timezone.utc + self.server_settings = {} + + if autoconnect: + self._init_common_settings(resolved_tz_source) + else: + # Store for deferred async initialization + self._deferred_tz_source = resolved_tz_source def _init_common_settings(self, tz_source: TzSource): - self.server_tz, self._dst_safe = pytz.UTC, True - self.server_version, server_tz = \ - tuple(self.command('SELECT version(), timezone()', use_database=False)) + self.server_tz, self._dst_safe = timezone.utc, True + self.server_version, server_tz = tuple(self.command("SELECT version(), timezone()", use_database=False)) try: - server_tz = pytz.timezone(server_tz) - server_tz, self._dst_safe = tzutil.normalize_timezone(server_tz) - if tz_source == "auto": - self._apply_server_tz = self._dst_safe - else: - self._apply_server_tz = tz_source == "server" - self.server_tz = server_tz - except UnknownTimeZoneError: - logger.warning('Warning, server is using an unrecognized timezone %s, will use UTC default', server_tz) + server_tz_info = tzutil.resolve_zone(server_tz) + server_tz_info, self._dst_safe = tzutil.normalize_timezone(server_tz_info) + self.server_tz = server_tz_info + except ZoneInfoNotFoundError: + logger.warning( + "Server timezone %s could not be resolved, falling back to UTC; %s", + server_tz, + tzutil.TZDATA_HINT, + ) + if tz_source == "auto": + self._apply_server_tz = self._dst_safe + else: + self._apply_server_tz = tz_source == "server" if not self._apply_server_tz and not tzutil.local_tz_dst_safe: - logger.warning('local timezone %s may return unexpected times due to Daylight Savings Time/' + - 'Summer Time differences', tzutil.local_tz.tzname(None)) - readonly = 'readonly' - if not self.min_version('19.17'): - readonly = common.get_setting('readonly') - server_settings = self.query(f'SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000') - self.server_settings = {row['name']: SettingDef(**row) for row in server_settings.named_results()} + logger.warning( + "local timezone %s may return unexpected times due to Daylight Savings Time/" + "Summer Time differences", + tzutil.local_tz.tzname(None), + ) + readonly = "readonly" + if not self.min_version("19.17"): + readonly = common.get_setting("readonly") + server_settings = self.query(f"SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000") + self.server_settings = {row["name"]: SettingDef(**row) for row in server_settings.named_results()} - if self.min_version(CH_VERSION_WITH_PROTOCOL) and common.get_setting('use_protocol_version'): + if self.min_version(CH_VERSION_WITH_PROTOCOL) and common.get_setting("use_protocol_version"): # Unfortunately we have to validate that the client protocol version is actually used by ClickHouse # since the query parameter could be stripped off (in particular, by CHProxy) - test_data = self.raw_query('SELECT 1 AS check', fmt='Native', settings={ - 'client_protocol_version': PROTOCOL_VERSION_WITH_LOW_CARD - }) - if test_data[8:16] == b'\x01\x01\x05check': + test_data = self.raw_query( + "SELECT 1 AS check", fmt="Native", settings={"client_protocol_version": PROTOCOL_VERSION_WITH_LOW_CARD} + ) + if test_data[8:16] == b"\x01\x01\x05check": self.protocol_version = PROTOCOL_VERSION_WITH_LOW_CARD - if self._setting_status('date_time_input_format').is_writable: - self.set_client_setting('date_time_input_format', 'best_effort') - if self._setting_status('allow_experimental_json_type').is_set and \ - self._setting_status('cast_string_to_dynamic_use_inference').is_writable: - self.set_client_setting('cast_string_to_dynamic_use_inference', '1') - if self.min_version('24.8') and not self.min_version('24.10'): + if self._setting_status("date_time_input_format").is_writable: + self.set_client_setting("date_time_input_format", "best_effort") + if ( + self._setting_status("allow_experimental_json_type").is_set + and self._setting_status("cast_string_to_dynamic_use_inference").is_writable + ): + self.set_client_setting("cast_string_to_dynamic_use_inference", "1") + if self.min_version("24.8") and not self.min_version("24.10"): dynamic_module.json_serialization_format = 0 - def _validate_settings(self, settings: Optional[Dict[str, Any]]) -> Dict[str, str]: + def _validate_settings(self, settings: dict[str, Any] | None) -> dict[str, str]: """ This strips any ClickHouse settings that are not recognized or are read only. :param settings: Dictionary of setting name and values :return: A filtered dictionary of settings with values rendered as strings """ validated = {} - invalid_action = common.get_setting('invalid_setting_action') + invalid_action = common.get_setting("invalid_setting_action") for key, value in settings.items(): str_value = self._validate_setting(key, value, invalid_action) if str_value is not None: validated[key] = value return validated - def _validate_setting(self, key: str, value: Any, invalid_action: str) -> Optional[str]: + def _validate_setting(self, key: str, value: Any, invalid_action: str) -> str | None: str_value = str(value) if value is True: - str_value = '1' + str_value = "1" elif value is False: - str_value = '0' + str_value = "0" if key not in self.valid_transport_settings: setting_def = self.server_settings.get(key) current_setting = self.get_client_setting(key) @@ -299,37 +272,41 @@ class Client(ABC): if setting_def is None or setting_def.readonly: if key in self.optional_transport_settings: return None - if invalid_action == 'send': - logger.warning('Attempting to send unrecognized or readonly setting %s', key) - elif invalid_action == 'drop': - logger.warning('Dropping unrecognized or readonly settings %s', key) + if invalid_action == "send": + logger.warning("Attempting to send unrecognized or readonly setting %s", key) + elif invalid_action == "drop": + logger.warning("Dropping unrecognized or readonly settings %s", key) return None else: - raise ProgrammingError(f'Setting {key} is unknown or readonly') from None + raise ProgrammingError(f"Setting {key} is unknown or readonly") from None return str_value def _setting_status(self, key: str) -> SettingStatus: comp_setting = self.server_settings.get(key) if not comp_setting: return SettingStatus(False, False) - return SettingStatus(comp_setting.value != '0', comp_setting.readonly != 1) + return SettingStatus(comp_setting.value != "0", comp_setting.readonly != 1) def _prep_query(self, context: QueryContext): if context.is_select and not context.has_limit and self.query_limit: - limit = f'\n LIMIT {self.query_limit}' + limit = f"\n LIMIT {self.query_limit}" if isinstance(context.query, bytes): return context.final_query + limit.encode() return context.final_query + limit return context.final_query - def _check_tz_change(self, new_tz) -> Optional[tzinfo]: + def _check_tz_change(self, new_tz) -> tzinfo | None: if new_tz: try: - new_tzinfo = pytz.timezone(new_tz) + new_tzinfo = tzutil.resolve_zone(new_tz) if new_tzinfo != self.server_tz: return new_tzinfo - except UnknownTimeZoneError: - logger.warning('Unrecognized timezone %s received from ClickHouse', new_tz) + except ZoneInfoNotFoundError: + logger.warning( + "Unrecognized timezone %s received from ClickHouse; %s", + new_tz, + tzutil.TZDATA_HINT, + ) return None @abstractmethod @@ -347,7 +324,7 @@ class Client(ABC): """ @abstractmethod - def get_client_setting(self, key: str) -> Optional[str]: + def get_client_setting(self, key: str) -> str | None: """ :param key: The setting key :return: The string value of the setting, if it exists, or None @@ -360,62 +337,66 @@ class Client(ABC): :param access_token: Access token string """ - # pylint: disable=unused-argument,too-many-locals - def query(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> QueryResult: + def query( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + column_oriented: bool | None = None, + use_numpy: bool | None = None, + max_str_len: int | None = None, + context: QueryContext = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> QueryResult: """ Main query method for SELECT, DESCRIBE and other SQL statements that return a result matrix. For parameters, see the create_query_context method :return: QueryResult -- data and metadata from response """ - if query and query.lower().strip().startswith('select __connect_version__'): - return QueryResult([[f'ClickHouse Connect v.{version()} ⓒ ClickHouse Inc.']], None, - ('connect_version',), (get_from_name('String'),)) + if query and query.lower().strip().startswith("select __connect_version__"): + return QueryResult( + [[f"ClickHouse Connect v.{version()} ⓒ ClickHouse Inc."]], None, ("connect_version",), (get_from_name("String"),) + ) kwargs = locals().copy() - del kwargs['self'] + del kwargs["self"] query_context = self.create_query_context(**kwargs) if query_context.is_command: - response = self.command(query, - parameters=query_context.parameters, - settings=query_context.settings, - external_data=query_context.external_data, - transport_settings=query_context.transport_settings) + response = self.command( + query, + parameters=query_context.parameters, + settings=query_context.settings, + external_data=query_context.external_data, + transport_settings=query_context.transport_settings, + ) if isinstance(response, QuerySummary): return response.as_query_result() return QueryResult([response] if isinstance(response, list) else [[response]]) return self._query_with_context(query_context) - def query_column_block_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: + def query_column_block_stream( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + context: QueryContext = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> StreamContext: """ Variation of main query method that returns a stream of column oriented blocks. For parameters, see the create_query_context method. @@ -423,21 +404,22 @@ class Client(ABC): """ return self._context_query(locals(), use_numpy=False, streaming=True).column_block_stream - def query_row_block_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: + def query_row_block_stream( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + context: QueryContext = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> StreamContext: """ Variation of main query method that returns a stream of row oriented blocks. For parameters, see the create_query_context method. @@ -445,21 +427,22 @@ class Client(ABC): """ return self._context_query(locals(), use_numpy=False, streaming=True).row_block_stream - def query_rows_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: + def query_rows_stream( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + context: QueryContext = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> StreamContext: """ Variation of main query method that returns a stream of row oriented blocks. For parameters, see the create_query_context method. @@ -468,13 +451,16 @@ class Client(ABC): return self._context_query(locals(), use_numpy=False, streaming=True).rows_stream @abstractmethod - def raw_query(self, query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: str = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> bytes: + def raw_query( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> bytes: """ Query method that simply returns the raw ClickHouse format bytes :param query: Query statement/format string @@ -489,39 +475,43 @@ class Client(ABC): """ @abstractmethod - def raw_stream(self, query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: str = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase: + def raw_stream( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> io.IOBase | StreamContext: + """ + Query method that returns the result as a stream iterator. + :param query: Query statement/format string + :param parameters: Optional dictionary used to format the query + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param fmt: ClickHouse output format + :param use_database Send the database parameter to ClickHouse so the command will be executed in the client + database context. + :param external_data: External data to send with the query. + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: io.IOBase (sync) or StreamContext (async) - both support iteration over raw bytes """ - Query method that returns the result as an io.IOBase iterator - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param fmt: ClickHouse output format - :param use_database Send the database parameter to ClickHouse so the command will be executed in the client - database context. - :param external_data: External data to send with the query. - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: io.IOBase stream/iterator for the result - """ - # pylint: disable=duplicate-code,unused-argument - def query_np(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> 'numpy.ndarray': + def query_np( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + max_str_len: int | None = None, + context: QueryContext = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> numpy.ndarray: """ Query method that returns the results as a numpy array. For parameter values, see the create_query_context method @@ -531,19 +521,20 @@ class Client(ABC): self._add_integration_tag("numpy") return self._context_query(locals(), use_numpy=True).np_result - # pylint: disable=duplicate-code,too-many-arguments,unused-argument - def query_np_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: + def query_np_stream( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + max_str_len: int | None = None, + context: QueryContext = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> StreamContext: """ Query method that returns the results as a stream of numpy arrays. For parameter values, see the create_query_context method @@ -553,25 +544,25 @@ class Client(ABC): self._add_integration_tag("numpy") return self._context_query(locals(), use_numpy=True, streaming=True).np_stream - # pylint: disable=duplicate-code,unused-argument - def query_df(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> 'pandas.DataFrame': + def query_df( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + max_str_len: int | None = None, + use_na_values: bool | None = None, + query_tz: str | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + context: QueryContext = None, + external_data: ExternalData | None = None, + use_extended_dtypes: bool | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> pandas.DataFrame: """ Query method that results the results as a pandas dataframe. For parameter values, see the create_query_context method @@ -581,25 +572,25 @@ class Client(ABC): self._add_integration_tag("pandas") return self._context_query(locals(), use_numpy=True, as_pandas=True).df_result - # pylint: disable=duplicate-code,unused-argument - def query_df_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: + def query_df_stream( + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + max_str_len: int | None = None, + use_na_values: bool | None = None, + query_tz: str | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + context: QueryContext = None, + external_data: ExternalData | None = None, + use_extended_dtypes: bool | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> StreamContext: """ Query method that returns the results as a StreamContext. For parameter values, see the create_query_context method @@ -607,32 +598,31 @@ class Client(ABC): """ check_pandas() self._add_integration_tag("pandas") - return self._context_query(locals(), use_numpy=True, - as_pandas=True, - streaming=True).df_stream + return self._context_query(locals(), use_numpy=True, as_pandas=True, streaming=True).df_stream - def create_query_context(self, - query: Optional[Union[str, bytes]] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = False, - max_str_len: Optional[int] = 0, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - use_na_values: Optional[bool] = None, - streaming: bool = False, - as_pandas: bool = False, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> QueryContext: + def create_query_context( + self, + query: str | bytes | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + column_oriented: bool | None = None, + use_numpy: bool | None = False, + max_str_len: int | None = 0, + context: QueryContext | None = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + use_na_values: bool | None = None, + streaming: bool = False, + as_pandas: bool = False, + external_data: ExternalData | None = None, + use_extended_dtypes: bool | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> QueryContext: """ Creates or updates a reusable QueryContext object :param query: Query statement/format string @@ -650,7 +640,7 @@ class Client(ABC): structured array even with ClickHouse variable length String columns. If 0, Numpy arrays for String columns will always be object arrays :param context: An existing QueryContext to be updated with any provided parameter values - :param query_tz: Either a string or a pytz tzinfo object. (Strings will be converted to tzinfo objects). + :param query_tz: Either a string IANA timezone name or a tzinfo object (strings are resolved via zoneinfo). Values for any DateTime or DateTime64 column in the query will be converted to Python datetime.datetime objects with the selected timezone. :param column_tzs: A dictionary of column names to tzinfo objects (or strings that will be converted to @@ -658,7 +648,6 @@ class Client(ABC): :param tz_mode: Override the client default for handling UTC results. "aware" forces timezone-aware UTC datetimes, "naive_utc" returns naive UTC datetimes, and "schema" returns datetimes matching the server's column definition. - :param utc_tz_aware: Deprecated. Use tz_mode instead. :param use_na_values: Deprecated alias for use_advanced_dtypes :param as_pandas Return the result columns as pandas.Series objects :param streaming Marker used to correctly configure streaming queries @@ -669,64 +658,67 @@ class Client(ABC): :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: Reusable QueryContext """ - if tz_mode is not None or utc_tz_aware is not None: - resolved_tz_mode = _resolve_tz_mode(tz_mode, utc_tz_aware) - else: - resolved_tz_mode = self.tz_mode + resolved_tz_mode = tz_mode if tz_mode is not None else self.tz_mode if context: - return context.updated_copy(query=query, - parameters=parameters, - settings=settings, - query_formats=query_formats, - column_formats=column_formats, - encoding=encoding, - server_tz=self.server_tz, - use_none=use_none, - column_oriented=column_oriented, - use_numpy=use_numpy, - max_str_len=max_str_len, - query_tz=query_tz, - column_tzs=column_tzs, - tz_mode=resolved_tz_mode, - as_pandas=as_pandas, - use_extended_dtypes=use_extended_dtypes, - streaming=streaming, - external_data=external_data, - transport_settings=transport_settings) + return context.updated_copy( + query=query, + parameters=parameters, + settings=settings, + query_formats=query_formats, + column_formats=column_formats, + encoding=encoding, + server_tz=self.server_tz, + use_none=use_none, + column_oriented=column_oriented, + use_numpy=use_numpy, + max_str_len=max_str_len, + query_tz=query_tz, + column_tzs=column_tzs, + tz_mode=resolved_tz_mode, + as_pandas=as_pandas, + use_extended_dtypes=use_extended_dtypes, + streaming=streaming, + external_data=external_data, + transport_settings=transport_settings, + ) if use_numpy and max_str_len is None: max_str_len = 0 if use_extended_dtypes is None: use_extended_dtypes = use_na_values if as_pandas and use_extended_dtypes is None: use_extended_dtypes = True - return QueryContext(query=query, - parameters=parameters, - settings=settings, - query_formats=query_formats, - column_formats=column_formats, - encoding=encoding, - server_tz=self.server_tz, - use_none=use_none, - column_oriented=column_oriented, - use_numpy=use_numpy, - max_str_len=max_str_len, - query_tz=query_tz, - column_tzs=column_tzs, - tz_mode=resolved_tz_mode, - use_extended_dtypes=use_extended_dtypes, - as_pandas=as_pandas, - streaming=streaming, - apply_server_tz=self._apply_server_tz, - external_data=external_data, - transport_settings=transport_settings) + return QueryContext( + query=query, + parameters=parameters, + settings=settings, + query_formats=query_formats, + column_formats=column_formats, + encoding=encoding, + server_tz=self.server_tz, + use_none=use_none, + column_oriented=column_oriented, + use_numpy=use_numpy, + max_str_len=max_str_len, + query_tz=query_tz, + column_tzs=column_tzs, + tz_mode=resolved_tz_mode, + use_extended_dtypes=use_extended_dtypes, + as_pandas=as_pandas, + streaming=streaming, + apply_server_tz=self._apply_server_tz, + external_data=external_data, + transport_settings=transport_settings, + ) - def query_arrow(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> 'pyarrow.Table': + def query_arrow( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + use_strings: bool | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> pyarrow.Table: """ Query method using the ClickHouse Arrow format to return a PyArrow table :param query: Query statement/format string @@ -740,20 +732,26 @@ class Client(ABC): check_arrow() self._add_integration_tag("arrow") settings = self._update_arrow_settings(settings, use_strings) - return to_arrow(self.raw_query(query, - parameters, - settings, - fmt='Arrow', - external_data=external_data, - transport_settings=transport_settings)) + return to_arrow( + self.raw_query( + query, + parameters, + settings, + fmt="Arrow", + external_data=external_data, + transport_settings=transport_settings, + ) + ) - def query_arrow_stream(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: + def query_arrow_stream( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + use_strings: bool | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> StreamContext: """ Query method that returns the results as a stream of Arrow tables :param query: Query statement/format string @@ -767,22 +765,27 @@ class Client(ABC): check_arrow() self._add_integration_tag("arrow") settings = self._update_arrow_settings(settings, use_strings) - return to_arrow_batches(self.raw_stream(query, - parameters, - settings, - fmt='ArrowStream', - external_data=external_data, - transport_settings=transport_settings)) + return to_arrow_batches( + self.raw_stream( + query, + parameters, + settings, + fmt="ArrowStream", + external_data=external_data, + transport_settings=transport_settings, + ) + ) - def query_df_arrow(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas" - ) -> Union["pd.DataFrame", "pl.DataFrame"]: + def query_df_arrow( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + use_strings: bool | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + dataframe_library: str = "pandas", + ) -> pandas.DataFrame | polars.DataFrame: """ Query method using the ClickHouse Arrow format to return a DataFrame with PyArrow dtype backend. This provides better performance and memory efficiency @@ -802,10 +805,8 @@ class Client(ABC): if dataframe_library == "pandas": check_pandas() self._add_integration_tag("pandas") - if not options.IS_PANDAS_2: - raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") - def converter(table: options.arrow.Table) -> options.pd.DataFrame: + def converter(table: pyarrow.Table) -> pandas.DataFrame: table = _apply_arrow_tz_policy(table, self.tz_mode) return table.to_pandas(types_mapper=options.pd.ArrowDtype, safe=False) @@ -813,7 +814,7 @@ class Client(ABC): check_polars() self._add_integration_tag("polars") - def converter(table: options.arrow.Table) -> options.pl.DataFrame: + def converter(table: pyarrow.Table) -> polars.DataFrame: table = _apply_arrow_tz_policy(table, self.tz_mode) return options.pl.from_arrow(table) @@ -831,14 +832,16 @@ class Client(ABC): return converter(arrow_table) - def query_df_arrow_stream(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas") -> StreamContext: + def query_df_arrow_stream( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + use_strings: bool | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + dataframe_library: str = "pandas", + ) -> StreamContext: """ Query method that returns the results as a stream of DataFrames with PyArrow dtype backend. Each DataFrame represents a block from the ClickHouse response. @@ -856,17 +859,15 @@ class Client(ABC): if dataframe_library == "pandas": check_pandas() self._add_integration_tag("pandas") - if not options.IS_PANDAS_2: - raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") - def converter(table: "options.arrow.Table") -> "options.pd.DataFrame": + def converter(table: pyarrow.Table) -> pandas.DataFrame: table = _apply_arrow_tz_policy(table, self.tz_mode) return table.to_pandas(types_mapper=options.pd.ArrowDtype, safe=False) elif dataframe_library == "polars": check_polars() self._add_integration_tag("polars") - def converter(table: options.arrow.Table) -> options.pl.DataFrame: + def converter(table: pyarrow.Table) -> polars.DataFrame: table = _apply_arrow_tz_policy(table, self.tz_mode) return options.pl.from_arrow(table) else: @@ -883,31 +884,31 @@ class Client(ABC): return StreamContext(raw_stream, df_generator()) - def _update_arrow_settings(self, - settings: Optional[Dict[str, Any]], - use_strings: Optional[bool]) -> Dict[str, Any]: + def _update_arrow_settings(self, settings: dict[str, Any] | None, use_strings: bool | None) -> dict[str, Any]: settings = dict_copy(settings) if self.database: - settings['database'] = self.database + settings["database"] = self.database str_status = self._setting_status(arrow_str_setting) if use_strings is None: if str_status.is_writable and not str_status.is_set: - settings[arrow_str_setting] = '1' # Default to returning strings if possible + settings[arrow_str_setting] = "1" # Default to returning strings if possible elif use_strings != str_status.is_set: if not str_status.is_writable: - raise OperationalError(f'Cannot change readonly {arrow_str_setting} to {use_strings}') - settings[arrow_str_setting] = '1' if use_strings else '0' + raise OperationalError(f"Cannot change readonly {arrow_str_setting} to {use_strings}") + settings[arrow_str_setting] = "1" if use_strings else "0" return settings @abstractmethod - def command(self, - cmd: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - data: Union[str, bytes] = None, - settings: Dict[str, Any] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> Union[str, int, Sequence[str], QuerySummary]: + def command( + self, + cmd: str, + parameters: Sequence | dict[str, Any] | None = None, + data: str | bytes = None, + settings: dict[str, Any] = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> str | int | Sequence[str] | QuerySummary: """ Client method that returns a single value instead of a result set :param cmd: ClickHouse query/command as a python format string @@ -930,17 +931,19 @@ class Client(ABC): :return: ClickHouse server is up and reachable """ - def insert(self, - table: Optional[str] = None, - data: Sequence[Sequence[Any]] = None, - column_names: Union[str, Iterable[str]] = '*', - database: Optional[str] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - context: InsertContext = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + def insert( + self, + table: str | None = None, + data: Sequence[Sequence[Any]] = None, + column_names: str | Iterable[str] = "*", + database: str | None = None, + column_types: Sequence[ClickHouseType] = None, + column_type_names: Sequence[str] = None, + column_oriented: bool = False, + settings: dict[str, Any] | None = None, + context: InsertContext = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: """ Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments other than data are ignored @@ -961,31 +964,36 @@ class Client(ABC): :return: QuerySummary with summary information, throws exception if insert fails """ if (context is None or context.empty) and data is None: - raise ProgrammingError('No data specified for insert') from None + raise ProgrammingError("No data specified for insert") from None if context is None: - context = self.create_insert_context(table, - column_names, - database, - column_types, - column_type_names, - column_oriented, - settings, - transport_settings=transport_settings) + context = self.create_insert_context( + table, + column_names, + database, + column_types, + column_type_names, + column_oriented, + settings, + transport_settings=transport_settings, + ) if data is not None: if not context.empty: - raise ProgrammingError('Attempting to insert new data with non-empty insert context') from None + raise ProgrammingError("Attempting to insert new data with non-empty insert context") from None context.data = data return self.data_insert(context) - def insert_df(self, table: str = None, - df=None, - database: Optional[str] = None, - settings: Optional[Dict] = None, - column_names: Optional[Sequence[str]] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - context: InsertContext = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + def insert_df( + self, + table: str = None, + df=None, + database: str | None = None, + settings: dict | None = None, + column_names: Sequence[str] | None = None, + column_types: Sequence[ClickHouseType] = None, + column_type_names: Sequence[str] = None, + context: InsertContext = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: """ Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored :param table: ClickHouse table @@ -1009,22 +1017,27 @@ class Client(ABC): if column_names is None: column_names = df.columns elif len(column_names) != len(df.columns): - raise ProgrammingError('DataFrame column count does not match insert_columns') from None - return self.insert(table, - df, - column_names, - database, - column_types=column_types, - column_type_names=column_type_names, - settings=settings, - transport_settings=transport_settings, - context=context) + raise ProgrammingError("DataFrame column count does not match insert_columns") from None + return self.insert( + table, + df, + column_names, + database, + column_types=column_types, + column_type_names=column_type_names, + settings=settings, + transport_settings=transport_settings, + context=context, + ) - def insert_arrow(self, table: str, - arrow_table, - database: str = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + def insert_arrow( + self, + table: str, + arrow_table, + database: str = None, + settings: dict | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: """ Insert a PyArrow table DataFrame into ClickHouse using raw Arrow format :param table: ClickHouse table @@ -1035,25 +1048,27 @@ class Client(ABC): """ check_arrow() self._add_integration_tag("arrow") - full_table = table if '.' in table or not database else f'{database}.{table}' - compression = self.write_compression if self.write_compression in ('zstd', 'lz4') else None + full_table = table if "." in table or not database else f"{database}.{table}" + compression = self.write_compression if self.write_compression in ("zstd", "lz4") else None column_names, insert_block = arrow_buffer(arrow_table, compression) - return self.raw_insert(full_table, column_names, insert_block, settings, 'Arrow', transport_settings) + return self.raw_insert(full_table, column_names, insert_block, settings, "Arrow", transport_settings) - def insert_df_arrow(self, - table: str, - df: Union["pd.DataFrame", "pl.DataFrame"], - database: Optional[str] = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + def insert_df_arrow( + self, + table: str, + df: pandas.DataFrame | polars.DataFrame, + database: str | None = None, + settings: dict | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: """ Insert a pandas DataFrame with PyArrow backend or a polars DataFrame into ClickHouse using Arrow format. This method is optimized for DataFrames that already use Arrow format, providing better performance than the standard insert_df method. - + Validation is performed and an exception will be raised if this requirement is not met. Polars DataFrames are natively Arrow-based and don't require additional validation. - + :param table: ClickHouse table name :param df: Pandas DataFrame with PyArrow dtype backend or Polars DataFrame :param database: Optional ClickHouse database name @@ -1073,9 +1088,6 @@ class Client(ABC): raise TypeError(f"df must be either a pandas DataFrame or polars DataFrame, got {type(df).__name__}") if df_lib == "pandas": - if not options.IS_PANDAS_2: - raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") - non_arrow_cols = [col for col, dtype in df.dtypes.items() if not isinstance(dtype, options.pd.ArrowDtype)] if non_arrow_cols: raise ProgrammingError( @@ -1100,16 +1112,18 @@ class Client(ABC): transport_settings=transport_settings, ) - def create_insert_context(self, - table: str, - column_names: Optional[Union[str, Sequence[str]]] = None, - database: Optional[str] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - data: Optional[Sequence[Sequence[Any]]] = None, - transport_settings: Optional[Dict[str, str]] = None) -> InsertContext: + def create_insert_context( + self, + table: str, + column_names: str | Sequence[str] | None = None, + database: str | None = None, + column_types: Sequence[ClickHouseType] = None, + column_type_names: Sequence[str] = None, + column_oriented: bool = False, + settings: dict[str, Any] | None = None, + data: Sequence[Sequence[Any]] | None = None, + transport_settings: dict[str, str] | None = None, + ) -> InsertContext: """ Builds a reusable insert context to hold state for a duration of an insert :param table: Target table @@ -1127,23 +1141,24 @@ class Client(ABC): :return: Reusable insert context """ full_table = table - if '.' not in table: + if "." not in table: if database: - full_table = f'{quote_identifier(database)}.{quote_identifier(table)}' + full_table = f"{quote_identifier(database)}.{quote_identifier(table)}" else: full_table = quote_identifier(table) column_defs = [] if column_types is None and column_type_names is None: - describe_result = self.query(f'DESCRIBE TABLE {full_table}', settings=settings) - column_defs = [ColumnDef(**row) for row in describe_result.named_results() - if row['default_type'] not in ('ALIAS', 'MATERIALIZED')] - if column_names is None or isinstance(column_names, str) and column_names == '*': + describe_result = self.query(f"DESCRIBE TABLE {full_table}", settings=settings) + column_defs = [ + ColumnDef(**row) for row in describe_result.named_results() if row["default_type"] not in ("ALIAS", "MATERIALIZED") + ] + if column_names is None or isinstance(column_names, str) and column_names == "*": column_names = [cd.name for cd in column_defs] column_types = [cd.ch_type for cd in column_defs] elif isinstance(column_names, str): column_names = [column_names] if len(column_names) == 0: - raise ValueError('Column names must be specified for insert') + raise ValueError("Column names must be specified for insert") if not column_types: if column_type_names: column_types = [get_from_name(name) for name in column_type_names] @@ -1152,16 +1167,18 @@ class Client(ABC): try: column_types = [column_map[name].ch_type for name in column_names] except KeyError as ex: - raise ProgrammingError(f'Unrecognized column {ex} in table {table}') from None + raise ProgrammingError(f"Unrecognized column {ex} in table {table}") from None if len(column_names) != len(column_types): - raise ProgrammingError('Column names do not match column types') from None - return InsertContext(full_table, - column_names, - column_types, - column_oriented=column_oriented, - settings=settings, - transport_settings=transport_settings, - data=data) + raise ProgrammingError("Column names do not match column types") from None + return InsertContext( + full_table, + column_names, + column_types, + column_oriented=column_oriented, + settings=settings, + transport_settings=transport_settings, + data=data, + ) def min_version(self, version_str: str) -> bool: """ @@ -1172,13 +1189,14 @@ class Client(ABC): :return: True if version_str is greater than the server_version, False if less than """ try: - server_parts = [int(x) for x in self.server_version.split('.') if x.isnumeric()] + server_parts = [int(x) for x in self.server_version.split(".") if x.isnumeric()] server_parts.extend([0] * (4 - len(server_parts))) - version_parts = [int(x) for x in version_str.split('.')] + version_parts = [int(x) for x in version_str.split(".")] version_parts.extend([0] * (4 - len(version_parts))) except ValueError: - logger.warning('Server %s or requested version %s does not match format of numbers separated by dots', - self.server_version, version_str) + logger.warning( + "Server %s or requested version %s does not match format of numbers separated by dots", self.server_version, version_str + ) return False for x, y in zip(server_parts, version_parts): if x > y: @@ -1187,7 +1205,6 @@ class Client(ABC): return False return True - # pylint: disable=no-self-use def _add_integration_tag(self, name: str) -> None: """Transport hook to surface 3rd party lib integration info (default: no-op).""" return @@ -1201,13 +1218,16 @@ class Client(ABC): """ @abstractmethod - def raw_insert(self, table: str, - column_names: Optional[Sequence[str]] = None, - insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None, - settings: Optional[Dict] = None, - fmt: Optional[str] = None, - compression: Optional[str] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + def raw_insert( + self, + table: str, + column_names: Sequence[str] | None = None, + insert_block: str | bytes | Generator[bytes, None, None] | BinaryIO = None, + settings: dict | None = None, + fmt: str | None = None, + compression: str | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: """ Insert data already formatted in a bytes object :param table: Table name (whether qualified with the database name or not) @@ -1233,11 +1253,11 @@ class Client(ABC): def _context_query(self, lcls: dict, **overrides): kwargs = lcls.copy() - kwargs.pop('self') + kwargs.pop("self") kwargs.update(overrides) - return self._query_with_context((self.create_query_context(**kwargs))) + return self._query_with_context(self.create_query_context(**kwargs)) - def __enter__(self) -> 'Client': + def __enter__(self) -> Client: return self def __exit__(self, exc_type, exc_value, exc_traceback) -> None: diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/common.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/common.py index 15ff2bdf4b6..a4acbd292df 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/common.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/common.py @@ -1,27 +1,27 @@ import array +import asyncio import struct import sys +from collections.abc import Callable, Generator, MutableSequence, Sequence +from typing import Any -from typing import Any, Sequence, MutableSequence, Dict, Optional, Union, Generator, Callable - -from clickhouse_connect.driver.exceptions import ProgrammingError, StreamClosedError, DataError +from clickhouse_connect.driver.exceptions import DataError, ProgrammingError, StreamClosedError from clickhouse_connect.driver.types import Closable -# pylint: disable=invalid-name -must_swap = sys.byteorder == 'big' -int_size = array.array('i').itemsize +must_swap = sys.byteorder == "big" +int_size = array.array("i").itemsize low_card_version = 1 -array_map = {1: 'b', 2: 'h', 4: 'i', 8: 'q'} +array_map = {1: "b", 2: "h", 4: "i", 8: "q"} decimal_prec = {32: 9, 64: 18, 128: 38, 256: 79} if int_size == 2: - array_map[4] = 'l' + array_map[4] = "l" array_sizes = {v: k for k, v in array_map.items()} -array_sizes['f'] = 4 -array_sizes['d'] = 8 -np_date_types = {0: '[s]', 3: '[ms]', 6: '[us]', 9: '[ns]'} +array_sizes["f"] = 4 +array_sizes["d"] = 8 +np_date_types = {0: "[s]", 3: "[ms]", 6: "[us]", 9: "[ns]"} def array_type(size: int, signed: bool): @@ -38,7 +38,7 @@ def array_type(size: int, signed: bool): return code if signed else code.upper() -def write_array(code: str, column: Sequence, dest: MutableSequence, col_name: Optional[str]=None): +def write_array(code: str, column: Sequence, dest: MutableSequence, col_name: str | None = None): """ Write a column of native Python data matching the array.array code :param code: Python array.array code matching the column data type @@ -47,14 +47,17 @@ def write_array(code: str, column: Sequence, dest: MutableSequence, col_name: Op :param col_name: Optional column name for error tracking """ try: - buff = struct.Struct(f'<{len(column)}{code}') + buff = struct.Struct(f"<{len(column)}{code}") dest += buff.pack(*column) except (TypeError, OverflowError, struct.error) as ex: - col_msg = '' - if col_name: - col_msg = f' for source column `{col_name}`' - raise DataError(f'Unable to create Python array{col_msg}. This is usually caused by trying to insert None ' + - 'values into a ClickHouse column that is not Nullable') from ex + col_msg = f" for column `{col_name}`" if col_name else "" + if isinstance(ex, OverflowError): + error_detail = "value out of range" + elif isinstance(ex, TypeError): + error_detail = "type mismatch (usually None in non-Nullable column)" + else: + error_detail = type(ex).__name__ + raise DataError(f"Unable to create native array{col_msg}: {error_detail}") from ex def write_uint64(value: int, dest: MutableSequence): @@ -63,7 +66,7 @@ def write_uint64(value: int, dest: MutableSequence): :param value: UInt64 value to write :param dest: Destination byte buffer """ - dest.extend(value.to_bytes(8, 'little')) + dest.extend(value.to_bytes(8, "little")) def write_leb128(value: int, dest: MutableSequence): @@ -73,7 +76,7 @@ def write_leb128(value: int, dest: MutableSequence): :param dest: Target buffer """ while True: - b = value & 0x7f + b = value & 0x7F value >>= 7 if value == 0: dest.append(b) @@ -88,7 +91,7 @@ def decimal_size(prec: int): :return: Required bit size """ if prec < 1 or prec > 79: - raise ArithmeticError(f'Invalid precision {prec} for ClickHouse Decimal type') + raise ArithmeticError(f"Invalid precision {prec} for ClickHouse Decimal type") if prec < 10: return 32 if prec < 19: @@ -99,19 +102,19 @@ def decimal_size(prec: int): def unescape_identifier(x: str) -> str: - if x.startswith('`') and x.endswith('`'): + if x.startswith("`") and x.endswith("`"): return x[1:-1] return x -def dict_copy(source: Dict = None, update: Optional[Dict] = None) -> Dict: +def dict_copy(source: dict = None, update: dict | None = None) -> dict: copy = source.copy() if source else {} if update: copy.update(update) return copy -def dict_add(source: Dict, key: str, value: Any) -> Dict: +def dict_add(source: dict, key: str, value: Any) -> dict: if value is not None: source[key] = value return source @@ -121,19 +124,19 @@ def empty_gen(): yield from () -def coerce_int(val: Optional[Union[str, int]]) -> int: +def coerce_int(val: str | int | None) -> int: if not val: return 0 return int(val) -def coerce_bool(val: Optional[Union[str, bool]]) -> bool: +def coerce_bool(val: str | bool | None) -> bool: if not val: return False - return val is True or (isinstance(val, str) and val.lower() in ('true', '1', 'y', 'yes')) + return val is True or (isinstance(val, str) and val.lower() in ("true", "1", "y", "yes")) -def first_value(column: Sequence, nullable:bool = True): +def first_value(column: Sequence, nullable: bool = True): if nullable: return next((x for x in column if x is not None), None) if len(column): @@ -147,9 +150,10 @@ class SliceView(Sequence): https://gist.github.com/mathieucaroff/0cf094325fb5294fb54c6a577f05a2c1 Also see the discussion on SO: https://stackoverflow.com/questions/3485475/can-i-create-a-view-on-a-python-list """ - slots = ('_source', '_range') - def __init__(self, source: Sequence, source_slice: Optional[slice] = None): + slots = ("_source", "_range") + + def __init__(self, source: Sequence, source_slice: slice | None = None): if isinstance(source, SliceView): self._source = source._source self._range = source._range[source_slice] @@ -174,7 +178,7 @@ class SliceView(Sequence): def __repr__(self): r = self._range - return f'SliceView({self._source[slice(r.start, r.stop, r.step)]})' + return f"SliceView({self._source[slice(r.start, r.stop, r.step)]})" def __eq__(self, other): if self is other: @@ -190,9 +194,11 @@ class SliceView(Sequence): class StreamContext: """ Wraps a generator and its "source" in a Context. This ensures that the source will be "closed" even if the - generator is not fully consumed or there is an exception during consumption + generator is not fully consumed or there is an exception during consumption. Supports both synchronous and + asynchronous usage. """ - __slots__ = 'source', 'gen', '_in_context' + + __slots__ = "source", "gen", "_in_context" def __init__(self, source: Closable, gen: Generator): self.source = source @@ -204,7 +210,7 @@ class StreamContext: def __next__(self): if not self._in_context: - raise ProgrammingError('Stream should be used within a context') + raise ProgrammingError("Stream should be used within a context") return next(self.gen) def __enter__(self): @@ -218,8 +224,59 @@ class StreamContext: self.source.close() self.gen = None -# pylint: disable=too-many-return-statements -def get_rename_method(method: Optional[str]) -> Optional[Callable[[str], str]]: + def __aiter__(self): + return self + + async def __anext__(self): + if not self._in_context: + raise ProgrammingError("Stream should be used within a context") + try: + if hasattr(self.gen, "__anext__"): + return await self.gen.__anext__() + + def _next_wrapper(): + try: + return True, self.gen.__next__() + except StopIteration: + return False, None + + loop = asyncio.get_running_loop() + has_value, value = await loop.run_in_executor(None, _next_wrapper) + if not has_value: + raise StopAsyncIteration from None + return value + except (StopAsyncIteration, StopIteration): + raise StopAsyncIteration from None + except Exception as ex: + if not isinstance(ex, StreamClosedError): + self._in_context = False + if hasattr(self.source, "close"): + if hasattr(self.source.close, "__await__"): + await self.source.close() + else: + self.source.close() + self.gen = None + raise ex + + async def __aenter__(self): + if not self.gen: + raise StreamClosedError + self._in_context = True + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._in_context = False + if hasattr(self.source, "aclose"): + await self.source.aclose() + elif hasattr(self.source, "close"): + if hasattr(self.source.close, "__await__"): + await self.source.close() + else: + self.source.close() + self.gen = None + + +def get_rename_method(method: str | None) -> Callable[[str], str] | None: def _to_camel(s: str) -> str: if not s: return "" diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/compression.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/compression.py index db69ae3f040..802b6e815c2 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/compression.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/compression.py @@ -1,6 +1,5 @@ import zlib from abc import abstractmethod -from typing import Union import lz4 import lz4.frame @@ -12,11 +11,11 @@ except ImportError: brotli = None -available_compression = ['lz4', 'zstd'] +available_compression = ["lz4", "zstd"] if brotli: - available_compression.append('br') -available_compression.extend(['gzip', 'deflate']) + available_compression.append("br") +available_compression.extend(["gzip", "deflate"]) comp_map = {} @@ -26,14 +25,14 @@ class Compressor: comp_map[tag] = cls() if thread_safe else cls @abstractmethod - def compress_block(self, block) -> Union[bytes, bytearray]: + def compress_block(self, block) -> bytes | bytearray: return block def flush(self): pass -class GzipCompressor(Compressor, tag='gzip', thread_safe=False): +class GzipCompressor(Compressor, tag="gzip", thread_safe=False): def __init__(self, level: int = 6, wbits: int = 31): self.zlib_obj = zlib.compressobj(level=level, wbits=wbits) @@ -44,7 +43,7 @@ class GzipCompressor(Compressor, tag='gzip', thread_safe=False): return self.zlib_obj.flush() -class Lz4Compressor(Compressor, tag='lz4', thread_safe=False): +class Lz4Compressor(Compressor, tag="lz4", thread_safe=False): def __init__(self): self.comp = lz4.frame.LZ4FrameCompressor() @@ -54,12 +53,12 @@ class Lz4Compressor(Compressor, tag='lz4', thread_safe=False): return output + self.comp.flush() -class ZstdCompressor(Compressor, tag='zstd'): +class ZstdCompressor(Compressor, tag="zstd"): def compress_block(self, block): return zstandard.compress(block) -class BrotliCompressor(Compressor, tag='br'): +class BrotliCompressor(Compressor, tag="br"): def compress_block(self, block): return brotli.compress(block) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/constants.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/constants.py index a242e559b94..eda8818908a 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/constants.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/constants.py @@ -1,2 +1,2 @@ PROTOCOL_VERSION_WITH_LOW_CARD = 54405 -CH_VERSION_WITH_PROTOCOL = '23.2.1.2537' +CH_VERSION_WITH_PROTOCOL = "23.2.1.2537" diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/context.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/context.py index b5d4fbeab5f..50b8315a0db 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/context.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/context.py @@ -1,40 +1,40 @@ import logging import re -from typing import Optional, Dict, Union, Any, Callable +from collections.abc import Callable +from typing import Any logger = logging.getLogger(__name__) _empty_map = {} -# pylint: disable=too-many-instance-attributes class BaseQueryContext: - - def __init__(self, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_extended_dtypes: bool = False, - use_numpy: bool = False, - transport_settings: Optional[Dict[str, str]] = None): + def __init__( + self, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_extended_dtypes: bool = False, + use_numpy: bool = False, + transport_settings: dict[str, str] | None = None, + ): self.settings = settings or {} if query_formats is None: self.type_formats = _empty_map else: - self.type_formats = {re.compile(type_name.replace('*', '.*'), re.IGNORECASE): fmt - for type_name, fmt in query_formats.items()} + self.type_formats = {re.compile(type_name.replace("*", ".*"), re.IGNORECASE): fmt for type_name, fmt in query_formats.items()} if column_formats is None: self.col_simple_formats = _empty_map self.col_type_formats = _empty_map else: - self.col_simple_formats = {col_name: fmt for col_name, fmt in column_formats.items() if - isinstance(fmt, str)} + self.col_simple_formats = {col_name: fmt for col_name, fmt in column_formats.items() if isinstance(fmt, str)} self.col_type_formats = {} for col_name, fmt in column_formats.items(): if not isinstance(fmt, str): - self.col_type_formats[col_name] = {re.compile(type_name.replace('*', '.*'), re.IGNORECASE): fmt - for type_name, fmt in fmt.items()} + self.col_type_formats[col_name] = { + re.compile(type_name.replace("*", ".*"), re.IGNORECASE): fmt for type_name, fmt in fmt.items() + } self.query_formats = query_formats or {} self.column_formats = column_formats or {} self.transport_settings = transport_settings @@ -44,7 +44,7 @@ class BaseQueryContext: self.use_extended_dtypes = use_extended_dtypes self._active_col_fmt = None self._active_col_type_fmts = _empty_map - self.column_renamer: Optional[Callable[[str], str]] = None + self.column_renamer: Callable[[str], str] | None = None def start_column(self, name: str): self.column_name = name diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/ctypes.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/ctypes.py index 6bb9e35f7a0..00ddcf6cd53 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/ctypes.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/ctypes.py @@ -13,38 +13,35 @@ data_conv = pydc # numpy_conv is resolved lazily via __getattr__ to avoid eagerly importing numpy -# pylint: disable=import-outside-toplevel,global-statement - def connect_c_modules(): - if not coerce_bool(os.environ.get('CLICKHOUSE_CONNECT_USE_C', True)): - logger.info('ClickHouse Connect C optimizations disabled') + if not coerce_bool(os.environ.get("CLICKHOUSE_CONNECT_USE_C", True)): + logger.info("ClickHouse Connect C optimizations disabled") return global RespBuffCls, data_conv try: - from clickhouse_connect.driverc.buffer import ResponseBuffer as CResponseBuffer import clickhouse_connect.driverc.dataconv as cdc + from clickhouse_connect.driverc.buffer import ResponseBuffer as CResponseBuffer data_conv = cdc RespBuffCls = CResponseBuffer - logger.debug('Successfully imported ClickHouse Connect C data optimizations') + logger.debug("Successfully imported ClickHouse Connect C data optimizations") except ImportError as ex: - logger.warning('Unable to connect optimized C data functions [%s], falling back to pure Python', - str(ex)) + logger.warning("Unable to connect optimized C data functions [%s], falling back to pure Python", str(ex)) def _resolve_numpy_conv(): if "numpy_conv" in globals(): return - if coerce_bool(os.environ.get('CLICKHOUSE_CONNECT_USE_C', True)): + if coerce_bool(os.environ.get("CLICKHOUSE_CONNECT_USE_C", True)): try: import clickhouse_connect.driverc.npconv as cnc + globals()["numpy_conv"] = cnc - logger.debug('Successfully import ClickHouse Connect C/Numpy optimizations') + logger.debug("Successfully import ClickHouse Connect C/Numpy optimizations") return except ImportError as ex: - logger.debug('Unable to connect ClickHouse Connect C to Numpy API [%s], falling back to pure Python', - str(ex)) + logger.debug("Unable to connect ClickHouse Connect C to Numpy API [%s], falling back to pure Python", str(ex)) globals()["numpy_conv"] = pync diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/dataconv.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/dataconv.py index bea56f2a1dd..5a6d1786f8a 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/dataconv.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/dataconv.py @@ -1,40 +1,40 @@ import array - -from datetime import datetime, date, tzinfo +from collections.abc import Sequence +from datetime import date, datetime, tzinfo from ipaddress import IPv4Address -from typing import Sequence, Optional, Any +from typing import Any from uuid import UUID, SafeUUID -from clickhouse_connect.driver import tzutil -from clickhouse_connect.driver.common import int_size +from clickhouse_connect.driver import options, tzutil +from clickhouse_connect.driver.common import int_size, must_swap, write_array from clickhouse_connect.driver.errors import NONE_IN_NULLABLE_COLUMN from clickhouse_connect.driver.types import ByteSource -from clickhouse_connect.driver import options - MONTH_DAYS = (0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365) MONTH_DAYS_LEAP = (0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366) def read_ipv4_col(source: ByteSource, num_rows: int): - column = source.read_array('I', num_rows) + column = source.read_array("I", num_rows) fast_ip_v4 = IPv4Address.__new__ new_col = [] app = new_col.append for x in column: ipv4 = fast_ip_v4(IPv4Address) - ipv4._ip = x # pylint: disable=protected-access + ipv4._ip = x app(ipv4) return new_col -def read_datetime_col(source: ByteSource, num_rows: int, tz_info: Optional[tzinfo]): - src_array = source.read_array('I', num_rows) +def read_datetime_col(source: ByteSource, num_rows: int, tz_info: tzinfo | None): + src_array = source.read_array("I", num_rows) if tz_info is None: - fts = tzutil.utcfromtimestamp - return [fts(ts) for ts in src_array] - fts = datetime.fromtimestamp - return [fts(ts, tz_info) for ts in src_array] + return [tzutil.utcfromtimestamp(ts) for ts in src_array] + elif tzutil.is_utc_timezone(tz_info): + return [tzutil.utc_equivalent_tzaware_datetime(ts, 0, tz_info) for ts in src_array] + else: + fts = datetime.fromtimestamp + return [fts(ts, tz_info) for ts in src_array] def epoch_days_to_date(days: int) -> date: @@ -53,17 +53,52 @@ def epoch_days_to_date(days: int) -> date: def read_date_col(source: ByteSource, num_rows: int): - column = source.read_array('H', num_rows) + column = source.read_array("H", num_rows) return [epoch_days_to_date(x) for x in column] def read_date32_col(source: ByteSource, num_rows: int): - column = source.read_array('l' if int_size == 2 else 'i', num_rows) + column = source.read_array("l" if int_size == 2 else "i", num_rows) return [epoch_days_to_date(x) for x in column] +def read_datetime64_naive_col(column: Sequence, prec: int, tz: tzinfo | None = None): + """Read DateTime64 column using epoch arithmetic, for naive UTC or UTC-equivalent timezones. + + When tz is None, the result is naive. When tz is a UTC-equivalent timezone, the + same arithmetic path is used and the tz is attached to the constructed datetime. + """ + result = [] + for ticks in column: + seconds, fractional_ticks = divmod(ticks, prec) + microseconds = (fractional_ticks * 1000000) // prec + if tz is None: + dt = tzutil.utcfromtimestamp_with_microseconds(seconds, microseconds) + else: + dt = tzutil.utc_equivalent_tzaware_datetime(seconds, microseconds, tz) + result.append(dt) + return result + + +def read_datetime64_tz_col(column: Sequence, prec: int, tz_info: tzinfo): + """Read DateTime64 column with non-UTC timezone conversion. + + Constructs datetime objects with the specified timezone and microseconds. + """ + result = [] + dt_from = datetime.fromtimestamp + for ticks in column: + seconds, fractional_ticks = divmod(ticks, prec) + microseconds = (fractional_ticks * 1000000) // prec + v = dt_from(seconds, tz_info) + if microseconds != 0: + v = v.replace(microsecond=microseconds) + result.append(v) + return result + + def read_uuid_col(source: ByteSource, num_rows: int): - v = source.read_array('Q', num_rows * 2) + v = source.read_array("Q", num_rows * 2) empty_uuid = UUID(int=0) new_uuid = UUID.__new__ unsafe = SafeUUID.unsafe @@ -77,8 +112,8 @@ def read_uuid_col(source: ByteSource, num_rows: int): app(empty_uuid) else: fast_uuid = new_uuid(UUID) - oset(fast_uuid, 'int', int_value) - oset(fast_uuid, 'is_safe', unsafe) + oset(fast_uuid, "int", int_value) + oset(fast_uuid, "is_safe", unsafe) app(fast_uuid) return column @@ -111,10 +146,10 @@ def to_numpy_array(column: Sequence): def pivot(data: Sequence[Sequence], start_row: int, end_row: int) -> Sequence[Sequence]: - return tuple(zip(*data[start_row: end_row])) + return tuple(zip(*data[start_row:end_row])) -def write_str_col(column: Sequence, nullable: bool, encoding: Optional[str], dest: bytearray) -> int: +def write_str_col(column: Sequence, nullable: bool, encoding: str | None, dest: bytearray) -> int: app = dest.append for x in column: if not x: @@ -128,7 +163,7 @@ def write_str_col(column: Sequence, nullable: bool, encoding: Optional[str], des x = bytes(x) sz = len(x) while True: - b = sz & 0x7f + b = sz & 0x7F sz >>= 7 if sz == 0: app(b) @@ -136,3 +171,34 @@ def write_str_col(column: Sequence, nullable: bool, encoding: Optional[str], des app(0x80 | b) dest += x return 0 + + +def write_native_col(code: str, column: Sequence, dest: bytearray, col_name: object = None) -> int: + """ + Pure Python fallback for write_native_col. + Delegates to write_array which uses struct.pack. + """ + write_array(code, column, dest, col_name) + return 0 + + +def build_map_columns(column: Sequence, dest: bytearray): + """ + Pure Python fallback for build_map_columns. + Flattens dicts into keys/values lists and writes UInt64 offsets into dest. + """ + offsets = array.array("Q") + total = 0 + for v in column: + total += len(v) + offsets.append(total) + if must_swap: + offsets.byteswap() + dest += offsets.tobytes() + keys = [] + values = [] + for v in column: + for k, val in v.items(): + keys.append(k) + values.append(val) + return keys, values diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/ddl.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/ddl.py index 65cd24ed78f..d6137d124b4 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/ddl.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/ddl.py @@ -1,13 +1,20 @@ -from typing import NamedTuple, Sequence +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, NamedTuple from clickhouse_connect.datatypes.base import ClickHouseType from clickhouse_connect.driver.options import check_arrow +if TYPE_CHECKING: + import pyarrow + class TableColumnDef(NamedTuple): """ Simplified ClickHouse Table Column definition for DDL """ + name: str ch_type: ClickHouseType expr_type: str = None @@ -15,9 +22,9 @@ class TableColumnDef(NamedTuple): @property def col_expr(self): - expr = f'{self.name} {self.ch_type.name}' + expr = f"{self.name} {self.ch_type.name}" if self.expr_type: - expr += f' {self.expr_type} {self.expr}' + expr += f" {self.expr_type} {self.expr}" return expr @@ -25,11 +32,11 @@ def create_table(table_name: str, columns: Sequence[TableColumnDef], engine: str stmt = f"CREATE TABLE {table_name} ({', '.join(col.col_expr for col in columns)}) ENGINE {engine} " if engine_params: for key, value in engine_params.items(): - stmt += f' {key} {value}' + stmt += f" {key} {value}" return stmt -def _arrow_type_to_ch(arrow_type: "pa.DataType") -> str: # pylint: disable=too-many-return-statements,too-many-branches +def _arrow_type_to_ch(arrow_type: pyarrow.DataType) -> str: """ Best-effort mapping from common PyArrow types to ClickHouse type names. @@ -42,64 +49,64 @@ def _arrow_type_to_ch(arrow_type: "pa.DataType") -> str: # pylint: disable=too-m # Signed ints if pat.is_int8(arrow_type): - return 'Int8' + return "Int8" if pat.is_int16(arrow_type): - return 'Int16' + return "Int16" if pat.is_int32(arrow_type): - return 'Int32' + return "Int32" if pat.is_int64(arrow_type): - return 'Int64' + return "Int64" # Unsigned ints if pat.is_uint8(arrow_type): - return 'UInt8' + return "UInt8" if pat.is_uint16(arrow_type): - return 'UInt16' + return "UInt16" if pat.is_uint32(arrow_type): - return 'UInt32' + return "UInt32" if pat.is_uint64(arrow_type): - return 'UInt64' + return "UInt64" # Floats if pat.is_float16(arrow_type) or pat.is_float32(arrow_type): - return 'Float32' + return "Float32" if pat.is_float64(arrow_type): - return 'Float64' + return "Float64" # Boolean if pat.is_boolean(arrow_type): - return 'Bool' + return "Bool" # Dates if pat.is_date32(arrow_type): - return 'Date32' + return "Date32" if pat.is_date64(arrow_type): - return 'DateTime64(3)' + return "DateTime64(3)" # Timestamps → DateTime / DateTime64 if pat.is_timestamp(arrow_type): - unit = getattr(arrow_type, 'unit', 's') - tz = getattr(arrow_type, 'tz', None) + unit = getattr(arrow_type, "unit", "s") + tz = getattr(arrow_type, "tz", None) - if unit == 's': - base = 'DateTime' + if unit == "s": + base = "DateTime" if tz: return f"DateTime('{tz}')" return base - scale_map = {'ms': 3, 'us': 6, 'ns': 9} + scale_map = {"ms": 3, "us": 6, "ns": 9} scale = scale_map.get(unit, 3) if tz: return f"DateTime64({scale}, '{tz}')" - return f'DateTime64({scale})' + return f"DateTime64({scale})" # Strings (this covers pa.string(), pa.large_string()) if pat.is_string(arrow_type) or pat.is_large_string(arrow_type): - return 'String' + return "String" # for any currently unsupported type, we raise so it’s clear that # this Arrow type isn’t supported by the helper yet. - raise TypeError(f'Unsupported Arrow type for automatic mapping: {arrow_type!r}') + raise TypeError(f"Unsupported Arrow type for automatic mapping: {arrow_type!r}") class _DDLType: @@ -110,14 +117,15 @@ class _DDLType: so we'll wrap the ClickHouse type name in this tiny object instead of constructing full ClickHouseType instances here. """ + def __init__(self, name: str): self.name = name -def arrow_schema_to_column_defs(schema: "pa.Schema") -> list[TableColumnDef]: +def arrow_schema_to_column_defs(schema: pyarrow.Schema) -> list[TableColumnDef]: """ Convert a PyArrow Schema into a list of TableColumnDef objects. - + This helper uses an *optimistic non-null* strategy: it always produces non-nullable ClickHouse types, even though Arrow fields are nullable by default. @@ -130,7 +138,7 @@ def arrow_schema_to_column_defs(schema: "pa.Schema") -> list[TableColumnDef]: pa = check_arrow() if not isinstance(schema, pa.Schema): - raise TypeError(f'Expected pyarrow.Schema, got {type(schema)!r}') + raise TypeError(f"Expected pyarrow.Schema, got {type(schema)!r}") col_defs: list[TableColumnDef] = [] for field in schema: @@ -146,7 +154,7 @@ def arrow_schema_to_column_defs(schema: "pa.Schema") -> list[TableColumnDef]: def create_table_from_arrow_schema( table_name: str, - schema: "pa.Schema", + schema: pyarrow.Schema, engine: str, engine_params: dict, ) -> str: diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/errors.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/errors.py index e5b3b6466ec..59313f4cdfc 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/errors.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/errors.py @@ -1,17 +1,16 @@ from clickhouse_connect.driver.context import BaseQueryContext from clickhouse_connect.driver.exceptions import DataError - # Error codes used in the Cython API NO_ERROR = 0 NONE_IN_NULLABLE_COLUMN = 1 -error_messages = {NONE_IN_NULLABLE_COLUMN: 'Invalid None value in non-Nullable column'} +error_messages = {NONE_IN_NULLABLE_COLUMN: "Invalid None value in non-Nullable column"} def handle_error(error_num: int, ctx: BaseQueryContext): if error_num > 0: msg = error_messages[error_num] if ctx.column_name: - msg = f'{msg}, column name: `{ctx.column_name}`' + msg = f"{msg}, column name: `{ctx.column_name}`" raise DataError(msg) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/exceptions.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/exceptions.py index 1cd41f99333..84009a319ab 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/exceptions.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/exceptions.py @@ -9,8 +9,7 @@ class ClickHouseError(Exception): """Exception related to operation with ClickHouse.""" -# pylint: disable=redefined-builtin -class Warning(Warning, ClickHouseError): +class Warning(Warning, ClickHouseError): # noqa: N818 """Exception raised for important warnings like data truncations while inserting, etc.""" @@ -73,12 +72,12 @@ class StreamClosedError(ProgrammingError): """Exception raised when a stream operation is executed on a closed stream.""" def __init__(self): - super().__init__('Executing a streaming operation on a closed stream') + super().__init__("Executing a streaming operation on a closed stream") -class StreamCompleteException(Exception): - """ Internal exception used to indicate the end of a ClickHouse query result stream.""" +class StreamCompleteException(Exception): # noqa: N818 + """Internal exception used to indicate the end of a ClickHouse query result stream.""" class StreamFailureError(Exception): - """ Stream failed unexpectedly """ + """Stream failed unexpectedly""" diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/external.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/external.py index 842ec7d59aa..81636593c5e 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/external.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/external.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Sequence, Dict, Union +from collections.abc import Sequence from pathlib import Path from clickhouse_connect.driver.exceptions import ProgrammingError @@ -8,119 +8,116 @@ logger = logging.getLogger(__name__) class ExternalFile: - # pylint: disable=too-many-branches - def __init__(self, - file_path: Optional[str] = None, - file_name: Optional[str] = None, - data: Optional[bytes] = None, - fmt: Optional[str] = None, - types: Optional[Union[str, Sequence[str]]] = None, - structure: Optional[Union[str, Sequence[str]]] = None, - mime_type: Optional[str] = None): + def __init__( + self, + file_path: str | None = None, + file_name: str | None = None, + data: bytes | None = None, + fmt: str | None = None, + types: str | Sequence[str] | None = None, + structure: str | Sequence[str] | None = None, + mime_type: str | None = None, + ): if file_path: if data: - raise ProgrammingError('Only data or file_path should be specified for external data, not both') + raise ProgrammingError("Only data or file_path should be specified for external data, not both") try: - with open(file_path, 'rb') as file: + with open(file_path, "rb") as file: self.data = file.read() except OSError as ex: - raise ProgrammingError(f'Failed to open file {file_path} for external data') from ex + raise ProgrammingError(f"Failed to open file {file_path} for external data") from ex path_name = Path(file_path).name - path_base = path_name.rsplit('.', maxsplit=1)[0] + path_base = path_name.rsplit(".", maxsplit=1)[0] if not file_name: self.name = path_base self.file_name = path_name else: - self.name = file_name.rsplit('.', maxsplit=1)[0] + self.name = file_name.rsplit(".", maxsplit=1)[0] self.file_name = file_name if file_name != path_name and path_base != self.name: - logger.warning('External data name %s and file_path %s use different names', file_name, path_name) + logger.warning("External data name %s and file_path %s use different names", file_name, path_name) elif data is not None: if not file_name: - raise ProgrammingError('Name is required for query external data') + raise ProgrammingError("Name is required for query external data") self.data = data - self.name = file_name.rsplit('.', maxsplit=1)[0] + self.name = file_name.rsplit(".", maxsplit=1)[0] self.file_name = file_name else: - raise ProgrammingError('Either data or file_path must be specified for external data') + raise ProgrammingError("Either data or file_path must be specified for external data") self.structure = None self.types = None if types: if structure: - raise ProgrammingError('Only types or structure should be specified for external data, not both') + raise ProgrammingError("Only types or structure should be specified for external data, not both") if isinstance(types, str): self.types = types else: - self.types = ','.join(types) + self.types = ",".join(types) elif structure: if isinstance(structure, str): self.structure = structure else: - self.structure = ','.join(structure) + self.structure = ",".join(structure) self.fmt = fmt - self.mime_type = mime_type or 'application/octet-stream' + self.mime_type = mime_type or "application/octet-stream" @property def form_data(self) -> tuple: return self.file_name, self.data, self.mime_type @property - def query_params(self) -> Dict[str, str]: + def query_params(self) -> dict[str, str]: params = {} - for name, value in (('format', self.fmt), - ('structure', self.structure), - ('types', self.types)): + for name, value in (("format", self.fmt), ("structure", self.structure), ("types", self.types)): if value: - params[f'{self.name}_{name}'] = value + params[f"{self.name}_{name}"] = value return params class ExternalData: - def __init__(self, - file_path: Optional[str] = None, - file_name: Optional[str] = None, - data: Optional[bytes] = None, - fmt: Optional[str] = None, - types: Optional[Union[str, Sequence[str]]] = None, - structure: Optional[Union[str, Sequence[str]]] = None, - mime_type: Optional[str] = None): + def __init__( + self, + file_path: str | None = None, + file_name: str | None = None, + data: bytes | None = None, + fmt: str | None = None, + types: str | Sequence[str] | None = None, + structure: str | Sequence[str] | None = None, + mime_type: str | None = None, + ): self.files: list[ExternalFile] = [] if file_path or data is not None: - first_file = ExternalFile(file_path=file_path, - file_name=file_name, - data=data, - fmt=fmt, - types=types, - structure=structure, - mime_type=mime_type) + first_file = ExternalFile( + file_path=file_path, file_name=file_name, data=data, fmt=fmt, types=types, structure=structure, mime_type=mime_type + ) self.files.append(first_file) - def add_file(self, - file_path: Optional[str] = None, - file_name: Optional[str] = None, - data: Optional[bytes] = None, - fmt: Optional[str] = None, - types: Optional[Union[str, Sequence[str]]] = None, - structure: Optional[Union[str, Sequence[str]]] = None, - mime_type: Optional[str] = None): - self.files.append(ExternalFile(file_path=file_path, - file_name=file_name, - data=data, - fmt=fmt, - types=types, - structure=structure, - mime_type=mime_type)) + def add_file( + self, + file_path: str | None = None, + file_name: str | None = None, + data: bytes | None = None, + fmt: str | None = None, + types: str | Sequence[str] | None = None, + structure: str | Sequence[str] | None = None, + mime_type: str | None = None, + ): + self.files.append( + ExternalFile( + file_path=file_path, file_name=file_name, data=data, fmt=fmt, types=types, structure=structure, mime_type=mime_type + ) + ) @property - def form_data(self) -> Dict[str, tuple]: + def form_data(self) -> dict[str, tuple]: if not self.files: - raise ProgrammingError('No external files set for external data') + raise ProgrammingError("No external files set for external data") return {file.name: file.form_data for file in self.files} @property - def query_params(self) -> Dict[str, str]: + def query_params(self) -> dict[str, str]: if not self.files: - raise ProgrammingError('No external files set for external data') + raise ProgrammingError("No external files set for external data") params = {} for file in self.files: params.update(file.query_params) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py index f456e3c9a12..576b3fd59ba 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/httpclient.py @@ -3,10 +3,11 @@ import json import logging import re import uuid +from base64 import b64encode +from collections.abc import Callable, Generator, Sequence from importlib import import_module from importlib.metadata import version as dist_version -from base64 import b64encode -from typing import Literal, Optional, Dict, Any, Sequence, Union, List, Callable, Generator, BinaryIO +from typing import Any, BinaryIO from urllib.parse import urlencode from urllib3 import Timeout @@ -17,129 +18,143 @@ from urllib3.response import HTTPResponse from clickhouse_connect import common from clickhouse_connect.datatypes import registry from clickhouse_connect.datatypes.base import ClickHouseType +from clickhouse_connect.driver.binding import bind_query, quote_identifier from clickhouse_connect.driver.client import Client -from clickhouse_connect.driver.common import dict_copy, coerce_bool, coerce_int, dict_add +from clickhouse_connect.driver.common import coerce_bool, coerce_int, dict_add, dict_copy from clickhouse_connect.driver.compression import available_compression from clickhouse_connect.driver.ctypes import RespBuffCls from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError, ProgrammingError from clickhouse_connect.driver.external import ExternalData -from clickhouse_connect.driver.httputil import ResponseSource, get_pool_manager, get_response_data, \ - default_pool_manager, get_proxy_manager, all_managers, check_env_proxy, check_conn_expiration +from clickhouse_connect.driver.httputil import ( + ResponseSource, + all_managers, + check_conn_expiration, + check_env_proxy, + default_pool_manager, + get_pool_manager, + get_proxy_manager, + get_response_data, +) from clickhouse_connect.driver.insert import InsertContext -from clickhouse_connect.driver.query import QueryResult, QueryContext, TzSource -from clickhouse_connect.driver.binding import quote_identifier, bind_query +from clickhouse_connect.driver.query import QueryContext, QueryResult, TzSource from clickhouse_connect.driver.summary import QuerySummary from clickhouse_connect.driver.transform import NativeTransform logger = logging.getLogger(__name__) -columns_only_re = re.compile(r'LIMIT 0\s*$', re.IGNORECASE) -ex_header = 'X-ClickHouse-Exception-Code' -ex_tag_header = 'X-ClickHouse-Exception-Tag' +columns_only_re = re.compile(r"LIMIT 0\s*$", re.IGNORECASE) +ex_header = "X-ClickHouse-Exception-Code" +ex_tag_header = "X-ClickHouse-Exception-Tag" -# pylint: disable=too-many-instance-attributes class HttpClient(Client): params = {} - valid_transport_settings = {'database', 'buffer_size', 'session_id', - 'compress', 'decompress', 'session_timeout', - 'session_check', 'query_id', 'quota_key', - 'wait_end_of_query', 'client_protocol_version', - 'role'} - optional_transport_settings = {'send_progress_in_http_headers', - 'http_headers_progress_interval_ms', - 'enable_http_compression'} + valid_transport_settings = { + "database", + "buffer_size", + "session_id", + "compress", + "decompress", + "session_timeout", + "session_check", + "query_id", + "quota_key", + "wait_end_of_query", + "client_protocol_version", + "role", + } + optional_transport_settings = {"send_progress_in_http_headers", "http_headers_progress_interval_ms", "enable_http_compression"} _owns_pool_manager = False # R0917: too-many-positional-arguments - # pylint: disable=too-many-arguments,R0917,too-many-locals,too-many-branches,too-many-statements,unused-argument - def __init__(self, - interface: str, - host: str, - port: int, - username: str, - password: str, - database: str, - access_token: Optional[str] = None, - compress: Union[bool, str] = True, - query_limit: int = 0, - query_retries: int = 2, - connect_timeout: int = 10, - send_receive_timeout: int = 300, - client_name: Optional[str] = None, - verify: Union[bool, str] = True, - ca_cert: Optional[str] = None, - client_cert: Optional[str] = None, - client_cert_key: Optional[str] = None, - session_id: Optional[str] = None, - settings: Optional[Dict[str, Any]] = None, - pool_mgr: Optional[PoolManager] = None, - http_proxy: Optional[str] = None, - https_proxy: Optional[str] = None, - server_host_name: Optional[str] = None, - tz_source: Optional[TzSource] = None, - tz_mode: Optional[str] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - apply_server_timezone: Optional[Union[str, bool]] = None, - show_clickhouse_errors: Optional[bool] = None, - autogenerate_session_id: Optional[bool] = None, - autogenerate_query_id: Optional[bool] = None, - tls_mode: Optional[str] = None, - proxy_path: str = '', - form_encode_query_params: bool = False, - rename_response_column: Optional[str] = None): + + def __init__( + self, + interface: str, + host: str, + port: int, + username: str, + password: str, + database: str, + access_token: str | None = None, + compress: bool | str = True, + query_limit: int = 0, + query_retries: int = 2, + connect_timeout: int = 10, + send_receive_timeout: int = 300, + client_name: str | None = None, + verify: bool | str = True, + ca_cert: str | None = None, + client_cert: str | None = None, + client_cert_key: str | None = None, + session_id: str | None = None, + settings: dict[str, Any] | None = None, + pool_mgr: PoolManager | None = None, + http_proxy: str | None = None, + https_proxy: str | None = None, + server_host_name: str | None = None, + tz_source: TzSource | None = None, + tz_mode: str | None = None, + show_clickhouse_errors: bool | None = None, + autogenerate_session_id: bool | None = None, + autogenerate_query_id: bool | None = None, + tls_mode: str | None = None, + proxy_path: str = "", + form_encode_query_params: bool = False, + rename_response_column: str | None = None, + ): """ Create an HTTP ClickHouse Connect client See clickhouse_connect.get_client for parameters """ - proxy_path = proxy_path.lstrip('/') + proxy_path = proxy_path.lstrip("/") if proxy_path: - proxy_path = '/' + proxy_path - self.url = f'{interface}://{host}:{port}{proxy_path}' + proxy_path = "/" + proxy_path + self.url = f"{interface}://{host}:{port}{proxy_path}" self.headers = {} self.form_encode_query_params = form_encode_query_params self.params = dict_copy(HttpClient.params) ch_settings = dict_copy(settings, self.params) self.http = pool_mgr - if interface == 'https': - if isinstance(verify, str) and verify.lower() == 'proxy': + if interface == "https": + if isinstance(verify, str) and verify.lower() == "proxy": verify = True - tls_mode = tls_mode or 'proxy' + tls_mode = tls_mode or "proxy" if not https_proxy: - https_proxy = check_env_proxy('https', host, port) + https_proxy = check_env_proxy("https", host, port) verify = coerce_bool(verify) - if client_cert and (tls_mode is None or tls_mode == 'mutual'): + if client_cert and (tls_mode is None or tls_mode == "mutual"): if not username: - raise ProgrammingError('username parameter is required for Mutual TLS authentication') - self.headers['X-ClickHouse-User'] = username - self.headers['X-ClickHouse-SSL-Certificate-Auth'] = 'on' - # pylint: disable=too-many-boolean-expressions + raise ProgrammingError("username parameter is required for Mutual TLS authentication") + self.headers["X-ClickHouse-User"] = username + self.headers["X-ClickHouse-SSL-Certificate-Auth"] = "on" + if not self.http and (server_host_name or ca_cert or client_cert or not verify or https_proxy): - options = {'verify': verify} - dict_add(options, 'ca_cert', ca_cert) - dict_add(options, 'client_cert', client_cert) - dict_add(options, 'client_cert_key', client_cert_key) + options = {"verify": verify} + dict_add(options, "ca_cert", ca_cert) + dict_add(options, "client_cert", client_cert) + dict_add(options, "client_cert_key", client_cert_key) if server_host_name: - if options['verify']: - options['assert_hostname'] = server_host_name - options['server_hostname'] = server_host_name + if options["verify"]: + options["assert_hostname"] = server_host_name + options["server_hostname"] = server_host_name self.http = get_pool_manager(https_proxy=https_proxy, **options) self._owns_pool_manager = True if not self.http: if not http_proxy: - http_proxy = check_env_proxy('http', host, port) + http_proxy = check_env_proxy("http", host, port) if http_proxy: self.http = get_proxy_manager(host, http_proxy) else: self.http = default_pool_manager() if access_token: - self.headers['Authorization'] = f'Bearer {access_token}' - elif (not client_cert or tls_mode in ('strict', 'proxy')) and username: - self.headers['Authorization'] = 'Basic ' + b64encode(f'{username}:{password}'.encode()).decode() + self.headers["Authorization"] = f"Bearer {access_token}" + elif (not client_cert or tls_mode in ("strict", "proxy")) and username: + self.headers["Authorization"] = "Basic " + b64encode(f"{username}:{password}".encode()).decode() self._reported_libs = set() - self.headers['User-Agent'] = common.build_client_name(client_name) - self._read_format = self._write_format = 'Native' + self.headers["User-Agent"] = common.build_client_name(client_name) + self._read_format = self._write_format = "Native" self._transform = NativeTransform() # There are use cases when the client needs to disable timeouts. @@ -156,75 +171,78 @@ class HttpClient(Client): self._rename_response_column = rename_response_column # allow to override the global autogenerate_session_id setting via the constructor params - _autogenerate_session_id = common.get_setting('autogenerate_session_id') \ - if autogenerate_session_id is None \ - else autogenerate_session_id + _autogenerate_session_id = ( + common.get_setting("autogenerate_session_id") if autogenerate_session_id is None else autogenerate_session_id + ) if session_id: - ch_settings['session_id'] = session_id - elif 'session_id' not in ch_settings and _autogenerate_session_id: - ch_settings['session_id'] = str(uuid.uuid4()) + ch_settings["session_id"] = session_id + elif "session_id" not in ch_settings and _autogenerate_session_id: + ch_settings["session_id"] = str(uuid.uuid4()) # allow to override the global autogenerate_query_id setting via the constructor params - self._autogenerate_query_id = common.get_setting('autogenerate_query_id') \ - if autogenerate_query_id is None \ - else autogenerate_query_id + self._autogenerate_query_id = ( + common.get_setting("autogenerate_query_id") if autogenerate_query_id is None else autogenerate_query_id + ) if coerce_bool(compress): - compression = ','.join(available_compression) + compression = ",".join(available_compression) self.write_compression = available_compression[0] - elif compress and compress not in ('False', 'false', '0'): + elif compress and compress not in ("False", "false", "0"): if compress not in available_compression: - raise ProgrammingError(f'Unsupported compression method {compress}') + raise ProgrammingError(f"Unsupported compression method {compress}") compression = compress self.write_compression = compress else: compression = None - super().__init__(database=database, - uri=self.url, - query_limit=query_limit, - query_retries=query_retries, - server_host_name=server_host_name, - tz_source=tz_source, - tz_mode=tz_mode, - utc_tz_aware=utc_tz_aware, - apply_server_timezone=apply_server_timezone, - show_clickhouse_errors=show_clickhouse_errors) + super().__init__( + database=database, + uri=self.url, + query_limit=query_limit, + query_retries=query_retries, + server_host_name=server_host_name, + tz_source=tz_source, + tz_mode=tz_mode, + show_clickhouse_errors=show_clickhouse_errors, + autoconnect=True, + ) self.params = dict_copy(self.params, self._validate_settings(ch_settings)) cancel_setting = self._setting_status("cancel_http_readonly_queries_on_client_close") - if cancel_setting.is_writable and not cancel_setting.is_set and \ - "cancel_http_readonly_queries_on_client_close" not in (settings or {}): + if ( + cancel_setting.is_writable + and not cancel_setting.is_set + and "cancel_http_readonly_queries_on_client_close" not in (settings or {}) + ): self.params["cancel_http_readonly_queries_on_client_close"] = "1" - comp_setting = self._setting_status('enable_http_compression') + comp_setting = self._setting_status("enable_http_compression") self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable if comp_setting.is_set or comp_setting.is_writable: self.compression = compression - send_setting = self._setting_status('send_progress_in_http_headers') + send_setting = self._setting_status("send_progress_in_http_headers") self._send_progress = not send_setting.is_set and send_setting.is_writable - if (send_setting.is_set or send_setting.is_writable) and \ - self._setting_status('http_headers_progress_interval_ms').is_writable: + if (send_setting.is_set or send_setting.is_writable) and self._setting_status("http_headers_progress_interval_ms").is_writable: self._progress_interval = str(min(120000, max(10000, (send_receive_timeout - 5) * 1000))) def set_client_setting(self, key: str, value: Any) -> None: - str_value = self._validate_setting(key, value, common.get_setting('invalid_setting_action')) + str_value = self._validate_setting(key, value, common.get_setting("invalid_setting_action")) if str_value is not None: self.params[key] = str_value - def get_client_setting(self, key: str) -> Optional[str]: + def get_client_setting(self, key: str) -> str | None: return self.params.get(key) def set_access_token(self, access_token: str) -> None: - auth_header = self.headers.get('Authorization') - if auth_header and not auth_header.startswith('Bearer'): - raise ProgrammingError('Cannot set access token when a different auth type is used') - self.headers['Authorization'] = f'Bearer {access_token}' + auth_header = self.headers.get("Authorization") + if auth_header and not auth_header.startswith("Bearer"): + raise ProgrammingError("Cannot set access token when a different auth type is used") + self.headers["Authorization"] = f"Bearer {access_token}" def _prep_query(self, context: QueryContext): final_query = super()._prep_query(context) if context.is_insert: return final_query - fmt = f'\n FORMAT {self._read_format}' + fmt = f"\n FORMAT {self._read_format}" if isinstance(final_query, bytes): return final_query + fmt.encode() return final_query + fmt @@ -233,84 +251,85 @@ class HttpClient(Client): headers = {} params = {} if self.database: - params['database'] = self.database + params["database"] = self.database if self.protocol_version: - params['client_protocol_version'] = self.protocol_version + params["client_protocol_version"] = self.protocol_version context.block_info = True params.update(self._validate_settings(context.settings)) context.rename_response_column = self._rename_response_column if not context.is_insert and columns_only_re.search(context.uncommented_query): # Mirror normal query behavior for form encoding and external data - fmt_json_query = f'{context.final_query}\n FORMAT JSON' + fmt_json_query = f"{context.final_query}\n FORMAT JSON" if self.form_encode_query_params: - fields = {'query': fmt_json_query} + fields = {"query": fmt_json_query} fields.update(context.bind_params) if context.external_data: # Deal with form encoding + external data params.update(context.external_data.query_params) fields.update(context.external_data.form_data) - response = self._raw_request(bytes(), params, headers, retries=self.query_retries, fields=fields) + response = self._raw_request(b"", params, headers, retries=self.query_retries, fields=fields) elif context.external_data: # Deal with external data without form encoding fields = context.external_data.form_data params.update(context.bind_params) params.update(context.external_data.query_params) - params['query'] = fmt_json_query - response = self._raw_request(bytes(), params, headers, retries=self.query_retries, fields=fields) + params["query"] = fmt_json_query + response = self._raw_request(b"", params, headers, retries=self.query_retries, fields=fields) else: # Legacy behavior (plain body, bind params in URL) params.update(context.bind_params) - response = self._raw_request(fmt_json_query, - params, headers, retries=self.query_retries) + response = self._raw_request(fmt_json_query, params, headers, retries=self.query_retries) json_result = json.loads(response.data) # ClickHouse will respond with a JSON object of meta, data, and some other objects # We just grab the column names and column types from the metadata sub object - names: List[str] = [] - types: List[ClickHouseType] = [] + names: list[str] = [] + types: list[ClickHouseType] = [] renamer = context.column_renamer - for col in json_result['meta']: - name = col['name'] + for col in json_result["meta"]: + name = col["name"] if renamer is not None: try: name = renamer(name) - except Exception as e: # pylint: disable=broad-exception-caught + except Exception as e: logger.debug("Failed to rename col '%s'. Skipping rename. Error: %s", name, e) names.append(name) - types.append(registry.get_from_name(col['type'])) + types.append(registry.get_from_name(col["type"])) return QueryResult([], None, tuple(names), tuple(types)) if self.compression: - headers['Accept-Encoding'] = self.compression + headers["Accept-Encoding"] = self.compression if self._send_comp_setting: - params['enable_http_compression'] = '1' + params["enable_http_compression"] = "1" final_query = self._prep_query(context) fields = {} # Setup additional query parameters and body if self.form_encode_query_params: - body = bytes() - fields['query'] = final_query + body = b"" + fields["query"] = final_query fields.update(context.bind_params) if context.external_data: params.update(context.external_data.query_params) fields.update(context.external_data.form_data) elif context.external_data: params.update(context.bind_params) - body = bytes() - params['query'] = final_query + body = b"" + params["query"] = final_query params.update(context.external_data.query_params) fields = context.external_data.form_data else: params.update(context.bind_params) body = final_query fields = None - headers['Content-Type'] = 'text/plain; charset=utf-8' - response = self._raw_request(body, - params, - dict_copy(headers, context.transport_settings), - stream=True, - retries=self.query_retries, - fields=fields, - server_wait=not context.streaming) + headers["Content-Type"] = "text/plain; charset=utf-8" + response = self._raw_request( + body, + params, + dict_copy(headers, context.transport_settings), + stream=True, + retries=self.query_retries, + fields=fields, + server_wait=not context.streaming, + ) exception_tag = response.headers.get(ex_tag_header) - byte_source = RespBuffCls(ResponseSource(response, exception_tag=exception_tag)) # pylint: disable=not-callable - context.set_response_tz(self._check_tz_change(response.headers.get('X-ClickHouse-Timezone'))) + byte_source = RespBuffCls(ResponseSource(response, exception_tag=exception_tag)) + context.set_response_tz(self._check_tz_change(response.headers.get("X-ClickHouse-Timezone"))) query_result = self._transform.parse_response(byte_source, context) query_result.summary = self._summary(response) return query_result @@ -320,7 +339,7 @@ class HttpClient(Client): See BaseClient doc_string for this method """ if context.empty: - logger.debug('No data included in insert, skipping') + logger.debug("No data included in insert, skipping") return QuerySummary() def error_handler(resp: HTTPResponse): @@ -331,75 +350,92 @@ class HttpClient(Client): raise ex self._error_handler(resp) - headers = {'Content-Type': 'application/octet-stream'} + headers = {"Content-Type": "application/octet-stream"} if context.compression is None: context.compression = self.write_compression if context.compression: - headers['Content-Encoding'] = context.compression + headers["Content-Encoding"] = context.compression block_gen = self._transform.build_insert(context) + def rebuild_block_gen(): + context.current_row = 0 + context.current_block = 0 + return self._transform.build_insert(context) + params = {} if self.database: - params['database'] = self.database + params["database"] = self.database params.update(self._validate_settings(context.settings)) headers = dict_copy(headers, context.transport_settings) try: - response = self._raw_request(block_gen, params, headers, error_handler=error_handler, server_wait=False) - logger.debug('Context insert response code: %d, content: %s', response.status, response.data) + response = self._raw_request( + block_gen, + params, + headers, + error_handler=error_handler, + server_wait=False, + retry_body=rebuild_block_gen, + ) + logger.debug("Context insert response code: %d, content: %s", response.status, response.data) return QuerySummary(self._summary(response)) finally: context.data = None - def raw_insert(self, table: str = None, - column_names: Optional[Sequence[str]] = None, - insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None, - settings: Optional[Dict] = None, - fmt: Optional[str] = None, - compression: Optional[str] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + def raw_insert( + self, + table: str = None, + column_names: Sequence[str] | None = None, + insert_block: str | bytes | Generator[bytes, None, None] | BinaryIO = None, + settings: dict | None = None, + fmt: str | None = None, + compression: str | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: """ See BaseClient doc_string for this method """ params = {} - headers = {'Content-Type': 'application/octet-stream'} + headers = {"Content-Type": "application/octet-stream"} if compression: - headers['Content-Encoding'] = compression + headers["Content-Encoding"] = compression if table: - cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else '' - query = f'INSERT INTO {table}{cols} FORMAT {fmt if fmt else self._write_format}' + cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else "" + query = f"INSERT INTO {table}{cols} FORMAT {fmt if fmt else self._write_format}" if not compression and isinstance(insert_block, str): - insert_block = query + '\n' + insert_block + insert_block = query + "\n" + insert_block elif not compression and isinstance(insert_block, (bytes, bytearray, BinaryIO)): - insert_block = (query + '\n').encode() + insert_block + insert_block = (query + "\n").encode() + insert_block else: - params['query'] = query + params["query"] = query if self.database: - params['database'] = self.database + params["database"] = self.database params.update(self._validate_settings(settings or {})) headers = dict_copy(headers, transport_settings) response = self._raw_request(insert_block, params, headers, server_wait=False) - logger.debug('Raw insert response code: %d, content: %s', response.status, response.data) + logger.debug("Raw insert response code: %d, content: %s", response.status, response.data) return QuerySummary(self._summary(response)) @staticmethod def _summary(response: HTTPResponse): summary = {} - if 'X-ClickHouse-Summary' in response.headers: + if "X-ClickHouse-Summary" in response.headers: try: - summary = json.loads(response.headers['X-ClickHouse-Summary']) + summary = json.loads(response.headers["X-ClickHouse-Summary"]) except json.JSONDecodeError: pass - summary['query_id'] = response.headers.get('X-ClickHouse-Query-Id', '') + summary["query_id"] = response.headers.get("X-ClickHouse-Query-Id", "") return summary - def command(self, - cmd: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - data: Union[str, bytes] = None, - settings: Optional[Dict] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> Union[str, int, Sequence[str], QuerySummary]: + def command( + self, + cmd: str, + parameters: Sequence | dict[str, Any] | None = None, + data: str | bytes = None, + settings: dict | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> str | int | Sequence[str] | QuerySummary: """ See BaseClient doc_string for this method """ @@ -409,30 +445,30 @@ class HttpClient(Client): fields = None if external_data: if data: - raise ProgrammingError('Cannot combine command data with external data') from None + raise ProgrammingError("Cannot combine command data with external data") from None fields = external_data.form_data params.update(external_data.query_params) elif isinstance(data, str): - headers['Content-Type'] = 'text/plain; charset=utf-8' + headers["Content-Type"] = "text/plain; charset=utf-8" payload = data.encode() elif isinstance(data, bytes): - headers['Content-Type'] = 'application/octet-stream' + headers["Content-Type"] = "application/octet-stream" payload = data if payload is None and not cmd: - raise ProgrammingError('Command sent without query or recognized data') from None + raise ProgrammingError("Command sent without query or recognized data") from None if payload or fields: - params['query'] = cmd + params["query"] = cmd else: payload = cmd if use_database and self.database: - params['database'] = self.database + params["database"] = self.database params.update(self._validate_settings(settings or {})) headers = dict_copy(headers, transport_settings) - method = 'POST' if payload or fields else 'GET' + method = "POST" if payload or fields else "GET" response = self._raw_request(payload, params, headers, method, fields=fields, server_wait=False) if response.data: try: - result = response.data.decode()[:-1].split('\t') + result = response.data.decode()[:-1].split("\t") if len(result) == 1: try: return int(result[0]) @@ -449,113 +485,105 @@ class HttpClient(Client): """ try: body = "" - # Always try to read the response body for context. try: - # get_response_data reads body and decodes it for the error message raw_body = get_response_data(response) - body = common.format_error( - raw_body.decode(errors="backslashreplace") - ).strip() - except Exception: # pylint: disable=broad-except - # If we can't read or decode the body, we'll proceed without it + body = common.format_error(raw_body.decode(errors="backslashreplace")).strip() + except Exception: logger.warning("Failed to read error response body", exc_info=True) - # Build the error message if self.show_clickhouse_errors: err_code = response.headers.get(ex_header) if err_code: - # Prioritize the specific ClickHouse exception code if it exists. err_str = f"Received ClickHouse exception, code: {err_code}" else: - # Otherwise, just use the generic HTTP status err_str = f"HTTP driver received HTTP status {response.status}" if body: - # Always append the body if it exists err_str = f"{err_str}, server response: {body}" else: - # Simple message for when detailed errors are disabled err_str = "The ClickHouse server returned an error" - # Add the URL for additional context err_str = f"{err_str} (for url {self.url})" finally: - # Ensure closed response to prevent resource leaks response.close() - # Raise the appropriate exception class raise OperationalError(err_str) if retried else DatabaseError(err_str) from None - def _raw_request(self, - data, - params: Dict[str, str], - headers: Optional[Dict[str, Any]] = None, - method: str = 'POST', - retries: int = 0, - stream: bool = False, - server_wait: bool = True, - fields: Optional[Dict[str, tuple]] = None, - error_handler: Callable = None) -> HTTPResponse: + def _raw_request( + self, + data, + params: dict[str, str], + headers: dict[str, Any] | None = None, + method: str = "POST", + retries: int = 0, + stream: bool = False, + server_wait: bool = True, + fields: dict[str, tuple] | None = None, + error_handler: Callable = None, + retry_body: Callable[[], Any] | None = None, + ) -> HTTPResponse: if isinstance(data, str): data = data.encode() headers = dict_copy(self.headers, headers) attempts = 0 final_params = {} if server_wait: - final_params['wait_end_of_query'] = '1' + final_params["wait_end_of_query"] = "1" # We can't actually read the progress headers, but we enable them so ClickHouse sends something # to keep the connection alive when waiting for long-running queries and (2) to get summary information # if not streaming if self._send_progress: - final_params['send_progress_in_http_headers'] = '1' + final_params["send_progress_in_http_headers"] = "1" if self._progress_interval: - final_params['http_headers_progress_interval_ms'] = self._progress_interval + final_params["http_headers_progress_interval_ms"] = self._progress_interval final_params = dict_copy(self.params, final_params) final_params = dict_copy(final_params, params) if self._autogenerate_query_id and "query_id" not in final_params: final_params["query_id"] = str(uuid.uuid4()) - url = f'{self.url}?{urlencode(final_params)}' - kwargs = { - 'headers': headers, - 'timeout': self.timeout, - 'retries': self.http_retries, - 'preload_content': not stream - } + url = f"{self.url}?{urlencode(final_params)}" + kwargs = {"headers": headers, "timeout": self.timeout, "retries": self.http_retries, "preload_content": not stream} if self.server_host_name: - kwargs['assert_same_host'] = False - kwargs['headers'].update({'Host': self.server_host_name}) + kwargs["assert_same_host"] = False + kwargs["headers"].update({"Host": self.server_host_name}) if fields: - kwargs['fields'] = fields + kwargs["fields"] = fields else: - kwargs['body'] = data + kwargs["body"] = data check_conn_expiration(self.http) - query_session = final_params.get('session_id') + query_session = final_params.get("session_id") while True: attempts += 1 if query_session: if query_session == self._active_session: - raise ProgrammingError('Attempt to execute concurrent queries within the same session.' + - 'Please use a separate client instance per thread/process.') + raise ProgrammingError( + "Attempt to execute concurrent queries within the same session. " + + "Please use a separate client instance per thread/process." + ) # There is a race condition here when using multiprocessing -- in that case the server will # throw an error instead, but in most cases this more helpful error will be thrown first self._active_session = query_session try: response = self.http.request(method, url, **kwargs) except HTTPError as ex: - if isinstance(ex.__context__, ConnectionResetError): + if isinstance(ex.__context__, ConnectionResetError) and attempts == 1: # The server closed the connection, probably because the Keep Alive has expired # We should be safe to retry, as ClickHouse should not have processed anything on a connection # that it killed. We also only retry this once, as multiple disconnects are unlikely to be # related to the Keep Alive settings - if attempts == 1: - logger.debug('Retrying remotely closed connection') + body = kwargs.get("body") + if retry_body is not None: + kwargs["body"] = retry_body() + logger.debug("Retrying remotely closed connection with rebuilt body") + continue + if body is None or isinstance(body, (bytes, bytearray, str)): + logger.debug("Retrying remotely closed connection") continue - logger.warning('Unexpected Http Driver Exception') - err_url = f' ({self.url})' if self.show_clickhouse_errors else '' - raise OperationalError(f'Error {ex} executing HTTP request attempt {attempts}{err_url}') from ex + logger.warning("Unexpected Http Driver Exception") + err_url = f" ({self.url})" if self.show_clickhouse_errors else "" + raise OperationalError(f"Error {ex} executing HTTP request attempt {attempts}{err_url}") from ex finally: if query_session: self._active_session = None # Make sure we always clear this @@ -564,67 +592,75 @@ class HttpClient(Client): if response.status in (429, 503, 504): if attempts > retries: self._error_handler(response, True) - logger.debug('Retrying requests with status code %d', response.status) + logger.debug("Retrying requests with status code %d", response.status) elif error_handler: error_handler(response) else: self._error_handler(response) - def raw_query(self, query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: str = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> bytes: + def raw_query( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> bytes: """ See BaseClient doc_string for this method """ body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) return self._raw_request(body, params, fields=fields, headers=transport_settings).data - def raw_stream(self, query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: str = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase: + def raw_stream( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> io.IOBase: """ See BaseClient doc_string for this method """ body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) - return self._raw_request(body, params, fields=fields, stream=True, server_wait=False, - headers=transport_settings) + return self._raw_request(body, params, fields=fields, stream=True, server_wait=False, headers=transport_settings) - def _prep_raw_query(self, query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]], - settings: Optional[Dict[str, Any]], - fmt: str, - use_database: bool, - external_data: Optional[ExternalData]): + def _prep_raw_query( + self, + query: str, + parameters: Sequence | dict[str, Any] | None, + settings: dict[str, Any] | None, + fmt: str, + use_database: bool, + external_data: ExternalData | None, + ): if fmt: - query += f'\n FORMAT {fmt}' + query += f"\n FORMAT {fmt}" final_query, bind_params = bind_query(query, parameters, self.server_tz) params = self._validate_settings(settings or {}) if use_database and self.database: - params['database'] = self.database + params["database"] = self.database fields = {} # Setup query body if external_data and not self.form_encode_query_params and isinstance(final_query, bytes): raise ProgrammingError("Binary query cannot be placed in URL when using External Data; enable form encoding.") # Setup additional query parameters and body if self.form_encode_query_params: - body = bytes() - fields['query'] = final_query + body = b"" + fields["query"] = final_query fields.update(bind_params) if external_data: params.update(external_data.query_params) fields.update(external_data.form_data) elif external_data: params.update(bind_params) - body = bytes() - params['query'] = final_query + body = b"" + params["query"] = final_query params.update(external_data.query_params) fields = external_data.form_data else: @@ -633,7 +669,6 @@ class HttpClient(Client): fields = None return body, params, fields - # pylint: disable=broad-exception-caught def _add_integration_tag(self, name: str): """ Dynamically adds a product (like pandas or sqlalchemy) to the User-Agent string details section. @@ -683,10 +718,10 @@ class HttpClient(Client): See BaseClient doc_string for this method """ try: - response = self.http.request('GET', f'{self.url}/ping', timeout=3, preload_content=True) + response = self.http.request("GET", f"{self.url}/ping", timeout=3, preload_content=True) return 200 <= response.status < 300 except HTTPError: - logger.debug('ping failed', exc_info=True) + logger.debug("ping failed", exc_info=True) return False def close_connections(self) -> None: diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/httputil.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/httputil.py index 961f9c02304..ac0e28ad3dd 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/httputil.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/httputil.py @@ -3,11 +3,12 @@ import http.client import logging import multiprocessing import os -import sys import socket +import sys import time from collections import deque -from typing import Dict, Any, Optional, Tuple, Callable +from collections.abc import Callable +from typing import Any import certifi import lz4.frame @@ -16,8 +17,8 @@ import zstandard from urllib3.poolmanager import PoolManager, ProxyManager from urllib3.response import HTTPResponse -from clickhouse_connect.driver.exceptions import ProgrammingError, OperationalError from clickhouse_connect import common +from clickhouse_connect.driver.exceptions import OperationalError, ProgrammingError logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # Increase this number just to be safe when ClickHouse is returning progress headers -http.client._MAXHEADERS = 10000 # pylint: disable=protected-access +http.client._MAXHEADERS = 10000 DEFAULT_KEEP_INTERVAL = 30 DEFAULT_KEEP_COUNT = 3 @@ -37,10 +38,10 @@ core_socket_options = [ (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), (SOCKET_TCP, socket.TCP_NODELAY, 1), (socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 256), - (socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 256) + (socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 256), ] -logging.getLogger('urllib3').setLevel(logging.WARNING) +logging.getLogger("urllib3").setLevel(logging.WARNING) _proxy_managers = {} all_managers = {} @@ -51,67 +52,72 @@ def close_managers(): manager.clear() -# pylint: disable=no-member,too-many-arguments,too-many-branches -def get_pool_manager_options(keep_interval: int = DEFAULT_KEEP_INTERVAL, - keep_count: int = DEFAULT_KEEP_COUNT, - keep_idle: int = DEFAULT_KEEP_IDLE, - ca_cert: Optional[str] = None, - verify: bool = True, - client_cert: Optional[str] = None, - client_cert_key: Optional[str] = None, - **options) -> Dict[str, Any]: +def get_pool_manager_options( + keep_interval: int = DEFAULT_KEEP_INTERVAL, + keep_count: int = DEFAULT_KEEP_COUNT, + keep_idle: int = DEFAULT_KEEP_IDLE, + ca_cert: str | None = None, + verify: bool = True, + client_cert: str | None = None, + client_cert_key: str | None = None, + **options, +) -> dict[str, Any]: socket_options = core_socket_options.copy() - if getattr(socket, 'TCP_KEEPINTVL', None) is not None: + if getattr(socket, "TCP_KEEPINTVL", None) is not None: socket_options.append((SOCKET_TCP, socket.TCP_KEEPINTVL, keep_interval)) - if getattr(socket, 'TCP_KEEPCNT', None) is not None: + if getattr(socket, "TCP_KEEPCNT", None) is not None: socket_options.append((SOCKET_TCP, socket.TCP_KEEPCNT, keep_count)) - if getattr(socket, 'TCP_KEEPIDLE', None) is not None: + if getattr(socket, "TCP_KEEPIDLE", None) is not None: socket_options.append((SOCKET_TCP, socket.TCP_KEEPIDLE, keep_idle)) - if sys.platform == 'darwin': - socket_options.append((SOCKET_TCP, getattr(socket, 'TCP_KEEPALIVE', 0x10), keep_interval)) - options['maxsize'] = options.get('maxsize', 8) - options['retries'] = options.get('retries', 1) - if ca_cert == 'certifi': + if sys.platform == "darwin": + socket_options.append((SOCKET_TCP, getattr(socket, "TCP_KEEPALIVE", 0x10), keep_interval)) + options["maxsize"] = options.get("maxsize", 8) + options["retries"] = options.get("retries", 1) + if ca_cert == "certifi": ca_cert = certifi.where() - options['cert_reqs'] = 'CERT_REQUIRED' if verify else 'CERT_NONE' + options["cert_reqs"] = "CERT_REQUIRED" if verify else "CERT_NONE" if ca_cert: - options['ca_certs'] = ca_cert + options["ca_certs"] = ca_cert if client_cert: - options['cert_file'] = client_cert + options["cert_file"] = client_cert if client_cert_key: - options['key_file'] = client_cert_key - options['socket_options'] = socket_options - options['block'] = options.get('block', False) + options["key_file"] = client_cert_key + options["socket_options"] = socket_options + options["block"] = options.get("block", False) return options -def get_pool_manager(keep_interval: int = DEFAULT_KEEP_INTERVAL, - keep_count: int = DEFAULT_KEEP_COUNT, - keep_idle: int = DEFAULT_KEEP_IDLE, - ca_cert: Optional[str] = None, - verify: bool = True, - client_cert: Optional[str] = None, - client_cert_key: Optional[str] = None, - http_proxy: Optional[str] = None, - https_proxy: Optional[str] = None, - **options): - options = get_pool_manager_options(keep_interval, - keep_count, - keep_idle, - ca_cert, - verify, - client_cert, - client_cert_key, - **options) +def get_pool_manager( + keep_interval: int = DEFAULT_KEEP_INTERVAL, + keep_count: int = DEFAULT_KEEP_COUNT, + keep_idle: int = DEFAULT_KEEP_IDLE, + ca_cert: str | None = None, + verify: bool = True, + client_cert: str | None = None, + client_cert_key: str | None = None, + http_proxy: str | None = None, + https_proxy: str | None = None, + **options, +): + options = get_pool_manager_options( + keep_interval, + keep_count, + keep_idle, + ca_cert, + verify, + client_cert, + client_cert_key, + **options, + ) if http_proxy: if https_proxy: - raise ProgrammingError('Only one of http_proxy or https_proxy should be specified') - if not http_proxy.startswith('http'): - http_proxy = f'http://{http_proxy}' + raise ProgrammingError("Only one of http_proxy or https_proxy should be specified") + if not http_proxy.startswith("http"): + http_proxy = f"http://{http_proxy}" manager = ProxyManager(http_proxy, **options) elif https_proxy: - if not https_proxy.startswith('http'): - https_proxy = f'https://{https_proxy}' + if not https_proxy.startswith("http"): + https_proxy = f"https://{https_proxy}" manager = ProxyManager(https_proxy, **options) else: manager = PoolManager(**options) @@ -120,18 +126,18 @@ def get_pool_manager(keep_interval: int = DEFAULT_KEEP_INTERVAL, def check_conn_expiration(manager: PoolManager): - reset_seconds = common.get_setting('max_connection_age') + reset_seconds = common.get_setting("max_connection_age") if reset_seconds: last_reset = all_managers.get(manager, 0) now = int(time.time()) if last_reset < now - reset_seconds: - logger.debug('connection expiration') + logger.debug("connection expiration") manager.clear() all_managers[manager] = now def get_proxy_manager(host: str, http_proxy): - key = f'{host}__{http_proxy}' + key = f"{host}__{http_proxy}" if key in _proxy_managers: return _proxy_managers[key] proxy_manager = get_pool_manager(http_proxy=http_proxy) @@ -140,41 +146,41 @@ def get_proxy_manager(host: str, http_proxy): def get_response_data(response: HTTPResponse) -> bytes: - encoding = response.headers.get('content-encoding', None) - if encoding == 'zstd': + encoding = response.headers.get("content-encoding", None) + if encoding == "zstd": try: zstd_decom = zstandard.ZstdDecompressor() return zstd_decom.stream_reader(response.data).read() except zstandard.ZstdError: pass - if encoding == 'lz4': + if encoding == "lz4": lz4_decom = lz4.frame.LZ4FrameDecompressor() return lz4_decom.decompress(response.data, len(response.data)) return response.data -def check_env_proxy(scheme: str, host: str, port: int) -> Optional[str]: - env_var = f'{scheme}_proxy'.lower() +def check_env_proxy(scheme: str, host: str, port: int) -> str | None: + env_var = f"{scheme}_proxy".lower() proxy = os.environ.get(env_var) if not proxy: proxy = os.environ.get(env_var.upper()) if not proxy: return None - no_proxy = os.environ.get('no_proxy') + no_proxy = os.environ.get("no_proxy") if not no_proxy: - no_proxy = os.environ.get('NO_PROXY') + no_proxy = os.environ.get("NO_PROXY") if not no_proxy: return proxy - if no_proxy == '*': + if no_proxy == "*": return None # Wildcard no proxy means don't actually proxy anything host = host.lower() - for name in no_proxy.split(','): + for name in no_proxy.split(","): name = name.strip() if name: - name = name.lstrip('.').lower() - if name in (host, f'{host}:{port}'): + name = name.lstrip(".").lower() + if name in (host, f"{host}:{port}"): return None # Host or host/port matches - if host.endswith('.' + name): + if host.endswith("." + name): return None # Domain matches return proxy @@ -183,31 +189,30 @@ _default_pool_manager = get_pool_manager() def default_pool_manager(): - if multiprocessing.current_process().name == 'MainProcess': + if multiprocessing.current_process().name == "MainProcess": return _default_pool_manager # PoolManagers don't seem to be safe for some multiprocessing environments, always return a new one return get_pool_manager() -# pylint: disable=too-many-statements class ResponseSource: - def __init__(self, response: HTTPResponse, chunk_size: int = 1024 * 1024, exception_tag: Optional[str] = None): + def __init__(self, response: HTTPResponse, chunk_size: int = 1024 * 1024, exception_tag: str | None = None): self.response = response self.exception_tag = exception_tag - compression = response.headers.get('content-encoding') - decompress:Optional[Callable] = None - if compression == 'zstd': + compression = response.headers.get("content-encoding") + decompress: Callable | None = None + if compression == "zstd": zstd_decom = zstandard.ZstdDecompressor().decompressobj() - def zstd_decompress(c: deque) -> Tuple[bytes, int]: + def zstd_decompress(c: deque) -> tuple[bytes, int]: chunk = c.popleft() return zstd_decom.decompress(chunk), len(chunk) decompress = zstd_decompress - elif compression == 'lz4': + elif compression == "lz4": lz4_decom = lz4.frame.LZ4FrameDecompressor() - def lz_decompress(c: deque) -> Tuple[Optional[bytes], int]: + def lz_decompress(c: deque) -> tuple[bytes | None, int]: read_amt = 0 data = c.popleft() read_amt += len(data) @@ -221,7 +226,7 @@ class ResponseSource: decompress = lz_decompress - buffer_size = common.get_setting('http_buffer_size') + buffer_size = common.get_setting("http_buffer_size") def buffered(): chunks = deque() @@ -234,11 +239,11 @@ class ResponseSource: while not done: chunk = None try: - chunk = next(read_gen, None) # Always try to read at least one chunk if there are any left - except Exception as ex: # pylint: disable=broad-except + chunk = next(read_gen, None) # Always try to read at least one chunk if there are any left + except Exception as ex: # Store the exception for potential re-raising if no data was received read_error = ex - logger.warning('unexpected failure to read next chunk', exc_info=True) + logger.warning("unexpected failure to read next chunk", exc_info=True) if not chunk: done = True break diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/insert.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/insert.py index 2151eca50cb..bcb361afaac 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/insert.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/insert.py @@ -1,19 +1,19 @@ import logging +from collections.abc import Generator, Iterable, Sequence from math import log -from typing import Iterable, Sequence, Optional, Any, Dict, NamedTuple, Generator, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, NamedTuple +from clickhouse_connect.driver import options from clickhouse_connect.driver.binding import quote_identifier - -from clickhouse_connect.driver.ctypes import data_conv from clickhouse_connect.driver.context import BaseQueryContext -from clickhouse_connect.driver import options -from clickhouse_connect.driver.exceptions import ProgrammingError, DataError +from clickhouse_connect.driver.ctypes import data_conv +from clickhouse_connect.driver.exceptions import DataError, ProgrammingError if TYPE_CHECKING: from clickhouse_connect.datatypes.base import ClickHouseType logger = logging.getLogger(__name__) -DEFAULT_BLOCK_BYTES = 1 << 21 # Try to generate blocks between 1MB and 2MB in raw size +DEFAULT_BLOCK_BYTES = 1 << 21 # Try to generate blocks between 1MB and 2MB in raw size class InsertBlock(NamedTuple): @@ -21,29 +21,29 @@ class InsertBlock(NamedTuple): column_count: int row_count: int column_names: Iterable[str] - column_types: Iterable['ClickHouseType'] + column_types: Iterable["ClickHouseType"] column_data: Iterable[Sequence[Any]] -# pylint: disable=too-many-instance-attributes class InsertContext(BaseQueryContext): """ Reusable Argument/parameter object for inserts. """ - # pylint: disable=too-many-arguments - def __init__(self, - table: str, - column_names: Sequence[str], - column_types: Sequence['ClickHouseType'], - data: Any = None, - column_oriented: Optional[bool] = None, - settings: Optional[Dict[str, Any]] = None, - compression: Optional[Union[str, bool]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - block_size: Optional[int] = None, - transport_settings: Optional[Dict[str, str]] = None): + def __init__( + self, + table: str, + column_names: Sequence[str], + column_types: Sequence["ClickHouseType"], + data: Any = None, + column_oriented: bool | None = None, + settings: dict[str, Any] | None = None, + compression: str | bool | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + block_size: int | None = None, + transport_settings: dict[str, str] | None = None, + ): super().__init__(settings, query_formats, column_formats, transport_settings=transport_settings) self.table = table self.column_names = column_names @@ -92,7 +92,7 @@ class InsertContext(BaseQueryContext): self.column_count = len(data[0]) if self.row_count and self.column_count: if self.column_count != len(self.column_names): - raise ProgrammingError('Insert data column count does not match column names') + raise ProgrammingError("Insert data column count does not match column names") self._data = data self.block_row_count = self._calc_block_size() @@ -118,7 +118,7 @@ class InsertContext(BaseQueryContext): sample = [data[j][i] for j in range(0, self.row_count, sample_freq)] d_size = d_type.data_size(sample) row_size += d_size - shift_size = (21 - int(log(row_size, 2))) + shift_size = 21 - int(log(row_size, 2)) return 1 if shift_size < 0 else 1 << (21 - int(log(row_size, 2))) def next_block(self) -> Generator[InsertBlock, None, None]: @@ -129,9 +129,9 @@ class InsertContext(BaseQueryContext): return if self.current_block == 0: cols = f" ({', '.join([quote_identifier(x) for x in self.column_names])})" - prefix = f'INSERT INTO {self.table}{cols} FORMAT Native\n'.encode() + prefix = f"INSERT INTO {self.table}{cols} FORMAT Native\n".encode() else: - prefix = bytes() + prefix = b"" self.current_block += 1 data = self._next_block_data(self.current_row, block_end) yield InsertBlock(prefix, self.column_count, row_count, self.column_names, self.column_types, data) @@ -140,7 +140,7 @@ class InsertContext(BaseQueryContext): def _column_block_data(self, block_start, block_end): if block_start == 0 and self.row_count <= block_end: return self._block_columns # Optimization if we don't need to break up the block - return [col[block_start: block_end] for col in self._block_columns] + return [col[block_start:block_end] for col in self._block_columns] def _row_block_data(self, block_start, block_end): return data_conv.pivot(self._block_rows, block_start, block_end) @@ -150,37 +150,44 @@ class InsertContext(BaseQueryContext): for df_col_name, col_name, ch_type in zip(df.columns, self.column_names, self.column_types): df_col = df[df_col_name] d_type_kind = df_col.dtype.kind - if ch_type.python_type == int: - if d_type_kind == 'f': - df_col = df_col.round().astype(ch_type.base_type, copy=False) - elif d_type_kind in ('i', 'u') and not df_col.hasnans: + if ch_type.python_type is int: + if d_type_kind == "f": + df_col = df_col.round().astype(ch_type.base_type) + elif d_type_kind in ("i", "u") and not df_col.hasnans: data.append(df_col.to_list()) continue - elif 'datetime' in ch_type.np_type and (options.pd_time_test(df_col) or 'datetime64[ns' in str(df_col.dtype)): - div = ch_type.nano_divisor - data.append([None if options.pd.isnull(x) else x.value // div for x in df_col]) - self.column_formats[col_name] = 'int' + elif "datetime" in ch_type.np_type and (options.pd_time_test(df_col) or "datetime64" in str(df_col.dtype)): + np_col = df_col.to_numpy(dtype=ch_type.np_type) + int_col = np_col.astype("int64") + if df_col.hasnans: + nat_mask = options.pd.isnull(df_col).to_numpy() + int_list = int_col.tolist() + data.append([None if nat_mask[i] else int_list[i] for i in range(len(int_list))]) + else: + data.append(int_col.tolist()) + self.column_formats[col_name] = "int" continue if ch_type.nullable: - if d_type_kind == 'O': - # This is ugly, but the multiple replaces seem required as a result of this bug: - # https://github.com/pandas-dev/pandas/issues/29024 - df_col = df_col.replace({options.pd.NaT: None}).replace({options.np.nan: None}) - elif 'Float' in ch_type.base_type: + if d_type_kind == "O" or ch_type.np_type == "O": + data.append(df_col.to_numpy(dtype=object, na_value=None)) + continue + if "Float" in ch_type.base_type: data.append([None if options.pd.isnull(x) else x for x in df_col]) continue - else: - df_col = df_col.replace({options.np.nan: None}) - data.append(df_col.to_numpy(copy=False)) + df_col = df_col.replace({options.np.nan: None}) + if ch_type.np_type == "O": + data.append(df_col.to_numpy(dtype=object, na_value=None)) + else: + data.append(df_col.to_numpy(copy=False)) return data def _convert_numpy(self, np_array): if np_array.dtype.names is None: - if 'date' in str(np_array.dtype): + if "date" in str(np_array.dtype): for col_name, col_type in zip(self.column_names, self.column_types): - if 'date' in col_type.np_type: - self.column_formats[col_name] = 'int' - return np_array.astype('int').tolist() + if "date" in col_type.np_type: + self.column_formats[col_name] = "int" + return np_array.astype("int").tolist() for col_type in self.column_types: if col_type.byte_size == 0 or col_type.byte_size > np_array.dtype.itemsize: return np_array.tolist() @@ -193,8 +200,8 @@ class InsertContext(BaseQueryContext): data = [np_array[col_name] for col_name in np_array.dtype.names] for ix, (col_name, col_type) in enumerate(zip(self.column_names, self.column_types)): d_type = data[ix].dtype - if 'date' in str(d_type) and 'date' in col_type.np_type: - self.column_formats[col_name] = 'int' + if "date" in str(d_type) and "date" in col_type.np_type: + self.column_formats[col_name] = "int" data[ix] = data[ix].astype(int).tolist() elif col_type.byte_size == 0 or col_type.byte_size > d_type.itemsize: data[ix] = data[ix].tolist() diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/models.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/models.py index 38407d1c63c..7d626f8efed 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/models.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/models.py @@ -7,6 +7,7 @@ class ColumnDef(NamedTuple): """ ClickHouse column definition from DESCRIBE TABLE command """ + name: str type: str default_type: str @@ -17,7 +18,7 @@ class ColumnDef(NamedTuple): @property def type_name(self): - return self.type.replace('\n', '').strip() + return self.type.replace("\n", "").strip() @property def ch_type(self): @@ -28,6 +29,7 @@ class SettingDef(NamedTuple): """ ClickHouse setting definition from system.settings table """ + name: str value: str readonly: int @@ -37,5 +39,6 @@ class SettingStatus(NamedTuple): """ Get the setting "status" from a ClickHouse server setting """ + is_set: bool is_writable: bool diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/npquery.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/npquery.py index 16cafcea68b..1c68ed71566 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/npquery.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/npquery.py @@ -1,28 +1,29 @@ -import logging import itertools -from typing import Generator, Sequence, Tuple +import logging +from collections.abc import Generator, Sequence -from clickhouse_connect.driver.common import empty_gen, StreamContext +from clickhouse_connect.driver import options +from clickhouse_connect.driver.common import StreamContext, empty_gen from clickhouse_connect.driver.exceptions import StreamClosedError from clickhouse_connect.driver.types import Closable -from clickhouse_connect.driver import options logger = logging.getLogger(__name__) -# pylint: disable=too-many-instance-attributes class NumpyResult(Closable): - def __init__(self, - block_gen: Generator[Sequence, None, None] = None, - column_names: Tuple = (), - column_types: Tuple = (), - d_types: Sequence = (), - source: Closable = None): + def __init__( + self, + block_gen: Generator[Sequence, None, None] = None, + column_names: tuple = (), + column_types: tuple = (), + d_types: Sequence = (), + source: Closable = None, + ): self.column_names = column_names self.column_types = column_types self.np_types = d_types self.source = source - self.query_id = '' + self.query_id = "" self.summary = {} self._block_gen = block_gen or empty_gen() self._numpy_result = None @@ -101,9 +102,9 @@ class NumpyResult(Closable): chains = [chain(b) for b in zip(*bg)] new_df_series = [] for c in chains: - series = [options.pd.Series(piece, copy=False) for piece in c if len(piece) > 0] + series = [options.pd.Series(piece) for piece in c if len(piece) > 0] if len(series) > 0: - new_df_series.append(options.pd.concat(series, copy=False, ignore_index=True)) + new_df_series.append(options.pd.concat(series, ignore_index=True)) self._df_result = options.pd.DataFrame(dict(zip(self.column_names, new_df_series))) self.close() return self diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/options.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/options.py index b4bfb31c1ec..e39da43db7c 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/options.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/options.py @@ -1,16 +1,21 @@ -import warnings - from clickhouse_connect.driver.exceptions import NotSupportedError # Attributes resolved lazily by __getattr__ / _resolve_* functions: -# np, pd, arrow, pl, pd_time_test, pd_extended_dtypes, PANDAS_VERSION, IS_PANDAS_2 +# np, pd, arrow, pl, pd_time_test -# pylint: disable=import-outside-toplevel -_PANDAS_ATTRS = frozenset({"pd", "pd_time_test", "pd_extended_dtypes", "PANDAS_VERSION", "IS_PANDAS_2"}) +_PANDAS_ATTRS = frozenset({"pd", "pd_time_test"}) _ALL_LAZY = frozenset({"np", "arrow", "pl"}) | _PANDAS_ATTRS +def _pd_time_test(arr_or_dtype): + """Check whether a Series or dtype is datetime64 or timedelta64.""" + kind = getattr(arr_or_dtype, "kind", None) + if kind is None: + kind = getattr(getattr(arr_or_dtype, "dtype", None), "kind", None) + return kind in ("M", "m") + + def _resolve_numpy(): if "np" in globals(): return @@ -28,41 +33,15 @@ def _resolve_pandas(): try: import pandas - globals()["pd"] = pandas version = tuple(map(int, pandas.__version__.split(".")[:2])) - globals()["PANDAS_VERSION"] = version - is_v2 = version >= (2, 0) - globals()["IS_PANDAS_2"] = is_v2 - globals()["pd_extended_dtypes"] = not pandas.__version__.startswith("0") - if not is_v2: - warnings.warn( - "clickhouse-connect support for pandas 1.x is deprecated and will be removed in v1.0.0. " - "Please upgrade to pandas 2.x or later.", - DeprecationWarning, - stacklevel=2, - ) - try: - from pandas.core.dtypes.common import ( - is_datetime64_dtype, - is_timedelta64_dtype, + if version < (2, 0): + raise NotSupportedError( + f"clickhouse-connect requires pandas 2.0 or later, found {pandas.__version__}. Please upgrade: pip install --upgrade pandas" ) - - def combined_test(arr_or_dtype): - return is_datetime64_dtype(arr_or_dtype) or is_timedelta64_dtype(arr_or_dtype) - - globals()["pd_time_test"] = combined_test - except ImportError: - try: - from pandas.core.dtypes.common import is_datetime_or_timedelta_dtype - - globals()["pd_time_test"] = is_datetime_or_timedelta_dtype - except ImportError as ex: - raise NotSupportedError("pandas version does not contain expected test for temporal types") from ex + globals()["pd"] = pandas + globals()["pd_time_test"] = _pd_time_test except ImportError: globals()["pd"] = None - globals()["PANDAS_VERSION"] = None - globals()["IS_PANDAS_2"] = None - globals()["pd_extended_dtypes"] = False globals()["pd_time_test"] = None @@ -106,7 +85,6 @@ def __dir__(): return list(globals().keys()) + list(_ALL_LAZY - globals().keys()) -# pylint: disable=redefined-outer-name def check_numpy(): _resolve_numpy() np = globals()["np"] diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/parser.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/parser.py index 02bdc03cd5f..99de3f56cc6 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/parser.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/parser.py @@ -1,10 +1,7 @@ -from typing import Union, Tuple - from clickhouse_connect.driver.common import unescape_identifier -# pylint: disable=too-many-branches -def parse_callable(expr) -> Tuple[str, Tuple[Union[str, int], ...], str]: +def parse_callable(expr) -> tuple[str, tuple[str | int, ...], str]: """ Parses a single level ClickHouse optionally 'callable' function/identifier. The identifier is returned as the first value in the response tuple. If the expression is callable -- i.e. an identifier followed by 0 or more @@ -19,16 +16,16 @@ def parse_callable(expr) -> Tuple[str, Tuple[Union[str, int], ...], str]: :return: Tuple of the identifier, a tuple of arguments, and remaining text """ expr = expr.strip() - pos = expr.find('(') - space = expr.find(' ') + pos = expr.find("(") + space = expr.find(" ") if pos == -1 and space == -1: - return expr, (), '' + return expr, (), "" if space != -1 and (pos == -1 or space < pos): return expr[:space], (), expr[space:].strip() name = expr[:pos] pos += 1 # Skip first paren values = [] - value = '' + value = "" in_str = False level = 0 @@ -45,39 +42,39 @@ def parse_callable(expr) -> Tuple[str, Tuple[Union[str, int], ...], str]: value += char if char == "'": in_str = False - elif char == '\\' and expr[pos] == "'" and expr[pos:pos + 4] != "' = " and expr[pos:pos + 2] != "')": + elif char == "\\" and expr[pos] == "'" and expr[pos : pos + 4] != "' = " and expr[pos : pos + 2] != "')": value += expr[pos] pos += 1 else: if level == 0: - if char == ' ': + if char == " ": space = pos temp_char = expr[space] - while temp_char == ' ': + while temp_char == " ": space += 1 temp_char = expr[space] if not value or temp_char in "()',=><0": char = temp_char pos = space + 1 - if char == ',': + if char == ",": add_value() - value = '' + value = "" continue - if char == ')': + if char == ")": break - if char == "'" and (not value or 'Enum' in value): + if char == "'" and (not value or "Enum" in value): in_str = True - elif char == '(': + elif char == "(": level += 1 - elif char == ')' and level: + elif char == ")" and level: level -= 1 value += char - if value != '': + if value != "": add_value() return name, tuple(values), expr[pos:].strip() -def parse_enum(expr) -> Tuple[Tuple[str], Tuple[int]]: +def parse_enum(expr) -> tuple[tuple[str], tuple[int]]: """ Parse a ClickHouse enum definition expression of the form ('key1' = 1, 'key2' = 2) :param expr: ClickHouse enum expression/arguments @@ -85,7 +82,7 @@ def parse_enum(expr) -> Tuple[Tuple[str], Tuple[int]]: """ keys = [] values = [] - pos = expr.find('(') + 1 + pos = expr.find("(") + 1 in_key = False key = [] value = [] @@ -94,20 +91,20 @@ def parse_enum(expr) -> Tuple[Tuple[str], Tuple[int]]: pos += 1 if in_key: if char == "'": - keys.append(''.join(key)) + keys.append("".join(key)) key = [] in_key = False - elif char == '\\' and expr[pos] == "'" and expr[pos:pos + 4] != "' = " and expr[pos:] != "')": + elif char == "\\" and expr[pos] == "'" and expr[pos : pos + 4] != "' = " and expr[pos:] != "')": key.append(expr[pos]) pos += 1 else: key.append(char) - elif char not in (' ', '='): - if char == ',': - values.append(int(''.join(value))) + elif char not in (" ", "="): + if char == ",": + values.append(int("".join(value))) value = [] - elif char == ')': - values.append(int(''.join(value))) + elif char == ")": + values.append(int("".join(value))) break elif char == "'" and not value: in_key = True @@ -129,7 +126,7 @@ def parse_columns(expr: str): pos = 1 named = False level = 0 - label = '' + label = "" quote = None while True: char = expr[pos] @@ -137,30 +134,30 @@ def parse_columns(expr: str): if quote: if char == quote: quote = None - elif char == '\\' and expr[pos] == "'" and expr[pos:pos + 4] != "' = " and expr[pos:pos + 2] != "')": + elif char == "\\" and expr[pos] == "'" and expr[pos : pos + 4] != "' = " and expr[pos : pos + 2] != "')": label += expr[pos] pos += 1 else: if level == 0: - if char in (' ', '='): + if char in (" ", "="): if label and not named: names.append(unescape_identifier(label)) - label = '' + label = "" named = True - char = '' - elif char == ',': + char = "" + elif char == ",": columns.append(label) named = False - label = '' + label = "" continue - elif char == ')': + elif char == ")": columns.append(label) break - if char in ("'", '`') and (not label or 'Enum' in label): + if char in ("'", "`") and (not label or "Enum" in label): quote = char - elif char == '(': + elif char == "(": level += 1 - elif char == ')': + elif char == ")": level -= 1 label += char return tuple(names), tuple(columns) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/query.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/query.py index 5ce5f8bc03e..38f6abec512 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/query.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/query.py @@ -1,23 +1,19 @@ import logging import re -import warnings -import pytz - +from collections.abc import Generator, Sequence +from datetime import timezone, tzinfo from io import IOBase -from typing import Any, Literal, Tuple, Dict, Sequence, Optional, Union, Generator, BinaryIO, TYPE_CHECKING -from datetime import tzinfo - -from pytz.exceptions import UnknownTimeZoneError +from typing import TYPE_CHECKING, Any, BinaryIO, Literal +from zoneinfo import ZoneInfoNotFoundError from clickhouse_connect.driver import tzutil from clickhouse_connect.driver.binding import bind_query -from clickhouse_connect.driver.common import dict_copy, empty_gen, StreamContext, get_rename_method +from clickhouse_connect.driver.common import StreamContext, dict_copy, empty_gen, get_rename_method +from clickhouse_connect.driver.context import BaseQueryContext +from clickhouse_connect.driver.exceptions import ProgrammingError, StreamClosedError from clickhouse_connect.driver.external import ExternalData -from clickhouse_connect.driver.types import Matrix, Closable -from clickhouse_connect.driver.exceptions import StreamClosedError, ProgrammingError -from clickhouse_connect.driver import options from clickhouse_connect.driver.options import check_arrow -from clickhouse_connect.driver.context import BaseQueryContext +from clickhouse_connect.driver.types import Closable, Matrix if TYPE_CHECKING: from clickhouse_connect.datatypes.base import ClickHouseType @@ -27,163 +23,48 @@ logger = logging.getLogger(__name__) TzMode = Literal["naive_utc", "aware", "schema"] TzSource = Literal["auto", "server", "local"] -_UTC_TZ_AWARE_TO_TZ_MODE: Dict[Union[bool, str], TzMode] = { - False: "naive_utc", - True: "aware", - "schema": "schema", -} - _VALID_TZ_MODES = {"naive_utc", "aware", "schema"} -_TZ_MODE_TO_UTC_TZ_AWARE: Dict[str, Union[bool, Literal["schema"]]] = { - "naive_utc": False, - "aware": True, - "schema": "schema", -} - -_APPLY_SERVER_TZ_TO_TZ_SOURCE: Dict[Union[bool, str, None], TzSource] = { - None: "auto", - True: "server", - False: "local", - "always": "server", -} - -_TZ_SOURCE_TO_APPLY_SERVER_TZ: Dict[str, Optional[Union[bool, str]]] = { - "auto": None, - "server": True, - "local": False, -} - _VALID_TZ_SOURCES = {"auto", "server", "local"} -# Mapping for string booleans that may arrive via URL params -_STR_BOOL_MAP = {"true": True, "false": False, "1": True, "0": False} - -def _resolve_tz_mode( - tz_mode: Optional[TzMode] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, -) -> TzMode: - """Resolve tz_mode from either the new ``tz_mode`` or deprecated ``utc_tz_aware`` parameter. +commands = "CREATE|ALTER|SYSTEM|GRANT|REVOKE|CHECK|DETACH|ATTACH|DROP|DELETE|KILL|OPTIMIZE|SET|RENAME|TRUNCATE|USE|UPDATE" - Returns the canonical TzMode string. Raises ``ProgrammingError`` on conflicts or - invalid values. - """ - if tz_mode is not None and utc_tz_aware is not None: - raise ProgrammingError( - "Cannot specify both 'tz_mode' and 'utc_tz_aware'. " - "Use 'tz_mode' only; 'utc_tz_aware' is deprecated." - ) - - if utc_tz_aware is not None: - # Coerce string booleans from URL params (e.g. "true" -> True) - if isinstance(utc_tz_aware, str) and utc_tz_aware.lower() in _STR_BOOL_MAP: - utc_tz_aware = _STR_BOOL_MAP[utc_tz_aware.lower()] - - if utc_tz_aware not in _UTC_TZ_AWARE_TO_TZ_MODE: - raise ProgrammingError( - f'utc_tz_aware must be True, False, or "schema", got "{utc_tz_aware}"' - ) - warnings.warn( - "utc_tz_aware is deprecated and will be removed in 1.0. " - "Use tz_mode='naive_utc' | 'aware' | 'schema' instead.", - DeprecationWarning, - stacklevel=3, - ) - return _UTC_TZ_AWARE_TO_TZ_MODE[utc_tz_aware] +limit_re = re.compile(r"\s+LIMIT($|\s)", re.IGNORECASE) +select_re = re.compile(r"(^|\s)SELECT\s", re.IGNORECASE) +insert_re = re.compile(r"(^|\s)INSERT\s*INTO", re.IGNORECASE) +command_re = re.compile(r"(^\s*)(" + commands + r")\s", re.IGNORECASE) - if tz_mode is not None: - if tz_mode not in _VALID_TZ_MODES: - raise ProgrammingError( - f'tz_mode must be "naive_utc", "aware", or "schema", got "{tz_mode}"' - ) - return tz_mode - return "naive_utc" - - -def _resolve_tz_source( - tz_source: Optional[TzSource] = None, - apply_server_timezone: Optional[Union[str, bool]] = None, -) -> TzSource: - """Resolve tz_source from either the new ``tz_source`` or deprecated ``apply_server_timezone`` parameter. - - Returns the canonical TzSource string. Raises ``ProgrammingError`` on conflicts or - invalid values. - """ - if tz_source is not None and apply_server_timezone is not None: - raise ProgrammingError( - "Cannot specify both 'tz_source' and 'apply_server_timezone'. " - "Use 'tz_source' only; 'apply_server_timezone' is deprecated." - ) - - if apply_server_timezone is not None: - # Coerce string booleans from URL params (e.g. "true" -> True) - if isinstance(apply_server_timezone, str) and apply_server_timezone.lower() in _STR_BOOL_MAP: - apply_server_timezone = _STR_BOOL_MAP[apply_server_timezone.lower()] - - if apply_server_timezone not in _APPLY_SERVER_TZ_TO_TZ_SOURCE: - raise ProgrammingError( - f"apply_server_timezone must be None, True, False, or 'always', " - f'got "{apply_server_timezone}"' - ) - warnings.warn( - "apply_server_timezone is deprecated and will be removed in 1.0. " - "Use tz_source='auto' | 'server' | 'local' instead.", - DeprecationWarning, - stacklevel=3, - ) - return _APPLY_SERVER_TZ_TO_TZ_SOURCE[apply_server_timezone] - - if tz_source is not None: - if tz_source not in _VALID_TZ_SOURCES: - raise ProgrammingError( - f'tz_source must be "auto", "server", or "local", got "{tz_source}"' - ) - return tz_source - - return "auto" - - -commands = 'CREATE|ALTER|SYSTEM|GRANT|REVOKE|CHECK|DETACH|ATTACH|DROP|DELETE|KILL|' + \ - 'OPTIMIZE|SET|RENAME|TRUNCATE|USE|UPDATE' - -limit_re = re.compile(r'\s+LIMIT($|\s)', re.IGNORECASE) -select_re = re.compile(r'(^|\s)SELECT\s', re.IGNORECASE) -insert_re = re.compile(r'(^|\s)INSERT\s*INTO', re.IGNORECASE) -command_re = re.compile(r'(^\s*)(' + commands + r')\s', re.IGNORECASE) - - -# pylint: disable=too-many-instance-attributes class QueryContext(BaseQueryContext): """ Argument/parameter object for queries. This context is used to set thread/query specific formats """ - # pylint: disable=duplicate-code,too-many-arguments,too-many-positional-arguments,too-many-locals - def __init__(self, - query: Union[str, bytes] = '', - parameters: Optional[Dict[str, Any]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - server_tz: tzinfo = pytz.UTC, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = None, - max_str_len: Optional[int] = 0, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - use_extended_dtypes: Optional[bool] = None, - as_pandas: bool = False, - streaming: bool = False, - apply_server_tz: bool = False, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - rename_response_column: Optional[str] = None, - tz_mode: Optional[TzMode] = None): + def __init__( + self, + query: str | bytes = "", + parameters: dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + server_tz: tzinfo = timezone.utc, + use_none: bool | None = None, + column_oriented: bool | None = None, + use_numpy: bool | None = None, + max_str_len: int | None = 0, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + use_extended_dtypes: bool | None = None, + as_pandas: bool = False, + streaming: bool = False, + apply_server_tz: bool = False, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + rename_response_column: str | None = None, + tz_mode: TzMode | None = None, + ): """ Initializes various configuration settings for the query context @@ -205,7 +86,7 @@ class QueryContext(BaseQueryContext): :param max_str_len Limit returned ClickHouse String values to this length, which allows a Numpy structured array even with ClickHouse variable length String columns. If 0, Numpy arrays for String columns will always be object arrays - :param query_tz Either a string or a pytz tzinfo object. (Strings will be converted to tzinfo objects). + :param query_tz Either a string IANA timezone name or a tzinfo object (strings are resolved via zoneinfo). Values for any DateTime or DateTime64 column in the query will be converted to Python datetime.datetime objects with the selected timezone :param column_tzs A dictionary of column names to tzinfo objects (or strings that will be converted to @@ -214,15 +95,16 @@ class QueryContext(BaseQueryContext): naive UTC timestamps. "aware" forces timezone-aware UTC datetimes. "schema" returns datetimes that match the server's column definition which means timezone-aware when the column schema defines a timezone (e.g. DateTime('UTC')) and naive for bare DateTime columns. - :param utc_tz_aware Deprecated. Use tz_mode instead. """ - super().__init__(settings, - query_formats, - column_formats, - encoding, - use_extended_dtypes if use_extended_dtypes is not None else False, - use_numpy if use_numpy is not None else False, - transport_settings=transport_settings) + super().__init__( + settings, + query_formats, + column_formats, + encoding, + use_extended_dtypes if use_extended_dtypes is not None else False, + use_numpy if use_numpy is not None else False, + transport_settings=transport_settings, + ) self.query = query self.parameters = parameters or {} self.use_none = True if use_none is None else use_none @@ -232,38 +114,42 @@ class QueryContext(BaseQueryContext): self.server_tz = server_tz self.apply_server_tz = apply_server_tz self.external_data = external_data - self.tz_mode = _resolve_tz_mode(tz_mode, utc_tz_aware) + self.tz_mode = tz_mode if tz_mode is not None else "naive_utc" + if self.tz_mode not in _VALID_TZ_MODES: + raise ProgrammingError(f'tz_mode must be "naive_utc", "aware", or "schema", got "{self.tz_mode}"') if isinstance(query_tz, str): try: - query_tz = pytz.timezone(query_tz) - except UnknownTimeZoneError as ex: - raise ProgrammingError(f'query_tz {query_tz} is not recognized') from ex + query_tz = tzutil.resolve_zone(query_tz) + except ZoneInfoNotFoundError as ex: + raise ProgrammingError(f"query_tz {query_tz} is not recognized; {tzutil.TZDATA_HINT}") from ex self.query_tz = query_tz if column_tzs is not None: - for col_name, timezone in column_tzs.items(): - if isinstance(timezone, str): + resolved_column_tzs = {} + for col_name, col_tz in column_tzs.items(): + if isinstance(col_tz, str): try: - timezone = pytz.timezone(timezone) - column_tzs[col_name] = timezone - except UnknownTimeZoneError as ex: - raise ProgrammingError(f'column_tz {timezone} is not recognized') from ex + resolved_column_tzs[col_name] = tzutil.resolve_zone(col_tz) + except ZoneInfoNotFoundError as ex: + raise ProgrammingError(f"column_tz {col_tz} is not recognized; {tzutil.TZDATA_HINT}") from ex + else: + resolved_column_tzs[col_name] = col_tz + column_tzs = resolved_column_tzs self.column_tzs = column_tzs self.column_tz = None self.response_tz = None self.block_info = False self.as_pandas = as_pandas - self.use_pandas_na = as_pandas and options.pd_extended_dtypes self.streaming = streaming - self._rename_response_column: Optional[str] = rename_response_column + self._rename_response_column: str | None = rename_response_column self.column_renamer = get_rename_method(rename_response_column) self._update_query() @property - def rename_response_column(self) -> Optional[str]: + def rename_response_column(self) -> str | None: return self._rename_response_column @rename_response_column.setter - def rename_response_column(self, method: Optional[str]): + def rename_response_column(self, method: str | None): self._rename_response_column = method self.column_renamer = get_rename_method(method) @@ -283,18 +169,7 @@ class QueryContext(BaseQueryContext): def is_command(self) -> bool: return command_re.search(self.uncommented_query) is not None - @property - def utc_tz_aware(self) -> Union[bool, Literal["schema"]]: - """Deprecated: use tz_mode instead.""" - warnings.warn( - "utc_tz_aware is deprecated and will be removed in 1.0. " - "Use tz_mode instead.", - DeprecationWarning, - stacklevel=2, - ) - return _TZ_MODE_TO_UTC_TZ_AWARE[self.tz_mode] - - def set_parameters(self, parameters: Dict[str, Any]): + def set_parameters(self, parameters: dict[str, Any]): self.parameters = parameters self._update_query() @@ -314,7 +189,7 @@ class QueryContext(BaseQueryContext): else: self.column_tz = None - def active_tz(self, datatype_tz: Optional[tzinfo]): + def active_tz(self, datatype_tz: tzinfo | None): if self.tz_mode == "schema": return self.column_tz or datatype_tz if self.column_tz: @@ -333,36 +208,33 @@ class QueryContext(BaseQueryContext): return None return active_tz - # pylint disable=too-many-positional-arguments - def updated_copy(self, - query: Optional[Union[str, bytes]] = None, - parameters: Optional[Dict[str, Any]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - server_tz: Optional[tzinfo] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = None, - max_str_len: Optional[int] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - use_extended_dtypes: Optional[bool] = None, - as_pandas: bool = False, - streaming: bool = False, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - rename_response_column: Optional[str] = None, - tz_mode: Optional[TzMode] = None) -> 'QueryContext': + def updated_copy( + self, + query: str | bytes | None = None, + parameters: dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + server_tz: tzinfo | None = None, + use_none: bool | None = None, + column_oriented: bool | None = None, + use_numpy: bool | None = None, + max_str_len: int | None = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + use_extended_dtypes: bool | None = None, + as_pandas: bool = False, + streaming: bool = False, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + rename_response_column: str | None = None, + tz_mode: TzMode | None = None, + ) -> "QueryContext": """ Creates Query context copy with parameters overridden/updated as appropriate. """ - if tz_mode is not None or utc_tz_aware is not None: - resolved_tz_mode = _resolve_tz_mode(tz_mode, utc_tz_aware) - else: - resolved_tz_mode = self.tz_mode + resolved_tz_mode = tz_mode if tz_mode is not None else self.tz_mode return QueryContext( query=query or self.query, parameters=dict_copy(self.parameters, parameters), @@ -401,16 +273,17 @@ class QueryResult(Closable): Wrapper class for query return values and metadata """ - # pylint: disable=too-many-arguments - def __init__(self, - result_set: Matrix = None, - block_gen: Generator[Matrix, None, None] = None, - column_names: Tuple[str, ...] = (), - column_types: Tuple['ClickHouseType', ...] = (), - column_oriented: bool = False, - source: Closable = None, - query_id: str = None, - summary: Dict[str, Any] = None): + def __init__( + self, + result_set: Matrix = None, + block_gen: Generator[Matrix, None, None] = None, + column_names: tuple[str, ...] = (), + column_types: tuple["ClickHouseType", ...] = (), + column_oriented: bool = False, + source: Closable = None, + query_id: str = None, + summary: dict[str, Any] = None, + ): self._result_rows = result_set self._result_columns = None self._block_gen = block_gen or empty_gen() @@ -431,12 +304,20 @@ class QueryResult(Closable): @property def result_columns(self) -> Matrix: if self._result_columns is None: - result = [[] for _ in range(len(self.column_names))] - with self.column_block_stream as stream: - for block in stream: - for base, added in zip(result, block): - base.extend(added) - self._result_columns = result + # If rows are already materialized and stream is closed, transpose from rows + # This happens when async client eagerly materializes result_rows + if self._result_rows is not None and self._block_gen is None: + if self._result_rows: + self._result_columns = list(map(list, zip(*self._result_rows))) + else: + self._result_columns = [[] for _ in range(len(self.column_names))] + else: + result = [[] for _ in range(len(self.column_names))] + with self.column_block_stream as stream: + for block in stream: + for base, added in zip(result, block): + base.extend(added) + self._result_columns = result return self._result_columns @property @@ -451,7 +332,7 @@ class QueryResult(Closable): @property def query_id(self) -> str: - query_id = self.summary.get('query_id') + query_id = self.summary.get("query_id") if query_id: return query_id return self._query_id @@ -494,7 +375,7 @@ class QueryResult(Closable): return len(self.result_set) @property - def first_item(self) -> Dict[str, Any]: + def first_item(self) -> dict[str, Any]: if self.column_oriented: return {name: col[0] for name, col in zip(self.column_names, self.result_set)} return dict(zip(self.column_names, self.result_set[0])) @@ -530,7 +411,7 @@ def remove_sql_comments(sql: str) -> str: # if the 2nd group (capturing comments) is not None, it means we have captured a # non-quoted, actual comment string, so return nothing to remove the comment if match.group(2): - return '' + return "" # Otherwise we've actually captured a quoted string, so return it return match.group(1) @@ -549,10 +430,10 @@ def to_arrow_batches(buffer: IOBase) -> StreamContext: return StreamContext(buffer, reader) -def arrow_buffer(table, compression: Optional[str] = None) -> Tuple[Sequence[str], Union[bytes, BinaryIO]]: +def arrow_buffer(table, compression: str | None = None) -> tuple[Sequence[str], bytes | BinaryIO]: pyarrow = check_arrow() write_options = None - if compression in ('zstd', 'lz4'): + if compression in ("zstd", "lz4"): write_options = pyarrow.ipc.IpcWriteOptions(compression=pyarrow.Codec(compression=compression)) sink = pyarrow.BufferOutputStream() with pyarrow.RecordBatchFileWriter(sink, table.schema, options=write_options) as writer: diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/streaming.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/streaming.py new file mode 100644 index 00000000000..c246c7e3b2a --- /dev/null +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/streaming.py @@ -0,0 +1,318 @@ +import asyncio +import logging +import threading +import zlib +from collections.abc import Iterator + +import lz4.frame +import zstandard + +from clickhouse_connect.driver.asyncqueue import EOF_SENTINEL, AsyncSyncQueue +from clickhouse_connect.driver.compression import available_compression +from clickhouse_connect.driver.exceptions import OperationalError +from clickhouse_connect.driver.types import Closable + +logger = logging.getLogger(__name__) + +__all__ = ["StreamingResponseSource", "StreamingFileAdapter", "StreamingInsertSource"] + +if "br" in available_compression: + import brotli +else: + brotli = None + + +class StreamingResponseSource(Closable): + """Streaming source that feeds chunks from async producer to sync consumer.""" + + READ_BUFFER_SIZE = 1024 * 1024 + + def __init__(self, response, encoding: str | None = None, exception_tag: str | None = None): + self.response = response + self.encoding = encoding + self.exception_tag = exception_tag + + # maxsize=10 means max ~10 socket reads buffered + self.queue = AsyncSyncQueue(maxsize=10) + + self._decompressor = None + self._decompressor_initialized = False + + # Multiple accesses to .gen must return the same generator, not create new ones + self._gen_cache = None + + self._producer_task = None + self._producer_started = threading.Event() + self._producer_error: Exception | None = None + self._producer_completed = False + + async def start_producer(self, loop: asyncio.AbstractEventLoop): + """Start the async producer task. + Must be called from the event loop thread before consuming. + """ + + async def producer(): + """Async producer: reads chunks from response, feeds queue.""" + data_sent = False + try: + while True: + chunk = await self.response.content.read(self.READ_BUFFER_SIZE) + if not chunk: + break + data_sent = True + await self.queue.async_q.put(chunk) + + await self.queue.async_q.put(EOF_SENTINEL) + self._producer_completed = True + + except Exception as e: + logger.error("Producer error while streaming response: %s", e, exc_info=True) + if not data_sent: + e = OperationalError("Failed to read response data from server") + self._producer_error = e + + try: + await self.queue.async_q.put(e) + except RuntimeError: + pass + + finally: + self.queue.shutdown() + + self._producer_task = loop.create_task(producer()) + self._producer_started.set() + + @property + def gen(self) -> Iterator[bytes]: + """Generator that yields decompressed chunks. + + CRITICAL: Returns cached generator to prevent multiple generators + from competing to read from the same queue. + """ + if self._gen_cache is not None: + return self._gen_cache + + self._gen_cache = self._create_generator() + return self._gen_cache + + def _create_generator(self) -> Iterator[bytes]: + """Creates the actual generator function.""" + if not self._producer_started.wait(timeout=5.0): + raise RuntimeError("Producer failed to start within timeout") + + if self.encoding and not self._decompressor_initialized: + self._decompressor_initialized = True + try: + self._decompressor = self._create_decompressor(self.encoding) + except Exception as e: + logger.error("Failed to create decompressor for %s: %s", self.encoding, e) + raise + + while True: + chunk = self.queue.sync_q.get() + + if chunk is EOF_SENTINEL: + if self._decompressor: + try: + if hasattr(self._decompressor, "flush"): + final = self._decompressor.flush() + if final: + yield final + except Exception as e: + logger.error("Error flushing decompressor: %s", e, exc_info=True) + raise + break + + if isinstance(chunk, Exception): + raise chunk + + if self._decompressor: + try: + if hasattr(self._decompressor, "decompress"): + decompressed = self._decompressor.decompress(chunk) + else: + decompressed = self._decompressor.process(chunk) + if decompressed: + yield decompressed + except Exception as e: + logger.error("Decompression error: %s", e, exc_info=True) + raise + else: + yield chunk + + @staticmethod + def _create_decompressor(encoding: str): + """Create incremental decompressor for encoding.""" + if encoding == "gzip": + return zlib.decompressobj(16 + zlib.MAX_WBITS) + + if encoding == "deflate": + return zlib.decompressobj() + + if encoding == "br": + if brotli is not None: + return brotli.Decompressor() + raise ImportError("brotli compression requires 'brotli' package. Install with: pip install brotli") + + if encoding == "zstd": + return zstandard.ZstdDecompressor().decompressobj() + + if encoding == "lz4": + return lz4.frame.LZ4FrameDecompressor() + + raise ValueError(f"Unsupported compression encoding: {encoding}") + + async def aclose(self): + """Async cleanup resources""" + self.queue.shutdown() + + if self._producer_task and not self._producer_task.done(): + self._producer_task.cancel() + try: + await self._producer_task + except asyncio.CancelledError: + pass + except Exception: + pass + + if self.response and not self.response.closed: + if not self._producer_completed: + self.response.close() + await asyncio.sleep(0.05) + + def close(self): + """Synchronous cleanup resources""" + self.queue.shutdown() + + if self._producer_task and not self._producer_task.done(): + self._producer_task.cancel() + + if self.response and not self.response.closed: + if not self._producer_completed: + self.response.close() + + +class StreamingFileAdapter: + """File-like adapter for PyArrow streaming.""" + + def __init__(self, streaming_source): + self.streaming_source = streaming_source + self.gen = streaming_source.gen + self.buffer = b"" + self.closed = False + self.eof = False + + def read(self, size: int = -1) -> bytes: + """Read up to size bytes from stream""" + if self.closed or self.eof: + return b"" + + if size != -1 and len(self.buffer) >= size: + result = self.buffer[:size] + self.buffer = self.buffer[size:] + return result + + chunks = [self.buffer] if self.buffer else [] + current_len = len(self.buffer) + self.buffer = b"" + + while (size == -1 or current_len < size) and not self.eof: + try: + chunk = next(self.gen) + if chunk: + chunks.append(chunk) + current_len += len(chunk) + else: + self.eof = True + break + except StopIteration: + self.eof = True + break + + full_data = b"".join(chunks) + + if size == -1 or len(full_data) <= size: + return full_data + + result = full_data[:size] + self.buffer = full_data[size:] + return result + + def close(self): + self.closed = True + + +class StreamingInsertSource: + """Streaming source for async inserts (reverse bridge)""" + + def __init__(self, transform, context, loop: asyncio.AbstractEventLoop, maxsize: int = 10): + self.transform = transform + self.context = context + self.loop = loop + self.queue = AsyncSyncQueue(maxsize=maxsize) + self._producer_future = None + self._started = False + + def start_producer(self): + if self._started: + raise RuntimeError("Producer already started") + self._started = True + + def producer(): + try: + for block in self.transform.build_insert(self.context): + self.queue.sync_q.put(block) + + self.queue.sync_q.put(EOF_SENTINEL) + + except Exception as e: + logger.error("Insert producer error: %s", e, exc_info=True) + try: + self.queue.sync_q.put(e) + except Exception: + pass + finally: + self.queue.shutdown() + + self._producer_future = self.loop.run_in_executor(None, producer) + + async def async_generator(self): + """Async generator that yields blocks for aiohttp streaming.""" + if not self._started: + raise RuntimeError("Producer not started, call start_producer() first") + + try: + while True: + chunk = await self.queue.async_q.get() + + if chunk is EOF_SENTINEL: + break + + if isinstance(chunk, Exception): + raise chunk + + yield chunk + + except Exception as e: + logger.error("Insert consumer error: %s", e, exc_info=True) + raise + finally: + if self._producer_future and not self._producer_future.done(): + try: + await self._producer_future + except Exception: + pass + + async def close(self, timeout: float | None = 1.0): + """Shut down the queue and wait for the producer thread to terminate. Pass ``timeout=None`` to wait without a deadline.""" + self.queue.shutdown() + if self._producer_future and not self._producer_future.done(): + try: + if timeout is None: + await self._producer_future + else: + await asyncio.wait_for(asyncio.shield(self._producer_future), timeout=timeout) + except asyncio.TimeoutError: + logger.warning("Insert producer did not finish within timeout") + except Exception: + pass diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/summary.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/summary.py index ef152cad769..8938bb82ed1 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/summary.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/summary.py @@ -1,33 +1,30 @@ -from typing import Optional - from clickhouse_connect.datatypes.registry import get_from_name - from clickhouse_connect.driver.query import QueryResult class QuerySummary: summary = {} - def __init__(self, summary: Optional[dict] = None): + def __init__(self, summary: dict | None = None): if summary is not None: self.summary = summary @property def written_rows(self) -> int: - return int(self.summary.get('written_rows', 0)) + return int(self.summary.get("written_rows", 0)) def written_bytes(self) -> int: - return int(self.summary.get('written_bytes', 0)) + return int(self.summary.get("written_bytes", 0)) def query_id(self) -> str: - return self.summary.get('query_id', '') + return self.summary.get("query_id", "") def as_query_result(self) -> QueryResult: data = [] column_names = [] column_types = [] - str_type = get_from_name('String') - int_type = get_from_name('Int64') + str_type = get_from_name("String") + int_type = get_from_name("Int64") for key, value in self.summary.items(): column_names.append(key) if value.isnumeric(): diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/tools.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/tools.py index 42480858d6c..ab66531d621 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/tools.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/tools.py @@ -1,33 +1,78 @@ -from typing import Optional, Sequence, Dict, Any +import asyncio +from collections.abc import Sequence +from typing import Any from clickhouse_connect.driver import Client -from clickhouse_connect.driver.summary import QuerySummary from clickhouse_connect.driver.binding import quote_identifier +from clickhouse_connect.driver.summary import QuerySummary + + +def insert_file( + client: Client, + table: str, + file_path: str, + fmt: str | None = None, + column_names: Sequence[str] | None = None, + database: str | None = None, + settings: dict[str, Any] | None = None, + compression: str | None = None, +) -> QuerySummary: + if not database and table[0] not in ("`", "'") and table.find(".") > 0: + full_table = table + elif database: + full_table = f"{quote_identifier(database)}.{quote_identifier(table)}" + else: + full_table = quote_identifier(table) + if not fmt: + fmt = "CSV" if column_names else "CSVWithNames" + if compression is None: + if file_path.endswith(".gzip") or file_path.endswith(".gz"): + compression = "gzip" + with open(file_path, "rb") as file: + return client.raw_insert( + full_table, + column_names=column_names, + insert_block=file, + fmt=fmt, + settings=settings, + compression=compression, + ) + +async def insert_file_async( + client, + table: str, + file_path: str, + fmt: str | None = None, + column_names: Sequence[str] | None = None, + database: str | None = None, + settings: dict[str, Any] | None = None, + compression: str | None = None, +) -> QuerySummary: -def insert_file(client: Client, - table: str, - file_path: str, - fmt: Optional[str] = None, - column_names: Optional[Sequence[str]] = None, - database: Optional[str] = None, - settings: Optional[Dict[str, Any]] = None, - compression: Optional[str] = None) -> QuerySummary: - if not database and table[0] not in ('`', "'") and table.find('.') > 0: + if not database and table[0] not in ("`", "'") and table.find(".") > 0: full_table = table elif database: - full_table = f'{quote_identifier(database)}.{quote_identifier(table)}' + full_table = f"{quote_identifier(database)}.{quote_identifier(table)}" else: full_table = quote_identifier(table) if not fmt: - fmt = 'CSV' if column_names else 'CSVWithNames' + fmt = "CSV" if column_names else "CSVWithNames" if compression is None: - if file_path.endswith('.gzip') or file_path.endswith('.gz'): - compression = 'gzip' - with open(file_path, 'rb') as file: - return client.raw_insert(full_table, - column_names=column_names, - insert_block=file, - fmt=fmt, - settings=settings, - compression=compression) + if file_path.endswith(".gzip") or file_path.endswith(".gz"): + compression = "gzip" + + def read_file(): + with open(file_path, "rb") as file: + return file.read() + + file_data = await asyncio.to_thread(read_file) + + return await client.raw_insert( + full_table, + column_names=column_names, + insert_block=file_data, + fmt=fmt, + settings=settings, + compression=compression, + ) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/transform.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/transform.py index ce4e9b0a6ff..5970277b5ee 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/transform.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/transform.py @@ -1,14 +1,13 @@ import logging -from typing import Union from clickhouse_connect.datatypes import registry from clickhouse_connect.driver.common import write_leb128 +from clickhouse_connect.driver.compression import get_compressor from clickhouse_connect.driver.exceptions import StreamCompleteException, StreamFailureError from clickhouse_connect.driver.insert import InsertContext from clickhouse_connect.driver.npquery import NumpyResult -from clickhouse_connect.driver.query import QueryResult, QueryContext +from clickhouse_connect.driver.query import QueryContext, QueryResult from clickhouse_connect.driver.types import ByteSource -from clickhouse_connect.driver.compression import get_compressor _EMPTY_CTX = QueryContext() @@ -16,9 +15,8 @@ logger = logging.getLogger(__name__) class NativeTransform: - # pylint: disable=too-many-locals, too-many-statements, too-many-branches @staticmethod - def parse_response(source: ByteSource, context: QueryContext = _EMPTY_CTX) -> Union[NumpyResult, QueryResult]: + def parse_response(source: ByteSource, context: QueryContext = _EMPTY_CTX) -> NumpyResult | QueryResult: names = [] col_types = [] block_num = 0 @@ -71,6 +69,20 @@ class NativeTransform: if not error_msg: error_msg = extract_error_message(source.last_message) raise StreamFailureError(error_msg) from None + raise StreamFailureError("Stream ended unexpectedly (connection closed by server)") from ex + + # Handle async streaming errors (ClientPayloadError from aiohttp) + if ex.__class__.__name__ == "ClientPayloadError": + if source.last_message: + error_msg = None + exception_tag = getattr(source, "exception_tag", None) + if exception_tag: + error_msg = extract_exception_with_tag(source.last_message, exception_tag) + if not error_msg: + error_msg = extract_error_message(source.last_message) + raise StreamFailureError(error_msg) from None + raise StreamFailureError("Stream failed during read (connection closed by server)") from ex + raise block_num += 1 return result_block @@ -88,7 +100,7 @@ class NativeTransform: yield next_block if context.use_numpy: - res_types = [col.dtype if hasattr(col, 'dtype') else 'O' for col in first_block] + res_types = [col.dtype if hasattr(col, "dtype") else "O" for col in first_block] return NumpyResult(gen(), tuple(names), tuple(col_types), res_types, source) return QueryResult(None, gen(), tuple(names), tuple(col_types), context.column_oriented, source) @@ -112,15 +124,14 @@ class NativeTransform: context.start_column(col_name) try: col_type.write_column(data, output, context) - except Exception as ex: # pylint: disable=broad-except + except Exception as ex: # This is hideous, but some low level serializations can fail while streaming # the insert if the user has included bad data in the column. We need to ensure that the # insert fails (using garbage data) to avoid a partial insert, and use the context to # propagate the correct exception to the user - logger.error('Error serializing column `%s` into data type `%s`', - col_name, col_type.name, exc_info=True) + logger.error("Error serializing column `%s` into data type `%s`", col_name, col_type.name, exc_info=True) context.insert_exception = ex - yield 'INTERNAL EXCEPTION WHILE SERIALIZING'.encode() + yield b"INTERNAL EXCEPTION WHILE SERIALIZING" return yield compressor.compress_block(output) footer = compressor.flush() @@ -130,8 +141,7 @@ class NativeTransform: return chunk_gen() -# pylint: disable=too-many-return-statements,too-many-branches -def extract_exception_with_tag(message: bytes, exception_tag: str) -> Union[str, None]: +def extract_exception_with_tag(message: bytes, exception_tag: str) -> str | None: """Extract exception message from the new format with exception tag. Server v25.11+. Format: __exception__<TAG>\\r\\n<error message>\\r\\n<message_length> <TAG>__exception__\\r\\n @@ -185,18 +195,18 @@ def extract_exception_with_tag(message: bytes, exception_tag: str) -> Union[str, try: return error_message.decode("utf-8", errors="replace").strip() - except Exception: # pylint: disable=broad-except + except Exception: return error_message.decode("latin-1", errors="replace").strip() def extract_error_message(message: bytes) -> str: if len(message) > 1024: message = message[-1024:] - error_start = message.find('Code: '.encode()) + error_start = message.find(b"Code: ") if error_start != -1: message = message[error_start:] try: message_str = message.decode() except UnicodeError: - message_str = f'unrecognized data found in stream: `{message.hex()[128:]}`' + message_str = f"unrecognized data found in stream: `{message.hex()[128:]}`" return message_str diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/types.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/types.py index 6a425e5e1f5..f54b5e86e60 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/types.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/types.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Sequence, Any +from collections.abc import Sequence +from typing import Any Matrix = Sequence[Sequence[Any]] @@ -11,7 +12,7 @@ class Closable(ABC): class ByteSource(Closable): - last_message:bytes = None + last_message: bytes = None @abstractmethod def read_leb128(self) -> int: diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driver/tzutil.py b/contrib/python/clickhouse-connect/clickhouse_connect/driver/tzutil.py index ab3124f46e5..612620a07db 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driver/tzutil.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driver/tzutil.py @@ -1,8 +1,6 @@ import os -from datetime import datetime, tzinfo -from typing import Optional, Tuple, Union - -import pytz +import zoneinfo +from datetime import datetime, timezone, tzinfo tzlocal = None try: @@ -14,49 +12,232 @@ except ImportError: # timezone, but if someone insists on using the local timezone we will try to convert. The problem is we # never have anything but an epoch timestamp returned from ClickHouse, so attempts to convert times when the # local timezone is "DST" aware (like 'CEST' vs 'CET') will be wrong approximately half the time -local_tz: pytz.timezone +local_tz: tzinfo local_tz_dst_safe: bool = False -# Timezone names that are equivalent to UTC -UTC_EQUIVALENTS = ('UTC', 'Etc/UTC', 'GMT', 'Universal', 'GMT-0', 'Zulu', 'Greenwich', 'UCT') +# Zero-offset IANA timezone aliases that are semantically UTC. Listing every alias lets +# resolve_zone() short-circuit these names without needing a system zoneinfo database, matching +# the behavior pytz provided by bundling its own tz data. +UTC_EQUIVALENTS = ( + "UTC", + "Etc/UTC", + "UCT", + "Etc/UCT", + "GMT", + "Etc/GMT", + "GMT0", + "GMT-0", + "GMT+0", + "Etc/GMT0", + "Etc/GMT-0", + "Etc/GMT+0", + "Universal", + "Etc/Universal", + "Zulu", + "Etc/Zulu", + "Greenwich", + "Etc/Greenwich", +) + +# Appended to error/warning messages when a named IANA zone cannot be resolved. On systems without +# a system zoneinfo database (slim containers, Windows without tzdata), users can install the tzdata +# extra to get the IANA zone data. +TZDATA_HINT = "install the tzdata package (e.g. `pip install clickhouse-connect[tzdata]`) if no system zoneinfo database is available" + + +def resolve_zone(tz_name: str) -> tzinfo: + """Resolve an IANA timezone name to a tzinfo. + + Short-circuits UTC-equivalent names to datetime.timezone.utc so that representing UTC + does not require an IANA zoneinfo database to be available on the host. Other names are + resolved via zoneinfo.ZoneInfo and will raise ZoneInfoNotFoundError if the host has + no system zoneinfo and the tzdata package is not installed. + """ + if tz_name in UTC_EQUIVALENTS: + return timezone.utc + try: + return zoneinfo.ZoneInfo(tz_name) + except ValueError as ex: + # ZoneInfo raises ValueError for empty strings, absolute paths, and non-normalized + # keys; funnel those into ZoneInfoNotFoundError so callers only need one except clause. + raise zoneinfo.ZoneInfoNotFoundError(str(ex)) from ex + +def normalize_timezone(tz: tzinfo) -> tuple[tzinfo, bool]: + # ZoneInfo exposes the IANA key on `.key`; fall back to tzname(None) for other tzinfo + # subclasses (datetime.timezone, fixed offsets). pytz used to return the IANA name from + # tzname(None), but ZoneInfo returns None, which would collapse every named zone into the + # "unsafe" fallback branch. + tz_key = getattr(tz, "key", None) or tz.tzname(None) -def normalize_timezone(timezone: pytz.timezone) -> Tuple[pytz.timezone, bool]: - if timezone.tzname(None) in UTC_EQUIVALENTS: - return pytz.UTC, True + if tz_key in UTC_EQUIVALENTS: + return timezone.utc, True - if timezone.tzname(None) in pytz.common_timezones: - return timezone, True + if tz_key in zoneinfo.available_timezones(): + return tz, True if tzlocal is not None: # Maybe we can use the tzlocal module to get a safe timezone local_name = tzlocal.get_localzone_name() - if local_name in pytz.common_timezones: - return pytz.timezone(local_name), True + if local_name in zoneinfo.available_timezones(): + return zoneinfo.ZoneInfo(local_name), True - return timezone, False + return tz, False -def is_utc_timezone(tz: Optional[Union[tzinfo, str]]) -> bool: +def is_utc_timezone(tz: tzinfo | str | None) -> bool: """Check if timezone is UTC or an equivalent (Etc/UTC, GMT, etc.). - This handles the issue where pytz.timezone('Etc/UTC') != pytz.UTC despite + This handles the issue where zoneinfo.ZoneInfo('Etc/UTC') != zoneinfo.ZoneInfo("UTC") despite being semantically equivalent. Also accepts timezone name strings. """ if tz is None: return False if isinstance(tz, str): return tz in UTC_EQUIVALENTS - if tz == pytz.UTC: + if tz is timezone.utc: return True return tz.tzname(None) in UTC_EQUIVALENTS -def utcfromtimestamp(ts: float) -> datetime: - return datetime.fromtimestamp(ts, tz=pytz.UTC).replace(tzinfo=None) +def utc_equivalent_tzaware_datetime(ts: int, microseconds: int, tz_info: tzinfo) -> datetime: + """Build a UTC-equivalent timezone-aware datetime via epoch arithmetic. + + For UTC-equivalent timezones (UTC, Etc/UTC, GMT, etc.), construct the datetime + using epoch arithmetic rather than datetime.fromtimestamp(), then attach the + timezone. This avoids timezone conversion machinery that's unnecessary for UTC. + + Sub-second precision must be supplied via the microseconds argument; the ts + value is interpreted as integer seconds. + + Args: + ts: Integer Unix timestamp (seconds since epoch) + microseconds: Microsecond component (0-999999) + tz_info: A UTC-equivalent timezone object + + Returns: + Timezone-aware datetime in the specified timezone + """ + seconds = int(ts) + + days = seconds // 86400 + secs_in_day = seconds % 86400 + + year, month, day = _epoch_days_to_date_components(days) + + hour = secs_in_day // 3600 + secs_in_day %= 3600 + minute = secs_in_day // 60 + second = secs_in_day % 60 + + return datetime(year, month, day, hour, minute, second, microseconds, tzinfo=tz_info) + + +def utcfromtimestamp_with_microseconds(ts: int, microseconds: int = 0) -> datetime: + """Convert integer Unix timestamp to naive UTC datetime with explicit microseconds. + + More efficient than calling utcfromtimestamp() and then .replace(microsecond=...) + because it constructs the datetime once with all components. + + Args: + ts: Integer Unix timestamp (seconds since epoch) + microseconds: Microsecond component (0-999999) + + Returns: + Naive UTC datetime with specified microseconds + """ + seconds = int(ts) + + days = seconds // 86400 + secs_in_day = seconds % 86400 + + year, month, day = _epoch_days_to_date_components(days) + + hour = secs_in_day // 3600 + secs_in_day %= 3600 + minute = secs_in_day // 60 + second = secs_in_day % 60 + + return datetime(year, month, day, hour, minute, second, microseconds) + + +def utcfromtimestamp(ts: int) -> datetime: + """Convert integer Unix timestamp to naive UTC datetime via epoch arithmetic. + + Avoids the expensive datetime.fromtimestamp() + replace() round-trip. Sub-second + precision is not supported; pass an integer number of seconds. For sub-second + inputs, use utcfromtimestamp_with_microseconds. + """ + seconds = int(ts) + + days = seconds // 86400 + secs_in_day = seconds % 86400 + + year, month, day = _epoch_days_to_date_components(days) + + hour = secs_in_day // 3600 + secs_in_day %= 3600 + minute = secs_in_day // 60 + second = secs_in_day % 60 + + return datetime(year, month, day, hour, minute, second, 0) + + +_MONTH_DAYS = (0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365) +_MONTH_DAYS_LEAP = (0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366) + + +def _epoch_days_to_date_components(days: int) -> tuple[int, int, int]: + """Convert days since epoch to (year, month, day). + + This is a pure Python implementation of the same algorithm as + the Cython epoch_days_to_date, but returns components instead of a date object. + """ + if 0 <= days < 47482: + cycles = (days + 365) // 1461 + rem = (days + 365) - cycles * 1461 + years = rem // 365 + rem -= years * 365 + year = (cycles << 2) + years + 1969 + if years == 4: + return year - 1, 12, 31 + if years == 3: + m_list = _MONTH_DAYS_LEAP + else: + m_list = _MONTH_DAYS + else: + cycles400 = (days + 134774) // 146097 + rem = days + 134774 - (cycles400 * 146097) + cycles100 = rem // 36524 + rem -= cycles100 * 36524 + cycles = rem // 1461 + rem -= cycles * 1461 + years = rem // 365 + rem -= years * 365 + year = (cycles << 2) + cycles400 * 400 + cycles100 * 100 + years + 1601 + if years == 4 or cycles100 == 4: + return year - 1, 12, 31 + if years == 3 and year % 100 != 0: + m_list = _MONTH_DAYS_LEAP + else: + m_list = _MONTH_DAYS + + month = (rem + 24) >> 5 + prev = m_list[month] + while rem < prev: + month -= 1 + prev = m_list[month] + + return year, month + 1, rem + 1 - prev + + +def _detect_local_tz() -> tzinfo: + env_tz = os.environ.get("TZ") + if env_tz: + try: + return resolve_zone(env_tz) + except zoneinfo.ZoneInfoNotFoundError: + pass + return datetime.now().astimezone().tzinfo -try: - local_tz = pytz.timezone(os.environ.get('TZ', '')) -except pytz.UnknownTimeZoneError: - local_tz = datetime.now().astimezone().tzinfo -local_tz, local_tz_dst_safe = normalize_timezone(local_tz) +local_tz, local_tz_dst_safe = normalize_timezone(_detect_local_tz()) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driverc/buffer.pyx b/contrib/python/clickhouse-connect/clickhouse_connect/driverc/buffer.pyx index 3b347245076..1e8bc8e0ca2 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driverc/buffer.pyx +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driverc/buffer.pyx @@ -159,18 +159,29 @@ cdef class ResponseBuffer: cdef unsigned char b cdef char* buf while x < num_rows: - sz = 0 - shift = 0 - while 1: - if self.buf_loc < self.buf_sz: - b = self.buffer[self.buf_loc] - self.buf_loc += 1 - else: - b = self._read_byte_load() - sz += ((b & 0x7f) << shift) - if (b & 0x80) == 0: - break - shift += 7 + # Fast path: 1-byte varint covers most string lengths < 128 + if self.buf_loc < self.buf_sz: + b = self.buffer[self.buf_loc] + self.buf_loc += 1 + else: + b = self._read_byte_load() + + if (b & 0x80) == 0: + sz = b + else: + sz = b & 0x7f + shift = 7 + while 1: + if self.buf_loc < self.buf_sz: + b = self.buffer[self.buf_loc] + self.buf_loc += 1 + else: + b = self._read_byte_load() + sz += ((b & 0x7f) << shift) + if (b & 0x80) == 0: + break + shift += 7 + buf = self.read_bytes_c(sz) if encoding: try: @@ -194,21 +205,29 @@ cdef class ResponseBuffer: cdef char * null_map = <char *> PyMem_Malloc(<size_t> num_rows) memcpy(<void *> null_map, <void *> self.read_bytes_c(num_rows), num_rows) for x in range(num_rows): + # Fast path: 1-byte varint covers most string lengths < 128 if self.buf_loc < self.buf_sz: b = self.buffer[self.buf_loc] self.buf_loc += 1 else: b = self._read_byte_load() - shift = 0 - sz = b & 0x7f - while b & 0x80: - shift += 7 - if self.buf_loc < self.buf_sz: - b = self.buffer[self.buf_loc] - self.buf_loc += 1 - else: - b = self._read_byte_load() - sz += ((b & 0x7f) << shift) + + if (b & 0x80) == 0: + sz = b + else: + sz = b & 0x7f + shift = 7 + while 1: + if self.buf_loc < self.buf_sz: + b = self.buffer[self.buf_loc] + self.buf_loc += 1 + else: + b = self._read_byte_load() + sz += ((b & 0x7f) << shift) + if (b & 0x80) == 0: + break + shift += 7 + buf = self.read_bytes_c(sz) if null_map[x]: v = null_obj diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/driverc/dataconv.pyx b/contrib/python/clickhouse-connect/clickhouse_connect/driverc/dataconv.pyx index 60d3df6fad7..2718ae46454 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/driverc/dataconv.pyx +++ b/contrib/python/clickhouse-connect/clickhouse_connect/driverc/dataconv.pyx @@ -5,22 +5,29 @@ import array from datetime import datetime, date import cython +import sys from .buffer cimport ResponseBuffer from cpython cimport Py_INCREF, Py_DECREF -from cpython.buffer cimport PyBUF_READ +from cpython.buffer cimport PyBUF_READ, PyObject_GetBuffer, PyBuffer_Release, PyBUF_SIMPLE, Py_buffer from cpython.mem cimport PyMem_Free, PyMem_Malloc from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM -from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_Resize +from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_Resize, PyByteArray_AS_STRING from cpython.memoryview cimport PyMemoryView_FromMemory +from cpython.datetime cimport datetime_new, import_datetime from cython.view cimport array as cvarray from ipaddress import IPv4Address from uuid import UUID, SafeUUID from libc.string cimport memcpy from datetime import tzinfo -from clickhouse_connect.driver import tzutil +from clickhouse_connect.driver import tzutil, options +from clickhouse_connect.driver.common import must_swap from clickhouse_connect.driver.errors import NONE_IN_NULLABLE_COLUMN +from clickhouse_connect.driver.exceptions import DataError + +# Initialize datetime C API for direct object construction +import_datetime() @cython.boundscheck(False) @cython.wraparound(False) @@ -52,14 +59,23 @@ def read_datetime_col(ResponseBuffer buffer, unsigned long long num_rows, tzinfo cdef char * loc = buffer.read_bytes_c(4 * num_rows) cdef object column = PyTuple_New(num_rows), v if tzinfo is None: - fts = tzutil.utcfromtimestamp + # Fast path: naive UTC, construct datetime directly via C API + while x < num_rows: + v = _epoch_to_datetime((<unsigned int*>loc)[0], 0, None) + PyTuple_SET_ITEM(column, x, v) + Py_INCREF(v) + loc += 4 + x += 1 + elif tzutil.is_utc_timezone(tzinfo): + # Fast path: UTC-equivalent timezone, direct C API construction with tzinfo while x < num_rows: - v = fts((<unsigned int*>loc)[0]) + v = _epoch_to_datetime((<unsigned int*>loc)[0], 0, tzinfo) PyTuple_SET_ITEM(column, x, v) Py_INCREF(v) loc += 4 x += 1 else: + # Slow path: non-UTC timezone, requires fromtimestamp for DST-aware conversion fts = datetime.fromtimestamp while x < num_rows: v = fts((<unsigned int*>loc)[0], tzinfo) @@ -145,7 +161,7 @@ cpdef inline object epoch_days_to_date(int days): year = (cycles << 2) + cycles400 * 400 + cycles100 * 100 + years + 1601 if years == 4 or cycles100 == 4: return date(year - 1, 12, 31) - if years == 3 and (year == 2000 or year % 100 != 0): + if years == 3 and year % 100 != 0: m_list = MONTH_DAYS_LEAP else: m_list = MONTH_DAYS @@ -157,6 +173,130 @@ cpdef inline object epoch_days_to_date(int days): return date(year, month + 1, rem + 1 - prev) [email protected](True) [email protected](False) [email protected](False) +cdef inline void _epoch_days_to_components_c(int days, int* out_year, int* out_month, int* out_day): + """Convert days since epoch to (year, month, day) components without allocating objects. + + Low-level helper that computes calendar components directly, reusing the fast + epoch_days_to_date algorithm but returning components as output parameters. + + Args: + days: Days since epoch + out_year, out_month, out_day: Pointers to store results + """ + cdef int years, month, year, cycles400, cycles100, cycles, rem + cdef unsigned short prev + cdef unsigned short* m_list + + if 0 <= days < 47482: + cycles = (days + 365) // 1461 + rem = (days + 365) - cycles * 1461 + years = rem // 365 + rem -= years * 365 + year = (cycles << 2) + years + 1969 + if years == 4: + out_year[0] = year - 1 + out_month[0] = 12 + out_day[0] = 31 + return + if years == 3: + m_list = MONTH_DAYS_LEAP + else: + m_list = MONTH_DAYS + else: + cycles400 = (days + 134774) // 146097 + rem = days + 134774 - (cycles400 * 146097) + cycles100 = rem // 36524 + rem -= cycles100 * 36524 + cycles = rem // 1461 + rem -= cycles * 1461 + years = rem // 365 + rem -= years * 365 + year = (cycles << 2) + cycles400 * 400 + cycles100 * 100 + years + 1601 + if years == 4 or cycles100 == 4: + out_year[0] = year - 1 + out_month[0] = 12 + out_day[0] = 31 + return + if years == 3 and year % 100 != 0: + m_list = MONTH_DAYS_LEAP + else: + m_list = MONTH_DAYS + + month = (rem + 24) >> 5 + prev = m_list[month] + while rem < prev: + month -= 1 + prev = m_list[month] + + out_year[0] = year + out_month[0] = month + 1 + out_day[0] = rem + 1 - prev + + [email protected](True) [email protected](False) [email protected](False) +cpdef inline tuple epoch_seconds_to_components(long long seconds): + """Convert epoch seconds to (year, month, day, hour, minute, second, microsecond). + + This decomposes a Unix timestamp into datetime components without creating + intermediate objects. Handles both positive and negative epoch values correctly. + + Args: + seconds: Unix timestamp (seconds since 1970-01-01 00:00:00 UTC) + + Returns: + Tuple of (year, month, day, hour, minute, second, microsecond) + """ + cdef long long days, secs_in_day + cdef int hour, minute, second + cdef int year, month, day + + days = seconds // 86400 + secs_in_day = seconds - days * 86400 + if secs_in_day < 0: + secs_in_day += 86400 + days -= 1 + + _epoch_days_to_components_c(days, &year, &month, &day) + + hour = secs_in_day // 3600 + secs_in_day %= 3600 + minute = secs_in_day // 60 + second = secs_in_day % 60 + + return (year, month, day, hour, minute, second, 0) + + +cdef inline object _epoch_to_datetime(long long seconds, int microseconds, object tz): + """Construct datetime directly from epoch seconds via C API, bypassing tuple + Python constructor. + + Uses cpython.datetime.datetime_new which calls PyDateTimeAPI factory directly, + avoiding intermediate tuple allocation and Python-level datetime(...) overhead. + """ + cdef long long days, secs_in_day + cdef int hour, minute, second + cdef int year, month, day + + days = seconds // 86400 + secs_in_day = seconds - days * 86400 + if secs_in_day < 0: + secs_in_day += 86400 + days -= 1 + + _epoch_days_to_components_c(days, &year, &month, &day) + + hour = secs_in_day // 3600 + secs_in_day %= 3600 + minute = secs_in_day // 60 + second = secs_in_day % 60 + + return datetime_new(year, month, day, hour, minute, second, microseconds, tz, 0) + + @cython.boundscheck(False) @cython.wraparound(False) def read_uuid_col(ResponseBuffer buffer, unsigned long long num_rows): @@ -181,6 +321,84 @@ def read_uuid_col(ResponseBuffer buffer, unsigned long long num_rows): @cython.boundscheck(False) @cython.wraparound(False) +def read_datetime64_naive_col(object column: Sequence, unsigned long long prec, tz: tzinfo = None): + """Read DateTime64 column using epoch arithmetic, for naive UTC or UTC-equivalent timezones. + + Constructs datetime objects directly from epoch seconds components via the + CPython datetime C API. When tz is None, the result is naive. When tz is a + UTC-equivalent timezone (UTC, Etc/UTC, GMT, etc.), the same arithmetic path + is used and the tz is attached to the constructed datetime. + + Args: + column: Sequence of integer ticks + prec: Precision divisor (10**scale) + tz: Optional UTC-equivalent timezone to attach. Must be None or UTC-equivalent + (no DST or offset conversion is performed by this function). + + Returns: + Tuple of datetime objects with microseconds, naive when tz is None + """ + cdef unsigned long long x = 0 + cdef unsigned long long num_rows = len(column) + cdef object result = PyTuple_New(num_rows), v + cdef long long ticks, seconds, fractional_ticks + cdef unsigned long long microseconds + cdef long long prec_signed = <long long>prec + + for x in range(num_rows): + ticks = column[x] + seconds = ticks // prec_signed + fractional_ticks = ticks - seconds * prec_signed + microseconds = (fractional_ticks * 1000000) // prec_signed + + v = _epoch_to_datetime(seconds, microseconds, tz) + PyTuple_SET_ITEM(result, x, v) + Py_INCREF(v) + + return result + + [email protected](False) [email protected](False) +def read_datetime64_tz_col(object column: Sequence, unsigned long long prec, tzinfo: tzinfo): + """Read DateTime64 column with timezone conversion using per-row fromtimestamp. + + This handles non-UTC timezone conversion where DST-aware logic is necessary. + The loop is in Cython for speed of the per-row datetime construction. + + Args: + column: Sequence of integer ticks + prec: Precision divisor (10**scale) + tzinfo: Target timezone object + + Returns: + List of datetime objects with specified timezone and microseconds + """ + cdef unsigned long long x = 0 + cdef unsigned long long num_rows = len(column) + cdef object result = PyTuple_New(num_rows), v + cdef long long ticks, seconds, fractional_ticks + cdef unsigned long long microseconds + cdef object dt_from = datetime.fromtimestamp + cdef long long prec_signed = <long long>prec + + for x in range(num_rows): + ticks = column[x] + seconds = ticks // prec_signed + fractional_ticks = ticks - seconds * prec_signed + microseconds = (fractional_ticks * 1000000) // prec_signed + + v = dt_from(seconds, tzinfo) + if microseconds != 0: + v = v.replace(microsecond=microseconds) + PyTuple_SET_ITEM(result, x, v) + Py_INCREF(v) + + return result + + [email protected](False) [email protected](False) def read_nullable_array(ResponseBuffer buffer, array_type: str, unsigned long long num_rows, object null_obj): if num_rows == 0: return [] @@ -303,3 +521,117 @@ def write_str_col(column: Sequence, nullable: bool, encoding: Optional[str], des mv.release() PyMem_Free(<void *>temp_buff) return 0 + + +# Mapping of struct format codes to expected numpy dtype kind +_code_to_kind = { + 'b': 'i', 'h': 'i', 'i': 'i', 'l': 'i', 'q': 'i', + 'B': 'u', 'H': 'u', 'I': 'u', 'L': 'u', 'Q': 'u', + 'f': 'f', 'd': 'f', +} + + [email protected](False) [email protected](False) +def write_native_col(str code, column, bytearray dest, object col_name=None) -> int: + """ + Write a column of fixed-width values directly into dest bytearray. + Fast-paths C-contiguous numpy arrays with matching dtype via memcpy. + Falls back to struct.pack for Python sequences. + """ + cdef Py_ssize_t old_size = PyByteArray_GET_SIZE(dest) + cdef Py_buffer view + cdef object dtype + cdef str byteorder + cdef str expected_kind + cdef object np + cdef Py_ssize_t num_rows + + # Numpy fast path: check if array is 1-D, contiguous, matching kind+size, little-endian + np = options.np + if np is not None and isinstance(column, np.ndarray): + dtype = column.dtype + byteorder = dtype.byteorder + expected_kind = _code_to_kind.get(code, None) + expected_size = array.array(code).itemsize + + # Check all safety conditions for memcpy + if (column.ndim == 1 and + column.flags['C_CONTIGUOUS'] and + expected_kind is not None and + dtype.kind == expected_kind and + dtype.itemsize == expected_size and + dtype.kind != 'O' and # no object arrays + (byteorder == '<' or + (byteorder in ('=', '|') and sys.byteorder == 'little'))): + + # All checks passed so do direct memcpy + PyObject_GetBuffer(column, &view, PyBUF_SIMPLE) + try: + PyByteArray_Resize(dest, old_size + view.len) + memcpy(PyByteArray_AS_STRING(dest) + old_size, view.buf, view.len) + finally: + PyBuffer_Release(&view) + return 0 + + # General fallback: struct.pack with C-level argument unpacking, then append to dest. + num_rows = len(column) + try: + dest += struct.Struct(f"<{num_rows}{code}").pack(*column) + except (TypeError, OverflowError, struct.error) as ex: + col_msg = f" for column `{str(col_name)}`" if col_name else "" + error_detail = type(ex).__name__ + if isinstance(ex, OverflowError): + error_detail = "value out of range" + elif isinstance(ex, TypeError): + error_detail = "type mismatch (usually None in non-Nullable column)" + raise DataError( + f"Unable to create native array{col_msg}: {error_detail}" + ) from ex + return 0 + + +cdef inline unsigned long long _bswap_uint64(unsigned long long v): + """Byte-swap a 64-bit unsigned integer for big-endian systems.""" + return (((v & 0xFF) << 56) | (((v >> 8) & 0xFF) << 48) | + (((v >> 16) & 0xFF) << 40) | (((v >> 24) & 0xFF) << 32) | + (((v >> 32) & 0xFF) << 24) | (((v >> 40) & 0xFF) << 16) | + (((v >> 48) & 0xFF) << 8) | ((v >> 56) & 0xFF)) + + [email protected](False) [email protected](False) +def build_map_columns(column, bytearray dest): + """ + Flatten a column of dicts into (keys, values) lists and write UInt64 offsets into dest. + Uses two-pass strategy: first compute offsets, pre-allocate lists, then fill by index. + """ + cdef unsigned long long num_rows = len(column) + cdef unsigned long long total = 0, ix = 0, old_size + cdef Py_ssize_t offset_bytes = num_rows * 8 + cdef char* dest_ptr + cdef unsigned long long offset_value + cdef unsigned long long i + + # First pass: compute offsets and total entry count, write into dest via memcpy (safe for alignment) + old_size = PyByteArray_GET_SIZE(dest) + PyByteArray_Resize(dest, old_size + offset_bytes) + dest_ptr = PyByteArray_AS_STRING(dest) + old_size + + for i, v in enumerate(column): + total += len(v) + offset_value = total + if must_swap: + offset_value = _bswap_uint64(offset_value) + memcpy(dest_ptr + i * 8, &offset_value, 8) + + # Pre-allocate lists at exact size, second pass fills by index + keys = [None] * total + values = [None] * total + for v in column: + for k, val in v.items(): + keys[ix] = k + values[ix] = val + ix += 1 + + return keys, values diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/entry_points.py b/contrib/python/clickhouse-connect/clickhouse_connect/entry_points.py index 9981e98a390..8d4b0518d0c 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/entry_points.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/entry_points.py @@ -2,37 +2,35 @@ # This script is used for validating installed entrypoints. Note that it fails on Python 3.7 import sys - from importlib.metadata import PackageNotFoundError, distribution -EXPECTED_EPS = {'sqlalchemy.dialects:clickhousedb', - 'sqlalchemy.dialects:clickhousedb.connect'} +EXPECTED_EPS = {"sqlalchemy.dialects:clickhousedb", "sqlalchemy.dialects:clickhousedb.connect"} def validate_entrypoints(): expected_eps = EXPECTED_EPS.copy() try: - dist = distribution('clickhouse-connect') + dist = distribution("clickhouse-connect") except PackageNotFoundError: - print ('\nClickHouse Connect package not found in this Python installation') + print("\nClickHouse Connect package not found in this Python installation") return -1 print() for entry_point in dist.entry_points: - name = f'{entry_point.group}:{entry_point.name}' - print(f' {name}={entry_point.value}') + name = f"{entry_point.group}:{entry_point.name}" + print(f" {name}={entry_point.value}") try: expected_eps.remove(name) except KeyError: - print (f'\nUnexpected entry point {name} found') + print(f"\nUnexpected entry point {name} found") return -1 if expected_eps: print() for name in expected_eps: - print (f'Did not find expected ep {name}') + print(f"Did not find expected ep {name}") return -1 - print ('\nEntrypoints correctly installed') + print("\nEntrypoints correctly installed") return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(validate_entrypoints()) diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/json_impl.py b/contrib/python/clickhouse-connect/clickhouse_connect/json_impl.py index 686ddf39555..a05f55aea5c 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/json_impl.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/json_impl.py @@ -1,11 +1,12 @@ -import logging import json as py_json +import logging from collections import OrderedDict from typing import Any try: import orjson - any_to_json = orjson.dumps # pylint: disable=no-member + + any_to_json = orjson.dumps except ImportError: orjson = None @@ -13,36 +14,36 @@ try: import ujson def _ujson_to_json(obj: Any) -> bytes: - return ujson.dumps(obj).encode() # pylint: disable=c-extension-no-member + return ujson.dumps(obj).encode() except ImportError: ujson = None _ujson_to_json = None def _pyjson_to_json(obj: Any) -> bytes: - return py_json.dumps(obj, separators=(',', ':')).encode() + return py_json.dumps(obj, separators=(",", ":")).encode() logger = logging.getLogger(__name__) _to_json = OrderedDict() -_to_json['orjson'] = orjson.dumps if orjson else None # pylint: disable=no-member -_to_json['ujson'] = _ujson_to_json if ujson else None -_to_json['python'] = _pyjson_to_json +_to_json["orjson"] = orjson.dumps if orjson else None +_to_json["ujson"] = _ujson_to_json if ujson else None +_to_json["python"] = _pyjson_to_json any_to_json = _pyjson_to_json def set_json_library(impl: str = None): - global any_to_json # pylint: disable=global-statement + global any_to_json if impl: func = _to_json.get(impl) if func: any_to_json = func return - raise NotImplementedError(f'JSON library {impl} is not supported') + raise NotImplementedError(f"JSON library {impl} is not supported") for library, func in _to_json.items(): if func: - logger.debug('Using %s library for writing JSON byte strings', library) + logger.debug("Using %s library for writing JSON byte strings", library) any_to_json = func break diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/tools/datagen.py b/contrib/python/clickhouse-connect/clickhouse_connect/tools/datagen.py index f956d382a71..6a231ee58ea 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/tools/datagen.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/tools/datagen.py @@ -1,20 +1,19 @@ import struct import uuid +from collections.abc import Callable, Sequence +from datetime import date, datetime, timedelta, timezone, tzinfo from decimal import Decimal as PyDecimal from ipaddress import IPv4Address, IPv6Address -from random import random, choice -from typing import Sequence, Union, NamedTuple, Callable, Type, Dict -from datetime import date, datetime, timedelta, tzinfo - -import pytz +from random import choice, random +from typing import NamedTuple from clickhouse_connect.datatypes.base import ClickHouseType -from clickhouse_connect.datatypes.container import Array, Tuple, Map, Nested +from clickhouse_connect.datatypes.container import Array, Map, Nested, Tuple from clickhouse_connect.datatypes.network import IPv4, IPv6 -from clickhouse_connect.datatypes.numeric import BigInt, Float32, Float64, Enum, Bool, Boolean, Decimal +from clickhouse_connect.datatypes.numeric import BigInt, Bool, Boolean, Decimal, Enum, Float32, Float64 from clickhouse_connect.datatypes.registry import get_from_name from clickhouse_connect.datatypes.special import UUID -from clickhouse_connect.datatypes.string import String, FixedString +from clickhouse_connect.datatypes.string import FixedString, String from clickhouse_connect.datatypes.temporal import Date, Date32, DateTime, DateTime64 from clickhouse_connect.driver import tzutil from clickhouse_connect.driver.common import array_sizes @@ -29,14 +28,15 @@ class RandomValueDef(NamedTuple): """ Parameter object to control the generation of random data values for testing """ - server_tz: tzinfo = pytz.UTC + + server_tz: tzinfo = timezone.utc null_pct: float = 0.15 str_len: int = 200 arr_len: int = 12 ascii_only: bool = False -def random_col_data(ch_type: Union[str, ClickHouseType], cnt: int, col_def: RandomValueDef = RandomValueDef()): +def random_col_data(ch_type: str | ClickHouseType, cnt: int, col_def: RandomValueDef = RandomValueDef()): # noqa: B008 """ Generate a column of random data for insert tests :param ch_type: ClickHouseType or ClickHouse type name @@ -53,7 +53,6 @@ def random_col_data(ch_type: Union[str, ClickHouseType], cnt: int, col_def: Rand return tuple(gen() for _ in range(cnt)) -# pylint: disable=too-many-return-statements,too-many-branches,protected-access def random_value_gen(ch_type: ClickHouseType, col_def: RandomValueDef): """ Returns a generator function of random values of the requested ClickHouseType @@ -63,7 +62,7 @@ def random_value_gen(ch_type: ClickHouseType, col_def: RandomValueDef): """ if ch_type.__class__ in gen_map: return gen_map[ch_type.__class__] - if isinstance(ch_type, BigInt) or ch_type.python_type == int: + if isinstance(ch_type, BigInt) or ch_type.python_type is int: if isinstance(ch_type, BigInt): sz = 2 ** (ch_type.byte_size * 8) signed = ch_type._signed @@ -92,19 +91,19 @@ def random_value_gen(ch_type: ClickHouseType, col_def: RandomValueDef): return lambda: random_ascii_str(col_def.str_len) return lambda: random_utf8_str(col_def.str_len) if isinstance(ch_type, FixedString): - return lambda: bytes((int(random() * 256) for _ in range(ch_type.byte_size))) + return lambda: bytes(int(random() * 256) for _ in range(ch_type.byte_size)) if isinstance(ch_type, DateTime): - if col_def.server_tz == pytz.UTC: + if tzutil.is_utc_timezone(col_def.server_tz): return random_datetime - timezone = col_def.server_tz - return lambda: random_datetime_tz(timezone) + tz = col_def.server_tz + return lambda: random_datetime_tz(tz) if isinstance(ch_type, DateTime64): prec = ch_type.prec - if col_def.server_tz == pytz.UTC: + if tzutil.is_utc_timezone(col_def.server_tz): return lambda: random_datetime64(prec) - timezone = col_def.server_tz - return lambda: random_datetime64_tz(prec, timezone) - raise ValueError(f'Invalid ClickHouse type {ch_type.name} for random column data') + tz = col_def.server_tz + return lambda: random_datetime64_tz(prec, tz) + raise ValueError(f"Invalid ClickHouse type {ch_type.name} for random column data") def random_float(): @@ -113,15 +112,15 @@ def random_float(): def random_float32(): f64 = (random() * random() * 65536) / (random() * (random() * 256 - 128)) - return struct.unpack('f', struct.pack('f', f64))[0] + return struct.unpack("f", struct.pack("f", f64))[0] def random_decimal(prec: int, scale: int): - digits = ''.join(str(int(random() * 12000000000)) for _ in range(prec // 10 + 1)).rjust(prec, '0')[:prec] - sign = '' if ord(digits[0]) & 0x01 else '-' + digits = "".join(str(int(random() * 12000000000)) for _ in range(prec // 10 + 1)).rjust(prec, "0")[:prec] + sign = "" if ord(digits[0]) & 0x01 else "-" if scale == 0: - return PyDecimal(f'{sign}{digits}') - return PyDecimal(f'{sign}{digits[:-scale]}.{digits[-scale:]}') + return PyDecimal(f"{sign}{digits}") + return PyDecimal(f"{sign}{digits[:-scale]}.{digits[-scale:]}") def random_tuple(element_types: Sequence[ClickHouseType], col_def): @@ -135,24 +134,24 @@ def random_map(key_type, value_type, sz: int, col_def): def random_datetime(): - return dt_from_ts(int(random() * 2 ** 32)).replace(microsecond=0) + return dt_from_ts(int(random() * 2**32)).replace(microsecond=0) def random_datetime_tz(timezone: tzinfo): - return dt_from_ts_tz(int(random() * 2 ** 32), timezone).replace(microsecond=0) + return dt_from_ts_tz(int(random() * 2**32), timezone).replace(microsecond=0) def random_ascii_str(max_len: int = 200, min_len: int = 0): - return ''.join((chr(int(random() * 95) + 32) for _ in range(int(random() * (max_len - min_len)) + min_len))) + return "".join(chr(int(random() * 95) + 32) for _ in range(int(random() * (max_len - min_len)) + min_len)) def random_utf8_str(max_len: int = 200): random_chars = [chr(int(random() * 65000) + 32) for _ in range(int(random() * max_len))] - return ''.join((c for c in random_chars if c.isprintable())) + return "".join(c for c in random_chars if c.isprintable()) def fixed_len_ascii_str(str_len: int = 200): - return ''.join((chr(int(random() * 95) + 32) for _ in range(str_len))) + return "".join(chr(int(random() * 95) + 32) for _ in range(str_len)) # Only accepts precisions in multiples of 3 because others are extremely unlikely to be actually used @@ -179,11 +178,15 @@ def random_datetime64_tz(prec: int, timezone: tzinfo): def random_ipv6(): if random() > 0.2: # multiple randoms because of random float multiply limitations - ip_int = (int(random() * 4294967296) << 96) | (int(random() * 4294967296)) | ( - int(random() * 4294967296) << 32) | ( int(random() * 4294967296) << 64) + ip_int = ( + (int(random() * 4294967296) << 96) + | (int(random() * 4294967296)) + | (int(random() * 4294967296) << 32) + | (int(random() * 4294967296) << 64) + ) return IPv6Address(ip_int) # Return mapped IPv4 as IPv6 - ipv4_int = int(random() * 2 ** 32) + ipv4_int = int(random() * 2**32) return IPv6Address(f"::ffff:{IPv4Address(ipv4_int)}") @@ -198,7 +201,7 @@ def random_nested(keys: Sequence[str], types: Sequence[ClickHouseType], col_def: return row -gen_map: Dict[Type[ClickHouseType], Callable] = { +gen_map: dict[type[ClickHouseType], Callable] = { Float64: random_float, Float32: random_float32, Date: lambda: epoch_date + timedelta(days=int(random() * 65536)), @@ -206,6 +209,6 @@ gen_map: Dict[Type[ClickHouseType], Callable] = { UUID: uuid.uuid4, IPv4: lambda: IPv4Address(int(random() * 4294967296)), IPv6: random_ipv6, - Boolean: lambda: random() > .5, - Bool: lambda: random() > .5 + Boolean: lambda: random() > 0.5, + Bool: lambda: random() > 0.5, } diff --git a/contrib/python/clickhouse-connect/clickhouse_connect/tools/testing.py b/contrib/python/clickhouse-connect/clickhouse_connect/tools/testing.py index 56ced725016..b5b5274a919 100644 --- a/contrib/python/clickhouse-connect/clickhouse_connect/tools/testing.py +++ b/contrib/python/clickhouse-connect/clickhouse_connect/tools/testing.py @@ -1,32 +1,36 @@ -from typing import Sequence, Optional, Union, Dict, Any +from collections.abc import Sequence +from typing import Any from clickhouse_connect.driver import Client from clickhouse_connect.driver.binding import quote_identifier, str_query_value class TableContext: - def __init__(self, client: Client, - table: str, - columns: Union[str, Sequence[str]], - column_types: Optional[Sequence[str]] = None, - engine: str = 'MergeTree', - order_by: str = None, - settings: Optional[Dict[str, Any]] = None): + def __init__( + self, + client: Client, + table: str, + columns: str | Sequence[str], + column_types: Sequence[str] | None = None, + engine: str = "MergeTree", + order_by: str = None, + settings: dict[str, Any] | None = None, + ): self.client = client - if '.' in table: + if "." in table: self.table = table else: self.table = quote_identifier(table) self.settings = settings if isinstance(columns, str): - columns = columns.split(',') + columns = columns.split(",") if column_types is None: self.column_names = [] self.column_types = [] for col in columns: col = col.strip() - ix = col.find(' ') - self.column_types.append(col[ix + 1:].strip()) + ix = col.find(" ") + self.column_types.append(col[ix + 1 :].strip()) self.column_names.append(quote_identifier(col[:ix].strip())) else: self.column_names = [quote_identifier(name) for name in columns] @@ -35,20 +39,20 @@ class TableContext: self.order_by = self.column_names[0] if order_by is None else order_by def __enter__(self): - if self.client.min_version('19'): - self.client.command(f'DROP TABLE IF EXISTS {self.table}') + if self.client.min_version("19"): + self.client.command(f"DROP TABLE IF EXISTS {self.table}") else: - self.client.command(f'DROP TABLE IF EXISTS {self.table} SYNC') - col_defs = ','.join(f'{quote_identifier(name)} {col_type}' for name, col_type in zip(self.column_names, self.column_types)) - create_cmd = f'CREATE TABLE {self.table} ({col_defs}) ENGINE {self.engine} ORDER BY {self.order_by}' + self.client.command(f"DROP TABLE IF EXISTS {self.table} SYNC") + col_defs = ",".join(f"{quote_identifier(name)} {col_type}" for name, col_type in zip(self.column_names, self.column_types)) + create_cmd = f"CREATE TABLE {self.table} ({col_defs}) ENGINE {self.engine} ORDER BY {self.order_by}" if self.settings: - create_cmd += ' SETTINGS ' + create_cmd += " SETTINGS " for key, value in self.settings.items(): - create_cmd += f'{key} = {str_query_value(value)}, ' - if create_cmd.endswith(', '): + create_cmd += f"{key} = {str_query_value(value)}, " + if create_cmd.endswith(", "): create_cmd = create_cmd[:-2] self.client.command(create_cmd) return self def __exit__(self, exc_type, exc_val, exc_tb): - self.client.command(f'DROP TABLE IF EXISTS {self.table}') + self.client.command(f"DROP TABLE IF EXISTS {self.table}") diff --git a/contrib/python/clickhouse-connect/ya.make b/contrib/python/clickhouse-connect/ya.make index 592abf5f673..f8f72d103b4 100644 --- a/contrib/python/clickhouse-connect/ya.make +++ b/contrib/python/clickhouse-connect/ya.make @@ -2,14 +2,13 @@ PY3_LIBRARY() -VERSION(0.15.1) +VERSION(1.0.0) LICENSE(Apache-2.0) PEERDIR( contrib/python/certifi contrib/python/lz4 - contrib/python/pytz contrib/python/urllib3 contrib/python/zstandard ) @@ -24,12 +23,13 @@ NO_LINT() NO_CHECK_IMPORTS( clickhouse_connect.cc_sqlalchemy.* + clickhouse_connect.driver.asyncclient ) PY_SRCS( TOP_LEVEL clickhouse_connect/__init__.py - clickhouse_connect/__version__.py + clickhouse_connect/_version.py clickhouse_connect/cc_sqlalchemy/__init__.py clickhouse_connect/cc_sqlalchemy/datatypes/__init__.py clickhouse_connect/cc_sqlalchemy/datatypes/base.py @@ -64,6 +64,7 @@ PY_SRCS( clickhouse_connect/dbapi/cursor.py clickhouse_connect/driver/__init__.py clickhouse_connect/driver/asyncclient.py + clickhouse_connect/driver/asyncqueue.py clickhouse_connect/driver/binding.py clickhouse_connect/driver/buffer.py clickhouse_connect/driver/bytesource.py @@ -87,6 +88,7 @@ PY_SRCS( clickhouse_connect/driver/options.py clickhouse_connect/driver/parser.py clickhouse_connect/driver/query.py + clickhouse_connect/driver/streaming.py clickhouse_connect/driver/summary.py clickhouse_connect/driver/tools.py clickhouse_connect/driver/transform.py |
