aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/itsdangerous/py2/tests/test_itsdangerous/test_jws.py
blob: 87529485eb13ae9a88ca32d56556397b241a8a3c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from datetime import timedelta
from functools import partial

import pytest
from .test_serializer import TestSerializer
from .test_timed import TestTimedSerializer

from itsdangerous.exc import BadData
from itsdangerous.exc import BadHeader
from itsdangerous.exc import BadPayload
from itsdangerous.exc import BadSignature
from itsdangerous.exc import SignatureExpired
from itsdangerous.jws import JSONWebSignatureSerializer
from itsdangerous.jws import TimedJSONWebSignatureSerializer


class TestJWSSerializer(TestSerializer):
    @pytest.fixture()
    def serializer_factory(self):
        return partial(JSONWebSignatureSerializer, secret_key="secret-key")

    test_signer_cls = None
    test_signer_kwargs = None
    test_fallback_signers = None
    test_iter_unsigners = None

    @pytest.mark.parametrize("algorithm_name", ("HS256", "HS384", "HS512", "none"))
    def test_algorithm(self, serializer_factory, algorithm_name):
        serializer = serializer_factory(algorithm_name=algorithm_name)
        assert serializer.loads(serializer.dumps("value")) == "value"

    def test_invalid_algorithm(self, serializer_factory):
        with pytest.raises(NotImplementedError) as exc_info:
            serializer_factory(algorithm_name="invalid")

        assert "not supported" in str(exc_info.value)

    def test_algorithm_mismatch(self, serializer_factory, serializer):
        other = serializer_factory(algorithm_name="HS256")
        other.algorithm = serializer.algorithm
        signed = other.dumps("value")

        with pytest.raises(BadHeader) as exc_info:
            serializer.loads(signed)

        assert "mismatch" in str(exc_info.value)

    @pytest.mark.parametrize(
        ("value", "exc_cls", "match"),
        (
            ("ab", BadPayload, '"."'),
            ("a.b", BadHeader, "base64 decode"),
            ("ew.b", BadPayload, "base64 decode"),
            ("ew.ab", BadData, "malformed"),
            ("W10.ab", BadHeader, "JSON object"),
        ),
    )
    def test_load_payload_exceptions(self, serializer, value, exc_cls, match):
        signer = serializer.make_signer()
        signed = signer.sign(value)

        with pytest.raises(exc_cls) as exc_info:
            serializer.loads(signed)

        assert match in str(exc_info.value)


class TestTimedJWSSerializer(TestJWSSerializer, TestTimedSerializer):
    @pytest.fixture()
    def serializer_factory(self):
        return partial(
            TimedJSONWebSignatureSerializer, secret_key="secret-key", expires_in=10
        )

    def test_default_expires_in(self, serializer_factory):
        serializer = serializer_factory(expires_in=None)
        assert serializer.expires_in == serializer.DEFAULT_EXPIRES_IN

    test_max_age = None

    def test_exp(self, serializer, value, ts, freeze):
        signed = serializer.dumps(value)
        freeze.tick()
        assert serializer.loads(signed) == value
        freeze.tick(timedelta(seconds=10))

        with pytest.raises(SignatureExpired) as exc_info:
            serializer.loads(signed)

        assert exc_info.value.date_signed == ts
        assert exc_info.value.payload == value

    test_return_payload = None

    def test_return_header(self, serializer, value, ts):
        signed = serializer.dumps(value)
        payload, header = serializer.loads(signed, return_header=True)
        date_signed = serializer.get_issue_date(header)
        assert (payload, date_signed) == (value, ts)

    def test_missing_exp(self, serializer):
        header = serializer.make_header(None)
        del header["exp"]
        signer = serializer.make_signer()
        signed = signer.sign(serializer.dump_payload(header, "value"))

        with pytest.raises(BadSignature):
            serializer.loads(signed)

    @pytest.mark.parametrize("exp", ("invalid", -1))
    def test_invalid_exp(self, serializer, exp):
        header = serializer.make_header(None)
        header["exp"] = exp
        signer = serializer.make_signer()
        signed = signer.sign(serializer.dump_payload(header, "value"))

        with pytest.raises(BadHeader) as exc_info:
            serializer.loads(signed)

        assert "IntDate" in str(exc_info.value)

    def test_invalid_iat(self, serializer):
        header = serializer.make_header(None)
        header["iat"] = "invalid"
        assert serializer.get_issue_date(header) is None