diff options
author | Maxim Yurchuk <maxim-yurchuk@ydb.tech> | 2024-11-20 17:37:57 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-20 17:37:57 +0000 |
commit | f76323e9b295c15751e51e3443aa47a36bee8023 (patch) | |
tree | 4113c8cad473a33e0f746966e0cf087252fa1d7a /yql/essentials/tests/common | |
parent | 753ecb8d410a4cb459c26f3a0082fb2d1724fe63 (diff) | |
parent | a7b9a6afea2a9d7a7bfac4c5eb4c1a8e60adb9e6 (diff) | |
download | ydb-f76323e9b295c15751e51e3443aa47a36bee8023.tar.gz |
Merge pull request #11788 from ydb-platform/mergelibs-241120-1113
Library import 241120-1113
Diffstat (limited to 'yql/essentials/tests/common')
13 files changed, 2254 insertions, 0 deletions
diff --git a/yql/essentials/tests/common/test_framework/conftest.py b/yql/essentials/tests/common/test_framework/conftest.py new file mode 100644 index 0000000000..675726de78 --- /dev/null +++ b/yql/essentials/tests/common/test_framework/conftest.py @@ -0,0 +1,14 @@ +try: + from yql_http_file_server import yql_http_file_server +except ImportError: + yql_http_file_server = None + +try: + from solomon_runner import solomon +except ImportError: + solomon = None + +# bunch of useless statements for linter happiness +# (otherwise it complains about unused names) +assert yql_http_file_server is yql_http_file_server +assert solomon is solomon diff --git a/yql/essentials/tests/common/test_framework/solomon_runner.py b/yql/essentials/tests/common/test_framework/solomon_runner.py new file mode 100644 index 0000000000..de6062a9ec --- /dev/null +++ b/yql/essentials/tests/common/test_framework/solomon_runner.py @@ -0,0 +1,40 @@ +import os +import pytest +import requests + + +class SolomonWrapper(object): + def __init__(self, url, endpoint): + self._url = url + self._endpoint = endpoint + self.table_prefix = "" + + def is_valid(self): + return self._url is not None + + def cleanup(self): + res = requests.post(self._url + "/cleanup") + res.raise_for_status() + + def get_metrics(self): + res = requests.get(self._url + "/metrics?project=my_project&cluster=my_cluster&service=my_service") + res.raise_for_status() + return res.text + + def prepare_program(self, program, program_file, res_dir, lang='sql'): + return program, program_file + + @property + def url(self): + return self._url + + @property + def endpoint(self): + return self._endpoint + + +@pytest.fixture(scope='module') +def solomon(request): + solomon_url = os.environ.get("SOLOMON_URL") + solomon_endpoint = os.environ.get("SOLOMON_ENDPOINT") + return SolomonWrapper(solomon_url, solomon_endpoint) diff --git a/yql/essentials/tests/common/test_framework/test_file_common.py b/yql/essentials/tests/common/test_framework/test_file_common.py new file mode 100644 index 0000000000..b33076e561 --- /dev/null +++ b/yql/essentials/tests/common/test_framework/test_file_common.py @@ -0,0 +1,155 @@ +import codecs +import os +import pytest +import re +import cyson + +import yql.essentials.providers.common.proto.gateways_config_pb2 as gateways_config_pb2 + +from google.protobuf import text_format +from yql_utils import execute_sql, get_supported_providers, get_tables, get_files, get_http_files, \ + get_pragmas, KSV_ATTR, is_xfail, get_param, YQLExecResult, yql_binary_path +from yqlrun import YQLRun + +from test_utils import get_parameters_json, DATA_PATH, replace_vars + + +def get_gateways_config(http_files, yql_http_file_server, force_blocks=False, is_hybrid=False): + config = None + + if http_files or force_blocks or is_hybrid: + config_message = gateways_config_pb2.TGatewaysConfig() + if http_files: + schema = config_message.Fs.CustomSchemes.add() + schema.Pattern = 'http_test://(.*)' + schema.TargetUrl = yql_http_file_server.compose_http_link('$1') + if force_blocks: + config_message.SqlCore.TranslationFlags.extend(['EmitAggApply']) + flags = config_message.YqlCore.Flags.add() + flags.Name = 'UseBlocks' + if is_hybrid: + activate_hybrid = config_message.Yt.DefaultSettings.add() + activate_hybrid.Name = "HybridDqExecution" + activate_hybrid.Value = "1" + deactivate_dq = config_message.Dq.DefaultSettings.add() + deactivate_dq.Name = "AnalyzeQuery" + deactivate_dq.Value = "0" + config = text_format.MessageToString(config_message) + + return config + + +def is_hybrid(provider): + return provider == 'hybrid' + + +def check_provider(provider, config): + if provider not in get_supported_providers(config): + pytest.skip('%s provider is not supported here' % provider) + + +def get_sql_query(provider, suite, case, config): + pragmas = get_pragmas(config) + + if get_param('TARGET_PLATFORM'): + if "yson" in case or "regexp" in case or "match" in case: + pytest.skip('yson/match/regexp is not supported on non-default target platform') + + if get_param('TARGET_PLATFORM') and is_xfail(config): + pytest.skip('xfail is not supported on non-default target platform') + + program_sql = os.path.join(DATA_PATH, suite, '%s.sql' % case) + + with codecs.open(program_sql, encoding='utf-8') as program_file_descr: + sql_query = program_file_descr.read() + if get_param('TARGET_PLATFORM'): + if "Yson::" in sql_query: + pytest.skip('yson udf is not supported on non-default target platform') + if (provider + 'file can not' in sql_query) or (is_hybrid(provider) and ('ytfile can not' in sql_query)): + pytest.skip(provider + ' can not execute this') + + pragmas.append(sql_query) + sql_query = ';\n'.join(pragmas) + if 'Python' in sql_query or 'Javascript' in sql_query: + pytest.skip('ScriptUdf') + + assert 'UseBlocks' not in sql_query, 'UseBlocks should not be used directly, only via ForceBlocks' + + return sql_query + + +def run_file_no_cache(provider, suite, case, cfg, config, yql_http_file_server, yqlrun_binary=None, extra_args=[], force_blocks=False): + check_provider(provider, config) + + sql_query = get_sql_query(provider, suite, case, config) + sql_query = replace_vars(sql_query, "yqlrun_var") + + xfail = is_xfail(config) + + in_tables, out_tables = get_tables(suite, config, DATA_PATH, def_attr=KSV_ATTR) + files = get_files(suite, config, DATA_PATH) + http_files = get_http_files(suite, config, DATA_PATH) + http_files_urls = yql_http_file_server.register_files({}, http_files) + + for table in in_tables: + if cyson.loads(table.attr).get("type") == "document": + content = table.content + else: + content = table.attr + if 'Python' in content or 'Javascript' in content: + pytest.skip('ScriptUdf') + + parameters = get_parameters_json(suite, config) + + yqlrun = YQLRun( + prov=provider, + keep_temp=not re.search(r"yt\.ReleaseTempData", sql_query), + binary=yqlrun_binary, + gateway_config=get_gateways_config(http_files, yql_http_file_server, force_blocks=force_blocks, is_hybrid=is_hybrid(provider)), + extra_args=extra_args, + udfs_dir=yql_binary_path('yql/essentials/tests/common/test_framework/udfs_deps') + ) + + res, tables_res = execute_sql( + yqlrun, + program=sql_query, + input_tables=in_tables, + output_tables=out_tables, + files=files, + urls=http_files_urls, + check_error=not xfail, + verbose=True, + parameters=parameters) + + fixed_stderr = res.std_err + if xfail: + assert res.execution_result.exit_code != 0 + custom_error = re.search(r"/\* custom error:(.*)\*/", sql_query) + if custom_error: + err_string = custom_error.group(1) + assert res.std_err.find(err_string) != -1 + fixed_stderr = None + + fixed_result = YQLExecResult(res.std_out, + fixed_stderr, + res.results, + res.results_file, + res.opt, + res.opt_file, + res.plan, + res.plan_file, + res.program, + res.execution_result, + None) + + return fixed_result, tables_res + + +def run_file(provider, suite, case, cfg, config, yql_http_file_server, yqlrun_binary=None, extra_args=[], force_blocks=False): + if (suite, case, cfg) not in run_file.cache: + run_file.cache[(suite, case, cfg)] = run_file_no_cache(provider, suite, case, cfg, config, yql_http_file_server, yqlrun_binary, extra_args, force_blocks=force_blocks) + + return run_file.cache[(suite, case, cfg)] + + +run_file.cache = {} diff --git a/yql/essentials/tests/common/test_framework/test_utils.py b/yql/essentials/tests/common/test_framework/test_utils.py new file mode 100644 index 0000000000..624b33be89 --- /dev/null +++ b/yql/essentials/tests/common/test_framework/test_utils.py @@ -0,0 +1,269 @@ +import json +import os +import re +import yatest.common + +from yql_utils import get_param as yql_get_param +from google.protobuf import text_format +import yql.essentials.providers.common.proto.gateways_config_pb2 as gateways_config_pb2 + +DATA_PATH = yatest.common.source_path('yql/essentials/tests/sql/suites') +try: + SQLRUN_PATH = yatest.common.binary_path('yql/essentials/tools/sql2yql/sql2yql') +except BaseException: + SQLRUN_PATH = None + +try: + YQLRUN_PATH = yatest.common.binary_path('contrib/ydb/library/yql/tools/yqlrun/yqlrun') +except BaseException: + YQLRUN_PATH = None + + +def get_sql_flags(): + gateway_config = gateways_config_pb2.TGatewaysConfig() + + with open(yatest.common.source_path('yql/essentials/cfg/tests/gateways.conf')) as f: + text_format.Merge(f.read(), gateway_config) + + if yql_get_param('SQL_FLAGS'): + flags = yql_get_param('SQL_FLAGS').split(',') + gateway_config.SqlCore.TranslationFlags.extend(flags) + return gateway_config.SqlCore.TranslationFlags + + +try: + SQL_FLAGS = get_sql_flags() +except BaseException: + SQL_FLAGS = None + + +def recursive_glob(root, begin_template=None, end_template=None): + for parent, dirs, files in os.walk(root): + for filename in files: + if begin_template is not None and not filename.startswith(begin_template): + continue + if end_template is not None and not filename.endswith(end_template): + continue + path = os.path.join(parent, filename) + yield os.path.relpath(path, root) + + +def pytest_generate_tests_by_template(template, metafunc): + argvalues = [] + + suites = [name for name in os.listdir(DATA_PATH) if os.path.isdir(os.path.join(DATA_PATH, name))] + for suite in suites: + for case in sorted([sql_query_path[:-len(template)] + for sql_query_path in recursive_glob(os.path.join(DATA_PATH, suite), end_template=template)]): + argvalues.append((suite, case)) + + metafunc.parametrize(['suite', 'case'], argvalues) + + +def pytest_generate_tests_for_run(metafunc, template='.sql', suites=None, currentPart=0, partsCount=1, data_path=None): + if data_path is None: + data_path = DATA_PATH + argvalues = [] + + if not suites: + suites = sorted([name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))]) + + for suite in suites: + suite_dir = os.path.join(data_path, suite) + # .sql's + for case in sorted([sql_query_path[:-len(template)] + for sql_query_path in recursive_glob(suite_dir, end_template=template)]): + case_program = case + template + with open(os.path.join(suite_dir, case_program)) as f: + if 'do not execute' in f.read(): + continue + + # .cfg's + configs = [ + cfg_file.replace(case + '-', '').replace('.cfg', '') + for cfg_file in recursive_glob(suite_dir, begin_template=case + '-', end_template='.cfg') + ] + if os.path.exists(suite_dir + '/' + case + '.cfg'): + configs.append('') + for cfg in sorted(configs): + if hash((suite, case, cfg)) % partsCount == currentPart: + argvalues.append((suite, case, cfg)) + if not configs and hash((suite, case, 'default.txt')) % partsCount == currentPart: + argvalues.append((suite, case, 'default.txt')) + + metafunc.parametrize( + ['suite', 'case', 'cfg'], + argvalues, + ) + + +def pytest_generate_tests_for_part(metafunc, currentPart, partsCount): + return pytest_generate_tests_for_run(metafunc, currentPart=currentPart, partsCount=partsCount) + + +def get_cfg_file(cfg, case): + if cfg: + return (case + '-' + cfg + '.cfg') if cfg != 'default.txt' else 'default.cfg' + else: + return case + '.cfg' + + +def validate_cfg(result): + for r in result: + assert r[0] in ( + "in", + "in2", + "out", + "udf", + "providers", + "res", + "canonize_peephole", + "canonize_lineage", + "peephole_use_blocks", + "with_final_result_issues", + "xfail", + "pragma", + "canonize_yt", + "file", + "http_file", + "yt_file", + "os", + "param", + ), "Unknown command in .cfg: %s" % (r[0]) + + +def get_config(suite, case, cfg, data_path=None): + if data_path is None: + data_path = DATA_PATH + result = [] + try: + default_cfg = get_cfg_file('default.txt', case) + inherit = ['canonize_peephole', 'canonize_lineage', 'peephole_use_blocks'] + with open(os.path.join(data_path, suite, default_cfg)) as cfg_file_content: + result = [line.strip().split() for line in cfg_file_content.readlines() if line.strip() and line.strip().split()[0]] + validate_cfg(result) + result = [r for r in result if r[0] in inherit] + except IOError: + pass + cfg_file = get_cfg_file(cfg, case) + with open(os.path.join(data_path, suite, cfg_file)) as cfg_file_content: + result = [line.strip().split() for line in cfg_file_content.readlines() if line.strip()] + result + validate_cfg(result) + return result + + +def load_json_file_strip_comments(path): + with open(path) as file: + return '\n'.join([line for line in file.readlines() if not line.startswith('#')]) + + +def get_parameters_files(suite, config): + result = [] + for line in config: + if len(line) != 3 or not line[0] == "param": + continue + + result.append((line[1], os.path.join(DATA_PATH, suite, line[2]))) + + return result + + +def get_parameters_json(suite, config): + parameters_files = get_parameters_files(suite, config) + data = {} + for p in parameters_files: + value_json = json.loads(load_json_file_strip_comments(p[1])) + data[p[0]] = {'Data': value_json} + + return data + + +def output_dir(name): + output_dir = yatest.common.output_path(name) + if not os.path.isdir(output_dir): + os.mkdir(output_dir) + return output_dir + + +def run_sql_on_mr(name, query, kikimr): + out_dir = output_dir(name) + opt_file = os.path.join(out_dir, 'opt.yql') + results_file = os.path.join(out_dir, 'results.yson') + + try: + kikimr( + 'yql-exec -d 1 -P %s --sql --run --optimize -i /dev/stdin --oexpr %s --oresults %s' % ( + kikimr.yql_pool_id, + opt_file, + results_file + ), + stdin=query + ) + except yatest.common.ExecutionError as e: + runyqljob_result = e.execution_result + assert 0, 'yql-exec finished with error: \n\n%s \n\non program: \n\n%s' % ( + runyqljob_result.std_err, + query + ) + return opt_file, results_file + + +def normalize_table(csv, fields_order=None): + ''' + :param csv: table content + :param fields_order: normal order of fields (default: 'key', 'subkey', 'value') + :return: normalized table content + ''' + if not csv.strip(): + return '' + + headers = csv.splitlines()[0].strip().split(';') + + if fields_order is None: + if len(set(headers)) < len(headers): + # we have duplicates in case of joining tables, let's just cut headers and return as is + return '\n'.join(csv.splitlines()[1:]) + + fields_order = headers + + normalized = '' + + if any(field not in headers for field in fields_order): + fields_order = sorted(headers) + + translator = { + field: headers.index(field) for field in fields_order + } + + def normalize_cell(s): + if s == 't': + return 'true' + if s == 'f': + return 'false' + + if '.' in s: + try: + f = float(s) + return str(str(int(f)) if f.is_integer() else f) + except ValueError: + return s + else: + return s + + for line in csv.splitlines()[1:]: + line = line.strip().split(';') + normalized_cells = [normalize_cell(line[translator[field]]) for field in fields_order] + normalized += '\n' + ';'.join(normalized_cells) + + return normalized.strip() + + +def replace_vars(sql_query, var_tag): + """ + Sql can contain comment like /* yt_local_var: VAR_NAME=VAR_VALUE */ + it will replace VAR_NAME with VAR_VALUE within sql query + """ + vars = re.findall(r"\/\* {}: (.*)=(.*) \*\/".format(var_tag), sql_query) + for var_name, var_value in vars: + sql_query = re.sub(re.escape(var_name.strip()), var_value.strip(), sql_query) + return sql_query diff --git a/yql/essentials/tests/common/test_framework/udfs_deps/ya.make b/yql/essentials/tests/common/test_framework/udfs_deps/ya.make new file mode 100644 index 0000000000..16b320bc3b --- /dev/null +++ b/yql/essentials/tests/common/test_framework/udfs_deps/ya.make @@ -0,0 +1,51 @@ +SET( + UDFS + yql/essentials/udfs/common/datetime2 + yql/essentials/udfs/common/digest + yql/essentials/udfs/common/file + yql/essentials/udfs/common/hyperloglog + yql/essentials/udfs/common/pire + yql/essentials/udfs/common/protobuf + yql/essentials/udfs/common/re2 + yql/essentials/udfs/common/set + yql/essentials/udfs/common/stat + yql/essentials/udfs/common/topfreq + yql/essentials/udfs/common/top + yql/essentials/udfs/common/string + yql/essentials/udfs/common/histogram + yql/essentials/udfs/common/json2 + yql/essentials/udfs/common/yson2 + yql/essentials/udfs/common/math + yql/essentials/udfs/common/url_base + yql/essentials/udfs/common/unicode_base + yql/essentials/udfs/common/streaming + yql/essentials/udfs/examples/callables + yql/essentials/udfs/examples/dicts + yql/essentials/udfs/examples/dummylog + yql/essentials/udfs/examples/lists + yql/essentials/udfs/examples/structs + yql/essentials/udfs/examples/type_inspection + yql/essentials/udfs/logs/dsv + yql/essentials/udfs/test/simple + yql/essentials/udfs/test/test_import +) + +IF (OS_LINUX AND CLANG) + SET( + UDFS + ${UDFS} + yql/essentials/udfs/common/hyperscan + ) +ENDIF() + +PACKAGE() + +IF (SANITIZER_TYPE != "undefined") + +PEERDIR( + ${UDFS} +) + +ENDIF() + +END() diff --git a/yql/essentials/tests/common/test_framework/ya.make b/yql/essentials/tests/common/test_framework/ya.make new file mode 100644 index 0000000000..c0d912d16e --- /dev/null +++ b/yql/essentials/tests/common/test_framework/ya.make @@ -0,0 +1,32 @@ +PY23_LIBRARY() + +PY_SRCS( + TOP_LEVEL + solomon_runner.py + yql_utils.py + yql_ports.py + yqlrun.py + yql_http_file_server.py + test_utils.py + test_file_common.py +) + +PY_SRCS( + NAMESPACE ydb_library_yql_test_framework + conftest.py +) + +PEERDIR( + contrib/python/requests + contrib/python/six + contrib/python/urllib3 + library/python/cyson + yql/essentials/core/file_storage/proto + yql/essentials/providers/common/proto +) + +END() + +RECURSE( + udfs_deps +) diff --git a/yql/essentials/tests/common/test_framework/yql_http_file_server.py b/yql/essentials/tests/common/test_framework/yql_http_file_server.py new file mode 100644 index 0000000000..ad58588ed1 --- /dev/null +++ b/yql/essentials/tests/common/test_framework/yql_http_file_server.py @@ -0,0 +1,136 @@ +import io +import os +import pytest +import threading +import shutil + +import six.moves.BaseHTTPServer as BaseHTTPServer +import six.moves.socketserver as socketserver + +from yql_ports import get_yql_port, release_yql_port + + +# handler is created on each request +# store state in server +class YqlHttpRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): + def get_requested_filename(self): + return self.path.lstrip('/') + + def do_GET(self): + f = self.send_head(self.get_requested_filename()) + if f: + try: + shutil.copyfileobj(f, self.wfile) + finally: + f.close() + + def do_HEAD(self): + f = self.send_head(self.get_requested_filename()) + if f: + f.close() + + def get_file_and_size(self, filename): + try: + path = self.server.file_paths[filename] + f = open(path, 'rb') + fs = os.fstat(f.fileno()) + size = fs[6] + return (f, size) + except KeyError: + try: + content = self.server.file_contents[filename] + return (io.BytesIO(content), len(content)) + except KeyError: + return (None, 0) + + return (None, 0) + + def send_head(self, filename): + (f, size) = self.get_file_and_size(filename) + + if not f: + self.send_error(404, "File %s not found" % filename) + return None + + if self.server.etag is not None: + if_none_match = self.headers.get('If-None-Match', None) + if if_none_match == self.server.etag: + self.send_response(304) + self.end_headers() + f.close() + return None + + self.send_response(200) + + if self.server.etag is not None: + self.send_header("ETag", self.server.etag) + + self.send_header("Content-type", 'application/octet-stream') + self.send_header("Content-Length", size) + self.end_headers() + return f + + +class YqlHttpFileServer(socketserver.TCPServer, object): + def __init__(self): + self.http_server_port = get_yql_port('YqlHttpFileServer') + super(YqlHttpFileServer, self).__init__(('', self.http_server_port), YqlHttpRequestHandler, + bind_and_activate=False) + self.file_contents = {} + self.file_paths = {} + # common etag for all resources + self.etag = None + self.serve_thread = None + + def start(self): + self.allow_reuse_address = True + self.server_bind() + self.server_activate() + self.serve_thread = threading.Thread(target=self.serve_forever) + self.serve_thread.start() + + def stop(self): + super(YqlHttpFileServer, self).shutdown() + self.serve_thread.join() + release_yql_port(self.http_server_port) + self.http_server_port = None + + def forget_files(self): + self.register_files({}, {}) + + def set_etag(self, newEtag): + self.etag = newEtag + + def register_new_path(self, key, file_path): + self.file_paths[key] = file_path + return self.compose_http_link(key) + + def register_files(self, file_contents, file_paths): + self.file_contents = file_contents + self.file_paths = file_paths + + keys = [] + if file_contents: + keys.extend(file_contents.keys()) + + if file_paths: + keys.extend(file_paths.keys()) + + return {k: self.compose_http_link(k) for k in keys} + + def compose_http_link(self, filename): + return self.compose_http_host() + '/' + filename + + def compose_http_host(self): + if not self.http_server_port: + raise Exception('http_server_port is empty. start HTTP server first') + + return 'http://localhost:%d' % self.http_server_port + + +@pytest.fixture(scope='module') +def yql_http_file_server(request): + server = YqlHttpFileServer() + server.start() + request.addfinalizer(server.stop) + return server diff --git a/yql/essentials/tests/common/test_framework/yql_ports.py b/yql/essentials/tests/common/test_framework/yql_ports.py new file mode 100644 index 0000000000..d61be1efdf --- /dev/null +++ b/yql/essentials/tests/common/test_framework/yql_ports.py @@ -0,0 +1,43 @@ +from yatest.common.network import PortManager +import yql_utils + +port_manager = None + + +def get_yql_port(service='unknown'): + global port_manager + + if port_manager is None: + port_manager = PortManager() + + port = port_manager.get_port() + yql_utils.log('get port for service %s: %d' % (service, port)) + return port + + +def release_yql_port(port): + if port is None: + return + + global port_manager + port_manager.release_port(port) + + +def get_yql_port_range(service, count): + global port_manager + + if port_manager is None: + port_manager = PortManager() + + port = port_manager.get_port_range(None, count) + yql_utils.log('get port range for service %s: start_port: %d, count: %d' % (service, port, count)) + return port + + +def release_yql_port_range(start_port, count): + if start_port is None: + return + + global port_manager + for port in range(start_port, start_port + count): + port_manager.release_port(port) diff --git a/yql/essentials/tests/common/test_framework/yql_utils.py b/yql/essentials/tests/common/test_framework/yql_utils.py new file mode 100644 index 0000000000..85970dab02 --- /dev/null +++ b/yql/essentials/tests/common/test_framework/yql_utils.py @@ -0,0 +1,1043 @@ +from __future__ import print_function + +import hashlib +import io +import os +import os.path +import six +import sys +import re +import tempfile +import shutil + +from google.protobuf import text_format +from collections import namedtuple, defaultdict, OrderedDict +from functools import partial +import codecs +import decimal +from threading import Lock + +import pytest +import yatest.common +import cyson + +import logging +import getpass + +logger = logging.getLogger(__name__) + +KSV_ATTR = '''{_yql_row_spec={ + Type=[StructType; + [[key;[DataType;String]]; + [subkey;[DataType;String]]; + [value;[DataType;String]]]]}}''' + + +def get_param(name, default=None): + name = 'YQL_' + name.upper() + return yatest.common.get_param(name, os.environ.get(name) or default) + + +def do_custom_query_check(res, sql_query): + custom_check = re.search(r"/\* custom check:(.*)\*/", sql_query) + if not custom_check: + return False + custom_check = custom_check.group(1) + yt_res_yson = res.results + yt_res_yson = cyson.loads(yt_res_yson) if yt_res_yson else cyson.loads("[]") + yt_res_yson = replace_vals(yt_res_yson) + assert eval(custom_check), 'Condition "%(custom_check)s" fails\nResult:\n %(yt_res_yson)s\n' % locals() + return True + + +def get_gateway_cfg_suffix(): + default_suffix = None + return get_param('gateway_config_suffix', default_suffix) or '' + + +def get_gateway_cfg_filename(): + suffix = get_gateway_cfg_suffix() + if suffix == '': + return 'gateways.conf' + else: + return 'gateways-' + suffix + '.conf' + + +def merge_default_gateway_cfg(cfg_dir, gateway_config): + + with open(yql_source_path(os.path.join(cfg_dir, 'gateways.conf'))) as f: + text_format.Merge(f.read(), gateway_config) + + suffix = get_gateway_cfg_suffix() + if suffix: + with open(yql_source_path(os.path.join(cfg_dir, 'gateways-' + suffix + '.conf'))) as f: + text_format.Merge(f.read(), gateway_config) + + +def find_file(path): + arcadia_root = '.' + while '.arcadia.root' not in os.listdir(arcadia_root): + arcadia_root = os.path.join(arcadia_root, '..') + res = os.path.abspath(os.path.join(arcadia_root, path)) + assert os.path.exists(res) + return res + + +output_path_cache = {} + + +def yql_output_path(*args, **kwargs): + if not get_param('LOCAL_BENCH_XX'): + # abspath is needed, because output_path may be relative when test is run directly (without ya make). + return os.path.abspath(yatest.common.output_path(*args, **kwargs)) + + else: + if args and args in output_path_cache: + return output_path_cache[args] + res = os.path.join(tempfile.mkdtemp(prefix='yql_tmp_'), *args) + if args: + output_path_cache[args] = res + return res + + +def yql_binary_path(*args, **kwargs): + if not get_param('LOCAL_BENCH_XX'): + return yatest.common.binary_path(*args, **kwargs) + + else: + return find_file(args[0]) + + +def yql_source_path(*args, **kwargs): + if not get_param('LOCAL_BENCH_XX'): + return yatest.common.source_path(*args, **kwargs) + else: + return find_file(args[0]) + + +def yql_work_path(): + return os.path.abspath('.') + + +YQLExecResult = namedtuple('YQLExecResult', ( + 'std_out', + 'std_err', + 'results', + 'results_file', + 'opt', + 'opt_file', + 'plan', + 'plan_file', + 'program', + 'execution_result', + 'statistics' +)) + +Table = namedtuple('Table', ( + 'name', + 'full_name', + 'content', + 'file', + 'yqlrun_file', + 'attr', + 'format', + 'exists' +)) + + +def new_table(full_name, file_path=None, yqlrun_file=None, content=None, res_dir=None, + attr=None, format_name='yson', def_attr=None, should_exist=False, src_file_alternative=None): + assert '.' in full_name, 'expected name like cedar.Input' + name = '.'.join(full_name.split('.')[1:]) + + if res_dir is None: + res_dir = get_yql_dir('table_') + + exists = True + if content is None: + # try read content from files + src_file = file_path or yqlrun_file + if src_file is None: + # nonexistent table, will be output for query + content = '' + exists = False + else: + if os.path.exists(src_file): + with open(src_file, 'rb') as f: + content = f.read() + elif src_file_alternative and os.path.exists(src_file_alternative): + with open(src_file_alternative, 'rb') as f: + content = f.read() + src_file = src_file_alternative + yqlrun_file, src_file_alternative = src_file_alternative, yqlrun_file + else: + content = '' + exists = False + + file_path = os.path.join(res_dir, name + '.txt') + new_yqlrun_file = os.path.join(res_dir, name + '.yqlrun.txt') + + if exists: + with open(file_path, 'wb') as f: + f.write(content) + + # copy or create yqlrun_file in proper dir + if yqlrun_file is not None: + shutil.copyfile(yqlrun_file, new_yqlrun_file) + else: + with open(new_yqlrun_file, 'wb') as f: + f.write(content) + else: + assert not should_exist, locals() + + if attr is None: + # try read from file + attr_file = None + if os.path.exists(file_path + '.attr'): + attr_file = file_path + '.attr' + elif yqlrun_file is not None and os.path.exists(yqlrun_file + '.attr'): + attr_file = yqlrun_file + '.attr' + elif src_file_alternative is not None and os.path.exists(src_file_alternative + '.attr'): + attr_file = src_file_alternative + '.attr' + + if attr_file is not None: + with open(attr_file) as f: + attr = f.read() + + if attr is None: + attr = def_attr + + if attr is not None: + # probably we get it, now write attr file to proper place + attr_file = new_yqlrun_file + '.attr' + with open(attr_file, 'w') as f: + f.write(attr) + + return Table( + name, + full_name, + content, + file_path, + new_yqlrun_file, + attr, + format_name, + exists + ) + + +def ensure_dir_exists(dir): + # handle race between isdir and mkdir + if os.path.isdir(dir): + return + + try: + os.mkdir(dir) + except OSError: + if not os.path.isdir(dir): + raise + + +def get_yql_dir(prefix): + yql_dir = yql_output_path('yql') + ensure_dir_exists(yql_dir) + res_dir = tempfile.mkdtemp(prefix=prefix, dir=yql_dir) + os.chmod(res_dir, 0o755) + return res_dir + + +def get_cmd_for_files(arg, files): + cmd = ' '.join( + arg + ' ' + name + '@' + files[name] + for name in files + ) + cmd += ' ' + return cmd + + +def read_res_file(file_path): + if os.path.exists(file_path): + with codecs.open(file_path, encoding="utf-8") as descr: + res = descr.read().strip() + if res == '': + log_res = '<EMPTY>' + else: + log_res = res + else: + res = '' + log_res = '<NOTHING>' + return res, log_res + + +def normalize_yson(y): + from cyson import YsonBoolean, YsonEntity + if isinstance(y, YsonBoolean) or isinstance(y, bool): + return 'true' if y else 'false' + if isinstance(y, YsonEntity) or y is None: + return None + if isinstance(y, list): + return [normalize_yson(i) for i in y] + if isinstance(y, dict): + return {normalize_yson(k): normalize_yson(v) for k, v in six.iteritems(y)} + s = str(y) if not isinstance(y, six.text_type) else y.encode('utf-8', errors='xmlcharrefreplace') + return s + + +volatile_attrs = {'DataSize', 'ModifyTime', 'Id', 'Revision'} +current_user = getpass.getuser() + + +def _replace_vals_impl(y): + if isinstance(y, list): + return [_replace_vals_impl(i) for i in y] + if isinstance(y, dict): + return {_replace_vals_impl(k): _replace_vals_impl(v) for k, v in six.iteritems(y) if k not in volatile_attrs} + if isinstance(y, str): + s = y.replace('tmp/yql/' + current_user + '/', 'tmp/') + s = re.sub(r'tmp/[0-9a-f]+-[0-9a-f]+-[0-9a-f]+-[0-9a-f]+', 'tmp/<temp_table_guid>', s) + return s + return y + + +def replace_vals(y): + y = normalize_yson(y) + y = _replace_vals_impl(y) + return y + + +def patch_yson_vals(y, patcher): + if isinstance(y, list): + return [patch_yson_vals(i, patcher) for i in y] + if isinstance(y, dict): + return {patch_yson_vals(k, patcher): patch_yson_vals(v, patcher) for k, v in six.iteritems(y)} + if isinstance(y, str): + return patcher(y) + return y + + +floatRe = re.compile(r'^-?\d*\.\d+$') +floatERe = re.compile(r'^-?(\d*\.)?\d+e[\+\-]?\d+$', re.IGNORECASE) +specFloatRe = re.compile(r'^(-?inf|nan)$', re.IGNORECASE) + + +def fix_double(x): + if floatRe.match(x) and len(x.replace('.', '').replace('-', '')) > 10: + # Emulate the same double precision as C++ code has + decimal.getcontext().rounding = decimal.ROUND_HALF_DOWN + decimal.getcontext().prec = 10 + return str(decimal.Decimal(0) + decimal.Decimal(x)).rstrip('0') + if floatERe.match(x): + # Emulate the same double precision as C++ code has + decimal.getcontext().rounding = decimal.ROUND_HALF_DOWN + decimal.getcontext().prec = 10 + return str(decimal.Decimal(0) + decimal.Decimal(x)).lower() + if specFloatRe.match(x): + return x.lower() + return x + + +def remove_volatile_ast_parts(ast): + return re.sub(r"\(KiClusterConfig '\('\(.*\) '\"\d\" '\"\d\" '\"\d\"\)\)", "(KiClusterConfig)", ast) + + +def prepare_program(program, program_file, yql_dir, ext='yql'): + assert not (program is None and program_file is None), 'Needs program or program_file' + + if program is None: + with codecs.open(program_file, encoding='utf-8') as program_file_descr: + program = program_file_descr.read() + + program_file = os.path.join(yql_dir, 'program.' + ext) + with codecs.open(program_file, 'w', encoding='utf-8') as program_file_descr: + program_file_descr.write(program) + + return program, program_file + + +def get_program_cfg(suite, case, DATA_PATH): + ret = [] + config = os.path.join(DATA_PATH, suite if suite else '', case + '.cfg') + if not os.path.exists(config): + config = os.path.join(DATA_PATH, suite if suite else '', 'default.cfg') + + if os.path.exists(config): + for line in open(config, 'r'): + if line.strip(): + ret.append(tuple(line.split())) + else: + in_filename = case + '.in' + in_path = os.path.join(DATA_PATH, in_filename) + default_filename = 'default.in' + default_path = os.path.join(DATA_PATH, default_filename) + for filepath in [in_path, in_filename, default_path, default_filename]: + if os.path.exists(filepath): + try: + shutil.copy2(filepath, in_path) + except shutil.Error: + pass + ret.append(('in', 'yamr.plato.Input', in_path)) + break + + if not is_os_supported(ret): + pytest.skip('%s not supported here' % sys.platform) + + return ret + + +def find_user_file(suite, path, DATA_PATH): + source_path = os.path.join(DATA_PATH, suite, path) + if os.path.exists(source_path): + return source_path + else: + try: + return yql_binary_path(path) + except Exception: + raise Exception('Can not find file ' + path) + + +def get_input_tables(suite, cfg, DATA_PATH, def_attr=None): + in_tables = [] + for item in cfg: + if item[0] in ('in', 'out'): + io, table_name, file_name = item + if io == 'in': + in_tables.append(new_table( + full_name=table_name.replace('yamr.', '').replace('yt.', ''), + yqlrun_file=os.path.join(DATA_PATH, suite if suite else '', file_name), + src_file_alternative=os.path.join(yql_work_path(), suite if suite else '', file_name), + def_attr=def_attr, + should_exist=True + )) + return in_tables + + +def get_tables(suite, cfg, DATA_PATH, def_attr=None): + in_tables = [] + out_tables = [] + suite_dir = os.path.join(DATA_PATH, suite) + res_dir = get_yql_dir('table_') + + for splitted in cfg: + if splitted[0] == 'udf' and yatest.common.context.sanitize == 'undefined': + pytest.skip("udf under ubsan") + + if len(splitted) == 4: + type_name, table, file_name, format_name = splitted + elif len(splitted) == 3: + type_name, table, file_name = splitted + format_name = 'yson' + else: + continue + yqlrun_file = os.path.join(suite_dir, file_name) + if type_name == 'in': + in_tables.append(new_table( + full_name='plato.' + table if '.' not in table else table, + yqlrun_file=yqlrun_file, + format_name=format_name, + def_attr=def_attr, + res_dir=res_dir + )) + if type_name == 'out': + out_tables.append(new_table( + full_name='plato.' + table if '.' not in table else table, + yqlrun_file=yqlrun_file if os.path.exists(yqlrun_file) else None, + res_dir=res_dir + )) + return in_tables, out_tables + + +def get_supported_providers(cfg): + providers = 'yt', 'kikimr', 'dq', 'hybrid' + for item in cfg: + if item[0] == 'providers': + providers = [i.strip() for i in ''.join(item[1:]).split(',')] + return providers + + +def is_os_supported(cfg): + for item in cfg: + if item[0] == 'os': + return any(sys.platform.startswith(_os) for _os in item[1].split(',')) + return True + + +def is_xfail(cfg): + for item in cfg: + if item[0] == 'xfail': + return True + return False + + +def is_skip_forceblocks(cfg): + for item in cfg: + if item[0] == 'skip_forceblocks': + return True + return False + + +def is_canonize_peephole(cfg): + for item in cfg: + if item[0] == 'canonize_peephole': + return True + return False + + +def is_peephole_use_blocks(cfg): + for item in cfg: + if item[0] == 'peephole_use_blocks': + return True + return False + + +def is_canonize_lineage(cfg): + for item in cfg: + if item[0] == 'canonize_lineage': + return True + return False + + +def is_canonize_yt(cfg): + for item in cfg: + if item[0] == 'canonize_yt': + return True + return False + + +def is_with_final_result_issues(cfg): + for item in cfg: + if item[0] == 'with_final_result_issues': + return True + return False + + +def skip_test_if_required(cfg): + for item in cfg: + if item[0] == 'skip_test': + pytest.skip(item[1]) + + +def get_pragmas(cfg): + pragmas = [] + for item in cfg: + if item[0] == 'pragma': + pragmas.append(' '.join(item)) + return pragmas + + +def execute( + klass=None, + program=None, + program_file=None, + files=None, + urls=None, + run_sql=False, + verbose=False, + check_error=True, + input_tables=None, + output_tables=None, + pretty_plan=True, + parameters={}, +): + ''' + Executes YQL/SQL + + :param klass: KiKiMRForYQL if instance (default: YQLRun) + :param program: string with YQL or SQL program + :param program_file: file with YQL or SQL program (optional, if :param program: is None) + :param files: dict like {'name': '/path'} with extra files + :param urls: dict like {'name': url} with extra files urls + :param run_sql: execute sql instead of yql + :param verbose: log all results and diagnostics + :param check_error: fail on non-zero exit code + :param input_tables: list of Table (will be written if not exist) + :param output_tables: list of Table (will be returned) + :param pretty_plan: whether to use pretty printing for plan or not + :param parameters: query parameters as dict like {name: json_value} + :return: YQLExecResult + ''' + + if input_tables is None: + input_tables = [] + else: + assert isinstance(input_tables, list) + if output_tables is None: + output_tables = [] + + klass.write_tables(input_tables + output_tables) + + res = klass.yql_exec( + program=program, + program_file=program_file, + files=files, + urls=urls, + run_sql=run_sql, + verbose=verbose, + check_error=check_error, + tables=(output_tables + input_tables), + pretty_plan=pretty_plan, + parameters=parameters + ) + + try: + res_tables = klass.get_tables(output_tables) + except Exception: + if check_error: + raise + res_tables = {} + + return res, res_tables + + +execute_sql = partial(execute, run_sql=True) + + +def log(s): + if get_param('STDERR'): + print(s, file=sys.stderr) + else: + logger.debug(s) + + +def tmpdir_module(request): + return tempfile.mkdtemp(prefix='kikimr_test_') + + +@pytest.fixture(name='tmpdir_module', scope='module') +def tmpdir_module_fixture(request): + return tmpdir_module(request) + + +def escape_backslash(s): + return s.replace('\\', '\\\\') + + +def get_default_mount_point_config_content(): + return ''' + MountPoints { + RootAlias: '/lib' + MountPoint: '%s' + Library: true + } + ''' % ( + escape_backslash(yql_source_path('yql/essentials/mount/lib')) + ) + + +def get_mount_config_file(content=None): + config = yql_output_path('mount.cfg') + if not os.path.exists(config): + with open(config, 'w') as f: + f.write(content or get_default_mount_point_config_content()) + return config + + +def run_command(program, cmd, tmpdir_module=None, stdin=None, + check_exit_code=True, env=None, stdout=None): + if tmpdir_module is None: + tmpdir_module = tempfile.mkdtemp() + + stdin_stream = None + if isinstance(stdin, six.string_types): + with tempfile.NamedTemporaryFile( + prefix='stdin_', + dir=tmpdir_module, + delete=False + ) as stdin_file: + stdin_file.write(stdin.encode() if isinstance(stdin, str) else stdin) + stdin_stream = open(stdin_file.name) + elif isinstance(stdin, io.IOBase): + stdin_stream = stdin + elif stdin is not None: + assert 0, 'Strange stdin ' + repr(stdin) + + if isinstance(cmd, six.string_types): + cmd = cmd.split() + else: + cmd = [str(c) for c in cmd] + log(' '.join('\'%s\'' % c if ' ' in c else c for c in cmd)) + cmd = [program] + cmd + + stderr_stream = None + stdout_stream = None + + if stdout: + stdout_stream = stdout + + res = yatest.common.execute( + cmd, + cwd=tmpdir_module, + stdin=stdin_stream, + stdout=stdout_stream, + stderr=stderr_stream, + check_exit_code=check_exit_code, + env=env, + wait=True + ) + + if res.std_err: + log(res.std_err) + if res.std_out: + log(res.std_out) + return res + + +def yson_to_csv(yson_content, columns=None, with_header=True, strict=False): + import cyson as yson + if columns: + headers = sorted(columns) + else: + headers = set() + for item in yson.loads(yson_content): + headers.update(six.iterkeys(item)) + headers = sorted(headers) + csv_content = [] + if with_header: + csv_content.append(';'.join(headers)) + for item in yson.loads(yson_content): + if strict and sorted(six.iterkeys(item)) != headers: + return None + csv_content.append(';'.join([str(item[h]).replace('YsonEntity', '').encode('string_escape') if h in item else '' for h in headers])) + return '\n'.join(csv_content) + + +udfs_lock = Lock() + + +def get_udfs_path(extra_paths=None): + essentials_udfs_build_path = yatest.common.build_path('yql/essentials/udfs') + udfs_build_path = yatest.common.build_path('yql/udfs') + ydb_udfs_build_path = yatest.common.build_path('contrib/ydb/library/yql/udfs') + contrib_ydb_udfs_build_path = yatest.common.build_path('contrib/ydb/library/yql/udfs') + rthub_udfs_build_path = yatest.common.build_path('robot/rthub/yql/udfs') + kwyt_udfs_build_path = yatest.common.build_path('robot/kwyt/yql/udfs') + + try: + udfs_bin_path = yatest.common.binary_path('yql/udfs') + except Exception: + udfs_bin_path = None + + try: + udfs_project_path = yql_binary_path('yql/library/test_framework/udfs_deps') + except Exception: + udfs_project_path = None + + try: + ydb_udfs_project_path = yql_binary_path('yql/essentials/tests/common/test_framework/udfs_deps') + except Exception: + ydb_udfs_project_path = None + + merged_udfs_path = yql_output_path('yql_udfs') + with udfs_lock: + if not os.path.isdir(merged_udfs_path): + os.mkdir(merged_udfs_path) + + udfs_paths = [ + udfs_project_path, + ydb_udfs_project_path, + udfs_bin_path, + essentials_udfs_build_path, + udfs_build_path, + ydb_udfs_build_path, + contrib_ydb_udfs_build_path, + rthub_udfs_build_path, + kwyt_udfs_build_path + ] + if extra_paths is not None: + udfs_paths += extra_paths + + log('process search UDF in: %s, %s, %s, %s' % (udfs_project_path, ydb_udfs_project_path, udfs_bin_path, udfs_build_path)) + for _udfs_path in udfs_paths: + if _udfs_path: + for dirpath, dnames, fnames in os.walk(_udfs_path): + for f in fnames: + if f.endswith('.so'): + f = os.path.join(dirpath, f) + if not os.path.exists(f) and os.path.lexists(f): # seems like broken symlink + try: + os.unlink(f) + except OSError: + pass + link_name = os.path.join(merged_udfs_path, os.path.basename(f)) + if not os.path.exists(link_name): + os.symlink(f, link_name) + log('Added UDF: ' + f) + return merged_udfs_path + + +def get_test_prefix(): + return 'yql_tmp_' + hashlib.md5(yatest.common.context.test_name).hexdigest() + + +def normalize_plan_ids(plan, no_detailed=False): + remapOps = {} + + for node in sorted(filter(lambda n: n["type"] == "in", plan["Basic"]["nodes"]), key=lambda n: n.get("name")): + if node["id"] not in remapOps: + remapOps[node["id"]] = len(remapOps) + 1 + + for node in plan["Basic"]["nodes"]: + if node["id"] not in remapOps: + remapOps[node["id"]] = len(remapOps) + 1 + + def subst_basic(y): + if isinstance(y, list): + return [subst_basic(i) for i in y] + if isinstance(y, dict): + res = {} + for k, v in six.iteritems(y): + if k in {'source', 'target', 'id'}: + res[k] = remapOps.get(v) + elif k == "links": + res[k] = sorted(subst_basic(v), key=lambda x: (x["source"], x["target"])) + elif k == "nodes": + res[k] = sorted(subst_basic(v), key=lambda x: x["id"]) + else: + res[k] = subst_basic(v) + return res + return y + + # Sort and normalize input ids + def subst_detailed(y): + if isinstance(y, list): + return [subst_detailed(i) for i in y] + if isinstance(y, dict): + res = {} + for k, v in six.iteritems(y): + if k == "DependsOn": + res[k] = sorted([remapOps.get(i) for i in v]) + elif k == "Providers": + res[k] = v + elif k in {'OperationRoot', 'Id'}: + res[k] = remapOps.get(v) + else: + res[k] = subst_detailed(v) + return res + return y + + if no_detailed: + return {"Basic": subst_basic(plan["Basic"])} + return {"Basic": subst_basic(plan["Basic"]), "Detailed": subst_detailed(plan["Detailed"])} + + +def normalized_plan_stats(plan): + renameMap = { + "MrLMap!": "YtMap!", + "MrMapReduce!": "YtMapReduce!", + "MrLReduce!": "YtMapReduce!", + "MrOrderedReduce!": "YtReduce!", + "MrSort!": "YtSort!", + "MrCopy!": "YtCopy!", + "YtMerge!": "YtCopy!", + "MrFill!": "YtFill!", + "MrDrop!": "YtDropTable!", + "YtTouch!": None, + "MrReadTable!": None, + "YtReadTable!": None, + "MrPublish!": "YtPublish!", + "MrReadTableScheme!": "YtReadTableScheme!", + } + + normalizedStat = defaultdict(int) + + for op, stat in six.iteritems(plan["Detailed"]["OperationStats"]): + renamedOp = renameMap.get(op, op) + if renamedOp is not None: + normalizedStat[renamedOp] += stat + + return normalizedStat + + +def normalize_table_yson(y): + from cyson import YsonEntity + if isinstance(y, list): + return [normalize_table_yson(i) for i in y] + if isinstance(y, dict): + normDict = OrderedDict() + for k, v in sorted(six.iteritems(y), key=lambda x: x[0], reverse=True): + if k == "_other": + normDict[normalize_table_yson(k)] = sorted(normalize_table_yson(v)) + elif v != "Void" and v is not None and not isinstance(v, YsonEntity): + normDict[normalize_table_yson(k)] = normalize_table_yson(v) + return normDict + return y + + +def dump_table_yson(res_yson, sort=True): + rows = normalize_table_yson(cyson.loads('[' + res_yson + ']')) + if sort: + rows = sorted(rows) + return cyson.dumps(rows, format="pretty") + + +def normalize_source_code_path(s): + # remove contrib/ + s = re.sub(r'\b(contrib/)(ydb/library/yql.*)', r'\2', s) + # replace line number in source code with 'xxx' + s = re.sub(r'\b(yql/[\w/]+(?:\.cpp|\.h)):(?:\d+)', r'\1:xxx', s) + return re.sub(r'(/lib/yql/[\w/]+(?:\.yql|\.sql)):(?:\d+):(?:\d+)', r'\1:xxx:yyy', s) + + +def do_get_files(suite, config, DATA_PATH, config_key): + files = dict() + suite_dir = os.path.join(DATA_PATH, suite) + res_dir = None + for line in config: + if line[0] == config_key: + _, name, path = line + userpath = find_user_file(suite, path, DATA_PATH) + relpath = os.path.relpath(userpath, suite_dir) + if os.path.exists(os.path.join('cwd', relpath)): + path = relpath + else: + path = userpath + + if not res_dir: + res_dir = get_yql_dir('file_') + + new_path = os.path.join(res_dir, os.path.basename(path)) + shutil.copyfile(path, new_path) + + files[name] = new_path + + return files + + +def get_files(suite, config, DATA_PATH): + return do_get_files(suite, config, DATA_PATH, 'file') + + +def get_http_files(suite, config, DATA_PATH): + return do_get_files(suite, config, DATA_PATH, 'http_file') + + +def get_yt_files(suite, config, DATA_PATH): + return do_get_files(suite, config, DATA_PATH, 'yt_file') + + +def get_syntax_version(program): + syntax_version_param = get_param('SYNTAX_VERSION') + default_syntax_version = 1 + if 'syntax version 0' in program: + return 0 + elif 'syntax version 1' in program: + return 1 + elif syntax_version_param: + return int(syntax_version_param) + else: + return default_syntax_version + + +def ansi_lexer_enabled(program): + return 'ansi_lexer' in program + + +def pytest_get_current_part(path): + folder = os.path.dirname(path) + folder_name = os.path.basename(folder) + assert folder_name.startswith('part'), "Current folder is {}".format(folder_name) + current = int(folder_name[len('part'):]) + + parent = os.path.dirname(folder) + maxpart = max([int(part[len('part'):]) if part.startswith('part') else -1 for part in os.listdir(parent)]) + assert maxpart > 0, "Cannot find parts in {}".format(parent) + return (current, 1 + maxpart) + + +def normalize_result(res, sort): + res = cyson.loads(res) if res else cyson.loads("[]") + res = replace_vals(res) + for r in res: + for data in r['Write']: + if sort and 'Data' in data: + data['Data'] = sorted(data['Data']) + if 'Ref' in data: + data['Ref'] = [] + data['Truncated'] = True + if 'Data' in data and len(data['Data']) == 0: + del data['Data'] + return res + + +def stable_write(writer, node): + if hasattr(node, 'attributes'): + writer.begin_attributes() + for k in sorted(node.attributes.keys()): + writer.key(k) + stable_write(writer, node.attributes[k]) + writer.end_attributes() + if isinstance(node, list): + writer.begin_list() + for r in node: + stable_write(writer, r) + writer.end_list() + return + if isinstance(node, dict): + writer.begin_map() + for k in sorted(node.keys()): + writer.key(k) + stable_write(writer, node[k]) + writer.end_map() + return + writer.write(node) + + +def stable_result_file(res): + path = res.results_file + assert os.path.exists(path) + with open(path) as f: + res = f.read() + res = cyson.loads(res) + res = replace_vals(res) + for r in res: + for data in r['Write']: + if 'Unordered' in r and 'Data' in data: + data['Data'] = sorted(data['Data']) + with open(path, 'w') as f: + writer = cyson.Writer(stream=cyson.OutputStream.from_file(f), format='pretty', mode='node') + writer.begin_stream() + stable_write(writer, res) + writer.end_stream() + with open(path) as f: + return f.read() + + +def stable_table_file(table): + path = table.file + assert os.path.exists(path) + assert table.attr is not None + is_sorted = False + for column in cyson.loads(table.attr)['schema']: + if 'sort_order' in column: + is_sorted = True + break + if not is_sorted: + with open(path) as f: + r = cyson.Reader(cyson.InputStream.from_file(f), mode='list_fragment') + lst = sorted(list(r.list_fragments())) + with open(path, 'w') as f: + writer = cyson.Writer(stream=cyson.OutputStream.from_file(f), format='pretty', mode='list_fragment') + writer.begin_stream() + for r in lst: + stable_write(writer, r) + writer.end_stream() + with open(path) as f: + return f.read() + + +class LoggingDowngrade(object): + + def __init__(self, loggers, level=logging.CRITICAL): + self.loggers = [(name, logging.getLogger(name).getEffectiveLevel()) for name in loggers] + self.level = level + + def __enter__(self): + self.prev_levels = [] + for name, _ in self.loggers: + log = logging.getLogger(name) + log.setLevel(self.level) + return self + + def __exit__(self, exc_type, exc_value, tb): + for name, level in self.loggers: + log = logging.getLogger(name) + log.setLevel(level) + return True diff --git a/yql/essentials/tests/common/test_framework/yqlrun.py b/yql/essentials/tests/common/test_framework/yqlrun.py new file mode 100644 index 0000000000..b96641912a --- /dev/null +++ b/yql/essentials/tests/common/test_framework/yqlrun.py @@ -0,0 +1,346 @@ +import os +import shutil +import yatest.common +import yql_utils +import cyson as yson +import yql.essentials.providers.common.proto.gateways_config_pb2 as gateways_config_pb2 +import yql.essentials.core.file_storage.proto.file_storage_pb2 as file_storage_pb2 + +import six + +from google.protobuf import text_format + +ARCADIA_PREFIX = 'arcadia/' +ARCADIA_TESTS_DATA_PREFIX = 'arcadia_tests_data/' + +VAR_CHAR_PREFIX = '$' +FIX_DIR_PREFIXES = { + 'SOURCE': yatest.common.source_path, + 'BUILD': yatest.common.build_path, + 'TEST_SOURCE': yatest.common.test_source_path, + 'DATA': yatest.common.data_path, + 'BINARY': yatest.common.binary_path, +} + + +class YQLRun(object): + + def __init__(self, udfs_dir=None, prov='yt', use_sql2yql=False, keep_temp=True, binary=None, gateway_config=None, fs_config=None, extra_args=[], cfg_dir=None, support_udfs=True): + if binary is None: + self.yqlrun_binary = yql_utils.yql_binary_path(os.getenv('YQL_YQLRUN_PATH') or 'contrib/ydb/library/yql/tools/yqlrun/yqlrun') + else: + self.yqlrun_binary = binary + self.extra_args = extra_args + + try: + self.sql2yql_binary = yql_utils.yql_binary_path(os.getenv('YQL_SQL2YQL_PATH') or 'yql/essentials/tools/sql2yql/sql2yql') + except BaseException: + self.sql2yql_binary = None + + try: + self.udf_resolver_binary = yql_utils.yql_binary_path(os.getenv('YQL_UDFRESOLVER_PATH') or 'yql/essentials/tools/udf_resolver/udf_resolver') + except Exception: + self.udf_resolver_binary = None + + if support_udfs: + if udfs_dir is None: + self.udfs_path = yql_utils.get_udfs_path() + else: + self.udfs_path = udfs_dir + else: + self.udfs_path = None + + res_dir = yql_utils.get_yql_dir(prefix='yqlrun_') + self.res_dir = res_dir + self.tables = {} + self.prov = prov + self.use_sql2yql = use_sql2yql + self.keep_temp = keep_temp + + self.gateway_config = gateways_config_pb2.TGatewaysConfig() + if gateway_config is not None: + text_format.Merge(gateway_config, self.gateway_config) + + yql_utils.merge_default_gateway_cfg(cfg_dir or 'yql/essentials/cfg/tests', self.gateway_config) + + self.fs_config = file_storage_pb2.TFileStorageConfig() + + with open(yql_utils.yql_source_path(os.path.join(cfg_dir or 'yql/essentials/cfg/tests', 'fs.conf'))) as f: + text_format.Merge(f.read(), self.fs_config) + + if fs_config is not None: + text_format.Merge(fs_config, self.fs_config) + + if yql_utils.get_param('USE_NATIVE_YT_TYPES'): + attr = self.gateway_config.Yt.DefaultSettings.add() + attr.Name = 'UseNativeYtTypes' + attr.Value = 'true' + + if yql_utils.get_param('SQL_FLAGS'): + flags = yql_utils.get_param('SQL_FLAGS').split(',') + self.gateway_config.SqlCore.TranslationFlags.extend(flags) + + def yql_exec(self, program=None, program_file=None, files=None, urls=None, + run_sql=False, verbose=False, check_error=True, tables=None, pretty_plan=True, + wait=True, parameters={}, extra_env={}, require_udf_resolver=False, scan_udfs=True): + del pretty_plan + + res_dir = self.res_dir + + def res_file_path(name): + return os.path.join(res_dir, name) + + opt_file = res_file_path('opt.yql') + results_file = res_file_path('results.txt') + plan_file = res_file_path('plan.txt') + err_file = res_file_path('err.txt') + + udfs_dir = self.udfs_path + prov = self.prov + + program, program_file = yql_utils.prepare_program(program, program_file, res_dir, + ext='sql' if run_sql else 'yql') + + syntax_version = yql_utils.get_syntax_version(program) + ansi_lexer = yql_utils.ansi_lexer_enabled(program) + + if run_sql and self.use_sql2yql: + orig_sql = program_file + '.orig_sql' + shutil.copy2(program_file, orig_sql) + cmd = [ + self.sql2yql_binary, + orig_sql, + '--yql', + '--output=' + program_file, + '--syntax-version=%d' % syntax_version + ] + if ansi_lexer: + cmd.append('--ansi-lexer') + env = {'YQL_DETERMINISTIC_MODE': '1'} + env.update(extra_env) + for var in [ + 'LLVM_PROFILE_FILE', + 'GO_COVERAGE_PREFIX', + 'PYTHON_COVERAGE_PREFIX', + 'NLG_COVERAGE_FILENAME', + 'YQL_EXPORT_PG_FUNCTIONS_DIR', + 'YQL_ALLOW_ALL_PG_FUNCTIONS', + ]: + if var in os.environ: + env[var] = os.environ[var] + yatest.common.process.execute(cmd, cwd=res_dir, env=env) + + with open(program_file) as f: + yql_program = f.read() + with open(program_file, 'w') as f: + f.write(yql_program) + + gateways_cfg_file = res_file_path('gateways.conf') + with open(gateways_cfg_file, 'w') as f: + f.write(str(self.gateway_config)) + + fs_cfg_file = res_file_path('fs.conf') + with open(fs_cfg_file, 'w') as f: + f.write(str(self.fs_config)) + + cmd = self.yqlrun_binary + ' ' + + if yql_utils.get_param('TRACE_OPT'): + cmd += '--trace-opt ' + + cmd += '-L ' \ + '--program=%(program_file)s ' \ + '--expr-file=%(opt_file)s ' \ + '--result-file=%(results_file)s ' \ + '--plan-file=%(plan_file)s ' \ + '--err-file=%(err_file)s ' \ + '--gateways=%(prov)s ' \ + '--syntax-version=%(syntax_version)d ' \ + '--tmp-dir=%(res_dir)s ' \ + '--gateways-cfg=%(gateways_cfg_file)s ' \ + '--fs-cfg=%(fs_cfg_file)s ' % locals() + + if self.udfs_path is not None: + cmd += '--udfs-dir=%(udfs_dir)s ' % locals() + + if ansi_lexer: + cmd += '--ansi-lexer ' + + if self.keep_temp: + cmd += '--keep-temp ' + + if self.extra_args: + cmd += " ".join(self.extra_args) + " " + + cmd += '--mounts=' + yql_utils.get_mount_config_file() + ' ' + cmd += '--validate-result-format ' + + if files: + for f in files: + if files[f].startswith(ARCADIA_PREFIX): # how does it work with folders? and does it? + files[f] = yatest.common.source_path(files[f][len(ARCADIA_PREFIX):]) + continue + if files[f].startswith(ARCADIA_TESTS_DATA_PREFIX): + files[f] = yatest.common.data_path(files[f][len(ARCADIA_TESTS_DATA_PREFIX):]) + continue + + if files[f].startswith(VAR_CHAR_PREFIX): + for prefix, func in six.iteritems(FIX_DIR_PREFIXES): + if files[f].startswith(VAR_CHAR_PREFIX + prefix): + real_path = func(files[f][len(prefix) + 2:]) # $ + prefix + / + break + else: + raise Exception("unknown prefix in file path %s" % (files[f],)) + copy_dest = os.path.join(res_dir, f) + if not os.path.exists(os.path.dirname(copy_dest)): + os.makedirs(os.path.dirname(copy_dest)) + shutil.copy2( + real_path, + copy_dest, + ) + files[f] = f + continue + + if not files[f].startswith('/'): # why do we check files[f] instead of f here? + path_to_copy = os.path.join( + yatest.common.work_path(), + files[f] + ) + if '/' in files[f]: + copy_dest = os.path.join( + res_dir, + os.path.dirname(files[f]) + ) + if not os.path.exists(copy_dest): + os.makedirs(copy_dest) + else: + copy_dest = res_dir + files[f] = os.path.basename(files[f]) + shutil.copy2(path_to_copy, copy_dest) + else: + shutil.copy2(files[f], res_dir) + files[f] = os.path.basename(files[f]) + cmd += yql_utils.get_cmd_for_files('--file', files) + + if urls: + cmd += yql_utils.get_cmd_for_files('--url', urls) + + optimize_only = False + if tables: + for table in tables: + self.tables[table.full_name] = table + if table.format != 'yson': + optimize_only = True + for name in self.tables: + cmd += '--table=yt.%s@%s ' % (name, self.tables[name].yqlrun_file) + + if "--lineage" not in self.extra_args: + if optimize_only: + cmd += '-O ' + else: + cmd += '--run ' + + if yql_utils.get_param('UDF_RESOLVER') or require_udf_resolver: + assert self.udf_resolver_binary, "Missing udf_resolver binary" + cmd += '--udf-resolver=' + self.udf_resolver_binary + ' ' + if scan_udfs: + cmd += '--scan-udfs ' + if not yatest.common.context.sanitize: + cmd += '--udf-resolver-filter-syscalls ' + + if run_sql and not self.use_sql2yql: + cmd += '--sql ' + + if parameters: + parameters_file = res_file_path('params.yson') + with open(parameters_file, 'w') as f: + f.write(six.ensure_str(yson.dumps(parameters))) + cmd += '--params-file=%s ' % parameters_file + + if verbose: + yql_utils.log('prov is ' + self.prov) + + env = {'YQL_DETERMINISTIC_MODE': '1'} + env.update(extra_env) + for var in [ + 'LLVM_PROFILE_FILE', + 'GO_COVERAGE_PREFIX', + 'PYTHON_COVERAGE_PREFIX', + 'NLG_COVERAGE_FILENAME', + 'YQL_EXPORT_PG_FUNCTIONS_DIR', + 'YQL_ALLOW_ALL_PG_FUNCTIONS', + ]: + if var in os.environ: + env[var] = os.environ[var] + if yql_utils.get_param('STDERR'): + debug_udfs_dir = os.path.join(os.path.abspath('.'), '..', '..', '..') + env_setters = ";".join("{}={}".format(k, v) for k, v in six.iteritems(env)) + yql_utils.log('GDB launch command:') + yql_utils.log('(cd "%s" && %s ya tool gdb --args %s)' % (res_dir, env_setters, cmd.replace(udfs_dir, debug_udfs_dir))) + + proc_result = yatest.common.process.execute(cmd.strip().split(), check_exit_code=False, cwd=res_dir, env=env) + if proc_result.exit_code != 0 and check_error: + with open(err_file, 'r') as f: + err_file_text = f.read() + assert 0, \ + 'Command\n%(command)s\n finished with exit code %(code)d, stderr:\n\n%(stderr)s\n\nerror file:\n%(err_file)s' % { + 'command': cmd, + 'code': proc_result.exit_code, + 'stderr': proc_result.std_err, + 'err_file': err_file_text + } + + if os.path.exists(results_file) and os.stat(results_file).st_size == 0: + os.unlink(results_file) # kikimr yql-exec compatibility + + results, log_results = yql_utils.read_res_file(results_file) + plan, log_plan = yql_utils.read_res_file(plan_file) + opt, log_opt = yql_utils.read_res_file(opt_file) + err, log_err = yql_utils.read_res_file(err_file) + + if verbose: + yql_utils.log('PROGRAM:') + yql_utils.log(program) + yql_utils.log('OPT:') + yql_utils.log(log_opt) + yql_utils.log('PLAN:') + yql_utils.log(log_plan) + yql_utils.log('RESULTS:') + yql_utils.log(log_results) + yql_utils.log('ERROR:') + yql_utils.log(log_err) + + return yql_utils.YQLExecResult( + proc_result.std_out, + yql_utils.normalize_source_code_path(err.replace(res_dir, '<tmp_path>')), + results, + results_file, + opt, + opt_file, + plan, + plan_file, + program, + proc_result, + None + ) + + def create_empty_tables(self, tables): + pass + + def write_tables(self, tables): + pass + + def get_tables(self, tables): + res = {} + for table in tables: + # recreate table after yql program was executed + res[table.full_name] = yql_utils.new_table( + table.full_name, + yqlrun_file=self.tables[table.full_name].yqlrun_file, + res_dir=self.res_dir + ) + + yql_utils.log('YQLRun table ' + table.full_name) + yql_utils.log(res[table.full_name].content) + + return res diff --git a/yql/essentials/tests/common/udf_test/test.py b/yql/essentials/tests/common/udf_test/test.py new file mode 100644 index 0000000000..218b05b4bd --- /dev/null +++ b/yql/essentials/tests/common/udf_test/test.py @@ -0,0 +1,111 @@ +import os +import os.path +import glob +import codecs +import shutil + +import pytest + +import yql_utils +from yqlrun import YQLRun + +import yatest.common + +project_path = yatest.common.context.project_path +SOURCE_PATH = yql_utils.yql_source_path((project_path + '/cases').replace('\\', '/')) +DATA_PATH = yatest.common.output_path('cases') +ASTDIFF_PATH = yql_utils.yql_binary_path(os.getenv('YQL_ASTDIFF_PATH') or 'yql/essentials/tools/astdiff/astdiff') + + +def pytest_generate_tests(metafunc): + if os.path.exists(SOURCE_PATH): + shutil.copytree(SOURCE_PATH, DATA_PATH) + cases = sorted([os.path.basename(sql_query)[:-4] for sql_query in glob.glob(DATA_PATH + '/*.sql')]) + + else: + cases = [] + metafunc.parametrize(['case'], [(case, ) for case in cases]) + + +def test(case): + program_file = os.path.join(DATA_PATH, case + '.sql') + + with codecs.open(program_file, encoding='utf-8') as f: + program = f.readlines() + + header = program[0] + canonize_ast = False + + if header.startswith('--ignore'): + pytest.skip(header) + elif header.startswith('--sanitizer ignore') and yatest.common.context.sanitize is not None: + pytest.skip(header) + elif header.startswith('--sanitizer ignore address') and yatest.common.context.sanitize == 'address': + pytest.skip(header) + elif header.startswith('--sanitizer ignore memory') and yatest.common.context.sanitize == 'memory': + pytest.skip(header) + elif header.startswith('--sanitizer ignore thread') and yatest.common.context.sanitize == 'thread': + pytest.skip(header) + elif header.startswith('--sanitizer ignore undefined') and yatest.common.context.sanitize == 'undefined': + pytest.skip(header) + elif header.startswith('--canonize ast'): + canonize_ast = True + + program = '\n'.join(['use plato;'] + program) + + cfg = yql_utils.get_program_cfg(None, case, DATA_PATH) + files = {} + diff_tool = None + scan_udfs = False + for item in cfg: + if item[0] == 'file': + files[item[1]] = item[2] + if item[0] == 'diff_tool': + diff_tool = item[1:] + if item[0] == 'scan_udfs': + scan_udfs = True + + in_tables = yql_utils.get_input_tables(None, cfg, DATA_PATH, def_attr=yql_utils.KSV_ATTR) + + udfs_dir = yql_utils.get_udfs_path([ + yatest.common.build_path(os.path.join(yatest.common.context.project_path, "..")) + ]) + + xfail = yql_utils.is_xfail(cfg) + if yql_utils.get_param('TARGET_PLATFORM') and xfail: + pytest.skip('xfail is not supported on non-default target platform') + + extra_env = dict(os.environ) + extra_env["YQL_UDF_RESOLVER"] = "1" + extra_env["YQL_ARCADIA_BINARY_PATH"] = os.path.expandvars(yatest.common.build_path('.')) + extra_env["YQL_ARCADIA_SOURCE_PATH"] = os.path.expandvars(yatest.common.source_path('.')) + extra_env["Y_NO_AVX_IN_DOT_PRODUCT"] = "1" + + # this breaks tests using V0 syntax + if "YA_TEST_RUNNER" in extra_env: + del extra_env["YA_TEST_RUNNER"] + + yqlrun_res = YQLRun(udfs_dir=udfs_dir, prov='yt', use_sql2yql=False, cfg_dir=os.getenv('YQL_CONFIG_DIR') or 'yql/essentials/cfg/udf_test').yql_exec( + program=program, + run_sql=True, + tables=in_tables, + files=files, + check_error=not xfail, + extra_env=extra_env, + require_udf_resolver=True, + scan_udfs=scan_udfs + ) + + if xfail: + assert yqlrun_res.execution_result.exit_code != 0 + + results_path = os.path.join(yql_utils.yql_output_path(), case + '.results.txt') + with open(results_path, 'w') as f: + f.write(yqlrun_res.results) + + to_canonize = [yqlrun_res.std_err] if xfail else [yatest.common.canonical_file(yqlrun_res.results_file, local=True, diff_tool=diff_tool)] + + if canonize_ast: + to_canonize += [yatest.common.canonical_file(yqlrun_res.opt_file, local=True, diff_tool=ASTDIFF_PATH)] + + return to_canonize diff --git a/yql/essentials/tests/common/udf_test/ya.make b/yql/essentials/tests/common/udf_test/ya.make new file mode 100644 index 0000000000..37570be0ab --- /dev/null +++ b/yql/essentials/tests/common/udf_test/ya.make @@ -0,0 +1,9 @@ +PY23_LIBRARY() + +TEST_SRCS(test.py) + +PEERDIR( + yql/essentials/tests/common/test_framework +) + +END() diff --git a/yql/essentials/tests/common/ya.make b/yql/essentials/tests/common/ya.make new file mode 100644 index 0000000000..1ac429bbb1 --- /dev/null +++ b/yql/essentials/tests/common/ya.make @@ -0,0 +1,5 @@ +RECURSE( + test_framework + udf_test +) + |