aboutsummaryrefslogtreecommitdiffstats
path: root/tools/black_linter/bin/__main__.py
blob: 8a6afb8cb61f4174251a01a9efe2ec54b3e460bf (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import contextlib
import io
import logging
import os
import sys
import time
from pathlib import Path

import yalibrary.term.console as console
from library.python.testing.custom_linter_util import linter_params, reporter

import black
import black.files
import black.report
import black.mode

logger = logging.getLogger(__name__)

SNIPPET_LINES_LIMIT = 20


@contextlib.contextmanager
def collect_stdout(stream):
    sys.stdout.flush()
    old = sys.stdout
    sys.stdout = stream
    yield stream
    stream.flush()
    sys.stdout = old


def run_black(filename, wr_back, mode, report, fast):
    # black prints diff to stdout
    bb = io.BytesIO()
    stream = io.TextIOWrapper(
        buffer=bb,
        encoding=sys.stdout.encoding,
        write_through=True,
    )

    with collect_stdout(stream):
        black.reformat_one(
            Path(filename),
            fast=fast,
            write_back=wr_back,
            mode=mode,
            report=report,
        )

    return bb.getvalue().decode()


def run_black_safe(filename, wr_back, mode, report):
    try:
        return run_black(filename, wr_back, mode, report, fast=False)
    except Exception:
        # fast mode failed - drop report stats and retry
        report.change_count = 0
        report.same_count = 0
        report.failure_count = 0

        return run_black(filename, wr_back, mode, report, fast=True)


def process_file(filename, config):
    logger.debug("Check %s", filename)

    report = black.report.Report(
        check=True,
        quiet=True,
    )
    mode = black.Mode(
        line_length=config.get("line_length"),
        string_normalization=not config.get("skip_string_normalization"),
    )
    wr_back_without_diff = black.WriteBack.from_configuration(check=True, diff=False)
    # Fast path for runs with fix_style option or without errors.
    error_msg = run_black_safe(filename, wr_back_without_diff, mode, report)
    if report.change_count:
        # black runs 15x+ slower if diff is requested, even for files w/o actual diff.
        # Rerun black in case of found error.
        wr_back_with_diff = black.WriteBack.from_configuration(check=True, diff=True)
        error_msg = run_black_safe(filename, wr_back_with_diff, mode, report)

    if error_msg:
        sys.stdout.write(console.strip_ansi_codes(error_msg))
        lines = error_msg.split(os.linesep)
        # strip diff header with "+++" "---" lines
        lines = lines[2:]
        if len(lines) > SNIPPET_LINES_LIMIT:
            lines = lines[:SNIPPET_LINES_LIMIT]
            lines += ["[[rst]]..[truncated].. see full diff in the stdout file in the logsdir"]
        error_msg = os.linesep.join(lines)
    return error_msg


def main():
    params = linter_params.get_params()

    black_parser_logger = logging.getLogger("blib2to3.pgen2.driver")
    black_parser_logger.setLevel(logging.WARNING)

    style_config_path = params.configs[0]
    black_config = black.parse_pyproject_toml(style_config_path)

    report = reporter.LintReport()
    for file_name in params.files:
        start_time = time.time()
        error = process_file(file_name, black_config)
        elapsed = time.time() - start_time

        status = reporter.LintStatus.FAIL if error else reporter.LintStatus.GOOD
        message = error if error else ""
        report.add(file_name, status, message, elapsed=elapsed)

    report.dump(params.report_file)


if __name__ == "__main__":
    main()