diff options
| author | robot-contrib <[email protected]> | 2023-10-19 17:11:31 +0300 |
|---|---|---|
| committer | robot-contrib <[email protected]> | 2023-10-19 18:26:04 +0300 |
| commit | b9fe236a503791a3a7b37d4ef5f466225218996c (patch) | |
| tree | c2f80019399b393ddf0450d0f91fc36478af8bea /contrib/python/traitlets/py3/tests/test_traitlets.py | |
| parent | 44dd27d0a2ae37c80d97a95581951d1d272bd7df (diff) | |
Update contrib/python/traitlets/py3 to 5.11.2
Diffstat (limited to 'contrib/python/traitlets/py3/tests/test_traitlets.py')
| -rw-r--r-- | contrib/python/traitlets/py3/tests/test_traitlets.py | 3141 |
1 files changed, 3141 insertions, 0 deletions
diff --git a/contrib/python/traitlets/py3/tests/test_traitlets.py b/contrib/python/traitlets/py3/tests/test_traitlets.py new file mode 100644 index 00000000000..62fa726f19b --- /dev/null +++ b/contrib/python/traitlets/py3/tests/test_traitlets.py @@ -0,0 +1,3141 @@ +"""Tests for traitlets.traitlets.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +# +# Adapted from enthought.traits, Copyright (c) Enthought, Inc., +# also under the terms of the Modified BSD License. + +import pickle +import re +import typing as t +from unittest import TestCase + +import pytest + +from traitlets import ( + All, + Any, + BaseDescriptor, + Bool, + Bytes, + Callable, + CBytes, + CFloat, + CInt, + CLong, + Complex, + CRegExp, + CUnicode, + Dict, + DottedObjectName, + Enum, + Float, + ForwardDeclaredInstance, + ForwardDeclaredType, + HasDescriptors, + HasTraits, + Instance, + Int, + Integer, + List, + Long, + MetaHasTraits, + ObjectName, + Set, + TCPAddress, + This, + TraitError, + TraitType, + Tuple, + Type, + Undefined, + Unicode, + Union, + default, + directional_link, + link, + observe, + observe_compat, + traitlets, + validate, +) +from traitlets.utils import cast_unicode + +from ._warnings import expected_warnings + + +def change_dict(*ordered_values): + change_names = ("name", "old", "new", "owner", "type") + return dict(zip(change_names, ordered_values)) + + +# ----------------------------------------------------------------------------- +# Helper classes for testing +# ----------------------------------------------------------------------------- + + +class HasTraitsStub(HasTraits): + def notify_change(self, change): + self._notify_name = change["name"] + self._notify_old = change["old"] + self._notify_new = change["new"] + self._notify_type = change["type"] + + +class CrossValidationStub(HasTraits): + _cross_validation_lock = False + + +# ----------------------------------------------------------------------------- +# Test classes +# ----------------------------------------------------------------------------- + + +class TestTraitType(TestCase): + def test_get_undefined(self): + class A(HasTraits): + a = TraitType + + a = A() + assert a.a is Undefined # type:ignore + + def test_set(self): + class A(HasTraitsStub): + a = TraitType + + a = A() + a.a = 10 # type:ignore + self.assertEqual(a.a, 10) + self.assertEqual(a._notify_name, "a") + self.assertEqual(a._notify_old, Undefined) + self.assertEqual(a._notify_new, 10) + + def test_validate(self): + class MyTT(TraitType[int, int]): + def validate(self, inst, value): + return -1 + + class A(HasTraitsStub): + tt = MyTT + + a = A() + a.tt = 10 # type:ignore + self.assertEqual(a.tt, -1) + + a = A(tt=11) + self.assertEqual(a.tt, -1) + + def test_default_validate(self): + class MyIntTT(TraitType[int, int]): + def validate(self, obj, value): + if isinstance(value, int): + return value + self.error(obj, value) + + class A(HasTraits): + tt = MyIntTT(10) + + a = A() + self.assertEqual(a.tt, 10) + + # Defaults are validated when the HasTraits is instantiated + class B(HasTraits): + tt = MyIntTT("bad default") + + self.assertRaises(TraitError, getattr, B(), "tt") + + def test_info(self): + class A(HasTraits): + tt = TraitType + + a = A() + self.assertEqual(A.tt.info(), "any value") # type:ignore + + def test_error(self): + class A(HasTraits): + tt = TraitType[int, int]() + + a = A() + self.assertRaises(TraitError, A.tt.error, a, 10) + + def test_deprecated_dynamic_initializer(self): + class A(HasTraits): + x = Int(10) + + def _x_default(self): + return 11 + + class B(A): + x = Int(20) + + class C(A): + def _x_default(self): + return 21 + + a = A() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + b = B() + self.assertEqual(b.x, 20) + self.assertEqual(b._trait_values, {"x": 20}) + c = C() + self.assertEqual(c._trait_values, {}) + self.assertEqual(c.x, 21) + self.assertEqual(c._trait_values, {"x": 21}) + # Ensure that the base class remains unmolested when the _default + # initializer gets overridden in a subclass. + a = A() + c = C() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + + def test_deprecated_method_warnings(self): + with expected_warnings([]): + + class ShouldntWarn(HasTraits): + x = Integer() + + @default("x") + def _x_default(self): + return 10 + + @validate("x") + def _x_validate(self, proposal): + return proposal.value + + @observe("x") + def _x_changed(self, change): + pass + + obj = ShouldntWarn() + obj.x = 5 + + assert obj.x == 5 + + with expected_warnings(["@validate", "@observe"]) as w: + + class ShouldWarn(HasTraits): + x = Integer() + + def _x_default(self): + return 10 + + def _x_validate(self, value, _): + return value + + def _x_changed(self): + pass + + obj = ShouldWarn() # type:ignore + obj.x = 5 + + assert obj.x == 5 + + def test_dynamic_initializer(self): + class A(HasTraits): + x = Int(10) + + @default("x") + def _default_x(self): + return 11 + + class B(A): + x = Int(20) + + class C(A): + @default("x") + def _default_x(self): + return 21 + + a = A() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + b = B() + self.assertEqual(b.x, 20) + self.assertEqual(b._trait_values, {"x": 20}) + c = C() + self.assertEqual(c._trait_values, {}) + self.assertEqual(c.x, 21) + self.assertEqual(c._trait_values, {"x": 21}) + # Ensure that the base class remains unmolested when the _default + # initializer gets overridden in a subclass. + a = A() + c = C() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + + def test_tag_metadata(self): + class MyIntTT(TraitType[int, int]): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10).tag(b=3, c=4) + self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4}) + + def test_metadata_localized_instance(self): + class MyIntTT(TraitType[int, int]): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10) + b = MyIntTT(10) + a.metadata["c"] = 3 + # make sure that changing a's metadata didn't change b's metadata + self.assertNotIn("c", b.metadata) + + def test_union_metadata(self): + class Foo(HasTraits): + bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti="b")).tag(ti="a") + + foo = Foo() + # At this point, no value has been set for bar, so value-specific + # is not set. + self.assertEqual(foo.trait_metadata("bar", "ta"), None) + self.assertEqual(foo.trait_metadata("bar", "ti"), "a") + foo.bar = {} + self.assertEqual(foo.trait_metadata("bar", "ta"), 2) + self.assertEqual(foo.trait_metadata("bar", "ti"), "b") + foo.bar = 1 + self.assertEqual(foo.trait_metadata("bar", "ta"), 1) + self.assertEqual(foo.trait_metadata("bar", "ti"), "a") + + def test_union_default_value(self): + class Foo(HasTraits): + bar = Union([Dict(), Int()], default_value=1) + + foo = Foo() + self.assertEqual(foo.bar, 1) + + def test_union_validation_priority(self): + class Foo(HasTraits): + bar = Union([CInt(), Unicode()]) + + foo = Foo() + foo.bar = "1" + # validation in order of the TraitTypes given + self.assertEqual(foo.bar, 1) + + def test_union_trait_default_value(self): + class Foo(HasTraits): + bar = Union([Dict(), Int()]) + + self.assertEqual(Foo().bar, {}) + + def test_deprecated_metadata_access(self): + class MyIntTT(TraitType[int, int]): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10) + with expected_warnings(["use the instance .metadata dictionary directly"] * 2): + a.set_metadata("key", "value") + v = a.get_metadata("key") + self.assertEqual(v, "value") + with expected_warnings(["use the instance .help string directly"] * 2): + a.set_metadata("help", "some help") + v = a.get_metadata("help") + self.assertEqual(v, "some help") + + def test_trait_types_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Int + + def test_trait_types_list_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = List(Int) + + def test_trait_types_tuple_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Tuple(Int) + + def test_trait_types_dict_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Dict(Int) + + +class TestHasDescriptorsMeta(TestCase): + def test_metaclass(self): + self.assertEqual(type(HasTraits), MetaHasTraits) + + class A(HasTraits): + a = Int() + + a = A() + self.assertEqual(type(a.__class__), MetaHasTraits) + self.assertEqual(a.a, 0) + a.a = 10 + self.assertEqual(a.a, 10) + + class B(HasTraits): + b = Int() + + b = B() + self.assertEqual(b.b, 0) + b.b = 10 + self.assertEqual(b.b, 10) + + class C(HasTraits): + c = Int(30) + + c = C() + self.assertEqual(c.c, 30) + c.c = 10 + self.assertEqual(c.c, 10) + + def test_this_class(self): + class A(HasTraits): + t = This["A"]() + tt = This["A"]() + + class B(A): + tt = This["A"]() + ttt = This["A"]() + + self.assertEqual(A.t.this_class, A) + self.assertEqual(B.t.this_class, A) + self.assertEqual(B.tt.this_class, B) + self.assertEqual(B.ttt.this_class, B) + + +class TestHasDescriptors(TestCase): + def test_setup_instance(self): + class FooDescriptor(BaseDescriptor): + def instance_init(self, inst): + foo = inst.foo # instance should have the attr + + class HasFooDescriptors(HasDescriptors): + fd = FooDescriptor() + + def setup_instance(self, *args, **kwargs): + self.foo = kwargs.get("foo", None) + super().setup_instance(*args, **kwargs) + + hfd = HasFooDescriptors(foo="bar") + + +class TestHasTraitsNotify(TestCase): + def setUp(self): + self._notify1 = [] + self._notify2 = [] + + def notify1(self, name, old, new): + self._notify1.append((name, old, new)) + + def notify2(self, name, old, new): + self._notify2.append((name, old, new)) + + def test_notify_all(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.on_trait_change(self.notify1) + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.b = 0.0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in self._notify1) + a.b = 10.0 + self.assertTrue(("b", 0.0, 10.0) in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + self.assertRaises(TraitError, setattr, a, "b", "bad string") + self._notify1 = [] + a.on_trait_change(self.notify1, remove=True) + a.a = 20 + a.b = 20.0 + self.assertEqual(len(self._notify1), 0) + + def test_notify_one(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.on_trait_change(self.notify1, "a") + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + + def test_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + self.assertEqual(b.a, 0) + self.assertEqual(b.b, 0.0) + b.a = 100 + b.b = 100.0 + self.assertEqual(b.a, 100) + self.assertEqual(b.b, 100.0) + + def test_notify_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + b.on_trait_change(self.notify1, "a") + b.on_trait_change(self.notify2, "b") + b.a = 0 + b.b = 0.0 + self.assertEqual(len(self._notify1), 0) + self.assertEqual(len(self._notify2), 0) + b.a = 10 + b.b = 10.0 + self.assertTrue(("a", 0, 10) in self._notify1) + self.assertTrue(("b", 0.0, 10.0) in self._notify2) + + def test_static_notify(self): + class A(HasTraits): + a = Int() + _notify1 = [] + + def _a_changed(self, name, old, new): + self._notify1.append((name, old, new)) + + a = A() + a.a = 0 + # This is broken!!! + self.assertEqual(len(a._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in a._notify1) + + class B(A): + b = Float() + _notify2 = [] + + def _b_changed(self, name, old, new): + self._notify2.append((name, old, new)) + + b = B() + b.a = 10 + b.b = 10.0 + self.assertTrue(("a", 0, 10) in b._notify1) + self.assertTrue(("b", 0.0, 10.0) in b._notify2) + + def test_notify_args(self): + def callback0(): + self.cb = () + + def callback1(name): + self.cb = (name,) # type:ignore + + def callback2(name, new): + self.cb = (name, new) # type:ignore + + def callback3(name, old, new): + self.cb = (name, old, new) # type:ignore + + def callback4(name, old, new, obj): + self.cb = (name, old, new, obj) # type:ignore + + class A(HasTraits): + a = Int() + + a = A() + a.on_trait_change(callback0, "a") + a.a = 10 + self.assertEqual(self.cb, ()) + a.on_trait_change(callback0, "a", remove=True) + + a.on_trait_change(callback1, "a") + a.a = 100 + self.assertEqual(self.cb, ("a",)) + a.on_trait_change(callback1, "a", remove=True) + + a.on_trait_change(callback2, "a") + a.a = 1000 + self.assertEqual(self.cb, ("a", 1000)) + a.on_trait_change(callback2, "a", remove=True) + + a.on_trait_change(callback3, "a") + a.a = 10000 + self.assertEqual(self.cb, ("a", 1000, 10000)) + a.on_trait_change(callback3, "a", remove=True) + + a.on_trait_change(callback4, "a") + a.a = 100000 + self.assertEqual(self.cb, ("a", 10000, 100000, a)) + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) + a.on_trait_change(callback4, "a", remove=True) + + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) + + def test_notify_only_once(self): + class A(HasTraits): + listen_to = ["a"] + + a = Int(0) + b = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.on_trait_change(self.listener1, ["a"]) + + def listener1(self, name, old, new): + self.b += 1 + + class B(A): + c = 0 + d = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.on_trait_change(self.listener2) + + def listener2(self, name, old, new): + self.c += 1 + + def _a_changed(self, name, old, new): + self.d += 1 + + b = B() + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + + +class TestObserveDecorator(TestCase): + def setUp(self): + self._notify1 = [] + self._notify2 = [] + + def notify1(self, change): + self._notify1.append(change) + + def notify2(self, change): + self._notify2.append(change) + + def test_notify_all(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.observe(self.notify1) + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.b = 0.0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in self._notify1) + a.b = 10.0 + change = change_dict("b", 0.0, 10.0, a, "change") + self.assertTrue(change in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + self.assertRaises(TraitError, setattr, a, "b", "bad string") + self._notify1 = [] + a.unobserve(self.notify1) + a.a = 20 + a.b = 20.0 + self.assertEqual(len(self._notify1), 0) + + def test_notify_one(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.observe(self.notify1, "a") + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + + def test_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + self.assertEqual(b.a, 0) + self.assertEqual(b.b, 0.0) + b.a = 100 + b.b = 100.0 + self.assertEqual(b.a, 100) + self.assertEqual(b.b, 100.0) + + def test_notify_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + b.observe(self.notify1, "a") + b.observe(self.notify2, "b") + b.a = 0 + b.b = 0.0 + self.assertEqual(len(self._notify1), 0) + self.assertEqual(len(self._notify2), 0) + b.a = 10 + b.b = 10.0 + change = change_dict("a", 0, 10, b, "change") + self.assertTrue(change in self._notify1) + change = change_dict("b", 0.0, 10.0, b, "change") + self.assertTrue(change in self._notify2) + + def test_static_notify(self): + class A(HasTraits): + a = Int() + b = Int() + _notify1 = [] + _notify_any = [] + + @observe("a") + def _a_changed(self, change): + self._notify1.append(change) + + @observe(All) + def _any_changed(self, change): + self._notify_any.append(change) + + a = A() + a.a = 0 + self.assertEqual(len(a._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in a._notify1) + a.b = 1 + self.assertEqual(len(a._notify_any), 2) + change = change_dict("b", 0, 1, a, "change") + self.assertTrue(change in a._notify_any) + + class B(A): + b = Float() # type:ignore + _notify2 = [] + + @observe("b") + def _b_changed(self, change): + self._notify2.append(change) + + b = B() + b.a = 10 + b.b = 10.0 # type:ignore + change = change_dict("a", 0, 10, b, "change") + self.assertTrue(change in b._notify1) + change = change_dict("b", 0.0, 10.0, b, "change") + self.assertTrue(change in b._notify2) + + def test_notify_args(self): + def callback0(): + self.cb = () + + def callback1(change): + self.cb = change + + class A(HasTraits): + a = Int() + + a = A() + a.on_trait_change(callback0, "a") + a.a = 10 + self.assertEqual(self.cb, ()) + a.unobserve(callback0, "a") + + a.observe(callback1, "a") + a.a = 100 + change = change_dict("a", 10, 100, a, "change") + self.assertEqual(self.cb, change) + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) + a.unobserve(callback1, "a") + + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) + + def test_notify_only_once(self): + class A(HasTraits): + listen_to = ["a"] + + a = Int(0) + b = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.observe(self.listener1, ["a"]) + + def listener1(self, change): + self.b += 1 + + class B(A): + c = 0 + d = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.observe(self.listener2) + + def listener2(self, change): + self.c += 1 + + @observe("a") + def _a_changed(self, change): + self.d += 1 + + b = B() + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + + +class TestHasTraits(TestCase): + def test_trait_names(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertEqual(sorted(a.trait_names()), ["f", "i"]) + self.assertEqual(sorted(A.class_trait_names()), ["f", "i"]) + self.assertTrue(a.has_trait("f")) + self.assertFalse(a.has_trait("g")) + + def test_trait_has_value(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertFalse(a.trait_has_value("f")) + self.assertFalse(a.trait_has_value("g")) + a.i = 1 + a.f + self.assertTrue(a.trait_has_value("i")) + self.assertTrue(a.trait_has_value("f")) + + def test_trait_metadata_deprecated(self): + with expected_warnings([r"metadata should be set using the \.tag\(\) method"]): + + class A(HasTraits): + i = Int(config_key="MY_VALUE") + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") + + def test_trait_metadata(self): + class A(HasTraits): + i = Int().tag(config_key="MY_VALUE") + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") + + def test_trait_metadata_default(self): + class A(HasTraits): + i = Int() + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), None) + self.assertEqual(a.trait_metadata("i", "config_key", "default"), "default") + + def test_traits(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f)) + self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f)) + + def test_traits_metadata(self): + class A(HasTraits): + i = Int().tag(config_key="VALUE1", other_thing="VALUE2") + f = Float().tag(config_key="VALUE3", other_thing="VALUE2") + j = Int(0) + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) + traits = a.traits(config_key="VALUE1", other_thing="VALUE2") + self.assertEqual(traits, dict(i=A.i)) + + # This passes, but it shouldn't because I am replicating a bug in + # traits. + traits = a.traits(config_key=lambda v: True) + self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) + + def test_traits_metadata_deprecated(self): + with expected_warnings([r"metadata should be set using the \.tag\(\) method"] * 2): + + class A(HasTraits): + i = Int(config_key="VALUE1", other_thing="VALUE2") + f = Float(config_key="VALUE3", other_thing="VALUE2") + j = Int(0) + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) + traits = a.traits(config_key="VALUE1", other_thing="VALUE2") + self.assertEqual(traits, dict(i=A.i)) + + # This passes, but it shouldn't because I am replicating a bug in + # traits. + traits = a.traits(config_key=lambda v: True) + self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) + + def test_init(self): + class A(HasTraits): + i = Int() + x = Float() + + a = A(i=1, x=10.0) + self.assertEqual(a.i, 1) + self.assertEqual(a.x, 10.0) + + def test_positional_args(self): + class A(HasTraits): + i = Int(0) + + def __init__(self, i): + super().__init__() + self.i = i + + a = A(5) + self.assertEqual(a.i, 5) + # should raise TypeError if no positional arg given + self.assertRaises(TypeError, A) + + +# ----------------------------------------------------------------------------- +# Tests for specific trait types +# ----------------------------------------------------------------------------- + + +class TestType(TestCase): + def test_default(self): + class B: + pass + + class A(HasTraits): + klass = Type(allow_none=True) + + a = A() + self.assertEqual(a.klass, object) + + a.klass = B + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", 10) + + def test_default_options(self): + class B: + pass + + class C(B): + pass + + class A(HasTraits): + # Different possible combinations of options for default_value + # and klass. default_value=None is only valid with allow_none=True. + k1 = Type() + k2 = Type(None, allow_none=True) + k3 = Type(B) + k4 = Type(klass=B) + k5 = Type(default_value=None, klass=B, allow_none=True) + k6 = Type(default_value=C, klass=B) + + self.assertIs(A.k1.default_value, object) + self.assertIs(A.k1.klass, object) + self.assertIs(A.k2.default_value, None) + self.assertIs(A.k2.klass, object) + self.assertIs(A.k3.default_value, B) + self.assertIs(A.k3.klass, B) + self.assertIs(A.k4.default_value, B) + self.assertIs(A.k4.klass, B) + self.assertIs(A.k5.default_value, None) + self.assertIs(A.k5.klass, B) + self.assertIs(A.k6.default_value, C) + self.assertIs(A.k6.klass, B) + + a = A() + self.assertIs(a.k1, object) + self.assertIs(a.k2, None) + self.assertIs(a.k3, B) + self.assertIs(a.k4, B) + self.assertIs(a.k5, None) + self.assertIs(a.k6, C) + + def test_value(self): + class B: + pass + + class C: + pass + + class A(HasTraits): + klass = Type(B) + + a = A() + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", C) + self.assertRaises(TraitError, setattr, a, "klass", object) + a.klass = B + + def test_allow_none(self): + class B: + pass + + class C(B): + pass + + class A(HasTraits): + klass = Type(B) + + a = A() + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", None) + a.klass = C + self.assertEqual(a.klass, C) + + def test_validate_klass(self): + class A(HasTraits): + klass = Type("no strings allowed") + + self.assertRaises(ImportError, A) + + class A(HasTraits): # type:ignore + klass = Type("rub.adub.Duck") + + self.assertRaises(ImportError, A) + + def test_validate_default(self): + class B: + pass + + class A(HasTraits): + klass = Type("bad default", B) + + self.assertRaises(ImportError, A) + + class C(HasTraits): + klass = Type(None, B) + + self.assertRaises(TraitError, getattr, C(), "klass") + + def test_str_klass(self): + class A(HasTraits): + klass = Type("traitlets.config.Config") + + from traitlets.config import Config + + a = A() + a.klass = Config + self.assertEqual(a.klass, Config) + + self.assertRaises(TraitError, setattr, a, "klass", 10) + + def test_set_str_klass(self): + class A(HasTraits): + klass = Type() + + a = A(klass="traitlets.config.Config") + from traitlets.config import Config + + self.assertEqual(a.klass, Config) + + +class TestInstance(TestCase): + def test_basic(self): + class Foo: + pass + + class Bar(Foo): + pass + + class Bah: + pass + + class A(HasTraits): + inst = Instance(Foo, allow_none=True) + + a = A() + self.assertTrue(a.inst is None) + a.inst = Foo() + self.assertTrue(isinstance(a.inst, Foo)) + a.inst = Bar() + self.assertTrue(isinstance(a.inst, Foo)) + self.assertRaises(TraitError, setattr, a, "inst", Foo) + self.assertRaises(TraitError, setattr, a, "inst", Bar) + self.assertRaises(TraitError, setattr, a, "inst", Bah()) + + def test_default_klass(self): + class Foo: + pass + + class Bar(Foo): + pass + + class Bah: + pass + + class FooInstance(Instance[Foo]): + klass = Foo + + class A(HasTraits): + inst = FooInstance(allow_none=True) + + a = A() + self.assertTrue(a.inst is None) + a.inst = Foo() + self.assertTrue(isinstance(a.inst, Foo)) + a.inst = Bar() + self.assertTrue(isinstance(a.inst, Foo)) + self.assertRaises(TraitError, setattr, a, "inst", Foo) + self.assertRaises(TraitError, setattr, a, "inst", Bar) + self.assertRaises(TraitError, setattr, a, "inst", Bah()) + + def test_unique_default_value(self): + class Foo: + pass + + class A(HasTraits): + inst = Instance(Foo, (), {}) + + a = A() + b = A() + self.assertTrue(a.inst is not b.inst) + + def test_args_kw(self): + class Foo: + def __init__(self, c): + self.c = c + + class Bar: + pass + + class Bah: + def __init__(self, c, d): + self.c = c + self.d = d + + class A(HasTraits): + inst = Instance(Foo, (10,)) + + a = A() + self.assertEqual(a.inst.c, 10) + + class B(HasTraits): + inst = Instance(Bah, args=(10,), kw=dict(d=20)) + + b = B() + self.assertEqual(b.inst.c, 10) + self.assertEqual(b.inst.d, 20) + + class C(HasTraits): + inst = Instance(Foo, allow_none=True) + + c = C() + self.assertTrue(c.inst is None) + + def test_bad_default(self): + class Foo: + pass + + class A(HasTraits): + inst = Instance(Foo) + + a = A() + with self.assertRaises(TraitError): + a.inst + + def test_instance(self): + class Foo: + pass + + def inner(): + class A(HasTraits): + inst = Instance(Foo()) # type:ignore + + self.assertRaises(TraitError, inner) + + +class TestThis(TestCase): + def test_this_class(self): + class Foo(HasTraits): + this = This["Foo"]() + + f = Foo() + self.assertEqual(f.this, None) + g = Foo() + f.this = g + self.assertEqual(f.this, g) + self.assertRaises(TraitError, setattr, f, "this", 10) + + def test_this_inst(self): + class Foo(HasTraits): + this = This["Foo"]() + + f = Foo() + f.this = Foo() + self.assertTrue(isinstance(f.this, Foo)) + + def test_subclass(self): + class Foo(HasTraits): + t = This["Foo"]() + + class Bar(Foo): + pass + + f = Foo() + b = Bar() + f.t = b + b.t = f + self.assertEqual(f.t, b) + self.assertEqual(b.t, f) + + def test_subclass_override(self): + class Foo(HasTraits): + t = This["Foo"]() + + class Bar(Foo): + t = This() + + f = Foo() + b = Bar() + f.t = b + self.assertEqual(f.t, b) + self.assertRaises(TraitError, setattr, b, "t", f) + + def test_this_in_container(self): + class Tree(HasTraits): + value = Unicode() + leaves = List(This()) + + tree = Tree(value="foo", leaves=[Tree(value="bar"), Tree(value="buzz")]) + + with self.assertRaises(TraitError): + tree.leaves = [1, 2] + + +class TraitTestBase(TestCase): + """A best testing class for basic trait types.""" + + def assign(self, value): + self.obj.value = value # type:ignore + + def coerce(self, value): + return value + + def test_good_values(self): + if hasattr(self, "_good_values"): + for value in self._good_values: + self.assign(value) + self.assertEqual(self.obj.value, self.coerce(value)) # type:ignore + + def test_bad_values(self): + if hasattr(self, "_bad_values"): + for value in self._bad_values: + try: + self.assertRaises(TraitError, self.assign, value) + except AssertionError: + assert False, value + + def test_default_value(self): + if hasattr(self, "_default_value"): + self.assertEqual(self._default_value, self.obj.value) # type:ignore + + def test_allow_none(self): + if ( + hasattr(self, "_bad_values") + and hasattr(self, "_good_values") + and None in self._bad_values + ): + trait = self.obj.traits()["value"] # type:ignore + try: + trait.allow_none = True + self._bad_values.remove(None) + # skip coerce. Allow None casts None to None. + self.assign(None) + self.assertEqual(self.obj.value, None) # type:ignore + self.test_good_values() + self.test_bad_values() + finally: + # tear down + trait.allow_none = False + self._bad_values.append(None) + + def tearDown(self): + # restore default value after tests, if set + if hasattr(self, "_default_value"): + self.obj.value = self._default_value # type:ignore + + +class AnyTrait(HasTraits): + value = Any() + + +class AnyTraitTest(TraitTestBase): + obj = AnyTrait() + + _default_value = None + _good_values = [10.0, "ten", [10], {"ten": 10}, (10,), None, 1j] + _bad_values: t.Any = [] + + +class UnionTrait(HasTraits): + value = Union([Type(), Bool()]) + + +class UnionTraitTest(TraitTestBase): + obj = UnionTrait(value="traitlets.config.Config") + _good_values = [int, float, True] + _bad_values = [[], (0,), 1j] + + +class CallableTrait(HasTraits): + value = Callable() + + +class CallableTraitTest(TraitTestBase): + obj = CallableTrait(value=lambda x: type(x)) + _good_values = [int, sorted, lambda x: print(x)] + _bad_values = [[], 1, ""] + + +class OrTrait(HasTraits): + value = Bool() | Unicode() + + +class OrTraitTest(TraitTestBase): + obj = OrTrait() + _good_values = [True, False, "ten"] + _bad_values = [[], (0,), 1j] + + +class IntTrait(HasTraits): + value = Int(99, min=-100) + + +class TestInt(TraitTestBase): + obj = IntTrait() + _default_value = 99 + _good_values = [10, -10] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + 10.1, + -10.1, + "10L", + "-10L", + "10.1", + "-10.1", + "10", + "-10", + -200, + ] + + +class CIntTrait(HasTraits): + value = CInt("5") + + +class TestCInt(TraitTestBase): + obj = CIntTrait() + + _default_value = 5 + _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] + + def coerce(self, n): + return int(n) + + +class MinBoundCIntTrait(HasTraits): + value = CInt("5", min=3) + + +class TestMinBoundCInt(TestCInt): + obj = MinBoundCIntTrait() # type:ignore + + _default_value = 5 + _good_values = [3, 3.0, "3"] + _bad_values = [2.6, 2, -3, -3.0] + + +class LongTrait(HasTraits): + value = Long(99) + + +class TestLong(TraitTestBase): + obj = LongTrait() + + _default_value = 99 + _good_values = [10, -10] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + 10.1, + -10.1, + "10", + "-10", + "10L", + "-10L", + "10.1", + "-10.1", + ] + + +class MinBoundLongTrait(HasTraits): + value = Long(99, min=5) + + +class TestMinBoundLong(TraitTestBase): + obj = MinBoundLongTrait() + + _default_value = 99 + _good_values = [5, 10] + _bad_values = [4, -10] + + +class MaxBoundLongTrait(HasTraits): + value = Long(5, max=10) + + +class TestMaxBoundLong(TraitTestBase): + obj = MaxBoundLongTrait() + + _default_value = 5 + _good_values = [10, -2] + _bad_values = [11, 20] + + +class CLongTrait(HasTraits): + value = CLong("5") + + +class TestCLong(TraitTestBase): + obj = CLongTrait() + + _default_value = 5 + _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] + + def coerce(self, n): + return int(n) + + +class MaxBoundCLongTrait(HasTraits): + value = CLong("5", max=10) + + +class TestMaxBoundCLong(TestCLong): + obj = MaxBoundCLongTrait() # type:ignore + + _default_value = 5 + _good_values = [10, "10", 10.3] + _bad_values = [11.0, "11"] + + +class IntegerTrait(HasTraits): + value = Integer(1) + + +class TestInteger(TestLong): + obj = IntegerTrait() # type:ignore + _default_value = 1 + + def coerce(self, n): + return int(n) + + +class MinBoundIntegerTrait(HasTraits): + value = Integer(5, min=3) + + +class TestMinBoundInteger(TraitTestBase): + obj = MinBoundIntegerTrait() + + _default_value = 5 + _good_values = 3, 20 + _bad_values = [2, -10] + + +class MaxBoundIntegerTrait(HasTraits): + value = Integer(1, max=3) + + +class TestMaxBoundInteger(TraitTestBase): + obj = MaxBoundIntegerTrait() + + _default_value = 1 + _good_values = 3, -2 + _bad_values = [4, 10] + + +class FloatTrait(HasTraits): + value = Float(99.0, max=200.0) + + +class TestFloat(TraitTestBase): + obj = FloatTrait() + + _default_value = 99.0 + _good_values = [10, -10, 10.1, -10.1] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + "10", + "-10", + "10L", + "-10L", + "10.1", + "-10.1", + 201.0, + ] + + +class CFloatTrait(HasTraits): + value = CFloat("99.0", max=200.0) + + +class TestCFloat(TraitTestBase): + obj = CFloatTrait() + + _default_value = 99.0 + _good_values = [10, 10.0, 10.5, "10.0", "10", "-10"] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, 200.1, "200.1"] + + def coerce(self, v): + return float(v) + + +class ComplexTrait(HasTraits): + value = Complex(99.0 - 99.0j) + + +class TestComplex(TraitTestBase): + obj = ComplexTrait() + + _default_value = 99.0 - 99.0j + _good_values = [ + 10, + -10, + 10.1, + -10.1, + 10j, + 10 + 10j, + 10 - 10j, + 10.1j, + 10.1 + 10.1j, + 10.1 - 10.1j, + ] + _bad_values = ["10L", "-10L", "ten", [10], {"ten": 10}, (10,), None] + + +class BytesTrait(HasTraits): + value = Bytes(b"string") + + +class TestBytes(TraitTestBase): + obj = BytesTrait() + + _default_value = b"string" + _good_values = [b"10", b"-10", b"10L", b"-10L", b"10.1", b"-10.1", b"string"] + _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None, "string"] + + +class UnicodeTrait(HasTraits): + value = Unicode("unicode") + + +class TestUnicode(TraitTestBase): + obj = UnicodeTrait() + + _default_value = "unicode" + _good_values = ["10", "-10", "10L", "-10L", "10.1", "-10.1", "", "string", "€", b"bytestring"] + _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None] + + def coerce(self, v): + return cast_unicode(v) + + +class ObjectNameTrait(HasTraits): + value = ObjectName("abc") + + +class TestObjectName(TraitTestBase): + obj = ObjectNameTrait() + + _default_value = "abc" + _good_values = ["a", "gh", "g9", "g_", "_G", "a345_"] + _bad_values = [ + 1, + "", + "€", + "9g", + "!", + "#abc", + "aj@", + "a.b", + "a()", + "a[0]", + None, + object(), + object, + ] + _good_values.append("þ") # þ=1 is valid in Python 3 (PEP 3131). + + +class DottedObjectNameTrait(HasTraits): + value = DottedObjectName("a.b") + + +class TestDottedObjectName(TraitTestBase): + obj = DottedObjectNameTrait() + + _default_value = "a.b" + _good_values = ["A", "y.t", "y765.__repr__", "os.path.join"] + _bad_values = [1, "abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None] + + _good_values.append("t.þ") + + +class TCPAddressTrait(HasTraits): + value = TCPAddress() + + +class TestTCPAddress(TraitTestBase): + obj = TCPAddressTrait() + + _default_value = ("127.0.0.1", 0) + _good_values = [("localhost", 0), ("192.168.0.1", 1000), ("www.google.com", 80)] + _bad_values = [(0, 0), ("localhost", 10.0), ("localhost", -1), None] + + +class ListTrait(HasTraits): + value = List(Int()) + + +class TestList(TraitTestBase): + obj = ListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[], [1], list(range(10)), (1, 2)] + _bad_values = [10, [1, "a"], "a"] + + def coerce(self, value): + if value is not None: + value = list(value) + return value + + +class Foo: + pass + + +class NoneInstanceListTrait(HasTraits): + value = List(Instance(Foo)) + + +class TestNoneInstanceList(TraitTestBase): + obj = NoneInstanceListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[Foo(), Foo()], []] + _bad_values = [[None], [Foo(), None]] + + +class InstanceListTrait(HasTraits): + value = List(Instance(__name__ + ".Foo")) + + +class TestInstanceList(TraitTestBase): + obj = InstanceListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, Foo) + + _default_value: t.List[t.Any] = [] + _good_values = [[Foo(), Foo()], []] + _bad_values = [ + [ + "1", + 2, + ], + "1", + [Foo], + None, + ] + + +class UnionListTrait(HasTraits): + value = List(Int() | Bool()) + + +class TestUnionListTrait(TraitTestBase): + obj = UnionListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[True, 1], [False, True]] + _bad_values = [[1, "True"], False] + + +class LenListTrait(HasTraits): + value = List(Int(), [0], minlen=1, maxlen=2) + + +class TestLenList(TraitTestBase): + obj = LenListTrait() + + _default_value = [0] + _good_values = [[1], [1, 2], (1, 2)] + _bad_values = [10, [1, "a"], "a", [], list(range(3))] + + def coerce(self, value): + if value is not None: + value = list(value) + return value + + +class TupleTrait(HasTraits): + value = Tuple(Int(allow_none=True), default_value=(1,)) + + +class TestTupleTrait(TraitTestBase): + obj = TupleTrait() + + _default_value = (1,) + _good_values = [(1,), (0,), [1]] + _bad_values = [10, (1, 2), ("a"), (), None] + + def coerce(self, value): + if value is not None: + value = tuple(value) + return value + + def test_invalid_args(self): + self.assertRaises(TypeError, Tuple, 5) + self.assertRaises(TypeError, Tuple, default_value="hello") + t = Tuple(Int(), CBytes(), default_value=(1, 5)) + + +class LooseTupleTrait(HasTraits): + value = Tuple((1, 2, 3)) + + +class TestLooseTupleTrait(TraitTestBase): + obj = LooseTupleTrait() + + _default_value = (1, 2, 3) + _good_values = [(1,), [1], (0,), tuple(range(5)), tuple("hello"), ("a", 5), ()] + _bad_values = [10, "hello", {}, None] + + def coerce(self, value): + if value is not None: + value = tuple(value) + return value + + def test_invalid_args(self): + self.assertRaises(TypeError, Tuple, 5) + self.assertRaises(TypeError, Tuple, default_value="hello") + t = Tuple(Int(), CBytes(), default_value=(1, 5)) + + +class MultiTupleTrait(HasTraits): + value = Tuple(Int(), Bytes(), default_value=[99, b"bottles"]) + + +class TestMultiTuple(TraitTestBase): + obj = MultiTupleTrait() + + _default_value = (99, b"bottles") + _good_values = [(1, b"a"), (2, b"b")] + _bad_values = ((), 10, b"a", (1, b"a", 3), (b"a", 1), (1, "a")) + + + "Trait", + ( + List, + Tuple, + Set, + Dict, + Integer, + Unicode, + ), +) +def test_allow_none_default_value(Trait): + class C(HasTraits): + t = Trait(default_value=None, allow_none=True) + + # test default value + c = C() + assert c.t is None + + # and in constructor + c = C(t=None) + assert c.t is None + + + "Trait, default_value", + ((List, []), (Tuple, ()), (Set, set()), (Dict, {}), (Integer, 0), (Unicode, "")), +) +def test_default_value(Trait, default_value): + class C(HasTraits): + t = Trait() + + # test default value + c = C() + assert type(c.t) is type(default_value) + assert c.t == default_value + + + "Trait, default_value", + ((List, []), (Tuple, ()), (Set, set())), +) +def test_subclass_default_value(Trait, default_value): + """Test deprecated default_value=None behavior for Container subclass traits""" + + class SubclassTrait(Trait): # type:ignore + def __init__(self, default_value=None): + super().__init__(default_value=default_value) + + class C(HasTraits): + t = SubclassTrait() + + # test default value + c = C() + assert type(c.t) is type(default_value) + assert c.t == default_value + + +class CRegExpTrait(HasTraits): + value = CRegExp(r"") + + +class TestCRegExp(TraitTestBase): + def coerce(self, value): + return re.compile(value) + + obj = CRegExpTrait() + + _default_value = re.compile(r"") + _good_values = [r"\d+", re.compile(r"\d+")] + _bad_values = ["(", None, ()] + + +class DictTrait(HasTraits): + value = Dict() + + +def test_dict_assignment(): + d: t.Dict[str, int] = {} + c = DictTrait() + c.value = d + d["a"] = 5 + assert d == c.value + assert c.value is d + + +class UniformlyValueValidatedDictTrait(HasTraits): + value = Dict(value_trait=Unicode(), default_value={"foo": "1"}) + + +class TestInstanceUniformlyValueValidatedDict(TraitTestBase): + obj = UniformlyValueValidatedDictTrait() + + _default_value = {"foo": "1"} + _good_values = [{"foo": "0", "bar": "1"}] + _bad_values = [{"foo": 0, "bar": "1"}] + + +class NonuniformlyValueValidatedDictTrait(HasTraits): + value = Dict(per_key_traits={"foo": Int()}, default_value={"foo": 1}) + + +class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase): + obj = NonuniformlyValueValidatedDictTrait() + + _default_value = {"foo": 1} + _good_values = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": 1}] + _bad_values = [{"foo": "0", "bar": "1"}] + + +class KeyValidatedDictTrait(HasTraits): + value = Dict(key_trait=Unicode(), default_value={"foo": "1"}) + + +class TestInstanceKeyValidatedDict(TraitTestBase): + obj = KeyValidatedDictTrait() + + _default_value = {"foo": "1"} + _good_values = [{"foo": "0", "bar": "1"}] + _bad_values = [{"foo": "0", 0: "1"}] + + +class FullyValidatedDictTrait(HasTraits): + value = Dict( + value_trait=Unicode(), + key_trait=Unicode(), + per_key_traits={"foo": Int()}, + default_value={"foo": 1}, + ) + + +class TestInstanceFullyValidatedDict(TraitTestBase): + obj = FullyValidatedDictTrait() + + _default_value = {"foo": 1} + _good_values = [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}] + _bad_values = [{"foo": 0, "bar": 1}, {"foo": "0", "bar": "1"}, {"foo": 0, 0: "1"}] + + +def test_dict_default_value(): + """Check that the `{}` default value of the Dict traitlet constructor is + actually copied.""" + + class Foo(HasTraits): + d1 = Dict() + d2 = Dict() + + foo = Foo() + assert foo.d1 == {} + assert foo.d2 == {} + assert foo.d1 is not foo.d2 + + +class TestValidationHook(TestCase): + def test_parity_trait(self): + """Verify that the early validation hook is effective""" + + class Parity(HasTraits): + value = Int(0) + parity = Enum(["odd", "even"], default_value="even") + + @validate("value") + def _value_validate(self, proposal): + value = proposal["value"] + if self.parity == "even" and value % 2: + raise TraitError("Expected an even number") + if self.parity == "odd" and (value % 2 == 0): + raise TraitError("Expected an odd number") + return value + + u = Parity() + u.parity = "odd" + u.value = 1 # OK + with self.assertRaises(TraitError): + u.value = 2 # Trait Error + + u.parity = "even" + u.value = 2 # OK + + def test_multiple_validate(self): + """Verify that we can register the same validator to multiple names""" + + class OddEven(HasTraits): + odd = Int(1) + even = Int(0) + + @validate("odd", "even") + def check_valid(self, proposal): + if proposal["trait"].name == "odd" and not proposal["value"] % 2: + raise TraitError("odd should be odd") + if proposal["trait"].name == "even" and proposal["value"] % 2: + raise TraitError("even should be even") + + u = OddEven() + u.odd = 3 # OK + with self.assertRaises(TraitError): + u.odd = 2 # Trait Error + + u.even = 2 # OK + with self.assertRaises(TraitError): + u.even = 3 # Trait Error + + def test_validate_used(self): + """Verify that the validate value is being used""" + + class FixedValue(HasTraits): + value = Int(0) + + @validate("value") + def _value_validate(self, proposal): + return -1 + + u = FixedValue(value=2) + assert u.value == -1 + + u = FixedValue() + u.value = 3 + assert u.value == -1 + + +class TestLink(TestCase): + def test_connect_same(self): + """Verify two traitlets of the same type can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "value")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.value) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.value) + b.value = 6 + self.assertEqual(a.value, b.value) + + def test_link_different(self): + """Verify two traitlets of different types can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "count")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.count) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.count) + b.count = 4 + self.assertEqual(a.value, b.count) + + def test_unlink_link(self): + """Verify two linked traitlets can be unlinked and relinked.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Connect the two classes. + c = link((a, "value"), (b, "value")) + a.value = 4 + c.unlink() + + # Change one of the values to make sure they don't stay in sync. + a.value = 5 + self.assertNotEqual(a.value, b.value) + c.link() + self.assertEqual(a.value, b.value) + a.value += 1 + self.assertEqual(a.value, b.value) + + def test_callbacks(self): + """Verify two linked traitlets have their callbacks called once.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Register callbacks that count. + callback_count = [] + + def a_callback(name, old, new): + callback_count.append("a") + + a.on_trait_change(a_callback, "value") + + def b_callback(name, old, new): + callback_count.append("b") + + b.on_trait_change(b_callback, "count") + + # Connect the two classes. + c = link((a, "value"), (b, "count")) + + # Make sure b's count was set to a's value once. + self.assertEqual("".join(callback_count), "b") + del callback_count[:] + + # Make sure a's value was set to b's count once. + b.count = 5 + self.assertEqual("".join(callback_count), "ba") + del callback_count[:] + + # Make sure b's count was set to a's value once. + a.value = 4 + self.assertEqual("".join(callback_count), "ab") + del callback_count[:] + + def test_tranform(self): + """Test transform link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "value"), transform=(lambda x: 2 * x, lambda x: int(x / 2.0))) + + # Make sure the values are correct at the point of linking. + self.assertEqual(b.value, 2 * a.value) + + # Change one the value of the source and check that it modifies the target. + a.value = 5 + self.assertEqual(b.value, 10) + # Change one the value of the target and check that it modifies the + # source. + b.value = 6 + self.assertEqual(a.value, 3) + + def test_link_broken_at_source(self): + class MyClass(HasTraits): + i = Int() + j = Int() + + @observe("j") + def another_update(self, change): + self.i = change.new * 2 + + mc = MyClass() + l = link((mc, "i"), (mc, "j")) # noqa + self.assertRaises(TraitError, setattr, mc, "i", 2) + + def test_link_broken_at_target(self): + class MyClass(HasTraits): + i = Int() + j = Int() + + @observe("i") + def another_update(self, change): + self.j = change.new * 2 + + mc = MyClass() + l = link((mc, "i"), (mc, "j")) # noqa + self.assertRaises(TraitError, setattr, mc, "j", 2) + + +class TestDirectionalLink(TestCase): + def test_connect_same(self): + """Verify two traitlets of the same type can be linked together using directional_link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "value")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.value) + + # Change one the value of the source and check that it synchronizes the target. + a.value = 5 + self.assertEqual(b.value, 5) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 + self.assertEqual(a.value, 5) + + def test_tranform(self): + """Test transform link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "value"), lambda x: 2 * x) + + # Make sure the values are correct at the point of linking. + self.assertEqual(b.value, 2 * a.value) + + # Change one the value of the source and check that it modifies the target. + a.value = 5 + self.assertEqual(b.value, 10) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 + self.assertEqual(a.value, 5) + + def test_link_different(self): + """Verify two traitlets of different types can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "count")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.count) + + # Change one the value of the source and check that it synchronizes the target. + a.value = 5 + self.assertEqual(b.count, 5) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 # type:ignore + self.assertEqual(a.value, 5) + + def test_unlink_link(self): + """Verify two linked traitlets can be unlinked and relinked.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Connect the two classes. + c = directional_link((a, "value"), (b, "value")) + a.value = 4 + c.unlink() + + # Change one of the values to make sure they don't stay in sync. + a.value = 5 + self.assertNotEqual(a.value, b.value) + c.link() + self.assertEqual(a.value, b.value) + a.value += 1 + self.assertEqual(a.value, b.value) + + +class Pickleable(HasTraits): + i = Int() + + @observe("i") + def _i_changed(self, change): + pass + + @validate("i") + def _i_validate(self, commit): + return commit["value"] + + j = Int() + + def __init__(self): + with self.hold_trait_notifications(): + self.i = 1 + self.on_trait_change(self._i_changed, "i") + + +def test_pickle_hastraits(): + c = Pickleable() + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(c, protocol) + c2 = pickle.loads(p) + assert c2.i == c.i + assert c2.j == c.j + + c.i = 5 + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(c, protocol) + c2 = pickle.loads(p) + assert c2.i == c.i + assert c2.j == c.j + + +def test_hold_trait_notifications(): + changes = [] + + class Test(HasTraits): + a = Integer(0) + b = Integer(0) + + def _a_changed(self, name, old, new): + changes.append((old, new)) + + def _b_validate(self, value, trait): + if value != 0: + raise TraitError("Only 0 is a valid value") + return value + + # Test context manager and nesting + t = Test() + with t.hold_trait_notifications(): + with t.hold_trait_notifications(): + t.a = 1 + assert t.a == 1 + assert changes == [] + t.a = 2 + assert t.a == 2 + with t.hold_trait_notifications(): + t.a = 3 + assert t.a == 3 + assert changes == [] + t.a = 4 + assert t.a == 4 + assert changes == [] + t.a = 4 + assert t.a == 4 + assert changes == [] + + assert changes == [(0, 4)] + # Test roll-back + try: + with t.hold_trait_notifications(): + t.b = 1 # raises a Trait error + except Exception: + pass + assert t.b == 0 + + +class RollBack(HasTraits): + bar = Int() + + def _bar_validate(self, value, trait): + if value: + raise TraitError("foobar") + return value + + +class TestRollback(TestCase): + def test_roll_back(self): + def assign_rollback(): + RollBack(bar=1) + + self.assertRaises(TraitError, assign_rollback) + + +class CacheModification(HasTraits): + foo = Int() + bar = Int() + + def _bar_validate(self, value, trait): + self.foo = value + return value + + def _foo_validate(self, value, trait): + self.bar = value + return value + + +def test_cache_modification(): + CacheModification(foo=1) + CacheModification(bar=1) + + +class OrderTraits(HasTraits): + notified = Dict() + + a = Unicode() + b = Unicode() + c = Unicode() + d = Unicode() + e = Unicode() + f = Unicode() + g = Unicode() + h = Unicode() + i = Unicode() + j = Unicode() + k = Unicode() + l = Unicode() # noqa + + def _notify(self, name, old, new): + """check the value of all traits when each trait change is triggered + + This verifies that the values are not sensitive + to dict ordering when loaded from kwargs + """ + # check the value of the other traits + # when a given trait change notification fires + self.notified[name] = {c: getattr(self, c) for c in "abcdefghijkl"} + + def __init__(self, **kwargs): + self.on_trait_change(self._notify) + super().__init__(**kwargs) + + +def test_notification_order(): + d = {c: c for c in "abcdefghijkl"} + obj = OrderTraits() + assert obj.notified == {} + obj = OrderTraits(**d) + notifications = {c: d for c in "abcdefghijkl"} + assert obj.notified == notifications + + +### +# Traits for Forward Declaration Tests +### +class ForwardDeclaredInstanceTrait(HasTraits): + value = ForwardDeclaredInstance["ForwardDeclaredBar"]("ForwardDeclaredBar", allow_none=True) + + +class ForwardDeclaredTypeTrait(HasTraits): + value = ForwardDeclaredType[t.Any, t.Any]("ForwardDeclaredBar", allow_none=True) + + +class ForwardDeclaredInstanceListTrait(HasTraits): + value = List(ForwardDeclaredInstance("ForwardDeclaredBar")) + + +class ForwardDeclaredTypeListTrait(HasTraits): + value = List(ForwardDeclaredType("ForwardDeclaredBar")) + + +### +# End Traits for Forward Declaration Tests +### + + +### +# Classes for Forward Declaration Tests +### +class ForwardDeclaredBar: + pass + + +class ForwardDeclaredBarSub(ForwardDeclaredBar): + pass + + +### +# End Classes for Forward Declaration Tests +### + + +### +# Forward Declaration Tests +### +class TestForwardDeclaredInstanceTrait(TraitTestBase): + obj = ForwardDeclaredInstanceTrait() + _default_value = None + _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()] + _bad_values = ["foo", 3, ForwardDeclaredBar, ForwardDeclaredBarSub] + + +class TestForwardDeclaredTypeTrait(TraitTestBase): + obj = ForwardDeclaredTypeTrait() + _default_value = None + _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub] + _bad_values = ["foo", 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()] + + +class TestForwardDeclaredInstanceList(TraitTestBase): + obj = ForwardDeclaredInstanceListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) + + _default_value: t.List[t.Any] = [] + _good_values = [ + [ForwardDeclaredBar(), ForwardDeclaredBarSub()], + [], + ] + _bad_values = [ + ForwardDeclaredBar(), + [ForwardDeclaredBar(), 3, None], + "1", + # Note that this is the type, not an instance. + [ForwardDeclaredBar], + [None], + None, + ] + + +class TestForwardDeclaredTypeList(TraitTestBase): + obj = ForwardDeclaredTypeListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) + + _default_value: t.List[t.Any] = [] + _good_values = [ + [ForwardDeclaredBar, ForwardDeclaredBarSub], + [], + ] + _bad_values = [ + ForwardDeclaredBar, + [ForwardDeclaredBar, 3], + "1", + # Note that this is an instance, not the type. + [ForwardDeclaredBar()], + [None], + None, + ] + + +### +# End Forward Declaration Tests +### + + +class TestDynamicTraits(TestCase): + def setUp(self): + self._notify1 = [] + + def notify1(self, name, old, new): + self._notify1.append((name, old, new)) + + @t.no_type_check + def test_notify_all(self): + class A(HasTraits): + pass + + a = A() + self.assertTrue(not hasattr(a, "x")) + self.assertTrue(not hasattr(a, "y")) + + # Dynamically add trait x. + a.add_traits(x=Int()) + self.assertTrue(hasattr(a, "x")) + self.assertTrue(isinstance(a, (A,))) + + # Dynamically add trait y. + a.add_traits(y=Float()) + self.assertTrue(hasattr(a, "y")) + self.assertTrue(isinstance(a, (A,))) + self.assertEqual(a.__class__.__name__, A.__name__) + + # Create a new instance and verify that x and y + # aren't defined. + b = A() + self.assertTrue(not hasattr(b, "x")) + self.assertTrue(not hasattr(b, "y")) + + # Verify that notification works like normal. + a.on_trait_change(self.notify1) + a.x = 0 + self.assertEqual(len(self._notify1), 0) + a.y = 0.0 + self.assertEqual(len(self._notify1), 0) + a.x = 10 + self.assertTrue(("x", 0, 10) in self._notify1) + a.y = 10.0 + self.assertTrue(("y", 0.0, 10.0) in self._notify1) + self.assertRaises(TraitError, setattr, a, "x", "bad string") + self.assertRaises(TraitError, setattr, a, "y", "bad string") + self._notify1 = [] + a.on_trait_change(self.notify1, remove=True) + a.x = 20 + a.y = 20.0 + self.assertEqual(len(self._notify1), 0) + + +def test_enum_no_default(): + class C(HasTraits): + t = Enum(["a", "b"]) + + c = C() + c.t = "a" + assert c.t == "a" + + c = C() + + with pytest.raises(TraitError): + t = c.t + + c = C(t="b") + assert c.t == "b" + + +def test_default_value_repr(): + class C(HasTraits): + t = Type("traitlets.HasTraits") + t2 = Type(HasTraits) + n = Integer(0) + lis = List() + d = Dict() + + assert C.t.default_value_repr() == "'traitlets.HasTraits'" + assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'" + assert C.n.default_value_repr() == "0" + assert C.lis.default_value_repr() == "[]" + assert C.d.default_value_repr() == "{}" + + +class TransitionalClass(HasTraits): + d = Any() + + @default("d") + def _d_default(self): + return TransitionalClass + + parent_super = False + calls_super = Integer(0) + + @default("calls_super") + def _calls_super_default(self): + return -1 + + @observe("calls_super") + @observe_compat + def _calls_super_changed(self, change): + self.parent_super = change + + parent_override = False + overrides = Integer(0) + + @observe("overrides") + @observe_compat + def _overrides_changed(self, change): + self.parent_override = change + + +class SubClass(TransitionalClass): + def _d_default(self): + return SubClass + + subclass_super = False + + def _calls_super_changed(self, name, old, new): + self.subclass_super = True + super()._calls_super_changed(name, old, new) + + subclass_override = False + + def _overrides_changed(self, name, old, new): + self.subclass_override = True + + +def test_subclass_compat(): + obj = SubClass() + obj.calls_super = 5 + assert obj.parent_super + assert obj.subclass_super + obj.overrides = 5 + assert obj.subclass_override + assert not obj.parent_override + assert obj.d is SubClass + + +class DefinesHandler(HasTraits): + parent_called = False + + trait = Integer() + + @observe("trait") + def handler(self, change): + self.parent_called = True + + +class OverridesHandler(DefinesHandler): + child_called = False + + @observe("trait") + def handler(self, change): + self.child_called = True + + +def test_subclass_override_observer(): + obj = OverridesHandler() + obj.trait = 5 + assert obj.child_called + assert not obj.parent_called + + +class DoesntRegisterHandler(DefinesHandler): + child_called = False + + def handler(self, change): + self.child_called = True + + +def test_subclass_override_not_registered(): + """Subclass that overrides observer and doesn't re-register unregisters both""" + obj = DoesntRegisterHandler() + obj.trait = 5 + assert not obj.child_called + assert not obj.parent_called + + +class AddsHandler(DefinesHandler): + child_called = False + + @observe("trait") + def child_handler(self, change): + self.child_called = True + + +def test_subclass_add_observer(): + obj = AddsHandler() + obj.trait = 5 + assert obj.child_called + assert obj.parent_called + + +def test_observe_iterables(): + class C(HasTraits): + i = Integer() + s = Unicode() + + c = C() + recorded = {} + + def record(change): + recorded["change"] = change + + # observe with names=set + c.observe(record, names={"i", "s"}) + c.i = 5 + assert recorded["change"].name == "i" + assert recorded["change"].new == 5 + c.s = "hi" + assert recorded["change"].name == "s" + assert recorded["change"].new == "hi" + + # observe with names=custom container with iter, contains + class MyContainer: + def __init__(self, container): + self.container = container + + def __iter__(self): + return iter(self.container) + + def __contains__(self, key): + return key in self.container + + c.observe(record, names=MyContainer({"i", "s"})) + c.i = 10 + assert recorded["change"].name == "i" + assert recorded["change"].new == 10 + c.s = "ok" + assert recorded["change"].name == "s" + assert recorded["change"].new == "ok" + + +def test_super_args(): + class SuperRecorder: + def __init__(self, *args, **kwargs): + self.super_args = args + self.super_kwargs = kwargs + + class SuperHasTraits(HasTraits, SuperRecorder): + i = Integer() + + obj = SuperHasTraits("a1", "a2", b=10, i=5, c="x") + assert obj.i == 5 + assert not hasattr(obj, "b") + assert not hasattr(obj, "c") + assert obj.super_args == ("a1", "a2") + assert obj.super_kwargs == {"b": 10, "c": "x"} + + +def test_super_bad_args(): + class SuperHasTraits(HasTraits): + a = Integer() + + w = ["Passing unrecognized arguments"] + with expected_warnings(w): + obj = SuperHasTraits(a=1, b=2) + assert obj.a == 1 + assert not hasattr(obj, "b") + + +def test_default_mro(): + """Verify that default values follow mro""" + + class Base(HasTraits): + trait = Unicode("base") + attr = "base" + + class A(Base): + pass + + class B(Base): + trait = Unicode("B") + attr = "B" + + class AB(A, B): + pass + + class BA(B, A): + pass + + assert A().trait == "base" + assert A().attr == "base" + assert BA().trait == "B" + assert BA().attr == "B" + assert AB().trait == "B" + assert AB().attr == "B" + + +def test_cls_self_argument(): + class X(HasTraits): + def __init__(__self, cls, self): # noqa + pass + + x = X(cls=None, self=None) + + +def test_override_default(): + class C(HasTraits): + a = Unicode("hard default") + + def _a_default(self): + return "default method" + + C._a_default = lambda self: "overridden" # type:ignore + c = C() + assert c.a == "overridden" + + +def test_override_default_decorator(): + class C(HasTraits): + a = Unicode("hard default") + + @default("a") + def _a_default(self): + return "default method" + + C._a_default = lambda self: "overridden" # type:ignore + c = C() + assert c.a == "overridden" + + +def test_override_default_instance(): + class C(HasTraits): + a = Unicode("hard default") + + @default("a") + def _a_default(self): + return "default method" + + c = C() + c._a_default = lambda self: "overridden" + assert c.a == "overridden" + + +def test_copy_HasTraits(): + from copy import copy + + class C(HasTraits): + a = Int() + + c = C(a=1) + assert c.a == 1 + + cc = copy(c) + cc.a = 2 + assert cc.a == 2 + assert c.a == 1 + + +def _from_string_test(traittype, s, expected): + """Run a test of trait.from_string""" + if isinstance(traittype, TraitType): + trait = traittype + else: + trait = traittype(allow_none=True) + if isinstance(s, list): + cast = trait.from_string_list # type:ignore + else: + cast = trait.from_string + if type(expected) is type and issubclass(expected, Exception): + with pytest.raises(expected): + value = cast(s) + trait.validate(CrossValidationStub(), value) # type:ignore + else: + value = cast(s) + assert value == expected + + + "s, expected", + [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], +) +def test_unicode_from_string(s, expected): + _from_string_test(Unicode, s, expected) + + + "s, expected", + [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], +) +def test_cunicode_from_string(s, expected): + _from_string_test(CUnicode, s, expected) + + + "s, expected", + [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], +) +def test_bytes_from_string(s, expected): + _from_string_test(Bytes, s, expected) + + + "s, expected", + [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], +) +def test_cbytes_from_string(s, expected): + _from_string_test(CBytes, s, expected) + + + "s, expected", + [("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)], +) +def test_int_from_string(s, expected): + _from_string_test(Integer, s, expected) + + + "s, expected", + [("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)], +) +def test_float_from_string(s, expected): + _from_string_test(Float, s, expected) + + + "s, expected", + [ + ("x", ValueError), + ("1", 1.0), + ("123.5", 123.5), + ("2.5", 2.5), + ("1+2j", 1 + 2j), + ("None", None), + ], +) +def test_complex_from_string(s, expected): + _from_string_test(Complex, s, expected) + + + "s, expected", + [ + ("true", True), + ("TRUE", True), + ("1", True), + ("0", False), + ("False", False), + ("false", False), + ("1.0", ValueError), + ("None", None), + ], +) +def test_bool_from_string(s, expected): + _from_string_test(Bool, s, expected) + + + "s, expected", + [ + ("{}", {}), + ("1", TraitError), + ("{1: 2}", {1: 2}), + ('{"key": "value"}', {"key": "value"}), + ("x", TraitError), + ("None", None), + ], +) +def test_dict_from_string(s, expected): + _from_string_test(Dict, s, expected) + + + "s, expected", + [ + ("[]", []), + ('[1, 2, "x"]', [1, 2, "x"]), + (["1", "x"], ["1", "x"]), + (["None"], None), + ], +) +def test_list_from_string(s, expected): + _from_string_test(List, s, expected) + + + "s, expected, value_trait", + [ + (["1", "2", "3"], [1, 2, 3], Integer()), + (["x"], ValueError, Integer()), + (["1", "x"], ["1", "x"], Unicode()), + (["None"], [None], Unicode(allow_none=True)), + (["None"], ["None"], Unicode(allow_none=False)), + ], +) +def test_list_items_from_string(s, expected, value_trait): + _from_string_test(List(value_trait), s, expected) + + + "s, expected", + [ + ("[]", set()), + ('[1, 2, "x"]', {1, 2, "x"}), + ('{1, 2, "x"}', {1, 2, "x"}), + (["1", "x"], {"1", "x"}), + (["None"], None), + ], +) +def test_set_from_string(s, expected): + _from_string_test(Set, s, expected) + + + "s, expected, value_trait", + [ + (["1", "2", "3"], {1, 2, 3}, Integer()), + (["x"], ValueError, Integer()), + (["1", "x"], {"1", "x"}, Unicode()), + (["None"], {None}, Unicode(allow_none=True)), + ], +) +def test_set_items_from_string(s, expected, value_trait): + _from_string_test(Set(value_trait), s, expected) + + + "s, expected", + [ + ("[]", ()), + ("()", ()), + ('[1, 2, "x"]', (1, 2, "x")), + ('(1, 2, "x")', (1, 2, "x")), + (["1", "x"], ("1", "x")), + (["None"], None), + ], +) +def test_tuple_from_string(s, expected): + _from_string_test(Tuple, s, expected) + + + "s, expected, value_traits", + [ + (["1", "2", "3"], (1, 2, 3), [Integer(), Integer(), Integer()]), + (["x"], ValueError, [Integer()]), + (["1", "x"], ("1", "x"), [Unicode()]), + (["None"], ("None",), [Unicode(allow_none=False)]), + (["None"], (None,), [Unicode(allow_none=True)]), + ], +) +def test_tuple_items_from_string(s, expected, value_traits): + _from_string_test(Tuple(*value_traits), s, expected) + + + "s, expected", + [ + ("x", "x"), + ("mod.submod", "mod.submod"), + ("not an identifier", TraitError), + ("1", "1"), + ("None", None), + ], +) +def test_object_from_string(s, expected): + _from_string_test(DottedObjectName, s, expected) + + + "s, expected", + [ + ("127.0.0.1:8000", ("127.0.0.1", 8000)), + ("host.tld:80", ("host.tld", 80)), + ("host:notaport", ValueError), + ("127.0.0.1", ValueError), + ("None", None), + ], +) +def test_tcp_from_string(s, expected): + _from_string_test(TCPAddress, s, expected) + + + "s, expected", + [("[]", []), ("{}", "{}")], +) +def test_union_of_list_and_unicode_from_string(s, expected): + _from_string_test(Union([List(), Unicode()]), s, expected) + + + "s, expected", + [("1", 1), ("1.5", 1.5)], +) +def test_union_of_int_and_float_from_string(s, expected): + _from_string_test(Union([Int(), Float()]), s, expected) + + + "s, expected, allow_none", + [("[]", [], False), ("{}", {}, False), ("None", TraitError, False), ("None", None, True)], +) +def test_union_of_list_and_dict_from_string(s, expected, allow_none): + _from_string_test(Union([List(), Dict()], allow_none=allow_none), s, expected) + + +def test_all_attribute(): + """Verify all trait types are added to `traitlets.__all__`""" + names = dir(traitlets) + for name in names: + value = getattr(traitlets, name) + if not name.startswith("_") and isinstance(value, type) and issubclass(value, TraitType): + if name not in traitlets.__all__: + raise ValueError(f"{name} not in __all__") + + for name in traitlets.__all__: + if name not in names: + raise ValueError(f"{name} should be removed from __all__") |
