aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/attrs/py3/attr/_compat.py
blob: 46b05ca453773da7f2972f023c4a4f447b44e824 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# SPDX-License-Identifier: MIT

import inspect
import platform
import sys
import threading

from collections.abc import Mapping, Sequence  # noqa: F401
from typing import _GenericAlias


PYPY = platform.python_implementation() == "PyPy"
PY_3_8_PLUS = sys.version_info[:2] >= (3, 8)
PY_3_9_PLUS = sys.version_info[:2] >= (3, 9)
PY310 = sys.version_info[:2] >= (3, 10)
PY_3_12_PLUS = sys.version_info[:2] >= (3, 12)


if sys.version_info < (3, 8):
    try:
        from typing_extensions import Protocol
    except ImportError:  # pragma: no cover
        Protocol = object
else:
    from typing import Protocol  # noqa: F401


class _AnnotationExtractor:
    """
    Extract type annotations from a callable, returning None whenever there
    is none.
    """

    __slots__ = ["sig"]

    def __init__(self, callable):
        try:
            self.sig = inspect.signature(callable)
        except (ValueError, TypeError):  # inspect failed
            self.sig = None

    def get_first_param_type(self):
        """
        Return the type annotation of the first argument if it's not empty.
        """
        if not self.sig:
            return None

        params = list(self.sig.parameters.values())
        if params and params[0].annotation is not inspect.Parameter.empty:
            return params[0].annotation

        return None

    def get_return_type(self):
        """
        Return the return type if it's not empty.
        """
        if (
            self.sig
            and self.sig.return_annotation is not inspect.Signature.empty
        ):
            return self.sig.return_annotation

        return None


# Thread-local global to track attrs instances which are already being repr'd.
# This is needed because there is no other (thread-safe) way to pass info
# about the instances that are already being repr'd through the call stack
# in order to ensure we don't perform infinite recursion.
#
# For instance, if an instance contains a dict which contains that instance,
# we need to know that we're already repr'ing the outside instance from within
# the dict's repr() call.
#
# This lives here rather than in _make.py so that the functions in _make.py
# don't have a direct reference to the thread-local in their globals dict.
# If they have such a reference, it breaks cloudpickle.
repr_context = threading.local()


def get_generic_base(cl):
    """If this is a generic class (A[str]), return the generic base for it."""
    if cl.__class__ is _GenericAlias:
        return cl.__origin__
    return None