diff options
author | zverevgeny <zverevgeny@ydb.tech> | 2024-12-25 19:59:33 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-25 19:59:33 +0300 |
commit | cf03f2ffe85bb63c57177cfbf166746750d0da2d (patch) | |
tree | 417303116595621c1577a9e5487c49f1cad4ef38 | |
parent | 61f38fcfdda136ef9494381db941b2c94f8539fc (diff) | |
download | ydb-cf03f2ffe85bb63c57177cfbf166746750d0da2d.tar.gz |
rewrite suite_tests parser (#12949)
-rw-r--r-- | ydb/tests/functional/suite_tests/test_base.py | 167 | ||||
-rw-r--r-- | ydb/tests/functional/suite_tests/test_sql_logic.py | 5 |
2 files changed, 73 insertions, 99 deletions
diff --git a/ydb/tests/functional/suite_tests/test_base.py b/ydb/tests/functional/suite_tests/test_base.py index 47b68a1f72..d3a51489b1 100644 --- a/ydb/tests/functional/suite_tests/test_base.py +++ b/ydb/tests/functional/suite_tests/test_base.py @@ -2,7 +2,6 @@ import itertools import json import abc -import collections import os import random import string @@ -41,21 +40,53 @@ def mute_sdk_loggers(): mute_sdk_loggers() -@enum.unique -class StatementTypes(enum.Enum): - Skipped = 'statement skipped' - Ok = 'statement ok' - Error = 'statement error' - Query = 'statement query' - StreamQuery = 'statement stream query' - ImportTableData = 'statement import table data' +class StatementDefinition: + @enum.unique + class Type(enum.Enum): + Skipped = 'statement skipped' + Ok = 'statement ok' + Error = 'statement error' + Query = 'statement query' + StreamQuery = 'statement stream query' + ImportTableData = 'statement import table data' + def __init__(self, suite: str, at_line: int, type: Type, text: [str]): + self.suite_name = suite + self.at_line = at_line + self.s_type = type + self.text = text -def get_statement_type(line): - for s_type in list(StatementTypes): - if s_type.value in line.lower(): - return s_type - raise RuntimeError("Can't find statement type for line %s" % line) + def __str__(self): + return f'''StatementDefinition: + suite: {self.suite_name} + line: {self.at_line} + type: {self.s_type} + text: +''' + '\n'.join([f' {row}' for row in self.text.split('\n')]) + + @staticmethod + def _parse_statement_type(statement_line: str) -> Type: + for t in list(StatementDefinition.Type): + if t.value in statement_line.lower(): + return t + return None + + @staticmethod + def parse(suite: str, at_line: int, lines: list[str]): + if not lines or not lines[0]: + raise RuntimeError(f'Invalid statement in {suite}, at line: {at_line}') + type = StatementDefinition._parse_statement_type(lines[0]) + if type is None: + raise RuntimeError(f'Unknown statement type in {suite}, at line: {at_line}') + lines.pop(0) + at_line += 1 + statement_lines = [] + for line in lines: + if line.startswith('side effect: '): # side effects are not supported yet + pass + else: + statement_lines.append(line) + return StatementDefinition(suite, at_line, type, "\n".join(statement_lines)) def get_token(length=10): @@ -67,12 +98,6 @@ def get_source_path(*args): return os.path.join(arcadia_root, test_source_path(os.path.join(*args))) -def is_empty_line(line): - if line.split(): - return False - return True - - def get_lines(suite_path): with open(suite_path) as reader: for line_idx, line in enumerate(reader.readlines()): @@ -97,79 +122,31 @@ def get_test_suites(directory): return suites -def get_single_statement(lines): +def split_by_statement(lines): statement_lines = [] + statement_start_line_idx = 0 for line_idx, line in lines: - if is_empty_line(line): - statement = "\n".join(statement_lines) - return statement - statement_lines.append(line) - return "\n".join(statement_lines) - - -class ParsedStatement(collections.namedtuple('ParsedStatement', ["at_line", "s_type", "suite_name", "text"])): - def get_fields(self): - return self._fields - - def __str__(self): - result = ["", "Parsed Statement"] - for field in self.get_fields(): - value = str(getattr(self, field)) - if field != 'text': - result.append(' ' * 4 + '%s: %s,' % (field, value)) - else: - result.append(' ' * 4 + '%s:' % field) - result.extend([' ' * 8 + row for row in value.split('\n')]) - return "\n".join(result) + if line: + if line.startswith("statement "): + statement_start_line_idx = line_idx + statement_lines = [line] + elif statement_lines: + statement_lines.append(line) + else: + if statement_lines: + yield (statement_start_line_idx, statement_lines) + statement_lines = [] + if statement_lines: + yield (statement_start_line_idx, statement_lines) def get_statements(suite_path, suite_name): - lines = get_lines(suite_path) - for line_idx, line in lines: - if is_empty_line(line) or not is_statement_definition(line): - # empty line or junk lines - continue - text = get_single_statement(lines) - yield ParsedStatement( - line_idx, - get_statement_type(line), + for statement_start_line_idx, statement_lines in split_by_statement(get_lines(suite_path)): + yield StatementDefinition.parse( suite_name, - text) - - -def is_side_effect(statement_line): - return statement_line.startswith('side effect: ') - - -def parse_side_effect(se_line): - pieces = se_line.split(':') - if len(pieces) < 3: - raise RuntimeError("Invalid side effect description: %s" % se_line) - se_type = pieces[1].strip() - se_description = ':'.join(pieces[2:]) - se_description = se_description.strip() - - return se_type, se_description - - -def get_statement_and_side_effects(statement_text): - statement_lines = statement_text.split('\n') - side_effects = {} - filtered = [] - for statement_line in statement_lines: - if not is_side_effect(statement_line): - filtered.append(statement_line) - continue - - se_type, se_description = parse_side_effect(statement_line) - - side_effects[se_type] = se_description - - return '\n'.join(filtered), side_effects - - -def is_statement_definition(line): - return line.startswith("statement") + statement_start_line_idx, + statement_lines, + ) def patch_yql_statement(lines_or_statement, table_path_prefix): @@ -307,12 +284,12 @@ class BaseSuiteRunner(object): def assert_statement(self, parsed_statement): start_time = time.time() from_type = { - StatementTypes.Ok: self.assert_statement_ok, - StatementTypes.Query: self.assert_statement_query, - StatementTypes.StreamQuery: self.assert_statement_stream_query, - StatementTypes.Error: (lambda x: x), - StatementTypes.ImportTableData: self.assert_statement_import_table_data, - StatementTypes.Skipped: lambda x: x + StatementDefinition.Type.Ok: self.assert_statement_ok, + StatementDefinition.Type.Query: self.assert_statement_query, + StatementDefinition.Type.StreamQuery: self.assert_statement_stream_query, + StatementDefinition.Type.Error: (lambda x: x), + StatementDefinition.Type.ImportTableData: self.assert_statement_import_table_data, + StatementDefinition.Type.Skipped: lambda x: x } assert_method = from_type.get(parsed_statement.s_type) assert_method(parsed_statement) @@ -329,10 +306,8 @@ class BaseSuiteRunner(object): ) def assert_statement_error(self, statement): - # not supported yet - statement_text, side_effects = get_statement_and_side_effects(statement.text) assert_that( - lambda: self.execute_query(statement_text), + lambda: self.execute_query(statement.text), raises( ydb.Error ) diff --git a/ydb/tests/functional/suite_tests/test_sql_logic.py b/ydb/tests/functional/suite_tests/test_sql_logic.py index 83859f8dba..0304d1f2c8 100644 --- a/ydb/tests/functional/suite_tests/test_sql_logic.py +++ b/ydb/tests/functional/suite_tests/test_sql_logic.py @@ -5,7 +5,7 @@ import sqlite3 import pytest from hamcrest import assert_that, raises -from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute, get_statement_and_side_effects +from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute """ This module is a specific runner of sqllogic tests. Test suites for this @@ -38,8 +38,7 @@ class TestSQLLogic(BaseSuiteRunner): safe_execute(lambda: self.__execute_sqlitedb(statement.text), statement) def assert_statement_error(self, statement): - statement_text, side_effects = get_statement_and_side_effects(statement.text) - assert_that(lambda: self.__execute_sqlitedb(statement_text), raises(sqlite3.Error), str(statement)) + assert_that(lambda: self.__execute_sqlitedb(statement.text), raises(sqlite3.Error), str(statement)) super(TestSQLLogic, self).assert_statement_error(statement) def get_query_and_output(self, statement_text): |