aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/clickhouse-connect/clickhouse_connect/driver/context.py
blob: 7984fbeebbe84f3b1d0b99ecb4f160a6affc423a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import logging
import re
from datetime import datetime
from typing import Optional, Dict, Union, Any

import pytz

logger = logging.getLogger(__name__)

_empty_map = {}


# pylint: disable=too-many-instance-attributes
class BaseQueryContext:
    local_tz: pytz.timezone

    def __init__(self,
                 settings: Optional[Dict[str, Any]] = None,
                 query_formats: Optional[Dict[str, str]] = None,
                 column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None,
                 encoding: Optional[str] = None,
                 use_extended_dtypes: bool = False,
                 use_numpy: bool = False):
        self.settings = settings or {}
        if query_formats is None:
            self.type_formats = _empty_map
        else:
            self.type_formats = {re.compile(type_name.replace('*', '.*'), re.IGNORECASE): fmt
                                 for type_name, fmt in query_formats.items()}
        if column_formats is None:
            self.col_simple_formats = _empty_map
            self.col_type_formats = _empty_map
        else:
            self.col_simple_formats = {col_name: fmt for col_name, fmt in column_formats.items() if
                                       isinstance(fmt, str)}
            self.col_type_formats = {}
            for col_name, fmt in column_formats.items():
                if not isinstance(fmt, str):
                    self.col_type_formats[col_name] = {re.compile(type_name.replace('*', '.*'), re.IGNORECASE): fmt
                                                       for type_name, fmt in fmt.items()}
        self.query_formats = query_formats or {}
        self.column_formats = column_formats or {}
        self.encoding = encoding
        self.use_numpy = use_numpy
        self.use_extended_dtypes = use_extended_dtypes
        self._active_col_fmt = None
        self._active_col_type_fmts = _empty_map

    def start_column(self, name: str):
        self._active_col_fmt = self.col_simple_formats.get(name)
        self._active_col_type_fmts = self.col_type_formats.get(name, _empty_map)

    def active_fmt(self, ch_type):
        if self._active_col_fmt:
            return self._active_col_fmt
        for type_pattern, fmt in self._active_col_type_fmts.items():
            if type_pattern.match(ch_type):
                return fmt
        for type_pattern, fmt in self.type_formats.items():
            if type_pattern.match(ch_type):
                return fmt
        return None


def _init_context_cls():
    local_tz = datetime.now().astimezone().tzinfo
    if local_tz.tzname(datetime.now()) in ('UTC', 'GMT', 'Universal', 'GMT-0', 'Zulu', 'Greenwich'):
        local_tz = pytz.UTC
    BaseQueryContext.local_tz = local_tz


_init_context_cls()