aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/traitlets/py3/tests
diff options
context:
space:
mode:
authorrobot-contrib <robot-contrib@yandex-team.com>2023-10-19 17:11:31 +0300
committerrobot-contrib <robot-contrib@yandex-team.com>2023-10-19 18:26:04 +0300
commitb9fe236a503791a3a7b37d4ef5f466225218996c (patch)
treec2f80019399b393ddf0450d0f91fc36478af8bea /contrib/python/traitlets/py3/tests
parent44dd27d0a2ae37c80d97a95581951d1d272bd7df (diff)
downloadydb-b9fe236a503791a3a7b37d4ef5f466225218996c.tar.gz
Update contrib/python/traitlets/py3 to 5.11.2
Diffstat (limited to 'contrib/python/traitlets/py3/tests')
-rw-r--r--contrib/python/traitlets/py3/tests/__init__.py0
-rw-r--r--contrib/python/traitlets/py3/tests/_warnings.py114
-rw-r--r--contrib/python/traitlets/py3/tests/config/__init__.py0
-rw-r--r--contrib/python/traitlets/py3/tests/config/test_application.py910
-rw-r--r--contrib/python/traitlets/py3/tests/config/test_argcomplete.py219
-rw-r--r--contrib/python/traitlets/py3/tests/config/test_configurable.py711
-rw-r--r--contrib/python/traitlets/py3/tests/config/test_loader.py753
-rw-r--r--contrib/python/traitlets/py3/tests/test_traitlets.py3141
-rw-r--r--contrib/python/traitlets/py3/tests/test_traitlets_docstring.py84
-rw-r--r--contrib/python/traitlets/py3/tests/test_traitlets_enum.py380
-rw-r--r--contrib/python/traitlets/py3/tests/test_typing.py395
-rw-r--r--contrib/python/traitlets/py3/tests/utils/__init__.py0
-rw-r--r--contrib/python/traitlets/py3/tests/utils/test_bunch.py16
-rw-r--r--contrib/python/traitlets/py3/tests/utils/test_decorators.py137
-rw-r--r--contrib/python/traitlets/py3/tests/utils/test_importstring.py26
-rw-r--r--contrib/python/traitlets/py3/tests/ya.make27
16 files changed, 6903 insertions, 10 deletions
diff --git a/contrib/python/traitlets/py3/tests/__init__.py b/contrib/python/traitlets/py3/tests/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/__init__.py
diff --git a/contrib/python/traitlets/py3/tests/_warnings.py b/contrib/python/traitlets/py3/tests/_warnings.py
new file mode 100644
index 00000000000..e3c3a0ac6d6
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/_warnings.py
@@ -0,0 +1,114 @@
+# From scikit-image: https://github.com/scikit-image/scikit-image/blob/c2f8c4ab123ebe5f7b827bc495625a32bb225c10/skimage/_shared/_warnings.py
+# Licensed under modified BSD license
+
+__all__ = ["all_warnings", "expected_warnings"]
+
+import inspect
+import os
+import re
+import sys
+import warnings
+from contextlib import contextmanager
+from unittest import mock
+
+
+@contextmanager
+def all_warnings():
+ """
+ Context for use in testing to ensure that all warnings are raised.
+ Examples
+ --------
+ >>> import warnings
+ >>> def foo():
+ ... warnings.warn(RuntimeWarning("bar"))
+
+ We raise the warning once, while the warning filter is set to "once".
+ Hereafter, the warning is invisible, even with custom filters:
+ >>> with warnings.catch_warnings():
+ ... warnings.simplefilter('once')
+ ... foo()
+
+ We can now run ``foo()`` without a warning being raised:
+ >>> from numpy.testing import assert_warns # doctest: +SKIP
+ >>> foo() # doctest: +SKIP
+
+ To catch the warning, we call in the help of ``all_warnings``:
+ >>> with all_warnings(): # doctest: +SKIP
+ ... assert_warns(RuntimeWarning, foo)
+ """
+
+ # Whenever a warning is triggered, Python adds a __warningregistry__
+ # member to the *calling* module. The exercize here is to find
+ # and eradicate all those breadcrumbs that were left lying around.
+ #
+ # We proceed by first searching all parent calling frames and explicitly
+ # clearing their warning registries (necessary for the doctests above to
+ # pass). Then, we search for all submodules of skimage and clear theirs
+ # as well (necessary for the skimage test suite to pass).
+
+ frame = inspect.currentframe()
+ if frame:
+ for f in inspect.getouterframes(frame):
+ f[0].f_locals["__warningregistry__"] = {}
+ del frame
+
+ for _, mod in list(sys.modules.items()):
+ try:
+ mod.__warningregistry__.clear()
+ except AttributeError:
+ pass
+
+ with warnings.catch_warnings(record=True) as w, mock.patch.dict(
+ os.environ, {"TRAITLETS_ALL_DEPRECATIONS": "1"}
+ ):
+ warnings.simplefilter("always")
+ yield w
+
+
+@contextmanager
+def expected_warnings(matching):
+ r"""Context for use in testing to catch known warnings matching regexes
+
+ Parameters
+ ----------
+ matching : list of strings or compiled regexes
+ Regexes for the desired warning to catch
+
+ Examples
+ --------
+ >>> from skimage import data, img_as_ubyte, img_as_float # doctest: +SKIP
+ >>> with expected_warnings(["precision loss"]): # doctest: +SKIP
+ ... d = img_as_ubyte(img_as_float(data.coins())) # doctest: +SKIP
+
+ Notes
+ -----
+ Uses `all_warnings` to ensure all warnings are raised.
+ Upon exiting, it checks the recorded warnings for the desired matching
+ pattern(s).
+ Raises a ValueError if any match was not found or an unexpected
+ warning was raised.
+ Allows for three types of behaviors: "and", "or", and "optional" matches.
+ This is done to accomodate different build enviroments or loop conditions
+ that may produce different warnings. The behaviors can be combined.
+ If you pass multiple patterns, you get an orderless "and", where all of the
+ warnings must be raised.
+ If you use the "|" operator in a pattern, you can catch one of several warnings.
+ Finally, you can use "|\A\Z" in a pattern to signify it as optional.
+ """
+ with all_warnings() as w:
+ # enter context
+ yield w
+ # exited user context, check the recorded warnings
+ remaining = [m for m in matching if r"\A\Z" not in m.split("|")]
+ for warn in w:
+ found = False
+ for match in matching:
+ if re.search(match, str(warn.message)) is not None:
+ found = True
+ if match in remaining:
+ remaining.remove(match)
+ if not found:
+ raise ValueError("Unexpected warning: %s" % str(warn.message))
+ if len(remaining) > 0:
+ msg = "No warning raised matching:\n%s" % "\n".join(remaining)
+ raise ValueError(msg)
diff --git a/contrib/python/traitlets/py3/tests/config/__init__.py b/contrib/python/traitlets/py3/tests/config/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/config/__init__.py
diff --git a/contrib/python/traitlets/py3/tests/config/test_application.py b/contrib/python/traitlets/py3/tests/config/test_application.py
new file mode 100644
index 00000000000..610cafc3cd5
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/config/test_application.py
@@ -0,0 +1,910 @@
+"""
+Tests for traitlets.config.application.Application
+"""
+
+# Copyright (c) IPython Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+import contextlib
+import io
+import json
+import logging
+import os
+import sys
+import typing as t
+from io import StringIO
+from tempfile import TemporaryDirectory
+from unittest import TestCase
+
+import pytest
+from pytest import mark
+
+from traitlets import Bool, Bytes, Dict, HasTraits, Integer, List, Set, Tuple, Unicode
+from traitlets.config.application import Application
+from traitlets.config.configurable import Configurable
+from traitlets.config.loader import Config, KVArgParseConfigLoader
+from traitlets.tests.utils import check_help_all_output, check_help_output, get_output_error_code
+
+try:
+ from unittest import mock
+except ImportError:
+ from unittest import mock
+
+pjoin = os.path.join
+
+
+class Foo(Configurable):
+ i = Integer(
+ 0,
+ help="""
+ The integer i.
+
+ Details about i.
+ """,
+ ).tag(config=True)
+ j = Integer(1, help="The integer j.").tag(config=True)
+ name = Unicode("Brian", help="First name.").tag(config=True)
+ la = List([]).tag(config=True)
+ li = List(Integer()).tag(config=True)
+ fdict = Dict().tag(config=True, multiplicity="+")
+
+
+class Bar(Configurable):
+ b = Integer(0, help="The integer b.").tag(config=True)
+ enabled = Bool(True, help="Enable bar.").tag(config=True)
+ tb = Tuple(()).tag(config=True, multiplicity="*")
+ aset = Set().tag(config=True, multiplicity="+")
+ bdict = Dict().tag(config=True)
+ idict = Dict(value_trait=Integer()).tag(config=True)
+ key_dict = Dict(per_key_traits={"i": Integer(), "b": Bytes()}).tag(config=True)
+
+
+class MyApp(Application):
+ name = Unicode("myapp")
+ running = Bool(False, help="Is the app running?").tag(config=True)
+ classes = List([Bar, Foo]) # type:ignore
+ config_file = Unicode("", help="Load this config file").tag(config=True)
+
+ warn_tpyo = Unicode(
+ "yes the name is wrong on purpose",
+ config=True,
+ help="Should print a warning if `MyApp.warn-typo=...` command is passed",
+ )
+
+ aliases: t.Dict[t.Any, t.Any] = {}
+ aliases.update(Application.aliases)
+ aliases.update(
+ {
+ ("fooi", "i"): "Foo.i",
+ ("j", "fooj"): ("Foo.j", "`j` terse help msg"),
+ "name": "Foo.name",
+ "la": "Foo.la",
+ "li": "Foo.li",
+ "tb": "Bar.tb",
+ "D": "Bar.bdict",
+ "enabled": "Bar.enabled",
+ "enable": "Bar.enabled",
+ "log-level": "Application.log_level",
+ }
+ )
+
+ flags: t.Dict[t.Any, t.Any] = {}
+ flags.update(Application.flags)
+ flags.update(
+ {
+ ("enable", "e"): ({"Bar": {"enabled": True}}, "Set Bar.enabled to True"),
+ ("d", "disable"): ({"Bar": {"enabled": False}}, "Set Bar.enabled to False"),
+ "crit": ({"Application": {"log_level": logging.CRITICAL}}, "set level=CRITICAL"),
+ }
+ )
+
+ def init_foo(self):
+ self.foo = Foo(parent=self)
+
+ def init_bar(self):
+ self.bar = Bar(parent=self)
+
+
+def class_to_names(classes):
+ return [klass.__name__ for klass in classes]
+
+
+class TestApplication(TestCase):
+ def test_log(self):
+ stream = StringIO()
+ app = MyApp(log_level=logging.INFO)
+ handler = logging.StreamHandler(stream)
+ # trigger reconstruction of the log formatter
+ app.log_format = "%(message)s"
+ app.log_datefmt = "%Y-%m-%d %H:%M"
+ app.log.handlers = [handler]
+ app.log.info("hello")
+ assert "hello" in stream.getvalue()
+
+ def test_no_eval_cli_text(self):
+ app = MyApp()
+ app.initialize(["--Foo.name=1"])
+ app.init_foo()
+ assert app.foo.name == "1"
+
+ def test_basic(self):
+ app = MyApp()
+ self.assertEqual(app.name, "myapp")
+ self.assertEqual(app.running, False)
+ self.assertEqual(app.classes, [MyApp, Bar, Foo]) # type:ignore
+ self.assertEqual(app.config_file, "")
+
+ def test_app_name_set_via_constructor(self):
+ app = MyApp(name='set_via_constructor')
+ assert app.name == "set_via_constructor"
+
+ def test_mro_discovery(self):
+ app = MyApp()
+
+ self.assertSequenceEqual(
+ class_to_names(app._classes_with_config_traits()),
+ ["Application", "MyApp", "Bar", "Foo"],
+ )
+ self.assertSequenceEqual(
+ class_to_names(app._classes_inc_parents()),
+ [
+ "Configurable",
+ "LoggingConfigurable",
+ "SingletonConfigurable",
+ "Application",
+ "MyApp",
+ "Bar",
+ "Foo",
+ ],
+ )
+
+ self.assertSequenceEqual(
+ class_to_names(app._classes_with_config_traits([Application])), ["Application"]
+ )
+ self.assertSequenceEqual(
+ class_to_names(app._classes_inc_parents([Application])),
+ ["Configurable", "LoggingConfigurable", "SingletonConfigurable", "Application"],
+ )
+
+ self.assertSequenceEqual(class_to_names(app._classes_with_config_traits([Foo])), ["Foo"])
+ self.assertSequenceEqual(
+ class_to_names(app._classes_inc_parents([Bar])), ["Configurable", "Bar"]
+ )
+
+ class MyApp2(Application): # no defined `classes` attr
+ pass
+
+ self.assertSequenceEqual(class_to_names(app._classes_with_config_traits([Foo])), ["Foo"])
+ self.assertSequenceEqual(
+ class_to_names(app._classes_inc_parents([Bar])), ["Configurable", "Bar"]
+ )
+
+ def test_config(self):
+ app = MyApp()
+ app.parse_command_line(
+ [
+ "--i=10",
+ "--Foo.j=10",
+ "--enable=False",
+ "--log-level=50",
+ ]
+ )
+ config = app.config
+ print(config)
+ self.assertEqual(config.Foo.i, 10)
+ self.assertEqual(config.Foo.j, 10)
+ self.assertEqual(config.Bar.enabled, False)
+ self.assertEqual(config.MyApp.log_level, 50)
+
+ def test_config_seq_args(self):
+ app = MyApp()
+ app.parse_command_line(
+ "--li 1 --li 3 --la 1 --tb AB 2 --Foo.la=ab --Bar.aset S1 --Bar.aset S2 --Bar.aset S1".split()
+ )
+ assert app.extra_args == ["2"]
+ config = app.config
+ assert config.Foo.li == [1, 3]
+ assert config.Foo.la == ["1", "ab"]
+ assert config.Bar.tb == ("AB",)
+ self.assertEqual(config.Bar.aset, {"S1", "S2"})
+ app.init_foo()
+ assert app.foo.li == [1, 3]
+ assert app.foo.la == ["1", "ab"]
+ app.init_bar()
+ self.assertEqual(app.bar.aset, {"S1", "S2"})
+ assert app.bar.tb == ("AB",)
+
+ def test_config_dict_args(self):
+ app = MyApp()
+ app.parse_command_line(
+ "--Foo.fdict a=1 --Foo.fdict b=b --Foo.fdict c=3 "
+ "--Bar.bdict k=1 -D=a=b -D 22=33 "
+ "--Bar.idict k=1 --Bar.idict b=2 --Bar.idict c=3 ".split()
+ )
+ fdict = {"a": "1", "b": "b", "c": "3"}
+ bdict = {"k": "1", "a": "b", "22": "33"}
+ idict = {"k": 1, "b": 2, "c": 3}
+ config = app.config
+ assert config.Bar.idict == idict
+ self.assertDictEqual(config.Foo.fdict, fdict)
+ self.assertDictEqual(config.Bar.bdict, bdict)
+ app.init_foo()
+ self.assertEqual(app.foo.fdict, fdict)
+ app.init_bar()
+ assert app.bar.idict == idict
+ self.assertEqual(app.bar.bdict, bdict)
+
+ def test_config_propagation(self):
+ app = MyApp()
+ app.parse_command_line(["--i=10", "--Foo.j=10", "--enable=False", "--log-level=50"])
+ app.init_foo()
+ app.init_bar()
+ self.assertEqual(app.foo.i, 10)
+ self.assertEqual(app.foo.j, 10)
+ self.assertEqual(app.bar.enabled, False)
+
+ def test_cli_priority(self):
+ """Test that loading config files does not override CLI options"""
+ name = "config.py"
+
+ class TestApp(Application):
+ value = Unicode().tag(config=True)
+ config_file_loaded = Bool().tag(config=True)
+ aliases = {"v": "TestApp.value"}
+
+ app = TestApp()
+ with TemporaryDirectory() as td:
+ config_file = pjoin(td, name)
+ with open(config_file, "w") as f:
+ f.writelines(
+ ["c.TestApp.value = 'config file'\n", "c.TestApp.config_file_loaded = True\n"]
+ )
+
+ app.parse_command_line(["--v=cli"])
+ assert "value" in app.config.TestApp
+ assert app.config.TestApp.value == "cli"
+ assert app.value == "cli"
+
+ app.load_config_file(name, path=[td])
+ assert app.config_file_loaded
+ assert app.config.TestApp.value == "cli"
+ assert app.value == "cli"
+
+ def test_ipython_cli_priority(self):
+ # this test is almost entirely redundant with above,
+ # but we can keep it around in case of subtle issues creeping into
+ # the exact sequence IPython follows.
+ name = "config.py"
+
+ class TestApp(Application):
+ value = Unicode().tag(config=True)
+ config_file_loaded = Bool().tag(config=True)
+ aliases = {"v": ("TestApp.value", "some help")}
+
+ app = TestApp()
+ with TemporaryDirectory() as td:
+ config_file = pjoin(td, name)
+ with open(config_file, "w") as f:
+ f.writelines(
+ ["c.TestApp.value = 'config file'\n", "c.TestApp.config_file_loaded = True\n"]
+ )
+ # follow IPython's config-loading sequence to ensure CLI priority is preserved
+ app.parse_command_line(["--v=cli"])
+ # this is where IPython makes a mistake:
+ # it assumes app.config will not be modified,
+ # and storing a reference is storing a copy
+ cli_config = app.config
+ assert "value" in app.config.TestApp
+ assert app.config.TestApp.value == "cli"
+ assert app.value == "cli"
+ app.load_config_file(name, path=[td])
+ assert app.config_file_loaded
+ # enforce cl-opts override config file opts:
+ # this is where IPython makes a mistake: it assumes
+ # that cl_config is a different object, but it isn't.
+ app.update_config(cli_config)
+ assert app.config.TestApp.value == "cli"
+ assert app.value == "cli"
+
+ def test_cli_allow_none(self):
+ class App(Application):
+ aliases = {"opt": "App.opt"}
+ opt = Unicode(allow_none=True, config=True)
+
+ app = App()
+ app.parse_command_line(["--opt=None"])
+ assert app.opt is None
+
+ def test_flags(self):
+ app = MyApp()
+ app.parse_command_line(["--disable"])
+ app.init_bar()
+ self.assertEqual(app.bar.enabled, False)
+
+ app = MyApp()
+ app.parse_command_line(["-d"])
+ app.init_bar()
+ self.assertEqual(app.bar.enabled, False)
+
+ app = MyApp()
+ app.parse_command_line(["--enable"])
+ app.init_bar()
+ self.assertEqual(app.bar.enabled, True)
+
+ app = MyApp()
+ app.parse_command_line(["-e"])
+ app.init_bar()
+ self.assertEqual(app.bar.enabled, True)
+
+ def test_flags_help_msg(self):
+ app = MyApp()
+ stdout = io.StringIO()
+ with contextlib.redirect_stdout(stdout):
+ app.print_flag_help()
+ hmsg = stdout.getvalue()
+ self.assertRegex(hmsg, "(?<!-)-e, --enable\\b")
+ self.assertRegex(hmsg, "(?<!-)-d, --disable\\b")
+ self.assertIn("Equivalent to: [--Bar.enabled=True]", hmsg)
+ self.assertIn("Equivalent to: [--Bar.enabled=False]", hmsg)
+
+ def test_aliases(self):
+ app = MyApp()
+ app.parse_command_line(["--i=5", "--j=10"])
+ app.init_foo()
+ self.assertEqual(app.foo.i, 5)
+ app.init_foo()
+ self.assertEqual(app.foo.j, 10)
+
+ app = MyApp()
+ app.parse_command_line(["-i=5", "-j=10"])
+ app.init_foo()
+ self.assertEqual(app.foo.i, 5)
+ app.init_foo()
+ self.assertEqual(app.foo.j, 10)
+
+ app = MyApp()
+ app.parse_command_line(["--fooi=5", "--fooj=10"])
+ app.init_foo()
+ self.assertEqual(app.foo.i, 5)
+ app.init_foo()
+ self.assertEqual(app.foo.j, 10)
+
+ def test_aliases_multiple(self):
+ # Test multiple > 2 aliases for the same argument
+ class TestMultiAliasApp(Application):
+ foo = Integer(config=True)
+ aliases = {("f", "bar", "qux"): "TestMultiAliasApp.foo"}
+
+ app = TestMultiAliasApp()
+ app.parse_command_line(["-f", "3"])
+ self.assertEqual(app.foo, 3)
+
+ app = TestMultiAliasApp()
+ app.parse_command_line(["--bar", "4"])
+ self.assertEqual(app.foo, 4)
+
+ app = TestMultiAliasApp()
+ app.parse_command_line(["--qux", "5"])
+ self.assertEqual(app.foo, 5)
+
+ def test_aliases_help_msg(self):
+ app = MyApp()
+ stdout = io.StringIO()
+ with contextlib.redirect_stdout(stdout):
+ app.print_alias_help()
+ hmsg = stdout.getvalue()
+ self.assertRegex(hmsg, "(?<!-)-i, --fooi\\b")
+ self.assertRegex(hmsg, "(?<!-)-j, --fooj\\b")
+ self.assertIn("Equivalent to: [--Foo.i]", hmsg)
+ self.assertIn("Equivalent to: [--Foo.j]", hmsg)
+ self.assertIn("Equivalent to: [--Foo.name]", hmsg)
+
+ def test_alias_unrecognized(self):
+ """Check ability to override handling for unrecognized aliases"""
+
+ class StrictLoader(KVArgParseConfigLoader):
+ def _handle_unrecognized_alias(self, arg):
+ self.parser.error("Unrecognized alias: %s" % arg)
+
+ class StrictApplication(Application):
+ def _create_loader(self, argv, aliases, flags, classes):
+ return StrictLoader(argv, aliases, flags, classes=classes, log=self.log)
+
+ app = StrictApplication()
+ app.initialize(["--log-level=20"]) # recognized alias
+ assert app.log_level == 20
+
+ app = StrictApplication()
+ with pytest.raises(SystemExit, match="2"):
+ app.initialize(["--unrecognized=20"])
+
+ # Ideally we would use pytest capsys fixture, but fixtures are incompatible
+ # with unittest.TestCase-style classes :(
+ # stderr = capsys.readouterr().err
+ # assert "Unrecognized alias: unrecognized" in stderr
+
+ def test_flag_clobber(self):
+ """test that setting flags doesn't clobber existing settings"""
+ app = MyApp()
+ app.parse_command_line(["--Bar.b=5", "--disable"])
+ app.init_bar()
+ self.assertEqual(app.bar.enabled, False)
+ self.assertEqual(app.bar.b, 5)
+ app.parse_command_line(["--enable", "--Bar.b=10"])
+ app.init_bar()
+ self.assertEqual(app.bar.enabled, True)
+ self.assertEqual(app.bar.b, 10)
+
+ def test_warn_autocorrect(self):
+ stream = StringIO()
+ app = MyApp(log_level=logging.INFO)
+ app.log.handlers = [logging.StreamHandler(stream)]
+
+ cfg = Config()
+ cfg.MyApp.warn_typo = "WOOOO"
+ app.config = cfg
+
+ self.assertIn("warn_typo", stream.getvalue())
+ self.assertIn("warn_tpyo", stream.getvalue())
+
+ def test_flatten_flags(self):
+ cfg = Config()
+ cfg.MyApp.log_level = logging.WARN
+ app = MyApp()
+ app.update_config(cfg)
+ self.assertEqual(app.log_level, logging.WARN)
+ self.assertEqual(app.config.MyApp.log_level, logging.WARN)
+ app.initialize(["--crit"])
+ self.assertEqual(app.log_level, logging.CRITICAL)
+ # this would be app.config.Application.log_level if it failed:
+ self.assertEqual(app.config.MyApp.log_level, logging.CRITICAL)
+
+ def test_flatten_aliases(self):
+ cfg = Config()
+ cfg.MyApp.log_level = logging.WARN
+ app = MyApp()
+ app.update_config(cfg)
+ self.assertEqual(app.log_level, logging.WARN)
+ self.assertEqual(app.config.MyApp.log_level, logging.WARN)
+ app.initialize(["--log-level", "CRITICAL"])
+ self.assertEqual(app.log_level, logging.CRITICAL)
+ # this would be app.config.Application.log_level if it failed:
+ self.assertEqual(app.config.MyApp.log_level, "CRITICAL")
+
+ def test_extra_args(self):
+ app = MyApp()
+ app.parse_command_line(["--Bar.b=5", "extra", "args", "--disable"])
+ app.init_bar()
+ self.assertEqual(app.bar.enabled, False)
+ self.assertEqual(app.bar.b, 5)
+ self.assertEqual(app.extra_args, ["extra", "args"])
+
+ app = MyApp()
+ app.parse_command_line(["--Bar.b=5", "--", "extra", "--disable", "args"])
+ app.init_bar()
+ self.assertEqual(app.bar.enabled, True)
+ self.assertEqual(app.bar.b, 5)
+ self.assertEqual(app.extra_args, ["extra", "--disable", "args"])
+
+ app = MyApp()
+ app.parse_command_line(["--disable", "--la", "-", "-", "--Bar.b=1", "--", "-", "extra"])
+ self.assertEqual(app.extra_args, ["-", "-", "extra"])
+
+ def test_unicode_argv(self):
+ app = MyApp()
+ app.parse_command_line(["ünîcødé"])
+
+ def test_document_config_option(self):
+ app = MyApp()
+ app.document_config_options()
+
+ def test_generate_config_file(self):
+ app = MyApp()
+ assert "The integer b." in app.generate_config_file()
+
+ def test_generate_config_file_classes_to_include(self):
+ class NotInConfig(HasTraits):
+ from_hidden = Unicode(
+ "x",
+ help="""From hidden class
+
+ Details about from_hidden.
+ """,
+ ).tag(config=True)
+
+ class NoTraits(Foo, Bar, NotInConfig):
+ pass
+
+ app = MyApp()
+ app.classes.append(NoTraits) # type:ignore
+
+ conf_txt = app.generate_config_file()
+ print(conf_txt)
+ self.assertIn("The integer b.", conf_txt)
+ self.assertIn("# Foo(Configurable)", conf_txt)
+ self.assertNotIn("# Configurable", conf_txt)
+ self.assertIn("# NoTraits(Foo, Bar)", conf_txt)
+
+ # inherited traits, parent in class list:
+ self.assertIn("# c.NoTraits.i", conf_txt)
+ self.assertIn("# c.NoTraits.j", conf_txt)
+ self.assertIn("# c.NoTraits.n", conf_txt)
+ self.assertIn("# See also: Foo.j", conf_txt)
+ self.assertIn("# See also: Bar.b", conf_txt)
+ self.assertEqual(conf_txt.count("Details about i."), 1)
+
+ # inherited traits, parent not in class list:
+ self.assertIn("# c.NoTraits.from_hidden", conf_txt)
+ self.assertNotIn("# See also: NotInConfig.", conf_txt)
+ self.assertEqual(conf_txt.count("Details about from_hidden."), 1)
+ self.assertNotIn("NotInConfig", conf_txt)
+
+ def test_multi_file(self):
+ app = MyApp()
+ app.log = logging.getLogger()
+ name = "config.py"
+ with TemporaryDirectory("_1") as td1:
+ with open(pjoin(td1, name), "w") as f1:
+ f1.write("get_config().MyApp.Bar.b = 1")
+ with TemporaryDirectory("_2") as td2:
+ with open(pjoin(td2, name), "w") as f2:
+ f2.write("get_config().MyApp.Bar.b = 2")
+ app.load_config_file(name, path=[td2, td1])
+ app.init_bar()
+ self.assertEqual(app.bar.b, 2)
+ app.load_config_file(name, path=[td1, td2])
+ app.init_bar()
+ self.assertEqual(app.bar.b, 1)
+
+ @mark.skipif(not hasattr(TestCase, "assertLogs"), reason="requires TestCase.assertLogs")
+ def test_log_collisions(self):
+ app = MyApp()
+ app.log = logging.getLogger()
+ app.log.setLevel(logging.INFO)
+ name = "config"
+ with TemporaryDirectory("_1") as td:
+ with open(pjoin(td, name + ".py"), "w") as f:
+ f.write("get_config().Bar.b = 1")
+ with open(pjoin(td, name + ".json"), "w") as f:
+ json.dump({"Bar": {"b": 2}}, f)
+ with self.assertLogs(app.log, logging.WARNING) as captured:
+ app.load_config_file(name, path=[td])
+ app.init_bar()
+ assert app.bar.b == 2
+ output = "\n".join(captured.output)
+ assert "Collision" in output
+ assert "1 ignored, using 2" in output
+ assert pjoin(td, name + ".py") in output
+ assert pjoin(td, name + ".json") in output
+
+ @mark.skipif(not hasattr(TestCase, "assertLogs"), reason="requires TestCase.assertLogs")
+ def test_log_bad_config(self):
+ app = MyApp()
+ app.log = logging.getLogger()
+ name = "config.py"
+ with TemporaryDirectory() as td:
+ with open(pjoin(td, name), "w") as f:
+ f.write("syntax error()")
+ with self.assertLogs(app.log, logging.ERROR) as captured:
+ app.load_config_file(name, path=[td])
+ output = "\n".join(captured.output)
+ self.assertIn("SyntaxError", output)
+
+ def test_raise_on_bad_config(self):
+ app = MyApp()
+ app.raise_config_file_errors = True
+ app.log = logging.getLogger()
+ name = "config.py"
+ with TemporaryDirectory() as td:
+ with open(pjoin(td, name), "w") as f:
+ f.write("syntax error()")
+ with self.assertRaises(SyntaxError):
+ app.load_config_file(name, path=[td])
+
+ def test_subcommands_instantiation(self):
+ """Try all ways to specify how to create sub-apps."""
+ app = Root.instance()
+ app.parse_command_line(["sub1"])
+
+ self.assertIsInstance(app.subapp, Sub1)
+ # Check parent hierarchy.
+ self.assertIs(app.subapp.parent, app)
+
+ Root.clear_instance()
+ Sub1.clear_instance() # Otherwise, replaced spuriously and hierarchy check fails.
+ app = Root.instance()
+
+ app.parse_command_line(["sub1", "sub2"])
+ self.assertIsInstance(app.subapp, Sub1)
+ self.assertIsInstance(app.subapp.subapp, Sub2)
+ # Check parent hierarchy.
+ self.assertIs(app.subapp.parent, app)
+ self.assertIs(app.subapp.subapp.parent, app.subapp)
+
+ Root.clear_instance()
+ Sub1.clear_instance() # Otherwise, replaced spuriously and hierarchy check fails.
+ app = Root.instance()
+
+ app.parse_command_line(["sub1", "sub3"])
+ self.assertIsInstance(app.subapp, Sub1)
+ self.assertIsInstance(app.subapp.subapp, Sub3)
+ self.assertTrue(app.subapp.subapp.flag) # Set by factory.
+ # Check parent hierarchy.
+ self.assertIs(app.subapp.parent, app)
+ self.assertIs(app.subapp.subapp.parent, app.subapp) # Set by factory.
+
+ Root.clear_instance()
+ Sub1.clear_instance()
+
+ def test_loaded_config_files(self):
+ app = MyApp()
+ app.log = logging.getLogger()
+ name = "config.py"
+ with TemporaryDirectory("_1") as td1:
+ config_file = pjoin(td1, name)
+ with open(config_file, "w") as f:
+ f.writelines(["c.MyApp.running = True\n"])
+
+ app.load_config_file(name, path=[td1])
+ self.assertEqual(len(app.loaded_config_files), 1)
+ self.assertEqual(app.loaded_config_files[0], config_file)
+
+ app.start()
+ self.assertEqual(app.running, True)
+
+ # emulate an app that allows dynamic updates and update config file
+ with open(config_file, "w") as f:
+ f.writelines(["c.MyApp.running = False\n"])
+
+ # reload and verify update, and that loaded_configs was not increased
+ app.load_config_file(name, path=[td1])
+ self.assertEqual(len(app.loaded_config_files), 1)
+ self.assertEqual(app.running, False)
+
+ # Attempt to update, ensure error...
+ with self.assertRaises(AttributeError):
+ app.loaded_config_files = "/foo" # type:ignore
+
+ # ensure it can't be udpated via append
+ app.loaded_config_files.append("/bar")
+ self.assertEqual(len(app.loaded_config_files), 1)
+
+ # repeat to ensure no unexpected changes occurred
+ app.load_config_file(name, path=[td1])
+ self.assertEqual(len(app.loaded_config_files), 1)
+ self.assertEqual(app.running, False)
+
+
+@mark.skip
+def test_cli_multi_scalar(caplog):
+ class App(Application):
+ aliases = {"opt": "App.opt"}
+ opt = Unicode(config=True)
+
+ app = App(log=logging.getLogger())
+ with pytest.raises(SystemExit):
+ app.parse_command_line(["--opt", "1", "--opt", "2"])
+ record = caplog.get_records("call")[-1]
+ message = record.message
+
+ assert "Error loading argument" in message
+ assert "App.opt=['1', '2']" in message
+ assert "opt only accepts one value" in message
+ assert record.levelno == logging.CRITICAL
+
+
+class Root(Application):
+ subcommands = {
+ "sub1": ("__tests__.config.test_application.Sub1", "import string"),
+ }
+
+
+class Sub3(Application):
+ flag = Bool(False)
+
+
+class Sub2(Application):
+ pass
+
+
+class Sub1(Application):
+ subcommands: dict = { # type:ignore
+ "sub2": (Sub2, "Application class"),
+ "sub3": (lambda root: Sub3(parent=root, flag=True), "factory"),
+ }
+
+
+class DeprecatedApp(Application):
+ override_called = False
+ parent_called = False
+
+ def _config_changed(self, name, old, new):
+ self.override_called = True
+
+ def _capture(*args):
+ self.parent_called = True
+
+ with mock.patch.object(self.log, "debug", _capture):
+ super()._config_changed(name, old, new)
+
+
+def test_deprecated_notifier():
+ app = DeprecatedApp()
+ assert not app.override_called
+ assert not app.parent_called
+ app.config = Config({"A": {"b": "c"}})
+ assert app.override_called
+ assert app.parent_called
+
+
+def test_help_output():
+ check_help_output(__name__)
+
+
+def test_help_all_output():
+ check_help_all_output(__name__)
+
+
+def test_show_config_cli():
+ out, err, ec = get_output_error_code([sys.executable, "-m", __name__, "--show-config"])
+ assert ec == 0
+ assert "show_config" not in out
+
+
+def test_show_config_json_cli():
+ out, err, ec = get_output_error_code([sys.executable, "-m", __name__, "--show-config-json"])
+ assert ec == 0
+ assert "show_config" not in out
+
+
+def test_show_config(capsys):
+ cfg = Config()
+ cfg.MyApp.i = 5
+ # don't show empty
+ cfg.OtherApp
+
+ app = MyApp(config=cfg, show_config=True)
+ app.start()
+ out, err = capsys.readouterr()
+ assert "MyApp" in out
+ assert "i = 5" in out
+ assert "OtherApp" not in out
+
+
+def test_show_config_json(capsys):
+ cfg = Config()
+ cfg.MyApp.i = 5
+ cfg.OtherApp
+
+ app = MyApp(config=cfg, show_config_json=True)
+ app.start()
+ out, err = capsys.readouterr()
+ displayed = json.loads(out)
+ assert Config(displayed) == cfg
+
+
+def test_deep_alias():
+ from traitlets import Int
+ from traitlets.config import Application, Configurable
+
+ class Foo(Configurable):
+ val = Int(default_value=5).tag(config=True)
+
+ class Bar(Configurable):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.foo = Foo(parent=self)
+
+ class TestApp(Application):
+ name = "test"
+
+ aliases = {"val": "Bar.Foo.val"}
+ classes = [Foo, Bar]
+
+ def initialize(self, *args, **kwargs):
+ super().initialize(*args, **kwargs)
+ self.bar = Bar(parent=self)
+
+ app = TestApp()
+ app.initialize(["--val=10"])
+ assert app.bar.foo.val == 10
+ assert len(list(app.emit_alias_help())) > 0
+
+
+def test_logging_config(tmp_path, capsys):
+ """We should be able to configure additional log handlers."""
+ log_file = tmp_path / "log_file"
+ app = Application(
+ logging_config={
+ "version": 1,
+ "handlers": {
+ "file": {
+ "class": "logging.FileHandler",
+ "level": "DEBUG",
+ "filename": str(log_file),
+ },
+ },
+ "loggers": {
+ "Application": {
+ "level": "DEBUG",
+ "handlers": ["console", "file"],
+ },
+ },
+ }
+ )
+ # the default "console" handler + our new "file" handler
+ assert len(app.log.handlers) == 2
+
+ # log a couple of messages
+ app.log.info("info")
+ app.log.warning("warn")
+
+ # test that log messages get written to the file
+ with open(log_file) as log_handle:
+ assert log_handle.read() == "info\nwarn\n"
+
+ # test that log messages get written to stderr (default console handler)
+ assert capsys.readouterr().err == "[Application] WARNING | warn\n"
+
+
+def test_get_default_logging_config_pythonw(monkeypatch):
+ """Ensure logging is correctly disabled for pythonw usage."""
+ monkeypatch.setattr("traitlets.config.application.IS_PYTHONW", True)
+ config = Application().get_default_logging_config()
+ assert "handlers" not in config
+ assert "loggers" not in config
+
+ monkeypatch.setattr("traitlets.config.application.IS_PYTHONW", False)
+ config = Application().get_default_logging_config()
+ assert "handlers" in config
+ assert "loggers" in config
+
+
+@pytest.fixture
+def caplogconfig(monkeypatch):
+ """Capture logging config events for DictConfigurator objects.
+
+ This suppresses the event (so the configuration doesn't happen).
+
+ Returns a list of (args, kwargs).
+ """
+ calls = []
+
+ def _configure(*args, **kwargs):
+ nonlocal calls
+ calls.append((args, kwargs))
+
+ monkeypatch.setattr(
+ "logging.config.DictConfigurator.configure",
+ _configure,
+ )
+
+ return calls
+
+
+@pytest.mark.skipif(sys.implementation.name == "pypy", reason="Test does not work on pypy")
+def test_logging_teardown_on_error(capsys, caplogconfig):
+ """Ensure we don't try to open logs in order to close them (See #722).
+
+ If you try to configure logging handlers whilst Python is shutting down
+ you may get traceback.
+ """
+ # create and destroy an app (without configuring logging)
+ # (ensure that the logs were not configured)
+ app = Application()
+ del app
+ assert len(caplogconfig) == 0 # logging was not configured
+ out, err = capsys.readouterr()
+ assert "Traceback" not in err
+
+ # ensure that the logs would have been configured otherwise
+ # (to prevent this test from yielding false-negatives)
+ app = Application()
+ app._logging_configured = True # make it look like logging was configured
+ del app
+ assert len(caplogconfig) == 1 # logging was configured
+
+
+if __name__ == "__main__":
+ # for test_help_output:
+ MyApp.launch_instance()
diff --git a/contrib/python/traitlets/py3/tests/config/test_argcomplete.py b/contrib/python/traitlets/py3/tests/config/test_argcomplete.py
new file mode 100644
index 00000000000..52ed6d2bb2c
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/config/test_argcomplete.py
@@ -0,0 +1,219 @@
+"""
+Tests for argcomplete handling by traitlets.config.application.Application
+"""
+
+# Copyright (c) IPython Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+import io
+import os
+import typing as t
+
+import pytest
+
+argcomplete = pytest.importorskip("argcomplete")
+
+from traitlets import Unicode
+from traitlets.config.application import Application
+from traitlets.config.configurable import Configurable
+from traitlets.config.loader import KVArgParseConfigLoader
+
+
+class ArgcompleteApp(Application):
+ """Override loader to pass through kwargs for argcomplete testing"""
+
+ argcomplete_kwargs: t.Dict[str, t.Any]
+
+ def __init__(self, *args, **kwargs):
+ # For subcommands, inherit argcomplete_kwargs from parent app
+ parent = kwargs.get("parent")
+ super().__init__(*args, **kwargs)
+ if parent:
+ argcomplete_kwargs = getattr(parent, "argcomplete_kwargs", None)
+ if argcomplete_kwargs:
+ self.argcomplete_kwargs = argcomplete_kwargs
+
+ def _create_loader(self, argv, aliases, flags, classes):
+ loader = KVArgParseConfigLoader(
+ argv, aliases, flags, classes=classes, log=self.log, subcommands=self.subcommands
+ )
+ loader._argcomplete_kwargs = self.argcomplete_kwargs # type: ignore[attr-defined]
+ return loader
+
+
+class SubApp1(ArgcompleteApp):
+ pass
+
+
+class SubApp2(ArgcompleteApp):
+ @classmethod
+ def get_subapp_instance(cls, app: Application) -> Application:
+ app.clear_instance() # since Application is singleton, need to clear main app
+ return cls.instance(parent=app) # type: ignore[no-any-return]
+
+
+class MainApp(ArgcompleteApp):
+ subcommands = {
+ "subapp1": (SubApp1, "First subapp"),
+ "subapp2": (SubApp2.get_subapp_instance, "Second subapp"),
+ }
+
+
+class CustomError(Exception):
+ """Helper for exit hook for testing argcomplete"""
+
+ @classmethod
+ def exit(cls, code):
+ raise cls(str(code))
+
+
+class TestArgcomplete:
+ IFS = "\013"
+ COMP_WORDBREAKS = " \t\n\"'><=;|&(:"
+
+ @pytest.fixture
+ def argcomplete_on(self, mocker):
+ """Mostly borrowed from argcomplete's unit test fixtures
+
+ Set up environment variables to mimic those passed by argcomplete
+ """
+ _old_environ = os.environ
+ os.environ = os.environ.copy() # type: ignore[assignment]
+ os.environ["_ARGCOMPLETE"] = "1"
+ os.environ["_ARC_DEBUG"] = "yes"
+ os.environ["IFS"] = self.IFS
+ os.environ["_ARGCOMPLETE_COMP_WORDBREAKS"] = self.COMP_WORDBREAKS
+
+ # argcomplete==2.0.0 always calls fdopen(9, "w") to open a debug stream,
+ # however this could conflict with file descriptors used by pytest
+ # and lead to obscure errors. Since we are not looking at debug stream
+ # in these tests, just mock this fdopen call out.
+ mocker.patch("os.fdopen")
+ try:
+ yield
+ finally:
+ os.environ = _old_environ
+
+ def run_completer(
+ self,
+ app: ArgcompleteApp,
+ command: str,
+ point: t.Union[str, int, None] = None,
+ **kwargs: t.Any,
+ ) -> t.List[str]:
+ """Mostly borrowed from argcomplete's unit tests
+
+ Modified to take an application instead of an ArgumentParser
+
+ Command is the current command being completed and point is the index
+ into the command where the completion is triggered.
+ """
+ if point is None:
+ point = str(len(command))
+ # Flushing tempfile was leading to CI failures with Bad file descriptor, not sure why.
+ # Fortunately we can just write to a StringIO instead.
+ # print("Writing completions to temp file with mode=", write_mode)
+ # from tempfile import TemporaryFile
+ # with TemporaryFile(mode=write_mode) as t:
+ strio = io.StringIO()
+ os.environ["COMP_LINE"] = command
+ os.environ["COMP_POINT"] = str(point)
+
+ with pytest.raises(CustomError) as cm:
+ app.argcomplete_kwargs = dict(
+ output_stream=strio, exit_method=CustomError.exit, **kwargs
+ )
+ app.initialize()
+
+ if str(cm.value) != "0":
+ raise RuntimeError(f"Unexpected exit code {cm.value}")
+ out = strio.getvalue()
+ return out.split(self.IFS)
+
+ def test_complete_simple_app(self, argcomplete_on):
+ app = ArgcompleteApp()
+ expected = [
+ '--help',
+ '--debug',
+ '--show-config',
+ '--show-config-json',
+ '--log-level',
+ '--Application.',
+ '--ArgcompleteApp.',
+ ]
+ assert set(self.run_completer(app, "app --")) == set(expected)
+
+ # completing class traits
+ assert set(self.run_completer(app, "app --App")) > {
+ '--Application.show_config',
+ '--Application.log_level',
+ '--Application.log_format',
+ }
+
+ def test_complete_custom_completers(self, argcomplete_on):
+ app = ArgcompleteApp()
+ # test pre-defined completers for Bool/Enum
+ assert set(self.run_completer(app, "app --Application.log_level=")) > {"DEBUG", "INFO"}
+ assert set(self.run_completer(app, "app --ArgcompleteApp.show_config ")) == {
+ "0",
+ "1",
+ "true",
+ "false",
+ }
+
+ # test custom completer and mid-command completions
+ class CustomCls(Configurable):
+ val = Unicode().tag(
+ config=True, argcompleter=argcomplete.completers.ChoicesCompleter(["foo", "bar"])
+ )
+
+ class CustomApp(ArgcompleteApp):
+ classes = [CustomCls]
+ aliases = {("v", "val"): "CustomCls.val"}
+
+ app = CustomApp()
+ assert self.run_completer(app, "app --val ") == ["foo", "bar"]
+ assert self.run_completer(app, "app --val=") == ["foo", "bar"]
+ assert self.run_completer(app, "app -v ") == ["foo", "bar"]
+ assert self.run_completer(app, "app -v=") == ["foo", "bar"]
+ assert self.run_completer(app, "app --CustomCls.val ") == ["foo", "bar"]
+ assert self.run_completer(app, "app --CustomCls.val=") == ["foo", "bar"]
+ completions = self.run_completer(app, "app --val= abc xyz", point=10)
+ # fixed in argcomplete >= 2.0 to return latter below
+ assert completions == ["--val=foo", "--val=bar"] or completions == ["foo", "bar"]
+ assert self.run_completer(app, "app --val --log-level=", point=10) == ["foo", "bar"]
+
+ def test_complete_subcommands(self, argcomplete_on):
+ app = MainApp()
+ assert set(self.run_completer(app, "app ")) >= {"subapp1", "subapp2"}
+ assert set(self.run_completer(app, "app sub")) == {"subapp1", "subapp2"}
+ assert set(self.run_completer(app, "app subapp1")) == {"subapp1"}
+
+ def test_complete_subcommands_subapp1(self, argcomplete_on):
+ # subcommand handling modifies _ARGCOMPLETE env var global state, so
+ # only can test one completion per unit test
+ app = MainApp()
+ try:
+ assert set(self.run_completer(app, "app subapp1 --Sub")) > {
+ '--SubApp1.show_config',
+ '--SubApp1.log_level',
+ '--SubApp1.log_format',
+ }
+ finally:
+ SubApp1.clear_instance()
+
+ def test_complete_subcommands_subapp2(self, argcomplete_on):
+ app = MainApp()
+ try:
+ assert set(self.run_completer(app, "app subapp2 --")) > {
+ '--Application.',
+ '--SubApp2.',
+ }
+ finally:
+ SubApp2.clear_instance()
+
+ def test_complete_subcommands_main(self, argcomplete_on):
+ app = MainApp()
+ completions = set(self.run_completer(app, "app --"))
+ assert completions > {'--Application.', '--MainApp.'}
+ assert "--SubApp1." not in completions and "--SubApp2." not in completions
diff --git a/contrib/python/traitlets/py3/tests/config/test_configurable.py b/contrib/python/traitlets/py3/tests/config/test_configurable.py
new file mode 100644
index 00000000000..f6499ea29d1
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/config/test_configurable.py
@@ -0,0 +1,711 @@
+"""Tests for traitlets.config.configurable"""
+
+# Copyright (c) IPython Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+import logging
+from unittest import TestCase
+
+from pytest import mark
+
+from .._warnings import expected_warnings
+from traitlets.config.application import Application
+from traitlets.config.configurable import Configurable, LoggingConfigurable, SingletonConfigurable
+from traitlets.config.loader import Config
+from traitlets.log import get_logger
+from traitlets.traitlets import (
+ CaselessStrEnum,
+ Dict,
+ Enum,
+ Float,
+ FuzzyEnum,
+ Integer,
+ List,
+ Set,
+ Unicode,
+ validate,
+)
+from traitlets.utils.warnings import _deprecations_shown
+
+
+class MyConfigurable(Configurable):
+ a = Integer(1, help="The integer a.").tag(config=True)
+ b = Float(1.0, help="The integer b.").tag(config=True)
+ c = Unicode("no config")
+
+
+mc_help = """MyConfigurable(Configurable) options
+------------------------------------
+--MyConfigurable.a=<Integer>
+ The integer a.
+ Default: 1
+--MyConfigurable.b=<Float>
+ The integer b.
+ Default: 1.0"""
+
+mc_help_inst = """MyConfigurable(Configurable) options
+------------------------------------
+--MyConfigurable.a=<Integer>
+ The integer a.
+ Current: 5
+--MyConfigurable.b=<Float>
+ The integer b.
+ Current: 4.0"""
+
+# On Python 3, the Integer trait is a synonym for Int
+mc_help = mc_help.replace("<Integer>", "<Int>")
+mc_help_inst = mc_help_inst.replace("<Integer>", "<Int>")
+
+
+class Foo(Configurable):
+ a = Integer(0, help="The integer a.").tag(config=True)
+ b = Unicode("nope").tag(config=True)
+ flist = List([]).tag(config=True)
+ fdict = Dict().tag(config=True)
+
+
+class Bar(Foo):
+ b = Unicode("gotit", help="The string b.").tag(config=False)
+ c = Float(help="The string c.").tag(config=True)
+ bset = Set([]).tag(config=True, multiplicity="+")
+ bset_values = Set([2, 1, 5]).tag(config=True, multiplicity="+")
+ bdict = Dict().tag(config=True, multiplicity="+")
+ bdict_values = Dict({1: "a", "0": "b", 5: "c"}).tag(config=True, multiplicity="+")
+
+
+foo_help = """Foo(Configurable) options
+-------------------------
+--Foo.a=<Int>
+ The integer a.
+ Default: 0
+--Foo.b=<Unicode>
+ Default: 'nope'
+--Foo.fdict=<key-1>=<value-1>...
+ Default: {}
+--Foo.flist=<list-item-1>...
+ Default: []"""
+
+bar_help = """Bar(Foo) options
+----------------
+--Bar.a=<Int>
+ The integer a.
+ Default: 0
+--Bar.bdict <key-1>=<value-1>...
+ Default: {}
+--Bar.bdict_values <key-1>=<value-1>...
+ Default: {1: 'a', '0': 'b', 5: 'c'}
+--Bar.bset <set-item-1>...
+ Default: set()
+--Bar.bset_values <set-item-1>...
+ Default: {1, 2, 5}
+--Bar.c=<Float>
+ The string c.
+ Default: 0.0
+--Bar.fdict=<key-1>=<value-1>...
+ Default: {}
+--Bar.flist=<list-item-1>...
+ Default: []"""
+
+
+class TestConfigurable(TestCase):
+ def test_default(self):
+ c1 = Configurable()
+ c2 = Configurable(config=c1.config)
+ c3 = Configurable(config=c2.config)
+ self.assertEqual(c1.config, c2.config)
+ self.assertEqual(c2.config, c3.config)
+
+ def test_custom(self):
+ config = Config()
+ config.foo = "foo"
+ config.bar = "bar"
+ c1 = Configurable(config=config)
+ c2 = Configurable(config=c1.config)
+ c3 = Configurable(config=c2.config)
+ self.assertEqual(c1.config, config)
+ self.assertEqual(c2.config, config)
+ self.assertEqual(c3.config, config)
+ # Test that copies are not made
+ self.assertTrue(c1.config is config)
+ self.assertTrue(c2.config is config)
+ self.assertTrue(c3.config is config)
+ self.assertTrue(c1.config is c2.config)
+ self.assertTrue(c2.config is c3.config)
+
+ def test_inheritance(self):
+ config = Config()
+ config.MyConfigurable.a = 2
+ config.MyConfigurable.b = 2.0
+ c1 = MyConfigurable(config=config)
+ c2 = MyConfigurable(config=c1.config)
+ self.assertEqual(c1.a, config.MyConfigurable.a)
+ self.assertEqual(c1.b, config.MyConfigurable.b)
+ self.assertEqual(c2.a, config.MyConfigurable.a)
+ self.assertEqual(c2.b, config.MyConfigurable.b)
+
+ def test_parent(self):
+ config = Config()
+ config.Foo.a = 10
+ config.Foo.b = "wow"
+ config.Bar.b = "later"
+ config.Bar.c = 100.0
+ f = Foo(config=config)
+ with expected_warnings(["`b` not recognized"]):
+ b = Bar(config=f.config)
+ self.assertEqual(f.a, 10)
+ self.assertEqual(f.b, "wow")
+ self.assertEqual(b.b, "gotit")
+ self.assertEqual(b.c, 100.0)
+
+ def test_override1(self):
+ config = Config()
+ config.MyConfigurable.a = 2
+ config.MyConfigurable.b = 2.0
+ c = MyConfigurable(a=3, config=config)
+ self.assertEqual(c.a, 3)
+ self.assertEqual(c.b, config.MyConfigurable.b)
+ self.assertEqual(c.c, "no config")
+
+ def test_override2(self):
+ config = Config()
+ config.Foo.a = 1
+ config.Bar.b = "or" # Up above b is config=False, so this won't do it.
+ config.Bar.c = 10.0
+ with expected_warnings(["`b` not recognized"]):
+ c = Bar(config=config)
+ self.assertEqual(c.a, config.Foo.a)
+ self.assertEqual(c.b, "gotit")
+ self.assertEqual(c.c, config.Bar.c)
+ with expected_warnings(["`b` not recognized"]):
+ c = Bar(a=2, b="and", c=20.0, config=config)
+ self.assertEqual(c.a, 2)
+ self.assertEqual(c.b, "and")
+ self.assertEqual(c.c, 20.0)
+
+ def test_help(self):
+ self.assertEqual(MyConfigurable.class_get_help(), mc_help)
+ self.assertEqual(Foo.class_get_help(), foo_help)
+ self.assertEqual(Bar.class_get_help(), bar_help)
+
+ def test_help_inst(self):
+ inst = MyConfigurable(a=5, b=4)
+ self.assertEqual(MyConfigurable.class_get_help(inst), mc_help_inst)
+
+ def test_generated_config_enum_comments(self):
+ class MyConf(Configurable):
+ an_enum = Enum("Choice1 choice2".split(), help="Many choices.").tag(config=True)
+
+ help_str = "Many choices."
+ enum_choices_str = "Choices: any of ['Choice1', 'choice2']"
+ rst_choices_str = "MyConf.an_enum : any of ``'Choice1'``|``'choice2'``"
+ or_none_str = "or None"
+
+ cls_help = MyConf.class_get_help()
+
+ self.assertIn(help_str, cls_help)
+ self.assertIn(enum_choices_str, cls_help)
+ self.assertNotIn(or_none_str, cls_help)
+
+ cls_cfg = MyConf.class_config_section()
+
+ self.assertIn(help_str, cls_cfg)
+ self.assertIn(enum_choices_str, cls_cfg)
+ self.assertNotIn(or_none_str, cls_help)
+ # Check order of Help-msg <--> Choices sections
+ self.assertGreater(cls_cfg.index(enum_choices_str), cls_cfg.index(help_str))
+
+ rst_help = MyConf.class_config_rst_doc()
+
+ self.assertIn(help_str, rst_help)
+ self.assertIn(rst_choices_str, rst_help)
+ self.assertNotIn(or_none_str, rst_help)
+
+ class MyConf2(Configurable):
+ an_enum = Enum(
+ "Choice1 choice2".split(),
+ allow_none=True,
+ default_value="choice2",
+ help="Many choices.",
+ ).tag(config=True)
+
+ defaults_str = "Default: 'choice2'"
+
+ cls2_msg = MyConf2.class_get_help()
+
+ self.assertIn(help_str, cls2_msg)
+ self.assertIn(enum_choices_str, cls2_msg)
+ self.assertIn(or_none_str, cls2_msg)
+ self.assertIn(defaults_str, cls2_msg)
+ # Check order of Default <--> Choices sections
+ self.assertGreater(cls2_msg.index(defaults_str), cls2_msg.index(enum_choices_str))
+
+ cls2_cfg = MyConf2.class_config_section()
+
+ self.assertIn(help_str, cls2_cfg)
+ self.assertIn(enum_choices_str, cls2_cfg)
+ self.assertIn(or_none_str, cls2_cfg)
+ self.assertIn(defaults_str, cls2_cfg)
+ # Check order of Default <--> Choices sections
+ self.assertGreater(cls2_cfg.index(defaults_str), cls2_cfg.index(enum_choices_str))
+
+ def test_generated_config_strenum_comments(self):
+ help_str = "Many choices."
+ defaults_str = "Default: 'choice2'"
+ or_none_str = "or None"
+
+ class MyConf3(Configurable):
+ an_enum = CaselessStrEnum(
+ "Choice1 choice2".split(),
+ allow_none=True,
+ default_value="choice2",
+ help="Many choices.",
+ ).tag(config=True)
+
+ enum_choices_str = "Choices: any of ['Choice1', 'choice2'] (case-insensitive)"
+
+ cls3_msg = MyConf3.class_get_help()
+
+ self.assertIn(help_str, cls3_msg)
+ self.assertIn(enum_choices_str, cls3_msg)
+ self.assertIn(or_none_str, cls3_msg)
+ self.assertIn(defaults_str, cls3_msg)
+ # Check order of Default <--> Choices sections
+ self.assertGreater(cls3_msg.index(defaults_str), cls3_msg.index(enum_choices_str))
+
+ cls3_cfg = MyConf3.class_config_section()
+
+ self.assertIn(help_str, cls3_cfg)
+ self.assertIn(enum_choices_str, cls3_cfg)
+ self.assertIn(or_none_str, cls3_cfg)
+ self.assertIn(defaults_str, cls3_cfg)
+ # Check order of Default <--> Choices sections
+ self.assertGreater(cls3_cfg.index(defaults_str), cls3_cfg.index(enum_choices_str))
+
+ class MyConf4(Configurable):
+ an_enum = FuzzyEnum(
+ "Choice1 choice2".split(),
+ allow_none=True,
+ default_value="choice2",
+ help="Many choices.",
+ ).tag(config=True)
+
+ enum_choices_str = "Choices: any case-insensitive prefix of ['Choice1', 'choice2']"
+
+ cls4_msg = MyConf4.class_get_help()
+
+ self.assertIn(help_str, cls4_msg)
+ self.assertIn(enum_choices_str, cls4_msg)
+ self.assertIn(or_none_str, cls4_msg)
+ self.assertIn(defaults_str, cls4_msg)
+ # Check order of Default <--> Choices sections
+ self.assertGreater(cls4_msg.index(defaults_str), cls4_msg.index(enum_choices_str))
+
+ cls4_cfg = MyConf4.class_config_section()
+
+ self.assertIn(help_str, cls4_cfg)
+ self.assertIn(enum_choices_str, cls4_cfg)
+ self.assertIn(or_none_str, cls4_cfg)
+ self.assertIn(defaults_str, cls4_cfg)
+ # Check order of Default <--> Choices sections
+ self.assertGreater(cls4_cfg.index(defaults_str), cls4_cfg.index(enum_choices_str))
+
+
+class TestSingletonConfigurable(TestCase):
+ def test_instance(self):
+ class Foo(SingletonConfigurable):
+ pass
+
+ self.assertEqual(Foo.initialized(), False)
+ foo = Foo.instance()
+ self.assertEqual(Foo.initialized(), True)
+ self.assertEqual(foo, Foo.instance())
+ self.assertEqual(SingletonConfigurable._instance, None)
+
+ def test_inheritance(self):
+ class Bar(SingletonConfigurable):
+ pass
+
+ class Bam(Bar):
+ pass
+
+ self.assertEqual(Bar.initialized(), False)
+ self.assertEqual(Bam.initialized(), False)
+ bam = Bam.instance()
+ self.assertEqual(Bar.initialized(), True)
+ self.assertEqual(Bam.initialized(), True)
+ self.assertEqual(bam, Bam._instance)
+ self.assertEqual(bam, Bar._instance)
+ self.assertEqual(SingletonConfigurable._instance, None)
+
+
+class TestLoggingConfigurable(TestCase):
+ def test_parent_logger(self):
+ class Parent(LoggingConfigurable):
+ pass
+
+ class Child(LoggingConfigurable):
+ pass
+
+ log = get_logger().getChild("TestLoggingConfigurable")
+
+ parent = Parent(log=log)
+ child = Child(parent=parent)
+ self.assertEqual(parent.log, log)
+ self.assertEqual(child.log, log)
+
+ parent = Parent()
+ child = Child(parent=parent, log=log)
+ self.assertEqual(parent.log, get_logger())
+ self.assertEqual(child.log, log)
+
+ def test_parent_not_logging_configurable(self):
+ class Parent(Configurable):
+ pass
+
+ class Child(LoggingConfigurable):
+ pass
+
+ parent = Parent()
+ child = Child(parent=parent)
+ self.assertEqual(child.log, get_logger())
+
+
+class MyParent(Configurable):
+ pass
+
+
+class MyParent2(MyParent):
+ pass
+
+
+class TestParentConfigurable(TestCase):
+ def test_parent_config(self):
+ cfg = Config(
+ {
+ "MyParent": {
+ "MyConfigurable": {
+ "b": 2.0,
+ }
+ }
+ }
+ )
+ parent = MyParent(config=cfg)
+ myc = MyConfigurable(parent=parent)
+ self.assertEqual(myc.b, parent.config.MyParent.MyConfigurable.b)
+
+ def test_parent_inheritance(self):
+ cfg = Config(
+ {
+ "MyParent": {
+ "MyConfigurable": {
+ "b": 2.0,
+ }
+ }
+ }
+ )
+ parent = MyParent2(config=cfg)
+ myc = MyConfigurable(parent=parent)
+ self.assertEqual(myc.b, parent.config.MyParent.MyConfigurable.b)
+
+ def test_multi_parent(self):
+ cfg = Config(
+ {
+ "MyParent2": {
+ "MyParent": {
+ "MyConfigurable": {
+ "b": 2.0,
+ }
+ },
+ # this one shouldn't count
+ "MyConfigurable": {
+ "b": 3.0,
+ },
+ }
+ }
+ )
+ parent2 = MyParent2(config=cfg)
+ parent = MyParent(parent=parent2)
+ myc = MyConfigurable(parent=parent)
+ self.assertEqual(myc.b, parent.config.MyParent2.MyParent.MyConfigurable.b)
+
+ def test_parent_priority(self):
+ cfg = Config(
+ {
+ "MyConfigurable": {
+ "b": 2.0,
+ },
+ "MyParent": {
+ "MyConfigurable": {
+ "b": 3.0,
+ }
+ },
+ "MyParent2": {
+ "MyConfigurable": {
+ "b": 4.0,
+ }
+ },
+ }
+ )
+ parent = MyParent2(config=cfg)
+ myc = MyConfigurable(parent=parent)
+ self.assertEqual(myc.b, parent.config.MyParent2.MyConfigurable.b)
+
+ def test_multi_parent_priority(self):
+ cfg = Config(
+ {
+ "MyConfigurable": {
+ "b": 2.0,
+ },
+ "MyParent": {
+ "MyConfigurable": {
+ "b": 3.0,
+ },
+ },
+ "MyParent2": {
+ "MyConfigurable": {
+ "b": 4.0,
+ },
+ "MyParent": {
+ "MyConfigurable": {
+ "b": 5.0,
+ },
+ },
+ },
+ }
+ )
+ parent2 = MyParent2(config=cfg)
+ parent = MyParent2(parent=parent2)
+ myc = MyConfigurable(parent=parent)
+ self.assertEqual(myc.b, parent.config.MyParent2.MyParent.MyConfigurable.b)
+
+
+class Containers(Configurable):
+ lis = List().tag(config=True)
+
+ def _lis_default(self):
+ return [-1]
+
+ s = Set().tag(config=True)
+
+ def _s_default(self):
+ return {"a"}
+
+ d = Dict().tag(config=True)
+
+ def _d_default(self):
+ return {"a": "b"}
+
+
+class TestConfigContainers(TestCase):
+ def test_extend(self):
+ c = Config()
+ c.Containers.lis.extend(list(range(5)))
+ obj = Containers(config=c)
+ self.assertEqual(obj.lis, list(range(-1, 5)))
+
+ def test_insert(self):
+ c = Config()
+ c.Containers.lis.insert(0, "a")
+ c.Containers.lis.insert(1, "b")
+ obj = Containers(config=c)
+ self.assertEqual(obj.lis, ["a", "b", -1])
+
+ def test_prepend(self):
+ c = Config()
+ c.Containers.lis.prepend([1, 2])
+ c.Containers.lis.prepend([2, 3])
+ obj = Containers(config=c)
+ self.assertEqual(obj.lis, [2, 3, 1, 2, -1])
+
+ def test_prepend_extend(self):
+ c = Config()
+ c.Containers.lis.prepend([1, 2])
+ c.Containers.lis.extend([2, 3])
+ obj = Containers(config=c)
+ self.assertEqual(obj.lis, [1, 2, -1, 2, 3])
+
+ def test_append_extend(self):
+ c = Config()
+ c.Containers.lis.append([1, 2])
+ c.Containers.lis.extend([2, 3])
+ obj = Containers(config=c)
+ self.assertEqual(obj.lis, [-1, [1, 2], 2, 3])
+
+ def test_extend_append(self):
+ c = Config()
+ c.Containers.lis.extend([2, 3])
+ c.Containers.lis.append([1, 2])
+ obj = Containers(config=c)
+ self.assertEqual(obj.lis, [-1, 2, 3, [1, 2]])
+
+ def test_insert_extend(self):
+ c = Config()
+ c.Containers.lis.insert(0, 1)
+ c.Containers.lis.extend([2, 3])
+ obj = Containers(config=c)
+ self.assertEqual(obj.lis, [1, -1, 2, 3])
+
+ def test_set_update(self):
+ c = Config()
+ c.Containers.s.update({0, 1, 2})
+ c.Containers.s.update({3})
+ obj = Containers(config=c)
+ self.assertEqual(obj.s, {"a", 0, 1, 2, 3})
+
+ def test_dict_update(self):
+ c = Config()
+ c.Containers.d.update({"c": "d"})
+ c.Containers.d.update({"e": "f"})
+ obj = Containers(config=c)
+ self.assertEqual(obj.d, {"a": "b", "c": "d", "e": "f"})
+
+ def test_update_twice(self):
+ c = Config()
+ c.MyConfigurable.a = 5
+ m = MyConfigurable(config=c)
+ self.assertEqual(m.a, 5)
+
+ c2 = Config()
+ c2.MyConfigurable.a = 10
+ m.update_config(c2)
+ self.assertEqual(m.a, 10)
+
+ c2.MyConfigurable.a = 15
+ m.update_config(c2)
+ self.assertEqual(m.a, 15)
+
+ def test_update_self(self):
+ """update_config with same config object still triggers config_changed"""
+ c = Config()
+ c.MyConfigurable.a = 5
+ m = MyConfigurable(config=c)
+ self.assertEqual(m.a, 5)
+ c.MyConfigurable.a = 10
+ m.update_config(c)
+ self.assertEqual(m.a, 10)
+
+ def test_config_default(self):
+ class SomeSingleton(SingletonConfigurable):
+ pass
+
+ class DefaultConfigurable(Configurable):
+ a = Integer().tag(config=True)
+
+ def _config_default(self):
+ if SomeSingleton.initialized():
+ return SomeSingleton.instance().config
+ return Config()
+
+ c = Config()
+ c.DefaultConfigurable.a = 5
+
+ d1 = DefaultConfigurable()
+ self.assertEqual(d1.a, 0)
+
+ single = SomeSingleton.instance(config=c)
+
+ d2 = DefaultConfigurable()
+ self.assertIs(d2.config, single.config)
+ self.assertEqual(d2.a, 5)
+
+ def test_config_default_deprecated(self):
+ """Make sure configurables work even with the deprecations in traitlets"""
+
+ class SomeSingleton(SingletonConfigurable):
+ pass
+
+ # reset deprecation limiter
+ _deprecations_shown.clear()
+ with expected_warnings([]):
+
+ class DefaultConfigurable(Configurable):
+ a = Integer(config=True)
+
+ def _config_default(self):
+ if SomeSingleton.initialized():
+ return SomeSingleton.instance().config
+ return Config()
+
+ c = Config()
+ c.DefaultConfigurable.a = 5
+
+ d1 = DefaultConfigurable()
+ self.assertEqual(d1.a, 0)
+
+ single = SomeSingleton.instance(config=c)
+
+ d2 = DefaultConfigurable()
+ self.assertIs(d2.config, single.config)
+ self.assertEqual(d2.a, 5)
+
+ def test_kwarg_config_priority(self):
+ # a, c set in kwargs
+ # a, b set in config
+ # verify that:
+ # - kwargs are set before config
+ # - kwargs have priority over config
+ class A(Configurable):
+ a = Unicode("default", config=True)
+ b = Unicode("default", config=True)
+ c = Unicode("default", config=True)
+ c_during_config = Unicode("never")
+
+ @validate("b")
+ def _record_c(self, proposal):
+ # setting b from config records c's value at the time
+ self.c_during_config = self.c
+ return proposal.value
+
+ cfg = Config()
+ cfg.A.a = "a-config"
+ cfg.A.b = "b-config"
+ obj = A(a="a-kwarg", c="c-kwarg", config=cfg)
+ assert obj.a == "a-kwarg"
+ assert obj.b == "b-config"
+ assert obj.c == "c-kwarg"
+ assert obj.c_during_config == "c-kwarg"
+
+
+class TestLogger(TestCase):
+ class A(LoggingConfigurable):
+ foo = Integer(config=True)
+ bar = Integer(config=True)
+ baz = Integer(config=True)
+
+ @mark.skipif(not hasattr(TestCase, "assertLogs"), reason="requires TestCase.assertLogs")
+ def test_warn_match(self):
+ logger = logging.getLogger("test_warn_match")
+ cfg = Config({"A": {"bat": 5}})
+ with self.assertLogs(logger, logging.WARNING) as captured:
+ TestLogger.A(config=cfg, log=logger)
+
+ output = "\n".join(captured.output)
+ self.assertIn("Did you mean one of: `bar, baz`?", output)
+ self.assertIn("Config option `bat` not recognized by `A`.", output)
+
+ cfg = Config({"A": {"fool": 5}})
+ with self.assertLogs(logger, logging.WARNING) as captured:
+ TestLogger.A(config=cfg, log=logger)
+
+ output = "\n".join(captured.output)
+ self.assertIn("Config option `fool` not recognized by `A`.", output)
+ self.assertIn("Did you mean `foo`?", output)
+
+ cfg = Config({"A": {"totally_wrong": 5}})
+ with self.assertLogs(logger, logging.WARNING) as captured:
+ TestLogger.A(config=cfg, log=logger)
+
+ output = "\n".join(captured.output)
+ self.assertIn("Config option `totally_wrong` not recognized by `A`.", output)
+ self.assertNotIn("Did you mean", output)
+
+
+def test_logger_adapter(caplog, capsys):
+ logger = logging.getLogger("Application")
+ adapter = logging.LoggerAdapter(logger, {"key": "adapted"})
+
+ app = Application(log=adapter, log_level=logging.INFO)
+ app.log_format = "%(key)s %(message)s"
+ app.log.info("test message")
+
+ assert "adapted test message" in capsys.readouterr().err
diff --git a/contrib/python/traitlets/py3/tests/config/test_loader.py b/contrib/python/traitlets/py3/tests/config/test_loader.py
new file mode 100644
index 00000000000..3a1f96120f7
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/config/test_loader.py
@@ -0,0 +1,753 @@
+"""Tests for traitlets.config.loader"""
+
+# Copyright (c) IPython Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+import copy
+import os
+import pickle
+from itertools import chain
+from tempfile import mkstemp
+from unittest import TestCase
+
+import pytest
+
+from traitlets import Dict, Integer, List, Tuple, Unicode
+from traitlets.config import Configurable
+from traitlets.config.loader import (
+ ArgParseConfigLoader,
+ Config,
+ JSONFileConfigLoader,
+ KeyValueConfigLoader,
+ KVArgParseConfigLoader,
+ LazyConfigValue,
+ PyFileConfigLoader,
+)
+
+pyfile = """
+c = get_config()
+c.a=10
+c.b=20
+c.Foo.Bar.value=10
+c.Foo.Bam.value=list(range(10))
+c.D.C.value='hi there'
+"""
+
+json1file = """
+{
+ "version": 1,
+ "a": 10,
+ "b": 20,
+ "Foo": {
+ "Bam": {
+ "value": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ]
+ },
+ "Bar": {
+ "value": 10
+ }
+ },
+ "D": {
+ "C": {
+ "value": "hi there"
+ }
+ }
+}
+"""
+
+# should not load
+json2file = """
+{
+ "version": 2
+}
+"""
+
+import logging
+
+log = logging.getLogger("devnull")
+log.setLevel(0)
+
+
+class TestFileCL(TestCase):
+ def _check_conf(self, config):
+ self.assertEqual(config.a, 10)
+ self.assertEqual(config.b, 20)
+ self.assertEqual(config.Foo.Bar.value, 10)
+ self.assertEqual(config.Foo.Bam.value, list(range(10)))
+ self.assertEqual(config.D.C.value, "hi there")
+
+ def test_python(self):
+ fd, fname = mkstemp(".py", prefix="μnïcø∂e")
+ f = os.fdopen(fd, "w")
+ f.write(pyfile)
+ f.close()
+ # Unlink the file
+ cl = PyFileConfigLoader(fname, log=log)
+ config = cl.load_config()
+ self._check_conf(config)
+
+ def test_json(self):
+ fd, fname = mkstemp(".json", prefix="μnïcø∂e")
+ f = os.fdopen(fd, "w")
+ f.write(json1file)
+ f.close()
+ # Unlink the file
+ cl = JSONFileConfigLoader(fname, log=log)
+ config = cl.load_config()
+ self._check_conf(config)
+
+ def test_context_manager(self):
+ fd, fname = mkstemp(".json", prefix="μnïcø∂e")
+ f = os.fdopen(fd, "w")
+ f.write("{}")
+ f.close()
+
+ cl = JSONFileConfigLoader(fname, log=log)
+
+ value = "context_manager"
+
+ with cl as c:
+ c.MyAttr.value = value
+
+ self.assertEqual(cl.config.MyAttr.value, value)
+
+ # check that another loader does see the change
+ _ = JSONFileConfigLoader(fname, log=log)
+ self.assertEqual(cl.config.MyAttr.value, value)
+
+ def test_json_context_bad_write(self):
+ fd, fname = mkstemp(".json", prefix="μnïcø∂e")
+ f = os.fdopen(fd, "w")
+ f.write("{}")
+ f.close()
+
+ with JSONFileConfigLoader(fname, log=log) as config:
+ config.A.b = 1
+
+ with self.assertRaises(TypeError):
+ with JSONFileConfigLoader(fname, log=log) as config:
+ config.A.cant_json = lambda x: x
+
+ loader = JSONFileConfigLoader(fname, log=log)
+ cfg = loader.load_config()
+ assert cfg.A.b == 1
+ assert "cant_json" not in cfg.A
+
+ def test_collision(self):
+ a = Config()
+ b = Config()
+ self.assertEqual(a.collisions(b), {})
+ a.A.trait1 = 1
+ b.A.trait2 = 2
+ self.assertEqual(a.collisions(b), {})
+ b.A.trait1 = 1
+ self.assertEqual(a.collisions(b), {})
+ b.A.trait1 = 0
+ self.assertEqual(
+ a.collisions(b),
+ {
+ "A": {
+ "trait1": "1 ignored, using 0",
+ }
+ },
+ )
+ self.assertEqual(
+ b.collisions(a),
+ {
+ "A": {
+ "trait1": "0 ignored, using 1",
+ }
+ },
+ )
+ a.A.trait2 = 3
+ self.assertEqual(
+ b.collisions(a),
+ {
+ "A": {
+ "trait1": "0 ignored, using 1",
+ "trait2": "2 ignored, using 3",
+ }
+ },
+ )
+
+ def test_v2raise(self):
+ fd, fname = mkstemp(".json", prefix="μnïcø∂e")
+ f = os.fdopen(fd, "w")
+ f.write(json2file)
+ f.close()
+ # Unlink the file
+ cl = JSONFileConfigLoader(fname, log=log)
+ with self.assertRaises(ValueError):
+ cl.load_config()
+
+
+def _parse_int_or_str(v):
+ try:
+ return int(v)
+ except Exception:
+ return str(v)
+
+
+class MyLoader1(ArgParseConfigLoader):
+ def _add_arguments(self, aliases=None, flags=None, classes=None):
+ p = self.parser
+ p.add_argument("-f", "--foo", dest="Global.foo", type=str)
+ p.add_argument("-b", dest="MyClass.bar", type=int)
+ p.add_argument("-n", dest="n", action="store_true")
+ p.add_argument("Global.bam", type=str)
+ p.add_argument("--list1", action="append", type=_parse_int_or_str)
+ p.add_argument("--list2", nargs="+", type=int)
+
+
+class MyLoader2(ArgParseConfigLoader):
+ def _add_arguments(self, aliases=None, flags=None, classes=None):
+ subparsers = self.parser.add_subparsers(dest="subparser_name")
+ subparser1 = subparsers.add_parser("1")
+ subparser1.add_argument("-x", dest="Global.x")
+ subparser2 = subparsers.add_parser("2")
+ subparser2.add_argument("y")
+
+
+class TestArgParseCL(TestCase):
+ def test_basic(self):
+ cl = MyLoader1()
+ config = cl.load_config("-f hi -b 10 -n wow".split())
+ self.assertEqual(config.Global.foo, "hi")
+ self.assertEqual(config.MyClass.bar, 10)
+ self.assertEqual(config.n, True)
+ self.assertEqual(config.Global.bam, "wow")
+ config = cl.load_config(["wow"])
+ self.assertEqual(list(config.keys()), ["Global"])
+ self.assertEqual(list(config.Global.keys()), ["bam"])
+ self.assertEqual(config.Global.bam, "wow")
+
+ def test_add_arguments(self):
+ cl = MyLoader2()
+ config = cl.load_config("2 frobble".split())
+ self.assertEqual(config.subparser_name, "2")
+ self.assertEqual(config.y, "frobble")
+ config = cl.load_config("1 -x frobble".split())
+ self.assertEqual(config.subparser_name, "1")
+ self.assertEqual(config.Global.x, "frobble")
+
+ def test_argv(self):
+ cl = MyLoader1(argv="-f hi -b 10 -n wow".split())
+ config = cl.load_config()
+ self.assertEqual(config.Global.foo, "hi")
+ self.assertEqual(config.MyClass.bar, 10)
+ self.assertEqual(config.n, True)
+ self.assertEqual(config.Global.bam, "wow")
+
+ def test_list_args(self):
+ cl = MyLoader1()
+ config = cl.load_config("--list1 1 wow --list2 1 2 3 --list1 B".split())
+ self.assertEqual(list(config.Global.keys()), ["bam"])
+ self.assertEqual(config.Global.bam, "wow")
+ self.assertEqual(config.list1, [1, "B"])
+ self.assertEqual(config.list2, [1, 2, 3])
+
+
+class C(Configurable):
+ str_trait = Unicode(config=True)
+ int_trait = Integer(config=True)
+ list_trait = List(config=True)
+ list_of_ints = List(Integer(), config=True)
+ dict_trait = Dict(config=True)
+ dict_of_ints = Dict(
+ key_trait=Integer(),
+ value_trait=Integer(),
+ config=True,
+ )
+ dict_multi = Dict(
+ key_trait=Unicode(),
+ per_key_traits={
+ "int": Integer(),
+ "str": Unicode(),
+ },
+ config=True,
+ )
+
+
+class TestKeyValueCL(TestCase):
+ klass = KeyValueConfigLoader
+
+ def test_eval(self):
+ cl = self.klass(log=log)
+ config = cl.load_config(
+ '--C.str_trait=all --C.int_trait=5 --C.list_trait=["hello",5]'.split()
+ )
+ c = C(config=config)
+ assert c.str_trait == "all"
+ assert c.int_trait == 5
+ assert c.list_trait == ["hello", 5]
+
+ def test_basic(self):
+ cl = self.klass(log=log)
+ argv = ["--" + s[2:] for s in pyfile.split("\n") if s.startswith("c.")]
+ config = cl.load_config(argv)
+ assert config.a == "10"
+ assert config.b == "20"
+ assert config.Foo.Bar.value == "10"
+ # non-literal expressions are not evaluated
+ self.assertEqual(config.Foo.Bam.value, "list(range(10))")
+ self.assertEqual(Unicode().from_string(config.D.C.value), "hi there")
+
+ def test_expanduser(self):
+ cl = self.klass(log=log)
+ argv = ["--a=~/1/2/3", "--b=~", "--c=~/", '--d="~/"']
+ config = cl.load_config(argv)
+ u = Unicode()
+ self.assertEqual(u.from_string(config.a), os.path.expanduser("~/1/2/3"))
+ self.assertEqual(u.from_string(config.b), os.path.expanduser("~"))
+ self.assertEqual(u.from_string(config.c), os.path.expanduser("~/"))
+ self.assertEqual(u.from_string(config.d), "~/")
+
+ def test_extra_args(self):
+ cl = self.klass(log=log)
+ config = cl.load_config(["--a=5", "b", "d", "--c=10"])
+ self.assertEqual(cl.extra_args, ["b", "d"])
+ assert config.a == "5"
+ assert config.c == "10"
+ config = cl.load_config(["--", "--a=5", "--c=10"])
+ self.assertEqual(cl.extra_args, ["--a=5", "--c=10"])
+
+ cl = self.klass(log=log)
+ config = cl.load_config(["extra", "--a=2", "--c=1", "--", "-"])
+ self.assertEqual(cl.extra_args, ["extra", "-"])
+
+ def test_unicode_args(self):
+ cl = self.klass(log=log)
+ argv = ["--a=épsîlön"]
+ config = cl.load_config(argv)
+ print(config, cl.extra_args)
+ self.assertEqual(config.a, "épsîlön")
+
+ def test_list_append(self):
+ cl = self.klass(log=log)
+ argv = ["--C.list_trait", "x", "--C.list_trait", "y"]
+ config = cl.load_config(argv)
+ assert config.C.list_trait == ["x", "y"]
+ c = C(config=config)
+ assert c.list_trait == ["x", "y"]
+
+ def test_list_single_item(self):
+ cl = self.klass(log=log)
+ argv = ["--C.list_trait", "x"]
+ config = cl.load_config(argv)
+ c = C(config=config)
+ assert c.list_trait == ["x"]
+
+ def test_dict(self):
+ cl = self.klass(log=log)
+ argv = ["--C.dict_trait", "x=5", "--C.dict_trait", "y=10"]
+ config = cl.load_config(argv)
+ c = C(config=config)
+ assert c.dict_trait == {"x": "5", "y": "10"}
+
+ def test_dict_key_traits(self):
+ cl = self.klass(log=log)
+ argv = ["--C.dict_of_ints", "1=2", "--C.dict_of_ints", "3=4"]
+ config = cl.load_config(argv)
+ c = C(config=config)
+ assert c.dict_of_ints == {1: 2, 3: 4}
+
+
+class CBase(Configurable):
+ a = List().tag(config=True)
+ b = List(Integer()).tag(config=True, multiplicity="*")
+ c = List().tag(config=True, multiplicity="append")
+ adict = Dict().tag(config=True)
+
+
+class CSub(CBase):
+ d = Tuple().tag(config=True)
+ e = Tuple().tag(config=True, multiplicity="+")
+ bdict = Dict().tag(config=True, multiplicity="*")
+
+
+class TestArgParseKVCL(TestKeyValueCL):
+ klass = KVArgParseConfigLoader # type:ignore
+
+ def test_no_cast_literals(self):
+ cl = self.klass(log=log) # type:ignore
+ # test ipython -c 1 doesn't cast to int
+ argv = ["-c", "1"]
+ config = cl.load_config(argv, aliases=dict(c="IPython.command_to_run"))
+ assert config.IPython.command_to_run == "1"
+
+ def test_int_literals(self):
+ cl = self.klass(log=log) # type:ignore
+ # test ipython -c 1 doesn't cast to int
+ argv = ["-c", "1"]
+ config = cl.load_config(argv, aliases=dict(c="IPython.command_to_run"))
+ assert config.IPython.command_to_run == "1"
+
+ def test_unicode_alias(self):
+ cl = self.klass(log=log) # type:ignore
+ argv = ["--a=épsîlön"]
+ config = cl.load_config(argv, aliases=dict(a="A.a"))
+ print(dict(config))
+ print(cl.extra_args)
+ print(cl.aliases)
+ self.assertEqual(config.A.a, "épsîlön")
+
+ def test_expanduser2(self):
+ cl = self.klass(log=log) # type:ignore
+ argv = ["-a", "~/1/2/3", "--b", "'~/1/2/3'"]
+ config = cl.load_config(argv, aliases=dict(a="A.a", b="A.b"))
+
+ class A(Configurable):
+ a = Unicode(config=True)
+ b = Unicode(config=True)
+
+ a = A(config=config)
+ self.assertEqual(a.a, os.path.expanduser("~/1/2/3"))
+ self.assertEqual(a.b, "~/1/2/3")
+
+ def test_eval(self):
+ cl = self.klass(log=log) # type:ignore
+ argv = ["-c", "a=5"]
+ config = cl.load_config(argv, aliases=dict(c="A.c"))
+ self.assertEqual(config.A.c, "a=5")
+
+ def test_seq_traits(self):
+ cl = self.klass(log=log, classes=(CBase, CSub)) # type:ignore
+ aliases = {"a3": "CBase.c", "a5": "CSub.e"}
+ argv = (
+ "--CBase.a A --CBase.a 2 --CBase.b 1 --CBase.b 3 --a3 AA --CBase.c BB "
+ "--CSub.d 1 --CSub.d BBB --CSub.e 1 --CSub.e=bcd a b c "
+ ).split()
+ config = cl.load_config(argv, aliases=aliases)
+ assert cl.extra_args == ["a", "b", "c"]
+ assert config.CBase.a == ["A", "2"]
+ assert config.CBase.b == [1, 3]
+ self.assertEqual(config.CBase.c, ["AA", "BB"])
+
+ assert config.CSub.d == ("1", "BBB")
+ assert config.CSub.e == ("1", "bcd")
+
+ def test_seq_traits_single_empty_string(self):
+ cl = self.klass(log=log, classes=(CBase,)) # type:ignore
+ aliases = {"seqopt": "CBase.c"}
+ argv = ["--seqopt", ""]
+ config = cl.load_config(argv, aliases=aliases)
+ self.assertEqual(config.CBase.c, [""])
+
+ def test_dict_traits(self):
+ cl = self.klass(log=log, classes=(CBase, CSub)) # type:ignore
+ aliases = {"D": "CBase.adict", "E": "CSub.bdict"}
+ argv = ["-D", "k1=v1", "-D=k2=2", "-D", "k3=v 3", "-E", "k=v", "-E", "22=222"]
+ config = cl.load_config(argv, aliases=aliases)
+ c = CSub(config=config)
+ assert c.adict == {"k1": "v1", "k2": "2", "k3": "v 3"}
+ assert c.bdict == {"k": "v", "22": "222"}
+
+ def test_mixed_seq_positional(self):
+ aliases = {"c": "Class.trait"}
+ cl = self.klass(log=log, aliases=aliases) # type:ignore
+ assignments = [("-c", "1"), ("--Class.trait=2",), ("--c=3",), ("--Class.trait", "4")]
+ positionals = ["a", "b", "c"]
+ # test with positionals at any index
+ for idx in range(len(assignments) + 1):
+ argv_parts = assignments[:]
+ argv_parts[idx:idx] = (positionals,) # type:ignore
+ argv = list(chain(*argv_parts))
+
+ config = cl.load_config(argv)
+ assert config.Class.trait == ["1", "2", "3", "4"]
+ assert cl.extra_args == ["a", "b", "c"]
+
+ def test_split_positional(self):
+ """Splitting positionals across flags is no longer allowed in traitlets 5"""
+ cl = self.klass(log=log) # type:ignore
+ argv = ["a", "--Class.trait=5", "b"]
+ with pytest.raises(SystemExit):
+ cl.load_config(argv)
+
+
+class TestConfig(TestCase):
+ def test_setget(self):
+ c = Config()
+ c.a = 10
+ self.assertEqual(c.a, 10)
+ self.assertEqual("b" in c, False)
+
+ def test_auto_section(self):
+ c = Config()
+ self.assertNotIn("A", c)
+ assert not c._has_section("A")
+ A = c.A
+ A.foo = "hi there"
+ self.assertIn("A", c)
+ assert c._has_section("A")
+ self.assertEqual(c.A.foo, "hi there")
+ del c.A
+ self.assertEqual(c.A, Config())
+
+ def test_merge_doesnt_exist(self):
+ c1 = Config()
+ c2 = Config()
+ c2.bar = 10
+ c2.Foo.bar = 10
+ c1.merge(c2)
+ self.assertEqual(c1.Foo.bar, 10)
+ self.assertEqual(c1.bar, 10)
+ c2.Bar.bar = 10
+ c1.merge(c2)
+ self.assertEqual(c1.Bar.bar, 10)
+
+ def test_merge_exists(self):
+ c1 = Config()
+ c2 = Config()
+ c1.Foo.bar = 10
+ c1.Foo.bam = 30
+ c2.Foo.bar = 20
+ c2.Foo.wow = 40
+ c1.merge(c2)
+ self.assertEqual(c1.Foo.bam, 30)
+ self.assertEqual(c1.Foo.bar, 20)
+ self.assertEqual(c1.Foo.wow, 40)
+ c2.Foo.Bam.bam = 10
+ c1.merge(c2)
+ self.assertEqual(c1.Foo.Bam.bam, 10)
+
+ def test_deepcopy(self):
+ c1 = Config()
+ c1.Foo.bar = 10
+ c1.Foo.bam = 30
+ c1.a = "asdf"
+ c1.b = range(10)
+ c1.Test.logger = logging.Logger("test")
+ c1.Test.get_logger = logging.getLogger("test")
+ c2 = copy.deepcopy(c1)
+ self.assertEqual(c1, c2)
+ self.assertTrue(c1 is not c2)
+ self.assertTrue(c1.Foo is not c2.Foo)
+ self.assertTrue(c1.Test is not c2.Test)
+ self.assertTrue(c1.Test.logger is c2.Test.logger)
+ self.assertTrue(c1.Test.get_logger is c2.Test.get_logger)
+
+ def test_builtin(self):
+ c1 = Config()
+ c1.format = "json"
+
+ def test_fromdict(self):
+ c1 = Config({"Foo": {"bar": 1}})
+ self.assertEqual(c1.Foo.__class__, Config)
+ self.assertEqual(c1.Foo.bar, 1)
+
+ def test_fromdictmerge(self):
+ c1 = Config()
+ c2 = Config({"Foo": {"bar": 1}})
+ c1.merge(c2)
+ self.assertEqual(c1.Foo.__class__, Config)
+ self.assertEqual(c1.Foo.bar, 1)
+
+ def test_fromdictmerge2(self):
+ c1 = Config({"Foo": {"baz": 2}})
+ c2 = Config({"Foo": {"bar": 1}})
+ c1.merge(c2)
+ self.assertEqual(c1.Foo.__class__, Config)
+ self.assertEqual(c1.Foo.bar, 1)
+ self.assertEqual(c1.Foo.baz, 2)
+ self.assertNotIn("baz", c2.Foo)
+
+ def test_contains(self):
+ c1 = Config({"Foo": {"baz": 2}})
+ c2 = Config({"Foo": {"bar": 1}})
+ self.assertIn("Foo", c1)
+ self.assertIn("Foo.baz", c1)
+ self.assertIn("Foo.bar", c2)
+ self.assertNotIn("Foo.bar", c1)
+
+ def test_pickle_config(self):
+ cfg = Config()
+ cfg.Foo.bar = 1
+ pcfg = pickle.dumps(cfg)
+ cfg2 = pickle.loads(pcfg)
+ self.assertEqual(cfg2, cfg)
+
+ def test_getattr_section(self):
+ cfg = Config()
+ self.assertNotIn("Foo", cfg)
+ Foo = cfg.Foo
+ assert isinstance(Foo, Config)
+ self.assertIn("Foo", cfg)
+
+ def test_getitem_section(self):
+ cfg = Config()
+ self.assertNotIn("Foo", cfg)
+ Foo = cfg["Foo"]
+ assert isinstance(Foo, Config)
+ self.assertIn("Foo", cfg)
+
+ def test_getattr_not_section(self):
+ cfg = Config()
+ self.assertNotIn("foo", cfg)
+ foo = cfg.foo
+ assert isinstance(foo, LazyConfigValue)
+ self.assertIn("foo", cfg)
+
+ def test_getattr_private_missing(self):
+ cfg = Config()
+ self.assertNotIn("_repr_html_", cfg)
+ with self.assertRaises(AttributeError):
+ _ = cfg._repr_html_
+ self.assertNotIn("_repr_html_", cfg)
+ self.assertEqual(len(cfg), 0)
+
+ def test_lazy_config_repr(self):
+ cfg = Config()
+ cfg.Class.lazy.append(1)
+ cfg_repr = repr(cfg)
+ assert "<LazyConfigValue" in cfg_repr
+ assert "extend" in cfg_repr
+ assert " [1]}>" in cfg_repr
+ assert "value=" not in cfg_repr
+ cfg.Class.lazy.get_value([0])
+ repr2 = repr(cfg)
+ assert repr([0, 1]) in repr2
+ assert "value=" in repr2
+
+ def test_getitem_not_section(self):
+ cfg = Config()
+ self.assertNotIn("foo", cfg)
+ foo = cfg["foo"]
+ assert isinstance(foo, LazyConfigValue)
+ self.assertIn("foo", cfg)
+
+ def test_merge_no_copies(self):
+ c = Config()
+ c2 = Config()
+ c2.Foo.trait = []
+ c.merge(c2)
+ c2.Foo.trait.append(1)
+ self.assertIs(c.Foo, c2.Foo)
+ self.assertEqual(c.Foo.trait, [1])
+ self.assertEqual(c2.Foo.trait, [1])
+
+ def test_merge_multi_lazy(self):
+ """
+ With multiple config files (systemwide and users), we want compounding.
+
+ If systemwide overwirte and user append, we want both in the right
+ order.
+ """
+ c1 = Config()
+ c2 = Config()
+
+ c1.Foo.trait = [1]
+ c2.Foo.trait.append(2)
+
+ c = Config()
+ c.merge(c1)
+ c.merge(c2)
+
+ self.assertEqual(c.Foo.trait, [1, 2])
+
+ def test_merge_multi_lazyII(self):
+ """
+ With multiple config files (systemwide and users), we want compounding.
+
+ If both are lazy we still want a lazy config.
+ """
+ c1 = Config()
+ c2 = Config()
+
+ c1.Foo.trait.append(1)
+ c2.Foo.trait.append(2)
+
+ c = Config()
+ c.merge(c1)
+ c.merge(c2)
+
+ self.assertEqual(c.Foo.trait._extend, [1, 2])
+
+ def test_merge_multi_lazy_III(self):
+ """
+ With multiple config files (systemwide and users), we want compounding.
+
+ Prepend should prepend in the right order.
+ """
+ c1 = Config()
+ c2 = Config()
+
+ c1.Foo.trait = [1]
+ c2.Foo.trait.prepend([0])
+
+ c = Config()
+ c.merge(c1)
+ c.merge(c2)
+
+ self.assertEqual(c.Foo.trait, [0, 1])
+
+ def test_merge_multi_lazy_IV(self):
+ """
+ With multiple config files (systemwide and users), we want compounding.
+
+ Both prepending should be lazy
+ """
+ c1 = Config()
+ c2 = Config()
+
+ c1.Foo.trait.prepend([1])
+ c2.Foo.trait.prepend([0])
+
+ c = Config()
+ c.merge(c1)
+ c.merge(c2)
+
+ self.assertEqual(c.Foo.trait._prepend, [0, 1])
+
+ def test_merge_multi_lazy_update_I(self):
+ """
+ With multiple config files (systemwide and users), we want compounding.
+
+ dict update shoudl be in the right order.
+ """
+ c1 = Config()
+ c2 = Config()
+
+ c1.Foo.trait = {"a": 1, "z": 26}
+ c2.Foo.trait.update({"a": 0, "b": 1})
+
+ c = Config()
+ c.merge(c1)
+ c.merge(c2)
+
+ self.assertEqual(c.Foo.trait, {"a": 0, "b": 1, "z": 26})
+
+ def test_merge_multi_lazy_update_II(self):
+ """
+ With multiple config files (systemwide and users), we want compounding.
+
+ Later dict overwrite lazyness
+ """
+ c1 = Config()
+ c2 = Config()
+
+ c1.Foo.trait.update({"a": 0, "b": 1})
+ c2.Foo.trait = {"a": 1, "z": 26}
+
+ c = Config()
+ c.merge(c1)
+ c.merge(c2)
+
+ self.assertEqual(c.Foo.trait, {"a": 1, "z": 26})
+
+ def test_merge_multi_lazy_update_III(self):
+ """
+ With multiple config files (systemwide and users), we want compounding.
+
+ Later dict overwrite lazyness
+ """
+ c1 = Config()
+ c2 = Config()
+
+ c1.Foo.trait.update({"a": 0, "b": 1})
+ c2.Foo.trait.update({"a": 1, "z": 26})
+
+ c = Config()
+ c.merge(c1)
+ c.merge(c2)
+
+ self.assertEqual(c.Foo.trait._update, {"a": 1, "z": 26, "b": 1})
diff --git a/contrib/python/traitlets/py3/tests/test_traitlets.py b/contrib/python/traitlets/py3/tests/test_traitlets.py
new file mode 100644
index 00000000000..62fa726f19b
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/test_traitlets.py
@@ -0,0 +1,3141 @@
+"""Tests for traitlets.traitlets."""
+
+# Copyright (c) IPython Development Team.
+# Distributed under the terms of the Modified BSD License.
+#
+# Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
+# also under the terms of the Modified BSD License.
+
+import pickle
+import re
+import typing as t
+from unittest import TestCase
+
+import pytest
+
+from traitlets import (
+ All,
+ Any,
+ BaseDescriptor,
+ Bool,
+ Bytes,
+ Callable,
+ CBytes,
+ CFloat,
+ CInt,
+ CLong,
+ Complex,
+ CRegExp,
+ CUnicode,
+ Dict,
+ DottedObjectName,
+ Enum,
+ Float,
+ ForwardDeclaredInstance,
+ ForwardDeclaredType,
+ HasDescriptors,
+ HasTraits,
+ Instance,
+ Int,
+ Integer,
+ List,
+ Long,
+ MetaHasTraits,
+ ObjectName,
+ Set,
+ TCPAddress,
+ This,
+ TraitError,
+ TraitType,
+ Tuple,
+ Type,
+ Undefined,
+ Unicode,
+ Union,
+ default,
+ directional_link,
+ link,
+ observe,
+ observe_compat,
+ traitlets,
+ validate,
+)
+from traitlets.utils import cast_unicode
+
+from ._warnings import expected_warnings
+
+
+def change_dict(*ordered_values):
+ change_names = ("name", "old", "new", "owner", "type")
+ return dict(zip(change_names, ordered_values))
+
+
+# -----------------------------------------------------------------------------
+# Helper classes for testing
+# -----------------------------------------------------------------------------
+
+
+class HasTraitsStub(HasTraits):
+ def notify_change(self, change):
+ self._notify_name = change["name"]
+ self._notify_old = change["old"]
+ self._notify_new = change["new"]
+ self._notify_type = change["type"]
+
+
+class CrossValidationStub(HasTraits):
+ _cross_validation_lock = False
+
+
+# -----------------------------------------------------------------------------
+# Test classes
+# -----------------------------------------------------------------------------
+
+
+class TestTraitType(TestCase):
+ def test_get_undefined(self):
+ class A(HasTraits):
+ a = TraitType
+
+ a = A()
+ assert a.a is Undefined # type:ignore
+
+ def test_set(self):
+ class A(HasTraitsStub):
+ a = TraitType
+
+ a = A()
+ a.a = 10 # type:ignore
+ self.assertEqual(a.a, 10)
+ self.assertEqual(a._notify_name, "a")
+ self.assertEqual(a._notify_old, Undefined)
+ self.assertEqual(a._notify_new, 10)
+
+ def test_validate(self):
+ class MyTT(TraitType[int, int]):
+ def validate(self, inst, value):
+ return -1
+
+ class A(HasTraitsStub):
+ tt = MyTT
+
+ a = A()
+ a.tt = 10 # type:ignore
+ self.assertEqual(a.tt, -1)
+
+ a = A(tt=11)
+ self.assertEqual(a.tt, -1)
+
+ def test_default_validate(self):
+ class MyIntTT(TraitType[int, int]):
+ def validate(self, obj, value):
+ if isinstance(value, int):
+ return value
+ self.error(obj, value)
+
+ class A(HasTraits):
+ tt = MyIntTT(10)
+
+ a = A()
+ self.assertEqual(a.tt, 10)
+
+ # Defaults are validated when the HasTraits is instantiated
+ class B(HasTraits):
+ tt = MyIntTT("bad default")
+
+ self.assertRaises(TraitError, getattr, B(), "tt")
+
+ def test_info(self):
+ class A(HasTraits):
+ tt = TraitType
+
+ a = A()
+ self.assertEqual(A.tt.info(), "any value") # type:ignore
+
+ def test_error(self):
+ class A(HasTraits):
+ tt = TraitType[int, int]()
+
+ a = A()
+ self.assertRaises(TraitError, A.tt.error, a, 10)
+
+ def test_deprecated_dynamic_initializer(self):
+ class A(HasTraits):
+ x = Int(10)
+
+ def _x_default(self):
+ return 11
+
+ class B(A):
+ x = Int(20)
+
+ class C(A):
+ def _x_default(self):
+ return 21
+
+ a = A()
+ self.assertEqual(a._trait_values, {})
+ self.assertEqual(a.x, 11)
+ self.assertEqual(a._trait_values, {"x": 11})
+ b = B()
+ self.assertEqual(b.x, 20)
+ self.assertEqual(b._trait_values, {"x": 20})
+ c = C()
+ self.assertEqual(c._trait_values, {})
+ self.assertEqual(c.x, 21)
+ self.assertEqual(c._trait_values, {"x": 21})
+ # Ensure that the base class remains unmolested when the _default
+ # initializer gets overridden in a subclass.
+ a = A()
+ c = C()
+ self.assertEqual(a._trait_values, {})
+ self.assertEqual(a.x, 11)
+ self.assertEqual(a._trait_values, {"x": 11})
+
+ def test_deprecated_method_warnings(self):
+ with expected_warnings([]):
+
+ class ShouldntWarn(HasTraits):
+ x = Integer()
+
+ @default("x")
+ def _x_default(self):
+ return 10
+
+ @validate("x")
+ def _x_validate(self, proposal):
+ return proposal.value
+
+ @observe("x")
+ def _x_changed(self, change):
+ pass
+
+ obj = ShouldntWarn()
+ obj.x = 5
+
+ assert obj.x == 5
+
+ with expected_warnings(["@validate", "@observe"]) as w:
+
+ class ShouldWarn(HasTraits):
+ x = Integer()
+
+ def _x_default(self):
+ return 10
+
+ def _x_validate(self, value, _):
+ return value
+
+ def _x_changed(self):
+ pass
+
+ obj = ShouldWarn() # type:ignore
+ obj.x = 5
+
+ assert obj.x == 5
+
+ def test_dynamic_initializer(self):
+ class A(HasTraits):
+ x = Int(10)
+
+ @default("x")
+ def _default_x(self):
+ return 11
+
+ class B(A):
+ x = Int(20)
+
+ class C(A):
+ @default("x")
+ def _default_x(self):
+ return 21
+
+ a = A()
+ self.assertEqual(a._trait_values, {})
+ self.assertEqual(a.x, 11)
+ self.assertEqual(a._trait_values, {"x": 11})
+ b = B()
+ self.assertEqual(b.x, 20)
+ self.assertEqual(b._trait_values, {"x": 20})
+ c = C()
+ self.assertEqual(c._trait_values, {})
+ self.assertEqual(c.x, 21)
+ self.assertEqual(c._trait_values, {"x": 21})
+ # Ensure that the base class remains unmolested when the _default
+ # initializer gets overridden in a subclass.
+ a = A()
+ c = C()
+ self.assertEqual(a._trait_values, {})
+ self.assertEqual(a.x, 11)
+ self.assertEqual(a._trait_values, {"x": 11})
+
+ def test_tag_metadata(self):
+ class MyIntTT(TraitType[int, int]):
+ metadata = {"a": 1, "b": 2}
+
+ a = MyIntTT(10).tag(b=3, c=4)
+ self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4})
+
+ def test_metadata_localized_instance(self):
+ class MyIntTT(TraitType[int, int]):
+ metadata = {"a": 1, "b": 2}
+
+ a = MyIntTT(10)
+ b = MyIntTT(10)
+ a.metadata["c"] = 3
+ # make sure that changing a's metadata didn't change b's metadata
+ self.assertNotIn("c", b.metadata)
+
+ def test_union_metadata(self):
+ class Foo(HasTraits):
+ bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti="b")).tag(ti="a")
+
+ foo = Foo()
+ # At this point, no value has been set for bar, so value-specific
+ # is not set.
+ self.assertEqual(foo.trait_metadata("bar", "ta"), None)
+ self.assertEqual(foo.trait_metadata("bar", "ti"), "a")
+ foo.bar = {}
+ self.assertEqual(foo.trait_metadata("bar", "ta"), 2)
+ self.assertEqual(foo.trait_metadata("bar", "ti"), "b")
+ foo.bar = 1
+ self.assertEqual(foo.trait_metadata("bar", "ta"), 1)
+ self.assertEqual(foo.trait_metadata("bar", "ti"), "a")
+
+ def test_union_default_value(self):
+ class Foo(HasTraits):
+ bar = Union([Dict(), Int()], default_value=1)
+
+ foo = Foo()
+ self.assertEqual(foo.bar, 1)
+
+ def test_union_validation_priority(self):
+ class Foo(HasTraits):
+ bar = Union([CInt(), Unicode()])
+
+ foo = Foo()
+ foo.bar = "1"
+ # validation in order of the TraitTypes given
+ self.assertEqual(foo.bar, 1)
+
+ def test_union_trait_default_value(self):
+ class Foo(HasTraits):
+ bar = Union([Dict(), Int()])
+
+ self.assertEqual(Foo().bar, {})
+
+ def test_deprecated_metadata_access(self):
+ class MyIntTT(TraitType[int, int]):
+ metadata = {"a": 1, "b": 2}
+
+ a = MyIntTT(10)
+ with expected_warnings(["use the instance .metadata dictionary directly"] * 2):
+ a.set_metadata("key", "value")
+ v = a.get_metadata("key")
+ self.assertEqual(v, "value")
+ with expected_warnings(["use the instance .help string directly"] * 2):
+ a.set_metadata("help", "some help")
+ v = a.get_metadata("help")
+ self.assertEqual(v, "some help")
+
+ def test_trait_types_deprecated(self):
+ with expected_warnings(["Traits should be given as instances"]):
+
+ class C(HasTraits):
+ t = Int
+
+ def test_trait_types_list_deprecated(self):
+ with expected_warnings(["Traits should be given as instances"]):
+
+ class C(HasTraits):
+ t = List(Int)
+
+ def test_trait_types_tuple_deprecated(self):
+ with expected_warnings(["Traits should be given as instances"]):
+
+ class C(HasTraits):
+ t = Tuple(Int)
+
+ def test_trait_types_dict_deprecated(self):
+ with expected_warnings(["Traits should be given as instances"]):
+
+ class C(HasTraits):
+ t = Dict(Int)
+
+
+class TestHasDescriptorsMeta(TestCase):
+ def test_metaclass(self):
+ self.assertEqual(type(HasTraits), MetaHasTraits)
+
+ class A(HasTraits):
+ a = Int()
+
+ a = A()
+ self.assertEqual(type(a.__class__), MetaHasTraits)
+ self.assertEqual(a.a, 0)
+ a.a = 10
+ self.assertEqual(a.a, 10)
+
+ class B(HasTraits):
+ b = Int()
+
+ b = B()
+ self.assertEqual(b.b, 0)
+ b.b = 10
+ self.assertEqual(b.b, 10)
+
+ class C(HasTraits):
+ c = Int(30)
+
+ c = C()
+ self.assertEqual(c.c, 30)
+ c.c = 10
+ self.assertEqual(c.c, 10)
+
+ def test_this_class(self):
+ class A(HasTraits):
+ t = This["A"]()
+ tt = This["A"]()
+
+ class B(A):
+ tt = This["A"]()
+ ttt = This["A"]()
+
+ self.assertEqual(A.t.this_class, A)
+ self.assertEqual(B.t.this_class, A)
+ self.assertEqual(B.tt.this_class, B)
+ self.assertEqual(B.ttt.this_class, B)
+
+
+class TestHasDescriptors(TestCase):
+ def test_setup_instance(self):
+ class FooDescriptor(BaseDescriptor):
+ def instance_init(self, inst):
+ foo = inst.foo # instance should have the attr
+
+ class HasFooDescriptors(HasDescriptors):
+ fd = FooDescriptor()
+
+ def setup_instance(self, *args, **kwargs):
+ self.foo = kwargs.get("foo", None)
+ super().setup_instance(*args, **kwargs)
+
+ hfd = HasFooDescriptors(foo="bar")
+
+
+class TestHasTraitsNotify(TestCase):
+ def setUp(self):
+ self._notify1 = []
+ self._notify2 = []
+
+ def notify1(self, name, old, new):
+ self._notify1.append((name, old, new))
+
+ def notify2(self, name, old, new):
+ self._notify2.append((name, old, new))
+
+ def test_notify_all(self):
+ class A(HasTraits):
+ a = Int()
+ b = Float()
+
+ a = A()
+ a.on_trait_change(self.notify1)
+ a.a = 0
+ self.assertEqual(len(self._notify1), 0)
+ a.b = 0.0
+ self.assertEqual(len(self._notify1), 0)
+ a.a = 10
+ self.assertTrue(("a", 0, 10) in self._notify1)
+ a.b = 10.0
+ self.assertTrue(("b", 0.0, 10.0) in self._notify1)
+ self.assertRaises(TraitError, setattr, a, "a", "bad string")
+ self.assertRaises(TraitError, setattr, a, "b", "bad string")
+ self._notify1 = []
+ a.on_trait_change(self.notify1, remove=True)
+ a.a = 20
+ a.b = 20.0
+ self.assertEqual(len(self._notify1), 0)
+
+ def test_notify_one(self):
+ class A(HasTraits):
+ a = Int()
+ b = Float()
+
+ a = A()
+ a.on_trait_change(self.notify1, "a")
+ a.a = 0
+ self.assertEqual(len(self._notify1), 0)
+ a.a = 10
+ self.assertTrue(("a", 0, 10) in self._notify1)
+ self.assertRaises(TraitError, setattr, a, "a", "bad string")
+
+ def test_subclass(self):
+ class A(HasTraits):
+ a = Int()
+
+ class B(A):
+ b = Float()
+
+ b = B()
+ self.assertEqual(b.a, 0)
+ self.assertEqual(b.b, 0.0)
+ b.a = 100
+ b.b = 100.0
+ self.assertEqual(b.a, 100)
+ self.assertEqual(b.b, 100.0)
+
+ def test_notify_subclass(self):
+ class A(HasTraits):
+ a = Int()
+
+ class B(A):
+ b = Float()
+
+ b = B()
+ b.on_trait_change(self.notify1, "a")
+ b.on_trait_change(self.notify2, "b")
+ b.a = 0
+ b.b = 0.0
+ self.assertEqual(len(self._notify1), 0)
+ self.assertEqual(len(self._notify2), 0)
+ b.a = 10
+ b.b = 10.0
+ self.assertTrue(("a", 0, 10) in self._notify1)
+ self.assertTrue(("b", 0.0, 10.0) in self._notify2)
+
+ def test_static_notify(self):
+ class A(HasTraits):
+ a = Int()
+ _notify1 = []
+
+ def _a_changed(self, name, old, new):
+ self._notify1.append((name, old, new))
+
+ a = A()
+ a.a = 0
+ # This is broken!!!
+ self.assertEqual(len(a._notify1), 0)
+ a.a = 10
+ self.assertTrue(("a", 0, 10) in a._notify1)
+
+ class B(A):
+ b = Float()
+ _notify2 = []
+
+ def _b_changed(self, name, old, new):
+ self._notify2.append((name, old, new))
+
+ b = B()
+ b.a = 10
+ b.b = 10.0
+ self.assertTrue(("a", 0, 10) in b._notify1)
+ self.assertTrue(("b", 0.0, 10.0) in b._notify2)
+
+ def test_notify_args(self):
+ def callback0():
+ self.cb = ()
+
+ def callback1(name):
+ self.cb = (name,) # type:ignore
+
+ def callback2(name, new):
+ self.cb = (name, new) # type:ignore
+
+ def callback3(name, old, new):
+ self.cb = (name, old, new) # type:ignore
+
+ def callback4(name, old, new, obj):
+ self.cb = (name, old, new, obj) # type:ignore
+
+ class A(HasTraits):
+ a = Int()
+
+ a = A()
+ a.on_trait_change(callback0, "a")
+ a.a = 10
+ self.assertEqual(self.cb, ())
+ a.on_trait_change(callback0, "a", remove=True)
+
+ a.on_trait_change(callback1, "a")
+ a.a = 100
+ self.assertEqual(self.cb, ("a",))
+ a.on_trait_change(callback1, "a", remove=True)
+
+ a.on_trait_change(callback2, "a")
+ a.a = 1000
+ self.assertEqual(self.cb, ("a", 1000))
+ a.on_trait_change(callback2, "a", remove=True)
+
+ a.on_trait_change(callback3, "a")
+ a.a = 10000
+ self.assertEqual(self.cb, ("a", 1000, 10000))
+ a.on_trait_change(callback3, "a", remove=True)
+
+ a.on_trait_change(callback4, "a")
+ a.a = 100000
+ self.assertEqual(self.cb, ("a", 10000, 100000, a))
+ self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1)
+ a.on_trait_change(callback4, "a", remove=True)
+
+ self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0)
+
+ def test_notify_only_once(self):
+ class A(HasTraits):
+ listen_to = ["a"]
+
+ a = Int(0)
+ b = 0
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.on_trait_change(self.listener1, ["a"])
+
+ def listener1(self, name, old, new):
+ self.b += 1
+
+ class B(A):
+ c = 0
+ d = 0
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.on_trait_change(self.listener2)
+
+ def listener2(self, name, old, new):
+ self.c += 1
+
+ def _a_changed(self, name, old, new):
+ self.d += 1
+
+ b = B()
+ b.a += 1
+ self.assertEqual(b.b, b.c)
+ self.assertEqual(b.b, b.d)
+ b.a += 1
+ self.assertEqual(b.b, b.c)
+ self.assertEqual(b.b, b.d)
+
+
+class TestObserveDecorator(TestCase):
+ def setUp(self):
+ self._notify1 = []
+ self._notify2 = []
+
+ def notify1(self, change):
+ self._notify1.append(change)
+
+ def notify2(self, change):
+ self._notify2.append(change)
+
+ def test_notify_all(self):
+ class A(HasTraits):
+ a = Int()
+ b = Float()
+
+ a = A()
+ a.observe(self.notify1)
+ a.a = 0
+ self.assertEqual(len(self._notify1), 0)
+ a.b = 0.0
+ self.assertEqual(len(self._notify1), 0)
+ a.a = 10
+ change = change_dict("a", 0, 10, a, "change")
+ self.assertTrue(change in self._notify1)
+ a.b = 10.0
+ change = change_dict("b", 0.0, 10.0, a, "change")
+ self.assertTrue(change in self._notify1)
+ self.assertRaises(TraitError, setattr, a, "a", "bad string")
+ self.assertRaises(TraitError, setattr, a, "b", "bad string")
+ self._notify1 = []
+ a.unobserve(self.notify1)
+ a.a = 20
+ a.b = 20.0
+ self.assertEqual(len(self._notify1), 0)
+
+ def test_notify_one(self):
+ class A(HasTraits):
+ a = Int()
+ b = Float()
+
+ a = A()
+ a.observe(self.notify1, "a")
+ a.a = 0
+ self.assertEqual(len(self._notify1), 0)
+ a.a = 10
+ change = change_dict("a", 0, 10, a, "change")
+ self.assertTrue(change in self._notify1)
+ self.assertRaises(TraitError, setattr, a, "a", "bad string")
+
+ def test_subclass(self):
+ class A(HasTraits):
+ a = Int()
+
+ class B(A):
+ b = Float()
+
+ b = B()
+ self.assertEqual(b.a, 0)
+ self.assertEqual(b.b, 0.0)
+ b.a = 100
+ b.b = 100.0
+ self.assertEqual(b.a, 100)
+ self.assertEqual(b.b, 100.0)
+
+ def test_notify_subclass(self):
+ class A(HasTraits):
+ a = Int()
+
+ class B(A):
+ b = Float()
+
+ b = B()
+ b.observe(self.notify1, "a")
+ b.observe(self.notify2, "b")
+ b.a = 0
+ b.b = 0.0
+ self.assertEqual(len(self._notify1), 0)
+ self.assertEqual(len(self._notify2), 0)
+ b.a = 10
+ b.b = 10.0
+ change = change_dict("a", 0, 10, b, "change")
+ self.assertTrue(change in self._notify1)
+ change = change_dict("b", 0.0, 10.0, b, "change")
+ self.assertTrue(change in self._notify2)
+
+ def test_static_notify(self):
+ class A(HasTraits):
+ a = Int()
+ b = Int()
+ _notify1 = []
+ _notify_any = []
+
+ @observe("a")
+ def _a_changed(self, change):
+ self._notify1.append(change)
+
+ @observe(All)
+ def _any_changed(self, change):
+ self._notify_any.append(change)
+
+ a = A()
+ a.a = 0
+ self.assertEqual(len(a._notify1), 0)
+ a.a = 10
+ change = change_dict("a", 0, 10, a, "change")
+ self.assertTrue(change in a._notify1)
+ a.b = 1
+ self.assertEqual(len(a._notify_any), 2)
+ change = change_dict("b", 0, 1, a, "change")
+ self.assertTrue(change in a._notify_any)
+
+ class B(A):
+ b = Float() # type:ignore
+ _notify2 = []
+
+ @observe("b")
+ def _b_changed(self, change):
+ self._notify2.append(change)
+
+ b = B()
+ b.a = 10
+ b.b = 10.0 # type:ignore
+ change = change_dict("a", 0, 10, b, "change")
+ self.assertTrue(change in b._notify1)
+ change = change_dict("b", 0.0, 10.0, b, "change")
+ self.assertTrue(change in b._notify2)
+
+ def test_notify_args(self):
+ def callback0():
+ self.cb = ()
+
+ def callback1(change):
+ self.cb = change
+
+ class A(HasTraits):
+ a = Int()
+
+ a = A()
+ a.on_trait_change(callback0, "a")
+ a.a = 10
+ self.assertEqual(self.cb, ())
+ a.unobserve(callback0, "a")
+
+ a.observe(callback1, "a")
+ a.a = 100
+ change = change_dict("a", 10, 100, a, "change")
+ self.assertEqual(self.cb, change)
+ self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1)
+ a.unobserve(callback1, "a")
+
+ self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0)
+
+ def test_notify_only_once(self):
+ class A(HasTraits):
+ listen_to = ["a"]
+
+ a = Int(0)
+ b = 0
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.observe(self.listener1, ["a"])
+
+ def listener1(self, change):
+ self.b += 1
+
+ class B(A):
+ c = 0
+ d = 0
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.observe(self.listener2)
+
+ def listener2(self, change):
+ self.c += 1
+
+ @observe("a")
+ def _a_changed(self, change):
+ self.d += 1
+
+ b = B()
+ b.a += 1
+ self.assertEqual(b.b, b.c)
+ self.assertEqual(b.b, b.d)
+ b.a += 1
+ self.assertEqual(b.b, b.c)
+ self.assertEqual(b.b, b.d)
+
+
+class TestHasTraits(TestCase):
+ def test_trait_names(self):
+ class A(HasTraits):
+ i = Int()
+ f = Float()
+
+ a = A()
+ self.assertEqual(sorted(a.trait_names()), ["f", "i"])
+ self.assertEqual(sorted(A.class_trait_names()), ["f", "i"])
+ self.assertTrue(a.has_trait("f"))
+ self.assertFalse(a.has_trait("g"))
+
+ def test_trait_has_value(self):
+ class A(HasTraits):
+ i = Int()
+ f = Float()
+
+ a = A()
+ self.assertFalse(a.trait_has_value("f"))
+ self.assertFalse(a.trait_has_value("g"))
+ a.i = 1
+ a.f
+ self.assertTrue(a.trait_has_value("i"))
+ self.assertTrue(a.trait_has_value("f"))
+
+ def test_trait_metadata_deprecated(self):
+ with expected_warnings([r"metadata should be set using the \.tag\(\) method"]):
+
+ class A(HasTraits):
+ i = Int(config_key="MY_VALUE")
+
+ a = A()
+ self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE")
+
+ def test_trait_metadata(self):
+ class A(HasTraits):
+ i = Int().tag(config_key="MY_VALUE")
+
+ a = A()
+ self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE")
+
+ def test_trait_metadata_default(self):
+ class A(HasTraits):
+ i = Int()
+
+ a = A()
+ self.assertEqual(a.trait_metadata("i", "config_key"), None)
+ self.assertEqual(a.trait_metadata("i", "config_key", "default"), "default")
+
+ def test_traits(self):
+ class A(HasTraits):
+ i = Int()
+ f = Float()
+
+ a = A()
+ self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
+ self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
+
+ def test_traits_metadata(self):
+ class A(HasTraits):
+ i = Int().tag(config_key="VALUE1", other_thing="VALUE2")
+ f = Float().tag(config_key="VALUE3", other_thing="VALUE2")
+ j = Int(0)
+
+ a = A()
+ self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
+ traits = a.traits(config_key="VALUE1", other_thing="VALUE2")
+ self.assertEqual(traits, dict(i=A.i))
+
+ # This passes, but it shouldn't because I am replicating a bug in
+ # traits.
+ traits = a.traits(config_key=lambda v: True)
+ self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
+
+ def test_traits_metadata_deprecated(self):
+ with expected_warnings([r"metadata should be set using the \.tag\(\) method"] * 2):
+
+ class A(HasTraits):
+ i = Int(config_key="VALUE1", other_thing="VALUE2")
+ f = Float(config_key="VALUE3", other_thing="VALUE2")
+ j = Int(0)
+
+ a = A()
+ self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
+ traits = a.traits(config_key="VALUE1", other_thing="VALUE2")
+ self.assertEqual(traits, dict(i=A.i))
+
+ # This passes, but it shouldn't because I am replicating a bug in
+ # traits.
+ traits = a.traits(config_key=lambda v: True)
+ self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
+
+ def test_init(self):
+ class A(HasTraits):
+ i = Int()
+ x = Float()
+
+ a = A(i=1, x=10.0)
+ self.assertEqual(a.i, 1)
+ self.assertEqual(a.x, 10.0)
+
+ def test_positional_args(self):
+ class A(HasTraits):
+ i = Int(0)
+
+ def __init__(self, i):
+ super().__init__()
+ self.i = i
+
+ a = A(5)
+ self.assertEqual(a.i, 5)
+ # should raise TypeError if no positional arg given
+ self.assertRaises(TypeError, A)
+
+
+# -----------------------------------------------------------------------------
+# Tests for specific trait types
+# -----------------------------------------------------------------------------
+
+
+class TestType(TestCase):
+ def test_default(self):
+ class B:
+ pass
+
+ class A(HasTraits):
+ klass = Type(allow_none=True)
+
+ a = A()
+ self.assertEqual(a.klass, object)
+
+ a.klass = B
+ self.assertEqual(a.klass, B)
+ self.assertRaises(TraitError, setattr, a, "klass", 10)
+
+ def test_default_options(self):
+ class B:
+ pass
+
+ class C(B):
+ pass
+
+ class A(HasTraits):
+ # Different possible combinations of options for default_value
+ # and klass. default_value=None is only valid with allow_none=True.
+ k1 = Type()
+ k2 = Type(None, allow_none=True)
+ k3 = Type(B)
+ k4 = Type(klass=B)
+ k5 = Type(default_value=None, klass=B, allow_none=True)
+ k6 = Type(default_value=C, klass=B)
+
+ self.assertIs(A.k1.default_value, object)
+ self.assertIs(A.k1.klass, object)
+ self.assertIs(A.k2.default_value, None)
+ self.assertIs(A.k2.klass, object)
+ self.assertIs(A.k3.default_value, B)
+ self.assertIs(A.k3.klass, B)
+ self.assertIs(A.k4.default_value, B)
+ self.assertIs(A.k4.klass, B)
+ self.assertIs(A.k5.default_value, None)
+ self.assertIs(A.k5.klass, B)
+ self.assertIs(A.k6.default_value, C)
+ self.assertIs(A.k6.klass, B)
+
+ a = A()
+ self.assertIs(a.k1, object)
+ self.assertIs(a.k2, None)
+ self.assertIs(a.k3, B)
+ self.assertIs(a.k4, B)
+ self.assertIs(a.k5, None)
+ self.assertIs(a.k6, C)
+
+ def test_value(self):
+ class B:
+ pass
+
+ class C:
+ pass
+
+ class A(HasTraits):
+ klass = Type(B)
+
+ a = A()
+ self.assertEqual(a.klass, B)
+ self.assertRaises(TraitError, setattr, a, "klass", C)
+ self.assertRaises(TraitError, setattr, a, "klass", object)
+ a.klass = B
+
+ def test_allow_none(self):
+ class B:
+ pass
+
+ class C(B):
+ pass
+
+ class A(HasTraits):
+ klass = Type(B)
+
+ a = A()
+ self.assertEqual(a.klass, B)
+ self.assertRaises(TraitError, setattr, a, "klass", None)
+ a.klass = C
+ self.assertEqual(a.klass, C)
+
+ def test_validate_klass(self):
+ class A(HasTraits):
+ klass = Type("no strings allowed")
+
+ self.assertRaises(ImportError, A)
+
+ class A(HasTraits): # type:ignore
+ klass = Type("rub.adub.Duck")
+
+ self.assertRaises(ImportError, A)
+
+ def test_validate_default(self):
+ class B:
+ pass
+
+ class A(HasTraits):
+ klass = Type("bad default", B)
+
+ self.assertRaises(ImportError, A)
+
+ class C(HasTraits):
+ klass = Type(None, B)
+
+ self.assertRaises(TraitError, getattr, C(), "klass")
+
+ def test_str_klass(self):
+ class A(HasTraits):
+ klass = Type("traitlets.config.Config")
+
+ from traitlets.config import Config
+
+ a = A()
+ a.klass = Config
+ self.assertEqual(a.klass, Config)
+
+ self.assertRaises(TraitError, setattr, a, "klass", 10)
+
+ def test_set_str_klass(self):
+ class A(HasTraits):
+ klass = Type()
+
+ a = A(klass="traitlets.config.Config")
+ from traitlets.config import Config
+
+ self.assertEqual(a.klass, Config)
+
+
+class TestInstance(TestCase):
+ def test_basic(self):
+ class Foo:
+ pass
+
+ class Bar(Foo):
+ pass
+
+ class Bah:
+ pass
+
+ class A(HasTraits):
+ inst = Instance(Foo, allow_none=True)
+
+ a = A()
+ self.assertTrue(a.inst is None)
+ a.inst = Foo()
+ self.assertTrue(isinstance(a.inst, Foo))
+ a.inst = Bar()
+ self.assertTrue(isinstance(a.inst, Foo))
+ self.assertRaises(TraitError, setattr, a, "inst", Foo)
+ self.assertRaises(TraitError, setattr, a, "inst", Bar)
+ self.assertRaises(TraitError, setattr, a, "inst", Bah())
+
+ def test_default_klass(self):
+ class Foo:
+ pass
+
+ class Bar(Foo):
+ pass
+
+ class Bah:
+ pass
+
+ class FooInstance(Instance[Foo]):
+ klass = Foo
+
+ class A(HasTraits):
+ inst = FooInstance(allow_none=True)
+
+ a = A()
+ self.assertTrue(a.inst is None)
+ a.inst = Foo()
+ self.assertTrue(isinstance(a.inst, Foo))
+ a.inst = Bar()
+ self.assertTrue(isinstance(a.inst, Foo))
+ self.assertRaises(TraitError, setattr, a, "inst", Foo)
+ self.assertRaises(TraitError, setattr, a, "inst", Bar)
+ self.assertRaises(TraitError, setattr, a, "inst", Bah())
+
+ def test_unique_default_value(self):
+ class Foo:
+ pass
+
+ class A(HasTraits):
+ inst = Instance(Foo, (), {})
+
+ a = A()
+ b = A()
+ self.assertTrue(a.inst is not b.inst)
+
+ def test_args_kw(self):
+ class Foo:
+ def __init__(self, c):
+ self.c = c
+
+ class Bar:
+ pass
+
+ class Bah:
+ def __init__(self, c, d):
+ self.c = c
+ self.d = d
+
+ class A(HasTraits):
+ inst = Instance(Foo, (10,))
+
+ a = A()
+ self.assertEqual(a.inst.c, 10)
+
+ class B(HasTraits):
+ inst = Instance(Bah, args=(10,), kw=dict(d=20))
+
+ b = B()
+ self.assertEqual(b.inst.c, 10)
+ self.assertEqual(b.inst.d, 20)
+
+ class C(HasTraits):
+ inst = Instance(Foo, allow_none=True)
+
+ c = C()
+ self.assertTrue(c.inst is None)
+
+ def test_bad_default(self):
+ class Foo:
+ pass
+
+ class A(HasTraits):
+ inst = Instance(Foo)
+
+ a = A()
+ with self.assertRaises(TraitError):
+ a.inst
+
+ def test_instance(self):
+ class Foo:
+ pass
+
+ def inner():
+ class A(HasTraits):
+ inst = Instance(Foo()) # type:ignore
+
+ self.assertRaises(TraitError, inner)
+
+
+class TestThis(TestCase):
+ def test_this_class(self):
+ class Foo(HasTraits):
+ this = This["Foo"]()
+
+ f = Foo()
+ self.assertEqual(f.this, None)
+ g = Foo()
+ f.this = g
+ self.assertEqual(f.this, g)
+ self.assertRaises(TraitError, setattr, f, "this", 10)
+
+ def test_this_inst(self):
+ class Foo(HasTraits):
+ this = This["Foo"]()
+
+ f = Foo()
+ f.this = Foo()
+ self.assertTrue(isinstance(f.this, Foo))
+
+ def test_subclass(self):
+ class Foo(HasTraits):
+ t = This["Foo"]()
+
+ class Bar(Foo):
+ pass
+
+ f = Foo()
+ b = Bar()
+ f.t = b
+ b.t = f
+ self.assertEqual(f.t, b)
+ self.assertEqual(b.t, f)
+
+ def test_subclass_override(self):
+ class Foo(HasTraits):
+ t = This["Foo"]()
+
+ class Bar(Foo):
+ t = This()
+
+ f = Foo()
+ b = Bar()
+ f.t = b
+ self.assertEqual(f.t, b)
+ self.assertRaises(TraitError, setattr, b, "t", f)
+
+ def test_this_in_container(self):
+ class Tree(HasTraits):
+ value = Unicode()
+ leaves = List(This())
+
+ tree = Tree(value="foo", leaves=[Tree(value="bar"), Tree(value="buzz")])
+
+ with self.assertRaises(TraitError):
+ tree.leaves = [1, 2]
+
+
+class TraitTestBase(TestCase):
+ """A best testing class for basic trait types."""
+
+ def assign(self, value):
+ self.obj.value = value # type:ignore
+
+ def coerce(self, value):
+ return value
+
+ def test_good_values(self):
+ if hasattr(self, "_good_values"):
+ for value in self._good_values:
+ self.assign(value)
+ self.assertEqual(self.obj.value, self.coerce(value)) # type:ignore
+
+ def test_bad_values(self):
+ if hasattr(self, "_bad_values"):
+ for value in self._bad_values:
+ try:
+ self.assertRaises(TraitError, self.assign, value)
+ except AssertionError:
+ assert False, value
+
+ def test_default_value(self):
+ if hasattr(self, "_default_value"):
+ self.assertEqual(self._default_value, self.obj.value) # type:ignore
+
+ def test_allow_none(self):
+ if (
+ hasattr(self, "_bad_values")
+ and hasattr(self, "_good_values")
+ and None in self._bad_values
+ ):
+ trait = self.obj.traits()["value"] # type:ignore
+ try:
+ trait.allow_none = True
+ self._bad_values.remove(None)
+ # skip coerce. Allow None casts None to None.
+ self.assign(None)
+ self.assertEqual(self.obj.value, None) # type:ignore
+ self.test_good_values()
+ self.test_bad_values()
+ finally:
+ # tear down
+ trait.allow_none = False
+ self._bad_values.append(None)
+
+ def tearDown(self):
+ # restore default value after tests, if set
+ if hasattr(self, "_default_value"):
+ self.obj.value = self._default_value # type:ignore
+
+
+class AnyTrait(HasTraits):
+ value = Any()
+
+
+class AnyTraitTest(TraitTestBase):
+ obj = AnyTrait()
+
+ _default_value = None
+ _good_values = [10.0, "ten", [10], {"ten": 10}, (10,), None, 1j]
+ _bad_values: t.Any = []
+
+
+class UnionTrait(HasTraits):
+ value = Union([Type(), Bool()])
+
+
+class UnionTraitTest(TraitTestBase):
+ obj = UnionTrait(value="traitlets.config.Config")
+ _good_values = [int, float, True]
+ _bad_values = [[], (0,), 1j]
+
+
+class CallableTrait(HasTraits):
+ value = Callable()
+
+
+class CallableTraitTest(TraitTestBase):
+ obj = CallableTrait(value=lambda x: type(x))
+ _good_values = [int, sorted, lambda x: print(x)]
+ _bad_values = [[], 1, ""]
+
+
+class OrTrait(HasTraits):
+ value = Bool() | Unicode()
+
+
+class OrTraitTest(TraitTestBase):
+ obj = OrTrait()
+ _good_values = [True, False, "ten"]
+ _bad_values = [[], (0,), 1j]
+
+
+class IntTrait(HasTraits):
+ value = Int(99, min=-100)
+
+
+class TestInt(TraitTestBase):
+ obj = IntTrait()
+ _default_value = 99
+ _good_values = [10, -10]
+ _bad_values = [
+ "ten",
+ [10],
+ {"ten": 10},
+ (10,),
+ None,
+ 1j,
+ 10.1,
+ -10.1,
+ "10L",
+ "-10L",
+ "10.1",
+ "-10.1",
+ "10",
+ "-10",
+ -200,
+ ]
+
+
+class CIntTrait(HasTraits):
+ value = CInt("5")
+
+
+class TestCInt(TraitTestBase):
+ obj = CIntTrait()
+
+ _default_value = 5
+ _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1]
+ _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"]
+
+ def coerce(self, n):
+ return int(n)
+
+
+class MinBoundCIntTrait(HasTraits):
+ value = CInt("5", min=3)
+
+
+class TestMinBoundCInt(TestCInt):
+ obj = MinBoundCIntTrait() # type:ignore
+
+ _default_value = 5
+ _good_values = [3, 3.0, "3"]
+ _bad_values = [2.6, 2, -3, -3.0]
+
+
+class LongTrait(HasTraits):
+ value = Long(99)
+
+
+class TestLong(TraitTestBase):
+ obj = LongTrait()
+
+ _default_value = 99
+ _good_values = [10, -10]
+ _bad_values = [
+ "ten",
+ [10],
+ {"ten": 10},
+ (10,),
+ None,
+ 1j,
+ 10.1,
+ -10.1,
+ "10",
+ "-10",
+ "10L",
+ "-10L",
+ "10.1",
+ "-10.1",
+ ]
+
+
+class MinBoundLongTrait(HasTraits):
+ value = Long(99, min=5)
+
+
+class TestMinBoundLong(TraitTestBase):
+ obj = MinBoundLongTrait()
+
+ _default_value = 99
+ _good_values = [5, 10]
+ _bad_values = [4, -10]
+
+
+class MaxBoundLongTrait(HasTraits):
+ value = Long(5, max=10)
+
+
+class TestMaxBoundLong(TraitTestBase):
+ obj = MaxBoundLongTrait()
+
+ _default_value = 5
+ _good_values = [10, -2]
+ _bad_values = [11, 20]
+
+
+class CLongTrait(HasTraits):
+ value = CLong("5")
+
+
+class TestCLong(TraitTestBase):
+ obj = CLongTrait()
+
+ _default_value = 5
+ _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1]
+ _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"]
+
+ def coerce(self, n):
+ return int(n)
+
+
+class MaxBoundCLongTrait(HasTraits):
+ value = CLong("5", max=10)
+
+
+class TestMaxBoundCLong(TestCLong):
+ obj = MaxBoundCLongTrait() # type:ignore
+
+ _default_value = 5
+ _good_values = [10, "10", 10.3]
+ _bad_values = [11.0, "11"]
+
+
+class IntegerTrait(HasTraits):
+ value = Integer(1)
+
+
+class TestInteger(TestLong):
+ obj = IntegerTrait() # type:ignore
+ _default_value = 1
+
+ def coerce(self, n):
+ return int(n)
+
+
+class MinBoundIntegerTrait(HasTraits):
+ value = Integer(5, min=3)
+
+
+class TestMinBoundInteger(TraitTestBase):
+ obj = MinBoundIntegerTrait()
+
+ _default_value = 5
+ _good_values = 3, 20
+ _bad_values = [2, -10]
+
+
+class MaxBoundIntegerTrait(HasTraits):
+ value = Integer(1, max=3)
+
+
+class TestMaxBoundInteger(TraitTestBase):
+ obj = MaxBoundIntegerTrait()
+
+ _default_value = 1
+ _good_values = 3, -2
+ _bad_values = [4, 10]
+
+
+class FloatTrait(HasTraits):
+ value = Float(99.0, max=200.0)
+
+
+class TestFloat(TraitTestBase):
+ obj = FloatTrait()
+
+ _default_value = 99.0
+ _good_values = [10, -10, 10.1, -10.1]
+ _bad_values = [
+ "ten",
+ [10],
+ {"ten": 10},
+ (10,),
+ None,
+ 1j,
+ "10",
+ "-10",
+ "10L",
+ "-10L",
+ "10.1",
+ "-10.1",
+ 201.0,
+ ]
+
+
+class CFloatTrait(HasTraits):
+ value = CFloat("99.0", max=200.0)
+
+
+class TestCFloat(TraitTestBase):
+ obj = CFloatTrait()
+
+ _default_value = 99.0
+ _good_values = [10, 10.0, 10.5, "10.0", "10", "-10"]
+ _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, 200.1, "200.1"]
+
+ def coerce(self, v):
+ return float(v)
+
+
+class ComplexTrait(HasTraits):
+ value = Complex(99.0 - 99.0j)
+
+
+class TestComplex(TraitTestBase):
+ obj = ComplexTrait()
+
+ _default_value = 99.0 - 99.0j
+ _good_values = [
+ 10,
+ -10,
+ 10.1,
+ -10.1,
+ 10j,
+ 10 + 10j,
+ 10 - 10j,
+ 10.1j,
+ 10.1 + 10.1j,
+ 10.1 - 10.1j,
+ ]
+ _bad_values = ["10L", "-10L", "ten", [10], {"ten": 10}, (10,), None]
+
+
+class BytesTrait(HasTraits):
+ value = Bytes(b"string")
+
+
+class TestBytes(TraitTestBase):
+ obj = BytesTrait()
+
+ _default_value = b"string"
+ _good_values = [b"10", b"-10", b"10L", b"-10L", b"10.1", b"-10.1", b"string"]
+ _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None, "string"]
+
+
+class UnicodeTrait(HasTraits):
+ value = Unicode("unicode")
+
+
+class TestUnicode(TraitTestBase):
+ obj = UnicodeTrait()
+
+ _default_value = "unicode"
+ _good_values = ["10", "-10", "10L", "-10L", "10.1", "-10.1", "", "string", "€", b"bytestring"]
+ _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None]
+
+ def coerce(self, v):
+ return cast_unicode(v)
+
+
+class ObjectNameTrait(HasTraits):
+ value = ObjectName("abc")
+
+
+class TestObjectName(TraitTestBase):
+ obj = ObjectNameTrait()
+
+ _default_value = "abc"
+ _good_values = ["a", "gh", "g9", "g_", "_G", "a345_"]
+ _bad_values = [
+ 1,
+ "",
+ "€",
+ "9g",
+ "!",
+ "#abc",
+ "aj@",
+ "a.b",
+ "a()",
+ "a[0]",
+ None,
+ object(),
+ object,
+ ]
+ _good_values.append("þ") # þ=1 is valid in Python 3 (PEP 3131).
+
+
+class DottedObjectNameTrait(HasTraits):
+ value = DottedObjectName("a.b")
+
+
+class TestDottedObjectName(TraitTestBase):
+ obj = DottedObjectNameTrait()
+
+ _default_value = "a.b"
+ _good_values = ["A", "y.t", "y765.__repr__", "os.path.join"]
+ _bad_values = [1, "abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
+
+ _good_values.append("t.þ")
+
+
+class TCPAddressTrait(HasTraits):
+ value = TCPAddress()
+
+
+class TestTCPAddress(TraitTestBase):
+ obj = TCPAddressTrait()
+
+ _default_value = ("127.0.0.1", 0)
+ _good_values = [("localhost", 0), ("192.168.0.1", 1000), ("www.google.com", 80)]
+ _bad_values = [(0, 0), ("localhost", 10.0), ("localhost", -1), None]
+
+
+class ListTrait(HasTraits):
+ value = List(Int())
+
+
+class TestList(TraitTestBase):
+ obj = ListTrait()
+
+ _default_value: t.List[t.Any] = []
+ _good_values = [[], [1], list(range(10)), (1, 2)]
+ _bad_values = [10, [1, "a"], "a"]
+
+ def coerce(self, value):
+ if value is not None:
+ value = list(value)
+ return value
+
+
+class Foo:
+ pass
+
+
+class NoneInstanceListTrait(HasTraits):
+ value = List(Instance(Foo))
+
+
+class TestNoneInstanceList(TraitTestBase):
+ obj = NoneInstanceListTrait()
+
+ _default_value: t.List[t.Any] = []
+ _good_values = [[Foo(), Foo()], []]
+ _bad_values = [[None], [Foo(), None]]
+
+
+class InstanceListTrait(HasTraits):
+ value = List(Instance(__name__ + ".Foo"))
+
+
+class TestInstanceList(TraitTestBase):
+ obj = InstanceListTrait()
+
+ def test_klass(self):
+ """Test that the instance klass is properly assigned."""
+ self.assertIs(self.obj.traits()["value"]._trait.klass, Foo)
+
+ _default_value: t.List[t.Any] = []
+ _good_values = [[Foo(), Foo()], []]
+ _bad_values = [
+ [
+ "1",
+ 2,
+ ],
+ "1",
+ [Foo],
+ None,
+ ]
+
+
+class UnionListTrait(HasTraits):
+ value = List(Int() | Bool())
+
+
+class TestUnionListTrait(TraitTestBase):
+ obj = UnionListTrait()
+
+ _default_value: t.List[t.Any] = []
+ _good_values = [[True, 1], [False, True]]
+ _bad_values = [[1, "True"], False]
+
+
+class LenListTrait(HasTraits):
+ value = List(Int(), [0], minlen=1, maxlen=2)
+
+
+class TestLenList(TraitTestBase):
+ obj = LenListTrait()
+
+ _default_value = [0]
+ _good_values = [[1], [1, 2], (1, 2)]
+ _bad_values = [10, [1, "a"], "a", [], list(range(3))]
+
+ def coerce(self, value):
+ if value is not None:
+ value = list(value)
+ return value
+
+
+class TupleTrait(HasTraits):
+ value = Tuple(Int(allow_none=True), default_value=(1,))
+
+
+class TestTupleTrait(TraitTestBase):
+ obj = TupleTrait()
+
+ _default_value = (1,)
+ _good_values = [(1,), (0,), [1]]
+ _bad_values = [10, (1, 2), ("a"), (), None]
+
+ def coerce(self, value):
+ if value is not None:
+ value = tuple(value)
+ return value
+
+ def test_invalid_args(self):
+ self.assertRaises(TypeError, Tuple, 5)
+ self.assertRaises(TypeError, Tuple, default_value="hello")
+ t = Tuple(Int(), CBytes(), default_value=(1, 5))
+
+
+class LooseTupleTrait(HasTraits):
+ value = Tuple((1, 2, 3))
+
+
+class TestLooseTupleTrait(TraitTestBase):
+ obj = LooseTupleTrait()
+
+ _default_value = (1, 2, 3)
+ _good_values = [(1,), [1], (0,), tuple(range(5)), tuple("hello"), ("a", 5), ()]
+ _bad_values = [10, "hello", {}, None]
+
+ def coerce(self, value):
+ if value is not None:
+ value = tuple(value)
+ return value
+
+ def test_invalid_args(self):
+ self.assertRaises(TypeError, Tuple, 5)
+ self.assertRaises(TypeError, Tuple, default_value="hello")
+ t = Tuple(Int(), CBytes(), default_value=(1, 5))
+
+
+class MultiTupleTrait(HasTraits):
+ value = Tuple(Int(), Bytes(), default_value=[99, b"bottles"])
+
+
+class TestMultiTuple(TraitTestBase):
+ obj = MultiTupleTrait()
+
+ _default_value = (99, b"bottles")
+ _good_values = [(1, b"a"), (2, b"b")]
+ _bad_values = ((), 10, b"a", (1, b"a", 3), (b"a", 1), (1, "a"))
+
+
+@pytest.mark.parametrize(
+ "Trait",
+ (
+ List,
+ Tuple,
+ Set,
+ Dict,
+ Integer,
+ Unicode,
+ ),
+)
+def test_allow_none_default_value(Trait):
+ class C(HasTraits):
+ t = Trait(default_value=None, allow_none=True)
+
+ # test default value
+ c = C()
+ assert c.t is None
+
+ # and in constructor
+ c = C(t=None)
+ assert c.t is None
+
+
+@pytest.mark.parametrize(
+ "Trait, default_value",
+ ((List, []), (Tuple, ()), (Set, set()), (Dict, {}), (Integer, 0), (Unicode, "")),
+)
+def test_default_value(Trait, default_value):
+ class C(HasTraits):
+ t = Trait()
+
+ # test default value
+ c = C()
+ assert type(c.t) is type(default_value)
+ assert c.t == default_value
+
+
+@pytest.mark.parametrize(
+ "Trait, default_value",
+ ((List, []), (Tuple, ()), (Set, set())),
+)
+def test_subclass_default_value(Trait, default_value):
+ """Test deprecated default_value=None behavior for Container subclass traits"""
+
+ class SubclassTrait(Trait): # type:ignore
+ def __init__(self, default_value=None):
+ super().__init__(default_value=default_value)
+
+ class C(HasTraits):
+ t = SubclassTrait()
+
+ # test default value
+ c = C()
+ assert type(c.t) is type(default_value)
+ assert c.t == default_value
+
+
+class CRegExpTrait(HasTraits):
+ value = CRegExp(r"")
+
+
+class TestCRegExp(TraitTestBase):
+ def coerce(self, value):
+ return re.compile(value)
+
+ obj = CRegExpTrait()
+
+ _default_value = re.compile(r"")
+ _good_values = [r"\d+", re.compile(r"\d+")]
+ _bad_values = ["(", None, ()]
+
+
+class DictTrait(HasTraits):
+ value = Dict()
+
+
+def test_dict_assignment():
+ d: t.Dict[str, int] = {}
+ c = DictTrait()
+ c.value = d
+ d["a"] = 5
+ assert d == c.value
+ assert c.value is d
+
+
+class UniformlyValueValidatedDictTrait(HasTraits):
+ value = Dict(value_trait=Unicode(), default_value={"foo": "1"})
+
+
+class TestInstanceUniformlyValueValidatedDict(TraitTestBase):
+ obj = UniformlyValueValidatedDictTrait()
+
+ _default_value = {"foo": "1"}
+ _good_values = [{"foo": "0", "bar": "1"}]
+ _bad_values = [{"foo": 0, "bar": "1"}]
+
+
+class NonuniformlyValueValidatedDictTrait(HasTraits):
+ value = Dict(per_key_traits={"foo": Int()}, default_value={"foo": 1})
+
+
+class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase):
+ obj = NonuniformlyValueValidatedDictTrait()
+
+ _default_value = {"foo": 1}
+ _good_values = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": 1}]
+ _bad_values = [{"foo": "0", "bar": "1"}]
+
+
+class KeyValidatedDictTrait(HasTraits):
+ value = Dict(key_trait=Unicode(), default_value={"foo": "1"})
+
+
+class TestInstanceKeyValidatedDict(TraitTestBase):
+ obj = KeyValidatedDictTrait()
+
+ _default_value = {"foo": "1"}
+ _good_values = [{"foo": "0", "bar": "1"}]
+ _bad_values = [{"foo": "0", 0: "1"}]
+
+
+class FullyValidatedDictTrait(HasTraits):
+ value = Dict(
+ value_trait=Unicode(),
+ key_trait=Unicode(),
+ per_key_traits={"foo": Int()},
+ default_value={"foo": 1},
+ )
+
+
+class TestInstanceFullyValidatedDict(TraitTestBase):
+ obj = FullyValidatedDictTrait()
+
+ _default_value = {"foo": 1}
+ _good_values = [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}]
+ _bad_values = [{"foo": 0, "bar": 1}, {"foo": "0", "bar": "1"}, {"foo": 0, 0: "1"}]
+
+
+def test_dict_default_value():
+ """Check that the `{}` default value of the Dict traitlet constructor is
+ actually copied."""
+
+ class Foo(HasTraits):
+ d1 = Dict()
+ d2 = Dict()
+
+ foo = Foo()
+ assert foo.d1 == {}
+ assert foo.d2 == {}
+ assert foo.d1 is not foo.d2
+
+
+class TestValidationHook(TestCase):
+ def test_parity_trait(self):
+ """Verify that the early validation hook is effective"""
+
+ class Parity(HasTraits):
+ value = Int(0)
+ parity = Enum(["odd", "even"], default_value="even")
+
+ @validate("value")
+ def _value_validate(self, proposal):
+ value = proposal["value"]
+ if self.parity == "even" and value % 2:
+ raise TraitError("Expected an even number")
+ if self.parity == "odd" and (value % 2 == 0):
+ raise TraitError("Expected an odd number")
+ return value
+
+ u = Parity()
+ u.parity = "odd"
+ u.value = 1 # OK
+ with self.assertRaises(TraitError):
+ u.value = 2 # Trait Error
+
+ u.parity = "even"
+ u.value = 2 # OK
+
+ def test_multiple_validate(self):
+ """Verify that we can register the same validator to multiple names"""
+
+ class OddEven(HasTraits):
+ odd = Int(1)
+ even = Int(0)
+
+ @validate("odd", "even")
+ def check_valid(self, proposal):
+ if proposal["trait"].name == "odd" and not proposal["value"] % 2:
+ raise TraitError("odd should be odd")
+ if proposal["trait"].name == "even" and proposal["value"] % 2:
+ raise TraitError("even should be even")
+
+ u = OddEven()
+ u.odd = 3 # OK
+ with self.assertRaises(TraitError):
+ u.odd = 2 # Trait Error
+
+ u.even = 2 # OK
+ with self.assertRaises(TraitError):
+ u.even = 3 # Trait Error
+
+ def test_validate_used(self):
+ """Verify that the validate value is being used"""
+
+ class FixedValue(HasTraits):
+ value = Int(0)
+
+ @validate("value")
+ def _value_validate(self, proposal):
+ return -1
+
+ u = FixedValue(value=2)
+ assert u.value == -1
+
+ u = FixedValue()
+ u.value = 3
+ assert u.value == -1
+
+
+class TestLink(TestCase):
+ def test_connect_same(self):
+ """Verify two traitlets of the same type can be linked together using link."""
+
+ # Create two simple classes with Int traitlets.
+ class A(HasTraits):
+ value = Int()
+
+ a = A(value=9)
+ b = A(value=8)
+
+ # Conenct the two classes.
+ c = link((a, "value"), (b, "value"))
+
+ # Make sure the values are the same at the point of linking.
+ self.assertEqual(a.value, b.value)
+
+ # Change one of the values to make sure they stay in sync.
+ a.value = 5
+ self.assertEqual(a.value, b.value)
+ b.value = 6
+ self.assertEqual(a.value, b.value)
+
+ def test_link_different(self):
+ """Verify two traitlets of different types can be linked together using link."""
+
+ # Create two simple classes with Int traitlets.
+ class A(HasTraits):
+ value = Int()
+
+ class B(HasTraits):
+ count = Int()
+
+ a = A(value=9)
+ b = B(count=8)
+
+ # Conenct the two classes.
+ c = link((a, "value"), (b, "count"))
+
+ # Make sure the values are the same at the point of linking.
+ self.assertEqual(a.value, b.count)
+
+ # Change one of the values to make sure they stay in sync.
+ a.value = 5
+ self.assertEqual(a.value, b.count)
+ b.count = 4
+ self.assertEqual(a.value, b.count)
+
+ def test_unlink_link(self):
+ """Verify two linked traitlets can be unlinked and relinked."""
+
+ # Create two simple classes with Int traitlets.
+ class A(HasTraits):
+ value = Int()
+
+ a = A(value=9)
+ b = A(value=8)
+
+ # Connect the two classes.
+ c = link((a, "value"), (b, "value"))
+ a.value = 4
+ c.unlink()
+
+ # Change one of the values to make sure they don't stay in sync.
+ a.value = 5
+ self.assertNotEqual(a.value, b.value)
+ c.link()
+ self.assertEqual(a.value, b.value)
+ a.value += 1
+ self.assertEqual(a.value, b.value)
+
+ def test_callbacks(self):
+ """Verify two linked traitlets have their callbacks called once."""
+
+ # Create two simple classes with Int traitlets.
+ class A(HasTraits):
+ value = Int()
+
+ class B(HasTraits):
+ count = Int()
+
+ a = A(value=9)
+ b = B(count=8)
+
+ # Register callbacks that count.
+ callback_count = []
+
+ def a_callback(name, old, new):
+ callback_count.append("a")
+
+ a.on_trait_change(a_callback, "value")
+
+ def b_callback(name, old, new):
+ callback_count.append("b")
+
+ b.on_trait_change(b_callback, "count")
+
+ # Connect the two classes.
+ c = link((a, "value"), (b, "count"))
+
+ # Make sure b's count was set to a's value once.
+ self.assertEqual("".join(callback_count), "b")
+ del callback_count[:]
+
+ # Make sure a's value was set to b's count once.
+ b.count = 5
+ self.assertEqual("".join(callback_count), "ba")
+ del callback_count[:]
+
+ # Make sure b's count was set to a's value once.
+ a.value = 4
+ self.assertEqual("".join(callback_count), "ab")
+ del callback_count[:]
+
+ def test_tranform(self):
+ """Test transform link."""
+
+ # Create two simple classes with Int traitlets.
+ class A(HasTraits):
+ value = Int()
+
+ a = A(value=9)
+ b = A(value=8)
+
+ # Conenct the two classes.
+ c = link((a, "value"), (b, "value"), transform=(lambda x: 2 * x, lambda x: int(x / 2.0)))
+
+ # Make sure the values are correct at the point of linking.
+ self.assertEqual(b.value, 2 * a.value)
+
+ # Change one the value of the source and check that it modifies the target.
+ a.value = 5
+ self.assertEqual(b.value, 10)
+ # Change one the value of the target and check that it modifies the
+ # source.
+ b.value = 6
+ self.assertEqual(a.value, 3)
+
+ def test_link_broken_at_source(self):
+ class MyClass(HasTraits):
+ i = Int()
+ j = Int()
+
+ @observe("j")
+ def another_update(self, change):
+ self.i = change.new * 2
+
+ mc = MyClass()
+ l = link((mc, "i"), (mc, "j")) # noqa
+ self.assertRaises(TraitError, setattr, mc, "i", 2)
+
+ def test_link_broken_at_target(self):
+ class MyClass(HasTraits):
+ i = Int()
+ j = Int()
+
+ @observe("i")
+ def another_update(self, change):
+ self.j = change.new * 2
+
+ mc = MyClass()
+ l = link((mc, "i"), (mc, "j")) # noqa
+ self.assertRaises(TraitError, setattr, mc, "j", 2)
+
+
+class TestDirectionalLink(TestCase):
+ def test_connect_same(self):
+ """Verify two traitlets of the same type can be linked together using directional_link."""
+
+ # Create two simple classes with Int traitlets.
+ class A(HasTraits):
+ value = Int()
+
+ a = A(value=9)
+ b = A(value=8)
+
+ # Conenct the two classes.
+ c = directional_link((a, "value"), (b, "value"))
+
+ # Make sure the values are the same at the point of linking.
+ self.assertEqual(a.value, b.value)
+
+ # Change one the value of the source and check that it synchronizes the target.
+ a.value = 5
+ self.assertEqual(b.value, 5)
+ # Change one the value of the target and check that it has no impact on the source
+ b.value = 6
+ self.assertEqual(a.value, 5)
+
+ def test_tranform(self):
+ """Test transform link."""
+
+ # Create two simple classes with Int traitlets.
+ class A(HasTraits):
+ value = Int()
+
+ a = A(value=9)
+ b = A(value=8)
+
+ # Conenct the two classes.
+ c = directional_link((a, "value"), (b, "value"), lambda x: 2 * x)
+
+ # Make sure the values are correct at the point of linking.
+ self.assertEqual(b.value, 2 * a.value)
+
+ # Change one the value of the source and check that it modifies the target.
+ a.value = 5
+ self.assertEqual(b.value, 10)
+ # Change one the value of the target and check that it has no impact on the source
+ b.value = 6
+ self.assertEqual(a.value, 5)
+
+ def test_link_different(self):
+ """Verify two traitlets of different types can be linked together using link."""
+
+ # Create two simple classes with Int traitlets.
+ class A(HasTraits):
+ value = Int()
+
+ class B(HasTraits):
+ count = Int()
+
+ a = A(value=9)
+ b = B(count=8)
+
+ # Conenct the two classes.
+ c = directional_link((a, "value"), (b, "count"))
+
+ # Make sure the values are the same at the point of linking.
+ self.assertEqual(a.value, b.count)
+
+ # Change one the value of the source and check that it synchronizes the target.
+ a.value = 5
+ self.assertEqual(b.count, 5)
+ # Change one the value of the target and check that it has no impact on the source
+ b.value = 6 # type:ignore
+ self.assertEqual(a.value, 5)
+
+ def test_unlink_link(self):
+ """Verify two linked traitlets can be unlinked and relinked."""
+
+ # Create two simple classes with Int traitlets.
+ class A(HasTraits):
+ value = Int()
+
+ a = A(value=9)
+ b = A(value=8)
+
+ # Connect the two classes.
+ c = directional_link((a, "value"), (b, "value"))
+ a.value = 4
+ c.unlink()
+
+ # Change one of the values to make sure they don't stay in sync.
+ a.value = 5
+ self.assertNotEqual(a.value, b.value)
+ c.link()
+ self.assertEqual(a.value, b.value)
+ a.value += 1
+ self.assertEqual(a.value, b.value)
+
+
+class Pickleable(HasTraits):
+ i = Int()
+
+ @observe("i")
+ def _i_changed(self, change):
+ pass
+
+ @validate("i")
+ def _i_validate(self, commit):
+ return commit["value"]
+
+ j = Int()
+
+ def __init__(self):
+ with self.hold_trait_notifications():
+ self.i = 1
+ self.on_trait_change(self._i_changed, "i")
+
+
+def test_pickle_hastraits():
+ c = Pickleable()
+ for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
+ p = pickle.dumps(c, protocol)
+ c2 = pickle.loads(p)
+ assert c2.i == c.i
+ assert c2.j == c.j
+
+ c.i = 5
+ for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
+ p = pickle.dumps(c, protocol)
+ c2 = pickle.loads(p)
+ assert c2.i == c.i
+ assert c2.j == c.j
+
+
+def test_hold_trait_notifications():
+ changes = []
+
+ class Test(HasTraits):
+ a = Integer(0)
+ b = Integer(0)
+
+ def _a_changed(self, name, old, new):
+ changes.append((old, new))
+
+ def _b_validate(self, value, trait):
+ if value != 0:
+ raise TraitError("Only 0 is a valid value")
+ return value
+
+ # Test context manager and nesting
+ t = Test()
+ with t.hold_trait_notifications():
+ with t.hold_trait_notifications():
+ t.a = 1
+ assert t.a == 1
+ assert changes == []
+ t.a = 2
+ assert t.a == 2
+ with t.hold_trait_notifications():
+ t.a = 3
+ assert t.a == 3
+ assert changes == []
+ t.a = 4
+ assert t.a == 4
+ assert changes == []
+ t.a = 4
+ assert t.a == 4
+ assert changes == []
+
+ assert changes == [(0, 4)]
+ # Test roll-back
+ try:
+ with t.hold_trait_notifications():
+ t.b = 1 # raises a Trait error
+ except Exception:
+ pass
+ assert t.b == 0
+
+
+class RollBack(HasTraits):
+ bar = Int()
+
+ def _bar_validate(self, value, trait):
+ if value:
+ raise TraitError("foobar")
+ return value
+
+
+class TestRollback(TestCase):
+ def test_roll_back(self):
+ def assign_rollback():
+ RollBack(bar=1)
+
+ self.assertRaises(TraitError, assign_rollback)
+
+
+class CacheModification(HasTraits):
+ foo = Int()
+ bar = Int()
+
+ def _bar_validate(self, value, trait):
+ self.foo = value
+ return value
+
+ def _foo_validate(self, value, trait):
+ self.bar = value
+ return value
+
+
+def test_cache_modification():
+ CacheModification(foo=1)
+ CacheModification(bar=1)
+
+
+class OrderTraits(HasTraits):
+ notified = Dict()
+
+ a = Unicode()
+ b = Unicode()
+ c = Unicode()
+ d = Unicode()
+ e = Unicode()
+ f = Unicode()
+ g = Unicode()
+ h = Unicode()
+ i = Unicode()
+ j = Unicode()
+ k = Unicode()
+ l = Unicode() # noqa
+
+ def _notify(self, name, old, new):
+ """check the value of all traits when each trait change is triggered
+
+ This verifies that the values are not sensitive
+ to dict ordering when loaded from kwargs
+ """
+ # check the value of the other traits
+ # when a given trait change notification fires
+ self.notified[name] = {c: getattr(self, c) for c in "abcdefghijkl"}
+
+ def __init__(self, **kwargs):
+ self.on_trait_change(self._notify)
+ super().__init__(**kwargs)
+
+
+def test_notification_order():
+ d = {c: c for c in "abcdefghijkl"}
+ obj = OrderTraits()
+ assert obj.notified == {}
+ obj = OrderTraits(**d)
+ notifications = {c: d for c in "abcdefghijkl"}
+ assert obj.notified == notifications
+
+
+###
+# Traits for Forward Declaration Tests
+###
+class ForwardDeclaredInstanceTrait(HasTraits):
+ value = ForwardDeclaredInstance["ForwardDeclaredBar"]("ForwardDeclaredBar", allow_none=True)
+
+
+class ForwardDeclaredTypeTrait(HasTraits):
+ value = ForwardDeclaredType[t.Any, t.Any]("ForwardDeclaredBar", allow_none=True)
+
+
+class ForwardDeclaredInstanceListTrait(HasTraits):
+ value = List(ForwardDeclaredInstance("ForwardDeclaredBar"))
+
+
+class ForwardDeclaredTypeListTrait(HasTraits):
+ value = List(ForwardDeclaredType("ForwardDeclaredBar"))
+
+
+###
+# End Traits for Forward Declaration Tests
+###
+
+
+###
+# Classes for Forward Declaration Tests
+###
+class ForwardDeclaredBar:
+ pass
+
+
+class ForwardDeclaredBarSub(ForwardDeclaredBar):
+ pass
+
+
+###
+# End Classes for Forward Declaration Tests
+###
+
+
+###
+# Forward Declaration Tests
+###
+class TestForwardDeclaredInstanceTrait(TraitTestBase):
+ obj = ForwardDeclaredInstanceTrait()
+ _default_value = None
+ _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
+ _bad_values = ["foo", 3, ForwardDeclaredBar, ForwardDeclaredBarSub]
+
+
+class TestForwardDeclaredTypeTrait(TraitTestBase):
+ obj = ForwardDeclaredTypeTrait()
+ _default_value = None
+ _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub]
+ _bad_values = ["foo", 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
+
+
+class TestForwardDeclaredInstanceList(TraitTestBase):
+ obj = ForwardDeclaredInstanceListTrait()
+
+ def test_klass(self):
+ """Test that the instance klass is properly assigned."""
+ self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar)
+
+ _default_value: t.List[t.Any] = []
+ _good_values = [
+ [ForwardDeclaredBar(), ForwardDeclaredBarSub()],
+ [],
+ ]
+ _bad_values = [
+ ForwardDeclaredBar(),
+ [ForwardDeclaredBar(), 3, None],
+ "1",
+ # Note that this is the type, not an instance.
+ [ForwardDeclaredBar],
+ [None],
+ None,
+ ]
+
+
+class TestForwardDeclaredTypeList(TraitTestBase):
+ obj = ForwardDeclaredTypeListTrait()
+
+ def test_klass(self):
+ """Test that the instance klass is properly assigned."""
+ self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar)
+
+ _default_value: t.List[t.Any] = []
+ _good_values = [
+ [ForwardDeclaredBar, ForwardDeclaredBarSub],
+ [],
+ ]
+ _bad_values = [
+ ForwardDeclaredBar,
+ [ForwardDeclaredBar, 3],
+ "1",
+ # Note that this is an instance, not the type.
+ [ForwardDeclaredBar()],
+ [None],
+ None,
+ ]
+
+
+###
+# End Forward Declaration Tests
+###
+
+
+class TestDynamicTraits(TestCase):
+ def setUp(self):
+ self._notify1 = []
+
+ def notify1(self, name, old, new):
+ self._notify1.append((name, old, new))
+
+ @t.no_type_check
+ def test_notify_all(self):
+ class A(HasTraits):
+ pass
+
+ a = A()
+ self.assertTrue(not hasattr(a, "x"))
+ self.assertTrue(not hasattr(a, "y"))
+
+ # Dynamically add trait x.
+ a.add_traits(x=Int())
+ self.assertTrue(hasattr(a, "x"))
+ self.assertTrue(isinstance(a, (A,)))
+
+ # Dynamically add trait y.
+ a.add_traits(y=Float())
+ self.assertTrue(hasattr(a, "y"))
+ self.assertTrue(isinstance(a, (A,)))
+ self.assertEqual(a.__class__.__name__, A.__name__)
+
+ # Create a new instance and verify that x and y
+ # aren't defined.
+ b = A()
+ self.assertTrue(not hasattr(b, "x"))
+ self.assertTrue(not hasattr(b, "y"))
+
+ # Verify that notification works like normal.
+ a.on_trait_change(self.notify1)
+ a.x = 0
+ self.assertEqual(len(self._notify1), 0)
+ a.y = 0.0
+ self.assertEqual(len(self._notify1), 0)
+ a.x = 10
+ self.assertTrue(("x", 0, 10) in self._notify1)
+ a.y = 10.0
+ self.assertTrue(("y", 0.0, 10.0) in self._notify1)
+ self.assertRaises(TraitError, setattr, a, "x", "bad string")
+ self.assertRaises(TraitError, setattr, a, "y", "bad string")
+ self._notify1 = []
+ a.on_trait_change(self.notify1, remove=True)
+ a.x = 20
+ a.y = 20.0
+ self.assertEqual(len(self._notify1), 0)
+
+
+def test_enum_no_default():
+ class C(HasTraits):
+ t = Enum(["a", "b"])
+
+ c = C()
+ c.t = "a"
+ assert c.t == "a"
+
+ c = C()
+
+ with pytest.raises(TraitError):
+ t = c.t
+
+ c = C(t="b")
+ assert c.t == "b"
+
+
+def test_default_value_repr():
+ class C(HasTraits):
+ t = Type("traitlets.HasTraits")
+ t2 = Type(HasTraits)
+ n = Integer(0)
+ lis = List()
+ d = Dict()
+
+ assert C.t.default_value_repr() == "'traitlets.HasTraits'"
+ assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'"
+ assert C.n.default_value_repr() == "0"
+ assert C.lis.default_value_repr() == "[]"
+ assert C.d.default_value_repr() == "{}"
+
+
+class TransitionalClass(HasTraits):
+ d = Any()
+
+ @default("d")
+ def _d_default(self):
+ return TransitionalClass
+
+ parent_super = False
+ calls_super = Integer(0)
+
+ @default("calls_super")
+ def _calls_super_default(self):
+ return -1
+
+ @observe("calls_super")
+ @observe_compat
+ def _calls_super_changed(self, change):
+ self.parent_super = change
+
+ parent_override = False
+ overrides = Integer(0)
+
+ @observe("overrides")
+ @observe_compat
+ def _overrides_changed(self, change):
+ self.parent_override = change
+
+
+class SubClass(TransitionalClass):
+ def _d_default(self):
+ return SubClass
+
+ subclass_super = False
+
+ def _calls_super_changed(self, name, old, new):
+ self.subclass_super = True
+ super()._calls_super_changed(name, old, new)
+
+ subclass_override = False
+
+ def _overrides_changed(self, name, old, new):
+ self.subclass_override = True
+
+
+def test_subclass_compat():
+ obj = SubClass()
+ obj.calls_super = 5
+ assert obj.parent_super
+ assert obj.subclass_super
+ obj.overrides = 5
+ assert obj.subclass_override
+ assert not obj.parent_override
+ assert obj.d is SubClass
+
+
+class DefinesHandler(HasTraits):
+ parent_called = False
+
+ trait = Integer()
+
+ @observe("trait")
+ def handler(self, change):
+ self.parent_called = True
+
+
+class OverridesHandler(DefinesHandler):
+ child_called = False
+
+ @observe("trait")
+ def handler(self, change):
+ self.child_called = True
+
+
+def test_subclass_override_observer():
+ obj = OverridesHandler()
+ obj.trait = 5
+ assert obj.child_called
+ assert not obj.parent_called
+
+
+class DoesntRegisterHandler(DefinesHandler):
+ child_called = False
+
+ def handler(self, change):
+ self.child_called = True
+
+
+def test_subclass_override_not_registered():
+ """Subclass that overrides observer and doesn't re-register unregisters both"""
+ obj = DoesntRegisterHandler()
+ obj.trait = 5
+ assert not obj.child_called
+ assert not obj.parent_called
+
+
+class AddsHandler(DefinesHandler):
+ child_called = False
+
+ @observe("trait")
+ def child_handler(self, change):
+ self.child_called = True
+
+
+def test_subclass_add_observer():
+ obj = AddsHandler()
+ obj.trait = 5
+ assert obj.child_called
+ assert obj.parent_called
+
+
+def test_observe_iterables():
+ class C(HasTraits):
+ i = Integer()
+ s = Unicode()
+
+ c = C()
+ recorded = {}
+
+ def record(change):
+ recorded["change"] = change
+
+ # observe with names=set
+ c.observe(record, names={"i", "s"})
+ c.i = 5
+ assert recorded["change"].name == "i"
+ assert recorded["change"].new == 5
+ c.s = "hi"
+ assert recorded["change"].name == "s"
+ assert recorded["change"].new == "hi"
+
+ # observe with names=custom container with iter, contains
+ class MyContainer:
+ def __init__(self, container):
+ self.container = container
+
+ def __iter__(self):
+ return iter(self.container)
+
+ def __contains__(self, key):
+ return key in self.container
+
+ c.observe(record, names=MyContainer({"i", "s"}))
+ c.i = 10
+ assert recorded["change"].name == "i"
+ assert recorded["change"].new == 10
+ c.s = "ok"
+ assert recorded["change"].name == "s"
+ assert recorded["change"].new == "ok"
+
+
+def test_super_args():
+ class SuperRecorder:
+ def __init__(self, *args, **kwargs):
+ self.super_args = args
+ self.super_kwargs = kwargs
+
+ class SuperHasTraits(HasTraits, SuperRecorder):
+ i = Integer()
+
+ obj = SuperHasTraits("a1", "a2", b=10, i=5, c="x")
+ assert obj.i == 5
+ assert not hasattr(obj, "b")
+ assert not hasattr(obj, "c")
+ assert obj.super_args == ("a1", "a2")
+ assert obj.super_kwargs == {"b": 10, "c": "x"}
+
+
+def test_super_bad_args():
+ class SuperHasTraits(HasTraits):
+ a = Integer()
+
+ w = ["Passing unrecognized arguments"]
+ with expected_warnings(w):
+ obj = SuperHasTraits(a=1, b=2)
+ assert obj.a == 1
+ assert not hasattr(obj, "b")
+
+
+def test_default_mro():
+ """Verify that default values follow mro"""
+
+ class Base(HasTraits):
+ trait = Unicode("base")
+ attr = "base"
+
+ class A(Base):
+ pass
+
+ class B(Base):
+ trait = Unicode("B")
+ attr = "B"
+
+ class AB(A, B):
+ pass
+
+ class BA(B, A):
+ pass
+
+ assert A().trait == "base"
+ assert A().attr == "base"
+ assert BA().trait == "B"
+ assert BA().attr == "B"
+ assert AB().trait == "B"
+ assert AB().attr == "B"
+
+
+def test_cls_self_argument():
+ class X(HasTraits):
+ def __init__(__self, cls, self): # noqa
+ pass
+
+ x = X(cls=None, self=None)
+
+
+def test_override_default():
+ class C(HasTraits):
+ a = Unicode("hard default")
+
+ def _a_default(self):
+ return "default method"
+
+ C._a_default = lambda self: "overridden" # type:ignore
+ c = C()
+ assert c.a == "overridden"
+
+
+def test_override_default_decorator():
+ class C(HasTraits):
+ a = Unicode("hard default")
+
+ @default("a")
+ def _a_default(self):
+ return "default method"
+
+ C._a_default = lambda self: "overridden" # type:ignore
+ c = C()
+ assert c.a == "overridden"
+
+
+def test_override_default_instance():
+ class C(HasTraits):
+ a = Unicode("hard default")
+
+ @default("a")
+ def _a_default(self):
+ return "default method"
+
+ c = C()
+ c._a_default = lambda self: "overridden"
+ assert c.a == "overridden"
+
+
+def test_copy_HasTraits():
+ from copy import copy
+
+ class C(HasTraits):
+ a = Int()
+
+ c = C(a=1)
+ assert c.a == 1
+
+ cc = copy(c)
+ cc.a = 2
+ assert cc.a == 2
+ assert c.a == 1
+
+
+def _from_string_test(traittype, s, expected):
+ """Run a test of trait.from_string"""
+ if isinstance(traittype, TraitType):
+ trait = traittype
+ else:
+ trait = traittype(allow_none=True)
+ if isinstance(s, list):
+ cast = trait.from_string_list # type:ignore
+ else:
+ cast = trait.from_string
+ if type(expected) is type and issubclass(expected, Exception):
+ with pytest.raises(expected):
+ value = cast(s)
+ trait.validate(CrossValidationStub(), value) # type:ignore
+ else:
+ value = cast(s)
+ assert value == expected
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)],
+)
+def test_unicode_from_string(s, expected):
+ _from_string_test(Unicode, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)],
+)
+def test_cunicode_from_string(s, expected):
+ _from_string_test(CUnicode, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)],
+)
+def test_bytes_from_string(s, expected):
+ _from_string_test(Bytes, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)],
+)
+def test_cbytes_from_string(s, expected):
+ _from_string_test(CBytes, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)],
+)
+def test_int_from_string(s, expected):
+ _from_string_test(Integer, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)],
+)
+def test_float_from_string(s, expected):
+ _from_string_test(Float, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [
+ ("x", ValueError),
+ ("1", 1.0),
+ ("123.5", 123.5),
+ ("2.5", 2.5),
+ ("1+2j", 1 + 2j),
+ ("None", None),
+ ],
+)
+def test_complex_from_string(s, expected):
+ _from_string_test(Complex, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [
+ ("true", True),
+ ("TRUE", True),
+ ("1", True),
+ ("0", False),
+ ("False", False),
+ ("false", False),
+ ("1.0", ValueError),
+ ("None", None),
+ ],
+)
+def test_bool_from_string(s, expected):
+ _from_string_test(Bool, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [
+ ("{}", {}),
+ ("1", TraitError),
+ ("{1: 2}", {1: 2}),
+ ('{"key": "value"}', {"key": "value"}),
+ ("x", TraitError),
+ ("None", None),
+ ],
+)
+def test_dict_from_string(s, expected):
+ _from_string_test(Dict, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [
+ ("[]", []),
+ ('[1, 2, "x"]', [1, 2, "x"]),
+ (["1", "x"], ["1", "x"]),
+ (["None"], None),
+ ],
+)
+def test_list_from_string(s, expected):
+ _from_string_test(List, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected, value_trait",
+ [
+ (["1", "2", "3"], [1, 2, 3], Integer()),
+ (["x"], ValueError, Integer()),
+ (["1", "x"], ["1", "x"], Unicode()),
+ (["None"], [None], Unicode(allow_none=True)),
+ (["None"], ["None"], Unicode(allow_none=False)),
+ ],
+)
+def test_list_items_from_string(s, expected, value_trait):
+ _from_string_test(List(value_trait), s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [
+ ("[]", set()),
+ ('[1, 2, "x"]', {1, 2, "x"}),
+ ('{1, 2, "x"}', {1, 2, "x"}),
+ (["1", "x"], {"1", "x"}),
+ (["None"], None),
+ ],
+)
+def test_set_from_string(s, expected):
+ _from_string_test(Set, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected, value_trait",
+ [
+ (["1", "2", "3"], {1, 2, 3}, Integer()),
+ (["x"], ValueError, Integer()),
+ (["1", "x"], {"1", "x"}, Unicode()),
+ (["None"], {None}, Unicode(allow_none=True)),
+ ],
+)
+def test_set_items_from_string(s, expected, value_trait):
+ _from_string_test(Set(value_trait), s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [
+ ("[]", ()),
+ ("()", ()),
+ ('[1, 2, "x"]', (1, 2, "x")),
+ ('(1, 2, "x")', (1, 2, "x")),
+ (["1", "x"], ("1", "x")),
+ (["None"], None),
+ ],
+)
+def test_tuple_from_string(s, expected):
+ _from_string_test(Tuple, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected, value_traits",
+ [
+ (["1", "2", "3"], (1, 2, 3), [Integer(), Integer(), Integer()]),
+ (["x"], ValueError, [Integer()]),
+ (["1", "x"], ("1", "x"), [Unicode()]),
+ (["None"], ("None",), [Unicode(allow_none=False)]),
+ (["None"], (None,), [Unicode(allow_none=True)]),
+ ],
+)
+def test_tuple_items_from_string(s, expected, value_traits):
+ _from_string_test(Tuple(*value_traits), s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [
+ ("x", "x"),
+ ("mod.submod", "mod.submod"),
+ ("not an identifier", TraitError),
+ ("1", "1"),
+ ("None", None),
+ ],
+)
+def test_object_from_string(s, expected):
+ _from_string_test(DottedObjectName, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [
+ ("127.0.0.1:8000", ("127.0.0.1", 8000)),
+ ("host.tld:80", ("host.tld", 80)),
+ ("host:notaport", ValueError),
+ ("127.0.0.1", ValueError),
+ ("None", None),
+ ],
+)
+def test_tcp_from_string(s, expected):
+ _from_string_test(TCPAddress, s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [("[]", []), ("{}", "{}")],
+)
+def test_union_of_list_and_unicode_from_string(s, expected):
+ _from_string_test(Union([List(), Unicode()]), s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected",
+ [("1", 1), ("1.5", 1.5)],
+)
+def test_union_of_int_and_float_from_string(s, expected):
+ _from_string_test(Union([Int(), Float()]), s, expected)
+
+
+@pytest.mark.parametrize(
+ "s, expected, allow_none",
+ [("[]", [], False), ("{}", {}, False), ("None", TraitError, False), ("None", None, True)],
+)
+def test_union_of_list_and_dict_from_string(s, expected, allow_none):
+ _from_string_test(Union([List(), Dict()], allow_none=allow_none), s, expected)
+
+
+def test_all_attribute():
+ """Verify all trait types are added to `traitlets.__all__`"""
+ names = dir(traitlets)
+ for name in names:
+ value = getattr(traitlets, name)
+ if not name.startswith("_") and isinstance(value, type) and issubclass(value, TraitType):
+ if name not in traitlets.__all__:
+ raise ValueError(f"{name} not in __all__")
+
+ for name in traitlets.__all__:
+ if name not in names:
+ raise ValueError(f"{name} should be removed from __all__")
diff --git a/contrib/python/traitlets/py3/tests/test_traitlets_docstring.py b/contrib/python/traitlets/py3/tests/test_traitlets_docstring.py
new file mode 100644
index 00000000000..700199108f1
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/test_traitlets_docstring.py
@@ -0,0 +1,84 @@
+"""Tests for traitlets.traitlets."""
+
+# Copyright (c) IPython Development Team.
+# Distributed under the terms of the Modified BSD License.
+#
+from traitlets import Dict, Instance, Integer, Unicode, Union
+from traitlets.config import Configurable
+
+
+def test_handle_docstring():
+ class SampleConfigurable(Configurable):
+ pass
+
+ class TraitTypesSampleConfigurable(Configurable):
+ """TraitTypesSampleConfigurable docstring"""
+
+ trait_integer = Integer(
+ help="""trait_integer help text""",
+ config=True,
+ )
+ trait_integer_nohelp = Integer(
+ config=True,
+ )
+ trait_integer_noconfig = Integer(
+ help="""trait_integer_noconfig help text""",
+ )
+
+ trait_unicode = Unicode(
+ help="""trait_unicode help text""",
+ config=True,
+ )
+ trait_unicode_nohelp = Unicode(
+ config=True,
+ )
+ trait_unicode_noconfig = Unicode(
+ help="""trait_unicode_noconfig help text""",
+ )
+
+ trait_dict = Dict(
+ help="""trait_dict help text""",
+ config=True,
+ )
+ trait_dict_nohelp = Dict(
+ config=True,
+ )
+ trait_dict_noconfig = Dict(
+ help="""trait_dict_noconfig help text""",
+ )
+
+ trait_instance = Instance(
+ klass=SampleConfigurable,
+ help="""trait_instance help text""",
+ config=True,
+ )
+ trait_instance_nohelp = Instance(
+ klass=SampleConfigurable,
+ config=True,
+ )
+ trait_instance_noconfig = Instance(
+ klass=SampleConfigurable,
+ help="""trait_instance_noconfig help text""",
+ )
+
+ trait_union = Union(
+ [Integer(), Unicode()],
+ help="""trait_union help text""",
+ config=True,
+ )
+ trait_union_nohelp = Union(
+ [Integer(), Unicode()],
+ config=True,
+ )
+ trait_union_noconfig = Union(
+ [Integer(), Unicode()],
+ help="""trait_union_noconfig help text""",
+ )
+
+ base_names = SampleConfigurable().trait_names()
+ for name in TraitTypesSampleConfigurable().trait_names():
+ if name in base_names:
+ continue
+ doc = getattr(TraitTypesSampleConfigurable, name).__doc__
+ if "nohelp" not in name:
+ assert doc == f"{name} help text"
diff --git a/contrib/python/traitlets/py3/tests/test_traitlets_enum.py b/contrib/python/traitlets/py3/tests/test_traitlets_enum.py
new file mode 100644
index 00000000000..c39007e8a05
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/test_traitlets_enum.py
@@ -0,0 +1,380 @@
+# pylint: disable=missing-docstring, too-few-public-methods
+"""
+Test the trait-type ``UseEnum``.
+"""
+
+import enum
+import unittest
+
+from traitlets import CaselessStrEnum, Enum, FuzzyEnum, HasTraits, TraitError, UseEnum
+
+# -----------------------------------------------------------------------------
+# TEST SUPPORT:
+# -----------------------------------------------------------------------------
+
+
+class Color(enum.Enum):
+ red = 1
+ green = 2
+ blue = 3
+ yellow = 4
+
+
+class OtherColor(enum.Enum):
+ red = 0
+ green = 1
+
+
+class CSColor(enum.Enum):
+ red = 1
+ Green = 2
+ BLUE = 3
+ YeLLoW = 4
+
+
+color_choices = "red Green BLUE YeLLoW".split()
+
+
+# -----------------------------------------------------------------------------
+# TESTSUITE:
+# -----------------------------------------------------------------------------
+class TestUseEnum(unittest.TestCase):
+ # pylint: disable=invalid-name
+
+ class Example(HasTraits):
+ color = UseEnum(Color, help="Color enum")
+
+ def test_assign_enum_value(self):
+ example = self.Example()
+ example.color = Color.green
+ self.assertEqual(example.color, Color.green)
+
+ def test_assign_all_enum_values(self):
+ # pylint: disable=no-member
+ enum_values = list(Color.__members__.values())
+ for value in enum_values:
+ self.assertIsInstance(value, Color)
+ example = self.Example()
+ example.color = value
+ self.assertEqual(example.color, value)
+ self.assertIsInstance(value, Color)
+
+ def test_assign_enum_value__with_other_enum_raises_error(self):
+ example = self.Example()
+ with self.assertRaises(TraitError):
+ example.color = OtherColor.green
+
+ def test_assign_enum_name_1(self):
+ # -- CONVERT: string => Enum value (item)
+ example = self.Example()
+ example.color = "red"
+ self.assertEqual(example.color, Color.red)
+
+ def test_assign_enum_value_name(self):
+ # -- CONVERT: string => Enum value (item)
+ # pylint: disable=no-member
+ enum_names = [enum_val.name for enum_val in Color.__members__.values()]
+ for value in enum_names:
+ self.assertIsInstance(value, str)
+ example = self.Example()
+ enum_value = Color.__members__.get(value)
+ example.color = value
+ self.assertIs(example.color, enum_value)
+ self.assertEqual(example.color.name, value) # type:ignore
+
+ def test_assign_scoped_enum_value_name(self):
+ # -- CONVERT: string => Enum value (item)
+ scoped_names = ["Color.red", "Color.green", "Color.blue", "Color.yellow"]
+ for value in scoped_names:
+ example = self.Example()
+ example.color = value
+ self.assertIsInstance(example.color, Color)
+ self.assertEqual(str(example.color), value)
+
+ def test_assign_bad_enum_value_name__raises_error(self):
+ # -- CONVERT: string => Enum value (item)
+ bad_enum_names = ["UNKNOWN_COLOR", "RED", "Green", "blue2"]
+ for value in bad_enum_names:
+ example = self.Example()
+ with self.assertRaises(TraitError):
+ example.color = value
+
+ def test_assign_enum_value_number_1(self):
+ # -- CONVERT: number => Enum value (item)
+ example = self.Example()
+ example.color = 1 # == Color.red.value
+ example.color = Color.red.value
+ self.assertEqual(example.color, Color.red)
+
+ def test_assign_enum_value_number(self):
+ # -- CONVERT: number => Enum value (item)
+ # pylint: disable=no-member
+ enum_numbers = [enum_val.value for enum_val in Color.__members__.values()]
+ for value in enum_numbers:
+ self.assertIsInstance(value, int)
+ example = self.Example()
+ example.color = value
+ self.assertIsInstance(example.color, Color)
+ self.assertEqual(example.color.value, value) # type:ignore
+
+ def test_assign_bad_enum_value_number__raises_error(self):
+ # -- CONVERT: number => Enum value (item)
+ bad_numbers = [-1, 0, 5]
+ for value in bad_numbers:
+ self.assertIsInstance(value, int)
+ assert UseEnum(Color).select_by_number(value, None) is None
+ example = self.Example()
+ with self.assertRaises(TraitError):
+ example.color = value
+
+ def test_ctor_without_default_value(self):
+ # -- IMPLICIT: default_value = Color.red (first enum-value)
+ class Example2(HasTraits):
+ color = UseEnum(Color)
+
+ example = Example2()
+ self.assertEqual(example.color, Color.red)
+
+ def test_ctor_with_default_value_as_enum_value(self):
+ # -- CONVERT: number => Enum value (item)
+ class Example2(HasTraits):
+ color = UseEnum(Color, default_value=Color.green)
+
+ example = Example2()
+ self.assertEqual(example.color, Color.green)
+
+ def test_ctor_with_default_value_none_and_not_allow_none(self):
+ # -- IMPLICIT: default_value = Color.red (first enum-value)
+ class Example2(HasTraits):
+ color1 = UseEnum(Color, default_value=None, allow_none=False)
+ color2 = UseEnum(Color, default_value=None)
+
+ example = Example2()
+ self.assertEqual(example.color1, Color.red)
+ self.assertEqual(example.color2, Color.red)
+
+ def test_ctor_with_default_value_none_and_allow_none(self):
+ class Example2(HasTraits):
+ color1 = UseEnum(Color, default_value=None, allow_none=True)
+ color2 = UseEnum(Color, allow_none=True)
+
+ example = Example2()
+ self.assertIs(example.color1, None)
+ self.assertIs(example.color2, None)
+
+ def test_assign_none_without_allow_none_resets_to_default_value(self):
+ class Example2(HasTraits):
+ color1 = UseEnum(Color, allow_none=False)
+ color2 = UseEnum(Color)
+
+ example = Example2()
+ example.color1 = None
+ example.color2 = None
+ self.assertIs(example.color1, Color.red)
+ self.assertIs(example.color2, Color.red)
+
+ def test_assign_none_to_enum_or_none(self):
+ class Example2(HasTraits):
+ color = UseEnum(Color, allow_none=True)
+
+ example = Example2()
+ example.color = None
+ self.assertIs(example.color, None)
+
+ def test_assign_bad_value_with_to_enum_or_none(self):
+ class Example2(HasTraits):
+ color = UseEnum(Color, allow_none=True)
+
+ example = Example2()
+ with self.assertRaises(TraitError):
+ example.color = "BAD_VALUE"
+
+ def test_info(self):
+ choices = color_choices
+
+ class Example(HasTraits):
+ enum1 = Enum(choices, allow_none=False)
+ enum2 = CaselessStrEnum(choices, allow_none=False)
+ enum3 = FuzzyEnum(choices, allow_none=False)
+ enum4 = UseEnum(CSColor, allow_none=False)
+
+ for i in range(1, 5):
+ attr = "enum%s" % i
+ enum = getattr(Example, attr)
+
+ enum.allow_none = True
+
+ info = enum.info()
+ self.assertEqual(len(info.split(", ")), len(choices), info.split(", "))
+ self.assertIn("or None", info)
+
+ info = enum.info_rst()
+ self.assertEqual(len(info.split("|")), len(choices), info.split("|"))
+ self.assertIn("or `None`", info)
+ # Check no single `\` exists.
+ self.assertNotRegex(info, r"\b\\\b")
+
+ enum.allow_none = False
+
+ info = enum.info()
+ self.assertEqual(len(info.split(", ")), len(choices), info.split(", "))
+ self.assertNotIn("None", info)
+
+ info = enum.info_rst()
+ self.assertEqual(len(info.split("|")), len(choices), info.split("|"))
+ self.assertNotIn("None", info)
+ # Check no single `\` exists.
+ self.assertNotRegex(info, r"\b\\\b")
+
+
+# -----------------------------------------------------------------------------
+# TESTSUITE:
+# -----------------------------------------------------------------------------
+
+
+class TestFuzzyEnum(unittest.TestCase):
+ # Check mostly `validate()`, Ctor must be checked on generic `Enum`
+ # or `CaselessStrEnum`.
+
+ def test_search_all_prefixes__overwrite(self):
+ class FuzzyExample(HasTraits):
+ color = FuzzyEnum(color_choices, help="Color enum")
+
+ example = FuzzyExample()
+ for color in color_choices:
+ for wlen in range(1, len(color)):
+ value = color[:wlen]
+
+ example.color = value
+ self.assertEqual(example.color, color)
+
+ example.color = value.upper()
+ self.assertEqual(example.color, color)
+
+ example.color = value.lower()
+ self.assertEqual(example.color, color)
+
+ def test_search_all_prefixes__ctor(self):
+ class FuzzyExample(HasTraits):
+ color = FuzzyEnum(color_choices, help="Color enum")
+
+ for color in color_choices:
+ for wlen in range(1, len(color)):
+ value = color[:wlen]
+
+ example = FuzzyExample()
+ example.color = value
+ self.assertEqual(example.color, color)
+
+ example = FuzzyExample()
+ example.color = value.upper()
+ self.assertEqual(example.color, color)
+
+ example = FuzzyExample()
+ example.color = value.lower()
+ self.assertEqual(example.color, color)
+
+ def test_search_substrings__overwrite(self):
+ class FuzzyExample(HasTraits):
+ color = FuzzyEnum(color_choices, help="Color enum", substring_matching=True)
+
+ example = FuzzyExample()
+ for color in color_choices:
+ for wlen in range(0, 2):
+ value = color[wlen:]
+
+ example.color = value
+ self.assertEqual(example.color, color)
+
+ example.color = value.upper()
+ self.assertEqual(example.color, color)
+
+ example.color = value.lower()
+ self.assertEqual(example.color, color)
+
+ def test_search_substrings__ctor(self):
+ class FuzzyExample(HasTraits):
+ color = FuzzyEnum(color_choices, help="Color enum", substring_matching=True)
+
+ color = color_choices[-1] # 'YeLLoW'
+ for end in (-1, len(color)):
+ for start in range(1, len(color) - 2):
+ value = color[start:end]
+
+ example = FuzzyExample()
+ example.color = value
+ self.assertEqual(example.color, color)
+
+ example = FuzzyExample()
+ example.color = value.upper()
+ self.assertEqual(example.color, color)
+
+ def test_assign_other_raises(self):
+ def new_trait_class(case_sensitive, substring_matching):
+ class Example(HasTraits):
+ color = FuzzyEnum(
+ color_choices,
+ case_sensitive=case_sensitive,
+ substring_matching=substring_matching,
+ )
+
+ return Example
+
+ example = new_trait_class(case_sensitive=False, substring_matching=False)()
+ with self.assertRaises(TraitError):
+ example.color = ""
+ with self.assertRaises(TraitError):
+ example.color = "BAD COLOR"
+ with self.assertRaises(TraitError):
+ example.color = "ed"
+
+ example = new_trait_class(case_sensitive=True, substring_matching=False)()
+ with self.assertRaises(TraitError):
+ example.color = ""
+ with self.assertRaises(TraitError):
+ example.color = "Red" # not 'red'
+
+ example = new_trait_class(case_sensitive=True, substring_matching=True)()
+ with self.assertRaises(TraitError):
+ example.color = ""
+ with self.assertRaises(TraitError):
+ example.color = "BAD COLOR"
+ with self.assertRaises(TraitError):
+ example.color = "green" # not 'Green'
+ with self.assertRaises(TraitError):
+ example.color = "lue" # not (b)'LUE'
+ with self.assertRaises(TraitError):
+ example.color = "lUE" # not (b)'LUE'
+
+ example = new_trait_class(case_sensitive=False, substring_matching=True)()
+ with self.assertRaises(TraitError):
+ example.color = ""
+ with self.assertRaises(TraitError):
+ example.color = "BAD COLOR"
+
+ def test_ctor_with_default_value(self):
+ def new_trait_class(default_value, case_sensitive, substring_matching):
+ class Example(HasTraits):
+ color = FuzzyEnum(
+ color_choices,
+ default_value=default_value,
+ case_sensitive=case_sensitive,
+ substring_matching=substring_matching,
+ )
+
+ return Example
+
+ for color in color_choices:
+ example = new_trait_class(color, False, False)()
+ self.assertEqual(example.color, color)
+
+ example = new_trait_class(color.upper(), False, False)()
+ self.assertEqual(example.color, color)
+
+ color = color_choices[-1] # 'YeLLoW'
+ example = new_trait_class(color, True, False)()
+ self.assertEqual(example.color, color)
+
+ # FIXME: default value not validated!
+ # with self.assertRaises(TraitError):
+ # example = new_trait_class(color.lower(), True, False)
diff --git a/contrib/python/traitlets/py3/tests/test_typing.py b/contrib/python/traitlets/py3/tests/test_typing.py
new file mode 100644
index 00000000000..2b4073ecf72
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/test_typing.py
@@ -0,0 +1,395 @@
+from __future__ import annotations
+
+import typing
+
+import pytest
+
+from traitlets import (
+ Any,
+ Bool,
+ CInt,
+ Dict,
+ HasTraits,
+ Instance,
+ Int,
+ List,
+ Set,
+ TCPAddress,
+ Type,
+ Unicode,
+ Union,
+ default,
+ observe,
+ validate,
+)
+from traitlets.config import Config
+
+if not typing.TYPE_CHECKING:
+
+ def reveal_type(*args, **kwargs):
+ pass
+
+
+# mypy: disallow-untyped-calls
+
+
+class Foo:
+ def __init__(self, c):
+ self.c = c
+
+
+@pytest.mark.mypy_testing
+def mypy_decorator_typing():
+ class T(HasTraits):
+ foo = Unicode("").tag(config=True)
+
+ @default("foo")
+ def _default_foo(self) -> str:
+ return "hi"
+
+ @observe("foo")
+ def _foo_observer(self, change: typing.Any) -> bool:
+ return True
+
+ @validate("foo")
+ def _foo_validate(self, commit: typing.Any) -> bool:
+ return True
+
+ t = T()
+ reveal_type(t.foo) # R: builtins.str
+ reveal_type(t._foo_observer) # R: Any
+ reveal_type(t._foo_validate) # R: Any
+
+
+@pytest.mark.mypy_testing
+def mypy_config_typing():
+ c = Config(
+ {
+ "ExtractOutputPreprocessor": {"enabled": True},
+ }
+ )
+ reveal_type(c) # R: traitlets.config.loader.Config
+
+
+@pytest.mark.mypy_testing
+def mypy_union_typing():
+ class T(HasTraits):
+ style = Union(
+ [Unicode("default"), Type(klass=object)],
+ help="Name of the pygments style to use",
+ default_value="hi",
+ ).tag(config=True)
+
+ t = T()
+ reveal_type(Union("foo")) # R: traitlets.traitlets.Union
+ reveal_type(Union("").tag(sync=True)) # R: traitlets.traitlets.Union
+ reveal_type(Union(None, allow_none=True)) # R: traitlets.traitlets.Union
+ reveal_type(Union(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Union
+ reveal_type(T.style) # R: traitlets.traitlets.Union
+ reveal_type(t.style) # R: Any
+
+
+@pytest.mark.mypy_testing
+def mypy_list_typing():
+ class T(HasTraits):
+ latex_command = List(
+ ["xelatex", "{filename}", "-quiet"], help="Shell command used to compile latex."
+ ).tag(config=True)
+
+ t = T()
+ reveal_type(List("foo")) # R: traitlets.traitlets.List
+ reveal_type(List("").tag(sync=True)) # R: traitlets.traitlets.List
+ reveal_type(List(None, allow_none=True)) # R: traitlets.traitlets.List
+ reveal_type(List(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.List
+ reveal_type(T.latex_command) # R: traitlets.traitlets.List
+ reveal_type(t.latex_command) # R: builtins.list[Any]
+
+
+@pytest.mark.mypy_testing
+def mypy_dict_typing():
+ class T(HasTraits):
+ foo = Dict({}, help="Shell command used to compile latex.").tag(config=True)
+
+ t = T()
+ reveal_type(Dict("foo")) # R: traitlets.traitlets.Dict
+ reveal_type(Dict("").tag(sync=True)) # R: traitlets.traitlets.Dict
+ reveal_type(Dict(None, allow_none=True)) # R: traitlets.traitlets.Dict
+ reveal_type(Dict(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Dict
+ reveal_type(T.foo) # R: traitlets.traitlets.Dict
+ reveal_type(t.foo) # R: builtins.dict[Any, Any]
+
+
+@pytest.mark.mypy_testing
+def mypy_type_typing():
+ class KernelSpec:
+ item = Unicode("foo")
+
+ class KernelSpecManager(HasTraits):
+ """A manager for kernel specs."""
+
+ kernel_spec_class = Type(
+ KernelSpec,
+ config=True,
+ help="""The kernel spec class. This is configurable to allow
+ subclassing of the KernelSpecManager for customized behavior.
+ """,
+ )
+ other_class = Type("foo.bar.baz")
+
+ t = KernelSpecManager()
+ reveal_type(t.kernel_spec_class) # R: def () -> tests.test_typing.KernelSpec@124
+ reveal_type(t.kernel_spec_class()) # R: tests.test_typing.KernelSpec@124
+ reveal_type(t.kernel_spec_class().item) # R: builtins.str
+ reveal_type(t.other_class) # R: builtins.type
+ reveal_type(t.other_class()) # R: Any
+
+
+@pytest.mark.mypy_testing
+def mypy_unicode_typing():
+ class T(HasTraits):
+ export_format = Unicode(
+ allow_none=False,
+ help="""The export format to be used, either one of the built-in formats
+ or a dotted object name that represents the import path for an
+ ``Exporter`` class""",
+ ).tag(config=True)
+
+ t = T()
+ reveal_type(
+ Unicode( # R: traitlets.traitlets.Unicode[builtins.str, Union[builtins.str, builtins.bytes]]
+ "foo"
+ )
+ )
+ reveal_type(
+ Unicode( # R: traitlets.traitlets.Unicode[builtins.str, Union[builtins.str, builtins.bytes]]
+ ""
+ ).tag(
+ sync=True
+ )
+ )
+ reveal_type(
+ Unicode( # R: traitlets.traitlets.Unicode[Union[builtins.str, None], Union[builtins.str, builtins.bytes, None]]
+ None, allow_none=True
+ )
+ )
+ reveal_type(
+ Unicode( # R: traitlets.traitlets.Unicode[Union[builtins.str, None], Union[builtins.str, builtins.bytes, None]]
+ None, allow_none=True
+ ).tag(
+ sync=True
+ )
+ )
+ reveal_type(
+ T.export_format # R: traitlets.traitlets.Unicode[builtins.str, Union[builtins.str, builtins.bytes]]
+ )
+ reveal_type(t.export_format) # R: builtins.str
+
+
+@pytest.mark.mypy_testing
+def mypy_set_typing():
+ class T(HasTraits):
+ remove_cell_tags = Set(
+ Unicode(),
+ default_value=[],
+ help=(
+ "Tags indicating which cells are to be removed,"
+ "matches tags in ``cell.metadata.tags``."
+ ),
+ ).tag(config=True)
+
+ safe_output_keys = Set(
+ config=True,
+ default_value={
+ "metadata", # Not a mimetype per-se, but expected and safe.
+ "text/plain",
+ "text/latex",
+ "application/json",
+ "image/png",
+ "image/jpeg",
+ },
+ help="Cell output mimetypes to render without modification",
+ )
+
+ t = T()
+ reveal_type(Set("foo")) # R: traitlets.traitlets.Set
+ reveal_type(Set("").tag(sync=True)) # R: traitlets.traitlets.Set
+ reveal_type(Set(None, allow_none=True)) # R: traitlets.traitlets.Set
+ reveal_type(Set(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Set
+ reveal_type(T.remove_cell_tags) # R: traitlets.traitlets.Set
+ reveal_type(t.remove_cell_tags) # R: builtins.set[Any]
+ reveal_type(T.safe_output_keys) # R: traitlets.traitlets.Set
+ reveal_type(t.safe_output_keys) # R: builtins.set[Any]
+
+
+@pytest.mark.mypy_testing
+def mypy_any_typing():
+ class T(HasTraits):
+ attributes = Any(
+ config=True,
+ default_value={
+ "a": ["href", "title"],
+ "abbr": ["title"],
+ "acronym": ["title"],
+ },
+ help="Allowed HTML tag attributes",
+ )
+
+ t = T()
+ reveal_type(Any("foo")) # R: traitlets.traitlets.Any
+ reveal_type(Any("").tag(sync=True)) # R: traitlets.traitlets.Any
+ reveal_type(Any(None, allow_none=True)) # R: traitlets.traitlets.Any
+ reveal_type(Any(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Any
+ reveal_type(T.attributes) # R: traitlets.traitlets.Any
+ reveal_type(t.attributes) # R: Any
+
+
+@pytest.mark.mypy_testing
+def mypy_bool_typing():
+ class T(HasTraits):
+ b = Bool(True).tag(sync=True)
+ ob = Bool(None, allow_none=True).tag(sync=True)
+
+ t = T()
+ reveal_type(
+ Bool(True) # R: traitlets.traitlets.Bool[builtins.bool, Union[builtins.bool, builtins.int]]
+ )
+ reveal_type(
+ Bool( # R: traitlets.traitlets.Bool[builtins.bool, Union[builtins.bool, builtins.int]]
+ True
+ ).tag(sync=True)
+ )
+ reveal_type(
+ Bool( # R: traitlets.traitlets.Bool[Union[builtins.bool, None], Union[builtins.bool, builtins.int, None]]
+ None, allow_none=True
+ )
+ )
+ reveal_type(
+ Bool( # R: traitlets.traitlets.Bool[Union[builtins.bool, None], Union[builtins.bool, builtins.int, None]]
+ None, allow_none=True
+ ).tag(
+ sync=True
+ )
+ )
+ reveal_type(
+ T.b # R: traitlets.traitlets.Bool[builtins.bool, Union[builtins.bool, builtins.int]]
+ )
+ reveal_type(t.b) # R: builtins.bool
+ reveal_type(t.ob) # R: Union[builtins.bool, None]
+ reveal_type(
+ T.b # R: traitlets.traitlets.Bool[builtins.bool, Union[builtins.bool, builtins.int]]
+ )
+ reveal_type(
+ T.ob # R: traitlets.traitlets.Bool[Union[builtins.bool, None], Union[builtins.bool, builtins.int, None]]
+ )
+ # we would expect this to be Optional[Union[bool, int]], but...
+ t.b = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Union[bool, int]") [assignment]
+ t.b = None # E: Incompatible types in assignment (expression has type "None", variable has type "Union[bool, int]") [assignment]
+
+
+@pytest.mark.mypy_testing
+def mypy_int_typing():
+ class T(HasTraits):
+ i: Int[int, int] = Int(42).tag(sync=True)
+ oi: Int[int | None, int | None] = Int(42, allow_none=True).tag(sync=True)
+
+ t = T()
+ reveal_type(Int(True)) # R: traitlets.traitlets.Int[builtins.int, builtins.int]
+ reveal_type(Int(True).tag(sync=True)) # R: traitlets.traitlets.Int[builtins.int, builtins.int]
+ reveal_type(
+ Int( # R: traitlets.traitlets.Int[Union[builtins.int, None], Union[builtins.int, None]]
+ None, allow_none=True
+ )
+ )
+ reveal_type(
+ Int( # R: traitlets.traitlets.Int[Union[builtins.int, None], Union[builtins.int, None]]
+ None, allow_none=True
+ ).tag(sync=True)
+ )
+ reveal_type(T.i) # R: traitlets.traitlets.Int[builtins.int, builtins.int]
+ reveal_type(t.i) # R: builtins.int
+ reveal_type(t.oi) # R: Union[builtins.int, None]
+ reveal_type(T.i) # R: traitlets.traitlets.Int[builtins.int, builtins.int]
+ reveal_type(
+ T.oi # R: traitlets.traitlets.Int[Union[builtins.int, None], Union[builtins.int, None]]
+ )
+ t.i = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "int") [assignment]
+ t.i = None # E: Incompatible types in assignment (expression has type "None", variable has type "int") [assignment]
+ t.i = 1.2 # E: Incompatible types in assignment (expression has type "float", variable has type "int") [assignment]
+
+
+@pytest.mark.mypy_testing
+def mypy_cint_typing():
+ class T(HasTraits):
+ i = CInt(42).tag(sync=True)
+ oi = CInt(42, allow_none=True).tag(sync=True)
+
+ t = T()
+ reveal_type(CInt(42)) # R: traitlets.traitlets.CInt[builtins.int, Any]
+ reveal_type(CInt(42).tag(sync=True)) # R: traitlets.traitlets.CInt[builtins.int, Any]
+ reveal_type(
+ CInt(None, allow_none=True) # R: traitlets.traitlets.CInt[Union[builtins.int, None], Any]
+ )
+ reveal_type(
+ CInt( # R: traitlets.traitlets.CInt[Union[builtins.int, None], Any]
+ None, allow_none=True
+ ).tag(sync=True)
+ )
+ reveal_type(T.i) # R: traitlets.traitlets.CInt[builtins.int, Any]
+ reveal_type(t.i) # R: builtins.int
+ reveal_type(t.oi) # R: Union[builtins.int, None]
+ reveal_type(T.i) # R: traitlets.traitlets.CInt[builtins.int, Any]
+ reveal_type(T.oi) # R: traitlets.traitlets.CInt[Union[builtins.int, None], Any]
+
+
+@pytest.mark.mypy_testing
+def mypy_tcp_typing():
+ class T(HasTraits):
+ tcp = TCPAddress()
+ otcp = TCPAddress(None, allow_none=True)
+
+ t = T()
+ reveal_type(t.tcp) # R: Tuple[builtins.str, builtins.int]
+ reveal_type(
+ T.tcp # R: traitlets.traitlets.TCPAddress[Tuple[builtins.str, builtins.int], Tuple[builtins.str, builtins.int]]
+ )
+ reveal_type(
+ T.tcp.tag( # R:traitlets.traitlets.TCPAddress[Tuple[builtins.str, builtins.int], Tuple[builtins.str, builtins.int]]
+ sync=True
+ )
+ )
+ reveal_type(t.otcp) # R: Union[Tuple[builtins.str, builtins.int], None]
+ reveal_type(
+ T.otcp # R: traitlets.traitlets.TCPAddress[Union[Tuple[builtins.str, builtins.int], None], Union[Tuple[builtins.str, builtins.int], None]]
+ )
+ reveal_type(
+ T.otcp.tag( # R: traitlets.traitlets.TCPAddress[Union[Tuple[builtins.str, builtins.int], None], Union[Tuple[builtins.str, builtins.int], None]]
+ sync=True
+ )
+ )
+ t.tcp = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Tuple[str, int]") [assignment]
+ t.otcp = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Optional[Tuple[str, int]]") [assignment]
+ t.tcp = None # E: Incompatible types in assignment (expression has type "None", variable has type "Tuple[str, int]") [assignment]
+
+
+@pytest.mark.mypy_testing
+def mypy_instance_typing():
+ class T(HasTraits):
+ inst = Instance(Foo)
+ oinst = Instance(Foo, allow_none=True)
+ oinst_string = Instance("Foo", allow_none=True)
+
+ t = T()
+ reveal_type(t.inst) # R: tests.test_typing.Foo
+ reveal_type(T.inst) # R: traitlets.traitlets.Instance[tests.test_typing.Foo]
+ reveal_type(T.inst.tag(sync=True)) # R: traitlets.traitlets.Instance[tests.test_typing.Foo]
+ reveal_type(t.oinst) # R: Union[tests.test_typing.Foo, None]
+ reveal_type(t.oinst_string) # R: Union[Any, None]
+ reveal_type(T.oinst) # R: traitlets.traitlets.Instance[Union[tests.test_typing.Foo, None]]
+ reveal_type(
+ T.oinst.tag( # R: traitlets.traitlets.Instance[Union[tests.test_typing.Foo, None]]
+ sync=True
+ )
+ )
+ t.inst = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Foo") [assignment]
+ t.oinst = "foo" # E: Incompatible types in assignment (expression has type "str", variable has type "Optional[Foo]") [assignment]
+ t.inst = None # E: Incompatible types in assignment (expression has type "None", variable has type "Foo") [assignment]
diff --git a/contrib/python/traitlets/py3/tests/utils/__init__.py b/contrib/python/traitlets/py3/tests/utils/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/utils/__init__.py
diff --git a/contrib/python/traitlets/py3/tests/utils/test_bunch.py b/contrib/python/traitlets/py3/tests/utils/test_bunch.py
new file mode 100644
index 00000000000..223124d7d5e
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/utils/test_bunch.py
@@ -0,0 +1,16 @@
+from traitlets.utils.bunch import Bunch
+
+
+def test_bunch():
+ b = Bunch(x=5, y=10)
+ assert "y" in b
+ assert "x" in b
+ assert b.x == 5
+ b["a"] = "hi"
+ assert b.a == "hi"
+
+
+def test_bunch_dir():
+ b = Bunch(x=5, y=10)
+ assert "x" in dir(b)
+ assert "keys" in dir(b)
diff --git a/contrib/python/traitlets/py3/tests/utils/test_decorators.py b/contrib/python/traitlets/py3/tests/utils/test_decorators.py
new file mode 100644
index 00000000000..d6bf8414e5a
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/utils/test_decorators.py
@@ -0,0 +1,137 @@
+from inspect import Parameter, signature
+from unittest import TestCase
+
+from traitlets import HasTraits, Int, Unicode
+from traitlets.utils.decorators import signature_has_traits
+
+
+class TestExpandSignature(TestCase):
+ def test_no_init(self):
+ @signature_has_traits
+ class Foo(HasTraits):
+ number1 = Int()
+ number2 = Int()
+ value = Unicode("Hello")
+
+ parameters = signature(Foo).parameters
+ parameter_names = list(parameters)
+
+ self.assertIs(parameters["args"].kind, Parameter.VAR_POSITIONAL)
+ self.assertEqual("args", parameter_names[0])
+
+ self.assertIs(parameters["number1"].kind, Parameter.KEYWORD_ONLY)
+ self.assertIs(parameters["number2"].kind, Parameter.KEYWORD_ONLY)
+ self.assertIs(parameters["value"].kind, Parameter.KEYWORD_ONLY)
+
+ self.assertIs(parameters["kwargs"].kind, Parameter.VAR_KEYWORD)
+ self.assertEqual("kwargs", parameter_names[-1])
+
+ f = Foo(number1=32, value="World")
+ self.assertEqual(f.number1, 32)
+ self.assertEqual(f.number2, 0)
+ self.assertEqual(f.value, "World")
+
+ def test_partial_init(self):
+ @signature_has_traits
+ class Foo(HasTraits):
+ number1 = Int()
+ number2 = Int()
+ value = Unicode("Hello")
+
+ def __init__(self, arg1, **kwargs):
+ self.arg1 = arg1
+
+ super().__init__(**kwargs)
+
+ parameters = signature(Foo).parameters
+ parameter_names = list(parameters)
+
+ self.assertIs(parameters["arg1"].kind, Parameter.POSITIONAL_OR_KEYWORD)
+ self.assertEqual("arg1", parameter_names[0])
+
+ self.assertIs(parameters["number1"].kind, Parameter.KEYWORD_ONLY)
+ self.assertIs(parameters["number2"].kind, Parameter.KEYWORD_ONLY)
+ self.assertIs(parameters["value"].kind, Parameter.KEYWORD_ONLY)
+
+ self.assertIs(parameters["kwargs"].kind, Parameter.VAR_KEYWORD)
+ self.assertEqual("kwargs", parameter_names[-1])
+
+ f = Foo(1, number1=32, value="World")
+ self.assertEqual(f.arg1, 1)
+ self.assertEqual(f.number1, 32)
+ self.assertEqual(f.number2, 0)
+ self.assertEqual(f.value, "World")
+
+ def test_duplicate_init(self):
+ @signature_has_traits
+ class Foo(HasTraits):
+ number1 = Int()
+ number2 = Int()
+
+ def __init__(self, number1, **kwargs):
+ self.test = number1
+
+ super().__init__(number1=number1, **kwargs)
+
+ parameters = signature(Foo).parameters
+ parameter_names = list(parameters)
+
+ self.assertListEqual(parameter_names, ["number1", "number2", "kwargs"])
+
+ f = Foo(number1=32, number2=36)
+ self.assertEqual(f.test, 32)
+ self.assertEqual(f.number1, 32)
+ self.assertEqual(f.number2, 36)
+
+ def test_full_init(self):
+ @signature_has_traits
+ class Foo(HasTraits):
+ number1 = Int()
+ number2 = Int()
+ value = Unicode("Hello")
+
+ def __init__(self, arg1, arg2=None, *pos_args, **kw_args):
+ self.arg1 = arg1
+ self.arg2 = arg2
+ self.pos_args = pos_args
+ self.kw_args = kw_args
+
+ super().__init__(*pos_args, **kw_args)
+
+ parameters = signature(Foo).parameters
+ parameter_names = list(parameters)
+
+ self.assertIs(parameters["arg1"].kind, Parameter.POSITIONAL_OR_KEYWORD)
+ self.assertEqual("arg1", parameter_names[0])
+
+ self.assertIs(parameters["arg2"].kind, Parameter.POSITIONAL_OR_KEYWORD)
+ self.assertEqual("arg2", parameter_names[1])
+
+ self.assertIs(parameters["pos_args"].kind, Parameter.VAR_POSITIONAL)
+ self.assertEqual("pos_args", parameter_names[2])
+
+ self.assertIs(parameters["number1"].kind, Parameter.KEYWORD_ONLY)
+ self.assertIs(parameters["number2"].kind, Parameter.KEYWORD_ONLY)
+ self.assertIs(parameters["value"].kind, Parameter.KEYWORD_ONLY)
+
+ self.assertIs(parameters["kw_args"].kind, Parameter.VAR_KEYWORD)
+ self.assertEqual("kw_args", parameter_names[-1])
+
+ f = Foo(1, 3, 45, "hey", number1=32, value="World")
+ self.assertEqual(f.arg1, 1)
+ self.assertEqual(f.arg2, 3)
+ self.assertTupleEqual(f.pos_args, (45, "hey"))
+ self.assertEqual(f.number1, 32)
+ self.assertEqual(f.number2, 0)
+ self.assertEqual(f.value, "World")
+
+ def test_no_kwargs(self):
+ with self.assertRaises(RuntimeError):
+
+ @signature_has_traits
+ class Foo(HasTraits):
+ number1 = Int()
+ number2 = Int()
+
+ def __init__(self, arg1, arg2=None):
+ pass
diff --git a/contrib/python/traitlets/py3/tests/utils/test_importstring.py b/contrib/python/traitlets/py3/tests/utils/test_importstring.py
new file mode 100644
index 00000000000..8ce28add41e
--- /dev/null
+++ b/contrib/python/traitlets/py3/tests/utils/test_importstring.py
@@ -0,0 +1,26 @@
+# Copyright (c) IPython Development Team.
+# Distributed under the terms of the Modified BSD License.
+#
+# Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
+# also under the terms of the Modified BSD License.
+"""Tests for traitlets.utils.importstring."""
+
+import os
+from unittest import TestCase
+
+from traitlets.utils.importstring import import_item
+
+
+class TestImportItem(TestCase):
+ def test_import_unicode(self):
+ self.assertIs(os, import_item("os"))
+ self.assertIs(os.path, import_item("os.path"))
+ self.assertIs(os.path.join, import_item("os.path.join"))
+
+ def test_bad_input(self):
+ class NotAString:
+ pass
+
+ msg = "import_item accepts strings, not '%s'." % NotAString
+ with self.assertRaisesRegex(TypeError, msg):
+ import_item(NotAString()) # type:ignore[arg-type]
diff --git a/contrib/python/traitlets/py3/tests/ya.make b/contrib/python/traitlets/py3/tests/ya.make
index 6a5cd7cf463..6ffd29993d5 100644
--- a/contrib/python/traitlets/py3/tests/ya.make
+++ b/contrib/python/traitlets/py3/tests/ya.make
@@ -1,20 +1,27 @@
PY3TEST()
PEERDIR(
+ contrib/python/argcomplete
contrib/python/traitlets
+ contrib/python/pytest-mock
)
-SRCDIR(contrib/python/traitlets/py3/traitlets)
-
TEST_SRCS(
- config/tests/test_application.py
- config/tests/test_configurable.py
- config/tests/test_loader.py
- tests/test_traitlets.py
- tests/test_traitlets_enum.py
- utils/tests/test_bunch.py
- utils/tests/test_decorators.py
- utils/tests/test_importstring.py
+ __init__.py
+ _warnings.py
+ config/__init__.py
+ config/test_application.py
+ config/test_argcomplete.py
+ config/test_configurable.py
+ config/test_loader.py
+ test_traitlets.py
+ test_traitlets_docstring.py
+ test_traitlets_enum.py
+ test_typing.py
+ utils/__init__.py
+ utils/test_bunch.py
+ utils/test_decorators.py
+ utils/test_importstring.py
)
NO_LINT()