summaryrefslogtreecommitdiffstats
path: root/contrib/python/clickhouse-connect/clickhouse_connect/cc_sqlalchemy/alembic/impl.py
blob: 771692ad015596c1530f7ac03e6b6308295424f5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
from __future__ import annotations

from types import SimpleNamespace
from typing import Any

from alembic.ddl.impl import DefaultImpl
from alembic.util import CommandError
from sqlalchemy import Column, MetaData, String, Table, text
from sqlalchemy.sql.dml import Delete, Update

from clickhouse_connect.cc_sqlalchemy.datatypes.base import ChSqlaType, sqla_type_from_name
from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import Array as ChSqlaArray
from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import Enum as ChSqlaEnum
from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import Map as ChSqlaMap
from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import Nullable
from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import Tuple as ChSqlaTuple
from clickhouse_connect.cc_sqlalchemy.ddl.tableengine import MergeTree
from clickhouse_connect.cc_sqlalchemy.sql import full_table
from clickhouse_connect.cc_sqlalchemy.sql.ddlcompiler import (
    ClickHouseDDLHelper,
    column_specification,
)
from clickhouse_connect.driver.binding import quote_identifier


def _render_ch_type(type_obj):
    """Render a ChSqlaType as valid Python source for autogen migrations"""
    wrappers = type_obj.type_def.wrappers
    if isinstance(type_obj, ChSqlaEnum):
        keys = list(type_obj.type_def.keys)
        values = list(type_obj.type_def.values)
        rendered = f"{type_obj.__class__.__name__}(keys={keys!r}, values={values!r})"
    elif isinstance(type_obj, ChSqlaArray):
        rendered = f"Array({_render_inner(type_obj.type_def.values[0])})"
    elif isinstance(type_obj, ChSqlaMap):
        key, value = type_obj.type_def.values
        rendered = f"Map({_render_inner(key)}, {_render_inner(value)})"
    elif isinstance(type_obj, ChSqlaTuple):
        elements = ", ".join(_render_inner(v) for v in type_obj.type_def.values)
        rendered = f"Tuple({elements})"
    else:
        return str(type_obj.name)
    for wrapper in reversed(wrappers):
        rendered = f"{wrapper}({rendered})"
    return rendered


def _render_inner(name):
    return _render_ch_type(sqla_type_from_name(name))


