aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/tests/postgresql/common/__init__.py
blob: e4238ef56a359c960bb7cc3d5d556b4a37fca15d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys
import logging
from pathlib import Path
import subprocess
from .differ import Differ


LOGGER = logging.getLogger(__name__)


def setup_logger():
    options = dict(
        level=logging.DEBUG,
        format='%(levelname)s: %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        stream=sys.stderr
        )

    logging.basicConfig(**options)


setup_logger()


def find_sql_tests(path):
    tests = []

    for sql_file in Path(path).glob('*.sql'):
        if not sql_file.is_file():
            LOGGER.warning("'%s' is not a file", sql_file.absolute())
            continue

        out_files = list(get_out_files(sql_file))
        if not out_files:
            LOGGER.warning("No .out files found for '%s'", sql_file.absolute())
            continue

        tests.append((sql_file.stem, (sql_file, out_files)))

    return tests


def load_init_scripts_for_testcase(testcase_name, init_scripts_cfg, init_scripts_dir):
    with open(init_scripts_cfg, 'r') as cfg:
        for lineno, line in enumerate(cfg, 1):
            cfgline = line.strip().split(':')
            if len(cfgline) != 2:
                LOGGER.info("Bad line %d in init scripts configuration '%s'", lineno, init_scripts_cfg)
                continue

            if cfgline[0].strip() == testcase_name:
                break
        else:
            return []

    avail_scripts = frozenset(s.stem for s in init_scripts_dir.glob("*.sql"))

    scripts = [(init_scripts_dir / s).with_suffix(".sql") for s in cfgline[1].split() if s in avail_scripts]

    if scripts:
        LOGGER.debug("Init scripts: %s", ", ".join(s.stem for s in scripts))

    return scripts


def run_sql_test(sql, out, tmp_path, runner, udfs, init_scripts_cfg, init_scripts_dir):
    args = [runner, "--datadir", tmp_path]
    for udf in udfs:
        args.append("--udf")
        args.append(udf)

    LOGGER.debug("Loading init scripts for '%s' from '%s'", sql.stem, init_scripts_cfg)
    init_scripts = load_init_scripts_for_testcase(sql.stem, init_scripts_cfg, Path(init_scripts_dir))

    if init_scripts:
        LOGGER.debug("Executing init scripts for '%s'", sql.stem)
        for script in init_scripts:
            LOGGER.debug("Executing init script '%s'", script.name)
            with open(script, 'rb') as f:
                pi = subprocess.run(args,
                                    stdin=f, stdout=subprocess.PIPE, stderr=sys.stderr, check=True)

    LOGGER.debug("Running %s '%s' -> [%s]", runner, sql, ', '.join("'{}'".format(a) for a in out))
    with open(sql, 'rb') as f:
        pi = subprocess.run(args,
                            stdin=f, stdout=subprocess.PIPE, stderr=sys.stderr, check=True)

    min_diff = sys.maxsize
    best_match = out[0]
    best_diff = ''

    for out_file in out:
        with open(out_file, 'rb') as f:
            out_data = f.read()

        last_diff = Differ.diff(pi.stdout, out_data)
        diff_len = len(last_diff)

        if diff_len == 0:
            return

        if diff_len < min_diff:
            min_diff = diff_len
            best_match = out_file
            best_diff = last_diff

    LOGGER.info("No exact match for '%s'. Best match is '%s'", sql, best_match)
    for line in best_diff:
        LOGGER.debug(line)

    # We need assert to fail the test properly
    assert min_diff == 0, \
        f"pgrun output does not match out-file for {sql}. Diff:\n" + ''.join(d.decode('utf8') for d in best_diff)[:1024]


def get_out_files(sql_file):
    base_name = sql_file.stem
    out_file = sql_file.with_suffix('.out')

    if out_file.is_file():
        yield out_file

    for i in range(1, 10):
        nth_out_file = out_file.with_stem('{}_{}'.format(base_name, i))

        if not nth_out_file.is_file():
            break

        yield nth_out_file