aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/tools/cython/Cython/Tests
diff options
context:
space:
mode:
authoralexv-smirnov <alex@ydb.tech>2023-03-15 19:59:12 +0300
committeralexv-smirnov <alex@ydb.tech>2023-03-15 19:59:12 +0300
commit056bb284ccf8dd6793ec3a54ffa36c4fb2b9ad11 (patch)
tree4740980126f32e3af7937ba0ca5f83e59baa4ab0 /contrib/tools/cython/Cython/Tests
parent269126dcced1cc8b53eb4398b4a33e5142f10290 (diff)
downloadydb-056bb284ccf8dd6793ec3a54ffa36c4fb2b9ad11.tar.gz
add library/cpp/actors, ymake build to ydb oss export
Diffstat (limited to 'contrib/tools/cython/Cython/Tests')
-rw-r--r--contrib/tools/cython/Cython/Tests/TestCodeWriter.py82
-rw-r--r--contrib/tools/cython/Cython/Tests/TestCythonUtils.py11
-rw-r--r--contrib/tools/cython/Cython/Tests/TestJediTyper.py225
-rw-r--r--contrib/tools/cython/Cython/Tests/TestStringIOTree.py67
-rw-r--r--contrib/tools/cython/Cython/Tests/__init__.py1
-rw-r--r--contrib/tools/cython/Cython/Tests/xmlrunner.py397
6 files changed, 783 insertions, 0 deletions
diff --git a/contrib/tools/cython/Cython/Tests/TestCodeWriter.py b/contrib/tools/cython/Cython/Tests/TestCodeWriter.py
new file mode 100644
index 0000000000..42e457da20
--- /dev/null
+++ b/contrib/tools/cython/Cython/Tests/TestCodeWriter.py
@@ -0,0 +1,82 @@
+from Cython.TestUtils import CythonTest
+
+class TestCodeWriter(CythonTest):
+ # CythonTest uses the CodeWriter heavily, so do some checking by
+ # roundtripping Cython code through the test framework.
+
+ # Note that this test is dependent upon the normal Cython parser
+ # to generate the input trees to the CodeWriter. This save *a lot*
+ # of time; better to spend that time writing other tests than perfecting
+ # this one...
+
+ # Whitespace is very significant in this process:
+ # - always newline on new block (!)
+ # - indent 4 spaces
+ # - 1 space around every operator
+
+ def t(self, codestr):
+ self.assertCode(codestr, self.fragment(codestr).root)
+
+ def test_print(self):
+ self.t(u"""
+ print x, y
+ print x + y ** 2
+ print x, y, z,
+ """)
+
+ def test_if(self):
+ self.t(u"if x:\n pass")
+
+ def test_ifelifelse(self):
+ self.t(u"""
+ if x:
+ pass
+ elif y:
+ pass
+ elif z + 34 ** 34 - 2:
+ pass
+ else:
+ pass
+ """)
+
+ def test_def(self):
+ self.t(u"""
+ def f(x, y, z):
+ pass
+ def f(x = 34, y = 54, z):
+ pass
+ """)
+
+ def test_longness_and_signedness(self):
+ self.t(u"def f(unsigned long long long long long int y):\n pass")
+
+ def test_signed_short(self):
+ self.t(u"def f(signed short int y):\n pass")
+
+ def test_typed_args(self):
+ self.t(u"def f(int x, unsigned long int y):\n pass")
+
+ def test_cdef_var(self):
+ self.t(u"""
+ cdef int hello
+ cdef int hello = 4, x = 3, y, z
+ """)
+
+ def test_for_loop(self):
+ self.t(u"""
+ for x, y, z in f(g(h(34) * 2) + 23):
+ print x, y, z
+ else:
+ print 43
+ """)
+
+ def test_inplace_assignment(self):
+ self.t(u"x += 43")
+
+ def test_attribute(self):
+ self.t(u"a.x")
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
+
diff --git a/contrib/tools/cython/Cython/Tests/TestCythonUtils.py b/contrib/tools/cython/Cython/Tests/TestCythonUtils.py
new file mode 100644
index 0000000000..2641900c01
--- /dev/null
+++ b/contrib/tools/cython/Cython/Tests/TestCythonUtils.py
@@ -0,0 +1,11 @@
+import unittest
+
+from ..Utils import build_hex_version
+
+class TestCythonUtils(unittest.TestCase):
+ def test_build_hex_version(self):
+ self.assertEqual('0x001D00A1', build_hex_version('0.29a1'))
+ self.assertEqual('0x001D00A1', build_hex_version('0.29a1'))
+ self.assertEqual('0x001D03C4', build_hex_version('0.29.3rc4'))
+ self.assertEqual('0x001D00F0', build_hex_version('0.29'))
+ self.assertEqual('0x040000F0', build_hex_version('4.0'))
diff --git a/contrib/tools/cython/Cython/Tests/TestJediTyper.py b/contrib/tools/cython/Cython/Tests/TestJediTyper.py
new file mode 100644
index 0000000000..253adef171
--- /dev/null
+++ b/contrib/tools/cython/Cython/Tests/TestJediTyper.py
@@ -0,0 +1,225 @@
+# -*- coding: utf-8 -*-
+# tag: jedi
+
+from __future__ import absolute_import
+
+import sys
+import os.path
+
+from textwrap import dedent
+from contextlib import contextmanager
+from tempfile import NamedTemporaryFile
+
+from Cython.Compiler.ParseTreeTransforms import NormalizeTree, InterpretCompilerDirectives
+from Cython.Compiler import Main, Symtab, Visitor
+from Cython.TestUtils import TransformTest
+
+TOOLS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'Tools'))
+
+
+@contextmanager
+def _tempfile(code):
+ code = dedent(code)
+ if not isinstance(code, bytes):
+ code = code.encode('utf8')
+
+ with NamedTemporaryFile(suffix='.py') as f:
+ f.write(code)
+ f.seek(0)
+ yield f
+
+
+def _test_typing(code, inject=False):
+ sys.path.insert(0, TOOLS_DIR)
+ try:
+ import jedityper
+ finally:
+ sys.path.remove(TOOLS_DIR)
+ lines = []
+ with _tempfile(code) as f:
+ types = jedityper.analyse(f.name)
+ if inject:
+ lines = jedityper.inject_types(f.name, types)
+ return types, lines
+
+
+class DeclarationsFinder(Visitor.VisitorTransform):
+ directives = None
+
+ visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+ def visit_CompilerDirectivesNode(self, node):
+ if not self.directives:
+ self.directives = []
+ self.directives.append(node)
+ self.visitchildren(node)
+ return node
+
+
+class TestJediTyper(TransformTest):
+ def _test(self, code):
+ return _test_typing(code)[0]
+
+ def test_typing_global_int_loop(self):
+ code = '''\
+ for i in range(10):
+ a = i + 1
+ '''
+ types = self._test(code)
+ self.assertIn((None, (1, 0)), types)
+ variables = types.pop((None, (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['int']), 'i': set(['int'])}, variables)
+
+ def test_typing_function_int_loop(self):
+ code = '''\
+ def func(x):
+ for i in range(x):
+ a = i + 1
+ return a
+ '''
+ types = self._test(code)
+ self.assertIn(('func', (1, 0)), types)
+ variables = types.pop(('func', (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['int']), 'i': set(['int'])}, variables)
+
+ def test_conflicting_types_in_function(self):
+ code = '''\
+ def func(a, b):
+ print(a)
+ a = 1
+ b += a
+ a = 'abc'
+ return a, str(b)
+
+ print(func(1.5, 2))
+ '''
+ types = self._test(code)
+ self.assertIn(('func', (1, 0)), types)
+ variables = types.pop(('func', (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['float', 'int', 'str']), 'b': set(['int'])}, variables)
+
+ def _test_typing_function_char_loop(self):
+ code = '''\
+ def func(x):
+ l = []
+ for c in x:
+ l.append(c)
+ return l
+
+ print(func('abcdefg'))
+ '''
+ types = self._test(code)
+ self.assertIn(('func', (1, 0)), types)
+ variables = types.pop(('func', (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['int']), 'i': set(['int'])}, variables)
+
+ def test_typing_global_list(self):
+ code = '''\
+ a = [x for x in range(10)]
+ b = list(range(10))
+ c = a + b
+ d = [0]*10
+ '''
+ types = self._test(code)
+ self.assertIn((None, (1, 0)), types)
+ variables = types.pop((None, (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['list']), 'b': set(['list']), 'c': set(['list']), 'd': set(['list'])}, variables)
+
+ def test_typing_function_list(self):
+ code = '''\
+ def func(x):
+ a = [[], []]
+ b = [0]* 10 + a
+ c = a[0]
+
+ print(func([0]*100))
+ '''
+ types = self._test(code)
+ self.assertIn(('func', (1, 0)), types)
+ variables = types.pop(('func', (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['list']), 'b': set(['list']), 'c': set(['list']), 'x': set(['list'])}, variables)
+
+ def test_typing_global_dict(self):
+ code = '''\
+ a = dict()
+ b = {i: i**2 for i in range(10)}
+ c = a
+ '''
+ types = self._test(code)
+ self.assertIn((None, (1, 0)), types)
+ variables = types.pop((None, (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['dict']), 'b': set(['dict']), 'c': set(['dict'])}, variables)
+
+ def test_typing_function_dict(self):
+ code = '''\
+ def func(x):
+ a = dict()
+ b = {i: i**2 for i in range(10)}
+ c = x
+
+ print(func({1:2, 'x':7}))
+ '''
+ types = self._test(code)
+ self.assertIn(('func', (1, 0)), types)
+ variables = types.pop(('func', (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['dict']), 'b': set(['dict']), 'c': set(['dict']), 'x': set(['dict'])}, variables)
+
+
+ def test_typing_global_set(self):
+ code = '''\
+ a = set()
+ # b = {i for i in range(10)} # jedi does not support set comprehension yet
+ c = a
+ d = {1,2,3}
+ e = a | b
+ '''
+ types = self._test(code)
+ self.assertIn((None, (1, 0)), types)
+ variables = types.pop((None, (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['set']), 'c': set(['set']), 'd': set(['set']), 'e': set(['set'])}, variables)
+
+ def test_typing_function_set(self):
+ code = '''\
+ def func(x):
+ a = set()
+ # b = {i for i in range(10)} # jedi does not support set comprehension yet
+ c = a
+ d = a | b
+
+ print(func({1,2,3}))
+ '''
+ types = self._test(code)
+ self.assertIn(('func', (1, 0)), types)
+ variables = types.pop(('func', (1, 0)))
+ self.assertFalse(types)
+ self.assertEqual({'a': set(['set']), 'c': set(['set']), 'd': set(['set']), 'x': set(['set'])}, variables)
+
+
+class TestTypeInjection(TestJediTyper):
+ """
+ Subtype of TestJediTyper that additionally tests type injection and compilation.
+ """
+ def setUp(self):
+ super(TestTypeInjection, self).setUp()
+ compilation_options = Main.CompilationOptions(Main.default_options)
+ ctx = compilation_options.create_context()
+ transform = InterpretCompilerDirectives(ctx, ctx.compiler_directives)
+ transform.module_scope = Symtab.ModuleScope('__main__', None, ctx)
+ self.declarations_finder = DeclarationsFinder()
+ self.pipeline = [NormalizeTree(None), transform, self.declarations_finder]
+
+ def _test(self, code):
+ types, lines = _test_typing(code, inject=True)
+ tree = self.run_pipeline(self.pipeline, ''.join(lines))
+ directives = self.declarations_finder.directives
+ # TODO: validate directives
+ return types
diff --git a/contrib/tools/cython/Cython/Tests/TestStringIOTree.py b/contrib/tools/cython/Cython/Tests/TestStringIOTree.py
new file mode 100644
index 0000000000..a15f2cd88d
--- /dev/null
+++ b/contrib/tools/cython/Cython/Tests/TestStringIOTree.py
@@ -0,0 +1,67 @@
+import unittest
+
+from Cython import StringIOTree as stringtree
+
+code = """
+cdef int spam # line 1
+
+cdef ham():
+ a = 1
+ b = 2
+ c = 3
+ d = 4
+
+def eggs():
+ pass
+
+cpdef bacon():
+ print spam
+ print 'scotch'
+ print 'tea?'
+ print 'or coffee?' # line 16
+"""
+
+linemap = dict(enumerate(code.splitlines()))
+
+class TestStringIOTree(unittest.TestCase):
+
+ def setUp(self):
+ self.tree = stringtree.StringIOTree()
+
+ def test_markers(self):
+ assert not self.tree.allmarkers()
+
+ def test_insertion(self):
+ self.write_lines((1, 2, 3))
+ line_4_to_6_insertion_point = self.tree.insertion_point()
+ self.write_lines((7, 8))
+ line_9_to_13_insertion_point = self.tree.insertion_point()
+ self.write_lines((14, 15, 16))
+
+ line_4_insertion_point = line_4_to_6_insertion_point.insertion_point()
+ self.write_lines((5, 6), tree=line_4_to_6_insertion_point)
+
+ line_9_to_12_insertion_point = (
+ line_9_to_13_insertion_point.insertion_point())
+ self.write_line(13, tree=line_9_to_13_insertion_point)
+
+ self.write_line(4, tree=line_4_insertion_point)
+ self.write_line(9, tree=line_9_to_12_insertion_point)
+ line_10_insertion_point = line_9_to_12_insertion_point.insertion_point()
+ self.write_line(11, tree=line_9_to_12_insertion_point)
+ self.write_line(10, tree=line_10_insertion_point)
+ self.write_line(12, tree=line_9_to_12_insertion_point)
+
+ self.assertEqual(self.tree.allmarkers(), list(range(1, 17)))
+ self.assertEqual(code.strip(), self.tree.getvalue().strip())
+
+
+ def write_lines(self, linenos, tree=None):
+ for lineno in linenos:
+ self.write_line(lineno, tree=tree)
+
+ def write_line(self, lineno, tree=None):
+ if tree is None:
+ tree = self.tree
+ tree.markers.append(lineno)
+ tree.write(linemap[lineno] + '\n')
diff --git a/contrib/tools/cython/Cython/Tests/__init__.py b/contrib/tools/cython/Cython/Tests/__init__.py
new file mode 100644
index 0000000000..fa81adaff6
--- /dev/null
+++ b/contrib/tools/cython/Cython/Tests/__init__.py
@@ -0,0 +1 @@
+# empty file
diff --git a/contrib/tools/cython/Cython/Tests/xmlrunner.py b/contrib/tools/cython/Cython/Tests/xmlrunner.py
new file mode 100644
index 0000000000..d6838aa22e
--- /dev/null
+++ b/contrib/tools/cython/Cython/Tests/xmlrunner.py
@@ -0,0 +1,397 @@
+# -*- coding: utf-8 -*-
+
+"""unittest-xml-reporting is a PyUnit-based TestRunner that can export test
+results to XML files that can be consumed by a wide range of tools, such as
+build systems, IDEs and Continuous Integration servers.
+
+This module provides the XMLTestRunner class, which is heavily based on the
+default TextTestRunner. This makes the XMLTestRunner very simple to use.
+
+The script below, adapted from the unittest documentation, shows how to use
+XMLTestRunner in a very simple way. In fact, the only difference between this
+script and the original one is the last line:
+
+import random
+import unittest
+import xmlrunner
+
+class TestSequenceFunctions(unittest.TestCase):
+ def setUp(self):
+ self.seq = range(10)
+
+ def test_shuffle(self):
+ # make sure the shuffled sequence does not lose any elements
+ random.shuffle(self.seq)
+ self.seq.sort()
+ self.assertEqual(self.seq, range(10))
+
+ def test_choice(self):
+ element = random.choice(self.seq)
+ self.assertTrue(element in self.seq)
+
+ def test_sample(self):
+ self.assertRaises(ValueError, random.sample, self.seq, 20)
+ for element in random.sample(self.seq, 5):
+ self.assertTrue(element in self.seq)
+
+if __name__ == '__main__':
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='test-reports'))
+"""
+
+from __future__ import absolute_import
+
+import os
+import sys
+import time
+from unittest import TestResult, TextTestResult, TextTestRunner
+import xml.dom.minidom
+try:
+ from StringIO import StringIO
+except ImportError:
+ from io import StringIO # doesn't accept 'str' in Py2
+
+
+class XMLDocument(xml.dom.minidom.Document):
+ def createCDATAOrText(self, data):
+ if ']]>' in data:
+ return self.createTextNode(data)
+ return self.createCDATASection(data)
+
+
+class _TestInfo(object):
+ """This class is used to keep useful information about the execution of a
+ test method.
+ """
+
+ # Possible test outcomes
+ (SUCCESS, FAILURE, ERROR) = range(3)
+
+ def __init__(self, test_result, test_method, outcome=SUCCESS, err=None):
+ "Create a new instance of _TestInfo."
+ self.test_result = test_result
+ self.test_method = test_method
+ self.outcome = outcome
+ self.err = err
+ self.stdout = test_result.stdout and test_result.stdout.getvalue().strip() or ''
+ self.stderr = test_result.stdout and test_result.stderr.getvalue().strip() or ''
+
+ def get_elapsed_time(self):
+ """Return the time that shows how long the test method took to
+ execute.
+ """
+ return self.test_result.stop_time - self.test_result.start_time
+
+ def get_description(self):
+ "Return a text representation of the test method."
+ return self.test_result.getDescription(self.test_method)
+
+ def get_error_info(self):
+ """Return a text representation of an exception thrown by a test
+ method.
+ """
+ if not self.err:
+ return ''
+ return self.test_result._exc_info_to_string(
+ self.err, self.test_method)
+
+
+class _XMLTestResult(TextTestResult):
+ """A test result class that can express test results in a XML report.
+
+ Used by XMLTestRunner.
+ """
+ def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1,
+ elapsed_times=True):
+ "Create a new instance of _XMLTestResult."
+ TextTestResult.__init__(self, stream, descriptions, verbosity)
+ self.successes = []
+ self.callback = None
+ self.elapsed_times = elapsed_times
+ self.output_patched = False
+
+ def _prepare_callback(self, test_info, target_list, verbose_str,
+ short_str):
+ """Append a _TestInfo to the given target list and sets a callback
+ method to be called by stopTest method.
+ """
+ target_list.append(test_info)
+ def callback():
+ """This callback prints the test method outcome to the stream,
+ as well as the elapsed time.
+ """
+
+ # Ignore the elapsed times for a more reliable unit testing
+ if not self.elapsed_times:
+ self.start_time = self.stop_time = 0
+
+ if self.showAll:
+ self.stream.writeln('(%.3fs) %s' % \
+ (test_info.get_elapsed_time(), verbose_str))
+ elif self.dots:
+ self.stream.write(short_str)
+ self.callback = callback
+
+ def _patch_standard_output(self):
+ """Replace the stdout and stderr streams with string-based streams
+ in order to capture the tests' output.
+ """
+ if not self.output_patched:
+ (self.old_stdout, self.old_stderr) = (sys.stdout, sys.stderr)
+ self.output_patched = True
+ (sys.stdout, sys.stderr) = (self.stdout, self.stderr) = \
+ (StringIO(), StringIO())
+
+ def _restore_standard_output(self):
+ "Restore the stdout and stderr streams."
+ (sys.stdout, sys.stderr) = (self.old_stdout, self.old_stderr)
+ self.output_patched = False
+
+ def startTest(self, test):
+ "Called before execute each test method."
+ self._patch_standard_output()
+ self.start_time = time.time()
+ TestResult.startTest(self, test)
+
+ if self.showAll:
+ self.stream.write(' ' + self.getDescription(test))
+ self.stream.write(" ... ")
+
+ def stopTest(self, test):
+ "Called after execute each test method."
+ self._restore_standard_output()
+ TextTestResult.stopTest(self, test)
+ self.stop_time = time.time()
+
+ if self.callback and callable(self.callback):
+ self.callback()
+ self.callback = None
+
+ def addSuccess(self, test):
+ "Called when a test executes successfully."
+ self._prepare_callback(_TestInfo(self, test),
+ self.successes, 'OK', '.')
+
+ def addFailure(self, test, err):
+ "Called when a test method fails."
+ self._prepare_callback(_TestInfo(self, test, _TestInfo.FAILURE, err),
+ self.failures, 'FAIL', 'F')
+
+ def addError(self, test, err):
+ "Called when a test method raises an error."
+ self._prepare_callback(_TestInfo(self, test, _TestInfo.ERROR, err),
+ self.errors, 'ERROR', 'E')
+
+ def printErrorList(self, flavour, errors):
+ "Write some information about the FAIL or ERROR to the stream."
+ for test_info in errors:
+ if isinstance(test_info, tuple):
+ test_info, exc_info = test_info
+
+ try:
+ t = test_info.get_elapsed_time()
+ except AttributeError:
+ t = 0
+ try:
+ descr = test_info.get_description()
+ except AttributeError:
+ try:
+ descr = test_info.getDescription()
+ except AttributeError:
+ descr = str(test_info)
+ try:
+ err_info = test_info.get_error_info()
+ except AttributeError:
+ err_info = str(test_info)
+
+ self.stream.writeln(self.separator1)
+ self.stream.writeln('%s [%.3fs]: %s' % (flavour, t, descr))
+ self.stream.writeln(self.separator2)
+ self.stream.writeln('%s' % err_info)
+
+ def _get_info_by_testcase(self):
+ """This method organizes test results by TestCase module. This
+ information is used during the report generation, where a XML report
+ will be generated for each TestCase.
+ """
+ tests_by_testcase = {}
+
+ for tests in (self.successes, self.failures, self.errors):
+ for test_info in tests:
+ if not isinstance(test_info, _TestInfo):
+ print("Unexpected test result type: %r" % (test_info,))
+ continue
+ testcase = type(test_info.test_method)
+
+ # Ignore module name if it is '__main__'
+ module = testcase.__module__ + '.'
+ if module == '__main__.':
+ module = ''
+ testcase_name = module + testcase.__name__
+
+ if testcase_name not in tests_by_testcase:
+ tests_by_testcase[testcase_name] = []
+ tests_by_testcase[testcase_name].append(test_info)
+
+ return tests_by_testcase
+
+ def _report_testsuite(suite_name, tests, xml_document):
+ "Appends the testsuite section to the XML document."
+ testsuite = xml_document.createElement('testsuite')
+ xml_document.appendChild(testsuite)
+
+ testsuite.setAttribute('name', str(suite_name))
+ testsuite.setAttribute('tests', str(len(tests)))
+
+ testsuite.setAttribute('time', '%.3f' %
+ sum([e.get_elapsed_time() for e in tests]))
+
+ failures = len([1 for e in tests if e.outcome == _TestInfo.FAILURE])
+ testsuite.setAttribute('failures', str(failures))
+
+ errors = len([1 for e in tests if e.outcome == _TestInfo.ERROR])
+ testsuite.setAttribute('errors', str(errors))
+
+ return testsuite
+
+ _report_testsuite = staticmethod(_report_testsuite)
+
+ def _report_testcase(suite_name, test_result, xml_testsuite, xml_document):
+ "Appends a testcase section to the XML document."
+ testcase = xml_document.createElement('testcase')
+ xml_testsuite.appendChild(testcase)
+
+ testcase.setAttribute('classname', str(suite_name))
+ testcase.setAttribute('name', test_result.test_method.shortDescription()
+ or getattr(test_result.test_method, '_testMethodName',
+ str(test_result.test_method)))
+ testcase.setAttribute('time', '%.3f' % test_result.get_elapsed_time())
+
+ if (test_result.outcome != _TestInfo.SUCCESS):
+ elem_name = ('failure', 'error')[test_result.outcome-1]
+ failure = xml_document.createElement(elem_name)
+ testcase.appendChild(failure)
+
+ failure.setAttribute('type', str(test_result.err[0].__name__))
+ failure.setAttribute('message', str(test_result.err[1]))
+
+ error_info = test_result.get_error_info()
+ failureText = xml_document.createCDATAOrText(error_info)
+ failure.appendChild(failureText)
+
+ _report_testcase = staticmethod(_report_testcase)
+
+ def _report_output(test_runner, xml_testsuite, xml_document, stdout, stderr):
+ "Appends the system-out and system-err sections to the XML document."
+ systemout = xml_document.createElement('system-out')
+ xml_testsuite.appendChild(systemout)
+
+ systemout_text = xml_document.createCDATAOrText(stdout)
+ systemout.appendChild(systemout_text)
+
+ systemerr = xml_document.createElement('system-err')
+ xml_testsuite.appendChild(systemerr)
+
+ systemerr_text = xml_document.createCDATAOrText(stderr)
+ systemerr.appendChild(systemerr_text)
+
+ _report_output = staticmethod(_report_output)
+
+ def generate_reports(self, test_runner):
+ "Generates the XML reports to a given XMLTestRunner object."
+ all_results = self._get_info_by_testcase()
+
+ if type(test_runner.output) == str and not \
+ os.path.exists(test_runner.output):
+ os.makedirs(test_runner.output)
+
+ for suite, tests in all_results.items():
+ doc = XMLDocument()
+
+ # Build the XML file
+ testsuite = _XMLTestResult._report_testsuite(suite, tests, doc)
+ stdout, stderr = [], []
+ for test in tests:
+ _XMLTestResult._report_testcase(suite, test, testsuite, doc)
+ if test.stdout:
+ stdout.extend(['*****************', test.get_description(), test.stdout])
+ if test.stderr:
+ stderr.extend(['*****************', test.get_description(), test.stderr])
+ _XMLTestResult._report_output(test_runner, testsuite, doc,
+ '\n'.join(stdout), '\n'.join(stderr))
+ xml_content = doc.toprettyxml(indent='\t')
+
+ if type(test_runner.output) is str:
+ report_file = open('%s%sTEST-%s.xml' % \
+ (test_runner.output, os.sep, suite), 'w')
+ try:
+ report_file.write(xml_content)
+ finally:
+ report_file.close()
+ else:
+ # Assume that test_runner.output is a stream
+ test_runner.output.write(xml_content)
+
+
+class XMLTestRunner(TextTestRunner):
+ """A test runner class that outputs the results in JUnit like XML files.
+ """
+ def __init__(self, output='.', stream=None, descriptions=True, verbose=False, elapsed_times=True):
+ "Create a new instance of XMLTestRunner."
+ if stream is None:
+ stream = sys.stderr
+ verbosity = (1, 2)[verbose]
+ TextTestRunner.__init__(self, stream, descriptions, verbosity)
+ self.output = output
+ self.elapsed_times = elapsed_times
+
+ def _make_result(self):
+ """Create the TestResult object which will be used to store
+ information about the executed tests.
+ """
+ return _XMLTestResult(self.stream, self.descriptions, \
+ self.verbosity, self.elapsed_times)
+
+ def run(self, test):
+ "Run the given test case or test suite."
+ # Prepare the test execution
+ result = self._make_result()
+
+ # Print a nice header
+ self.stream.writeln()
+ self.stream.writeln('Running tests...')
+ self.stream.writeln(result.separator2)
+
+ # Execute tests
+ start_time = time.time()
+ test(result)
+ stop_time = time.time()
+ time_taken = stop_time - start_time
+
+ # Generate reports
+ self.stream.writeln()
+ self.stream.writeln('Generating XML reports...')
+ result.generate_reports(self)
+
+ # Print results
+ result.printErrors()
+ self.stream.writeln(result.separator2)
+ run = result.testsRun
+ self.stream.writeln("Ran %d test%s in %.3fs" %
+ (run, run != 1 and "s" or "", time_taken))
+ self.stream.writeln()
+
+ # Error traces
+ if not result.wasSuccessful():
+ self.stream.write("FAILED (")
+ failed, errored = (len(result.failures), len(result.errors))
+ if failed:
+ self.stream.write("failures=%d" % failed)
+ if errored:
+ if failed:
+ self.stream.write(", ")
+ self.stream.write("errors=%d" % errored)
+ self.stream.writeln(")")
+ else:
+ self.stream.writeln("OK")
+
+ return result