class ClickHouseImpl(DefaultImpl):
    __dialect__ = "clickhousedb"
    transactional_ddl = False

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if self.context_opts.get("include_schemas") and not self.context_opts.get("version_table_schema") and self.connection is not None:
            current_database = self.connection.execute(text("SELECT currentDatabase()")).scalar()
            if current_database:
                self.context_opts["version_table_schema"] = current_database

    def version_table_impl(
        self,
        *,
        version_table: str,
        version_table_schema: str | None,
        version_table_pk: bool,
        **_kw: Any,
    ) -> Table:
        return Table(
            version_table,
            MetaData(),
            Column("version_num", String(32), nullable=False),
            MergeTree(order_by="version_num"),
            schema=version_table_schema,
        )

    def _exec(
        self,
        construct,
        execution_options=None,
        multiparams=None,
        params=None,
    ) -> Any:
        if isinstance(construct, Update) and self._is_version_table_construct(construct):
            return self._exec_version_update(construct, execution_options)
        if isinstance(construct, Delete) and self._is_version_table_construct(construct):
            return self._exec_version_delete(construct, execution_options)
        return super()._exec(
            construct,
            execution_options=execution_options,
            multiparams=multiparams,
            params=params or {},
        )

    def add_column(
        self,
        table_name: str,
        column: Column,
        *,
        schema: str | None = None,
        if_not_exists: bool | None = None,
        **kw: Any,
    ) -> None:
        sql = [
            "ALTER TABLE",
            full_table(table_name, schema),
            "ADD COLUMN",
        ]
        if if_not_exists:
            sql.append("IF NOT EXISTS")
        sql.append(column_specification(self.dialect, column))
        after = kw.get("after") or ClickHouseDDLHelper.get_option(column, "after")
        if after:
            sql.extend(["AFTER", quote_identifier(after)])
        settings = ClickHouseDDLHelper.render_settings(kw.get("clickhouse_settings"))
        if settings:
            sql.extend(["SETTINGS", settings])
        self._exec(text(" ".join(sql)))

    def drop_column(
        self,
        table_name: str,
        column: Column,
        *,
        schema: str | None = None,
        if_exists: bool | None = None,
        **kw: Any,
    ) -> None:
        sql = ["ALTER TABLE", full_table(table_name, schema), "DROP COLUMN"]
        if if_exists:
            sql.append("IF EXISTS")
        sql.append(quote_identifier(column.name))
        settings = ClickHouseDDLHelper.render_settings(kw.get("clickhouse_settings"))
        if settings:
            sql.extend(["SETTINGS", settings])
        self._exec(text(" ".join(sql)))

    def create_table_comment(self, table: Table) -> None:
        self._exec(text(self._comment_table_sql(table, table.comment)))

    def drop_table_comment(self, table: Table) -> None:
        self._exec(text(self._comment_table_sql(table, None)))

    def alter_column(
        self,
        table_name: str,
        column_name: str,
        *,
        nullable: bool | None = None,
        server_default=False,
        name: str | None = None,
        type_=None,
        schema: str | None = None,
        autoincrement: bool | None = None,
        comment=False,
        existing_comment: str | None = None,
        existing_type=None,
        existing_server_default=None,
        existing_nullable: bool | None = None,
        existing_autoincrement: bool | None = None,
        if_exists: bool | None = None,
        **kw: Any,
    ) -> None:
        if autoincrement is not None or existing_autoincrement is not None:
            return
        if name is not None:
            rename_sql = ["ALTER TABLE", full_table(table_name, schema), "RENAME COLUMN"]
            if if_exists:
                rename_sql.append("IF EXISTS")
            rename_sql.extend([quote_identifier(column_name), "TO", quote_identifier(name)])
            self._exec(text(" ".join(rename_sql)))
            column_name = name

        settings = ClickHouseDDLHelper.render_settings(kw.get("clickhouse_settings"))
        will_modify = nullable is not None or server_default is not False or type_ is not None

        if comment is not False and not will_modify:
            self._exec(text(self._comment_column_sql(table_name, column_name, comment, schema, settings)))

        if not will_modify:
            return

        if type_ is not None:
            effective_type = type_
        else:
            effective_type = existing_type
        if effective_type is None:
            raise CommandError(f"ClickHouse alter_column requires existing_type for {table_name}.{column_name}")
        if nullable is not None:
            effective_type = self._set_type_nullable(effective_type, nullable)

        sql = [
            "ALTER TABLE",
            full_table(table_name, schema),
            "MODIFY COLUMN",
        ]
        if if_exists:
            sql.append("IF EXISTS")
        sql.append(
            column_specification(
                self.dialect,
                Column(
                    column_name,
                    effective_type,
                    server_default=None if server_default is False else server_default,
                    comment=existing_comment if comment is False else comment,
                ),
            )
        )
        if settings:
            sql.extend(["SETTINGS", settings])
        self._exec(text(" ".join(sql)))

    def compare_type(self, inspector_column, metadata_column) -> bool:
        inspector_type = inspector_column.type
        metadata_type = metadata_column.type
        explicit_nullable = ClickHouseDDLHelper.explicit_column_nullable(metadata_column)
        if explicit_nullable is None and isinstance(inspector_type, ChSqlaType) and isinstance(metadata_type, ChSqlaType):
            inspector_type = ClickHouseDDLHelper.without_nullable(inspector_type)
            metadata_type = ClickHouseDDLHelper.without_nullable(metadata_type)
        else:
            metadata_type = ClickHouseDDLHelper.effective_column_type(metadata_column)
        inspector_type = self._normalize_type_name(inspector_type)
        metadata_type = self._normalize_type_name(metadata_type)
        return inspector_type != metadata_type

    def compare_server_default(
        self,
        inspector_column,
        metadata_column,
        rendered_metadata_default,
        rendered_inspector_default,
    ):
        return self._normalize_default(rendered_inspector_default) != self._normalize_default(rendered_metadata_default)

    def render_type(self, type_obj, autogen_context):
        if not isinstance(type_obj, ChSqlaType):
            return False
        return _render_ch_type(type_obj)

    def _exec_version_update(self, construct: Update, execution_options=None):
        # Alembic emits a normal SQLAlchemy Update here, but ClickHouse version tracking
        # needs insert + mutation delete semantics. SQLAlchemy does not expose a stable
        # public API for these values across versions, so this depends on the current
        # Update internals.
        values = construct._values
        if not values:
            raise CommandError("ClickHouse Alembic version update is missing values")
        version_value = self._compile_clause(list(values.values())[0])
        where_clause = self._compile_version_where(construct)
        self._exec(text(f"INSERT INTO {self._version_table_name} (version_num) VALUES ({version_value})"))
        self._exec(text(f"ALTER TABLE {self._version_table_name} DELETE WHERE {where_clause} SETTINGS mutations_sync = 2"))
        return SimpleNamespace(rowcount=1)

    def _exec_version_delete(self, construct: Delete, execution_options=None):
        where_clause = self._compile_version_where(construct)
        return super()._exec(
            text(f"ALTER TABLE {self._version_table_name} DELETE WHERE {where_clause} SETTINGS mutations_sync = 2"),
            execution_options=execution_options,
        )

    @property
    def _version_table_name(self) -> str:
        schema = self.context_opts.get("version_table_schema")
        table = self.context_opts.get("version_table", "alembic_version")
        if schema:
            return f"{quote_identifier(schema)}.{quote_identifier(table)}"
        return quote_identifier(table)

    def _is_version_table_construct(self, construct) -> bool:
        table = getattr(construct, "table", None)
        if table is None:
            return False
        if table.name != self.context_opts.get("version_table", "alembic_version"):
            return False
        expected_schema = self.context_opts.get("version_table_schema")
        # Alembic captures version_table_schema before ClickHouseImpl.__init__
        # has a chance to set it, so the _version Table may have schema=None
        # while context_opts has the auto-detected database name.
        if table.schema == expected_schema:
            return True
        if table.schema is None and expected_schema is not None:
            return True
        return False

    def _compile_version_where(self, construct) -> str:
        predicates = []
        for expression in construct._where_criteria:
            # SQLAlchemy does not provide a public helper for pulling these predicates
            # back apart, so this relies on the current binary expression structure.
            column_name = getattr(getattr(expression, "left", None), "name", None)
            if not column_name:
                predicates.append(self._compile_clause(expression))
                continue
            right = self._compile_clause(expression.right)
            predicates.append(f"{quote_identifier(column_name)} = {right}")
        return " AND ".join(predicates)

    def _compile_clause(self, clause) -> str:
        return str(
            clause.compile(
                dialect=self.dialect,
                compile_kwargs={"literal_binds": True},
            )
        )

    def _comment_column_sql(
        self,
        table_name: str,
        column_name: str,
        comment: str | None,
        schema: str | None,
        settings: str,
    ) -> str:
        sql = [
            "ALTER TABLE",
            full_table(table_name, schema),
            "COMMENT COLUMN",
            quote_identifier(column_name),
            ClickHouseDDLHelper.render_comment(comment),
        ]
        if settings:
            sql.extend(["SETTINGS", settings])
        return " ".join(sql)

    @staticmethod
    def _comment_table_sql(table: Table, comment: str | None) -> str:
        return " ".join(
            [
                "ALTER TABLE",
                full_table(table.name, table.schema),
                "MODIFY COMMENT",
                ClickHouseDDLHelper.render_comment(comment),
            ]
        )

    @staticmethod
    def _normalize_default(default: str | None) -> str | None:
        if default is None:
            return None
        return default.strip()

    @staticmethod
    def _normalize_type_name(type_: Any) -> str:
        if hasattr(type_, "name"):
            return str(type_.name).replace(" ", "")
        return str(type_).replace(" ", "")

    @staticmethod
    def _set_type_nullable(type_: Any, nullable: bool):
        if isinstance(type_, type) and issubclass(type_, ChSqlaType):
            type_ = type_()
        if not isinstance(type_, ChSqlaType):
            return type_
        if nullable:
            if type_.nullable:
                return type_

            return Nullable(type_)
        if not type_.nullable:
            return type_
        wrappers = tuple(wrapper for wrapper in type_.type_def.wrappers if wrapper != "Nullable")
        return type_.__class__(type_def=type_.type_def.__class__(wrappers, type_.type_def.keys, type_.type_def.values))