aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorzverevgeny <zverevgeny@ydb.tech>2024-12-25 19:59:33 +0300
committerGitHub <noreply@github.com>2024-12-25 19:59:33 +0300
commitcf03f2ffe85bb63c57177cfbf166746750d0da2d (patch)
tree417303116595621c1577a9e5487c49f1cad4ef38
parent61f38fcfdda136ef9494381db941b2c94f8539fc (diff)
downloadydb-cf03f2ffe85bb63c57177cfbf166746750d0da2d.tar.gz
rewrite suite_tests parser (#12949)
-rw-r--r--ydb/tests/functional/suite_tests/test_base.py167
-rw-r--r--ydb/tests/functional/suite_tests/test_sql_logic.py5
2 files changed, 73 insertions, 99 deletions
diff --git a/ydb/tests/functional/suite_tests/test_base.py b/ydb/tests/functional/suite_tests/test_base.py
index 47b68a1f72..d3a51489b1 100644
--- a/ydb/tests/functional/suite_tests/test_base.py
+++ b/ydb/tests/functional/suite_tests/test_base.py
@@ -2,7 +2,6 @@
import itertools
import json
import abc
-import collections
import os
import random
import string
@@ -41,21 +40,53 @@ def mute_sdk_loggers():
mute_sdk_loggers()
-@enum.unique
-class StatementTypes(enum.Enum):
- Skipped = 'statement skipped'
- Ok = 'statement ok'
- Error = 'statement error'
- Query = 'statement query'
- StreamQuery = 'statement stream query'
- ImportTableData = 'statement import table data'
+class StatementDefinition:
+ @enum.unique
+ class Type(enum.Enum):
+ Skipped = 'statement skipped'
+ Ok = 'statement ok'
+ Error = 'statement error'
+ Query = 'statement query'
+ StreamQuery = 'statement stream query'
+ ImportTableData = 'statement import table data'
+ def __init__(self, suite: str, at_line: int, type: Type, text: [str]):
+ self.suite_name = suite
+ self.at_line = at_line
+ self.s_type = type
+ self.text = text
-def get_statement_type(line):
- for s_type in list(StatementTypes):
- if s_type.value in line.lower():
- return s_type
- raise RuntimeError("Can't find statement type for line %s" % line)
+ def __str__(self):
+ return f'''StatementDefinition:
+ suite: {self.suite_name}
+ line: {self.at_line}
+ type: {self.s_type}
+ text:
+''' + '\n'.join([f' {row}' for row in self.text.split('\n')])
+
+ @staticmethod
+ def _parse_statement_type(statement_line: str) -> Type:
+ for t in list(StatementDefinition.Type):
+ if t.value in statement_line.lower():
+ return t
+ return None
+
+ @staticmethod
+ def parse(suite: str, at_line: int, lines: list[str]):
+ if not lines or not lines[0]:
+ raise RuntimeError(f'Invalid statement in {suite}, at line: {at_line}')
+ type = StatementDefinition._parse_statement_type(lines[0])
+ if type is None:
+ raise RuntimeError(f'Unknown statement type in {suite}, at line: {at_line}')
+ lines.pop(0)
+ at_line += 1
+ statement_lines = []
+ for line in lines:
+ if line.startswith('side effect: '): # side effects are not supported yet
+ pass
+ else:
+ statement_lines.append(line)
+ return StatementDefinition(suite, at_line, type, "\n".join(statement_lines))
def get_token(length=10):
@@ -67,12 +98,6 @@ def get_source_path(*args):
return os.path.join(arcadia_root, test_source_path(os.path.join(*args)))
-def is_empty_line(line):
- if line.split():
- return False
- return True
-
-
def get_lines(suite_path):
with open(suite_path) as reader:
for line_idx, line in enumerate(reader.readlines()):
@@ -97,79 +122,31 @@ def get_test_suites(directory):
return suites
-def get_single_statement(lines):
+def split_by_statement(lines):
statement_lines = []
+ statement_start_line_idx = 0
for line_idx, line in lines:
- if is_empty_line(line):
- statement = "\n".join(statement_lines)
- return statement
- statement_lines.append(line)
- return "\n".join(statement_lines)
-
-
-class ParsedStatement(collections.namedtuple('ParsedStatement', ["at_line", "s_type", "suite_name", "text"])):
- def get_fields(self):
- return self._fields
-
- def __str__(self):
- result = ["", "Parsed Statement"]
- for field in self.get_fields():
- value = str(getattr(self, field))
- if field != 'text':
- result.append(' ' * 4 + '%s: %s,' % (field, value))
- else:
- result.append(' ' * 4 + '%s:' % field)
- result.extend([' ' * 8 + row for row in value.split('\n')])
- return "\n".join(result)
+ if line:
+ if line.startswith("statement "):
+ statement_start_line_idx = line_idx
+ statement_lines = [line]
+ elif statement_lines:
+ statement_lines.append(line)
+ else:
+ if statement_lines:
+ yield (statement_start_line_idx, statement_lines)
+ statement_lines = []
+ if statement_lines:
+ yield (statement_start_line_idx, statement_lines)
def get_statements(suite_path, suite_name):
- lines = get_lines(suite_path)
- for line_idx, line in lines:
- if is_empty_line(line) or not is_statement_definition(line):
- # empty line or junk lines
- continue
- text = get_single_statement(lines)
- yield ParsedStatement(
- line_idx,
- get_statement_type(line),
+ for statement_start_line_idx, statement_lines in split_by_statement(get_lines(suite_path)):
+ yield StatementDefinition.parse(
suite_name,
- text)
-
-
-def is_side_effect(statement_line):
- return statement_line.startswith('side effect: ')
-
-
-def parse_side_effect(se_line):
- pieces = se_line.split(':')
- if len(pieces) < 3:
- raise RuntimeError("Invalid side effect description: %s" % se_line)
- se_type = pieces[1].strip()
- se_description = ':'.join(pieces[2:])
- se_description = se_description.strip()
-
- return se_type, se_description
-
-
-def get_statement_and_side_effects(statement_text):
- statement_lines = statement_text.split('\n')
- side_effects = {}
- filtered = []
- for statement_line in statement_lines:
- if not is_side_effect(statement_line):
- filtered.append(statement_line)
- continue
-
- se_type, se_description = parse_side_effect(statement_line)
-
- side_effects[se_type] = se_description
-
- return '\n'.join(filtered), side_effects
-
-
-def is_statement_definition(line):
- return line.startswith("statement")
+ statement_start_line_idx,
+ statement_lines,
+ )
def patch_yql_statement(lines_or_statement, table_path_prefix):
@@ -307,12 +284,12 @@ class BaseSuiteRunner(object):
def assert_statement(self, parsed_statement):
start_time = time.time()
from_type = {
- StatementTypes.Ok: self.assert_statement_ok,
- StatementTypes.Query: self.assert_statement_query,
- StatementTypes.StreamQuery: self.assert_statement_stream_query,
- StatementTypes.Error: (lambda x: x),
- StatementTypes.ImportTableData: self.assert_statement_import_table_data,
- StatementTypes.Skipped: lambda x: x
+ StatementDefinition.Type.Ok: self.assert_statement_ok,
+ StatementDefinition.Type.Query: self.assert_statement_query,
+ StatementDefinition.Type.StreamQuery: self.assert_statement_stream_query,
+ StatementDefinition.Type.Error: (lambda x: x),
+ StatementDefinition.Type.ImportTableData: self.assert_statement_import_table_data,
+ StatementDefinition.Type.Skipped: lambda x: x
}
assert_method = from_type.get(parsed_statement.s_type)
assert_method(parsed_statement)
@@ -329,10 +306,8 @@ class BaseSuiteRunner(object):
)
def assert_statement_error(self, statement):
- # not supported yet
- statement_text, side_effects = get_statement_and_side_effects(statement.text)
assert_that(
- lambda: self.execute_query(statement_text),
+ lambda: self.execute_query(statement.text),
raises(
ydb.Error
)
diff --git a/ydb/tests/functional/suite_tests/test_sql_logic.py b/ydb/tests/functional/suite_tests/test_sql_logic.py
index 83859f8dba..0304d1f2c8 100644
--- a/ydb/tests/functional/suite_tests/test_sql_logic.py
+++ b/ydb/tests/functional/suite_tests/test_sql_logic.py
@@ -5,7 +5,7 @@ import sqlite3
import pytest
from hamcrest import assert_that, raises
-from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute, get_statement_and_side_effects
+from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute
"""
This module is a specific runner of sqllogic tests. Test suites for this
@@ -38,8 +38,7 @@ class TestSQLLogic(BaseSuiteRunner):
safe_execute(lambda: self.__execute_sqlitedb(statement.text), statement)
def assert_statement_error(self, statement):
- statement_text, side_effects = get_statement_and_side_effects(statement.text)
- assert_that(lambda: self.__execute_sqlitedb(statement_text), raises(sqlite3.Error), str(statement))
+ assert_that(lambda: self.__execute_sqlitedb(statement.text), raises(sqlite3.Error), str(statement))
super(TestSQLLogic, self).assert_statement_error(statement)
def get_query_and_output(self, statement_text):