diff options
author | robot-contrib <robot-contrib@yandex-team.com> | 2023-10-19 17:11:31 +0300 |
---|---|---|
committer | robot-contrib <robot-contrib@yandex-team.com> | 2023-10-19 18:26:04 +0300 |
commit | b9fe236a503791a3a7b37d4ef5f466225218996c (patch) | |
tree | c2f80019399b393ddf0450d0f91fc36478af8bea /contrib/python/traitlets/py3/tests | |
parent | 44dd27d0a2ae37c80d97a95581951d1d272bd7df (diff) | |
download | ydb-b9fe236a503791a3a7b37d4ef5f466225218996c.tar.gz |
Update contrib/python/traitlets/py3 to 5.11.2
Diffstat (limited to 'contrib/python/traitlets/py3/tests')
16 files changed, 6903 insertions, 10 deletions
diff --git a/contrib/python/traitlets/py3/tests/__init__.py b/contrib/python/traitlets/py3/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/contrib/python/traitlets/py3/tests/__init__.py diff --git a/contrib/python/traitlets/py3/tests/_warnings.py b/contrib/python/traitlets/py3/tests/_warnings.py new file mode 100644 index 00000000000..e3c3a0ac6d6 --- /dev/null +++ b/contrib/python/traitlets/py3/tests/_warnings.py @@ -0,0 +1,114 @@ +# From scikit-image: https://github.com/scikit-image/scikit-image/blob/c2f8c4ab123ebe5f7b827bc495625a32bb225c10/skimage/_shared/_warnings.py +# Licensed under modified BSD license + +__all__ = ["all_warnings", "expected_warnings"] + +import inspect +import os +import re +import sys +import warnings +from contextlib import contextmanager +from unittest import mock + + +@contextmanager +def all_warnings(): + """ + Context for use in testing to ensure that all warnings are raised. + Examples + -------- + >>> import warnings + >>> def foo(): + ... warnings.warn(RuntimeWarning("bar")) + + We raise the warning once, while the warning filter is set to "once". + Hereafter, the warning is invisible, even with custom filters: + >>> with warnings.catch_warnings(): + ... warnings.simplefilter('once') + ... foo() + + We can now run ``foo()`` without a warning being raised: + >>> from numpy.testing import assert_warns # doctest: +SKIP + >>> foo() # doctest: +SKIP + + To catch the warning, we call in the help of ``all_warnings``: + >>> with all_warnings(): # doctest: +SKIP + ... assert_warns(RuntimeWarning, foo) + """ + + # Whenever a warning is triggered, Python adds a __warningregistry__ + # member to the *calling* module. The exercize here is to find + # and eradicate all those breadcrumbs that were left lying around. + # + # We proceed by first searching all parent calling frames and explicitly + # clearing their warning registries (necessary for the doctests above to + # pass). Then, we search for all submodules of skimage and clear theirs + # as well (necessary for the skimage test suite to pass). + + frame = inspect.currentframe() + if frame: + for f in inspect.getouterframes(frame): + f[0].f_locals["__warningregistry__"] = {} + del frame + + for _, mod in list(sys.modules.items()): + try: + mod.__warningregistry__.clear() + except AttributeError: + pass + + with warnings.catch_warnings(record=True) as w, mock.patch.dict( + os.environ, {"TRAITLETS_ALL_DEPRECATIONS": "1"} + ): + warnings.simplefilter("always") + yield w + + +@contextmanager +def expected_warnings(matching): + r"""Context for use in testing to catch known warnings matching regexes + + Parameters + ---------- + matching : list of strings or compiled regexes + Regexes for the desired warning to catch + + Examples + -------- + >>> from skimage import data, img_as_ubyte, img_as_float # doctest: +SKIP + >>> with expected_warnings(["precision loss"]): # doctest: +SKIP + ... d = img_as_ubyte(img_as_float(data.coins())) # doctest: +SKIP + + Notes + ----- + Uses `all_warnings` to ensure all warnings are raised. + Upon exiting, it checks the recorded warnings for the desired matching + pattern(s). + Raises a ValueError if any match was not found or an unexpected + warning was raised. + Allows for three types of behaviors: "and", "or", and "optional" matches. + This is done to accomodate different build enviroments or loop conditions + that may produce different warnings. The behaviors can be combined. + If you pass multiple patterns, you get an orderless "and", where all of the + warnings must be raised. + If you use the "|" operator in a pattern, you can catch one of several warnings. + Finally, you can use "|\A\Z" in a pattern to signify it as optional. + """ + with all_warnings() as w: + # enter context + yield w + # exited user context, check the recorded warnings + remaining = [m for m in matching if r"\A\Z" not in m.split("|")] + for warn in w: + found = False + for match in matching: + if re.search(match, str(warn.message)) is not None: + found = True + if match in remaining: + remaining.remove(match) + if not found: + raise ValueError("Unexpected warning: %s" % str(warn.message)) + if len(remaining) > 0: + msg = "No warning raised matching:\n%s" % "\n".join(remaining) + raise ValueError(msg) diff --git a/contrib/python/traitlets/py3/tests/config/__init__.py b/contrib/python/traitlets/py3/tests/config/__init__.py new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/contrib/python/traitlets/py3/tests/config/__init__.py diff --git a/contrib/python/traitlets/py3/tests/config/test_application.py b/contrib/python/traitlets/py3/tests/config/test_application.py new file mode 100644 index 00000000000..610cafc3cd5 --- /dev/null +++ b/contrib/python/traitlets/py3/tests/config/test_application.py @@ -0,0 +1,910 @@ +""" +Tests for traitlets.config.application.Application +""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +import contextlib +import io +import json +import logging +import os +import sys +import typing as t +from io import StringIO +from tempfile import TemporaryDirectory +from unittest import TestCase + +import pytest +from pytest import mark + +from traitlets import Bool, Bytes, Dict, HasTraits, Integer, List, Set, Tuple, Unicode +from traitlets.config.application import Application +from traitlets.config.configurable import Configurable +from traitlets.config.loader import Config, KVArgParseConfigLoader +from traitlets.tests.utils import check_help_all_output, check_help_output, get_output_error_code + +try: + from unittest import mock +except ImportError: + from unittest import mock + +pjoin = os.path.join + + +class Foo(Configurable): + i = Integer( + 0, + help=""" + The integer i. + + Details about i. + """, + ).tag(config=True) + j = Integer(1, help="The integer j.").tag(config=True) + name = Unicode("Brian", help="First name.").tag(config=True) + la = List([]).tag(config=True) + li = List(Integer()).tag(config=True) + fdict = Dict().tag(config=True, multiplicity="+") + + +class Bar(Configurable): + b = Integer(0, help="The integer b.").tag(config=True) + enabled = Bool(True, help="Enable bar.").tag(config=True) + tb = Tuple(()).tag(config=True, multiplicity="*") + aset = Set().tag(config=True, multiplicity="+") + bdict = Dict().tag(config=True) + idict = Dict(value_trait=Integer()).tag(config=True) + key_dict = Dict(per_key_traits={"i": Integer(), "b": Bytes()}).tag(config=True) + + +class MyApp(Application): + name = Unicode("myapp") + running = Bool(False, help="Is the app running?").tag(config=True) + classes = List([Bar, Foo]) # type:ignore + config_file = Unicode("", help="Load this config file").tag(config=True) + + warn_tpyo = Unicode( + "yes the name is wrong on purpose", + config=True, + help="Should print a warning if `MyApp.warn-typo=...` command is passed", + ) + + aliases: t.Dict[t.Any, t.Any] = {} + aliases.update(Application.aliases) + aliases.update( + { + ("fooi", "i"): "Foo.i", + ("j", "fooj"): ("Foo.j", "`j` terse help msg"), + "name": "Foo.name", + "la": "Foo.la", + "li": "Foo.li", + "tb": "Bar.tb", + "D": "Bar.bdict", + "enabled": "Bar.enabled", + "enable": "Bar.enabled", + "log-level": "Application.log_level", + } + ) + + flags: t.Dict[t.Any, t.Any] = {} + flags.update(Application.flags) + flags.update( + { + ("enable", "e"): ({"Bar": {"enabled": True}}, "Set Bar.enabled to True"), + ("d", "disable"): ({"Bar": {"enabled": False}}, "Set Bar.enabled to False"), + "crit": ({"Application": {"log_level": logging.CRITICAL}}, "set level=CRITICAL"), + } + ) + + def init_foo(self): + self.foo = Foo(parent=self) + + def init_bar(self): + self.bar = Bar(parent=self) + + +def class_to_names(classes): + return [klass.__name__ for klass in classes] + + +class TestApplication(TestCase): + def test_log(self): + stream = StringIO() + app = MyApp(log_level=logging.INFO) + handler = logging.StreamHandler(stream) + # trigger reconstruction of the log formatter + app.log_format = "%(message)s" + app.log_datefmt = "%Y-%m-%d %H:%M" + app.log.handlers = [handler] + app.log.info("hello") + assert "hello" in stream.getvalue() + + def test_no_eval_cli_text(self): + app = MyApp() + app.initialize(["--Foo.name=1"]) + app.init_foo() + assert app.foo.name == "1" + + def test_basic(self): + app = MyApp() + self.assertEqual(app.name, "myapp") + self.assertEqual(app.running, False) + self.assertEqual(app.classes, [MyApp, Bar, Foo]) # type:ignore + self.assertEqual(app.config_file, "") + + def test_app_name_set_via_constructor(self): + app = MyApp(name='set_via_constructor') + assert app.name == "set_via_constructor" + + def test_mro_discovery(self): + app = MyApp() + + self.assertSequenceEqual( + class_to_names(app._classes_with_config_traits()), + ["Application", "MyApp", "Bar", "Foo"], + ) + self.assertSequenceEqual( + class_to_names(app._classes_inc_parents()), + [ + "Configurable", + "LoggingConfigurable", + "SingletonConfigurable", + "Application", + "MyApp", + "Bar", + "Foo", + ], + ) + + self.assertSequenceEqual( + class_to_names(app._classes_with_config_traits([Application])), ["Application"] + ) + self.assertSequenceEqual( + class_to_names(app._classes_inc_parents([Application])), + ["Configurable", "LoggingConfigurable", "SingletonConfigurable", "Application"], + ) + + self.assertSequenceEqual(class_to_names(app._classes_with_config_traits([Foo])), ["Foo"]) + self.assertSequenceEqual( + class_to_names(app._classes_inc_parents([Bar])), ["Configurable", "Bar"] + ) + + class MyApp2(Application): # no defined `classes` attr + pass + + self.assertSequenceEqual(class_to_names(app._classes_with_config_traits([Foo])), ["Foo"]) + self.assertSequenceEqual( + class_to_names(app._classes_inc_parents([Bar])), ["Configurable", "Bar"] + ) + + def test_config(self): + app = MyApp() + app.parse_command_line( + [ + "--i=10", + "--Foo.j=10", + "--enable=False", + "--log-level=50", + ] + ) + config = app.config + print(config) + self.assertEqual(config.Foo.i, 10) + self.assertEqual(config.Foo.j, 10) + self.assertEqual(config.Bar.enabled, False) + self.assertEqual(config.MyApp.log_level, 50) + + def test_config_seq_args(self): + app = MyApp() + app.parse_command_line( + "--li 1 --li 3 --la 1 --tb AB 2 --Foo.la=ab --Bar.aset S1 --Bar.aset S2 --Bar.aset S1".split() + ) + assert app.extra_args == ["2"] + config = app.config + assert config.Foo.li == [1, 3] + assert config.Foo.la == ["1", "ab"] + assert config.Bar.tb == ("AB",) + self.assertEqual(config.Bar.aset, {"S1", "S2"}) + app.init_foo() + assert app.foo.li == [1, 3] + assert app.foo.la == ["1", "ab"] + app.init_bar() + self.assertEqual(app.bar.aset, {"S1", "S2"}) + assert app.bar.tb == ("AB",) + + def test_config_dict_args(self): + app = MyApp() + app.parse_command_line( + "--Foo.fdict a=1 --Foo.fdict b=b --Foo.fdict c=3 " + "--Bar.bdict k=1 -D=a=b -D 22=33 " + "--Bar.idict k=1 --Bar.idict b=2 --Bar.idict c=3 ".split() + ) + fdict = {"a": "1", "b": "b", "c": "3"} + bdict = {"k": "1", "a": "b", "22": "33"} + idict = {"k": 1, "b": 2, "c": 3} + config = app.config + assert config.Bar.idict == idict + self.assertDictEqual(config.Foo.fdict, fdict) + self.assertDictEqual(config.Bar.bdict, bdict) + app.init_foo() + self.assertEqual(app.foo.fdict, fdict) + app.init_bar() + assert app.bar.idict == idict + self.assertEqual(app.bar.bdict, bdict) + + def test_config_propagation(self): + app = MyApp() + app.parse_command_line(["--i=10", "--Foo.j=10", "--enable=False", "--log-level=50"]) + app.init_foo() + app.init_bar() + self.assertEqual(app.foo.i, 10) + self.assertEqual(app.foo.j, 10) + self.assertEqual(app.bar.enabled, False) + + def test_cli_priority(self): + """Test that loading config files does not override CLI options""" + name = "config.py" + + class TestApp(Application): + value = Unicode().tag(config=True) + config_file_loaded = Bool().tag(config=True) + aliases = {"v": "TestApp.value"} + + app = TestApp() + with TemporaryDirectory() as td: + config_file = pjoin(td, name) + with open(config_file, "w") as f: + f.writelines( + ["c.TestApp.value = 'config file'\n", "c.TestApp.config_file_loaded = True\n"] + ) + + app.parse_command_line(["--v=cli"]) + assert "value" in app.config.TestApp + assert app.config.TestApp.value == "cli" + assert app.value == "cli" + + app.load_config_file(name, path=[td]) + assert app.config_file_loaded + assert app.config.TestApp.value == "cli" + assert app.value == "cli" + + def test_ipython_cli_priority(self): + # this test is almost entirely redundant with above, + # but we can keep it around in case of subtle issues creeping into + # the exact sequence IPython follows. + name = "config.py" + + class TestApp(Application): + value = Unicode().tag(config=True) + config_file_loaded = Bool().tag(config=True) + aliases = {"v": ("TestApp.value", "some help")} + + app = TestApp() + with TemporaryDirectory() as td: + config_file = pjoin(td, name) + with open(config_file, "w") as f: + f.writelines( + ["c.TestApp.value = 'config file'\n", "c.TestApp.config_file_loaded = True\n"] + ) + # follow IPython's config-loading sequence to ensure CLI priority is preserved + app.parse_command_line(["--v=cli"]) + # this is where IPython makes a mistake: + # it assumes app.config will not be modified, + # and storing a reference is storing a copy + cli_config = app.config + assert "value" in app.config.TestApp + assert app.config.TestApp.value == "cli" + assert app.value == "cli" + app.load_config_file(name, path=[td]) + assert app.config_file_loaded + # enforce cl-opts override config file opts: + # this is where IPython makes a mistake: it assumes + # that cl_config is a different object, but it isn't. + app.update_config(cli_config) + assert app.config.TestApp.value == "cli" + assert app.value == "cli" + + def test_cli_allow_none(self): + class App(Application): + aliases = {"opt": "App.opt"} + opt = Unicode(allow_none=True, config=True) + + app = App() + app.parse_command_line(["--opt=None"]) + assert app.opt is None + + def test_flags(self): + app = MyApp() + app.parse_command_line(["--disable"]) + app.init_bar() + self.assertEqual(app.bar.enabled, False) + + app = MyApp() + app.parse_command_line(["-d"]) + app.init_bar() + self.assertEqual(app.bar.enabled, False) + + app = MyApp() + app.parse_command_line(["--enable"]) + app.init_bar() + self.assertEqual(app.bar.enabled, True) + + app = MyApp() + app.parse_command_line(["-e"]) + app.init_bar() + self.assertEqual(app.bar.enabled, True) + + def test_flags_help_msg(self): + app = MyApp() + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout): + app.print_flag_help() + hmsg = stdout.getvalue() + self.assertRegex(hmsg, "(?<!-)-e, --enable\\b") + self.assertRegex(hmsg, "(?<!-)-d, --disable\\b") + self.assertIn("Equivalent to: [--Bar.enabled=True]", hmsg) + self.assertIn("Equivalent to: [--Bar.enabled=False]", hmsg) + + def test_aliases(self): + app = MyApp() + app.parse_command_line(["--i=5", "--j=10"]) + app.init_foo() + self.assertEqual(app.foo.i, 5) + app.init_foo() + self.assertEqual(app.foo.j, 10) + + app = MyApp() + app.parse_command_line(["-i=5", "-j=10"]) + app.init_foo() + self.assertEqual(app.foo.i, 5) + app.init_foo() + self.assertEqual(app.foo.j, 10) + + app = MyApp() + app.parse_command_line(["--fooi=5", "--fooj=10"]) + app.init_foo() + self.assertEqual(app.foo.i, 5) + app.init_foo() + self.assertEqual(app.foo.j, 10) + + def test_aliases_multiple(self): + # Test multiple > 2 aliases for the same argument + class TestMultiAliasApp(Application): + foo = Integer(config=True) + aliases = {("f", "bar", "qux"): "TestMultiAliasApp.foo"} + + app = TestMultiAliasApp() + app.parse_command_line(["-f", "3"]) + self.assertEqual(app.foo, 3) + + app = TestMultiAliasApp() + app.parse_command_line(["--bar", "4"]) + self.assertEqual(app.foo, 4) + + app = TestMultiAliasApp() + app.parse_command_line(["--qux", "5"]) + self.assertEqual(app.foo, 5) + + def test_aliases_help_msg(self): + app = MyApp() + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout): + app.print_alias_help() + hmsg = stdout.getvalue() + self.assertRegex(hmsg, "(?<!-)-i, --fooi\\b") + self.assertRegex(hmsg, "(?<!-)-j, --fooj\\b") + self.assertIn("Equivalent to: [--Foo.i]", hmsg) + self.assertIn("Equivalent to: [--Foo.j]", hmsg) + self.assertIn("Equivalent to: [--Foo.name]", hmsg) + + def test_alias_unrecognized(self): + """Check ability to override handling for unrecognized aliases""" + + class StrictLoader(KVArgParseConfigLoader): + def _handle_unrecognized_alias(self, arg): + self.parser.error("Unrecognized alias: %s" % arg) + + class StrictApplication(Application): + def _create_loader(self, argv, aliases, flags, classes): + return StrictLoader(argv, aliases, flags, classes=classes, log=self.log) + + app = StrictApplication() + app.initialize(["--log-level=20"]) # recognized alias + assert app.log_level == 20 + + app = StrictApplication() + with pytest.raises(SystemExit, match="2"): + app.initialize(["--unrecognized=20"]) + + # Ideally we would use pytest capsys fixture, but fixtures are incompatible + # with unittest.TestCase-style classes :( + # stderr = capsys.readouterr().err + # assert "Unrecognized alias: unrecognized" in stderr + + def test_flag_clobber(self): + """test that setting flags doesn't clobber existing settings""" + app = MyApp() + app.parse_command_line(["--Bar.b=5", "--disable"]) + app.init_bar() + self.assertEqual(app.bar.enabled, False) + self.assertEqual(app.bar.b, 5) + app.parse_command_line(["--enable", "--Bar.b=10"]) + app.init_bar() + self.assertEqual(app.bar.enabled, True) + self.assertEqual(app.bar.b, 10) + + def test_warn_autocorrect(self): + stream = StringIO() + app = MyApp(log_level=logging.INFO) + app.log.handlers = [logging.StreamHandler(stream)] + + cfg = Config() + cfg.MyApp.warn_typo = "WOOOO" + app.config = cfg + + self.assertIn("warn_typo", stream.getvalue()) + self.assertIn("warn_tpyo", stream.getvalue()) + + def test_flatten_flags(self): + cfg = Config() + cfg.MyApp.log_level = logging.WARN + app = MyApp() + app.update_config(cfg) + self.assertEqual(app.log_level, logging.WARN) + self.assertEqual(app.config.MyApp.log_level, logging.WARN) + app.initialize(["--crit"]) + self.assertEqual(app.log_level, logging.CRITICAL) + # this would be app.config.Application.log_level if it failed: + self.assertEqual(app.config.MyApp.log_level, logging.CRITICAL) + + def test_flatten_aliases(self): + cfg = Config() + cfg.MyApp.log_level = logging.WARN + app = MyApp() + app.update_config(cfg) + self.assertEqual(app.log_level, logging.WARN) + self.assertEqual(app.config.MyApp.log_level, logging.WARN) + app.initialize(["--log-level", "CRITICAL"]) + self.assertEqual(app.log_level, logging.CRITICAL) + # this would be app.config.Application.log_level if it failed: + self.assertEqual(app.config.MyApp.log_level, "CRITICAL") + + def test_extra_args(self): + app = MyApp() + app.parse_command_line(["--Bar.b=5", "extra", "args", "--disable"]) + app.init_bar() + self.assertEqual(app.bar.enabled, False) + self.assertEqual(app.bar.b, 5) + self.assertEqual(app.extra_args, ["extra", "args"]) + + app = MyApp() + app.parse_command_line(["--Bar.b=5", "--", "extra", "--disable", "args"]) + app.init_bar() + self.assertEqual(app.bar.enabled, True) + self.assertEqual(app.bar.b, 5) + self.assertEqual(app.extra_args, ["extra", "--disable", "args"]) + + app = MyApp() + app.parse_command_line(["--disable", "--la", "-", "-", "--Bar.b=1", "--", "-", "extra"]) + self.assertEqual(app.extra_args, ["-", "-", "extra"]) + + def test_unicode_argv(self): + app = MyApp() + app.parse_command_line(["ünîcødé"]) + + def test_document_config_option(self): + app = MyApp() + app.document_config_options() + + def test_generate_config_file(self): + app = MyApp() + assert "The integer b." in app.generate_config_file() + + def test_generate_config_file_classes_to_include(self): + class NotInConfig(HasTraits): + from_hidden = Unicode( + "x", + help="""From hidden class + + Details about from_hidden. + """, + ).tag(config=True) + + class NoTraits(Foo, Bar, NotInConfig): + pass + + app = MyApp() + app.classes.append(NoTraits) # type:ignore + + conf_txt = app.generate_config_file() + print(conf_txt) + self.assertIn("The integer b.", conf_txt) + self.assertIn("# Foo(Configurable)", conf_txt) + self.assertNotIn("# Configurable", conf_txt) + self.assertIn("# NoTraits(Foo, Bar)", conf_txt) + + # inherited traits, parent in class list: + self.assertIn("# c.NoTraits.i", conf_txt) + self.assertIn("# c.NoTraits.j", conf_txt) + self.assertIn("# c.NoTraits.n", conf_txt) + self.assertIn("# See also: Foo.j", conf_txt) + self.assertIn("# See also: Bar.b", conf_txt) + self.assertEqual(conf_txt.count("Details about i."), 1) + + # inherited traits, parent not in class list: + self.assertIn("# c.NoTraits.from_hidden", conf_txt) + self.assertNotIn("# See also: NotInConfig.", conf_txt) + self.assertEqual(conf_txt.count("Details about from_hidden."), 1) + self.assertNotIn("NotInConfig", conf_txt) + + def test_multi_file(self): + app = MyApp() + app.log = logging.getLogger() + name = "config.py" + with TemporaryDirectory("_1") as td1: + with open(pjoin(td1, name), "w") as f1: + f1.write("get_config().MyApp.Bar.b = 1") + with TemporaryDirectory("_2") as td2: + with open(pjoin(td2, name), "w") as f2: + f2.write("get_config().MyApp.Bar.b = 2") + app.load_config_file(name, path=[td2, td1]) + app.init_bar() + self.assertEqual(app.bar.b, 2) + app.load_config_file(name, path=[td1, td2]) + app.init_bar() + self.assertEqual(app.bar.b, 1) + + @mark.skipif(not hasattr(TestCase, "assertLogs"), reason="requires TestCase.assertLogs") + def test_log_collisions(self): + app = MyApp() + app.log = logging.getLogger() + app.log.setLevel(logging.INFO) + name = "config" + with TemporaryDirectory("_1") as td: + with open(pjoin(td, name + ".py"), "w") as f: + f.write("get_config().Bar.b = 1") + with open(pjoin(td, name + ".json"), "w") as f: + json.dump({"Bar": {"b": 2}}, f) + with self.assertLogs(app.log, logging.WARNING) as captured: + app.load_config_file(name, path=[td]) + app.init_bar() + assert app.bar.b == 2 + output = "\n".join(captured.output) + assert "Collision" in output + assert "1 ignored, using 2" in output + assert pjoin(td, name + ".py") in output + assert pjoin(td, name + ".json") in output + + @mark.skipif(not hasattr(TestCase, "assertLogs"), reason="requires TestCase.assertLogs") + def test_log_bad_config(self): + app = MyApp() + app.log = logging.getLogger() + name = "config.py" + with TemporaryDirectory() as td: + with open(pjoin(td, name), "w") as f: + f.write("syntax error()") + with self.assertLogs(app.log, logging.ERROR) as captured: + app.load_config_file(name, path=[td]) + output = "\n".join(captured.output) + self.assertIn("SyntaxError", output) + + def test_raise_on_bad_config(self): + app = MyApp() + app.raise_config_file_errors = True + app.log = logging.getLogger() + name = "config.py" + with TemporaryDirectory() as td: + with open(pjoin(td, name), "w") as f: + f.write("syntax error()") + with self.assertRaises(SyntaxError): + app.load_config_file(name, path=[td]) + + def test_subcommands_instantiation(self): + """Try all ways to specify how to create sub-apps.""" + app = Root.instance() + app.parse_command_line(["sub1"]) + + self.assertIsInstance(app.subapp, Sub1) + # Check parent hierarchy. + self.assertIs(app.subapp.parent, app) + + Root.clear_instance() + Sub1.clear_instance() # Otherwise, replaced spuriously and hierarchy check fails. + app = Root.instance() + + app.parse_command_line(["sub1", "sub2"]) + self.assertIsInstance(app.subapp, Sub1) + self.assertIsInstance(app.subapp.subapp, Sub2) + # Check parent hierarchy. + self.assertIs(app.subapp.parent, app) + self.assertIs(app.subapp.subapp.parent, app.subapp) + + Root.clear_instance() + Sub1.clear_instance() # Otherwise, replaced spuriously and hierarchy check fails. + app = Root.instance() + + app.parse_command_line(["sub1", "sub3"]) + self.assertIsInstance(app.subapp, Sub1) + self.assertIsInstance(app.subapp.subapp, Sub3) + self.assertTrue(app.subapp.subapp.flag) # Set by factory. + # Check parent hierarchy. + self.assertIs(app.subapp.parent, app) + self.assertIs(app.subapp.subapp.parent, app.subapp) # Set by factory. + + Root.clear_instance() + Sub1.clear_instance() + + def test_loaded_config_files(self): + app = MyApp() + app.log = logging.getLogger() + name = "config.py" + with TemporaryDirectory("_1") as td1: + config_file = pjoin(td1, name) + with open(config_file, "w") as f: + f.writelines(["c.MyApp.running = True\n"]) + + app.load_config_file(name, path=[td1]) + self.assertEqual(len(app.loaded_config_files), 1) + self.assertEqual(app.loaded_config_files[0], config_file) + + app.start() + self.assertEqual(app.running, True) + + # emulate an app that allows dynamic updates and update config file + with open(config_file, "w") as f: + f.writelines(["c.MyApp.running = False\n"]) + + # reload and verify update, and that loaded_configs was not increased + app.load_config_file(name, path=[td1]) + self.assertEqual(len(app.loaded_config_files), 1) + self.assertEqual(app.running, False) + + # Attempt to update, ensure error... + with self.assertRaises(AttributeError): + app.loaded_config_files = "/foo" # type:ignore + + # ensure it can't be udpated via append + app.loaded_config_files.append("/bar") + self.assertEqual(len(app.loaded_config_files), 1) + + # repeat to ensure no unexpected changes occurred + app.load_config_file(name, path=[td1]) + self.assertEqual(len(app.loaded_config_files), 1) + self.assertEqual(app.running, False) + + +@mark.skip +def test_cli_multi_scalar(caplog): + class App(Application): + aliases = {"opt": "App.opt"} + opt = Unicode(config=True) + + app = App(log=logging.getLogger()) + with pytest.raises(SystemExit): + app.parse_command_line(["--opt", "1", "--opt", "2"]) + record = caplog.get_records("call")[-1] + message = record.message + + assert "Error loading argument" in message + assert "App.opt=['1', '2']" in message + assert "opt only accepts one value" in message + assert record.levelno == logging.CRITICAL + + +class Root(Application): + subcommands = { + "sub1": ("__tests__.config.test_application.Sub1", "import string"), + } + + +class Sub3(Application): + flag = Bool(False) + + +class Sub2(Application): + pass + + +class Sub1(Application): + subcommands: dict = { # type:ignore + "sub2": (Sub2, "Application class"), + "sub3": (lambda root: Sub3(parent=root, flag=True), "factory"), + } + + +class DeprecatedApp(Application): + override_called = False + parent_called = False + + def _config_changed(self, name, old, new): + self.override_called = True + + def _capture(*args): + self.parent_called = True + + with mock.patch.object(self.log, "debug", _capture): + super()._config_changed(name, old, new) + + +def test_deprecated_notifier(): + app = DeprecatedApp() + assert not app.override_called + assert not app.parent_called + app.config = Config({"A": {"b": "c"}}) + assert app.override_called + assert app.parent_called + + +def test_help_output(): + check_help_output(__name__) + + +def test_help_all_output(): + check_help_all_output(__name__) + + +def test_show_config_cli(): + out, err, ec = get_output_error_code([sys.executable, "-m", __name__, "--show-config"]) + assert ec == 0 + assert "show_config" not in out + + +def test_show_config_json_cli(): + out, err, ec = get_output_error_code([sys.executable, "-m", __name__, "--show-config-json"]) + assert ec == 0 + assert "show_config" not in out + + +def test_show_config(capsys): + cfg = Config() + cfg.MyApp.i = 5 + # don't show empty + cfg.OtherApp + + app = MyApp(config=cfg, show_config=True) + app.start() + out, err = capsys.readouterr() + assert "MyApp" in out + assert "i = 5" in out + assert "OtherApp" not in out + + +def test_show_config_json(capsys): + cfg = Config() + cfg.MyApp.i = 5 + cfg.OtherApp + + app = MyApp(config=cfg, show_config_json=True) + app.start() + out, err = capsys.readouterr() + displayed = json.loads(out) + assert Config(displayed) == cfg + + +def test_deep_alias(): + from traitlets import Int + from traitlets.config import Application, Configurable + + class Foo(Configurable): + val = Int(default_value=5).tag(config=True) + + class Bar(Configurable): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.foo = Foo(parent=self) + + class TestApp(Application): + name = "test" + + aliases = {"val": "Bar.Foo.val"} + classes = [Foo, Bar] + + def initialize(self, *args, **kwargs): + super().initialize(*args, **kwargs) + self.bar = Bar(parent=self) + + app = TestApp() + app.initialize(["--val=10"]) + assert app.bar.foo.val == 10 + assert len(list(app.emit_alias_help())) > 0 + + +def test_logging_config(tmp_path, capsys): + """We should be able to configure additional log handlers.""" + log_file = tmp_path / "log_file" + app = Application( + logging_config={ + "version": 1, + "handlers": { + "file": { + "class": "logging.FileHandler", + "level": "DEBUG", + "filename": str(log_file), + }, + }, + "loggers": { + "Application": { + "level": "DEBUG", + "handlers": ["console", "file"], + }, + }, + } + ) + # the default "console" handler + our new "file" handler + assert len(app.log.handlers) == 2 + + # log a couple of messages + app.log.info("info") + app.log.warning("warn") + + # test that log messages get written to the file + with open(log_file) as log_handle: + assert log_handle.read() == "info\nwarn\n" + + # test that log messages get written to stderr (default console handler) + assert capsys.readouterr().err == "[Application] WARNING | warn\n" + + +def test_get_default_logging_config_pythonw(monkeypatch): + """Ensure logging is correctly disabled for pythonw usage.""" + monkeypatch.setattr("traitlets.config.application.IS_PYTHONW", True) + config = Application().get_default_logging_config() + assert "handlers" not in config + assert "loggers" not in config + + monkeypatch.setattr("traitlets.config.application.IS_PYTHONW", False) + config = Application().get_default_logging_config() + assert "handlers" in config + assert "loggers" in config + + +@pytest.fixture +def caplogconfig(monkeypatch): + """Capture logging config events for DictConfigurator objects. + + This suppresses the event (so the configuration doesn't happen). + + Returns a list of (args, kwargs). + """ + calls = [] + + def _configure(*args, **kwargs): + nonlocal calls + calls.append((args, kwargs)) + + monkeypatch.setattr( + "logging.config.DictConfigurator.configure", + _configure, + ) + + return calls + + +@pytest.mark.skipif(sys.implementation.name == "pypy", reason="Test does not work on pypy") +def test_logging_teardown_on_error(capsys, caplogconfig): + """Ensure we don't try to open logs in order to close them (See #722). + + If you try to configure logging handlers whilst Python is shutting down + you may get traceback. + """ + # create and destroy an app (without configuring logging) + # (ensure that the logs were not configured) + app = Application() + del app + assert len(caplogconfig) == 0 # logging was not configured + out, err = capsys.readouterr() + assert "Traceback" not in err + + # ensure that the logs would have been configured otherwise + # (to prevent this test from yielding false-negatives) + app = Application() + app._logging_configured = True # make it look like logging was configured + del app + assert len(caplogconfig) == 1 # logging was configured + + +if __name__ == "__main__": + # for test_help_output: + MyApp.launch_instance() diff --git a/contrib/python/traitlets/py3/tests/config/test_argcomplete.py b/contrib/python/traitlets/py3/tests/config/test_argcomplete.py new file mode 100644 index 00000000000..52ed6d2bb2c --- /dev/null +++ b/contrib/python/traitlets/py3/tests/config/test_argcomplete.py @@ -0,0 +1,219 @@ +""" +Tests for argcomplete handling by traitlets.config.application.Application +""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +import io +import os +import typing as t + +import pytest + +argcomplete = pytest.importorskip("argcomplete") + +from traitlets import Unicode +from traitlets.config.application import Application +from traitlets.config.configurable import Configurable +from traitlets.config.loader import KVArgParseConfigLoader + + +class ArgcompleteApp(Application): + """Override loader to pass through kwargs for argcomplete testing""" + + argcomplete_kwargs: t.Dict[str, t.Any] + + def __init__(self, *args, **kwargs): + # For subcommands, inherit argcomplete_kwargs from parent app + parent = kwargs.get("parent") + super().__init__(*args, **kwargs) + if parent: + argcomplete_kwargs = getattr(parent, "argcomplete_kwargs", None) + if argcomplete_kwargs: + self.argcomplete_kwargs = argcomplete_kwargs + + def _create_loader(self, argv, aliases, flags, classes): + loader = KVArgParseConfigLoader( + argv, aliases, flags, classes=classes, log=self.log, subcommands=self.subcommands + ) + loader._argcomplete_kwargs = self.argcomplete_kwargs # type: ignore[attr-defined] + return loader + + +class SubApp1(ArgcompleteApp): + pass + + +class SubApp2(ArgcompleteApp): + @classmethod + def get_subapp_instance(cls, app: Application) -> Application: + app.clear_instance() # since Application is singleton, need to clear main app + return cls.instance(parent=app) # type: ignore[no-any-return] + + +class MainApp(ArgcompleteApp): + subcommands = { + "subapp1": (SubApp1, "First subapp"), + "subapp2": (SubApp2.get_subapp_instance, "Second subapp"), + } + + +class CustomError(Exception): + """Helper for exit hook for testing argcomplete""" + + @classmethod + def exit(cls, code): + raise cls(str(code)) + + +class TestArgcomplete: + IFS = "\013" + COMP_WORDBREAKS = " \t\n\"'><=;|&(:" + + @pytest.fixture + def argcomplete_on(self, mocker): + """Mostly borrowed from argcomplete's unit test fixtures + + Set up environment variables to mimic those passed by argcomplete + """ + _old_environ = os.environ + os.environ = os.environ.copy() # type: ignore[assignment] + os.environ["_ARGCOMPLETE"] = "1" + os.environ["_ARC_DEBUG"] = "yes" + os.environ["IFS"] = self.IFS + os.environ["_ARGCOMPLETE_COMP_WORDBREAKS"] = self.COMP_WORDBREAKS + + # argcomplete==2.0.0 always calls fdopen(9, "w") to open a debug stream, + # however this could conflict with file descriptors used by pytest + # and lead to obscure errors. Since we are not looking at debug stream + # in these tests, just mock this fdopen call out. + mocker.patch("os.fdopen") + try: + yield + finally: + os.environ = _old_environ + + def run_completer( + self, + app: ArgcompleteApp, + command: str, + point: t.Union[str, int, None] = None, + **kwargs: t.Any, + ) -> t.List[str]: + """Mostly borrowed from argcomplete's unit tests + + Modified to take an application instead of an ArgumentParser + + Command is the current command being completed and point is the index + into the command where the completion is triggered. + """ + if point is None: + point = str(len(command)) + # Flushing tempfile was leading to CI failures with Bad file descriptor, not sure why. + # Fortunately we can just write to a StringIO instead. + # print("Writing completions to temp file with mode=", write_mode) + # from tempfile import TemporaryFile + # with TemporaryFile(mode=write_mode) as t: + strio = io.StringIO() + os.environ["COMP_LINE"] = command + os.environ["COMP_POINT"] = str(point) + + with pytest.raises(CustomError) as cm: + app.argcomplete_kwargs = dict( + output_stream=strio, exit_method=CustomError.exit, **kwargs + ) + app.initialize() + + if str(cm.value) != "0": + raise RuntimeError(f"Unexpected exit code {cm.value}") + out = strio.getvalue() + return out.split(self.IFS) + + def test_complete_simple_app(self, argcomplete_on): + app = ArgcompleteApp() + expected = [ + '--help', + '--debug', + '--show-config', + '--show-config-json', + '--log-level', + '--Application.', + '--ArgcompleteApp.', + ] + assert set(self.run_completer(app, "app --")) == set(expected) + + # completing class traits + assert set(self.run_completer(app, "app --App")) > { + '--Application.show_config', + '--Application.log_level', + '--Application.log_format', + } + + def test_complete_custom_completers(self, argcomplete_on): + app = ArgcompleteApp() + # test pre-defined completers for Bool/Enum + assert set(self.run_completer(app, "app --Application.log_level=")) > {"DEBUG", "INFO"} + assert set(self.run_completer(app, "app --ArgcompleteApp.show_config ")) == { + "0", + "1", + "true", + "false", + } + + # test custom completer and mid-command completions + class CustomCls(Configurable): + val = Unicode().tag( + config=True, argcompleter=argcomplete.completers.ChoicesCompleter(["foo", "bar"]) + ) + + class CustomApp(ArgcompleteApp): + classes = [CustomCls] + aliases = {("v", "val"): "CustomCls.val"} + + app = CustomApp() + assert self.run_completer(app, "app --val ") == ["foo", "bar"] + assert self.run_completer(app, "app --val=") == ["foo", "bar"] + assert self.run_completer(app, "app -v ") == ["foo", "bar"] + assert self.run_completer(app, "app -v=") == ["foo", "bar"] + assert self.run_completer(app, "app --CustomCls.val ") == ["foo", "bar"] + assert self.run_completer(app, "app --CustomCls.val=") == ["foo", "bar"] + completions = self.run_completer(app, "app --val= abc xyz", point=10) + # fixed in argcomplete >= 2.0 to return latter below + assert completions == ["--val=foo", "--val=bar"] or completions == ["foo", "bar"] + assert self.run_completer(app, "app --val --log-level=", point=10) == ["foo", "bar"] + + def test_complete_subcommands(self, argcomplete_on): + app = MainApp() + assert set(self.run_completer(app, "app ")) >= {"subapp1", "subapp2"} + assert set(self.run_completer(app, "app sub")) == {"subapp1", "subapp2"} + assert set(self.run_completer(app, "app subapp1")) == {"subapp1"} + + def test_complete_subcommands_subapp1(self, argcomplete_on): + # subcommand handling modifies _ARGCOMPLETE env var global state, so + # only can test one completion per unit test + app = MainApp() + try: + assert set(self.run_completer(app, "app subapp1 --Sub")) > { + '--SubApp1.show_config', + '--SubApp1.log_level', + '--SubApp1.log_format', + } + finally: + SubApp1.clear_instance() + + def test_complete_subcommands_subapp2(self, argcomplete_on): + app = MainApp() + try: + assert set(self.run_completer(app, "app subapp2 --")) > { + '--Application.', + '--SubApp2.', + } + finally: + SubApp2.clear_instance() + + def test_complete_subcommands_main(self, argcomplete_on): + app = MainApp() + completions = set(self.run_completer(app, "app --")) + assert completions > {'--Application.', '--MainApp.'} + assert "--SubApp1." not in completions and "--SubApp2." not in completions diff --git a/contrib/python/traitlets/py3/tests/config/test_configurable.py b/contrib/python/traitlets/py3/tests/config/test_configurable.py new file mode 100644 index 00000000000..f6499ea29d1 --- /dev/null +++ b/contrib/python/traitlets/py3/tests/config/test_configurable.py @@ -0,0 +1,711 @@ +"""Tests for traitlets.config.configurable""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +import logging +from unittest import TestCase + +from pytest import mark + +from .._warnings import expected_warnings +from traitlets.config.application import Application +from traitlets.config.configurable import Configurable, LoggingConfigurable, SingletonConfigurable +from traitlets.config.loader import Config +from traitlets.log import get_logger +from traitlets.traitlets import ( + CaselessStrEnum, + Dict, + Enum, + Float, + FuzzyEnum, + Integer, + List, + Set, + Unicode, + validate, +) +from traitlets.utils.warnings import _deprecations_shown + + +class MyConfigurable(Configurable): + a = Integer(1, help="The integer a.").tag(config=True) + b = Float(1.0, help="The integer b.").tag(config=True) + c = Unicode("no config") + + +mc_help = """MyConfigurable(Configurable) options +------------------------------------ +--MyConfigurable.a=<Integer> + The integer a. + Default: 1 +--MyConfigurable.b=<Float> + The integer b. + Default: 1.0""" + +mc_help_inst = """MyConfigurable(Configurable) options +------------------------------------ +--MyConfigurable.a=<Integer> + The integer a. + Current: 5 +--MyConfigurable.b=<Float> + The integer b. + Current: 4.0""" + +# On Python 3, the Integer trait is a synonym for Int +mc_help = mc_help.replace("<Integer>", "<Int>") +mc_help_inst = mc_help_inst.replace("<Integer>", "<Int>") + + +class Foo(Configurable): + a = Integer(0, help="The integer a.").tag(config=True) + b = Unicode("nope").tag(config=True) + flist = List([]).tag(config=True) + fdict = Dict().tag(config=True) + + +class Bar(Foo): + b = Unicode("gotit", help="The string b.").tag(config=False) + c = Float(help="The string c.").tag(config=True) + bset = Set([]).tag(config=True, multiplicity="+") + bset_values = Set([2, 1, 5]).tag(config=True, multiplicity="+") + bdict = Dict().tag(config=True, multiplicity="+") + bdict_values = Dict({1: "a", "0": "b", 5: "c"}).tag(config=True, multiplicity="+") + + +foo_help = """Foo(Configurable) options +------------------------- +--Foo.a=<Int> + The integer a. + Default: 0 +--Foo.b=<Unicode> + Default: 'nope' +--Foo.fdict=<key-1>=<value-1>... + Default: {} +--Foo.flist=<list-item-1>... + Default: []""" + +bar_help = """Bar(Foo) options +---------------- +--Bar.a=<Int> + The integer a. + Default: 0 +--Bar.bdict <key-1>=<value-1>... + Default: {} +--Bar.bdict_values <key-1>=<value-1>... + Default: {1: 'a', '0': 'b', 5: 'c'} +--Bar.bset <set-item-1>... + Default: set() +--Bar.bset_values <set-item-1>... + Default: {1, 2, 5} +--Bar.c=<Float> + The string c. + Default: 0.0 +--Bar.fdict=<key-1>=<value-1>... + Default: {} +--Bar.flist=<list-item-1>... + Default: []""" + + +class TestConfigurable(TestCase): + def test_default(self): + c1 = Configurable() + c2 = Configurable(config=c1.config) + c3 = Configurable(config=c2.config) + self.assertEqual(c1.config, c2.config) + self.assertEqual(c2.config, c3.config) + + def test_custom(self): + config = Config() + config.foo = "foo" + config.bar = "bar" + c1 = Configurable(config=config) + c2 = Configurable(config=c1.config) + c3 = Configurable(config=c2.config) + self.assertEqual(c1.config, config) + self.assertEqual(c2.config, config) + self.assertEqual(c3.config, config) + # Test that copies are not made + self.assertTrue(c1.config is config) + self.assertTrue(c2.config is config) + self.assertTrue(c3.config is config) + self.assertTrue(c1.config is c2.config) + self.assertTrue(c2.config is c3.config) + + def test_inheritance(self): + config = Config() + config.MyConfigurable.a = 2 + config.MyConfigurable.b = 2.0 + c1 = MyConfigurable(config=config) + c2 = MyConfigurable(config=c1.config) + self.assertEqual(c1.a, config.MyConfigurable.a) + self.assertEqual(c1.b, config.MyConfigurable.b) + self.assertEqual(c2.a, config.MyConfigurable.a) + self.assertEqual(c2.b, config.MyConfigurable.b) + + def test_parent(self): + config = Config() + config.Foo.a = 10 + config.Foo.b = "wow" + config.Bar.b = "later" + config.Bar.c = 100.0 + f = Foo(config=config) + with expected_warnings(["`b` not recognized"]): + b = Bar(config=f.config) + self.assertEqual(f.a, 10) + self.assertEqual(f.b, "wow") + self.assertEqual(b.b, "gotit") + self.assertEqual(b.c, 100.0) + + def test_override1(self): + config = Config() + config.MyConfigurable.a = 2 + config.MyConfigurable.b = 2.0 + c = MyConfigurable(a=3, config=config) + self.assertEqual(c.a, 3) + self.assertEqual(c.b, config.MyConfigurable.b) + self.assertEqual(c.c, "no config") + + def test_override2(self): + config = Config() + config.Foo.a = 1 + config.Bar.b = "or" # Up above b is config=False, so this won't do it. + config.Bar.c = 10.0 + with expected_warnings(["`b` not recognized"]): + c = Bar(config=config) + self.assertEqual(c.a, config.Foo.a) + self.assertEqual(c.b, "gotit") + self.assertEqual(c.c, config.Bar.c) + with expected_warnings(["`b` not recognized"]): + c = Bar(a=2, b="and", c=20.0, config=config) + self.assertEqual(c.a, 2) + self.assertEqual(c.b, "and") + self.assertEqual(c.c, 20.0) + + def test_help(self): + self.assertEqual(MyConfigurable.class_get_help(), mc_help) + self.assertEqual(Foo.class_get_help(), foo_help) + self.assertEqual(Bar.class_get_help(), bar_help) + + def test_help_inst(self): + inst = MyConfigurable(a=5, b=4) + self.assertEqual(MyConfigurable.class_get_help(inst), mc_help_inst) + + def test_generated_config_enum_comments(self): + class MyConf(Configurable): + an_enum = Enum("Choice1 choice2".split(), help="Many choices.").tag(config=True) + + help_str = "Many choices." + enum_choices_str = "Choices: any of ['Choice1', 'choice2']" + rst_choices_str = "MyConf.an_enum : any of ``'Choice1'``|``'choice2'``" + or_none_str = "or None" + + cls_help = MyConf.class_get_help() + + self.assertIn(help_str, cls_help) + self.assertIn(enum_choices_str, cls_help) + self.assertNotIn(or_none_str, cls_help) + + cls_cfg = MyConf.class_config_section() + + self.assertIn(help_str, cls_cfg) + self.assertIn(enum_choices_str, cls_cfg) + self.assertNotIn(or_none_str, cls_help) + # Check order of Help-msg <--> Choices sections + self.assertGreater(cls_cfg.index(enum_choices_str), cls_cfg.index(help_str)) + + rst_help = MyConf.class_config_rst_doc() + + self.assertIn(help_str, rst_help) + self.assertIn(rst_choices_str, rst_help) + self.assertNotIn(or_none_str, rst_help) + + class MyConf2(Configurable): + an_enum = Enum( + "Choice1 choice2".split(), + allow_none=True, + default_value="choice2", + help="Many choices.", + ).tag(config=True) + + defaults_str = "Default: 'choice2'" + + cls2_msg = MyConf2.class_get_help() + + self.assertIn(help_str, cls2_msg) + self.assertIn(enum_choices_str, cls2_msg) + self.assertIn(or_none_str, cls2_msg) + self.assertIn(defaults_str, cls2_msg) + # Check order of Default <--> Choices sections + self.assertGreater(cls2_msg.index(defaults_str), cls2_msg.index(enum_choices_str)) + + cls2_cfg = MyConf2.class_config_section() + + self.assertIn(help_str, cls2_cfg) + self.assertIn(enum_choices_str, cls2_cfg) + self.assertIn(or_none_str, cls2_cfg) + self.assertIn(defaults_str, cls2_cfg) + # Check order of Default <--> Choices sections + self.assertGreater(cls2_cfg.index(defaults_str), cls2_cfg.index(enum_choices_str)) + + def test_generated_config_strenum_comments(self): + help_str = "Many choices." + defaults_str = "Default: 'choice2'" + or_none_str = "or None" + + class MyConf3(Configurable): + an_enum = CaselessStrEnum( + "Choice1 choice2".split(), + allow_none=True, + default_value="choice2", + help="Many choices.", + ).tag(config=True) + + enum_choices_str = "Choices: any of ['Choice1', 'choice2'] (case-insensitive)" + + cls3_msg = MyConf3.class_get_help() + + self.assertIn(help_str, cls3_msg) + self.assertIn(enum_choices_str, cls3_msg) + self.assertIn(or_none_str, cls3_msg) + self.assertIn(defaults_str, cls3_msg) + # Check order of Default <--> Choices sections + self.assertGreater(cls3_msg.index(defaults_str), cls3_msg.index(enum_choices_str)) + + cls3_cfg = MyConf3.class_config_section() + + self.assertIn(help_str, cls3_cfg) + self.assertIn(enum_choices_str, cls3_cfg) + self.assertIn(or_none_str, cls3_cfg) + self.assertIn(defaults_str, cls3_cfg) + # Check order of Default <--> Choices sections + self.assertGreater(cls3_cfg.index(defaults_str), cls3_cfg.index(enum_choices_str)) + + class MyConf4(Configurable): + an_enum = FuzzyEnum( + "Choice1 choice2".split(), + allow_none=True, + default_value="choice2", + help="Many choices.", + ).tag(config=True) + + enum_choices_str = "Choices: any case-insensitive prefix of ['Choice1', 'choice2']" + + cls4_msg = MyConf4.class_get_help() + + self.assertIn(help_str, cls4_msg) + self.assertIn(enum_choices_str, cls4_msg) + self.assertIn(or_none_str, cls4_msg) + self.assertIn(defaults_str, cls4_msg) + # Check order of Default <--> Choices sections + self.assertGreater(cls4_msg.index(defaults_str), cls4_msg.index(enum_choices_str)) + + cls4_cfg = MyConf4.class_config_section() + + self.assertIn(help_str, cls4_cfg) + self.assertIn(enum_choices_str, cls4_cfg) + self.assertIn(or_none_str, cls4_cfg) + self.assertIn(defaults_str, cls4_cfg) + # Check order of Default <--> Choices sections + self.assertGreater(cls4_cfg.index(defaults_str), cls4_cfg.index(enum_choices_str)) + + +class TestSingletonConfigurable(TestCase): + def test_instance(self): + class Foo(SingletonConfigurable): + pass + + self.assertEqual(Foo.initialized(), False) + foo = Foo.instance() + self.assertEqual(Foo.initialized(), True) + self.assertEqual(foo, Foo.instance()) + self.assertEqual(SingletonConfigurable._instance, None) + + def test_inheritance(self): + class Bar(SingletonConfigurable): + pass + + class Bam(Bar): + pass + + self.assertEqual(Bar.initialized(), False) + self.assertEqual(Bam.initialized(), False) + bam = Bam.instance() + self.assertEqual(Bar.initialized(), True) + self.assertEqual(Bam.initialized(), True) + self.assertEqual(bam, Bam._instance) + self.assertEqual(bam, Bar._instance) + self.assertEqual(SingletonConfigurable._instance, None) + + +class TestLoggingConfigurable(TestCase): + def test_parent_logger(self): + class Parent(LoggingConfigurable): + pass + + class Child(LoggingConfigurable): + pass + + log = get_logger().getChild("TestLoggingConfigurable") + + parent = Parent(log=log) + child = Child(parent=parent) + self.assertEqual(parent.log, log) + self.assertEqual(child.log, log) + + parent = Parent() + child = Child(parent=parent, log=log) + self.assertEqual(parent.log, get_logger()) + self.assertEqual(child.log, log) + + def test_parent_not_logging_configurable(self): + class Parent(Configurable): + pass + + class Child(LoggingConfigurable): + pass + + parent = Parent() + child = Child(parent=parent) + self.assertEqual(child.log, get_logger()) + + +class MyParent(Configurable): + pass + + +class MyParent2(MyParent): + pass + + +class TestParentConfigurable(TestCase): + def test_parent_config(self): + cfg = Config( + { + "MyParent": { + "MyConfigurable": { + "b": 2.0, + } + } + } + ) + parent = MyParent(config=cfg) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent.MyConfigurable.b) + + def test_parent_inheritance(self): + cfg = Config( + { + "MyParent": { + "MyConfigurable": { + "b": 2.0, + } + } + } + ) + parent = MyParent2(config=cfg) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent.MyConfigurable.b) + + def test_multi_parent(self): + cfg = Config( + { + "MyParent2": { + "MyParent": { + "MyConfigurable": { + "b": 2.0, + } + }, + # this one shouldn't count + "MyConfigurable": { + "b": 3.0, + }, + } + } + ) + parent2 = MyParent2(config=cfg) + parent = MyParent(parent=parent2) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent2.MyParent.MyConfigurable.b) + + def test_parent_priority(self): + cfg = Config( + { + "MyConfigurable": { + "b": 2.0, + }, + "MyParent": { + "MyConfigurable": { + "b": 3.0, + } + }, + "MyParent2": { + "MyConfigurable": { + "b": 4.0, + } + }, + } + ) + parent = MyParent2(config=cfg) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent2.MyConfigurable.b) + + def test_multi_parent_priority(self): + cfg = Config( + { + "MyConfigurable": { + "b": 2.0, + }, + "MyParent": { + "MyConfigurable": { + "b": 3.0, + }, + }, + "MyParent2": { + "MyConfigurable": { + "b": 4.0, + }, + "MyParent": { + "MyConfigurable": { + "b": 5.0, + }, + }, + }, + } + ) + parent2 = MyParent2(config=cfg) + parent = MyParent2(parent=parent2) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent2.MyParent.MyConfigurable.b) + + +class Containers(Configurable): + lis = List().tag(config=True) + + def _lis_default(self): + return [-1] + + s = Set().tag(config=True) + + def _s_default(self): + return {"a"} + + d = Dict().tag(config=True) + + def _d_default(self): + return {"a": "b"} + + +class TestConfigContainers(TestCase): + def test_extend(self): + c = Config() + c.Containers.lis.extend(list(range(5))) + obj = Containers(config=c) + self.assertEqual(obj.lis, list(range(-1, 5))) + + def test_insert(self): + c = Config() + c.Containers.lis.insert(0, "a") + c.Containers.lis.insert(1, "b") + obj = Containers(config=c) + self.assertEqual(obj.lis, ["a", "b", -1]) + + def test_prepend(self): + c = Config() + c.Containers.lis.prepend([1, 2]) + c.Containers.lis.prepend([2, 3]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [2, 3, 1, 2, -1]) + + def test_prepend_extend(self): + c = Config() + c.Containers.lis.prepend([1, 2]) + c.Containers.lis.extend([2, 3]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [1, 2, -1, 2, 3]) + + def test_append_extend(self): + c = Config() + c.Containers.lis.append([1, 2]) + c.Containers.lis.extend([2, 3]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [-1, [1, 2], 2, 3]) + + def test_extend_append(self): + c = Config() + c.Containers.lis.extend([2, 3]) + c.Containers.lis.append([1, 2]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [-1, 2, 3, [1, 2]]) + + def test_insert_extend(self): + c = Config() + c.Containers.lis.insert(0, 1) + c.Containers.lis.extend([2, 3]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [1, -1, 2, 3]) + + def test_set_update(self): + c = Config() + c.Containers.s.update({0, 1, 2}) + c.Containers.s.update({3}) + obj = Containers(config=c) + self.assertEqual(obj.s, {"a", 0, 1, 2, 3}) + + def test_dict_update(self): + c = Config() + c.Containers.d.update({"c": "d"}) + c.Containers.d.update({"e": "f"}) + obj = Containers(config=c) + self.assertEqual(obj.d, {"a": "b", "c": "d", "e": "f"}) + + def test_update_twice(self): + c = Config() + c.MyConfigurable.a = 5 + m = MyConfigurable(config=c) + self.assertEqual(m.a, 5) + + c2 = Config() + c2.MyConfigurable.a = 10 + m.update_config(c2) + self.assertEqual(m.a, 10) + + c2.MyConfigurable.a = 15 + m.update_config(c2) + self.assertEqual(m.a, 15) + + def test_update_self(self): + """update_config with same config object still triggers config_changed""" + c = Config() + c.MyConfigurable.a = 5 + m = MyConfigurable(config=c) + self.assertEqual(m.a, 5) + c.MyConfigurable.a = 10 + m.update_config(c) + self.assertEqual(m.a, 10) + + def test_config_default(self): + class SomeSingleton(SingletonConfigurable): + pass + + class DefaultConfigurable(Configurable): + a = Integer().tag(config=True) + + def _config_default(self): + if SomeSingleton.initialized(): + return SomeSingleton.instance().config + return Config() + + c = Config() + c.DefaultConfigurable.a = 5 + + d1 = DefaultConfigurable() + self.assertEqual(d1.a, 0) + + single = SomeSingleton.instance(config=c) + + d2 = DefaultConfigurable() + self.assertIs(d2.config, single.config) + self.assertEqual(d2.a, 5) + + def test_config_default_deprecated(self): + """Make sure configurables work even with the deprecations in traitlets""" + + class SomeSingleton(SingletonConfigurable): + pass + + # reset deprecation limiter + _deprecations_shown.clear() + with expected_warnings([]): + + class DefaultConfigurable(Configurable): + a = Integer(config=True) + + def _config_default(self): + if SomeSingleton.initialized(): + return SomeSingleton.instance().config + return Config() + + c = Config() + c.DefaultConfigurable.a = 5 + + d1 = DefaultConfigurable() + self.assertEqual(d1.a, 0) + + single = SomeSingleton.instance(config=c) + + d2 = DefaultConfigurable() + self.assertIs(d2.config, single.config) + self.assertEqual(d2.a, 5) + + def test_kwarg_config_priority(self): + # a, c set in kwargs + # a, b set in config + # verify that: + # - kwargs are set before config + # - kwargs have priority over config + class A(Configurable): + a = Unicode("default", config=True) + b = Unicode("default", config=True) + c = Unicode("default", config=True) + c_during_config = Unicode("never") + + @validate("b") + def _record_c(self, proposal): + # setting b from config records c's value at the time + self.c_during_config = self.c + return proposal.value + + cfg = Config() + cfg.A.a = "a-config" + cfg.A.b = "b-config" + obj = A(a="a-kwarg", c="c-kwarg", config=cfg) + assert obj.a == "a-kwarg" + assert obj.b == "b-config" + assert obj.c == "c-kwarg" + assert obj.c_during_config == "c-kwarg" + + +class TestLogger(TestCase): + class A(LoggingConfigurable): + foo = Integer(config=True) + bar = Integer(config=True) + baz = Integer(config=True) + + @mark.skipif(not hasattr(TestCase, "assertLogs"), reason="requires TestCase.assertLogs") + def test_warn_match(self): + logger = logging.getLogger("test_warn_match") + cfg = Config({"A": {"bat": 5}}) + with self.assertLogs(logger, logging.WARNING) as captured: + TestLogger.A(config=cfg, log=logger) + + output = "\n".join(captured.output) + self.assertIn("Did you mean one of: `bar, baz`?", output) + self.assertIn("Config option `bat` not recognized by `A`.", output) + + cfg = Config({"A": {"fool": 5}}) + with self.assertLogs(logger, logging.WARNING) as captured: + TestLogger.A(config=cfg, log=logger) + + output = "\n".join(captured.output) + self.assertIn("Config option `fool` not recognized by `A`.", output) + self.assertIn("Did you mean `foo`?", output) + + cfg = Config({"A": {"totally_wrong": 5}}) + with self.assertLogs(logger, logging.WARNING) as captured: + TestLogger.A(config=cfg, log=logger) + + output = "\n".join(captured.output) + self.assertIn("Config option `totally_wrong` not recognized by `A`.", output) + self.assertNotIn("Did you mean", output) + + +def test_logger_adapter(caplog, capsys): + logger = logging.getLogger("Application") + adapter = logging.LoggerAdapter(logger, {"key": "adapted"}) + + app = Application(log=adapter, log_level=logging.INFO) + app.log_format = "%(key)s %(message)s" + app.log.info("test message") + + assert "adapted test message" in capsys.readouterr().err diff --git a/contrib/python/traitlets/py3/tests/config/test_loader.py b/contrib/python/traitlets/py3/tests/config/test_loader.py new file mode 100644 index 00000000000..3a1f96120f7 --- /dev/null +++ b/contrib/python/traitlets/py3/tests/config/test_loader.py @@ -0,0 +1,753 @@ +"""Tests for traitlets.config.loader""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +import copy +import os +import pickle +from itertools import chain +from tempfile import mkstemp +from unittest import TestCase + +import pytest + +from traitlets import Dict, Integer, List, Tuple, Unicode +from traitlets.config import Configurable +from traitlets.config.loader import ( + ArgParseConfigLoader, + Config, + JSONFileConfigLoader, + KeyValueConfigLoader, + KVArgParseConfigLoader, + LazyConfigValue, + PyFileConfigLoader, +) + +pyfile = """ +c = get_config() +c.a=10 +c.b=20 +c.Foo.Bar.value=10 +c.Foo.Bam.value=list(range(10)) +c.D.C.value='hi there' +""" + +json1file = """ +{ + "version": 1, + "a": 10, + "b": 20, + "Foo": { + "Bam": { + "value": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ] + }, + "Bar": { + "value": 10 + } + }, + "D": { + "C": { + "value": "hi there" + } + } +} +""" + +# should not load +json2file = """ +{ + "version": 2 +} +""" + +import logging + +log = logging.getLogger("devnull") +log.setLevel(0) + + +class TestFileCL(TestCase): + def _check_conf(self, config): + self.assertEqual(config.a, 10) + self.assertEqual(config.b, 20) + self.assertEqual(config.Foo.Bar.value, 10) + self.assertEqual(config.Foo.Bam.value, list(range(10))) + self.assertEqual(config.D.C.value, "hi there") + + def test_python(self): + fd, fname = mkstemp(".py", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write(pyfile) + f.close() + # Unlink the file + cl = PyFileConfigLoader(fname, log=log) + config = cl.load_config() + self._check_conf(config) + + def test_json(self): + fd, fname = mkstemp(".json", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write(json1file) + f.close() + # Unlink the file + cl = JSONFileConfigLoader(fname, log=log) + config = cl.load_config() + self._check_conf(config) + + def test_context_manager(self): + fd, fname = mkstemp(".json", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write("{}") + f.close() + + cl = JSONFileConfigLoader(fname, log=log) + + value = "context_manager" + + with cl as c: + c.MyAttr.value = value + + self.assertEqual(cl.config.MyAttr.value, value) + + # check that another loader does see the change + _ = JSONFileConfigLoader(fname, log=log) + self.assertEqual(cl.config.MyAttr.value, value) + + def test_json_context_bad_write(self): + fd, fname = mkstemp(".json", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write("{}") + f.close() + + with JSONFileConfigLoader(fname, log=log) as config: + config.A.b = 1 + + with self.assertRaises(TypeError): + with JSONFileConfigLoader(fname, log=log) as config: + config.A.cant_json = lambda x: x + + loader = JSONFileConfigLoader(fname, log=log) + cfg = loader.load_config() + assert cfg.A.b == 1 + assert "cant_json" not in cfg.A + + def test_collision(self): + a = Config() + b = Config() + self.assertEqual(a.collisions(b), {}) + a.A.trait1 = 1 + b.A.trait2 = 2 + self.assertEqual(a.collisions(b), {}) + b.A.trait1 = 1 + self.assertEqual(a.collisions(b), {}) + b.A.trait1 = 0 + self.assertEqual( + a.collisions(b), + { + "A": { + "trait1": "1 ignored, using 0", + } + }, + ) + self.assertEqual( + b.collisions(a), + { + "A": { + "trait1": "0 ignored, using 1", + } + }, + ) + a.A.trait2 = 3 + self.assertEqual( + b.collisions(a), + { + "A": { + "trait1": "0 ignored, using 1", + "trait2": "2 ignored, using 3", + } + }, + ) + + def test_v2raise(self): + fd, fname = mkstemp(".json", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write(json2file) + f.close() + # Unlink the file + cl = JSONFileConfigLoader(fname, log=log) + with self.assertRaises(ValueError): + cl.load_config() + + +def _parse_int_or_str(v): + try: + return int(v) + except Exception: + return str(v) + + +class MyLoader1(ArgParseConfigLoader): + def _add_arguments(self, aliases=None, flags=None, classes=None): + p = self.parser + p.add_argument("-f", "--foo", dest="Global.foo", type=str) + p.add_argument("-b", dest="MyClass.bar", type=int) + p.add_argument("-n", dest="n", action="store_true") + p.add_argument("Global.bam", type=str) + p.add_argument("--list1", action="append", type=_parse_int_or_str) + p.add_argument("--list2", nargs="+", type=int) + + +class MyLoader2(ArgParseConfigLoader): + def _add_arguments(self, aliases=None, flags=None, classes=None): + subparsers = self.parser.add_subparsers(dest="subparser_name") + subparser1 = subparsers.add_parser("1") + subparser1.add_argument("-x", dest="Global.x") + subparser2 = subparsers.add_parser("2") + subparser2.add_argument("y") + + +class TestArgParseCL(TestCase): + def test_basic(self): + cl = MyLoader1() + config = cl.load_config("-f hi -b 10 -n wow".split()) + self.assertEqual(config.Global.foo, "hi") + self.assertEqual(config.MyClass.bar, 10) + self.assertEqual(config.n, True) + self.assertEqual(config.Global.bam, "wow") + config = cl.load_config(["wow"]) + self.assertEqual(list(config.keys()), ["Global"]) + self.assertEqual(list(config.Global.keys()), ["bam"]) + self.assertEqual(config.Global.bam, "wow") + + def test_add_arguments(self): + cl = MyLoader2() + config = cl.load_config("2 frobble".split()) + self.assertEqual(config.subparser_name, "2") + self.assertEqual(config.y, "frobble") + config = cl.load_config("1 -x frobble".split()) + self.assertEqual(config.subparser_name, "1") + self.assertEqual(config.Global.x, "frobble") + + def test_argv(self): + cl = MyLoader1(argv="-f hi -b 10 -n wow".split()) + config = cl.load_config() + self.assertEqual(config.Global.foo, "hi") + self.assertEqual(config.MyClass.bar, 10) + self.assertEqual(config.n, True) + self.assertEqual(config.Global.bam, "wow") + + def test_list_args(self): + cl = MyLoader1() + config = cl.load_config("--list1 1 wow --list2 1 2 3 --list1 B".split()) + self.assertEqual(list(config.Global.keys()), ["bam"]) + self.assertEqual(config.Global.bam, "wow") + self.assertEqual(config.list1, [1, "B"]) + self.assertEqual(config.list2, [1, 2, 3]) + + +class C(Configurable): + str_trait = Unicode(config=True) + int_trait = Integer(config=True) + list_trait = List(config=True) + list_of_ints = List(Integer(), config=True) + dict_trait = Dict(config=True) + dict_of_ints = Dict( + key_trait=Integer(), + value_trait=Integer(), + config=True, + ) + dict_multi = Dict( + key_trait=Unicode(), + per_key_traits={ + "int": Integer(), + "str": Unicode(), + }, + config=True, + ) + + +class TestKeyValueCL(TestCase): + klass = KeyValueConfigLoader + + def test_eval(self): + cl = self.klass(log=log) + config = cl.load_config( + '--C.str_trait=all --C.int_trait=5 --C.list_trait=["hello",5]'.split() + ) + c = C(config=config) + assert c.str_trait == "all" + assert c.int_trait == 5 + assert c.list_trait == ["hello", 5] + + def test_basic(self): + cl = self.klass(log=log) + argv = ["--" + s[2:] for s in pyfile.split("\n") if s.startswith("c.")] + config = cl.load_config(argv) + assert config.a == "10" + assert config.b == "20" + assert config.Foo.Bar.value == "10" + # non-literal expressions are not evaluated + self.assertEqual(config.Foo.Bam.value, "list(range(10))") + self.assertEqual(Unicode().from_string(config.D.C.value), "hi there") + + def test_expanduser(self): + cl = self.klass(log=log) + argv = ["--a=~/1/2/3", "--b=~", "--c=~/", '--d="~/"'] + config = cl.load_config(argv) + u = Unicode() + self.assertEqual(u.from_string(config.a), os.path.expanduser("~/1/2/3")) + self.assertEqual(u.from_string(config.b), os.path.expanduser("~")) + self.assertEqual(u.from_string(config.c), os.path.expanduser("~/")) + self.assertEqual(u.from_string(config.d), "~/") + + def test_extra_args(self): + cl = self.klass(log=log) + config = cl.load_config(["--a=5", "b", "d", "--c=10"]) + self.assertEqual(cl.extra_args, ["b", "d"]) + assert config.a == "5" + assert config.c == "10" + config = cl.load_config(["--", "--a=5", "--c=10"]) + self.assertEqual(cl.extra_args, ["--a=5", "--c=10"]) + + cl = self.klass(log=log) + config = cl.load_config(["extra", "--a=2", "--c=1", "--", "-"]) + self.assertEqual(cl.extra_args, ["extra", "-"]) + + def test_unicode_args(self): + cl = self.klass(log=log) + argv = ["--a=épsîlön"] + config = cl.load_config(argv) + print(config, cl.extra_args) + self.assertEqual(config.a, "épsîlön") + + def test_list_append(self): + cl = self.klass(log=log) + argv = ["--C.list_trait", "x", "--C.list_trait", "y"] + config = cl.load_config(argv) + assert config.C.list_trait == ["x", "y"] + c = C(config=config) + assert c.list_trait == ["x", "y"] + + def test_list_single_item(self): + cl = self.klass(log=log) + argv = ["--C.list_trait", "x"] + config = cl.load_config(argv) + c = C(config=config) + assert c.list_trait == ["x"] + + def test_dict(self): + cl = self.klass(log=log) + argv = ["--C.dict_trait", "x=5", "--C.dict_trait", "y=10"] + config = cl.load_config(argv) + c = C(config=config) + assert c.dict_trait == {"x": "5", "y": "10"} + + def test_dict_key_traits(self): + cl = self.klass(log=log) + argv = ["--C.dict_of_ints", "1=2", "--C.dict_of_ints", "3=4"] + config = cl.load_config(argv) + c = C(config=config) + assert c.dict_of_ints == {1: 2, 3: 4} + + +class CBase(Configurable): + a = List().tag(config=True) + b = List(Integer()).tag(config=True, multiplicity="*") + c = List().tag(config=True, multiplicity="append") + adict = Dict().tag(config=True) + + +class CSub(CBase): + d = Tuple().tag(config=True) + e = Tuple().tag(config=True, multiplicity="+") + bdict = Dict().tag(config=True, multiplicity="*") + + +class TestArgParseKVCL(TestKeyValueCL): + klass = KVArgParseConfigLoader # type:ignore + + def test_no_cast_literals(self): + cl = self.klass(log=log) # type:ignore + # test ipython -c 1 doesn't cast to int + argv = ["-c", "1"] + config = cl.load_config(argv, aliases=dict(c="IPython.command_to_run")) + assert config.IPython.command_to_run == "1" + + def test_int_literals(self): + cl = self.klass(log=log) # type:ignore + # test ipython -c 1 doesn't cast to int + argv = ["-c", "1"] + config = cl.load_config(argv, aliases=dict(c="IPython.command_to_run")) + assert config.IPython.command_to_run == "1" + + def test_unicode_alias(self): + cl = self.klass(log=log) # type:ignore + argv = ["--a=épsîlön"] + config = cl.load_config(argv, aliases=dict(a="A.a")) + print(dict(config)) + print(cl.extra_args) + print(cl.aliases) + self.assertEqual(config.A.a, "épsîlön") + + def test_expanduser2(self): + cl = self.klass(log=log) # type:ignore + argv = ["-a", "~/1/2/3", "--b", "'~/1/2/3'"] + config = cl.load_config(argv, aliases=dict(a="A.a", b="A.b")) + + class A(Configurable): + a = Unicode(config=True) + b = Unicode(config=True) + + a = A(config=config) + self.assertEqual(a.a, os.path.expanduser("~/1/2/3")) + self.assertEqual(a.b, "~/1/2/3") + + def test_eval(self): + cl = self.klass(log=log) # type:ignore + argv = ["-c", "a=5"] + config = cl.load_config(argv, aliases=dict(c="A.c")) + self.assertEqual(config.A.c, "a=5") + + def test_seq_traits(self): + cl = self.klass(log=log, classes=(CBase, CSub)) # type:ignore + aliases = {"a3": "CBase.c", "a5": "CSub.e"} + argv = ( + "--CBase.a A --CBase.a 2 --CBase.b 1 --CBase.b 3 --a3 AA --CBase.c BB " + "--CSub.d 1 --CSub.d BBB --CSub.e 1 --CSub.e=bcd a b c " + ).split() + config = cl.load_config(argv, aliases=aliases) + assert cl.extra_args == ["a", "b", "c"] + assert config.CBase.a == ["A", "2"] + assert config.CBase.b == [1, 3] + self.assertEqual(config.CBase.c, ["AA", "BB"]) + + assert config.CSub.d == ("1", "BBB") + assert config.CSub.e == ("1", "bcd") + + def test_seq_traits_single_empty_string(self): + cl = self.klass(log=log, classes=(CBase,)) # type:ignore + aliases = {"seqopt": "CBase.c"} + argv = ["--seqopt", ""] + config = cl.load_config(argv, aliases=aliases) + self.assertEqual(config.CBase.c, [""]) + + def test_dict_traits(self): + cl = self.klass(log=log, classes=(CBase, CSub)) # type:ignore + aliases = {"D": "CBase.adict", "E": "CSub.bdict"} + argv = ["-D", "k1=v1", "-D=k2=2", "-D", "k3=v 3", "-E", "k=v", "-E", "22=222"] + config = cl.load_config(argv, aliases=aliases) + c = CSub(config=config) + assert c.adict == {"k1": "v1", "k2": "2", "k3": "v 3"} + assert c.bdict == {"k": "v", "22": "222"} + + def test_mixed_seq_positional(self): + aliases = {"c": "Class.trait"} + cl = self.klass(log=log, aliases=aliases) # type:ignore + assignments = [("-c", "1"), ("--Class.trait=2",), ("--c=3",), ("--Class.trait", "4")] + positionals = ["a", "b", "c"] + # test with positionals at any index + for idx in range(len(assignments) + 1): + argv_parts = assignments[:] + argv_parts[idx:idx] = (positionals,) # type:ignore + argv = list(chain(*argv_parts)) + + config = cl.load_config(argv) + assert config.Class.trait == ["1", "2", "3", "4"] + assert cl.extra_args == ["a", "b", "c"] + + def test_split_positional(self): + """Splitting positionals across flags is no longer allowed in traitlets 5""" + cl = self.klass(log=log) # type:ignore + argv = ["a", "--Class.trait=5", "b"] + with pytest.raises(SystemExit): + cl.load_config(argv) + + +class TestConfig(TestCase): + def test_setget(self): + c = Config() + c.a = 10 + self.assertEqual(c.a, 10) + self.assertEqual("b" in c, False) + + def test_auto_section(self): + c = Config() + self.assertNotIn("A", c) + assert not c._has_section("A") + A = c.A + A.foo = "hi there" + self.assertIn("A", c) + assert c._has_section("A") + self.assertEqual(c.A.foo, "hi there") + del c.A + self.assertEqual(c.A, Config()) + + def test_merge_doesnt_exist(self): + c1 = Config() + c2 = Config() + c2.bar = 10 + c2.Foo.bar = 10 + c1.merge(c2) + self.assertEqual(c1.Foo.bar, 10) + self.assertEqual(c1.bar, 10) + c2.Bar.bar = 10 + c1.merge(c2) + self.assertEqual(c1.Bar.bar, 10) + + def test_merge_exists(self): + c1 = Config() + c2 = Config() + c1.Foo.bar = 10 + c1.Foo.bam = 30 + c2.Foo.bar = 20 + c2.Foo.wow = 40 + c1.merge(c2) + self.assertEqual(c1.Foo.bam, 30) + self.assertEqual(c1.Foo.bar, 20) + self.assertEqual(c1.Foo.wow, 40) + c2.Foo.Bam.bam = 10 + c1.merge(c2) + self.assertEqual(c1.Foo.Bam.bam, 10) + + def test_deepcopy(self): + c1 = Config() + c1.Foo.bar = 10 + c1.Foo.bam = 30 + c1.a = "asdf" + c1.b = range(10) + c1.Test.logger = logging.Logger("test") + c1.Test.get_logger = logging.getLogger("test") + c2 = copy.deepcopy(c1) + self.assertEqual(c1, c2) + self.assertTrue(c1 is not c2) + self.assertTrue(c1.Foo is not c2.Foo) + self.assertTrue(c1.Test is not c2.Test) + self.assertTrue(c1.Test.logger is c2.Test.logger) + self.assertTrue(c1.Test.get_logger is c2.Test.get_logger) + + def test_builtin(self): + c1 = Config() + c1.format = "json" + + def test_fromdict(self): + c1 = Config({"Foo": {"bar": 1}}) + self.assertEqual(c1.Foo.__class__, Config) + self.assertEqual(c1.Foo.bar, 1) + + def test_fromdictmerge(self): + c1 = Config() + c2 = Config({"Foo": {"bar": 1}}) + c1.merge(c2) + self.assertEqual(c1.Foo.__class__, Config) + self.assertEqual(c1.Foo.bar, 1) + + def test_fromdictmerge2(self): + c1 = Config({"Foo": {"baz": 2}}) + c2 = Config({"Foo": {"bar": 1}}) + c1.merge(c2) + self.assertEqual(c1.Foo.__class__, Config) + self.assertEqual(c1.Foo.bar, 1) + self.assertEqual(c1.Foo.baz, 2) + self.assertNotIn("baz", c2.Foo) + + def test_contains(self): + c1 = Config({"Foo": {"baz": 2}}) + c2 = Config({"Foo": {"bar": 1}}) + self.assertIn("Foo", c1) + self.assertIn("Foo.baz", c1) + self.assertIn("Foo.bar", c2) + self.assertNotIn("Foo.bar", c1) + + def test_pickle_config(self): + cfg = Config() + cfg.Foo.bar = 1 + pcfg = pickle.dumps(cfg) + cfg2 = pickle.loads(pcfg) + self.assertEqual(cfg2, cfg) + + def test_getattr_section(self): + cfg = Config() + self.assertNotIn("Foo", cfg) + Foo = cfg.Foo + assert isinstance(Foo, Config) + self.assertIn("Foo", cfg) + + def test_getitem_section(self): + cfg = Config() + self.assertNotIn("Foo", cfg) + Foo = cfg["Foo"] + assert isinstance(Foo, Config) + self.assertIn("Foo", cfg) + + def test_getattr_not_section(self): + cfg = Config() + self.assertNotIn("foo", cfg) + foo = cfg.foo + assert isinstance(foo, LazyConfigValue) + self.assertIn("foo", cfg) + + def test_getattr_private_missing(self): + cfg = Config() + self.assertNotIn("_repr_html_", cfg) + with self.assertRaises(AttributeError): + _ = cfg._repr_html_ + self.assertNotIn("_repr_html_", cfg) + self.assertEqual(len(cfg), 0) + + def test_lazy_config_repr(self): + cfg = Config() + cfg.Class.lazy.append(1) + cfg_repr = repr(cfg) + assert "<LazyConfigValue" in cfg_repr + assert "extend" in cfg_repr + assert " [1]}>" in cfg_repr + assert "value=" not in cfg_repr + cfg.Class.lazy.get_value([0]) + repr2 = repr(cfg) + assert repr([0, 1]) in repr2 + assert "value=" in repr2 + + def test_getitem_not_section(self): + cfg = Config() + self.assertNotIn("foo", cfg) + foo = cfg["foo"] + assert isinstance(foo, LazyConfigValue) + self.assertIn("foo", cfg) + + def test_merge_no_copies(self): + c = Config() + c2 = Config() + c2.Foo.trait = [] + c.merge(c2) + c2.Foo.trait.append(1) + self.assertIs(c.Foo, c2.Foo) + self.assertEqual(c.Foo.trait, [1]) + self.assertEqual(c2.Foo.trait, [1]) + + def test_merge_multi_lazy(self): + """ + With multiple config files (systemwide and users), we want compounding. + + If systemwide overwirte and user append, we want both in the right + order. + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait = [1] + c2.Foo.trait.append(2) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait, [1, 2]) + + def test_merge_multi_lazyII(self): + """ + With multiple config files (systemwide and users), we want compounding. + + If both are lazy we still want a lazy config. + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait.append(1) + c2.Foo.trait.append(2) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait._extend, [1, 2]) + + def test_merge_multi_lazy_III(self): + """ + With multiple config files (systemwide and users), we want compounding. + + Prepend should prepend in the right order. + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait = [1] + c2.Foo.trait.prepend([0]) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait, [0, 1]) + + def test_merge_multi_lazy_IV(self): + """ + With multiple config files (systemwide and users), we want compounding. + + Both prepending should be lazy + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait.prepend([1]) + c2.Foo.trait.prepend([0]) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait._prepend, [0, 1]) + + def test_merge_multi_lazy_update_I(self): + """ + With multiple config files (systemwide and users), we want compounding. + + dict update shoudl be in the right order. + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait = {"a": 1, "z": 26} + c2.Foo.trait.update({"a": 0, "b": 1}) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait, {"a": 0, "b": 1, "z": 26}) + + def test_merge_multi_lazy_update_II(self): + """ + With multiple config files (systemwide and users), we want compounding. + + Later dict overwrite lazyness + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait.update({"a": 0, "b": 1}) + c2.Foo.trait = {"a": 1, "z": 26} + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait, {"a": 1, "z": 26}) + + def test_merge_multi_lazy_update_III(self): + """ + With multiple config files (systemwide and users), we want compounding. + + Later dict overwrite lazyness + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait.update({"a": 0, "b": 1}) + c2.Foo.trait.update({"a": 1, "z": 26}) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait._update, {"a": 1, "z": 26, "b": 1}) diff --git a/contrib/python/traitlets/py3/tests/test_traitlets.py b/contrib/python/traitlets/py3/tests/test_traitlets.py new file mode 100644 index 00000000000..62fa726f19b --- /dev/null +++ b/contrib/python/traitlets/py3/tests/test_traitlets.py @@ -0,0 +1,3141 @@ +"""Tests for traitlets.traitlets.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +# +# Adapted from enthought.traits, Copyright (c) Enthought, Inc., +# also under the terms of the Modified BSD License. + +import pickle +import re +import typing as t +from unittest import TestCase + +import pytest + +from traitlets import ( + All, + Any, + BaseDescriptor, + Bool, + Bytes, + Callable, + CBytes, + CFloat, + CInt, + CLong, + Complex, + CRegExp, + CUnicode, + Dict, + DottedObjectName, + Enum, + Float, + ForwardDeclaredInstance, + ForwardDeclaredType, + HasDescriptors, + HasTraits, + Instance, + Int, + Integer, + List, + Long, + MetaHasTraits, + ObjectName, + Set, + TCPAddress, + This, + TraitError, + TraitType, + Tuple, + Type, + Undefined, + Unicode, + Union, + default, + directional_link, + link, + observe, + observe_compat, + traitlets, + validate, +) +from traitlets.utils import cast_unicode + +from ._warnings import expected_warnings + + +def change_dict(*ordered_values): + change_names = ("name", "old", "new", "owner", "type") + return dict(zip(change_names, ordered_values)) + + +# ----------------------------------------------------------------------------- +# Helper classes for testing +# ----------------------------------------------------------------------------- + + +class HasTraitsStub(HasTraits): + def notify_change(self, change): + self._notify_name = change["name"] + self._notify_old = change["old"] + self._notify_new = change["new"] + self._notify_type = change["type"] + + +class CrossValidationStub(HasTraits): + _cross_validation_lock = False + + +# ----------------------------------------------------------------------------- +# Test classes +# ----------------------------------------------------------------------------- + + +class TestTraitType(TestCase): + def test_get_undefined(self): + class A(HasTraits): + a = TraitType + + a = A() + assert a.a is Undefined # type:ignore + + def test_set(self): + class A(HasTraitsStub): + a = TraitType + + a = A() + a.a = 10 # type:ignore + self.assertEqual(a.a, 10) + self.assertEqual(a._notify_name, "a") + self.assertEqual(a._notify_old, Undefined) + self.assertEqual(a._notify_new, 10) + + def test_validate(self): + class MyTT(TraitType[int, int]): + def validate(self, inst, value): + return -1 + + class A(HasTraitsStub): + tt = MyTT + + a = A() + a.tt = 10 # type:ignore + self.assertEqual(a.tt, -1) + + a = A(tt=11) + self.assertEqual(a.tt, -1) + + def test_default_validate(self): + class MyIntTT(TraitType[int, int]): + def validate(self, obj, value): + if isinstance(value, int): + return value + self.error(obj, value) + + class A(HasTraits): + tt = MyIntTT(10) + + a = A() + self.assertEqual(a.tt, 10) + + # Defaults are validated when the HasTraits is instantiated + class B(HasTraits): + tt = MyIntTT("bad default") + + self.assertRaises(TraitError, getattr, B(), "tt") + + def test_info(self): + class A(HasTraits): + tt = TraitType + + a = A() + self.assertEqual(A.tt.info(), "any value") # type:ignore + + def test_error(self): + class A(HasTraits): + tt = TraitType[int, int]() + + a = A() + self.assertRaises(TraitError, A.tt.error, a, 10) + + def test_deprecated_dynamic_initializer(self): + class A(HasTraits): + x = Int(10) + + def _x_default(self): + return 11 + + class B(A): + x = Int(20) + + class C(A): + def _x_default(self): + return 21 + + a = A() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + b = B() + self.assertEqual(b.x, 20) + self.assertEqual(b._trait_values, {"x": 20}) + c = C() + self.assertEqual(c._trait_values, {}) + self.assertEqual(c.x, 21) + self.assertEqual(c._trait_values, {"x": 21}) + # Ensure that the base class remains unmolested when the _default + # initializer gets overridden in a subclass. + a = A() + c = C() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + + def test_deprecated_method_warnings(self): + with expected_warnings([]): + + class ShouldntWarn(HasTraits): + x = Integer() + + @default("x") + def _x_default(self): + return 10 + + @validate("x") + def _x_validate(self, proposal): + return proposal.value + + @observe("x") + def _x_changed(self, change): + pass + + obj = ShouldntWarn() + obj.x = 5 + + assert obj.x == 5 + + with expected_warnings(["@validate", "@observe"]) as w: + + class ShouldWarn(HasTraits): + x = Integer() + + def _x_default(self): + return 10 + + def _x_validate(self, value, _): + return value + + def _x_changed(self): + pass + + obj = ShouldWarn() # type:ignore + obj.x = 5 + + assert obj.x == 5 + + def test_dynamic_initializer(self): + class A(HasTraits): + x = Int(10) + + @default("x") + def _default_x(self): + return 11 + + class B(A): + x = Int(20) + + class C(A): + @default("x") + def _default_x(self): + return 21 + + a = A() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + b = B() + self.assertEqual(b.x, 20) + self.assertEqual(b._trait_values, {"x": 20}) + c = C() + self.assertEqual(c._trait_values, {}) + self.assertEqual(c.x, 21) + self.assertEqual(c._trait_values, {"x": 21}) + # Ensure that the base class remains unmolested when the _default + # initializer gets overridden in a subclass. + a = A() + c = C() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + + def test_tag_metadata(self): + class MyIntTT(TraitType[int, int]): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10).tag(b=3, c=4) + self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4}) + + def test_metadata_localized_instance(self): + class MyIntTT(TraitType[int, int]): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10) + b = MyIntTT(10) + a.metadata["c"] = 3 + # make sure that changing a's metadata didn't change b's metadata + self.assertNotIn("c", b.metadata) + + def test_union_metadata(self): + class Foo(HasTraits): + bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti="b")).tag(ti="a") + + foo = Foo() + # At this point, no value has been set for bar, so value-specific + # is not set. + self.assertEqual(foo.trait_metadata("bar", "ta"), None) + self.assertEqual(foo.trait_metadata("bar", "ti"), "a") + foo.bar = {} + self.assertEqual(foo.trait_metadata("bar", "ta"), 2) + self.assertEqual(foo.trait_metadata("bar", "ti"), "b") + foo.bar = 1 + self.assertEqual(foo.trait_metadata("bar", "ta"), 1) + self.assertEqual(foo.trait_metadata("bar", "ti"), "a") + + def test_union_default_value(self): + class Foo(HasTraits): + bar = Union([Dict(), Int()], default_value=1) + + foo = Foo() + self.assertEqual(foo.bar, 1) + + def test_union_validation_priority(self): + class Foo(HasTraits): + bar = Union([CInt(), Unicode()]) + + foo = Foo() + foo.bar = "1" + # validation in order of the TraitTypes given + self.assertEqual(foo.bar, 1) + + def test_union_trait_default_value(self): + class Foo(HasTraits): + bar = Union([Dict(), Int()]) + + self.assertEqual(Foo().bar, {}) + + def test_deprecated_metadata_access(self): + class MyIntTT(TraitType[int, int]): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10) + with expected_warnings(["use the instance .metadata dictionary directly"] * 2): + a.set_metadata("key", "value") + v = a.get_metadata("key") + self.assertEqual(v, "value") + with expected_warnings(["use the instance .help string directly"] * 2): + a.set_metadata("help", "some help") + v = a.get_metadata("help") + self.assertEqual(v, "some help") + + def test_trait_types_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Int + + def test_trait_types_list_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = List(Int) + + def test_trait_types_tuple_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Tuple(Int) + + def test_trait_types_dict_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Dict(Int) + + +class TestHasDescriptorsMeta(TestCase): + def test_metaclass(self): + self.assertEqual(type(HasTraits), MetaHasTraits) + + class A(HasTraits): + a = Int() + + a = A() + self.assertEqual(type(a.__class__), MetaHasTraits) + self.assertEqual(a.a, 0) + a.a = 10 + self.assertEqual(a.a, 10) + + class B(HasTraits): + b = Int() + + b = B() + self.assertEqual(b.b, 0) + b.b = 10 + self.assertEqual(b.b, 10) + + class C(HasTraits): + c = Int(30) + + c = C() + self.assertEqual(c.c, 30) + c.c = 10 + self.assertEqual(c.c, 10) + + def test_this_class(self): + class A(HasTraits): + t = This["A"]() + tt = This["A"]() + + class B(A): + tt = This["A"]() + ttt = This["A"]() + + self.assertEqual(A.t.this_class, A) + self.assertEqual(B.t.this_class, A) + self.assertEqual(B.tt.this_class, B) + self.assertEqual(B.ttt.this_class, B) + + +class TestHasDescriptors(TestCase): + def test_setup_instance(self): + class FooDescriptor(BaseDescriptor): + def instance_init(self, inst): + foo = inst.foo # instance should have the attr + + class HasFooDescriptors(HasDescriptors): + fd = FooDescriptor() + + def setup_instance(self, *args, **kwargs): + self.foo = kwargs.get("foo", None) + super().setup_instance(*args, **kwargs) + + hfd = HasFooDescriptors(foo="bar") + + +class TestHasTraitsNotify(TestCase): + def setUp(self): + self._notify1 = [] + self._notify2 = [] + + def notify1(self, name, old, new): + self._notify1.append((name, old, new)) + + def notify2(self, name, old, new): + self._notify2.append((name, old, new)) + + def test_notify_all(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.on_trait_change(self.notify1) + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.b = 0.0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in self._notify1) + a.b = 10.0 + self.assertTrue(("b", 0.0, 10.0) in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + self.assertRaises(TraitError, setattr, a, "b", "bad string") + self._notify1 = [] + a.on_trait_change(self.notify1, remove=True) + a.a = 20 + a.b = 20.0 + self.assertEqual(len(self._notify1), 0) + + def test_notify_one(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.on_trait_change(self.notify1, "a") + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + + def test_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + self.assertEqual(b.a, 0) + self.assertEqual(b.b, 0.0) + b.a = 100 + b.b = 100.0 + self.assertEqual(b.a, 100) + self.assertEqual(b.b, 100.0) + + def test_notify_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + b.on_trait_change(self.notify1, "a") + b.on_trait_change(self.notify2, "b") + b.a = 0 + b.b = 0.0 + self.assertEqual(len(self._notify1), 0) + self.assertEqual(len(self._notify2), 0) + b.a = 10 + b.b = 10.0 + self.assertTrue(("a", 0, 10) in self._notify1) + self.assertTrue(("b", 0.0, 10.0) in self._notify2) + + def test_static_notify(self): + class A(HasTraits): + a = Int() + _notify1 = [] + + def _a_changed(self, name, old, new): + self._notify1.append((name, old, new)) + + a = A() + a.a = 0 + # This is broken!!! + self.assertEqual(len(a._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in a._notify1) + + class B(A): + b = Float() + _notify2 = [] + + def _b_changed(self, name, old, new): + self._notify2.append((name, old, new)) + + b = B() + b.a = 10 + b.b = 10.0 + self.assertTrue(("a", 0, 10) in b._notify1) + self.assertTrue(("b", 0.0, 10.0) in b._notify2) + + def test_notify_args(self): + def callback0(): + self.cb = () + + def callback1(name): + self.cb = (name,) # type:ignore + + def callback2(name, new): + self.cb = (name, new) # type:ignore + + def callback3(name, old, new): + self.cb = (name, old, new) # type:ignore + + def callback4(name, old, new, obj): + self.cb = (name, old, new, obj) # type:ignore + + class A(HasTraits): + a = Int() + + a = A() + a.on_trait_change(callback0, "a") + a.a = 10 + self.assertEqual(self.cb, ()) + a.on_trait_change(callback0, "a", remove=True) + + a.on_trait_change(callback1, "a") + a.a = 100 + self.assertEqual(self.cb, ("a",)) + a.on_trait_change(callback1, "a", remove=True) + + a.on_trait_change(callback2, "a") + a.a = 1000 + self.assertEqual(self.cb, ("a", 1000)) + a.on_trait_change(callback2, "a", remove=True) + + a.on_trait_change(callback3, "a") + a.a = 10000 + self.assertEqual(self.cb, ("a", 1000, 10000)) + a.on_trait_change(callback3, "a", remove=True) + + a.on_trait_change(callback4, "a") + a.a = 100000 + self.assertEqual(self.cb, ("a", 10000, 100000, a)) + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) + a.on_trait_change(callback4, "a", remove=True) + + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) + + def test_notify_only_once(self): + class A(HasTraits): + listen_to = ["a"] + + a = Int(0) + b = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.on_trait_change(self.listener1, ["a"]) + + def listener1(self, name, old, new): + self.b += 1 + + class B(A): + c = 0 + d = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.on_trait_change(self.listener2) + + def listener2(self, name, old, new): + self.c += 1 + + def _a_changed(self, name, old, new): + self.d += 1 + + b = B() + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + + +class TestObserveDecorator(TestCase): + def setUp(self): + self._notify1 = [] + self._notify2 = [] + + def notify1(self, change): + self._notify1.append(change) + + def notify2(self, change): + self._notify2.append(change) + + def test_notify_all(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.observe(self.notify1) + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.b = 0.0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in self._notify1) + a.b = 10.0 + change = change_dict("b", 0.0, 10.0, a, "change") + self.assertTrue(change in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + self.assertRaises(TraitError, setattr, a, "b", "bad string") + self._notify1 = [] + a.unobserve(self.notify1) + a.a = 20 + a.b = 20.0 + self.assertEqual(len(self._notify1), 0) + + def test_notify_one(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.observe(self.notify1, "a") + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + + def test_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + self.assertEqual(b.a, 0) + self.assertEqual(b.b, 0.0) + b.a = 100 + b.b = 100.0 + self.assertEqual(b.a, 100) + self.assertEqual(b.b, 100.0) + + def test_notify_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + b.observe(self.notify1, "a") + b.observe(self.notify2, "b") + b.a = 0 + b.b = 0.0 + self.assertEqual(len(self._notify1), 0) + self.assertEqual(len(self._notify2), 0) + b.a = 10 + b.b = 10.0 + change = change_dict("a", 0, 10, b, "change") + self.assertTrue(change in self._notify1) + change = change_dict("b", 0.0, 10.0, b, "change") + self.assertTrue(change in self._notify2) + + def test_static_notify(self): + class A(HasTraits): + a = Int() + b = Int() + _notify1 = [] + _notify_any = [] + + @observe("a") + def _a_changed(self, change): + self._notify1.append(change) + + @observe(All) + def _any_changed(self, change): + self._notify_any.append(change) + + a = A() + a.a = 0 + self.assertEqual(len(a._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in a._notify1) + a.b = 1 + self.assertEqual(len(a._notify_any), 2) + change = change_dict("b", 0, 1, a, "change") + self.assertTrue(change in a._notify_any) + + class B(A): + b = Float() # type:ignore + _notify2 = [] + + @observe("b") + def _b_changed(self, change): + self._notify2.append(change) + + b = B() + b.a = 10 + b.b = 10.0 # type:ignore + change = change_dict("a", 0, 10, b, "change") + self.assertTrue(change in b._notify1) + change = change_dict("b", 0.0, 10.0, b, "change") + self.assertTrue(change in b._notify2) + + def test_notify_args(self): + def callback0(): + self.cb = () + + def callback1(change): + self.cb = change + + class A(HasTraits): + a = Int() + + a = A() + a.on_trait_change(callback0, "a") + a.a = 10 + self.assertEqual(self.cb, ()) + a.unobserve(callback0, "a") + + a.observe(callback1, "a") + a.a = 100 + change = change_dict("a", 10, 100, a, "change") + self.assertEqual(self.cb, change) + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) + a.unobserve(callback1, "a") + + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) + + def test_notify_only_once(self): + class A(HasTraits): + listen_to = ["a"] + + a = Int(0) + b = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.observe(self.listener1, ["a"]) + + def listener1(self, change): + self.b += 1 + + class B(A): + c = 0 + d = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.observe(self.listener2) + + def listener2(self, change): + self.c += 1 + + @observe("a") + def _a_changed(self, change): + self.d += 1 + + b = B() + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + + +class TestHasTraits(TestCase): + def test_trait_names(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertEqual(sorted(a.trait_names()), ["f", "i"]) + self.assertEqual(sorted(A.class_trait_names()), ["f", "i"]) + self.assertTrue(a.has_trait("f")) + self.assertFalse(a.has_trait("g")) + + def test_trait_has_value(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertFalse(a.trait_has_value("f")) + self.assertFalse(a.trait_has_value("g")) + a.i = 1 + a.f + self.assertTrue(a.trait_has_value("i")) + self.assertTrue(a.trait_has_value("f")) + + def test_trait_metadata_deprecated(self): + with expected_warnings([r"metadata should be set using the \.tag\(\) method"]): + + class A(HasTraits): + i = Int(config_key="MY_VALUE") + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") + + def test_trait_metadata(self): + class A(HasTraits): + i = Int().tag(config_key="MY_VALUE") + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") + + def test_trait_metadata_default(self): + class A(HasTraits): + i = Int() + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), None) + self.assertEqual(a.trait_metadata("i", "config_key", "default"), "default") + + def test_traits(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f)) + self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f)) + + def test_traits_metadata(self): + class A(HasTraits): + i = Int().tag(config_key="VALUE1", other_thing="VALUE2") + f = Float().tag(config_key="VALUE3", other_thing="VALUE2") + j = Int(0) + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) + traits = a.traits(config_key="VALUE1", other_thing="VALUE2") + self.assertEqual(traits, dict(i=A.i)) + + # This passes, but it shouldn't because I am replicating a bug in + # traits. + traits = a.traits(config_key=lambda v: True) + self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) + + def test_traits_metadata_deprecated(self): + with expected_warnings([r"metadata should be set using the \.tag\(\) method"] * 2): + + class A(HasTraits): + i = Int(config_key="VALUE1", other_thing="VALUE2") + f = Float(config_key="VALUE3", other_thing="VALUE2") + j = Int(0) + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) + traits = a.traits(config_key="VALUE1", other_thing="VALUE2") + self.assertEqual(traits, dict(i=A.i)) + + # This passes, but it shouldn't because I am replicating a bug in + # traits. + traits = a.traits(config_key=lambda v: True) + self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) + + def test_init(self): + class A(HasTraits): + i = Int() + x = Float() + + a = A(i=1, x=10.0) + self.assertEqual(a.i, 1) + self.assertEqual(a.x, 10.0) + + def test_positional_args(self): + class A(HasTraits): + i = Int(0) + + def __init__(self, i): + super().__init__() + self.i = i + + a = A(5) + self.assertEqual(a.i, 5) + # should raise TypeError if no positional arg given + self.assertRaises(TypeError, A) + + +# ----------------------------------------------------------------------------- +# Tests for specific trait types +# ----------------------------------------------------------------------------- + + +class TestType(TestCase): + def test_default(self): + class B: + pass + + class A(HasTraits): + klass = Type(allow_none=True) + + a = A() + self.assertEqual(a.klass, object) + + a.klass = B + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", 10) + + def test_default_options(self): + class B: + pass + + class C(B): + pass + + class A(HasTraits): + # Different possible combinations of options for default_value + # and klass. default_value=None is only valid with allow_none=True. + k1 = Type() + k2 = Type(None, allow_none=True) + k3 = Type(B) + k4 = Type(klass=B) + k5 = Type(default_value=None, klass=B, allow_none=True) + k6 = Type(default_value=C, klass=B) + + self.assertIs(A.k1.default_value, object) + self.assertIs(A.k1.klass, object) + self.assertIs(A.k2.default_value, None) + self.assertIs(A.k2.klass, object) + self.assertIs(A.k3.default_value, B) + self.assertIs(A.k3.klass, B) + self.assertIs(A.k4.default_value, B) + self.assertIs(A.k4.klass, B) + self.assertIs(A.k5.default_value, None) + self.assertIs(A.k5.klass, B) + self.assertIs(A.k6.default_value, C) + self.assertIs(A.k6.klass, B) + + a = A() + self.assertIs(a.k1, object) + self.assertIs(a.k2, None) + self.assertIs(a.k3, B) + self.assertIs(a.k4, B) + self.assertIs(a.k5, None) + self.assertIs(a.k6, C) + + def test_value(self): + class B: + pass + + class C: + pass + + class A(HasTraits): + klass = Type(B) + + a = A() + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", C) + self.assertRaises(TraitError, setattr, a, "klass", object) + a.klass = B + + def test_allow_none(self): + class B: + pass + + class C(B): + pass + + class A(HasTraits): + klass = Type(B) + + a = A() + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", None) + a.klass = C + self.assertEqual(a.klass, C) + + def test_validate_klass(self): + class A(HasTraits): + klass = Type("no strings allowed") + + self.assertRaises(ImportError, A) + + class A(HasTraits): # type:ignore + klass = Type("rub.adub.Duck") + + self.assertRaises(ImportError, A) + + def test_validate_default(self): + class B: + pass + + class A(HasTraits): + klass = Type("bad default", B) + + self.assertRaises(ImportError, A) + + class C(HasTraits): + klass = Type(None, B) + + self.assertRaises(TraitError, getattr, C(), "klass") + + def test_str_klass(self): + class A(HasTraits): + klass = Type("traitlets.config.Config") + + from traitlets.config import Config + + a = A() + a.klass = Config + self.assertEqual(a.klass, Config) + + self.assertRaises(TraitError, setattr, a, "klass", 10) + + def test_set_str_klass(self): + class A(HasTraits): + klass = Type() + + a = A(klass="traitlets.config.Config") + from traitlets.config import Config + + self.assertEqual(a.klass, Config) + + +class TestInstance(TestCase): + def test_basic(self): + class Foo: + pass + + class Bar(Foo): + pass + + class Bah: + pass + + class A(HasTraits): + inst = Instance(Foo, allow_none=True) + + a = A() + self.assertTrue(a.inst is None) + a.inst = Foo() + self.assertTrue(isinstance(a.inst, Foo)) + a.inst = Bar() + self.assertTrue(isinstance(a.inst, Foo)) + self.assertRaises(TraitError, setattr, a, "inst", Foo) + self.assertRaises(TraitError, setattr, a, "inst", Bar) + self.assertRaises(TraitError, setattr, a, "inst", Bah()) + + def test_default_klass(self): + class Foo: + pass + + class Bar(Foo): + pass + + class Bah: + pass + + class FooInstance(Instance[Foo]): + klass = Foo + + class A(HasTraits): + inst = FooInstance(allow_none=True) + + a = A() + self.assertTrue(a.inst is None) + a.inst = Foo() + self.assertTrue(isinstance(a.inst, Foo)) + a.inst = Bar() + self.assertTrue(isinstance(a.inst, Foo)) + self.assertRaises(TraitError, setattr, a, "inst", Foo) + self.assertRaises(TraitError, setattr, a, "inst", Bar) + self.assertRaises(TraitError, setattr, a, "inst", Bah()) + + def test_unique_default_value(self): + class Foo: + pass + + class A(HasTraits): + inst = Instance(Foo, (), {}) + + a = A() + b = A() + self.assertTrue(a.inst is not b.inst) + + def test_args_kw(self): + class Foo: + def __init__(self, c): + self.c = c + + class Bar: + pass + + class Bah: + def __init__(self, c, d): + self.c = c + self.d = d + + class A(HasTraits): + inst = Instance(Foo, (10,)) + + a = A() + self.assertEqual(a.inst.c, 10) + + class B(HasTraits): + inst = Instance(Bah, args=(10,), kw=dict(d=20)) + + b = B() + self.assertEqual(b.inst.c, 10) + self.assertEqual(b.inst.d, 20) + + class C(HasTraits): + inst = Instance(Foo, allow_none=True) + + c = C() + self.assertTrue(c.inst is None) + + def test_bad_default(self): + class Foo: + pass + + class A(HasTraits): + inst = Instance(Foo) + + a = A() + with self.assertRaises(TraitError): + a.inst + + def test_instance(self): + class Foo: + pass + + def inner(): + class A(HasTraits): + inst = Instance(Foo()) # type:ignore + + self.assertRaises(TraitError, inner) + + +class TestThis(TestCase): + def test_this_class(self): + class Foo(HasTraits): + this = This["Foo"]() + + f = Foo() + self.assertEqual(f.this, None) + g = Foo() + f.this = g + self.assertEqual(f.this, g) + self.assertRaises(TraitError, setattr, f, "this", 10) + + def test_this_inst(self): + class Foo(HasTraits): + this = This["Foo"]() + + f = Foo() + f.this = Foo() + self.assertTrue(isinstance(f.this, Foo)) + + def test_subclass(self): + class Foo(HasTraits): + t = This["Foo"]() + + class Bar(Foo): + pass + + f = Foo() + b = Bar() + f.t = b + b.t = f + self.assertEqual(f.t, b) + self.assertEqual(b.t, f) + + def test_subclass_override(self): + class Foo(HasTraits): + t = This["Foo"]() + + class Bar(Foo): + t = This() + + f = Foo() + b = Bar() + f.t = b + self.assertEqual(f.t, b) + self.assertRaises(TraitError, setattr, b, "t", f) + + def test_this_in_container(self): + class Tree(HasTraits): + value = Unicode() + leaves = List(This()) + + tree = Tree(value="foo", leaves=[Tree(value="bar"), Tree(value="buzz")]) + + with self.assertRaises(TraitError): + tree.leaves = [1, 2] + + +class TraitTestBase(TestCase): + """A best testing class for basic trait types.""" + + def assign(self, value): + self.obj.value = value # type:ignore + + def coerce(self, value): + return value + + def test_good_values(self): + if hasattr(self, "_good_values"): + for value in self._good_values: + self.assign(value) + self.assertEqual(self.obj.value, self.coerce(value)) # type:ignore + + def test_bad_values(self): + if hasattr(self, "_bad_values"): + for value in self._bad_values: + try: + self.assertRaises(TraitError, self.assign, value) + except AssertionError: + assert False, value + + def test_default_value(self): + if hasattr(self, "_default_value"): + self.assertEqual(self._default_value, self.obj.value) # type:ignore + + def test_allow_none(self): + if ( + hasattr(self, "_bad_values") + and hasattr(self, "_good_values") + and None in self._bad_values + ): + trait = self.obj.traits()["value"] # type:ignore + try: + trait.allow_none = True + self._bad_values.remove(None) + # skip coerce. Allow None casts None to None. + self.assign(None) + self.assertEqual(self.obj.value, None) # type:ignore + self.test_good_values() + self.test_bad_values() + finally: + # tear down + trait.allow_none = False + self._bad_values.append(None) + + def tearDown(self): + # restore default value after tests, if set + if hasattr(self, "_default_value"): + self.obj.value = self._default_value # type:ignore + + +class AnyTrait(HasTraits): + value = Any() + + +class AnyTraitTest(TraitTestBase): + obj = AnyTrait() + + _default_value = None + _good_values = [10.0, "ten", [10], {"ten": 10}, (10,), None, 1j] + _bad_values: t.Any = [] + + +class UnionTrait(HasTraits): + value = Union([Type(), Bool()]) + + +class UnionTraitTest(TraitTestBase): + obj = UnionTrait(value="traitlets.config.Config") + _good_values = [int, float, True] + _bad_values = [[], (0,), 1j] + + +class CallableTrait(HasTraits): + value = Callable() + + +class CallableTraitTest(TraitTestBase): + obj = CallableTrait(value=lambda x: type(x)) + _good_values = [int, sorted, lambda x: print(x)] + _bad_values = [[], 1, ""] + + +class OrTrait(HasTraits): + value = Bool() | Unicode() + + +class OrTraitTest(TraitTestBase): + obj = OrTrait() + _good_values = [True, False, "ten"] + _bad_values = [[], (0,), 1j] + + +class IntTrait(HasTraits): + value = Int(99, min=-100) + + +class TestInt(TraitTestBase): + obj = IntTrait() + _default_value = 99 + _good_values = [10, -10] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + 10.1, + -10.1, + "10L", + "-10L", + "10.1", + "-10.1", + "10", + "-10", + -200, + ] + + +class CIntTrait(HasTraits): + value = CInt("5") + + +class TestCInt(TraitTestBase): + obj = CIntTrait() + + _default_value = 5 + _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] + + def coerce(self, n): + return int(n) + + +class MinBoundCIntTrait(HasTraits): + value = CInt("5", min=3) + + +class TestMinBoundCInt(TestCInt): + obj = MinBoundCIntTrait() # type:ignore + + _default_value = 5 + _good_values = [3, 3.0, "3"] + _bad_values = [2.6, 2, -3, -3.0] + + +class LongTrait(HasTraits): + value = Long(99) + + +class TestLong(TraitTestBase): + obj = LongTrait() + + _default_value = 99 + _good_values = [10, -10] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + 10.1, + -10.1, + "10", + "-10", + "10L", + "-10L", + "10.1", + "-10.1", + ] + + +class MinBoundLongTrait(HasTraits): + value = Long(99, min=5) + + +class TestMinBoundLong(TraitTestBase): + obj = MinBoundLongTrait() + + _default_value = 99 + _good_values = [5, 10] + _bad_values = [4, -10] + + +class MaxBoundLongTrait(HasTraits): + value = Long(5, max=10) + + +class TestMaxBoundLong(TraitTestBase): + obj = MaxBoundLongTrait() + + _default_value = 5 + _good_values = [10, -2] + _bad_values = [11, 20] + + +class CLongTrait(HasTraits): + value = CLong("5") + + +class TestCLong(TraitTestBase): + obj = CLongTrait() + + _default_value = 5 + _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] + + def coerce(self, n): + return int(n) + + +class MaxBoundCLongTrait(HasTraits): + value = CLong("5", max=10) + + +class TestMaxBoundCLong(TestCLong): + obj = MaxBoundCLongTrait() # type:ignore + + _default_value = 5 + _good_values = [10, "10", 10.3] + _bad_values = [11.0, "11"] + + +class IntegerTrait(HasTraits): + value = Integer(1) + + +class TestInteger(TestLong): + obj = IntegerTrait() # type:ignore + _default_value = 1 + + def coerce(self, n): + return int(n) + + +class MinBoundIntegerTrait(HasTraits): + value = Integer(5, min=3) + + +class TestMinBoundInteger(TraitTestBase): + obj = MinBoundIntegerTrait() + + _default_value = 5 + _good_values = 3, 20 + _bad_values = [2, -10] + + +class MaxBoundIntegerTrait(HasTraits): + value = Integer(1, max=3) + + +class TestMaxBoundInteger(TraitTestBase): + obj = MaxBoundIntegerTrait() + + _default_value = 1 + _good_values = 3, -2 + _bad_values = [4, 10] + + +class FloatTrait(HasTraits): + value = Float(99.0, max=200.0) + + +class TestFloat(TraitTestBase): + obj = FloatTrait() + + _default_value = 99.0 + _good_values = [10, -10, 10.1, -10.1] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + "10", + "-10", + "10L", + "-10L", + "10.1", + "-10.1", + 201.0, + ] + + +class CFloatTrait(HasTraits): + value = CFloat("99.0", max=200.0) + + +class TestCFloat(TraitTestBase): + obj = CFloatTrait() + + _default_value = 99.0 + _good_values = [10, 10.0, 10.5, "10.0", "10", "-10"] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, 200.1, "200.1"] + + def coerce(self, v): + return float(v) + + +class ComplexTrait(HasTraits): + value = Complex(99.0 - 99.0j) + + +class TestComplex(TraitTestBase): + obj = ComplexTrait() + + _default_value = 99.0 - 99.0j + _good_values = [ + 10, + -10, + 10.1, + -10.1, + 10j, + 10 + 10j, + 10 - 10j, + 10.1j, + 10.1 + 10.1j, + 10.1 - 10.1j, + ] + _bad_values = ["10L", "-10L", "ten", [10], {"ten": 10}, (10,), None] + + +class BytesTrait(HasTraits): + value = Bytes(b"string") + + +class TestBytes(TraitTestBase): + obj = BytesTrait() + + _default_value = b"string" + _good_values = [b"10", b"-10", b"10L", b"-10L", b"10.1", b"-10.1", b"string"] + _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None, "string"] + + +class UnicodeTrait(HasTraits): + value = Unicode("unicode") + + +class TestUnicode(TraitTestBase): + obj = UnicodeTrait() + + _default_value = "unicode" + _good_values = ["10", "-10", "10L", "-10L", "10.1", "-10.1", "", "string", "€", b"bytestring"] + _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None] + + def coerce(self, v): + return cast_unicode(v) + + +class ObjectNameTrait(HasTraits): + value = ObjectName("abc") + + +class TestObjectName(TraitTestBase): + obj = ObjectNameTrait() + + _default_value = "abc" + _good_values = ["a", "gh", "g9", "g_", "_G", "a345_"] + _bad_values = [ + 1, + "", + "€", + "9g", + "!", + "#abc", + "aj@", + "a.b", + "a()", + "a[0]", + None, + object(), + object, + ] + _good_values.append("þ") # þ=1 is valid in Python 3 (PEP 3131). + + +class DottedObjectNameTrait(HasTraits): + value = DottedObjectName("a.b") + + +class TestDottedObjectName(TraitTestBase): + obj = DottedObjectNameTrait() + + _default_value = "a.b" + _good_values = ["A", "y.t", "y765.__repr__", "os.path.join"] + _bad_values = [1, "abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None] + + _good_values.append("t.þ") + + +class TCPAddressTrait(HasTraits): + value = TCPAddress() + + +class TestTCPAddress(TraitTestBase): + obj = TCPAddressTrait() + + _default_value = ("127.0.0.1", 0) + _good_values = [("localhost", 0), ("192.168.0.1", 1000), ("www.google.com", 80)] + _bad_values = [(0, 0), ("localhost", 10.0), ("localhost", -1), None] + + +class ListTrait(HasTraits): + value = List(Int()) + + +class TestList(TraitTestBase): + obj = ListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[], [1], list(range(10)), (1, 2)] + _bad_values = [10, [1, "a"], "a"] + + def coerce(self, value): + if value is not None: + value = list(value) + return value + + +class Foo: + pass + + +class NoneInstanceListTrait(HasTraits): + value = List(Instance(Foo)) + + +class TestNoneInstanceList(TraitTestBase): + obj = NoneInstanceListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[Foo(), Foo()], []] + _bad_values = [[None], [Foo(), None]] + + +class InstanceListTrait(HasTraits): + value = List(Instance(__name__ + ".Foo")) + + +class TestInstanceList(TraitTestBase): + obj = InstanceListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, Foo) + + _default_value: t.List[t.Any] = [] + _good_values = [[Foo(), Foo()], []] + _bad_values = [ + [ + "1", + 2, + ], + "1", + [Foo], + None, + ] + + +class UnionListTrait(HasTraits): + value = List(Int() | Bool()) + + +class TestUnionListTrait(TraitTestBase): + obj = UnionListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[True, 1], [False, True]] + _bad_values = [[1, "True"], False] + + +class LenListTrait(HasTraits): + value = List(Int(), [0], minlen=1, maxlen=2) + + +class TestLenList(TraitTestBase): + obj = LenListTrait() + + _default_value = [0] + _good_values = [[1], [1, 2], (1, 2)] + _bad_values = [10, [1, "a"], "a", [], list(range(3))] + + def coerce(self, value): + if value is not None: + value = list(value) + return value + + +class TupleTrait(HasTraits): + value = Tuple(Int(allow_none=True), default_value=(1,)) + + +class TestTupleTrait(TraitTestBase): + obj = TupleTrait() + + _default_value = (1,) + _good_values = [(1,), (0,), [1]] + _bad_values = [10, (1, 2), ("a"), (), None] + + def coerce(self, value): + if value is not None: + value = tuple(value) + return value + + def test_invalid_args(self): + self.assertRaises(TypeError, Tuple, 5) + self.assertRaises(TypeError, Tuple, default_value="hello") + t = Tuple(Int(), CBytes(), default_value=(1, 5)) + + +class LooseTupleTrait(HasTraits): + value = Tuple((1, 2, 3)) + + +class TestLooseTupleTrait(TraitTestBase): + obj = LooseTupleTrait() + + _default_value = (1, 2, 3) + _good_values = [(1,), [1], (0,), tuple(range(5)), tuple("hello"), ("a", 5), ()] + _bad_values = [10, "hello", {}, None] + + def coerce(self, value): + if value is not None: + value = tuple(value) + return value + + def test_invalid_args(self): + self.assertRaises(TypeError, Tuple, 5) + self.assertRaises(TypeError, Tuple, default_value="hello") + t = Tuple(Int(), CBytes(), default_value=(1, 5)) + + +class MultiTupleTrait(HasTraits): + value = Tuple(Int(), Bytes(), default_value=[99, b"bottles"]) + + +class TestMultiTuple(TraitTestBase): + obj = MultiTupleTrait() + + _default_value = (99, b"bottles") + _good_values = [(1, b"a"), (2, b"b")] + _bad_values = ((), 10, b"a", (1, b"a", 3), (b"a", 1), (1, "a")) + + +@pytest.mark.parametrize( + "Trait", + ( + List, + Tuple, + Set, + Dict, + Integer, + Unicode, + ), +) +def test_allow_none_default_value(Trait): + class C(HasTraits): + t = Trait(default_value=None, allow_none=True) + + # test default value + c = C() + assert c.t is None + + # and in constructor + c = C(t=None) + assert c.t is None + + +@pytest.mark.parametrize( + "Trait, default_value", + ((List, []), (Tuple, ()), (Set, set()), (Dict, {}), (Integer, 0), (Unicode, "")), +) +def test_default_value(Trait, default_value): + class C(HasTraits): + t = Trait() + + # test default value + c = C() + assert type(c.t) is type(default_value) + assert c.t == default_value + + +@pytest.mark.parametrize( + "Trait, default_value", + ((List, []), (Tuple, ()), (Set, set())), +) +def test_subclass_default_value(Trait, default_value): + """Test deprecated default_value=None behavior for Container subclass traits""" + + class SubclassTrait(Trait): # type:ignore + def __init__(self, default_value=None): + super().__init__(default_value=default_value) + + class C(HasTraits): + t = SubclassTrait() + + # test default value + c = C() + assert type(c.t) is type(default_value) + assert c.t == default_value + + +class CRegExpTrait(HasTraits): + value = CRegExp(r"") + + +class TestCRegExp(TraitTestBase): + def coerce(self, value): + return re.compile(value) + + obj = CRegExpTrait() + + _default_value = re.compile(r"") + _good_values = [r"\d+", re.compile(r"\d+")] + _bad_values = ["(", None, ()] + + +class DictTrait(HasTraits): + value = Dict() + + +def test_dict_assignment(): + d: t.Dict[str, int] = {} + c = DictTrait() + c.value = d + d["a"] = 5 + assert d == c.value + assert c.value is d + + +class UniformlyValueValidatedDictTrait(HasTraits): + value = Dict(value_trait=Unicode(), default_value={"foo": "1"}) + + +class TestInstanceUniformlyValueValidatedDict(TraitTestBase): + obj = UniformlyValueValidatedDictTrait() + + _default_value = {"foo": "1"} + _good_values = [{"foo": "0", "bar": "1"}] + _bad_values = [{"foo": 0, "bar": "1"}] + + +class NonuniformlyValueValidatedDictTrait(HasTraits): + value = Dict(per_key_traits={"foo": Int()}, default_value={"foo": 1}) + + +class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase): + obj = NonuniformlyValueValidatedDictTrait() + + _default_value = {"foo": 1} + _good_values = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": 1}] + _bad_values = [{"foo": "0", "bar": "1"}] + + +class KeyValidatedDictTrait(HasTraits): + value = Dict(key_trait=Unicode(), default_value={"foo": "1"}) + + +class TestInstanceKeyValidatedDict(TraitTestBase): + obj = KeyValidatedDictTrait() + + _default_value = {"foo": "1"} + _good_values = [{"foo": "0", "bar": "1"}] + _bad_values = [{"foo": "0", 0: "1"}] + + +class FullyValidatedDictTrait(HasTraits): + value = Dict( + value_trait=Unicode(), + key_trait=Unicode(), + per_key_traits={"foo": Int()}, + default_value={"foo": 1}, + ) + + +class TestInstanceFullyValidatedDict(TraitTestBase): + obj = FullyValidatedDictTrait() + + _default_value = {"foo": 1} + _good_values = [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}] + _bad_values = [{"foo": 0, "bar": 1}, {"foo": "0", "bar": "1"}, {"foo": 0, 0: "1"}] + + +def test_dict_default_value(): + """Check that the `{}` default value of the Dict traitlet constructor is + actually copied.""" + + class Foo(HasTraits): + d1 = Dict() + d2 = Dict() + + foo = Foo() + assert foo.d1 == {} + assert foo.d2 == {} + assert foo.d1 is not foo.d2 + + +class TestValidationHook(TestCase): + def test_parity_trait(self): + """Verify that the early validation hook is effective""" + + class Parity(HasTraits): + value = Int(0) + parity = Enum(["odd", "even"], default_value="even") + + @validate("value") + def _value_validate(self, proposal): + value = proposal["value"] + if self.parity == "even" and value % 2: + raise TraitError("Expected an even number") + if self.parity == "odd" and (value % 2 == 0): + raise TraitError("Expected an odd number") + return value + + u = Parity() + u.parity = "odd" + u.value = 1 # OK + with self.assertRaises(TraitError): + u.value = 2 # Trait Error + + u.parity = "even" + u.value = 2 # OK + + def test_multiple_validate(self): + """Verify that we can register the same validator to multiple names""" + + class OddEven(HasTraits): + odd = Int(1) + even = Int(0) + + @validate("odd", "even") + def check_valid(self, proposal): + if proposal["trait"].name == "odd" and not proposal["value"] % 2: + raise TraitError("odd should be odd") + if proposal["trait"].name == "even" and proposal["value"] % 2: + raise TraitError("even should be even") + + u = OddEven() + u.odd = 3 # OK + with self.assertRaises(TraitError): + u.odd = 2 # Trait Error + + u.even = 2 # OK + with self.assertRaises(TraitError): + u.even = 3 # Trait Error + + def test_validate_used(self): + """Verify that the validate value is being used""" + + class FixedValue(HasTraits): + value = Int(0) + + @validate("value") + def _value_validate(self, proposal): + return -1 + + u = FixedValue(value=2) + assert u.value == -1 + + u = FixedValue() + u.value = 3 + assert u.value == -1 + + +class TestLink(TestCase): + def test_connect_same(self): + """Verify two traitlets of the same type can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "value")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.value) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.value) + b.value = 6 + self.assertEqual(a.value, b.value) + + def test_link_different(self): + """Verify two traitlets of different types can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "count")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.count) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.count) + b.count = 4 + self.assertEqual(a.value, b.count) + + def test_unlink_link(self): + """Verify two linked traitlets can be unlinked and relinked.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Connect the two classes. + c = link((a, "value"), (b, "value")) + a.value = 4 + c.unlink() + + # Change one of the values to make sure they don't stay in sync. + a.value = 5 + self.assertNotEqual(a.value, b.value) + c.link() + self.assertEqual(a.value, b.value) + a.value += 1 + self.assertEqual(a.value, b.value) + + def test_callbacks(self): + """Verify two linked traitlets have their callbacks called once.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Register callbacks that count. + callback_count = [] + + def a_callback(name, old, new): + callback_count.append("a") + + a.on_trait_change(a_callback, "value") + + def b_callback(name, old, new): + callback_count.append("b") + + b.on_trait_change(b_callback, "count") + + # Connect the two classes. + c = link((a, "value"), (b, "count")) + + # Make sure b's count was set to a's value once. + self.assertEqual("".join(callback_count), "b") + del callback_count[:] + + # Make sure a's value was set to b's count once. + b.count = 5 + self.assertEqual("".join(callback_count), "ba") + del callback_count[:] + + # Make sure b's count was set to a's value once. + a.value = 4 + self.assertEqual("".join(callback_count), "ab") + del callback_count[:] + + def test_tranform(self): + """Test transform link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "value"), transform=(lambda x: 2 * x, lambda x: int(x / 2.0))) + + # Make sure the values are correct at the point of linking. + self.assertEqual(b.value, 2 * a.value) + + # Change one the value of the source and check that it modifies the target. + a.value = 5 + self.assertEqual(b.value, 10) + # Change one the value of the target and check that it modifies the + # source. + b.value = 6 + self.assertEqual(a.value, 3) + + def test_link_broken_at_source(self): + class MyClass(HasTraits): + i = Int() + j = Int() + + @observe("j") + def another_update(self, change): + self.i = change.new * 2 + + mc = MyClass() + l = link((mc, "i"), (mc, "j")) # noqa + self.assertRaises(TraitError, setattr, mc, "i", 2) + + def test_link_broken_at_target(self): + class MyClass(HasTraits): + i = Int() + j = Int() + + @observe("i") + def another_update(self, change): + self.j = change.new * 2 + + mc = MyClass() + l = link((mc, "i"), (mc, "j")) # noqa + self.assertRaises(TraitError, setattr, mc, "j", 2) + + +class TestDirectionalLink(TestCase): + def test_connect_same(self): + """Verify two traitlets of the same type can be linked together using directional_link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "value")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.value) + + # Change one the value of the source and check that it synchronizes the target. + a.value = 5 + self.assertEqual(b.value, 5) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 + self.assertEqual(a.value, 5) + + def test_tranform(self): + """Test transform link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "value"), lambda x: 2 * x) + + # Make sure the values are correct at the point of linking. + self.assertEqual(b.value, 2 * a.value) + + # Change one the value of the source and check that it modifies the target. + a.value = 5 + self.assertEqual(b.value, 10) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 + self.assertEqual(a.value, 5) + + def test_link_different(self): + """Verify two traitlets of different types can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "count")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.count) + + # Change one the value of the source and check that it synchronizes the target. + a.value = 5 + self.assertEqual(b.count, 5) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 # type:ignore + self.assertEqual(a.value, 5) + + def test_unlink_link(self): + """Verify two linked traitlets can be unlinked and relinked.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Connect the two classes. + c = directional_link((a, "value"), (b, "value")) + a.value = 4 + c.unlink() + + # Change one of the values to make sure they don't stay in sync. + a.value = 5 + self.assertNotEqual(a.value, b.value) + c.link() + self.assertEqual(a.value, b.value) + a.value += 1 + self.assertEqual(a.value, b.value) + + +class Pickleable(HasTraits): + i = Int() + + @observe("i") + def _i_changed(self, change): + pass + + @validate("i") + def _i_validate(self, commit): + return commit["value"] + + j = Int() + + def __init__(self): + with self.hold_trait_notifications(): + self.i = 1 + self.on_trait_change(self._i_changed, "i") + + +def test_pickle_hastraits(): + c = Pickleable() + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(c, protocol) + c2 = pickle.loads(p) + assert c2.i == c.i + assert c2.j == c.j + + c.i = 5 + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(c, protocol) + c2 = pickle.loads(p) + assert c2.i == c.i + assert c2.j == c.j + + +def test_hold_trait_notifications(): + changes = [] + + class Test(HasTraits): + a = Integer(0) + b = Integer(0) + + def _a_changed(self, name, old, new): + changes.append((old, new)) + + def _b_validate(self, value, trait): + if value != 0: + raise TraitError("Only 0 is a valid value") + return value + + # Test context manager and nesting + t = Test() + with t.hold_trait_notifications(): + with t.hold_trait_notifications(): + t.a = 1 + assert t.a == 1 + assert changes == [] + t.a = 2 + assert t.a == 2 + with t.hold_trait_notifications(): + t.a = 3 + assert t.a == 3 + assert changes == [] + t.a = 4 + assert t.a == 4 + assert changes == [] + t.a = 4 + assert t.a == 4 + assert changes == [] + + assert changes == [(0, 4)] + # Test roll-back + try: + with t.hold_trait_notifications(): + t.b = 1 # raises a Trait error + except Exception: + pass + assert t.b == 0 + + +class RollBack(HasTraits): + bar = Int() + + def _bar_validate(self, value, trait): + if value: + raise TraitError("foobar") + return value + + +class TestRollback(TestCase): + def test_roll_back(self): + def assign_rollback(): + RollBack(bar=1) + + self.assertRaises(TraitError, assign_rollback) + + +class CacheModification(HasTraits): + foo = Int() + bar = Int() + + def _bar_validate(self, value, trait): + self.foo = value + return value + + def _foo_validate(self, value, trait): + self.bar = value + return value + + +def test_cache_modification(): + CacheModification(foo=1) + CacheModification(bar=1) + + +class OrderTraits(HasTraits): + notified = Dict() + + a = Unicode() + b = Unicode() + c = Unicode() + d = Unicode() + e = Unicode() + f = Unicode() + g = Unicode() + h = Unicode() + i = Unicode() + j = Unicode() + k = Unicode() + l = Unicode() # noqa + + def _notify(self, name, old, new): + """check the value of all traits when each trait change is triggered + + This verifies that the values are not sensitive + to dict ordering when loaded from kwargs + """ + # check the value of the other traits + # when a given trait change notification fires + self.notified[name] = {c: getattr(self, c) for c in "abcdefghijkl"} + + def __init__(self, **kwargs): + self.on_trait_change(self._notify) + super().__init__(**kwargs) + + +def test_notification_order(): + d = {c: c for c in "abcdefghijkl"} + obj = OrderTraits() + assert obj.notified == {} + obj = OrderTraits(**d) + notifications = {c: d for c in "abcdefghijkl"} + assert obj.notified == notifications + + +### +# Traits for Forward Declaration Tests +### +class ForwardDeclaredInstanceTrait(HasTraits): + value = ForwardDeclaredInstance["ForwardDeclaredBar"]("ForwardDeclaredBar", allow_none=True) + + +class ForwardDeclaredTypeTrait(HasTraits): + value = ForwardDeclaredType[t.Any, t.Any]("ForwardDeclaredBar", allow_none=True) + + +class ForwardDeclaredInstanceListTrait(HasTraits): + value = List(ForwardDeclaredInstance("ForwardDeclaredBar")) + + +class ForwardDeclaredTypeListTrait(HasTraits): + value = List(ForwardDeclaredType("ForwardDeclaredBar")) + + +### +# End Traits for Forward Declaration Tests +### + + +### +# Classes for Forward Declaration Tests +### +class ForwardDeclaredBar: + pass + + +class ForwardDeclaredBarSub(ForwardDeclaredBar): + pass + + +### +# End Classes for Forward Declaration Tests +### + + +### +# Forward Declaration Tests +### +class TestForwardDeclaredInstanceTrait(TraitTestBase): + obj = ForwardDeclaredInstanceTrait() + _default_value = None + _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()] + _bad_values = ["foo", 3, ForwardDeclaredBar, ForwardDeclaredBarSub] + + +class TestForwardDeclaredTypeTrait(TraitTestBase): + obj = ForwardDeclaredTypeTrait() + _default_value = None + _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub] + _bad_values = ["foo", 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()] + + +class TestForwardDeclaredInstanceList(TraitTestBase): + obj = ForwardDeclaredInstanceListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) + + _default_value: t.List[t.Any] = [] + _good_values = [ + [ForwardDeclaredBar(), ForwardDeclaredBarSub()], + [], + ] + _bad_values = [ + ForwardDeclaredBar(), + [ForwardDeclaredBar(), 3, None], + "1", + # Note that this is the type, not an instance. + [ForwardDeclaredBar], + [None], + None, + ] + + +class TestForwardDeclaredTypeList(TraitTestBase): + obj = ForwardDeclaredTypeListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) + + _default_value: t.List[t.Any] = [] + _good_values = [ + [ForwardDeclaredBar, ForwardDeclaredBarSub], + [], + ] + _bad_values = [ + ForwardDeclaredBar, + [ForwardDeclaredBar, 3], + "1", + # Note that this is an instance, not the type. + [ForwardDeclaredBar()], + [None], + None, + ] + + +### +# End Forward Declaration Tests +### + + +class TestDynamicTraits(TestCase): + def setUp(self): + self._notify1 = [] + + def notify1(self, name, old, new): + self._notify1.append((name, old, new)) + + @t.no_type_check + def test_notify_all(self): + class A(HasTraits): + pass + + a = A() + self.assertTrue(not hasattr(a, "x")) + self.assertTrue(not hasattr(a, "y")) + + # Dynamically add trait x. + a.add_traits(x=Int()) + self.assertTrue(hasattr(a, "x")) + self.assertTrue(isinstance(a, (A,))) + + # Dynamically add trait y. + a.add_traits(y=Float()) + self.assertTrue(hasattr(a, "y")) + self.assertTrue(isinstance(a, (A,))) + self.assertEqual(a.__class__.__name__, A.__name__) + + # Create a new instance and verify that x and y + # aren't defined. + b = A() + self.assertTrue(not hasattr(b, "x")) + self.assertTrue(not hasattr(b, "y")) + + # Verify that notification works like normal. + a.on_trait_change(self.notify1) + a.x = 0 + self.assertEqual(len(self._notify1), 0) + a.y = 0.0 + self.assertEqual(len(self._notify1), 0) + a.x = 10 + self.assertTrue(("x", 0, 10) in self._notify1) + a.y = 10.0 + self.assertTrue(("y", 0.0, 10.0) in self._notify1) + self.assertRaises(TraitError, setattr, a, "x", "bad string") + self.assertRaises(TraitError, setattr, a, "y", "bad string") + self._notify1 = [] + a.on_trait_change(self.notify1, remove=True) + a.x = 20 + a.y = 20.0 + self.assertEqual(len(self._notify1), 0) + + +def test_enum_no_default(): + class C(HasTraits): + t = Enum(["a", "b"]) + + c = C() + c.t = "a" + assert c.t == "a" + + c = C() + + with pytest.raises(TraitError): + t = c.t + + c = C(t="b") + assert c.t == "b" + + +def test_default_value_repr(): + class C(HasTraits): + t = Type("traitlets.HasTraits") + t2 = Type(HasTraits) + n = Integer(0) + lis = List() + d = Dict() + + assert C.t.default_value_repr() == "'traitlets.HasTraits'" + assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'" + assert C.n.default_value_repr() == "0" + assert C.lis.default_value_repr() == "[]" + assert C.d.default_value_repr() == "{}" + + +class TransitionalClass(HasTraits): + d = Any() + + @default("d") + def _d_default(self): + return TransitionalClass + + parent_super = False + calls_super = Integer(0) + + @default("calls_super") + def _calls_super_default(self): + return -1 + + @observe("calls_super") + @observe_compat + def _calls_super_changed(self, change): + self.parent_super = change + + parent_override = False + overrides = Integer(0) + + @observe("overrides") + @observe_compat + def _overrides_changed(self, change): + self.parent_override = change + + +class SubClass(TransitionalClass): + def _d_default(self): + return SubClass + + subclass_super = False + + def _calls_super_changed(self, name, old, new): + self.subclass_super = True + super()._calls_super_changed(name, old, new) + + subclass_override = False + + def _overrides_changed(self, name, old, new): + self.subclass_override = True + + +def test_subclass_compat(): + obj = SubClass() + obj.calls_super = 5 + assert obj.parent_super + assert obj.subclass_super + obj.overrides = 5 + assert obj.subclass_override + assert not obj.parent_override + assert obj.d is SubClass + + +class DefinesHandler(HasTraits): + parent_called = False + + trait = Integer() + + @observe("trait") + def handler(self, change): + self.parent_called = True + + +class OverridesHandler(DefinesHandler): + child_called = False + + @observe("trait") + def handler(self, change): + self.child_called = True + + +def test_subclass_override_observer(): + obj = OverridesHandler() + obj.trait = 5 + assert obj.child_called + assert not obj.parent_called + + +class DoesntRegisterHandler(DefinesHandler): + child_called = False + + def handler(self, change): + self.child_called = True + + +def test_subclass_override_not_registered(): + """Subclass that overrides observer and doesn't re-register unregisters both""" + obj = DoesntRegisterHandler() + obj.trait = 5 + assert not obj.child_called + assert not obj.parent_called + + +class AddsHandler(DefinesHandler): + child_called = False + + @observe("trait") + def child_handler(self, change): + self.child_called = True + + +def test_subclass_add_observer(): + obj = AddsHandler() + obj.trait = 5 + assert obj.child_called + assert obj.parent_called + + +def test_observe_iterables(): + class C(HasTraits): + i = Integer() + s = Unicode() + + c = C() + recorded = {} + + def record(change): + recorded["change"] = change + + # observe with names=set + c.observe(record, names={"i", "s"}) + c.i = 5 + assert recorded["change"].name == "i" + assert recorded["change"].new == 5 + c.s = "hi" + assert recorded["change"].name == "s" + assert recorded["change"].new == "hi" + + # observe with names=custom container with iter, contains + class MyContainer: + def __init__(self, container): + self.container = container + + def __iter__(self): + return iter(self.container) + + def __contains__(self, key): + return key in self.container + + c.observe(record, names=MyContainer({"i", "s"})) + c.i = 10 + assert recorded["change"].name == "i" + assert recorded["change"].new == 10 + c.s = "ok" + assert recorded["change"].name == "s" + assert recorded["change"].new == "ok" + + +def test_super_args(): + class SuperRecorder: + def __init__(self, *args, **kwargs): + self.super_args = args + self.super_kwargs = kwargs + + class SuperHasTraits(HasTraits, SuperRecorder): + i = Integer() + + obj = SuperHasTraits("a1", "a2", b=10, i=5, c="x") + assert obj.i == 5 + assert not hasattr(obj, "b") + assert not hasattr(obj, "c") + assert obj.super_args == ("a1", "a2") + assert obj.super_kwargs == {"b": 10, "c": "x"} + + +def test_super_bad_args(): + class SuperHasTraits(HasTraits): + a = Integer() + + w = ["Passing unrecognized arguments"] + with expected_warnings(w): + obj = SuperHasTraits(a=1, b=2) + assert obj.a == 1 + assert not hasattr(obj, "b") + + +def test_default_mro(): + """Verify that default values follow mro""" + + class Base(HasTraits): + trait = Unicode("base") + attr = "base" + + class A(Base): + pass + + class B(Base): + trait = Unicode("B") + attr = "B" + + class AB(A, B): + pass + + class BA(B, A): + pass + + assert A().trait == "base" + assert A().attr == "base" + assert BA().trait == "B" + assert BA().attr == "B" + assert AB().trait == "B" + assert AB().attr == "B" + + +def test_cls_self_argument(): + class X(HasTraits): + def __init__(__self, cls, self): # noqa + pass + + x = X(cls=None, self=None) + + +def test_override_default(): + class C(HasTraits): + a = Unicode("hard default") + + def _a_default(self): + return "default method" + + C._a_default = lambda self: "overridden" # type:ignore + c = C() + assert c.a == "overridden" + + +def test_override_default_decorator(): + class C(HasTraits): + a = Unicode("hard default") + + @default("a") + def _a_default(self): + return "default method" + + C._a_default = lambda self: "overridden" # type:ignore + c = C() + assert c.a == "overridden" + + +def test_override_default_instance(): + class C(HasTraits): + a = Unicode("hard default") + + @default("a") + def _a_default(self): + return "default method" + + c = C() + c._a_default = lambda self: "overridden" + assert c.a == "overridden" + + +def test_copy_HasTraits(): + from copy import copy + + class C(HasTraits): + a = Int() + + c = C(a=1) + assert c.a == 1 + + cc = copy(c) + cc.a = 2 + assert cc.a == 2 + assert c.a == 1 + + +def _from_string_test(traittype, s, expected): + """Run a test of trait.from_string""" + if isinstance(traittype, TraitType): + trait = traittype + else: + trait = traittype(allow_none=True) + if isinstance(s, list): + cast = trait.from_string_list # type:ignore + else: + cast = trait.from_string + if type(expected) is type and issubclass(expected, Exception): + with pytest.raises(expected): + value = cast(s) + trait.validate(CrossValidationStub(), value) # type:ignore + else: + value = cast(s) + assert value == expected + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], +) +def test_unicode_from_string(s, expected): + _from_string_test(Unicode, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], +) +def test_cunicode_from_string(s, expected): + _from_string_test(CUnicode, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], +) +def test_bytes_from_string(s, expected): + _from_string_test(Bytes, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], +) +def test_cbytes_from_string(s, expected): + _from_string_test(CBytes, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)], +) +def test_int_from_string(s, expected): + _from_string_test(Integer, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)], +) +def test_float_from_string(s, expected): + _from_string_test(Float, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("x", ValueError), + ("1", 1.0), + ("123.5", 123.5), + ("2.5", 2.5), + ("1+2j", 1 + 2j), + ("None", None), + ], +) +def test_complex_from_string(s, expected): + _from_string_test(Complex, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("true", True), + ("TRUE", True), + ("1", True), + ("0", False), + ("False", False), + ("false", False), + ("1.0", ValueError), + ("None", None), + ], +) +def test_bool_from_string(s, expected): + _from_string_test(Bool, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("{}", {}), + ("1", TraitError), + ("{1: 2}", {1: 2}), + ('{"key": "value"}', {"key": "value"}), + ("x", TraitError), + ("None", None), + ], +) +def test_dict_from_string(s, expected): + _from_string_test(Dict, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("[]", []), + ('[1, 2, "x"]', [1, 2, "x"]), + (["1", "x"], ["1", "x"]), + (["None"], None), + ], +) +def test_list_from_string(s, expected): + _from_string_test(List, s, expected) + + +@pytest.mark.parametrize( + "s, expected, value_trait", + [ + (["1", "2", "3"], [1, 2, 3], Integer()), + (["x"], ValueError, Integer()), + (["1", "x"], ["1", "x"], Unicode()), + (["None"], [None], Unicode(allow_none=True)), + (["None"], ["None"], Unicode(allow_none=False)), + ], +) +def test_list_items_from_string(s, expected, value_trait): + _from_string_test(List(value_trait), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("[]", set()), + ('[1, 2, "x"]', {1, 2, "x"}), + ('{1, 2, "x"}', {1, 2, "x"}), + (["1", "x"], {"1", "x"}), + (["None"], None), + ], +) +def test_set_from_string(s, expected): + _from_string_test(Set, s, expected) + + +@pytest.mark.parametrize( + "s, expected, value_trait", + [ + (["1", "2", "3"], {1, 2, 3}, Integer()), + (["x"], ValueError, Integer()), + (["1", "x"], {"1", "x"}, Unicode()), + (["None"], {None}, Unicode(allow_none=True)), + ], +) +def test_set_items_from_string(s, expected, value_trait): + _from_string_test(Set(value_trait), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("[]", ()), + ("()", ()), + ('[1, 2, "x"]', (1, 2, "x")), + ('(1, 2, "x")', (1, 2, "x")), + (["1", "x"], ("1", "x")), + (["None"], None), + ], +) +def test_tuple_from_string(s, expected): + _from_string_test(Tuple, s, expected) + + +@pytest.mark.parametrize( + "s, expected, value_traits", + [ + (["1", "2", "3"], (1, 2, 3), [Integer(), Integer(), Integer()]), + (["x"], ValueError, [Integer()]), + (["1", "x"], ("1", "x"), [Unicode()]), + (["None"], ("None",), [Unicode(allow_none=False)]), + (["None"], (None,), [Unicode(allow_none=True)]), + ], +) +def test_tuple_items_from_string(s, expected, value_traits): + _from_string_test(Tuple(*value_traits), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("x", "x"), + ("mod.submod", "mod.submod"), + ("not an identifier", TraitError), + ("1", "1"), + ("None", None), + ], +) +def test_object_from_string(s, expected): + _from_string_test(DottedObjectName, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("127.0.0.1:8000", ("127.0.0.1", 8000)), + ("host.tld:80", ("host.tld", 80)), + ("host:notaport", ValueError), + ("127.0.0.1", ValueError), + ("None", None), + ], +) +def test_tcp_from_string(s, expected): + _from_string_test(TCPAddress, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("[]", []), ("{}", "{}")], +) +def test_union_of_list_and_unicode_from_string(s, expected): + _from_string_test(Union([List(), Unicode()]), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("1", 1), ("1.5", 1.5)], +) +def test_union_of_int_and_float_from_string(s, expected): + _from_string_test(Union([Int(), Float()]), s, expected) + + +@pytest.mark.parametrize( + "s, expected, allow_none", + [("[]", [], False), ("{}", {}, False), ("None", TraitError, False), ("None", None, True)], +) +def test_union_of_list_and_dict_from_string(s, expected, allow_none): + _from_string_test(Union([List(), Dict()], allow_none=allow_none), s, expected) + + +def test_all_attribute(): + """Verify all trait types are added to `traitlets.__all__`""" + names = dir(traitlets) + for name in names: + value = getattr(traitlets, name) + if not name.startswith("_") and isinstance(value, type) and issubclass(value, TraitType): + if name not in traitlets.__all__: + raise ValueError(f"{name} not in __all__") + + for name in traitlets.__all__: + if name not in names: + raise ValueError(f"{name} should be removed from __all__") diff --git a/contrib/python/traitlets/py3/tests/test_traitlets_docstring.py b/contrib/python/traitlets/py3/tests/test_traitlets_docstring.py new file mode 100644 index 00000000000..700199108f1 --- /dev/null +++ b/contrib/python/traitlets/py3/tests/test_traitlets_docstring.py @@ -0,0 +1,84 @@ +"""Tests for traitlets.traitlets.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +# +from traitlets import Dict, Instance, Integer, Unicode, Union +from traitlets.config import Configurable + + +def test_handle_docstring(): + class SampleConfigurable(Configurable): + pass + + class TraitTypesSampleConfigurable(Configurable): + """TraitTypesSampleConfigurable docstring""" + + trait_integer = Integer( + help="""trait_integer help text""", + config=True, + ) + trait_integer_nohelp = Integer( + config=True, + ) + trait_integer_noconfig = Integer( + help="""trait_integer_noconfig help text""", + ) + + trait_unicode = Unicode( + help="""trait_unicode help text""", + config=True, + ) + trait_unicode_nohelp = Unicode( + config=True, + ) + trait_unicode_noconfig = Unicode( + help="""trait_unicode_noconfig help text""", + ) + + trait_dict = Dict( + help="""trait_dict help text""", + config=True, + ) + trait_dict_nohelp = Dict( + config=True, + ) + trait_dict_noconfig = Dict( + help="""trait_dict_noconfig help text""", + ) + + trait_instance = Instance( + klass=SampleConfigurable, + help="""trait_instance help text""", + config=True, + ) + trait_instance_nohelp = Instance( + klass=SampleConfigurable, + config=True, + ) + trait_instance_noconfig = Instance( + klass=SampleConfigurable, + help="""trait_instance_noconfig help text""", + ) + + trait_union = Union( + [Integer(), Unicode()], + help="""trait_union help text""", + config=True, + ) + trait_union_nohelp = Union( + [Integer(), Unicode()], + config=True, + ) + trait_union_noconfig = Union( + [Integer(), Unicode()], + help="""trait_union_noconfig help text""", + ) + + base_names = SampleConfigurable().trait_names() + for name in TraitTypesSampleConfigurable().trait_names(): + if name in base_names: + continue + doc = getattr(TraitTypesSampleConfigurable, name).__doc__ + if "nohelp" not in name: + assert doc == f"{name} help text" diff --git a/contrib/python/traitlets/py3/tests/test_traitlets_enum.py b/contrib/python/traitlets/py3/tests/test_traitlets_enum.py new file mode 100644 index 00000000000..c39007e8a05 --- /dev/null +++ b/contrib/python/traitlets/py3/tests/test_traitlets_enum.py @@ -0,0 +1,380 @@ +# pylint: disable=missing-docstring, too-few-public-methods +""" +Test the trait-type ``UseEnum``. +""" + +import enum +import unittest + +from traitlets import CaselessStrEnum, Enum, FuzzyEnum, HasTraits, TraitError, UseEnum + +# ----------------------------------------------------------------------------- +# TEST SUPPORT: +# ----------------------------------------------------------------------------- + + +class Color(enum.Enum): + red = 1 + green = 2 + blue = 3 + yellow = 4 + + +class OtherColor(enum.Enum): + red = 0 + green = 1 + + +class CSColor(enum.Enum): + red = 1 + Green = 2 + BLUE = 3 + YeLLoW = 4 + + +color_choices = "red Green BLUE YeLLoW".split() + + +# ----------------------------------------------------------------------------- +# TESTSUITE: +# ----------------------------------------------------------------------------- +class TestUseEnum(unittest.TestCase): + # pylint: disable=invalid-name + + class Example(HasTraits): + color = UseEnum(Color, help="Color enum") + + def test_assign_enum_value(self): + example = self.Example() + example.color = Color.green + self.assertEqual(example.color, Color.green) + + def test_assign_all_enum_values(self): + # pylint: disable=no-member + enum_values = list(Color.__members__.values()) + for value in enum_values: + self.assertIsInstance(value, Color) + example = self.Example() + example.color = value + self.assertEqual(example.color, value) + self.assertIsInstance(value, Color) + + def test_assign_enum_value__with_other_enum_raises_error(self): + example = self.Example() + with self.assertRaises(TraitError): + example.color = OtherColor.green + + def test_assign_enum_name_1(self): + # -- CONVERT: string => Enum value (item) + example = self.Example() + example.color = "red" + self.assertEqual(example.color, Color.red) + + def test_assign_enum_value_name(self): + # -- CONVERT: string => Enum value (item) + # pylint: disable=no-member + enum_names = [enum_val.name for enum_val in Color.__members__.values()] + for value in enum_names: + self.assertIsInstance(value, str) + example = self.Example() + enum_value = Color.__members__.get(value) + example.color = value + self.assertIs(example.color, enum_value) + self.assertEqual(example.color.name, value) # type:ignore + + def test_assign_scoped_enum_value_name(self): + # -- CONVERT: string => Enum value (item) + scoped_names = ["Color.red", "Color.green", "Color.blue", "Color.yellow"] + for value in scoped_names: + example = self.Example() + example.color = value + self.assertIsInstance(example.color, Color) + self.assertEqual(str(example.color), value) + + def test_assign_bad_enum_value_name__raises_error(self): + # -- CONVERT: string => Enum value (item) + bad_enum_names = ["UNKNOWN_COLOR", "RED", "Green", "blue2"] + for value in bad_enum_names: + example = self.Example() + with self.assertRaises(TraitError): + example.color = value + + def test_assign_enum_value_number_1(self): + # -- CONVERT: number => Enum value (item) + example = self.Example() + example.color = 1 # == Color.red.value + example.color = Color.red.value + self.assertEqual(example.color, Color.red) + + def test_assign_enum_value_number(self): + # -- CONVERT: number => Enum value (item) + # pylint: disable=no-member + enum_numbers = [enum_val.value for enum_val in Color.__members__.values()] + for value in enum_numbers: + self.assertIsInstance(value, int) + example = self.Example() + example.color = value + self.assertIsInstance(example.color, Color) + self.assertEqual(example.color.value, value) # type:ignore + + def test_assign_bad_enum_value_number__raises_error(self): + # -- CONVERT: number => Enum value (item) + bad_numbers = [-1, 0, 5] + for value in bad_numbers: + self.assertIsInstance(value, int) + assert UseEnum(Color).select_by_number(value, None) is None + example = self.Example() + with self.assertRaises(TraitError): + example.color = value + + def test_ctor_without_default_value(self): + # -- IMPLICIT: default_value = Color.red (first enum-value) + class Example2(HasTraits): + color = UseEnum(Color) + + example = Example2() + self.assertEqual(example.color, Color.red) + + def test_ctor_with_default_value_as_enum_value(self): + # -- CONVERT: number => Enum value (item) + class Example2(HasTraits): + color = UseEnum(Color, default_value=Color.green) + + example = Example2() + self.assertEqual(example.color, Color.green) + + def test_ctor_with_default_value_none_and_not_allow_none(self): + # -- IMPLICIT: default_value = Color.red (first enum-value) + class Example2(HasTraits): + color1 = UseEnum(Color, default_value=None, allow_none=False) + color2 = UseEnum(Color, default_value=None) + + example = Example2() + self.assertEqual(example.color1, Color.red) + self.assertEqual(example.color2, Color.red) + + def test_ctor_with_default_value_none_and_allow_none(self): + class Example2(HasTraits): + color1 = UseEnum(Color, default_value=None, allow_none=True) + color2 = UseEnum(Color, allow_none=True) + + example = Example2() + self.assertIs(example.color1, None) + self.assertIs(example.color2, None) + + def test_assign_none_without_allow_none_resets_to_default_value(self): + class Example2(HasTraits): + color1 = UseEnum(Color, allow_none=False) + color2 = UseEnum(Color) + + example = Example2() + example.color1 = None + example.color2 = None + self.assertIs(example.color1, Color.red) + self.assertIs(example.color2, Color.red) + + def test_assign_none_to_enum_or_none(self): + class Example2(HasTraits): + color = UseEnum(Color, allow_none=True) + + example = Example2() + example.color = None + self.assertIs(example.color, None) + + def test_assign_bad_value_with_to_enum_or_none(self): + class Example2(HasTraits): + color = UseEnum(Color, allow_none=True) + + example = Example2() + with self.assertRaises(TraitError): + example.color = "BAD_VALUE" + + def test_info(self): + choices = color_choices + + class Example(HasTraits): + enum1 = Enum(choices, allow_none=False) + enum2 = CaselessStrEnum(choices, allow_none=False) + enum3 = FuzzyEnum(choices, allow_none=False) + enum4 = UseEnum(CSColor, allow_none=False) + + for i in range(1, 5): + attr = "enum%s" % i + enum = getattr(Example, attr) + + enum.allow_none = True + + info = enum.info() + self.assertEqual(len(info.split(", ")), len(choices), info.split(", ")) + self.assertIn("or None", info) + + info = enum.info_rst() + self.assertEqual(len(info.split("|")), len(choices), info.split("|")) + self.assertIn("or `None`", info) + # Check no single `\` exists. + self.assertNotRegex(info, r"\b\\\b") + + enum.allow_none = False + + info = enum.info() + self.assertEqual(len(info.split(", ")), len(choices), info.split(", ")) + self.assertNotIn("None", info) + + info = enum.info_rst() + self.assertEqual(len(info.split("|")), len(choices), info.split("|")) + self.assertNotIn("None", info) + # Check no single `\` exists. + self.assertNotRegex(info, r"\b\\\b") + + +# ----------------------------------------------------------------------------- +# TESTSUITE: +# ----------------------------------------------------------------------------- + + +class TestFuzzyEnum(unittest.TestCase): + # Check mostly `validate()`, Ctor must be checked on generic `Enum` + # or `CaselessStrEnum`. + + def test_search_all_prefixes__overwrite(self): + class FuzzyExample(HasTraits): + color = FuzzyEnum(color_choices, help="Color enum") + + example = FuzzyExample() + for color in color_choices: + for wlen in range(1, len(color)): + value = color[:wlen] + + example.color = value + self.assertEqual(example.color, color) + + example.color = value.upper() + self.assertEqual(example.color, color) + + example.color = value.lower() + self.assertEqual(example.color, color) + + def test_search_all_prefixes__ctor(self): + class FuzzyExample(HasTraits): + color = FuzzyEnum(color_choices, help="Color enum") + + for color in color_choices: + for wlen in range(1, len(color)): + value = color[:wlen] + + example = FuzzyExample() + example.color = value + self.assertEqual(example.color, color) + + example = FuzzyExample() + example.color = value.upper() + self.assertEqual(example.color, color) + + example = FuzzyExample() + example.color = value.lower() + self.assertEqual(example.color, color) + + def test_search_substrings__overwrite(self): + class FuzzyExample(HasTraits): + color = FuzzyEnum(color_choices, help="Color enum", substring_matching=True) + + example = FuzzyExample() + for color in color_choices: + for wlen in range(0, 2): + value = color[wlen:] + + example.color = value + self.assertEqual(example.color, color) + + example.color = value.upper() + self.assertEqual(example.color, color) + + example.color = value.lower() + self.assertEqual(example.color, color) + + def test_search_substrings__ctor(self): + class FuzzyExample(HasTraits): + color = FuzzyEnum(color_choices, help="Color enum", substring_matching=True) + + color = color_choices[-1] # 'YeLLoW' + for end in (-1, len(color)): + for start in range(1, len(color) - 2): + value = color[start:end] + + example = FuzzyExample() + example.color = value + self.assertEqual(example.color, color) + + example = FuzzyExample() + example.color = value.upper() + self.assertEqual(example.color, color) + + def test_assign_other_raises(self): + def new_trait_class(case_sensitive, substring_matching): + class Example(HasTraits): + color = FuzzyEnum( + color_choices, + case_sensitive=case_sensitive, + substring_matching=substring_matching, + ) + + return Example + + example = new_trait_class(case_sensitive=False, substring_matching=False)() + with self.assertRaises(TraitError): + example.color = "" + with self.assertRaises(TraitError): + example.color = "BAD COLOR" + with self.assertRaises(TraitError): + example.color = "ed" + + example = new_trait_class(case_sensitive=True, substring_matching=False)() + with self.assertRaises(TraitError): + example.color = "" + with self.assertRaises(TraitError): + example.color = "Red" # not 'red' + + example = new_trait_class(case_sensitive=True, substring_matching=True)() + with self.assertRaises(TraitError): + example.color = "" + with self.assertRaises(TraitError): + example.color = "BAD COLOR" + with self.assertRaises(TraitError): + example.color = "green" # not 'Green' + with self.assertRaises(TraitError): + example.color = "lue" # not (b)'LUE' + with self.assertRaises(TraitError): + example.color = "lUE" # not (b)'LUE' + + example = new_trait_class(case_sensitive=False, substring_matching=True)() + with self.assertRaises(TraitError): + example.color = "" + with self.assertRaises(TraitError): + example.color = "BAD COLOR" + + def test_ctor_with_default_value(self): + def new_trait_class(default_value, case_sensitive, substring_matching): + class Example(HasTraits): + color = FuzzyEnum( + color_choices, + default_value=default_value, + case_sensitive=case_sensitive, + substring_matching=substring_matching, + ) + + return Example + + for color in color_choices: + example = new_trait_class(color, False, False)() + self.assertEqual(example.color, color) + + example = new_trait_class(color.upper(), False, False)() + self.assertEqual(example.color, color) + + color = color_choices[-1] # 'YeLLoW' + example = new_trait_class(color, True, False)() + self.assertEqual(example.color, color) + + # FIXME: default value not validated! + # with self.assertRaises(TraitError): + # example = new_trait_class(color.lower(), True, False) diff --git a/contrib/python/traitlets/py3/tests/test_typing.py b/contrib/python/traitlets/py3/tests/test_typing.py new file mode 100644 index 00000000000..2b4073ecf72 --- /dev/null +++ b/contrib/python/traitlets/py3/tests/test_typing.py @@ -0,0 +1,395 @@ +from __future__ import annotations + +import typing + +import pytest + +from traitlets import ( + Any, + Bool, + CInt, + Dict, + HasTraits, + Instance, + Int, + List, + Set, + TCPAddress, + Type, + Unicode, + Union, + default, + observe, + validate, +) +from traitlets.config import Config + +if not typing.TYPE_CHECKING: + + def reveal_type(*args, **kwargs): + pass + + +# mypy: disallow-untyped-calls + + +class Foo: + def __init__(self, c): + self.c = c + + +@pytest.mark.mypy_testing +def mypy_decorator_typing(): + class T(HasTraits): + foo = Unicode("").tag(config=True) + + @default("foo") + def _default_foo(self) -> str: + return "hi" + + @observe("foo") + def _foo_observer(self, change: typing.Any) -> bool: + return True + + @validate("foo") + def _foo_validate(self, commit: typing.Any) -> bool: + return True + + t = T() + reveal_type(t.foo) # R: builtins.str + reveal_type(t._foo_observer) # R: Any + reveal_type(t._foo_validate) # R: Any + + +@pytest.mark.mypy_testing +def mypy_config_typing(): + c = Config( + { + "ExtractOutputPreprocessor": {"enabled": True}, + } + ) + reveal_type(c) # R: traitlets.config.loader.Config + + +@pytest.mark.mypy_testing +def mypy_union_typing(): + class T(HasTraits): + style = Union( + [Unicode("default"), Type(klass=object)], + help="Name of the pygments style to use", + default_value="hi", + ).tag(config=True) + + t = T() + reveal_type(Union("foo")) # R: traitlets.traitlets.Union + reveal_type(Union("").tag(sync=True)) # R: traitlets.traitlets.Union + reveal_type(Union(None, allow_none=True)) # R: traitlets.traitlets.Union + reveal_type(Union(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Union + reveal_type(T.style) # R: traitlets.traitlets.Union + reveal_type(t.style) # R: Any + + +@pytest.mark.mypy_testing +def mypy_list_typing(): + class T(HasTraits): + latex_command = List( + ["xelatex", "{filename}", "-quiet"], help="Shell command used to compile latex." + ).tag(config=True) + + t = T() + reveal_type(List("foo")) # R: traitlets.traitlets.List + reveal_type(List("").tag(sync=True)) # R: traitlets.traitlets.List + reveal_type(List(None, allow_none=True)) # R: traitlets.traitlets.List + reveal_type(List(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.List + reveal_type(T.latex_command) # R: traitlets.traitlets.List + reveal_type(t.latex_command) # R: builtins.list[Any] + + +@pytest.mark.mypy_testing +def mypy_dict_typing(): + class T(HasTraits): + foo = Dict({}, help="Shell command used to compile latex.").tag(config=True) + + t = T() + reveal_type(Dict("foo")) # R: traitlets.traitlets.Dict + reveal_type(Dict("").tag(sync=True)) # R: traitlets.traitlets.Dict + reveal_type(Dict(None, allow_none=True)) # R: traitlets.traitlets.Dict + reveal_type(Dict(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Dict + reveal_type(T.foo) # R: traitlets.traitlets.Dict + reveal_type(t.foo) # R: builtins.dict[Any, Any] + + +@pytest.mark.mypy_testing +def mypy_type_typing(): + class KernelSpec: + item = Unicode("foo") + + class KernelSpecManager(HasTraits): + """A manager for kernel specs.""" + + kernel_spec_class = Type( + KernelSpec, + config=True, + help="""The kernel spec class. This is configurable to allow + subclassing of the KernelSpecManager for customized behavior. + """, + ) + other_class = Type("foo.bar.baz") + + t = KernelSpecManager() + reveal_type(t.kernel_spec_class) # R: def () -> tests.test_typing.KernelSpec@124 + reveal_type(t.kernel_spec_class()) # R: tests.test_typing.KernelSpec@124 + reveal_type(t.kernel_spec_class().item) # R: builtins.str + reveal_type(t.other_class) # R: builtins.type + reveal_type(t.other_class()) # R: Any + + +@pytest.mark.mypy_testing +def mypy_unicode_typing(): + class T(HasTraits): + export_format = Unicode( + allow_none=False, + help="""The export format to be used, either one of the built-in formats + or a dotted object name that represents the import path for an + ``Exporter`` class""", + ).tag(config=True) + + t = T() + reveal_type( + Unicode( # R: traitlets.traitlets.Unicode[builtins.str, Union[builtins.str, builtins.bytes]] + "foo" + ) + ) + reveal_type( + Unicode( # R: traitlets.traitlets.Unicode[builtins.str, Union[builtins.str, builtins.bytes]] + "" + ).tag( + sync=True + ) + ) + reveal_type( + Unicode( # R: traitlets.traitlets.Unicode[Union[builtins.str, None], Union[builtins.str, builtins.bytes, None]] + None, allow_none=True + ) + ) + reveal_type( + Unicode( # R: traitlets.traitlets.Unicode[Union[builtins.str, None], Union[builtins.str, builtins.bytes, None]] + None, allow_none=True + ).tag( + sync=True + ) + ) + reveal_type( + T.export_format # R: traitlets.traitlets.Unicode[builtins.str, Union[builtins.str, builtins.bytes]] + ) + reveal_type(t.export_format) # R: builtins.str + + +@pytest.mark.mypy_testing +def mypy_set_typing(): + class T(HasTraits): + remove_cell_tags = Set( + Unicode(), + default_value=[], + help=( + "Tags indicating which cells are to be removed," + "matches tags in ``cell.metadata.tags``." + ), + ).tag(config=True) + + safe_output_keys = Set( + config=True, + default_value={ + "metadata", # Not a mimetype per-se, but expected and safe. + "text/plain", + "text/latex", + "application/json", + "image/png", + "image/jpeg", + }, + help="Cell output mimetypes to render without modification", + ) + + t = T() + reveal_type(Set("foo")) # R: traitlets.traitlets.Set + reveal_type(Set("").tag(sync=True)) # R: traitlets.traitlets.Set + reveal_type(Set(None, allow_none=True)) # R: traitlets.traitlets.Set + reveal_type(Set(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Set + reveal_type(T.remove_cell_tags) # R: traitlets.traitlets.Set + reveal_type(t.remove_cell_tags) # R: builtins.set[Any] + reveal_type(T.safe_output_keys) # R: traitlets.traitlets.Set + reveal_type(t.safe_output_keys) # R: builtins.set[Any] + + +@pytest.mark.mypy_testing +def mypy_any_typing(): + class T(HasTraits): + attributes = Any( + config=True, + default_value={ + "a": ["href", "title"], + "abbr": ["title"], + "acronym": ["title"], + }, + help="Allowed HTML tag attributes", + ) + + t = T() + reveal_type(Any("foo")) # R: traitlets.traitlets.Any + reveal_type(Any("").tag(sync=True)) # R: traitlets.traitlets.Any + reveal_type(Any(None, allow_none=True)) # R: traitlets.traitlets.Any + reveal_type(Any(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Any + reveal_type(T.attributes) # R: traitlets.traitlets.Any + reveal_type(t.attributes) # R: Any + + +@pytest.mark.mypy_testing +def mypy_bool_typing(): + class T(HasTraits): + b = Bool(True).tag(sync=True) + ob = Bool(None, allow_none=True).tag(sync=True) + + t = T() + reveal_type( + Bool(True) # R: traitlets.traitlets.Bool[builtins.bool, Union[builtins.bool, builtins.int]] + ) + reveal_type( + Bool( # R: traitlets.traitlets.Bool[builtins.bool, Union[builtins.bool, builtins.int]] + True + ).tag(sync=True) + ) + reveal_type( + Bool( # R: traitlets.traitlets.Bool[Union[builtins.bool, None], Union[builtins.bool, builtins.int, None]] + None, allow_none=True + ) + ) + reveal_type( + Bool( # R: traitlets.traitlets.Bool[Union[builtins.bool, None], Union[builtins.bool, builtins.int, None]] + None, allow_none=True + ).tag( + sync=True + ) + ) + reveal_type( + T.b # R: traitlets.traitlets.Bool[builtins.bool, Union[builtins.bool, builtins.int]] + ) + reveal_type(t.b) # R: builtins.bool + reveal_type(t.ob) # R: Union[builtins.bool, None] + reveal_type( + T.b # R: traitlets.traitlets.Bool[builtins.bool, Union[builtins.bool, builtins.int]] + ) + reveal_type( + T.ob # R: traitlets.traitlets.Bool[Union[builtins.bool, None], Union[builtins.bool, builtins.int, None]] + ) + # we would expect this to be Optional[Union[bool, int]], but... + t.b = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Union[bool, int]") [assignment] + t.b = None # E: Incompatible types in assignment (expression has type "None", variable has type "Union[bool, int]") [assignment] + + +@pytest.mark.mypy_testing +def mypy_int_typing(): + class T(HasTraits): + i: Int[int, int] = Int(42).tag(sync=True) + oi: Int[int | None, int | None] = Int(42, allow_none=True).tag(sync=True) + + t = T() + reveal_type(Int(True)) # R: traitlets.traitlets.Int[builtins.int, builtins.int] + reveal_type(Int(True).tag(sync=True)) # R: traitlets.traitlets.Int[builtins.int, builtins.int] + reveal_type( + Int( # R: traitlets.traitlets.Int[Union[builtins.int, None], Union[builtins.int, None]] + None, allow_none=True + ) + ) + reveal_type( + Int( # R: traitlets.traitlets.Int[Union[builtins.int, None], Union[builtins.int, None]] + None, allow_none=True + ).tag(sync=True) + ) + reveal_type(T.i) # R: traitlets.traitlets.Int[builtins.int, builtins.int] + reveal_type(t.i) # R: builtins.int + reveal_type(t.oi) # R: Union[builtins.int, None] + reveal_type(T.i) # R: traitlets.traitlets.Int[builtins.int, builtins.int] + reveal_type( + T.oi # R: traitlets.traitlets.Int[Union[builtins.int, None], Union[builtins.int, None]] + ) + t.i = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "int") [assignment] + t.i = None # E: Incompatible types in assignment (expression has type "None", variable has type "int") [assignment] + t.i = 1.2 # E: Incompatible types in assignment (expression has type "float", variable has type "int") [assignment] + + +@pytest.mark.mypy_testing +def mypy_cint_typing(): + class T(HasTraits): + i = CInt(42).tag(sync=True) + oi = CInt(42, allow_none=True).tag(sync=True) + + t = T() + reveal_type(CInt(42)) # R: traitlets.traitlets.CInt[builtins.int, Any] + reveal_type(CInt(42).tag(sync=True)) # R: traitlets.traitlets.CInt[builtins.int, Any] + reveal_type( + CInt(None, allow_none=True) # R: traitlets.traitlets.CInt[Union[builtins.int, None], Any] + ) + reveal_type( + CInt( # R: traitlets.traitlets.CInt[Union[builtins.int, None], Any] + None, allow_none=True + ).tag(sync=True) + ) + reveal_type(T.i) # R: traitlets.traitlets.CInt[builtins.int, Any] + reveal_type(t.i) # R: builtins.int + reveal_type(t.oi) # R: Union[builtins.int, None] + reveal_type(T.i) # R: traitlets.traitlets.CInt[builtins.int, Any] + reveal_type(T.oi) # R: traitlets.traitlets.CInt[Union[builtins.int, None], Any] + + +@pytest.mark.mypy_testing +def mypy_tcp_typing(): + class T(HasTraits): + tcp = TCPAddress() + otcp = TCPAddress(None, allow_none=True) + + t = T() + reveal_type(t.tcp) # R: Tuple[builtins.str, builtins.int] + reveal_type( + T.tcp # R: traitlets.traitlets.TCPAddress[Tuple[builtins.str, builtins.int], Tuple[builtins.str, builtins.int]] + ) + reveal_type( + T.tcp.tag( # R:traitlets.traitlets.TCPAddress[Tuple[builtins.str, builtins.int], Tuple[builtins.str, builtins.int]] + sync=True + ) + ) + reveal_type(t.otcp) # R: Union[Tuple[builtins.str, builtins.int], None] + reveal_type( + T.otcp # R: traitlets.traitlets.TCPAddress[Union[Tuple[builtins.str, builtins.int], None], Union[Tuple[builtins.str, builtins.int], None]] + ) + reveal_type( + T.otcp.tag( # R: traitlets.traitlets.TCPAddress[Union[Tuple[builtins.str, builtins.int], None], Union[Tuple[builtins.str, builtins.int], None]] + sync=True + ) + ) + t.tcp = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Tuple[str, int]") [assignment] + t.otcp = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Optional[Tuple[str, int]]") [assignment] + t.tcp = None # E: Incompatible types in assignment (expression has type "None", variable has type "Tuple[str, int]") [assignment] + + +@pytest.mark.mypy_testing +def mypy_instance_typing(): + class T(HasTraits): + inst = Instance(Foo) + oinst = Instance(Foo, allow_none=True) + oinst_string = Instance("Foo", allow_none=True) + + t = T() + reveal_type(t.inst) # R: tests.test_typing.Foo + reveal_type(T.inst) # R: traitlets.traitlets.Instance[tests.test_typing.Foo] + reveal_type(T.inst.tag(sync=True)) # R: traitlets.traitlets.Instance[tests.test_typing.Foo] + reveal_type(t.oinst) # R: Union[tests.test_typing.Foo, None] + reveal_type(t.oinst_string) # R: Union[Any, None] + reveal_type(T.oinst) # R: traitlets.traitlets.Instance[Union[tests.test_typing.Foo, None]] + reveal_type( + T.oinst.tag( # R: traitlets.traitlets.Instance[Union[tests.test_typing.Foo, None]] + sync=True + ) + ) + t.inst = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Foo") [assignment] + t.oinst = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Optional[Foo]") [assignment] + t.inst = None # E: Incompatible types in assignment (expression has type "None", variable has type "Foo") [assignment] diff --git a/contrib/python/traitlets/py3/tests/utils/__init__.py b/contrib/python/traitlets/py3/tests/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/contrib/python/traitlets/py3/tests/utils/__init__.py diff --git a/contrib/python/traitlets/py3/tests/utils/test_bunch.py b/contrib/python/traitlets/py3/tests/utils/test_bunch.py new file mode 100644 index 00000000000..223124d7d5e --- /dev/null +++ b/contrib/python/traitlets/py3/tests/utils/test_bunch.py @@ -0,0 +1,16 @@ +from traitlets.utils.bunch import Bunch + + +def test_bunch(): + b = Bunch(x=5, y=10) + assert "y" in b + assert "x" in b + assert b.x == 5 + b["a"] = "hi" + assert b.a == "hi" + + +def test_bunch_dir(): + b = Bunch(x=5, y=10) + assert "x" in dir(b) + assert "keys" in dir(b) diff --git a/contrib/python/traitlets/py3/tests/utils/test_decorators.py b/contrib/python/traitlets/py3/tests/utils/test_decorators.py new file mode 100644 index 00000000000..d6bf8414e5a --- /dev/null +++ b/contrib/python/traitlets/py3/tests/utils/test_decorators.py @@ -0,0 +1,137 @@ +from inspect import Parameter, signature +from unittest import TestCase + +from traitlets import HasTraits, Int, Unicode +from traitlets.utils.decorators import signature_has_traits + + +class TestExpandSignature(TestCase): + def test_no_init(self): + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + value = Unicode("Hello") + + parameters = signature(Foo).parameters + parameter_names = list(parameters) + + self.assertIs(parameters["args"].kind, Parameter.VAR_POSITIONAL) + self.assertEqual("args", parameter_names[0]) + + self.assertIs(parameters["number1"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["number2"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["value"].kind, Parameter.KEYWORD_ONLY) + + self.assertIs(parameters["kwargs"].kind, Parameter.VAR_KEYWORD) + self.assertEqual("kwargs", parameter_names[-1]) + + f = Foo(number1=32, value="World") + self.assertEqual(f.number1, 32) + self.assertEqual(f.number2, 0) + self.assertEqual(f.value, "World") + + def test_partial_init(self): + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + value = Unicode("Hello") + + def __init__(self, arg1, **kwargs): + self.arg1 = arg1 + + super().__init__(**kwargs) + + parameters = signature(Foo).parameters + parameter_names = list(parameters) + + self.assertIs(parameters["arg1"].kind, Parameter.POSITIONAL_OR_KEYWORD) + self.assertEqual("arg1", parameter_names[0]) + + self.assertIs(parameters["number1"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["number2"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["value"].kind, Parameter.KEYWORD_ONLY) + + self.assertIs(parameters["kwargs"].kind, Parameter.VAR_KEYWORD) + self.assertEqual("kwargs", parameter_names[-1]) + + f = Foo(1, number1=32, value="World") + self.assertEqual(f.arg1, 1) + self.assertEqual(f.number1, 32) + self.assertEqual(f.number2, 0) + self.assertEqual(f.value, "World") + + def test_duplicate_init(self): + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + + def __init__(self, number1, **kwargs): + self.test = number1 + + super().__init__(number1=number1, **kwargs) + + parameters = signature(Foo).parameters + parameter_names = list(parameters) + + self.assertListEqual(parameter_names, ["number1", "number2", "kwargs"]) + + f = Foo(number1=32, number2=36) + self.assertEqual(f.test, 32) + self.assertEqual(f.number1, 32) + self.assertEqual(f.number2, 36) + + def test_full_init(self): + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + value = Unicode("Hello") + + def __init__(self, arg1, arg2=None, *pos_args, **kw_args): + self.arg1 = arg1 + self.arg2 = arg2 + self.pos_args = pos_args + self.kw_args = kw_args + + super().__init__(*pos_args, **kw_args) + + parameters = signature(Foo).parameters + parameter_names = list(parameters) + + self.assertIs(parameters["arg1"].kind, Parameter.POSITIONAL_OR_KEYWORD) + self.assertEqual("arg1", parameter_names[0]) + + self.assertIs(parameters["arg2"].kind, Parameter.POSITIONAL_OR_KEYWORD) + self.assertEqual("arg2", parameter_names[1]) + + self.assertIs(parameters["pos_args"].kind, Parameter.VAR_POSITIONAL) + self.assertEqual("pos_args", parameter_names[2]) + + self.assertIs(parameters["number1"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["number2"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["value"].kind, Parameter.KEYWORD_ONLY) + + self.assertIs(parameters["kw_args"].kind, Parameter.VAR_KEYWORD) + self.assertEqual("kw_args", parameter_names[-1]) + + f = Foo(1, 3, 45, "hey", number1=32, value="World") + self.assertEqual(f.arg1, 1) + self.assertEqual(f.arg2, 3) + self.assertTupleEqual(f.pos_args, (45, "hey")) + self.assertEqual(f.number1, 32) + self.assertEqual(f.number2, 0) + self.assertEqual(f.value, "World") + + def test_no_kwargs(self): + with self.assertRaises(RuntimeError): + + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + + def __init__(self, arg1, arg2=None): + pass diff --git a/contrib/python/traitlets/py3/tests/utils/test_importstring.py b/contrib/python/traitlets/py3/tests/utils/test_importstring.py new file mode 100644 index 00000000000..8ce28add41e --- /dev/null +++ b/contrib/python/traitlets/py3/tests/utils/test_importstring.py @@ -0,0 +1,26 @@ +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +# +# Adapted from enthought.traits, Copyright (c) Enthought, Inc., +# also under the terms of the Modified BSD License. +"""Tests for traitlets.utils.importstring.""" + +import os +from unittest import TestCase + +from traitlets.utils.importstring import import_item + + +class TestImportItem(TestCase): + def test_import_unicode(self): + self.assertIs(os, import_item("os")) + self.assertIs(os.path, import_item("os.path")) + self.assertIs(os.path.join, import_item("os.path.join")) + + def test_bad_input(self): + class NotAString: + pass + + msg = "import_item accepts strings, not '%s'." % NotAString + with self.assertRaisesRegex(TypeError, msg): + import_item(NotAString()) # type:ignore[arg-type] diff --git a/contrib/python/traitlets/py3/tests/ya.make b/contrib/python/traitlets/py3/tests/ya.make index 6a5cd7cf463..6ffd29993d5 100644 --- a/contrib/python/traitlets/py3/tests/ya.make +++ b/contrib/python/traitlets/py3/tests/ya.make @@ -1,20 +1,27 @@ PY3TEST() PEERDIR( + contrib/python/argcomplete contrib/python/traitlets + contrib/python/pytest-mock ) -SRCDIR(contrib/python/traitlets/py3/traitlets) - TEST_SRCS( - config/tests/test_application.py - config/tests/test_configurable.py - config/tests/test_loader.py - tests/test_traitlets.py - tests/test_traitlets_enum.py - utils/tests/test_bunch.py - utils/tests/test_decorators.py - utils/tests/test_importstring.py + __init__.py + _warnings.py + config/__init__.py + config/test_application.py + config/test_argcomplete.py + config/test_configurable.py + config/test_loader.py + test_traitlets.py + test_traitlets_docstring.py + test_traitlets_enum.py + test_typing.py + utils/__init__.py + utils/test_bunch.py + utils/test_decorators.py + utils/test_importstring.py ) NO_LINT() |