diff options
author | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:24:06 +0300 |
---|---|---|
committer | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:41:34 +0300 |
commit | e0e3e1717e3d33762ce61950504f9637a6e669ed (patch) | |
tree | bca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/traitlets/py3 | |
parent | 38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff) | |
download | ydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz |
add ydb deps
Diffstat (limited to 'contrib/python/traitlets/py3')
38 files changed, 14497 insertions, 0 deletions
diff --git a/contrib/python/traitlets/py3/.dist-info/METADATA b/contrib/python/traitlets/py3/.dist-info/METADATA new file mode 100644 index 0000000000..1c9d322cc3 --- /dev/null +++ b/contrib/python/traitlets/py3/.dist-info/METADATA @@ -0,0 +1,280 @@ +Metadata-Version: 2.1 +Name: traitlets +Version: 5.9.0 +Summary: Traitlets Python configuration system +Project-URL: Homepage, https://github.com/ipython/traitlets +Author-email: IPython Development Team <ipython-dev@python.org> +License: # Licensing terms + + Traitlets is adapted from enthought.traits, Copyright (c) Enthought, Inc., + under the terms of the Modified BSD License. + + This project is licensed under the terms of the Modified BSD License + (also known as New or Revised or 3-Clause BSD), as follows: + + - Copyright (c) 2001-, IPython Development Team + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + + Neither the name of the IPython Development Team nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ## About the IPython Development Team + + The IPython Development Team is the set of all contributors to the IPython project. + This includes all of the IPython subprojects. + + The core team that coordinates development on GitHub can be found here: + https://github.com/jupyter/. + + ## Our Copyright Policy + + IPython uses a shared copyright model. Each contributor maintains copyright + over their contributions to IPython. But, it is important to note that these + contributions are typically only changes to the repositories. Thus, the IPython + source code, in its entirety is not the copyright of any single person or + institution. Instead, it is the collective copyright of the entire IPython + Development Team. If individual contributors want to maintain a record of what + changes/contributions they have specific copyright on, they should indicate + their copyright in the commit message of the change, when they commit the + change to one of the IPython repositories. + + With this in mind, the following banner should be used in any source code file + to indicate the copyright and license terms: + + ``` + # Copyright (c) IPython Development Team. + # Distributed under the terms of the Modified BSD License. + ``` +License-File: COPYING.md +Keywords: Interactive,Interpreter,Shell,Web +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Science/Research +Classifier: Intended Audience :: System Administrators +Classifier: License :: OSI Approved :: BSD License +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Requires-Python: >=3.7 +Provides-Extra: docs +Requires-Dist: myst-parser; extra == 'docs' +Requires-Dist: pydata-sphinx-theme; extra == 'docs' +Requires-Dist: sphinx; extra == 'docs' +Provides-Extra: test +Requires-Dist: argcomplete>=2.0; extra == 'test' +Requires-Dist: pre-commit; extra == 'test' +Requires-Dist: pytest; extra == 'test' +Requires-Dist: pytest-mock; extra == 'test' +Description-Content-Type: text/markdown + +# Traitlets + +[![Tests](https://github.com/ipython/traitlets/actions/workflows/tests.yml/badge.svg)](https://github.com/ipython/traitlets/actions/workflows/tests.yml) +[![Documentation Status](https://readthedocs.org/projects/traitlets/badge/?version=latest)](https://traitlets.readthedocs.io/en/latest/?badge=latest) +[![codecov](https://codecov.io/gh/ipython/traitlets/branch/main/graph/badge.svg?token=HcsbLGEmI1)](https://codecov.io/gh/ipython/traitlets) +[![Tidelift](https://tidelift.com/subscription/pkg/pypi-traitlets)](https://tidelift.com/badges/package/pypi/traitlets) + +| | | +| ------------- | ------------------------------------ | +| **home** | https://github.com/ipython/traitlets | +| **pypi-repo** | https://pypi.org/project/traitlets/ | +| **docs** | https://traitlets.readthedocs.io/ | +| **license** | Modified BSD License | + +Traitlets is a pure Python library enabling: + +- the enforcement of strong typing for attributes of Python objects + (typed attributes are called _"traits"_); +- dynamically calculated default values; +- automatic validation and coercion of trait attributes when attempting a + change; +- registering for receiving notifications when trait values change; +- reading configuring values from files or from command line + arguments - a distinct layer on top of traitlets, so you may use + traitlets without the configuration machinery. + +Its implementation relies on the [descriptor](https://docs.python.org/howto/descriptor.html) +pattern, and it is a lightweight pure-python alternative of the +[_traits_ library](https://docs.enthought.com/traits/). + +Traitlets powers the configuration system of IPython and Jupyter +and the declarative API of IPython interactive widgets. + +## Installation + +For a local installation, make sure you have +[pip installed](https://pip.pypa.io/en/stable/installing/) and run: + +```bash +pip install traitlets +``` + +For a **development installation**, clone this repository, change into the +`traitlets` root directory, and run pip: + +```bash +git clone https://github.com/ipython/traitlets.git +cd traitlets +pip install -e . +``` + +## Running the tests + +```bash +pip install "traitlets[test]" +py.test traitlets +``` + +## Code Styling + +`traitlets` has adopted automatic code formatting so you shouldn't +need to worry too much about your code style. +As long as your code is valid, +the pre-commit hook should take care of how it should look. + +To install `pre-commit` locally, run the following:: + +``` +pip install pre-commit +pre-commit install +``` + +You can invoke the pre-commit hook by hand at any time with:: + +``` +pre-commit run +``` + +which should run any autoformatting on your code +and tell you about any errors it couldn't fix automatically. +You may also install [black integration](https://github.com/psf/black#editor-integration) +into your text editor to format code automatically. + +If you have already committed files before setting up the pre-commit +hook with `pre-commit install`, you can fix everything up using +`pre-commit run --all-files`. You need to make the fixing commit +yourself after that. + +Some of the hooks only run on CI by default, but you can invoke them by +running with the `--hook-stage manual` argument. + +## Usage + +Any class with trait attributes must inherit from `HasTraits`. +For the list of available trait types and their properties, see the +[Trait Types](https://traitlets.readthedocs.io/en/latest/trait_types.html) +section of the documentation. + +### Dynamic default values + +To calculate a default value dynamically, decorate a method of your class with +`@default({traitname})`. This method will be called on the instance, and +should return the default value. In this example, the `_username_default` +method is decorated with `@default('username')`: + +```Python +import getpass +from traitlets import HasTraits, Unicode, default + +class Identity(HasTraits): + username = Unicode() + + @default('username') + def _username_default(self): + return getpass.getuser() +``` + +### Callbacks when a trait attribute changes + +When a trait changes, an application can follow this trait change with +additional actions. + +To do something when a trait attribute is changed, decorate a method with +[`traitlets.observe()`](https://traitlets.readthedocs.io/en/latest/api.html?highlight=observe#traitlets.observe). +The method will be called with a single argument, a dictionary which contains +an owner, new value, old value, name of the changed trait, and the event type. + +In this example, the `_num_changed` method is decorated with `` @observe(`num`) ``: + +```Python +from traitlets import HasTraits, Integer, observe + +class TraitletsExample(HasTraits): + num = Integer(5, help="a number").tag(config=True) + + @observe('num') + def _num_changed(self, change): + print("{name} changed from {old} to {new}".format(**change)) +``` + +and is passed the following dictionary when called: + +```Python +{ + 'owner': object, # The HasTraits instance + 'new': 6, # The new value + 'old': 5, # The old value + 'name': "foo", # The name of the changed trait + 'type': 'change', # The event type of the notification, usually 'change' +} +``` + +### Validation and coercion + +Each trait type (`Int`, `Unicode`, `Dict` etc.) may have its own validation or +coercion logic. In addition, we can register custom cross-validators +that may depend on the state of other attributes. For example: + +```Python +from traitlets import HasTraits, TraitError, Int, Bool, validate + +class Parity(HasTraits): + value = Int() + parity = Int() + + @validate('value') + def _valid_value(self, proposal): + if proposal['value'] % 2 != self.parity: + raise TraitError('value and parity should be consistent') + return proposal['value'] + + @validate('parity') + def _valid_parity(self, proposal): + parity = proposal['value'] + if parity not in [0, 1]: + raise TraitError('parity should be 0 or 1') + if self.value % 2 != parity: + raise TraitError('value and parity should be consistent') + return proposal['value'] + +parity_check = Parity(value=2) + +# Changing required parity and value together while holding cross validation +with parity_check.hold_trait_notifications(): + parity_check.value = 1 + parity_check.parity = 1 +``` + +However, we **recommend** that custom cross-validators don't modify the state +of the HasTraits instance. diff --git a/contrib/python/traitlets/py3/.dist-info/top_level.txt b/contrib/python/traitlets/py3/.dist-info/top_level.txt new file mode 100644 index 0000000000..adfea9c6eb --- /dev/null +++ b/contrib/python/traitlets/py3/.dist-info/top_level.txt @@ -0,0 +1 @@ +traitlets diff --git a/contrib/python/traitlets/py3/COPYING.md b/contrib/python/traitlets/py3/COPYING.md new file mode 100644 index 0000000000..b4325343eb --- /dev/null +++ b/contrib/python/traitlets/py3/COPYING.md @@ -0,0 +1,64 @@ +# Licensing terms + +Traitlets is adapted from enthought.traits, Copyright (c) Enthought, Inc., +under the terms of the Modified BSD License. + +This project is licensed under the terms of the Modified BSD License +(also known as New or Revised or 3-Clause BSD), as follows: + +- Copyright (c) 2001-, IPython Development Team + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, this +list of conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +Neither the name of the IPython Development Team nor the names of its +contributors may be used to endorse or promote products derived from this +software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +## About the IPython Development Team + +The IPython Development Team is the set of all contributors to the IPython project. +This includes all of the IPython subprojects. + +The core team that coordinates development on GitHub can be found here: +https://github.com/jupyter/. + +## Our Copyright Policy + +IPython uses a shared copyright model. Each contributor maintains copyright +over their contributions to IPython. But, it is important to note that these +contributions are typically only changes to the repositories. Thus, the IPython +source code, in its entirety is not the copyright of any single person or +institution. Instead, it is the collective copyright of the entire IPython +Development Team. If individual contributors want to maintain a record of what +changes/contributions they have specific copyright on, they should indicate +their copyright in the commit message of the change, when they commit the +change to one of the IPython repositories. + +With this in mind, the following banner should be used in any source code file +to indicate the copyright and license terms: + +``` +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +``` diff --git a/contrib/python/traitlets/py3/README.md b/contrib/python/traitlets/py3/README.md new file mode 100644 index 0000000000..e7b32e2b2a --- /dev/null +++ b/contrib/python/traitlets/py3/README.md @@ -0,0 +1,190 @@ +# Traitlets + +[![Tests](https://github.com/ipython/traitlets/actions/workflows/tests.yml/badge.svg)](https://github.com/ipython/traitlets/actions/workflows/tests.yml) +[![Documentation Status](https://readthedocs.org/projects/traitlets/badge/?version=latest)](https://traitlets.readthedocs.io/en/latest/?badge=latest) +[![codecov](https://codecov.io/gh/ipython/traitlets/branch/main/graph/badge.svg?token=HcsbLGEmI1)](https://codecov.io/gh/ipython/traitlets) +[![Tidelift](https://tidelift.com/subscription/pkg/pypi-traitlets)](https://tidelift.com/badges/package/pypi/traitlets) + +| | | +| ------------- | ------------------------------------ | +| **home** | https://github.com/ipython/traitlets | +| **pypi-repo** | https://pypi.org/project/traitlets/ | +| **docs** | https://traitlets.readthedocs.io/ | +| **license** | Modified BSD License | + +Traitlets is a pure Python library enabling: + +- the enforcement of strong typing for attributes of Python objects + (typed attributes are called _"traits"_); +- dynamically calculated default values; +- automatic validation and coercion of trait attributes when attempting a + change; +- registering for receiving notifications when trait values change; +- reading configuring values from files or from command line + arguments - a distinct layer on top of traitlets, so you may use + traitlets without the configuration machinery. + +Its implementation relies on the [descriptor](https://docs.python.org/howto/descriptor.html) +pattern, and it is a lightweight pure-python alternative of the +[_traits_ library](https://docs.enthought.com/traits/). + +Traitlets powers the configuration system of IPython and Jupyter +and the declarative API of IPython interactive widgets. + +## Installation + +For a local installation, make sure you have +[pip installed](https://pip.pypa.io/en/stable/installing/) and run: + +```bash +pip install traitlets +``` + +For a **development installation**, clone this repository, change into the +`traitlets` root directory, and run pip: + +```bash +git clone https://github.com/ipython/traitlets.git +cd traitlets +pip install -e . +``` + +## Running the tests + +```bash +pip install "traitlets[test]" +py.test traitlets +``` + +## Code Styling + +`traitlets` has adopted automatic code formatting so you shouldn't +need to worry too much about your code style. +As long as your code is valid, +the pre-commit hook should take care of how it should look. + +To install `pre-commit` locally, run the following:: + +``` +pip install pre-commit +pre-commit install +``` + +You can invoke the pre-commit hook by hand at any time with:: + +``` +pre-commit run +``` + +which should run any autoformatting on your code +and tell you about any errors it couldn't fix automatically. +You may also install [black integration](https://github.com/psf/black#editor-integration) +into your text editor to format code automatically. + +If you have already committed files before setting up the pre-commit +hook with `pre-commit install`, you can fix everything up using +`pre-commit run --all-files`. You need to make the fixing commit +yourself after that. + +Some of the hooks only run on CI by default, but you can invoke them by +running with the `--hook-stage manual` argument. + +## Usage + +Any class with trait attributes must inherit from `HasTraits`. +For the list of available trait types and their properties, see the +[Trait Types](https://traitlets.readthedocs.io/en/latest/trait_types.html) +section of the documentation. + +### Dynamic default values + +To calculate a default value dynamically, decorate a method of your class with +`@default({traitname})`. This method will be called on the instance, and +should return the default value. In this example, the `_username_default` +method is decorated with `@default('username')`: + +```Python +import getpass +from traitlets import HasTraits, Unicode, default + +class Identity(HasTraits): + username = Unicode() + + @default('username') + def _username_default(self): + return getpass.getuser() +``` + +### Callbacks when a trait attribute changes + +When a trait changes, an application can follow this trait change with +additional actions. + +To do something when a trait attribute is changed, decorate a method with +[`traitlets.observe()`](https://traitlets.readthedocs.io/en/latest/api.html?highlight=observe#traitlets.observe). +The method will be called with a single argument, a dictionary which contains +an owner, new value, old value, name of the changed trait, and the event type. + +In this example, the `_num_changed` method is decorated with `` @observe(`num`) ``: + +```Python +from traitlets import HasTraits, Integer, observe + +class TraitletsExample(HasTraits): + num = Integer(5, help="a number").tag(config=True) + + @observe('num') + def _num_changed(self, change): + print("{name} changed from {old} to {new}".format(**change)) +``` + +and is passed the following dictionary when called: + +```Python +{ + 'owner': object, # The HasTraits instance + 'new': 6, # The new value + 'old': 5, # The old value + 'name': "foo", # The name of the changed trait + 'type': 'change', # The event type of the notification, usually 'change' +} +``` + +### Validation and coercion + +Each trait type (`Int`, `Unicode`, `Dict` etc.) may have its own validation or +coercion logic. In addition, we can register custom cross-validators +that may depend on the state of other attributes. For example: + +```Python +from traitlets import HasTraits, TraitError, Int, Bool, validate + +class Parity(HasTraits): + value = Int() + parity = Int() + + @validate('value') + def _valid_value(self, proposal): + if proposal['value'] % 2 != self.parity: + raise TraitError('value and parity should be consistent') + return proposal['value'] + + @validate('parity') + def _valid_parity(self, proposal): + parity = proposal['value'] + if parity not in [0, 1]: + raise TraitError('parity should be 0 or 1') + if self.value % 2 != parity: + raise TraitError('value and parity should be consistent') + return proposal['value'] + +parity_check = Parity(value=2) + +# Changing required parity and value together while holding cross validation +with parity_check.hold_trait_notifications(): + parity_check.value = 1 + parity_check.parity = 1 +``` + +However, we **recommend** that custom cross-validators don't modify the state +of the HasTraits instance. diff --git a/contrib/python/traitlets/py3/tests/ya.make b/contrib/python/traitlets/py3/tests/ya.make new file mode 100644 index 0000000000..6a5cd7cf46 --- /dev/null +++ b/contrib/python/traitlets/py3/tests/ya.make @@ -0,0 +1,22 @@ +PY3TEST() + +PEERDIR( + contrib/python/traitlets +) + +SRCDIR(contrib/python/traitlets/py3/traitlets) + +TEST_SRCS( + config/tests/test_application.py + config/tests/test_configurable.py + config/tests/test_loader.py + tests/test_traitlets.py + tests/test_traitlets_enum.py + utils/tests/test_bunch.py + utils/tests/test_decorators.py + utils/tests/test_importstring.py +) + +NO_LINT() + +END() diff --git a/contrib/python/traitlets/py3/traitlets/__init__.py b/contrib/python/traitlets/py3/traitlets/__init__.py new file mode 100644 index 0000000000..be890981f1 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/__init__.py @@ -0,0 +1,32 @@ +"""Traitlets Python configuration system""" +from warnings import warn + +from . import traitlets +from ._version import __version__, version_info +from .traitlets import * +from .utils.bunch import Bunch +from .utils.decorators import signature_has_traits +from .utils.importstring import import_item + +__all__ = [ + "traitlets", + "__version__", + "version_info", + "Bunch", + "signature_has_traits", + "import_item", + "Sentinel", +] + + +class Sentinel(traitlets.Sentinel): # type:ignore[name-defined] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warn( + """ + Sentinel is not a public part of the traitlets API. + It was published by mistake, and may be removed in the future. + """, + DeprecationWarning, + stacklevel=2, + ) diff --git a/contrib/python/traitlets/py3/traitlets/_version.py b/contrib/python/traitlets/py3/traitlets/_version.py new file mode 100644 index 0000000000..6e09af1b91 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/_version.py @@ -0,0 +1,17 @@ +""" +handle the current version info of traitlets. +""" +import re +from typing import List + +# Version string must appear intact for hatch versioning +__version__ = "5.9.0" + +# Build up version_info tuple for backwards compatibility +pattern = r"(?P<major>\d+).(?P<minor>\d+).(?P<patch>\d+)(?P<rest>.*)" +match = re.match(pattern, __version__) +assert match is not None +parts: List[object] = [int(match[part]) for part in ["major", "minor", "patch"]] +if match["rest"]: + parts.append(match["rest"]) +version_info = tuple(parts) diff --git a/contrib/python/traitlets/py3/traitlets/config/__init__.py b/contrib/python/traitlets/py3/traitlets/config/__init__.py new file mode 100644 index 0000000000..699b12b80a --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +from .application import * +from .configurable import * +from .loader import Config + +__all__ = [ # noqa + "Config", + "Application", + "ApplicationError", + "LevelFormatter", + "configurable", + "ConfigurableError", + "MultipleInstanceError", + "LoggingConfigurable", + "SingletonConfigurable", +] diff --git a/contrib/python/traitlets/py3/traitlets/config/application.py b/contrib/python/traitlets/py3/traitlets/config/application.py new file mode 100644 index 0000000000..3cffa6b008 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/application.py @@ -0,0 +1,1097 @@ +"""A base class for a configurable application.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + + +import functools +import json +import logging +import os +import pprint +import re +import sys +import typing as t +from collections import OrderedDict, defaultdict +from contextlib import suppress +from copy import deepcopy +from logging.config import dictConfig +from textwrap import dedent +from typing import Any, Callable, TypeVar, cast + +from traitlets.config.configurable import Configurable, SingletonConfigurable +from traitlets.config.loader import ( + ArgumentError, + Config, + ConfigFileNotFound, + JSONFileConfigLoader, + KVArgParseConfigLoader, + PyFileConfigLoader, +) +from traitlets.traitlets import ( + Bool, + Dict, + Enum, + Instance, + List, + TraitError, + Unicode, + default, + observe, + observe_compat, +) +from traitlets.utils.nested_update import nested_update +from traitlets.utils.text import indent, wrap_paragraphs + +from ..utils import cast_unicode +from ..utils.importstring import import_item + +# ----------------------------------------------------------------------------- +# Descriptions for the various sections +# ----------------------------------------------------------------------------- +# merge flags&aliases into options +option_description = """ +The options below are convenience aliases to configurable class-options, +as listed in the "Equivalent to" description-line of the aliases. +To see all configurable class-options for some <cmd>, use: + <cmd> --help-all +""".strip() # trim newlines of front and back + +keyvalue_description = """ +The command-line option below sets the respective configurable class-parameter: + --Class.parameter=value +This line is evaluated in Python, so simple expressions are allowed. +For instance, to set `C.a=[0,1,2]`, you may type this: + --C.a='range(3)' +""".strip() # trim newlines of front and back + +# sys.argv can be missing, for example when python is embedded. See the docs +# for details: http://docs.python.org/2/c-api/intro.html#embedding-python +if not hasattr(sys, "argv"): + sys.argv = [""] + +subcommand_description = """ +Subcommands are launched as `{app} cmd [args]`. For information on using +subcommand 'cmd', do: `{app} cmd -h`. +""" +# get running program name + +# ----------------------------------------------------------------------------- +# Application class +# ----------------------------------------------------------------------------- + + +_envvar = os.environ.get("TRAITLETS_APPLICATION_RAISE_CONFIG_FILE_ERROR", "") +if _envvar.lower() in {"1", "true"}: + TRAITLETS_APPLICATION_RAISE_CONFIG_FILE_ERROR = True +elif _envvar.lower() in {"0", "false", ""}: + TRAITLETS_APPLICATION_RAISE_CONFIG_FILE_ERROR = False +else: + raise ValueError( + "Unsupported value for environment variable: 'TRAITLETS_APPLICATION_RAISE_CONFIG_FILE_ERROR' is set to '%s' which is none of {'0', '1', 'false', 'true', ''}." + % _envvar + ) + + +IS_PYTHONW = sys.executable and sys.executable.endswith("pythonw.exe") + +T = TypeVar("T", bound=Callable[..., Any]) + + +def catch_config_error(method: T) -> T: + """Method decorator for catching invalid config (Trait/ArgumentErrors) during init. + + On a TraitError (generally caused by bad config), this will print the trait's + message, and exit the app. + + For use on init methods, to prevent invoking excepthook on invalid input. + """ + + @functools.wraps(method) + def inner(app, *args, **kwargs): + try: + return method(app, *args, **kwargs) + except (TraitError, ArgumentError) as e: + app.log.fatal("Bad config encountered during initialization: %s", e) + app.log.debug("Config at the time: %s", app.config) + app.exit(1) + + return cast(T, inner) + + +class ApplicationError(Exception): + pass + + +class LevelFormatter(logging.Formatter): + """Formatter with additional `highlevel` record + + This field is empty if log level is less than highlevel_limit, + otherwise it is formatted with self.highlevel_format. + + Useful for adding 'WARNING' to warning messages, + without adding 'INFO' to info, etc. + """ + + highlevel_limit = logging.WARN + highlevel_format = " %(levelname)s |" + + def format(self, record): + if record.levelno >= self.highlevel_limit: + record.highlevel = self.highlevel_format % record.__dict__ + else: + record.highlevel = "" + return super().format(record) + + +class Application(SingletonConfigurable): + """A singleton application with full configuration support.""" + + # The name of the application, will usually match the name of the command + # line application + name: t.Union[str, Unicode] = Unicode("application") + + # The description of the application that is printed at the beginning + # of the help. + description: t.Union[str, Unicode] = Unicode("This is an application.") + # default section descriptions + option_description: t.Union[str, Unicode] = Unicode(option_description) + keyvalue_description: t.Union[str, Unicode] = Unicode(keyvalue_description) + subcommand_description: t.Union[str, Unicode] = Unicode(subcommand_description) + + python_config_loader_class = PyFileConfigLoader + json_config_loader_class = JSONFileConfigLoader + + # The usage and example string that goes at the end of the help string. + examples: t.Union[str, Unicode] = Unicode() + + # A sequence of Configurable subclasses whose config=True attributes will + # be exposed at the command line. + classes: t.List[t.Type[t.Any]] = [] + + def _classes_inc_parents(self, classes=None): + """Iterate through configurable classes, including configurable parents + + :param classes: + The list of classes to iterate; if not set, uses :attr:`classes`. + + Children should always be after parents, and each class should only be + yielded once. + """ + if classes is None: + classes = self.classes + + seen = set() + for c in classes: + # We want to sort parents before children, so we reverse the MRO + for parent in reversed(c.mro()): + if issubclass(parent, Configurable) and (parent not in seen): + seen.add(parent) + yield parent + + # The version string of this application. + version: t.Union[str, Unicode] = Unicode("0.0") + + # the argv used to initialize the application + argv: t.Union[t.List[str], List] = List() + + # Whether failing to load config files should prevent startup + raise_config_file_errors: t.Union[bool, Bool] = Bool( + TRAITLETS_APPLICATION_RAISE_CONFIG_FILE_ERROR + ) + + # The log level for the application + log_level: t.Union[str, int, Enum] = Enum( + (0, 10, 20, 30, 40, 50, "DEBUG", "INFO", "WARN", "ERROR", "CRITICAL"), + default_value=logging.WARN, + help="Set the log level by value or name.", + ).tag(config=True) + + _log_formatter_cls = LevelFormatter + + log_datefmt: t.Union[str, Unicode] = Unicode( + "%Y-%m-%d %H:%M:%S", help="The date format used by logging formatters for %(asctime)s" + ).tag(config=True) + + log_format: t.Union[str, Unicode] = Unicode( + "[%(name)s]%(highlevel)s %(message)s", + help="The Logging format template", + ).tag(config=True) + + def get_default_logging_config(self): + """Return the base logging configuration. + + The default is to log to stderr using a StreamHandler, if no default + handler already exists. + + The log handler level starts at logging.WARN, but this can be adjusted + by setting the ``log_level`` attribute. + + The ``logging_config`` trait is merged into this allowing for finer + control of logging. + + """ + config: t.Dict[str, t.Any] = { + "version": 1, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "console", + "level": logging.getLevelName(self.log_level), + "stream": "ext://sys.stderr", + }, + }, + "formatters": { + "console": { + "class": ( + f"{self._log_formatter_cls.__module__}" + f".{self._log_formatter_cls.__name__}" + ), + "format": self.log_format, + "datefmt": self.log_datefmt, + }, + }, + "loggers": { + self.__class__.__name__: { + "level": "DEBUG", + "handlers": ["console"], + } + }, + "disable_existing_loggers": False, + } + + if IS_PYTHONW: + # disable logging + # (this should really go to a file, but file-logging is only + # hooked up in parallel applications) + del config["handlers"] + del config["loggers"] + + return config + + @observe("log_datefmt", "log_format", "log_level", "logging_config") + def _observe_logging_change(self, change): + # convert log level strings to ints + log_level = self.log_level + if isinstance(log_level, str): + self.log_level = getattr(logging, log_level) + self._configure_logging() + + @observe("log", type="default") + def _observe_logging_default(self, change): + self._configure_logging() + + def _configure_logging(self): + config = self.get_default_logging_config() + nested_update(config, self.logging_config or {}) + dictConfig(config) + # make a note that we have configured logging + self._logging_configured = True + + @default("log") + def _log_default(self): + """Start logging for this application.""" + log = logging.getLogger(self.__class__.__name__) + log.propagate = False + _log = log # copied from Logger.hasHandlers() (new in Python 3.2) + while _log: + if _log.handlers: + return log + if not _log.propagate: + break + else: + _log = _log.parent # type:ignore[assignment] + return log + + logging_config = Dict( + help=""" + Configure additional log handlers. + + The default stderr logs handler is configured by the + log_level, log_datefmt and log_format settings. + + This configuration can be used to configure additional handlers + (e.g. to output the log to a file) or for finer control over the + default handlers. + + If provided this should be a logging configuration dictionary, for + more information see: + https://docs.python.org/3/library/logging.config.html#logging-config-dictschema + + This dictionary is merged with the base logging configuration which + defines the following: + + * A logging formatter intended for interactive use called + ``console``. + * A logging handler that writes to stderr called + ``console`` which uses the formatter ``console``. + * A logger with the name of this application set to ``DEBUG`` + level. + + This example adds a new handler that writes to a file: + + .. code-block:: python + + c.Application.logging_config = { + 'handlers': { + 'file': { + 'class': 'logging.FileHandler', + 'level': 'DEBUG', + 'filename': '<path/to/file>', + } + }, + 'loggers': { + '<application-name>': { + 'level': 'DEBUG', + # NOTE: if you don't list the default "console" + # handler here then it will be disabled + 'handlers': ['console', 'file'], + }, + } + } + + """, + ).tag(config=True) + + #: the alias map for configurables + #: Keys might strings or tuples for additional options; single-letter alias accessed like `-v`. + #: Values might be like "Class.trait" strings of two-tuples: (Class.trait, help-text), + # or just the "Class.trait" string, in which case the help text is inferred from the + # corresponding trait + aliases: t.Dict[t.Union[str, t.Tuple[str, ...]], t.Union[str, t.Tuple[str, str]]] = { + "log-level": "Application.log_level" + } + + # flags for loading Configurables or store_const style flags + # flags are loaded from this dict by '--key' flags + # this must be a dict of two-tuples, the first element being the Config/dict + # and the second being the help string for the flag + flags: t.Dict[ + t.Union[str, t.Tuple[str, ...]], t.Tuple[t.Union[t.Dict[str, t.Any], Config], str] + ] = { + "debug": ( + { + "Application": { + "log_level": logging.DEBUG, + }, + }, + "Set log-level to debug, for the most verbose logging.", + ), + "show-config": ( + { + "Application": { + "show_config": True, + }, + }, + "Show the application's configuration (human-readable format)", + ), + "show-config-json": ( + { + "Application": { + "show_config_json": True, + }, + }, + "Show the application's configuration (json format)", + ), + } + + # subcommands for launching other applications + # if this is not empty, this will be a parent Application + # this must be a dict of two-tuples, + # the first element being the application class/import string + # and the second being the help string for the subcommand + subcommands: t.Union[t.Dict[str, t.Tuple[t.Any, str]], Dict] = Dict() + # parse_command_line will initialize a subapp, if requested + subapp = Instance("traitlets.config.application.Application", allow_none=True) + + # extra command-line arguments that don't set config values + extra_args: t.Union[t.List[str], List] = List(Unicode()) + + cli_config = Instance( + Config, + (), + {}, + help="""The subset of our configuration that came from the command-line + + We re-load this configuration after loading config files, + to ensure that it maintains highest priority. + """, + ) + + _loaded_config_files = List() + + show_config: t.Union[bool, Bool] = Bool( + help="Instead of starting the Application, dump configuration to stdout" + ).tag(config=True) + + show_config_json: t.Union[bool, Bool] = Bool( + help="Instead of starting the Application, dump configuration to stdout (as JSON)" + ).tag(config=True) + + @observe("show_config_json") + def _show_config_json_changed(self, change): + self.show_config = change.new + + @observe("show_config") + def _show_config_changed(self, change): + if change.new: + self._save_start = self.start + self.start = self.start_show_config # type:ignore[assignment] + + def __init__(self, **kwargs): + SingletonConfigurable.__init__(self, **kwargs) + # Ensure my class is in self.classes, so my attributes appear in command line + # options and config files. + cls = self.__class__ + if cls not in self.classes: + if self.classes is cls.classes: + # class attr, assign instead of insert + self.classes = [cls] + self.classes + else: + self.classes.insert(0, self.__class__) + + @observe("config") + @observe_compat + def _config_changed(self, change): + super()._config_changed(change) + self.log.debug("Config changed: %r", change.new) + + @catch_config_error + def initialize(self, argv=None): + """Do the basic steps to configure me. + + Override in subclasses. + """ + self.parse_command_line(argv) + + def start(self): + """Start the app mainloop. + + Override in subclasses. + """ + if self.subapp is not None: + return self.subapp.start() + + def start_show_config(self): + """start function used when show_config is True""" + config = self.config.copy() + # exclude show_config flags from displayed config + for cls in self.__class__.mro(): + if cls.__name__ in config: + cls_config = config[cls.__name__] + cls_config.pop("show_config", None) + cls_config.pop("show_config_json", None) + + if self.show_config_json: + json.dump(config, sys.stdout, indent=1, sort_keys=True, default=repr) + # add trailing newline + sys.stdout.write("\n") + return + + if self._loaded_config_files: + print("Loaded config files:") + for f in self._loaded_config_files: + print(" " + f) + print() + + for classname in sorted(config): + class_config = config[classname] + if not class_config: + continue + print(classname) + pformat_kwargs: t.Dict[str, t.Any] = dict(indent=4, compact=True) + + for traitname in sorted(class_config): + value = class_config[traitname] + print( + " .{} = {}".format( + traitname, + pprint.pformat(value, **pformat_kwargs), + ) + ) + + def print_alias_help(self): + """Print the alias parts of the help.""" + print("\n".join(self.emit_alias_help())) + + def emit_alias_help(self): + """Yield the lines for alias part of the help.""" + if not self.aliases: + return + + classdict = {} + for cls in self.classes: + # include all parents (up to, but excluding Configurable) in available names + for c in cls.mro()[:-3]: + classdict[c.__name__] = c + + fhelp: t.Optional[str] + for alias, longname in self.aliases.items(): + try: + if isinstance(longname, tuple): + longname, fhelp = longname + else: + fhelp = None + classname, traitname = longname.split(".")[-2:] + longname = classname + "." + traitname + cls = classdict[classname] + + trait = cls.class_traits(config=True)[traitname] + fhelp = cls.class_get_trait_help(trait, helptext=fhelp).splitlines() + + if not isinstance(alias, tuple): + alias = (alias,) + alias = sorted(alias, key=len) # type:ignore[assignment] + alias = ", ".join(("--%s" if len(m) > 1 else "-%s") % m for m in alias) + + # reformat first line + assert fhelp is not None + fhelp[0] = fhelp[0].replace("--" + longname, alias) # type:ignore + yield from fhelp + yield indent("Equivalent to: [--%s]" % longname) + except Exception as ex: + self.log.error("Failed collecting help-message for alias %r, due to: %s", alias, ex) + raise + + def print_flag_help(self): + """Print the flag part of the help.""" + print("\n".join(self.emit_flag_help())) + + def emit_flag_help(self): + """Yield the lines for the flag part of the help.""" + if not self.flags: + return + + for flags, (cfg, fhelp) in self.flags.items(): + try: + if not isinstance(flags, tuple): + flags = (flags,) + flags = sorted(flags, key=len) # type:ignore[assignment] + flags = ", ".join(("--%s" if len(m) > 1 else "-%s") % m for m in flags) + yield flags + yield indent(dedent(fhelp.strip())) + cfg_list = " ".join( + f"--{clname}.{prop}={val}" + for clname, props_dict in cfg.items() + for prop, val in props_dict.items() + ) + cfg_txt = "Equivalent to: [%s]" % cfg_list + yield indent(dedent(cfg_txt)) + except Exception as ex: + self.log.error("Failed collecting help-message for flag %r, due to: %s", flags, ex) + raise + + def print_options(self): + """Print the options part of the help.""" + print("\n".join(self.emit_options_help())) + + def emit_options_help(self): + """Yield the lines for the options part of the help.""" + if not self.flags and not self.aliases: + return + header = "Options" + yield header + yield "=" * len(header) + for p in wrap_paragraphs(self.option_description): + yield p + yield "" + + yield from self.emit_flag_help() + yield from self.emit_alias_help() + yield "" + + def print_subcommands(self): + """Print the subcommand part of the help.""" + print("\n".join(self.emit_subcommands_help())) + + def emit_subcommands_help(self): + """Yield the lines for the subcommand part of the help.""" + if not self.subcommands: + return + + header = "Subcommands" + yield header + yield "=" * len(header) + for p in wrap_paragraphs(self.subcommand_description.format(app=self.name)): + yield p + yield "" + for subc, (_, help) in self.subcommands.items(): + yield subc + if help: + yield indent(dedent(help.strip())) + yield "" + + def emit_help_epilogue(self, classes): + """Yield the very bottom lines of the help message. + + If classes=False (the default), print `--help-all` msg. + """ + if not classes: + yield "To see all available configurables, use `--help-all`." + yield "" + + def print_help(self, classes=False): + """Print the help for each Configurable class in self.classes. + + If classes=False (the default), only flags and aliases are printed. + """ + print("\n".join(self.emit_help(classes=classes))) + + def emit_help(self, classes=False): + """Yield the help-lines for each Configurable class in self.classes. + + If classes=False (the default), only flags and aliases are printed. + """ + yield from self.emit_description() + yield from self.emit_subcommands_help() + yield from self.emit_options_help() + + if classes: + help_classes = self._classes_with_config_traits() + if help_classes: + yield "Class options" + yield "=============" + for p in wrap_paragraphs(self.keyvalue_description): + yield p + yield "" + + for cls in help_classes: + yield cls.class_get_help() + yield "" + yield from self.emit_examples() + + yield from self.emit_help_epilogue(classes) + + def document_config_options(self): + """Generate rST format documentation for the config options this application + + Returns a multiline string. + """ + return "\n".join(c.class_config_rst_doc() for c in self._classes_inc_parents()) + + def print_description(self): + """Print the application description.""" + print("\n".join(self.emit_description())) + + def emit_description(self): + """Yield lines with the application description.""" + for p in wrap_paragraphs(self.description or self.__doc__ or ""): + yield p + yield "" + + def print_examples(self): + """Print usage and examples (see `emit_examples()`).""" + print("\n".join(self.emit_examples())) + + def emit_examples(self): + """Yield lines with the usage and examples. + + This usage string goes at the end of the command line help string + and should contain examples of the application's usage. + """ + if self.examples: + yield "Examples" + yield "--------" + yield "" + yield indent(dedent(self.examples.strip())) + yield "" + + def print_version(self): + """Print the version string.""" + print(self.version) + + @catch_config_error + def initialize_subcommand(self, subc, argv=None): + """Initialize a subcommand with argv.""" + val = self.subcommands.get(subc) + assert val is not None + subapp, _ = val + + if isinstance(subapp, str): + subapp = import_item(subapp) + + # Cannot issubclass() on a non-type (SOhttp://stackoverflow.com/questions/8692430) + if isinstance(subapp, type) and issubclass(subapp, Application): + # Clear existing instances before... + self.__class__.clear_instance() + # instantiating subapp... + self.subapp = subapp.instance(parent=self) + elif callable(subapp): + # or ask factory to create it... + self.subapp = subapp(self) # type:ignore[call-arg] + else: + raise AssertionError("Invalid mappings for subcommand '%s'!" % subc) + + # ... and finally initialize subapp. + self.subapp.initialize(argv) + + def flatten_flags(self): + """Flatten flags and aliases for loaders, so cl-args override as expected. + + This prevents issues such as an alias pointing to InteractiveShell, + but a config file setting the same trait in TerminalInteraciveShell + getting inappropriate priority over the command-line arg. + Also, loaders expect ``(key: longname)`` and not ``key: (longname, help)`` items. + + Only aliases with exactly one descendent in the class list + will be promoted. + + """ + # build a tree of classes in our list that inherit from a particular + # it will be a dict by parent classname of classes in our list + # that are descendents + mro_tree = defaultdict(list) + for cls in self.classes: + clsname = cls.__name__ + for parent in cls.mro()[1:-3]: + # exclude cls itself and Configurable,HasTraits,object + mro_tree[parent.__name__].append(clsname) + # flatten aliases, which have the form: + # { 'alias' : 'Class.trait' } + aliases: t.Dict[str, str] = {} + for alias, longname in self.aliases.items(): + if isinstance(longname, tuple): + longname, _ = longname + cls, trait = longname.split(".", 1) # type:ignore + children = mro_tree[cls] # type:ignore[index] + if len(children) == 1: + # exactly one descendent, promote alias + cls = children[0] # type:ignore[assignment] + if not isinstance(aliases, tuple): + alias = (alias,) # type:ignore[assignment] + for al in alias: + aliases[al] = ".".join([cls, trait]) # type:ignore[list-item] + + # flatten flags, which are of the form: + # { 'key' : ({'Cls' : {'trait' : value}}, 'help')} + flags = {} + for key, (flagdict, help) in self.flags.items(): + newflag: t.Dict[t.Any, t.Any] = {} + for cls, subdict in flagdict.items(): # type:ignore + children = mro_tree[cls] # type:ignore[index] + # exactly one descendent, promote flag section + if len(children) == 1: + cls = children[0] # type:ignore[assignment] + + if cls in newflag: + newflag[cls].update(subdict) + else: + newflag[cls] = subdict + + if not isinstance(key, tuple): + key = (key,) + for k in key: + flags[k] = (newflag, help) + return flags, aliases + + def _create_loader(self, argv, aliases, flags, classes): + return KVArgParseConfigLoader( + argv, aliases, flags, classes=classes, log=self.log, subcommands=self.subcommands + ) + + @classmethod + def _get_sys_argv(cls, check_argcomplete: bool = False) -> t.List[str]: + """Get `sys.argv` or equivalent from `argcomplete` + + `argcomplete`'s strategy is to call the python script with no arguments, + so ``len(sys.argv) == 1``, and run until the `ArgumentParser` is constructed + and determine what completions are available. + + On the other hand, `traitlet`'s subcommand-handling strategy is to check + ``sys.argv[1]`` and see if it matches a subcommand, and if so then dynamically + load the subcommand app and initialize it with ``sys.argv[1:]``. + + This helper method helps to take the current tokens for `argcomplete` and pass + them through as `argv`. + """ + if check_argcomplete and "_ARGCOMPLETE" in os.environ: + try: + from traitlets.config.argcomplete_config import get_argcomplete_cwords + + cwords = get_argcomplete_cwords() + assert cwords is not None + return cwords + except (ImportError, ModuleNotFoundError): + pass + return sys.argv + + @classmethod + def _handle_argcomplete_for_subcommand(cls): + """Helper for `argcomplete` to recognize `traitlets` subcommands + + `argcomplete` does not know that `traitlets` has already consumed subcommands, + as it only "sees" the final `argparse.ArgumentParser` that is constructed. + (Indeed `KVArgParseConfigLoader` does not get passed subcommands at all currently.) + We explicitly manipulate the environment variables used internally by `argcomplete` + to get it to skip over the subcommand tokens. + """ + if "_ARGCOMPLETE" not in os.environ: + return + + try: + from traitlets.config.argcomplete_config import increment_argcomplete_index + + increment_argcomplete_index() + except (ImportError, ModuleNotFoundError): + pass + + @catch_config_error + def parse_command_line(self, argv=None): + """Parse the command line arguments.""" + assert not isinstance(argv, str) + if argv is None: + argv = self._get_sys_argv(check_argcomplete=bool(self.subcommands))[1:] + self.argv = [cast_unicode(arg) for arg in argv] + + if argv and argv[0] == "help": + # turn `ipython help notebook` into `ipython notebook -h` + argv = argv[1:] + ["-h"] + + if self.subcommands and len(argv) > 0: + # we have subcommands, and one may have been specified + subc, subargv = argv[0], argv[1:] + if re.match(r"^\w(\-?\w)*$", subc) and subc in self.subcommands: + # it's a subcommand, and *not* a flag or class parameter + self._handle_argcomplete_for_subcommand() + return self.initialize_subcommand(subc, subargv) + + # Arguments after a '--' argument are for the script IPython may be + # about to run, not IPython iteslf. For arguments parsed here (help and + # version), we want to only search the arguments up to the first + # occurrence of '--', which we're calling interpreted_argv. + try: + interpreted_argv = argv[: argv.index("--")] + except ValueError: + interpreted_argv = argv + + if any(x in interpreted_argv for x in ("-h", "--help-all", "--help")): + self.print_help("--help-all" in interpreted_argv) + self.exit(0) + + if "--version" in interpreted_argv or "-V" in interpreted_argv: + self.print_version() + self.exit(0) + + # flatten flags&aliases, so cl-args get appropriate priority: + flags, aliases = self.flatten_flags() + classes = tuple(self._classes_with_config_traits()) + loader = self._create_loader(argv, aliases, flags, classes=classes) + try: + self.cli_config = deepcopy(loader.load_config()) + except SystemExit: + # traitlets 5: no longer print help output on error + # help output is huge, and comes after the error + raise + self.update_config(self.cli_config) + # store unparsed args in extra_args + self.extra_args = loader.extra_args + + @classmethod + def _load_config_files(cls, basefilename, path=None, log=None, raise_config_file_errors=False): + """Load config files (py,json) by filename and path. + + yield each config object in turn. + """ + + if not isinstance(path, list): + path = [path] + for current in reversed(path): + # path list is in descending priority order, so load files backwards: + pyloader = cls.python_config_loader_class(basefilename + ".py", path=current, log=log) + if log: + log.debug("Looking for %s in %s", basefilename, current or os.getcwd()) + jsonloader = cls.json_config_loader_class(basefilename + ".json", path=current, log=log) + loaded: t.List[t.Any] = [] + filenames: t.List[str] = [] + for loader in [pyloader, jsonloader]: + config = None + try: + config = loader.load_config() + except ConfigFileNotFound: + pass + except Exception: + # try to get the full filename, but it will be empty in the + # unlikely event that the error raised before filefind finished + filename = loader.full_filename or basefilename + # problem while running the file + if raise_config_file_errors: + raise + if log: + log.error("Exception while loading config file %s", filename, exc_info=True) + else: + if log: + log.debug("Loaded config file: %s", loader.full_filename) + if config: + for filename, earlier_config in zip(filenames, loaded): + collisions = earlier_config.collisions(config) + if collisions and log: + log.warning( + "Collisions detected in {0} and {1} config files." + " {1} has higher priority: {2}".format( + filename, + loader.full_filename, + json.dumps(collisions, indent=2), + ) + ) + yield (config, loader.full_filename) + loaded.append(config) + filenames.append(loader.full_filename) + + @property + def loaded_config_files(self): + """Currently loaded configuration files""" + return self._loaded_config_files[:] + + @catch_config_error + def load_config_file(self, filename, path=None): + """Load config files by filename and path.""" + filename, ext = os.path.splitext(filename) + new_config = Config() + for (config, fname) in self._load_config_files( + filename, + path=path, + log=self.log, + raise_config_file_errors=self.raise_config_file_errors, + ): + new_config.merge(config) + if ( + fname not in self._loaded_config_files + ): # only add to list of loaded files if not previously loaded + self._loaded_config_files.append(fname) + # add self.cli_config to preserve CLI config priority + new_config.merge(self.cli_config) + self.update_config(new_config) + + def _classes_with_config_traits(self, classes=None): + """ + Yields only classes with configurable traits, and their subclasses. + + :param classes: + The list of classes to iterate; if not set, uses :attr:`classes`. + + Thus, produced sample config-file will contain all classes + on which a trait-value may be overridden: + + - either on the class owning the trait, + - or on its subclasses, even if those subclasses do not define + any traits themselves. + """ + if classes is None: + classes = self.classes + + cls_to_config = OrderedDict( + (cls, bool(cls.class_own_traits(config=True))) + for cls in self._classes_inc_parents(classes) + ) + + def is_any_parent_included(cls): + return any(b in cls_to_config and cls_to_config[b] for b in cls.__bases__) + + # Mark "empty" classes for inclusion if their parents own-traits, + # and loop until no more classes gets marked. + # + while True: + to_incl_orig = cls_to_config.copy() + cls_to_config = OrderedDict( + (cls, inc_yes or is_any_parent_included(cls)) + for cls, inc_yes in cls_to_config.items() + ) + if cls_to_config == to_incl_orig: + break + for cl, inc_yes in cls_to_config.items(): + if inc_yes: + yield cl + + def generate_config_file(self, classes=None): + """generate default config file from Configurables""" + lines = ["# Configuration file for %s." % self.name] + lines.append("") + lines.append("c = get_config() #" + "noqa") + lines.append("") + classes = self.classes if classes is None else classes + config_classes = list(self._classes_with_config_traits(classes)) + for cls in config_classes: + lines.append(cls.class_config_section(config_classes)) + return "\n".join(lines) + + def close_handlers(self): + if getattr(self, "_logging_configured", False): + # don't attempt to close handlers unless they have been opened + # (note accessing self.log.handlers will create handlers if they + # have not yet been initialised) + for handler in self.log.handlers: + with suppress(Exception): + handler.close() + self._logging_configured = False + + def exit(self, exit_status=0): + self.log.debug("Exiting application: %s" % self.name) + self.close_handlers() + sys.exit(exit_status) + + def __del__(self): + self.close_handlers() + + @classmethod + def launch_instance(cls, argv=None, **kwargs): + """Launch a global instance of this Application + + If a global instance already exists, this reinitializes and starts it + """ + app = cls.instance(**kwargs) + app.initialize(argv) + app.start() + + +# ----------------------------------------------------------------------------- +# utility functions, for convenience +# ----------------------------------------------------------------------------- + +default_aliases = Application.aliases +default_flags = Application.flags + + +def boolean_flag(name, configurable, set_help="", unset_help=""): + """Helper for building basic --trait, --no-trait flags. + + Parameters + ---------- + name : str + The name of the flag. + configurable : str + The 'Class.trait' string of the trait to be set/unset with the flag + set_help : unicode + help string for --name flag + unset_help : unicode + help string for --no-name flag + + Returns + ------- + cfg : dict + A dict with two keys: 'name', and 'no-name', for setting and unsetting + the trait, respectively. + """ + # default helpstrings + set_help = set_help or "set %s=True" % configurable + unset_help = unset_help or "set %s=False" % configurable + + cls, trait = configurable.split(".") + + setter = {cls: {trait: True}} + unsetter = {cls: {trait: False}} + return {name: (setter, set_help), "no-" + name: (unsetter, unset_help)} + + +def get_config(): + """Get the config object for the global Application instance, if there is one + + otherwise return an empty config object + """ + if Application.initialized(): + return Application.instance().config + else: + return Config() + + +if __name__ == "__main__": + Application.launch_instance() diff --git a/contrib/python/traitlets/py3/traitlets/config/argcomplete_config.py b/contrib/python/traitlets/py3/traitlets/config/argcomplete_config.py new file mode 100644 index 0000000000..afda7f86d2 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/argcomplete_config.py @@ -0,0 +1,220 @@ +"""Helper utilities for integrating argcomplete with traitlets""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + + +import argparse +import os +import typing as t + +try: + import argcomplete # type: ignore[import] + from argcomplete import CompletionFinder +except ImportError: + # This module and its utility methods are written to not crash even + # if argcomplete is not installed. + class StubModule: + def __getattr__(self, attr): + if not attr.startswith("__"): + raise ModuleNotFoundError("No module named 'argcomplete'") + raise AttributeError(f"argcomplete stub module has no attribute '{attr}'") + + argcomplete = StubModule() + CompletionFinder = object + + +def get_argcomplete_cwords() -> t.Optional[t.List[str]]: + """Get current words prior to completion point + + This is normally done in the `argcomplete.CompletionFinder` constructor, + but is exposed here to allow `traitlets` to follow dynamic code-paths such + as determining whether to evaluate a subcommand. + """ + if "_ARGCOMPLETE" not in os.environ: + return None + + comp_line = os.environ["COMP_LINE"] + comp_point = int(os.environ["COMP_POINT"]) + # argcomplete.debug("splitting COMP_LINE for:", comp_line, comp_point) + comp_words: t.List[str] + try: + ( + cword_prequote, + cword_prefix, + cword_suffix, + comp_words, + last_wordbreak_pos, + ) = argcomplete.split_line(comp_line, comp_point) + except ModuleNotFoundError: + return None + + # _ARGCOMPLETE is set by the shell script to tell us where comp_words + # should start, based on what we're completing. + # 1: <script> [args] + # 2: python <script> [args] + # 3: python -m <module> [args] + start = int(os.environ["_ARGCOMPLETE"]) - 1 + comp_words = comp_words[start:] + + # argcomplete.debug("prequote=", cword_prequote, "prefix=", cword_prefix, "suffix=", cword_suffix, "words=", comp_words, "last=", last_wordbreak_pos) + return comp_words + + +def increment_argcomplete_index(): + """Assumes ``$_ARGCOMPLETE`` is set and `argcomplete` is importable + + Increment the index pointed to by ``$_ARGCOMPLETE``, which is used to + determine which word `argcomplete` should start evaluating the command-line. + This may be useful to "inform" `argcomplete` that we have already evaluated + the first word as a subcommand. + """ + try: + os.environ["_ARGCOMPLETE"] = str(int(os.environ["_ARGCOMPLETE"]) + 1) + except Exception: + try: + argcomplete.debug("Unable to increment $_ARGCOMPLETE", os.environ["_ARGCOMPLETE"]) + except (KeyError, ModuleNotFoundError): + pass + + +class ExtendedCompletionFinder(CompletionFinder): + """An extension of CompletionFinder which dynamically completes class-trait based options + + This finder adds a few functionalities: + + 1. When completing options, it will add ``--Class.`` to the list of completions, for each + class in `Application.classes` that could complete the current option. + 2. If it detects that we are currently trying to complete an option related to ``--Class.``, + it will add the corresponding config traits of Class to the `ArgumentParser` instance, + so that the traits' completers can be used. + 3. If there are any subcommands, they are added as completions for the first word + + Note that we are avoiding adding all config traits of all classes to the `ArgumentParser`, + which would be easier but would add more runtime overhead and would also make completions + appear more spammy. + + These changes do require using the internals of `argcomplete.CompletionFinder`. + """ + + _parser: argparse.ArgumentParser + config_classes: t.List[t.Any] = [] # Configurables + subcommands: t.List[str] = [] + + def match_class_completions(self, cword_prefix: str) -> t.List[t.Tuple[t.Any, str]]: + """Match the word to be completed against our Configurable classes + + Check if cword_prefix could potentially match against --{class}. for any class + in Application.classes. + """ + class_completions = [(cls, f"--{cls.__name__}.") for cls in self.config_classes] + matched_completions = class_completions + if "." in cword_prefix: + cword_prefix = cword_prefix[: cword_prefix.index(".") + 1] + matched_completions = [(cls, c) for (cls, c) in class_completions if c == cword_prefix] + elif len(cword_prefix) > 0: + matched_completions = [ + (cls, c) for (cls, c) in class_completions if c.startswith(cword_prefix) + ] + return matched_completions + + def inject_class_to_parser(self, cls): + """Add dummy arguments to our ArgumentParser for the traits of this class + + The argparse-based loader currently does not actually add any class traits to + the constructed ArgumentParser, only the flags & aliaes. In order to work nicely + with argcomplete's completers functionality, this method adds dummy arguments + of the form --Class.trait to the ArgumentParser instance. + + This method should be called selectively to reduce runtime overhead and to avoid + spamming options across all of Application.classes. + """ + try: + for traitname, trait in cls.class_traits(config=True).items(): + completer = trait.metadata.get("argcompleter") or getattr( + trait, "argcompleter", None + ) + multiplicity = trait.metadata.get("multiplicity") + self._parser.add_argument( # type: ignore[attr-defined] + f"--{cls.__name__}.{traitname}", + type=str, + help=trait.help, + nargs=multiplicity, + # metavar=traitname, + ).completer = completer + # argcomplete.debug(f"added --{cls.__name__}.{traitname}") + except AttributeError: + pass + + def _get_completions( + self, comp_words: t.List[str], cword_prefix: str, *args: t.Any + ) -> t.List[str]: + """Overriden to dynamically append --Class.trait arguments if appropriate + + Warning: + This does not (currently) support completions of the form + --Class1.Class2.<...>.trait, although this is valid for traitlets. + Part of the reason is that we don't currently have a way to identify + which classes may be used with Class1 as a parent. + + Warning: + This is an internal method in CompletionFinder and so the API might + be subject to drift. + """ + # Try to identify if we are completing something related to --Class. for + # a known Class, if we are then add the Class config traits to our ArgumentParser. + prefix_chars = self._parser.prefix_chars + is_option = len(cword_prefix) > 0 and cword_prefix[0] in prefix_chars + if is_option: + # If we are currently completing an option, check if it could + # match with any of the --Class. completions. If there's exactly + # one matched class, then expand out the --Class.trait options. + matched_completions = self.match_class_completions(cword_prefix) + if len(matched_completions) == 1: + matched_cls = matched_completions[0][0] + self.inject_class_to_parser(matched_cls) + elif len(comp_words) > 0 and "." in comp_words[-1] and not is_option: + # If not an option, perform a hacky check to see if we are completing + # an argument for an already present --Class.trait option. Search backwards + # for last option (based on last word starting with prefix_chars), and see + # if it is of the form --Class.trait. Note that if multiplicity="+", these + # arguments might conflict with positional arguments. + for prev_word in comp_words[::-1]: + if len(prev_word) > 0 and prev_word[0] in prefix_chars: + matched_completions = self.match_class_completions(prev_word) + if matched_completions: + matched_cls = matched_completions[0][0] + self.inject_class_to_parser(matched_cls) + break + + completions: t.List[str] + completions = super()._get_completions(comp_words, cword_prefix, *args) + + # For subcommand-handling: it is difficult to get this to work + # using argparse subparsers, because the ArgumentParser accepts + # arbitrary extra_args, which ends up masking subparsers. + # Instead, check if comp_words only consists of the script, + # if so check if any subcommands start with cword_prefix. + if self.subcommands and len(comp_words) == 1: + argcomplete.debug("Adding subcommands for", cword_prefix) + completions.extend(subc for subc in self.subcommands if subc.startswith(cword_prefix)) + + return completions + + def _get_option_completions( + self, parser: argparse.ArgumentParser, cword_prefix: str + ) -> t.List[str]: + """Overriden to add --Class. completions when appropriate""" + completions: t.List[str] + completions = super()._get_option_completions(parser, cword_prefix) + if cword_prefix.endswith("."): + return completions + + matched_completions = self.match_class_completions(cword_prefix) + if len(matched_completions) > 1: + completions.extend(opt for cls, opt in matched_completions) + # If there is exactly one match, we would expect it to have aleady + # been handled by the options dynamically added in _get_completions(). + # However, maybe there's an edge cases missed here, for example if the + # matched class has no configurable traits. + return completions diff --git a/contrib/python/traitlets/py3/traitlets/config/configurable.py b/contrib/python/traitlets/py3/traitlets/config/configurable.py new file mode 100644 index 0000000000..effa8e429d --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/configurable.py @@ -0,0 +1,568 @@ +"""A base class for objects that are configurable.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + + +import logging +import warnings +from copy import deepcopy +from textwrap import dedent + +from traitlets.traitlets import ( + Any, + Container, + Dict, + HasTraits, + Instance, + default, + observe, + observe_compat, + validate, +) +from traitlets.utils.text import indent, wrap_paragraphs + +from .loader import Config, DeferredConfig, LazyConfigValue, _is_section_key + +# ----------------------------------------------------------------------------- +# Helper classes for Configurables +# ----------------------------------------------------------------------------- + + +class ConfigurableError(Exception): + pass + + +class MultipleInstanceError(ConfigurableError): + pass + + +# ----------------------------------------------------------------------------- +# Configurable implementation +# ----------------------------------------------------------------------------- + + +class Configurable(HasTraits): + + config = Instance(Config, (), {}) + parent = Instance("traitlets.config.configurable.Configurable", allow_none=True) + + def __init__(self, **kwargs): + """Create a configurable given a config config. + + Parameters + ---------- + config : Config + If this is empty, default values are used. If config is a + :class:`Config` instance, it will be used to configure the + instance. + parent : Configurable instance, optional + The parent Configurable instance of this object. + + Notes + ----- + Subclasses of Configurable must call the :meth:`__init__` method of + :class:`Configurable` *before* doing anything else and using + :func:`super`:: + + class MyConfigurable(Configurable): + def __init__(self, config=None): + super(MyConfigurable, self).__init__(config=config) + # Then any other code you need to finish initialization. + + This ensures that instances will be configured properly. + """ + parent = kwargs.pop("parent", None) + if parent is not None: + # config is implied from parent + if kwargs.get("config", None) is None: + kwargs["config"] = parent.config + self.parent = parent + + config = kwargs.pop("config", None) + + # load kwarg traits, other than config + super().__init__(**kwargs) + + # record traits set by config + config_override_names = set() + + def notice_config_override(change): + """Record traits set by both config and kwargs. + + They will need to be overridden again after loading config. + """ + if change.name in kwargs: + config_override_names.add(change.name) + + self.observe(notice_config_override) + + # load config + if config is not None: + # We used to deepcopy, but for now we are trying to just save + # by reference. This *could* have side effects as all components + # will share config. In fact, I did find such a side effect in + # _config_changed below. If a config attribute value was a mutable type + # all instances of a component were getting the same copy, effectively + # making that a class attribute. + # self.config = deepcopy(config) + self.config = config + else: + # allow _config_default to return something + self._load_config(self.config) + self.unobserve(notice_config_override) + + for name in config_override_names: + setattr(self, name, kwargs[name]) + + # ------------------------------------------------------------------------- + # Static trait notifiations + # ------------------------------------------------------------------------- + + @classmethod + def section_names(cls): + """return section names as a list""" + return [ + c.__name__ + for c in reversed(cls.__mro__) + if issubclass(c, Configurable) and issubclass(cls, c) + ] + + def _find_my_config(self, cfg): + """extract my config from a global Config object + + will construct a Config object of only the config values that apply to me + based on my mro(), as well as those of my parent(s) if they exist. + + If I am Bar and my parent is Foo, and their parent is Tim, + this will return merge following config sections, in this order:: + + [Bar, Foo.Bar, Tim.Foo.Bar] + + With the last item being the highest priority. + """ + cfgs = [cfg] + if self.parent: + cfgs.append(self.parent._find_my_config(cfg)) + my_config = Config() + for c in cfgs: + for sname in self.section_names(): + # Don't do a blind getattr as that would cause the config to + # dynamically create the section with name Class.__name__. + if c._has_section(sname): + my_config.merge(c[sname]) + return my_config + + def _load_config(self, cfg, section_names=None, traits=None): + """load traits from a Config object""" + + if traits is None: + traits = self.traits(config=True) + if section_names is None: + section_names = self.section_names() + + my_config = self._find_my_config(cfg) + + # hold trait notifications until after all config has been loaded + with self.hold_trait_notifications(): + for name, config_value in my_config.items(): + if name in traits: + if isinstance(config_value, LazyConfigValue): + # ConfigValue is a wrapper for using append / update on containers + # without having to copy the initial value + initial = getattr(self, name) + config_value = config_value.get_value(initial) + elif isinstance(config_value, DeferredConfig): + # DeferredConfig tends to come from CLI/environment variables + config_value = config_value.get_value(traits[name]) + # We have to do a deepcopy here if we don't deepcopy the entire + # config object. If we don't, a mutable config_value will be + # shared by all instances, effectively making it a class attribute. + setattr(self, name, deepcopy(config_value)) + elif not _is_section_key(name) and not isinstance(config_value, Config): + from difflib import get_close_matches + + if isinstance(self, LoggingConfigurable): + warn = self.log.warning + else: + warn = lambda msg: warnings.warn(msg, stacklevel=9) # noqa[E371] + matches = get_close_matches(name, traits) + msg = "Config option `{option}` not recognized by `{klass}`.".format( + option=name, klass=self.__class__.__name__ + ) + + if len(matches) == 1: + msg += f" Did you mean `{matches[0]}`?" + elif len(matches) >= 1: + msg += " Did you mean one of: `{matches}`?".format( + matches=", ".join(sorted(matches)) + ) + warn(msg) + + @observe("config") + @observe_compat + def _config_changed(self, change): + """Update all the class traits having ``config=True`` in metadata. + + For any class trait with a ``config`` metadata attribute that is + ``True``, we update the trait with the value of the corresponding + config entry. + """ + # Get all traits with a config metadata entry that is True + traits = self.traits(config=True) + + # We auto-load config section for this class as well as any parent + # classes that are Configurable subclasses. This starts with Configurable + # and works down the mro loading the config for each section. + section_names = self.section_names() + self._load_config(change.new, traits=traits, section_names=section_names) + + def update_config(self, config): + """Update config and load the new values""" + # traitlets prior to 4.2 created a copy of self.config in order to trigger change events. + # Some projects (IPython < 5) relied upon one side effect of this, + # that self.config prior to update_config was not modified in-place. + # For backward-compatibility, we must ensure that self.config + # is a new object and not modified in-place, + # but config consumers should not rely on this behavior. + self.config = deepcopy(self.config) + # load config + self._load_config(config) + # merge it into self.config + self.config.merge(config) + # TODO: trigger change event if/when dict-update change events take place + # DO NOT trigger full trait-change + + @classmethod + def class_get_help(cls, inst=None): + """Get the help string for this class in ReST format. + + If `inst` is given, its current trait values will be used in place of + class defaults. + """ + assert inst is None or isinstance(inst, cls) + final_help = [] + base_classes = ", ".join(p.__name__ for p in cls.__bases__) + final_help.append(f"{cls.__name__}({base_classes}) options") + final_help.append(len(final_help[0]) * "-") + for _, v in sorted(cls.class_traits(config=True).items()): + help = cls.class_get_trait_help(v, inst) + final_help.append(help) + return "\n".join(final_help) + + @classmethod + def class_get_trait_help(cls, trait, inst=None, helptext=None): + """Get the helptext string for a single trait. + + :param inst: + If given, its current trait values will be used in place of + the class default. + :param helptext: + If not given, uses the `help` attribute of the current trait. + """ + assert inst is None or isinstance(inst, cls) + lines = [] + header = f"--{cls.__name__}.{trait.name}" + if isinstance(trait, (Container, Dict)): + multiplicity = trait.metadata.get("multiplicity", "append") + if isinstance(trait, Dict): + sample_value = "<key-1>=<value-1>" + else: + sample_value = "<%s-item-1>" % trait.__class__.__name__.lower() + if multiplicity == "append": + header = f"{header}={sample_value}..." + else: + header = f"{header} {sample_value}..." + else: + header = f"{header}=<{trait.__class__.__name__}>" + # header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__) + lines.append(header) + + if helptext is None: + helptext = trait.help + if helptext != "": + helptext = "\n".join(wrap_paragraphs(helptext, 76)) + lines.append(indent(helptext)) + + if "Enum" in trait.__class__.__name__: + # include Enum choices + lines.append(indent("Choices: %s" % trait.info())) + + if inst is not None: + lines.append(indent(f"Current: {getattr(inst, trait.name)!r}")) + else: + try: + dvr = trait.default_value_repr() + except Exception: + dvr = None # ignore defaults we can't construct + if dvr is not None: + if len(dvr) > 64: + dvr = dvr[:61] + "..." + lines.append(indent("Default: %s" % dvr)) + + return "\n".join(lines) + + @classmethod + def class_print_help(cls, inst=None): + """Get the help string for a single trait and print it.""" + print(cls.class_get_help(inst)) + + @classmethod + def _defining_class(cls, trait, classes): + """Get the class that defines a trait + + For reducing redundant help output in config files. + Returns the current class if: + - the trait is defined on this class, or + - the class where it is defined would not be in the config file + + Parameters + ---------- + trait : Trait + The trait to look for + classes : list + The list of other classes to consider for redundancy. + Will return `cls` even if it is not defined on `cls` + if the defining class is not in `classes`. + """ + defining_cls = cls + for parent in cls.mro(): + if ( + issubclass(parent, Configurable) + and parent in classes + and parent.class_own_traits(config=True).get(trait.name, None) is trait + ): + defining_cls = parent + return defining_cls + + @classmethod + def class_config_section(cls, classes=None): + """Get the config section for this class. + + Parameters + ---------- + classes : list, optional + The list of other classes in the config file. + Used to reduce redundant information. + """ + + def c(s): + """return a commented, wrapped block.""" + s = "\n\n".join(wrap_paragraphs(s, 78)) + + return "## " + s.replace("\n", "\n# ") + + # section header + breaker = "#" + "-" * 78 + parent_classes = ", ".join(p.__name__ for p in cls.__bases__ if issubclass(p, Configurable)) + + s = f"# {cls.__name__}({parent_classes}) configuration" + lines = [breaker, s, breaker] + # get the description trait + desc = cls.class_traits().get("description") + if desc: + desc = desc.default_value + if not desc: + # no description from trait, use __doc__ + desc = getattr(cls, "__doc__", "") + if desc: + lines.append(c(desc)) + lines.append("") + + for name, trait in sorted(cls.class_traits(config=True).items()): + default_repr = trait.default_value_repr() + + if classes: + defining_class = cls._defining_class(trait, classes) + else: + defining_class = cls + if defining_class is cls: + # cls owns the trait, show full help + if trait.help: + lines.append(c(trait.help)) + if "Enum" in type(trait).__name__: + # include Enum choices + lines.append("# Choices: %s" % trait.info()) + lines.append("# Default: %s" % default_repr) + else: + # Trait appears multiple times and isn't defined here. + # Truncate help to first line + "See also Original.trait" + if trait.help: + lines.append(c(trait.help.split("\n", 1)[0])) + lines.append(f"# See also: {defining_class.__name__}.{name}") + + lines.append(f"# c.{cls.__name__}.{name} = {default_repr}") + lines.append("") + return "\n".join(lines) + + @classmethod + def class_config_rst_doc(cls): + """Generate rST documentation for this class' config options. + + Excludes traits defined on parent classes. + """ + lines = [] + classname = cls.__name__ + for _, trait in sorted(cls.class_traits(config=True).items()): + ttype = trait.__class__.__name__ + + termline = classname + "." + trait.name + + # Choices or type + if "Enum" in ttype: + # include Enum choices + termline += " : " + trait.info_rst() + else: + termline += " : " + ttype + lines.append(termline) + + # Default value + try: + dvr = trait.default_value_repr() + except Exception: + dvr = None # ignore defaults we can't construct + if dvr is not None: + if len(dvr) > 64: + dvr = dvr[:61] + "..." + # Double up backslashes, so they get to the rendered docs + dvr = dvr.replace("\\n", "\\\\n") + lines.append(indent("Default: ``%s``" % dvr)) + lines.append("") + + help = trait.help or "No description" + lines.append(indent(dedent(help))) + + # Blank line + lines.append("") + + return "\n".join(lines) + + +class LoggingConfigurable(Configurable): + """A parent class for Configurables that log. + + Subclasses have a log trait, and the default behavior + is to get the logger from the currently running Application. + """ + + log = Any(help="Logger or LoggerAdapter instance") + + @validate("log") + def _validate_log(self, proposal): + if not isinstance(proposal.value, (logging.Logger, logging.LoggerAdapter)): + # warn about unsupported type, but be lenient to allow for duck typing + warnings.warn( + f"{self.__class__.__name__}.log should be a Logger or LoggerAdapter," + f" got {proposal.value}." + ) + return proposal.value + + @default("log") + def _log_default(self): + if isinstance(self.parent, LoggingConfigurable): + return self.parent.log + from traitlets import log + + return log.get_logger() + + def _get_log_handler(self): + """Return the default Handler + + Returns None if none can be found + + Deprecated, this now returns the first log handler which may or may + not be the default one. + """ + logger = self.log + if isinstance(logger, logging.LoggerAdapter): + logger = logger.logger + if not getattr(logger, "handlers", None): + # no handlers attribute or empty handlers list + return None + return logger.handlers[0] + + +class SingletonConfigurable(LoggingConfigurable): + """A configurable that only allows one instance. + + This class is for classes that should only have one instance of itself + or *any* subclass. To create and retrieve such a class use the + :meth:`SingletonConfigurable.instance` method. + """ + + _instance = None + + @classmethod + def _walk_mro(cls): + """Walk the cls.mro() for parent classes that are also singletons + + For use in instance() + """ + + for subclass in cls.mro(): + if ( + issubclass(cls, subclass) + and issubclass(subclass, SingletonConfigurable) + and subclass != SingletonConfigurable + ): + yield subclass + + @classmethod + def clear_instance(cls): + """unset _instance for this class and singleton parents.""" + if not cls.initialized(): + return + for subclass in cls._walk_mro(): + if isinstance(subclass._instance, cls): + # only clear instances that are instances + # of the calling class + subclass._instance = None + + @classmethod + def instance(cls, *args, **kwargs): + """Returns a global instance of this class. + + This method create a new instance if none have previously been created + and returns a previously created instance is one already exists. + + The arguments and keyword arguments passed to this method are passed + on to the :meth:`__init__` method of the class upon instantiation. + + Examples + -------- + Create a singleton class using instance, and retrieve it:: + + >>> from traitlets.config.configurable import SingletonConfigurable + >>> class Foo(SingletonConfigurable): pass + >>> foo = Foo.instance() + >>> foo == Foo.instance() + True + + Create a subclass that is retrived using the base class instance:: + + >>> class Bar(SingletonConfigurable): pass + >>> class Bam(Bar): pass + >>> bam = Bam.instance() + >>> bam == Bar.instance() + True + """ + # Create and save the instance + if cls._instance is None: + inst = cls(*args, **kwargs) + # Now make sure that the instance will also be returned by + # parent classes' _instance attribute. + for subclass in cls._walk_mro(): + subclass._instance = inst + + if isinstance(cls._instance, cls): + return cls._instance + else: + raise MultipleInstanceError( + "An incompatible sibling of '%s' is already instantiated" + " as singleton: %s" % (cls.__name__, type(cls._instance).__name__) + ) + + @classmethod + def initialized(cls): + """Has an instance been created?""" + return hasattr(cls, "_instance") and cls._instance is not None diff --git a/contrib/python/traitlets/py3/traitlets/config/loader.py b/contrib/python/traitlets/py3/traitlets/config/loader.py new file mode 100644 index 0000000000..c1834a9bc4 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/loader.py @@ -0,0 +1,1167 @@ +"""A simple configuration system.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +import argparse +import copy +import functools +import json +import os +import re +import sys +import typing as t +import warnings + +from traitlets.traitlets import Any, Container, Dict, HasTraits, List, Undefined + +from ..utils import cast_unicode, filefind + +# ----------------------------------------------------------------------------- +# Exceptions +# ----------------------------------------------------------------------------- + + +class ConfigError(Exception): + pass + + +class ConfigLoaderError(ConfigError): + pass + + +class ConfigFileNotFound(ConfigError): # noqa + pass + + +class ArgumentError(ConfigLoaderError): + pass + + +# ----------------------------------------------------------------------------- +# Argparse fix +# ----------------------------------------------------------------------------- + +# Unfortunately argparse by default prints help messages to stderr instead of +# stdout. This makes it annoying to capture long help screens at the command +# line, since one must know how to pipe stderr, which many users don't know how +# to do. So we override the print_help method with one that defaults to +# stdout and use our class instead. + + +class _Sentinel: + def __repr__(self): + return "<Sentinel deprecated>" + + def __str__(self): + return "<deprecated>" + + +_deprecated = _Sentinel() + + +class ArgumentParser(argparse.ArgumentParser): + """Simple argparse subclass that prints help to stdout by default.""" + + def print_help(self, file=None): + if file is None: + file = sys.stdout + return super().print_help(file) + + print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__ + + +# ----------------------------------------------------------------------------- +# Config class for holding config information +# ----------------------------------------------------------------------------- + + +def execfile(fname, glob): + with open(fname, "rb") as f: + exec(compile(f.read(), fname, "exec"), glob, glob) # noqa + + +class LazyConfigValue(HasTraits): + """Proxy object for exposing methods on configurable containers + + These methods allow appending/extending/updating + to add to non-empty defaults instead of clobbering them. + + Exposes: + + - append, extend, insert on lists + - update on dicts + - update, add on sets + """ + + _value = None + + # list methods + _extend = List() + _prepend = List() + _inserts = List() + + def append(self, obj): + """Append an item to a List""" + self._extend.append(obj) + + def extend(self, other): + """Extend a list""" + self._extend.extend(other) + + def prepend(self, other): + """like list.extend, but for the front""" + self._prepend[:0] = other + + def merge_into(self, other): + """ + Merge with another earlier LazyConfigValue or an earlier container. + This is useful when having global system-wide configuration files. + + Self is expected to have higher precedence. + + Parameters + ---------- + other : LazyConfigValue or container + + Returns + ------- + LazyConfigValue + if ``other`` is also lazy, a reified container otherwise. + """ + if isinstance(other, LazyConfigValue): + other._extend.extend(self._extend) + self._extend = other._extend + + self._prepend.extend(other._prepend) + + other._inserts.extend(self._inserts) + self._inserts = other._inserts + + if self._update: + other.update(self._update) + self._update = other._update + return self + else: + # other is a container, reify now. + return self.get_value(other) + + def insert(self, index, other): + if not isinstance(index, int): + raise TypeError("An integer is required") + self._inserts.append((index, other)) + + # dict methods + # update is used for both dict and set + _update = Any() + + def update(self, other): + """Update either a set or dict""" + if self._update is None: + if isinstance(other, dict): + self._update = {} + else: + self._update = set() + self._update.update(other) + + # set methods + def add(self, obj): + """Add an item to a set""" + self.update({obj}) + + def get_value(self, initial): + """construct the value from the initial one + + after applying any insert / extend / update changes + """ + if self._value is not None: + return self._value + value = copy.deepcopy(initial) + if isinstance(value, list): + for idx, obj in self._inserts: + value.insert(idx, obj) + value[:0] = self._prepend + value.extend(self._extend) + + elif isinstance(value, dict): + if self._update: + value.update(self._update) + elif isinstance(value, set): + if self._update: + value.update(self._update) + self._value = value + return value + + def to_dict(self): + """return JSONable dict form of my data + + Currently update as dict or set, extend, prepend as lists, and inserts as list of tuples. + """ + d = {} + if self._update: + d["update"] = self._update + if self._extend: + d["extend"] = self._extend + if self._prepend: + d["prepend"] = self._prepend + elif self._inserts: + d["inserts"] = self._inserts + return d + + def __repr__(self): + if self._value is not None: + return f"<{self.__class__.__name__} value={self._value!r}>" + else: + return f"<{self.__class__.__name__} {self.to_dict()!r}>" + + +def _is_section_key(key): + """Is a Config key a section name (does it start with a capital)?""" + if key and key[0].upper() == key[0] and not key.startswith("_"): + return True + else: + return False + + +class Config(dict): # type:ignore[type-arg] + """An attribute-based dict that can do smart merges. + + Accessing a field on a config object for the first time populates the key + with either a nested Config object for keys starting with capitals + or :class:`.LazyConfigValue` for lowercase keys, + allowing quick assignments such as:: + + c = Config() + c.Class.int_trait = 5 + c.Class.list_trait.append("x") + + """ + + def __init__(self, *args, **kwds): + dict.__init__(self, *args, **kwds) + self._ensure_subconfig() + + def _ensure_subconfig(self): + """ensure that sub-dicts that should be Config objects are + + casts dicts that are under section keys to Config objects, + which is necessary for constructing Config objects from dict literals. + """ + for key in self: + obj = self[key] + if _is_section_key(key) and isinstance(obj, dict) and not isinstance(obj, Config): + setattr(self, key, Config(obj)) + + def _merge(self, other): + """deprecated alias, use Config.merge()""" + self.merge(other) + + def merge(self, other): + """merge another config object into this one""" + to_update = {} + for k, v in other.items(): + if k not in self: + to_update[k] = v + else: # I have this key + if isinstance(v, Config) and isinstance(self[k], Config): + # Recursively merge common sub Configs + self[k].merge(v) + elif isinstance(v, LazyConfigValue): + self[k] = v.merge_into(self[k]) + else: + # Plain updates for non-Configs + to_update[k] = v + + self.update(to_update) + + def collisions(self, other: "Config") -> t.Dict[str, t.Any]: + """Check for collisions between two config objects. + + Returns a dict of the form {"Class": {"trait": "collision message"}}`, + indicating which values have been ignored. + + An empty dict indicates no collisions. + """ + collisions: t.Dict[str, t.Any] = {} + for section in self: + if section not in other: + continue + mine = self[section] + theirs = other[section] + for key in mine: + if key in theirs and mine[key] != theirs[key]: + collisions.setdefault(section, {}) + collisions[section][key] = f"{mine[key]!r} ignored, using {theirs[key]!r}" + return collisions + + def __contains__(self, key): + # allow nested contains of the form `"Section.key" in config` + if "." in key: + first, remainder = key.split(".", 1) + if first not in self: + return False + return remainder in self[first] + + return super().__contains__(key) + + # .has_key is deprecated for dictionaries. + has_key = __contains__ + + def _has_section(self, key): + return _is_section_key(key) and key in self + + def copy(self): + return type(self)(dict.copy(self)) + + def __copy__(self): + return self.copy() + + def __deepcopy__(self, memo): + new_config = type(self)() + for key, value in self.items(): + if isinstance(value, (Config, LazyConfigValue)): + # deep copy config objects + value = copy.deepcopy(value, memo) + elif type(value) in {dict, list, set, tuple}: + # shallow copy plain container traits + value = copy.copy(value) + new_config[key] = value + return new_config + + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + if _is_section_key(key): + c = Config() + dict.__setitem__(self, key, c) + return c + elif not key.startswith("_"): + # undefined, create lazy value, used for container methods + v = LazyConfigValue() + dict.__setitem__(self, key, v) + return v + else: + raise KeyError + + def __setitem__(self, key, value): + if _is_section_key(key): + if not isinstance(value, Config): + raise ValueError( + "values whose keys begin with an uppercase " + "char must be Config instances: %r, %r" % (key, value) + ) + dict.__setitem__(self, key, value) + + def __getattr__(self, key): + if key.startswith("__"): + return dict.__getattr__(self, key) # type:ignore[attr-defined] + try: + return self.__getitem__(key) + except KeyError as e: + raise AttributeError(e) from e + + def __setattr__(self, key, value): + if key.startswith("__"): + return dict.__setattr__(self, key, value) + try: + self.__setitem__(key, value) + except KeyError as e: + raise AttributeError(e) from e + + def __delattr__(self, key): + if key.startswith("__"): + return dict.__delattr__(self, key) + try: + dict.__delitem__(self, key) + except KeyError as e: + raise AttributeError(e) from e + + +class DeferredConfig: + """Class for deferred-evaluation of config from CLI""" + + pass + + def get_value(self, trait): + raise NotImplementedError("Implement in subclasses") + + def _super_repr(self): + # explicitly call super on direct parent + return super(self.__class__, self).__repr__() + + +class DeferredConfigString(str, DeferredConfig): + """Config value for loading config from a string + + Interpretation is deferred until it is loaded into the trait. + + Subclass of str for backward compatibility. + + This class is only used for values that are not listed + in the configurable classes. + + When config is loaded, `trait.from_string` will be used. + + If an error is raised in `.from_string`, + the original string is returned. + + .. versionadded:: 5.0 + """ + + def get_value(self, trait): + """Get the value stored in this string""" + s = str(self) + try: + return trait.from_string(s) + except Exception: + # exception casting from string, + # let the original string lie. + # this will raise a more informative error when config is loaded. + return s + + def __repr__(self): + return f"{self.__class__.__name__}({self._super_repr()})" + + +class DeferredConfigList(list, DeferredConfig): # type:ignore[type-arg] + """Config value for loading config from a list of strings + + Interpretation is deferred until it is loaded into the trait. + + This class is only used for values that are not listed + in the configurable classes. + + When config is loaded, `trait.from_string_list` will be used. + + If an error is raised in `.from_string_list`, + the original string list is returned. + + .. versionadded:: 5.0 + """ + + def get_value(self, trait): + """Get the value stored in this string""" + if hasattr(trait, "from_string_list"): + src = list(self) + cast = trait.from_string_list + else: + # only allow one item + if len(self) > 1: + raise ValueError( + f"{trait.name} only accepts one value, got {len(self)}: {list(self)}" + ) + src = self[0] + cast = trait.from_string + + try: + return cast(src) + except Exception: + # exception casting from string, + # let the original value lie. + # this will raise a more informative error when config is loaded. + return src + + def __repr__(self): + return f"{self.__class__.__name__}({self._super_repr()})" + + +# ----------------------------------------------------------------------------- +# Config loading classes +# ----------------------------------------------------------------------------- + + +class ConfigLoader: + """A object for loading configurations from just about anywhere. + + The resulting configuration is packaged as a :class:`Config`. + + Notes + ----- + A :class:`ConfigLoader` does one thing: load a config from a source + (file, command line arguments) and returns the data as a :class:`Config` object. + There are lots of things that :class:`ConfigLoader` does not do. It does + not implement complex logic for finding config files. It does not handle + default values or merge multiple configs. These things need to be + handled elsewhere. + """ + + def _log_default(self): + from traitlets.log import get_logger + + return get_logger() + + def __init__(self, log=None): + """A base class for config loaders. + + log : instance of :class:`logging.Logger` to use. + By default logger of :meth:`traitlets.config.application.Application.instance()` + will be used + + Examples + -------- + >>> cl = ConfigLoader() + >>> config = cl.load_config() + >>> config + {} + """ + self.clear() + if log is None: + self.log = self._log_default() + self.log.debug("Using default logger") + else: + self.log = log + + def clear(self): + self.config = Config() + + def load_config(self): + """Load a config from somewhere, return a :class:`Config` instance. + + Usually, this will cause self.config to be set and then returned. + However, in most cases, :meth:`ConfigLoader.clear` should be called + to erase any previous state. + """ + self.clear() + return self.config + + +class FileConfigLoader(ConfigLoader): + """A base class for file based configurations. + + As we add more file based config loaders, the common logic should go + here. + """ + + def __init__(self, filename, path=None, **kw): + """Build a config loader for a filename and path. + + Parameters + ---------- + filename : str + The file name of the config file. + path : str, list, tuple + The path to search for the config file on, or a sequence of + paths to try in order. + """ + super().__init__(**kw) + self.filename = filename + self.path = path + self.full_filename = "" + + def _find_file(self): + """Try to find the file by searching the paths.""" + self.full_filename = filefind(self.filename, self.path) + + +class JSONFileConfigLoader(FileConfigLoader): + """A JSON file loader for config + + Can also act as a context manager that rewrite the configuration file to disk on exit. + + Example:: + + with JSONFileConfigLoader('myapp.json','/home/jupyter/configurations/') as c: + c.MyNewConfigurable.new_value = 'Updated' + + """ + + def load_config(self): + """Load the config from a file and return it as a Config object.""" + self.clear() + try: + self._find_file() + except OSError as e: + raise ConfigFileNotFound(str(e)) from e + dct = self._read_file_as_dict() + self.config = self._convert_to_config(dct) + return self.config + + def _read_file_as_dict(self): + with open(self.full_filename) as f: + return json.load(f) + + def _convert_to_config(self, dictionary): + if "version" in dictionary: + version = dictionary.pop("version") + else: + version = 1 + + if version == 1: + return Config(dictionary) + else: + raise ValueError(f"Unknown version of JSON config file: {version}") + + def __enter__(self): + self.load_config() + return self.config + + def __exit__(self, exc_type, exc_value, traceback): + """ + Exit the context manager but do not handle any errors. + + In case of any error, we do not want to write the potentially broken + configuration to disk. + """ + self.config.version = 1 + json_config = json.dumps(self.config, indent=2) + with open(self.full_filename, "w") as f: + f.write(json_config) + + +class PyFileConfigLoader(FileConfigLoader): + """A config loader for pure python files. + + This is responsible for locating a Python config file by filename and + path, then executing it to construct a Config object. + """ + + def load_config(self): + """Load the config from a file and return it as a Config object.""" + self.clear() + try: + self._find_file() + except OSError as e: + raise ConfigFileNotFound(str(e)) from e + self._read_file_as_dict() + return self.config + + def load_subconfig(self, fname, path=None): + """Injected into config file namespace as load_subconfig""" + if path is None: + path = self.path + + loader = self.__class__(fname, path) + try: + sub_config = loader.load_config() + except ConfigFileNotFound: + # Pass silently if the sub config is not there, + # treat it as an empty config file. + pass + else: + self.config.merge(sub_config) + + def _read_file_as_dict(self): + """Load the config file into self.config, with recursive loading.""" + + def get_config(): + """Unnecessary now, but a deprecation warning is more trouble than it's worth.""" + return self.config + + namespace = dict( + c=self.config, + load_subconfig=self.load_subconfig, + get_config=get_config, + __file__=self.full_filename, + ) + conf_filename = self.full_filename + with open(conf_filename, "rb") as f: + exec(compile(f.read(), conf_filename, "exec"), namespace, namespace) # noqa + + +class CommandLineConfigLoader(ConfigLoader): + """A config loader for command line arguments. + + As we add more command line based loaders, the common logic should go + here. + """ + + def _exec_config_str(self, lhs, rhs, trait=None): + """execute self.config.<lhs> = <rhs> + + * expands ~ with expanduser + * interprets value with trait if available + """ + value = rhs + if isinstance(value, DeferredConfig): + if trait: + # trait available, reify config immediately + value = value.get_value(trait) + elif isinstance(rhs, DeferredConfigList) and len(rhs) == 1: + # single item, make it a deferred str + value = DeferredConfigString(os.path.expanduser(rhs[0])) + else: + if trait: + value = trait.from_string(value) + else: + value = DeferredConfigString(value) + + *path, key = lhs.split(".") + section = self.config + for part in path: + section = section[part] + section[key] = value + return + + def _load_flag(self, cfg): + """update self.config from a flag, which can be a dict or Config""" + if isinstance(cfg, (dict, Config)): + # don't clobber whole config sections, update + # each section from config: + for sec, c in cfg.items(): + self.config[sec].update(c) + else: + raise TypeError("Invalid flag: %r" % cfg) + + +# match --Class.trait keys for argparse +# matches: +# --Class.trait +# --x +# -x + +class_trait_opt_pattern = re.compile(r"^\-?\-[A-Za-z][\w]*(\.[\w]+)*$") + +_DOT_REPLACEMENT = "__DOT__" +_DASH_REPLACEMENT = "__DASH__" + + +class _KVAction(argparse.Action): + """Custom argparse action for handling --Class.trait=x + + Always + """ + + def __call__(self, parser, namespace, values, option_string=None): + if isinstance(values, str): + values = [values] + values = ["-" if v is _DASH_REPLACEMENT else v for v in values] + items = getattr(namespace, self.dest, None) + if items is None: + items = DeferredConfigList() + else: + items = DeferredConfigList(items) + items.extend(values) + setattr(namespace, self.dest, items) + + +class _DefaultOptionDict(dict): # type:ignore[type-arg] + """Like the default options dict + + but acts as if all --Class.trait options are predefined + """ + + def _add_kv_action(self, key): + self[key] = _KVAction( + option_strings=[key], + dest=key.lstrip("-").replace(".", _DOT_REPLACEMENT), + # use metavar for display purposes + metavar=key.lstrip("-"), + ) + + def __contains__(self, key): + if "=" in key: + return False + if super().__contains__(key): + return True + + if key.startswith("-") and class_trait_opt_pattern.match(key): + self._add_kv_action(key) + return True + return False + + def __getitem__(self, key): + if key in self: + return super().__getitem__(key) + else: + raise KeyError(key) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + +class _KVArgParser(argparse.ArgumentParser): + """subclass of ArgumentParser where any --Class.trait option is implicitly defined""" + + def parse_known_args(self, args=None, namespace=None): + # must be done immediately prior to parsing because if we do it in init, + # registration of explicit actions via parser.add_option will fail during setup + for container in (self, self._optionals): + container._option_string_actions = _DefaultOptionDict(container._option_string_actions) + return super().parse_known_args(args, namespace) + + +# type aliases +Flags = t.Union[str, t.Tuple[str, ...]] +SubcommandsDict = t.Dict[str, t.Any] + + +class ArgParseConfigLoader(CommandLineConfigLoader): + """A loader that uses the argparse module to load from the command line.""" + + parser_class = ArgumentParser + + def __init__( + self, + argv: t.Optional[t.List[str]] = None, + aliases: t.Optional[t.Dict[Flags, str]] = None, + flags: t.Optional[t.Dict[Flags, str]] = None, + log: t.Any = None, + classes: t.Optional[t.List[t.Type[t.Any]]] = None, + subcommands: t.Optional[SubcommandsDict] = None, + *parser_args: t.Any, + **parser_kw: t.Any, + ) -> None: + """Create a config loader for use with argparse. + + Parameters + ---------- + classes : optional, list + The classes to scan for *container* config-traits and decide + for their "multiplicity" when adding them as *argparse* arguments. + argv : optional, list + If given, used to read command-line arguments from, otherwise + sys.argv[1:] is used. + *parser_args : tuple + A tuple of positional arguments that will be passed to the + constructor of :class:`argparse.ArgumentParser`. + **parser_kw : dict + A tuple of keyword arguments that will be passed to the + constructor of :class:`argparse.ArgumentParser`. + aliases : dict of str to str + Dict of aliases to full traitlets names for CLI parsing + flags : dict of str to str + Dict of flags to full traitlets names for CLI parsing + log + Passed to `ConfigLoader` + + Returns + ------- + config : Config + The resulting Config object. + """ + classes = classes or [] + super(CommandLineConfigLoader, self).__init__(log=log) + self.clear() + if argv is None: + argv = sys.argv[1:] + self.argv = argv + self.aliases = aliases or {} + self.flags = flags or {} + self.classes = classes + self.subcommands = subcommands # only used for argcomplete currently + + self.parser_args = parser_args + self.version = parser_kw.pop("version", None) + kwargs = dict(argument_default=argparse.SUPPRESS) + kwargs.update(parser_kw) + self.parser_kw = kwargs + + def load_config(self, argv=None, aliases=None, flags=_deprecated, classes=None): + """Parse command line arguments and return as a Config object. + + Parameters + ---------- + argv : optional, list + If given, a list with the structure of sys.argv[1:] to parse + arguments from. If not given, the instance's self.argv attribute + (given at construction time) is used. + flags + Deprecated in traitlets 5.0, instanciate the config loader with the flags. + + """ + + if flags is not _deprecated: + warnings.warn( + "The `flag` argument to load_config is deprecated since Traitlets " + f"5.0 and will be ignored, pass flags the `{type(self)}` constructor.", + DeprecationWarning, + stacklevel=2, + ) + + self.clear() + if argv is None: + argv = self.argv + if aliases is not None: + self.aliases = aliases + if classes is not None: + self.classes = classes + self._create_parser() + self._argcomplete(self.classes, self.subcommands) + self._parse_args(argv) + self._convert_to_config() + return self.config + + def get_extra_args(self): + if hasattr(self, "extra_args"): + return self.extra_args + else: + return [] + + def _create_parser(self): + self.parser = self.parser_class( + *self.parser_args, **self.parser_kw # type:ignore[arg-type] + ) + self._add_arguments(self.aliases, self.flags, self.classes) + + def _add_arguments(self, aliases, flags, classes): + raise NotImplementedError("subclasses must implement _add_arguments") + + def _argcomplete( + self, classes: t.List[t.Any], subcommands: t.Optional[SubcommandsDict] + ) -> None: + """If argcomplete is enabled, allow triggering command-line autocompletion""" + pass + + def _parse_args(self, args): + """self.parser->self.parsed_data""" + uargs = [cast_unicode(a) for a in args] + + unpacked_aliases: t.Dict[str, str] = {} + if self.aliases: + unpacked_aliases = {} + for alias, alias_target in self.aliases.items(): + if alias in self.flags: + continue + if not isinstance(alias, tuple): + alias = (alias,) + for al in alias: + if len(al) == 1: + unpacked_aliases["-" + al] = "--" + alias_target + unpacked_aliases["--" + al] = "--" + alias_target + + def _replace(arg): + if arg == "-": + return _DASH_REPLACEMENT + for k, v in unpacked_aliases.items(): + if arg == k: + return v + if arg.startswith(k + "="): + return v + "=" + arg[len(k) + 1 :] + return arg + + if "--" in uargs: + idx = uargs.index("--") + extra_args = uargs[idx + 1 :] + to_parse = uargs[:idx] + else: + extra_args = [] + to_parse = uargs + to_parse = [_replace(a) for a in to_parse] + + self.parsed_data = self.parser.parse_args(to_parse) + self.extra_args = extra_args + + def _convert_to_config(self): + """self.parsed_data->self.config""" + for k, v in vars(self.parsed_data).items(): + *path, key = k.split(".") + section = self.config + for p in path: + section = section[p] + setattr(section, key, v) + + +class _FlagAction(argparse.Action): + """ArgParse action to handle a flag""" + + def __init__(self, *args, **kwargs): + self.flag = kwargs.pop("flag") + self.alias = kwargs.pop("alias", None) + kwargs["const"] = Undefined + if not self.alias: + kwargs["nargs"] = 0 + super().__init__(*args, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + if self.nargs == 0 or values is Undefined: + if not hasattr(namespace, "_flags"): + namespace._flags = [] + namespace._flags.append(self.flag) + else: + setattr(namespace, self.alias, values) + + +class KVArgParseConfigLoader(ArgParseConfigLoader): + """A config loader that loads aliases and flags with argparse, + + as well as arbitrary --Class.trait value + """ + + parser_class = _KVArgParser # type:ignore[assignment] + + def _add_arguments(self, aliases, flags, classes): + alias_flags: t.Dict[str, t.Any] = {} + argparse_kwds: t.Dict[str, t.Any] + paa = self.parser.add_argument + self.parser.set_defaults(_flags=[]) + paa("extra_args", nargs="*") + + # An index of all container traits collected:: + # + # { <traitname>: (<trait>, <argparse-kwds>) } + # + # Used to add the correct type into the `config` tree. + # Used also for aliases, not to re-collect them. + self.argparse_traits = argparse_traits = {} + for cls in classes: + for traitname, trait in cls.class_traits(config=True).items(): + argname = f"{cls.__name__}.{traitname}" + argparse_kwds = {"type": str} + if isinstance(trait, (Container, Dict)): + multiplicity = trait.metadata.get("multiplicity", "append") + if multiplicity == "append": + argparse_kwds["action"] = multiplicity + else: + argparse_kwds["nargs"] = multiplicity + argparse_traits[argname] = (trait, argparse_kwds) + + for keys, (value, fhelp) in flags.items(): + if not isinstance(keys, tuple): + keys = (keys,) + for key in keys: + if key in aliases: + alias_flags[aliases[key]] = value + continue + keys = ("-" + key, "--" + key) if len(key) == 1 else ("--" + key,) + paa(*keys, action=_FlagAction, flag=value, help=fhelp) + + for keys, traitname in aliases.items(): + if not isinstance(keys, tuple): + keys = (keys,) + + for key in keys: + argparse_kwds = { + "type": str, + "dest": traitname.replace(".", _DOT_REPLACEMENT), + "metavar": traitname, + } + argcompleter = None + if traitname in argparse_traits: + trait, kwds = argparse_traits[traitname] + argparse_kwds.update(kwds) + if "action" in argparse_kwds and traitname in alias_flags: + # flag sets 'action', so can't have flag & alias with custom action + # on the same name + raise ArgumentError( + "The alias `%s` for the 'append' sequence " + "config-trait `%s` cannot be also a flag!'" % (key, traitname) + ) + # For argcomplete, check if any either an argcompleter metadata tag or method + # is available. If so, it should be a callable which takes the command-line key + # string as an argument and other kwargs passed by argcomplete, + # and returns the a list of string completions. + argcompleter = trait.metadata.get("argcompleter") or getattr( + trait, "argcompleter", None + ) + if traitname in alias_flags: + # alias and flag. + # when called with 0 args: flag + # when called with >= 1: alias + argparse_kwds.setdefault("nargs", "?") + argparse_kwds["action"] = _FlagAction + argparse_kwds["flag"] = alias_flags[traitname] + argparse_kwds["alias"] = traitname + keys = ("-" + key, "--" + key) if len(key) == 1 else ("--" + key,) + action = paa(*keys, **argparse_kwds) + if argcompleter is not None: + # argcomplete's completers are callables returning list of completion strings + action.completer = functools.partial(argcompleter, key=key) # type: ignore + + def _convert_to_config(self): + """self.parsed_data->self.config, parse unrecognized extra args via KVLoader.""" + extra_args = self.extra_args + + for lhs, rhs in vars(self.parsed_data).items(): + if lhs == "extra_args": + self.extra_args = ["-" if a == _DASH_REPLACEMENT else a for a in rhs] + extra_args + continue + elif lhs == "_flags": + # _flags will be handled later + continue + + lhs = lhs.replace(_DOT_REPLACEMENT, ".") + if "." not in lhs: + self._handle_unrecognized_alias(lhs) + trait = None + + if isinstance(rhs, list): + rhs = DeferredConfigList(rhs) + elif isinstance(rhs, str): + rhs = DeferredConfigString(rhs) + + trait = self.argparse_traits.get(lhs) + if trait: + trait = trait[0] + + # eval the KV assignment + try: + self._exec_config_str(lhs, rhs, trait) + except Exception as e: + # cast deferred to nicer repr for the error + # DeferredList->list, etc + if isinstance(rhs, DeferredConfig): + rhs = rhs._super_repr() + raise ArgumentError(f"Error loading argument {lhs}={rhs}, {e}") from e + + for subc in self.parsed_data._flags: + self._load_flag(subc) + + def _handle_unrecognized_alias(self, arg: str) -> None: + """Handling for unrecognized alias arguments + + Probably a mistyped alias. By default just log a warning, + but users can override this to raise an error instead, e.g. + self.parser.error("Unrecognized alias: '%s'" % arg) + """ + self.log.warning("Unrecognized alias: '%s', it will have no effect.", arg) + + def _argcomplete( + self, classes: t.List[t.Any], subcommands: t.Optional[SubcommandsDict] + ) -> None: + """If argcomplete is enabled, allow triggering command-line autocompletion""" + try: + import argcomplete # type: ignore[import] # noqa + except ImportError: + return + + from . import argcomplete_config + + finder = argcomplete_config.ExtendedCompletionFinder() + finder.config_classes = classes + finder.subcommands = list(subcommands or []) + # for ease of testing, pass through self._argcomplete_kwargs if set + finder(self.parser, **getattr(self, "_argcomplete_kwargs", {})) + + +class KeyValueConfigLoader(KVArgParseConfigLoader): + """Deprecated in traitlets 5.0 + + Use KVArgParseConfigLoader + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + "KeyValueConfigLoader is deprecated since Traitlets 5.0." + " Use KVArgParseConfigLoader instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + + +def load_pyconfig_files(config_files, path): + """Load multiple Python config files, merging each of them in turn. + + Parameters + ---------- + config_files : list of str + List of config files names to load and merge into the config. + path : unicode + The full path to the location of the config files. + """ + config = Config() + for cf in config_files: + loader = PyFileConfigLoader(cf, path=path) + try: + next_config = loader.load_config() + except ConfigFileNotFound: + pass + except Exception: + raise + else: + config.merge(next_config) + return config diff --git a/contrib/python/traitlets/py3/traitlets/config/manager.py b/contrib/python/traitlets/py3/traitlets/config/manager.py new file mode 100644 index 0000000000..728cd2f22c --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/manager.py @@ -0,0 +1,82 @@ +"""Manager to read and modify config data in JSON files. +""" +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +import errno +import json +import os + +from traitlets.config import LoggingConfigurable +from traitlets.traitlets import Unicode + + +def recursive_update(target, new): + """Recursively update one dictionary using another. + + None values will delete their keys. + """ + for k, v in new.items(): + if isinstance(v, dict): + if k not in target: + target[k] = {} + recursive_update(target[k], v) + if not target[k]: + # Prune empty subdicts + del target[k] + + elif v is None: + target.pop(k, None) + + else: + target[k] = v + + +class BaseJSONConfigManager(LoggingConfigurable): + """General JSON config manager + + Deals with persisting/storing config in a json file + """ + + config_dir = Unicode(".") + + def ensure_config_dir_exists(self): + try: + os.makedirs(self.config_dir, 0o755) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + def file_name(self, section_name): + return os.path.join(self.config_dir, section_name + ".json") + + def get(self, section_name): + """Retrieve the config data for the specified section. + + Returns the data as a dictionary, or an empty dictionary if the file + doesn't exist. + """ + filename = self.file_name(section_name) + if os.path.isfile(filename): + with open(filename, encoding="utf-8") as f: + return json.load(f) + else: + return {} + + def set(self, section_name, data): + """Store the given config data.""" + filename = self.file_name(section_name) + self.ensure_config_dir_exists() + + f = open(filename, "w", encoding="utf-8") + with f: + json.dump(data, f, indent=2) + + def update(self, section_name, new_data): + """Modify the config section by recursively updating it with new_data. + + Returns the modified config data as a dictionary. + """ + data = self.get(section_name) + recursive_update(data, new_data) + self.set(section_name, data) + return data diff --git a/contrib/python/traitlets/py3/traitlets/config/sphinxdoc.py b/contrib/python/traitlets/py3/traitlets/config/sphinxdoc.py new file mode 100644 index 0000000000..92c2d64d67 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/sphinxdoc.py @@ -0,0 +1,161 @@ +"""Machinery for documenting traitlets config options with Sphinx. + +This includes: + +- A Sphinx extension defining directives and roles for config options. +- A function to generate an rst file given an Application instance. + +To make this documentation, first set this module as an extension in Sphinx's +conf.py:: + + extensions = [ + # ... + 'traitlets.config.sphinxdoc', + ] + +Autogenerate the config documentation by running code like this before +Sphinx builds:: + + from traitlets.config.sphinxdoc import write_doc + from myapp import MyApplication + + writedoc('config/options.rst', # File to write + 'MyApp config options', # Title + MyApplication() + ) + +The generated rST syntax looks like this:: + + .. configtrait:: Application.log_datefmt + + Description goes here. + + Cross reference like this: :configtrait:`Application.log_datefmt`. +""" +from collections import defaultdict +from textwrap import dedent + +from traitlets import Undefined +from traitlets.utils.text import indent + + +def setup(app): + """Registers the Sphinx extension. + + You shouldn't need to call this directly; configure Sphinx to use this + module instead. + """ + app.add_object_type("configtrait", "configtrait", objname="Config option") + metadata = {"parallel_read_safe": True, "parallel_write_safe": True} + return metadata + + +def interesting_default_value(dv): + if (dv is None) or (dv is Undefined): + return False + if isinstance(dv, (str, list, tuple, dict, set)): + return bool(dv) + return True + + +def format_aliases(aliases): + fmted = [] + for a in aliases: + dashes = "-" if len(a) == 1 else "--" + fmted.append(f"``{dashes}{a}``") + return ", ".join(fmted) + + +def class_config_rst_doc(cls, trait_aliases): + """Generate rST documentation for this class' config options. + + Excludes traits defined on parent classes. + """ + lines = [] + classname = cls.__name__ + for _, trait in sorted(cls.class_traits(config=True).items()): + ttype = trait.__class__.__name__ + + fullname = classname + "." + trait.name + lines += [".. configtrait:: " + fullname, ""] + + help = trait.help.rstrip() or "No description" + lines.append(indent(dedent(help)) + "\n") + + # Choices or type + if "Enum" in ttype: + # include Enum choices + lines.append(indent(":options: " + ", ".join("``%r``" % x for x in trait.values))) + else: + lines.append(indent(":trait type: " + ttype)) + + # Default value + # Ignore boring default values like None, [] or '' + if interesting_default_value(trait.default_value): + try: + dvr = trait.default_value_repr() + except Exception: + dvr = None # ignore defaults we can't construct + if dvr is not None: + if len(dvr) > 64: + dvr = dvr[:61] + "..." + # Double up backslashes, so they get to the rendered docs + dvr = dvr.replace("\\n", "\\\\n") + lines.append(indent(":default: ``%s``" % dvr)) + + # Command line aliases + if trait_aliases[fullname]: + fmt_aliases = format_aliases(trait_aliases[fullname]) + lines.append(indent(":CLI option: " + fmt_aliases)) + + # Blank line + lines.append("") + + return "\n".join(lines) + + +def reverse_aliases(app): + """Produce a mapping of trait names to lists of command line aliases.""" + res = defaultdict(list) + for alias, trait in app.aliases.items(): + res[trait].append(alias) + + # Flags also often act as aliases for a boolean trait. + # Treat flags which set one trait to True as aliases. + for flag, (cfg, _) in app.flags.items(): + if len(cfg) == 1: + classname = list(cfg)[0] + cls_cfg = cfg[classname] + if len(cls_cfg) == 1: + traitname = list(cls_cfg)[0] + if cls_cfg[traitname] is True: + res[classname + "." + traitname].append(flag) + + return res + + +def write_doc(path, title, app, preamble=None): + """Write a rst file documenting config options for a traitlets application. + + Parameters + ---------- + path : str + The file to be written + title : str + The human-readable title of the document + app : traitlets.config.Application + An instance of the application class to be documented + preamble : str + Extra text to add just after the title (optional) + """ + trait_aliases = reverse_aliases(app) + with open(path, "w") as f: + f.write(title + "\n") + f.write(("=" * len(title)) + "\n") + f.write("\n") + if preamble is not None: + f.write(preamble + "\n\n") + + for c in app._classes_inc_parents(): + f.write(class_config_rst_doc(c, trait_aliases)) + f.write("\n") diff --git a/contrib/python/traitlets/py3/traitlets/config/tests/test_application.py b/contrib/python/traitlets/py3/traitlets/config/tests/test_application.py new file mode 100644 index 0000000000..62585aa29c --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/tests/test_application.py @@ -0,0 +1,914 @@ +""" +Tests for traitlets.config.application.Application +""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +import contextlib +import io +import json +import logging +import os +import sys +import typing as t +from io import StringIO +from tempfile import TemporaryDirectory +from unittest import TestCase + +import pytest +from pytest import mark + +from traitlets import Bool, Bytes, Dict, HasTraits, Integer, List, Set, Tuple, Unicode +from traitlets.config.application import Application +from traitlets.config.configurable import Configurable +from traitlets.config.loader import Config, KVArgParseConfigLoader +from traitlets.tests.utils import check_help_all_output, check_help_output, get_output_error_code + +try: + from unittest import mock +except ImportError: + from unittest import mock + +pjoin = os.path.join + + +class Foo(Configurable): + + i = Integer( + 0, + help=""" + The integer i. + + Details about i. + """, + ).tag(config=True) + j = Integer(1, help="The integer j.").tag(config=True) + name = Unicode("Brian", help="First name.").tag(config=True) + la = List([]).tag(config=True) + li = List(Integer()).tag(config=True) + fdict = Dict().tag(config=True, multiplicity="+") + + +class Bar(Configurable): + + b = Integer(0, help="The integer b.").tag(config=True) + enabled = Bool(True, help="Enable bar.").tag(config=True) + tb = Tuple(()).tag(config=True, multiplicity="*") + aset = Set().tag(config=True, multiplicity="+") + bdict = Dict().tag(config=True) + idict = Dict(value_trait=Integer()).tag(config=True) + key_dict = Dict(per_key_traits={"i": Integer(), "b": Bytes()}).tag(config=True) + + +class MyApp(Application): + + name = Unicode("myapp") + running = Bool(False, help="Is the app running?").tag(config=True) + classes = List([Bar, Foo]) # type:ignore + config_file = Unicode("", help="Load this config file").tag(config=True) + + warn_tpyo = Unicode( + "yes the name is wrong on purpose", + config=True, + help="Should print a warning if `MyApp.warn-typo=...` command is passed", + ) + + aliases: t.Dict[t.Any, t.Any] = {} + aliases.update(Application.aliases) + aliases.update( + { + ("fooi", "i"): "Foo.i", + ("j", "fooj"): ("Foo.j", "`j` terse help msg"), + "name": "Foo.name", + "la": "Foo.la", + "li": "Foo.li", + "tb": "Bar.tb", + "D": "Bar.bdict", + "enabled": "Bar.enabled", + "enable": "Bar.enabled", + "log-level": "Application.log_level", + } + ) + + flags: t.Dict[t.Any, t.Any] = {} + flags.update(Application.flags) + flags.update( + { + ("enable", "e"): ({"Bar": {"enabled": True}}, "Set Bar.enabled to True"), + ("d", "disable"): ({"Bar": {"enabled": False}}, "Set Bar.enabled to False"), + "crit": ({"Application": {"log_level": logging.CRITICAL}}, "set level=CRITICAL"), + } + ) + + def init_foo(self): + self.foo = Foo(parent=self) + + def init_bar(self): + self.bar = Bar(parent=self) + + +def class_to_names(classes): + return [klass.__name__ for klass in classes] + + +class TestApplication(TestCase): + def test_log(self): + stream = StringIO() + app = MyApp(log_level=logging.INFO) + handler = logging.StreamHandler(stream) + # trigger reconstruction of the log formatter + app.log_format = "%(message)s" + app.log_datefmt = "%Y-%m-%d %H:%M" + app.log.handlers = [handler] + app.log.info("hello") + assert "hello" in stream.getvalue() + + def test_no_eval_cli_text(self): + app = MyApp() + app.initialize(["--Foo.name=1"]) + app.init_foo() + assert app.foo.name == "1" + + def test_basic(self): + app = MyApp() + self.assertEqual(app.name, "myapp") + self.assertEqual(app.running, False) + self.assertEqual(app.classes, [MyApp, Bar, Foo]) # type:ignore + self.assertEqual(app.config_file, "") + + def test_app_name_set_via_constructor(self): + app = MyApp(name='set_via_constructor') + assert app.name == "set_via_constructor" + + def test_mro_discovery(self): + app = MyApp() + + self.assertSequenceEqual( + class_to_names(app._classes_with_config_traits()), + ["Application", "MyApp", "Bar", "Foo"], + ) + self.assertSequenceEqual( + class_to_names(app._classes_inc_parents()), + [ + "Configurable", + "LoggingConfigurable", + "SingletonConfigurable", + "Application", + "MyApp", + "Bar", + "Foo", + ], + ) + + self.assertSequenceEqual( + class_to_names(app._classes_with_config_traits([Application])), ["Application"] + ) + self.assertSequenceEqual( + class_to_names(app._classes_inc_parents([Application])), + ["Configurable", "LoggingConfigurable", "SingletonConfigurable", "Application"], + ) + + self.assertSequenceEqual(class_to_names(app._classes_with_config_traits([Foo])), ["Foo"]) + self.assertSequenceEqual( + class_to_names(app._classes_inc_parents([Bar])), ["Configurable", "Bar"] + ) + + class MyApp2(Application): # no defined `classes` attr + pass + + self.assertSequenceEqual(class_to_names(app._classes_with_config_traits([Foo])), ["Foo"]) + self.assertSequenceEqual( + class_to_names(app._classes_inc_parents([Bar])), ["Configurable", "Bar"] + ) + + def test_config(self): + app = MyApp() + app.parse_command_line( + [ + "--i=10", + "--Foo.j=10", + "--enable=False", + "--log-level=50", + ] + ) + config = app.config + print(config) + self.assertEqual(config.Foo.i, 10) + self.assertEqual(config.Foo.j, 10) + self.assertEqual(config.Bar.enabled, False) + self.assertEqual(config.MyApp.log_level, 50) + + def test_config_seq_args(self): + app = MyApp() + app.parse_command_line( + "--li 1 --li 3 --la 1 --tb AB 2 --Foo.la=ab --Bar.aset S1 --Bar.aset S2 --Bar.aset S1".split() + ) + assert app.extra_args == ["2"] + config = app.config + assert config.Foo.li == [1, 3] + assert config.Foo.la == ["1", "ab"] + assert config.Bar.tb == ("AB",) + self.assertEqual(config.Bar.aset, {"S1", "S2"}) + app.init_foo() + assert app.foo.li == [1, 3] + assert app.foo.la == ["1", "ab"] + app.init_bar() + self.assertEqual(app.bar.aset, {"S1", "S2"}) + assert app.bar.tb == ("AB",) + + def test_config_dict_args(self): + app = MyApp() + app.parse_command_line( + "--Foo.fdict a=1 --Foo.fdict b=b --Foo.fdict c=3 " + "--Bar.bdict k=1 -D=a=b -D 22=33 " + "--Bar.idict k=1 --Bar.idict b=2 --Bar.idict c=3 ".split() + ) + fdict = {"a": "1", "b": "b", "c": "3"} + bdict = {"k": "1", "a": "b", "22": "33"} + idict = {"k": 1, "b": 2, "c": 3} + config = app.config + assert config.Bar.idict == idict + self.assertDictEqual(config.Foo.fdict, fdict) + self.assertDictEqual(config.Bar.bdict, bdict) + app.init_foo() + self.assertEqual(app.foo.fdict, fdict) + app.init_bar() + assert app.bar.idict == idict + self.assertEqual(app.bar.bdict, bdict) + + def test_config_propagation(self): + app = MyApp() + app.parse_command_line(["--i=10", "--Foo.j=10", "--enable=False", "--log-level=50"]) + app.init_foo() + app.init_bar() + self.assertEqual(app.foo.i, 10) + self.assertEqual(app.foo.j, 10) + self.assertEqual(app.bar.enabled, False) + + def test_cli_priority(self): + """Test that loading config files does not override CLI options""" + name = "config.py" + + class TestApp(Application): + value = Unicode().tag(config=True) + config_file_loaded = Bool().tag(config=True) + aliases = {"v": "TestApp.value"} + + app = TestApp() + with TemporaryDirectory() as td: + config_file = pjoin(td, name) + with open(config_file, "w") as f: + f.writelines( + ["c.TestApp.value = 'config file'\n", "c.TestApp.config_file_loaded = True\n"] + ) + + app.parse_command_line(["--v=cli"]) + assert "value" in app.config.TestApp + assert app.config.TestApp.value == "cli" + assert app.value == "cli" + + app.load_config_file(name, path=[td]) + assert app.config_file_loaded + assert app.config.TestApp.value == "cli" + assert app.value == "cli" + + def test_ipython_cli_priority(self): + # this test is almost entirely redundant with above, + # but we can keep it around in case of subtle issues creeping into + # the exact sequence IPython follows. + name = "config.py" + + class TestApp(Application): + value = Unicode().tag(config=True) + config_file_loaded = Bool().tag(config=True) + aliases = {"v": ("TestApp.value", "some help")} + + app = TestApp() + with TemporaryDirectory() as td: + config_file = pjoin(td, name) + with open(config_file, "w") as f: + f.writelines( + ["c.TestApp.value = 'config file'\n", "c.TestApp.config_file_loaded = True\n"] + ) + # follow IPython's config-loading sequence to ensure CLI priority is preserved + app.parse_command_line(["--v=cli"]) + # this is where IPython makes a mistake: + # it assumes app.config will not be modified, + # and storing a reference is storing a copy + cli_config = app.config + assert "value" in app.config.TestApp + assert app.config.TestApp.value == "cli" + assert app.value == "cli" + app.load_config_file(name, path=[td]) + assert app.config_file_loaded + # enforce cl-opts override config file opts: + # this is where IPython makes a mistake: it assumes + # that cl_config is a different object, but it isn't. + app.update_config(cli_config) + assert app.config.TestApp.value == "cli" + assert app.value == "cli" + + def test_cli_allow_none(self): + class App(Application): + aliases = {"opt": "App.opt"} + opt = Unicode(allow_none=True, config=True) + + app = App() + app.parse_command_line(["--opt=None"]) + assert app.opt is None + + def test_flags(self): + app = MyApp() + app.parse_command_line(["--disable"]) + app.init_bar() + self.assertEqual(app.bar.enabled, False) + + app = MyApp() + app.parse_command_line(["-d"]) + app.init_bar() + self.assertEqual(app.bar.enabled, False) + + app = MyApp() + app.parse_command_line(["--enable"]) + app.init_bar() + self.assertEqual(app.bar.enabled, True) + + app = MyApp() + app.parse_command_line(["-e"]) + app.init_bar() + self.assertEqual(app.bar.enabled, True) + + def test_flags_help_msg(self): + app = MyApp() + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout): + app.print_flag_help() + hmsg = stdout.getvalue() + self.assertRegex(hmsg, "(?<!-)-e, --enable\\b") + self.assertRegex(hmsg, "(?<!-)-d, --disable\\b") + self.assertIn("Equivalent to: [--Bar.enabled=True]", hmsg) + self.assertIn("Equivalent to: [--Bar.enabled=False]", hmsg) + + def test_aliases(self): + app = MyApp() + app.parse_command_line(["--i=5", "--j=10"]) + app.init_foo() + self.assertEqual(app.foo.i, 5) + app.init_foo() + self.assertEqual(app.foo.j, 10) + + app = MyApp() + app.parse_command_line(["-i=5", "-j=10"]) + app.init_foo() + self.assertEqual(app.foo.i, 5) + app.init_foo() + self.assertEqual(app.foo.j, 10) + + app = MyApp() + app.parse_command_line(["--fooi=5", "--fooj=10"]) + app.init_foo() + self.assertEqual(app.foo.i, 5) + app.init_foo() + self.assertEqual(app.foo.j, 10) + + def test_aliases_multiple(self): + # Test multiple > 2 aliases for the same argument + class TestMultiAliasApp(Application): + foo = Integer(config=True) + aliases = {("f", "bar", "qux"): "TestMultiAliasApp.foo"} + + app = TestMultiAliasApp() + app.parse_command_line(["-f", "3"]) + self.assertEqual(app.foo, 3) + + app = TestMultiAliasApp() + app.parse_command_line(["--bar", "4"]) + self.assertEqual(app.foo, 4) + + app = TestMultiAliasApp() + app.parse_command_line(["--qux", "5"]) + self.assertEqual(app.foo, 5) + + def test_aliases_help_msg(self): + app = MyApp() + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout): + app.print_alias_help() + hmsg = stdout.getvalue() + self.assertRegex(hmsg, "(?<!-)-i, --fooi\\b") + self.assertRegex(hmsg, "(?<!-)-j, --fooj\\b") + self.assertIn("Equivalent to: [--Foo.i]", hmsg) + self.assertIn("Equivalent to: [--Foo.j]", hmsg) + self.assertIn("Equivalent to: [--Foo.name]", hmsg) + + def test_alias_unrecognized(self): + """Check ability to override handling for unrecognized aliases""" + + class StrictLoader(KVArgParseConfigLoader): + def _handle_unrecognized_alias(self, arg): + self.parser.error("Unrecognized alias: %s" % arg) + + class StrictApplication(Application): + def _create_loader(self, argv, aliases, flags, classes): + return StrictLoader(argv, aliases, flags, classes=classes, log=self.log) + + app = StrictApplication() + app.initialize(["--log-level=20"]) # recognized alias + assert app.log_level == 20 + + app = StrictApplication() + with pytest.raises(SystemExit, match="2"): + app.initialize(["--unrecognized=20"]) + + # Ideally we would use pytest capsys fixture, but fixtures are incompatible + # with unittest.TestCase-style classes :( + # stderr = capsys.readouterr().err + # assert "Unrecognized alias: unrecognized" in stderr + + def test_flag_clobber(self): + """test that setting flags doesn't clobber existing settings""" + app = MyApp() + app.parse_command_line(["--Bar.b=5", "--disable"]) + app.init_bar() + self.assertEqual(app.bar.enabled, False) + self.assertEqual(app.bar.b, 5) + app.parse_command_line(["--enable", "--Bar.b=10"]) + app.init_bar() + self.assertEqual(app.bar.enabled, True) + self.assertEqual(app.bar.b, 10) + + def test_warn_autocorrect(self): + stream = StringIO() + app = MyApp(log_level=logging.INFO) + app.log.handlers = [logging.StreamHandler(stream)] + + cfg = Config() + cfg.MyApp.warn_typo = "WOOOO" + app.config = cfg + + self.assertIn("warn_typo", stream.getvalue()) + self.assertIn("warn_tpyo", stream.getvalue()) + + def test_flatten_flags(self): + cfg = Config() + cfg.MyApp.log_level = logging.WARN + app = MyApp() + app.update_config(cfg) + self.assertEqual(app.log_level, logging.WARN) + self.assertEqual(app.config.MyApp.log_level, logging.WARN) + app.initialize(["--crit"]) + self.assertEqual(app.log_level, logging.CRITICAL) + # this would be app.config.Application.log_level if it failed: + self.assertEqual(app.config.MyApp.log_level, logging.CRITICAL) + + def test_flatten_aliases(self): + cfg = Config() + cfg.MyApp.log_level = logging.WARN + app = MyApp() + app.update_config(cfg) + self.assertEqual(app.log_level, logging.WARN) + self.assertEqual(app.config.MyApp.log_level, logging.WARN) + app.initialize(["--log-level", "CRITICAL"]) + self.assertEqual(app.log_level, logging.CRITICAL) + # this would be app.config.Application.log_level if it failed: + self.assertEqual(app.config.MyApp.log_level, "CRITICAL") + + def test_extra_args(self): + + app = MyApp() + app.parse_command_line(["--Bar.b=5", "extra", "args", "--disable"]) + app.init_bar() + self.assertEqual(app.bar.enabled, False) + self.assertEqual(app.bar.b, 5) + self.assertEqual(app.extra_args, ["extra", "args"]) + + app = MyApp() + app.parse_command_line(["--Bar.b=5", "--", "extra", "--disable", "args"]) + app.init_bar() + self.assertEqual(app.bar.enabled, True) + self.assertEqual(app.bar.b, 5) + self.assertEqual(app.extra_args, ["extra", "--disable", "args"]) + + app = MyApp() + app.parse_command_line(["--disable", "--la", "-", "-", "--Bar.b=1", "--", "-", "extra"]) + self.assertEqual(app.extra_args, ["-", "-", "extra"]) + + def test_unicode_argv(self): + app = MyApp() + app.parse_command_line(["ünîcødé"]) + + def test_document_config_option(self): + app = MyApp() + app.document_config_options() + + def test_generate_config_file(self): + app = MyApp() + assert "The integer b." in app.generate_config_file() + + def test_generate_config_file_classes_to_include(self): + class NotInConfig(HasTraits): + from_hidden = Unicode( + "x", + help="""From hidden class + + Details about from_hidden. + """, + ).tag(config=True) + + class NoTraits(Foo, Bar, NotInConfig): + pass + + app = MyApp() + app.classes.append(NoTraits) # type:ignore + + conf_txt = app.generate_config_file() + print(conf_txt) + self.assertIn("The integer b.", conf_txt) + self.assertIn("# Foo(Configurable)", conf_txt) + self.assertNotIn("# Configurable", conf_txt) + self.assertIn("# NoTraits(Foo, Bar)", conf_txt) + + # inherited traits, parent in class list: + self.assertIn("# c.NoTraits.i", conf_txt) + self.assertIn("# c.NoTraits.j", conf_txt) + self.assertIn("# c.NoTraits.n", conf_txt) + self.assertIn("# See also: Foo.j", conf_txt) + self.assertIn("# See also: Bar.b", conf_txt) + self.assertEqual(conf_txt.count("Details about i."), 1) + + # inherited traits, parent not in class list: + self.assertIn("# c.NoTraits.from_hidden", conf_txt) + self.assertNotIn("# See also: NotInConfig.", conf_txt) + self.assertEqual(conf_txt.count("Details about from_hidden."), 1) + self.assertNotIn("NotInConfig", conf_txt) + + def test_multi_file(self): + app = MyApp() + app.log = logging.getLogger() + name = "config.py" + with TemporaryDirectory("_1") as td1: + with open(pjoin(td1, name), "w") as f1: + f1.write("get_config().MyApp.Bar.b = 1") + with TemporaryDirectory("_2") as td2: + with open(pjoin(td2, name), "w") as f2: + f2.write("get_config().MyApp.Bar.b = 2") + app.load_config_file(name, path=[td2, td1]) + app.init_bar() + self.assertEqual(app.bar.b, 2) + app.load_config_file(name, path=[td1, td2]) + app.init_bar() + self.assertEqual(app.bar.b, 1) + + @mark.skipif(not hasattr(TestCase, "assertLogs"), reason="requires TestCase.assertLogs") + def test_log_collisions(self): + app = MyApp() + app.log = logging.getLogger() + app.log.setLevel(logging.INFO) + name = "config" + with TemporaryDirectory("_1") as td: + with open(pjoin(td, name + ".py"), "w") as f: + f.write("get_config().Bar.b = 1") + with open(pjoin(td, name + ".json"), "w") as f: + json.dump({"Bar": {"b": 2}}, f) + with self.assertLogs(app.log, logging.WARNING) as captured: + app.load_config_file(name, path=[td]) + app.init_bar() + assert app.bar.b == 2 + output = "\n".join(captured.output) + assert "Collision" in output + assert "1 ignored, using 2" in output + assert pjoin(td, name + ".py") in output + assert pjoin(td, name + ".json") in output + + @mark.skipif(not hasattr(TestCase, "assertLogs"), reason="requires TestCase.assertLogs") + def test_log_bad_config(self): + app = MyApp() + app.log = logging.getLogger() + name = "config.py" + with TemporaryDirectory() as td: + with open(pjoin(td, name), "w") as f: + f.write("syntax error()") + with self.assertLogs(app.log, logging.ERROR) as captured: + app.load_config_file(name, path=[td]) + output = "\n".join(captured.output) + self.assertIn("SyntaxError", output) + + def test_raise_on_bad_config(self): + app = MyApp() + app.raise_config_file_errors = True + app.log = logging.getLogger() + name = "config.py" + with TemporaryDirectory() as td: + with open(pjoin(td, name), "w") as f: + f.write("syntax error()") + with self.assertRaises(SyntaxError): + app.load_config_file(name, path=[td]) + + def test_subcommands_instanciation(self): + """Try all ways to specify how to create sub-apps.""" + app = Root.instance() + app.parse_command_line(["sub1"]) + + self.assertIsInstance(app.subapp, Sub1) + # Check parent hierarchy. + self.assertIs(app.subapp.parent, app) + + Root.clear_instance() + Sub1.clear_instance() # Otherwise, replaced spuriously and hierarchy check fails. + app = Root.instance() + + app.parse_command_line(["sub1", "sub2"]) + self.assertIsInstance(app.subapp, Sub1) + self.assertIsInstance(app.subapp.subapp, Sub2) + # Check parent hierarchy. + self.assertIs(app.subapp.parent, app) + self.assertIs(app.subapp.subapp.parent, app.subapp) + + Root.clear_instance() + Sub1.clear_instance() # Otherwise, replaced spuriously and hierarchy check fails. + app = Root.instance() + + app.parse_command_line(["sub1", "sub3"]) + self.assertIsInstance(app.subapp, Sub1) + self.assertIsInstance(app.subapp.subapp, Sub3) + self.assertTrue(app.subapp.subapp.flag) # Set by factory. + # Check parent hierarchy. + self.assertIs(app.subapp.parent, app) + self.assertIs(app.subapp.subapp.parent, app.subapp) # Set by factory. + + Root.clear_instance() + Sub1.clear_instance() + + def test_loaded_config_files(self): + app = MyApp() + app.log = logging.getLogger() + name = "config.py" + with TemporaryDirectory("_1") as td1: + config_file = pjoin(td1, name) + with open(config_file, "w") as f: + f.writelines(["c.MyApp.running = True\n"]) + + app.load_config_file(name, path=[td1]) + self.assertEqual(len(app.loaded_config_files), 1) + self.assertEqual(app.loaded_config_files[0], config_file) + + app.start() + self.assertEqual(app.running, True) + + # emulate an app that allows dynamic updates and update config file + with open(config_file, "w") as f: + f.writelines(["c.MyApp.running = False\n"]) + + # reload and verify update, and that loaded_configs was not increased + app.load_config_file(name, path=[td1]) + self.assertEqual(len(app.loaded_config_files), 1) + self.assertEqual(app.running, False) + + # Attempt to update, ensure error... + with self.assertRaises(AttributeError): + app.loaded_config_files = "/foo" # type:ignore + + # ensure it can't be udpated via append + app.loaded_config_files.append("/bar") + self.assertEqual(len(app.loaded_config_files), 1) + + # repeat to ensure no unexpected changes occurred + app.load_config_file(name, path=[td1]) + self.assertEqual(len(app.loaded_config_files), 1) + self.assertEqual(app.running, False) + + +@mark.skip +def test_cli_multi_scalar(caplog): + class App(Application): + aliases = {"opt": "App.opt"} + opt = Unicode(config=True) + + app = App(log=logging.getLogger()) + with pytest.raises(SystemExit): + app.parse_command_line(["--opt", "1", "--opt", "2"]) + record = caplog.get_records("call")[-1] + message = record.message + + assert "Error loading argument" in message + assert "App.opt=['1', '2']" in message + assert "opt only accepts one value" in message + assert record.levelno == logging.CRITICAL + + +class Root(Application): + subcommands = { + "sub1": ("__tests__.config.tests.test_application.Sub1", "import string"), + } + + +class Sub3(Application): + flag = Bool(False) + + +class Sub2(Application): + pass + + +class Sub1(Application): + subcommands: dict = { # type:ignore + "sub2": (Sub2, "Application class"), + "sub3": (lambda root: Sub3(parent=root, flag=True), "factory"), + } + + +class DeprecatedApp(Application): + override_called = False + parent_called = False + + def _config_changed(self, name, old, new): + self.override_called = True + + def _capture(*args): + self.parent_called = True + + with mock.patch.object(self.log, "debug", _capture): + super()._config_changed(name, old, new) + + +def test_deprecated_notifier(): + app = DeprecatedApp() + assert not app.override_called + assert not app.parent_called + app.config = Config({"A": {"b": "c"}}) + assert app.override_called + assert app.parent_called + + +def test_help_output(): + check_help_output(__name__) + + +def test_help_all_output(): + check_help_all_output(__name__) + + +def test_show_config_cli(): + out, err, ec = get_output_error_code([sys.executable, "-m", __name__, "--show-config"]) + assert ec == 0 + assert "show_config" not in out + + +def test_show_config_json_cli(): + out, err, ec = get_output_error_code([sys.executable, "-m", __name__, "--show-config-json"]) + assert ec == 0 + assert "show_config" not in out + + +def test_show_config(capsys): + cfg = Config() + cfg.MyApp.i = 5 + # don't show empty + cfg.OtherApp + + app = MyApp(config=cfg, show_config=True) + app.start() + out, err = capsys.readouterr() + assert "MyApp" in out + assert "i = 5" in out + assert "OtherApp" not in out + + +def test_show_config_json(capsys): + cfg = Config() + cfg.MyApp.i = 5 + cfg.OtherApp + + app = MyApp(config=cfg, show_config_json=True) + app.start() + out, err = capsys.readouterr() + displayed = json.loads(out) + assert Config(displayed) == cfg + + +def test_deep_alias(): + from traitlets import Int + from traitlets.config import Application, Configurable + + class Foo(Configurable): + val = Int(default_value=5).tag(config=True) + + class Bar(Configurable): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.foo = Foo(parent=self) + + class TestApp(Application): + name = "test" + + aliases = {"val": "Bar.Foo.val"} + classes = [Foo, Bar] + + def initialize(self, *args, **kwargs): + super().initialize(*args, **kwargs) + self.bar = Bar(parent=self) + + app = TestApp() + app.initialize(["--val=10"]) + assert app.bar.foo.val == 10 + assert len(list(app.emit_alias_help())) > 0 + + +def test_logging_config(tmp_path, capsys): + """We should be able to configure additional log handlers.""" + log_file = tmp_path / "log_file" + app = Application( + logging_config={ + "version": 1, + "handlers": { + "file": { + "class": "logging.FileHandler", + "level": "DEBUG", + "filename": str(log_file), + }, + }, + "loggers": { + "Application": { + "level": "DEBUG", + "handlers": ["console", "file"], + }, + }, + } + ) + # the default "console" handler + our new "file" handler + assert len(app.log.handlers) == 2 + + # log a couple of messages + app.log.info("info") + app.log.warning("warn") + + # test that log messages get written to the file + with open(log_file) as log_handle: + assert log_handle.read() == "info\nwarn\n" + + # test that log messages get written to stderr (default console handler) + assert capsys.readouterr().err == "[Application] WARNING | warn\n" + + +def test_get_default_logging_config_pythonw(monkeypatch): + """Ensure logging is correctly disabled for pythonw usage.""" + monkeypatch.setattr("traitlets.config.application.IS_PYTHONW", True) + config = Application().get_default_logging_config() + assert "handlers" not in config + assert "loggers" not in config + + monkeypatch.setattr("traitlets.config.application.IS_PYTHONW", False) + config = Application().get_default_logging_config() + assert "handlers" in config + assert "loggers" in config + + +@pytest.fixture +def caplogconfig(monkeypatch): + """Capture logging config events for DictConfigurator objects. + + This suppresses the event (so the configuration doesn't happen). + + Returns a list of (args, kwargs). + """ + calls = [] + + def _configure(*args, **kwargs): + nonlocal calls + calls.append((args, kwargs)) + + monkeypatch.setattr( + "logging.config.DictConfigurator.configure", + _configure, + ) + + return calls + + +@pytest.mark.skipif(sys.implementation.name == "pypy", reason="Test does not work on pypy") +def test_logging_teardown_on_error(capsys, caplogconfig): + """Ensure we don't try to open logs in order to close them (See #722). + + If you try to configure logging handlers whilst Python is shutting down + you may get traceback. + """ + # create and destroy an app (without configuring logging) + # (ensure that the logs were not configured) + app = Application() + del app + assert len(caplogconfig) == 0 # logging was not configured + out, err = capsys.readouterr() + assert "Traceback" not in err + + # ensure that the logs would have been configured otherwise + # (to prevent this test from yielding false-negatives) + app = Application() + app._logging_configured = True # make it look like logging was configured + del app + assert len(caplogconfig) == 1 # logging was configured + + +if __name__ == "__main__": + # for test_help_output: + MyApp.launch_instance() diff --git a/contrib/python/traitlets/py3/traitlets/config/tests/test_configurable.py b/contrib/python/traitlets/py3/traitlets/config/tests/test_configurable.py new file mode 100644 index 0000000000..1769976601 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/tests/test_configurable.py @@ -0,0 +1,712 @@ +"""Tests for traitlets.config.configurable""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +import logging +from unittest import TestCase + +from pytest import mark + +from traitlets.config.application import Application +from traitlets.config.configurable import Configurable, LoggingConfigurable, SingletonConfigurable +from traitlets.config.loader import Config +from traitlets.log import get_logger +from traitlets.traitlets import ( + CaselessStrEnum, + Dict, + Enum, + Float, + FuzzyEnum, + Integer, + List, + Set, + Unicode, + _deprecations_shown, + validate, +) + +from traitlets.tests._warnings import expected_warnings + + +class MyConfigurable(Configurable): + a = Integer(1, help="The integer a.").tag(config=True) + b = Float(1.0, help="The integer b.").tag(config=True) + c = Unicode("no config") + + +mc_help = """MyConfigurable(Configurable) options +------------------------------------ +--MyConfigurable.a=<Integer> + The integer a. + Default: 1 +--MyConfigurable.b=<Float> + The integer b. + Default: 1.0""" + +mc_help_inst = """MyConfigurable(Configurable) options +------------------------------------ +--MyConfigurable.a=<Integer> + The integer a. + Current: 5 +--MyConfigurable.b=<Float> + The integer b. + Current: 4.0""" + +# On Python 3, the Integer trait is a synonym for Int +mc_help = mc_help.replace("<Integer>", "<Int>") +mc_help_inst = mc_help_inst.replace("<Integer>", "<Int>") + + +class Foo(Configurable): + a = Integer(0, help="The integer a.").tag(config=True) + b = Unicode("nope").tag(config=True) + flist = List([]).tag(config=True) + fdict = Dict().tag(config=True) + + +class Bar(Foo): + b = Unicode("gotit", help="The string b.").tag(config=False) + c = Float(help="The string c.").tag(config=True) + bset = Set([]).tag(config=True, multiplicity="+") + bset_values = Set([2, 1, 5]).tag(config=True, multiplicity="+") + bdict = Dict().tag(config=True, multiplicity="+") + bdict_values = Dict({1: "a", "0": "b", 5: "c"}).tag(config=True, multiplicity="+") + + +foo_help = """Foo(Configurable) options +------------------------- +--Foo.a=<Int> + The integer a. + Default: 0 +--Foo.b=<Unicode> + Default: 'nope' +--Foo.fdict=<key-1>=<value-1>... + Default: {} +--Foo.flist=<list-item-1>... + Default: []""" + +bar_help = """Bar(Foo) options +---------------- +--Bar.a=<Int> + The integer a. + Default: 0 +--Bar.bdict <key-1>=<value-1>... + Default: {} +--Bar.bdict_values <key-1>=<value-1>... + Default: {1: 'a', '0': 'b', 5: 'c'} +--Bar.bset <set-item-1>... + Default: set() +--Bar.bset_values <set-item-1>... + Default: {1, 2, 5} +--Bar.c=<Float> + The string c. + Default: 0.0 +--Bar.fdict=<key-1>=<value-1>... + Default: {} +--Bar.flist=<list-item-1>... + Default: []""" + + +class TestConfigurable(TestCase): + def test_default(self): + c1 = Configurable() + c2 = Configurable(config=c1.config) + c3 = Configurable(config=c2.config) + self.assertEqual(c1.config, c2.config) + self.assertEqual(c2.config, c3.config) + + def test_custom(self): + config = Config() + config.foo = "foo" + config.bar = "bar" + c1 = Configurable(config=config) + c2 = Configurable(config=c1.config) + c3 = Configurable(config=c2.config) + self.assertEqual(c1.config, config) + self.assertEqual(c2.config, config) + self.assertEqual(c3.config, config) + # Test that copies are not made + self.assertTrue(c1.config is config) + self.assertTrue(c2.config is config) + self.assertTrue(c3.config is config) + self.assertTrue(c1.config is c2.config) + self.assertTrue(c2.config is c3.config) + + def test_inheritance(self): + config = Config() + config.MyConfigurable.a = 2 + config.MyConfigurable.b = 2.0 + c1 = MyConfigurable(config=config) + c2 = MyConfigurable(config=c1.config) + self.assertEqual(c1.a, config.MyConfigurable.a) + self.assertEqual(c1.b, config.MyConfigurable.b) + self.assertEqual(c2.a, config.MyConfigurable.a) + self.assertEqual(c2.b, config.MyConfigurable.b) + + def test_parent(self): + config = Config() + config.Foo.a = 10 + config.Foo.b = "wow" + config.Bar.b = "later" + config.Bar.c = 100.0 + f = Foo(config=config) + with expected_warnings(["`b` not recognized"]): + b = Bar(config=f.config) + self.assertEqual(f.a, 10) + self.assertEqual(f.b, "wow") + self.assertEqual(b.b, "gotit") + self.assertEqual(b.c, 100.0) + + def test_override1(self): + config = Config() + config.MyConfigurable.a = 2 + config.MyConfigurable.b = 2.0 + c = MyConfigurable(a=3, config=config) + self.assertEqual(c.a, 3) + self.assertEqual(c.b, config.MyConfigurable.b) + self.assertEqual(c.c, "no config") + + def test_override2(self): + config = Config() + config.Foo.a = 1 + config.Bar.b = "or" # Up above b is config=False, so this won't do it. + config.Bar.c = 10.0 + with expected_warnings(["`b` not recognized"]): + c = Bar(config=config) + self.assertEqual(c.a, config.Foo.a) + self.assertEqual(c.b, "gotit") + self.assertEqual(c.c, config.Bar.c) + with expected_warnings(["`b` not recognized"]): + c = Bar(a=2, b="and", c=20.0, config=config) + self.assertEqual(c.a, 2) + self.assertEqual(c.b, "and") + self.assertEqual(c.c, 20.0) + + def test_help(self): + self.assertEqual(MyConfigurable.class_get_help(), mc_help) + self.assertEqual(Foo.class_get_help(), foo_help) + self.assertEqual(Bar.class_get_help(), bar_help) + + def test_help_inst(self): + inst = MyConfigurable(a=5, b=4) + self.assertEqual(MyConfigurable.class_get_help(inst), mc_help_inst) + + def test_generated_config_enum_comments(self): + class MyConf(Configurable): + an_enum = Enum("Choice1 choice2".split(), help="Many choices.").tag(config=True) + + help_str = "Many choices." + enum_choices_str = "Choices: any of ['Choice1', 'choice2']" + rst_choices_str = "MyConf.an_enum : any of ``'Choice1'``|``'choice2'``" + or_none_str = "or None" + + cls_help = MyConf.class_get_help() + + self.assertIn(help_str, cls_help) + self.assertIn(enum_choices_str, cls_help) + self.assertNotIn(or_none_str, cls_help) + + cls_cfg = MyConf.class_config_section() + + self.assertIn(help_str, cls_cfg) + self.assertIn(enum_choices_str, cls_cfg) + self.assertNotIn(or_none_str, cls_help) + # Check order of Help-msg <--> Choices sections + self.assertGreater(cls_cfg.index(enum_choices_str), cls_cfg.index(help_str)) + + rst_help = MyConf.class_config_rst_doc() + + self.assertIn(help_str, rst_help) + self.assertIn(rst_choices_str, rst_help) + self.assertNotIn(or_none_str, rst_help) + + class MyConf2(Configurable): + an_enum = Enum( + "Choice1 choice2".split(), + allow_none=True, + default_value="choice2", + help="Many choices.", + ).tag(config=True) + + defaults_str = "Default: 'choice2'" + + cls2_msg = MyConf2.class_get_help() + + self.assertIn(help_str, cls2_msg) + self.assertIn(enum_choices_str, cls2_msg) + self.assertIn(or_none_str, cls2_msg) + self.assertIn(defaults_str, cls2_msg) + # Check order of Default <--> Choices sections + self.assertGreater(cls2_msg.index(defaults_str), cls2_msg.index(enum_choices_str)) + + cls2_cfg = MyConf2.class_config_section() + + self.assertIn(help_str, cls2_cfg) + self.assertIn(enum_choices_str, cls2_cfg) + self.assertIn(or_none_str, cls2_cfg) + self.assertIn(defaults_str, cls2_cfg) + # Check order of Default <--> Choices sections + self.assertGreater(cls2_cfg.index(defaults_str), cls2_cfg.index(enum_choices_str)) + + def test_generated_config_strenum_comments(self): + help_str = "Many choices." + defaults_str = "Default: 'choice2'" + or_none_str = "or None" + + class MyConf3(Configurable): + an_enum = CaselessStrEnum( + "Choice1 choice2".split(), + allow_none=True, + default_value="choice2", + help="Many choices.", + ).tag(config=True) + + enum_choices_str = "Choices: any of ['Choice1', 'choice2'] (case-insensitive)" + + cls3_msg = MyConf3.class_get_help() + + self.assertIn(help_str, cls3_msg) + self.assertIn(enum_choices_str, cls3_msg) + self.assertIn(or_none_str, cls3_msg) + self.assertIn(defaults_str, cls3_msg) + # Check order of Default <--> Choices sections + self.assertGreater(cls3_msg.index(defaults_str), cls3_msg.index(enum_choices_str)) + + cls3_cfg = MyConf3.class_config_section() + + self.assertIn(help_str, cls3_cfg) + self.assertIn(enum_choices_str, cls3_cfg) + self.assertIn(or_none_str, cls3_cfg) + self.assertIn(defaults_str, cls3_cfg) + # Check order of Default <--> Choices sections + self.assertGreater(cls3_cfg.index(defaults_str), cls3_cfg.index(enum_choices_str)) + + class MyConf4(Configurable): + an_enum = FuzzyEnum( + "Choice1 choice2".split(), + allow_none=True, + default_value="choice2", + help="Many choices.", + ).tag(config=True) + + enum_choices_str = "Choices: any case-insensitive prefix of ['Choice1', 'choice2']" + + cls4_msg = MyConf4.class_get_help() + + self.assertIn(help_str, cls4_msg) + self.assertIn(enum_choices_str, cls4_msg) + self.assertIn(or_none_str, cls4_msg) + self.assertIn(defaults_str, cls4_msg) + # Check order of Default <--> Choices sections + self.assertGreater(cls4_msg.index(defaults_str), cls4_msg.index(enum_choices_str)) + + cls4_cfg = MyConf4.class_config_section() + + self.assertIn(help_str, cls4_cfg) + self.assertIn(enum_choices_str, cls4_cfg) + self.assertIn(or_none_str, cls4_cfg) + self.assertIn(defaults_str, cls4_cfg) + # Check order of Default <--> Choices sections + self.assertGreater(cls4_cfg.index(defaults_str), cls4_cfg.index(enum_choices_str)) + + +class TestSingletonConfigurable(TestCase): + def test_instance(self): + class Foo(SingletonConfigurable): + pass + + self.assertEqual(Foo.initialized(), False) + foo = Foo.instance() + self.assertEqual(Foo.initialized(), True) + self.assertEqual(foo, Foo.instance()) + self.assertEqual(SingletonConfigurable._instance, None) + + def test_inheritance(self): + class Bar(SingletonConfigurable): + pass + + class Bam(Bar): + pass + + self.assertEqual(Bar.initialized(), False) + self.assertEqual(Bam.initialized(), False) + bam = Bam.instance() + self.assertEqual(Bar.initialized(), True) + self.assertEqual(Bam.initialized(), True) + self.assertEqual(bam, Bam._instance) + self.assertEqual(bam, Bar._instance) + self.assertEqual(SingletonConfigurable._instance, None) + + +class TestLoggingConfigurable(TestCase): + def test_parent_logger(self): + class Parent(LoggingConfigurable): + pass + + class Child(LoggingConfigurable): + pass + + log = get_logger().getChild("TestLoggingConfigurable") + + parent = Parent(log=log) + child = Child(parent=parent) + self.assertEqual(parent.log, log) + self.assertEqual(child.log, log) + + parent = Parent() + child = Child(parent=parent, log=log) + self.assertEqual(parent.log, get_logger()) + self.assertEqual(child.log, log) + + def test_parent_not_logging_configurable(self): + class Parent(Configurable): + pass + + class Child(LoggingConfigurable): + pass + + parent = Parent() + child = Child(parent=parent) + self.assertEqual(child.log, get_logger()) + + +class MyParent(Configurable): + pass + + +class MyParent2(MyParent): + pass + + +class TestParentConfigurable(TestCase): + def test_parent_config(self): + cfg = Config( + { + "MyParent": { + "MyConfigurable": { + "b": 2.0, + } + } + } + ) + parent = MyParent(config=cfg) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent.MyConfigurable.b) + + def test_parent_inheritance(self): + cfg = Config( + { + "MyParent": { + "MyConfigurable": { + "b": 2.0, + } + } + } + ) + parent = MyParent2(config=cfg) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent.MyConfigurable.b) + + def test_multi_parent(self): + cfg = Config( + { + "MyParent2": { + "MyParent": { + "MyConfigurable": { + "b": 2.0, + } + }, + # this one shouldn't count + "MyConfigurable": { + "b": 3.0, + }, + } + } + ) + parent2 = MyParent2(config=cfg) + parent = MyParent(parent=parent2) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent2.MyParent.MyConfigurable.b) + + def test_parent_priority(self): + cfg = Config( + { + "MyConfigurable": { + "b": 2.0, + }, + "MyParent": { + "MyConfigurable": { + "b": 3.0, + } + }, + "MyParent2": { + "MyConfigurable": { + "b": 4.0, + } + }, + } + ) + parent = MyParent2(config=cfg) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent2.MyConfigurable.b) + + def test_multi_parent_priority(self): + cfg = Config( + { + "MyConfigurable": { + "b": 2.0, + }, + "MyParent": { + "MyConfigurable": { + "b": 3.0, + }, + }, + "MyParent2": { + "MyConfigurable": { + "b": 4.0, + }, + "MyParent": { + "MyConfigurable": { + "b": 5.0, + }, + }, + }, + } + ) + parent2 = MyParent2(config=cfg) + parent = MyParent2(parent=parent2) + myc = MyConfigurable(parent=parent) + self.assertEqual(myc.b, parent.config.MyParent2.MyParent.MyConfigurable.b) + + +class Containers(Configurable): + lis = List().tag(config=True) + + def _lis_default(self): + return [-1] + + s = Set().tag(config=True) + + def _s_default(self): + return {"a"} + + d = Dict().tag(config=True) + + def _d_default(self): + return {"a": "b"} + + +class TestConfigContainers(TestCase): + def test_extend(self): + c = Config() + c.Containers.lis.extend(list(range(5))) + obj = Containers(config=c) + self.assertEqual(obj.lis, list(range(-1, 5))) + + def test_insert(self): + c = Config() + c.Containers.lis.insert(0, "a") + c.Containers.lis.insert(1, "b") + obj = Containers(config=c) + self.assertEqual(obj.lis, ["a", "b", -1]) + + def test_prepend(self): + c = Config() + c.Containers.lis.prepend([1, 2]) + c.Containers.lis.prepend([2, 3]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [2, 3, 1, 2, -1]) + + def test_prepend_extend(self): + c = Config() + c.Containers.lis.prepend([1, 2]) + c.Containers.lis.extend([2, 3]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [1, 2, -1, 2, 3]) + + def test_append_extend(self): + c = Config() + c.Containers.lis.append([1, 2]) + c.Containers.lis.extend([2, 3]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [-1, [1, 2], 2, 3]) + + def test_extend_append(self): + c = Config() + c.Containers.lis.extend([2, 3]) + c.Containers.lis.append([1, 2]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [-1, 2, 3, [1, 2]]) + + def test_insert_extend(self): + c = Config() + c.Containers.lis.insert(0, 1) + c.Containers.lis.extend([2, 3]) + obj = Containers(config=c) + self.assertEqual(obj.lis, [1, -1, 2, 3]) + + def test_set_update(self): + c = Config() + c.Containers.s.update({0, 1, 2}) + c.Containers.s.update({3}) + obj = Containers(config=c) + self.assertEqual(obj.s, {"a", 0, 1, 2, 3}) + + def test_dict_update(self): + c = Config() + c.Containers.d.update({"c": "d"}) + c.Containers.d.update({"e": "f"}) + obj = Containers(config=c) + self.assertEqual(obj.d, {"a": "b", "c": "d", "e": "f"}) + + def test_update_twice(self): + c = Config() + c.MyConfigurable.a = 5 + m = MyConfigurable(config=c) + self.assertEqual(m.a, 5) + + c2 = Config() + c2.MyConfigurable.a = 10 + m.update_config(c2) + self.assertEqual(m.a, 10) + + c2.MyConfigurable.a = 15 + m.update_config(c2) + self.assertEqual(m.a, 15) + + def test_update_self(self): + """update_config with same config object still triggers config_changed""" + c = Config() + c.MyConfigurable.a = 5 + m = MyConfigurable(config=c) + self.assertEqual(m.a, 5) + c.MyConfigurable.a = 10 + m.update_config(c) + self.assertEqual(m.a, 10) + + def test_config_default(self): + class SomeSingleton(SingletonConfigurable): + pass + + class DefaultConfigurable(Configurable): + a = Integer().tag(config=True) + + def _config_default(self): + if SomeSingleton.initialized(): + return SomeSingleton.instance().config + return Config() + + c = Config() + c.DefaultConfigurable.a = 5 + + d1 = DefaultConfigurable() + self.assertEqual(d1.a, 0) + + single = SomeSingleton.instance(config=c) + + d2 = DefaultConfigurable() + self.assertIs(d2.config, single.config) + self.assertEqual(d2.a, 5) + + def test_config_default_deprecated(self): + """Make sure configurables work even with the deprecations in traitlets""" + + class SomeSingleton(SingletonConfigurable): + pass + + # reset deprecation limiter + _deprecations_shown.clear() + with expected_warnings([]): + + class DefaultConfigurable(Configurable): + a = Integer(config=True) + + def _config_default(self): + if SomeSingleton.initialized(): + return SomeSingleton.instance().config + return Config() + + c = Config() + c.DefaultConfigurable.a = 5 + + d1 = DefaultConfigurable() + self.assertEqual(d1.a, 0) + + single = SomeSingleton.instance(config=c) + + d2 = DefaultConfigurable() + self.assertIs(d2.config, single.config) + self.assertEqual(d2.a, 5) + + def test_kwarg_config_priority(self): + # a, c set in kwargs + # a, b set in config + # verify that: + # - kwargs are set before config + # - kwargs have priority over config + class A(Configurable): + a = Unicode("default", config=True) + b = Unicode("default", config=True) + c = Unicode("default", config=True) + c_during_config = Unicode("never") + + @validate("b") + def _record_c(self, proposal): + # setting b from config records c's value at the time + self.c_during_config = self.c + return proposal.value + + cfg = Config() + cfg.A.a = "a-config" + cfg.A.b = "b-config" + obj = A(a="a-kwarg", c="c-kwarg", config=cfg) + assert obj.a == "a-kwarg" + assert obj.b == "b-config" + assert obj.c == "c-kwarg" + assert obj.c_during_config == "c-kwarg" + + +class TestLogger(TestCase): + class A(LoggingConfigurable): + foo = Integer(config=True) + bar = Integer(config=True) + baz = Integer(config=True) + + @mark.skipif(not hasattr(TestCase, "assertLogs"), reason="requires TestCase.assertLogs") + def test_warn_match(self): + logger = logging.getLogger("test_warn_match") + cfg = Config({"A": {"bat": 5}}) + with self.assertLogs(logger, logging.WARNING) as captured: + TestLogger.A(config=cfg, log=logger) + + output = "\n".join(captured.output) + self.assertIn("Did you mean one of: `bar, baz`?", output) + self.assertIn("Config option `bat` not recognized by `A`.", output) + + cfg = Config({"A": {"fool": 5}}) + with self.assertLogs(logger, logging.WARNING) as captured: + TestLogger.A(config=cfg, log=logger) + + output = "\n".join(captured.output) + self.assertIn("Config option `fool` not recognized by `A`.", output) + self.assertIn("Did you mean `foo`?", output) + + cfg = Config({"A": {"totally_wrong": 5}}) + with self.assertLogs(logger, logging.WARNING) as captured: + TestLogger.A(config=cfg, log=logger) + + output = "\n".join(captured.output) + self.assertIn("Config option `totally_wrong` not recognized by `A`.", output) + self.assertNotIn("Did you mean", output) + + +def test_logger_adapter(caplog, capsys): + logger = logging.getLogger("Application") + adapter = logging.LoggerAdapter(logger, {"key": "adapted"}) + + app = Application(log=adapter, log_level=logging.INFO) + app.log_format = "%(key)s %(message)s" + app.log.info("test message") + + assert "adapted test message" in capsys.readouterr().err diff --git a/contrib/python/traitlets/py3/traitlets/config/tests/test_loader.py b/contrib/python/traitlets/py3/traitlets/config/tests/test_loader.py new file mode 100644 index 0000000000..7355544c77 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/config/tests/test_loader.py @@ -0,0 +1,754 @@ +"""Tests for traitlets.config.loader""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +import copy +import os +import pickle +from itertools import chain +from tempfile import mkstemp +from unittest import TestCase + +import pytest + +from traitlets import Dict, Integer, List, Tuple, Unicode +from traitlets.config import Configurable +from traitlets.config.loader import ( + ArgParseConfigLoader, + Config, + JSONFileConfigLoader, + KeyValueConfigLoader, + KVArgParseConfigLoader, + LazyConfigValue, + PyFileConfigLoader, +) + +pyfile = """ +c = get_config() +c.a=10 +c.b=20 +c.Foo.Bar.value=10 +c.Foo.Bam.value=list(range(10)) +c.D.C.value='hi there' +""" + +json1file = """ +{ + "version": 1, + "a": 10, + "b": 20, + "Foo": { + "Bam": { + "value": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ] + }, + "Bar": { + "value": 10 + } + }, + "D": { + "C": { + "value": "hi there" + } + } +} +""" + +# should not load +json2file = """ +{ + "version": 2 +} +""" + +import logging + +log = logging.getLogger("devnull") +log.setLevel(0) + + +class TestFileCL(TestCase): + def _check_conf(self, config): + self.assertEqual(config.a, 10) + self.assertEqual(config.b, 20) + self.assertEqual(config.Foo.Bar.value, 10) + self.assertEqual(config.Foo.Bam.value, list(range(10))) + self.assertEqual(config.D.C.value, "hi there") + + def test_python(self): + fd, fname = mkstemp(".py", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write(pyfile) + f.close() + # Unlink the file + cl = PyFileConfigLoader(fname, log=log) + config = cl.load_config() + self._check_conf(config) + + def test_json(self): + fd, fname = mkstemp(".json", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write(json1file) + f.close() + # Unlink the file + cl = JSONFileConfigLoader(fname, log=log) + config = cl.load_config() + self._check_conf(config) + + def test_context_manager(self): + + fd, fname = mkstemp(".json", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write("{}") + f.close() + + cl = JSONFileConfigLoader(fname, log=log) + + value = "context_manager" + + with cl as c: + c.MyAttr.value = value + + self.assertEqual(cl.config.MyAttr.value, value) + + # check that another loader does see the change + _ = JSONFileConfigLoader(fname, log=log) + self.assertEqual(cl.config.MyAttr.value, value) + + def test_json_context_bad_write(self): + fd, fname = mkstemp(".json", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write("{}") + f.close() + + with JSONFileConfigLoader(fname, log=log) as config: + config.A.b = 1 + + with self.assertRaises(TypeError): + with JSONFileConfigLoader(fname, log=log) as config: + config.A.cant_json = lambda x: x + + loader = JSONFileConfigLoader(fname, log=log) + cfg = loader.load_config() + assert cfg.A.b == 1 + assert "cant_json" not in cfg.A + + def test_collision(self): + a = Config() + b = Config() + self.assertEqual(a.collisions(b), {}) + a.A.trait1 = 1 + b.A.trait2 = 2 + self.assertEqual(a.collisions(b), {}) + b.A.trait1 = 1 + self.assertEqual(a.collisions(b), {}) + b.A.trait1 = 0 + self.assertEqual( + a.collisions(b), + { + "A": { + "trait1": "1 ignored, using 0", + } + }, + ) + self.assertEqual( + b.collisions(a), + { + "A": { + "trait1": "0 ignored, using 1", + } + }, + ) + a.A.trait2 = 3 + self.assertEqual( + b.collisions(a), + { + "A": { + "trait1": "0 ignored, using 1", + "trait2": "2 ignored, using 3", + } + }, + ) + + def test_v2raise(self): + fd, fname = mkstemp(".json", prefix="μnïcø∂e") + f = os.fdopen(fd, "w") + f.write(json2file) + f.close() + # Unlink the file + cl = JSONFileConfigLoader(fname, log=log) + with self.assertRaises(ValueError): + cl.load_config() + + +def _parse_int_or_str(v): + try: + return int(v) + except Exception: + return str(v) + + +class MyLoader1(ArgParseConfigLoader): + def _add_arguments(self, aliases=None, flags=None, classes=None): + p = self.parser + p.add_argument("-f", "--foo", dest="Global.foo", type=str) + p.add_argument("-b", dest="MyClass.bar", type=int) + p.add_argument("-n", dest="n", action="store_true") + p.add_argument("Global.bam", type=str) + p.add_argument("--list1", action="append", type=_parse_int_or_str) + p.add_argument("--list2", nargs="+", type=int) + + +class MyLoader2(ArgParseConfigLoader): + def _add_arguments(self, aliases=None, flags=None, classes=None): + subparsers = self.parser.add_subparsers(dest="subparser_name") + subparser1 = subparsers.add_parser("1") + subparser1.add_argument("-x", dest="Global.x") + subparser2 = subparsers.add_parser("2") + subparser2.add_argument("y") + + +class TestArgParseCL(TestCase): + def test_basic(self): + cl = MyLoader1() + config = cl.load_config("-f hi -b 10 -n wow".split()) + self.assertEqual(config.Global.foo, "hi") + self.assertEqual(config.MyClass.bar, 10) + self.assertEqual(config.n, True) + self.assertEqual(config.Global.bam, "wow") + config = cl.load_config(["wow"]) + self.assertEqual(list(config.keys()), ["Global"]) + self.assertEqual(list(config.Global.keys()), ["bam"]) + self.assertEqual(config.Global.bam, "wow") + + def test_add_arguments(self): + cl = MyLoader2() + config = cl.load_config("2 frobble".split()) + self.assertEqual(config.subparser_name, "2") + self.assertEqual(config.y, "frobble") + config = cl.load_config("1 -x frobble".split()) + self.assertEqual(config.subparser_name, "1") + self.assertEqual(config.Global.x, "frobble") + + def test_argv(self): + cl = MyLoader1(argv="-f hi -b 10 -n wow".split()) + config = cl.load_config() + self.assertEqual(config.Global.foo, "hi") + self.assertEqual(config.MyClass.bar, 10) + self.assertEqual(config.n, True) + self.assertEqual(config.Global.bam, "wow") + + def test_list_args(self): + cl = MyLoader1() + config = cl.load_config("--list1 1 wow --list2 1 2 3 --list1 B".split()) + self.assertEqual(list(config.Global.keys()), ["bam"]) + self.assertEqual(config.Global.bam, "wow") + self.assertEqual(config.list1, [1, "B"]) + self.assertEqual(config.list2, [1, 2, 3]) + + +class C(Configurable): + str_trait = Unicode(config=True) + int_trait = Integer(config=True) + list_trait = List(config=True) + list_of_ints = List(Integer(), config=True) + dict_trait = Dict(config=True) + dict_of_ints = Dict( + key_trait=Integer(), + value_trait=Integer(), + config=True, + ) + dict_multi = Dict( + key_trait=Unicode(), + per_key_traits={ + "int": Integer(), + "str": Unicode(), + }, + config=True, + ) + + +class TestKeyValueCL(TestCase): + klass = KeyValueConfigLoader + + def test_eval(self): + cl = self.klass(log=log) + config = cl.load_config( + '--C.str_trait=all --C.int_trait=5 --C.list_trait=["hello",5]'.split() + ) + c = C(config=config) + assert c.str_trait == "all" + assert c.int_trait == 5 + assert c.list_trait == ["hello", 5] + + def test_basic(self): + cl = self.klass(log=log) + argv = ["--" + s[2:] for s in pyfile.split("\n") if s.startswith("c.")] + config = cl.load_config(argv) + assert config.a == "10" + assert config.b == "20" + assert config.Foo.Bar.value == "10" + # non-literal expressions are not evaluated + self.assertEqual(config.Foo.Bam.value, "list(range(10))") + self.assertEqual(Unicode().from_string(config.D.C.value), "hi there") + + def test_expanduser(self): + cl = self.klass(log=log) + argv = ["--a=~/1/2/3", "--b=~", "--c=~/", '--d="~/"'] + config = cl.load_config(argv) + u = Unicode() + self.assertEqual(u.from_string(config.a), os.path.expanduser("~/1/2/3")) + self.assertEqual(u.from_string(config.b), os.path.expanduser("~")) + self.assertEqual(u.from_string(config.c), os.path.expanduser("~/")) + self.assertEqual(u.from_string(config.d), "~/") + + def test_extra_args(self): + cl = self.klass(log=log) + config = cl.load_config(["--a=5", "b", "d", "--c=10"]) + self.assertEqual(cl.extra_args, ["b", "d"]) + assert config.a == "5" + assert config.c == "10" + config = cl.load_config(["--", "--a=5", "--c=10"]) + self.assertEqual(cl.extra_args, ["--a=5", "--c=10"]) + + cl = self.klass(log=log) + config = cl.load_config(["extra", "--a=2", "--c=1", "--", "-"]) + self.assertEqual(cl.extra_args, ["extra", "-"]) + + def test_unicode_args(self): + cl = self.klass(log=log) + argv = ["--a=épsîlön"] + config = cl.load_config(argv) + print(config, cl.extra_args) + self.assertEqual(config.a, "épsîlön") + + def test_list_append(self): + cl = self.klass(log=log) + argv = ["--C.list_trait", "x", "--C.list_trait", "y"] + config = cl.load_config(argv) + assert config.C.list_trait == ["x", "y"] + c = C(config=config) + assert c.list_trait == ["x", "y"] + + def test_list_single_item(self): + cl = self.klass(log=log) + argv = ["--C.list_trait", "x"] + config = cl.load_config(argv) + c = C(config=config) + assert c.list_trait == ["x"] + + def test_dict(self): + cl = self.klass(log=log) + argv = ["--C.dict_trait", "x=5", "--C.dict_trait", "y=10"] + config = cl.load_config(argv) + c = C(config=config) + assert c.dict_trait == {"x": "5", "y": "10"} + + def test_dict_key_traits(self): + cl = self.klass(log=log) + argv = ["--C.dict_of_ints", "1=2", "--C.dict_of_ints", "3=4"] + config = cl.load_config(argv) + c = C(config=config) + assert c.dict_of_ints == {1: 2, 3: 4} + + +class CBase(Configurable): + a = List().tag(config=True) + b = List(Integer()).tag(config=True, multiplicity="*") + c = List().tag(config=True, multiplicity="append") + adict = Dict().tag(config=True) + + +class CSub(CBase): + d = Tuple().tag(config=True) + e = Tuple().tag(config=True, multiplicity="+") + bdict = Dict().tag(config=True, multiplicity="*") + + +class TestArgParseKVCL(TestKeyValueCL): + klass = KVArgParseConfigLoader # type:ignore + + def test_no_cast_literals(self): + cl = self.klass(log=log) # type:ignore + # test ipython -c 1 doesn't cast to int + argv = ["-c", "1"] + config = cl.load_config(argv, aliases=dict(c="IPython.command_to_run")) + assert config.IPython.command_to_run == "1" + + def test_int_literals(self): + cl = self.klass(log=log) # type:ignore + # test ipython -c 1 doesn't cast to int + argv = ["-c", "1"] + config = cl.load_config(argv, aliases=dict(c="IPython.command_to_run")) + assert config.IPython.command_to_run == "1" + + def test_unicode_alias(self): + cl = self.klass(log=log) # type:ignore + argv = ["--a=épsîlön"] + config = cl.load_config(argv, aliases=dict(a="A.a")) + print(dict(config)) + print(cl.extra_args) + print(cl.aliases) + self.assertEqual(config.A.a, "épsîlön") + + def test_expanduser2(self): + cl = self.klass(log=log) # type:ignore + argv = ["-a", "~/1/2/3", "--b", "'~/1/2/3'"] + config = cl.load_config(argv, aliases=dict(a="A.a", b="A.b")) + + class A(Configurable): + a = Unicode(config=True) + b = Unicode(config=True) + + a = A(config=config) + self.assertEqual(a.a, os.path.expanduser("~/1/2/3")) + self.assertEqual(a.b, "~/1/2/3") + + def test_eval(self): + cl = self.klass(log=log) # type:ignore + argv = ["-c", "a=5"] + config = cl.load_config(argv, aliases=dict(c="A.c")) + self.assertEqual(config.A.c, "a=5") + + def test_seq_traits(self): + cl = self.klass(log=log, classes=(CBase, CSub)) # type:ignore + aliases = {"a3": "CBase.c", "a5": "CSub.e"} + argv = ( + "--CBase.a A --CBase.a 2 --CBase.b 1 --CBase.b 3 --a3 AA --CBase.c BB " + "--CSub.d 1 --CSub.d BBB --CSub.e 1 --CSub.e=bcd a b c " + ).split() + config = cl.load_config(argv, aliases=aliases) + assert cl.extra_args == ["a", "b", "c"] + assert config.CBase.a == ["A", "2"] + assert config.CBase.b == [1, 3] + self.assertEqual(config.CBase.c, ["AA", "BB"]) + + assert config.CSub.d == ("1", "BBB") + assert config.CSub.e == ("1", "bcd") + + def test_seq_traits_single_empty_string(self): + cl = self.klass(log=log, classes=(CBase,)) # type:ignore + aliases = {"seqopt": "CBase.c"} + argv = ["--seqopt", ""] + config = cl.load_config(argv, aliases=aliases) + self.assertEqual(config.CBase.c, [""]) + + def test_dict_traits(self): + cl = self.klass(log=log, classes=(CBase, CSub)) # type:ignore + aliases = {"D": "CBase.adict", "E": "CSub.bdict"} + argv = ["-D", "k1=v1", "-D=k2=2", "-D", "k3=v 3", "-E", "k=v", "-E", "22=222"] + config = cl.load_config(argv, aliases=aliases) + c = CSub(config=config) + assert c.adict == {"k1": "v1", "k2": "2", "k3": "v 3"} + assert c.bdict == {"k": "v", "22": "222"} + + def test_mixed_seq_positional(self): + aliases = {"c": "Class.trait"} + cl = self.klass(log=log, aliases=aliases) # type:ignore + assignments = [("-c", "1"), ("--Class.trait=2",), ("--c=3",), ("--Class.trait", "4")] + positionals = ["a", "b", "c"] + # test with positionals at any index + for idx in range(len(assignments) + 1): + argv_parts = assignments[:] + argv_parts[idx:idx] = (positionals,) # type:ignore + argv = list(chain(*argv_parts)) + + config = cl.load_config(argv) + assert config.Class.trait == ["1", "2", "3", "4"] + assert cl.extra_args == ["a", "b", "c"] + + def test_split_positional(self): + """Splitting positionals across flags is no longer allowed in traitlets 5""" + cl = self.klass(log=log) # type:ignore + argv = ["a", "--Class.trait=5", "b"] + with pytest.raises(SystemExit): + cl.load_config(argv) + + +class TestConfig(TestCase): + def test_setget(self): + c = Config() + c.a = 10 + self.assertEqual(c.a, 10) + self.assertEqual("b" in c, False) + + def test_auto_section(self): + c = Config() + self.assertNotIn("A", c) + assert not c._has_section("A") + A = c.A + A.foo = "hi there" + self.assertIn("A", c) + assert c._has_section("A") + self.assertEqual(c.A.foo, "hi there") + del c.A + self.assertEqual(c.A, Config()) + + def test_merge_doesnt_exist(self): + c1 = Config() + c2 = Config() + c2.bar = 10 + c2.Foo.bar = 10 + c1.merge(c2) + self.assertEqual(c1.Foo.bar, 10) + self.assertEqual(c1.bar, 10) + c2.Bar.bar = 10 + c1.merge(c2) + self.assertEqual(c1.Bar.bar, 10) + + def test_merge_exists(self): + c1 = Config() + c2 = Config() + c1.Foo.bar = 10 + c1.Foo.bam = 30 + c2.Foo.bar = 20 + c2.Foo.wow = 40 + c1.merge(c2) + self.assertEqual(c1.Foo.bam, 30) + self.assertEqual(c1.Foo.bar, 20) + self.assertEqual(c1.Foo.wow, 40) + c2.Foo.Bam.bam = 10 + c1.merge(c2) + self.assertEqual(c1.Foo.Bam.bam, 10) + + def test_deepcopy(self): + c1 = Config() + c1.Foo.bar = 10 + c1.Foo.bam = 30 + c1.a = "asdf" + c1.b = range(10) + c1.Test.logger = logging.Logger("test") + c1.Test.get_logger = logging.getLogger("test") + c2 = copy.deepcopy(c1) + self.assertEqual(c1, c2) + self.assertTrue(c1 is not c2) + self.assertTrue(c1.Foo is not c2.Foo) + self.assertTrue(c1.Test is not c2.Test) + self.assertTrue(c1.Test.logger is c2.Test.logger) + self.assertTrue(c1.Test.get_logger is c2.Test.get_logger) + + def test_builtin(self): + c1 = Config() + c1.format = "json" + + def test_fromdict(self): + c1 = Config({"Foo": {"bar": 1}}) + self.assertEqual(c1.Foo.__class__, Config) + self.assertEqual(c1.Foo.bar, 1) + + def test_fromdictmerge(self): + c1 = Config() + c2 = Config({"Foo": {"bar": 1}}) + c1.merge(c2) + self.assertEqual(c1.Foo.__class__, Config) + self.assertEqual(c1.Foo.bar, 1) + + def test_fromdictmerge2(self): + c1 = Config({"Foo": {"baz": 2}}) + c2 = Config({"Foo": {"bar": 1}}) + c1.merge(c2) + self.assertEqual(c1.Foo.__class__, Config) + self.assertEqual(c1.Foo.bar, 1) + self.assertEqual(c1.Foo.baz, 2) + self.assertNotIn("baz", c2.Foo) + + def test_contains(self): + c1 = Config({"Foo": {"baz": 2}}) + c2 = Config({"Foo": {"bar": 1}}) + self.assertIn("Foo", c1) + self.assertIn("Foo.baz", c1) + self.assertIn("Foo.bar", c2) + self.assertNotIn("Foo.bar", c1) + + def test_pickle_config(self): + cfg = Config() + cfg.Foo.bar = 1 + pcfg = pickle.dumps(cfg) + cfg2 = pickle.loads(pcfg) + self.assertEqual(cfg2, cfg) + + def test_getattr_section(self): + cfg = Config() + self.assertNotIn("Foo", cfg) + Foo = cfg.Foo + assert isinstance(Foo, Config) + self.assertIn("Foo", cfg) + + def test_getitem_section(self): + cfg = Config() + self.assertNotIn("Foo", cfg) + Foo = cfg["Foo"] + assert isinstance(Foo, Config) + self.assertIn("Foo", cfg) + + def test_getattr_not_section(self): + cfg = Config() + self.assertNotIn("foo", cfg) + foo = cfg.foo + assert isinstance(foo, LazyConfigValue) + self.assertIn("foo", cfg) + + def test_getattr_private_missing(self): + cfg = Config() + self.assertNotIn("_repr_html_", cfg) + with self.assertRaises(AttributeError): + _ = cfg._repr_html_ + self.assertNotIn("_repr_html_", cfg) + self.assertEqual(len(cfg), 0) + + def test_lazy_config_repr(self): + cfg = Config() + cfg.Class.lazy.append(1) + cfg_repr = repr(cfg) + assert "<LazyConfigValue" in cfg_repr + assert "extend" in cfg_repr + assert " [1]}>" in cfg_repr + assert "value=" not in cfg_repr + cfg.Class.lazy.get_value([0]) + repr2 = repr(cfg) + assert repr([0, 1]) in repr2 + assert "value=" in repr2 + + def test_getitem_not_section(self): + cfg = Config() + self.assertNotIn("foo", cfg) + foo = cfg["foo"] + assert isinstance(foo, LazyConfigValue) + self.assertIn("foo", cfg) + + def test_merge_no_copies(self): + c = Config() + c2 = Config() + c2.Foo.trait = [] + c.merge(c2) + c2.Foo.trait.append(1) + self.assertIs(c.Foo, c2.Foo) + self.assertEqual(c.Foo.trait, [1]) + self.assertEqual(c2.Foo.trait, [1]) + + def test_merge_multi_lazy(self): + """ + With multiple config files (systemwide and users), we want compounding. + + If systemwide overwirte and user append, we want both in the right + order. + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait = [1] + c2.Foo.trait.append(2) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait, [1, 2]) + + def test_merge_multi_lazyII(self): + """ + With multiple config files (systemwide and users), we want compounding. + + If both are lazy we still want a lazy config. + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait.append(1) + c2.Foo.trait.append(2) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait._extend, [1, 2]) + + def test_merge_multi_lazy_III(self): + """ + With multiple config files (systemwide and users), we want compounding. + + Prepend should prepend in the right order. + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait = [1] + c2.Foo.trait.prepend([0]) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait, [0, 1]) + + def test_merge_multi_lazy_IV(self): + """ + With multiple config files (systemwide and users), we want compounding. + + Both prepending should be lazy + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait.prepend([1]) + c2.Foo.trait.prepend([0]) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait._prepend, [0, 1]) + + def test_merge_multi_lazy_update_I(self): + """ + With multiple config files (systemwide and users), we want compounding. + + dict update shoudl be in the right order. + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait = {"a": 1, "z": 26} + c2.Foo.trait.update({"a": 0, "b": 1}) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait, {"a": 0, "b": 1, "z": 26}) + + def test_merge_multi_lazy_update_II(self): + """ + With multiple config files (systemwide and users), we want compounding. + + Later dict overwrite lazyness + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait.update({"a": 0, "b": 1}) + c2.Foo.trait = {"a": 1, "z": 26} + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait, {"a": 1, "z": 26}) + + def test_merge_multi_lazy_update_III(self): + """ + With multiple config files (systemwide and users), we want compounding. + + Later dict overwrite lazyness + """ + c1 = Config() + c2 = Config() + + c1.Foo.trait.update({"a": 0, "b": 1}) + c2.Foo.trait.update({"a": 1, "z": 26}) + + c = Config() + c.merge(c1) + c.merge(c2) + + self.assertEqual(c.Foo.trait._update, {"a": 1, "z": 26, "b": 1}) diff --git a/contrib/python/traitlets/py3/traitlets/log.py b/contrib/python/traitlets/py3/traitlets/log.py new file mode 100644 index 0000000000..016529fcac --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/log.py @@ -0,0 +1,29 @@ +"""Grab the global logger instance.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + +import logging + +_logger = None + + +def get_logger(): + """Grab the global logger instance. + + If a global Application is instantiated, grab its logger. + Otherwise, grab the root logger. + """ + global _logger + + if _logger is None: + from .config import Application + + if Application.initialized(): + _logger = Application.instance().log + else: + _logger = logging.getLogger("traitlets") + # Add a NullHandler to silence warnings about not being + # initialized, per best practice for libraries. + _logger.addHandler(logging.NullHandler()) + return _logger diff --git a/contrib/python/traitlets/py3/traitlets/py.typed b/contrib/python/traitlets/py3/traitlets/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/py.typed diff --git a/contrib/python/traitlets/py3/traitlets/tests/__init__.py b/contrib/python/traitlets/py3/traitlets/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/tests/__init__.py diff --git a/contrib/python/traitlets/py3/traitlets/tests/_warnings.py b/contrib/python/traitlets/py3/traitlets/tests/_warnings.py new file mode 100644 index 0000000000..e3c3a0ac6d --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/tests/_warnings.py @@ -0,0 +1,114 @@ +# From scikit-image: https://github.com/scikit-image/scikit-image/blob/c2f8c4ab123ebe5f7b827bc495625a32bb225c10/skimage/_shared/_warnings.py +# Licensed under modified BSD license + +__all__ = ["all_warnings", "expected_warnings"] + +import inspect +import os +import re +import sys +import warnings +from contextlib import contextmanager +from unittest import mock + + +@contextmanager +def all_warnings(): + """ + Context for use in testing to ensure that all warnings are raised. + Examples + -------- + >>> import warnings + >>> def foo(): + ... warnings.warn(RuntimeWarning("bar")) + + We raise the warning once, while the warning filter is set to "once". + Hereafter, the warning is invisible, even with custom filters: + >>> with warnings.catch_warnings(): + ... warnings.simplefilter('once') + ... foo() + + We can now run ``foo()`` without a warning being raised: + >>> from numpy.testing import assert_warns # doctest: +SKIP + >>> foo() # doctest: +SKIP + + To catch the warning, we call in the help of ``all_warnings``: + >>> with all_warnings(): # doctest: +SKIP + ... assert_warns(RuntimeWarning, foo) + """ + + # Whenever a warning is triggered, Python adds a __warningregistry__ + # member to the *calling* module. The exercize here is to find + # and eradicate all those breadcrumbs that were left lying around. + # + # We proceed by first searching all parent calling frames and explicitly + # clearing their warning registries (necessary for the doctests above to + # pass). Then, we search for all submodules of skimage and clear theirs + # as well (necessary for the skimage test suite to pass). + + frame = inspect.currentframe() + if frame: + for f in inspect.getouterframes(frame): + f[0].f_locals["__warningregistry__"] = {} + del frame + + for _, mod in list(sys.modules.items()): + try: + mod.__warningregistry__.clear() + except AttributeError: + pass + + with warnings.catch_warnings(record=True) as w, mock.patch.dict( + os.environ, {"TRAITLETS_ALL_DEPRECATIONS": "1"} + ): + warnings.simplefilter("always") + yield w + + +@contextmanager +def expected_warnings(matching): + r"""Context for use in testing to catch known warnings matching regexes + + Parameters + ---------- + matching : list of strings or compiled regexes + Regexes for the desired warning to catch + + Examples + -------- + >>> from skimage import data, img_as_ubyte, img_as_float # doctest: +SKIP + >>> with expected_warnings(["precision loss"]): # doctest: +SKIP + ... d = img_as_ubyte(img_as_float(data.coins())) # doctest: +SKIP + + Notes + ----- + Uses `all_warnings` to ensure all warnings are raised. + Upon exiting, it checks the recorded warnings for the desired matching + pattern(s). + Raises a ValueError if any match was not found or an unexpected + warning was raised. + Allows for three types of behaviors: "and", "or", and "optional" matches. + This is done to accomodate different build enviroments or loop conditions + that may produce different warnings. The behaviors can be combined. + If you pass multiple patterns, you get an orderless "and", where all of the + warnings must be raised. + If you use the "|" operator in a pattern, you can catch one of several warnings. + Finally, you can use "|\A\Z" in a pattern to signify it as optional. + """ + with all_warnings() as w: + # enter context + yield w + # exited user context, check the recorded warnings + remaining = [m for m in matching if r"\A\Z" not in m.split("|")] + for warn in w: + found = False + for match in matching: + if re.search(match, str(warn.message)) is not None: + found = True + if match in remaining: + remaining.remove(match) + if not found: + raise ValueError("Unexpected warning: %s" % str(warn.message)) + if len(remaining) > 0: + msg = "No warning raised matching:\n%s" % "\n".join(remaining) + raise ValueError(msg) diff --git a/contrib/python/traitlets/py3/traitlets/tests/test_traitlets.py b/contrib/python/traitlets/py3/traitlets/tests/test_traitlets.py new file mode 100644 index 0000000000..c99e9d2341 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/tests/test_traitlets.py @@ -0,0 +1,3203 @@ +"""Tests for traitlets.traitlets.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +# +# Adapted from enthought.traits, Copyright (c) Enthought, Inc., +# also under the terms of the Modified BSD License. + +import pickle +import re +import typing as t +from unittest import TestCase + +import pytest + +from traitlets import ( + All, + Any, + BaseDescriptor, + Bool, + Bytes, + Callable, + CBytes, + CFloat, + CInt, + CLong, + Complex, + CRegExp, + CUnicode, + Dict, + DottedObjectName, + Enum, + Float, + ForwardDeclaredInstance, + ForwardDeclaredType, + HasDescriptors, + HasTraits, + Instance, + Int, + Integer, + List, + Long, + MetaHasTraits, + ObjectName, + Set, + TCPAddress, + This, + TraitError, + TraitType, + Tuple, + Type, + Undefined, + Unicode, + Union, + default, + directional_link, + link, + observe, + observe_compat, + traitlets, + validate, +) +from traitlets.utils import cast_unicode + +from traitlets.tests._warnings import expected_warnings + + +def change_dict(*ordered_values): + change_names = ("name", "old", "new", "owner", "type") + return dict(zip(change_names, ordered_values)) + + +# ----------------------------------------------------------------------------- +# Helper classes for testing +# ----------------------------------------------------------------------------- + + +class HasTraitsStub(HasTraits): + def notify_change(self, change): + self._notify_name = change["name"] + self._notify_old = change["old"] + self._notify_new = change["new"] + self._notify_type = change["type"] + + +class CrossValidationStub(HasTraits): + _cross_validation_lock = False + + +# ----------------------------------------------------------------------------- +# Test classes +# ----------------------------------------------------------------------------- + + +class TestTraitType(TestCase): + def test_get_undefined(self): + class A(HasTraits): + a = TraitType + + a = A() + assert a.a is Undefined # type:ignore + + def test_set(self): + class A(HasTraitsStub): + a = TraitType + + a = A() + a.a = 10 # type:ignore + self.assertEqual(a.a, 10) + self.assertEqual(a._notify_name, "a") + self.assertEqual(a._notify_old, Undefined) + self.assertEqual(a._notify_new, 10) + + def test_validate(self): + class MyTT(TraitType): + def validate(self, inst, value): + return -1 + + class A(HasTraitsStub): + tt = MyTT + + a = A() + a.tt = 10 # type:ignore + self.assertEqual(a.tt, -1) + + a = A(tt=11) + self.assertEqual(a.tt, -1) + + def test_default_validate(self): + class MyIntTT(TraitType): + def validate(self, obj, value): + if isinstance(value, int): + return value + self.error(obj, value) + + class A(HasTraits): + tt = MyIntTT(10) + + a = A() + self.assertEqual(a.tt, 10) + + # Defaults are validated when the HasTraits is instantiated + class B(HasTraits): + tt = MyIntTT("bad default") + + self.assertRaises(TraitError, getattr, B(), "tt") + + def test_info(self): + class A(HasTraits): + tt = TraitType + + a = A() + self.assertEqual(A.tt.info(), "any value") # type:ignore + + def test_error(self): + class A(HasTraits): + tt = TraitType() + + a = A() + self.assertRaises(TraitError, A.tt.error, a, 10) + + def test_deprecated_dynamic_initializer(self): + class A(HasTraits): + x = Int(10) + + def _x_default(self): + return 11 + + class B(A): + x = Int(20) + + class C(A): + def _x_default(self): + return 21 + + a = A() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + b = B() + self.assertEqual(b.x, 20) + self.assertEqual(b._trait_values, {"x": 20}) + c = C() + self.assertEqual(c._trait_values, {}) + self.assertEqual(c.x, 21) + self.assertEqual(c._trait_values, {"x": 21}) + # Ensure that the base class remains unmolested when the _default + # initializer gets overridden in a subclass. + a = A() + c = C() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + + def test_deprecated_method_warnings(self): + + with expected_warnings([]): + + class ShouldntWarn(HasTraits): + x = Integer() + + @default("x") + def _x_default(self): + return 10 + + @validate("x") + def _x_validate(self, proposal): + return proposal.value + + @observe("x") + def _x_changed(self, change): + pass + + obj = ShouldntWarn() + obj.x = 5 + + assert obj.x == 5 + + with expected_warnings(["@validate", "@observe"]) as w: + + class ShouldWarn(HasTraits): + x = Integer() + + def _x_default(self): + return 10 + + def _x_validate(self, value, _): + return value + + def _x_changed(self): + pass + + obj = ShouldWarn() # type:ignore + obj.x = 5 + + assert obj.x == 5 + + def test_dynamic_initializer(self): + class A(HasTraits): + x = Int(10) + + @default("x") + def _default_x(self): + return 11 + + class B(A): + x = Int(20) + + class C(A): + @default("x") + def _default_x(self): + return 21 + + a = A() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + b = B() + self.assertEqual(b.x, 20) + self.assertEqual(b._trait_values, {"x": 20}) + c = C() + self.assertEqual(c._trait_values, {}) + self.assertEqual(c.x, 21) + self.assertEqual(c._trait_values, {"x": 21}) + # Ensure that the base class remains unmolested when the _default + # initializer gets overridden in a subclass. + a = A() + c = C() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + + def test_tag_metadata(self): + class MyIntTT(TraitType): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10).tag(b=3, c=4) + self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4}) + + def test_metadata_localized_instance(self): + class MyIntTT(TraitType): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10) + b = MyIntTT(10) + a.metadata["c"] = 3 + # make sure that changing a's metadata didn't change b's metadata + self.assertNotIn("c", b.metadata) + + def test_union_metadata(self): + class Foo(HasTraits): + bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti="b")).tag(ti="a") + + foo = Foo() + # At this point, no value has been set for bar, so value-specific + # is not set. + self.assertEqual(foo.trait_metadata("bar", "ta"), None) + self.assertEqual(foo.trait_metadata("bar", "ti"), "a") + foo.bar = {} + self.assertEqual(foo.trait_metadata("bar", "ta"), 2) + self.assertEqual(foo.trait_metadata("bar", "ti"), "b") + foo.bar = 1 + self.assertEqual(foo.trait_metadata("bar", "ta"), 1) + self.assertEqual(foo.trait_metadata("bar", "ti"), "a") + + def test_union_default_value(self): + class Foo(HasTraits): + bar = Union([Dict(), Int()], default_value=1) + + foo = Foo() + self.assertEqual(foo.bar, 1) + + def test_union_validation_priority(self): + class Foo(HasTraits): + bar = Union([CInt(), Unicode()]) + + foo = Foo() + foo.bar = "1" + # validation in order of the TraitTypes given + self.assertEqual(foo.bar, 1) + + def test_union_trait_default_value(self): + class Foo(HasTraits): + bar = Union([Dict(), Int()]) + + self.assertEqual(Foo().bar, {}) + + def test_deprecated_metadata_access(self): + class MyIntTT(TraitType): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10) + with expected_warnings(["use the instance .metadata dictionary directly"] * 2): + a.set_metadata("key", "value") + v = a.get_metadata("key") + self.assertEqual(v, "value") + with expected_warnings(["use the instance .help string directly"] * 2): + a.set_metadata("help", "some help") + v = a.get_metadata("help") + self.assertEqual(v, "some help") + + def test_trait_types_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Int + + def test_trait_types_list_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = List(Int) + + def test_trait_types_tuple_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Tuple(Int) + + def test_trait_types_dict_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Dict(Int) + + +class TestHasDescriptorsMeta(TestCase): + def test_metaclass(self): + self.assertEqual(type(HasTraits), MetaHasTraits) + + class A(HasTraits): + a = Int() + + a = A() + self.assertEqual(type(a.__class__), MetaHasTraits) + self.assertEqual(a.a, 0) + a.a = 10 + self.assertEqual(a.a, 10) + + class B(HasTraits): + b = Int() + + b = B() + self.assertEqual(b.b, 0) + b.b = 10 + self.assertEqual(b.b, 10) + + class C(HasTraits): + c = Int(30) + + c = C() + self.assertEqual(c.c, 30) + c.c = 10 + self.assertEqual(c.c, 10) + + def test_this_class(self): + class A(HasTraits): + t = This() + tt = This() + + class B(A): + tt = This() + ttt = This() + + self.assertEqual(A.t.this_class, A) + self.assertEqual(B.t.this_class, A) + self.assertEqual(B.tt.this_class, B) + self.assertEqual(B.ttt.this_class, B) + + +class TestHasDescriptors(TestCase): + def test_setup_instance(self): + class FooDescriptor(BaseDescriptor): + def instance_init(self, inst): + foo = inst.foo # instance should have the attr + + class HasFooDescriptors(HasDescriptors): + + fd = FooDescriptor() + + def setup_instance(self, *args, **kwargs): + self.foo = kwargs.get("foo", None) + super().setup_instance(*args, **kwargs) + + hfd = HasFooDescriptors(foo="bar") + + +class TestHasTraitsNotify(TestCase): + def setUp(self): + self._notify1 = [] + self._notify2 = [] + + def notify1(self, name, old, new): + self._notify1.append((name, old, new)) + + def notify2(self, name, old, new): + self._notify2.append((name, old, new)) + + def test_notify_all(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.on_trait_change(self.notify1) + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.b = 0.0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in self._notify1) + a.b = 10.0 + self.assertTrue(("b", 0.0, 10.0) in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + self.assertRaises(TraitError, setattr, a, "b", "bad string") + self._notify1 = [] + a.on_trait_change(self.notify1, remove=True) + a.a = 20 + a.b = 20.0 + self.assertEqual(len(self._notify1), 0) + + def test_notify_one(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.on_trait_change(self.notify1, "a") + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + + def test_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + self.assertEqual(b.a, 0) + self.assertEqual(b.b, 0.0) + b.a = 100 + b.b = 100.0 + self.assertEqual(b.a, 100) + self.assertEqual(b.b, 100.0) + + def test_notify_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + b.on_trait_change(self.notify1, "a") + b.on_trait_change(self.notify2, "b") + b.a = 0 + b.b = 0.0 + self.assertEqual(len(self._notify1), 0) + self.assertEqual(len(self._notify2), 0) + b.a = 10 + b.b = 10.0 + self.assertTrue(("a", 0, 10) in self._notify1) + self.assertTrue(("b", 0.0, 10.0) in self._notify2) + + def test_static_notify(self): + class A(HasTraits): + a = Int() + _notify1 = [] + + def _a_changed(self, name, old, new): + self._notify1.append((name, old, new)) + + a = A() + a.a = 0 + # This is broken!!! + self.assertEqual(len(a._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in a._notify1) + + class B(A): + b = Float() + _notify2 = [] + + def _b_changed(self, name, old, new): + self._notify2.append((name, old, new)) + + b = B() + b.a = 10 + b.b = 10.0 + self.assertTrue(("a", 0, 10) in b._notify1) + self.assertTrue(("b", 0.0, 10.0) in b._notify2) + + def test_notify_args(self): + def callback0(): + self.cb = () + + def callback1(name): + self.cb = (name,) # type:ignore + + def callback2(name, new): + self.cb = (name, new) # type:ignore + + def callback3(name, old, new): + self.cb = (name, old, new) # type:ignore + + def callback4(name, old, new, obj): + self.cb = (name, old, new, obj) # type:ignore + + class A(HasTraits): + a = Int() + + a = A() + a.on_trait_change(callback0, "a") + a.a = 10 + self.assertEqual(self.cb, ()) + a.on_trait_change(callback0, "a", remove=True) + + a.on_trait_change(callback1, "a") + a.a = 100 + self.assertEqual(self.cb, ("a",)) + a.on_trait_change(callback1, "a", remove=True) + + a.on_trait_change(callback2, "a") + a.a = 1000 + self.assertEqual(self.cb, ("a", 1000)) + a.on_trait_change(callback2, "a", remove=True) + + a.on_trait_change(callback3, "a") + a.a = 10000 + self.assertEqual(self.cb, ("a", 1000, 10000)) + a.on_trait_change(callback3, "a", remove=True) + + a.on_trait_change(callback4, "a") + a.a = 100000 + self.assertEqual(self.cb, ("a", 10000, 100000, a)) + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) + a.on_trait_change(callback4, "a", remove=True) + + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) + + def test_notify_only_once(self): + class A(HasTraits): + listen_to = ["a"] + + a = Int(0) + b = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.on_trait_change(self.listener1, ["a"]) + + def listener1(self, name, old, new): + self.b += 1 + + class B(A): + + c = 0 + d = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.on_trait_change(self.listener2) + + def listener2(self, name, old, new): + self.c += 1 + + def _a_changed(self, name, old, new): + self.d += 1 + + b = B() + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + + +class TestObserveDecorator(TestCase): + def setUp(self): + self._notify1 = [] + self._notify2 = [] + + def notify1(self, change): + self._notify1.append(change) + + def notify2(self, change): + self._notify2.append(change) + + def test_notify_all(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.observe(self.notify1) + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.b = 0.0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in self._notify1) + a.b = 10.0 + change = change_dict("b", 0.0, 10.0, a, "change") + self.assertTrue(change in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + self.assertRaises(TraitError, setattr, a, "b", "bad string") + self._notify1 = [] + a.unobserve(self.notify1) + a.a = 20 + a.b = 20.0 + self.assertEqual(len(self._notify1), 0) + + def test_notify_one(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.observe(self.notify1, "a") + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + + def test_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + self.assertEqual(b.a, 0) + self.assertEqual(b.b, 0.0) + b.a = 100 + b.b = 100.0 + self.assertEqual(b.a, 100) + self.assertEqual(b.b, 100.0) + + def test_notify_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + b.observe(self.notify1, "a") + b.observe(self.notify2, "b") + b.a = 0 + b.b = 0.0 + self.assertEqual(len(self._notify1), 0) + self.assertEqual(len(self._notify2), 0) + b.a = 10 + b.b = 10.0 + change = change_dict("a", 0, 10, b, "change") + self.assertTrue(change in self._notify1) + change = change_dict("b", 0.0, 10.0, b, "change") + self.assertTrue(change in self._notify2) + + def test_static_notify(self): + class A(HasTraits): + a = Int() + b = Int() + _notify1 = [] + _notify_any = [] + + @observe("a") + def _a_changed(self, change): + self._notify1.append(change) + + @observe(All) + def _any_changed(self, change): + self._notify_any.append(change) + + a = A() + a.a = 0 + self.assertEqual(len(a._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in a._notify1) + a.b = 1 + self.assertEqual(len(a._notify_any), 2) + change = change_dict("b", 0, 1, a, "change") + self.assertTrue(change in a._notify_any) + + class B(A): + b = Float() # type:ignore + _notify2 = [] + + @observe("b") + def _b_changed(self, change): + self._notify2.append(change) + + b = B() + b.a = 10 + b.b = 10.0 # type:ignore + change = change_dict("a", 0, 10, b, "change") + self.assertTrue(change in b._notify1) + change = change_dict("b", 0.0, 10.0, b, "change") + self.assertTrue(change in b._notify2) + + def test_notify_args(self): + def callback0(): + self.cb = () + + def callback1(change): + self.cb = change + + class A(HasTraits): + a = Int() + + a = A() + a.on_trait_change(callback0, "a") + a.a = 10 + self.assertEqual(self.cb, ()) + a.unobserve(callback0, "a") + + a.observe(callback1, "a") + a.a = 100 + change = change_dict("a", 10, 100, a, "change") + self.assertEqual(self.cb, change) + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) + a.unobserve(callback1, "a") + + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) + + def test_notify_only_once(self): + class A(HasTraits): + listen_to = ["a"] + + a = Int(0) + b = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.observe(self.listener1, ["a"]) + + def listener1(self, change): + self.b += 1 + + class B(A): + + c = 0 + d = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.observe(self.listener2) + + def listener2(self, change): + self.c += 1 + + @observe("a") + def _a_changed(self, change): + self.d += 1 + + b = B() + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + + +class TestHasTraits(TestCase): + def test_trait_names(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertEqual(sorted(a.trait_names()), ["f", "i"]) + self.assertEqual(sorted(A.class_trait_names()), ["f", "i"]) + self.assertTrue(a.has_trait("f")) + self.assertFalse(a.has_trait("g")) + + def test_trait_has_value(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertFalse(a.trait_has_value("f")) + self.assertFalse(a.trait_has_value("g")) + a.i = 1 + a.f + self.assertTrue(a.trait_has_value("i")) + self.assertTrue(a.trait_has_value("f")) + + def test_trait_metadata_deprecated(self): + with expected_warnings([r"metadata should be set using the \.tag\(\) method"]): + + class A(HasTraits): + i = Int(config_key="MY_VALUE") + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") + + def test_trait_metadata(self): + class A(HasTraits): + i = Int().tag(config_key="MY_VALUE") + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") + + def test_trait_metadata_default(self): + class A(HasTraits): + i = Int() + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), None) + self.assertEqual(a.trait_metadata("i", "config_key", "default"), "default") + + def test_traits(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f)) + self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f)) + + def test_traits_metadata(self): + class A(HasTraits): + i = Int().tag(config_key="VALUE1", other_thing="VALUE2") + f = Float().tag(config_key="VALUE3", other_thing="VALUE2") + j = Int(0) + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) + traits = a.traits(config_key="VALUE1", other_thing="VALUE2") + self.assertEqual(traits, dict(i=A.i)) + + # This passes, but it shouldn't because I am replicating a bug in + # traits. + traits = a.traits(config_key=lambda v: True) + self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) + + def test_traits_metadata_deprecated(self): + with expected_warnings([r"metadata should be set using the \.tag\(\) method"] * 2): + + class A(HasTraits): + i = Int(config_key="VALUE1", other_thing="VALUE2") + f = Float(config_key="VALUE3", other_thing="VALUE2") + j = Int(0) + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) + traits = a.traits(config_key="VALUE1", other_thing="VALUE2") + self.assertEqual(traits, dict(i=A.i)) + + # This passes, but it shouldn't because I am replicating a bug in + # traits. + traits = a.traits(config_key=lambda v: True) + self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) + + def test_init(self): + class A(HasTraits): + i = Int() + x = Float() + + a = A(i=1, x=10.0) + self.assertEqual(a.i, 1) + self.assertEqual(a.x, 10.0) + + def test_positional_args(self): + class A(HasTraits): + i = Int(0) + + def __init__(self, i): + super().__init__() + self.i = i + + a = A(5) + self.assertEqual(a.i, 5) + # should raise TypeError if no positional arg given + self.assertRaises(TypeError, A) + + +# ----------------------------------------------------------------------------- +# Tests for specific trait types +# ----------------------------------------------------------------------------- + + +class TestType(TestCase): + def test_default(self): + class B: + pass + + class A(HasTraits): + klass = Type(allow_none=True) + + a = A() + self.assertEqual(a.klass, object) + + a.klass = B + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", 10) + + def test_default_options(self): + class B: + pass + + class C(B): + pass + + class A(HasTraits): + # Different possible combinations of options for default_value + # and klass. default_value=None is only valid with allow_none=True. + k1 = Type() + k2 = Type(None, allow_none=True) + k3 = Type(B) + k4 = Type(klass=B) + k5 = Type(default_value=None, klass=B, allow_none=True) + k6 = Type(default_value=C, klass=B) + + self.assertIs(A.k1.default_value, object) + self.assertIs(A.k1.klass, object) + self.assertIs(A.k2.default_value, None) + self.assertIs(A.k2.klass, object) + self.assertIs(A.k3.default_value, B) + self.assertIs(A.k3.klass, B) + self.assertIs(A.k4.default_value, B) + self.assertIs(A.k4.klass, B) + self.assertIs(A.k5.default_value, None) + self.assertIs(A.k5.klass, B) + self.assertIs(A.k6.default_value, C) + self.assertIs(A.k6.klass, B) + + a = A() + self.assertIs(a.k1, object) + self.assertIs(a.k2, None) + self.assertIs(a.k3, B) + self.assertIs(a.k4, B) + self.assertIs(a.k5, None) + self.assertIs(a.k6, C) + + def test_value(self): + class B: + pass + + class C: + pass + + class A(HasTraits): + klass = Type(B) + + a = A() + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", C) + self.assertRaises(TraitError, setattr, a, "klass", object) + a.klass = B + + def test_allow_none(self): + class B: + pass + + class C(B): + pass + + class A(HasTraits): + klass = Type(B) + + a = A() + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", None) + a.klass = C + self.assertEqual(a.klass, C) + + def test_validate_klass(self): + class A(HasTraits): + klass = Type("no strings allowed") + + self.assertRaises(ImportError, A) + + class A(HasTraits): # type:ignore + klass = Type("rub.adub.Duck") + + self.assertRaises(ImportError, A) + + def test_validate_default(self): + class B: + pass + + class A(HasTraits): + klass = Type("bad default", B) + + self.assertRaises(ImportError, A) + + class C(HasTraits): + klass = Type(None, B) + + self.assertRaises(TraitError, getattr, C(), "klass") + + def test_str_klass(self): + class A(HasTraits): + klass = Type("traitlets.config.Config") + + from traitlets.config import Config + + a = A() + a.klass = Config + self.assertEqual(a.klass, Config) + + self.assertRaises(TraitError, setattr, a, "klass", 10) + + def test_set_str_klass(self): + class A(HasTraits): + klass = Type() + + a = A(klass="traitlets.config.Config") + from traitlets.config import Config + + self.assertEqual(a.klass, Config) + + +class TestInstance(TestCase): + def test_basic(self): + class Foo: + pass + + class Bar(Foo): + pass + + class Bah: + pass + + class A(HasTraits): + inst = Instance(Foo, allow_none=True) + + a = A() + self.assertTrue(a.inst is None) + a.inst = Foo() + self.assertTrue(isinstance(a.inst, Foo)) + a.inst = Bar() + self.assertTrue(isinstance(a.inst, Foo)) + self.assertRaises(TraitError, setattr, a, "inst", Foo) + self.assertRaises(TraitError, setattr, a, "inst", Bar) + self.assertRaises(TraitError, setattr, a, "inst", Bah()) + + def test_default_klass(self): + class Foo: + pass + + class Bar(Foo): + pass + + class Bah: + pass + + class FooInstance(Instance): + klass = Foo + + class A(HasTraits): + inst = FooInstance(allow_none=True) + + a = A() + self.assertTrue(a.inst is None) + a.inst = Foo() + self.assertTrue(isinstance(a.inst, Foo)) + a.inst = Bar() + self.assertTrue(isinstance(a.inst, Foo)) + self.assertRaises(TraitError, setattr, a, "inst", Foo) + self.assertRaises(TraitError, setattr, a, "inst", Bar) + self.assertRaises(TraitError, setattr, a, "inst", Bah()) + + def test_unique_default_value(self): + class Foo: + pass + + class A(HasTraits): + inst = Instance(Foo, (), {}) + + a = A() + b = A() + self.assertTrue(a.inst is not b.inst) + + def test_args_kw(self): + class Foo: + def __init__(self, c): + self.c = c + + class Bar: + pass + + class Bah: + def __init__(self, c, d): + self.c = c + self.d = d + + class A(HasTraits): + inst = Instance(Foo, (10,)) + + a = A() + self.assertEqual(a.inst.c, 10) + + class B(HasTraits): + inst = Instance(Bah, args=(10,), kw=dict(d=20)) + + b = B() + self.assertEqual(b.inst.c, 10) + self.assertEqual(b.inst.d, 20) + + class C(HasTraits): + inst = Instance(Foo, allow_none=True) + + c = C() + self.assertTrue(c.inst is None) + + def test_bad_default(self): + class Foo: + pass + + class A(HasTraits): + inst = Instance(Foo) + + a = A() + with self.assertRaises(TraitError): + a.inst + + def test_instance(self): + class Foo: + pass + + def inner(): + class A(HasTraits): + inst = Instance(Foo()) + + self.assertRaises(TraitError, inner) + + +class TestThis(TestCase): + def test_this_class(self): + class Foo(HasTraits): + this = This() + + f = Foo() + self.assertEqual(f.this, None) + g = Foo() + f.this = g + self.assertEqual(f.this, g) + self.assertRaises(TraitError, setattr, f, "this", 10) + + def test_this_inst(self): + class Foo(HasTraits): + this = This() + + f = Foo() + f.this = Foo() + self.assertTrue(isinstance(f.this, Foo)) + + def test_subclass(self): + class Foo(HasTraits): + t = This() + + class Bar(Foo): + pass + + f = Foo() + b = Bar() + f.t = b + b.t = f + self.assertEqual(f.t, b) + self.assertEqual(b.t, f) + + def test_subclass_override(self): + class Foo(HasTraits): + t = This() + + class Bar(Foo): + t = This() + + f = Foo() + b = Bar() + f.t = b + self.assertEqual(f.t, b) + self.assertRaises(TraitError, setattr, b, "t", f) + + def test_this_in_container(self): + class Tree(HasTraits): + value = Unicode() + leaves = List(This()) + + tree = Tree(value="foo", leaves=[Tree(value="bar"), Tree(value="buzz")]) + + with self.assertRaises(TraitError): + tree.leaves = [1, 2] + + +class TraitTestBase(TestCase): + """A best testing class for basic trait types.""" + + def assign(self, value): + self.obj.value = value # type:ignore + + def coerce(self, value): + return value + + def test_good_values(self): + if hasattr(self, "_good_values"): + for value in self._good_values: + self.assign(value) + self.assertEqual(self.obj.value, self.coerce(value)) # type:ignore + + def test_bad_values(self): + if hasattr(self, "_bad_values"): + for value in self._bad_values: + try: + self.assertRaises(TraitError, self.assign, value) + except AssertionError: + assert False, value + + def test_default_value(self): + if hasattr(self, "_default_value"): + self.assertEqual(self._default_value, self.obj.value) # type:ignore + + def test_allow_none(self): + if ( + hasattr(self, "_bad_values") + and hasattr(self, "_good_values") + and None in self._bad_values + ): + trait = self.obj.traits()["value"] # type:ignore + try: + trait.allow_none = True + self._bad_values.remove(None) + # skip coerce. Allow None casts None to None. + self.assign(None) + self.assertEqual(self.obj.value, None) # type:ignore + self.test_good_values() + self.test_bad_values() + finally: + # tear down + trait.allow_none = False + self._bad_values.append(None) + + def tearDown(self): + # restore default value after tests, if set + if hasattr(self, "_default_value"): + self.obj.value = self._default_value # type:ignore + + +class AnyTrait(HasTraits): + + value = Any() + + +class AnyTraitTest(TraitTestBase): + + obj = AnyTrait() + + _default_value = None + _good_values = [10.0, "ten", [10], {"ten": 10}, (10,), None, 1j] + _bad_values: t.Any = [] + + +class UnionTrait(HasTraits): + + value = Union([Type(), Bool()]) + + +class UnionTraitTest(TraitTestBase): + + obj = UnionTrait(value="traitlets.config.Config") + _good_values = [int, float, True] + _bad_values = [[], (0,), 1j] + + +class CallableTrait(HasTraits): + + value = Callable() + + +class CallableTraitTest(TraitTestBase): + + obj = CallableTrait(value=lambda x: type(x)) + _good_values = [int, sorted, lambda x: print(x)] + _bad_values = [[], 1, ""] + + +class OrTrait(HasTraits): + + value = Bool() | Unicode() + + +class OrTraitTest(TraitTestBase): + + obj = OrTrait() + _good_values = [True, False, "ten"] + _bad_values = [[], (0,), 1j] + + +class IntTrait(HasTraits): + + value = Int(99, min=-100) + + +class TestInt(TraitTestBase): + + obj = IntTrait() + _default_value = 99 + _good_values = [10, -10] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + 10.1, + -10.1, + "10L", + "-10L", + "10.1", + "-10.1", + "10", + "-10", + -200, + ] + + +class CIntTrait(HasTraits): + value = CInt("5") + + +class TestCInt(TraitTestBase): + obj = CIntTrait() + + _default_value = 5 + _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] + + def coerce(self, n): + return int(n) + + +class MinBoundCIntTrait(HasTraits): + value = CInt("5", min=3) + + +class TestMinBoundCInt(TestCInt): + obj = MinBoundCIntTrait() # type:ignore + + _default_value = 5 + _good_values = [3, 3.0, "3"] + _bad_values = [2.6, 2, -3, -3.0] + + +class LongTrait(HasTraits): + + value = Long(99) + + +class TestLong(TraitTestBase): + + obj = LongTrait() + + _default_value = 99 + _good_values = [10, -10] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + 10.1, + -10.1, + "10", + "-10", + "10L", + "-10L", + "10.1", + "-10.1", + ] + + +class MinBoundLongTrait(HasTraits): + value = Long(99, min=5) + + +class TestMinBoundLong(TraitTestBase): + obj = MinBoundLongTrait() + + _default_value = 99 + _good_values = [5, 10] + _bad_values = [4, -10] + + +class MaxBoundLongTrait(HasTraits): + value = Long(5, max=10) + + +class TestMaxBoundLong(TraitTestBase): + obj = MaxBoundLongTrait() + + _default_value = 5 + _good_values = [10, -2] + _bad_values = [11, 20] + + +class CLongTrait(HasTraits): + value = CLong("5") + + +class TestCLong(TraitTestBase): + obj = CLongTrait() + + _default_value = 5 + _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] + + def coerce(self, n): + return int(n) + + +class MaxBoundCLongTrait(HasTraits): + value = CLong("5", max=10) + + +class TestMaxBoundCLong(TestCLong): + obj = MaxBoundCLongTrait() # type:ignore + + _default_value = 5 + _good_values = [10, "10", 10.3] + _bad_values = [11.0, "11"] + + +class IntegerTrait(HasTraits): + value = Integer(1) + + +class TestInteger(TestLong): + obj = IntegerTrait() # type:ignore + _default_value = 1 + + def coerce(self, n): + return int(n) + + +class MinBoundIntegerTrait(HasTraits): + value = Integer(5, min=3) + + +class TestMinBoundInteger(TraitTestBase): + obj = MinBoundIntegerTrait() + + _default_value = 5 + _good_values = 3, 20 + _bad_values = [2, -10] + + +class MaxBoundIntegerTrait(HasTraits): + value = Integer(1, max=3) + + +class TestMaxBoundInteger(TraitTestBase): + obj = MaxBoundIntegerTrait() + + _default_value = 1 + _good_values = 3, -2 + _bad_values = [4, 10] + + +class FloatTrait(HasTraits): + + value = Float(99.0, max=200.0) + + +class TestFloat(TraitTestBase): + + obj = FloatTrait() + + _default_value = 99.0 + _good_values = [10, -10, 10.1, -10.1] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + "10", + "-10", + "10L", + "-10L", + "10.1", + "-10.1", + 201.0, + ] + + +class CFloatTrait(HasTraits): + + value = CFloat("99.0", max=200.0) + + +class TestCFloat(TraitTestBase): + + obj = CFloatTrait() + + _default_value = 99.0 + _good_values = [10, 10.0, 10.5, "10.0", "10", "-10"] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, 200.1, "200.1"] + + def coerce(self, v): + return float(v) + + +class ComplexTrait(HasTraits): + + value = Complex(99.0 - 99.0j) + + +class TestComplex(TraitTestBase): + + obj = ComplexTrait() + + _default_value = 99.0 - 99.0j + _good_values = [ + 10, + -10, + 10.1, + -10.1, + 10j, + 10 + 10j, + 10 - 10j, + 10.1j, + 10.1 + 10.1j, + 10.1 - 10.1j, + ] + _bad_values = ["10L", "-10L", "ten", [10], {"ten": 10}, (10,), None] + + +class BytesTrait(HasTraits): + + value = Bytes(b"string") + + +class TestBytes(TraitTestBase): + + obj = BytesTrait() + + _default_value = b"string" + _good_values = [b"10", b"-10", b"10L", b"-10L", b"10.1", b"-10.1", b"string"] + _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None, "string"] + + +class UnicodeTrait(HasTraits): + + value = Unicode("unicode") + + +class TestUnicode(TraitTestBase): + + obj = UnicodeTrait() + + _default_value = "unicode" + _good_values = ["10", "-10", "10L", "-10L", "10.1", "-10.1", "", "string", "€", b"bytestring"] + _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None] + + def coerce(self, v): + return cast_unicode(v) + + +class ObjectNameTrait(HasTraits): + value = ObjectName("abc") + + +class TestObjectName(TraitTestBase): + obj = ObjectNameTrait() + + _default_value = "abc" + _good_values = ["a", "gh", "g9", "g_", "_G", "a345_"] + _bad_values = [ + 1, + "", + "€", + "9g", + "!", + "#abc", + "aj@", + "a.b", + "a()", + "a[0]", + None, + object(), + object, + ] + _good_values.append("þ") # þ=1 is valid in Python 3 (PEP 3131). + + +class DottedObjectNameTrait(HasTraits): + value = DottedObjectName("a.b") + + +class TestDottedObjectName(TraitTestBase): + obj = DottedObjectNameTrait() + + _default_value = "a.b" + _good_values = ["A", "y.t", "y765.__repr__", "os.path.join"] + _bad_values = [1, "abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None] + + _good_values.append("t.þ") + + +class TCPAddressTrait(HasTraits): + value = TCPAddress() + + +class TestTCPAddress(TraitTestBase): + + obj = TCPAddressTrait() + + _default_value = ("127.0.0.1", 0) + _good_values = [("localhost", 0), ("192.168.0.1", 1000), ("www.google.com", 80)] + _bad_values = [(0, 0), ("localhost", 10.0), ("localhost", -1), None] + + +class ListTrait(HasTraits): + + value = List(Int()) + + +class TestList(TraitTestBase): + + obj = ListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[], [1], list(range(10)), (1, 2)] + _bad_values = [10, [1, "a"], "a"] + + def coerce(self, value): + if value is not None: + value = list(value) + return value + + +class Foo: + pass + + +class NoneInstanceListTrait(HasTraits): + + value = List(Instance(Foo)) + + +class TestNoneInstanceList(TraitTestBase): + + obj = NoneInstanceListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[Foo(), Foo()], []] + _bad_values = [[None], [Foo(), None]] + + +class InstanceListTrait(HasTraits): + + value = List(Instance(__name__ + ".Foo")) + + +class TestInstanceList(TraitTestBase): + + obj = InstanceListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, Foo) + + _default_value: t.List[t.Any] = [] + _good_values = [[Foo(), Foo()], []] + _bad_values = [ + [ + "1", + 2, + ], + "1", + [Foo], + None, + ] + + +class UnionListTrait(HasTraits): + + value = List(Int() | Bool()) + + +class TestUnionListTrait(TraitTestBase): + + obj = UnionListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[True, 1], [False, True]] + _bad_values = [[1, "True"], False] + + +class LenListTrait(HasTraits): + + value = List(Int(), [0], minlen=1, maxlen=2) + + +class TestLenList(TraitTestBase): + + obj = LenListTrait() + + _default_value = [0] + _good_values = [[1], [1, 2], (1, 2)] + _bad_values = [10, [1, "a"], "a", [], list(range(3))] + + def coerce(self, value): + if value is not None: + value = list(value) + return value + + +class TupleTrait(HasTraits): + + value = Tuple(Int(allow_none=True), default_value=(1,)) + + +class TestTupleTrait(TraitTestBase): + + obj = TupleTrait() + + _default_value = (1,) + _good_values = [(1,), (0,), [1]] + _bad_values = [10, (1, 2), ("a"), (), None] + + def coerce(self, value): + if value is not None: + value = tuple(value) + return value + + def test_invalid_args(self): + self.assertRaises(TypeError, Tuple, 5) + self.assertRaises(TypeError, Tuple, default_value="hello") + t = Tuple(Int(), CBytes(), default_value=(1, 5)) + + +class LooseTupleTrait(HasTraits): + + value = Tuple((1, 2, 3)) + + +class TestLooseTupleTrait(TraitTestBase): + + obj = LooseTupleTrait() + + _default_value = (1, 2, 3) + _good_values = [(1,), [1], (0,), tuple(range(5)), tuple("hello"), ("a", 5), ()] + _bad_values = [10, "hello", {}, None] + + def coerce(self, value): + if value is not None: + value = tuple(value) + return value + + def test_invalid_args(self): + self.assertRaises(TypeError, Tuple, 5) + self.assertRaises(TypeError, Tuple, default_value="hello") + t = Tuple(Int(), CBytes(), default_value=(1, 5)) + + +class MultiTupleTrait(HasTraits): + + value = Tuple(Int(), Bytes(), default_value=[99, b"bottles"]) + + +class TestMultiTuple(TraitTestBase): + + obj = MultiTupleTrait() + + _default_value = (99, b"bottles") + _good_values = [(1, b"a"), (2, b"b")] + _bad_values = ((), 10, b"a", (1, b"a", 3), (b"a", 1), (1, "a")) + + +@pytest.mark.parametrize( + "Trait", + ( + List, + Tuple, + Set, + Dict, + Integer, + Unicode, + ), +) +def test_allow_none_default_value(Trait): + class C(HasTraits): + t = Trait(default_value=None, allow_none=True) + + # test default value + c = C() + assert c.t is None + + # and in constructor + c = C(t=None) + assert c.t is None + + +@pytest.mark.parametrize( + "Trait, default_value", + ((List, []), (Tuple, ()), (Set, set()), (Dict, {}), (Integer, 0), (Unicode, "")), +) +def test_default_value(Trait, default_value): + class C(HasTraits): + t = Trait() + + # test default value + c = C() + assert type(c.t) is type(default_value) + assert c.t == default_value + + +@pytest.mark.parametrize( + "Trait, default_value", + ((List, []), (Tuple, ()), (Set, set())), +) +def test_subclass_default_value(Trait, default_value): + """Test deprecated default_value=None behavior for Container subclass traits""" + + class SubclassTrait(Trait): # type:ignore + def __init__(self, default_value=None): + super().__init__(default_value=default_value) + + class C(HasTraits): + t = SubclassTrait() + + # test default value + c = C() + assert type(c.t) is type(default_value) + assert c.t == default_value + + +class CRegExpTrait(HasTraits): + + value = CRegExp(r"") + + +class TestCRegExp(TraitTestBase): + def coerce(self, value): + return re.compile(value) + + obj = CRegExpTrait() + + _default_value = re.compile(r"") + _good_values = [r"\d+", re.compile(r"\d+")] + _bad_values = ["(", None, ()] + + +class DictTrait(HasTraits): + value = Dict() + + +def test_dict_assignment(): + d: t.Dict[str, int] = {} + c = DictTrait() + c.value = d + d["a"] = 5 + assert d == c.value + assert c.value is d + + +class UniformlyValueValidatedDictTrait(HasTraits): + + value = Dict(value_trait=Unicode(), default_value={"foo": "1"}) + + +class TestInstanceUniformlyValueValidatedDict(TraitTestBase): + + obj = UniformlyValueValidatedDictTrait() + + _default_value = {"foo": "1"} + _good_values = [{"foo": "0", "bar": "1"}] + _bad_values = [{"foo": 0, "bar": "1"}] + + +class NonuniformlyValueValidatedDictTrait(HasTraits): + + value = Dict(per_key_traits={"foo": Int()}, default_value={"foo": 1}) + + +class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase): + + obj = NonuniformlyValueValidatedDictTrait() + + _default_value = {"foo": 1} + _good_values = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": 1}] + _bad_values = [{"foo": "0", "bar": "1"}] + + +class KeyValidatedDictTrait(HasTraits): + + value = Dict(key_trait=Unicode(), default_value={"foo": "1"}) + + +class TestInstanceKeyValidatedDict(TraitTestBase): + + obj = KeyValidatedDictTrait() + + _default_value = {"foo": "1"} + _good_values = [{"foo": "0", "bar": "1"}] + _bad_values = [{"foo": "0", 0: "1"}] + + +class FullyValidatedDictTrait(HasTraits): + + value = Dict( + value_trait=Unicode(), + key_trait=Unicode(), + per_key_traits={"foo": Int()}, + default_value={"foo": 1}, + ) + + +class TestInstanceFullyValidatedDict(TraitTestBase): + + obj = FullyValidatedDictTrait() + + _default_value = {"foo": 1} + _good_values = [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}] + _bad_values = [{"foo": 0, "bar": 1}, {"foo": "0", "bar": "1"}, {"foo": 0, 0: "1"}] + + +def test_dict_default_value(): + """Check that the `{}` default value of the Dict traitlet constructor is + actually copied.""" + + class Foo(HasTraits): + d1 = Dict() + d2 = Dict() + + foo = Foo() + assert foo.d1 == {} + assert foo.d2 == {} + assert foo.d1 is not foo.d2 + + +class TestValidationHook(TestCase): + def test_parity_trait(self): + """Verify that the early validation hook is effective""" + + class Parity(HasTraits): + + value = Int(0) + parity = Enum(["odd", "even"], default_value="even") + + @validate("value") + def _value_validate(self, proposal): + value = proposal["value"] + if self.parity == "even" and value % 2: + raise TraitError("Expected an even number") + if self.parity == "odd" and (value % 2 == 0): + raise TraitError("Expected an odd number") + return value + + u = Parity() + u.parity = "odd" + u.value = 1 # OK + with self.assertRaises(TraitError): + u.value = 2 # Trait Error + + u.parity = "even" + u.value = 2 # OK + + def test_multiple_validate(self): + """Verify that we can register the same validator to multiple names""" + + class OddEven(HasTraits): + + odd = Int(1) + even = Int(0) + + @validate("odd", "even") + def check_valid(self, proposal): + if proposal["trait"].name == "odd" and not proposal["value"] % 2: + raise TraitError("odd should be odd") + if proposal["trait"].name == "even" and proposal["value"] % 2: + raise TraitError("even should be even") + + u = OddEven() + u.odd = 3 # OK + with self.assertRaises(TraitError): + u.odd = 2 # Trait Error + + u.even = 2 # OK + with self.assertRaises(TraitError): + u.even = 3 # Trait Error + + def test_validate_used(self): + """Verify that the validate value is being used""" + + class FixedValue(HasTraits): + value = Int(0) + + @validate("value") + def _value_validate(self, proposal): + return -1 + + u = FixedValue(value=2) + assert u.value == -1 + + u = FixedValue() + u.value = 3 + assert u.value == -1 + + +class TestLink(TestCase): + def test_connect_same(self): + """Verify two traitlets of the same type can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "value")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.value) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.value) + b.value = 6 + self.assertEqual(a.value, b.value) + + def test_link_different(self): + """Verify two traitlets of different types can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "count")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.count) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.count) + b.count = 4 + self.assertEqual(a.value, b.count) + + def test_unlink_link(self): + """Verify two linked traitlets can be unlinked and relinked.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Connect the two classes. + c = link((a, "value"), (b, "value")) + a.value = 4 + c.unlink() + + # Change one of the values to make sure they don't stay in sync. + a.value = 5 + self.assertNotEqual(a.value, b.value) + c.link() + self.assertEqual(a.value, b.value) + a.value += 1 + self.assertEqual(a.value, b.value) + + def test_callbacks(self): + """Verify two linked traitlets have their callbacks called once.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Register callbacks that count. + callback_count = [] + + def a_callback(name, old, new): + callback_count.append("a") + + a.on_trait_change(a_callback, "value") + + def b_callback(name, old, new): + callback_count.append("b") + + b.on_trait_change(b_callback, "count") + + # Connect the two classes. + c = link((a, "value"), (b, "count")) + + # Make sure b's count was set to a's value once. + self.assertEqual("".join(callback_count), "b") + del callback_count[:] + + # Make sure a's value was set to b's count once. + b.count = 5 + self.assertEqual("".join(callback_count), "ba") + del callback_count[:] + + # Make sure b's count was set to a's value once. + a.value = 4 + self.assertEqual("".join(callback_count), "ab") + del callback_count[:] + + def test_tranform(self): + """Test transform link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "value"), transform=(lambda x: 2 * x, lambda x: int(x / 2.0))) + + # Make sure the values are correct at the point of linking. + self.assertEqual(b.value, 2 * a.value) + + # Change one the value of the source and check that it modifies the target. + a.value = 5 + self.assertEqual(b.value, 10) + # Change one the value of the target and check that it modifies the + # source. + b.value = 6 + self.assertEqual(a.value, 3) + + def test_link_broken_at_source(self): + class MyClass(HasTraits): + i = Int() + j = Int() + + @observe("j") + def another_update(self, change): + self.i = change.new * 2 + + mc = MyClass() + l = link((mc, "i"), (mc, "j")) # noqa + self.assertRaises(TraitError, setattr, mc, "i", 2) + + def test_link_broken_at_target(self): + class MyClass(HasTraits): + i = Int() + j = Int() + + @observe("i") + def another_update(self, change): + self.j = change.new * 2 + + mc = MyClass() + l = link((mc, "i"), (mc, "j")) # noqa + self.assertRaises(TraitError, setattr, mc, "j", 2) + + +class TestDirectionalLink(TestCase): + def test_connect_same(self): + """Verify two traitlets of the same type can be linked together using directional_link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "value")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.value) + + # Change one the value of the source and check that it synchronizes the target. + a.value = 5 + self.assertEqual(b.value, 5) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 + self.assertEqual(a.value, 5) + + def test_tranform(self): + """Test transform link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "value"), lambda x: 2 * x) + + # Make sure the values are correct at the point of linking. + self.assertEqual(b.value, 2 * a.value) + + # Change one the value of the source and check that it modifies the target. + a.value = 5 + self.assertEqual(b.value, 10) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 + self.assertEqual(a.value, 5) + + def test_link_different(self): + """Verify two traitlets of different types can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "count")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.count) + + # Change one the value of the source and check that it synchronizes the target. + a.value = 5 + self.assertEqual(b.count, 5) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 # type:ignore + self.assertEqual(a.value, 5) + + def test_unlink_link(self): + """Verify two linked traitlets can be unlinked and relinked.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Connect the two classes. + c = directional_link((a, "value"), (b, "value")) + a.value = 4 + c.unlink() + + # Change one of the values to make sure they don't stay in sync. + a.value = 5 + self.assertNotEqual(a.value, b.value) + c.link() + self.assertEqual(a.value, b.value) + a.value += 1 + self.assertEqual(a.value, b.value) + + +class Pickleable(HasTraits): + + i = Int() + + @observe("i") + def _i_changed(self, change): + pass + + @validate("i") + def _i_validate(self, commit): + return commit["value"] + + j = Int() + + def __init__(self): + with self.hold_trait_notifications(): + self.i = 1 + self.on_trait_change(self._i_changed, "i") + + +def test_pickle_hastraits(): + c = Pickleable() + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(c, protocol) + c2 = pickle.loads(p) + assert c2.i == c.i + assert c2.j == c.j + + c.i = 5 + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(c, protocol) + c2 = pickle.loads(p) + assert c2.i == c.i + assert c2.j == c.j + + +def test_hold_trait_notifications(): + changes = [] + + class Test(HasTraits): + a = Integer(0) + b = Integer(0) + + def _a_changed(self, name, old, new): + changes.append((old, new)) + + def _b_validate(self, value, trait): + if value != 0: + raise TraitError("Only 0 is a valid value") + return value + + # Test context manager and nesting + t = Test() + with t.hold_trait_notifications(): + with t.hold_trait_notifications(): + t.a = 1 + assert t.a == 1 + assert changes == [] + t.a = 2 + assert t.a == 2 + with t.hold_trait_notifications(): + t.a = 3 + assert t.a == 3 + assert changes == [] + t.a = 4 + assert t.a == 4 + assert changes == [] + t.a = 4 + assert t.a == 4 + assert changes == [] + + assert changes == [(0, 4)] + # Test roll-back + try: + with t.hold_trait_notifications(): + t.b = 1 # raises a Trait error + except Exception: + pass + assert t.b == 0 + + +class RollBack(HasTraits): + bar = Int() + + def _bar_validate(self, value, trait): + if value: + raise TraitError("foobar") + return value + + +class TestRollback(TestCase): + def test_roll_back(self): + def assign_rollback(): + RollBack(bar=1) + + self.assertRaises(TraitError, assign_rollback) + + +class CacheModification(HasTraits): + foo = Int() + bar = Int() + + def _bar_validate(self, value, trait): + self.foo = value + return value + + def _foo_validate(self, value, trait): + self.bar = value + return value + + +def test_cache_modification(): + CacheModification(foo=1) + CacheModification(bar=1) + + +class OrderTraits(HasTraits): + notified = Dict() + + a = Unicode() + b = Unicode() + c = Unicode() + d = Unicode() + e = Unicode() + f = Unicode() + g = Unicode() + h = Unicode() + i = Unicode() + j = Unicode() + k = Unicode() + l = Unicode() # noqa + + def _notify(self, name, old, new): + """check the value of all traits when each trait change is triggered + + This verifies that the values are not sensitive + to dict ordering when loaded from kwargs + """ + # check the value of the other traits + # when a given trait change notification fires + self.notified[name] = {c: getattr(self, c) for c in "abcdefghijkl"} + + def __init__(self, **kwargs): + self.on_trait_change(self._notify) + super().__init__(**kwargs) + + +def test_notification_order(): + d = {c: c for c in "abcdefghijkl"} + obj = OrderTraits() + assert obj.notified == {} + obj = OrderTraits(**d) + notifications = {c: d for c in "abcdefghijkl"} + assert obj.notified == notifications + + +### +# Traits for Forward Declaration Tests +### +class ForwardDeclaredInstanceTrait(HasTraits): + + value = ForwardDeclaredInstance("ForwardDeclaredBar", allow_none=True) + + +class ForwardDeclaredTypeTrait(HasTraits): + + value = ForwardDeclaredType("ForwardDeclaredBar", allow_none=True) + + +class ForwardDeclaredInstanceListTrait(HasTraits): + + value = List(ForwardDeclaredInstance("ForwardDeclaredBar")) + + +class ForwardDeclaredTypeListTrait(HasTraits): + + value = List(ForwardDeclaredType("ForwardDeclaredBar")) + + +### +# End Traits for Forward Declaration Tests +### + +### +# Classes for Forward Declaration Tests +### +class ForwardDeclaredBar: + pass + + +class ForwardDeclaredBarSub(ForwardDeclaredBar): + pass + + +### +# End Classes for Forward Declaration Tests +### + +### +# Forward Declaration Tests +### +class TestForwardDeclaredInstanceTrait(TraitTestBase): + + obj = ForwardDeclaredInstanceTrait() + _default_value = None + _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()] + _bad_values = ["foo", 3, ForwardDeclaredBar, ForwardDeclaredBarSub] + + +class TestForwardDeclaredTypeTrait(TraitTestBase): + + obj = ForwardDeclaredTypeTrait() + _default_value = None + _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub] + _bad_values = ["foo", 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()] + + +class TestForwardDeclaredInstanceList(TraitTestBase): + + obj = ForwardDeclaredInstanceListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) + + _default_value: t.List[t.Any] = [] + _good_values = [ + [ForwardDeclaredBar(), ForwardDeclaredBarSub()], + [], + ] + _bad_values = [ + ForwardDeclaredBar(), + [ForwardDeclaredBar(), 3, None], + "1", + # Note that this is the type, not an instance. + [ForwardDeclaredBar], + [None], + None, + ] + + +class TestForwardDeclaredTypeList(TraitTestBase): + + obj = ForwardDeclaredTypeListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) + + _default_value: t.List[t.Any] = [] + _good_values = [ + [ForwardDeclaredBar, ForwardDeclaredBarSub], + [], + ] + _bad_values = [ + ForwardDeclaredBar, + [ForwardDeclaredBar, 3], + "1", + # Note that this is an instance, not the type. + [ForwardDeclaredBar()], + [None], + None, + ] + + +### +# End Forward Declaration Tests +### + + +class TestDynamicTraits(TestCase): + def setUp(self): + self._notify1 = [] + + def notify1(self, name, old, new): + self._notify1.append((name, old, new)) + + @t.no_type_check + def test_notify_all(self): + class A(HasTraits): + pass + + a = A() + self.assertTrue(not hasattr(a, "x")) + self.assertTrue(not hasattr(a, "y")) + + # Dynamically add trait x. + a.add_traits(x=Int()) + self.assertTrue(hasattr(a, "x")) + self.assertTrue(isinstance(a, (A,))) + + # Dynamically add trait y. + a.add_traits(y=Float()) + self.assertTrue(hasattr(a, "y")) + self.assertTrue(isinstance(a, (A,))) + self.assertEqual(a.__class__.__name__, A.__name__) + + # Create a new instance and verify that x and y + # aren't defined. + b = A() + self.assertTrue(not hasattr(b, "x")) + self.assertTrue(not hasattr(b, "y")) + + # Verify that notification works like normal. + a.on_trait_change(self.notify1) + a.x = 0 + self.assertEqual(len(self._notify1), 0) + a.y = 0.0 + self.assertEqual(len(self._notify1), 0) + a.x = 10 + self.assertTrue(("x", 0, 10) in self._notify1) + a.y = 10.0 + self.assertTrue(("y", 0.0, 10.0) in self._notify1) + self.assertRaises(TraitError, setattr, a, "x", "bad string") + self.assertRaises(TraitError, setattr, a, "y", "bad string") + self._notify1 = [] + a.on_trait_change(self.notify1, remove=True) + a.x = 20 + a.y = 20.0 + self.assertEqual(len(self._notify1), 0) + + +def test_enum_no_default(): + class C(HasTraits): + t = Enum(["a", "b"]) + + c = C() + c.t = "a" + assert c.t == "a" + + c = C() + + with pytest.raises(TraitError): + t = c.t + + c = C(t="b") + assert c.t == "b" + + +def test_default_value_repr(): + class C(HasTraits): + t = Type("traitlets.HasTraits") + t2 = Type(HasTraits) + n = Integer(0) + lis = List() + d = Dict() + + assert C.t.default_value_repr() == "'traitlets.HasTraits'" + assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'" + assert C.n.default_value_repr() == "0" + assert C.lis.default_value_repr() == "[]" + assert C.d.default_value_repr() == "{}" + + +class TransitionalClass(HasTraits): + + d = Any() + + @default("d") + def _d_default(self): + return TransitionalClass + + parent_super = False + calls_super = Integer(0) + + @default("calls_super") + def _calls_super_default(self): + return -1 + + @observe("calls_super") + @observe_compat + def _calls_super_changed(self, change): + self.parent_super = change + + parent_override = False + overrides = Integer(0) + + @observe("overrides") + @observe_compat + def _overrides_changed(self, change): + self.parent_override = change + + +class SubClass(TransitionalClass): + def _d_default(self): + return SubClass + + subclass_super = False + + def _calls_super_changed(self, name, old, new): + self.subclass_super = True + super()._calls_super_changed(name, old, new) + + subclass_override = False + + def _overrides_changed(self, name, old, new): + self.subclass_override = True + + +def test_subclass_compat(): + obj = SubClass() + obj.calls_super = 5 + assert obj.parent_super + assert obj.subclass_super + obj.overrides = 5 + assert obj.subclass_override + assert not obj.parent_override + assert obj.d is SubClass + + +class DefinesHandler(HasTraits): + parent_called = False + + trait = Integer() + + @observe("trait") + def handler(self, change): + self.parent_called = True + + +class OverridesHandler(DefinesHandler): + child_called = False + + @observe("trait") + def handler(self, change): + self.child_called = True + + +def test_subclass_override_observer(): + obj = OverridesHandler() + obj.trait = 5 + assert obj.child_called + assert not obj.parent_called + + +class DoesntRegisterHandler(DefinesHandler): + child_called = False + + def handler(self, change): + self.child_called = True + + +def test_subclass_override_not_registered(): + """Subclass that overrides observer and doesn't re-register unregisters both""" + obj = DoesntRegisterHandler() + obj.trait = 5 + assert not obj.child_called + assert not obj.parent_called + + +class AddsHandler(DefinesHandler): + child_called = False + + @observe("trait") + def child_handler(self, change): + self.child_called = True + + +def test_subclass_add_observer(): + obj = AddsHandler() + obj.trait = 5 + assert obj.child_called + assert obj.parent_called + + +def test_observe_iterables(): + class C(HasTraits): + i = Integer() + s = Unicode() + + c = C() + recorded = {} + + def record(change): + recorded["change"] = change + + # observe with names=set + c.observe(record, names={"i", "s"}) + c.i = 5 + assert recorded["change"].name == "i" + assert recorded["change"].new == 5 + c.s = "hi" + assert recorded["change"].name == "s" + assert recorded["change"].new == "hi" + + # observe with names=custom container with iter, contains + class MyContainer: + def __init__(self, container): + self.container = container + + def __iter__(self): + return iter(self.container) + + def __contains__(self, key): + return key in self.container + + c.observe(record, names=MyContainer({"i", "s"})) + c.i = 10 + assert recorded["change"].name == "i" + assert recorded["change"].new == 10 + c.s = "ok" + assert recorded["change"].name == "s" + assert recorded["change"].new == "ok" + + +def test_super_args(): + class SuperRecorder: + def __init__(self, *args, **kwargs): + self.super_args = args + self.super_kwargs = kwargs + + class SuperHasTraits(HasTraits, SuperRecorder): + i = Integer() + + obj = SuperHasTraits("a1", "a2", b=10, i=5, c="x") + assert obj.i == 5 + assert not hasattr(obj, "b") + assert not hasattr(obj, "c") + assert obj.super_args == ("a1", "a2") + assert obj.super_kwargs == {"b": 10, "c": "x"} + + +def test_super_bad_args(): + class SuperHasTraits(HasTraits): + a = Integer() + + w = ["Passing unrecognized arguments"] + with expected_warnings(w): + obj = SuperHasTraits(a=1, b=2) + assert obj.a == 1 + assert not hasattr(obj, "b") + + +def test_default_mro(): + """Verify that default values follow mro""" + + class Base(HasTraits): + trait = Unicode("base") + attr = "base" + + class A(Base): + pass + + class B(Base): + trait = Unicode("B") + attr = "B" + + class AB(A, B): + pass + + class BA(B, A): + pass + + assert A().trait == "base" + assert A().attr == "base" + assert BA().trait == "B" + assert BA().attr == "B" + assert AB().trait == "B" + assert AB().attr == "B" + + +def test_cls_self_argument(): + class X(HasTraits): + def __init__(__self, cls, self): # noqa + pass + + x = X(cls=None, self=None) + + +def test_override_default(): + class C(HasTraits): + a = Unicode("hard default") + + def _a_default(self): + return "default method" + + C._a_default = lambda self: "overridden" # type:ignore + c = C() + assert c.a == "overridden" + + +def test_override_default_decorator(): + class C(HasTraits): + a = Unicode("hard default") + + @default("a") + def _a_default(self): + return "default method" + + C._a_default = lambda self: "overridden" # type:ignore + c = C() + assert c.a == "overridden" + + +def test_override_default_instance(): + class C(HasTraits): + a = Unicode("hard default") + + @default("a") + def _a_default(self): + return "default method" + + c = C() + c._a_default = lambda self: "overridden" + assert c.a == "overridden" + + +def test_copy_HasTraits(): + from copy import copy + + class C(HasTraits): + a = Int() + + c = C(a=1) + assert c.a == 1 + + cc = copy(c) + cc.a = 2 + assert cc.a == 2 + assert c.a == 1 + + +def _from_string_test(traittype, s, expected): + """Run a test of trait.from_string""" + if isinstance(traittype, TraitType): + trait = traittype + else: + trait = traittype(allow_none=True) + if isinstance(s, list): + cast = trait.from_string_list # type:ignore + else: + cast = trait.from_string + if type(expected) is type and issubclass(expected, Exception): + with pytest.raises(expected): + value = cast(s) + trait.validate(CrossValidationStub(), value) # type:ignore + else: + value = cast(s) + assert value == expected + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], +) +def test_unicode_from_string(s, expected): + _from_string_test(Unicode, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], +) +def test_cunicode_from_string(s, expected): + _from_string_test(CUnicode, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], +) +def test_bytes_from_string(s, expected): + _from_string_test(Bytes, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], +) +def test_cbytes_from_string(s, expected): + _from_string_test(CBytes, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)], +) +def test_int_from_string(s, expected): + _from_string_test(Integer, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)], +) +def test_float_from_string(s, expected): + _from_string_test(Float, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("x", ValueError), + ("1", 1.0), + ("123.5", 123.5), + ("2.5", 2.5), + ("1+2j", 1 + 2j), + ("None", None), + ], +) +def test_complex_from_string(s, expected): + _from_string_test(Complex, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("true", True), + ("TRUE", True), + ("1", True), + ("0", False), + ("False", False), + ("false", False), + ("1.0", ValueError), + ("None", None), + ], +) +def test_bool_from_string(s, expected): + _from_string_test(Bool, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("{}", {}), + ("1", TraitError), + ("{1: 2}", {1: 2}), + ('{"key": "value"}', {"key": "value"}), + ("x", TraitError), + ("None", None), + ], +) +def test_dict_from_string(s, expected): + _from_string_test(Dict, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("[]", []), + ('[1, 2, "x"]', [1, 2, "x"]), + (["1", "x"], ["1", "x"]), + (["None"], None), + ], +) +def test_list_from_string(s, expected): + _from_string_test(List, s, expected) + + +@pytest.mark.parametrize( + "s, expected, value_trait", + [ + (["1", "2", "3"], [1, 2, 3], Integer()), + (["x"], ValueError, Integer()), + (["1", "x"], ["1", "x"], Unicode()), + (["None"], [None], Unicode(allow_none=True)), + (["None"], ["None"], Unicode(allow_none=False)), + ], +) +def test_list_items_from_string(s, expected, value_trait): + _from_string_test(List(value_trait), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("[]", set()), + ('[1, 2, "x"]', {1, 2, "x"}), + ('{1, 2, "x"}', {1, 2, "x"}), + (["1", "x"], {"1", "x"}), + (["None"], None), + ], +) +def test_set_from_string(s, expected): + _from_string_test(Set, s, expected) + + +@pytest.mark.parametrize( + "s, expected, value_trait", + [ + (["1", "2", "3"], {1, 2, 3}, Integer()), + (["x"], ValueError, Integer()), + (["1", "x"], {"1", "x"}, Unicode()), + (["None"], {None}, Unicode(allow_none=True)), + ], +) +def test_set_items_from_string(s, expected, value_trait): + _from_string_test(Set(value_trait), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("[]", ()), + ("()", ()), + ('[1, 2, "x"]', (1, 2, "x")), + ('(1, 2, "x")', (1, 2, "x")), + (["1", "x"], ("1", "x")), + (["None"], None), + ], +) +def test_tuple_from_string(s, expected): + _from_string_test(Tuple, s, expected) + + +@pytest.mark.parametrize( + "s, expected, value_traits", + [ + (["1", "2", "3"], (1, 2, 3), [Integer(), Integer(), Integer()]), + (["x"], ValueError, [Integer()]), + (["1", "x"], ("1", "x"), [Unicode()]), + (["None"], ("None",), [Unicode(allow_none=False)]), + (["None"], (None,), [Unicode(allow_none=True)]), + ], +) +def test_tuple_items_from_string(s, expected, value_traits): + _from_string_test(Tuple(*value_traits), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("x", "x"), + ("mod.submod", "mod.submod"), + ("not an identifier", TraitError), + ("1", "1"), + ("None", None), + ], +) +def test_object_from_string(s, expected): + _from_string_test(DottedObjectName, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("127.0.0.1:8000", ("127.0.0.1", 8000)), + ("host.tld:80", ("host.tld", 80)), + ("host:notaport", ValueError), + ("127.0.0.1", ValueError), + ("None", None), + ], +) +def test_tcp_from_string(s, expected): + _from_string_test(TCPAddress, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("[]", []), ("{}", "{}")], +) +def test_union_of_list_and_unicode_from_string(s, expected): + _from_string_test(Union([List(), Unicode()]), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("1", 1), ("1.5", 1.5)], +) +def test_union_of_int_and_float_from_string(s, expected): + _from_string_test(Union([Int(), Float()]), s, expected) + + +@pytest.mark.parametrize( + "s, expected, allow_none", + [("[]", [], False), ("{}", {}, False), ("None", TraitError, False), ("None", None, True)], +) +def test_union_of_list_and_dict_from_string(s, expected, allow_none): + _from_string_test(Union([List(), Dict()], allow_none=allow_none), s, expected) + + +def test_all_attribute(): + """Verify all trait types are added to `traitlets.__all__`""" + names = dir(traitlets) + for name in names: + value = getattr(traitlets, name) + if not name.startswith("_") and isinstance(value, type) and issubclass(value, TraitType): + if name not in traitlets.__all__: + raise ValueError(f"{name} not in __all__") + + for name in traitlets.__all__: + if name not in names: + raise ValueError(f"{name} should be removed from __all__") diff --git a/contrib/python/traitlets/py3/traitlets/tests/test_traitlets_enum.py b/contrib/python/traitlets/py3/traitlets/tests/test_traitlets_enum.py new file mode 100644 index 0000000000..c39007e8a0 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/tests/test_traitlets_enum.py @@ -0,0 +1,380 @@ +# pylint: disable=missing-docstring, too-few-public-methods +""" +Test the trait-type ``UseEnum``. +""" + +import enum +import unittest + +from traitlets import CaselessStrEnum, Enum, FuzzyEnum, HasTraits, TraitError, UseEnum + +# ----------------------------------------------------------------------------- +# TEST SUPPORT: +# ----------------------------------------------------------------------------- + + +class Color(enum.Enum): + red = 1 + green = 2 + blue = 3 + yellow = 4 + + +class OtherColor(enum.Enum): + red = 0 + green = 1 + + +class CSColor(enum.Enum): + red = 1 + Green = 2 + BLUE = 3 + YeLLoW = 4 + + +color_choices = "red Green BLUE YeLLoW".split() + + +# ----------------------------------------------------------------------------- +# TESTSUITE: +# ----------------------------------------------------------------------------- +class TestUseEnum(unittest.TestCase): + # pylint: disable=invalid-name + + class Example(HasTraits): + color = UseEnum(Color, help="Color enum") + + def test_assign_enum_value(self): + example = self.Example() + example.color = Color.green + self.assertEqual(example.color, Color.green) + + def test_assign_all_enum_values(self): + # pylint: disable=no-member + enum_values = list(Color.__members__.values()) + for value in enum_values: + self.assertIsInstance(value, Color) + example = self.Example() + example.color = value + self.assertEqual(example.color, value) + self.assertIsInstance(value, Color) + + def test_assign_enum_value__with_other_enum_raises_error(self): + example = self.Example() + with self.assertRaises(TraitError): + example.color = OtherColor.green + + def test_assign_enum_name_1(self): + # -- CONVERT: string => Enum value (item) + example = self.Example() + example.color = "red" + self.assertEqual(example.color, Color.red) + + def test_assign_enum_value_name(self): + # -- CONVERT: string => Enum value (item) + # pylint: disable=no-member + enum_names = [enum_val.name for enum_val in Color.__members__.values()] + for value in enum_names: + self.assertIsInstance(value, str) + example = self.Example() + enum_value = Color.__members__.get(value) + example.color = value + self.assertIs(example.color, enum_value) + self.assertEqual(example.color.name, value) # type:ignore + + def test_assign_scoped_enum_value_name(self): + # -- CONVERT: string => Enum value (item) + scoped_names = ["Color.red", "Color.green", "Color.blue", "Color.yellow"] + for value in scoped_names: + example = self.Example() + example.color = value + self.assertIsInstance(example.color, Color) + self.assertEqual(str(example.color), value) + + def test_assign_bad_enum_value_name__raises_error(self): + # -- CONVERT: string => Enum value (item) + bad_enum_names = ["UNKNOWN_COLOR", "RED", "Green", "blue2"] + for value in bad_enum_names: + example = self.Example() + with self.assertRaises(TraitError): + example.color = value + + def test_assign_enum_value_number_1(self): + # -- CONVERT: number => Enum value (item) + example = self.Example() + example.color = 1 # == Color.red.value + example.color = Color.red.value + self.assertEqual(example.color, Color.red) + + def test_assign_enum_value_number(self): + # -- CONVERT: number => Enum value (item) + # pylint: disable=no-member + enum_numbers = [enum_val.value for enum_val in Color.__members__.values()] + for value in enum_numbers: + self.assertIsInstance(value, int) + example = self.Example() + example.color = value + self.assertIsInstance(example.color, Color) + self.assertEqual(example.color.value, value) # type:ignore + + def test_assign_bad_enum_value_number__raises_error(self): + # -- CONVERT: number => Enum value (item) + bad_numbers = [-1, 0, 5] + for value in bad_numbers: + self.assertIsInstance(value, int) + assert UseEnum(Color).select_by_number(value, None) is None + example = self.Example() + with self.assertRaises(TraitError): + example.color = value + + def test_ctor_without_default_value(self): + # -- IMPLICIT: default_value = Color.red (first enum-value) + class Example2(HasTraits): + color = UseEnum(Color) + + example = Example2() + self.assertEqual(example.color, Color.red) + + def test_ctor_with_default_value_as_enum_value(self): + # -- CONVERT: number => Enum value (item) + class Example2(HasTraits): + color = UseEnum(Color, default_value=Color.green) + + example = Example2() + self.assertEqual(example.color, Color.green) + + def test_ctor_with_default_value_none_and_not_allow_none(self): + # -- IMPLICIT: default_value = Color.red (first enum-value) + class Example2(HasTraits): + color1 = UseEnum(Color, default_value=None, allow_none=False) + color2 = UseEnum(Color, default_value=None) + + example = Example2() + self.assertEqual(example.color1, Color.red) + self.assertEqual(example.color2, Color.red) + + def test_ctor_with_default_value_none_and_allow_none(self): + class Example2(HasTraits): + color1 = UseEnum(Color, default_value=None, allow_none=True) + color2 = UseEnum(Color, allow_none=True) + + example = Example2() + self.assertIs(example.color1, None) + self.assertIs(example.color2, None) + + def test_assign_none_without_allow_none_resets_to_default_value(self): + class Example2(HasTraits): + color1 = UseEnum(Color, allow_none=False) + color2 = UseEnum(Color) + + example = Example2() + example.color1 = None + example.color2 = None + self.assertIs(example.color1, Color.red) + self.assertIs(example.color2, Color.red) + + def test_assign_none_to_enum_or_none(self): + class Example2(HasTraits): + color = UseEnum(Color, allow_none=True) + + example = Example2() + example.color = None + self.assertIs(example.color, None) + + def test_assign_bad_value_with_to_enum_or_none(self): + class Example2(HasTraits): + color = UseEnum(Color, allow_none=True) + + example = Example2() + with self.assertRaises(TraitError): + example.color = "BAD_VALUE" + + def test_info(self): + choices = color_choices + + class Example(HasTraits): + enum1 = Enum(choices, allow_none=False) + enum2 = CaselessStrEnum(choices, allow_none=False) + enum3 = FuzzyEnum(choices, allow_none=False) + enum4 = UseEnum(CSColor, allow_none=False) + + for i in range(1, 5): + attr = "enum%s" % i + enum = getattr(Example, attr) + + enum.allow_none = True + + info = enum.info() + self.assertEqual(len(info.split(", ")), len(choices), info.split(", ")) + self.assertIn("or None", info) + + info = enum.info_rst() + self.assertEqual(len(info.split("|")), len(choices), info.split("|")) + self.assertIn("or `None`", info) + # Check no single `\` exists. + self.assertNotRegex(info, r"\b\\\b") + + enum.allow_none = False + + info = enum.info() + self.assertEqual(len(info.split(", ")), len(choices), info.split(", ")) + self.assertNotIn("None", info) + + info = enum.info_rst() + self.assertEqual(len(info.split("|")), len(choices), info.split("|")) + self.assertNotIn("None", info) + # Check no single `\` exists. + self.assertNotRegex(info, r"\b\\\b") + + +# ----------------------------------------------------------------------------- +# TESTSUITE: +# ----------------------------------------------------------------------------- + + +class TestFuzzyEnum(unittest.TestCase): + # Check mostly `validate()`, Ctor must be checked on generic `Enum` + # or `CaselessStrEnum`. + + def test_search_all_prefixes__overwrite(self): + class FuzzyExample(HasTraits): + color = FuzzyEnum(color_choices, help="Color enum") + + example = FuzzyExample() + for color in color_choices: + for wlen in range(1, len(color)): + value = color[:wlen] + + example.color = value + self.assertEqual(example.color, color) + + example.color = value.upper() + self.assertEqual(example.color, color) + + example.color = value.lower() + self.assertEqual(example.color, color) + + def test_search_all_prefixes__ctor(self): + class FuzzyExample(HasTraits): + color = FuzzyEnum(color_choices, help="Color enum") + + for color in color_choices: + for wlen in range(1, len(color)): + value = color[:wlen] + + example = FuzzyExample() + example.color = value + self.assertEqual(example.color, color) + + example = FuzzyExample() + example.color = value.upper() + self.assertEqual(example.color, color) + + example = FuzzyExample() + example.color = value.lower() + self.assertEqual(example.color, color) + + def test_search_substrings__overwrite(self): + class FuzzyExample(HasTraits): + color = FuzzyEnum(color_choices, help="Color enum", substring_matching=True) + + example = FuzzyExample() + for color in color_choices: + for wlen in range(0, 2): + value = color[wlen:] + + example.color = value + self.assertEqual(example.color, color) + + example.color = value.upper() + self.assertEqual(example.color, color) + + example.color = value.lower() + self.assertEqual(example.color, color) + + def test_search_substrings__ctor(self): + class FuzzyExample(HasTraits): + color = FuzzyEnum(color_choices, help="Color enum", substring_matching=True) + + color = color_choices[-1] # 'YeLLoW' + for end in (-1, len(color)): + for start in range(1, len(color) - 2): + value = color[start:end] + + example = FuzzyExample() + example.color = value + self.assertEqual(example.color, color) + + example = FuzzyExample() + example.color = value.upper() + self.assertEqual(example.color, color) + + def test_assign_other_raises(self): + def new_trait_class(case_sensitive, substring_matching): + class Example(HasTraits): + color = FuzzyEnum( + color_choices, + case_sensitive=case_sensitive, + substring_matching=substring_matching, + ) + + return Example + + example = new_trait_class(case_sensitive=False, substring_matching=False)() + with self.assertRaises(TraitError): + example.color = "" + with self.assertRaises(TraitError): + example.color = "BAD COLOR" + with self.assertRaises(TraitError): + example.color = "ed" + + example = new_trait_class(case_sensitive=True, substring_matching=False)() + with self.assertRaises(TraitError): + example.color = "" + with self.assertRaises(TraitError): + example.color = "Red" # not 'red' + + example = new_trait_class(case_sensitive=True, substring_matching=True)() + with self.assertRaises(TraitError): + example.color = "" + with self.assertRaises(TraitError): + example.color = "BAD COLOR" + with self.assertRaises(TraitError): + example.color = "green" # not 'Green' + with self.assertRaises(TraitError): + example.color = "lue" # not (b)'LUE' + with self.assertRaises(TraitError): + example.color = "lUE" # not (b)'LUE' + + example = new_trait_class(case_sensitive=False, substring_matching=True)() + with self.assertRaises(TraitError): + example.color = "" + with self.assertRaises(TraitError): + example.color = "BAD COLOR" + + def test_ctor_with_default_value(self): + def new_trait_class(default_value, case_sensitive, substring_matching): + class Example(HasTraits): + color = FuzzyEnum( + color_choices, + default_value=default_value, + case_sensitive=case_sensitive, + substring_matching=substring_matching, + ) + + return Example + + for color in color_choices: + example = new_trait_class(color, False, False)() + self.assertEqual(example.color, color) + + example = new_trait_class(color.upper(), False, False)() + self.assertEqual(example.color, color) + + color = color_choices[-1] # 'YeLLoW' + example = new_trait_class(color, True, False)() + self.assertEqual(example.color, color) + + # FIXME: default value not validated! + # with self.assertRaises(TraitError): + # example = new_trait_class(color.lower(), True, False) diff --git a/contrib/python/traitlets/py3/traitlets/tests/utils.py b/contrib/python/traitlets/py3/traitlets/tests/utils.py new file mode 100644 index 0000000000..50360a372f --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/tests/utils.py @@ -0,0 +1,42 @@ +import sys +from subprocess import PIPE, Popen +import os + + +def get_output_error_code(cmd): + """Get stdout, stderr, and exit code from running a command""" + env = os.environ.copy() + env["Y_PYTHON_ENTRY_POINT"] = ":main" + p = Popen(cmd, stdout=PIPE, stderr=PIPE, env=env) + out, err = p.communicate() + out = out.decode("utf8", "replace") # type:ignore + err = err.decode("utf8", "replace") # type:ignore + return out, err, p.returncode + + +def check_help_output(pkg, subcommand=None): + """test that `python -m PKG [subcommand] -h` works""" + cmd = [sys.executable, "-m", pkg] + if subcommand: + cmd.extend(subcommand) + cmd.append("-h") + out, err, rc = get_output_error_code(cmd) + assert rc == 0, err + assert "Traceback" not in err + assert "Options" in out + assert "--help-all" in out + return out, err + + +def check_help_all_output(pkg, subcommand=None): + """test that `python -m PKG --help-all` works""" + cmd = [sys.executable, "-m", pkg] + if subcommand: + cmd.extend(subcommand) + cmd.append("--help-all") + out, err, rc = get_output_error_code(cmd) + assert rc == 0, err + assert "Traceback" not in err + assert "Options" in out + assert "Class options" in out + return out, err diff --git a/contrib/python/traitlets/py3/traitlets/traitlets.py b/contrib/python/traitlets/py3/traitlets/traitlets.py new file mode 100644 index 0000000000..7cf6916856 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/traitlets.py @@ -0,0 +1,3639 @@ +""" +A lightweight Traits like module. + +This is designed to provide a lightweight, simple, pure Python version of +many of the capabilities of enthought.traits. This includes: + +* Validation +* Type specification with defaults +* Static and dynamic notification +* Basic predefined types +* An API that is similar to enthought.traits + +We don't support: + +* Delegation +* Automatic GUI generation +* A full set of trait types. Most importantly, we don't provide container + traits (list, dict, tuple) that can trigger notifications if their + contents change. +* API compatibility with enthought.traits + +There are also some important difference in our design: + +* enthought.traits does not validate default values. We do. + +We choose to create this module because we need these capabilities, but +we need them to be pure Python so they work in all Python implementations, +including Jython and IronPython. + +Inheritance diagram: + +.. inheritance-diagram:: traitlets.traitlets + :parts: 3 +""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +# +# Adapted from enthought.traits, Copyright (c) Enthought, Inc., +# also under the terms of the Modified BSD License. + +import contextlib +import enum +import inspect +import os +import re +import sys +import types +import typing as t +from ast import literal_eval +from warnings import warn, warn_explicit + +from .utils.bunch import Bunch +from .utils.descriptions import add_article, class_of, describe, repr_type +from .utils.getargspec import getargspec +from .utils.importstring import import_item +from .utils.sentinel import Sentinel + +SequenceTypes = (list, tuple, set, frozenset) + +# backward compatibility, use to differ between Python 2 and 3. +ClassTypes = (type,) + +# exports: + +__all__ = [ + "All", + "Any", + "BaseDescriptor", + "Bool", + "Bytes", + "CBool", + "CBytes", + "CComplex", + "CFloat", + "CInt", + "CLong", + "CRegExp", + "CUnicode", + "Callable", + "CaselessStrEnum", + "ClassBasedTraitType", + "Complex", + "Container", + "DefaultHandler", + "Dict", + "DottedObjectName", + "Enum", + "EventHandler", + "Float", + "ForwardDeclaredInstance", + "ForwardDeclaredMixin", + "ForwardDeclaredType", + "FuzzyEnum", + "HasDescriptors", + "HasTraits", + "Instance", + "Int", + "Integer", + "List", + "Long", + "MetaHasDescriptors", + "MetaHasTraits", + "ObjectName", + "ObserveHandler", + "Set", + "TCPAddress", + "This", + "TraitError", + "TraitType", + "Tuple", + "Type", + "Unicode", + "Undefined", + "Union", + "UseEnum", + "ValidateHandler", + "default", + "directional_link", + "dlink", + "link", + "observe", + "observe_compat", + "parse_notifier_name", + "validate", +] + +# any TraitType subclass (that doesn't start with _) will be added automatically + +# ----------------------------------------------------------------------------- +# Basic classes +# ----------------------------------------------------------------------------- + + +Undefined = Sentinel( + "Undefined", + "traitlets", + """ +Used in Traitlets to specify that no defaults are set in kwargs +""", +) + +All = Sentinel( + "All", + "traitlets", + """ +Used in Traitlets to listen to all types of notification or to notifications +from all trait attributes. +""", +) + +# Deprecated alias +NoDefaultSpecified = Undefined + + +class TraitError(Exception): + pass + + +# ----------------------------------------------------------------------------- +# Utilities +# ----------------------------------------------------------------------------- + +_name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$") + + +def isidentifier(s): + return s.isidentifier() + + +_deprecations_shown = set() + + +def _should_warn(key): + """Add our own checks for too many deprecation warnings. + + Limit to once per package. + """ + env_flag = os.environ.get("TRAITLETS_ALL_DEPRECATIONS") + if env_flag and env_flag != "0": + return True + + if key not in _deprecations_shown: + _deprecations_shown.add(key) + return True + else: + return False + + +def _deprecated_method(method, cls, method_name, msg): + """Show deprecation warning about a magic method definition. + + Uses warn_explicit to bind warning to method definition instead of triggering code, + which isn't relevant. + """ + warn_msg = "{classname}.{method_name} is deprecated in traitlets 4.1: {msg}".format( + classname=cls.__name__, method_name=method_name, msg=msg + ) + + for parent in inspect.getmro(cls): + if method_name in parent.__dict__: + cls = parent + break + # limit deprecation messages to once per package + package_name = cls.__module__.split(".", 1)[0] + key = (package_name, msg) + if not _should_warn(key): + return + try: + fname = inspect.getsourcefile(method) or "<unknown>" + lineno = inspect.getsourcelines(method)[1] or 0 + except (OSError, TypeError) as e: + # Failed to inspect for some reason + warn(warn_msg + ("\n(inspection failed) %s" % e), DeprecationWarning) + else: + warn_explicit(warn_msg, DeprecationWarning, fname, lineno) + + +def _safe_literal_eval(s): + """Safely evaluate an expression + + Returns original string if eval fails. + + Use only where types are ambiguous. + """ + try: + return literal_eval(s) + except (NameError, SyntaxError, ValueError): + return s + + +def is_trait(t): + """Returns whether the given value is an instance or subclass of TraitType.""" + return isinstance(t, TraitType) or (isinstance(t, type) and issubclass(t, TraitType)) + + +def parse_notifier_name(names): + """Convert the name argument to a list of names. + + Examples + -------- + >>> parse_notifier_name([]) + [traitlets.All] + >>> parse_notifier_name("a") + ['a'] + >>> parse_notifier_name(["a", "b"]) + ['a', 'b'] + >>> parse_notifier_name(All) + [traitlets.All] + """ + if names is All or isinstance(names, str): + return [names] + else: + if not names or All in names: + return [All] + for n in names: + if not isinstance(n, str): + raise TypeError("names must be strings, not %r" % n) + return names + + +class _SimpleTest: + def __init__(self, value): + self.value = value + + def __call__(self, test): + return test == self.value + + def __repr__(self): + return "<SimpleTest(%r)" % self.value + + def __str__(self): + return self.__repr__() + + +def getmembers(object, predicate=None): + """A safe version of inspect.getmembers that handles missing attributes. + + This is useful when there are descriptor based attributes that for + some reason raise AttributeError even though they exist. This happens + in zope.inteface with the __provides__ attribute. + """ + results = [] + for key in dir(object): + try: + value = getattr(object, key) + except AttributeError: + pass + else: + if not predicate or predicate(value): + results.append((key, value)) + results.sort() + return results + + +def _validate_link(*tuples): + """Validate arguments for traitlet link functions""" + for tup in tuples: + if not len(tup) == 2: + raise TypeError( + "Each linked traitlet must be specified as (HasTraits, 'trait_name'), not %r" % t + ) + obj, trait_name = tup + if not isinstance(obj, HasTraits): + raise TypeError("Each object must be HasTraits, not %r" % type(obj)) + if trait_name not in obj.traits(): + raise TypeError(f"{obj!r} has no trait {trait_name!r}") + + +class link: + """Link traits from different objects together so they remain in sync. + + Parameters + ---------- + source : (object / attribute name) pair + target : (object / attribute name) pair + transform: iterable with two callables (optional) + Data transformation between source and target and target and source. + + Examples + -------- + >>> class X(HasTraits): + ... value = Int() + + >>> src = X(value=1) + >>> tgt = X(value=42) + >>> c = link((src, "value"), (tgt, "value")) + + Setting source updates target objects: + >>> src.value = 5 + >>> tgt.value + 5 + """ + + updating = False + + def __init__(self, source, target, transform=None): + _validate_link(source, target) + self.source, self.target = source, target + self._transform, self._transform_inv = transform if transform else (lambda x: x,) * 2 + + self.link() + + def link(self): + try: + setattr( + self.target[0], + self.target[1], + self._transform(getattr(self.source[0], self.source[1])), + ) + + finally: + self.source[0].observe(self._update_target, names=self.source[1]) + self.target[0].observe(self._update_source, names=self.target[1]) + + @contextlib.contextmanager + def _busy_updating(self): + self.updating = True + try: + yield + finally: + self.updating = False + + def _update_target(self, change): + if self.updating: + return + with self._busy_updating(): + setattr(self.target[0], self.target[1], self._transform(change.new)) + if getattr(self.source[0], self.source[1]) != change.new: + raise TraitError( + "Broken link {}: the source value changed while updating " + "the target.".format(self) + ) + + def _update_source(self, change): + if self.updating: + return + with self._busy_updating(): + setattr(self.source[0], self.source[1], self._transform_inv(change.new)) + if getattr(self.target[0], self.target[1]) != change.new: + raise TraitError( + "Broken link {}: the target value changed while updating " + "the source.".format(self) + ) + + def unlink(self): + self.source[0].unobserve(self._update_target, names=self.source[1]) + self.target[0].unobserve(self._update_source, names=self.target[1]) + + +class directional_link: + """Link the trait of a source object with traits of target objects. + + Parameters + ---------- + source : (object, attribute name) pair + target : (object, attribute name) pair + transform: callable (optional) + Data transformation between source and target. + + Examples + -------- + >>> class X(HasTraits): + ... value = Int() + + >>> src = X(value=1) + >>> tgt = X(value=42) + >>> c = directional_link((src, "value"), (tgt, "value")) + + Setting source updates target objects: + >>> src.value = 5 + >>> tgt.value + 5 + + Setting target does not update source object: + >>> tgt.value = 6 + >>> src.value + 5 + + """ + + updating = False + + def __init__(self, source, target, transform=None): + self._transform = transform if transform else lambda x: x + _validate_link(source, target) + self.source, self.target = source, target + self.link() + + def link(self): + try: + setattr( + self.target[0], + self.target[1], + self._transform(getattr(self.source[0], self.source[1])), + ) + finally: + self.source[0].observe(self._update, names=self.source[1]) + + @contextlib.contextmanager + def _busy_updating(self): + self.updating = True + try: + yield + finally: + self.updating = False + + def _update(self, change): + if self.updating: + return + with self._busy_updating(): + setattr(self.target[0], self.target[1], self._transform(change.new)) + + def unlink(self): + self.source[0].unobserve(self._update, names=self.source[1]) + + +dlink = directional_link + + +# ----------------------------------------------------------------------------- +# Base Descriptor Class +# ----------------------------------------------------------------------------- + + +class BaseDescriptor: + """Base descriptor class + + Notes + ----- + This implements Python's descriptor protocol. + + This class is the base class for all such descriptors. The + only magic we use is a custom metaclass for the main :class:`HasTraits` + class that does the following: + + 1. Sets the :attr:`name` attribute of every :class:`BaseDescriptor` + instance in the class dict to the name of the attribute. + 2. Sets the :attr:`this_class` attribute of every :class:`BaseDescriptor` + instance in the class dict to the *class* that declared the trait. + This is used by the :class:`This` trait to allow subclasses to + accept superclasses for :class:`This` values. + """ + + name: t.Optional[str] = None + this_class: t.Optional[t.Type[t.Any]] = None + + def class_init(self, cls, name): + """Part of the initialization which may depend on the underlying + HasDescriptors class. + + It is typically overloaded for specific types. + + This method is called by :meth:`MetaHasDescriptors.__init__` + passing the class (`cls`) and `name` under which the descriptor + has been assigned. + """ + self.this_class = cls + self.name = name + + def subclass_init(self, cls): + # Instead of HasDescriptors.setup_instance calling + # every instance_init, we opt in by default. + # This gives descriptors a change to opt out for + # performance reasons. + # Because most traits do not need instance_init, + # and it will otherwise be called for every HasTrait instance + # beging created, this otherwise gives a significant performance + # pentalty. Most TypeTraits in traitlets opt out. + cls._instance_inits.append(self.instance_init) + + def instance_init(self, obj): + """Part of the initialization which may depend on the underlying + HasDescriptors instance. + + It is typically overloaded for specific types. + + This method is called by :meth:`HasTraits.__new__` and in the + :meth:`BaseDescriptor.instance_init` method of descriptors holding + other descriptors. + """ + pass + + +class TraitType(BaseDescriptor): + """A base class for all trait types.""" + + metadata: t.Dict[str, t.Any] = {} + allow_none = False + read_only = False + info_text = "any value" + default_value: t.Optional[t.Any] = Undefined + + def __init__( + self, + default_value=Undefined, + allow_none=False, + read_only=None, + help=None, + config=None, + **kwargs, + ): + """Declare a traitlet. + + If *allow_none* is True, None is a valid value in addition to any + values that are normally valid. The default is up to the subclass. + For most trait types, the default value for ``allow_none`` is False. + + If *read_only* is True, attempts to directly modify a trait attribute raises a TraitError. + + Extra metadata can be associated with the traitlet using the .tag() convenience method + or by using the traitlet instance's .metadata dictionary. + """ + if default_value is not Undefined: + self.default_value = default_value + if allow_none: + self.allow_none = allow_none + if read_only is not None: + self.read_only = read_only + self.help = help if help is not None else "" + if self.help: + # define __doc__ so that inspectors like autodoc find traits + self.__doc__ = self.help + + if len(kwargs) > 0: + stacklevel = 1 + f = inspect.currentframe() + # count supers to determine stacklevel for warning + assert f is not None + while f.f_code.co_name == "__init__": + stacklevel += 1 + f = f.f_back + assert f is not None + mod = f.f_globals.get("__name__") or "" + pkg = mod.split(".", 1)[0] + key = tuple(["metadata-tag", pkg] + sorted(kwargs)) + if _should_warn(key): + warn( + "metadata %s was set from the constructor. " + "With traitlets 4.1, metadata should be set using the .tag() method, " + "e.g., Int().tag(key1='value1', key2='value2')" % (kwargs,), + DeprecationWarning, + stacklevel=stacklevel, + ) + if len(self.metadata) > 0: + self.metadata = self.metadata.copy() + self.metadata.update(kwargs) + else: + self.metadata = kwargs + else: + self.metadata = self.metadata.copy() + if config is not None: + self.metadata["config"] = config + + # We add help to the metadata during a deprecation period so that + # code that looks for the help string there can find it. + if help is not None: + self.metadata["help"] = help + + def from_string(self, s): + """Get a value from a config string + + such as an environment variable or CLI arguments. + + Traits can override this method to define their own + parsing of config strings. + + .. seealso:: item_from_string + + .. versionadded:: 5.0 + """ + if self.allow_none and s == "None": + return None + return s + + def default(self, obj=None): + """The default generator for this trait + + Notes + ----- + This method is registered to HasTraits classes during ``class_init`` + in the same way that dynamic defaults defined by ``@default`` are. + """ + if self.default_value is not Undefined: + return self.default_value + elif hasattr(self, "make_dynamic_default"): + return self.make_dynamic_default() + else: + # Undefined will raise in TraitType.get + return self.default_value + + def get_default_value(self): + """DEPRECATED: Retrieve the static default value for this trait. + Use self.default_value instead + """ + warn( + "get_default_value is deprecated in traitlets 4.0: use the .default_value attribute", + DeprecationWarning, + stacklevel=2, + ) + return self.default_value + + def init_default_value(self, obj): + """DEPRECATED: Set the static default value for the trait type.""" + warn( + "init_default_value is deprecated in traitlets 4.0, and may be removed in the future", + DeprecationWarning, + stacklevel=2, + ) + value = self._validate(obj, self.default_value) + obj._trait_values[self.name] = value + return value + + def get(self, obj, cls=None): + try: + value = obj._trait_values[self.name] + except KeyError: + # Check for a dynamic initializer. + default = obj.trait_defaults(self.name) + if default is Undefined: + warn( + "Explicit using of Undefined as the default value " + "is deprecated in traitlets 5.0, and may cause " + "exceptions in the future.", + DeprecationWarning, + stacklevel=2, + ) + # Using a context manager has a large runtime overhead, so we + # write out the obj.cross_validation_lock call here. + _cross_validation_lock = obj._cross_validation_lock + try: + obj._cross_validation_lock = True + value = self._validate(obj, default) + finally: + obj._cross_validation_lock = _cross_validation_lock + obj._trait_values[self.name] = value + obj._notify_observers( + Bunch( + name=self.name, + value=value, + owner=obj, + type="default", + ) + ) + return value + except Exception as e: + # This should never be reached. + raise TraitError("Unexpected error in TraitType: default value not set properly") from e + else: + return value + + def __get__(self, obj, cls=None): + """Get the value of the trait by self.name for the instance. + + Default values are instantiated when :meth:`HasTraits.__new__` + is called. Thus by the time this method gets called either the + default value or a user defined value (they called :meth:`__set__`) + is in the :class:`HasTraits` instance. + """ + if obj is None: + return self + else: + return self.get(obj, cls) + + def set(self, obj, value): + new_value = self._validate(obj, value) + try: + old_value = obj._trait_values[self.name] + except KeyError: + old_value = self.default_value + + obj._trait_values[self.name] = new_value + try: + silent = bool(old_value == new_value) + except Exception: + # if there is an error in comparing, default to notify + silent = False + if silent is not True: + # we explicitly compare silent to True just in case the equality + # comparison above returns something other than True/False + obj._notify_trait(self.name, old_value, new_value) + + def __set__(self, obj, value): + """Set the value of the trait by self.name for the instance. + + Values pass through a validation stage where errors are raised when + impropper types, or types that cannot be coerced, are encountered. + """ + if self.read_only: + raise TraitError('The "%s" trait is read-only.' % self.name) + else: + self.set(obj, value) + + def _validate(self, obj, value): + if value is None and self.allow_none: + return value + if hasattr(self, "validate"): + value = self.validate(obj, value) + if obj._cross_validation_lock is False: + value = self._cross_validate(obj, value) + return value + + def _cross_validate(self, obj, value): + if self.name in obj._trait_validators: + proposal = Bunch({"trait": self, "value": value, "owner": obj}) + value = obj._trait_validators[self.name](obj, proposal) + elif hasattr(obj, "_%s_validate" % self.name): + meth_name = "_%s_validate" % self.name + cross_validate = getattr(obj, meth_name) + _deprecated_method( + cross_validate, + obj.__class__, + meth_name, + "use @validate decorator instead.", + ) + value = cross_validate(value, self) + return value + + def __or__(self, other): + if isinstance(other, Union): + return Union([self] + other.trait_types) + else: + return Union([self, other]) + + def info(self): + return self.info_text + + def error(self, obj, value, error=None, info=None): + """Raise a TraitError + + Parameters + ---------- + obj : HasTraits or None + The instance which owns the trait. If not + object is given, then an object agnostic + error will be raised. + value : any + The value that caused the error. + error : Exception (default: None) + An error that was raised by a child trait. + The arguments of this exception should be + of the form ``(value, info, *traits)``. + Where the ``value`` and ``info`` are the + problem value, and string describing the + expected value. The ``traits`` are a series + of :class:`TraitType` instances that are + "children" of this one (the first being + the deepest). + info : str (default: None) + A description of the expected value. By + default this is infered from this trait's + ``info`` method. + """ + if error is not None: + # handle nested error + error.args += (self,) + if self.name is not None: + # this is the root trait that must format the final message + chain = " of ".join(describe("a", t) for t in error.args[2:]) + if obj is not None: + error.args = ( + "The '%s' trait of %s instance contains %s which " + "expected %s, not %s." + % ( + self.name, + describe("an", obj), + chain, + error.args[1], + describe("the", error.args[0]), + ), + ) + else: + error.args = ( + "The '%s' trait contains %s which " + "expected %s, not %s." + % ( + self.name, + chain, + error.args[1], + describe("the", error.args[0]), + ), + ) + raise error + else: + # this trait caused an error + if self.name is None: + # this is not the root trait + raise TraitError(value, info or self.info(), self) + else: + # this is the root trait + if obj is not None: + e = "The '{}' trait of {} instance expected {}, not {}.".format( + self.name, + class_of(obj), + self.info(), + describe("the", value), + ) + else: + e = "The '{}' trait expected {}, not {}.".format( + self.name, + self.info(), + describe("the", value), + ) + raise TraitError(e) + + def get_metadata(self, key, default=None): + """DEPRECATED: Get a metadata value. + + Use .metadata[key] or .metadata.get(key, default) instead. + """ + if key == "help": + msg = "use the instance .help string directly, like x.help" + else: + msg = "use the instance .metadata dictionary directly, like x.metadata[key] or x.metadata.get(key, default)" + warn("Deprecated in traitlets 4.1, " + msg, DeprecationWarning, stacklevel=2) + return self.metadata.get(key, default) + + def set_metadata(self, key, value): + """DEPRECATED: Set a metadata key/value. + + Use .metadata[key] = value instead. + """ + if key == "help": + msg = "use the instance .help string directly, like x.help = value" + else: + msg = "use the instance .metadata dictionary directly, like x.metadata[key] = value" + warn("Deprecated in traitlets 4.1, " + msg, DeprecationWarning, stacklevel=2) + self.metadata[key] = value + + def tag(self, **metadata): + """Sets metadata and returns self. + + This allows convenient metadata tagging when initializing the trait, such as: + + Examples + -------- + >>> Int(0).tag(config=True, sync=True) + <traitlets.traitlets.Int object at ...> + + """ + maybe_constructor_keywords = set(metadata.keys()).intersection( + {"help", "allow_none", "read_only", "default_value"} + ) + if maybe_constructor_keywords: + warn( + "The following attributes are set in using `tag`, but seem to be constructor keywords arguments: %s " + % maybe_constructor_keywords, + UserWarning, + stacklevel=2, + ) + + self.metadata.update(metadata) + return self + + def default_value_repr(self): + return repr(self.default_value) + + +# ----------------------------------------------------------------------------- +# The HasTraits implementation +# ----------------------------------------------------------------------------- + + +class _CallbackWrapper: + """An object adapting a on_trait_change callback into an observe callback. + + The comparison operator __eq__ is implemented to enable removal of wrapped + callbacks. + """ + + def __init__(self, cb): + self.cb = cb + # Bound methods have an additional 'self' argument. + offset = -1 if isinstance(self.cb, types.MethodType) else 0 + self.nargs = len(getargspec(cb)[0]) + offset + if self.nargs > 4: + raise TraitError("a trait changed callback must have 0-4 arguments.") + + def __eq__(self, other): + # The wrapper is equal to the wrapped element + if isinstance(other, _CallbackWrapper): + return self.cb == other.cb + else: + return self.cb == other + + def __call__(self, change): + # The wrapper is callable + if self.nargs == 0: + self.cb() + elif self.nargs == 1: + self.cb(change.name) + elif self.nargs == 2: + self.cb(change.name, change.new) + elif self.nargs == 3: + self.cb(change.name, change.old, change.new) + elif self.nargs == 4: + self.cb(change.name, change.old, change.new, change.owner) + + +def _callback_wrapper(cb): + if isinstance(cb, _CallbackWrapper): + return cb + else: + return _CallbackWrapper(cb) + + +class MetaHasDescriptors(type): + """A metaclass for HasDescriptors. + + This metaclass makes sure that any TraitType class attributes are + instantiated and sets their name attribute. + """ + + def __new__(mcls, name, bases, classdict): # noqa + """Create the HasDescriptors class.""" + for k, v in classdict.items(): + # ---------------------------------------------------------------- + # Support of deprecated behavior allowing for TraitType types + # to be used instead of TraitType instances. + if inspect.isclass(v) and issubclass(v, TraitType): + warn( + "Traits should be given as instances, not types (for example, `Int()`, not `Int`)." + " Passing types is deprecated in traitlets 4.1.", + DeprecationWarning, + stacklevel=2, + ) + classdict[k] = v() + # ---------------------------------------------------------------- + + return super().__new__(mcls, name, bases, classdict) + + def __init__(cls, name, bases, classdict): + """Finish initializing the HasDescriptors class.""" + super().__init__(name, bases, classdict) + cls.setup_class(classdict) + + def setup_class(cls, classdict): + """Setup descriptor instance on the class + + This sets the :attr:`this_class` and :attr:`name` attributes of each + BaseDescriptor in the class dict of the newly created ``cls`` before + calling their :attr:`class_init` method. + """ + cls._descriptors = [] + cls._instance_inits = [] + for k, v in classdict.items(): + if isinstance(v, BaseDescriptor): + v.class_init(cls, k) + + for _, v in getmembers(cls): + if isinstance(v, BaseDescriptor): + v.subclass_init(cls) + cls._descriptors.append(v) + + +class MetaHasTraits(MetaHasDescriptors): + """A metaclass for HasTraits.""" + + def setup_class(cls, classdict): # noqa + # for only the current class + cls._trait_default_generators = {} + # also looking at base classes + cls._all_trait_default_generators = {} + cls._traits = {} + cls._static_immutable_initial_values = {} + + super().setup_class(classdict) + + mro = cls.mro() + + for name in dir(cls): + # Some descriptors raise AttributeError like zope.interface's + # __provides__ attributes even though they exist. This causes + # AttributeErrors even though they are listed in dir(cls). + try: + value = getattr(cls, name) + except AttributeError: + continue + if isinstance(value, TraitType): + cls._traits[name] = value + trait = value + default_method_name = "_%s_default" % name + mro_trait = mro + try: + mro_trait = mro[: mro.index(trait.this_class) + 1] # type:ignore[arg-type] + except ValueError: + # this_class not in mro + pass + for c in mro_trait: + if default_method_name in c.__dict__: + cls._all_trait_default_generators[name] = c.__dict__[default_method_name] + break + if name in c.__dict__.get("_trait_default_generators", {}): + cls._all_trait_default_generators[name] = c._trait_default_generators[name] # type: ignore[attr-defined] + break + else: + # We don't have a dynamic default generator using @default etc. + # Now if the default value is not dynamic and immutable (string, number) + # and does not require any validation, we keep them in a dict + # of initial values to speed up instance creation. + # This is a very specific optimization, but a very common scenario in + # for instance ipywidgets. + none_ok = trait.default_value is None and trait.allow_none + if ( + type(trait) in [CInt, Int] + and trait.min is None # type: ignore[attr-defined] + and trait.max is None # type: ignore[attr-defined] + and (isinstance(trait.default_value, int) or none_ok) + ): + cls._static_immutable_initial_values[name] = trait.default_value + elif ( + type(trait) in [CFloat, Float] + and trait.min is None # type: ignore[attr-defined] + and trait.max is None # type: ignore[attr-defined] + and (isinstance(trait.default_value, float) or none_ok) + ): + cls._static_immutable_initial_values[name] = trait.default_value + elif type(trait) in [CBool, Bool] and ( + isinstance(trait.default_value, bool) or none_ok + ): + cls._static_immutable_initial_values[name] = trait.default_value + elif type(trait) in [CUnicode, Unicode] and ( + isinstance(trait.default_value, str) or none_ok + ): + cls._static_immutable_initial_values[name] = trait.default_value + elif type(trait) == Any and ( + isinstance(trait.default_value, (str, int, float, bool)) or none_ok + ): + cls._static_immutable_initial_values[name] = trait.default_value + elif type(trait) == Union and trait.default_value is None: + cls._static_immutable_initial_values[name] = None + elif ( + isinstance(trait, Instance) + and trait.default_args is None + and trait.default_kwargs is None + and trait.allow_none + ): + cls._static_immutable_initial_values[name] = None + + # we always add it, because a class may change when we call add_trait + # and then the instance may not have all the _static_immutable_initial_values + cls._all_trait_default_generators[name] = trait.default + + +def observe(*names: t.Union[Sentinel, str], type: str = "change") -> "ObserveHandler": + """A decorator which can be used to observe Traits on a class. + + The handler passed to the decorator will be called with one ``change`` + dict argument. The change dictionary at least holds a 'type' key and a + 'name' key, corresponding respectively to the type of notification and the + name of the attribute that triggered the notification. + + Other keys may be passed depending on the value of 'type'. In the case + where type is 'change', we also have the following keys: + * ``owner`` : the HasTraits instance + * ``old`` : the old value of the modified trait attribute + * ``new`` : the new value of the modified trait attribute + * ``name`` : the name of the modified trait attribute. + + Parameters + ---------- + *names + The str names of the Traits to observe on the object. + type : str, kwarg-only + The type of event to observe (e.g. 'change') + """ + if not names: + raise TypeError("Please specify at least one trait name to observe.") + for name in names: + if name is not All and not isinstance(name, str): + raise TypeError("trait names to observe must be strings or All, not %r" % name) + return ObserveHandler(names, type=type) + + +def observe_compat(func): + """Backward-compatibility shim decorator for observers + + Use with: + + @observe('name') + @observe_compat + def _foo_changed(self, change): + ... + + With this, `super()._foo_changed(self, name, old, new)` in subclasses will still work. + Allows adoption of new observer API without breaking subclasses that override and super. + """ + + def compatible_observer(self, change_or_name, old=Undefined, new=Undefined): + if isinstance(change_or_name, dict): + change = change_or_name + else: + clsname = self.__class__.__name__ + warn( + "A parent of %s._%s_changed has adopted the new (traitlets 4.1) @observe(change) API" + % (clsname, change_or_name), + DeprecationWarning, + ) + change = Bunch( + type="change", + old=old, + new=new, + name=change_or_name, + owner=self, + ) + return func(self, change) + + return compatible_observer + + +def validate(*names: t.Union[Sentinel, str]) -> "ValidateHandler": + """A decorator to register cross validator of HasTraits object's state + when a Trait is set. + + The handler passed to the decorator must have one ``proposal`` dict argument. + The proposal dictionary must hold the following keys: + + * ``owner`` : the HasTraits instance + * ``value`` : the proposed value for the modified trait attribute + * ``trait`` : the TraitType instance associated with the attribute + + Parameters + ---------- + *names + The str names of the Traits to validate. + + Notes + ----- + Since the owner has access to the ``HasTraits`` instance via the 'owner' key, + the registered cross validator could potentially make changes to attributes + of the ``HasTraits`` instance. However, we recommend not to do so. The reason + is that the cross-validation of attributes may run in arbitrary order when + exiting the ``hold_trait_notifications`` context, and such changes may not + commute. + """ + if not names: + raise TypeError("Please specify at least one trait name to validate.") + for name in names: + if name is not All and not isinstance(name, str): + raise TypeError("trait names to validate must be strings or All, not %r" % name) + return ValidateHandler(names) + + +def default(name: str) -> "DefaultHandler": + """A decorator which assigns a dynamic default for a Trait on a HasTraits object. + + Parameters + ---------- + name + The str name of the Trait on the object whose default should be generated. + + Notes + ----- + Unlike observers and validators which are properties of the HasTraits + instance, default value generators are class-level properties. + + Besides, default generators are only invoked if they are registered in + subclasses of `this_type`. + + :: + + class A(HasTraits): + bar = Int() + + @default('bar') + def get_bar_default(self): + return 11 + + class B(A): + bar = Float() # This trait ignores the default generator defined in + # the base class A + + class C(B): + + @default('bar') + def some_other_default(self): # This default generator should not be + return 3.0 # ignored since it is defined in a + # class derived from B.a.this_class. + """ + if not isinstance(name, str): + raise TypeError("Trait name must be a string or All, not %r" % name) + return DefaultHandler(name) + + +class EventHandler(BaseDescriptor): + def _init_call(self, func): + self.func = func + return self + + def __call__(self, *args, **kwargs): + """Pass `*args` and `**kwargs` to the handler's function if it exists.""" + if hasattr(self, "func"): + return self.func(*args, **kwargs) + else: + return self._init_call(*args, **kwargs) + + def __get__(self, inst, cls=None): + if inst is None: + return self + return types.MethodType(self.func, inst) + + +class ObserveHandler(EventHandler): + def __init__(self, names, type): + self.trait_names = names + self.type = type + + def instance_init(self, inst): + inst.observe(self, self.trait_names, type=self.type) + + +class ValidateHandler(EventHandler): + def __init__(self, names): + self.trait_names = names + + def instance_init(self, inst): + inst._register_validator(self, self.trait_names) + + +class DefaultHandler(EventHandler): + def __init__(self, name): + self.trait_name = name + + def class_init(self, cls, name): + super().class_init(cls, name) + cls._trait_default_generators[self.trait_name] = self + + +class HasDescriptors(metaclass=MetaHasDescriptors): + """The base class for all classes that have descriptors.""" + + def __new__(*args: t.Any, **kwargs: t.Any) -> t.Any: + # Pass cls as args[0] to allow "cls" as keyword argument + cls = args[0] + args = args[1:] + + # This is needed because object.__new__ only accepts + # the cls argument. + new_meth = super(HasDescriptors, cls).__new__ + if new_meth is object.__new__: + inst = new_meth(cls) + else: + inst = new_meth(cls, *args, **kwargs) + inst.setup_instance(*args, **kwargs) + return inst + + def setup_instance(*args, **kwargs): + """ + This is called **before** self.__init__ is called. + """ + # Pass self as args[0] to allow "self" as keyword argument + self = args[0] + args = args[1:] + + self._cross_validation_lock = False # type:ignore[attr-defined] + cls = self.__class__ + # Let descriptors performance initialization when a HasDescriptor + # instance is created. This allows registration of observers and + # default creations or other bookkeepings. + # Note that descriptors can opt-out of this behavior by overriding + # subclass_init. + for init in cls._instance_inits: + init(self) + + +class HasTraits(HasDescriptors, metaclass=MetaHasTraits): + _trait_values: t.Dict[str, t.Any] + _static_immutable_initial_values: t.Dict[str, t.Any] + _trait_notifiers: t.Dict[str, t.Any] + _trait_validators: t.Dict[str, t.Any] + _cross_validation_lock: bool + _traits: t.Dict[str, t.Any] + _all_trait_default_generators: t.Dict[str, t.Any] + + def setup_instance(*args, **kwargs): + # Pass self as args[0] to allow "self" as keyword argument + self = args[0] + args = args[1:] + + # although we'd prefer to set only the initial values not present + # in kwargs, we will overwrite them in `__init__`, and simply making + # a copy of a dict is faster than checking for each key. + self._trait_values = self._static_immutable_initial_values.copy() + self._trait_notifiers = {} + self._trait_validators = {} + self._cross_validation_lock = False + super(HasTraits, self).setup_instance(*args, **kwargs) + + def __init__(self, *args, **kwargs): + # Allow trait values to be set using keyword arguments. + # We need to use setattr for this to trigger validation and + # notifications. + super_args = args + super_kwargs = {} + + if kwargs: + # this is a simplified (and faster) version of + # the hold_trait_notifications(self) context manager + def ignore(*_ignore_args): + pass + + self.notify_change = ignore # type:ignore[assignment] + self._cross_validation_lock = True + changes = {} + for key, value in kwargs.items(): + if self.has_trait(key): + setattr(self, key, value) + changes[key] = Bunch( + name=key, + old=None, + new=value, + owner=self, + type="change", + ) + else: + # passthrough args that don't set traits to super + super_kwargs[key] = value + # notify and cross validate all trait changes that were set in kwargs + changed = set(kwargs) & set(self._traits) + for key in changed: + value = self._traits[key]._cross_validate(self, getattr(self, key)) + self.set_trait(key, value) + changes[key]['new'] = value + self._cross_validation_lock = False + # Restore method retrieval from class + del self.notify_change + for key in changed: + self.notify_change(changes[key]) + + try: + super().__init__(*super_args, **super_kwargs) + except TypeError as e: + arg_s_list = [repr(arg) for arg in super_args] + for k, v in super_kwargs.items(): + arg_s_list.append(f"{k}={v!r}") + arg_s = ", ".join(arg_s_list) + warn( + "Passing unrecognized arguments to super({classname}).__init__({arg_s}).\n" + "{error}\n" + "This is deprecated in traitlets 4.2." + "This error will be raised in a future release of traitlets.".format( + arg_s=arg_s, + classname=self.__class__.__name__, + error=e, + ), + DeprecationWarning, + stacklevel=2, + ) + + def __getstate__(self): + d = self.__dict__.copy() + # event handlers stored on an instance are + # expected to be reinstantiated during a + # recall of instance_init during __setstate__ + d["_trait_notifiers"] = {} + d["_trait_validators"] = {} + d["_trait_values"] = self._trait_values.copy() + d["_cross_validation_lock"] = False # FIXME: raise if cloning locked! + + return d + + def __setstate__(self, state): + self.__dict__ = state.copy() + + # event handlers are reassigned to self + cls = self.__class__ + for key in dir(cls): + # Some descriptors raise AttributeError like zope.interface's + # __provides__ attributes even though they exist. This causes + # AttributeErrors even though they are listed in dir(cls). + try: + value = getattr(cls, key) + except AttributeError: + pass + else: + if isinstance(value, EventHandler): + value.instance_init(self) + + @property + @contextlib.contextmanager + def cross_validation_lock(self): + """ + A contextmanager for running a block with our cross validation lock set + to True. + + At the end of the block, the lock's value is restored to its value + prior to entering the block. + """ + if self._cross_validation_lock: + yield + return + else: + try: + self._cross_validation_lock = True + yield + finally: + self._cross_validation_lock = False + + @contextlib.contextmanager + def hold_trait_notifications(self): + """Context manager for bundling trait change notifications and cross + validation. + + Use this when doing multiple trait assignments (init, config), to avoid + race conditions in trait notifiers requesting other trait values. + All trait notifications will fire after all values have been assigned. + """ + if self._cross_validation_lock: + yield + return + else: + cache: t.Dict[str, t.Any] = {} + + def compress(past_changes, change): + """Merges the provided change with the last if possible.""" + if past_changes is None: + return [change] + else: + if past_changes[-1]["type"] == "change" and change.type == "change": + past_changes[-1]["new"] = change.new + else: + # In case of changes other than 'change', append the notification. + past_changes.append(change) + return past_changes + + def hold(change): + name = change.name + cache[name] = compress(cache.get(name), change) + + try: + # Replace notify_change with `hold`, caching and compressing + # notifications, disable cross validation and yield. + self.notify_change = hold # type:ignore[assignment] + self._cross_validation_lock = True + yield + # Cross validate final values when context is released. + for name in list(cache.keys()): + trait = getattr(self.__class__, name) + value = trait._cross_validate(self, getattr(self, name)) + self.set_trait(name, value) + except TraitError as e: + # Roll back in case of TraitError during final cross validation. + self.notify_change = lambda x: None # type:ignore[assignment] + for name, changes in cache.items(): + for change in changes[::-1]: + # TODO: Separate in a rollback function per notification type. + if change.type == "change": + if change.old is not Undefined: + self.set_trait(name, change.old) + else: + self._trait_values.pop(name) + cache = {} + raise e + finally: + self._cross_validation_lock = False + # Restore method retrieval from class + del self.notify_change + + # trigger delayed notifications + for changes in cache.values(): + for change in changes: + self.notify_change(change) + + def _notify_trait(self, name, old_value, new_value): + self.notify_change( + Bunch( + name=name, + old=old_value, + new=new_value, + owner=self, + type="change", + ) + ) + + def notify_change(self, change): + """Notify observers of a change event""" + return self._notify_observers(change) + + def _notify_observers(self, event): + """Notify observers of any event""" + if not isinstance(event, Bunch): + # cast to bunch if given a dict + event = Bunch(event) + name, type = event['name'], event['type'] + + callables = [] + if name in self._trait_notifiers: + callables.extend(self._trait_notifiers.get(name, {}).get(type, [])) + callables.extend(self._trait_notifiers.get(name, {}).get(All, [])) + if All in self._trait_notifiers: # type:ignore[comparison-overlap] + callables.extend( + self._trait_notifiers.get(All, {}).get(type, []) # type:ignore[call-overload] + ) + callables.extend( + self._trait_notifiers.get(All, {}).get(All, []) # type:ignore[call-overload] + ) + + # Now static ones + magic_name = "_%s_changed" % name + if event['type'] == "change" and hasattr(self, magic_name): + class_value = getattr(self.__class__, magic_name) + if not isinstance(class_value, ObserveHandler): + _deprecated_method( + class_value, + self.__class__, + magic_name, + "use @observe and @unobserve instead.", + ) + cb = getattr(self, magic_name) + # Only append the magic method if it was not manually registered + if cb not in callables: + callables.append(_callback_wrapper(cb)) + + # Call them all now + # Traits catches and logs errors here. I allow them to raise + for c in callables: + # Bound methods have an additional 'self' argument. + + if isinstance(c, _CallbackWrapper): + c = c.__call__ + elif isinstance(c, EventHandler) and c.name is not None: + c = getattr(self, c.name) + + c(event) + + def _add_notifiers(self, handler, name, type): + if name not in self._trait_notifiers: + nlist: t.List[t.Any] = [] + self._trait_notifiers[name] = {type: nlist} + else: + if type not in self._trait_notifiers[name]: + nlist = [] + self._trait_notifiers[name][type] = nlist + else: + nlist = self._trait_notifiers[name][type] + if handler not in nlist: + nlist.append(handler) + + def _remove_notifiers(self, handler, name, type): + try: + if handler is None: + del self._trait_notifiers[name][type] + else: + self._trait_notifiers[name][type].remove(handler) + except KeyError: + pass + + def on_trait_change(self, handler=None, name=None, remove=False): + """DEPRECATED: Setup a handler to be called when a trait changes. + + This is used to setup dynamic notifications of trait changes. + + Static handlers can be created by creating methods on a HasTraits + subclass with the naming convention '_[traitname]_changed'. Thus, + to create static handler for the trait 'a', create the method + _a_changed(self, name, old, new) (fewer arguments can be used, see + below). + + If `remove` is True and `handler` is not specified, all change + handlers for the specified name are uninstalled. + + Parameters + ---------- + handler : callable, None + A callable that is called when a trait changes. Its + signature can be handler(), handler(name), handler(name, new), + handler(name, old, new), or handler(name, old, new, self). + name : list, str, None + If None, the handler will apply to all traits. If a list + of str, handler will apply to all names in the list. If a + str, the handler will apply just to that name. + remove : bool + If False (the default), then install the handler. If True + then unintall it. + """ + warn( + "on_trait_change is deprecated in traitlets 4.1: use observe instead", + DeprecationWarning, + stacklevel=2, + ) + if name is None: + name = All + if remove: + self.unobserve(_callback_wrapper(handler), names=name) + else: + self.observe(_callback_wrapper(handler), names=name) + + def observe(self, handler, names=All, type="change"): + """Setup a handler to be called when a trait changes. + + This is used to setup dynamic notifications of trait changes. + + Parameters + ---------- + handler : callable + A callable that is called when a trait changes. Its + signature should be ``handler(change)``, where ``change`` is a + dictionary. The change dictionary at least holds a 'type' key. + * ``type``: the type of notification. + Other keys may be passed depending on the value of 'type'. In the + case where type is 'change', we also have the following keys: + * ``owner`` : the HasTraits instance + * ``old`` : the old value of the modified trait attribute + * ``new`` : the new value of the modified trait attribute + * ``name`` : the name of the modified trait attribute. + names : list, str, All + If names is All, the handler will apply to all traits. If a list + of str, handler will apply to all names in the list. If a + str, the handler will apply just to that name. + type : str, All (default: 'change') + The type of notification to filter by. If equal to All, then all + notifications are passed to the observe handler. + """ + names = parse_notifier_name(names) + for n in names: + self._add_notifiers(handler, n, type) + + def unobserve(self, handler, names=All, type="change"): + """Remove a trait change handler. + + This is used to unregister handlers to trait change notifications. + + Parameters + ---------- + handler : callable + The callable called when a trait attribute changes. + names : list, str, All (default: All) + The names of the traits for which the specified handler should be + uninstalled. If names is All, the specified handler is uninstalled + from the list of notifiers corresponding to all changes. + type : str or All (default: 'change') + The type of notification to filter by. If All, the specified handler + is uninstalled from the list of notifiers corresponding to all types. + """ + names = parse_notifier_name(names) + for n in names: + self._remove_notifiers(handler, n, type) + + def unobserve_all(self, name=All): + """Remove trait change handlers of any type for the specified name. + If name is not specified, removes all trait notifiers.""" + if name is All: + self._trait_notifiers: t.Dict[str, t.Any] = {} + else: + try: + del self._trait_notifiers[name] + except KeyError: + pass + + def _register_validator(self, handler, names): + """Setup a handler to be called when a trait should be cross validated. + + This is used to setup dynamic notifications for cross-validation. + + If a validator is already registered for any of the provided names, a + TraitError is raised and no new validator is registered. + + Parameters + ---------- + handler : callable + A callable that is called when the given trait is cross-validated. + Its signature is handler(proposal), where proposal is a Bunch (dictionary with attribute access) + with the following attributes/keys: + * ``owner`` : the HasTraits instance + * ``value`` : the proposed value for the modified trait attribute + * ``trait`` : the TraitType instance associated with the attribute + names : List of strings + The names of the traits that should be cross-validated + """ + for name in names: + magic_name = "_%s_validate" % name + if hasattr(self, magic_name): + class_value = getattr(self.__class__, magic_name) + if not isinstance(class_value, ValidateHandler): + _deprecated_method( + class_value, + self.__class__, + magic_name, + "use @validate decorator instead.", + ) + for name in names: + self._trait_validators[name] = handler + + def add_traits(self, **traits): + """Dynamically add trait attributes to the HasTraits instance.""" + cls = self.__class__ + attrs = {"__module__": cls.__module__} + if hasattr(cls, "__qualname__"): + # __qualname__ introduced in Python 3.3 (see PEP 3155) + attrs["__qualname__"] = cls.__qualname__ + attrs.update(traits) + self.__class__ = type(cls.__name__, (cls,), attrs) + for trait in traits.values(): + trait.instance_init(self) + + def set_trait(self, name, value): + """Forcibly sets trait attribute, including read-only attributes.""" + cls = self.__class__ + if not self.has_trait(name): + raise TraitError(f"Class {cls.__name__} does not have a trait named {name}") + else: + getattr(cls, name).set(self, value) + + @classmethod + def class_trait_names(cls, **metadata): + """Get a list of all the names of this class' traits. + + This method is just like the :meth:`trait_names` method, + but is unbound. + """ + return list(cls.class_traits(**metadata)) + + @classmethod + def class_traits(cls, **metadata): + """Get a ``dict`` of all the traits of this class. The dictionary + is keyed on the name and the values are the TraitType objects. + + This method is just like the :meth:`traits` method, but is unbound. + + The TraitTypes returned don't know anything about the values + that the various HasTrait's instances are holding. + + The metadata kwargs allow functions to be passed in which + filter traits based on metadata values. The functions should + take a single value as an argument and return a boolean. If + any function returns False, then the trait is not included in + the output. If a metadata key doesn't exist, None will be passed + to the function. + """ + traits = cls._traits.copy() + + if len(metadata) == 0: + return traits + + result = {} + for name, trait in traits.items(): + for meta_name, meta_eval in metadata.items(): + if not callable(meta_eval): + meta_eval = _SimpleTest(meta_eval) + if not meta_eval(trait.metadata.get(meta_name, None)): + break + else: + result[name] = trait + + return result + + @classmethod + def class_own_traits(cls, **metadata): + """Get a dict of all the traitlets defined on this class, not a parent. + + Works like `class_traits`, except for excluding traits from parents. + """ + sup = super(cls, cls) + return { + n: t + for (n, t) in cls.class_traits(**metadata).items() + if getattr(sup, n, None) is not t + } + + def has_trait(self, name): + """Returns True if the object has a trait with the specified name.""" + return name in self._traits + + def trait_has_value(self, name): + """Returns True if the specified trait has a value. + + This will return false even if ``getattr`` would return a + dynamically generated default value. These default values + will be recognized as existing only after they have been + generated. + + Example + + .. code-block:: python + + class MyClass(HasTraits): + i = Int() + + mc = MyClass() + assert not mc.trait_has_value("i") + mc.i # generates a default value + assert mc.trait_has_value("i") + """ + return name in self._trait_values + + def trait_values(self, **metadata): + """A ``dict`` of trait names and their values. + + The metadata kwargs allow functions to be passed in which + filter traits based on metadata values. The functions should + take a single value as an argument and return a boolean. If + any function returns False, then the trait is not included in + the output. If a metadata key doesn't exist, None will be passed + to the function. + + Returns + ------- + A ``dict`` of trait names and their values. + + Notes + ----- + Trait values are retrieved via ``getattr``, any exceptions raised + by traits or the operations they may trigger will result in the + absence of a trait value in the result ``dict``. + """ + return {name: getattr(self, name) for name in self.trait_names(**metadata)} + + def _get_trait_default_generator(self, name): + """Return default generator for a given trait + + Walk the MRO to resolve the correct default generator according to inheritance. + """ + method_name = "_%s_default" % name + if method_name in self.__dict__: + return getattr(self, method_name) + if method_name in self.__class__.__dict__: + return getattr(self.__class__, method_name) + return self._all_trait_default_generators[name] + + def trait_defaults(self, *names, **metadata): + """Return a trait's default value or a dictionary of them + + Notes + ----- + Dynamically generated default values may + depend on the current state of the object.""" + for n in names: + if not self.has_trait(n): + raise TraitError(f"'{n}' is not a trait of '{type(self).__name__}' instances") + + if len(names) == 1 and len(metadata) == 0: + return self._get_trait_default_generator(names[0])(self) + + trait_names = self.trait_names(**metadata) + trait_names.extend(names) + + defaults = {} + for n in trait_names: + defaults[n] = self._get_trait_default_generator(n)(self) + return defaults + + def trait_names(self, **metadata): + """Get a list of all the names of this class' traits.""" + return list(self.traits(**metadata)) + + def traits(self, **metadata): + """Get a ``dict`` of all the traits of this class. The dictionary + is keyed on the name and the values are the TraitType objects. + + The TraitTypes returned don't know anything about the values + that the various HasTrait's instances are holding. + + The metadata kwargs allow functions to be passed in which + filter traits based on metadata values. The functions should + take a single value as an argument and return a boolean. If + any function returns False, then the trait is not included in + the output. If a metadata key doesn't exist, None will be passed + to the function. + """ + traits = self._traits.copy() + + if len(metadata) == 0: + return traits + + result = {} + for name, trait in traits.items(): + for meta_name, meta_eval in metadata.items(): + if not callable(meta_eval): + meta_eval = _SimpleTest(meta_eval) + if not meta_eval(trait.metadata.get(meta_name, None)): + break + else: + result[name] = trait + + return result + + def trait_metadata(self, traitname, key, default=None): + """Get metadata values for trait by key.""" + try: + trait = getattr(self.__class__, traitname) + except AttributeError as e: + raise TraitError( + f"Class {self.__class__.__name__} does not have a trait named {traitname}" + ) from e + metadata_name = "_" + traitname + "_metadata" + if hasattr(self, metadata_name) and key in getattr(self, metadata_name): + return getattr(self, metadata_name).get(key, default) + else: + return trait.metadata.get(key, default) + + @classmethod + def class_own_trait_events(cls, name): + """Get a dict of all event handlers defined on this class, not a parent. + + Works like ``event_handlers``, except for excluding traits from parents. + """ + sup = super(cls, cls) + return { + n: e + for (n, e) in cls.events(name).items() # type:ignore[attr-defined] + if getattr(sup, n, None) is not e + } + + @classmethod + def trait_events(cls, name=None): + """Get a ``dict`` of all the event handlers of this class. + + Parameters + ---------- + name : str (default: None) + The name of a trait of this class. If name is ``None`` then all + the event handlers of this class will be returned instead. + + Returns + ------- + The event handlers associated with a trait name, or all event handlers. + """ + events = {} + for k, v in getmembers(cls): + if isinstance(v, EventHandler): + if name is None: + events[k] = v + elif name in v.trait_names: # type:ignore[attr-defined] + events[k] = v + elif hasattr(v, "tags"): + if cls.trait_names(**v.tags): + events[k] = v + return events + + +# ----------------------------------------------------------------------------- +# Actual TraitTypes implementations/subclasses +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# TraitTypes subclasses for handling classes and instances of classes +# ----------------------------------------------------------------------------- + + +class ClassBasedTraitType(TraitType): + """ + A trait with error reporting and string -> type resolution for Type, + Instance and This. + """ + + def _resolve_string(self, string): + """ + Resolve a string supplied for a type into an actual object. + """ + return import_item(string) + + +class Type(ClassBasedTraitType): + """A trait whose value must be a subclass of a specified class.""" + + def __init__(self, default_value=Undefined, klass=None, **kwargs): + """Construct a Type trait + + A Type trait specifies that its values must be subclasses of + a particular class. + + If only ``default_value`` is given, it is used for the ``klass`` as + well. If neither are given, both default to ``object``. + + Parameters + ---------- + default_value : class, str or None + The default value must be a subclass of klass. If an str, + the str must be a fully specified class name, like 'foo.bar.Bah'. + The string is resolved into real class, when the parent + :class:`HasTraits` class is instantiated. + klass : class, str [ default object ] + Values of this trait must be a subclass of klass. The klass + may be specified in a string like: 'foo.bar.MyClass'. + The string is resolved into real class, when the parent + :class:`HasTraits` class is instantiated. + allow_none : bool [ default False ] + Indicates whether None is allowed as an assignable value. + **kwargs + extra kwargs passed to `ClassBasedTraitType` + """ + if default_value is Undefined: + new_default_value = object if (klass is None) else klass + else: + new_default_value = default_value + + if klass is None: + if (default_value is None) or (default_value is Undefined): + klass = object + else: + klass = default_value + + if not (inspect.isclass(klass) or isinstance(klass, str)): + raise TraitError("A Type trait must specify a class.") + + self.klass = klass + + super().__init__(new_default_value, **kwargs) + + def validate(self, obj, value): + """Validates that the value is a valid object instance.""" + if isinstance(value, str): + try: + value = self._resolve_string(value) + except ImportError as e: + raise TraitError( + "The '%s' trait of %s instance must be a type, but " + "%r could not be imported" % (self.name, obj, value) + ) from e + try: + if issubclass(value, self.klass): # type:ignore[arg-type] + return value + except Exception: + pass + + self.error(obj, value) + + def info(self): + """Returns a description of the trait.""" + if isinstance(self.klass, str): + klass = self.klass + else: + klass = self.klass.__module__ + "." + self.klass.__name__ + result = "a subclass of '%s'" % klass + if self.allow_none: + return result + " or None" + return result + + def instance_init(self, obj): + # we can't do this in subclass_init because that + # might be called before all imports are done. + self._resolve_classes() + + def _resolve_classes(self): + if isinstance(self.klass, str): + self.klass = self._resolve_string(self.klass) + if isinstance(self.default_value, str): + self.default_value = self._resolve_string(self.default_value) + + def default_value_repr(self): + value = self.default_value + assert value is not None + if isinstance(value, str): + return repr(value) + else: + return repr(f"{value.__module__}.{value.__name__}") + + +class Instance(ClassBasedTraitType): + """A trait whose value must be an instance of a specified class. + + The value can also be an instance of a subclass of the specified class. + + Subclasses can declare default classes by overriding the klass attribute + """ + + klass = None + + def __init__(self, klass=None, args=None, kw=None, **kwargs): + """Construct an Instance trait. + + This trait allows values that are instances of a particular + class or its subclasses. Our implementation is quite different + from that of enthough.traits as we don't allow instances to be used + for klass and we handle the ``args`` and ``kw`` arguments differently. + + Parameters + ---------- + klass : class, str + The class that forms the basis for the trait. Class names + can also be specified as strings, like 'foo.bar.Bar'. + args : tuple + Positional arguments for generating the default value. + kw : dict + Keyword arguments for generating the default value. + allow_none : bool [ default False ] + Indicates whether None is allowed as a value. + **kwargs + Extra kwargs passed to `ClassBasedTraitType` + + Notes + ----- + If both ``args`` and ``kw`` are None, then the default value is None. + If ``args`` is a tuple and ``kw`` is a dict, then the default is + created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is + None, the None is replaced by ``()`` or ``{}``, respectively. + """ + if klass is None: + klass = self.klass + + if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, str)): + self.klass = klass + else: + raise TraitError("The klass attribute must be a class not: %r" % klass) + + if (kw is not None) and not isinstance(kw, dict): + raise TraitError("The 'kw' argument must be a dict or None.") + if (args is not None) and not isinstance(args, tuple): + raise TraitError("The 'args' argument must be a tuple or None.") + + self.default_args = args + self.default_kwargs = kw + + super().__init__(**kwargs) + + def validate(self, obj, value): + assert self.klass is not None + if isinstance(value, self.klass): # type:ignore[arg-type] + return value + else: + self.error(obj, value) + + def info(self): + if isinstance(self.klass, str): + result = add_article(self.klass) + else: + result = describe("a", self.klass) + if self.allow_none: + result += " or None" + return result + + def instance_init(self, obj): + # we can't do this in subclass_init because that + # might be called before all imports are done. + self._resolve_classes() + + def _resolve_classes(self): + if isinstance(self.klass, str): + self.klass = self._resolve_string(self.klass) + + def make_dynamic_default(self): + if (self.default_args is None) and (self.default_kwargs is None): + return None + assert self.klass is not None + return self.klass( + *(self.default_args or ()), **(self.default_kwargs or {}) + ) # type:ignore[operator] + + def default_value_repr(self): + return repr(self.make_dynamic_default()) + + def from_string(self, s): + return _safe_literal_eval(s) + + +class ForwardDeclaredMixin: + """ + Mixin for forward-declared versions of Instance and Type. + """ + + def _resolve_string(self, string): + """ + Find the specified class name by looking for it in the module in which + our this_class attribute was defined. + """ + modname = self.this_class.__module__ # type:ignore[attr-defined] + return import_item(".".join([modname, string])) + + +class ForwardDeclaredType(ForwardDeclaredMixin, Type): + """ + Forward-declared version of Type. + """ + + pass + + +class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance): + """ + Forward-declared version of Instance. + """ + + pass + + +class This(ClassBasedTraitType): + """A trait for instances of the class containing this trait. + + Because how how and when class bodies are executed, the ``This`` + trait can only have a default value of None. This, and because we + always validate default values, ``allow_none`` is *always* true. + """ + + info_text = "an instance of the same type as the receiver or None" + + def __init__(self, **kwargs): + super().__init__(None, **kwargs) + + def validate(self, obj, value): + # What if value is a superclass of obj.__class__? This is + # complicated if it was the superclass that defined the This + # trait. + assert self.this_class is not None + if isinstance(value, self.this_class) or (value is None): + return value + else: + self.error(obj, value) + + +class Union(TraitType): + """A trait type representing a Union type.""" + + def __init__(self, trait_types, **kwargs): + """Construct a Union trait. + + This trait allows values that are allowed by at least one of the + specified trait types. A Union traitlet cannot have metadata on + its own, besides the metadata of the listed types. + + Parameters + ---------- + trait_types : sequence + The list of trait types of length at least 1. + **kwargs + Extra kwargs passed to `TraitType` + + Notes + ----- + Union([Float(), Bool(), Int()]) attempts to validate the provided values + with the validation function of Float, then Bool, and finally Int. + + Parsing from string is ambiguous for container types which accept other + collection-like literals (e.g. List accepting both `[]` and `()` + precludes Union from ever parsing ``Union([List(), Tuple()])`` as a tuple; + you can modify behaviour of too permissive container traits by overriding + ``_literal_from_string_pairs`` in subclasses. + Similarly, parsing unions of numeric types is only unambiguous if + types are provided in order of increasing permissiveness, e.g. + ``Union([Int(), Float()])`` (since floats accept integer-looking values). + """ + self.trait_types = list(trait_types) + self.info_text = " or ".join([tt.info() for tt in self.trait_types]) + super().__init__(**kwargs) + + def default(self, obj=None): + default = super().default(obj) + for trait in self.trait_types: + if default is Undefined: + default = trait.default(obj) + else: + break + return default + + def class_init(self, cls, name): + for trait_type in reversed(self.trait_types): + trait_type.class_init(cls, None) + super().class_init(cls, name) + + def subclass_init(self, cls): + for trait_type in reversed(self.trait_types): + trait_type.subclass_init(cls) + # explicitly not calling super().subclass_init(cls) + # to opt out of instance_init + + def validate(self, obj, value): + with obj.cross_validation_lock: + for trait_type in self.trait_types: + try: + v = trait_type._validate(obj, value) + # In the case of an element trait, the name is None + if self.name is not None: + setattr(obj, "_" + self.name + "_metadata", trait_type.metadata) + return v + except TraitError: + continue + self.error(obj, value) + + def __or__(self, other): + if isinstance(other, Union): + return Union(self.trait_types + other.trait_types) + else: + return Union(self.trait_types + [other]) + + def from_string(self, s): + for trait_type in self.trait_types: + try: + v = trait_type.from_string(s) + return trait_type.validate(None, v) + except (TraitError, ValueError): + continue + return super().from_string(s) + + +# ----------------------------------------------------------------------------- +# Basic TraitTypes implementations/subclasses +# ----------------------------------------------------------------------------- + + +class Any(TraitType): + """A trait which allows any value.""" + + default_value: t.Optional[t.Any] = None + allow_none = True + info_text = "any value" + + def subclass_init(self, cls): + pass # fully opt out of instance_init + + +def _validate_bounds(trait, obj, value): + """ + Validate that a number to be applied to a trait is between bounds. + + If value is not between min_bound and max_bound, this raises a + TraitError with an error message appropriate for this trait. + """ + if trait.min is not None and value < trait.min: + raise TraitError( + "The value of the '{name}' trait of {klass} instance should " + "not be less than {min_bound}, but a value of {value} was " + "specified".format( + name=trait.name, klass=class_of(obj), value=value, min_bound=trait.min + ) + ) + if trait.max is not None and value > trait.max: + raise TraitError( + "The value of the '{name}' trait of {klass} instance should " + "not be greater than {max_bound}, but a value of {value} was " + "specified".format( + name=trait.name, klass=class_of(obj), value=value, max_bound=trait.max + ) + ) + return value + + +class Int(TraitType): + """An int trait.""" + + default_value = 0 + info_text = "an int" + + def __init__(self, default_value=Undefined, allow_none=False, **kwargs): + self.min = kwargs.pop("min", None) + self.max = kwargs.pop("max", None) + super().__init__(default_value=default_value, allow_none=allow_none, **kwargs) + + def validate(self, obj, value): + if not isinstance(value, int): + self.error(obj, value) + return _validate_bounds(self, obj, value) + + def from_string(self, s): + if self.allow_none and s == "None": + return None + return int(s) + + def subclass_init(self, cls): + pass # fully opt out of instance_init + + +class CInt(Int): + """A casting version of the int trait.""" + + def validate(self, obj, value): + try: + value = int(value) + except Exception: + self.error(obj, value) + return _validate_bounds(self, obj, value) + + +Long, CLong = Int, CInt +Integer = Int + + +class Float(TraitType): + """A float trait.""" + + default_value = 0.0 + info_text = "a float" + + def __init__(self, default_value=Undefined, allow_none=False, **kwargs): + self.min = kwargs.pop("min", -float("inf")) + self.max = kwargs.pop("max", float("inf")) + super().__init__(default_value=default_value, allow_none=allow_none, **kwargs) + + def validate(self, obj, value): + if isinstance(value, int): + value = float(value) + if not isinstance(value, float): + self.error(obj, value) + return _validate_bounds(self, obj, value) + + def from_string(self, s): + if self.allow_none and s == "None": + return None + return float(s) + + def subclass_init(self, cls): + pass # fully opt out of instance_init + + +class CFloat(Float): + """A casting version of the float trait.""" + + def validate(self, obj, value): + try: + value = float(value) + except Exception: + self.error(obj, value) + return _validate_bounds(self, obj, value) + + +class Complex(TraitType): + """A trait for complex numbers.""" + + default_value = 0.0 + 0.0j + info_text = "a complex number" + + def validate(self, obj, value): + if isinstance(value, complex): + return value + if isinstance(value, (float, int)): + return complex(value) + self.error(obj, value) + + def from_string(self, s): + if self.allow_none and s == "None": + return None + return complex(s) + + def subclass_init(self, cls): + pass # fully opt out of instance_init + + +class CComplex(Complex): + """A casting version of the complex number trait.""" + + def validate(self, obj, value): + try: + return complex(value) + except Exception: + self.error(obj, value) + + +# We should always be explicit about whether we're using bytes or unicode, both +# for Python 3 conversion and for reliable unicode behaviour on Python 2. So +# we don't have a Str type. +class Bytes(TraitType): + """A trait for byte strings.""" + + default_value = b"" + info_text = "a bytes object" + + def validate(self, obj, value): + if isinstance(value, bytes): + return value + self.error(obj, value) + + def from_string(self, s): + if self.allow_none and s == "None": + return None + if len(s) >= 3: + # handle deprecated b"string" + for quote in ('"', "'"): + if s[:2] == f"b{quote}" and s[-1] == quote: + old_s = s + s = s[2:-1] + warn( + "Supporting extra quotes around Bytes is deprecated in traitlets 5.0. " + "Use %r instead of %r." % (s, old_s), + FutureWarning, + ) + break + return s.encode("utf8") + + def subclass_init(self, cls): + pass # fully opt out of instance_init + + +class CBytes(Bytes): + """A casting version of the byte string trait.""" + + def validate(self, obj, value): + try: + return bytes(value) + except Exception: + self.error(obj, value) + + +class Unicode(TraitType): + """A trait for unicode strings.""" + + default_value = "" + info_text = "a unicode string" + + def validate(self, obj, value): + if isinstance(value, str): + return value + if isinstance(value, bytes): + try: + return value.decode("ascii", "strict") + except UnicodeDecodeError as e: + msg = "Could not decode {!r} for unicode trait '{}' of {} instance." + raise TraitError(msg.format(value, self.name, class_of(obj))) from e + self.error(obj, value) + + def from_string(self, s): + if self.allow_none and s == "None": + return None + s = os.path.expanduser(s) + if len(s) >= 2: + # handle deprecated "1" + for c in ('"', "'"): + if s[0] == s[-1] == c: + old_s = s + s = s[1:-1] + warn( + "Supporting extra quotes around strings is deprecated in traitlets 5.0. " + "You can use %r instead of %r if you require traitlets >=5." % (s, old_s), + FutureWarning, + ) + return s + + def subclass_init(self, cls): + pass # fully opt out of instance_init + + +class CUnicode(Unicode): + """A casting version of the unicode trait.""" + + def validate(self, obj, value): + try: + return str(value) + except Exception: + self.error(obj, value) + + +class ObjectName(TraitType): + """A string holding a valid object name in this version of Python. + + This does not check that the name exists in any scope.""" + + info_text = "a valid object identifier in Python" + + coerce_str = staticmethod(lambda _, s: s) # type:ignore[no-any-return] + + def validate(self, obj, value): + value = self.coerce_str(obj, value) + + if isinstance(value, str) and isidentifier(value): + return value + self.error(obj, value) + + def from_string(self, s): + if self.allow_none and s == "None": + return None + return s + + +class DottedObjectName(ObjectName): + """A string holding a valid dotted object name in Python, such as A.b3._c""" + + def validate(self, obj, value): + value = self.coerce_str(obj, value) + + if isinstance(value, str) and all(isidentifier(a) for a in value.split(".")): + return value + self.error(obj, value) + + +class Bool(TraitType): + """A boolean (True, False) trait.""" + + default_value = False + info_text = "a boolean" + + def validate(self, obj, value): + if isinstance(value, bool): + return value + elif isinstance(value, int): + if value == 1: + return True + elif value == 0: + return False + self.error(obj, value) + + def from_string(self, s): + if self.allow_none and s == "None": + return None + s = s.lower() + if s in {"true", "1"}: + return True + elif s in {"false", "0"}: + return False + else: + raise ValueError("%r is not 1, 0, true, or false") + + def subclass_init(self, cls): + pass # fully opt out of instance_init + + def argcompleter(self, **kwargs): + """Completion hints for argcomplete""" + completions = ["true", "1", "false", "0"] + if self.allow_none: + completions.append("None") + return completions + + +class CBool(Bool): + """A casting version of the boolean trait.""" + + def validate(self, obj, value): + try: + return bool(value) + except Exception: + self.error(obj, value) + + +class Enum(TraitType): + """An enum whose value must be in a given sequence.""" + + def __init__(self, values, default_value=Undefined, **kwargs): + self.values = values + if kwargs.get("allow_none", False) and default_value is Undefined: + default_value = None + super().__init__(default_value, **kwargs) + + def validate(self, obj, value): + if value in self.values: + return value + self.error(obj, value) + + def _choices_str(self, as_rst=False): + """Returns a description of the trait choices (not none).""" + choices = self.values + if as_rst: + choices = "|".join("``%r``" % x for x in choices) + else: + choices = repr(list(choices)) + return choices + + def _info(self, as_rst=False): + """Returns a description of the trait.""" + none = " or %s" % ("`None`" if as_rst else "None") if self.allow_none else "" + return f"any of {self._choices_str(as_rst)}{none}" + + def info(self): + return self._info(as_rst=False) + + def info_rst(self): + return self._info(as_rst=True) + + def from_string(self, s): + try: + return self.validate(None, s) + except TraitError: + return _safe_literal_eval(s) + + def subclass_init(self, cls): + pass # fully opt out of instance_init + + def argcompleter(self, **kwargs): + """Completion hints for argcomplete""" + return [str(v) for v in self.values] + + +class CaselessStrEnum(Enum): + """An enum of strings where the case should be ignored.""" + + def __init__(self, values, default_value=Undefined, **kwargs): + super().__init__(values, default_value=default_value, **kwargs) + + def validate(self, obj, value): + if not isinstance(value, str): + self.error(obj, value) + + for v in self.values: + if v.lower() == value.lower(): + return v + self.error(obj, value) + + def _info(self, as_rst=False): + """Returns a description of the trait.""" + none = " or %s" % ("`None`" if as_rst else "None") if self.allow_none else "" + return f"any of {self._choices_str(as_rst)} (case-insensitive){none}" + + def info(self): + return self._info(as_rst=False) + + def info_rst(self): + return self._info(as_rst=True) + + +class FuzzyEnum(Enum): + """An case-ignoring enum matching choices by unique prefixes/substrings.""" + + case_sensitive = False + #: If True, choices match anywhere in the string, otherwise match prefixes. + substring_matching = False + + def __init__( + self, + values, + default_value=Undefined, + case_sensitive=False, + substring_matching=False, + **kwargs, + ): + self.case_sensitive = case_sensitive + self.substring_matching = substring_matching + super().__init__(values, default_value=default_value, **kwargs) + + def validate(self, obj, value): + if not isinstance(value, str): + self.error(obj, value) + + conv_func = (lambda c: c) if self.case_sensitive else lambda c: c.lower() + substring_matching = self.substring_matching + match_func = ( + (lambda v, c: v in c) + if substring_matching + else (lambda v, c: c.startswith(v)) # type:ignore[no-any-return] + ) + value = conv_func(value) + choices = self.values + matches = [match_func(value, conv_func(c)) for c in choices] + if sum(matches) == 1: + for v, m in zip(choices, matches): + if m: + return v + + self.error(obj, value) + + def _info(self, as_rst=False): + """Returns a description of the trait.""" + none = " or %s" % ("`None`" if as_rst else "None") if self.allow_none else "" + case = "sensitive" if self.case_sensitive else "insensitive" + substr = "substring" if self.substring_matching else "prefix" + return f"any case-{case} {substr} of {self._choices_str(as_rst)}{none}" + + def info(self): + return self._info(as_rst=False) + + def info_rst(self): + return self._info(as_rst=True) + + +class Container(Instance): + """An instance of a container (list, set, etc.) + + To be subclassed by overriding klass. + """ + + klass: t.Optional[t.Union[str, t.Type[t.Any]]] = None + _cast_types: t.Any = () + _valid_defaults = SequenceTypes + _trait = None + _literal_from_string_pairs: t.Any = ("[]", "()") + + def __init__(self, trait=None, default_value=Undefined, **kwargs): + """Create a container trait type from a list, set, or tuple. + + The default value is created by doing ``List(default_value)``, + which creates a copy of the ``default_value``. + + ``trait`` can be specified, which restricts the type of elements + in the container to that TraitType. + + If only one arg is given and it is not a Trait, it is taken as + ``default_value``: + + ``c = List([1, 2, 3])`` + + Parameters + ---------- + trait : TraitType [ optional ] + the type for restricting the contents of the Container. If unspecified, + types are not checked. + default_value : SequenceType [ optional ] + The default value for the Trait. Must be list/tuple/set, and + will be cast to the container type. + allow_none : bool [ default False ] + Whether to allow the value to be None + **kwargs : any + further keys for extensions to the Trait (e.g. config) + + """ + + # allow List([values]): + if trait is not None and default_value is Undefined and not is_trait(trait): + default_value = trait + trait = None + + if default_value is None and not kwargs.get("allow_none", False): + # improve backward-compatibility for possible subclasses + # specifying default_value=None as default, + # keeping 'unspecified' behavior (i.e. empty container) + warn( + f"Specifying {self.__class__.__name__}(default_value=None)" + " for no default is deprecated in traitlets 5.0.5." + " Use default_value=Undefined", + DeprecationWarning, + stacklevel=2, + ) + default_value = Undefined + + if default_value is Undefined: + args: t.Any = () + elif default_value is None: + # default_value back on kwargs for super() to handle + args = () + kwargs["default_value"] = None + elif isinstance(default_value, self._valid_defaults): + args = (default_value,) + else: + raise TypeError(f"default value of {self.__class__.__name__} was {default_value}") + + if is_trait(trait): + if isinstance(trait, type): + warn( + "Traits should be given as instances, not types (for example, `Int()`, not `Int`)." + " Passing types is deprecated in traitlets 4.1.", + DeprecationWarning, + stacklevel=3, + ) + self._trait = trait() if isinstance(trait, type) else trait + elif trait is not None: + raise TypeError("`trait` must be a Trait or None, got %s" % repr_type(trait)) + + super().__init__(klass=self.klass, args=args, **kwargs) + + def validate(self, obj, value): + if isinstance(value, self._cast_types): + assert self.klass is not None + value = self.klass(value) # type:ignore[operator] + value = super().validate(obj, value) + if value is None: + return value + + value = self.validate_elements(obj, value) + + return value + + def validate_elements(self, obj, value): + validated = [] + if self._trait is None or isinstance(self._trait, Any): + return value + for v in value: + try: + v = self._trait._validate(obj, v) + except TraitError as error: + self.error(obj, v, error) + else: + validated.append(v) + assert self.klass is not None + return self.klass(validated) # type:ignore[operator] + + def class_init(self, cls, name): + if isinstance(self._trait, TraitType): + self._trait.class_init(cls, None) + super().class_init(cls, name) + + def subclass_init(self, cls): + if isinstance(self._trait, TraitType): + self._trait.subclass_init(cls) + # explicitly not calling super().subclass_init(cls) + # to opt out of instance_init + + def from_string(self, s): + """Load value from a single string""" + if not isinstance(s, str): + raise TraitError(f"Expected string, got {s!r}") + try: + test = literal_eval(s) + except Exception: + test = None + return self.validate(None, test) + + def from_string_list(self, s_list): + """Return the value from a list of config strings + + This is where we parse CLI configuration + """ + assert self.klass is not None + if len(s_list) == 1: + # check for deprecated --Class.trait="['a', 'b', 'c']" + r = s_list[0] + if r == "None" and self.allow_none: + return None + if len(r) >= 2 and any( + r.startswith(start) and r.endswith(end) + for start, end in self._literal_from_string_pairs + ): + if self.this_class: + clsname = self.this_class.__name__ + "." + else: + clsname = "" + assert self.name is not None + warn( + "--{0}={1} for containers is deprecated in traitlets 5.0. " + "You can pass `--{0} item` ... multiple times to add items to a list.".format( + clsname + self.name, r + ), + FutureWarning, + ) + return self.klass(literal_eval(r)) # type:ignore[operator] + sig = inspect.signature(self.item_from_string) + if "index" in sig.parameters: + item_from_string = self.item_from_string + else: + # backward-compat: allow item_from_string to ignore index arg + item_from_string = lambda s, index=None: self.item_from_string(s) # noqa[E371] + + return self.klass( + [item_from_string(s, index=idx) for idx, s in enumerate(s_list)] + ) # type:ignore[operator] + + def item_from_string(self, s, index=None): + """Cast a single item from a string + + Evaluated when parsing CLI configuration from a string + """ + if self._trait: + return self._trait.from_string(s) + else: + return s + + +class List(Container): + """An instance of a Python list.""" + + klass = list + _cast_types: t.Any = (tuple,) + + def __init__( + self, + trait=None, + default_value=Undefined, + minlen=0, + maxlen=sys.maxsize, + **kwargs, + ): + """Create a List trait type from a list, set, or tuple. + + The default value is created by doing ``list(default_value)``, + which creates a copy of the ``default_value``. + + ``trait`` can be specified, which restricts the type of elements + in the container to that TraitType. + + If only one arg is given and it is not a Trait, it is taken as + ``default_value``: + + ``c = List([1, 2, 3])`` + + Parameters + ---------- + trait : TraitType [ optional ] + the type for restricting the contents of the Container. + If unspecified, types are not checked. + default_value : SequenceType [ optional ] + The default value for the Trait. Must be list/tuple/set, and + will be cast to the container type. + minlen : Int [ default 0 ] + The minimum length of the input list + maxlen : Int [ default sys.maxsize ] + The maximum length of the input list + """ + self._minlen = minlen + self._maxlen = maxlen + super().__init__(trait=trait, default_value=default_value, **kwargs) + + def length_error(self, obj, value): + e = ( + "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." + % (self.name, class_of(obj), self._minlen, self._maxlen, value) + ) + raise TraitError(e) + + def validate_elements(self, obj, value): + length = len(value) + if length < self._minlen or length > self._maxlen: + self.length_error(obj, value) + + return super().validate_elements(obj, value) + + def set(self, obj, value): + if isinstance(value, str): + return super().set(obj, [value]) + else: + return super().set(obj, value) + + +class Set(List): + """An instance of a Python set.""" + + klass = set # type:ignore[assignment] + _cast_types = (tuple, list) + + _literal_from_string_pairs = ("[]", "()", "{}") + + # Redefine __init__ just to make the docstring more accurate. + def __init__( + self, + trait=None, + default_value=Undefined, + minlen=0, + maxlen=sys.maxsize, + **kwargs, + ): + """Create a Set trait type from a list, set, or tuple. + + The default value is created by doing ``set(default_value)``, + which creates a copy of the ``default_value``. + + ``trait`` can be specified, which restricts the type of elements + in the container to that TraitType. + + If only one arg is given and it is not a Trait, it is taken as + ``default_value``: + + ``c = Set({1, 2, 3})`` + + Parameters + ---------- + trait : TraitType [ optional ] + the type for restricting the contents of the Container. + If unspecified, types are not checked. + default_value : SequenceType [ optional ] + The default value for the Trait. Must be list/tuple/set, and + will be cast to the container type. + minlen : Int [ default 0 ] + The minimum length of the input list + maxlen : Int [ default sys.maxsize ] + The maximum length of the input list + """ + super().__init__(trait, default_value, minlen, maxlen, **kwargs) + + def default_value_repr(self): + # Ensure default value is sorted for a reproducible build + list_repr = repr(sorted(self.make_dynamic_default())) + if list_repr == "[]": + return "set()" + return "{" + list_repr[1:-1] + "}" + + +class Tuple(Container): + """An instance of a Python tuple.""" + + klass = tuple + _cast_types = (list,) + + def __init__(self, *traits, **kwargs): + """Create a tuple from a list, set, or tuple. + + Create a fixed-type tuple with Traits: + + ``t = Tuple(Int(), Str(), CStr())`` + + would be length 3, with Int,Str,CStr for each element. + + If only one arg is given and it is not a Trait, it is taken as + default_value: + + ``t = Tuple((1, 2, 3))`` + + Otherwise, ``default_value`` *must* be specified by keyword. + + Parameters + ---------- + *traits : TraitTypes [ optional ] + the types for restricting the contents of the Tuple. If unspecified, + types are not checked. If specified, then each positional argument + corresponds to an element of the tuple. Tuples defined with traits + are of fixed length. + default_value : SequenceType [ optional ] + The default value for the Tuple. Must be list/tuple/set, and + will be cast to a tuple. If ``traits`` are specified, + ``default_value`` must conform to the shape and type they specify. + **kwargs + Other kwargs passed to `Container` + """ + default_value = kwargs.pop("default_value", Undefined) + # allow Tuple((values,)): + if len(traits) == 1 and default_value is Undefined and not is_trait(traits[0]): + default_value = traits[0] + traits = () + + if default_value is None and not kwargs.get("allow_none", False): + # improve backward-compatibility for possible subclasses + # specifying default_value=None as default, + # keeping 'unspecified' behavior (i.e. empty container) + warn( + f"Specifying {self.__class__.__name__}(default_value=None)" + " for no default is deprecated in traitlets 5.0.5." + " Use default_value=Undefined", + DeprecationWarning, + stacklevel=2, + ) + default_value = Undefined + + if default_value is Undefined: + args: t.Any = () + elif default_value is None: + # default_value back on kwargs for super() to handle + args = () + kwargs["default_value"] = None + elif isinstance(default_value, self._valid_defaults): + args = (default_value,) + else: + raise TypeError(f"default value of {self.__class__.__name__} was {default_value}") + + self._traits = [] + for trait in traits: + if isinstance(trait, type): + warn( + "Traits should be given as instances, not types (for example, `Int()`, not `Int`)" + " Passing types is deprecated in traitlets 4.1.", + DeprecationWarning, + stacklevel=2, + ) + trait = trait() + self._traits.append(trait) + + if self._traits and (default_value is None or default_value is Undefined): + # don't allow default to be an empty container if length is specified + args = None + super(Container, self).__init__(klass=self.klass, args=args, **kwargs) + + def item_from_string(self, s, index): + """Cast a single item from a string + + Evaluated when parsing CLI configuration from a string + """ + if not self._traits or index >= len(self._traits): + # return s instead of raising index error + # length errors will be raised later on validation + return s + return self._traits[index].from_string(s) + + def validate_elements(self, obj, value): + if not self._traits: + # nothing to validate + return value + if len(value) != len(self._traits): + e = ( + "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." + % (self.name, class_of(obj), len(self._traits), repr_type(value)) + ) + raise TraitError(e) + + validated = [] + for trait, v in zip(self._traits, value): + try: + v = trait._validate(obj, v) + except TraitError as error: + self.error(obj, v, error) + else: + validated.append(v) + return tuple(validated) + + def class_init(self, cls, name): + for trait in self._traits: + if isinstance(trait, TraitType): + trait.class_init(cls, None) + super(Container, self).class_init(cls, name) + + def subclass_init(self, cls): + for trait in self._traits: + if isinstance(trait, TraitType): + trait.subclass_init(cls) + # explicitly not calling super().subclass_init(cls) + # to opt out of instance_init + + +class Dict(Instance): + """An instance of a Python dict. + + One or more traits can be passed to the constructor + to validate the keys and/or values of the dict. + If you need more detailed validation, + you may use a custom validator method. + + .. versionchanged:: 5.0 + Added key_trait for validating dict keys. + + .. versionchanged:: 5.0 + Deprecated ambiguous ``trait``, ``traits`` args in favor of ``value_trait``, ``per_key_traits``. + """ + + _value_trait = None + _key_trait = None + + def __init__( + self, + value_trait=None, + per_key_traits=None, + key_trait=None, + default_value=Undefined, + **kwargs, + ): + """Create a dict trait type from a Python dict. + + The default value is created by doing ``dict(default_value)``, + which creates a copy of the ``default_value``. + + Parameters + ---------- + value_trait : TraitType [ optional ] + The specified trait type to check and use to restrict the values of + the dict. If unspecified, values are not checked. + per_key_traits : Dictionary of {keys:trait types} [ optional, keyword-only ] + A Python dictionary containing the types that are valid for + restricting the values of the dict on a per-key basis. + Each value in this dict should be a Trait for validating + key_trait : TraitType [ optional, keyword-only ] + The type for restricting the keys of the dict. If + unspecified, the types of the keys are not checked. + default_value : SequenceType [ optional, keyword-only ] + The default value for the Dict. Must be dict, tuple, or None, and + will be cast to a dict if not None. If any key or value traits are specified, + the `default_value` must conform to the constraints. + + Examples + -------- + a dict whose values must be text + >>> d = Dict(Unicode()) + + d2['n'] must be an integer + d2['s'] must be text + >>> d2 = Dict(per_key_traits={"n": Integer(), "s": Unicode()}) + + d3's keys must be text + d3's values must be integers + >>> d3 = Dict(value_trait=Integer(), key_trait=Unicode()) + + """ + + # handle deprecated keywords + trait = kwargs.pop("trait", None) + if trait is not None: + if value_trait is not None: + raise TypeError( + "Found a value for both `value_trait` and its deprecated alias `trait`." + ) + value_trait = trait + warn( + "Keyword `trait` is deprecated in traitlets 5.0, use `value_trait` instead", + DeprecationWarning, + stacklevel=2, + ) + traits = kwargs.pop("traits", None) + if traits is not None: + if per_key_traits is not None: + raise TypeError( + "Found a value for both `per_key_traits` and its deprecated alias `traits`." + ) + per_key_traits = traits + warn( + "Keyword `traits` is deprecated in traitlets 5.0, use `per_key_traits` instead", + DeprecationWarning, + stacklevel=2, + ) + + # Handling positional arguments + if default_value is Undefined and value_trait is not None: + if not is_trait(value_trait): + default_value = value_trait + value_trait = None + + if key_trait is None and per_key_traits is not None: + if is_trait(per_key_traits): + key_trait = per_key_traits + per_key_traits = None + + # Handling default value + if default_value is Undefined: + default_value = {} + if default_value is None: + args: t.Any = None + elif isinstance(default_value, dict): + args = (default_value,) + elif isinstance(default_value, SequenceTypes): + args = (default_value,) + else: + raise TypeError("default value of Dict was %s" % default_value) + + # Case where a type of TraitType is provided rather than an instance + if is_trait(value_trait): + if isinstance(value_trait, type): + warn( + "Traits should be given as instances, not types (for example, `Int()`, not `Int`)" + " Passing types is deprecated in traitlets 4.1.", + DeprecationWarning, + stacklevel=2, + ) + value_trait = value_trait() + self._value_trait = value_trait + elif value_trait is not None: + raise TypeError( + "`value_trait` must be a Trait or None, got %s" % repr_type(value_trait) + ) + + if is_trait(key_trait): + if isinstance(key_trait, type): + warn( + "Traits should be given as instances, not types (for example, `Int()`, not `Int`)" + " Passing types is deprecated in traitlets 4.1.", + DeprecationWarning, + stacklevel=2, + ) + key_trait = key_trait() + self._key_trait = key_trait + elif key_trait is not None: + raise TypeError("`key_trait` must be a Trait or None, got %s" % repr_type(key_trait)) + + self._per_key_traits = per_key_traits + + super().__init__(klass=dict, args=args, **kwargs) + + def element_error(self, obj, element, validator, side="Values"): + e = ( + side + + " of the '%s' trait of %s instance must be %s, but a value of %s was specified." + % (self.name, class_of(obj), validator.info(), repr_type(element)) + ) + raise TraitError(e) + + def validate(self, obj, value): + value = super().validate(obj, value) + if value is None: + return value + value = self.validate_elements(obj, value) + return value + + def validate_elements(self, obj, value): + per_key_override = self._per_key_traits or {} + key_trait = self._key_trait + value_trait = self._value_trait + if not (key_trait or value_trait or per_key_override): + return value + + validated = {} + for key in value: + v = value[key] + if key_trait: + try: + key = key_trait._validate(obj, key) + except TraitError: + self.element_error(obj, key, key_trait, "Keys") + active_value_trait = per_key_override.get(key, value_trait) + if active_value_trait: + try: + v = active_value_trait._validate(obj, v) + except TraitError: + self.element_error(obj, v, active_value_trait, "Values") + validated[key] = v + + return self.klass(validated) # type:ignore + + def class_init(self, cls, name): + if isinstance(self._value_trait, TraitType): + self._value_trait.class_init(cls, None) + if isinstance(self._key_trait, TraitType): + self._key_trait.class_init(cls, None) + if self._per_key_traits is not None: + for trait in self._per_key_traits.values(): + trait.class_init(cls, None) + super().class_init(cls, name) + + def subclass_init(self, cls): + if isinstance(self._value_trait, TraitType): + self._value_trait.subclass_init(cls) + if isinstance(self._key_trait, TraitType): + self._key_trait.subclass_init(cls) + if self._per_key_traits is not None: + for trait in self._per_key_traits.values(): + trait.subclass_init(cls) + # explicitly not calling super().subclass_init(cls) + # to opt out of instance_init + + def from_string(self, s): + """Load value from a single string""" + if not isinstance(s, str): + raise TypeError(f"from_string expects a string, got {repr(s)} of type {type(s)}") + try: + return self.from_string_list([s]) + except Exception: + test = _safe_literal_eval(s) + if isinstance(test, dict): + return test + raise + + def from_string_list(self, s_list): + """Return a dict from a list of config strings. + + This is where we parse CLI configuration. + + Each item should have the form ``"key=value"``. + + item parsing is done in :meth:`.item_from_string`. + """ + if len(s_list) == 1 and s_list[0] == "None" and self.allow_none: + return None + if len(s_list) == 1 and s_list[0].startswith("{") and s_list[0].endswith("}"): + warn( + "--{0}={1} for dict-traits is deprecated in traitlets 5.0. " + "You can pass --{0} <key=value> ... multiple times to add items to a dict.".format( + self.name, + s_list[0], + ), + FutureWarning, + ) + + return literal_eval(s_list[0]) + + combined = {} + for d in [self.item_from_string(s) for s in s_list]: + combined.update(d) + return combined + + def item_from_string(self, s): + """Cast a single-key dict from a string. + + Evaluated when parsing CLI configuration from a string. + + Dicts expect strings of the form key=value. + + Returns a one-key dictionary, + which will be merged in :meth:`.from_string_list`. + """ + + if "=" not in s: + raise TraitError( + "'%s' options must have the form 'key=value', got %s" + % ( + self.__class__.__name__, + repr(s), + ) + ) + key, value = s.split("=", 1) + + # cast key with key trait, if defined + if self._key_trait: + key = self._key_trait.from_string(key) + + # cast value with value trait, if defined (per-key or global) + value_trait = (self._per_key_traits or {}).get(key, self._value_trait) + if value_trait: + value = value_trait.from_string(value) + return {key: value} + + +class TCPAddress(TraitType): + """A trait for an (ip, port) tuple. + + This allows for both IPv4 IP addresses as well as hostnames. + """ + + default_value = ("127.0.0.1", 0) + info_text = "an (ip, port) tuple" + + def validate(self, obj, value): + if isinstance(value, tuple): + if len(value) == 2: + if isinstance(value[0], str) and isinstance(value[1], int): + port = value[1] + if port >= 0 and port <= 65535: + return value + self.error(obj, value) + + def from_string(self, s): + if self.allow_none and s == "None": + return None + if ":" not in s: + raise ValueError("Require `ip:port`, got %r" % s) + ip, port = s.split(":", 1) + port = int(port) + return (ip, port) + + +class CRegExp(TraitType): + """A casting compiled regular expression trait. + + Accepts both strings and compiled regular expressions. The resulting + attribute will be a compiled regular expression.""" + + info_text = "a regular expression" + + def validate(self, obj, value): + try: + return re.compile(value) + except Exception: + self.error(obj, value) + + +class UseEnum(TraitType): + """Use a Enum class as model for the data type description. + Note that if no default-value is provided, the first enum-value is used + as default-value. + + .. sourcecode:: python + + # -- SINCE: Python 3.4 (or install backport: pip install enum34) + import enum + from traitlets import HasTraits, UseEnum + + class Color(enum.Enum): + red = 1 # -- IMPLICIT: default_value + blue = 2 + green = 3 + + class MyEntity(HasTraits): + color = UseEnum(Color, default_value=Color.blue) + + entity = MyEntity(color=Color.red) + entity.color = Color.green # USE: Enum-value (preferred) + entity.color = "green" # USE: name (as string) + entity.color = "Color.green" # USE: scoped-name (as string) + entity.color = 3 # USE: number (as int) + assert entity.color is Color.green + """ + + default_value: t.Optional[enum.Enum] = None + info_text = "Trait type adapter to a Enum class" + + def __init__(self, enum_class, default_value=None, **kwargs): + assert issubclass(enum_class, enum.Enum), "REQUIRE: enum.Enum, but was: %r" % enum_class + allow_none = kwargs.get("allow_none", False) + if default_value is None and not allow_none: + default_value = list(enum_class.__members__.values())[0] + super().__init__(default_value=default_value, **kwargs) + self.enum_class = enum_class + self.name_prefix = enum_class.__name__ + "." + + def select_by_number(self, value, default=Undefined): + """Selects enum-value by using its number-constant.""" + assert isinstance(value, int) + enum_members = self.enum_class.__members__ + for enum_item in enum_members.values(): + if enum_item.value == value: + return enum_item + # -- NOT FOUND: + return default + + def select_by_name(self, value, default=Undefined): + """Selects enum-value by using its name or scoped-name.""" + assert isinstance(value, str) + if value.startswith(self.name_prefix): + # -- SUPPORT SCOPED-NAMES, like: "Color.red" => "red" + value = value.replace(self.name_prefix, "", 1) + return self.enum_class.__members__.get(value, default) + + def validate(self, obj, value): + if isinstance(value, self.enum_class): + return value + elif isinstance(value, int): + # -- CONVERT: number => enum_value (item) + value2 = self.select_by_number(value) + if value2 is not Undefined: + return value2 + elif isinstance(value, str): + # -- CONVERT: name or scoped_name (as string) => enum_value (item) + value2 = self.select_by_name(value) + if value2 is not Undefined: + return value2 + elif value is None: + if self.allow_none: + return None + else: + return self.default_value + self.error(obj, value) + + def _choices_str(self, as_rst=False): + """Returns a description of the trait choices (not none).""" + choices = self.enum_class.__members__.keys() + if as_rst: + return "|".join("``%r``" % x for x in choices) + else: + return repr(list(choices)) # Listify because py3.4- prints odict-class + + def _info(self, as_rst=False): + """Returns a description of the trait.""" + none = " or %s" % ("`None`" if as_rst else "None") if self.allow_none else "" + return f"any of {self._choices_str(as_rst)}{none}" + + def info(self): + return self._info(as_rst=False) + + def info_rst(self): + return self._info(as_rst=True) + + +class Callable(TraitType): + """A trait which is callable. + + Notes + ----- + Classes are callable, as are instances + with a __call__() method.""" + + info_text = "a callable" + + def validate(self, obj, value): + if callable(value): + return value + else: + self.error(obj, value) diff --git a/contrib/python/traitlets/py3/traitlets/utils/__init__.py b/contrib/python/traitlets/py3/traitlets/utils/__init__.py new file mode 100644 index 0000000000..dfec4ee322 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/__init__.py @@ -0,0 +1,88 @@ +import os +import pathlib + + +# vestigal things from IPython_genutils. +def cast_unicode(s, encoding="utf-8"): + if isinstance(s, bytes): + return s.decode(encoding, "replace") + return s + + +def filefind(filename, path_dirs=None): + """Find a file by looking through a sequence of paths. + + This iterates through a sequence of paths looking for a file and returns + the full, absolute path of the first occurence of the file. If no set of + path dirs is given, the filename is tested as is, after running through + :func:`expandvars` and :func:`expanduser`. Thus a simple call:: + + filefind('myfile.txt') + + will find the file in the current working dir, but:: + + filefind('~/myfile.txt') + + Will find the file in the users home directory. This function does not + automatically try any paths, such as the cwd or the user's home directory. + + Parameters + ---------- + filename : str + The filename to look for. + path_dirs : str, None or sequence of str + The sequence of paths to look for the file in. If None, the filename + need to be absolute or be in the cwd. If a string, the string is + put into a sequence and the searched. If a sequence, walk through + each element and join with ``filename``, calling :func:`expandvars` + and :func:`expanduser` before testing for existence. + + Returns + ------- + Raises :exc:`IOError` or returns absolute path to file. + """ + + # If paths are quoted, abspath gets confused, strip them... + filename = filename.strip('"').strip("'") + # If the input is an absolute path, just check it exists + if os.path.isabs(filename) and os.path.isfile(filename): + return filename + + if path_dirs is None: + path_dirs = ("",) + elif isinstance(path_dirs, str): + path_dirs = (path_dirs,) + elif isinstance(path_dirs, pathlib.Path): + path_dirs = (str(path_dirs),) + + for path in path_dirs: + if path == ".": + path = os.getcwd() + testname = expand_path(os.path.join(path, filename)) + if os.path.isfile(testname): + return os.path.abspath(testname) + + raise OSError(f"File {filename!r} does not exist in any of the search paths: {path_dirs!r}") + + +def expand_path(s): + """Expand $VARS and ~names in a string, like a shell + + :Examples: + + In [2]: os.environ['FOO']='test' + + In [3]: expand_path('variable FOO is $FOO') + Out[3]: 'variable FOO is test' + """ + # This is a pretty subtle hack. When expand user is given a UNC path + # on Windows (\\server\share$\%username%), os.path.expandvars, removes + # the $ to get (\\server\share\%username%). I think it considered $ + # alone an empty var. But, we need the $ to remains there (it indicates + # a hidden share). + if os.name == "nt": + s = s.replace("$\\", "IPYTHON_TEMP") + s = os.path.expandvars(os.path.expanduser(s)) + if os.name == "nt": + s = s.replace("IPYTHON_TEMP", "$\\") + return s diff --git a/contrib/python/traitlets/py3/traitlets/utils/bunch.py b/contrib/python/traitlets/py3/traitlets/utils/bunch.py new file mode 100644 index 0000000000..6b3fffeb12 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/bunch.py @@ -0,0 +1,26 @@ +"""Yet another implementation of bunch + +attribute-access of items on a dict. +""" + +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + + +class Bunch(dict): # type:ignore[type-arg] + """A dict with attribute-access""" + + def __getattr__(self, key): + try: + return self.__getitem__(key) + except KeyError as e: + raise AttributeError(key) from e + + def __setattr__(self, key, value): + self.__setitem__(key, value) + + def __dir__(self): + # py2-compat: can't use super because dict doesn't have __dir__ + names = dir({}) + names.extend(self.keys()) + return names diff --git a/contrib/python/traitlets/py3/traitlets/utils/decorators.py b/contrib/python/traitlets/py3/traitlets/utils/decorators.py new file mode 100644 index 0000000000..a59e8167b0 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/decorators.py @@ -0,0 +1,87 @@ +"""Useful decorators for Traitlets users.""" + +import copy +from inspect import Parameter, Signature, signature +from typing import Type, TypeVar + +from ..traitlets import HasTraits, Undefined + + +def _get_default(value): + """Get default argument value, given the trait default value.""" + return Parameter.empty if value == Undefined else value + + +T = TypeVar("T", bound=HasTraits) + + +def signature_has_traits(cls: Type[T]) -> Type[T]: + """Return a decorated class with a constructor signature that contain Trait names as kwargs.""" + traits = [ + (name, _get_default(value.default_value)) + for name, value in cls.class_traits().items() + if not name.startswith("_") + ] + + # Taking the __init__ signature, as the cls signature is not initialized yet + old_signature = signature(cls.__init__) + old_parameter_names = list(old_signature.parameters) + + old_positional_parameters = [] + old_var_positional_parameter = None # This won't be None if the old signature contains *args + old_keyword_only_parameters = [] + old_var_keyword_parameter = None # This won't be None if the old signature contains **kwargs + + for parameter_name in old_signature.parameters: + # Copy the parameter + parameter = copy.copy(old_signature.parameters[parameter_name]) + + if ( + parameter.kind is Parameter.POSITIONAL_ONLY + or parameter.kind is Parameter.POSITIONAL_OR_KEYWORD + ): + old_positional_parameters.append(parameter) + + elif parameter.kind is Parameter.VAR_POSITIONAL: + old_var_positional_parameter = parameter + + elif parameter.kind is Parameter.KEYWORD_ONLY: + old_keyword_only_parameters.append(parameter) + + elif parameter.kind is Parameter.VAR_KEYWORD: + old_var_keyword_parameter = parameter + + # Unfortunately, if the old signature does not contain **kwargs, we can't do anything, + # because it can't accept traits as keyword arguments + if old_var_keyword_parameter is None: + raise RuntimeError( + "The {} constructor does not take **kwargs, which means that the signature can not be expanded with trait names".format( + cls + ) + ) + + new_parameters = [] + + # Append the old positional parameters (except `self` which is the first parameter) + new_parameters += old_positional_parameters[1:] + + # Append *args if the old signature had it + if old_var_positional_parameter is not None: + new_parameters.append(old_var_positional_parameter) + + # Append the old keyword only parameters + new_parameters += old_keyword_only_parameters + + # Append trait names as keyword only parameters in the signature + new_parameters += [ + Parameter(name, kind=Parameter.KEYWORD_ONLY, default=default) + for name, default in traits + if name not in old_parameter_names + ] + + # Append **kwargs + new_parameters.append(old_var_keyword_parameter) + + cls.__signature__ = Signature(new_parameters) # type:ignore[attr-defined] + + return cls diff --git a/contrib/python/traitlets/py3/traitlets/utils/descriptions.py b/contrib/python/traitlets/py3/traitlets/utils/descriptions.py new file mode 100644 index 0000000000..232eb0e728 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/descriptions.py @@ -0,0 +1,174 @@ +import inspect +import re +import types + + +def describe(article, value, name=None, verbose=False, capital=False): + """Return string that describes a value + + Parameters + ---------- + article : str or None + A definite or indefinite article. If the article is + indefinite (i.e. "a" or "an") the appropriate one + will be infered. Thus, the arguments of ``describe`` + can themselves represent what the resulting string + will actually look like. If None, then no article + will be prepended to the result. For non-articled + description, values that are instances are treated + definitely, while classes are handled indefinitely. + value : any + The value which will be named. + name : str or None (default: None) + Only applies when ``article`` is "the" - this + ``name`` is a definite reference to the value. + By default one will be infered from the value's + type and repr methods. + verbose : bool (default: False) + Whether the name should be concise or verbose. When + possible, verbose names include the module, and/or + class name where an object was defined. + capital : bool (default: False) + Whether the first letter of the article should + be capitalized or not. By default it is not. + + Examples + -------- + Indefinite description: + + >>> describe("a", object()) + 'an object' + >>> describe("a", object) + 'an object' + >>> describe("a", type(object)) + 'a type' + + Definite description: + + >>> describe("the", object()) + "the object at '...'" + >>> describe("the", object) + 'the object object' + >>> describe("the", type(object)) + 'the type type' + + Definitely named description: + + >>> describe("the", object(), "I made") + 'the object I made' + >>> describe("the", object, "I will use") + 'the object I will use' + """ + if isinstance(article, str): + article = article.lower() + + if not inspect.isclass(value): + typename = type(value).__name__ + else: + typename = value.__name__ + if verbose: + typename = _prefix(value) + typename + + if article == "the" or (article is None and not inspect.isclass(value)): + if name is not None: + result = f"{typename} {name}" + if article is not None: + return add_article(result, True, capital) + else: + return result + else: + tick_wrap = False + if inspect.isclass(value): + name = value.__name__ + elif isinstance(value, types.FunctionType): + name = value.__name__ + tick_wrap = True + elif isinstance(value, types.MethodType): + name = value.__func__.__name__ + tick_wrap = True + elif type(value).__repr__ in ( + object.__repr__, + type.__repr__, + ): # type:ignore[comparison-overlap] + name = "at '%s'" % hex(id(value)) + verbose = False + else: + name = repr(value) + verbose = False + if verbose: + name = _prefix(value) + name + if tick_wrap: + name = name.join("''") + return describe(article, value, name=name, verbose=verbose, capital=capital) + elif article in ("a", "an") or article is None: + if article is None: + return typename + return add_article(typename, False, capital) + else: + raise ValueError( + "The 'article' argument should be 'the', 'a', 'an', or None not %r" % article + ) + + +def _prefix(value): + if isinstance(value, types.MethodType): + name = describe(None, value.__self__, verbose=True) + "." + else: + module = inspect.getmodule(value) + if module is not None and module.__name__ != "builtins": + name = module.__name__ + "." + else: + name = "" + return name + + +def class_of(value): + """Returns a string of the value's type with an indefinite article. + + For example 'an Image' or 'a PlotValue'. + """ + if inspect.isclass(value): + return add_article(value.__name__) + else: + return class_of(type(value)) + + +def add_article(name, definite=False, capital=False): + """Returns the string with a prepended article. + + The input does not need to begin with a charater. + + Parameters + ---------- + name : str + Name to which to prepend an article + definite : bool (default: False) + Whether the article is definite or not. + Indefinite articles being 'a' and 'an', + while 'the' is definite. + capital : bool (default: False) + Whether the added article should have + its first letter capitalized or not. + """ + if definite: + result = "the " + name + else: + first_letters = re.compile(r"[\W_]+").sub("", name) + if first_letters[:1].lower() in "aeiou": + result = "an " + name + else: + result = "a " + name + if capital: + return result[0].upper() + result[1:] + else: + return result + + +def repr_type(obj): + """Return a string representation of a value and its type for readable + + error messages. + """ + the_type = type(obj) + msg = f"{obj!r} {the_type!r}" + return msg diff --git a/contrib/python/traitlets/py3/traitlets/utils/getargspec.py b/contrib/python/traitlets/py3/traitlets/utils/getargspec.py new file mode 100644 index 0000000000..e2b1f235c8 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/getargspec.py @@ -0,0 +1,49 @@ +""" + getargspec excerpted from: + + sphinx.util.inspect + ~~~~~~~~~~~~~~~~~~~ + Helpers for inspecting Python modules. + :copyright: Copyright 2007-2015 by the Sphinx team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import inspect +from functools import partial + +# Unmodified from sphinx below this line + + +def getargspec(func): + """Like inspect.getargspec but supports functools.partial as well.""" + if inspect.ismethod(func): + func = func.__func__ + if type(func) is partial: + orig_func = func.func + argspec = getargspec(orig_func) + args = list(argspec[0]) + defaults = list(argspec[3] or ()) + kwoargs = list(argspec[4]) + kwodefs = dict(argspec[5] or {}) + if func.args: + args = args[len(func.args) :] + for arg in func.keywords or (): + try: + i = args.index(arg) - len(args) + del args[i] + try: + del defaults[i] + except IndexError: + pass + except ValueError: # must be a kwonly arg + i = kwoargs.index(arg) + del kwoargs[i] + del kwodefs[arg] + return inspect.FullArgSpec( + args, argspec[1], argspec[2], tuple(defaults), kwoargs, kwodefs, argspec[6] + ) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if not inspect.isfunction(func): + raise TypeError("%r is not a Python function" % func) + return inspect.getfullargspec(func) diff --git a/contrib/python/traitlets/py3/traitlets/utils/importstring.py b/contrib/python/traitlets/py3/traitlets/utils/importstring.py new file mode 100644 index 0000000000..7ac1e9aba7 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/importstring.py @@ -0,0 +1,38 @@ +""" +A simple utility to import something by its string name. +""" +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + + +def import_item(name): + """Import and return ``bar`` given the string ``foo.bar``. + + Calling ``bar = import_item("foo.bar")`` is the functional equivalent of + executing the code ``from foo import bar``. + + Parameters + ---------- + name : string + The fully qualified name of the module/package being imported. + + Returns + ------- + mod : module object + The module that was imported. + """ + if not isinstance(name, str): + raise TypeError("import_item accepts strings, not '%s'." % type(name)) + parts = name.rsplit(".", 1) + if len(parts) == 2: + # called with 'foo.bar....' + package, obj = parts + module = __import__(package, fromlist=[obj]) + try: + pak = getattr(module, obj) + except AttributeError as e: + raise ImportError("No module named %s" % obj) from e + return pak + else: + # called with un-dotted string + return __import__(parts[0]) diff --git a/contrib/python/traitlets/py3/traitlets/utils/nested_update.py b/contrib/python/traitlets/py3/traitlets/utils/nested_update.py new file mode 100644 index 0000000000..7f09e171a3 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/nested_update.py @@ -0,0 +1,38 @@ +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + + +def nested_update(this, that): + """Merge two nested dictionaries. + + Effectively a recursive ``dict.update``. + + Examples + -------- + Merge two flat dictionaries: + >>> nested_update( + ... {'a': 1, 'b': 2}, + ... {'b': 3, 'c': 4} + ... ) + {'a': 1, 'b': 3, 'c': 4} + + Merge two nested dictionaries: + >>> nested_update( + ... {'x': {'a': 1, 'b': 2}, 'y': 5, 'z': 6}, + ... {'x': {'b': 3, 'c': 4}, 'z': 7, '0': 8}, + ... ) + {'x': {'a': 1, 'b': 3, 'c': 4}, 'y': 5, 'z': 7, '0': 8} + + """ + for key, value in this.items(): + if isinstance(value, dict): + if key in that and isinstance(that[key], dict): + nested_update(this[key], that[key]) + elif key in that: + this[key] = that[key] + + for key, value in that.items(): + if key not in this: + this[key] = value + + return this diff --git a/contrib/python/traitlets/py3/traitlets/utils/sentinel.py b/contrib/python/traitlets/py3/traitlets/utils/sentinel.py new file mode 100644 index 0000000000..75e000f81b --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/sentinel.py @@ -0,0 +1,21 @@ +"""Sentinel class for constants with useful reprs""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. + + +class Sentinel: + def __init__(self, name, module, docstring=None): + self.name = name + self.module = module + if docstring: + self.__doc__ = docstring + + def __repr__(self): + return str(self.module) + "." + self.name + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self diff --git a/contrib/python/traitlets/py3/traitlets/utils/tests/test_bunch.py b/contrib/python/traitlets/py3/traitlets/utils/tests/test_bunch.py new file mode 100644 index 0000000000..223124d7d5 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/tests/test_bunch.py @@ -0,0 +1,16 @@ +from traitlets.utils.bunch import Bunch + + +def test_bunch(): + b = Bunch(x=5, y=10) + assert "y" in b + assert "x" in b + assert b.x == 5 + b["a"] = "hi" + assert b.a == "hi" + + +def test_bunch_dir(): + b = Bunch(x=5, y=10) + assert "x" in dir(b) + assert "keys" in dir(b) diff --git a/contrib/python/traitlets/py3/traitlets/utils/tests/test_decorators.py b/contrib/python/traitlets/py3/traitlets/utils/tests/test_decorators.py new file mode 100644 index 0000000000..5410c20137 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/tests/test_decorators.py @@ -0,0 +1,137 @@ +from inspect import Parameter, signature +from unittest import TestCase + +from traitlets.traitlets import HasTraits, Int, Unicode +from traitlets.utils.decorators import signature_has_traits + + +class TestExpandSignature(TestCase): + def test_no_init(self): + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + value = Unicode("Hello") + + parameters = signature(Foo).parameters + parameter_names = list(parameters) + + self.assertIs(parameters["args"].kind, Parameter.VAR_POSITIONAL) + self.assertEqual("args", parameter_names[0]) + + self.assertIs(parameters["number1"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["number2"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["value"].kind, Parameter.KEYWORD_ONLY) + + self.assertIs(parameters["kwargs"].kind, Parameter.VAR_KEYWORD) + self.assertEqual("kwargs", parameter_names[-1]) + + f = Foo(number1=32, value="World") + self.assertEqual(f.number1, 32) + self.assertEqual(f.number2, 0) + self.assertEqual(f.value, "World") + + def test_partial_init(self): + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + value = Unicode("Hello") + + def __init__(self, arg1, **kwargs): + self.arg1 = arg1 + + super().__init__(**kwargs) + + parameters = signature(Foo).parameters + parameter_names = list(parameters) + + self.assertIs(parameters["arg1"].kind, Parameter.POSITIONAL_OR_KEYWORD) + self.assertEqual("arg1", parameter_names[0]) + + self.assertIs(parameters["number1"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["number2"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["value"].kind, Parameter.KEYWORD_ONLY) + + self.assertIs(parameters["kwargs"].kind, Parameter.VAR_KEYWORD) + self.assertEqual("kwargs", parameter_names[-1]) + + f = Foo(1, number1=32, value="World") + self.assertEqual(f.arg1, 1) + self.assertEqual(f.number1, 32) + self.assertEqual(f.number2, 0) + self.assertEqual(f.value, "World") + + def test_duplicate_init(self): + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + + def __init__(self, number1, **kwargs): + self.test = number1 + + super().__init__(number1=number1, **kwargs) + + parameters = signature(Foo).parameters + parameter_names = list(parameters) + + self.assertListEqual(parameter_names, ["number1", "number2", "kwargs"]) + + f = Foo(number1=32, number2=36) + self.assertEqual(f.test, 32) + self.assertEqual(f.number1, 32) + self.assertEqual(f.number2, 36) + + def test_full_init(self): + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + value = Unicode("Hello") + + def __init__(self, arg1, arg2=None, *pos_args, **kw_args): + self.arg1 = arg1 + self.arg2 = arg2 + self.pos_args = pos_args + self.kw_args = kw_args + + super().__init__(*pos_args, **kw_args) + + parameters = signature(Foo).parameters + parameter_names = list(parameters) + + self.assertIs(parameters["arg1"].kind, Parameter.POSITIONAL_OR_KEYWORD) + self.assertEqual("arg1", parameter_names[0]) + + self.assertIs(parameters["arg2"].kind, Parameter.POSITIONAL_OR_KEYWORD) + self.assertEqual("arg2", parameter_names[1]) + + self.assertIs(parameters["pos_args"].kind, Parameter.VAR_POSITIONAL) + self.assertEqual("pos_args", parameter_names[2]) + + self.assertIs(parameters["number1"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["number2"].kind, Parameter.KEYWORD_ONLY) + self.assertIs(parameters["value"].kind, Parameter.KEYWORD_ONLY) + + self.assertIs(parameters["kw_args"].kind, Parameter.VAR_KEYWORD) + self.assertEqual("kw_args", parameter_names[-1]) + + f = Foo(1, 3, 45, "hey", number1=32, value="World") + self.assertEqual(f.arg1, 1) + self.assertEqual(f.arg2, 3) + self.assertTupleEqual(f.pos_args, (45, "hey")) + self.assertEqual(f.number1, 32) + self.assertEqual(f.number2, 0) + self.assertEqual(f.value, "World") + + def test_no_kwargs(self): + with self.assertRaises(RuntimeError): + + @signature_has_traits + class Foo(HasTraits): + number1 = Int() + number2 = Int() + + def __init__(self, arg1, arg2=None): + pass diff --git a/contrib/python/traitlets/py3/traitlets/utils/tests/test_importstring.py b/contrib/python/traitlets/py3/traitlets/utils/tests/test_importstring.py new file mode 100644 index 0000000000..a3a74c3214 --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/tests/test_importstring.py @@ -0,0 +1,26 @@ +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +# +# Adapted from enthought.traits, Copyright (c) Enthought, Inc., +# also under the terms of the Modified BSD License. +"""Tests for traitlets.utils.importstring.""" + +import os +from unittest import TestCase + +from traitlets.utils.importstring import import_item + + +class TestImportItem(TestCase): + def test_import_unicode(self): + self.assertIs(os, import_item("os")) + self.assertIs(os.path, import_item("os.path")) + self.assertIs(os.path.join, import_item("os.path.join")) + + def test_bad_input(self): + class NotAString: + pass + + msg = "import_item accepts strings, not '%s'." % NotAString + with self.assertRaisesRegex(TypeError, msg): + import_item(NotAString()) diff --git a/contrib/python/traitlets/py3/traitlets/utils/text.py b/contrib/python/traitlets/py3/traitlets/utils/text.py new file mode 100644 index 0000000000..c7d49edece --- /dev/null +++ b/contrib/python/traitlets/py3/traitlets/utils/text.py @@ -0,0 +1,40 @@ +""" +Utilities imported from ipython_genutils +""" + +import re +import textwrap +from textwrap import dedent +from textwrap import indent as _indent +from typing import List + + +def indent(val): + res = _indent(val, " ") + return res + + +def wrap_paragraphs(text: str, ncols: int = 80) -> List[str]: + """Wrap multiple paragraphs to fit a specified width. + + This is equivalent to textwrap.wrap, but with support for multiple + paragraphs, as separated by empty lines. + + Returns + ------- + + list of complete paragraphs, wrapped to fill `ncols` columns. + """ + paragraph_re = re.compile(r"\n(\s*\n)+", re.MULTILINE) + text = dedent(text).strip() + paragraphs = paragraph_re.split(text)[::2] # every other entry is space + out_ps = [] + indent_re = re.compile(r"\n\s+", re.MULTILINE) + for p in paragraphs: + # presume indentation that survives dedent is meaningful formatting, + # so don't fill unless text is flush. + if indent_re.search(p) is None: + # wrap paragraph + p = textwrap.fill(p, ncols) + out_ps.append(p) + return out_ps diff --git a/contrib/python/traitlets/py3/ya.make b/contrib/python/traitlets/py3/ya.make new file mode 100644 index 0000000000..fc5ec80cbc --- /dev/null +++ b/contrib/python/traitlets/py3/ya.make @@ -0,0 +1,51 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +PROVIDES(python_traitlets) + +VERSION(5.9.0) + +LICENSE(BSD-3-Clause) + +NO_LINT() + +PY_SRCS( + TOP_LEVEL + traitlets/__init__.py + traitlets/_version.py + traitlets/config/__init__.py + traitlets/config/application.py + traitlets/config/argcomplete_config.py + traitlets/config/configurable.py + traitlets/config/loader.py + traitlets/config/manager.py + traitlets/config/sphinxdoc.py + traitlets/log.py + traitlets/tests/__init__.py + traitlets/tests/_warnings.py + traitlets/tests/utils.py + traitlets/traitlets.py + traitlets/utils/__init__.py + traitlets/utils/bunch.py + traitlets/utils/decorators.py + traitlets/utils/descriptions.py + traitlets/utils/getargspec.py + traitlets/utils/importstring.py + traitlets/utils/nested_update.py + traitlets/utils/sentinel.py + traitlets/utils/text.py +) + +RESOURCE_FILES( + PREFIX contrib/python/traitlets/py3/ + .dist-info/METADATA + .dist-info/top_level.txt + traitlets/py.typed +) + +END() + +RECURSE_FOR_TESTS( + tests +) |