aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorudovichenko-r <udovichenko-r@yandex-team.com>2024-11-18 18:55:10 +0300
committerudovichenko-r <udovichenko-r@yandex-team.com>2024-11-18 19:27:57 +0300
commit9034652d9fcda22d641b8b030a757a3942112f5f (patch)
treecf418ec574ac74e6a3eef6459eeffd35da75aeff
parent898c44637b5a3602cc610a0352889be7b897a6a6 (diff)
downloadydb-9034652d9fcda22d641b8b030a757a3942112f5f.tar.gz
YQL-19206 Move contrib/ydb/library/yql/tests/common -> yql/essentials/tests/common
commit_hash:b0bab3353351a5c79f0a64237b103ddba0004fd7
-rw-r--r--build/conf/project_specific/yql_udf.conf6
-rw-r--r--yql/essentials/parser/pg_wrapper/test/ya.make2
-rw-r--r--yql/essentials/tests/common/test_framework/conftest.py14
-rw-r--r--yql/essentials/tests/common/test_framework/solomon_runner.py40
-rw-r--r--yql/essentials/tests/common/test_framework/udfs_deps/ya.make51
-rw-r--r--yql/essentials/tests/common/test_framework/ya.make30
-rw-r--r--yql/essentials/tests/common/test_framework/yql_http_file_server.py136
-rw-r--r--yql/essentials/tests/common/test_framework/yql_ports.py43
-rw-r--r--yql/essentials/tests/common/test_framework/yql_utils.py1043
-rw-r--r--yql/essentials/tests/common/test_framework/yqlrun.py346
-rw-r--r--yql/essentials/tests/common/udf_test/test.py111
-rw-r--r--yql/essentials/tests/common/udf_test/ya.make9
-rw-r--r--yql/essentials/tests/common/ya.make5
-rw-r--r--yql/essentials/tests/ya.make6
-rw-r--r--yql/essentials/ya.make1
15 files changed, 1839 insertions, 4 deletions
diff --git a/build/conf/project_specific/yql_udf.conf b/build/conf/project_specific/yql_udf.conf
index df2ec98480..8f623fbcb8 100644
--- a/build/conf/project_specific/yql_udf.conf
+++ b/build/conf/project_specific/yql_udf.conf
@@ -36,7 +36,7 @@ macro UDF_NO_PROBE() {
module YQL_UDF_TEST: PY3TEST_BIN {
SET_APPEND(_MAKEFILE_INCLUDE_LIKE_DEPS canondata/result.json)
- PEERDIR($YQL_BASE_TEST_DIR/tests/common/udf_test)
+ PEERDIR(yql/essentials/tests/common/udf_test)
DEPENDS(yql/essentials/tools/astdiff)
DEPENDS($YQL_BASE_TEST_DIR/tools/yqlrun)
@@ -53,7 +53,7 @@ module YQL_UDF_TEST: PY3TEST_BIN {
module YQL_UDF_YDB_TEST: PY3TEST_BIN {
SET_APPEND(_MAKEFILE_INCLUDE_LIKE_DEPS canondata/result.json)
- PEERDIR($YQL_BASE_TEST_DIR/tests/common/udf_test)
+ PEERDIR(yql/essentials/tests/common/udf_test)
DEPENDS(yql/essentials/tools/astdiff)
DEPENDS($YQL_BASE_TEST_DIR/tools/yqlrun)
@@ -65,7 +65,7 @@ module YQL_UDF_YDB_TEST: PY3TEST_BIN {
module YQL_UDF_TEST_CONTRIB: PY3TEST_BIN {
SET_APPEND(_MAKEFILE_INCLUDE_LIKE_DEPS canondata/result.json)
- PEERDIR($YQL_BASE_TEST_DIR/tests/common/udf_test)
+ PEERDIR(yql/essentials/tests/common/udf_test)
DEPENDS(yql/essentials/tools/astdiff)
DEPENDS($YQL_BASE_TEST_DIR/tools/yqlrun)
diff --git a/yql/essentials/parser/pg_wrapper/test/ya.make b/yql/essentials/parser/pg_wrapper/test/ya.make
index da4cb9a291..9b29b15f97 100644
--- a/yql/essentials/parser/pg_wrapper/test/ya.make
+++ b/yql/essentials/parser/pg_wrapper/test/ya.make
@@ -28,7 +28,7 @@ DATA(
)
PEERDIR(
- contrib/ydb/library/yql/tests/common/test_framework
+ yql/essentials/tests/common/test_framework
)
DEPENDS(
diff --git a/yql/essentials/tests/common/test_framework/conftest.py b/yql/essentials/tests/common/test_framework/conftest.py
new file mode 100644
index 0000000000..675726de78
--- /dev/null
+++ b/yql/essentials/tests/common/test_framework/conftest.py
@@ -0,0 +1,14 @@
+try:
+ from yql_http_file_server import yql_http_file_server
+except ImportError:
+ yql_http_file_server = None
+
+try:
+ from solomon_runner import solomon
+except ImportError:
+ solomon = None
+
+# bunch of useless statements for linter happiness
+# (otherwise it complains about unused names)
+assert yql_http_file_server is yql_http_file_server
+assert solomon is solomon
diff --git a/yql/essentials/tests/common/test_framework/solomon_runner.py b/yql/essentials/tests/common/test_framework/solomon_runner.py
new file mode 100644
index 0000000000..de6062a9ec
--- /dev/null
+++ b/yql/essentials/tests/common/test_framework/solomon_runner.py
@@ -0,0 +1,40 @@
+import os
+import pytest
+import requests
+
+
+class SolomonWrapper(object):
+ def __init__(self, url, endpoint):
+ self._url = url
+ self._endpoint = endpoint
+ self.table_prefix = ""
+
+ def is_valid(self):
+ return self._url is not None
+
+ def cleanup(self):
+ res = requests.post(self._url + "/cleanup")
+ res.raise_for_status()
+
+ def get_metrics(self):
+ res = requests.get(self._url + "/metrics?project=my_project&cluster=my_cluster&service=my_service")
+ res.raise_for_status()
+ return res.text
+
+ def prepare_program(self, program, program_file, res_dir, lang='sql'):
+ return program, program_file
+
+ @property
+ def url(self):
+ return self._url
+
+ @property
+ def endpoint(self):
+ return self._endpoint
+
+
+@pytest.fixture(scope='module')
+def solomon(request):
+ solomon_url = os.environ.get("SOLOMON_URL")
+ solomon_endpoint = os.environ.get("SOLOMON_ENDPOINT")
+ return SolomonWrapper(solomon_url, solomon_endpoint)
diff --git a/yql/essentials/tests/common/test_framework/udfs_deps/ya.make b/yql/essentials/tests/common/test_framework/udfs_deps/ya.make
new file mode 100644
index 0000000000..16b320bc3b
--- /dev/null
+++ b/yql/essentials/tests/common/test_framework/udfs_deps/ya.make
@@ -0,0 +1,51 @@
+SET(
+ UDFS
+ yql/essentials/udfs/common/datetime2
+ yql/essentials/udfs/common/digest
+ yql/essentials/udfs/common/file
+ yql/essentials/udfs/common/hyperloglog
+ yql/essentials/udfs/common/pire
+ yql/essentials/udfs/common/protobuf
+ yql/essentials/udfs/common/re2
+ yql/essentials/udfs/common/set
+ yql/essentials/udfs/common/stat
+ yql/essentials/udfs/common/topfreq
+ yql/essentials/udfs/common/top
+ yql/essentials/udfs/common/string
+ yql/essentials/udfs/common/histogram
+ yql/essentials/udfs/common/json2
+ yql/essentials/udfs/common/yson2
+ yql/essentials/udfs/common/math
+ yql/essentials/udfs/common/url_base
+ yql/essentials/udfs/common/unicode_base
+ yql/essentials/udfs/common/streaming
+ yql/essentials/udfs/examples/callables
+ yql/essentials/udfs/examples/dicts
+ yql/essentials/udfs/examples/dummylog
+ yql/essentials/udfs/examples/lists
+ yql/essentials/udfs/examples/structs
+ yql/essentials/udfs/examples/type_inspection
+ yql/essentials/udfs/logs/dsv
+ yql/essentials/udfs/test/simple
+ yql/essentials/udfs/test/test_import
+)
+
+IF (OS_LINUX AND CLANG)
+ SET(
+ UDFS
+ ${UDFS}
+ yql/essentials/udfs/common/hyperscan
+ )
+ENDIF()
+
+PACKAGE()
+
+IF (SANITIZER_TYPE != "undefined")
+
+PEERDIR(
+ ${UDFS}
+)
+
+ENDIF()
+
+END()
diff --git a/yql/essentials/tests/common/test_framework/ya.make b/yql/essentials/tests/common/test_framework/ya.make
new file mode 100644
index 0000000000..c9b91ff5c8
--- /dev/null
+++ b/yql/essentials/tests/common/test_framework/ya.make
@@ -0,0 +1,30 @@
+PY23_LIBRARY()
+
+PY_SRCS(
+ TOP_LEVEL
+ solomon_runner.py
+ yql_utils.py
+ yql_ports.py
+ yqlrun.py
+ yql_http_file_server.py
+)
+
+PY_SRCS(
+ NAMESPACE ydb_library_yql_test_framework
+ conftest.py
+)
+
+PEERDIR(
+ contrib/python/requests
+ contrib/python/six
+ contrib/python/urllib3
+ library/python/cyson
+ yql/essentials/core/file_storage/proto
+ yql/essentials/providers/common/proto
+)
+
+END()
+
+RECURSE(
+ udfs_deps
+)
diff --git a/yql/essentials/tests/common/test_framework/yql_http_file_server.py b/yql/essentials/tests/common/test_framework/yql_http_file_server.py
new file mode 100644
index 0000000000..ad58588ed1
--- /dev/null
+++ b/yql/essentials/tests/common/test_framework/yql_http_file_server.py
@@ -0,0 +1,136 @@
+import io
+import os
+import pytest
+import threading
+import shutil
+
+import six.moves.BaseHTTPServer as BaseHTTPServer
+import six.moves.socketserver as socketserver
+
+from yql_ports import get_yql_port, release_yql_port
+
+
+# handler is created on each request
+# store state in server
+class YqlHttpRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
+ def get_requested_filename(self):
+ return self.path.lstrip('/')
+
+ def do_GET(self):
+ f = self.send_head(self.get_requested_filename())
+ if f:
+ try:
+ shutil.copyfileobj(f, self.wfile)
+ finally:
+ f.close()
+
+ def do_HEAD(self):
+ f = self.send_head(self.get_requested_filename())
+ if f:
+ f.close()
+
+ def get_file_and_size(self, filename):
+ try:
+ path = self.server.file_paths[filename]
+ f = open(path, 'rb')
+ fs = os.fstat(f.fileno())
+ size = fs[6]
+ return (f, size)
+ except KeyError:
+ try:
+ content = self.server.file_contents[filename]
+ return (io.BytesIO(content), len(content))
+ except KeyError:
+ return (None, 0)
+
+ return (None, 0)
+
+ def send_head(self, filename):
+ (f, size) = self.get_file_and_size(filename)
+
+ if not f:
+ self.send_error(404, "File %s not found" % filename)
+ return None
+
+ if self.server.etag is not None:
+ if_none_match = self.headers.get('If-None-Match', None)
+ if if_none_match == self.server.etag:
+ self.send_response(304)
+ self.end_headers()
+ f.close()
+ return None
+
+ self.send_response(200)
+
+ if self.server.etag is not None:
+ self.send_header("ETag", self.server.etag)
+
+ self.send_header("Content-type", 'application/octet-stream')
+ self.send_header("Content-Length", size)
+ self.end_headers()
+ return f
+
+
+class YqlHttpFileServer(socketserver.TCPServer, object):
+ def __init__(self):
+ self.http_server_port = get_yql_port('YqlHttpFileServer')
+ super(YqlHttpFileServer, self).__init__(('', self.http_server_port), YqlHttpRequestHandler,
+ bind_and_activate=False)
+ self.file_contents = {}
+ self.file_paths = {}
+ # common etag for all resources
+ self.etag = None
+ self.serve_thread = None
+
+ def start(self):
+ self.allow_reuse_address = True
+ self.server_bind()
+ self.server_activate()
+ self.serve_thread = threading.Thread(target=self.serve_forever)
+ self.serve_thread.start()
+
+ def stop(self):
+ super(YqlHttpFileServer, self).shutdown()
+ self.serve_thread.join()
+ release_yql_port(self.http_server_port)
+ self.http_server_port = None
+
+ def forget_files(self):
+ self.register_files({}, {})
+
+ def set_etag(self, newEtag):
+ self.etag = newEtag
+
+ def register_new_path(self, key, file_path):
+ self.file_paths[key] = file_path
+ return self.compose_http_link(key)
+
+ def register_files(self, file_contents, file_paths):
+ self.file_contents = file_contents
+ self.file_paths = file_paths
+
+ keys = []
+ if file_contents:
+ keys.extend(file_contents.keys())
+
+ if file_paths:
+ keys.extend(file_paths.keys())
+
+ return {k: self.compose_http_link(k) for k in keys}
+
+ def compose_http_link(self, filename):
+ return self.compose_http_host() + '/' + filename
+
+ def compose_http_host(self):
+ if not self.http_server_port:
+ raise Exception('http_server_port is empty. start HTTP server first')
+
+ return 'http://localhost:%d' % self.http_server_port
+
+
+@pytest.fixture(scope='module')
+def yql_http_file_server(request):
+ server = YqlHttpFileServer()
+ server.start()
+ request.addfinalizer(server.stop)
+ return server
diff --git a/yql/essentials/tests/common/test_framework/yql_ports.py b/yql/essentials/tests/common/test_framework/yql_ports.py
new file mode 100644
index 0000000000..d61be1efdf
--- /dev/null
+++ b/yql/essentials/tests/common/test_framework/yql_ports.py
@@ -0,0 +1,43 @@
+from yatest.common.network import PortManager
+import yql_utils
+
+port_manager = None
+
+
+def get_yql_port(service='unknown'):
+ global port_manager
+
+ if port_manager is None:
+ port_manager = PortManager()
+
+ port = port_manager.get_port()
+ yql_utils.log('get port for service %s: %d' % (service, port))
+ return port
+
+
+def release_yql_port(port):
+ if port is None:
+ return
+
+ global port_manager
+ port_manager.release_port(port)
+
+
+def get_yql_port_range(service, count):
+ global port_manager
+
+ if port_manager is None:
+ port_manager = PortManager()
+
+ port = port_manager.get_port_range(None, count)
+ yql_utils.log('get port range for service %s: start_port: %d, count: %d' % (service, port, count))
+ return port
+
+
+def release_yql_port_range(start_port, count):
+ if start_port is None:
+ return
+
+ global port_manager
+ for port in range(start_port, start_port + count):
+ port_manager.release_port(port)
diff --git a/yql/essentials/tests/common/test_framework/yql_utils.py b/yql/essentials/tests/common/test_framework/yql_utils.py
new file mode 100644
index 0000000000..85970dab02
--- /dev/null
+++ b/yql/essentials/tests/common/test_framework/yql_utils.py
@@ -0,0 +1,1043 @@
+from __future__ import print_function
+
+import hashlib
+import io
+import os
+import os.path
+import six
+import sys
+import re
+import tempfile
+import shutil
+
+from google.protobuf import text_format
+from collections import namedtuple, defaultdict, OrderedDict
+from functools import partial
+import codecs
+import decimal
+from threading import Lock
+
+import pytest
+import yatest.common
+import cyson
+
+import logging
+import getpass
+
+logger = logging.getLogger(__name__)
+
+KSV_ATTR = '''{_yql_row_spec={
+ Type=[StructType;
+ [[key;[DataType;String]];
+ [subkey;[DataType;String]];
+ [value;[DataType;String]]]]}}'''
+
+
+def get_param(name, default=None):
+ name = 'YQL_' + name.upper()
+ return yatest.common.get_param(name, os.environ.get(name) or default)
+
+
+def do_custom_query_check(res, sql_query):
+ custom_check = re.search(r"/\* custom check:(.*)\*/", sql_query)
+ if not custom_check:
+ return False
+ custom_check = custom_check.group(1)
+ yt_res_yson = res.results
+ yt_res_yson = cyson.loads(yt_res_yson) if yt_res_yson else cyson.loads("[]")
+ yt_res_yson = replace_vals(yt_res_yson)
+ assert eval(custom_check), 'Condition "%(custom_check)s" fails\nResult:\n %(yt_res_yson)s\n' % locals()
+ return True
+
+
+def get_gateway_cfg_suffix():
+ default_suffix = None
+ return get_param('gateway_config_suffix', default_suffix) or ''
+
+
+def get_gateway_cfg_filename():
+ suffix = get_gateway_cfg_suffix()
+ if suffix == '':
+ return 'gateways.conf'
+ else:
+ return 'gateways-' + suffix + '.conf'
+
+
+def merge_default_gateway_cfg(cfg_dir, gateway_config):
+
+ with open(yql_source_path(os.path.join(cfg_dir, 'gateways.conf'))) as f:
+ text_format.Merge(f.read(), gateway_config)
+
+ suffix = get_gateway_cfg_suffix()
+ if suffix:
+ with open(yql_source_path(os.path.join(cfg_dir, 'gateways-' + suffix + '.conf'))) as f:
+ text_format.Merge(f.read(), gateway_config)
+
+
+def find_file(path):
+ arcadia_root = '.'
+ while '.arcadia.root' not in os.listdir(arcadia_root):
+ arcadia_root = os.path.join(arcadia_root, '..')
+ res = os.path.abspath(os.path.join(arcadia_root, path))
+ assert os.path.exists(res)
+ return res
+
+
+output_path_cache = {}
+
+
+def yql_output_path(*args, **kwargs):
+ if not get_param('LOCAL_BENCH_XX'):
+ # abspath is needed, because output_path may be relative when test is run directly (without ya make).
+ return os.path.abspath(yatest.common.output_path(*args, **kwargs))
+
+ else:
+ if args and args in output_path_cache:
+ return output_path_cache[args]
+ res = os.path.join(tempfile.mkdtemp(prefix='yql_tmp_'), *args)
+ if args:
+ output_path_cache[args] = res
+ return res
+
+
+def yql_binary_path(*args, **kwargs):
+ if not get_param('LOCAL_BENCH_XX'):
+ return yatest.common.binary_path(*args, **kwargs)
+
+ else:
+ return find_file(args[0])
+
+
+def yql_source_path(*args, **kwargs):
+ if not get_param('LOCAL_BENCH_XX'):
+ return yatest.common.source_path(*args, **kwargs)
+ else:
+ return find_file(args[0])
+
+
+def yql_work_path():
+ return os.path.abspath('.')
+
+
+YQLExecResult = namedtuple('YQLExecResult', (
+ 'std_out',
+ 'std_err',
+ 'results',
+ 'results_file',
+ 'opt',
+ 'opt_file',
+ 'plan',
+ 'plan_file',
+ 'program',
+ 'execution_result',
+ 'statistics'
+))
+
+Table = namedtuple('Table', (
+ 'name',
+ 'full_name',
+ 'content',
+ 'file',
+ 'yqlrun_file',
+ 'attr',
+ 'format',
+ 'exists'
+))
+
+
+def new_table(full_name, file_path=None, yqlrun_file=None, content=None, res_dir=None,
+ attr=None, format_name='yson', def_attr=None, should_exist=False, src_file_alternative=None):
+ assert '.' in full_name, 'expected name like cedar.Input'
+ name = '.'.join(full_name.split('.')[1:])
+
+ if res_dir is None:
+ res_dir = get_yql_dir('table_')
+
+ exists = True
+ if content is None:
+ # try read content from files
+ src_file = file_path or yqlrun_file
+ if src_file is None:
+ # nonexistent table, will be output for query
+ content = ''
+ exists = False
+ else:
+ if os.path.exists(src_file):
+ with open(src_file, 'rb') as f:
+ content = f.read()
+ elif src_file_alternative and os.path.exists(src_file_alternative):
+ with open(src_file_alternative, 'rb') as f:
+ content = f.read()
+ src_file = src_file_alternative
+ yqlrun_file, src_file_alternative = src_file_alternative, yqlrun_file
+ else:
+ content = ''
+ exists = False
+
+ file_path = os.path.join(res_dir, name + '.txt')
+ new_yqlrun_file = os.path.join(res_dir, name + '.yqlrun.txt')
+
+ if exists:
+ with open(file_path, 'wb') as f:
+ f.write(content)
+
+ # copy or create yqlrun_file in proper dir
+ if yqlrun_file is not None:
+ shutil.copyfile(yqlrun_file, new_yqlrun_file)
+ else:
+ with open(new_yqlrun_file, 'wb') as f:
+ f.write(content)
+ else:
+ assert not should_exist, locals()
+
+ if attr is None:
+ # try read from file
+ attr_file = None
+ if os.path.exists(file_path + '.attr'):
+ attr_file = file_path + '.attr'
+ elif yqlrun_file is not None and os.path.exists(yqlrun_file + '.attr'):
+ attr_file = yqlrun_file + '.attr'
+ elif src_file_alternative is not None and os.path.exists(src_file_alternative + '.attr'):
+ attr_file = src_file_alternative + '.attr'
+
+ if attr_file is not None:
+ with open(attr_file) as f:
+ attr = f.read()
+
+ if attr is None:
+ attr = def_attr
+
+ if attr is not None:
+ # probably we get it, now write attr file to proper place
+ attr_file = new_yqlrun_file + '.attr'
+ with open(attr_file, 'w') as f:
+ f.write(attr)
+
+ return Table(
+ name,
+ full_name,
+ content,
+ file_path,
+ new_yqlrun_file,
+ attr,
+ format_name,
+ exists
+ )
+
+
+def ensure_dir_exists(dir):
+ # handle race between isdir and mkdir
+ if os.path.isdir(dir):
+ return
+
+ try:
+ os.mkdir(dir)
+ except OSError:
+ if not os.path.isdir(dir):
+ raise
+
+
+def get_yql_dir(prefix):
+ yql_dir = yql_output_path('yql')
+ ensure_dir_exists(yql_dir)
+ res_dir = tempfile.mkdtemp(prefix=prefix, dir=yql_dir)
+ os.chmod(res_dir, 0o755)
+ return res_dir
+
+
+def get_cmd_for_files(arg, files):
+ cmd = ' '.join(
+ arg + ' ' + name + '@' + files[name]
+ for name in files
+ )
+ cmd += ' '
+ return cmd
+
+
+def read_res_file(file_path):
+ if os.path.exists(file_path):
+ with codecs.open(file_path, encoding="utf-8") as descr:
+ res = descr.read().strip()
+ if res == '':
+ log_res = '<EMPTY>'
+ else:
+ log_res = res
+ else:
+ res = ''
+ log_res = '<NOTHING>'
+ return res, log_res
+
+
+def normalize_yson(y):
+ from cyson import YsonBoolean, YsonEntity
+ if isinstance(y, YsonBoolean) or isinstance(y, bool):
+ return 'true' if y else 'false'
+ if isinstance(y, YsonEntity) or y is None:
+ return None
+ if isinstance(y, list):
+ return [normalize_yson(i) for i in y]
+ if isinstance(y, dict):
+ return {normalize_yson(k): normalize_yson(v) for k, v in six.iteritems(y)}
+ s = str(y) if not isinstance(y, six.text_type) else y.encode('utf-8', errors='xmlcharrefreplace')
+ return s
+
+
+volatile_attrs = {'DataSize', 'ModifyTime', 'Id', 'Revision'}
+current_user = getpass.getuser()
+
+
+def _replace_vals_impl(y):
+ if isinstance(y, list):
+ return [_replace_vals_impl(i) for i in y]
+ if isinstance(y, dict):
+ return {_replace_vals_impl(k): _replace_vals_impl(v) for k, v in six.iteritems(y) if k not in volatile_attrs}
+ if isinstance(y, str):
+ s = y.replace('tmp/yql/' + current_user + '/', 'tmp/')
+ s = re.sub(r'tmp/[0-9a-f]+-[0-9a-f]+-[0-9a-f]+-[0-9a-f]+', 'tmp/<temp_table_guid>', s)
+ return s
+ return y
+
+
+def replace_vals(y):
+ y = normalize_yson(y)
+ y = _replace_vals_impl(y)
+ return y
+
+
+def patch_yson_vals(y, patcher):
+ if isinstance(y, list):
+ return [patch_yson_vals(i, patcher) for i in y]
+ if isinstance(y, dict):
+ return {patch_yson_vals(k, patcher): patch_yson_vals(v, patcher) for k, v in six.iteritems(y)}
+ if isinstance(y, str):
+ return patcher(y)
+ return y
+
+
+floatRe = re.compile(r'^-?\d*\.\d+$')
+floatERe = re.compile(r'^-?(\d*\.)?\d+e[\+\-]?\d+$', re.IGNORECASE)
+specFloatRe = re.compile(r'^(-?inf|nan)$', re.IGNORECASE)
+
+
+def fix_double(x):
+ if floatRe.match(x) and len(x.replace('.', '').replace('-', '')) > 10:
+ # Emulate the same double precision as C++ code has
+ decimal.getcontext().rounding = decimal.ROUND_HALF_DOWN
+ decimal.getcontext().prec = 10
+ return str(decimal.Decimal(0) + decimal.Decimal(x)).rstrip('0')
+ if floatERe.match(x):
+ # Emulate the same double precision as C++ code has
+ decimal.getcontext().rounding = decimal.ROUND_HALF_DOWN
+ decimal.getcontext().prec = 10
+ return str(decimal.Decimal(0) + decimal.Decimal(x)).lower()
+ if specFloatRe.match(x):
+ return x.lower()
+ return x
+
+
+def remove_volatile_ast_parts(ast):
+ return re.sub(r"\(KiClusterConfig '\('\(.*\) '\"\d\" '\"\d\" '\"\d\"\)\)", "(KiClusterConfig)", ast)
+
+
+def prepare_program(program, program_file, yql_dir, ext='yql'):
+ assert not (program is None and program_file is None), 'Needs program or program_file'
+
+ if program is None:
+ with codecs.open(program_file, encoding='utf-8') as program_file_descr:
+ program = program_file_descr.read()
+
+ program_file = os.path.join(yql_dir, 'program.' + ext)
+ with codecs.open(program_file, 'w', encoding='utf-8') as program_file_descr:
+ program_file_descr.write(program)
+
+ return program, program_file
+
+
+def get_program_cfg(suite, case, DATA_PATH):
+ ret = []
+ config = os.path.join(DATA_PATH, suite if suite else '', case + '.cfg')
+ if not os.path.exists(config):
+ config = os.path.join(DATA_PATH, suite if suite else '', 'default.cfg')
+
+ if os.path.exists(config):
+ for line in open(config, 'r'):
+ if line.strip():
+ ret.append(tuple(line.split()))
+ else:
+ in_filename = case + '.in'
+ in_path = os.path.join(DATA_PATH, in_filename)
+ default_filename = 'default.in'
+ default_path = os.path.join(DATA_PATH, default_filename)
+ for filepath in [in_path, in_filename, default_path, default_filename]:
+ if os.path.exists(filepath):
+ try:
+ shutil.copy2(filepath, in_path)
+ except shutil.Error:
+ pass
+ ret.append(('in', 'yamr.plato.Input', in_path))
+ break
+
+ if not is_os_supported(ret):
+ pytest.skip('%s not supported here' % sys.platform)
+
+ return ret
+
+
+def find_user_file(suite, path, DATA_PATH):
+ source_path = os.path.join(DATA_PATH, suite, path)
+ if os.path.exists(source_path):
+ return source_path
+ else:
+ try:
+ return yql_binary_path(path)
+ except Exception:
+ raise Exception('Can not find file ' + path)
+
+
+def get_input_tables(suite, cfg, DATA_PATH, def_attr=None):
+ in_tables = []
+ for item in cfg:
+ if item[0] in ('in', 'out'):
+ io, table_name, file_name = item
+ if io == 'in':
+ in_tables.append(new_table(
+ full_name=table_name.replace('yamr.', '').replace('yt.', ''),
+ yqlrun_file=os.path.join(DATA_PATH, suite if suite else '', file_name),
+ src_file_alternative=os.path.join(yql_work_path(), suite if suite else '', file_name),
+ def_attr=def_attr,
+ should_exist=True
+ ))
+ return in_tables
+
+
+def get_tables(suite, cfg, DATA_PATH, def_attr=None):
+ in_tables = []
+ out_tables = []
+ suite_dir = os.path.join(DATA_PATH, suite)
+ res_dir = get_yql_dir('table_')
+
+ for splitted in cfg:
+ if splitted[0] == 'udf' and yatest.common.context.sanitize == 'undefined':
+ pytest.skip("udf under ubsan")
+
+ if len(splitted) == 4:
+ type_name, table, file_name, format_name = splitted
+ elif len(splitted) == 3:
+ type_name, table, file_name = splitted
+ format_name = 'yson'
+ else:
+ continue
+ yqlrun_file = os.path.join(suite_dir, file_name)
+ if type_name == 'in':
+ in_tables.append(new_table(
+ full_name='plato.' + table if '.' not in table else table,
+ yqlrun_file=yqlrun_file,
+ format_name=format_name,
+ def_attr=def_attr,
+ res_dir=res_dir
+ ))
+ if type_name == 'out':
+ out_tables.append(new_table(
+ full_name='plato.' + table if '.' not in table else table,
+ yqlrun_file=yqlrun_file if os.path.exists(yqlrun_file) else None,
+ res_dir=res_dir
+ ))
+ return in_tables, out_tables
+
+
+def get_supported_providers(cfg):
+ providers = 'yt', 'kikimr', 'dq', 'hybrid'
+ for item in cfg:
+ if item[0] == 'providers':
+ providers = [i.strip() for i in ''.join(item[1:]).split(',')]
+ return providers
+
+
+def is_os_supported(cfg):
+ for item in cfg:
+ if item[0] == 'os':
+ return any(sys.platform.startswith(_os) for _os in item[1].split(','))
+ return True
+
+
+def is_xfail(cfg):
+ for item in cfg:
+ if item[0] == 'xfail':
+ return True
+ return False
+
+
+def is_skip_forceblocks(cfg):
+ for item in cfg:
+ if item[0] == 'skip_forceblocks':
+ return True
+ return False
+
+
+def is_canonize_peephole(cfg):
+ for item in cfg:
+ if item[0] == 'canonize_peephole':
+ return True
+ return False
+
+
+def is_peephole_use_blocks(cfg):
+ for item in cfg:
+ if item[0] == 'peephole_use_blocks':
+ return True
+ return False
+
+
+def is_canonize_lineage(cfg):
+ for item in cfg:
+ if item[0] == 'canonize_lineage':
+ return True
+ return False
+
+
+def is_canonize_yt(cfg):
+ for item in cfg:
+ if item[0] == 'canonize_yt':
+ return True
+ return False
+
+
+def is_with_final_result_issues(cfg):
+ for item in cfg:
+ if item[0] == 'with_final_result_issues':
+ return True
+ return False
+
+
+def skip_test_if_required(cfg):
+ for item in cfg:
+ if item[0] == 'skip_test':
+ pytest.skip(item[1])
+
+
+def get_pragmas(cfg):
+ pragmas = []
+ for item in cfg:
+ if item[0] == 'pragma':
+ pragmas.append(' '.join(item))
+ return pragmas
+
+
+def execute(
+ klass=None,
+ program=None,
+ program_file=None,
+ files=None,
+ urls=None,
+ run_sql=False,
+ verbose=False,
+ check_error=True,
+ input_tables=None,
+ output_tables=None,
+ pretty_plan=True,
+ parameters={},
+):
+ '''
+ Executes YQL/SQL
+
+ :param klass: KiKiMRForYQL if instance (default: YQLRun)
+ :param program: string with YQL or SQL program
+ :param program_file: file with YQL or SQL program (optional, if :param program: is None)
+ :param files: dict like {'name': '/path'} with extra files
+ :param urls: dict like {'name': url} with extra files urls
+ :param run_sql: execute sql instead of yql
+ :param verbose: log all results and diagnostics
+ :param check_error: fail on non-zero exit code
+ :param input_tables: list of Table (will be written if not exist)
+ :param output_tables: list of Table (will be returned)
+ :param pretty_plan: whether to use pretty printing for plan or not
+ :param parameters: query parameters as dict like {name: json_value}
+ :return: YQLExecResult
+ '''
+
+ if input_tables is None:
+ input_tables = []
+ else:
+ assert isinstance(input_tables, list)
+ if output_tables is None:
+ output_tables = []
+
+ klass.write_tables(input_tables + output_tables)
+
+ res = klass.yql_exec(
+ program=program,
+ program_file=program_file,
+ files=files,
+ urls=urls,
+ run_sql=run_sql,
+ verbose=verbose,
+ check_error=check_error,
+ tables=(output_tables + input_tables),
+ pretty_plan=pretty_plan,
+ parameters=parameters
+ )
+
+ try:
+ res_tables = klass.get_tables(output_tables)
+ except Exception:
+ if check_error:
+ raise
+ res_tables = {}
+
+ return res, res_tables
+
+
+execute_sql = partial(execute, run_sql=True)
+
+
+def log(s):
+ if get_param('STDERR'):
+ print(s, file=sys.stderr)
+ else:
+ logger.debug(s)
+
+
+def tmpdir_module(request):
+ return tempfile.mkdtemp(prefix='kikimr_test_')
+
+
+@pytest.fixture(name='tmpdir_module', scope='module')
+def tmpdir_module_fixture(request):
+ return tmpdir_module(request)
+
+
+def escape_backslash(s):
+ return s.replace('\\', '\\\\')
+
+
+def get_default_mount_point_config_content():
+ return '''
+ MountPoints {
+ RootAlias: '/lib'
+ MountPoint: '%s'
+ Library: true
+ }
+ ''' % (
+ escape_backslash(yql_source_path('yql/essentials/mount/lib'))
+ )
+
+
+def get_mount_config_file(content=None):
+ config = yql_output_path('mount.cfg')
+ if not os.path.exists(config):
+ with open(config, 'w') as f:
+ f.write(content or get_default_mount_point_config_content())
+ return config
+
+
+def run_command(program, cmd, tmpdir_module=None, stdin=None,
+ check_exit_code=True, env=None, stdout=None):
+ if tmpdir_module is None:
+ tmpdir_module = tempfile.mkdtemp()
+
+ stdin_stream = None
+ if isinstance(stdin, six.string_types):
+ with tempfile.NamedTemporaryFile(
+ prefix='stdin_',
+ dir=tmpdir_module,
+ delete=False
+ ) as stdin_file:
+ stdin_file.write(stdin.encode() if isinstance(stdin, str) else stdin)
+ stdin_stream = open(stdin_file.name)
+ elif isinstance(stdin, io.IOBase):
+ stdin_stream = stdin
+ elif stdin is not None:
+ assert 0, 'Strange stdin ' + repr(stdin)
+
+ if isinstance(cmd, six.string_types):
+ cmd = cmd.split()
+ else:
+ cmd = [str(c) for c in cmd]
+ log(' '.join('\'%s\'' % c if ' ' in c else c for c in cmd))
+ cmd = [program] + cmd
+
+ stderr_stream = None
+ stdout_stream = None
+
+ if stdout:
+ stdout_stream = stdout
+
+ res = yatest.common.execute(
+ cmd,
+ cwd=tmpdir_module,
+ stdin=stdin_stream,
+ stdout=stdout_stream,
+ stderr=stderr_stream,
+ check_exit_code=check_exit_code,
+ env=env,
+ wait=True
+ )
+
+ if res.std_err:
+ log(res.std_err)
+ if res.std_out:
+ log(res.std_out)
+ return res
+
+
+def yson_to_csv(yson_content, columns=None, with_header=True, strict=False):
+ import cyson as yson
+ if columns:
+ headers = sorted(columns)
+ else:
+ headers = set()
+ for item in yson.loads(yson_content):
+ headers.update(six.iterkeys(item))
+ headers = sorted(headers)
+ csv_content = []
+ if with_header:
+ csv_content.append(';'.join(headers))
+ for item in yson.loads(yson_content):
+ if strict and sorted(six.iterkeys(item)) != headers:
+ return None
+ csv_content.append(';'.join([str(item[h]).replace('YsonEntity', '').encode('string_escape') if h in item else '' for h in headers]))
+ return '\n'.join(csv_content)
+
+
+udfs_lock = Lock()
+
+
+def get_udfs_path(extra_paths=None):
+ essentials_udfs_build_path = yatest.common.build_path('yql/essentials/udfs')
+ udfs_build_path = yatest.common.build_path('yql/udfs')
+ ydb_udfs_build_path = yatest.common.build_path('contrib/ydb/library/yql/udfs')
+ contrib_ydb_udfs_build_path = yatest.common.build_path('contrib/ydb/library/yql/udfs')
+ rthub_udfs_build_path = yatest.common.build_path('robot/rthub/yql/udfs')
+ kwyt_udfs_build_path = yatest.common.build_path('robot/kwyt/yql/udfs')
+
+ try:
+ udfs_bin_path = yatest.common.binary_path('yql/udfs')
+ except Exception:
+ udfs_bin_path = None
+
+ try:
+ udfs_project_path = yql_binary_path('yql/library/test_framework/udfs_deps')
+ except Exception:
+ udfs_project_path = None
+
+ try:
+ ydb_udfs_project_path = yql_binary_path('yql/essentials/tests/common/test_framework/udfs_deps')
+ except Exception:
+ ydb_udfs_project_path = None
+
+ merged_udfs_path = yql_output_path('yql_udfs')
+ with udfs_lock:
+ if not os.path.isdir(merged_udfs_path):
+ os.mkdir(merged_udfs_path)
+
+ udfs_paths = [
+ udfs_project_path,
+ ydb_udfs_project_path,
+ udfs_bin_path,
+ essentials_udfs_build_path,
+ udfs_build_path,
+ ydb_udfs_build_path,
+ contrib_ydb_udfs_build_path,
+ rthub_udfs_build_path,
+ kwyt_udfs_build_path
+ ]
+ if extra_paths is not None:
+ udfs_paths += extra_paths
+
+ log('process search UDF in: %s, %s, %s, %s' % (udfs_project_path, ydb_udfs_project_path, udfs_bin_path, udfs_build_path))
+ for _udfs_path in udfs_paths:
+ if _udfs_path:
+ for dirpath, dnames, fnames in os.walk(_udfs_path):
+ for f in fnames:
+ if f.endswith('.so'):
+ f = os.path.join(dirpath, f)
+ if not os.path.exists(f) and os.path.lexists(f): # seems like broken symlink
+ try:
+ os.unlink(f)
+ except OSError:
+ pass
+ link_name = os.path.join(merged_udfs_path, os.path.basename(f))
+ if not os.path.exists(link_name):
+ os.symlink(f, link_name)
+ log('Added UDF: ' + f)
+ return merged_udfs_path
+
+
+def get_test_prefix():
+ return 'yql_tmp_' + hashlib.md5(yatest.common.context.test_name).hexdigest()
+
+
+def normalize_plan_ids(plan, no_detailed=False):
+ remapOps = {}
+
+ for node in sorted(filter(lambda n: n["type"] == "in", plan["Basic"]["nodes"]), key=lambda n: n.get("name")):
+ if node["id"] not in remapOps:
+ remapOps[node["id"]] = len(remapOps) + 1
+
+ for node in plan["Basic"]["nodes"]:
+ if node["id"] not in remapOps:
+ remapOps[node["id"]] = len(remapOps) + 1
+
+ def subst_basic(y):
+ if isinstance(y, list):
+ return [subst_basic(i) for i in y]
+ if isinstance(y, dict):
+ res = {}
+ for k, v in six.iteritems(y):
+ if k in {'source', 'target', 'id'}:
+ res[k] = remapOps.get(v)
+ elif k == "links":
+ res[k] = sorted(subst_basic(v), key=lambda x: (x["source"], x["target"]))
+ elif k == "nodes":
+ res[k] = sorted(subst_basic(v), key=lambda x: x["id"])
+ else:
+ res[k] = subst_basic(v)
+ return res
+ return y
+
+ # Sort and normalize input ids
+ def subst_detailed(y):
+ if isinstance(y, list):
+ return [subst_detailed(i) for i in y]
+ if isinstance(y, dict):
+ res = {}
+ for k, v in six.iteritems(y):
+ if k == "DependsOn":
+ res[k] = sorted([remapOps.get(i) for i in v])
+ elif k == "Providers":
+ res[k] = v
+ elif k in {'OperationRoot', 'Id'}:
+ res[k] = remapOps.get(v)
+ else:
+ res[k] = subst_detailed(v)
+ return res
+ return y
+
+ if no_detailed:
+ return {"Basic": subst_basic(plan["Basic"])}
+ return {"Basic": subst_basic(plan["Basic"]), "Detailed": subst_detailed(plan["Detailed"])}
+
+
+def normalized_plan_stats(plan):
+ renameMap = {
+ "MrLMap!": "YtMap!",
+ "MrMapReduce!": "YtMapReduce!",
+ "MrLReduce!": "YtMapReduce!",
+ "MrOrderedReduce!": "YtReduce!",
+ "MrSort!": "YtSort!",
+ "MrCopy!": "YtCopy!",
+ "YtMerge!": "YtCopy!",
+ "MrFill!": "YtFill!",
+ "MrDrop!": "YtDropTable!",
+ "YtTouch!": None,
+ "MrReadTable!": None,
+ "YtReadTable!": None,
+ "MrPublish!": "YtPublish!",
+ "MrReadTableScheme!": "YtReadTableScheme!",
+ }
+
+ normalizedStat = defaultdict(int)
+
+ for op, stat in six.iteritems(plan["Detailed"]["OperationStats"]):
+ renamedOp = renameMap.get(op, op)
+ if renamedOp is not None:
+ normalizedStat[renamedOp] += stat
+
+ return normalizedStat
+
+
+def normalize_table_yson(y):
+ from cyson import YsonEntity
+ if isinstance(y, list):
+ return [normalize_table_yson(i) for i in y]
+ if isinstance(y, dict):
+ normDict = OrderedDict()
+ for k, v in sorted(six.iteritems(y), key=lambda x: x[0], reverse=True):
+ if k == "_other":
+ normDict[normalize_table_yson(k)] = sorted(normalize_table_yson(v))
+ elif v != "Void" and v is not None and not isinstance(v, YsonEntity):
+ normDict[normalize_table_yson(k)] = normalize_table_yson(v)
+ return normDict
+ return y
+
+
+def dump_table_yson(res_yson, sort=True):
+ rows = normalize_table_yson(cyson.loads('[' + res_yson + ']'))
+ if sort:
+ rows = sorted(rows)
+ return cyson.dumps(rows, format="pretty")
+
+
+def normalize_source_code_path(s):
+ # remove contrib/
+ s = re.sub(r'\b(contrib/)(ydb/library/yql.*)', r'\2', s)
+ # replace line number in source code with 'xxx'
+ s = re.sub(r'\b(yql/[\w/]+(?:\.cpp|\.h)):(?:\d+)', r'\1:xxx', s)
+ return re.sub(r'(/lib/yql/[\w/]+(?:\.yql|\.sql)):(?:\d+):(?:\d+)', r'\1:xxx:yyy', s)
+
+
+def do_get_files(suite, config, DATA_PATH, config_key):
+ files = dict()
+ suite_dir = os.path.join(DATA_PATH, suite)
+ res_dir = None
+ for line in config:
+ if line[0] == config_key:
+ _, name, path = line
+ userpath = find_user_file(suite, path, DATA_PATH)
+ relpath = os.path.relpath(userpath, suite_dir)
+ if os.path.exists(os.path.join('cwd', relpath)):
+ path = relpath
+ else:
+ path = userpath
+
+ if not res_dir:
+ res_dir = get_yql_dir('file_')
+
+ new_path = os.path.join(res_dir, os.path.basename(path))
+ shutil.copyfile(path, new_path)
+
+ files[name] = new_path
+
+ return files
+
+
+def get_files(suite, config, DATA_PATH):
+ return do_get_files(suite, config, DATA_PATH, 'file')
+
+
+def get_http_files(suite, config, DATA_PATH):
+ return do_get_files(suite, config, DATA_PATH, 'http_file')
+
+
+def get_yt_files(suite, config, DATA_PATH):
+ return do_get_files(suite, config, DATA_PATH, 'yt_file')
+
+
+def get_syntax_version(program):
+ syntax_version_param = get_param('SYNTAX_VERSION')
+ default_syntax_version = 1
+ if 'syntax version 0' in program:
+ return 0
+ elif 'syntax version 1' in program:
+ return 1
+ elif syntax_version_param:
+ return int(syntax_version_param)
+ else:
+ return default_syntax_version
+
+
+def ansi_lexer_enabled(program):
+ return 'ansi_lexer' in program
+
+
+def pytest_get_current_part(path):
+ folder = os.path.dirname(path)
+ folder_name = os.path.basename(folder)
+ assert folder_name.startswith('part'), "Current folder is {}".format(folder_name)
+ current = int(folder_name[len('part'):])
+
+ parent = os.path.dirname(folder)
+ maxpart = max([int(part[len('part'):]) if part.startswith('part') else -1 for part in os.listdir(parent)])
+ assert maxpart > 0, "Cannot find parts in {}".format(parent)
+ return (current, 1 + maxpart)
+
+
+def normalize_result(res, sort):
+ res = cyson.loads(res) if res else cyson.loads("[]")
+ res = replace_vals(res)
+ for r in res:
+ for data in r['Write']:
+ if sort and 'Data' in data:
+ data['Data'] = sorted(data['Data'])
+ if 'Ref' in data:
+ data['Ref'] = []
+ data['Truncated'] = True
+ if 'Data' in data and len(data['Data']) == 0:
+ del data['Data']
+ return res
+
+
+def stable_write(writer, node):
+ if hasattr(node, 'attributes'):
+ writer.begin_attributes()
+ for k in sorted(node.attributes.keys()):
+ writer.key(k)
+ stable_write(writer, node.attributes[k])
+ writer.end_attributes()
+ if isinstance(node, list):
+ writer.begin_list()
+ for r in node:
+ stable_write(writer, r)
+ writer.end_list()
+ return
+ if isinstance(node, dict):
+ writer.begin_map()
+ for k in sorted(node.keys()):
+ writer.key(k)
+ stable_write(writer, node[k])
+ writer.end_map()
+ return
+ writer.write(node)
+
+
+def stable_result_file(res):
+ path = res.results_file
+ assert os.path.exists(path)
+ with open(path) as f:
+ res = f.read()
+ res = cyson.loads(res)
+ res = replace_vals(res)
+ for r in res:
+ for data in r['Write']:
+ if 'Unordered' in r and 'Data' in data:
+ data['Data'] = sorted(data['Data'])
+ with open(path, 'w') as f:
+ writer = cyson.Writer(stream=cyson.OutputStream.from_file(f), format='pretty', mode='node')
+ writer.begin_stream()
+ stable_write(writer, res)
+ writer.end_stream()
+ with open(path) as f:
+ return f.read()
+
+
+def stable_table_file(table):
+ path = table.file
+ assert os.path.exists(path)
+ assert table.attr is not None
+ is_sorted = False
+ for column in cyson.loads(table.attr)['schema']:
+ if 'sort_order' in column:
+ is_sorted = True
+ break
+ if not is_sorted:
+ with open(path) as f:
+ r = cyson.Reader(cyson.InputStream.from_file(f), mode='list_fragment')
+ lst = sorted(list(r.list_fragments()))
+ with open(path, 'w') as f:
+ writer = cyson.Writer(stream=cyson.OutputStream.from_file(f), format='pretty', mode='list_fragment')
+ writer.begin_stream()
+ for r in lst:
+ stable_write(writer, r)
+ writer.end_stream()
+ with open(path) as f:
+ return f.read()
+
+
+class LoggingDowngrade(object):
+
+ def __init__(self, loggers, level=logging.CRITICAL):
+ self.loggers = [(name, logging.getLogger(name).getEffectiveLevel()) for name in loggers]
+ self.level = level
+
+ def __enter__(self):
+ self.prev_levels = []
+ for name, _ in self.loggers:
+ log = logging.getLogger(name)
+ log.setLevel(self.level)
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ for name, level in self.loggers:
+ log = logging.getLogger(name)
+ log.setLevel(level)
+ return True
diff --git a/yql/essentials/tests/common/test_framework/yqlrun.py b/yql/essentials/tests/common/test_framework/yqlrun.py
new file mode 100644
index 0000000000..b96641912a
--- /dev/null
+++ b/yql/essentials/tests/common/test_framework/yqlrun.py
@@ -0,0 +1,346 @@
+import os
+import shutil
+import yatest.common
+import yql_utils
+import cyson as yson
+import yql.essentials.providers.common.proto.gateways_config_pb2 as gateways_config_pb2
+import yql.essentials.core.file_storage.proto.file_storage_pb2 as file_storage_pb2
+
+import six
+
+from google.protobuf import text_format
+
+ARCADIA_PREFIX = 'arcadia/'
+ARCADIA_TESTS_DATA_PREFIX = 'arcadia_tests_data/'
+
+VAR_CHAR_PREFIX = '$'
+FIX_DIR_PREFIXES = {
+ 'SOURCE': yatest.common.source_path,
+ 'BUILD': yatest.common.build_path,
+ 'TEST_SOURCE': yatest.common.test_source_path,
+ 'DATA': yatest.common.data_path,
+ 'BINARY': yatest.common.binary_path,
+}
+
+
+class YQLRun(object):
+
+ def __init__(self, udfs_dir=None, prov='yt', use_sql2yql=False, keep_temp=True, binary=None, gateway_config=None, fs_config=None, extra_args=[], cfg_dir=None, support_udfs=True):
+ if binary is None:
+ self.yqlrun_binary = yql_utils.yql_binary_path(os.getenv('YQL_YQLRUN_PATH') or 'contrib/ydb/library/yql/tools/yqlrun/yqlrun')
+ else:
+ self.yqlrun_binary = binary
+ self.extra_args = extra_args
+
+ try:
+ self.sql2yql_binary = yql_utils.yql_binary_path(os.getenv('YQL_SQL2YQL_PATH') or 'yql/essentials/tools/sql2yql/sql2yql')
+ except BaseException:
+ self.sql2yql_binary = None
+
+ try:
+ self.udf_resolver_binary = yql_utils.yql_binary_path(os.getenv('YQL_UDFRESOLVER_PATH') or 'yql/essentials/tools/udf_resolver/udf_resolver')
+ except Exception:
+ self.udf_resolver_binary = None
+
+ if support_udfs:
+ if udfs_dir is None:
+ self.udfs_path = yql_utils.get_udfs_path()
+ else:
+ self.udfs_path = udfs_dir
+ else:
+ self.udfs_path = None
+
+ res_dir = yql_utils.get_yql_dir(prefix='yqlrun_')
+ self.res_dir = res_dir
+ self.tables = {}
+ self.prov = prov
+ self.use_sql2yql = use_sql2yql
+ self.keep_temp = keep_temp
+
+ self.gateway_config = gateways_config_pb2.TGatewaysConfig()
+ if gateway_config is not None:
+ text_format.Merge(gateway_config, self.gateway_config)
+
+ yql_utils.merge_default_gateway_cfg(cfg_dir or 'yql/essentials/cfg/tests', self.gateway_config)
+
+ self.fs_config = file_storage_pb2.TFileStorageConfig()
+
+ with open(yql_utils.yql_source_path(os.path.join(cfg_dir or 'yql/essentials/cfg/tests', 'fs.conf'))) as f:
+ text_format.Merge(f.read(), self.fs_config)
+
+ if fs_config is not None:
+ text_format.Merge(fs_config, self.fs_config)
+
+ if yql_utils.get_param('USE_NATIVE_YT_TYPES'):
+ attr = self.gateway_config.Yt.DefaultSettings.add()
+ attr.Name = 'UseNativeYtTypes'
+ attr.Value = 'true'
+
+ if yql_utils.get_param('SQL_FLAGS'):
+ flags = yql_utils.get_param('SQL_FLAGS').split(',')
+ self.gateway_config.SqlCore.TranslationFlags.extend(flags)
+
+ def yql_exec(self, program=None, program_file=None, files=None, urls=None,
+ run_sql=False, verbose=False, check_error=True, tables=None, pretty_plan=True,
+ wait=True, parameters={}, extra_env={}, require_udf_resolver=False, scan_udfs=True):
+ del pretty_plan
+
+ res_dir = self.res_dir
+
+ def res_file_path(name):
+ return os.path.join(res_dir, name)
+
+ opt_file = res_file_path('opt.yql')
+ results_file = res_file_path('results.txt')
+ plan_file = res_file_path('plan.txt')
+ err_file = res_file_path('err.txt')
+
+ udfs_dir = self.udfs_path
+ prov = self.prov
+
+ program, program_file = yql_utils.prepare_program(program, program_file, res_dir,
+ ext='sql' if run_sql else 'yql')
+
+ syntax_version = yql_utils.get_syntax_version(program)
+ ansi_lexer = yql_utils.ansi_lexer_enabled(program)
+
+ if run_sql and self.use_sql2yql:
+ orig_sql = program_file + '.orig_sql'
+ shutil.copy2(program_file, orig_sql)
+ cmd = [
+ self.sql2yql_binary,
+ orig_sql,
+ '--yql',
+ '--output=' + program_file,
+ '--syntax-version=%d' % syntax_version
+ ]
+ if ansi_lexer:
+ cmd.append('--ansi-lexer')
+ env = {'YQL_DETERMINISTIC_MODE': '1'}
+ env.update(extra_env)
+ for var in [
+ 'LLVM_PROFILE_FILE',
+ 'GO_COVERAGE_PREFIX',
+ 'PYTHON_COVERAGE_PREFIX',
+ 'NLG_COVERAGE_FILENAME',
+ 'YQL_EXPORT_PG_FUNCTIONS_DIR',
+ 'YQL_ALLOW_ALL_PG_FUNCTIONS',
+ ]:
+ if var in os.environ:
+ env[var] = os.environ[var]
+ yatest.common.process.execute(cmd, cwd=res_dir, env=env)
+
+ with open(program_file) as f:
+ yql_program = f.read()
+ with open(program_file, 'w') as f:
+ f.write(yql_program)
+
+ gateways_cfg_file = res_file_path('gateways.conf')
+ with open(gateways_cfg_file, 'w') as f:
+ f.write(str(self.gateway_config))
+
+ fs_cfg_file = res_file_path('fs.conf')
+ with open(fs_cfg_file, 'w') as f:
+ f.write(str(self.fs_config))
+
+ cmd = self.yqlrun_binary + ' '
+
+ if yql_utils.get_param('TRACE_OPT'):
+ cmd += '--trace-opt '
+
+ cmd += '-L ' \
+ '--program=%(program_file)s ' \
+ '--expr-file=%(opt_file)s ' \
+ '--result-file=%(results_file)s ' \
+ '--plan-file=%(plan_file)s ' \
+ '--err-file=%(err_file)s ' \
+ '--gateways=%(prov)s ' \
+ '--syntax-version=%(syntax_version)d ' \
+ '--tmp-dir=%(res_dir)s ' \
+ '--gateways-cfg=%(gateways_cfg_file)s ' \
+ '--fs-cfg=%(fs_cfg_file)s ' % locals()
+
+ if self.udfs_path is not None:
+ cmd += '--udfs-dir=%(udfs_dir)s ' % locals()
+
+ if ansi_lexer:
+ cmd += '--ansi-lexer '
+
+ if self.keep_temp:
+ cmd += '--keep-temp '
+
+ if self.extra_args:
+ cmd += " ".join(self.extra_args) + " "
+
+ cmd += '--mounts=' + yql_utils.get_mount_config_file() + ' '
+ cmd += '--validate-result-format '
+
+ if files:
+ for f in files:
+ if files[f].startswith(ARCADIA_PREFIX): # how does it work with folders? and does it?
+ files[f] = yatest.common.source_path(files[f][len(ARCADIA_PREFIX):])
+ continue
+ if files[f].startswith(ARCADIA_TESTS_DATA_PREFIX):
+ files[f] = yatest.common.data_path(files[f][len(ARCADIA_TESTS_DATA_PREFIX):])
+ continue
+
+ if files[f].startswith(VAR_CHAR_PREFIX):
+ for prefix, func in six.iteritems(FIX_DIR_PREFIXES):
+ if files[f].startswith(VAR_CHAR_PREFIX + prefix):
+ real_path = func(files[f][len(prefix) + 2:]) # $ + prefix + /
+ break
+ else:
+ raise Exception("unknown prefix in file path %s" % (files[f],))
+ copy_dest = os.path.join(res_dir, f)
+ if not os.path.exists(os.path.dirname(copy_dest)):
+ os.makedirs(os.path.dirname(copy_dest))
+ shutil.copy2(
+ real_path,
+ copy_dest,
+ )
+ files[f] = f
+ continue
+
+ if not files[f].startswith('/'): # why do we check files[f] instead of f here?
+ path_to_copy = os.path.join(
+ yatest.common.work_path(),
+ files[f]
+ )
+ if '/' in files[f]:
+ copy_dest = os.path.join(
+ res_dir,
+ os.path.dirname(files[f])
+ )
+ if not os.path.exists(copy_dest):
+ os.makedirs(copy_dest)
+ else:
+ copy_dest = res_dir
+ files[f] = os.path.basename(files[f])
+ shutil.copy2(path_to_copy, copy_dest)
+ else:
+ shutil.copy2(files[f], res_dir)
+ files[f] = os.path.basename(files[f])
+ cmd += yql_utils.get_cmd_for_files('--file', files)
+
+ if urls:
+ cmd += yql_utils.get_cmd_for_files('--url', urls)
+
+ optimize_only = False
+ if tables:
+ for table in tables:
+ self.tables[table.full_name] = table
+ if table.format != 'yson':
+ optimize_only = True
+ for name in self.tables:
+ cmd += '--table=yt.%s@%s ' % (name, self.tables[name].yqlrun_file)
+
+ if "--lineage" not in self.extra_args:
+ if optimize_only:
+ cmd += '-O '
+ else:
+ cmd += '--run '
+
+ if yql_utils.get_param('UDF_RESOLVER') or require_udf_resolver:
+ assert self.udf_resolver_binary, "Missing udf_resolver binary"
+ cmd += '--udf-resolver=' + self.udf_resolver_binary + ' '
+ if scan_udfs:
+ cmd += '--scan-udfs '
+ if not yatest.common.context.sanitize:
+ cmd += '--udf-resolver-filter-syscalls '
+
+ if run_sql and not self.use_sql2yql:
+ cmd += '--sql '
+
+ if parameters:
+ parameters_file = res_file_path('params.yson')
+ with open(parameters_file, 'w') as f:
+ f.write(six.ensure_str(yson.dumps(parameters)))
+ cmd += '--params-file=%s ' % parameters_file
+
+ if verbose:
+ yql_utils.log('prov is ' + self.prov)
+
+ env = {'YQL_DETERMINISTIC_MODE': '1'}
+ env.update(extra_env)
+ for var in [
+ 'LLVM_PROFILE_FILE',
+ 'GO_COVERAGE_PREFIX',
+ 'PYTHON_COVERAGE_PREFIX',
+ 'NLG_COVERAGE_FILENAME',
+ 'YQL_EXPORT_PG_FUNCTIONS_DIR',
+ 'YQL_ALLOW_ALL_PG_FUNCTIONS',
+ ]:
+ if var in os.environ:
+ env[var] = os.environ[var]
+ if yql_utils.get_param('STDERR'):
+ debug_udfs_dir = os.path.join(os.path.abspath('.'), '..', '..', '..')
+ env_setters = ";".join("{}={}".format(k, v) for k, v in six.iteritems(env))
+ yql_utils.log('GDB launch command:')
+ yql_utils.log('(cd "%s" && %s ya tool gdb --args %s)' % (res_dir, env_setters, cmd.replace(udfs_dir, debug_udfs_dir)))
+
+ proc_result = yatest.common.process.execute(cmd.strip().split(), check_exit_code=False, cwd=res_dir, env=env)
+ if proc_result.exit_code != 0 and check_error:
+ with open(err_file, 'r') as f:
+ err_file_text = f.read()
+ assert 0, \
+ 'Command\n%(command)s\n finished with exit code %(code)d, stderr:\n\n%(stderr)s\n\nerror file:\n%(err_file)s' % {
+ 'command': cmd,
+ 'code': proc_result.exit_code,
+ 'stderr': proc_result.std_err,
+ 'err_file': err_file_text
+ }
+
+ if os.path.exists(results_file) and os.stat(results_file).st_size == 0:
+ os.unlink(results_file) # kikimr yql-exec compatibility
+
+ results, log_results = yql_utils.read_res_file(results_file)
+ plan, log_plan = yql_utils.read_res_file(plan_file)
+ opt, log_opt = yql_utils.read_res_file(opt_file)
+ err, log_err = yql_utils.read_res_file(err_file)
+
+ if verbose:
+ yql_utils.log('PROGRAM:')
+ yql_utils.log(program)
+ yql_utils.log('OPT:')
+ yql_utils.log(log_opt)
+ yql_utils.log('PLAN:')
+ yql_utils.log(log_plan)
+ yql_utils.log('RESULTS:')
+ yql_utils.log(log_results)
+ yql_utils.log('ERROR:')
+ yql_utils.log(log_err)
+
+ return yql_utils.YQLExecResult(
+ proc_result.std_out,
+ yql_utils.normalize_source_code_path(err.replace(res_dir, '<tmp_path>')),
+ results,
+ results_file,
+ opt,
+ opt_file,
+ plan,
+ plan_file,
+ program,
+ proc_result,
+ None
+ )
+
+ def create_empty_tables(self, tables):
+ pass
+
+ def write_tables(self, tables):
+ pass
+
+ def get_tables(self, tables):
+ res = {}
+ for table in tables:
+ # recreate table after yql program was executed
+ res[table.full_name] = yql_utils.new_table(
+ table.full_name,
+ yqlrun_file=self.tables[table.full_name].yqlrun_file,
+ res_dir=self.res_dir
+ )
+
+ yql_utils.log('YQLRun table ' + table.full_name)
+ yql_utils.log(res[table.full_name].content)
+
+ return res
diff --git a/yql/essentials/tests/common/udf_test/test.py b/yql/essentials/tests/common/udf_test/test.py
new file mode 100644
index 0000000000..218b05b4bd
--- /dev/null
+++ b/yql/essentials/tests/common/udf_test/test.py
@@ -0,0 +1,111 @@
+import os
+import os.path
+import glob
+import codecs
+import shutil
+
+import pytest
+
+import yql_utils
+from yqlrun import YQLRun
+
+import yatest.common
+
+project_path = yatest.common.context.project_path
+SOURCE_PATH = yql_utils.yql_source_path((project_path + '/cases').replace('\\', '/'))
+DATA_PATH = yatest.common.output_path('cases')
+ASTDIFF_PATH = yql_utils.yql_binary_path(os.getenv('YQL_ASTDIFF_PATH') or 'yql/essentials/tools/astdiff/astdiff')
+
+
+def pytest_generate_tests(metafunc):
+ if os.path.exists(SOURCE_PATH):
+ shutil.copytree(SOURCE_PATH, DATA_PATH)
+ cases = sorted([os.path.basename(sql_query)[:-4] for sql_query in glob.glob(DATA_PATH + '/*.sql')])
+
+ else:
+ cases = []
+ metafunc.parametrize(['case'], [(case, ) for case in cases])
+
+
+def test(case):
+ program_file = os.path.join(DATA_PATH, case + '.sql')
+
+ with codecs.open(program_file, encoding='utf-8') as f:
+ program = f.readlines()
+
+ header = program[0]
+ canonize_ast = False
+
+ if header.startswith('--ignore'):
+ pytest.skip(header)
+ elif header.startswith('--sanitizer ignore') and yatest.common.context.sanitize is not None:
+ pytest.skip(header)
+ elif header.startswith('--sanitizer ignore address') and yatest.common.context.sanitize == 'address':
+ pytest.skip(header)
+ elif header.startswith('--sanitizer ignore memory') and yatest.common.context.sanitize == 'memory':
+ pytest.skip(header)
+ elif header.startswith('--sanitizer ignore thread') and yatest.common.context.sanitize == 'thread':
+ pytest.skip(header)
+ elif header.startswith('--sanitizer ignore undefined') and yatest.common.context.sanitize == 'undefined':
+ pytest.skip(header)
+ elif header.startswith('--canonize ast'):
+ canonize_ast = True
+
+ program = '\n'.join(['use plato;'] + program)
+
+ cfg = yql_utils.get_program_cfg(None, case, DATA_PATH)
+ files = {}
+ diff_tool = None
+ scan_udfs = False
+ for item in cfg:
+ if item[0] == 'file':
+ files[item[1]] = item[2]
+ if item[0] == 'diff_tool':
+ diff_tool = item[1:]
+ if item[0] == 'scan_udfs':
+ scan_udfs = True
+
+ in_tables = yql_utils.get_input_tables(None, cfg, DATA_PATH, def_attr=yql_utils.KSV_ATTR)
+
+ udfs_dir = yql_utils.get_udfs_path([
+ yatest.common.build_path(os.path.join(yatest.common.context.project_path, ".."))
+ ])
+
+ xfail = yql_utils.is_xfail(cfg)
+ if yql_utils.get_param('TARGET_PLATFORM') and xfail:
+ pytest.skip('xfail is not supported on non-default target platform')
+
+ extra_env = dict(os.environ)
+ extra_env["YQL_UDF_RESOLVER"] = "1"
+ extra_env["YQL_ARCADIA_BINARY_PATH"] = os.path.expandvars(yatest.common.build_path('.'))
+ extra_env["YQL_ARCADIA_SOURCE_PATH"] = os.path.expandvars(yatest.common.source_path('.'))
+ extra_env["Y_NO_AVX_IN_DOT_PRODUCT"] = "1"
+
+ # this breaks tests using V0 syntax
+ if "YA_TEST_RUNNER" in extra_env:
+ del extra_env["YA_TEST_RUNNER"]
+
+ yqlrun_res = YQLRun(udfs_dir=udfs_dir, prov='yt', use_sql2yql=False, cfg_dir=os.getenv('YQL_CONFIG_DIR') or 'yql/essentials/cfg/udf_test').yql_exec(
+ program=program,
+ run_sql=True,
+ tables=in_tables,
+ files=files,
+ check_error=not xfail,
+ extra_env=extra_env,
+ require_udf_resolver=True,
+ scan_udfs=scan_udfs
+ )
+
+ if xfail:
+ assert yqlrun_res.execution_result.exit_code != 0
+
+ results_path = os.path.join(yql_utils.yql_output_path(), case + '.results.txt')
+ with open(results_path, 'w') as f:
+ f.write(yqlrun_res.results)
+
+ to_canonize = [yqlrun_res.std_err] if xfail else [yatest.common.canonical_file(yqlrun_res.results_file, local=True, diff_tool=diff_tool)]
+
+ if canonize_ast:
+ to_canonize += [yatest.common.canonical_file(yqlrun_res.opt_file, local=True, diff_tool=ASTDIFF_PATH)]
+
+ return to_canonize
diff --git a/yql/essentials/tests/common/udf_test/ya.make b/yql/essentials/tests/common/udf_test/ya.make
new file mode 100644
index 0000000000..37570be0ab
--- /dev/null
+++ b/yql/essentials/tests/common/udf_test/ya.make
@@ -0,0 +1,9 @@
+PY23_LIBRARY()
+
+TEST_SRCS(test.py)
+
+PEERDIR(
+ yql/essentials/tests/common/test_framework
+)
+
+END()
diff --git a/yql/essentials/tests/common/ya.make b/yql/essentials/tests/common/ya.make
new file mode 100644
index 0000000000..1ac429bbb1
--- /dev/null
+++ b/yql/essentials/tests/common/ya.make
@@ -0,0 +1,5 @@
+RECURSE(
+ test_framework
+ udf_test
+)
+
diff --git a/yql/essentials/tests/ya.make b/yql/essentials/tests/ya.make
new file mode 100644
index 0000000000..d2d08c248b
--- /dev/null
+++ b/yql/essentials/tests/ya.make
@@ -0,0 +1,6 @@
+SUBSCRIBER(g:yql)
+
+RECURSE(
+ common
+)
+
diff --git a/yql/essentials/ya.make b/yql/essentials/ya.make
index c209762b0b..7c34bf60b7 100644
--- a/yql/essentials/ya.make
+++ b/yql/essentials/ya.make
@@ -9,6 +9,7 @@ RECURSE(
providers
public
sql
+ tests
tools
types
udfs