diff options
author | alexv-smirnov <alex@ydb.tech> | 2023-12-01 12:02:50 +0300 |
---|---|---|
committer | alexv-smirnov <alex@ydb.tech> | 2023-12-01 13:28:10 +0300 |
commit | 0e578a4c44d4abd539d9838347b9ebafaca41dfb (patch) | |
tree | a0c1969c37f818c830ebeff9c077eacf30be6ef8 /contrib/python/oauthlib/tests | |
parent | 84f2d3d4cc985e63217cff149bd2e6d67ae6fe22 (diff) | |
download | ydb-0e578a4c44d4abd539d9838347b9ebafaca41dfb.tar.gz |
Change "ya.make"
Diffstat (limited to 'contrib/python/oauthlib/tests')
71 files changed, 8963 insertions, 0 deletions
diff --git a/contrib/python/oauthlib/tests/__init__.py b/contrib/python/oauthlib/tests/__init__.py new file mode 100644 index 0000000000..f33236b5ee --- /dev/null +++ b/contrib/python/oauthlib/tests/__init__.py @@ -0,0 +1,3 @@ +import oauthlib + +oauthlib.set_debug(True) diff --git a/contrib/python/oauthlib/tests/oauth1/__init__.py b/contrib/python/oauthlib/tests/oauth1/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/__init__.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/__init__.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_access_token.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_access_token.py new file mode 100644 index 0000000000..57d8117531 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_access_token.py @@ -0,0 +1,91 @@ +from unittest.mock import ANY, MagicMock + +from oauthlib.oauth1 import RequestValidator +from oauthlib.oauth1.rfc5849 import Client +from oauthlib.oauth1.rfc5849.endpoints import AccessTokenEndpoint + +from tests.unittest import TestCase + + +class AccessTokenEndpointTest(TestCase): + + def setUp(self): + self.validator = MagicMock(wraps=RequestValidator()) + self.validator.check_client_key.return_value = True + self.validator.check_request_token.return_value = True + self.validator.check_verifier.return_value = True + self.validator.allowed_signature_methods = ['HMAC-SHA1'] + self.validator.get_client_secret.return_value = 'bar' + self.validator.get_request_token_secret.return_value = 'secret' + self.validator.get_realms.return_value = ['foo'] + self.validator.timestamp_lifetime = 600 + self.validator.validate_client_key.return_value = True + self.validator.validate_request_token.return_value = True + self.validator.validate_verifier.return_value = True + self.validator.validate_timestamp_and_nonce.return_value = True + self.validator.invalidate_request_token.return_value = True + self.validator.dummy_client = 'dummy' + self.validator.dummy_secret = 'dummy' + self.validator.dummy_request_token = 'dummy' + self.validator.save_access_token = MagicMock() + self.endpoint = AccessTokenEndpoint(self.validator) + self.client = Client('foo', + client_secret='bar', + resource_owner_key='token', + resource_owner_secret='secret', + verifier='verfier') + self.uri, self.headers, self.body = self.client.sign( + 'https://i.b/access_token') + + def test_check_request_token(self): + self.validator.check_request_token.return_value = False + h, b, s = self.endpoint.create_access_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_check_verifier(self): + self.validator.check_verifier.return_value = False + h, b, s = self.endpoint.create_access_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_validate_client_key(self): + self.validator.validate_client_key.return_value = False + h, b, s = self.endpoint.create_access_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 401) + + def test_validate_request_token(self): + self.validator.validate_request_token.return_value = False + h, b, s = self.endpoint.create_access_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 401) + + def test_validate_verifier(self): + self.validator.validate_verifier.return_value = False + h, b, s = self.endpoint.create_access_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 401) + + def test_validate_signature(self): + client = Client('foo', + resource_owner_key='token', + resource_owner_secret='secret', + verifier='verfier') + _, headers, _ = client.sign(self.uri + '/extra') + h, b, s = self.endpoint.create_access_token_response( + self.uri, headers=headers) + self.assertEqual(s, 401) + + def test_valid_request(self): + h, b, s = self.endpoint.create_access_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 200) + self.assertIn('oauth_token', b) + self.validator.validate_timestamp_and_nonce.assert_called_once_with( + self.client.client_key, ANY, ANY, ANY, + request_token=self.client.resource_owner_key) + self.validator.invalidate_request_token.assert_called_once_with( + self.client.client_key, self.client.resource_owner_key, ANY) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_authorization.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_authorization.py new file mode 100644 index 0000000000..a9b2fc0c9f --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_authorization.py @@ -0,0 +1,54 @@ +from unittest.mock import MagicMock + +from oauthlib.oauth1 import RequestValidator +from oauthlib.oauth1.rfc5849 import errors +from oauthlib.oauth1.rfc5849.endpoints import AuthorizationEndpoint + +from tests.unittest import TestCase + + +class AuthorizationEndpointTest(TestCase): + + def setUp(self): + self.validator = MagicMock(wraps=RequestValidator()) + self.validator.verify_request_token.return_value = True + self.validator.verify_realms.return_value = True + self.validator.get_realms.return_value = ['test'] + self.validator.save_verifier = MagicMock() + self.endpoint = AuthorizationEndpoint(self.validator) + self.uri = 'https://i.b/authorize?oauth_token=foo' + + def test_get_realms_and_credentials(self): + realms, credentials = self.endpoint.get_realms_and_credentials(self.uri) + self.assertEqual(realms, ['test']) + + def test_verify_token(self): + self.validator.verify_request_token.return_value = False + self.assertRaises(errors.InvalidClientError, + self.endpoint.get_realms_and_credentials, self.uri) + self.assertRaises(errors.InvalidClientError, + self.endpoint.create_authorization_response, self.uri) + + def test_verify_realms(self): + self.validator.verify_realms.return_value = False + self.assertRaises(errors.InvalidRequestError, + self.endpoint.create_authorization_response, + self.uri, + realms=['bar']) + + def test_create_authorization_response(self): + self.validator.get_redirect_uri.return_value = 'https://c.b/cb' + h, b, s = self.endpoint.create_authorization_response(self.uri) + self.assertEqual(s, 302) + self.assertIn('Location', h) + location = h['Location'] + self.assertTrue(location.startswith('https://c.b/cb')) + self.assertIn('oauth_verifier', location) + + def test_create_authorization_response_oob(self): + self.validator.get_redirect_uri.return_value = 'oob' + h, b, s = self.endpoint.create_authorization_response(self.uri) + self.assertEqual(s, 200) + self.assertNotIn('Location', h) + self.assertIn('oauth_verifier', b) + self.assertIn('oauth_token', b) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_base.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_base.py new file mode 100644 index 0000000000..e87f359baa --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_base.py @@ -0,0 +1,406 @@ +from re import sub +from unittest.mock import MagicMock + +from oauthlib.common import CaseInsensitiveDict, safe_string_equals +from oauthlib.oauth1 import Client, RequestValidator +from oauthlib.oauth1.rfc5849 import ( + SIGNATURE_HMAC, SIGNATURE_PLAINTEXT, SIGNATURE_RSA, errors, +) +from oauthlib.oauth1.rfc5849.endpoints import ( + BaseEndpoint, RequestTokenEndpoint, +) + +from tests.unittest import TestCase + +URLENCODED = {"Content-Type": "application/x-www-form-urlencoded"} + + +class BaseEndpointTest(TestCase): + + def setUp(self): + self.validator = MagicMock(spec=RequestValidator) + self.validator.allowed_signature_methods = ['HMAC-SHA1'] + self.validator.timestamp_lifetime = 600 + self.endpoint = RequestTokenEndpoint(self.validator) + self.client = Client('foo', callback_uri='https://c.b/cb') + self.uri, self.headers, self.body = self.client.sign( + 'https://i.b/request_token') + + def test_ssl_enforcement(self): + uri, headers, _ = self.client.sign('http://i.b/request_token') + h, b, s = self.endpoint.create_request_token_response( + uri, headers=headers) + self.assertEqual(s, 400) + self.assertIn('insecure_transport_protocol', b) + + def test_missing_parameters(self): + h, b, s = self.endpoint.create_request_token_response(self.uri) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_signature_methods(self): + headers = {} + headers['Authorization'] = self.headers['Authorization'].replace( + 'HMAC', 'RSA') + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=headers) + self.assertEqual(s, 400) + self.assertIn('invalid_signature_method', b) + + def test_invalid_version(self): + headers = {} + headers['Authorization'] = self.headers['Authorization'].replace( + '1.0', '2.0') + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_expired_timestamp(self): + headers = {} + for pattern in ('12345678901', '4567890123', '123456789K'): + headers['Authorization'] = sub(r'timestamp="\d*k?"', + 'timestamp="%s"' % pattern, + self.headers['Authorization']) + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_client_key_check(self): + self.validator.check_client_key.return_value = False + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_noncecheck(self): + self.validator.check_nonce.return_value = False + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_enforce_ssl(self): + """Ensure SSL is enforced by default.""" + v = RequestValidator() + e = BaseEndpoint(v) + c = Client('foo') + u, h, b = c.sign('http://example.com') + r = e._create_request(u, 'GET', b, h) + self.assertRaises(errors.InsecureTransportError, + e._check_transport_security, r) + + def test_multiple_source_params(self): + """Check for duplicate params""" + v = RequestValidator() + e = BaseEndpoint(v) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/?oauth_signature_method=HMAC-SHA1', + 'GET', 'oauth_version=foo', URLENCODED) + headers = {'Authorization': 'OAuth oauth_signature="foo"'} + headers.update(URLENCODED) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/?oauth_signature_method=HMAC-SHA1', + 'GET', + 'oauth_version=foo', + headers) + headers = {'Authorization': 'OAuth oauth_signature_method="foo"'} + headers.update(URLENCODED) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/', + 'GET', + 'oauth_signature=foo', + headers) + + def test_duplicate_params(self): + """Ensure params are only supplied once""" + v = RequestValidator() + e = BaseEndpoint(v) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/?oauth_version=a&oauth_version=b', + 'GET', None, URLENCODED) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/', 'GET', 'oauth_version=a&oauth_version=b', + URLENCODED) + + def test_mandated_params(self): + """Ensure all mandatory params are present.""" + v = RequestValidator() + e = BaseEndpoint(v) + r = e._create_request('https://a.b/', 'GET', + 'oauth_signature=a&oauth_consumer_key=b&oauth_nonce', + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + def test_oauth_version(self): + """OAuth version must be 1.0 if present.""" + v = RequestValidator() + e = BaseEndpoint(v) + r = e._create_request('https://a.b/', 'GET', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_timestamp=a&oauth_signature_method=RSA-SHA1&' + 'oauth_version=2.0'), + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + def test_oauth_timestamp(self): + """Check for a valid UNIX timestamp.""" + v = RequestValidator() + e = BaseEndpoint(v) + + # Invalid timestamp length, must be 10 + r = e._create_request('https://a.b/', 'GET', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&' + 'oauth_timestamp=123456789'), + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + # Invalid timestamp age, must be younger than 10 minutes + r = e._create_request('https://a.b/', 'GET', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&' + 'oauth_timestamp=1234567890'), + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + # Timestamp must be an integer + r = e._create_request('https://a.b/', 'GET', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&' + 'oauth_timestamp=123456789a'), + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + def test_case_insensitive_headers(self): + """Ensure headers are case-insensitive""" + v = RequestValidator() + e = BaseEndpoint(v) + r = e._create_request('https://a.b', 'POST', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&' + 'oauth_timestamp=123456789a'), + URLENCODED) + self.assertIsInstance(r.headers, CaseInsensitiveDict) + + def test_signature_method_validation(self): + """Ensure valid signature method is used.""" + + body = ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=%s&' + 'oauth_timestamp=1234567890') + + uri = 'https://example.com/' + + class HMACValidator(RequestValidator): + + @property + def allowed_signature_methods(self): + return (SIGNATURE_HMAC,) + + v = HMACValidator() + e = BaseEndpoint(v) + r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + + class RSAValidator(RequestValidator): + + @property + def allowed_signature_methods(self): + return (SIGNATURE_RSA,) + + v = RSAValidator() + e = BaseEndpoint(v) + r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + + class PlainValidator(RequestValidator): + + @property + def allowed_signature_methods(self): + return (SIGNATURE_PLAINTEXT,) + + v = PlainValidator() + e = BaseEndpoint(v) + r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + + +class ClientValidator(RequestValidator): + clients = ['foo'] + nonces = [('foo', 'once', '1234567891', 'fez')] + owners = {'foo': ['abcdefghijklmnopqrstuvxyz', 'fez']} + assigned_realms = {('foo', 'abcdefghijklmnopqrstuvxyz'): 'photos'} + verifiers = {('foo', 'fez'): 'shibboleth'} + + @property + def client_key_length(self): + return 1, 30 + + @property + def request_token_length(self): + return 1, 30 + + @property + def access_token_length(self): + return 1, 30 + + @property + def nonce_length(self): + return 2, 30 + + @property + def verifier_length(self): + return 2, 30 + + @property + def realms(self): + return ['photos'] + + @property + def timestamp_lifetime(self): + # Disabled check to allow hardcoded verification signatures + return 1000000000 + + @property + def dummy_client(self): + return 'dummy' + + @property + def dummy_request_token(self): + return 'dumbo' + + @property + def dummy_access_token(self): + return 'dumbo' + + def validate_timestamp_and_nonce(self, client_key, timestamp, nonce, + request, request_token=None, access_token=None): + resource_owner_key = request_token if request_token else access_token + return not (client_key, nonce, timestamp, resource_owner_key) in self.nonces + + def validate_client_key(self, client_key): + return client_key in self.clients + + def validate_access_token(self, client_key, access_token, request): + return (self.owners.get(client_key) and + access_token in self.owners.get(client_key)) + + def validate_request_token(self, client_key, request_token, request): + return (self.owners.get(client_key) and + request_token in self.owners.get(client_key)) + + def validate_requested_realm(self, client_key, realm, request): + return True + + def validate_realm(self, client_key, access_token, request, uri=None, + required_realm=None): + return (client_key, access_token) in self.assigned_realms + + def validate_verifier(self, client_key, request_token, verifier, + request): + return ((client_key, request_token) in self.verifiers and + safe_string_equals(verifier, self.verifiers.get( + (client_key, request_token)))) + + def validate_redirect_uri(self, client_key, redirect_uri, request): + return redirect_uri.startswith('http://client.example.com/') + + def get_client_secret(self, client_key, request): + return 'super secret' + + def get_access_token_secret(self, client_key, access_token, request): + return 'even more secret' + + def get_request_token_secret(self, client_key, request_token, request): + return 'even more secret' + + def get_rsa_key(self, client_key, request): + return ("-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNA" + "DCBiQKBgQDVLQCATX8iK+aZuGVdkGb6uiar\nLi/jqFwL1dYj0JLIsdQc" + "KaMWtPC06K0+vI+RRZcjKc6sNB9/7kJcKN9Ekc9BUxyT\n/D09Cz47cmC" + "YsUoiW7G8NSqbE4wPiVpGkJRzFAxaCWwOSSQ+lpC9vwxnvVQfOoZ1\nnp" + "mWbCdA0iTxsMahwQIDAQAB\n-----END PUBLIC KEY-----") + + +class SignatureVerificationTest(TestCase): + + def setUp(self): + v = ClientValidator() + self.e = BaseEndpoint(v) + + self.uri = 'https://example.com/' + self.sig = ('oauth_signature=%s&' + 'oauth_timestamp=1234567890&' + 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&' + 'oauth_version=1.0&' + 'oauth_signature_method=%s&' + 'oauth_token=abcdefghijklmnopqrstuvxyz&' + 'oauth_consumer_key=foo') + + def test_signature_too_short(self): + short_sig = ('oauth_signature=fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY&' + 'oauth_timestamp=1234567890&' + 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&' + 'oauth_version=1.0&oauth_signature_method=HMAC-SHA1&' + 'oauth_token=abcdefghijklmnopqrstuvxyz&' + 'oauth_consumer_key=foo') + r = self.e._create_request(self.uri, 'GET', short_sig, URLENCODED) + self.assertFalse(self.e._check_signature(r)) + + plain = ('oauth_signature=correctlengthbutthewrongcontent1111&' + 'oauth_timestamp=1234567890&' + 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&' + 'oauth_version=1.0&oauth_signature_method=PLAINTEXT&' + 'oauth_token=abcdefghijklmnopqrstuvxyz&' + 'oauth_consumer_key=foo') + r = self.e._create_request(self.uri, 'GET', plain, URLENCODED) + self.assertFalse(self.e._check_signature(r)) + + def test_hmac_signature(self): + hmac_sig = "fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY%3D" + sig = self.sig % (hmac_sig, "HMAC-SHA1") + r = self.e._create_request(self.uri, 'GET', sig, URLENCODED) + self.assertTrue(self.e._check_signature(r)) + + def test_rsa_signature(self): + rsa_sig = ("fxFvCx33oKlR9wDquJ%2FPsndFzJphyBa3RFPPIKi3flqK%2BJ7yIrMVbH" + "YTM%2FLHPc7NChWz4F4%2FzRA%2BDN1k08xgYGSBoWJUOW6VvOQ6fbYhMA" + "FkOGYbuGDbje487XMzsAcv6ZjqZHCROSCk5vofgLk2SN7RZ3OrgrFzf4in" + "xetClqA%3D") + sig = self.sig % (rsa_sig, "RSA-SHA1") + r = self.e._create_request(self.uri, 'GET', sig, URLENCODED) + self.assertTrue(self.e._check_signature(r)) + + def test_plaintext_signature(self): + plain_sig = "super%252520secret%26even%252520more%252520secret" + sig = self.sig % (plain_sig, "PLAINTEXT") + r = self.e._create_request(self.uri, 'GET', sig, URLENCODED) + self.assertTrue(self.e._check_signature(r)) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_request_token.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_request_token.py new file mode 100644 index 0000000000..879cad2f48 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_request_token.py @@ -0,0 +1,90 @@ +from unittest.mock import ANY, MagicMock + +from oauthlib.oauth1 import RequestValidator +from oauthlib.oauth1.rfc5849 import Client +from oauthlib.oauth1.rfc5849.endpoints import RequestTokenEndpoint + +from tests.unittest import TestCase + + +class RequestTokenEndpointTest(TestCase): + + def setUp(self): + self.validator = MagicMock(wraps=RequestValidator()) + self.validator.check_client_key.return_value = True + self.validator.allowed_signature_methods = ['HMAC-SHA1'] + self.validator.get_client_secret.return_value = 'bar' + self.validator.get_default_realms.return_value = ['foo'] + self.validator.timestamp_lifetime = 600 + self.validator.check_realms.return_value = True + self.validator.validate_client_key.return_value = True + self.validator.validate_requested_realms.return_value = True + self.validator.validate_redirect_uri.return_value = True + self.validator.validate_timestamp_and_nonce.return_value = True + self.validator.dummy_client = 'dummy' + self.validator.dummy_secret = 'dummy' + self.validator.save_request_token = MagicMock() + self.endpoint = RequestTokenEndpoint(self.validator) + self.client = Client('foo', client_secret='bar', realm='foo', + callback_uri='https://c.b/cb') + self.uri, self.headers, self.body = self.client.sign( + 'https://i.b/request_token') + + def test_check_redirect_uri(self): + client = Client('foo') + uri, headers, _ = client.sign(self.uri) + h, b, s = self.endpoint.create_request_token_response( + uri, headers=headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_check_realms(self): + self.validator.check_realms.return_value = False + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_validate_client_key(self): + self.validator.validate_client_key.return_value = False + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 401) + + def test_validate_realms(self): + self.validator.validate_requested_realms.return_value = False + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 401) + + def test_validate_redirect_uri(self): + self.validator.validate_redirect_uri.return_value = False + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 401) + + def test_validate_signature(self): + client = Client('foo', callback_uri='https://c.b/cb') + _, headers, _ = client.sign(self.uri + '/extra') + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=headers) + self.assertEqual(s, 401) + + def test_valid_request(self): + h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 200) + self.assertIn('oauth_token', b) + self.validator.validate_timestamp_and_nonce.assert_called_once_with( + self.client.client_key, ANY, ANY, ANY, + request_token=self.client.resource_owner_key) + + def test_uri_provided_realm(self): + client = Client('foo', callback_uri='https://c.b/cb', + client_secret='bar') + uri = self.uri + '?realm=foo' + _, headers, _ = client.sign(uri) + h, b, s = self.endpoint.create_request_token_response( + uri, headers=headers) + self.assertEqual(s, 200) + self.assertIn('oauth_token', b) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_resource.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_resource.py new file mode 100644 index 0000000000..416216f737 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_resource.py @@ -0,0 +1,102 @@ +from unittest.mock import ANY, MagicMock + +from oauthlib.oauth1 import RequestValidator +from oauthlib.oauth1.rfc5849 import Client +from oauthlib.oauth1.rfc5849.endpoints import ResourceEndpoint + +from tests.unittest import TestCase + + +class ResourceEndpointTest(TestCase): + + def setUp(self): + self.validator = MagicMock(wraps=RequestValidator()) + self.validator.check_client_key.return_value = True + self.validator.check_access_token.return_value = True + self.validator.allowed_signature_methods = ['HMAC-SHA1'] + self.validator.get_client_secret.return_value = 'bar' + self.validator.get_access_token_secret.return_value = 'secret' + self.validator.timestamp_lifetime = 600 + self.validator.validate_client_key.return_value = True + self.validator.validate_access_token.return_value = True + self.validator.validate_timestamp_and_nonce.return_value = True + self.validator.validate_realms.return_value = True + self.validator.dummy_client = 'dummy' + self.validator.dummy_secret = 'dummy' + self.validator.dummy_access_token = 'dummy' + self.endpoint = ResourceEndpoint(self.validator) + self.client = Client('foo', + client_secret='bar', + resource_owner_key='token', + resource_owner_secret='secret') + self.uri, self.headers, self.body = self.client.sign( + 'https://i.b/protected_resource') + + def test_missing_parameters(self): + self.validator.check_access_token.return_value = False + v, r = self.endpoint.validate_protected_resource_request( + self.uri) + self.assertFalse(v) + + def test_check_access_token(self): + self.validator.check_access_token.return_value = False + v, r = self.endpoint.validate_protected_resource_request( + self.uri, headers=self.headers) + self.assertFalse(v) + + def test_validate_client_key(self): + self.validator.validate_client_key.return_value = False + v, r = self.endpoint.validate_protected_resource_request( + self.uri, headers=self.headers) + self.assertFalse(v) + # the validator log should have `False` values + self.assertFalse(r.validator_log['client']) + self.assertTrue(r.validator_log['realm']) + self.assertTrue(r.validator_log['resource_owner']) + self.assertTrue(r.validator_log['signature']) + + def test_validate_access_token(self): + self.validator.validate_access_token.return_value = False + v, r = self.endpoint.validate_protected_resource_request( + self.uri, headers=self.headers) + self.assertFalse(v) + # the validator log should have `False` values + self.assertTrue(r.validator_log['client']) + self.assertTrue(r.validator_log['realm']) + self.assertFalse(r.validator_log['resource_owner']) + self.assertTrue(r.validator_log['signature']) + + def test_validate_realms(self): + self.validator.validate_realms.return_value = False + v, r = self.endpoint.validate_protected_resource_request( + self.uri, headers=self.headers) + self.assertFalse(v) + # the validator log should have `False` values + self.assertTrue(r.validator_log['client']) + self.assertFalse(r.validator_log['realm']) + self.assertTrue(r.validator_log['resource_owner']) + self.assertTrue(r.validator_log['signature']) + + def test_validate_signature(self): + client = Client('foo', + resource_owner_key='token', + resource_owner_secret='secret') + _, headers, _ = client.sign(self.uri + '/extra') + v, r = self.endpoint.validate_protected_resource_request( + self.uri, headers=headers) + self.assertFalse(v) + # the validator log should have `False` values + self.assertTrue(r.validator_log['client']) + self.assertTrue(r.validator_log['realm']) + self.assertTrue(r.validator_log['resource_owner']) + self.assertFalse(r.validator_log['signature']) + + def test_valid_request(self): + v, r = self.endpoint.validate_protected_resource_request( + self.uri, headers=self.headers) + self.assertTrue(v) + self.validator.validate_timestamp_and_nonce.assert_called_once_with( + self.client.client_key, ANY, ANY, ANY, + access_token=self.client.resource_owner_key) + # everything in the validator_log should be `True` + self.assertTrue(all(r.validator_log.items())) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_signature_only.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_signature_only.py new file mode 100644 index 0000000000..16585bd580 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/endpoints/test_signature_only.py @@ -0,0 +1,50 @@ +from unittest.mock import ANY, MagicMock + +from oauthlib.oauth1 import RequestValidator +from oauthlib.oauth1.rfc5849 import Client +from oauthlib.oauth1.rfc5849.endpoints import SignatureOnlyEndpoint + +from tests.unittest import TestCase + + +class SignatureOnlyEndpointTest(TestCase): + + def setUp(self): + self.validator = MagicMock(wraps=RequestValidator()) + self.validator.check_client_key.return_value = True + self.validator.allowed_signature_methods = ['HMAC-SHA1'] + self.validator.get_client_secret.return_value = 'bar' + self.validator.timestamp_lifetime = 600 + self.validator.validate_client_key.return_value = True + self.validator.validate_timestamp_and_nonce.return_value = True + self.validator.dummy_client = 'dummy' + self.validator.dummy_secret = 'dummy' + self.endpoint = SignatureOnlyEndpoint(self.validator) + self.client = Client('foo', client_secret='bar') + self.uri, self.headers, self.body = self.client.sign( + 'https://i.b/protected_resource') + + def test_missing_parameters(self): + v, r = self.endpoint.validate_request( + self.uri) + self.assertFalse(v) + + def test_validate_client_key(self): + self.validator.validate_client_key.return_value = False + v, r = self.endpoint.validate_request( + self.uri, headers=self.headers) + self.assertFalse(v) + + def test_validate_signature(self): + client = Client('foo') + _, headers, _ = client.sign(self.uri + '/extra') + v, r = self.endpoint.validate_request( + self.uri, headers=headers) + self.assertFalse(v) + + def test_valid_request(self): + v, r = self.endpoint.validate_request( + self.uri, headers=self.headers) + self.assertTrue(v) + self.validator.validate_timestamp_and_nonce.assert_called_once_with( + self.client.client_key, ANY, ANY, ANY) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/test_client.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_client.py new file mode 100644 index 0000000000..f7c997f509 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_client.py @@ -0,0 +1,269 @@ +# -*- coding: utf-8 -*- +from oauthlib.common import Request +from oauthlib.oauth1 import ( + SIGNATURE_HMAC_SHA1, SIGNATURE_HMAC_SHA256, SIGNATURE_PLAINTEXT, + SIGNATURE_RSA, SIGNATURE_TYPE_BODY, SIGNATURE_TYPE_QUERY, +) +from oauthlib.oauth1.rfc5849 import Client + +from tests.unittest import TestCase + + +class ClientRealmTests(TestCase): + + def test_client_no_realm(self): + client = Client("client-key") + uri, header, body = client.sign("http://example-uri") + self.assertTrue( + header["Authorization"].startswith('OAuth oauth_nonce=')) + + def test_client_realm_sign_with_default_realm(self): + client = Client("client-key", realm="moo-realm") + self.assertEqual(client.realm, "moo-realm") + uri, header, body = client.sign("http://example-uri") + self.assertTrue( + header["Authorization"].startswith('OAuth realm="moo-realm",')) + + def test_client_realm_sign_with_additional_realm(self): + client = Client("client-key", realm="moo-realm") + uri, header, body = client.sign("http://example-uri", realm="baa-realm") + self.assertTrue( + header["Authorization"].startswith('OAuth realm="baa-realm",')) + # make sure sign() does not override the default realm + self.assertEqual(client.realm, "moo-realm") + + +class ClientConstructorTests(TestCase): + + def test_convert_to_unicode_resource_owner(self): + client = Client('client-key', + resource_owner_key=b'owner key') + self.assertNotIsInstance(client.resource_owner_key, bytes) + self.assertEqual(client.resource_owner_key, 'owner key') + + def test_give_explicit_timestamp(self): + client = Client('client-key', timestamp='1') + params = dict(client.get_oauth_params(Request('http://example.com'))) + self.assertEqual(params['oauth_timestamp'], '1') + + def test_give_explicit_nonce(self): + client = Client('client-key', nonce='1') + params = dict(client.get_oauth_params(Request('http://example.com'))) + self.assertEqual(params['oauth_nonce'], '1') + + def test_decoding(self): + client = Client('client_key', decoding='utf-8') + uri, headers, body = client.sign('http://a.b/path?query', + http_method='POST', body='a=b', + headers={'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertIsInstance(uri, bytes) + self.assertIsInstance(body, bytes) + for k, v in headers.items(): + self.assertIsInstance(k, bytes) + self.assertIsInstance(v, bytes) + + def test_hmac_sha1(self): + client = Client('client_key') + # instance is using the correct signer method + self.assertEqual(Client.SIGNATURE_METHODS[SIGNATURE_HMAC_SHA1], + client.SIGNATURE_METHODS[client.signature_method]) + + def test_hmac_sha256(self): + client = Client('client_key', signature_method=SIGNATURE_HMAC_SHA256) + # instance is using the correct signer method + self.assertEqual(Client.SIGNATURE_METHODS[SIGNATURE_HMAC_SHA256], + client.SIGNATURE_METHODS[client.signature_method]) + + def test_rsa(self): + client = Client('client_key', signature_method=SIGNATURE_RSA) + # instance is using the correct signer method + self.assertEqual(Client.SIGNATURE_METHODS[SIGNATURE_RSA], + client.SIGNATURE_METHODS[client.signature_method]) + # don't need an RSA key to instantiate + self.assertIsNone(client.rsa_key) + + +class SignatureMethodTest(TestCase): + + def test_hmac_sha1_method(self): + client = Client('client_key', timestamp='1234567890', nonce='abc') + u, h, b = client.sign('http://example.com') + correct = ('OAuth oauth_nonce="abc", oauth_timestamp="1234567890", ' + 'oauth_version="1.0", oauth_signature_method="HMAC-SHA1", ' + 'oauth_consumer_key="client_key", ' + 'oauth_signature="hH5BWYVqo7QI4EmPBUUe9owRUUQ%3D"') + self.assertEqual(h['Authorization'], correct) + + def test_hmac_sha256_method(self): + client = Client('client_key', signature_method=SIGNATURE_HMAC_SHA256, + timestamp='1234567890', nonce='abc') + u, h, b = client.sign('http://example.com') + correct = ('OAuth oauth_nonce="abc", oauth_timestamp="1234567890", ' + 'oauth_version="1.0", oauth_signature_method="HMAC-SHA256", ' + 'oauth_consumer_key="client_key", ' + 'oauth_signature="JzgJWBxX664OiMW3WE4MEjtYwOjI%2FpaUWHqtdHe68Es%3D"') + self.assertEqual(h['Authorization'], correct) + + def test_rsa_method(self): + private_key = ( + "-----BEGIN RSA PRIVATE KEY-----\nMIICXgIBAAKBgQDk1/bxy" + "S8Q8jiheHeYYp/4rEKJopeQRRKKpZI4s5i+UPwVpupG\nAlwXWfzXw" + "SMaKPAoKJNdu7tqKRniqst5uoHXw98gj0x7zamu0Ck1LtQ4c7pFMVa" + "h\n5IYGhBi2E9ycNS329W27nJPWNCbESTu7snVlG8V8mfvGGg3xNjT" + "MO7IdrwIDAQAB\nAoGBAOQ2KuH8S5+OrsL4K+wfjoCi6MfxCUyqVU9" + "GxocdM1m30WyWRFMEz2nKJ8fR\np3vTD4w8yplTOhcoXdQZl0kRoaD" + "zrcYkm2VvJtQRrX7dKFT8dR8D/Tr7dNQLOXfC\nDY6xveQczE7qt7V" + "k7lp4FqmxBsaaEuokt78pOOjywZoInjZhAkEA9wz3zoZNT0/i\nrf6" + "qv2qTIeieUB035N3dyw6f1BGSWYaXSuerDCD/J1qZbAPKKhyHZbVaw" + "Ft3UMhe\n542UftBaxQJBAO0iJy1I8GQjGnS7B3yvyH3CcLYGy296+" + "XO/2xKp/d/ty1OIeovx\nC60pLNwuFNF3z9d2GVQAdoQ89hUkOtjZL" + "eMCQQD0JO6oPHUeUjYT+T7ImAv7UKVT\nSuy30sKjLzqoGw1kR+wv7" + "C5PeDRvscs4wa4CW9s6mjSrMDkDrmCLuJDtmf55AkEA\nkmaMg2PNr" + "jUR51F0zOEFycaaqXbGcFwe1/xx9zLmHzMDXd4bsnwt9kk+fe0hQzV" + "S\nJzatanQit3+feev1PN3QewJAWv4RZeavEUhKv+kLe95Yd0su7lT" + "LVduVgh4v5yLT\nGa6FHdjGPcfajt+nrpB1n8UQBEH9ZxniokR/IPv" + "dMlxqXA==\n-----END RSA PRIVATE KEY-----" + ) + client = Client('client_key', signature_method=SIGNATURE_RSA, + rsa_key=private_key, timestamp='1234567890', nonce='abc') + u, h, b = client.sign('http://example.com') + correct = ('OAuth oauth_nonce="abc", oauth_timestamp="1234567890", ' + 'oauth_version="1.0", oauth_signature_method="RSA-SHA1", ' + 'oauth_consumer_key="client_key", ' + 'oauth_signature="ktvzkUhtrIawBcq21DRJrAyysTc3E1Zq5GdGu8EzH' + 'OtbeaCmOBDLGHAcqlm92mj7xp5E1Z6i2vbExPimYAJL7FzkLnkRE5YEJR4' + 'rNtIgAf1OZbYsIUmmBO%2BCLuStuu5Lg3tAluwC7XkkgoXCBaRKT1mUXzP' + 'HJILzZ8iFOvS6w5E%3D"') + self.assertEqual(h['Authorization'], correct) + + def test_plaintext_method(self): + client = Client('client_key', + signature_method=SIGNATURE_PLAINTEXT, + timestamp='1234567890', + nonce='abc', + client_secret='foo', + resource_owner_secret='bar') + u, h, b = client.sign('http://example.com') + correct = ('OAuth oauth_nonce="abc", oauth_timestamp="1234567890", ' + 'oauth_version="1.0", oauth_signature_method="PLAINTEXT", ' + 'oauth_consumer_key="client_key", ' + 'oauth_signature="foo%26bar"') + self.assertEqual(h['Authorization'], correct) + + def test_invalid_method(self): + client = Client('client_key', signature_method='invalid') + self.assertRaises(ValueError, client.sign, 'http://example.com') + + def test_rsa_no_key(self): + client = Client('client_key', signature_method=SIGNATURE_RSA) + self.assertRaises(ValueError, client.sign, 'http://example.com') + + def test_register_method(self): + Client.register_signature_method('PIZZA', + lambda base_string, client: 'PIZZA') + + self.assertIn('PIZZA', Client.SIGNATURE_METHODS) + + client = Client('client_key', signature_method='PIZZA', + timestamp='1234567890', nonce='abc') + + u, h, b = client.sign('http://example.com') + + self.assertEqual(h['Authorization'], ( + 'OAuth oauth_nonce="abc", oauth_timestamp="1234567890", ' + 'oauth_version="1.0", oauth_signature_method="PIZZA", ' + 'oauth_consumer_key="client_key", ' + 'oauth_signature="PIZZA"' + )) + + +class SignatureTypeTest(TestCase): + + def test_params_in_body(self): + client = Client('client_key', signature_type=SIGNATURE_TYPE_BODY, + timestamp='1378988215', nonce='14205877133089081931378988215') + _, h, b = client.sign('http://i.b/path', http_method='POST', body='a=b', + headers={'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertEqual(h['Content-Type'], 'application/x-www-form-urlencoded') + correct = ('a=b&oauth_nonce=14205877133089081931378988215&' + 'oauth_timestamp=1378988215&' + 'oauth_version=1.0&' + 'oauth_signature_method=HMAC-SHA1&' + 'oauth_consumer_key=client_key&' + 'oauth_signature=2JAQomgbShqoscqKWBiYQZwWq94%3D') + self.assertEqual(b, correct) + + def test_params_in_query(self): + client = Client('client_key', signature_type=SIGNATURE_TYPE_QUERY, + timestamp='1378988215', nonce='14205877133089081931378988215') + u, _, _ = client.sign('http://i.b/path', http_method='POST') + correct = ('http://i.b/path?oauth_nonce=14205877133089081931378988215&' + 'oauth_timestamp=1378988215&' + 'oauth_version=1.0&' + 'oauth_signature_method=HMAC-SHA1&' + 'oauth_consumer_key=client_key&' + 'oauth_signature=08G5Snvw%2BgDAzBF%2BCmT5KqlrPKo%3D') + self.assertEqual(u, correct) + + def test_invalid_signature_type(self): + client = Client('client_key', signature_type='invalid') + self.assertRaises(ValueError, client.sign, 'http://i.b/path') + + +class SigningTest(TestCase): + + def test_case_insensitive_headers(self): + client = Client('client_key') + # Uppercase + _, h, _ = client.sign('http://i.b/path', http_method='POST', body='', + headers={'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertEqual(h['Content-Type'], 'application/x-www-form-urlencoded') + + # Lowercase + _, h, _ = client.sign('http://i.b/path', http_method='POST', body='', + headers={'content-type': 'application/x-www-form-urlencoded'}) + self.assertEqual(h['content-type'], 'application/x-www-form-urlencoded') + + # Capitalized + _, h, _ = client.sign('http://i.b/path', http_method='POST', body='', + headers={'Content-type': 'application/x-www-form-urlencoded'}) + self.assertEqual(h['Content-type'], 'application/x-www-form-urlencoded') + + # Random + _, h, _ = client.sign('http://i.b/path', http_method='POST', body='', + headers={'conTent-tYpe': 'application/x-www-form-urlencoded'}) + self.assertEqual(h['conTent-tYpe'], 'application/x-www-form-urlencoded') + + def test_sign_no_body(self): + client = Client('client_key', decoding='utf-8') + self.assertRaises(ValueError, client.sign, 'http://i.b/path', + http_method='POST', body=None, + headers={'Content-Type': 'application/x-www-form-urlencoded'}) + + def test_sign_body(self): + client = Client('client_key') + _, h, b = client.sign('http://i.b/path', http_method='POST', body='', + headers={'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertEqual(h['Content-Type'], 'application/x-www-form-urlencoded') + + def test_sign_get_with_body(self): + client = Client('client_key') + for method in ('GET', 'HEAD'): + self.assertRaises(ValueError, client.sign, 'http://a.b/path?query', + http_method=method, body='a=b', + headers={ + 'Content-Type': 'application/x-www-form-urlencoded' + }) + + def test_sign_unicode(self): + client = Client('client_key', nonce='abc', timestamp='abc') + _, h, b = client.sign('http://i.b/path', http_method='POST', + body='status=%E5%95%A6%E5%95%A6', + headers={'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertEqual(b, 'status=%E5%95%A6%E5%95%A6') + self.assertIn('oauth_signature="yrtSqp88m%2Fc5UDaucI8BXK4oEtk%3D"', h['Authorization']) + _, h, b = client.sign('http://i.b/path', http_method='POST', + body='status=%C3%A6%C3%A5%C3%B8', + headers={'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertEqual(b, 'status=%C3%A6%C3%A5%C3%B8') + self.assertIn('oauth_signature="oG5t3Eg%2FXO5FfQgUUlTtUeeZzvk%3D"', h['Authorization']) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/test_parameters.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_parameters.py new file mode 100644 index 0000000000..92b95c1167 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_parameters.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +from oauthlib.common import urlencode +from oauthlib.oauth1.rfc5849.parameters import ( + _append_params, prepare_form_encoded_body, prepare_headers, + prepare_request_uri_query, +) + +from tests.unittest import TestCase + + +class ParameterTests(TestCase): + auth_only_params = [ + ('oauth_consumer_key', "9djdj82h48djs9d2"), + ('oauth_token', "kkk9d7dh3k39sjv7"), + ('oauth_signature_method', "HMAC-SHA1"), + ('oauth_timestamp', "137131201"), + ('oauth_nonce', "7d8f3e4a"), + ('oauth_signature', "bYT5CMsGcbgUdFHObYMEfcx6bsw=") + ] + auth_and_data = list(auth_only_params) + auth_and_data.append(('data_param_foo', 'foo')) + auth_and_data.append(('data_param_1', '1')) + realm = 'testrealm' + norealm_authorization_header = ' '.join(( + 'OAuth', + 'oauth_consumer_key="9djdj82h48djs9d2",', + 'oauth_token="kkk9d7dh3k39sjv7",', + 'oauth_signature_method="HMAC-SHA1",', + 'oauth_timestamp="137131201",', + 'oauth_nonce="7d8f3e4a",', + 'oauth_signature="bYT5CMsGcbgUdFHObYMEfcx6bsw%3D"', + )) + withrealm_authorization_header = ' '.join(( + 'OAuth', + 'realm="testrealm",', + 'oauth_consumer_key="9djdj82h48djs9d2",', + 'oauth_token="kkk9d7dh3k39sjv7",', + 'oauth_signature_method="HMAC-SHA1",', + 'oauth_timestamp="137131201",', + 'oauth_nonce="7d8f3e4a",', + 'oauth_signature="bYT5CMsGcbgUdFHObYMEfcx6bsw%3D"', + )) + + def test_append_params(self): + unordered_1 = [ + ('oauth_foo', 'foo'), + ('lala', 123), + ('oauth_baz', 'baz'), + ('oauth_bar', 'bar'), ] + unordered_2 = [ + ('teehee', 456), + ('oauth_quux', 'quux'), ] + expected = [ + ('teehee', 456), + ('lala', 123), + ('oauth_quux', 'quux'), + ('oauth_foo', 'foo'), + ('oauth_baz', 'baz'), + ('oauth_bar', 'bar'), ] + self.assertEqual(_append_params(unordered_1, unordered_2), expected) + + def test_prepare_headers(self): + self.assertEqual( + prepare_headers(self.auth_only_params, {}), + {'Authorization': self.norealm_authorization_header}) + self.assertEqual( + prepare_headers(self.auth_only_params, {}, realm=self.realm), + {'Authorization': self.withrealm_authorization_header}) + + def test_prepare_headers_ignore_data(self): + self.assertEqual( + prepare_headers(self.auth_and_data, {}), + {'Authorization': self.norealm_authorization_header}) + self.assertEqual( + prepare_headers(self.auth_and_data, {}, realm=self.realm), + {'Authorization': self.withrealm_authorization_header}) + + def test_prepare_form_encoded_body(self): + existing_body = '' + form_encoded_body = 'data_param_foo=foo&data_param_1=1&oauth_consumer_key=9djdj82h48djs9d2&oauth_token=kkk9d7dh3k39sjv7&oauth_signature_method=HMAC-SHA1&oauth_timestamp=137131201&oauth_nonce=7d8f3e4a&oauth_signature=bYT5CMsGcbgUdFHObYMEfcx6bsw%3D' + self.assertEqual( + urlencode(prepare_form_encoded_body(self.auth_and_data, existing_body)), + form_encoded_body) + + def test_prepare_request_uri_query(self): + url = 'http://notarealdomain.com/foo/bar/baz?some=args&go=here' + request_uri_query = 'http://notarealdomain.com/foo/bar/baz?some=args&go=here&data_param_foo=foo&data_param_1=1&oauth_consumer_key=9djdj82h48djs9d2&oauth_token=kkk9d7dh3k39sjv7&oauth_signature_method=HMAC-SHA1&oauth_timestamp=137131201&oauth_nonce=7d8f3e4a&oauth_signature=bYT5CMsGcbgUdFHObYMEfcx6bsw%3D' + self.assertEqual( + prepare_request_uri_query(self.auth_and_data, url), + request_uri_query) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/test_request_validator.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_request_validator.py new file mode 100644 index 0000000000..8d34415040 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_request_validator.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +from oauthlib.oauth1 import RequestValidator + +from tests.unittest import TestCase + + +class RequestValidatorTests(TestCase): + + def test_not_implemented(self): + v = RequestValidator() + self.assertRaises(NotImplementedError, v.get_client_secret, None, None) + self.assertRaises(NotImplementedError, v.get_request_token_secret, + None, None, None) + self.assertRaises(NotImplementedError, v.get_access_token_secret, + None, None, None) + self.assertRaises(NotImplementedError, lambda: v.dummy_client) + self.assertRaises(NotImplementedError, lambda: v.dummy_request_token) + self.assertRaises(NotImplementedError, lambda: v.dummy_access_token) + self.assertRaises(NotImplementedError, v.get_rsa_key, None, None) + self.assertRaises(NotImplementedError, v.get_default_realms, None, None) + self.assertRaises(NotImplementedError, v.get_realms, None, None) + self.assertRaises(NotImplementedError, v.get_redirect_uri, None, None) + self.assertRaises(NotImplementedError, v.validate_client_key, None, None) + self.assertRaises(NotImplementedError, v.validate_access_token, + None, None, None) + self.assertRaises(NotImplementedError, v.validate_request_token, + None, None, None) + self.assertRaises(NotImplementedError, v.verify_request_token, + None, None) + self.assertRaises(NotImplementedError, v.verify_realms, + None, None, None) + self.assertRaises(NotImplementedError, v.validate_timestamp_and_nonce, + None, None, None, None) + self.assertRaises(NotImplementedError, v.validate_redirect_uri, + None, None, None) + self.assertRaises(NotImplementedError, v.validate_realms, + None, None, None, None, None) + self.assertRaises(NotImplementedError, v.validate_requested_realms, + None, None, None) + self.assertRaises(NotImplementedError, v.validate_verifier, + None, None, None, None) + self.assertRaises(NotImplementedError, v.save_access_token, None, None) + self.assertRaises(NotImplementedError, v.save_request_token, None, None) + self.assertRaises(NotImplementedError, v.save_verifier, + None, None, None) + + def test_check_length(self): + v = RequestValidator() + + for method in (v.check_client_key, v.check_request_token, + v.check_access_token, v.check_nonce, v.check_verifier): + for not_valid in ('tooshort', 'invalid?characters!', + 'thisclientkeyisalittlebittoolong'): + self.assertFalse(method(not_valid)) + for valid in ('itsjustaboutlongenough',): + self.assertTrue(method(valid)) + + def test_check_realms(self): + v = RequestValidator() + self.assertFalse(v.check_realms(['foo'])) + + class FooRealmValidator(RequestValidator): + @property + def realms(self): + return ['foo'] + + v = FooRealmValidator() + self.assertTrue(v.check_realms(['foo'])) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/test_signatures.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_signatures.py new file mode 100644 index 0000000000..2d4735eafd --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_signatures.py @@ -0,0 +1,896 @@ +# -*- coding: utf-8 -*- +from oauthlib.oauth1.rfc5849.signature import ( + base_string_uri, collect_parameters, normalize_parameters, + sign_hmac_sha1_with_client, sign_hmac_sha256_with_client, + sign_hmac_sha512_with_client, sign_plaintext_with_client, + sign_rsa_sha1_with_client, sign_rsa_sha256_with_client, + sign_rsa_sha512_with_client, signature_base_string, verify_hmac_sha1, + verify_hmac_sha256, verify_hmac_sha512, verify_plaintext, verify_rsa_sha1, + verify_rsa_sha256, verify_rsa_sha512, +) + +from tests.unittest import TestCase + +# ################################################################ + +class MockRequest: + """ + Mock of a request used by the verify_* functions. + """ + + def __init__(self, + method: str, + uri_str: str, + params: list, + signature: str): + """ + The params is a list of (name, value) tuples. It is not a dictionary, + because there can be multiple parameters with the same name. + """ + self.uri = uri_str + self.http_method = method + self.params = params + self.signature = signature + + +# ################################################################ + +class MockClient: + """ + Mock of client credentials used by the sign_*_with_client functions. + + For HMAC, set the client_secret and resource_owner_secret. + + For RSA, set the rsa_key to either a PEM formatted PKCS #1 public key or + PEM formatted PKCS #1 private key. + """ + def __init__(self, + client_secret: str = None, + resource_owner_secret: str = None, + rsa_key: str = None): + self.client_secret = client_secret + self.resource_owner_secret = resource_owner_secret + self.rsa_key = rsa_key # used for private or public key: a poor design! + + +# ################################################################ + +class SignatureTests(TestCase): + """ + Unit tests for the oauthlib/oauth1/rfc5849/signature.py module. + + The tests in this class are organised into sections, to test the + functions relating to: + + - Signature base string calculation + - HMAC-based signature methods + - RSA-based signature methods + - PLAINTEXT signature method + + Each section is separated by a comment beginning with "====". + + Those comments have been formatted to remain visible when the code is + collapsed using PyCharm's code folding feature. That is, those section + heading comments do not have any other comment lines around it, so they + don't get collapsed when the contents of the class is collapsed. While + there is a "Sequential comments" option in the code folding configuration, + by default they are folded. + + They all use some/all of the example test vector, defined in the first + section below. + """ + + # ==== Example test vector ======================================= + + eg_signature_base_string =\ + 'POST&http%3A%2F%2Fexample.com%2Frequest&a2%3Dr%2520b%26a3%3D2%2520q' \ + '%26a3%3Da%26b5%3D%253D%25253D%26c%2540%3D%26c2%3D%26oauth_consumer_' \ + 'key%3D9djdj82h48djs9d2%26oauth_nonce%3D7d8f3e4a%26oauth_signature_m' \ + 'ethod%3DHMAC-SHA1%26oauth_timestamp%3D137131201%26oauth_token%3Dkkk' \ + '9d7dh3k39sjv7' + + # The _signature base string_ above is copied from the end of + # RFC 5849 section 3.4.1.1. + # + # It corresponds to the three values below. + # + # The _normalized parameters_ below is copied from the end of + # RFC 5849 section 3.4.1.3.2. + + eg_http_method = 'POST' + + eg_base_string_uri = 'http://example.com/request' + + eg_normalized_parameters =\ + 'a2=r%20b&a3=2%20q&a3=a&b5=%3D%253D&c%40=&c2=&oauth_consumer_key=9dj' \ + 'dj82h48djs9d2&oauth_nonce=7d8f3e4a&oauth_signature_method=HMAC-SHA1' \ + '&oauth_timestamp=137131201&oauth_token=kkk9d7dh3k39sjv7' + + # The above _normalized parameters_ corresponds to the parameters below. + # + # The parameters below is copied from the table at the end of + # RFC 5849 section 3.4.1.3.1. + + eg_params = [ + ('b5', '=%3D'), + ('a3', 'a'), + ('c@', ''), + ('a2', 'r b'), + ('oauth_consumer_key', '9djdj82h48djs9d2'), + ('oauth_token', 'kkk9d7dh3k39sjv7'), + ('oauth_signature_method', 'HMAC-SHA1'), + ('oauth_timestamp', '137131201'), + ('oauth_nonce', '7d8f3e4a'), + ('c2', ''), + ('a3', '2 q'), + ] + + # The above parameters correspond to parameters from the three values below. + # + # These come from RFC 5849 section 3.4.1.3.1. + + eg_uri_query = 'b5=%3D%253D&a3=a&c%40=&a2=r%20b' + + eg_body = 'c2&a3=2+q' + + eg_authorization_header =\ + 'OAuth realm="Example", oauth_consumer_key="9djdj82h48djs9d2",' \ + ' oauth_token="kkk9d7dh3k39sjv7", oauth_signature_method="HMAC-SHA1",' \ + ' oauth_timestamp="137131201", oauth_nonce="7d8f3e4a",' \ + ' oauth_signature="djosJKDKJSD8743243%2Fjdk33klY%3D"' + + # ==== Signature base string calculating function tests ========== + + def test_signature_base_string(self): + """ + Test the ``signature_base_string`` function. + """ + + # Example from RFC 5849 + + self.assertEqual( + self.eg_signature_base_string, + signature_base_string( + self.eg_http_method, + self.eg_base_string_uri, + self.eg_normalized_parameters)) + + # Test method is always uppercase in the signature base string + + for test_method in ['POST', 'Post', 'pOST', 'poST', 'posT', 'post']: + self.assertEqual( + self.eg_signature_base_string, + signature_base_string( + test_method, + self.eg_base_string_uri, + self.eg_normalized_parameters)) + + def test_base_string_uri(self): + """ + Test the ``base_string_uri`` function. + """ + + # ---------------- + # Examples from the OAuth 1.0a specification: RFC 5849. + + # First example from RFC 5849 section 3.4.1.2. + # + # GET /r%20v/X?id=123 HTTP/1.1 + # Host: EXAMPLE.COM:80 + # + # Note: there is a space between "r" and "v" + + self.assertEqual( + 'http://example.com/r%20v/X', + base_string_uri('http://EXAMPLE.COM:80/r v/X?id=123')) + + # Second example from RFC 5849 section 3.4.1.2. + # + # GET /?q=1 HTTP/1.1 + # Host: www.example.net:8080 + + self.assertEqual( + 'https://www.example.net:8080/', + base_string_uri('https://www.example.net:8080/?q=1')) + + # ---------------- + # Scheme: will always be in lowercase + + for uri in [ + 'foobar://www.example.com', + 'FOOBAR://www.example.com', + 'Foobar://www.example.com', + 'FooBar://www.example.com', + 'fOObAR://www.example.com', + ]: + self.assertEqual('foobar://www.example.com/', base_string_uri(uri)) + + # ---------------- + # Host: will always be in lowercase + + for uri in [ + 'http://www.example.com', + 'http://WWW.EXAMPLE.COM', + 'http://www.EXAMPLE.com', + 'http://wWW.eXAMPLE.cOM', + ]: + self.assertEqual('http://www.example.com/', base_string_uri(uri)) + + # base_string_uri has an optional host parameter that can be used to + # override the URI's netloc (or used as the host if there is no netloc) + # The "netloc" refers to the "hostname[:port]" part of the URI. + + self.assertEqual( + 'http://actual.example.com/', + base_string_uri('http://IGNORE.example.com', 'ACTUAL.example.com')) + + self.assertEqual( + 'http://override.example.com/path', + base_string_uri('http:///path', 'OVERRIDE.example.com')) + + # ---------------- + # Host: valid host allows for IPv4 and IPv6 + + self.assertEqual( + 'https://192.168.0.1/', + base_string_uri('https://192.168.0.1') + ) + self.assertEqual( + 'https://192.168.0.1:13000/', + base_string_uri('https://192.168.0.1:13000') + ) + self.assertEqual( + 'https://[123:db8:fd00:1000::5]:13000/', + base_string_uri('https://[123:db8:fd00:1000::5]:13000') + ) + self.assertEqual( + 'https://[123:db8:fd00:1000::5]/', + base_string_uri('https://[123:db8:fd00:1000::5]') + ) + + # ---------------- + # Port: default ports always excluded; non-default ports always included + + self.assertEqual( + "http://www.example.com/", + base_string_uri("http://www.example.com:80/")) # default port + + self.assertEqual( + "https://www.example.com/", + base_string_uri("https://www.example.com:443/")) # default port + + self.assertEqual( + "https://www.example.com:999/", + base_string_uri("https://www.example.com:999/")) # non-default port + + self.assertEqual( + "http://www.example.com:443/", + base_string_uri("HTTP://www.example.com:443/")) # non-default port + + self.assertEqual( + "https://www.example.com:80/", + base_string_uri("HTTPS://www.example.com:80/")) # non-default port + + self.assertEqual( + "http://www.example.com/", + base_string_uri("http://www.example.com:/")) # colon but no number + + # ---------------- + # Paths + + self.assertEqual( + 'http://www.example.com/', + base_string_uri('http://www.example.com')) # no slash + + self.assertEqual( + 'http://www.example.com/', + base_string_uri('http://www.example.com/')) # with slash + + self.assertEqual( + 'http://www.example.com:8080/', + base_string_uri('http://www.example.com:8080')) # no slash + + self.assertEqual( + 'http://www.example.com:8080/', + base_string_uri('http://www.example.com:8080/')) # with slash + + self.assertEqual( + 'http://www.example.com/foo/bar', + base_string_uri('http://www.example.com/foo/bar')) # no slash + self.assertEqual( + 'http://www.example.com/foo/bar/', + base_string_uri('http://www.example.com/foo/bar/')) # with slash + + # ---------------- + # Query parameters & fragment IDs do not appear in the base string URI + + self.assertEqual( + 'https://www.example.com/path', + base_string_uri('https://www.example.com/path?foo=bar')) + + self.assertEqual( + 'https://www.example.com/path', + base_string_uri('https://www.example.com/path#fragment')) + + # ---------------- + # Percent encoding + # + # RFC 5849 does not specify what characters are percent encoded, but in + # one of its examples it shows spaces being percent encoded. + # So it is assumed that spaces must be encoded, but we don't know what + # other characters are encoded or not. + + self.assertEqual( + 'https://www.example.com/hello%20world', + base_string_uri('https://www.example.com/hello world')) + + self.assertEqual( + 'https://www.hello%20world.com/', + base_string_uri('https://www.hello world.com/')) + + # ---------------- + # Errors detected + + # base_string_uri expects a string + self.assertRaises(ValueError, base_string_uri, None) + self.assertRaises(ValueError, base_string_uri, 42) + self.assertRaises(ValueError, base_string_uri, b'http://example.com') + + # Missing scheme is an error + self.assertRaises(ValueError, base_string_uri, '') + self.assertRaises(ValueError, base_string_uri, ' ') # single space + self.assertRaises(ValueError, base_string_uri, 'http') + self.assertRaises(ValueError, base_string_uri, 'example.com') + + # Missing host is an error + self.assertRaises(ValueError, base_string_uri, 'http:') + self.assertRaises(ValueError, base_string_uri, 'http://') + self.assertRaises(ValueError, base_string_uri, 'http://:8080') + + # Port is not a valid TCP/IP port number + self.assertRaises(ValueError, base_string_uri, 'http://eg.com:0') + self.assertRaises(ValueError, base_string_uri, 'http://eg.com:-1') + self.assertRaises(ValueError, base_string_uri, 'http://eg.com:65536') + self.assertRaises(ValueError, base_string_uri, 'http://eg.com:3.14') + self.assertRaises(ValueError, base_string_uri, 'http://eg.com:BAD') + self.assertRaises(ValueError, base_string_uri, 'http://eg.com:NaN') + self.assertRaises(ValueError, base_string_uri, 'http://eg.com: ') + self.assertRaises(ValueError, base_string_uri, 'http://eg.com:42:42') + + def test_collect_parameters(self): + """ + Test the ``collect_parameters`` function. + """ + + # ---------------- + # Examples from the OAuth 1.0a specification: RFC 5849. + + params = collect_parameters( + self.eg_uri_query, + self.eg_body, + {'Authorization': self.eg_authorization_header}) + + # Check params contains the same pairs as control_params, ignoring order + self.assertEqual(sorted(self.eg_params), sorted(params)) + + # ---------------- + # Examples with no parameters + + self.assertEqual([], collect_parameters('', '', {})) + + self.assertEqual([], collect_parameters(None, None, None)) + + self.assertEqual([], collect_parameters()) + + self.assertEqual([], collect_parameters(headers={'foo': 'bar'})) + + # ---------------- + # Test effect of exclude_oauth_signature" + + no_sig = collect_parameters( + headers={'authorization': self.eg_authorization_header}) + with_sig = collect_parameters( + headers={'authorization': self.eg_authorization_header}, + exclude_oauth_signature=False) + + self.assertEqual(sorted(no_sig + [('oauth_signature', + 'djosJKDKJSD8743243/jdk33klY=')]), + sorted(with_sig)) + + # ---------------- + # Test effect of "with_realm" as well as header name case insensitivity + + no_realm = collect_parameters( + headers={'authorization': self.eg_authorization_header}, + with_realm=False) + with_realm = collect_parameters( + headers={'AUTHORIZATION': self.eg_authorization_header}, + with_realm=True) + + self.assertEqual(sorted(no_realm + [('realm', 'Example')]), + sorted(with_realm)) + + def test_normalize_parameters(self): + """ + Test the ``normalize_parameters`` function. + """ + + # headers = {'Authorization': self.authorization_header} + # parameters = collect_parameters( + # uri_query=self.uri_query, body=self.body, headers=headers) + # normalized = normalize_parameters(parameters) + # + # # Unicode everywhere and always + # self.assertIsInstance(normalized, str) + # + # # Lets see if things are in order + # # check to see that querystring keys come in alphanumeric order: + # querystring_keys = ['a2', 'a3', 'b5', 'oauth_consumer_key', + # 'oauth_nonce', 'oauth_signature_method', + # 'oauth_timestamp', 'oauth_token'] + # index = -1 # start at -1 because the 'a2' key starts at index 0 + # for key in querystring_keys: + # self.assertGreater(normalized.index(key), index) + # index = normalized.index(key) + + # ---------------- + # Example from the OAuth 1.0a specification: RFC 5849. + # Params from end of section 3.4.1.3.1. and the expected + # normalized parameters from the end of section 3.4.1.3.2. + + self.assertEqual(self.eg_normalized_parameters, + normalize_parameters(self.eg_params)) + + # ==== HMAC-based signature method tests ========================= + + hmac_client = MockClient( + client_secret='ECrDNoq1VYzzzzzzzzzyAK7TwZNtPnkqatqZZZZ', + resource_owner_secret='just-a-string asdasd') + + # The following expected signatures were calculated by putting the value of + # the eg_signature_base_string in a file ("base-str.txt") and running: + # + # echo -n `cat base-str.txt` | openssl dgst -hmac KEY -sha1 -binary| base64 + # + # Where the KEY is the concatenation of the client_secret, an ampersand and + # the resource_owner_secret. But those values need to be encoded properly, + # so the spaces in the resource_owner_secret must be represented as '%20'. + # + # Note: the "echo -n" is needed to remove the last newline character, which + # most text editors will add. + + expected_signature_hmac_sha1 = \ + 'wsdNmjGB7lvis0UJuPAmjvX/PXw=' + + expected_signature_hmac_sha256 = \ + 'wdfdHUKXHbOnOGZP8WFAWMSAmWzN3EVBWWgXGlC/Eo4=' + + expected_signature_hmac_sha512 = \ + 'u/vlyZFDxOWOZ9UUXwRBJHvq8/T4jCA74ocRmn2ECnjUBTAeJiZIRU8hDTjS88Tz' \ + '1fGONffMpdZxUkUTW3k1kg==' + + def test_sign_hmac_sha1_with_client(self): + """ + Test sign and verify with HMAC-SHA1. + """ + self.assertEqual( + self.expected_signature_hmac_sha1, + sign_hmac_sha1_with_client(self.eg_signature_base_string, + self.hmac_client)) + self.assertTrue(verify_hmac_sha1( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + self.expected_signature_hmac_sha1), + self.hmac_client.client_secret, + self.hmac_client.resource_owner_secret)) + + def test_sign_hmac_sha256_with_client(self): + """ + Test sign and verify with HMAC-SHA256. + """ + self.assertEqual( + self.expected_signature_hmac_sha256, + sign_hmac_sha256_with_client(self.eg_signature_base_string, + self.hmac_client)) + self.assertTrue(verify_hmac_sha256( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + self.expected_signature_hmac_sha256), + self.hmac_client.client_secret, + self.hmac_client.resource_owner_secret)) + + def test_sign_hmac_sha512_with_client(self): + """ + Test sign and verify with HMAC-SHA512. + """ + self.assertEqual( + self.expected_signature_hmac_sha512, + sign_hmac_sha512_with_client(self.eg_signature_base_string, + self.hmac_client)) + self.assertTrue(verify_hmac_sha512( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + self.expected_signature_hmac_sha512), + self.hmac_client.client_secret, + self.hmac_client.resource_owner_secret)) + + def test_hmac_false_positives(self): + """ + Test verify_hmac-* functions will correctly detect invalid signatures. + """ + + _ros = self.hmac_client.resource_owner_secret + + for functions in [ + (sign_hmac_sha1_with_client, verify_hmac_sha1), + (sign_hmac_sha256_with_client, verify_hmac_sha256), + (sign_hmac_sha512_with_client, verify_hmac_sha512), + ]: + signing_function = functions[0] + verify_function = functions[1] + + good_signature = \ + signing_function( + self.eg_signature_base_string, + self.hmac_client) + + bad_signature_on_different_value = \ + signing_function( + 'not the signature base string', + self.hmac_client) + + bad_signature_produced_by_different_client_secret = \ + signing_function( + self.eg_signature_base_string, + MockClient(client_secret='wrong-secret', + resource_owner_secret=_ros)) + bad_signature_produced_by_different_resource_owner_secret = \ + signing_function( + self.eg_signature_base_string, + MockClient(client_secret=self.hmac_client.client_secret, + resource_owner_secret='wrong-secret')) + + bad_signature_produced_with_no_resource_owner_secret = \ + signing_function( + self.eg_signature_base_string, + MockClient(client_secret=self.hmac_client.client_secret)) + bad_signature_produced_with_no_client_secret = \ + signing_function( + self.eg_signature_base_string, + MockClient(resource_owner_secret=_ros)) + + self.assertTrue(verify_function( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + good_signature), + self.hmac_client.client_secret, + self.hmac_client.resource_owner_secret)) + + for bad_signature in [ + '', + 'ZG9uJ3QgdHJ1c3QgbWUK', # random base64 encoded value + 'altérer', # value with a non-ASCII character in it + bad_signature_on_different_value, + bad_signature_produced_by_different_client_secret, + bad_signature_produced_by_different_resource_owner_secret, + bad_signature_produced_with_no_resource_owner_secret, + bad_signature_produced_with_no_client_secret, + ]: + self.assertFalse(verify_function( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + bad_signature), + self.hmac_client.client_secret, + self.hmac_client.resource_owner_secret)) + + # ==== RSA-based signature methods tests ========================= + + rsa_private_client = MockClient(rsa_key=''' +-----BEGIN RSA PRIVATE KEY----- +MIICXgIBAAKBgQDk1/bxyS8Q8jiheHeYYp/4rEKJopeQRRKKpZI4s5i+UPwVpupG +AlwXWfzXwSMaKPAoKJNdu7tqKRniqst5uoHXw98gj0x7zamu0Ck1LtQ4c7pFMVah +5IYGhBi2E9ycNS329W27nJPWNCbESTu7snVlG8V8mfvGGg3xNjTMO7IdrwIDAQAB +AoGBAOQ2KuH8S5+OrsL4K+wfjoCi6MfxCUyqVU9GxocdM1m30WyWRFMEz2nKJ8fR +p3vTD4w8yplTOhcoXdQZl0kRoaDzrcYkm2VvJtQRrX7dKFT8dR8D/Tr7dNQLOXfC +DY6xveQczE7qt7Vk7lp4FqmxBsaaEuokt78pOOjywZoInjZhAkEA9wz3zoZNT0/i +rf6qv2qTIeieUB035N3dyw6f1BGSWYaXSuerDCD/J1qZbAPKKhyHZbVawFt3UMhe +542UftBaxQJBAO0iJy1I8GQjGnS7B3yvyH3CcLYGy296+XO/2xKp/d/ty1OIeovx +C60pLNwuFNF3z9d2GVQAdoQ89hUkOtjZLeMCQQD0JO6oPHUeUjYT+T7ImAv7UKVT +Suy30sKjLzqoGw1kR+wv7C5PeDRvscs4wa4CW9s6mjSrMDkDrmCLuJDtmf55AkEA +kmaMg2PNrjUR51F0zOEFycaaqXbGcFwe1/xx9zLmHzMDXd4bsnwt9kk+fe0hQzVS +JzatanQit3+feev1PN3QewJAWv4RZeavEUhKv+kLe95Yd0su7lTLVduVgh4v5yLT +Ga6FHdjGPcfajt+nrpB1n8UQBEH9ZxniokR/IPvdMlxqXA== +-----END RSA PRIVATE KEY----- +''') + + rsa_public_client = MockClient(rsa_key=''' +-----BEGIN RSA PUBLIC KEY----- +MIGJAoGBAOTX9vHJLxDyOKF4d5hin/isQomil5BFEoqlkjizmL5Q/BWm6kYCXBdZ +/NfBIxoo8Cgok127u2opGeKqy3m6gdfD3yCPTHvNqa7QKTUu1DhzukUxVqHkhgaE +GLYT3Jw1Lfb1bbuck9Y0JsRJO7uydWUbxXyZ+8YaDfE2NMw7sh2vAgMBAAE= +-----END RSA PUBLIC KEY----- +''') + + # The above private key was generated using: + # $ openssl genrsa -out example.pvt 1024 + # $ chmod 600 example.pvt + # Public key was extract from it using: + # $ ssh-keygen -e -m pem -f example.pvt + # PEM encoding requires the key to be concatenated with linebreaks. + + # The following expected signatures were calculated by putting the private + # key in a file (test.pvt) and the value of sig_base_str_rsa in another file + # ("base-str.txt") and running: + # + # echo -n `cat base-str.txt` | openssl dgst -sha1 -sign test.pvt| base64 + # + # Note: the "echo -n" is needed to remove the last newline character, which + # most text editors will add. + + expected_signature_rsa_sha1 = \ + 'mFY2KOEnlYWsTvUA+5kxuBIcvBYXu+ljw9ttVJQxKduMueGSVPCB1tK1PlqVLK738' \ + 'HK0t19ecBJfb6rMxUwrriw+MlBO+jpojkZIWccw1J4cAb4qu4M81DbpUAq4j/1w/Q' \ + 'yTR4TWCODlEfN7Zfgy8+pf+TjiXfIwRC1jEWbuL1E=' + + expected_signature_rsa_sha256 = \ + 'jqKl6m0WS69tiVJV8ZQ6aQEfJqISoZkiPBXRv6Al2+iFSaDpfeXjYm+Hbx6m1azR' \ + 'drZ/35PM3cvuid3LwW/siAkzb0xQcGnTyAPH8YcGWzmnKGY7LsB7fkqThchNxvRK' \ + '/N7s9M1WMnfZZ+1dQbbwtTs1TG1+iexUcV7r3M7Heec=' + + expected_signature_rsa_sha512 = \ + 'jL1CnjlsNd25qoZVHZ2oJft47IRYTjpF5CvCUjL3LY0NTnbEeVhE4amWXUFBe9GL' \ + 'DWdUh/79ZWNOrCirBFIP26cHLApjYdt4ZG7EVK0/GubS2v8wT1QPRsog8zyiMZkm' \ + 'g4JXdWCGXG8YRvRJTg+QKhXuXwS6TcMNakrgzgFIVhA=' + + def test_sign_rsa_sha1_with_client(self): + """ + Test sign and verify with RSA-SHA1. + """ + self.assertEqual( + self.expected_signature_rsa_sha1, + sign_rsa_sha1_with_client(self.eg_signature_base_string, + self.rsa_private_client)) + self.assertTrue(verify_rsa_sha1( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + self.expected_signature_rsa_sha1), + self.rsa_public_client.rsa_key)) + + def test_sign_rsa_sha256_with_client(self): + """ + Test sign and verify with RSA-SHA256. + """ + self.assertEqual( + self.expected_signature_rsa_sha256, + sign_rsa_sha256_with_client(self.eg_signature_base_string, + self.rsa_private_client)) + self.assertTrue(verify_rsa_sha256( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + self.expected_signature_rsa_sha256), + self.rsa_public_client.rsa_key)) + + def test_sign_rsa_sha512_with_client(self): + """ + Test sign and verify with RSA-SHA512. + """ + self.assertEqual( + self.expected_signature_rsa_sha512, + sign_rsa_sha512_with_client(self.eg_signature_base_string, + self.rsa_private_client)) + self.assertTrue(verify_rsa_sha512( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + self.expected_signature_rsa_sha512), + self.rsa_public_client.rsa_key)) + + def test_rsa_false_positives(self): + """ + Test verify_rsa-* functions will correctly detect invalid signatures. + """ + + another_client = MockClient(rsa_key=''' +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQDZcD/1OZNJJ6Y3QZM16Z+O7fkD9kTIQuT2BfpAOUvDfxzYhVC9 +TNmSDHCQhr+ClutyolBk5jTE1/FXFUuHoPsTrkI7KQFXPP834D4gnSY9jrAiUJHe +DVF6wXNuS7H4Ueh16YPjUxgLLRh/nn/JSEj98gsw+7DP01OWMfWS99S7eQIDAQAB +AoGBALsQZRXVyK7BG7CiC8HwEcNnXDpaXmZjlpNKJTenk1THQMvONd4GBZAuf5D3 +PD9fE4R1u/ByVKecmBaxTV+L0TRQfD8K/nbQe0SKRQIkLI2ymLJKC/eyw5iTKT0E ++BS6wYpVd+mfcqgvpHOYpUmz9X8k/eOa7uslFmvt+sDb5ZcBAkEA+++SRqqUxFEG +s/ZWAKw9p5YgkeVUOYVUwyAeZ97heySrjVzg1nZ6v6kv7iOPi9KOEpaIGPW7x1K/ +uQuSt4YEqQJBANzyNqZTTPpv7b/R8ABFy0YMwPVNt3b1GOU1Xxl6iuhH2WcHuueo +UB13JHoZCMZ7hsEqieEz6uteUjdRzRPKclECQFNhVK4iop3emzNQYeJTHwyp+RmQ +JrHq2MTDioyiDUouNsDQbnFMQQ/RtNVB265Q/0hTnbN1ELLFRkK9+87VghECQQC9 +hacLFPk6+TffCp3sHfI3rEj4Iin1iFhKhHWGzW7JwJfjoOXaQK44GDLZ6Q918g+t +MmgDHR2tt8KeYTSgfU+BAkBcaVF91EQ7VXhvyABNYjeYP7lU7orOgdWMa/zbLXSU +4vLsK1WOmwPY9zsXpPkilqszqcru4gzlG462cSbEdAW9 +-----END RSA PRIVATE KEY----- +''') + + for functions in [ + (sign_rsa_sha1_with_client, verify_rsa_sha1), + (sign_rsa_sha256_with_client, verify_rsa_sha256), + (sign_rsa_sha512_with_client, verify_rsa_sha512), + ]: + signing_function = functions[0] + verify_function = functions[1] + + good_signature = \ + signing_function(self.eg_signature_base_string, + self.rsa_private_client) + + bad_signature_on_different_value = \ + signing_function('wrong value signed', self.rsa_private_client) + + bad_signature_produced_by_different_private_key = \ + signing_function(self.eg_signature_base_string, another_client) + + self.assertTrue(verify_function( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + good_signature), + self.rsa_public_client.rsa_key)) + + for bad_signature in [ + '', + 'ZG9uJ3QgdHJ1c3QgbWUK', # random base64 encoded value + 'altérer', # value with a non-ASCII character in it + bad_signature_on_different_value, + bad_signature_produced_by_different_private_key, + ]: + self.assertFalse(verify_function( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + bad_signature), + self.rsa_public_client.rsa_key)) + + def test_rsa_bad_keys(self): + """ + Testing RSA sign and verify with bad key values produces errors. + + This test is useful for coverage tests, since it runs the code branches + that deal with error situations. + """ + + # Signing needs a private key + + for bad_value in [None, '', 'foobar']: + self.assertRaises(ValueError, + sign_rsa_sha1_with_client, + self.eg_signature_base_string, + MockClient(rsa_key=bad_value)) + + self.assertRaises(AttributeError, + sign_rsa_sha1_with_client, + self.eg_signature_base_string, + self.rsa_public_client) # public key doesn't sign + + # Verify needs a public key + + for bad_value in [None, '', 'foobar', self.rsa_private_client.rsa_key]: + self.assertRaises(TypeError, + verify_rsa_sha1, + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + self.expected_signature_rsa_sha1), + MockClient(rsa_key=bad_value)) + + # For completeness, this text could repeat the above for RSA-SHA256 and + # RSA-SHA512 signing and verification functions. + + def test_rsa_jwt_algorithm_cache(self): + # Tests cache of RSAAlgorithm objects is implemented correctly. + + # This is difficult to test, since the cache is internal. + # + # Running this test with coverage will show the cache-hit branch of code + # being executed by two signing operations with the same hash algorithm. + + self.test_sign_rsa_sha1_with_client() # creates cache entry + self.test_sign_rsa_sha1_with_client() # reuses cache entry + + # Some possible bugs will be detected if multiple signing operations + # with different hash algorithms produce the wrong results (e.g. if the + # cache incorrectly returned the previously used algorithm, instead + # of the one that is needed). + + self.test_sign_rsa_sha256_with_client() + self.test_sign_rsa_sha256_with_client() + self.test_sign_rsa_sha1_with_client() + self.test_sign_rsa_sha256_with_client() + self.test_sign_rsa_sha512_with_client() + + # ==== PLAINTEXT signature method tests ========================== + + plaintext_client = hmac_client # for convenience, use the same HMAC secrets + + expected_signature_plaintext = ( + 'ECrDNoq1VYzzzzzzzzzyAK7TwZNtPnkqatqZZZZ' + '&' + 'just-a-string%20%20%20%20asdasd') + + def test_sign_plaintext_with_client(self): + # With PLAINTEXT, the "signature" is always the same: regardless of the + # contents of the request. It is the concatenation of the encoded + # client_secret, an ampersand, and the encoded resource_owner_secret. + # + # That is why the spaces in the resource owner secret are "%20". + + self.assertEqual(self.expected_signature_plaintext, + sign_plaintext_with_client(None, # request is ignored + self.plaintext_client)) + self.assertTrue(verify_plaintext( + MockRequest('PUT', + 'http://example.com/some-other-path', + [('description', 'request is ignored in PLAINTEXT')], + self.expected_signature_plaintext), + self.plaintext_client.client_secret, + self.plaintext_client.resource_owner_secret)) + + def test_plaintext_false_positives(self): + """ + Test verify_plaintext function will correctly detect invalid signatures. + """ + + _ros = self.plaintext_client.resource_owner_secret + + good_signature = \ + sign_plaintext_with_client( + self.eg_signature_base_string, + self.plaintext_client) + + bad_signature_produced_by_different_client_secret = \ + sign_plaintext_with_client( + self.eg_signature_base_string, + MockClient(client_secret='wrong-secret', + resource_owner_secret=_ros)) + bad_signature_produced_by_different_resource_owner_secret = \ + sign_plaintext_with_client( + self.eg_signature_base_string, + MockClient(client_secret=self.plaintext_client.client_secret, + resource_owner_secret='wrong-secret')) + + bad_signature_produced_with_no_resource_owner_secret = \ + sign_plaintext_with_client( + self.eg_signature_base_string, + MockClient(client_secret=self.plaintext_client.client_secret)) + bad_signature_produced_with_no_client_secret = \ + sign_plaintext_with_client( + self.eg_signature_base_string, + MockClient(resource_owner_secret=_ros)) + + self.assertTrue(verify_plaintext( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + good_signature), + self.plaintext_client.client_secret, + self.plaintext_client.resource_owner_secret)) + + for bad_signature in [ + '', + 'ZG9uJ3QgdHJ1c3QgbWUK', # random base64 encoded value + 'altérer', # value with a non-ASCII character in it + bad_signature_produced_by_different_client_secret, + bad_signature_produced_by_different_resource_owner_secret, + bad_signature_produced_with_no_resource_owner_secret, + bad_signature_produced_with_no_client_secret, + ]: + self.assertFalse(verify_plaintext( + MockRequest('POST', + 'http://example.com/request', + self.eg_params, + bad_signature), + self.plaintext_client.client_secret, + self.plaintext_client.resource_owner_secret)) diff --git a/contrib/python/oauthlib/tests/oauth1/rfc5849/test_utils.py b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_utils.py new file mode 100644 index 0000000000..013c71a910 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth1/rfc5849/test_utils.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +from oauthlib.oauth1.rfc5849.utils import * + +from tests.unittest import TestCase + + +class UtilsTests(TestCase): + + sample_params_list = [ + ("notoauth", "shouldnotbehere"), + ("oauth_consumer_key", "9djdj82h48djs9d2"), + ("oauth_token", "kkk9d7dh3k39sjv7"), + ("notoautheither", "shouldnotbehere") + ] + + sample_params_dict = { + "notoauth": "shouldnotbehere", + "oauth_consumer_key": "9djdj82h48djs9d2", + "oauth_token": "kkk9d7dh3k39sjv7", + "notoautheither": "shouldnotbehere" + } + + sample_params_unicode_list = [ + ("notoauth", "shouldnotbehere"), + ("oauth_consumer_key", "9djdj82h48djs9d2"), + ("oauth_token", "kkk9d7dh3k39sjv7"), + ("notoautheither", "shouldnotbehere") + ] + + sample_params_unicode_dict = { + "notoauth": "shouldnotbehere", + "oauth_consumer_key": "9djdj82h48djs9d2", + "oauth_token": "kkk9d7dh3k39sjv7", + "notoautheither": "shouldnotbehere" + } + + authorization_header = """OAuth realm="Example", + oauth_consumer_key="9djdj82h48djs9d2", + oauth_token="kkk9d7dh3k39sjv7", + oauth_signature_method="HMAC-SHA1", + oauth_timestamp="137131201", + oauth_nonce="7d8f3e4a", + oauth_signature="djosJKDKJSD8743243%2Fjdk33klY%3D" """.strip() + bad_authorization_headers = ( + "OAuth", + "OAuth oauth_nonce=", + "Negotiate b2F1dGhsaWI=", + "OA", + ) + + def test_filter_params(self): + + # The following is an isolated test function used to test the filter_params decorator. + @filter_params + def special_test_function(params, realm=None): + """ I am a special test function """ + return 'OAuth ' + ','.join(['='.join([k, v]) for k, v in params]) + + # check that the docstring got through + self.assertEqual(special_test_function.__doc__, " I am a special test function ") + + # Check that the decorator filtering works as per design. + # Any param that does not start with 'oauth' + # should not be present in the filtered params + filtered_params = special_test_function(self.sample_params_list) + self.assertNotIn("notoauth", filtered_params) + self.assertIn("oauth_consumer_key", filtered_params) + self.assertIn("oauth_token", filtered_params) + self.assertNotIn("notoautheither", filtered_params) + + def test_filter_oauth_params(self): + + # try with list + # try with list + # try with list + self.assertEqual(len(self.sample_params_list), 4) + + # Any param that does not start with 'oauth' + # should not be present in the filtered params + filtered_params = filter_oauth_params(self.sample_params_list) + self.assertEqual(len(filtered_params), 2) + + self.assertTrue(filtered_params[0][0].startswith('oauth')) + self.assertTrue(filtered_params[1][0].startswith('oauth')) + + # try with dict + # try with dict + # try with dict + self.assertEqual(len(self.sample_params_dict), 4) + + # Any param that does not start with 'oauth' + # should not be present in the filtered params + filtered_params = filter_oauth_params(self.sample_params_dict) + self.assertEqual(len(filtered_params), 2) + + self.assertTrue(filtered_params[0][0].startswith('oauth')) + self.assertTrue(filtered_params[1][0].startswith('oauth')) + + def test_escape(self): + self.assertRaises(ValueError, escape, b"I am a string type. Not a unicode type.") + self.assertEqual(escape("I am a unicode type."), "I%20am%20a%20unicode%20type.") + self.assertIsInstance(escape("I am a unicode type."), str) + + def test_unescape(self): + self.assertRaises(ValueError, unescape, b"I am a string type. Not a unicode type.") + self.assertEqual(unescape("I%20am%20a%20unicode%20type."), 'I am a unicode type.') + self.assertIsInstance(unescape("I%20am%20a%20unicode%20type."), str) + + def test_parse_authorization_header(self): + # make us some headers + authorization_headers = parse_authorization_header(self.authorization_header) + + # is it a list? + self.assertIsInstance(authorization_headers, list) + + # are the internal items tuples? + for header in authorization_headers: + self.assertIsInstance(header, tuple) + + # are the internal components of each tuple unicode? + for k, v in authorization_headers: + self.assertIsInstance(k, str) + self.assertIsInstance(v, str) + + # let's check the parsed headers created + correct_headers = [ + ("oauth_nonce", "7d8f3e4a"), + ("oauth_timestamp", "137131201"), + ("oauth_consumer_key", "9djdj82h48djs9d2"), + ('oauth_signature', 'djosJKDKJSD8743243%2Fjdk33klY%3D'), + ('oauth_signature_method', 'HMAC-SHA1'), + ('oauth_token', 'kkk9d7dh3k39sjv7'), + ('realm', 'Example')] + self.assertEqual(sorted(authorization_headers), sorted(correct_headers)) + + # Check against malformed headers. + for header in self.bad_authorization_headers: + self.assertRaises(ValueError, parse_authorization_header, header) diff --git a/contrib/python/oauthlib/tests/oauth2/__init__.py b/contrib/python/oauthlib/tests/oauth2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/__init__.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/__init__.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_backend_application.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_backend_application.py new file mode 100644 index 0000000000..c1489ac7c6 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_backend_application.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +import os +from unittest.mock import patch + +from oauthlib import signals +from oauthlib.oauth2 import BackendApplicationClient + +from tests.unittest import TestCase + + +@patch('time.time', new=lambda: 1000) +class BackendApplicationClientTest(TestCase): + + client_id = "someclientid" + client_secret = 'someclientsecret' + scope = ["/profile"] + kwargs = { + "some": "providers", + "require": "extra arguments" + } + + body = "not=empty" + + body_up = "not=empty&grant_type=client_credentials" + body_kwargs = body_up + "&some=providers&require=extra+arguments" + + token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",' + ' "token_type":"example",' + ' "expires_in":3600,' + ' "scope":"/profile",' + ' "example_parameter":"example_value"}') + token = { + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_in": 3600, + "expires_at": 4600, + "scope": ["/profile"], + "example_parameter": "example_value" + } + + def test_request_body(self): + client = BackendApplicationClient(self.client_id) + + # Basic, no extra arguments + body = client.prepare_request_body(body=self.body) + self.assertFormBodyEqual(body, self.body_up) + + rclient = BackendApplicationClient(self.client_id) + body = rclient.prepare_request_body(body=self.body) + self.assertFormBodyEqual(body, self.body_up) + + # With extra parameters + body = client.prepare_request_body(body=self.body, **self.kwargs) + self.assertFormBodyEqual(body, self.body_kwargs) + + def test_parse_token_response(self): + client = BackendApplicationClient(self.client_id) + + # Parse code and state + response = client.parse_request_body_response(self.token_json, scope=self.scope) + self.assertEqual(response, self.token) + self.assertEqual(client.access_token, response.get("access_token")) + self.assertEqual(client.refresh_token, response.get("refresh_token")) + self.assertEqual(client.token_type, response.get("token_type")) + + # Mismatching state + self.assertRaises(Warning, client.parse_request_body_response, self.token_json, scope="invalid") + os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '3' + token = client.parse_request_body_response(self.token_json, scope="invalid") + self.assertTrue(token.scope_changed) + + scope_changes_recorded = [] + def record_scope_change(sender, message, old, new): + scope_changes_recorded.append((message, old, new)) + + signals.scope_changed.connect(record_scope_change) + try: + client.parse_request_body_response(self.token_json, scope="invalid") + self.assertEqual(len(scope_changes_recorded), 1) + message, old, new = scope_changes_recorded[0] + self.assertEqual(message, 'Scope has changed from "invalid" to "/profile".') + self.assertEqual(old, ['invalid']) + self.assertEqual(new, ['/profile']) + finally: + signals.scope_changed.disconnect(record_scope_change) + del os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_base.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_base.py new file mode 100644 index 0000000000..70a22834c3 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_base.py @@ -0,0 +1,355 @@ +# -*- coding: utf-8 -*- +import datetime + +from oauthlib import common +from oauthlib.oauth2 import Client, InsecureTransportError, TokenExpiredError +from oauthlib.oauth2.rfc6749 import utils +from oauthlib.oauth2.rfc6749.clients import AUTH_HEADER, BODY, URI_QUERY + +from tests.unittest import TestCase + + +class ClientTest(TestCase): + + client_id = "someclientid" + uri = "https://example.com/path?query=world" + body = "not=empty" + headers = {} + access_token = "token" + mac_key = "secret" + + bearer_query = uri + "&access_token=" + access_token + bearer_header = { + "Authorization": "Bearer " + access_token + } + bearer_body = body + "&access_token=" + access_token + + mac_00_header = { + "Authorization": 'MAC id="' + access_token + '", nonce="0:abc123",' + + ' bodyhash="Yqyso8r3hR5Nm1ZFv+6AvNHrxjE=",' + + ' mac="0X6aACoBY0G6xgGZVJ1IeE8dF9k="' + } + mac_01_header = { + "Authorization": 'MAC id="' + access_token + '", ts="123456789",' + + ' nonce="abc123", mac="Xuk+9oqaaKyhitkgh1CD0xrI6+s="' + } + + def test_add_bearer_token(self): + """Test a number of bearer token placements""" + + # Invalid token type + client = Client(self.client_id, token_type="invalid") + self.assertRaises(ValueError, client.add_token, self.uri) + + # Case-insensitive token type + client = Client(self.client_id, access_token=self.access_token, token_type="bEAreR") + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers) + self.assertURLEqual(uri, self.uri) + self.assertFormBodyEqual(body, self.body) + self.assertEqual(headers, self.bearer_header) + + # Non-HTTPS + insecure_uri = 'http://example.com/path?query=world' + client = Client(self.client_id, access_token=self.access_token, token_type="Bearer") + self.assertRaises(InsecureTransportError, client.add_token, insecure_uri, + body=self.body, + headers=self.headers) + + # Missing access token + client = Client(self.client_id) + self.assertRaises(ValueError, client.add_token, self.uri) + + # Expired token + expired = 523549800 + expired_token = { + 'expires_at': expired, + } + client = Client(self.client_id, token=expired_token, access_token=self.access_token, token_type="Bearer") + self.assertRaises(TokenExpiredError, client.add_token, self.uri, + body=self.body, headers=self.headers) + + # The default token placement, bearer in auth header + client = Client(self.client_id, access_token=self.access_token) + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers) + self.assertURLEqual(uri, self.uri) + self.assertFormBodyEqual(body, self.body) + self.assertEqual(headers, self.bearer_header) + + # Setting default placements of tokens + client = Client(self.client_id, access_token=self.access_token, + default_token_placement=AUTH_HEADER) + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers) + self.assertURLEqual(uri, self.uri) + self.assertFormBodyEqual(body, self.body) + self.assertEqual(headers, self.bearer_header) + + client = Client(self.client_id, access_token=self.access_token, + default_token_placement=URI_QUERY) + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers) + self.assertURLEqual(uri, self.bearer_query) + self.assertFormBodyEqual(body, self.body) + self.assertEqual(headers, self.headers) + + client = Client(self.client_id, access_token=self.access_token, + default_token_placement=BODY) + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers) + self.assertURLEqual(uri, self.uri) + self.assertFormBodyEqual(body, self.bearer_body) + self.assertEqual(headers, self.headers) + + # Asking for specific placement in the add_token method + client = Client(self.client_id, access_token=self.access_token) + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers, token_placement=AUTH_HEADER) + self.assertURLEqual(uri, self.uri) + self.assertFormBodyEqual(body, self.body) + self.assertEqual(headers, self.bearer_header) + + client = Client(self.client_id, access_token=self.access_token) + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers, token_placement=URI_QUERY) + self.assertURLEqual(uri, self.bearer_query) + self.assertFormBodyEqual(body, self.body) + self.assertEqual(headers, self.headers) + + client = Client(self.client_id, access_token=self.access_token) + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers, token_placement=BODY) + self.assertURLEqual(uri, self.uri) + self.assertFormBodyEqual(body, self.bearer_body) + self.assertEqual(headers, self.headers) + + # Invalid token placement + client = Client(self.client_id, access_token=self.access_token) + self.assertRaises(ValueError, client.add_token, self.uri, body=self.body, + headers=self.headers, token_placement="invalid") + + client = Client(self.client_id, access_token=self.access_token, + default_token_placement="invalid") + self.assertRaises(ValueError, client.add_token, self.uri, body=self.body, + headers=self.headers) + + def test_add_mac_token(self): + # Missing access token + client = Client(self.client_id, token_type="MAC") + self.assertRaises(ValueError, client.add_token, self.uri) + + # Invalid hash algorithm + client = Client(self.client_id, token_type="MAC", + access_token=self.access_token, mac_key=self.mac_key, + mac_algorithm="hmac-sha-2") + self.assertRaises(ValueError, client.add_token, self.uri) + + orig_generate_timestamp = common.generate_timestamp + orig_generate_nonce = common.generate_nonce + orig_generate_age = utils.generate_age + self.addCleanup(setattr, common, 'generage_timestamp', orig_generate_timestamp) + self.addCleanup(setattr, common, 'generage_nonce', orig_generate_nonce) + self.addCleanup(setattr, utils, 'generate_age', orig_generate_age) + common.generate_timestamp = lambda: '123456789' + common.generate_nonce = lambda: 'abc123' + utils.generate_age = lambda *args: 0 + + # Add the Authorization header (draft 00) + client = Client(self.client_id, token_type="MAC", + access_token=self.access_token, mac_key=self.mac_key, + mac_algorithm="hmac-sha-1") + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers, issue_time=datetime.datetime.now()) + self.assertEqual(uri, self.uri) + self.assertEqual(body, self.body) + self.assertEqual(headers, self.mac_00_header) + # Non-HTTPS + insecure_uri = 'http://example.com/path?query=world' + self.assertRaises(InsecureTransportError, client.add_token, insecure_uri, + body=self.body, + headers=self.headers, + issue_time=datetime.datetime.now()) + # Expired Token + expired = 523549800 + expired_token = { + 'expires_at': expired, + } + client = Client(self.client_id, token=expired_token, token_type="MAC", + access_token=self.access_token, mac_key=self.mac_key, + mac_algorithm="hmac-sha-1") + self.assertRaises(TokenExpiredError, client.add_token, self.uri, + body=self.body, + headers=self.headers, + issue_time=datetime.datetime.now()) + + # Add the Authorization header (draft 01) + client = Client(self.client_id, token_type="MAC", + access_token=self.access_token, mac_key=self.mac_key, + mac_algorithm="hmac-sha-1") + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers, draft=1) + self.assertEqual(uri, self.uri) + self.assertEqual(body, self.body) + self.assertEqual(headers, self.mac_01_header) + # Non-HTTPS + insecure_uri = 'http://example.com/path?query=world' + self.assertRaises(InsecureTransportError, client.add_token, insecure_uri, + body=self.body, + headers=self.headers, + draft=1) + # Expired Token + expired = 523549800 + expired_token = { + 'expires_at': expired, + } + client = Client(self.client_id, token=expired_token, token_type="MAC", + access_token=self.access_token, mac_key=self.mac_key, + mac_algorithm="hmac-sha-1") + self.assertRaises(TokenExpiredError, client.add_token, self.uri, + body=self.body, + headers=self.headers, + draft=1) + + def test_revocation_request(self): + client = Client(self.client_id) + + url = 'https://example.com/revoke' + token = 'foobar' + + # Valid request + u, h, b = client.prepare_token_revocation_request(url, token) + self.assertEqual(u, url) + self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertEqual(b, 'token=%s&token_type_hint=access_token' % token) + + # Non-HTTPS revocation endpoint + self.assertRaises(InsecureTransportError, + client.prepare_token_revocation_request, + 'http://example.com/revoke', token) + + + u, h, b = client.prepare_token_revocation_request( + url, token, token_type_hint='refresh_token') + self.assertEqual(u, url) + self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertEqual(b, 'token=%s&token_type_hint=refresh_token' % token) + + # JSONP + u, h, b = client.prepare_token_revocation_request( + url, token, callback='hello.world') + self.assertURLEqual(u, url + '?callback=hello.world&token=%s&token_type_hint=access_token' % token) + self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertEqual(b, '') + + def test_prepare_authorization_request(self): + redirect_url = 'https://example.com/callback/' + scopes = 'read' + auth_url = 'https://example.com/authorize/' + state = 'fake_state' + + client = Client(self.client_id, redirect_url=redirect_url, scope=scopes, state=state) + + # Non-HTTPS + self.assertRaises(InsecureTransportError, + client.prepare_authorization_request, 'http://example.com/authorize/') + + # NotImplementedError + self.assertRaises(NotImplementedError, client.prepare_authorization_request, auth_url) + + def test_prepare_token_request(self): + redirect_url = 'https://example.com/callback/' + scopes = 'read' + token_url = 'https://example.com/token/' + state = 'fake_state' + + client = Client(self.client_id, scope=scopes, state=state) + + # Non-HTTPS + self.assertRaises(InsecureTransportError, + client.prepare_token_request, 'http://example.com/token/') + + # NotImplementedError + self.assertRaises(NotImplementedError, client.prepare_token_request, token_url) + + def test_prepare_refresh_token_request(self): + client = Client(self.client_id) + + url = 'https://example.com/revoke' + token = 'foobar' + scope = 'extra_scope' + + u, h, b = client.prepare_refresh_token_request(url, token) + self.assertEqual(u, url) + self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertFormBodyEqual(b, 'grant_type=refresh_token&refresh_token=%s' % token) + + # Non-HTTPS revocation endpoint + self.assertRaises(InsecureTransportError, + client.prepare_refresh_token_request, + 'http://example.com/revoke', token) + + # provide extra scope + u, h, b = client.prepare_refresh_token_request(url, token, scope=scope) + self.assertEqual(u, url) + self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertFormBodyEqual(b, 'grant_type=refresh_token&scope={}&refresh_token={}'.format(scope, token)) + + # provide scope while init + client = Client(self.client_id, scope=scope) + u, h, b = client.prepare_refresh_token_request(url, token, scope=scope) + self.assertEqual(u, url) + self.assertEqual(h, {'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertFormBodyEqual(b, 'grant_type=refresh_token&scope={}&refresh_token={}'.format(scope, token)) + + def test_parse_token_response_invalid_expires_at(self): + token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",' + ' "token_type":"example",' + ' "expires_at":"2006-01-02T15:04:05Z",' + ' "scope":"/profile",' + ' "example_parameter":"example_value"}') + token = { + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_at": "2006-01-02T15:04:05Z", + "scope": ["/profile"], + "example_parameter": "example_value" + } + + client = Client(self.client_id) + + # Parse code and state + response = client.parse_request_body_response(token_json, scope=["/profile"]) + self.assertEqual(response, token) + self.assertEqual(None, client._expires_at) + self.assertEqual(client.access_token, response.get("access_token")) + self.assertEqual(client.refresh_token, response.get("refresh_token")) + self.assertEqual(client.token_type, response.get("token_type")) + + + def test_create_code_verifier_min_length(self): + client = Client(self.client_id) + length = 43 + code_verifier = client.create_code_verifier(length=length) + self.assertEqual(client.code_verifier, code_verifier) + + def test_create_code_verifier_max_length(self): + client = Client(self.client_id) + length = 128 + code_verifier = client.create_code_verifier(length=length) + self.assertEqual(client.code_verifier, code_verifier) + + def test_create_code_challenge_plain(self): + client = Client(self.client_id) + code_verifier = client.create_code_verifier(length=128) + code_challenge_plain = client.create_code_challenge(code_verifier=code_verifier) + + # if no code_challenge_method specified, code_challenge = code_verifier + self.assertEqual(code_challenge_plain, client.code_verifier) + self.assertEqual(client.code_challenge_method, "plain") + + def test_create_code_challenge_s256(self): + client = Client(self.client_id) + code_verifier = client.create_code_verifier(length=128) + code_challenge_s256 = client.create_code_challenge(code_verifier=code_verifier, code_challenge_method='S256') + self.assertEqual(code_challenge_s256, client.code_challenge) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_legacy_application.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_legacy_application.py new file mode 100644 index 0000000000..b5a18194b7 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_legacy_application.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +import os +import urllib.parse as urlparse +from unittest.mock import patch + +from oauthlib import signals +from oauthlib.oauth2 import LegacyApplicationClient + +from tests.unittest import TestCase + + +@patch('time.time', new=lambda: 1000) +class LegacyApplicationClientTest(TestCase): + + client_id = "someclientid" + client_secret = 'someclientsecret' + scope = ["/profile"] + kwargs = { + "some": "providers", + "require": "extra arguments" + } + + username = "user_username" + password = "user_password" + body = "not=empty" + + body_up = "not=empty&grant_type=password&username={}&password={}".format(username, password) + body_kwargs = body_up + "&some=providers&require=extra+arguments" + + token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",' + ' "token_type":"example",' + ' "expires_in":3600,' + ' "scope":"/profile",' + ' "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter":"example_value"}') + token = { + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_in": 3600, + "expires_at": 4600, + "scope": scope, + "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", + "example_parameter": "example_value" + } + + def test_request_body(self): + client = LegacyApplicationClient(self.client_id) + + # Basic, no extra arguments + body = client.prepare_request_body(self.username, self.password, + body=self.body) + self.assertFormBodyEqual(body, self.body_up) + + # With extra parameters + body = client.prepare_request_body(self.username, self.password, + body=self.body, **self.kwargs) + self.assertFormBodyEqual(body, self.body_kwargs) + + def test_parse_token_response(self): + client = LegacyApplicationClient(self.client_id) + + # Parse code and state + response = client.parse_request_body_response(self.token_json, scope=self.scope) + self.assertEqual(response, self.token) + self.assertEqual(client.access_token, response.get("access_token")) + self.assertEqual(client.refresh_token, response.get("refresh_token")) + self.assertEqual(client.token_type, response.get("token_type")) + + # Mismatching state + self.assertRaises(Warning, client.parse_request_body_response, self.token_json, scope="invalid") + os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '5' + token = client.parse_request_body_response(self.token_json, scope="invalid") + self.assertTrue(token.scope_changed) + + scope_changes_recorded = [] + def record_scope_change(sender, message, old, new): + scope_changes_recorded.append((message, old, new)) + + signals.scope_changed.connect(record_scope_change) + try: + client.parse_request_body_response(self.token_json, scope="invalid") + self.assertEqual(len(scope_changes_recorded), 1) + message, old, new = scope_changes_recorded[0] + self.assertEqual(message, 'Scope has changed from "invalid" to "/profile".') + self.assertEqual(old, ['invalid']) + self.assertEqual(new, ['/profile']) + finally: + signals.scope_changed.disconnect(record_scope_change) + del os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] + + def test_prepare_request_body(self): + """ + see issue #585 + https://github.com/oauthlib/oauthlib/issues/585 + """ + client = LegacyApplicationClient(self.client_id) + + # scenario 1, default behavior to not include `client_id` + r1 = client.prepare_request_body(username=self.username, password=self.password) + self.assertIn(r1, ('grant_type=password&username={}&password={}'.format(self.username, self.password), + 'grant_type=password&password={}&username={}'.format(self.password, self.username), + )) + + # scenario 2, include `client_id` in the body + r2 = client.prepare_request_body(username=self.username, password=self.password, include_client_id=True) + r2_params = dict(urlparse.parse_qsl(r2, keep_blank_values=True)) + self.assertEqual(len(r2_params.keys()), 4) + self.assertEqual(r2_params['grant_type'], 'password') + self.assertEqual(r2_params['username'], self.username) + self.assertEqual(r2_params['password'], self.password) + self.assertEqual(r2_params['client_id'], self.client_id) + + # scenario 3, include `client_id` + `client_secret` in the body + r3 = client.prepare_request_body(username=self.username, password=self.password, include_client_id=True, client_secret=self.client_secret) + r3_params = dict(urlparse.parse_qsl(r3, keep_blank_values=True)) + self.assertEqual(len(r3_params.keys()), 5) + self.assertEqual(r3_params['grant_type'], 'password') + self.assertEqual(r3_params['username'], self.username) + self.assertEqual(r3_params['password'], self.password) + self.assertEqual(r3_params['client_id'], self.client_id) + self.assertEqual(r3_params['client_secret'], self.client_secret) + + # scenario 4, `client_secret` is an empty string + r4 = client.prepare_request_body(username=self.username, password=self.password, include_client_id=True, client_secret='') + r4_params = dict(urlparse.parse_qsl(r4, keep_blank_values=True)) + self.assertEqual(len(r4_params.keys()), 5) + self.assertEqual(r4_params['grant_type'], 'password') + self.assertEqual(r4_params['username'], self.username) + self.assertEqual(r4_params['password'], self.password) + self.assertEqual(r4_params['client_id'], self.client_id) + self.assertEqual(r4_params['client_secret'], '') + + # scenario 4b`,` client_secret is `None` + r4b = client.prepare_request_body(username=self.username, password=self.password, include_client_id=True, client_secret=None) + r4b_params = dict(urlparse.parse_qsl(r4b, keep_blank_values=True)) + self.assertEqual(len(r4b_params.keys()), 4) + self.assertEqual(r4b_params['grant_type'], 'password') + self.assertEqual(r4b_params['username'], self.username) + self.assertEqual(r4b_params['password'], self.password) + self.assertEqual(r4b_params['client_id'], self.client_id) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_mobile_application.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_mobile_application.py new file mode 100644 index 0000000000..c40950c978 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_mobile_application.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +import os +from unittest.mock import patch + +from oauthlib import signals +from oauthlib.oauth2 import MobileApplicationClient + +from tests.unittest import TestCase + + +@patch('time.time', new=lambda: 1000) +class MobileApplicationClientTest(TestCase): + + client_id = "someclientid" + uri = "https://example.com/path?query=world" + uri_id = uri + "&response_type=token&client_id=" + client_id + uri_redirect = uri_id + "&redirect_uri=http%3A%2F%2Fmy.page.com%2Fcallback" + redirect_uri = "http://my.page.com/callback" + scope = ["/profile"] + state = "xyz" + uri_scope = uri_id + "&scope=%2Fprofile" + uri_state = uri_id + "&state=" + state + kwargs = { + "some": "providers", + "require": "extra arguments" + } + uri_kwargs = uri_id + "&some=providers&require=extra+arguments" + + code = "zzzzaaaa" + + response_uri = ('https://client.example.com/cb?#' + 'access_token=2YotnFZFEjr1zCsicMWpAA&' + 'token_type=example&' + 'expires_in=3600&' + 'scope=%2Fprofile&' + 'example_parameter=example_value') + token = { + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_in": 3600, + "expires_at": 4600, + "scope": scope, + "example_parameter": "example_value" + } + + def test_implicit_token_uri(self): + client = MobileApplicationClient(self.client_id) + + # Basic, no extra arguments + uri = client.prepare_request_uri(self.uri) + self.assertURLEqual(uri, self.uri_id) + + # With redirection uri + uri = client.prepare_request_uri(self.uri, redirect_uri=self.redirect_uri) + self.assertURLEqual(uri, self.uri_redirect) + + # With scope + uri = client.prepare_request_uri(self.uri, scope=self.scope) + self.assertURLEqual(uri, self.uri_scope) + + # With state + uri = client.prepare_request_uri(self.uri, state=self.state) + self.assertURLEqual(uri, self.uri_state) + + # With extra parameters through kwargs + uri = client.prepare_request_uri(self.uri, **self.kwargs) + self.assertURLEqual(uri, self.uri_kwargs) + + def test_populate_attributes(self): + + client = MobileApplicationClient(self.client_id) + + response_uri = (self.response_uri + "&code=EVIL-CODE") + + client.parse_request_uri_response(response_uri, scope=self.scope) + + # We must not accidentally pick up any further security + # credentials at this point. + self.assertIsNone(client.code) + + def test_parse_token_response(self): + client = MobileApplicationClient(self.client_id) + + # Parse code and state + response = client.parse_request_uri_response(self.response_uri, scope=self.scope) + self.assertEqual(response, self.token) + self.assertEqual(client.access_token, response.get("access_token")) + self.assertEqual(client.refresh_token, response.get("refresh_token")) + self.assertEqual(client.token_type, response.get("token_type")) + + # Mismatching scope + self.assertRaises(Warning, client.parse_request_uri_response, self.response_uri, scope="invalid") + os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '4' + token = client.parse_request_uri_response(self.response_uri, scope='invalid') + self.assertTrue(token.scope_changed) + + scope_changes_recorded = [] + def record_scope_change(sender, message, old, new): + scope_changes_recorded.append((message, old, new)) + + signals.scope_changed.connect(record_scope_change) + try: + client.parse_request_uri_response(self.response_uri, scope="invalid") + self.assertEqual(len(scope_changes_recorded), 1) + message, old, new = scope_changes_recorded[0] + self.assertEqual(message, 'Scope has changed from "invalid" to "/profile".') + self.assertEqual(old, ['invalid']) + self.assertEqual(new, ['/profile']) + finally: + signals.scope_changed.disconnect(record_scope_change) + del os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_service_application.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_service_application.py new file mode 100644 index 0000000000..b97d8554ed --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_service_application.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +import os +from time import time +from unittest.mock import patch + +import jwt + +from oauthlib.common import Request +from oauthlib.oauth2 import ServiceApplicationClient + +from tests.unittest import TestCase + + +class ServiceApplicationClientTest(TestCase): + + gt = ServiceApplicationClient.grant_type + + private_key = """ +-----BEGIN RSA PRIVATE KEY----- +MIICXgIBAAKBgQDk1/bxyS8Q8jiheHeYYp/4rEKJopeQRRKKpZI4s5i+UPwVpupG +AlwXWfzXwSMaKPAoKJNdu7tqKRniqst5uoHXw98gj0x7zamu0Ck1LtQ4c7pFMVah +5IYGhBi2E9ycNS329W27nJPWNCbESTu7snVlG8V8mfvGGg3xNjTMO7IdrwIDAQAB +AoGBAOQ2KuH8S5+OrsL4K+wfjoCi6MfxCUyqVU9GxocdM1m30WyWRFMEz2nKJ8fR +p3vTD4w8yplTOhcoXdQZl0kRoaDzrcYkm2VvJtQRrX7dKFT8dR8D/Tr7dNQLOXfC +DY6xveQczE7qt7Vk7lp4FqmxBsaaEuokt78pOOjywZoInjZhAkEA9wz3zoZNT0/i +rf6qv2qTIeieUB035N3dyw6f1BGSWYaXSuerDCD/J1qZbAPKKhyHZbVawFt3UMhe +542UftBaxQJBAO0iJy1I8GQjGnS7B3yvyH3CcLYGy296+XO/2xKp/d/ty1OIeovx +C60pLNwuFNF3z9d2GVQAdoQ89hUkOtjZLeMCQQD0JO6oPHUeUjYT+T7ImAv7UKVT +Suy30sKjLzqoGw1kR+wv7C5PeDRvscs4wa4CW9s6mjSrMDkDrmCLuJDtmf55AkEA +kmaMg2PNrjUR51F0zOEFycaaqXbGcFwe1/xx9zLmHzMDXd4bsnwt9kk+fe0hQzVS +JzatanQit3+feev1PN3QewJAWv4RZeavEUhKv+kLe95Yd0su7lTLVduVgh4v5yLT +Ga6FHdjGPcfajt+nrpB1n8UQBEH9ZxniokR/IPvdMlxqXA== +-----END RSA PRIVATE KEY----- +""" + + public_key = """ +-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDk1/bxyS8Q8jiheHeYYp/4rEKJ +opeQRRKKpZI4s5i+UPwVpupGAlwXWfzXwSMaKPAoKJNdu7tqKRniqst5uoHXw98g +j0x7zamu0Ck1LtQ4c7pFMVah5IYGhBi2E9ycNS329W27nJPWNCbESTu7snVlG8V8 +mfvGGg3xNjTMO7IdrwIDAQAB +-----END PUBLIC KEY----- +""" + + subject = 'resource-owner@provider.com' + + issuer = 'the-client@provider.com' + + audience = 'https://provider.com/token' + + client_id = "someclientid" + scope = ["/profile"] + kwargs = { + "some": "providers", + "require": "extra arguments" + } + + body = "isnot=empty" + + body_up = "not=empty&grant_type=%s" % gt + body_kwargs = body_up + "&some=providers&require=extra+arguments" + + token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",' + ' "token_type":"example",' + ' "expires_in":3600,' + ' "scope":"/profile",' + ' "example_parameter":"example_value"}') + token = { + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_in": 3600, + "scope": ["/profile"], + "example_parameter": "example_value" + } + + @patch('time.time') + def test_request_body(self, t): + t.return_value = time() + self.token['expires_at'] = self.token['expires_in'] + t.return_value + + client = ServiceApplicationClient( + self.client_id, private_key=self.private_key) + + # Basic with min required params + body = client.prepare_request_body(issuer=self.issuer, + subject=self.subject, + audience=self.audience, + body=self.body) + r = Request('https://a.b', body=body) + self.assertEqual(r.isnot, 'empty') + self.assertEqual(r.grant_type, ServiceApplicationClient.grant_type) + + claim = jwt.decode(r.assertion, self.public_key, audience=self.audience, algorithms=['RS256']) + + self.assertEqual(claim['iss'], self.issuer) + # audience verification is handled during decode now + self.assertEqual(claim['sub'], self.subject) + self.assertEqual(claim['iat'], int(t.return_value)) + self.assertNotIn('nbf', claim) + self.assertNotIn('jti', claim) + + # Missing issuer parameter + self.assertRaises(ValueError, client.prepare_request_body, + issuer=None, subject=self.subject, audience=self.audience, body=self.body) + + # Missing subject parameter + self.assertRaises(ValueError, client.prepare_request_body, + issuer=self.issuer, subject=None, audience=self.audience, body=self.body) + + # Missing audience parameter + self.assertRaises(ValueError, client.prepare_request_body, + issuer=self.issuer, subject=self.subject, audience=None, body=self.body) + + # Optional kwargs + not_before = time() - 3600 + jwt_id = '8zd15df4s35f43sd' + body = client.prepare_request_body(issuer=self.issuer, + subject=self.subject, + audience=self.audience, + body=self.body, + not_before=not_before, + jwt_id=jwt_id) + + r = Request('https://a.b', body=body) + self.assertEqual(r.isnot, 'empty') + self.assertEqual(r.grant_type, ServiceApplicationClient.grant_type) + + claim = jwt.decode(r.assertion, self.public_key, audience=self.audience, algorithms=['RS256']) + + self.assertEqual(claim['iss'], self.issuer) + # audience verification is handled during decode now + self.assertEqual(claim['sub'], self.subject) + self.assertEqual(claim['iat'], int(t.return_value)) + self.assertEqual(claim['nbf'], not_before) + self.assertEqual(claim['jti'], jwt_id) + + @patch('time.time') + def test_request_body_no_initial_private_key(self, t): + t.return_value = time() + self.token['expires_at'] = self.token['expires_in'] + t.return_value + + client = ServiceApplicationClient( + self.client_id, private_key=None) + + # Basic with private key provided + body = client.prepare_request_body(issuer=self.issuer, + subject=self.subject, + audience=self.audience, + body=self.body, + private_key=self.private_key) + r = Request('https://a.b', body=body) + self.assertEqual(r.isnot, 'empty') + self.assertEqual(r.grant_type, ServiceApplicationClient.grant_type) + + claim = jwt.decode(r.assertion, self.public_key, audience=self.audience, algorithms=['RS256']) + + self.assertEqual(claim['iss'], self.issuer) + # audience verification is handled during decode now + self.assertEqual(claim['sub'], self.subject) + self.assertEqual(claim['iat'], int(t.return_value)) + + # No private key provided + self.assertRaises(ValueError, client.prepare_request_body, + issuer=self.issuer, subject=self.subject, audience=self.audience, body=self.body) + + @patch('time.time') + def test_parse_token_response(self, t): + t.return_value = time() + self.token['expires_at'] = self.token['expires_in'] + t.return_value + + client = ServiceApplicationClient(self.client_id) + + # Parse code and state + response = client.parse_request_body_response(self.token_json, scope=self.scope) + self.assertEqual(response, self.token) + self.assertEqual(client.access_token, response.get("access_token")) + self.assertEqual(client.refresh_token, response.get("refresh_token")) + self.assertEqual(client.token_type, response.get("token_type")) + + # Mismatching state + self.assertRaises(Warning, client.parse_request_body_response, self.token_json, scope="invalid") + os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '2' + token = client.parse_request_body_response(self.token_json, scope="invalid") + self.assertTrue(token.scope_changed) + del os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_web_application.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_web_application.py new file mode 100644 index 0000000000..7a71121512 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/clients/test_web_application.py @@ -0,0 +1,269 @@ +# -*- coding: utf-8 -*- +import os +import urllib.parse as urlparse +import warnings +from unittest.mock import patch + +from oauthlib import common, signals +from oauthlib.oauth2 import ( + BackendApplicationClient, Client, LegacyApplicationClient, + MobileApplicationClient, WebApplicationClient, +) +from oauthlib.oauth2.rfc6749 import errors, utils +from oauthlib.oauth2.rfc6749.clients import AUTH_HEADER, BODY, URI_QUERY + +from tests.unittest import TestCase + + +@patch('time.time', new=lambda: 1000) +class WebApplicationClientTest(TestCase): + + client_id = "someclientid" + client_secret = 'someclientsecret' + uri = "https://example.com/path?query=world" + uri_id = uri + "&response_type=code&client_id=" + client_id + uri_redirect = uri_id + "&redirect_uri=http%3A%2F%2Fmy.page.com%2Fcallback" + redirect_uri = "http://my.page.com/callback" + code_verifier = "code_verifier" + scope = ["/profile"] + state = "xyz" + code_challenge = "code_challenge" + code_challenge_method = "S256" + uri_scope = uri_id + "&scope=%2Fprofile" + uri_state = uri_id + "&state=" + state + uri_code_challenge = uri_id + "&code_challenge=" + code_challenge + "&code_challenge_method=" + code_challenge_method + uri_code_challenge_method = uri_id + "&code_challenge=" + code_challenge + "&code_challenge_method=plain" + kwargs = { + "some": "providers", + "require": "extra arguments" + } + uri_kwargs = uri_id + "&some=providers&require=extra+arguments" + uri_authorize_code = uri_redirect + "&scope=%2Fprofile&state=" + state + + code = "zzzzaaaa" + body = "not=empty" + + body_code = "not=empty&grant_type=authorization_code&code={}&client_id={}".format(code, client_id) + body_redirect = body_code + "&redirect_uri=http%3A%2F%2Fmy.page.com%2Fcallback" + body_code_verifier = body_code + "&code_verifier=code_verifier" + body_kwargs = body_code + "&some=providers&require=extra+arguments" + + response_uri = "https://client.example.com/cb?code=zzzzaaaa&state=xyz" + response = {"code": "zzzzaaaa", "state": "xyz"} + + token_json = ('{ "access_token":"2YotnFZFEjr1zCsicMWpAA",' + ' "token_type":"example",' + ' "expires_in":3600,' + ' "scope":"/profile",' + ' "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter":"example_value"}') + token = { + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_in": 3600, + "expires_at": 4600, + "scope": scope, + "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", + "example_parameter": "example_value" + } + + def test_auth_grant_uri(self): + client = WebApplicationClient(self.client_id) + + # Basic, no extra arguments + uri = client.prepare_request_uri(self.uri) + self.assertURLEqual(uri, self.uri_id) + + # With redirection uri + uri = client.prepare_request_uri(self.uri, redirect_uri=self.redirect_uri) + self.assertURLEqual(uri, self.uri_redirect) + + # With scope + uri = client.prepare_request_uri(self.uri, scope=self.scope) + self.assertURLEqual(uri, self.uri_scope) + + # With state + uri = client.prepare_request_uri(self.uri, state=self.state) + self.assertURLEqual(uri, self.uri_state) + + # with code_challenge and code_challenge_method + uri = client.prepare_request_uri(self.uri, code_challenge=self.code_challenge, code_challenge_method=self.code_challenge_method) + self.assertURLEqual(uri, self.uri_code_challenge) + + # with no code_challenge_method + uri = client.prepare_request_uri(self.uri, code_challenge=self.code_challenge) + self.assertURLEqual(uri, self.uri_code_challenge_method) + + # With extra parameters through kwargs + uri = client.prepare_request_uri(self.uri, **self.kwargs) + self.assertURLEqual(uri, self.uri_kwargs) + + def test_request_body(self): + client = WebApplicationClient(self.client_id, code=self.code) + + # Basic, no extra arguments + body = client.prepare_request_body(body=self.body) + self.assertFormBodyEqual(body, self.body_code) + + rclient = WebApplicationClient(self.client_id) + body = rclient.prepare_request_body(code=self.code, body=self.body) + self.assertFormBodyEqual(body, self.body_code) + + # With redirection uri + body = client.prepare_request_body(body=self.body, redirect_uri=self.redirect_uri) + self.assertFormBodyEqual(body, self.body_redirect) + + # With code verifier + body = client.prepare_request_body(body=self.body, code_verifier=self.code_verifier) + self.assertFormBodyEqual(body, self.body_code_verifier) + + # With extra parameters + body = client.prepare_request_body(body=self.body, **self.kwargs) + self.assertFormBodyEqual(body, self.body_kwargs) + + def test_parse_grant_uri_response(self): + client = WebApplicationClient(self.client_id) + + # Parse code and state + response = client.parse_request_uri_response(self.response_uri, state=self.state) + self.assertEqual(response, self.response) + self.assertEqual(client.code, self.code) + + # Mismatching state + self.assertRaises(errors.MismatchingStateError, + client.parse_request_uri_response, + self.response_uri, + state="invalid") + + def test_populate_attributes(self): + + client = WebApplicationClient(self.client_id) + + response_uri = (self.response_uri + + "&access_token=EVIL-TOKEN" + "&refresh_token=EVIL-TOKEN" + "&mac_key=EVIL-KEY") + + client.parse_request_uri_response(response_uri, self.state) + + self.assertEqual(client.code, self.code) + + # We must not accidentally pick up any further security + # credentials at this point. + self.assertIsNone(client.access_token) + self.assertIsNone(client.refresh_token) + self.assertIsNone(client.mac_key) + + def test_parse_token_response(self): + client = WebApplicationClient(self.client_id) + + # Parse code and state + response = client.parse_request_body_response(self.token_json, scope=self.scope) + self.assertEqual(response, self.token) + self.assertEqual(client.access_token, response.get("access_token")) + self.assertEqual(client.refresh_token, response.get("refresh_token")) + self.assertEqual(client.token_type, response.get("token_type")) + + # Mismatching state + self.assertRaises(Warning, client.parse_request_body_response, self.token_json, scope="invalid") + os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '1' + token = client.parse_request_body_response(self.token_json, scope="invalid") + self.assertTrue(token.scope_changed) + + scope_changes_recorded = [] + def record_scope_change(sender, message, old, new): + scope_changes_recorded.append((message, old, new)) + + signals.scope_changed.connect(record_scope_change) + try: + client.parse_request_body_response(self.token_json, scope="invalid") + self.assertEqual(len(scope_changes_recorded), 1) + message, old, new = scope_changes_recorded[0] + self.assertEqual(message, 'Scope has changed from "invalid" to "/profile".') + self.assertEqual(old, ['invalid']) + self.assertEqual(new, ['/profile']) + finally: + signals.scope_changed.disconnect(record_scope_change) + del os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] + + def test_prepare_authorization_requeset(self): + client = WebApplicationClient(self.client_id) + + url, header, body = client.prepare_authorization_request( + self.uri, redirect_url=self.redirect_uri, state=self.state, scope=self.scope) + self.assertURLEqual(url, self.uri_authorize_code) + # verify default header and body only + self.assertEqual(header, {'Content-Type': 'application/x-www-form-urlencoded'}) + self.assertEqual(body, '') + + def test_prepare_request_body(self): + """ + see issue #585 + https://github.com/oauthlib/oauthlib/issues/585 + + `prepare_request_body` should support the following scenarios: + 1. Include client_id alone in the body (default) + 2. Include client_id and client_secret in auth and not include them in the body (RFC preferred solution) + 3. Include client_id and client_secret in the body (RFC alternative solution) + 4. Include client_id in the body and an empty string for client_secret. + """ + client = WebApplicationClient(self.client_id) + + # scenario 1, default behavior to include `client_id` + r1 = client.prepare_request_body() + self.assertEqual(r1, 'grant_type=authorization_code&client_id=%s' % self.client_id) + + r1b = client.prepare_request_body(include_client_id=True) + self.assertEqual(r1b, 'grant_type=authorization_code&client_id=%s' % self.client_id) + + # scenario 2, do not include `client_id` in the body, so it can be sent in auth. + r2 = client.prepare_request_body(include_client_id=False) + self.assertEqual(r2, 'grant_type=authorization_code') + + # scenario 3, Include client_id and client_secret in the body (RFC alternative solution) + # the order of kwargs being appended is not guaranteed. for brevity, check the 2 permutations instead of sorting + r3 = client.prepare_request_body(client_secret=self.client_secret) + r3_params = dict(urlparse.parse_qsl(r3, keep_blank_values=True)) + self.assertEqual(len(r3_params.keys()), 3) + self.assertEqual(r3_params['grant_type'], 'authorization_code') + self.assertEqual(r3_params['client_id'], self.client_id) + self.assertEqual(r3_params['client_secret'], self.client_secret) + + r3b = client.prepare_request_body(include_client_id=True, client_secret=self.client_secret) + r3b_params = dict(urlparse.parse_qsl(r3b, keep_blank_values=True)) + self.assertEqual(len(r3b_params.keys()), 3) + self.assertEqual(r3b_params['grant_type'], 'authorization_code') + self.assertEqual(r3b_params['client_id'], self.client_id) + self.assertEqual(r3b_params['client_secret'], self.client_secret) + + # scenario 4, `client_secret` is an empty string + r4 = client.prepare_request_body(include_client_id=True, client_secret='') + r4_params = dict(urlparse.parse_qsl(r4, keep_blank_values=True)) + self.assertEqual(len(r4_params.keys()), 3) + self.assertEqual(r4_params['grant_type'], 'authorization_code') + self.assertEqual(r4_params['client_id'], self.client_id) + self.assertEqual(r4_params['client_secret'], '') + + # scenario 4b, `client_secret` is `None` + r4b = client.prepare_request_body(include_client_id=True, client_secret=None) + r4b_params = dict(urlparse.parse_qsl(r4b, keep_blank_values=True)) + self.assertEqual(len(r4b_params.keys()), 2) + self.assertEqual(r4b_params['grant_type'], 'authorization_code') + self.assertEqual(r4b_params['client_id'], self.client_id) + + # scenario Warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # catch all + + # warning1 - raise a DeprecationWarning if a `client_id` is submitted + rWarnings1 = client.prepare_request_body(client_id=self.client_id) + self.assertEqual(len(w), 1) + self.assertIsInstance(w[0].message, DeprecationWarning) + + # testing the exact warning message in Python2&Python3 is a pain + + # scenario Exceptions + # exception1 - raise a ValueError if the a different `client_id` is submitted + with self.assertRaises(ValueError) as cm: + client.prepare_request_body(client_id='different_client_id') + # testing the exact exception message in Python2&Python3 is a pain diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/__init__.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py new file mode 100644 index 0000000000..b1af6c3306 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +from oauthlib.oauth2 import ( + FatalClientError, OAuth2Error, RequestValidator, Server, +) +from oauthlib.oauth2.rfc6749 import ( + BaseEndpoint, catch_errors_and_unavailability, +) + +from tests.unittest import TestCase + + +class BaseEndpointTest(TestCase): + + def test_default_config(self): + endpoint = BaseEndpoint() + self.assertFalse(endpoint.catch_errors) + self.assertTrue(endpoint.available) + endpoint.catch_errors = True + self.assertTrue(endpoint.catch_errors) + endpoint.available = False + self.assertFalse(endpoint.available) + + def test_error_catching(self): + validator = RequestValidator() + server = Server(validator) + server.catch_errors = True + h, b, s = server.create_token_response( + 'https://example.com', body='grant_type=authorization_code&code=abc' + ) + self.assertIn("server_error", b) + self.assertEqual(s, 500) + + def test_unavailability(self): + validator = RequestValidator() + server = Server(validator) + server.available = False + h, b, s = server.create_authorization_response('https://example.com') + self.assertIn("temporarily_unavailable", b) + self.assertEqual(s, 503) + + def test_wrapper(self): + + class TestServer(Server): + + @catch_errors_and_unavailability + def throw_error(self, uri): + raise ValueError() + + @catch_errors_and_unavailability + def throw_oauth_error(self, uri): + raise OAuth2Error() + + @catch_errors_and_unavailability + def throw_fatal_oauth_error(self, uri): + raise FatalClientError() + + validator = RequestValidator() + server = TestServer(validator) + + server.catch_errors = True + h, b, s = server.throw_error('a') + self.assertIn("server_error", b) + self.assertEqual(s, 500) + + server.available = False + h, b, s = server.throw_error('a') + self.assertIn("temporarily_unavailable", b) + self.assertEqual(s, 503) + + server.available = True + self.assertRaises(OAuth2Error, server.throw_oauth_error, 'a') + self.assertRaises(FatalClientError, server.throw_fatal_oauth_error, 'a') + server.catch_errors = False + self.assertRaises(OAuth2Error, server.throw_oauth_error, 'a') + self.assertRaises(FatalClientError, server.throw_fatal_oauth_error, 'a') diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_client_authentication.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_client_authentication.py new file mode 100644 index 0000000000..0659ee0d25 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_client_authentication.py @@ -0,0 +1,162 @@ +"""Client authentication tests across all endpoints. + +Client authentication in OAuth2 serve two purposes, to authenticate +confidential clients and to ensure public clients are in fact public. The +latter is achieved with authenticate_client_id and the former with +authenticate_client. + +We make sure authentication is done by requiring a client object to be set +on the request object with a client_id parameter. The client_id attribute +prevents this check from being circumvented with a client form parameter. +""" +import json +from unittest import mock + +from oauthlib.oauth2 import ( + BackendApplicationServer, LegacyApplicationServer, MobileApplicationServer, + RequestValidator, WebApplicationServer, +) + +from tests.unittest import TestCase + +from .test_utils import get_fragment_credentials + + +class ClientAuthenticationTest(TestCase): + + def inspect_client(self, request, refresh_token=False): + if not request.client or not request.client.client_id: + raise ValueError() + return 'abc' + + def setUp(self): + self.validator = mock.MagicMock(spec=RequestValidator) + self.validator.is_pkce_required.return_value = False + self.validator.get_code_challenge.return_value = None + self.validator.get_default_redirect_uri.return_value = 'http://i.b./path' + self.web = WebApplicationServer(self.validator, + token_generator=self.inspect_client) + self.mobile = MobileApplicationServer(self.validator, + token_generator=self.inspect_client) + self.legacy = LegacyApplicationServer(self.validator, + token_generator=self.inspect_client) + self.backend = BackendApplicationServer(self.validator, + token_generator=self.inspect_client) + self.token_uri = 'http://example.com/path' + self.auth_uri = 'http://example.com/path?client_id=abc&response_type=token' + # should be base64 but no added value in this unittest + self.basicauth_client_creds = {"Authorization": "john:doe"} + self.basicauth_client_id = {"Authorization": "john:"} + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def set_client_id(self, client_id, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def basicauth_authenticate_client(self, request): + assert "Authorization" in request.headers + assert "john:doe" in request.headers["Authorization"] + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def test_client_id_authentication(self): + token_uri = 'http://example.com/path' + + # authorization code grant + self.validator.authenticate_client.return_value = False + self.validator.authenticate_client_id.return_value = False + _, body, _ = self.web.create_token_response(token_uri, + body='grant_type=authorization_code&code=mock') + self.assertEqual(json.loads(body)['error'], 'invalid_client') + + self.validator.authenticate_client_id.return_value = True + self.validator.authenticate_client.side_effect = self.set_client + _, body, _ = self.web.create_token_response(token_uri, + body='grant_type=authorization_code&code=mock') + self.assertIn('access_token', json.loads(body)) + + # implicit grant + auth_uri = 'http://example.com/path?client_id=abc&response_type=token' + self.assertRaises(ValueError, self.mobile.create_authorization_response, + auth_uri, scopes=['random']) + + self.validator.validate_client_id.side_effect = self.set_client_id + h, _, s = self.mobile.create_authorization_response(auth_uri, scopes=['random']) + self.assertEqual(302, s) + self.assertIn('Location', h) + self.assertIn('access_token', get_fragment_credentials(h['Location'])) + + def test_basicauth_web(self): + self.validator.authenticate_client.side_effect = self.basicauth_authenticate_client + _, body, _ = self.web.create_token_response( + self.token_uri, + body='grant_type=authorization_code&code=mock', + headers=self.basicauth_client_creds + ) + self.assertIn('access_token', json.loads(body)) + + def test_basicauth_legacy(self): + self.validator.authenticate_client.side_effect = self.basicauth_authenticate_client + _, body, _ = self.legacy.create_token_response( + self.token_uri, + body='grant_type=password&username=abc&password=secret', + headers=self.basicauth_client_creds + ) + self.assertIn('access_token', json.loads(body)) + + def test_basicauth_backend(self): + self.validator.authenticate_client.side_effect = self.basicauth_authenticate_client + _, body, _ = self.backend.create_token_response( + self.token_uri, + body='grant_type=client_credentials', + headers=self.basicauth_client_creds + ) + self.assertIn('access_token', json.loads(body)) + + def test_basicauth_revoke(self): + self.validator.authenticate_client.side_effect = self.basicauth_authenticate_client + + # legacy or any other uses the same RevocationEndpoint + _, body, status = self.legacy.create_revocation_response( + self.token_uri, + body='token=foobar', + headers=self.basicauth_client_creds + ) + self.assertEqual(status, 200, body) + + def test_basicauth_introspect(self): + self.validator.authenticate_client.side_effect = self.basicauth_authenticate_client + + # legacy or any other uses the same IntrospectEndpoint + _, body, status = self.legacy.create_introspect_response( + self.token_uri, + body='token=foobar', + headers=self.basicauth_client_creds + ) + self.assertEqual(status, 200, body) + + def test_custom_authentication(self): + token_uri = 'http://example.com/path' + + # authorization code grant + self.assertRaises(NotImplementedError, + self.web.create_token_response, token_uri, + body='grant_type=authorization_code&code=mock') + + # password grant + self.validator.authenticate_client.return_value = True + self.assertRaises(NotImplementedError, + self.legacy.create_token_response, token_uri, + body='grant_type=password&username=abc&password=secret') + + # client credentials grant + self.validator.authenticate_client.return_value = True + self.assertRaises(NotImplementedError, + self.backend.create_token_response, token_uri, + body='grant_type=client_credentials') diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_credentials_preservation.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_credentials_preservation.py new file mode 100644 index 0000000000..32c770ccb7 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_credentials_preservation.py @@ -0,0 +1,128 @@ +"""Ensure credentials are preserved through the authorization. + +The Authorization Code Grant will need to preserve state as well as redirect +uri and the Implicit Grant will need to preserve state. +""" +import json +from unittest import mock + +from oauthlib.oauth2 import ( + MobileApplicationServer, RequestValidator, WebApplicationServer, +) +from oauthlib.oauth2.rfc6749 import errors + +from tests.unittest import TestCase + +from .test_utils import get_fragment_credentials, get_query_credentials + + +class PreservationTest(TestCase): + + DEFAULT_REDIRECT_URI = 'http://i.b./path' + + def setUp(self): + self.validator = mock.MagicMock(spec=RequestValidator) + self.validator.get_default_redirect_uri.return_value = self.DEFAULT_REDIRECT_URI + self.validator.get_code_challenge.return_value = None + self.validator.authenticate_client.side_effect = self.set_client + self.web = WebApplicationServer(self.validator) + self.mobile = MobileApplicationServer(self.validator) + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def test_state_preservation(self): + auth_uri = 'http://example.com/path?state=xyz&client_id=abc&response_type=' + + # authorization grant + h, _, s = self.web.create_authorization_response( + auth_uri + 'code', scopes=['random']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertEqual(get_query_credentials(h['Location'])['state'][0], 'xyz') + + # implicit grant + h, _, s = self.mobile.create_authorization_response( + auth_uri + 'token', scopes=['random']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertEqual(get_fragment_credentials(h['Location'])['state'][0], 'xyz') + + def test_redirect_uri_preservation(self): + auth_uri = 'http://example.com/path?redirect_uri=http%3A%2F%2Fi.b%2Fpath&client_id=abc' + redirect_uri = 'http://i.b/path' + token_uri = 'http://example.com/path' + + # authorization grant + h, _, s = self.web.create_authorization_response( + auth_uri + '&response_type=code', scopes=['random']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertTrue(h['Location'].startswith(redirect_uri)) + + # confirm_redirect_uri should return false if the redirect uri + # was given in the authorization but not in the token request. + self.validator.confirm_redirect_uri.return_value = False + code = get_query_credentials(h['Location'])['code'][0] + _, body, _ = self.web.create_token_response(token_uri, + body='grant_type=authorization_code&code=%s' % code) + self.assertEqual(json.loads(body)['error'], 'invalid_request') + + # implicit grant + h, _, s = self.mobile.create_authorization_response( + auth_uri + '&response_type=token', scopes=['random']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertTrue(h['Location'].startswith(redirect_uri)) + + def test_invalid_redirect_uri(self): + auth_uri = 'http://example.com/path?redirect_uri=http%3A%2F%2Fi.b%2Fpath&client_id=abc' + self.validator.validate_redirect_uri.return_value = False + + # authorization grant + self.assertRaises(errors.MismatchingRedirectURIError, + self.web.create_authorization_response, + auth_uri + '&response_type=code', scopes=['random']) + + # implicit grant + self.assertRaises(errors.MismatchingRedirectURIError, + self.mobile.create_authorization_response, + auth_uri + '&response_type=token', scopes=['random']) + + def test_default_uri(self): + auth_uri = 'http://example.com/path?state=xyz&client_id=abc' + + self.validator.get_default_redirect_uri.return_value = None + + # authorization grant + self.assertRaises(errors.MissingRedirectURIError, + self.web.create_authorization_response, + auth_uri + '&response_type=code', scopes=['random']) + + # implicit grant + self.assertRaises(errors.MissingRedirectURIError, + self.mobile.create_authorization_response, + auth_uri + '&response_type=token', scopes=['random']) + + def test_default_uri_in_token(self): + auth_uri = 'http://example.com/path?state=xyz&client_id=abc' + token_uri = 'http://example.com/path' + + # authorization grant + h, _, s = self.web.create_authorization_response( + auth_uri + '&response_type=code', scopes=['random']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertTrue(h['Location'].startswith(self.DEFAULT_REDIRECT_URI)) + + # confirm_redirect_uri should return true if the redirect uri + # was not given in the authorization AND not in the token request. + self.validator.confirm_redirect_uri.return_value = True + code = get_query_credentials(h['Location'])['code'][0] + self.validator.validate_code.return_value = True + _, body, s = self.web.create_token_response(token_uri, + body='grant_type=authorization_code&code=%s' % code) + self.assertEqual(s, 200) + self.assertEqual(self.validator.confirm_redirect_uri.call_args[0][2], self.DEFAULT_REDIRECT_URI) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_error_responses.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_error_responses.py new file mode 100644 index 0000000000..f61595e213 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_error_responses.py @@ -0,0 +1,491 @@ +"""Ensure the correct error responses are returned for all defined error types. +""" +import json +from unittest import mock + +from oauthlib.common import urlencode +from oauthlib.oauth2 import ( + BackendApplicationServer, LegacyApplicationServer, MobileApplicationServer, + RequestValidator, WebApplicationServer, +) +from oauthlib.oauth2.rfc6749 import errors + +from tests.unittest import TestCase + + +class ErrorResponseTest(TestCase): + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def setUp(self): + self.validator = mock.MagicMock(spec=RequestValidator) + self.validator.get_default_redirect_uri.return_value = None + self.validator.get_code_challenge.return_value = None + self.web = WebApplicationServer(self.validator) + self.mobile = MobileApplicationServer(self.validator) + self.legacy = LegacyApplicationServer(self.validator) + self.backend = BackendApplicationServer(self.validator) + + def test_invalid_redirect_uri(self): + uri = 'https://example.com/authorize?response_type={0}&client_id=foo&redirect_uri=wrong' + + # Authorization code grant + self.assertRaises(errors.InvalidRedirectURIError, + self.web.validate_authorization_request, uri.format('code')) + self.assertRaises(errors.InvalidRedirectURIError, + self.web.create_authorization_response, uri.format('code'), scopes=['foo']) + + # Implicit grant + self.assertRaises(errors.InvalidRedirectURIError, + self.mobile.validate_authorization_request, uri.format('token')) + self.assertRaises(errors.InvalidRedirectURIError, + self.mobile.create_authorization_response, uri.format('token'), scopes=['foo']) + + def test_invalid_default_redirect_uri(self): + uri = 'https://example.com/authorize?response_type={0}&client_id=foo' + self.validator.get_default_redirect_uri.return_value = "wrong" + + # Authorization code grant + self.assertRaises(errors.InvalidRedirectURIError, + self.web.validate_authorization_request, uri.format('code')) + self.assertRaises(errors.InvalidRedirectURIError, + self.web.create_authorization_response, uri.format('code'), scopes=['foo']) + + # Implicit grant + self.assertRaises(errors.InvalidRedirectURIError, + self.mobile.validate_authorization_request, uri.format('token')) + self.assertRaises(errors.InvalidRedirectURIError, + self.mobile.create_authorization_response, uri.format('token'), scopes=['foo']) + + def test_missing_redirect_uri(self): + uri = 'https://example.com/authorize?response_type={0}&client_id=foo' + + # Authorization code grant + self.assertRaises(errors.MissingRedirectURIError, + self.web.validate_authorization_request, uri.format('code')) + self.assertRaises(errors.MissingRedirectURIError, + self.web.create_authorization_response, uri.format('code'), scopes=['foo']) + + # Implicit grant + self.assertRaises(errors.MissingRedirectURIError, + self.mobile.validate_authorization_request, uri.format('token')) + self.assertRaises(errors.MissingRedirectURIError, + self.mobile.create_authorization_response, uri.format('token'), scopes=['foo']) + + def test_mismatching_redirect_uri(self): + uri = 'https://example.com/authorize?response_type={0}&client_id=foo&redirect_uri=https%3A%2F%2Fi.b%2Fback' + + # Authorization code grant + self.validator.validate_redirect_uri.return_value = False + self.assertRaises(errors.MismatchingRedirectURIError, + self.web.validate_authorization_request, uri.format('code')) + self.assertRaises(errors.MismatchingRedirectURIError, + self.web.create_authorization_response, uri.format('code'), scopes=['foo']) + + # Implicit grant + self.assertRaises(errors.MismatchingRedirectURIError, + self.mobile.validate_authorization_request, uri.format('token')) + self.assertRaises(errors.MismatchingRedirectURIError, + self.mobile.create_authorization_response, uri.format('token'), scopes=['foo']) + + def test_missing_client_id(self): + uri = 'https://example.com/authorize?response_type={0}&redirect_uri=https%3A%2F%2Fi.b%2Fback' + + # Authorization code grant + self.validator.validate_redirect_uri.return_value = False + self.assertRaises(errors.MissingClientIdError, + self.web.validate_authorization_request, uri.format('code')) + self.assertRaises(errors.MissingClientIdError, + self.web.create_authorization_response, uri.format('code'), scopes=['foo']) + + # Implicit grant + self.assertRaises(errors.MissingClientIdError, + self.mobile.validate_authorization_request, uri.format('token')) + self.assertRaises(errors.MissingClientIdError, + self.mobile.create_authorization_response, uri.format('token'), scopes=['foo']) + + def test_invalid_client_id(self): + uri = 'https://example.com/authorize?response_type={0}&client_id=foo&redirect_uri=https%3A%2F%2Fi.b%2Fback' + + # Authorization code grant + self.validator.validate_client_id.return_value = False + self.assertRaises(errors.InvalidClientIdError, + self.web.validate_authorization_request, uri.format('code')) + self.assertRaises(errors.InvalidClientIdError, + self.web.create_authorization_response, uri.format('code'), scopes=['foo']) + + # Implicit grant + self.assertRaises(errors.InvalidClientIdError, + self.mobile.validate_authorization_request, uri.format('token')) + self.assertRaises(errors.InvalidClientIdError, + self.mobile.create_authorization_response, uri.format('token'), scopes=['foo']) + + def test_empty_parameter(self): + uri = 'https://example.com/authorize?client_id=foo&redirect_uri=https%3A%2F%2Fi.b%2Fback&response_type=code&' + + # Authorization code grant + self.assertRaises(errors.InvalidRequestFatalError, + self.web.validate_authorization_request, uri) + + # Implicit grant + self.assertRaises(errors.InvalidRequestFatalError, + self.mobile.validate_authorization_request, uri) + + def test_invalid_request(self): + self.validator.get_default_redirect_uri.return_value = 'https://i.b/cb' + token_uri = 'https://i.b/token' + + invalid_bodies = [ + # duplicate params + 'grant_type=authorization_code&client_id=nope&client_id=nope&code=foo' + ] + for body in invalid_bodies: + _, body, _ = self.web.create_token_response(token_uri, + body=body) + self.assertEqual('invalid_request', json.loads(body)['error']) + + # Password credentials grant + invalid_bodies = [ + # duplicate params + 'grant_type=password&username=foo&username=bar&password=baz' + # missing username + 'grant_type=password&password=baz' + # missing password + 'grant_type=password&username=foo' + ] + self.validator.authenticate_client.side_effect = self.set_client + for body in invalid_bodies: + _, body, _ = self.legacy.create_token_response(token_uri, + body=body) + self.assertEqual('invalid_request', json.loads(body)['error']) + + # Client credentials grant + invalid_bodies = [ + # duplicate params + 'grant_type=client_credentials&scope=foo&scope=bar' + ] + for body in invalid_bodies: + _, body, _ = self.backend.create_token_response(token_uri, + body=body) + self.assertEqual('invalid_request', json.loads(body)['error']) + + def test_invalid_request_duplicate_params(self): + self.validator.get_default_redirect_uri.return_value = 'https://i.b/cb' + uri = 'https://i.b/auth?client_id=foo&client_id=bar&response_type={0}' + description = 'Duplicate client_id parameter.' + + # Authorization code + self.assertRaisesRegex(errors.InvalidRequestFatalError, + description, + self.web.validate_authorization_request, + uri.format('code')) + self.assertRaisesRegex(errors.InvalidRequestFatalError, + description, + self.web.create_authorization_response, + uri.format('code'), scopes=['foo']) + + # Implicit grant + self.assertRaisesRegex(errors.InvalidRequestFatalError, + description, + self.mobile.validate_authorization_request, + uri.format('token')) + self.assertRaisesRegex(errors.InvalidRequestFatalError, + description, + self.mobile.create_authorization_response, + uri.format('token'), scopes=['foo']) + + def test_invalid_request_missing_response_type(self): + + self.validator.get_default_redirect_uri.return_value = 'https://i.b/cb' + + uri = 'https://i.b/auth?client_id=foo' + + # Authorization code + self.assertRaises(errors.MissingResponseTypeError, + self.web.validate_authorization_request, + uri.format('code')) + h, _, s = self.web.create_authorization_response(uri, scopes=['foo']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertIn('error=invalid_request', h['Location']) + + # Implicit grant + self.assertRaises(errors.MissingResponseTypeError, + self.mobile.validate_authorization_request, + uri.format('token')) + h, _, s = self.mobile.create_authorization_response(uri, scopes=['foo']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertIn('error=invalid_request', h['Location']) + + def test_unauthorized_client(self): + self.validator.get_default_redirect_uri.return_value = 'https://i.b/cb' + self.validator.validate_grant_type.return_value = False + self.validator.validate_response_type.return_value = False + self.validator.authenticate_client.side_effect = self.set_client + token_uri = 'https://i.b/token' + + # Authorization code grant + self.assertRaises(errors.UnauthorizedClientError, + self.web.validate_authorization_request, + 'https://i.b/auth?response_type=code&client_id=foo') + _, body, _ = self.web.create_token_response(token_uri, + body='grant_type=authorization_code&code=foo') + self.assertEqual('unauthorized_client', json.loads(body)['error']) + + # Implicit grant + self.assertRaises(errors.UnauthorizedClientError, + self.mobile.validate_authorization_request, + 'https://i.b/auth?response_type=token&client_id=foo') + + # Password credentials grant + _, body, _ = self.legacy.create_token_response(token_uri, + body='grant_type=password&username=foo&password=bar') + self.assertEqual('unauthorized_client', json.loads(body)['error']) + + # Client credentials grant + _, body, _ = self.backend.create_token_response(token_uri, + body='grant_type=client_credentials') + self.assertEqual('unauthorized_client', json.loads(body)['error']) + + def test_access_denied(self): + self.validator.authenticate_client.side_effect = self.set_client + self.validator.get_default_redirect_uri.return_value = 'https://i.b/cb' + self.validator.confirm_redirect_uri.return_value = False + token_uri = 'https://i.b/token' + # Authorization code grant + _, body, _ = self.web.create_token_response(token_uri, + body='grant_type=authorization_code&code=foo') + self.assertEqual('invalid_request', json.loads(body)['error']) + + def test_access_denied_no_default_redirecturi(self): + self.validator.authenticate_client.side_effect = self.set_client + self.validator.get_default_redirect_uri.return_value = None + token_uri = 'https://i.b/token' + # Authorization code grant + _, body, _ = self.web.create_token_response(token_uri, + body='grant_type=authorization_code&code=foo') + self.assertEqual('invalid_request', json.loads(body)['error']) + + def test_unsupported_response_type(self): + self.validator.get_default_redirect_uri.return_value = 'https://i.b/cb' + + # Authorization code grant + self.assertRaises(errors.UnsupportedResponseTypeError, + self.web.validate_authorization_request, + 'https://i.b/auth?response_type=foo&client_id=foo') + + # Implicit grant + self.assertRaises(errors.UnsupportedResponseTypeError, + self.mobile.validate_authorization_request, + 'https://i.b/auth?response_type=foo&client_id=foo') + + def test_invalid_scope(self): + self.validator.get_default_redirect_uri.return_value = 'https://i.b/cb' + self.validator.validate_scopes.return_value = False + self.validator.authenticate_client.side_effect = self.set_client + + # Authorization code grant + self.assertRaises(errors.InvalidScopeError, + self.web.validate_authorization_request, + 'https://i.b/auth?response_type=code&client_id=foo') + + # Implicit grant + self.assertRaises(errors.InvalidScopeError, + self.mobile.validate_authorization_request, + 'https://i.b/auth?response_type=token&client_id=foo') + + # Password credentials grant + _, body, _ = self.legacy.create_token_response( + 'https://i.b/token', + body='grant_type=password&username=foo&password=bar') + self.assertEqual('invalid_scope', json.loads(body)['error']) + + # Client credentials grant + _, body, _ = self.backend.create_token_response( + 'https://i.b/token', + body='grant_type=client_credentials') + self.assertEqual('invalid_scope', json.loads(body)['error']) + + def test_server_error(self): + def raise_error(*args, **kwargs): + raise ValueError() + + self.validator.validate_client_id.side_effect = raise_error + self.validator.authenticate_client.side_effect = raise_error + self.validator.get_default_redirect_uri.return_value = 'https://i.b/cb' + + # Authorization code grant + self.web.catch_errors = True + _, _, s = self.web.create_authorization_response( + 'https://i.b/auth?client_id=foo&response_type=code', + scopes=['foo']) + self.assertEqual(s, 500) + _, _, s = self.web.create_token_response( + 'https://i.b/token', + body='grant_type=authorization_code&code=foo', + scopes=['foo']) + self.assertEqual(s, 500) + + # Implicit grant + self.mobile.catch_errors = True + _, _, s = self.mobile.create_authorization_response( + 'https://i.b/auth?client_id=foo&response_type=token', + scopes=['foo']) + self.assertEqual(s, 500) + + # Password credentials grant + self.legacy.catch_errors = True + _, _, s = self.legacy.create_token_response( + 'https://i.b/token', + body='grant_type=password&username=foo&password=foo') + self.assertEqual(s, 500) + + # Client credentials grant + self.backend.catch_errors = True + _, _, s = self.backend.create_token_response( + 'https://i.b/token', + body='grant_type=client_credentials') + self.assertEqual(s, 500) + + def test_temporarily_unavailable(self): + # Authorization code grant + self.web.available = False + _, _, s = self.web.create_authorization_response( + 'https://i.b/auth?client_id=foo&response_type=code', + scopes=['foo']) + self.assertEqual(s, 503) + _, _, s = self.web.create_token_response( + 'https://i.b/token', + body='grant_type=authorization_code&code=foo', + scopes=['foo']) + self.assertEqual(s, 503) + + # Implicit grant + self.mobile.available = False + _, _, s = self.mobile.create_authorization_response( + 'https://i.b/auth?client_id=foo&response_type=token', + scopes=['foo']) + self.assertEqual(s, 503) + + # Password credentials grant + self.legacy.available = False + _, _, s = self.legacy.create_token_response( + 'https://i.b/token', + body='grant_type=password&username=foo&password=foo') + self.assertEqual(s, 503) + + # Client credentials grant + self.backend.available = False + _, _, s = self.backend.create_token_response( + 'https://i.b/token', + body='grant_type=client_credentials') + self.assertEqual(s, 503) + + def test_invalid_client(self): + self.validator.authenticate_client.return_value = False + self.validator.authenticate_client_id.return_value = False + + # Authorization code grant + _, body, _ = self.web.create_token_response('https://i.b/token', + body='grant_type=authorization_code&code=foo') + self.assertEqual('invalid_client', json.loads(body)['error']) + + # Password credentials grant + _, body, _ = self.legacy.create_token_response('https://i.b/token', + body='grant_type=password&username=foo&password=bar') + self.assertEqual('invalid_client', json.loads(body)['error']) + + # Client credentials grant + _, body, _ = self.legacy.create_token_response('https://i.b/token', + body='grant_type=client_credentials') + self.assertEqual('invalid_client', json.loads(body)['error']) + + def test_invalid_grant(self): + self.validator.authenticate_client.side_effect = self.set_client + + # Authorization code grant + self.validator.validate_code.return_value = False + _, body, _ = self.web.create_token_response('https://i.b/token', + body='grant_type=authorization_code&code=foo') + self.assertEqual('invalid_grant', json.loads(body)['error']) + + # Password credentials grant + self.validator.validate_user.return_value = False + _, body, _ = self.legacy.create_token_response('https://i.b/token', + body='grant_type=password&username=foo&password=bar') + self.assertEqual('invalid_grant', json.loads(body)['error']) + + def test_unsupported_grant_type(self): + self.validator.authenticate_client.side_effect = self.set_client + + # Authorization code grant + _, body, _ = self.web.create_token_response('https://i.b/token', + body='grant_type=bar&code=foo') + self.assertEqual('unsupported_grant_type', json.loads(body)['error']) + + # Password credentials grant + _, body, _ = self.legacy.create_token_response('https://i.b/token', + body='grant_type=bar&username=foo&password=bar') + self.assertEqual('unsupported_grant_type', json.loads(body)['error']) + + # Client credentials grant + _, body, _ = self.backend.create_token_response('https://i.b/token', + body='grant_type=bar') + self.assertEqual('unsupported_grant_type', json.loads(body)['error']) + + def test_invalid_request_method(self): + test_methods = ['GET', 'pUt', 'dEleTe', 'paTcH'] + test_methods = test_methods + [x.lower() for x in test_methods] + [x.upper() for x in test_methods] + for method in test_methods: + self.validator.authenticate_client.side_effect = self.set_client + + uri = "http://i/b/token/" + try: + _, body, s = self.web.create_token_response(uri, + body='grant_type=access_token&code=123', http_method=method) + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('Unsupported request method', ire.description) + + try: + _, body, s = self.legacy.create_token_response(uri, + body='grant_type=access_token&code=123', http_method=method) + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('Unsupported request method', ire.description) + + try: + _, body, s = self.backend.create_token_response(uri, + body='grant_type=access_token&code=123', http_method=method) + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('Unsupported request method', ire.description) + + def test_invalid_post_request(self): + self.validator.authenticate_client.side_effect = self.set_client + for param in ['token', 'secret', 'code', 'foo']: + uri = 'https://i/b/token?' + urlencode([(param, 'secret')]) + try: + _, body, s = self.web.create_token_response(uri, + body='grant_type=access_token&code=123') + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('URL query parameters are not allowed', ire.description) + + try: + _, body, s = self.legacy.create_token_response(uri, + body='grant_type=access_token&code=123') + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('URL query parameters are not allowed', ire.description) + + try: + _, body, s = self.backend.create_token_response(uri, + body='grant_type=access_token&code=123') + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('URL query parameters are not allowed', ire.description) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_extra_credentials.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_extra_credentials.py new file mode 100644 index 0000000000..97aaf86dff --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_extra_credentials.py @@ -0,0 +1,69 @@ +"""Ensure extra credentials can be supplied for inclusion in tokens. +""" +from unittest import mock + +from oauthlib.oauth2 import ( + BackendApplicationServer, LegacyApplicationServer, MobileApplicationServer, + RequestValidator, WebApplicationServer, +) + +from tests.unittest import TestCase + + +class ExtraCredentialsTest(TestCase): + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def setUp(self): + self.validator = mock.MagicMock(spec=RequestValidator) + self.validator.get_default_redirect_uri.return_value = 'https://i.b/cb' + self.web = WebApplicationServer(self.validator) + self.mobile = MobileApplicationServer(self.validator) + self.legacy = LegacyApplicationServer(self.validator) + self.backend = BackendApplicationServer(self.validator) + + def test_post_authorization_request(self): + def save_code(client_id, token, request): + self.assertEqual('creds', request.extra) + + def save_token(token, request): + self.assertEqual('creds', request.extra) + + # Authorization code grant + self.validator.save_authorization_code.side_effect = save_code + self.web.create_authorization_response( + 'https://i.b/auth?client_id=foo&response_type=code', + scopes=['foo'], + credentials={'extra': 'creds'}) + + # Implicit grant + self.validator.save_bearer_token.side_effect = save_token + self.mobile.create_authorization_response( + 'https://i.b/auth?client_id=foo&response_type=token', + scopes=['foo'], + credentials={'extra': 'creds'}) + + def test_token_request(self): + def save_token(token, request): + self.assertIn('extra', token) + + self.validator.save_bearer_token.side_effect = save_token + self.validator.authenticate_client.side_effect = self.set_client + + # Authorization code grant + self.web.create_token_response('https://i.b/token', + body='grant_type=authorization_code&code=foo', + credentials={'extra': 'creds'}) + + # Password credentials grant + self.legacy.create_token_response('https://i.b/token', + body='grant_type=password&username=foo&password=bar', + credentials={'extra': 'creds'}) + + # Client credentials grant + self.backend.create_token_response('https://i.b/token', + body='grant_type=client_credentials', + credentials={'extra': 'creds'}) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py new file mode 100644 index 0000000000..6d3d119a3b --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +from json import loads +from unittest.mock import MagicMock + +from oauthlib.common import urlencode +from oauthlib.oauth2 import IntrospectEndpoint, RequestValidator + +from tests.unittest import TestCase + + +class IntrospectEndpointTest(TestCase): + + def setUp(self): + self.validator = MagicMock(wraps=RequestValidator()) + self.validator.client_authentication_required.return_value = True + self.validator.authenticate_client.return_value = True + self.validator.validate_bearer_token.return_value = True + self.validator.introspect_token.return_value = {} + self.endpoint = IntrospectEndpoint(self.validator) + + self.uri = 'should_not_matter' + self.headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + } + self.resp_h = { + 'Cache-Control': 'no-store', + 'Content-Type': 'application/json', + 'Pragma': 'no-cache' + } + self.resp_b = { + "active": True + } + + def test_introspect_token(self): + for token_type in ('access_token', 'refresh_token', 'invalid'): + body = urlencode([('token', 'foo'), + ('token_type_hint', token_type)]) + h, b, s = self.endpoint.create_introspect_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b), self.resp_b) + self.assertEqual(s, 200) + + def test_introspect_token_nohint(self): + # don't specify token_type_hint + body = urlencode([('token', 'foo')]) + h, b, s = self.endpoint.create_introspect_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b), self.resp_b) + self.assertEqual(s, 200) + + def test_introspect_token_false(self): + self.validator.introspect_token.return_value = None + body = urlencode([('token', 'foo')]) + h, b, s = self.endpoint.create_introspect_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b), {"active": False}) + self.assertEqual(s, 200) + + def test_introspect_token_claims(self): + self.validator.introspect_token.return_value = {"foo": "bar"} + body = urlencode([('token', 'foo')]) + h, b, s = self.endpoint.create_introspect_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b), {"active": True, "foo": "bar"}) + self.assertEqual(s, 200) + + def test_introspect_token_claims_spoof_active(self): + self.validator.introspect_token.return_value = {"foo": "bar", "active": False} + body = urlencode([('token', 'foo')]) + h, b, s = self.endpoint.create_introspect_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b), {"active": True, "foo": "bar"}) + self.assertEqual(s, 200) + + def test_introspect_token_client_authentication_failed(self): + self.validator.authenticate_client.return_value = False + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = self.endpoint.create_introspect_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, { + 'Content-Type': 'application/json', + 'Cache-Control': 'no-store', + 'Pragma': 'no-cache', + "WWW-Authenticate": 'Bearer error="invalid_client"' + }) + self.assertEqual(loads(b)['error'], 'invalid_client') + self.assertEqual(s, 401) + + def test_introspect_token_public_client_authentication(self): + self.validator.client_authentication_required.return_value = False + self.validator.authenticate_client_id.return_value = True + for token_type in ('access_token', 'refresh_token', 'invalid'): + body = urlencode([('token', 'foo'), + ('token_type_hint', token_type)]) + h, b, s = self.endpoint.create_introspect_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b), self.resp_b) + self.assertEqual(s, 200) + + def test_introspect_token_public_client_authentication_failed(self): + self.validator.client_authentication_required.return_value = False + self.validator.authenticate_client_id.return_value = False + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = self.endpoint.create_introspect_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, { + 'Content-Type': 'application/json', + 'Cache-Control': 'no-store', + 'Pragma': 'no-cache', + "WWW-Authenticate": 'Bearer error="invalid_client"' + }) + self.assertEqual(loads(b)['error'], 'invalid_client') + self.assertEqual(s, 401) + + def test_introspect_unsupported_token(self): + endpoint = IntrospectEndpoint(self.validator, + supported_token_types=['access_token']) + body = urlencode([('token', 'foo'), + ('token_type_hint', 'refresh_token')]) + h, b, s = endpoint.create_introspect_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'unsupported_token_type') + self.assertEqual(s, 400) + + h, b, s = endpoint.create_introspect_response(self.uri, + headers=self.headers, body='') + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertEqual(s, 400) + + def test_introspect_invalid_request_method(self): + endpoint = IntrospectEndpoint(self.validator, + supported_token_types=['access_token']) + test_methods = ['GET', 'pUt', 'dEleTe', 'paTcH'] + test_methods = test_methods + [x.lower() for x in test_methods] + [x.upper() for x in test_methods] + for method in test_methods: + body = urlencode([('token', 'foo'), + ('token_type_hint', 'refresh_token')]) + h, b, s = endpoint.create_introspect_response(self.uri, + http_method = method, headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn('Unsupported request method', loads(b)['error_description']) + self.assertEqual(s, 400) + + def test_introspect_bad_post_request(self): + endpoint = IntrospectEndpoint(self.validator, + supported_token_types=['access_token']) + for param in ['token', 'secret', 'code', 'foo']: + uri = 'http://some.endpoint?' + urlencode([(param, 'secret')]) + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = endpoint.create_introspect_response( + uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn('query parameters are not allowed', loads(b)['error_description']) + self.assertEqual(s, 400) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_metadata.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_metadata.py new file mode 100644 index 0000000000..1f5b912100 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_metadata.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +import json + +from oauthlib.oauth2 import MetadataEndpoint, Server, TokenEndpoint + +from tests.unittest import TestCase + + +class MetadataEndpointTest(TestCase): + def setUp(self): + self.metadata = { + "issuer": 'https://foo.bar' + } + + def test_openid_oauth2_preconfigured(self): + default_claims = { + "issuer": 'https://foo.bar', + "authorization_endpoint": "https://foo.bar/authorize", + "revocation_endpoint": "https://foo.bar/revoke", + "introspection_endpoint": "https://foo.bar/introspect", + "token_endpoint": "https://foo.bar/token" + } + from oauthlib.oauth2 import Server as OAuth2Server + from oauthlib.openid import Server as OpenIDServer + + endpoint = OAuth2Server(None) + metadata = MetadataEndpoint([endpoint], default_claims) + oauth2_claims = metadata.claims + + endpoint = OpenIDServer(None) + metadata = MetadataEndpoint([endpoint], default_claims) + openid_claims = metadata.claims + + # Pure OAuth2 Authorization Metadata are similar with OpenID but + # response_type not! (OIDC contains "id_token" and hybrid flows) + del oauth2_claims['response_types_supported'] + del openid_claims['response_types_supported'] + + self.maxDiff = None + self.assertEqual(openid_claims, oauth2_claims) + + def test_create_metadata_response(self): + endpoint = TokenEndpoint(None, None, grant_types={"password": None}) + metadata = MetadataEndpoint([endpoint], { + "issuer": 'https://foo.bar', + "token_endpoint": "https://foo.bar/token" + }) + headers, body, status = metadata.create_metadata_response('/', 'GET') + assert headers == { + 'Content-Type': 'application/json', + 'Access-Control-Allow-Origin': '*', + } + claims = json.loads(body) + assert claims['issuer'] == 'https://foo.bar' + + def test_token_endpoint(self): + endpoint = TokenEndpoint(None, None, grant_types={"password": None}) + metadata = MetadataEndpoint([endpoint], { + "issuer": 'https://foo.bar', + "token_endpoint": "https://foo.bar/token" + }) + self.assertIn("grant_types_supported", metadata.claims) + self.assertEqual(metadata.claims["grant_types_supported"], ["password"]) + + def test_token_endpoint_overridden(self): + endpoint = TokenEndpoint(None, None, grant_types={"password": None}) + metadata = MetadataEndpoint([endpoint], { + "issuer": 'https://foo.bar', + "token_endpoint": "https://foo.bar/token", + "grant_types_supported": ["pass_word_special_provider"] + }) + self.assertIn("grant_types_supported", metadata.claims) + self.assertEqual(metadata.claims["grant_types_supported"], ["pass_word_special_provider"]) + + def test_mandatory_fields(self): + metadata = MetadataEndpoint([], self.metadata) + self.assertIn("issuer", metadata.claims) + self.assertEqual(metadata.claims["issuer"], 'https://foo.bar') + + def test_server_metadata(self): + endpoint = Server(None) + metadata = MetadataEndpoint([endpoint], { + "issuer": 'https://foo.bar', + "authorization_endpoint": "https://foo.bar/authorize", + "introspection_endpoint": "https://foo.bar/introspect", + "revocation_endpoint": "https://foo.bar/revoke", + "token_endpoint": "https://foo.bar/token", + "jwks_uri": "https://foo.bar/certs", + "scopes_supported": ["email", "profile"] + }) + expected_claims = { + "issuer": "https://foo.bar", + "authorization_endpoint": "https://foo.bar/authorize", + "introspection_endpoint": "https://foo.bar/introspect", + "revocation_endpoint": "https://foo.bar/revoke", + "token_endpoint": "https://foo.bar/token", + "jwks_uri": "https://foo.bar/certs", + "scopes_supported": ["email", "profile"], + "grant_types_supported": [ + "authorization_code", + "password", + "client_credentials", + "refresh_token", + "implicit" + ], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic" + ], + "response_types_supported": [ + "code", + "token" + ], + "response_modes_supported": [ + "query", + "fragment" + ], + "code_challenge_methods_supported": [ + "plain", + "S256" + ], + "revocation_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic" + ], + "introspection_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic" + ] + } + + def sort_list(claims): + for k in claims.keys(): + claims[k] = sorted(claims[k]) + + sort_list(metadata.claims) + sort_list(expected_claims) + self.assertEqual(sorted(metadata.claims.items()), sorted(expected_claims.items())) + + def test_metadata_validate_issuer(self): + with self.assertRaises(ValueError): + endpoint = TokenEndpoint( + None, None, grant_types={"password": None}, + ) + metadata = MetadataEndpoint([endpoint], { + "issuer": 'http://foo.bar', + "token_endpoint": "https://foo.bar/token", + }) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_resource_owner_association.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_resource_owner_association.py new file mode 100644 index 0000000000..04533888e9 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_resource_owner_association.py @@ -0,0 +1,108 @@ +"""Ensure all tokens are associated with a resource owner. +""" +import json +from unittest import mock + +from oauthlib.oauth2 import ( + BackendApplicationServer, LegacyApplicationServer, MobileApplicationServer, + RequestValidator, WebApplicationServer, +) + +from tests.unittest import TestCase + +from .test_utils import get_fragment_credentials, get_query_credentials + + +class ResourceOwnerAssociationTest(TestCase): + + auth_uri = 'http://example.com/path?client_id=abc' + token_uri = 'http://example.com/path' + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def set_user(self, client_id, code, client, request): + request.user = 'test' + return True + + def set_user_from_username(self, username, password, client, request): + request.user = 'test' + return True + + def set_user_from_credentials(self, request): + request.user = 'test' + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def inspect_client(self, request, refresh_token=False): + if not request.user: + raise ValueError() + return 'abc' + + def setUp(self): + self.validator = mock.MagicMock(spec=RequestValidator) + self.validator.get_default_redirect_uri.return_value = 'http://i.b./path' + self.validator.get_code_challenge.return_value = None + self.validator.authenticate_client.side_effect = self.set_client + self.web = WebApplicationServer(self.validator, + token_generator=self.inspect_client) + self.mobile = MobileApplicationServer(self.validator, + token_generator=self.inspect_client) + self.legacy = LegacyApplicationServer(self.validator, + token_generator=self.inspect_client) + self.backend = BackendApplicationServer(self.validator, + token_generator=self.inspect_client) + + def test_web_application(self): + # TODO: code generator + intercept test + h, _, s = self.web.create_authorization_response( + self.auth_uri + '&response_type=code', + credentials={'user': 'test'}, scopes=['random']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + code = get_query_credentials(h['Location'])['code'][0] + self.assertRaises(ValueError, + self.web.create_token_response, self.token_uri, + body='grant_type=authorization_code&code=%s' % code) + + self.validator.validate_code.side_effect = self.set_user + _, body, _ = self.web.create_token_response(self.token_uri, + body='grant_type=authorization_code&code=%s' % code) + self.assertEqual(json.loads(body)['access_token'], 'abc') + + def test_mobile_application(self): + self.assertRaises(ValueError, + self.mobile.create_authorization_response, + self.auth_uri + '&response_type=token') + + h, _, s = self.mobile.create_authorization_response( + self.auth_uri + '&response_type=token', + credentials={'user': 'test'}, scopes=['random']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertEqual(get_fragment_credentials(h['Location'])['access_token'][0], 'abc') + + def test_legacy_application(self): + body = 'grant_type=password&username=abc&password=secret' + self.assertRaises(ValueError, + self.legacy.create_token_response, + self.token_uri, body=body) + + self.validator.validate_user.side_effect = self.set_user_from_username + _, body, _ = self.legacy.create_token_response( + self.token_uri, body=body) + self.assertEqual(json.loads(body)['access_token'], 'abc') + + def test_backend_application(self): + body = 'grant_type=client_credentials' + self.assertRaises(ValueError, + self.backend.create_token_response, + self.token_uri, body=body) + + self.validator.authenticate_client.side_effect = self.set_user_from_credentials + _, body, _ = self.backend.create_token_response( + self.token_uri, body=body) + self.assertEqual(json.loads(body)['access_token'], 'abc') diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py new file mode 100644 index 0000000000..338dbd91fa --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +from json import loads +from unittest.mock import MagicMock + +from oauthlib.common import urlencode +from oauthlib.oauth2 import RequestValidator, RevocationEndpoint + +from tests.unittest import TestCase + + +class RevocationEndpointTest(TestCase): + + def setUp(self): + self.validator = MagicMock(wraps=RequestValidator()) + self.validator.client_authentication_required.return_value = True + self.validator.authenticate_client.return_value = True + self.validator.revoke_token.return_value = True + self.endpoint = RevocationEndpoint(self.validator) + + self.uri = 'https://example.com/revoke_token' + self.headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + } + self.resp_h = { + 'Cache-Control': 'no-store', + 'Content-Type': 'application/json', + 'Pragma': 'no-cache' + } + + def test_revoke_token(self): + for token_type in ('access_token', 'refresh_token', 'invalid'): + body = urlencode([('token', 'foo'), + ('token_type_hint', token_type)]) + h, b, s = self.endpoint.create_revocation_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, {}) + self.assertEqual(b, '') + self.assertEqual(s, 200) + + # don't specify token_type_hint + body = urlencode([('token', 'foo')]) + h, b, s = self.endpoint.create_revocation_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, {}) + self.assertEqual(b, '') + self.assertEqual(s, 200) + + def test_revoke_token_client_authentication_failed(self): + self.validator.authenticate_client.return_value = False + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = self.endpoint.create_revocation_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, { + 'Content-Type': 'application/json', + 'Cache-Control': 'no-store', + 'Pragma': 'no-cache', + "WWW-Authenticate": 'Bearer error="invalid_client"' + }) + self.assertEqual(loads(b)['error'], 'invalid_client') + self.assertEqual(s, 401) + + def test_revoke_token_public_client_authentication(self): + self.validator.client_authentication_required.return_value = False + self.validator.authenticate_client_id.return_value = True + for token_type in ('access_token', 'refresh_token', 'invalid'): + body = urlencode([('token', 'foo'), + ('token_type_hint', token_type)]) + h, b, s = self.endpoint.create_revocation_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, {}) + self.assertEqual(b, '') + self.assertEqual(s, 200) + + def test_revoke_token_public_client_authentication_failed(self): + self.validator.client_authentication_required.return_value = False + self.validator.authenticate_client_id.return_value = False + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = self.endpoint.create_revocation_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, { + 'Content-Type': 'application/json', + 'Cache-Control': 'no-store', + 'Pragma': 'no-cache', + "WWW-Authenticate": 'Bearer error="invalid_client"' + }) + self.assertEqual(loads(b)['error'], 'invalid_client') + self.assertEqual(s, 401) + + def test_revoke_with_callback(self): + endpoint = RevocationEndpoint(self.validator, enable_jsonp=True) + callback = 'package.hello_world' + for token_type in ('access_token', 'refresh_token', 'invalid'): + body = urlencode([('token', 'foo'), + ('token_type_hint', token_type), + ('callback', callback)]) + h, b, s = endpoint.create_revocation_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, {}) + self.assertEqual(b, callback + '();') + self.assertEqual(s, 200) + + def test_revoke_unsupported_token(self): + endpoint = RevocationEndpoint(self.validator, + supported_token_types=['access_token']) + body = urlencode([('token', 'foo'), + ('token_type_hint', 'refresh_token')]) + h, b, s = endpoint.create_revocation_response(self.uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'unsupported_token_type') + self.assertEqual(s, 400) + + h, b, s = endpoint.create_revocation_response(self.uri, + headers=self.headers, body='') + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertEqual(s, 400) + + def test_revoke_invalid_request_method(self): + endpoint = RevocationEndpoint(self.validator, + supported_token_types=['access_token']) + test_methods = ['GET', 'pUt', 'dEleTe', 'paTcH'] + test_methods = test_methods + [x.lower() for x in test_methods] + [x.upper() for x in test_methods] + for method in test_methods: + body = urlencode([('token', 'foo'), + ('token_type_hint', 'refresh_token')]) + h, b, s = endpoint.create_revocation_response(self.uri, + http_method = method, headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn('Unsupported request method', loads(b)['error_description']) + self.assertEqual(s, 400) + + def test_revoke_bad_post_request(self): + endpoint = RevocationEndpoint(self.validator, + supported_token_types=['access_token']) + for param in ['token', 'secret', 'code', 'foo']: + uri = 'http://some.endpoint?' + urlencode([(param, 'secret')]) + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = endpoint.create_revocation_response(uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn('query parameters are not allowed', loads(b)['error_description']) + self.assertEqual(s, 400) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_scope_handling.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_scope_handling.py new file mode 100644 index 0000000000..4c87d9c7c8 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_scope_handling.py @@ -0,0 +1,193 @@ +"""Ensure scope is preserved across authorization. + +Fairly trivial in all grants except the Authorization Code Grant where scope +need to be persisted temporarily in an authorization code. +""" +import json +from unittest import mock + +from oauthlib.oauth2 import ( + BackendApplicationServer, LegacyApplicationServer, MobileApplicationServer, + RequestValidator, Server, WebApplicationServer, +) + +from tests.unittest import TestCase + +from .test_utils import get_fragment_credentials, get_query_credentials + + +class TestScopeHandling(TestCase): + + DEFAULT_REDIRECT_URI = 'http://i.b./path' + + def set_scopes(self, scopes): + def set_request_scopes(client_id, code, client, request): + request.scopes = scopes + return True + return set_request_scopes + + def set_user(self, request): + request.user = 'foo' + request.client_id = 'bar' + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def setUp(self): + self.validator = mock.MagicMock(spec=RequestValidator) + self.validator.get_default_redirect_uri.return_value = TestScopeHandling.DEFAULT_REDIRECT_URI + self.validator.get_code_challenge.return_value = None + self.validator.authenticate_client.side_effect = self.set_client + self.server = Server(self.validator) + self.web = WebApplicationServer(self.validator) + self.mobile = MobileApplicationServer(self.validator) + self.legacy = LegacyApplicationServer(self.validator) + self.backend = BackendApplicationServer(self.validator) + + def test_scope_extraction(self): + scopes = ( + ('images', ['images']), + ('images+videos', ['images', 'videos']), + ('images+videos+openid', ['images', 'videos', 'openid']), + ('http%3A%2f%2fa.b%2fvideos', ['http://a.b/videos']), + ('http%3A%2f%2fa.b%2fvideos+pics', ['http://a.b/videos', 'pics']), + ('pics+http%3A%2f%2fa.b%2fvideos', ['pics', 'http://a.b/videos']), + ('http%3A%2f%2fa.b%2fvideos+https%3A%2f%2fc.d%2Fsecret', ['http://a.b/videos', 'https://c.d/secret']), + ) + + uri = 'http://example.com/path?client_id=abc&scope=%s&response_type=%s' + for scope, correct_scopes in scopes: + scopes, _ = self.web.validate_authorization_request( + uri % (scope, 'code')) + self.assertCountEqual(scopes, correct_scopes) + scopes, _ = self.mobile.validate_authorization_request( + uri % (scope, 'token')) + self.assertCountEqual(scopes, correct_scopes) + scopes, _ = self.server.validate_authorization_request( + uri % (scope, 'code')) + self.assertCountEqual(scopes, correct_scopes) + + def test_scope_preservation(self): + scope = 'pics+http%3A%2f%2fa.b%2fvideos' + decoded_scope = 'pics http://a.b/videos' + auth_uri = 'http://example.com/path?client_id=abc&response_type=' + token_uri = 'http://example.com/path' + + # authorization grant + for backend_server_type in ['web', 'server']: + h, _, s = getattr(self, backend_server_type).create_authorization_response( + auth_uri + 'code', scopes=decoded_scope.split(' ')) + self.validator.validate_code.side_effect = self.set_scopes(decoded_scope.split(' ')) + self.assertEqual(s, 302) + self.assertIn('Location', h) + code = get_query_credentials(h['Location'])['code'][0] + _, body, _ = getattr(self, backend_server_type).create_token_response(token_uri, + body='client_id=me&redirect_uri=http://back.to/me&grant_type=authorization_code&code=%s' % code) + self.assertEqual(json.loads(body)['scope'], decoded_scope) + + # implicit grant + for backend_server_type in ['mobile', 'server']: + h, _, s = getattr(self, backend_server_type).create_authorization_response( + auth_uri + 'token', scopes=decoded_scope.split(' ')) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertEqual(get_fragment_credentials(h['Location'])['scope'][0], decoded_scope) + + # resource owner password credentials grant + for backend_server_type in ['legacy', 'server']: + body = 'grant_type=password&username=abc&password=secret&scope=%s' + + _, body, _ = getattr(self, backend_server_type).create_token_response(token_uri, + body=body % scope) + self.assertEqual(json.loads(body)['scope'], decoded_scope) + + # client credentials grant + for backend_server_type in ['backend', 'server']: + body = 'grant_type=client_credentials&scope=%s' + self.validator.authenticate_client.side_effect = self.set_user + _, body, _ = getattr(self, backend_server_type).create_token_response(token_uri, + body=body % scope) + self.assertEqual(json.loads(body)['scope'], decoded_scope) + + def test_scope_changed(self): + scope = 'pics+http%3A%2f%2fa.b%2fvideos' + scopes = ['images', 'http://a.b/videos'] + decoded_scope = 'images http://a.b/videos' + auth_uri = 'http://example.com/path?client_id=abc&response_type=' + token_uri = 'http://example.com/path' + + # authorization grant + h, _, s = self.web.create_authorization_response( + auth_uri + 'code', scopes=scopes) + self.assertEqual(s, 302) + self.assertIn('Location', h) + code = get_query_credentials(h['Location'])['code'][0] + self.validator.validate_code.side_effect = self.set_scopes(scopes) + _, body, _ = self.web.create_token_response(token_uri, + body='grant_type=authorization_code&code=%s' % code) + self.assertEqual(json.loads(body)['scope'], decoded_scope) + + # implicit grant + self.validator.validate_scopes.side_effect = self.set_scopes(scopes) + h, _, s = self.mobile.create_authorization_response( + auth_uri + 'token', scopes=scopes) + self.assertEqual(s, 302) + self.assertIn('Location', h) + self.assertEqual(get_fragment_credentials(h['Location'])['scope'][0], decoded_scope) + + # resource owner password credentials grant + self.validator.validate_scopes.side_effect = self.set_scopes(scopes) + body = 'grant_type=password&username=abc&password=secret&scope=%s' + _, body, _ = self.legacy.create_token_response(token_uri, + body=body % scope) + self.assertEqual(json.loads(body)['scope'], decoded_scope) + + # client credentials grant + self.validator.validate_scopes.side_effect = self.set_scopes(scopes) + self.validator.authenticate_client.side_effect = self.set_user + body = 'grant_type=client_credentials&scope=%s' + _, body, _ = self.backend.create_token_response(token_uri, + body=body % scope) + + self.assertEqual(json.loads(body)['scope'], decoded_scope) + + def test_invalid_scope(self): + scope = 'pics+http%3A%2f%2fa.b%2fvideos' + auth_uri = 'http://example.com/path?client_id=abc&response_type=' + token_uri = 'http://example.com/path' + + self.validator.validate_scopes.return_value = False + + # authorization grant + h, _, s = self.web.create_authorization_response( + auth_uri + 'code', scopes=['invalid']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + error = get_query_credentials(h['Location'])['error'][0] + self.assertEqual(error, 'invalid_scope') + + # implicit grant + h, _, s = self.mobile.create_authorization_response( + auth_uri + 'token', scopes=['invalid']) + self.assertEqual(s, 302) + self.assertIn('Location', h) + error = get_fragment_credentials(h['Location'])['error'][0] + self.assertEqual(error, 'invalid_scope') + + # resource owner password credentials grant + body = 'grant_type=password&username=abc&password=secret&scope=%s' + _, body, _ = self.legacy.create_token_response(token_uri, + body=body % scope) + self.assertEqual(json.loads(body)['error'], 'invalid_scope') + + # client credentials grant + self.validator.authenticate_client.side_effect = self.set_user + body = 'grant_type=client_credentials&scope=%s' + _, body, _ = self.backend.create_token_response(token_uri, + body=body % scope) + self.assertEqual(json.loads(body)['error'], 'invalid_scope') diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_utils.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_utils.py new file mode 100644 index 0000000000..5eae1956f4 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/endpoints/test_utils.py @@ -0,0 +1,11 @@ +import urllib.parse as urlparse + + +def get_query_credentials(uri): + return urlparse.parse_qs(urlparse.urlparse(uri).query, + keep_blank_values=True) + + +def get_fragment_credentials(uri): + return urlparse.parse_qs(urlparse.urlparse(uri).fragment, + keep_blank_values=True) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/__init__.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_authorization_code.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_authorization_code.py new file mode 100644 index 0000000000..77e1a81b46 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_authorization_code.py @@ -0,0 +1,382 @@ +# -*- coding: utf-8 -*- +import json +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749 import errors +from oauthlib.oauth2.rfc6749.grant_types import ( + AuthorizationCodeGrant, authorization_code, +) +from oauthlib.oauth2.rfc6749.tokens import BearerToken + +from tests.unittest import TestCase + + +class AuthorizationCodeGrantTest(TestCase): + + def setUp(self): + self.request = Request('http://a.b/path') + self.request.scopes = ('hello', 'world') + self.request.expires_in = 1800 + self.request.client = 'batman' + self.request.client_id = 'abcdef' + self.request.code = '1234' + self.request.response_type = 'code' + self.request.grant_type = 'authorization_code' + self.request.redirect_uri = 'https://a.b/cb' + + self.mock_validator = mock.MagicMock() + self.mock_validator.is_pkce_required.return_value = False + self.mock_validator.get_code_challenge.return_value = None + self.mock_validator.is_origin_allowed.return_value = False + self.mock_validator.authenticate_client.side_effect = self.set_client + self.auth = AuthorizationCodeGrant(request_validator=self.mock_validator) + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def setup_validators(self): + self.authval1, self.authval2 = mock.Mock(), mock.Mock() + self.authval1.return_value = {} + self.authval2.return_value = {} + self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() + self.tknval1.return_value = None + self.tknval2.return_value = None + self.auth.custom_validators.pre_token.append(self.tknval1) + self.auth.custom_validators.post_token.append(self.tknval2) + self.auth.custom_validators.pre_auth.append(self.authval1) + self.auth.custom_validators.post_auth.append(self.authval2) + + def test_custom_auth_validators(self): + self.setup_validators() + + bearer = BearerToken(self.mock_validator) + self.auth.create_authorization_response(self.request, bearer) + self.assertTrue(self.authval1.called) + self.assertTrue(self.authval2.called) + self.assertFalse(self.tknval1.called) + self.assertFalse(self.tknval2.called) + + def test_custom_token_validators(self): + self.setup_validators() + + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(self.tknval1.called) + self.assertTrue(self.tknval2.called) + self.assertFalse(self.authval1.called) + self.assertFalse(self.authval2.called) + + def test_create_authorization_grant(self): + bearer = BearerToken(self.mock_validator) + self.request.response_mode = 'query' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + grant = dict(Request(h['Location']).uri_query_params) + self.assertIn('code', grant) + self.assertTrue(self.mock_validator.validate_redirect_uri.called) + self.assertTrue(self.mock_validator.validate_response_type.called) + self.assertTrue(self.mock_validator.validate_scopes.called) + + def test_create_authorization_grant_no_scopes(self): + bearer = BearerToken(self.mock_validator) + self.request.response_mode = 'query' + self.request.scopes = [] + self.auth.create_authorization_response(self.request, bearer) + + def test_create_authorization_grant_state(self): + self.request.state = 'abc' + self.request.redirect_uri = None + self.request.response_mode = 'query' + self.mock_validator.get_default_redirect_uri.return_value = 'https://a.b/cb' + bearer = BearerToken(self.mock_validator) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + grant = dict(Request(h['Location']).uri_query_params) + self.assertIn('code', grant) + self.assertIn('state', grant) + self.assertFalse(self.mock_validator.validate_redirect_uri.called) + self.assertTrue(self.mock_validator.get_default_redirect_uri.called) + self.assertTrue(self.mock_validator.validate_response_type.called) + self.assertTrue(self.mock_validator.validate_scopes.called) + + @mock.patch('oauthlib.common.generate_token') + def test_create_authorization_response(self, generate_token): + generate_token.return_value = 'abc' + bearer = BearerToken(self.mock_validator) + self.request.response_mode = 'query' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], 'https://a.b/cb?code=abc') + self.request.response_mode = 'fragment' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], 'https://a.b/cb#code=abc') + + def test_create_token_response(self): + bearer = BearerToken(self.mock_validator) + + h, token, s = self.auth.create_token_response(self.request, bearer) + token = json.loads(token) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('refresh_token', token) + self.assertIn('expires_in', token) + self.assertIn('scope', token) + self.assertTrue(self.mock_validator.client_authentication_required.called) + self.assertTrue(self.mock_validator.authenticate_client.called) + self.assertTrue(self.mock_validator.validate_code.called) + self.assertTrue(self.mock_validator.confirm_redirect_uri.called) + self.assertTrue(self.mock_validator.validate_grant_type.called) + self.assertTrue(self.mock_validator.invalidate_authorization_code.called) + + def test_create_token_response_without_refresh_token(self): + self.auth.refresh_token = False # Not to issue refresh token. + + bearer = BearerToken(self.mock_validator) + h, token, s = self.auth.create_token_response(self.request, bearer) + token = json.loads(token) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertNotIn('refresh_token', token) + self.assertIn('expires_in', token) + self.assertIn('scope', token) + self.assertTrue(self.mock_validator.client_authentication_required.called) + self.assertTrue(self.mock_validator.authenticate_client.called) + self.assertTrue(self.mock_validator.validate_code.called) + self.assertTrue(self.mock_validator.confirm_redirect_uri.called) + self.assertTrue(self.mock_validator.validate_grant_type.called) + self.assertTrue(self.mock_validator.invalidate_authorization_code.called) + + def test_invalid_request(self): + del self.request.code + self.assertRaises(errors.InvalidRequestError, self.auth.validate_token_request, + self.request) + + def test_invalid_request_duplicates(self): + request = mock.MagicMock(wraps=self.request) + request.grant_type = 'authorization_code' + request.duplicate_params = ['client_id'] + self.assertRaises(errors.InvalidRequestError, self.auth.validate_token_request, + request) + + def test_authentication_required(self): + """ + ensure client_authentication_required() is properly called + """ + self.auth.validate_token_request(self.request) + self.mock_validator.client_authentication_required.assert_called_once_with(self.request) + + def test_authenticate_client(self): + self.mock_validator.authenticate_client.side_effect = None + self.mock_validator.authenticate_client.return_value = False + self.assertRaises(errors.InvalidClientError, self.auth.validate_token_request, + self.request) + + def test_client_id_missing(self): + self.mock_validator.authenticate_client.side_effect = None + request = mock.MagicMock(wraps=self.request) + request.grant_type = 'authorization_code' + del request.client.client_id + self.assertRaises(NotImplementedError, self.auth.validate_token_request, + request) + + def test_invalid_grant(self): + self.request.client = 'batman' + self.mock_validator.authenticate_client = self.set_client + self.mock_validator.validate_code.return_value = False + self.assertRaises(errors.InvalidGrantError, + self.auth.validate_token_request, self.request) + + def test_invalid_grant_type(self): + self.request.grant_type = 'foo' + self.assertRaises(errors.UnsupportedGrantTypeError, + self.auth.validate_token_request, self.request) + + def test_authenticate_client_id(self): + self.mock_validator.client_authentication_required.return_value = False + self.mock_validator.authenticate_client_id.return_value = False + self.request.state = 'abc' + self.assertRaises(errors.InvalidClientError, + self.auth.validate_token_request, self.request) + + def test_invalid_redirect_uri(self): + self.mock_validator.confirm_redirect_uri.return_value = False + self.assertRaises(errors.MismatchingRedirectURIError, + self.auth.validate_token_request, self.request) + + # PKCE validate_authorization_request + def test_pkce_challenge_missing(self): + self.mock_validator.is_pkce_required.return_value = True + self.assertRaises(errors.MissingCodeChallengeError, + self.auth.validate_authorization_request, self.request) + + def test_pkce_default_method(self): + for required in [True, False]: + self.mock_validator.is_pkce_required.return_value = required + self.request.code_challenge = "present" + _, ri = self.auth.validate_authorization_request(self.request) + self.assertIn("code_challenge", ri) + self.assertIn("code_challenge_method", ri) + self.assertEqual(ri["code_challenge"], "present") + self.assertEqual(ri["code_challenge_method"], "plain") + + def test_pkce_wrong_method(self): + for required in [True, False]: + self.mock_validator.is_pkce_required.return_value = required + self.request.code_challenge = "present" + self.request.code_challenge_method = "foobar" + self.assertRaises(errors.UnsupportedCodeChallengeMethodError, + self.auth.validate_authorization_request, self.request) + + # PKCE validate_token_request + def test_pkce_verifier_missing(self): + self.mock_validator.is_pkce_required.return_value = True + self.assertRaises(errors.MissingCodeVerifierError, + self.auth.validate_token_request, self.request) + + # PKCE validate_token_request + def test_pkce_required_verifier_missing_challenge_missing(self): + self.mock_validator.is_pkce_required.return_value = True + self.request.code_verifier = None + self.mock_validator.get_code_challenge.return_value = None + self.assertRaises(errors.MissingCodeVerifierError, + self.auth.validate_token_request, self.request) + + def test_pkce_required_verifier_missing_challenge_valid(self): + self.mock_validator.is_pkce_required.return_value = True + self.request.code_verifier = None + self.mock_validator.get_code_challenge.return_value = "foo" + self.assertRaises(errors.MissingCodeVerifierError, + self.auth.validate_token_request, self.request) + + def test_pkce_required_verifier_valid_challenge_missing(self): + self.mock_validator.is_pkce_required.return_value = True + self.request.code_verifier = "foobar" + self.mock_validator.get_code_challenge.return_value = None + self.assertRaises(errors.InvalidGrantError, + self.auth.validate_token_request, self.request) + + def test_pkce_required_verifier_valid_challenge_valid_method_valid(self): + self.mock_validator.is_pkce_required.return_value = True + self.request.code_verifier = "foobar" + self.mock_validator.get_code_challenge.return_value = "foobar" + self.mock_validator.get_code_challenge_method.return_value = "plain" + self.auth.validate_token_request(self.request) + + def test_pkce_required_verifier_invalid_challenge_valid_method_valid(self): + self.mock_validator.is_pkce_required.return_value = True + self.request.code_verifier = "foobar" + self.mock_validator.get_code_challenge.return_value = "raboof" + self.mock_validator.get_code_challenge_method.return_value = "plain" + self.assertRaises(errors.InvalidGrantError, + self.auth.validate_token_request, self.request) + + def test_pkce_required_verifier_valid_challenge_valid_method_wrong(self): + self.mock_validator.is_pkce_required.return_value = True + self.request.code_verifier = "present" + self.mock_validator.get_code_challenge.return_value = "foobar" + self.mock_validator.get_code_challenge_method.return_value = "cryptic_method" + self.assertRaises(errors.ServerError, + self.auth.validate_token_request, self.request) + + def test_pkce_verifier_valid_challenge_valid_method_missing(self): + self.mock_validator.is_pkce_required.return_value = True + self.request.code_verifier = "present" + self.mock_validator.get_code_challenge.return_value = "foobar" + self.mock_validator.get_code_challenge_method.return_value = None + self.assertRaises(errors.InvalidGrantError, + self.auth.validate_token_request, self.request) + + def test_pkce_optional_verifier_valid_challenge_missing(self): + self.mock_validator.is_pkce_required.return_value = False + self.request.code_verifier = "present" + self.mock_validator.get_code_challenge.return_value = None + self.auth.validate_token_request(self.request) + + def test_pkce_optional_verifier_missing_challenge_valid(self): + self.mock_validator.is_pkce_required.return_value = False + self.request.code_verifier = None + self.mock_validator.get_code_challenge.return_value = "foobar" + self.assertRaises(errors.MissingCodeVerifierError, + self.auth.validate_token_request, self.request) + + # PKCE functions + def test_wrong_code_challenge_method_plain(self): + self.assertFalse(authorization_code.code_challenge_method_plain("foo", "bar")) + + def test_correct_code_challenge_method_plain(self): + self.assertTrue(authorization_code.code_challenge_method_plain("foo", "foo")) + + def test_wrong_code_challenge_method_s256(self): + self.assertFalse(authorization_code.code_challenge_method_s256("foo", "bar")) + + def test_correct_code_challenge_method_s256(self): + # "abcd" as verifier gives a '+' to base64 + self.assertTrue( + authorization_code.code_challenge_method_s256("abcd", + "iNQmb9TmM40TuEX88olXnSCciXgjuSF9o-Fhk28DFYk") + ) + # "/" as verifier gives a '/' and '+' to base64 + self.assertTrue( + authorization_code.code_challenge_method_s256("/", + "il7asoJjJEMhngUeSt4tHVu8Zxx4EFG_FDeJfL3-oPE") + ) + # Example from PKCE RFCE + self.assertTrue( + authorization_code.code_challenge_method_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", + "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM") + ) + + def test_code_modifier_called(self): + bearer = BearerToken(self.mock_validator) + code_modifier = mock.MagicMock(wraps=lambda grant, *a: grant) + self.auth.register_code_modifier(code_modifier) + self.auth.create_authorization_response(self.request, bearer) + code_modifier.assert_called_once() + + def test_hybrid_token_save(self): + bearer = BearerToken(self.mock_validator) + self.auth.register_code_modifier( + lambda grant, *a: dict(list(grant.items()) + [('access_token', 1)]) + ) + self.auth.create_authorization_response(self.request, bearer) + self.mock_validator.save_token.assert_called_once() + + # CORS + + def test_create_cors_headers(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'https://foo.bar' + self.mock_validator.is_origin_allowed.return_value = True + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertEqual( + headers['Access-Control-Allow-Origin'], 'https://foo.bar' + ) + self.mock_validator.is_origin_allowed.assert_called_once_with( + 'abcdef', 'https://foo.bar', self.request + ) + + def test_create_cors_headers_no_origin(self): + bearer = BearerToken(self.mock_validator) + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_not_called() + + def test_create_cors_headers_insecure_origin(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'http://foo.bar' + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_not_called() + + def test_create_cors_headers_invalid_origin(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'https://foo.bar' + self.mock_validator.is_origin_allowed.return_value = False + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_called_once_with( + 'abcdef', 'https://foo.bar', self.request + ) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_client_credentials.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_client_credentials.py new file mode 100644 index 0000000000..e9559c7931 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_client_credentials.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +import json +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749.grant_types import ClientCredentialsGrant +from oauthlib.oauth2.rfc6749.tokens import BearerToken + +from tests.unittest import TestCase + + +class ClientCredentialsGrantTest(TestCase): + + def setUp(self): + mock_client = mock.MagicMock() + mock_client.user.return_value = 'mocked user' + self.request = Request('http://a.b/path') + self.request.grant_type = 'client_credentials' + self.request.client = mock_client + self.request.scopes = ('mocked', 'scopes') + self.mock_validator = mock.MagicMock() + self.auth = ClientCredentialsGrant( + request_validator=self.mock_validator) + + def test_custom_auth_validators_unsupported(self): + authval1, authval2 = mock.Mock(), mock.Mock() + expected = ('ClientCredentialsGrant does not support authorization ' + 'validators. Use token validators instead.') + with self.assertRaises(ValueError) as caught: + ClientCredentialsGrant(self.mock_validator, pre_auth=[authval1]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(ValueError) as caught: + ClientCredentialsGrant(self.mock_validator, post_auth=[authval2]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval1) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval2) + + def test_custom_token_validators(self): + tknval1, tknval2 = mock.Mock(), mock.Mock() + self.auth.custom_validators.pre_token.append(tknval1) + self.auth.custom_validators.post_token.append(tknval2) + + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(tknval1.called) + self.assertTrue(tknval2.called) + + def test_create_token_response(self): + bearer = BearerToken(self.mock_validator) + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('token_type', token) + self.assertIn('expires_in', token) + self.assertIn('Content-Type', headers) + self.assertEqual(headers['Content-Type'], 'application/json') + + def test_error_response(self): + bearer = BearerToken(self.mock_validator) + self.mock_validator.authenticate_client.return_value = False + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + self.assertEqual(self.mock_validator.save_token.call_count, 0) + error_msg = json.loads(body) + self.assertIn('error', error_msg) + self.assertEqual(error_msg['error'], 'invalid_client') + self.assertIn('Content-Type', headers) + self.assertEqual(headers['Content-Type'], 'application/json') + + def test_validate_token_response(self): + # wrong grant type, scope + pass diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_implicit.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_implicit.py new file mode 100644 index 0000000000..1fb71a1dc9 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_implicit.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749.grant_types import ImplicitGrant +from oauthlib.oauth2.rfc6749.tokens import BearerToken + +from tests.unittest import TestCase + + +class ImplicitGrantTest(TestCase): + + def setUp(self): + mock_client = mock.MagicMock() + mock_client.user.return_value = 'mocked user' + self.request = Request('http://a.b/path') + self.request.scopes = ('hello', 'world') + self.request.client = mock_client + self.request.client_id = 'abcdef' + self.request.response_type = 'token' + self.request.state = 'xyz' + self.request.redirect_uri = 'https://b.c/p' + + self.mock_validator = mock.MagicMock() + self.auth = ImplicitGrant(request_validator=self.mock_validator) + + @mock.patch('oauthlib.common.generate_token') + def test_create_token_response(self, generate_token): + generate_token.return_value = '1234' + bearer = BearerToken(self.mock_validator, expires_in=1800) + h, b, s = self.auth.create_token_response(self.request, bearer) + correct_uri = 'https://b.c/p#access_token=1234&token_type=Bearer&expires_in=1800&state=xyz&scope=hello+world' + self.assertEqual(s, 302) + self.assertURLEqual(h['Location'], correct_uri, parse_fragment=True) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + + correct_uri = 'https://b.c/p?access_token=1234&token_type=Bearer&expires_in=1800&state=xyz&scope=hello+world' + self.request.response_mode = 'query' + h, b, s = self.auth.create_token_response(self.request, bearer) + self.assertURLEqual(h['Location'], correct_uri) + + def test_custom_validators(self): + self.authval1, self.authval2 = mock.Mock(), mock.Mock() + self.tknval1, self.tknval2 = mock.Mock(), mock.Mock() + for val in (self.authval1, self.authval2): + val.return_value = {} + for val in (self.tknval1, self.tknval2): + val.return_value = None + self.auth.custom_validators.pre_token.append(self.tknval1) + self.auth.custom_validators.post_token.append(self.tknval2) + self.auth.custom_validators.pre_auth.append(self.authval1) + self.auth.custom_validators.post_auth.append(self.authval2) + + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(self.tknval1.called) + self.assertTrue(self.tknval2.called) + self.assertTrue(self.authval1.called) + self.assertTrue(self.authval2.called) + + def test_error_response(self): + pass diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_refresh_token.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_refresh_token.py new file mode 100644 index 0000000000..581f2a4d6a --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_refresh_token.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +import json +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749 import errors +from oauthlib.oauth2.rfc6749.grant_types import RefreshTokenGrant +from oauthlib.oauth2.rfc6749.tokens import BearerToken + +from tests.unittest import TestCase + + +class RefreshTokenGrantTest(TestCase): + + def setUp(self): + mock_client = mock.MagicMock() + mock_client.user.return_value = 'mocked user' + self.request = Request('http://a.b/path') + self.request.grant_type = 'refresh_token' + self.request.refresh_token = 'lsdkfhj230' + self.request.client_id = 'abcdef' + self.request.client = mock_client + self.request.scope = 'foo' + self.mock_validator = mock.MagicMock() + self.auth = RefreshTokenGrant( + request_validator=self.mock_validator) + + def test_create_token_response(self): + self.mock_validator.get_original_scopes.return_value = ['foo', 'bar'] + bearer = BearerToken(self.mock_validator) + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('token_type', token) + self.assertIn('expires_in', token) + self.assertEqual(token['scope'], 'foo') + + def test_custom_auth_validators_unsupported(self): + authval1, authval2 = mock.Mock(), mock.Mock() + expected = ('RefreshTokenGrant does not support authorization ' + 'validators. Use token validators instead.') + with self.assertRaises(ValueError) as caught: + RefreshTokenGrant(self.mock_validator, pre_auth=[authval1]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(ValueError) as caught: + RefreshTokenGrant(self.mock_validator, post_auth=[authval2]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval1) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval2) + + def test_custom_token_validators(self): + tknval1, tknval2 = mock.Mock(), mock.Mock() + self.auth.custom_validators.pre_token.append(tknval1) + self.auth.custom_validators.post_token.append(tknval2) + + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(tknval1.called) + self.assertTrue(tknval2.called) + + def test_create_token_inherit_scope(self): + self.request.scope = None + self.mock_validator.get_original_scopes.return_value = ['foo', 'bar'] + bearer = BearerToken(self.mock_validator) + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('token_type', token) + self.assertIn('expires_in', token) + self.assertEqual(token['scope'], 'foo bar') + + def test_create_token_within_original_scope(self): + self.mock_validator.get_original_scopes.return_value = ['baz'] + self.mock_validator.is_within_original_scope.return_value = True + bearer = BearerToken(self.mock_validator) + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('token_type', token) + self.assertIn('expires_in', token) + self.assertEqual(token['scope'], 'foo') + + def test_invalid_scope(self): + self.mock_validator.get_original_scopes.return_value = ['baz'] + self.mock_validator.is_within_original_scope.return_value = False + bearer = BearerToken(self.mock_validator) + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 0) + self.assertEqual(token['error'], 'invalid_scope') + self.assertEqual(status_code, 400) + + def test_invalid_token(self): + self.mock_validator.validate_refresh_token.return_value = False + bearer = BearerToken(self.mock_validator) + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 0) + self.assertEqual(token['error'], 'invalid_grant') + self.assertEqual(status_code, 400) + + def test_invalid_client(self): + self.mock_validator.authenticate_client.return_value = False + bearer = BearerToken(self.mock_validator) + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 0) + self.assertEqual(token['error'], 'invalid_client') + self.assertEqual(status_code, 401) + + def test_authentication_required(self): + """ + ensure client_authentication_required() is properly called + """ + self.mock_validator.authenticate_client.return_value = False + self.mock_validator.authenticate_client_id.return_value = False + self.request.code = 'waffles' + self.assertRaises(errors.InvalidClientError, self.auth.validate_token_request, + self.request) + self.mock_validator.client_authentication_required.assert_called_once_with(self.request) + + def test_invalid_grant_type(self): + self.request.grant_type = 'wrong_type' + self.assertRaises(errors.UnsupportedGrantTypeError, + self.auth.validate_token_request, self.request) + + def test_authenticate_client_id(self): + self.mock_validator.client_authentication_required.return_value = False + self.request.refresh_token = mock.MagicMock() + self.mock_validator.authenticate_client_id.return_value = False + self.assertRaises(errors.InvalidClientError, + self.auth.validate_token_request, self.request) + + def test_invalid_refresh_token(self): + # invalid refresh token + self.mock_validator.authenticate_client_id.return_value = True + self.mock_validator.validate_refresh_token.return_value = False + self.assertRaises(errors.InvalidGrantError, + self.auth.validate_token_request, self.request) + # no token provided + del self.request.refresh_token + self.assertRaises(errors.InvalidRequestError, + self.auth.validate_token_request, self.request) + + def test_invalid_scope_original_scopes_empty(self): + self.mock_validator.validate_refresh_token.return_value = True + self.mock_validator.is_within_original_scope.return_value = False + self.assertRaises(errors.InvalidScopeError, + self.auth.validate_token_request, self.request) + + def test_valid_token_request(self): + self.request.scope = 'foo bar' + self.mock_validator.get_original_scopes = mock.Mock() + self.mock_validator.get_original_scopes.return_value = 'foo bar baz' + self.auth.validate_token_request(self.request) + self.assertEqual(self.request.scopes, self.request.scope.split()) + # all ok but without request.scope + del self.request.scope + self.auth.validate_token_request(self.request) + self.assertEqual(self.request.scopes, 'foo bar baz'.split()) + + # CORS + + def test_create_cors_headers(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'https://foo.bar' + self.mock_validator.is_origin_allowed.return_value = True + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertEqual( + headers['Access-Control-Allow-Origin'], 'https://foo.bar' + ) + self.mock_validator.is_origin_allowed.assert_called_once_with( + 'abcdef', 'https://foo.bar', self.request + ) + + def test_create_cors_headers_no_origin(self): + bearer = BearerToken(self.mock_validator) + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_not_called() + + def test_create_cors_headers_insecure_origin(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'http://foo.bar' + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_not_called() + + def test_create_cors_headers_invalid_origin(self): + bearer = BearerToken(self.mock_validator) + self.request.headers['origin'] = 'https://foo.bar' + self.mock_validator.is_origin_allowed.return_value = False + + headers = self.auth.create_token_response(self.request, bearer)[0] + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.mock_validator.is_origin_allowed.assert_called_once_with( + 'abcdef', 'https://foo.bar', self.request + ) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py new file mode 100644 index 0000000000..294e27be35 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/grant_types/test_resource_owner_password.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +import json +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749 import errors +from oauthlib.oauth2.rfc6749.grant_types import ( + ResourceOwnerPasswordCredentialsGrant, +) +from oauthlib.oauth2.rfc6749.tokens import BearerToken + +from tests.unittest import TestCase + + +class ResourceOwnerPasswordCredentialsGrantTest(TestCase): + + def setUp(self): + mock_client = mock.MagicMock() + mock_client.user.return_value = 'mocked user' + self.request = Request('http://a.b/path') + self.request.grant_type = 'password' + self.request.username = 'john' + self.request.password = 'doe' + self.request.client = mock_client + self.request.scopes = ('mocked', 'scopes') + self.mock_validator = mock.MagicMock() + self.auth = ResourceOwnerPasswordCredentialsGrant( + request_validator=self.mock_validator) + + def set_client(self, request, *args, **kwargs): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def test_create_token_response(self): + bearer = BearerToken(self.mock_validator) + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('token_type', token) + self.assertIn('expires_in', token) + self.assertIn('refresh_token', token) + # ensure client_authentication_required() is properly called + self.mock_validator.client_authentication_required.assert_called_once_with(self.request) + # fail client authentication + self.mock_validator.reset_mock() + self.mock_validator.validate_user.return_value = True + self.mock_validator.authenticate_client.return_value = False + status_code = self.auth.create_token_response(self.request, bearer)[2] + self.assertEqual(status_code, 401) + self.assertEqual(self.mock_validator.save_token.call_count, 0) + + # mock client_authentication_required() returning False then fail + self.mock_validator.reset_mock() + self.mock_validator.client_authentication_required.return_value = False + self.mock_validator.authenticate_client_id.return_value = False + status_code = self.auth.create_token_response(self.request, bearer)[2] + self.assertEqual(status_code, 401) + self.assertEqual(self.mock_validator.save_token.call_count, 0) + + def test_create_token_response_without_refresh_token(self): + # self.auth.refresh_token = False so we don't generate a refresh token + self.auth = ResourceOwnerPasswordCredentialsGrant( + request_validator=self.mock_validator, refresh_token=False) + bearer = BearerToken(self.mock_validator) + headers, body, status_code = self.auth.create_token_response( + self.request, bearer) + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('token_type', token) + self.assertIn('expires_in', token) + # ensure no refresh token is generated + self.assertNotIn('refresh_token', token) + # ensure client_authentication_required() is properly called + self.mock_validator.client_authentication_required.assert_called_once_with(self.request) + # fail client authentication + self.mock_validator.reset_mock() + self.mock_validator.validate_user.return_value = True + self.mock_validator.authenticate_client.return_value = False + status_code = self.auth.create_token_response(self.request, bearer)[2] + self.assertEqual(status_code, 401) + self.assertEqual(self.mock_validator.save_token.call_count, 0) + # mock client_authentication_required() returning False then fail + self.mock_validator.reset_mock() + self.mock_validator.client_authentication_required.return_value = False + self.mock_validator.authenticate_client_id.return_value = False + status_code = self.auth.create_token_response(self.request, bearer)[2] + self.assertEqual(status_code, 401) + self.assertEqual(self.mock_validator.save_token.call_count, 0) + + def test_custom_auth_validators_unsupported(self): + authval1, authval2 = mock.Mock(), mock.Mock() + expected = ('ResourceOwnerPasswordCredentialsGrant does not ' + 'support authorization validators. Use token ' + 'validators instead.') + with self.assertRaises(ValueError) as caught: + ResourceOwnerPasswordCredentialsGrant(self.mock_validator, + pre_auth=[authval1]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(ValueError) as caught: + ResourceOwnerPasswordCredentialsGrant(self.mock_validator, + post_auth=[authval2]) + self.assertEqual(caught.exception.args[0], expected) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval1) + with self.assertRaises(AttributeError): + self.auth.custom_validators.pre_auth.append(authval2) + + def test_custom_token_validators(self): + tknval1, tknval2 = mock.Mock(), mock.Mock() + self.auth.custom_validators.pre_token.append(tknval1) + self.auth.custom_validators.post_token.append(tknval2) + + bearer = BearerToken(self.mock_validator) + self.auth.create_token_response(self.request, bearer) + self.assertTrue(tknval1.called) + self.assertTrue(tknval2.called) + + def test_error_response(self): + pass + + def test_scopes(self): + pass + + def test_invalid_request_missing_params(self): + del self.request.grant_type + self.assertRaises(errors.InvalidRequestError, self.auth.validate_token_request, + self.request) + + def test_invalid_request_duplicates(self): + request = mock.MagicMock(wraps=self.request) + request.duplicate_params = ['scope'] + self.assertRaises(errors.InvalidRequestError, self.auth.validate_token_request, + request) + + def test_invalid_grant_type(self): + self.request.grant_type = 'foo' + self.assertRaises(errors.UnsupportedGrantTypeError, + self.auth.validate_token_request, self.request) + + def test_invalid_user(self): + self.mock_validator.validate_user.return_value = False + self.assertRaises(errors.InvalidGrantError, self.auth.validate_token_request, + self.request) + + def test_client_id_missing(self): + del self.request.client.client_id + self.assertRaises(NotImplementedError, self.auth.validate_token_request, + self.request) + + def test_valid_token_request(self): + self.mock_validator.validate_grant_type.return_value = True + self.auth.validate_token_request(self.request) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/test_parameters.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_parameters.py new file mode 100644 index 0000000000..cd8c9e952b --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_parameters.py @@ -0,0 +1,304 @@ +from unittest.mock import patch + +from oauthlib import signals +from oauthlib.oauth2.rfc6749.errors import * +from oauthlib.oauth2.rfc6749.parameters import * + +from tests.unittest import TestCase + + +@patch('time.time', new=lambda: 1000) +class ParameterTests(TestCase): + + state = 'xyz' + auth_base = { + 'uri': 'https://server.example.com/authorize', + 'client_id': 's6BhdRkqt3', + 'redirect_uri': 'https://client.example.com/cb', + 'state': state, + 'scope': 'photos' + } + list_scope = ['list', 'of', 'scopes'] + + auth_grant = {'response_type': 'code'} + auth_grant_pkce = {'response_type': 'code', 'code_challenge': "code_challenge", + 'code_challenge_method': 'code_challenge_method'} + auth_grant_list_scope = {} + auth_implicit = {'response_type': 'token', 'extra': 'extra'} + auth_implicit_list_scope = {} + + def setUp(self): + self.auth_grant.update(self.auth_base) + self.auth_grant_pkce.update(self.auth_base) + self.auth_implicit.update(self.auth_base) + self.auth_grant_list_scope.update(self.auth_grant) + self.auth_grant_list_scope['scope'] = self.list_scope + self.auth_implicit_list_scope.update(self.auth_implicit) + self.auth_implicit_list_scope['scope'] = self.list_scope + + auth_base_uri = ('https://server.example.com/authorize?response_type={0}' + '&client_id=s6BhdRkqt3&redirect_uri=https%3A%2F%2F' + 'client.example.com%2Fcb&scope={1}&state={2}{3}') + + auth_base_uri_pkce = ('https://server.example.com/authorize?response_type={0}' + '&client_id=s6BhdRkqt3&redirect_uri=https%3A%2F%2F' + 'client.example.com%2Fcb&scope={1}&state={2}{3}&code_challenge={4}' + '&code_challenge_method={5}') + + auth_grant_uri = auth_base_uri.format('code', 'photos', state, '') + auth_grant_uri_pkce = auth_base_uri_pkce.format('code', 'photos', state, '', 'code_challenge', + 'code_challenge_method') + auth_grant_uri_list_scope = auth_base_uri.format('code', 'list+of+scopes', state, '') + auth_implicit_uri = auth_base_uri.format('token', 'photos', state, '&extra=extra') + auth_implicit_uri_list_scope = auth_base_uri.format('token', 'list+of+scopes', state, '&extra=extra') + + grant_body = { + 'grant_type': 'authorization_code', + 'code': 'SplxlOBeZQQYbYS6WxSbIA', + 'redirect_uri': 'https://client.example.com/cb' + } + grant_body_pkce = { + 'grant_type': 'authorization_code', + 'code': 'SplxlOBeZQQYbYS6WxSbIA', + 'redirect_uri': 'https://client.example.com/cb', + 'code_verifier': 'code_verifier' + } + grant_body_scope = {'scope': 'photos'} + grant_body_list_scope = {'scope': list_scope} + auth_grant_body = ('grant_type=authorization_code&' + 'code=SplxlOBeZQQYbYS6WxSbIA&' + 'redirect_uri=https%3A%2F%2Fclient.example.com%2Fcb') + auth_grant_body_pkce = ('grant_type=authorization_code&' + 'code=SplxlOBeZQQYbYS6WxSbIA&' + 'redirect_uri=https%3A%2F%2Fclient.example.com%2Fcb' + '&code_verifier=code_verifier') + auth_grant_body_scope = auth_grant_body + '&scope=photos' + auth_grant_body_list_scope = auth_grant_body + '&scope=list+of+scopes' + + pwd_body = { + 'grant_type': 'password', + 'username': 'johndoe', + 'password': 'A3ddj3w' + } + password_body = 'grant_type=password&username=johndoe&password=A3ddj3w' + + cred_grant = {'grant_type': 'client_credentials'} + cred_body = 'grant_type=client_credentials' + + grant_response = 'https://client.example.com/cb?code=SplxlOBeZQQYbYS6WxSbIA&state=xyz' + grant_dict = {'code': 'SplxlOBeZQQYbYS6WxSbIA', 'state': state} + + error_nocode = 'https://client.example.com/cb?state=xyz' + error_nostate = 'https://client.example.com/cb?code=SplxlOBeZQQYbYS6WxSbIA' + error_wrongstate = 'https://client.example.com/cb?code=SplxlOBeZQQYbYS6WxSbIA&state=abc' + error_denied = 'https://client.example.com/cb?error=access_denied&state=xyz' + error_invalid = 'https://client.example.com/cb?error=invalid_request&state=xyz' + + implicit_base = 'https://example.com/cb#access_token=2YotnFZFEjr1zCsicMWpAA&scope=abc&' + implicit_response = implicit_base + 'state={}&token_type=example&expires_in=3600'.format(state) + implicit_notype = implicit_base + 'state={}&expires_in=3600'.format(state) + implicit_wrongstate = implicit_base + 'state={}&token_type=exampleexpires_in=3600'.format('invalid') + implicit_nostate = implicit_base + 'token_type=example&expires_in=3600' + implicit_notoken = 'https://example.com/cb#state=xyz&token_type=example&expires_in=3600' + + implicit_dict = { + 'access_token': '2YotnFZFEjr1zCsicMWpAA', + 'state': state, + 'token_type': 'example', + 'expires_in': 3600, + 'expires_at': 4600, + 'scope': ['abc'] + } + + json_response = ('{ "access_token": "2YotnFZFEjr1zCsicMWpAA",' + ' "token_type": "example",' + ' "expires_in": 3600,' + ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter": "example_value",' + ' "scope":"abc def"}') + json_response_noscope = ('{ "access_token": "2YotnFZFEjr1zCsicMWpAA",' + ' "token_type": "example",' + ' "expires_in": 3600,' + ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter": "example_value" }') + json_response_noexpire = ('{ "access_token": "2YotnFZFEjr1zCsicMWpAA",' + ' "token_type": "example",' + ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter": "example_value"}') + json_response_expirenull = ('{ "access_token": "2YotnFZFEjr1zCsicMWpAA",' + ' "token_type": "example",' + ' "expires_in": null,' + ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter": "example_value"}') + + json_custom_error = '{ "error": "incorrect_client_credentials" }' + json_error = '{ "error": "access_denied" }' + + json_notoken = ('{ "token_type": "example",' + ' "expires_in": 3600,' + ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter": "example_value" }') + + json_notype = ('{ "access_token": "2YotnFZFEjr1zCsicMWpAA",' + ' "expires_in": 3600,' + ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter": "example_value" }') + + json_dict = { + 'access_token': '2YotnFZFEjr1zCsicMWpAA', + 'token_type': 'example', + 'expires_in': 3600, + 'expires_at': 4600, + 'refresh_token': 'tGzv3JOkF0XG5Qx2TlKWIA', + 'example_parameter': 'example_value', + 'scope': ['abc', 'def'] + } + + json_noscope_dict = { + 'access_token': '2YotnFZFEjr1zCsicMWpAA', + 'token_type': 'example', + 'expires_in': 3600, + 'expires_at': 4600, + 'refresh_token': 'tGzv3JOkF0XG5Qx2TlKWIA', + 'example_parameter': 'example_value' + } + + json_noexpire_dict = { + 'access_token': '2YotnFZFEjr1zCsicMWpAA', + 'token_type': 'example', + 'refresh_token': 'tGzv3JOkF0XG5Qx2TlKWIA', + 'example_parameter': 'example_value' + } + + json_notype_dict = { + 'access_token': '2YotnFZFEjr1zCsicMWpAA', + 'expires_in': 3600, + 'expires_at': 4600, + 'refresh_token': 'tGzv3JOkF0XG5Qx2TlKWIA', + 'example_parameter': 'example_value', + } + + url_encoded_response = ('access_token=2YotnFZFEjr1zCsicMWpAA' + '&token_type=example' + '&expires_in=3600' + '&refresh_token=tGzv3JOkF0XG5Qx2TlKWIA' + '&example_parameter=example_value' + '&scope=abc def') + + url_encoded_error = 'error=access_denied' + + url_encoded_notoken = ('token_type=example' + '&expires_in=3600' + '&refresh_token=tGzv3JOkF0XG5Qx2TlKWIA' + '&example_parameter=example_value') + + + def test_prepare_grant_uri(self): + """Verify correct authorization URI construction.""" + self.assertURLEqual(prepare_grant_uri(**self.auth_grant), self.auth_grant_uri) + self.assertURLEqual(prepare_grant_uri(**self.auth_grant_list_scope), self.auth_grant_uri_list_scope) + self.assertURLEqual(prepare_grant_uri(**self.auth_implicit), self.auth_implicit_uri) + self.assertURLEqual(prepare_grant_uri(**self.auth_implicit_list_scope), self.auth_implicit_uri_list_scope) + self.assertURLEqual(prepare_grant_uri(**self.auth_grant_pkce), self.auth_grant_uri_pkce) + + def test_prepare_token_request(self): + """Verify correct access token request body construction.""" + self.assertFormBodyEqual(prepare_token_request(**self.grant_body), self.auth_grant_body) + self.assertFormBodyEqual(prepare_token_request(**self.pwd_body), self.password_body) + self.assertFormBodyEqual(prepare_token_request(**self.cred_grant), self.cred_body) + self.assertFormBodyEqual(prepare_token_request(**self.grant_body_pkce), self.auth_grant_body_pkce) + + def test_grant_response(self): + """Verify correct parameter parsing and validation for auth code responses.""" + params = parse_authorization_code_response(self.grant_response) + self.assertEqual(params, self.grant_dict) + params = parse_authorization_code_response(self.grant_response, state=self.state) + self.assertEqual(params, self.grant_dict) + + self.assertRaises(MissingCodeError, parse_authorization_code_response, + self.error_nocode) + self.assertRaises(AccessDeniedError, parse_authorization_code_response, + self.error_denied) + self.assertRaises(InvalidRequestFatalError, parse_authorization_code_response, + self.error_invalid) + self.assertRaises(MismatchingStateError, parse_authorization_code_response, + self.error_nostate, state=self.state) + self.assertRaises(MismatchingStateError, parse_authorization_code_response, + self.error_wrongstate, state=self.state) + + def test_implicit_token_response(self): + """Verify correct parameter parsing and validation for implicit responses.""" + self.assertEqual(parse_implicit_response(self.implicit_response), + self.implicit_dict) + self.assertRaises(MissingTokenError, parse_implicit_response, + self.implicit_notoken) + self.assertRaises(ValueError, parse_implicit_response, + self.implicit_nostate, state=self.state) + self.assertRaises(ValueError, parse_implicit_response, + self.implicit_wrongstate, state=self.state) + + def test_custom_json_error(self): + self.assertRaises(CustomOAuth2Error, parse_token_response, self.json_custom_error) + + def test_json_token_response(self): + """Verify correct parameter parsing and validation for token responses. """ + self.assertEqual(parse_token_response(self.json_response), self.json_dict) + self.assertRaises(AccessDeniedError, parse_token_response, self.json_error) + self.assertRaises(MissingTokenError, parse_token_response, self.json_notoken) + + self.assertEqual(parse_token_response(self.json_response_noscope, + scope=['all', 'the', 'scopes']), self.json_noscope_dict) + self.assertEqual(parse_token_response(self.json_response_noexpire), self.json_noexpire_dict) + self.assertEqual(parse_token_response(self.json_response_expirenull), self.json_noexpire_dict) + + scope_changes_recorded = [] + def record_scope_change(sender, message, old, new): + scope_changes_recorded.append((message, old, new)) + + os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '1' + signals.scope_changed.connect(record_scope_change) + try: + parse_token_response(self.json_response, scope='aaa') + self.assertEqual(len(scope_changes_recorded), 1) + message, old, new = scope_changes_recorded[0] + for scope in new + old: + self.assertIn(scope, message) + self.assertEqual(old, ['aaa']) + self.assertEqual(set(new), {'abc', 'def'}) + finally: + signals.scope_changed.disconnect(record_scope_change) + del os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] + + + def test_json_token_notype(self): + """Verify strict token type parsing only when configured. """ + self.assertEqual(parse_token_response(self.json_notype), self.json_notype_dict) + try: + os.environ['OAUTHLIB_STRICT_TOKEN_TYPE'] = '1' + self.assertRaises(MissingTokenTypeError, parse_token_response, self.json_notype) + finally: + del os.environ['OAUTHLIB_STRICT_TOKEN_TYPE'] + + def test_url_encoded_token_response(self): + """Verify fallback parameter parsing and validation for token responses. """ + self.assertEqual(parse_token_response(self.url_encoded_response), self.json_dict) + self.assertRaises(AccessDeniedError, parse_token_response, self.url_encoded_error) + self.assertRaises(MissingTokenError, parse_token_response, self.url_encoded_notoken) + + scope_changes_recorded = [] + def record_scope_change(sender, message, old, new): + scope_changes_recorded.append((message, old, new)) + + os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '1' + signals.scope_changed.connect(record_scope_change) + try: + token = parse_token_response(self.url_encoded_response, scope='aaa') + self.assertEqual(len(scope_changes_recorded), 1) + message, old, new = scope_changes_recorded[0] + for scope in new + old: + self.assertIn(scope, message) + self.assertEqual(old, ['aaa']) + self.assertEqual(set(new), {'abc', 'def'}) + finally: + signals.scope_changed.disconnect(record_scope_change) + del os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/test_request_validator.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_request_validator.py new file mode 100644 index 0000000000..7a8d06b668 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_request_validator.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +from oauthlib.oauth2 import RequestValidator + +from tests.unittest import TestCase + + +class RequestValidatorTest(TestCase): + + def test_method_contracts(self): + v = RequestValidator() + self.assertRaises(NotImplementedError, v.authenticate_client, 'r') + self.assertRaises(NotImplementedError, v.authenticate_client_id, + 'client_id', 'r') + self.assertRaises(NotImplementedError, v.confirm_redirect_uri, + 'client_id', 'code', 'redirect_uri', 'client', 'request') + self.assertRaises(NotImplementedError, v.get_default_redirect_uri, + 'client_id', 'request') + self.assertRaises(NotImplementedError, v.get_default_scopes, + 'client_id', 'request') + self.assertRaises(NotImplementedError, v.get_original_scopes, + 'refresh_token', 'request') + self.assertFalse(v.is_within_original_scope( + ['scope'], 'refresh_token', 'request')) + self.assertRaises(NotImplementedError, v.invalidate_authorization_code, + 'client_id', 'code', 'request') + self.assertRaises(NotImplementedError, v.save_authorization_code, + 'client_id', 'code', 'request') + self.assertRaises(NotImplementedError, v.save_bearer_token, + 'token', 'request') + self.assertRaises(NotImplementedError, v.validate_bearer_token, + 'token', 'scopes', 'request') + self.assertRaises(NotImplementedError, v.validate_client_id, + 'client_id', 'request') + self.assertRaises(NotImplementedError, v.validate_code, + 'client_id', 'code', 'client', 'request') + self.assertRaises(NotImplementedError, v.validate_grant_type, + 'client_id', 'grant_type', 'client', 'request') + self.assertRaises(NotImplementedError, v.validate_redirect_uri, + 'client_id', 'redirect_uri', 'request') + self.assertRaises(NotImplementedError, v.validate_refresh_token, + 'refresh_token', 'client', 'request') + self.assertRaises(NotImplementedError, v.validate_response_type, + 'client_id', 'response_type', 'client', 'request') + self.assertRaises(NotImplementedError, v.validate_scopes, + 'client_id', 'scopes', 'client', 'request') + self.assertRaises(NotImplementedError, v.validate_user, + 'username', 'password', 'client', 'request') + self.assertTrue(v.client_authentication_required('r')) + self.assertFalse( + v.is_origin_allowed('client_id', 'https://foo.bar', 'r') + ) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/test_server.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_server.py new file mode 100644 index 0000000000..94af37e56b --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_server.py @@ -0,0 +1,391 @@ +# -*- coding: utf-8 -*- +import json +from unittest import mock + +from oauthlib import common +from oauthlib.oauth2.rfc6749 import errors, tokens +from oauthlib.oauth2.rfc6749.endpoints import Server +from oauthlib.oauth2.rfc6749.endpoints.authorization import ( + AuthorizationEndpoint, +) +from oauthlib.oauth2.rfc6749.endpoints.resource import ResourceEndpoint +from oauthlib.oauth2.rfc6749.endpoints.token import TokenEndpoint +from oauthlib.oauth2.rfc6749.grant_types import ( + AuthorizationCodeGrant, ClientCredentialsGrant, ImplicitGrant, + ResourceOwnerPasswordCredentialsGrant, +) + +from tests.unittest import TestCase + + +class AuthorizationEndpointTest(TestCase): + + def setUp(self): + self.mock_validator = mock.MagicMock() + self.mock_validator.get_code_challenge.return_value = None + self.addCleanup(setattr, self, 'mock_validator', mock.MagicMock()) + auth_code = AuthorizationCodeGrant( + request_validator=self.mock_validator) + auth_code.save_authorization_code = mock.MagicMock() + implicit = ImplicitGrant( + request_validator=self.mock_validator) + implicit.save_token = mock.MagicMock() + + response_types = { + 'code': auth_code, + 'token': implicit, + 'none': auth_code + } + self.expires_in = 1800 + token = tokens.BearerToken( + self.mock_validator, + expires_in=self.expires_in + ) + self.endpoint = AuthorizationEndpoint( + default_response_type='code', + default_token_type=token, + response_types=response_types + ) + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_authorization_grant(self): + uri = 'http://i.b/l?response_type=code&client_id=me&scope=all+of+them&state=xyz' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me?code=abc&state=xyz') + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_implicit_grant(self): + uri = 'http://i.b/l?response_type=token&client_id=me&scope=all+of+them&state=xyz' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me#access_token=abc&expires_in=' + str(self.expires_in) + '&token_type=Bearer&state=xyz&scope=all+of+them', parse_fragment=True) + + def test_none_grant(self): + uri = 'http://i.b/l?response_type=none&client_id=me&scope=all+of+them&state=xyz' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me?state=xyz', parse_fragment=True) + self.assertIsNone(body) + self.assertEqual(status_code, 302) + + # and without the state parameter + uri = 'http://i.b/l?response_type=none&client_id=me&scope=all+of+them' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me', parse_fragment=True) + self.assertIsNone(body) + self.assertEqual(status_code, 302) + + def test_missing_type(self): + uri = 'http://i.b/l?client_id=me&scope=all+of+them' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + self.mock_validator.validate_request = mock.MagicMock( + side_effect=errors.InvalidRequestError()) + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me?error=invalid_request&error_description=Missing+response_type+parameter.') + + def test_invalid_type(self): + uri = 'http://i.b/l?response_type=invalid&client_id=me&scope=all+of+them' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + self.mock_validator.validate_request = mock.MagicMock( + side_effect=errors.UnsupportedResponseTypeError()) + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me?error=unsupported_response_type') + + +class TokenEndpointTest(TestCase): + + def setUp(self): + def set_user(request): + request.user = mock.MagicMock() + request.client = mock.MagicMock() + request.client.client_id = 'mocked_client_id' + return True + + self.mock_validator = mock.MagicMock() + self.mock_validator.authenticate_client.side_effect = set_user + self.mock_validator.get_code_challenge.return_value = None + self.addCleanup(setattr, self, 'mock_validator', mock.MagicMock()) + auth_code = AuthorizationCodeGrant( + request_validator=self.mock_validator) + password = ResourceOwnerPasswordCredentialsGrant( + request_validator=self.mock_validator) + client = ClientCredentialsGrant( + request_validator=self.mock_validator) + supported_types = { + 'authorization_code': auth_code, + 'password': password, + 'client_credentials': client, + } + self.expires_in = 1800 + token = tokens.BearerToken( + self.mock_validator, + expires_in=self.expires_in + ) + self.endpoint = TokenEndpoint( + 'authorization_code', + default_token_type=token, + grant_types=supported_types + ) + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_authorization_grant(self): + body = 'grant_type=authorization_code&code=abc&scope=all+of+them' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': 'abc', + 'refresh_token': 'abc', + 'scope': 'all of them' + } + self.assertEqual(json.loads(body), token) + + body = 'grant_type=authorization_code&code=abc' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': 'abc', + 'refresh_token': 'abc' + } + self.assertEqual(json.loads(body), token) + + # try with additional custom variables + body = 'grant_type=authorization_code&code=abc&state=foobar' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + self.assertEqual(json.loads(body), token) + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_password_grant(self): + body = 'grant_type=password&username=a&password=hello&scope=all+of+them' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': 'abc', + 'refresh_token': 'abc', + 'scope': 'all of them', + } + self.assertEqual(json.loads(body), token) + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_client_grant(self): + body = 'grant_type=client_credentials&scope=all+of+them' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': 'abc', + 'scope': 'all of them', + } + self.assertEqual(json.loads(body), token) + + def test_missing_type(self): + _, body, _ = self.endpoint.create_token_response('', body='') + token = {'error': 'unsupported_grant_type'} + self.assertEqual(json.loads(body), token) + + def test_invalid_type(self): + body = 'grant_type=invalid' + _, body, _ = self.endpoint.create_token_response('', body=body) + token = {'error': 'unsupported_grant_type'} + self.assertEqual(json.loads(body), token) + + +class SignedTokenEndpointTest(TestCase): + + def setUp(self): + self.expires_in = 1800 + + def set_user(request): + request.user = mock.MagicMock() + request.client = mock.MagicMock() + request.client.client_id = 'mocked_client_id' + return True + + self.mock_validator = mock.MagicMock() + self.mock_validator.get_code_challenge.return_value = None + self.mock_validator.authenticate_client.side_effect = set_user + self.addCleanup(setattr, self, 'mock_validator', mock.MagicMock()) + + self.private_pem = """ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA6TtDhWGwzEOWZP6m/zHoZnAPLABfetvoMPmxPGjFjtDuMRPv +EvI1sbixZBjBtdnc5rTtHUUQ25Am3JzwPRGo5laMGbj1pPyCPxlVi9LK82HQNX0B +YK7tZtVfDHElQA7F4v3j9d3rad4O9/n+lyGIQ0tT7yQcBm2A8FEaP0bZYCLMjwMN +WfaVLE8eXHyv+MfpNNLI9wttLxygKYM48I3NwsFuJgOa/KuodXaAmf8pJnx8t1Wn +nxvaYXFiUn/TxmhM/qhemPa6+0nqq+aWV5eT7xn4K/ghLgNs09v6Yge0pmPl9Oz+ ++bjJ+aKRnAmwCOY8/5U5EilAiUOeBoO9+8OXtwIDAQABAoIBAGFTTbXXMkPK4HN8 +oItVdDlrAanG7hECuz3UtFUVE3upS/xG6TjqweVLwRqYCh2ssDXFwjy4mXRGDzF4 +e/e/6s9Txlrlh/w1MtTJ6ZzTdcViR9RKOczysjZ7S5KRlI3KnGFAuWPcG2SuOWjZ +dZfzcj1Crd/ZHajBAVFHRsCo/ATVNKbTRprFfb27xKpQ2BwH/GG781sLE3ZVNIhs +aRRaED4622kI1E/WXws2qQMqbFKzo0m1tPbLb3Z89WgZJ/tRQwuDype1Vfm7k6oX +xfbp3948qSe/yWKRlMoPkleji/WxPkSIalzWSAi9ziN/0Uzhe65FURgrfHL3XR1A +B8UR+aECgYEA7NPQZV4cAikk02Hv65JgISofqV49P8MbLXk8sdnI1n7Mj10TgzU3 +lyQGDEX4hqvT0bTXe4KAOxQZx9wumu05ejfzhdtSsEm6ptGHyCdmYDQeV0C/pxDX +JNCK8XgMku2370XG0AnyBCT7NGlgtDcNCQufcesF2gEuoKiXg6Zjo7sCgYEA/Bzs +9fWGZZnSsMSBSW2OYbFuhF3Fne0HcxXQHipl0Rujc/9g0nccwqKGizn4fGOE7a8F +usQgJoeGcinL7E9OEP/uQ9VX1C9RNVjIxP1O5/Guw1zjxQQYetOvbPhN2QhD1Ye7 +0TRKrW1BapcjwLpFQlVg1ZeTPOi5lv24W/wX9jUCgYEAkrMSX/hPuTbrTNVZ3L6r +NV/2hN+PaTPeXei/pBuXwOaCqDurnpcUfFcgN/IP5LwDVd+Dq0pHTFFDNv45EFbq +R77o5n3ZVsIVEMiyJ1XgoK8oLDw7e61+15smtjT69Piz+09pu+ytMcwGn4y3Dmsb +dALzHYnL8iLRU0ubrz0ec4kCgYAJiVKRTzNBPptQom49h85d9ac3jJCAE8o3WTjh +Gzt0uHXrWlqgO280EY/DTnMOyXjqwLcXxHlu26uDP/99tdY/IF8z46sJ1KxetzgI +84f7kBHLRAU9m5UNeFpnZdEUB5MBTbwWAsNcYgiabpMkpCcghjg+fBhOsoLqqjhC +CnwhjQKBgQDkv0QTdyBU84TE8J0XY3eLQwXbrvG2yD5A2ntN3PyxGEneX5WTJGMZ +xJxwaFYQiDS3b9E7b8Q5dg8qa5Y1+epdhx3cuQAWPm+AoHKshDfbRve4txBDQAqh +c6MxSWgsa+2Ld5SWSNbGtpPcmEM3Fl5ttMCNCKtNc0UE16oHwaPAIw== +-----END RSA PRIVATE KEY----- + """ + + self.public_pem = """ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA6TtDhWGwzEOWZP6m/zHo +ZnAPLABfetvoMPmxPGjFjtDuMRPvEvI1sbixZBjBtdnc5rTtHUUQ25Am3JzwPRGo +5laMGbj1pPyCPxlVi9LK82HQNX0BYK7tZtVfDHElQA7F4v3j9d3rad4O9/n+lyGI +Q0tT7yQcBm2A8FEaP0bZYCLMjwMNWfaVLE8eXHyv+MfpNNLI9wttLxygKYM48I3N +wsFuJgOa/KuodXaAmf8pJnx8t1WnnxvaYXFiUn/TxmhM/qhemPa6+0nqq+aWV5eT +7xn4K/ghLgNs09v6Yge0pmPl9Oz++bjJ+aKRnAmwCOY8/5U5EilAiUOeBoO9+8OX +twIDAQAB +-----END PUBLIC KEY----- + """ + + signed_token = tokens.signed_token_generator(self.private_pem, + user_id=123) + self.endpoint = Server( + self.mock_validator, + token_expires_in=self.expires_in, + token_generator=signed_token, + refresh_token_generator=tokens.random_token_generator + ) + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_authorization_grant(self): + body = 'client_id=me&redirect_uri=http%3A%2F%2Fback.to%2Fme&grant_type=authorization_code&code=abc&scope=all+of+them' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + body = json.loads(body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': body['access_token'], + 'refresh_token': 'abc', + 'scope': 'all of them' + } + self.assertEqual(body, token) + + body = 'client_id=me&redirect_uri=http%3A%2F%2Fback.to%2Fme&grant_type=authorization_code&code=abc' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + body = json.loads(body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': body['access_token'], + 'refresh_token': 'abc' + } + self.assertEqual(body, token) + + # try with additional custom variables + body = 'client_id=me&redirect_uri=http%3A%2F%2Fback.to%2Fme&grant_type=authorization_code&code=abc&state=foobar' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + body = json.loads(body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': body['access_token'], + 'refresh_token': 'abc' + } + self.assertEqual(body, token) + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_password_grant(self): + body = 'grant_type=password&username=a&password=hello&scope=all+of+them' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + body = json.loads(body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': body['access_token'], + 'refresh_token': 'abc', + 'scope': 'all of them', + } + self.assertEqual(body, token) + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_scopes_and_user_id_stored_in_access_token(self): + body = 'grant_type=password&username=a&password=hello&scope=all+of+them' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + + access_token = json.loads(body)['access_token'] + + claims = common.verify_signed_token(self.public_pem, access_token) + + self.assertEqual(claims['scope'], 'all of them') + self.assertEqual(claims['user_id'], 123) + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_client_grant(self): + body = 'grant_type=client_credentials&scope=all+of+them' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + body = json.loads(body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': body['access_token'], + 'scope': 'all of them', + } + self.assertEqual(body, token) + + def test_missing_type(self): + _, body, _ = self.endpoint.create_token_response('', body='client_id=me&redirect_uri=http%3A%2F%2Fback.to%2Fme&code=abc') + token = {'error': 'unsupported_grant_type'} + self.assertEqual(json.loads(body), token) + + def test_invalid_type(self): + body = 'client_id=me&redirect_uri=http%3A%2F%2Fback.to%2Fme&grant_type=invalid&code=abc' + _, body, _ = self.endpoint.create_token_response('', body=body) + token = {'error': 'unsupported_grant_type'} + self.assertEqual(json.loads(body), token) + + +class ResourceEndpointTest(TestCase): + + def setUp(self): + self.mock_validator = mock.MagicMock() + self.addCleanup(setattr, self, 'mock_validator', mock.MagicMock()) + token = tokens.BearerToken(request_validator=self.mock_validator) + self.endpoint = ResourceEndpoint( + default_token='Bearer', + token_types={'Bearer': token} + ) + + def test_defaults(self): + uri = 'http://a.b/path?some=query' + self.mock_validator.validate_bearer_token.return_value = False + valid, request = self.endpoint.verify_request(uri) + self.assertFalse(valid) + self.assertEqual(request.token_type, 'Bearer') diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/test_tokens.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_tokens.py new file mode 100644 index 0000000000..fa6b1c092c --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_tokens.py @@ -0,0 +1,170 @@ +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749.tokens import ( + BearerToken, prepare_bearer_body, prepare_bearer_headers, + prepare_bearer_uri, prepare_mac_header, +) + +from tests.unittest import TestCase + + +class TokenTest(TestCase): + + # MAC without body/payload or extension + mac_plain = { + 'token': 'h480djs93hd8', + 'uri': 'http://example.com/resource/1?b=1&a=2', + 'key': '489dks293j39', + 'http_method': 'GET', + 'nonce': '264095:dj83hs9s', + 'hash_algorithm': 'hmac-sha-1' + } + auth_plain = { + 'Authorization': 'MAC id="h480djs93hd8", nonce="264095:dj83hs9s",' + ' mac="SLDJd4mg43cjQfElUs3Qub4L6xE="' + } + + # MAC with body/payload, no extension + mac_body = { + 'token': 'jd93dh9dh39D', + 'uri': 'http://example.com/request', + 'key': '8yfrufh348h', + 'http_method': 'POST', + 'nonce': '273156:di3hvdf8', + 'hash_algorithm': 'hmac-sha-1', + 'body': 'hello=world%21' + } + auth_body = { + 'Authorization': 'MAC id="jd93dh9dh39D", nonce="273156:di3hvdf8",' + ' bodyhash="k9kbtCIy0CkI3/FEfpS/oIDjk6k=", mac="W7bdMZbv9UWOTadASIQHagZyirA="' + } + + # MAC with body/payload and extension + mac_both = { + 'token': 'h480djs93hd8', + 'uri': 'http://example.com/request?b5=%3D%253D&a3=a&c%40=&a2=r%20b&c2&a3=2+q', + 'key': '489dks293j39', + 'http_method': 'GET', + 'nonce': '264095:7d8f3e4a', + 'hash_algorithm': 'hmac-sha-1', + 'body': 'Hello World!', + 'ext': 'a,b,c' + } + auth_both = { + 'Authorization': 'MAC id="h480djs93hd8", nonce="264095:7d8f3e4a",' + ' bodyhash="Lve95gjOVATpfV8EL5X4nxwjKHE=", ext="a,b,c",' + ' mac="Z3C2DojEopRDIC88/imW8Ez853g="' + } + + # Bearer + token = 'vF9dft4qmT' + uri = 'http://server.example.com/resource' + bearer_headers = { + 'Authorization': 'Bearer vF9dft4qmT' + } + valid_bearer_header_lowercase = {"Authorization": "bearer vF9dft4qmT"} + fake_bearer_headers = [ + {'Authorization': 'Beaver vF9dft4qmT'}, + {'Authorization': 'BeavervF9dft4qmT'}, + {'Authorization': 'Beaver vF9dft4qmT'}, + {'Authorization': 'BearerF9dft4qmT'}, + {'Authorization': 'Bearer vF9d ft4qmT'}, + ] + valid_header_with_multiple_spaces = {'Authorization': 'Bearer vF9dft4qmT'} + bearer_body = 'access_token=vF9dft4qmT' + bearer_uri = 'http://server.example.com/resource?access_token=vF9dft4qmT' + + def _mocked_validate_bearer_token(self, token, scopes, request): + if not token: + return False + return True + + def test_prepare_mac_header(self): + """Verify mac signatures correctness + + TODO: verify hmac-sha-256 + """ + self.assertEqual(prepare_mac_header(**self.mac_plain), self.auth_plain) + self.assertEqual(prepare_mac_header(**self.mac_body), self.auth_body) + self.assertEqual(prepare_mac_header(**self.mac_both), self.auth_both) + + def test_prepare_bearer_request(self): + """Verify proper addition of bearer tokens to requests. + + They may be represented as query components in body or URI or + in a Bearer authorization header. + """ + self.assertEqual(prepare_bearer_headers(self.token), self.bearer_headers) + self.assertEqual(prepare_bearer_body(self.token), self.bearer_body) + self.assertEqual(prepare_bearer_uri(self.token, uri=self.uri), self.bearer_uri) + + def test_valid_bearer_is_validated(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + request = Request("/", headers=self.bearer_headers) + result = BearerToken(request_validator=request_validator).validate_request( + request + ) + self.assertTrue(result) + + def test_lowercase_bearer_is_validated(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + request = Request("/", headers=self.valid_bearer_header_lowercase) + result = BearerToken(request_validator=request_validator).validate_request( + request + ) + self.assertTrue(result) + + def test_fake_bearer_is_not_validated(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + for fake_header in self.fake_bearer_headers: + request = Request("/", headers=fake_header) + result = BearerToken(request_validator=request_validator).validate_request( + request + ) + + self.assertFalse(result) + + def test_header_with_multispaces_is_validated(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + request = Request("/", headers=self.valid_header_with_multiple_spaces) + result = BearerToken(request_validator=request_validator).validate_request( + request + ) + + self.assertTrue(result) + + def test_estimate_type(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + request = Request("/", headers=self.bearer_headers) + result = BearerToken(request_validator=request_validator).estimate_type(request) + self.assertEqual(result, 9) + + def test_estimate_type_with_fake_header_returns_type_0(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + for fake_header in self.fake_bearer_headers: + request = Request("/", headers=fake_header) + result = BearerToken(request_validator=request_validator).estimate_type( + request + ) + + if ( + fake_header["Authorization"].count(" ") == 2 + and fake_header["Authorization"].split()[0] == "Bearer" + ): + # If we're dealing with the header containing 2 spaces, it will be recognized + # as a Bearer valid header, the token itself will be invalid by the way. + self.assertEqual(result, 9) + else: + self.assertEqual(result, 0) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc6749/test_utils.py b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_utils.py new file mode 100644 index 0000000000..3299591926 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc6749/test_utils.py @@ -0,0 +1,100 @@ +import datetime +import os + +from oauthlib.oauth2.rfc6749.utils import ( + escape, generate_age, host_from_uri, is_secure_transport, list_to_scope, + params_from_uri, scope_to_list, +) + +from tests.unittest import TestCase + + +class ScopeObject: + """ + Fixture for testing list_to_scope()/scope_to_list() with objects other + than regular strings. + """ + def __init__(self, scope): + self.scope = scope + + def __str__(self): + return self.scope + + +class UtilsTests(TestCase): + + def test_escape(self): + """Assert that we are only escaping unicode""" + self.assertRaises(ValueError, escape, b"I am a string type. Not a unicode type.") + self.assertEqual(escape("I am a unicode type."), "I%20am%20a%20unicode%20type.") + + def test_host_from_uri(self): + """Test if hosts and ports are properly extracted from URIs. + + This should be done according to the MAC Authentication spec. + Defaults ports should be provided when none is present in the URI. + """ + self.assertEqual(host_from_uri('http://a.b-c.com:8080'), ('a.b-c.com', '8080')) + self.assertEqual(host_from_uri('https://a.b.com:8080'), ('a.b.com', '8080')) + self.assertEqual(host_from_uri('http://www.example.com'), ('www.example.com', '80')) + self.assertEqual(host_from_uri('https://www.example.com'), ('www.example.com', '443')) + + def test_is_secure_transport(self): + """Test check secure uri.""" + if 'OAUTHLIB_INSECURE_TRANSPORT' in os.environ: + del os.environ['OAUTHLIB_INSECURE_TRANSPORT'] + + self.assertTrue(is_secure_transport('https://example.com')) + self.assertFalse(is_secure_transport('http://example.com')) + + os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = '1' + self.assertTrue(is_secure_transport('http://example.com')) + del os.environ['OAUTHLIB_INSECURE_TRANSPORT'] + + def test_params_from_uri(self): + self.assertEqual(params_from_uri('http://i.b/?foo=bar&g&scope=a+d'), + {'foo': 'bar', 'g': '', 'scope': ['a', 'd']}) + + def test_generate_age(self): + issue_time = datetime.datetime.now() - datetime.timedelta( + days=3, minutes=1, seconds=4) + self.assertGreater(float(generate_age(issue_time)), 259263.0) + + def test_list_to_scope(self): + expected = 'foo bar baz' + + string_list = ['foo', 'bar', 'baz'] + self.assertEqual(list_to_scope(string_list), expected) + + string_tuple = ('foo', 'bar', 'baz') + self.assertEqual(list_to_scope(string_tuple), expected) + + obj_list = [ScopeObject('foo'), ScopeObject('bar'), ScopeObject('baz')] + self.assertEqual(list_to_scope(obj_list), expected) + + set_list = set(string_list) + set_scope = list_to_scope(set_list) + assert len(set_scope.split(' ')) == 3 + for x in string_list: + assert x in set_scope + + self.assertRaises(ValueError, list_to_scope, object()) + + def test_scope_to_list(self): + expected = ['foo', 'bar', 'baz'] + + string_scopes = 'foo bar baz ' + self.assertEqual(scope_to_list(string_scopes), expected) + + string_list_scopes = ['foo', 'bar', 'baz'] + self.assertEqual(scope_to_list(string_list_scopes), expected) + + tuple_list_scopes = ('foo', 'bar', 'baz') + self.assertEqual(scope_to_list(tuple_list_scopes), expected) + + obj_list_scopes = [ScopeObject('foo'), ScopeObject('bar'), ScopeObject('baz')] + self.assertEqual(scope_to_list(obj_list_scopes), expected) + + set_list_scopes = set(string_list_scopes) + set_list = scope_to_list(set_list_scopes) + self.assertEqual(sorted(set_list), sorted(string_list_scopes)) diff --git a/contrib/python/oauthlib/tests/oauth2/rfc8628/__init__.py b/contrib/python/oauthlib/tests/oauth2/rfc8628/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc8628/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth2/rfc8628/clients/__init__.py b/contrib/python/oauthlib/tests/oauth2/rfc8628/clients/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc8628/clients/__init__.py diff --git a/contrib/python/oauthlib/tests/oauth2/rfc8628/clients/test_device.py b/contrib/python/oauthlib/tests/oauth2/rfc8628/clients/test_device.py new file mode 100644 index 0000000000..725dea2a92 --- /dev/null +++ b/contrib/python/oauthlib/tests/oauth2/rfc8628/clients/test_device.py @@ -0,0 +1,63 @@ +import os +from unittest.mock import patch + +from oauthlib import signals +from oauthlib.oauth2 import DeviceClient + +from tests.unittest import TestCase + + +class DeviceClientTest(TestCase): + + client_id = "someclientid" + kwargs = { + "some": "providers", + "require": "extra arguments" + } + + client_secret = "asecret" + + device_code = "somedevicecode" + + scope = ["profile", "email"] + + body = "not=empty" + + body_up = "not=empty&grant_type=urn:ietf:params:oauth:grant-type:device_code" + body_code = body_up + "&device_code=somedevicecode" + body_kwargs = body_code + "&some=providers&require=extra+arguments" + + uri = "https://example.com/path?query=world" + uri_id = uri + "&client_id=" + client_id + uri_grant = uri_id + "&grant_type=urn:ietf:params:oauth:grant-type:device_code" + uri_secret = uri_grant + "&client_secret=asecret" + uri_scope = uri_secret + "&scope=profile+email" + + def test_request_body(self): + client = DeviceClient(self.client_id) + + # Basic, no extra arguments + body = client.prepare_request_body(self.device_code, body=self.body) + self.assertFormBodyEqual(body, self.body_code) + + rclient = DeviceClient(self.client_id) + body = rclient.prepare_request_body(self.device_code, body=self.body) + self.assertFormBodyEqual(body, self.body_code) + + # With extra parameters + body = client.prepare_request_body( + self.device_code, body=self.body, **self.kwargs) + self.assertFormBodyEqual(body, self.body_kwargs) + + def test_request_uri(self): + client = DeviceClient(self.client_id) + + uri = client.prepare_request_uri(self.uri) + self.assertURLEqual(uri, self.uri_grant) + + client = DeviceClient(self.client_id, client_secret=self.client_secret) + uri = client.prepare_request_uri(self.uri) + self.assertURLEqual(uri, self.uri_secret) + + uri = client.prepare_request_uri(self.uri, scope=self.scope) + self.assertURLEqual(uri, self.uri_scope) diff --git a/contrib/python/oauthlib/tests/openid/__init__.py b/contrib/python/oauthlib/tests/openid/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/__init__.py diff --git a/contrib/python/oauthlib/tests/openid/connect/__init__.py b/contrib/python/oauthlib/tests/openid/connect/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/__init__.py diff --git a/contrib/python/oauthlib/tests/openid/connect/core/__init__.py b/contrib/python/oauthlib/tests/openid/connect/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/__init__.py diff --git a/contrib/python/oauthlib/tests/openid/connect/core/endpoints/__init__.py b/contrib/python/oauthlib/tests/openid/connect/core/endpoints/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/endpoints/__init__.py diff --git a/contrib/python/oauthlib/tests/openid/connect/core/endpoints/test_claims_handling.py b/contrib/python/oauthlib/tests/openid/connect/core/endpoints/test_claims_handling.py new file mode 100644 index 0000000000..301ed1aa44 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/endpoints/test_claims_handling.py @@ -0,0 +1,107 @@ +"""Ensure OpenID Connect Authorization Request 'claims' are preserved across authorization. + +The claims parameter is an optional query param for the Authorization Request endpoint + but if it is provided and is valid it needs to be deserialized (from urlencoded JSON) + and persisted with the authorization code itself, then in the subsequent Access Token + request the claims should be transferred (via the oauthlib request) to be persisted + with the Access Token when it is created. +""" +from unittest import mock + +from oauthlib.openid import RequestValidator +from oauthlib.openid.connect.core.endpoints.pre_configured import Server + +from __tests__.oauth2.rfc6749.endpoints.test_utils import get_query_credentials +from tests.unittest import TestCase + + +class TestClaimsHandling(TestCase): + + DEFAULT_REDIRECT_URI = 'http://i.b./path' + + def set_scopes(self, scopes): + def set_request_scopes(client_id, code, client, request): + request.scopes = scopes + return True + return set_request_scopes + + def set_user(self, request): + request.user = 'foo' + request.client_id = 'bar' + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def save_claims_with_code(self, client_id, code, request, *args, **kwargs): + # a real validator would save the claims with the code during save_authorization_code() + self.claims_from_auth_code_request = request.claims + self.scopes = request.scopes.split() + + def retrieve_claims_saved_with_code(self, client_id, code, client, request, *args, **kwargs): + request.claims = self.claims_from_auth_code_request + request.scopes = self.scopes + + return True + + def save_claims_with_bearer_token(self, token, request, *args, **kwargs): + # a real validator would save the claims with the access token during save_bearer_token() + self.claims_saved_with_bearer_token = request.claims + + def setUp(self): + self.validator = mock.MagicMock(spec=RequestValidator) + self.validator.get_code_challenge.return_value = None + self.validator.get_default_redirect_uri.return_value = TestClaimsHandling.DEFAULT_REDIRECT_URI + self.validator.authenticate_client.side_effect = self.set_client + + self.validator.save_authorization_code.side_effect = self.save_claims_with_code + self.validator.validate_code.side_effect = self.retrieve_claims_saved_with_code + self.validator.save_token.side_effect = self.save_claims_with_bearer_token + + self.server = Server(self.validator) + + def test_claims_stored_on_code_creation(self): + + claims = { + "id_token": { + "claim_1": None, + "claim_2": { + "essential": True + } + }, + "userinfo": { + "claim_3": { + "essential": True + }, + "claim_4": None + } + } + + claims_urlquoted = '%7B%22id_token%22%3A%20%7B%22claim_2%22%3A%20%7B%22essential%22%3A%20true%7D%2C%20%22claim_1%22%3A%20null%7D%2C%20%22userinfo%22%3A%20%7B%22claim_4%22%3A%20null%2C%20%22claim_3%22%3A%20%7B%22essential%22%3A%20true%7D%7D%7D' + uri = 'http://example.com/path?client_id=abc&scope=openid+test_scope&response_type=code&claims=%s' + + h, b, s = self.server.create_authorization_response(uri % claims_urlquoted, scopes='openid test_scope') + + self.assertDictEqual(self.claims_from_auth_code_request, claims) + + code = get_query_credentials(h['Location'])['code'][0] + token_uri = 'http://example.com/path' + _, body, _ = self.server.create_token_response( + token_uri, + body='client_id=me&redirect_uri=http://back.to/me&grant_type=authorization_code&code=%s' % code + ) + + self.assertDictEqual(self.claims_saved_with_bearer_token, claims) + + def test_invalid_claims(self): + uri = 'http://example.com/path?client_id=abc&scope=openid+test_scope&response_type=code&claims=this-is-not-json' + + h, b, s = self.server.create_authorization_response(uri, scopes='openid test_scope') + error = get_query_credentials(h['Location'])['error'][0] + error_desc = get_query_credentials(h['Location'])['error_description'][0] + self.assertEqual(error, 'invalid_request') + self.assertEqual(error_desc, "Malformed claims parameter") diff --git a/contrib/python/oauthlib/tests/openid/connect/core/endpoints/test_openid_connect_params_handling.py b/contrib/python/oauthlib/tests/openid/connect/core/endpoints/test_openid_connect_params_handling.py new file mode 100644 index 0000000000..c55136fbf1 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/endpoints/test_openid_connect_params_handling.py @@ -0,0 +1,78 @@ +from unittest import mock +from urllib.parse import urlencode + +from oauthlib.oauth2 import InvalidRequestError +from oauthlib.oauth2.rfc6749.endpoints.authorization import ( + AuthorizationEndpoint, +) +from oauthlib.oauth2.rfc6749.tokens import BearerToken +from oauthlib.openid.connect.core.grant_types import AuthorizationCodeGrant + +from tests.unittest import TestCase + + +class OpenIDConnectEndpointTest(TestCase): + + def setUp(self): + self.mock_validator = mock.MagicMock() + self.mock_validator.authenticate_client.side_effect = self.set_client + grant = AuthorizationCodeGrant(request_validator=self.mock_validator) + bearer = BearerToken(self.mock_validator) + self.endpoint = AuthorizationEndpoint(grant, bearer, + response_types={'code': grant}) + params = { + 'prompt': 'consent', + 'display': 'touch', + 'nonce': 'abcd', + 'state': 'abc', + 'redirect_uri': 'https://a.b/cb', + 'response_type': 'code', + 'client_id': 'abcdef', + 'scope': 'hello openid' + } + self.url = 'http://a.b/path?' + urlencode(params) + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + @mock.patch('oauthlib.common.generate_token') + def test_authorization_endpoint_handles_prompt(self, generate_token): + generate_token.return_value = "MOCK_CODE" + # In the GET view: + scopes, creds = self.endpoint.validate_authorization_request(self.url) + # In the POST view: + creds['scopes'] = scopes + h, b, s = self.endpoint.create_authorization_response(self.url, + credentials=creds) + expected = 'https://a.b/cb?state=abc&code=MOCK_CODE' + self.assertURLEqual(h['Location'], expected) + self.assertIsNone(b) + self.assertEqual(s, 302) + + def test_prompt_none_exclusiveness(self): + """ + Test that prompt=none can't be used with another prompt value. + """ + params = { + 'prompt': 'none consent', + 'state': 'abc', + 'redirect_uri': 'https://a.b/cb', + 'response_type': 'code', + 'client_id': 'abcdef', + 'scope': 'hello openid' + } + url = 'http://a.b/path?' + urlencode(params) + with self.assertRaises(InvalidRequestError): + self.endpoint.validate_authorization_request(url) + + def test_oidc_params_preservation(self): + """ + Test that the nonce parameter is passed through. + """ + scopes, creds = self.endpoint.validate_authorization_request(self.url) + + self.assertEqual(creds['prompt'], {'consent'}) + self.assertEqual(creds['nonce'], 'abcd') + self.assertEqual(creds['display'], 'touch') diff --git a/contrib/python/oauthlib/tests/openid/connect/core/endpoints/test_userinfo_endpoint.py b/contrib/python/oauthlib/tests/openid/connect/core/endpoints/test_userinfo_endpoint.py new file mode 100644 index 0000000000..4833485195 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/endpoints/test_userinfo_endpoint.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +import json +from unittest import mock + +from oauthlib.oauth2.rfc6749 import errors +from oauthlib.openid import RequestValidator, UserInfoEndpoint + +from tests.unittest import TestCase + + +def set_scopes_valid(token, scopes, request): + request.scopes = ["openid", "bar"] + return True + + +class UserInfoEndpointTest(TestCase): + def setUp(self): + self.claims = { + "sub": "john", + "fruit": "banana" + } + # Can't use MagicMock/wraps below. + # Triggers error when endpoint copies to self.bearer.request_validator + self.validator = RequestValidator() + self.validator.validate_bearer_token = mock.Mock() + self.validator.validate_bearer_token.side_effect = set_scopes_valid + self.validator.get_userinfo_claims = mock.Mock() + self.validator.get_userinfo_claims.return_value = self.claims + self.endpoint = UserInfoEndpoint(self.validator) + + self.uri = 'should_not_matter' + self.headers = { + 'Authorization': 'Bearer eyJxx' + } + + def test_userinfo_no_auth(self): + self.endpoint.create_userinfo_response(self.uri) + + def test_userinfo_wrong_auth(self): + self.headers['Authorization'] = 'Basic foifoifoi' + self.endpoint.create_userinfo_response(self.uri, headers=self.headers) + + def test_userinfo_token_expired(self): + self.validator.validate_bearer_token.return_value = False + self.endpoint.create_userinfo_response(self.uri, headers=self.headers) + + def test_userinfo_token_no_openid_scope(self): + def set_scopes_invalid(token, scopes, request): + request.scopes = ["foo", "bar"] + return True + self.validator.validate_bearer_token.side_effect = set_scopes_invalid + with self.assertRaises(errors.InsufficientScopeError) as context: + self.endpoint.create_userinfo_response(self.uri) + + def test_userinfo_json_response(self): + h, b, s = self.endpoint.create_userinfo_response(self.uri) + self.assertEqual(s, 200) + body_json = json.loads(b) + self.assertEqual(self.claims, body_json) + self.assertEqual("application/json", h['Content-Type']) + + def test_userinfo_jwt_response(self): + self.validator.get_userinfo_claims.return_value = "eyJzzzzz" + h, b, s = self.endpoint.create_userinfo_response(self.uri) + self.assertEqual(s, 200) + self.assertEqual(b, "eyJzzzzz") + self.assertEqual("application/jwt", h['Content-Type']) diff --git a/contrib/python/oauthlib/tests/openid/connect/core/grant_types/__init__.py b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/__init__.py diff --git a/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_authorization_code.py b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_authorization_code.py new file mode 100644 index 0000000000..49b03a7f7d --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_authorization_code.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- +import json +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749.errors import ( + ConsentRequired, InvalidRequestError, LoginRequired, +) +from oauthlib.oauth2.rfc6749.tokens import BearerToken +from oauthlib.openid.connect.core.grant_types.authorization_code import ( + AuthorizationCodeGrant, +) + +from __tests__.oauth2.rfc6749.grant_types.test_authorization_code import ( + AuthorizationCodeGrantTest, +) +from tests.unittest import TestCase + + +def get_id_token_mock(token, token_handler, request): + return "MOCKED_TOKEN" + + +class OpenIDAuthCodeInterferenceTest(AuthorizationCodeGrantTest): + """Test that OpenID don't interfere with normal OAuth 2 flows.""" + + def setUp(self): + super().setUp() + self.auth = AuthorizationCodeGrant(request_validator=self.mock_validator) + + +class OpenIDAuthCodeTest(TestCase): + + def setUp(self): + self.request = Request('http://a.b/path') + self.request.scopes = ('hello', 'openid') + self.request.expires_in = 1800 + self.request.client_id = 'abcdef' + self.request.code = '1234' + self.request.response_type = 'code' + self.request.grant_type = 'authorization_code' + self.request.redirect_uri = 'https://a.b/cb' + self.request.state = 'abc' + self.request.nonce = None + + self.mock_validator = mock.MagicMock() + self.mock_validator.authenticate_client.side_effect = self.set_client + self.mock_validator.get_code_challenge.return_value = None + self.mock_validator.get_id_token.side_effect = get_id_token_mock + self.auth = AuthorizationCodeGrant(request_validator=self.mock_validator) + + self.url_query = 'https://a.b/cb?code=abc&state=abc' + self.url_fragment = 'https://a.b/cb#code=abc&state=abc' + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + @mock.patch('oauthlib.common.generate_token') + def test_authorization(self, generate_token): + + scope, info = self.auth.validate_authorization_request(self.request) + + generate_token.return_value = 'abc' + bearer = BearerToken(self.mock_validator) + self.request.response_mode = 'query' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], self.url_query) + self.assertIsNone(b) + self.assertEqual(s, 302) + + self.request.response_mode = 'fragment' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], self.url_fragment, parse_fragment=True) + self.assertIsNone(b) + self.assertEqual(s, 302) + + @mock.patch('oauthlib.common.generate_token') + def test_no_prompt_authorization(self, generate_token): + generate_token.return_value = 'abc' + self.request.prompt = 'none' + + bearer = BearerToken(self.mock_validator) + + self.request.response_mode = 'query' + self.request.id_token_hint = 'me@email.com' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], self.url_query) + self.assertIsNone(b) + self.assertEqual(s, 302) + + # Test alternative response modes + self.request.response_mode = 'fragment' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], self.url_fragment, parse_fragment=True) + + # Ensure silent authentication and authorization is done + self.mock_validator.validate_silent_login.return_value = False + self.mock_validator.validate_silent_authorization.return_value = True + self.assertRaises(LoginRequired, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=login_required', h['Location']) + + self.mock_validator.validate_silent_login.return_value = True + self.mock_validator.validate_silent_authorization.return_value = False + self.assertRaises(ConsentRequired, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=consent_required', h['Location']) + + # ID token hint must match logged in user + self.mock_validator.validate_silent_authorization.return_value = True + self.mock_validator.validate_user_match.return_value = False + self.assertRaises(LoginRequired, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=login_required', h['Location']) + + def test_none_multi_prompt(self): + bearer = BearerToken(self.mock_validator) + + self.request.prompt = 'none login' + self.assertRaises(InvalidRequestError, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + + self.request.prompt = 'none consent' + self.assertRaises(InvalidRequestError, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + + self.request.prompt = 'none select_account' + self.assertRaises(InvalidRequestError, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + + self.request.prompt = 'consent none login' + self.assertRaises(InvalidRequestError, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + + def set_scopes(self, client_id, code, client, request): + request.scopes = self.request.scopes + request.user = 'bob' + return True + + def test_create_token_response(self): + self.request.response_type = None + self.mock_validator.validate_code.side_effect = self.set_scopes + + bearer = BearerToken(self.mock_validator) + + h, token, s = self.auth.create_token_response(self.request, bearer) + token = json.loads(token) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('refresh_token', token) + self.assertIn('expires_in', token) + self.assertIn('scope', token) + self.assertIn('id_token', token) + self.assertIn('openid', token['scope']) + + self.mock_validator.reset_mock() + + self.request.scopes = ('hello', 'world') + h, token, s = self.auth.create_token_response(self.request, bearer) + token = json.loads(token) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('refresh_token', token) + self.assertIn('expires_in', token) + self.assertIn('scope', token) + self.assertNotIn('id_token', token) + self.assertNotIn('openid', token['scope']) + + @mock.patch('oauthlib.common.generate_token') + def test_optional_nonce(self, generate_token): + generate_token.return_value = 'abc' + self.request.nonce = 'xyz' + scope, info = self.auth.validate_authorization_request(self.request) + + bearer = BearerToken(self.mock_validator) + self.request.response_mode = 'query' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], self.url_query) + self.assertIsNone(b) + self.assertEqual(s, 302) diff --git a/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_base.py b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_base.py new file mode 100644 index 0000000000..a88834b807 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_base.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +import time +from unittest import mock + +from oauthlib.common import Request +from oauthlib.openid.connect.core.grant_types.base import GrantTypeBase + +from tests.unittest import TestCase + + +class GrantBase(GrantTypeBase): + """Class to test GrantTypeBase""" + def __init__(self, request_validator=None, **kwargs): + self.request_validator = request_validator + + +class IDTokenTest(TestCase): + + def setUp(self): + self.request = Request('http://a.b/path') + self.request.scopes = ('hello', 'openid') + self.request.expires_in = 1800 + self.request.client_id = 'abcdef' + self.request.code = '1234' + self.request.response_type = 'id_token' + self.request.grant_type = 'authorization_code' + self.request.redirect_uri = 'https://a.b/cb' + self.request.state = 'abc' + self.request.nonce = None + + self.mock_validator = mock.MagicMock() + self.mock_validator.get_id_token.return_value = None + self.mock_validator.finalize_id_token.return_value = "eyJ.body.signature" + self.token = {} + + self.grant = GrantBase(request_validator=self.mock_validator) + + self.url_query = 'https://a.b/cb?code=abc&state=abc' + self.url_fragment = 'https://a.b/cb#code=abc&state=abc' + + def test_id_token_hash(self): + self.assertEqual(self.grant.id_token_hash( + "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk", + ), "LDktKdoQak3Pk0cnXxCltA", "hash differs from RFC") + + def test_get_id_token_no_openid(self): + self.request.scopes = ('hello') + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertNotIn("id_token", token) + + self.request.scopes = None + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertNotIn("id_token", token) + + self.request.scopes = () + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertNotIn("id_token", token) + + def test_get_id_token(self): + self.mock_validator.get_id_token.return_value = "toto" + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "toto") + + def test_finalize_id_token(self): + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['aud'], 'abcdef') + self.assertGreaterEqual(int(time.time()), id_token['iat']) + + def test_finalize_id_token_with_nonce(self): + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request, "my_nonce") + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['nonce'], 'my_nonce') + + def test_finalize_id_token_with_at_hash(self): + self.token["access_token"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk" + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['at_hash'], 'LDktKdoQak3Pk0cnXxCltA') + + def test_finalize_id_token_with_c_hash(self): + self.token["code"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk" + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['c_hash'], 'LDktKdoQak3Pk0cnXxCltA') + + def test_finalize_id_token_with_c_and_at_hash(self): + self.token["code"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk" + self.token["access_token"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk" + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['at_hash'], 'LDktKdoQak3Pk0cnXxCltA') + self.assertEqual(id_token['c_hash'], 'LDktKdoQak3Pk0cnXxCltA') diff --git a/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_dispatchers.py b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_dispatchers.py new file mode 100644 index 0000000000..ccbada490d --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_dispatchers.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749.grant_types import ( + AuthorizationCodeGrant as OAuth2AuthorizationCodeGrant, + ImplicitGrant as OAuth2ImplicitGrant, +) +from oauthlib.openid.connect.core.grant_types.authorization_code import ( + AuthorizationCodeGrant, +) +from oauthlib.openid.connect.core.grant_types.dispatchers import ( + AuthorizationTokenGrantDispatcher, ImplicitTokenGrantDispatcher, +) +from oauthlib.openid.connect.core.grant_types.implicit import ImplicitGrant + +from tests.unittest import TestCase + + +class ImplicitTokenGrantDispatcherTest(TestCase): + def setUp(self): + self.request = Request('http://a.b/path') + request_validator = mock.MagicMock() + implicit_grant = OAuth2ImplicitGrant(request_validator) + openid_connect_implicit = ImplicitGrant(request_validator) + + self.dispatcher = ImplicitTokenGrantDispatcher( + default_grant=implicit_grant, + oidc_grant=openid_connect_implicit + ) + + def test_create_authorization_response_openid(self): + self.request.scopes = ('hello', 'openid') + self.request.response_type = 'id_token' + handler = self.dispatcher._handler_for_request(self.request) + self.assertIsInstance(handler, ImplicitGrant) + + def test_validate_authorization_request_openid(self): + self.request.scopes = ('hello', 'openid') + self.request.response_type = 'id_token' + handler = self.dispatcher._handler_for_request(self.request) + self.assertIsInstance(handler, ImplicitGrant) + + def test_create_authorization_response_oauth(self): + self.request.scopes = ('hello', 'world') + handler = self.dispatcher._handler_for_request(self.request) + self.assertIsInstance(handler, OAuth2ImplicitGrant) + + def test_validate_authorization_request_oauth(self): + self.request.scopes = ('hello', 'world') + handler = self.dispatcher._handler_for_request(self.request) + self.assertIsInstance(handler, OAuth2ImplicitGrant) + + +class DispatcherTest(TestCase): + def setUp(self): + self.request = Request('http://a.b/path') + self.request.decoded_body = ( + ("client_id", "me"), + ("code", "code"), + ("redirect_url", "https://a.b/cb"), + ) + + self.request_validator = mock.MagicMock() + self.auth_grant = OAuth2AuthorizationCodeGrant(self.request_validator) + self.openid_connect_auth = AuthorizationCodeGrant(self.request_validator) + + +class AuthTokenGrantDispatcherOpenIdTest(DispatcherTest): + + def setUp(self): + super().setUp() + self.request_validator.get_authorization_code_scopes.return_value = ('hello', 'openid') + self.dispatcher = AuthorizationTokenGrantDispatcher( + self.request_validator, + default_grant=self.auth_grant, + oidc_grant=self.openid_connect_auth + ) + + def test_create_token_response_openid(self): + handler = self.dispatcher._handler_for_request(self.request) + self.assertIsInstance(handler, AuthorizationCodeGrant) + self.assertTrue(self.dispatcher.request_validator.get_authorization_code_scopes.called) + + +class AuthTokenGrantDispatcherOpenIdWithoutCodeTest(DispatcherTest): + + def setUp(self): + super().setUp() + self.request.decoded_body = ( + ("client_id", "me"), + ("code", ""), + ("redirect_url", "https://a.b/cb"), + ) + self.request_validator.get_authorization_code_scopes.return_value = ('hello', 'openid') + self.dispatcher = AuthorizationTokenGrantDispatcher( + self.request_validator, + default_grant=self.auth_grant, + oidc_grant=self.openid_connect_auth + ) + + def test_create_token_response_openid_without_code(self): + handler = self.dispatcher._handler_for_request(self.request) + self.assertIsInstance(handler, OAuth2AuthorizationCodeGrant) + self.assertFalse(self.dispatcher.request_validator.get_authorization_code_scopes.called) + + +class AuthTokenGrantDispatcherOAuthTest(DispatcherTest): + + def setUp(self): + super().setUp() + self.request_validator.get_authorization_code_scopes.return_value = ('hello', 'world') + self.dispatcher = AuthorizationTokenGrantDispatcher( + self.request_validator, + default_grant=self.auth_grant, + oidc_grant=self.openid_connect_auth + ) + + def test_create_token_response_oauth(self): + handler = self.dispatcher._handler_for_request(self.request) + self.assertIsInstance(handler, OAuth2AuthorizationCodeGrant) + self.assertTrue(self.dispatcher.request_validator.get_authorization_code_scopes.called) diff --git a/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_hybrid.py b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_hybrid.py new file mode 100644 index 0000000000..111c8c5c4b --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_hybrid.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +from unittest import mock + +from oauthlib.oauth2.rfc6749 import errors +from oauthlib.oauth2.rfc6749.tokens import BearerToken +from oauthlib.openid.connect.core.grant_types.hybrid import HybridGrant + +from __tests__.oauth2.rfc6749.grant_types.test_authorization_code import ( + AuthorizationCodeGrantTest, +) + +from .test_authorization_code import OpenIDAuthCodeTest + + +class OpenIDHybridInterferenceTest(AuthorizationCodeGrantTest): + """Test that OpenID don't interfere with normal OAuth 2 flows.""" + + def setUp(self): + super().setUp() + self.auth = HybridGrant(request_validator=self.mock_validator) + + +class OpenIDHybridCodeTokenTest(OpenIDAuthCodeTest): + + def setUp(self): + super().setUp() + self.request.response_type = 'code token' + self.request.nonce = None + self.auth = HybridGrant(request_validator=self.mock_validator) + self.url_query = 'https://a.b/cb?code=abc&state=abc&token_type=Bearer&expires_in=3600&scope=hello+openid&access_token=abc' + self.url_fragment = 'https://a.b/cb#code=abc&state=abc&token_type=Bearer&expires_in=3600&scope=hello+openid&access_token=abc' + + @mock.patch('oauthlib.common.generate_token') + def test_optional_nonce(self, generate_token): + generate_token.return_value = 'abc' + self.request.nonce = 'xyz' + scope, info = self.auth.validate_authorization_request(self.request) + + bearer = BearerToken(self.mock_validator) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], self.url_fragment, parse_fragment=True) + self.assertIsNone(b) + self.assertEqual(s, 302) + + +class OpenIDHybridCodeIdTokenTest(OpenIDAuthCodeTest): + + def setUp(self): + super().setUp() + self.mock_validator.get_code_challenge.return_value = None + self.request.response_type = 'code id_token' + self.request.nonce = 'zxc' + self.auth = HybridGrant(request_validator=self.mock_validator) + token = 'MOCKED_TOKEN' + self.url_query = 'https://a.b/cb?code=abc&state=abc&id_token=%s' % token + self.url_fragment = 'https://a.b/cb#code=abc&state=abc&id_token=%s' % token + + @mock.patch('oauthlib.common.generate_token') + def test_required_nonce(self, generate_token): + generate_token.return_value = 'abc' + self.request.nonce = None + self.assertRaises(errors.InvalidRequestError, self.auth.validate_authorization_request, self.request) + + bearer = BearerToken(self.mock_validator) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + self.assertIsNone(b) + self.assertEqual(s, 302) + + def test_id_token_contains_nonce(self): + token = {} + self.mock_validator.get_id_token.side_effect = None + self.mock_validator.get_id_token.return_value = None + token = self.auth.add_id_token(token, None, self.request) + assert self.mock_validator.finalize_id_token.call_count == 1 + claims = self.mock_validator.finalize_id_token.call_args[0][0] + assert "nonce" in claims + + +class OpenIDHybridCodeIdTokenTokenTest(OpenIDAuthCodeTest): + + def setUp(self): + super().setUp() + self.mock_validator.get_code_challenge.return_value = None + self.request.response_type = 'code id_token token' + self.request.nonce = 'xyz' + self.auth = HybridGrant(request_validator=self.mock_validator) + token = 'MOCKED_TOKEN' + self.url_query = 'https://a.b/cb?code=abc&state=abc&token_type=Bearer&expires_in=3600&scope=hello+openid&access_token=abc&id_token=%s' % token + self.url_fragment = 'https://a.b/cb#code=abc&state=abc&token_type=Bearer&expires_in=3600&scope=hello+openid&access_token=abc&id_token=%s' % token + + @mock.patch('oauthlib.common.generate_token') + def test_required_nonce(self, generate_token): + generate_token.return_value = 'abc' + self.request.nonce = None + self.assertRaises(errors.InvalidRequestError, self.auth.validate_authorization_request, self.request) + + bearer = BearerToken(self.mock_validator) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + self.assertIsNone(b) + self.assertEqual(s, 302) diff --git a/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_implicit.py b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_implicit.py new file mode 100644 index 0000000000..825093138c --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_implicit.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749 import errors +from oauthlib.oauth2.rfc6749.tokens import BearerToken +from oauthlib.openid.connect.core.grant_types.implicit import ImplicitGrant + +from __tests__.oauth2.rfc6749.grant_types.test_implicit import ImplicitGrantTest +from tests.unittest import TestCase + +from .test_authorization_code import get_id_token_mock + + +class OpenIDImplicitInterferenceTest(ImplicitGrantTest): + """Test that OpenID don't interfere with normal OAuth 2 flows.""" + + def setUp(self): + super().setUp() + self.auth = ImplicitGrant(request_validator=self.mock_validator) + + +class OpenIDImplicitTest(TestCase): + + def setUp(self): + self.request = Request('http://a.b/path') + self.request.scopes = ('hello', 'openid') + self.request.expires_in = 1800 + self.request.client_id = 'abcdef' + self.request.response_type = 'id_token token' + self.request.redirect_uri = 'https://a.b/cb' + self.request.state = 'abc' + self.request.nonce = 'xyz' + + self.mock_validator = mock.MagicMock() + self.mock_validator.get_id_token.side_effect = get_id_token_mock + self.auth = ImplicitGrant(request_validator=self.mock_validator) + + token = 'MOCKED_TOKEN' + self.url_query = 'https://a.b/cb?state=abc&token_type=Bearer&expires_in=3600&scope=hello+openid&access_token=abc&id_token=%s' % token + self.url_fragment = 'https://a.b/cb#state=abc&token_type=Bearer&expires_in=3600&scope=hello+openid&access_token=abc&id_token=%s' % token + + @mock.patch('oauthlib.common.generate_token') + def test_authorization(self, generate_token): + scope, info = self.auth.validate_authorization_request(self.request) + + generate_token.return_value = 'abc' + bearer = BearerToken(self.mock_validator) + + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], self.url_fragment, parse_fragment=True) + self.assertIsNone(b) + self.assertEqual(s, 302) + + self.request.response_type = 'id_token' + token = 'MOCKED_TOKEN' + url = 'https://a.b/cb#state=abc&id_token=%s' % token + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], url, parse_fragment=True) + self.assertIsNone(b) + self.assertEqual(s, 302) + + @mock.patch('oauthlib.common.generate_token') + def test_no_prompt_authorization(self, generate_token): + generate_token.return_value = 'abc' + self.request.prompt = 'none' + + bearer = BearerToken(self.mock_validator) + + self.request.response_mode = 'query' + self.request.id_token_hint = 'me@email.com' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], self.url_query) + self.assertIsNone(b) + self.assertEqual(s, 302) + + # Test alternative response modes + self.request.response_mode = 'fragment' + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertURLEqual(h['Location'], self.url_fragment, parse_fragment=True) + + # Ensure silent authentication and authorization is done + self.mock_validator.validate_silent_login.return_value = False + self.mock_validator.validate_silent_authorization.return_value = True + self.assertRaises(errors.LoginRequired, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=login_required', h['Location']) + + self.mock_validator.validate_silent_login.return_value = True + self.mock_validator.validate_silent_authorization.return_value = False + self.assertRaises(errors.ConsentRequired, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=consent_required', h['Location']) + + # ID token hint must match logged in user + self.mock_validator.validate_silent_authorization.return_value = True + self.mock_validator.validate_user_match.return_value = False + self.assertRaises(errors.LoginRequired, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=login_required', h['Location']) + + def test_none_multi_prompt(self): + bearer = BearerToken(self.mock_validator) + + self.request.prompt = 'none login' + self.assertRaises(errors.InvalidRequestError, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + + self.request.prompt = 'none consent' + self.assertRaises(errors.InvalidRequestError, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + + self.request.prompt = 'none select_account' + self.assertRaises(errors.InvalidRequestError, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + + self.request.prompt = 'consent none login' + self.assertRaises(errors.InvalidRequestError, + self.auth.validate_authorization_request, + self.request) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + + @mock.patch('oauthlib.common.generate_token') + def test_required_nonce(self, generate_token): + generate_token.return_value = 'abc' + self.request.nonce = None + self.assertRaises(errors.InvalidRequestError, self.auth.validate_authorization_request, self.request) + + bearer = BearerToken(self.mock_validator) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + self.assertIsNone(b) + self.assertEqual(s, 302) + + +class OpenIDImplicitNoAccessTokenTest(OpenIDImplicitTest): + def setUp(self): + super().setUp() + self.request.response_type = 'id_token' + token = 'MOCKED_TOKEN' + self.url_query = 'https://a.b/cb?state=abc&id_token=%s' % token + self.url_fragment = 'https://a.b/cb#state=abc&id_token=%s' % token + + @mock.patch('oauthlib.common.generate_token') + def test_required_nonce(self, generate_token): + generate_token.return_value = 'abc' + self.request.nonce = None + self.assertRaises(errors.InvalidRequestError, self.auth.validate_authorization_request, self.request) + + bearer = BearerToken(self.mock_validator) + h, b, s = self.auth.create_authorization_response(self.request, bearer) + self.assertIn('error=invalid_request', h['Location']) + self.assertIsNone(b) + self.assertEqual(s, 302) diff --git a/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_refresh_token.py b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_refresh_token.py new file mode 100644 index 0000000000..2e363fef1a --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/grant_types/test_refresh_token.py @@ -0,0 +1,105 @@ +import json +from unittest import mock + +from oauthlib.common import Request +from oauthlib.oauth2.rfc6749.tokens import BearerToken +from oauthlib.openid.connect.core.grant_types import RefreshTokenGrant + +from __tests__.oauth2.rfc6749.grant_types.test_refresh_token import ( + RefreshTokenGrantTest, +) +from tests.unittest import TestCase + + +def get_id_token_mock(token, token_handler, request): + return "MOCKED_TOKEN" + + +class OpenIDRefreshTokenInterferenceTest(RefreshTokenGrantTest): + """Test that OpenID don't interfere with normal OAuth 2 flows.""" + + def setUp(self): + super().setUp() + self.auth = RefreshTokenGrant(request_validator=self.mock_validator) + + +class OpenIDRefreshTokenTest(TestCase): + + def setUp(self): + self.request = Request('http://a.b/path') + self.request.grant_type = 'refresh_token' + self.request.refresh_token = 'lsdkfhj230' + self.request.scope = ('hello', 'openid') + self.mock_validator = mock.MagicMock() + + self.mock_validator = mock.MagicMock() + self.mock_validator.authenticate_client.side_effect = self.set_client + self.mock_validator.get_id_token.side_effect = get_id_token_mock + self.auth = RefreshTokenGrant(request_validator=self.mock_validator) + + def set_client(self, request): + request.client = mock.MagicMock() + request.client.client_id = 'mocked' + return True + + def test_refresh_id_token(self): + self.mock_validator.get_original_scopes.return_value = [ + 'hello', 'openid' + ] + bearer = BearerToken(self.mock_validator) + + headers, body, status_code = self.auth.create_token_response( + self.request, bearer + ) + + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('refresh_token', token) + self.assertIn('id_token', token) + self.assertIn('token_type', token) + self.assertIn('expires_in', token) + self.assertEqual(token['scope'], 'hello openid') + self.mock_validator.refresh_id_token.assert_called_once_with( + self.request + ) + + def test_refresh_id_token_false(self): + self.mock_validator.refresh_id_token.return_value = False + self.mock_validator.get_original_scopes.return_value = [ + 'hello', 'openid' + ] + bearer = BearerToken(self.mock_validator) + + headers, body, status_code = self.auth.create_token_response( + self.request, bearer + ) + + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('refresh_token', token) + self.assertIn('token_type', token) + self.assertIn('expires_in', token) + self.assertEqual(token['scope'], 'hello openid') + self.assertNotIn('id_token', token) + self.mock_validator.refresh_id_token.assert_called_once_with( + self.request + ) + + def test_refresh_token_without_openid_scope(self): + self.request.scope = "hello" + bearer = BearerToken(self.mock_validator) + + headers, body, status_code = self.auth.create_token_response( + self.request, bearer + ) + + token = json.loads(body) + self.assertEqual(self.mock_validator.save_token.call_count, 1) + self.assertIn('access_token', token) + self.assertIn('refresh_token', token) + self.assertIn('token_type', token) + self.assertIn('expires_in', token) + self.assertNotIn('id_token', token) + self.assertEqual(token['scope'], 'hello') diff --git a/contrib/python/oauthlib/tests/openid/connect/core/test_request_validator.py b/contrib/python/oauthlib/tests/openid/connect/core/test_request_validator.py new file mode 100644 index 0000000000..6a800d41ca --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/test_request_validator.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +from oauthlib.openid import RequestValidator + +from tests.unittest import TestCase + + +class RequestValidatorTest(TestCase): + + def test_method_contracts(self): + v = RequestValidator() + self.assertRaises( + NotImplementedError, + v.get_authorization_code_scopes, + 'client_id', 'code', 'redirect_uri', 'request' + ) + self.assertRaises( + NotImplementedError, + v.get_jwt_bearer_token, + 'token', 'token_handler', 'request' + ) + self.assertRaises( + NotImplementedError, + v.finalize_id_token, + 'id_token', 'token', 'token_handler', 'request' + ) + self.assertRaises( + NotImplementedError, + v.validate_jwt_bearer_token, + 'token', 'scopes', 'request' + ) + self.assertRaises( + NotImplementedError, + v.validate_id_token, + 'token', 'scopes', 'request' + ) + self.assertRaises( + NotImplementedError, + v.validate_silent_authorization, + 'request' + ) + self.assertRaises( + NotImplementedError, + v.validate_silent_login, + 'request' + ) + self.assertRaises( + NotImplementedError, + v.validate_user_match, + 'id_token_hint', 'scopes', 'claims', 'request' + ) diff --git a/contrib/python/oauthlib/tests/openid/connect/core/test_server.py b/contrib/python/oauthlib/tests/openid/connect/core/test_server.py new file mode 100644 index 0000000000..47f0ecc842 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/test_server.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +import json +from unittest import mock + +from oauthlib.oauth2.rfc6749 import errors +from oauthlib.oauth2.rfc6749.endpoints.authorization import ( + AuthorizationEndpoint, +) +from oauthlib.oauth2.rfc6749.endpoints.token import TokenEndpoint +from oauthlib.oauth2.rfc6749.tokens import BearerToken +from oauthlib.openid.connect.core.grant_types.authorization_code import ( + AuthorizationCodeGrant, +) +from oauthlib.openid.connect.core.grant_types.hybrid import HybridGrant +from oauthlib.openid.connect.core.grant_types.implicit import ImplicitGrant + +from tests.unittest import TestCase + + +class AuthorizationEndpointTest(TestCase): + + def setUp(self): + self.mock_validator = mock.MagicMock() + self.mock_validator.get_code_challenge.return_value = None + self.addCleanup(setattr, self, 'mock_validator', mock.MagicMock()) + auth_code = AuthorizationCodeGrant(request_validator=self.mock_validator) + auth_code.save_authorization_code = mock.MagicMock() + implicit = ImplicitGrant( + request_validator=self.mock_validator) + implicit.save_token = mock.MagicMock() + hybrid = HybridGrant(self.mock_validator) + + response_types = { + 'code': auth_code, + 'token': implicit, + 'id_token': implicit, + 'id_token token': implicit, + 'code token': hybrid, + 'code id_token': hybrid, + 'code token id_token': hybrid, + 'none': auth_code + } + self.expires_in = 1800 + token = BearerToken( + self.mock_validator, + expires_in=self.expires_in + ) + self.endpoint = AuthorizationEndpoint( + default_response_type='code', + default_token_type=token, + response_types=response_types + ) + + # TODO: Add hybrid grant test + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_authorization_grant(self): + uri = 'http://i.b/l?response_type=code&client_id=me&scope=all+of+them&state=xyz' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me?code=abc&state=xyz') + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_implicit_grant(self): + uri = 'http://i.b/l?response_type=token&client_id=me&scope=all+of+them&state=xyz' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me#access_token=abc&expires_in=' + str(self.expires_in) + '&token_type=Bearer&state=xyz&scope=all+of+them', parse_fragment=True) + + def test_none_grant(self): + uri = 'http://i.b/l?response_type=none&client_id=me&scope=all+of+them&state=xyz' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me?state=xyz', parse_fragment=True) + self.assertIsNone(body) + self.assertEqual(status_code, 302) + + # and without the state parameter + uri = 'http://i.b/l?response_type=none&client_id=me&scope=all+of+them' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me', parse_fragment=True) + self.assertIsNone(body) + self.assertEqual(status_code, 302) + + def test_missing_type(self): + uri = 'http://i.b/l?client_id=me&scope=all+of+them' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + self.mock_validator.validate_request = mock.MagicMock( + side_effect=errors.InvalidRequestError()) + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me?error=invalid_request&error_description=Missing+response_type+parameter.') + + def test_invalid_type(self): + uri = 'http://i.b/l?response_type=invalid&client_id=me&scope=all+of+them' + uri += '&redirect_uri=http%3A%2F%2Fback.to%2Fme' + self.mock_validator.validate_request = mock.MagicMock( + side_effect=errors.UnsupportedResponseTypeError()) + headers, body, status_code = self.endpoint.create_authorization_response( + uri, scopes=['all', 'of', 'them']) + self.assertIn('Location', headers) + self.assertURLEqual(headers['Location'], 'http://back.to/me?error=unsupported_response_type') + + +class TokenEndpointTest(TestCase): + + def setUp(self): + def set_user(request): + request.user = mock.MagicMock() + request.client = mock.MagicMock() + request.client.client_id = 'mocked_client_id' + return True + + self.mock_validator = mock.MagicMock() + self.mock_validator.authenticate_client.side_effect = set_user + self.mock_validator.get_code_challenge.return_value = None + self.addCleanup(setattr, self, 'mock_validator', mock.MagicMock()) + auth_code = AuthorizationCodeGrant( + request_validator=self.mock_validator) + supported_types = { + 'authorization_code': auth_code, + } + self.expires_in = 1800 + token = BearerToken( + self.mock_validator, + expires_in=self.expires_in + ) + self.endpoint = TokenEndpoint( + 'authorization_code', + default_token_type=token, + grant_types=supported_types + ) + + @mock.patch('oauthlib.common.generate_token', new=lambda: 'abc') + def test_authorization_grant(self): + body = 'grant_type=authorization_code&code=abc&scope=all+of+them' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': 'abc', + 'refresh_token': 'abc', + 'scope': 'all of them' + } + self.assertEqual(json.loads(body), token) + + body = 'grant_type=authorization_code&code=abc' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + token = { + 'token_type': 'Bearer', + 'expires_in': self.expires_in, + 'access_token': 'abc', + 'refresh_token': 'abc' + } + self.assertEqual(json.loads(body), token) + + # ignore useless fields + body = 'grant_type=authorization_code&code=abc&state=foobar' + headers, body, status_code = self.endpoint.create_token_response( + '', body=body) + self.assertEqual(json.loads(body), token) + + def test_missing_type(self): + _, body, _ = self.endpoint.create_token_response('', body='') + token = {'error': 'unsupported_grant_type'} + self.assertEqual(json.loads(body), token) + + def test_invalid_type(self): + body = 'grant_type=invalid' + _, body, _ = self.endpoint.create_token_response('', body=body) + token = {'error': 'unsupported_grant_type'} + self.assertEqual(json.loads(body), token) diff --git a/contrib/python/oauthlib/tests/openid/connect/core/test_tokens.py b/contrib/python/oauthlib/tests/openid/connect/core/test_tokens.py new file mode 100644 index 0000000000..fe90142bb8 --- /dev/null +++ b/contrib/python/oauthlib/tests/openid/connect/core/test_tokens.py @@ -0,0 +1,157 @@ +from unittest import mock + +from oauthlib.openid.connect.core.tokens import JWTToken + +from tests.unittest import TestCase + + +class JWTTokenTestCase(TestCase): + + def test_create_token_callable_expires_in(self): + """ + Test retrieval of the expires in value by calling the callable expires_in property + """ + + expires_in_mock = mock.MagicMock() + request_mock = mock.MagicMock() + + token = JWTToken(expires_in=expires_in_mock, request_validator=mock.MagicMock()) + token.create_token(request=request_mock) + + expires_in_mock.assert_called_once_with(request_mock) + + def test_create_token_non_callable_expires_in(self): + """ + When a non callable expires in is set this should just be set to the request + """ + + expires_in_mock = mock.NonCallableMagicMock() + request_mock = mock.MagicMock() + + token = JWTToken(expires_in=expires_in_mock, request_validator=mock.MagicMock()) + token.create_token(request=request_mock) + + self.assertFalse(expires_in_mock.called) + self.assertEqual(request_mock.expires_in, expires_in_mock) + + def test_create_token_calls_get_id_token(self): + """ + When create_token is called the call should be forwarded to the get_id_token on the token validator + """ + request_mock = mock.MagicMock() + + with mock.patch('oauthlib.openid.RequestValidator', + autospec=True) as RequestValidatorMock: + + request_validator = RequestValidatorMock() + + token = JWTToken(expires_in=mock.MagicMock(), request_validator=request_validator) + token.create_token(request=request_mock) + + request_validator.get_jwt_bearer_token.assert_called_once_with(None, None, request_mock) + + def test_validate_request_token_from_headers(self): + """ + Bearer token get retrieved from headers. + """ + + with mock.patch('oauthlib.common.Request', autospec=True) as RequestMock, \ + mock.patch('oauthlib.openid.RequestValidator', + autospec=True) as RequestValidatorMock: + request_validator_mock = RequestValidatorMock() + + token = JWTToken(request_validator=request_validator_mock) + + request = RequestMock('/uri') + # Scopes is retrieved using the __call__ method which is not picked up correctly by mock.patch + # with autospec=True + request.scopes = mock.MagicMock() + request.headers = { + 'Authorization': 'Bearer some-token-from-header' + } + + token.validate_request(request=request) + + request_validator_mock.validate_jwt_bearer_token.assert_called_once_with('some-token-from-header', + request.scopes, + request) + + def test_validate_request_token_from_headers_basic(self): + """ + Wrong kind of token (Basic) retrieved from headers. Confirm token is not parsed. + """ + + with mock.patch('oauthlib.common.Request', autospec=True) as RequestMock, \ + mock.patch('oauthlib.openid.RequestValidator', + autospec=True) as RequestValidatorMock: + request_validator_mock = RequestValidatorMock() + + token = JWTToken(request_validator=request_validator_mock) + + request = RequestMock('/uri') + # Scopes is retrieved using the __call__ method which is not picked up correctly by mock.patch + # with autospec=True + request.scopes = mock.MagicMock() + request.headers = { + 'Authorization': 'Basic some-token-from-header' + } + + token.validate_request(request=request) + + request_validator_mock.validate_jwt_bearer_token.assert_called_once_with(None, + request.scopes, + request) + + def test_validate_token_from_request(self): + """ + Token get retrieved from request object. + """ + + with mock.patch('oauthlib.common.Request', autospec=True) as RequestMock, \ + mock.patch('oauthlib.openid.RequestValidator', + autospec=True) as RequestValidatorMock: + request_validator_mock = RequestValidatorMock() + + token = JWTToken(request_validator=request_validator_mock) + + request = RequestMock('/uri') + # Scopes is retrieved using the __call__ method which is not picked up correctly by mock.patch + # with autospec=True + request.scopes = mock.MagicMock() + request.access_token = 'some-token-from-request-object' + request.headers = {} + + token.validate_request(request=request) + + request_validator_mock.validate_jwt_bearer_token.assert_called_once_with('some-token-from-request-object', + request.scopes, + request) + + def test_estimate_type(self): + """ + Estimate type results for a jwt token + """ + + def test_token(token, expected_result): + with mock.patch('oauthlib.common.Request', autospec=True) as RequestMock: + jwt_token = JWTToken() + + request = RequestMock('/uri') + # Scopes is retrieved using the __call__ method which is not picked up correctly by mock.patch + # with autospec=True + request.headers = { + 'Authorization': 'Bearer {}'.format(token) + } + + result = jwt_token.estimate_type(request=request) + + self.assertEqual(result, expected_result) + + test_items = ( + ('eyfoo.foo.foo', 10), + ('eyfoo.foo.foo.foo.foo', 10), + ('eyfoobar', 0) + ) + + for token, expected_result in test_items: + test_token(token, expected_result) diff --git a/contrib/python/oauthlib/tests/test_common.py b/contrib/python/oauthlib/tests/test_common.py new file mode 100644 index 0000000000..7f0e35bc9c --- /dev/null +++ b/contrib/python/oauthlib/tests/test_common.py @@ -0,0 +1,243 @@ +# -*- coding: utf-8 -*- +import oauthlib +from oauthlib.common import ( + CaseInsensitiveDict, Request, add_params_to_uri, extract_params, + generate_client_id, generate_nonce, generate_timestamp, generate_token, + urldecode, +) + +from tests.unittest import TestCase + +PARAMS_DICT = {'foo': 'bar', 'baz': '123', } +PARAMS_TWOTUPLE = [('foo', 'bar'), ('baz', '123')] +PARAMS_FORMENCODED = 'foo=bar&baz=123' +URI = 'http://www.someuri.com' + + +class EncodingTest(TestCase): + + def test_urldecode(self): + self.assertCountEqual(urldecode(''), []) + self.assertCountEqual(urldecode('='), [('', '')]) + self.assertCountEqual(urldecode('%20'), [(' ', '')]) + self.assertCountEqual(urldecode('+'), [(' ', '')]) + self.assertCountEqual(urldecode('c2'), [('c2', '')]) + self.assertCountEqual(urldecode('c2='), [('c2', '')]) + self.assertCountEqual(urldecode('foo=bar'), [('foo', 'bar')]) + self.assertCountEqual(urldecode('foo_%20~=.bar-'), + [('foo_ ~', '.bar-')]) + self.assertCountEqual(urldecode('foo=1,2,3'), [('foo', '1,2,3')]) + self.assertCountEqual(urldecode('foo=(1,2,3)'), [('foo', '(1,2,3)')]) + self.assertCountEqual(urldecode('foo=bar.*'), [('foo', 'bar.*')]) + self.assertCountEqual(urldecode('foo=bar@spam'), [('foo', 'bar@spam')]) + self.assertCountEqual(urldecode('foo=bar/baz'), [('foo', 'bar/baz')]) + self.assertCountEqual(urldecode('foo=bar?baz'), [('foo', 'bar?baz')]) + self.assertCountEqual(urldecode('foo=bar\'s'), [('foo', 'bar\'s')]) + self.assertCountEqual(urldecode('foo=$'), [('foo', '$')]) + self.assertRaises(ValueError, urldecode, 'foo bar') + self.assertRaises(ValueError, urldecode, '%R') + self.assertRaises(ValueError, urldecode, '%RA') + self.assertRaises(ValueError, urldecode, '%AR') + self.assertRaises(ValueError, urldecode, '%RR') + + +class ParameterTest(TestCase): + + def test_extract_params_dict(self): + self.assertCountEqual(extract_params(PARAMS_DICT), PARAMS_TWOTUPLE) + + def test_extract_params_twotuple(self): + self.assertCountEqual(extract_params(PARAMS_TWOTUPLE), PARAMS_TWOTUPLE) + + def test_extract_params_formencoded(self): + self.assertCountEqual(extract_params(PARAMS_FORMENCODED), + PARAMS_TWOTUPLE) + + def test_extract_params_blank_string(self): + self.assertCountEqual(extract_params(''), []) + + def test_extract_params_empty_list(self): + self.assertCountEqual(extract_params([]), []) + + def test_extract_non_formencoded_string(self): + self.assertIsNone(extract_params('not a formencoded string')) + + def test_extract_invalid(self): + self.assertIsNone(extract_params(object())) + self.assertIsNone(extract_params([('')])) + + def test_add_params_to_uri(self): + correct = '{}?{}'.format(URI, PARAMS_FORMENCODED) + self.assertURLEqual(add_params_to_uri(URI, PARAMS_DICT), correct) + self.assertURLEqual(add_params_to_uri(URI, PARAMS_TWOTUPLE), correct) + + +class GeneratorTest(TestCase): + + def test_generate_timestamp(self): + timestamp = generate_timestamp() + self.assertIsInstance(timestamp, str) + self.assertTrue(int(timestamp)) + self.assertGreater(int(timestamp), 1331672335) + + def test_generate_nonce(self): + """Ping me (ib-lundgren) when you discover how to test randomness.""" + nonce = generate_nonce() + for i in range(50): + self.assertNotEqual(nonce, generate_nonce()) + + def test_generate_token(self): + token = generate_token() + self.assertEqual(len(token), 30) + + token = generate_token(length=44) + self.assertEqual(len(token), 44) + + token = generate_token(length=6, chars="python") + self.assertEqual(len(token), 6) + for c in token: + self.assertIn(c, "python") + + def test_generate_client_id(self): + client_id = generate_client_id() + self.assertEqual(len(client_id), 30) + + client_id = generate_client_id(length=44) + self.assertEqual(len(client_id), 44) + + client_id = generate_client_id(length=6, chars="python") + self.assertEqual(len(client_id), 6) + for c in client_id: + self.assertIn(c, "python") + + +class RequestTest(TestCase): + + def test_non_unicode_params(self): + r = Request( + b'http://a.b/path?query', + http_method=b'GET', + body=b'you=shall+pass', + headers={ + b'a': b'b', + } + ) + self.assertEqual(r.uri, 'http://a.b/path?query') + self.assertEqual(r.http_method, 'GET') + self.assertEqual(r.body, 'you=shall+pass') + self.assertEqual(r.decoded_body, [('you', 'shall pass')]) + self.assertEqual(r.headers, {'a': 'b'}) + + def test_none_body(self): + r = Request(URI) + self.assertIsNone(r.decoded_body) + + def test_empty_list_body(self): + r = Request(URI, body=[]) + self.assertEqual(r.decoded_body, []) + + def test_empty_dict_body(self): + r = Request(URI, body={}) + self.assertEqual(r.decoded_body, []) + + def test_empty_string_body(self): + r = Request(URI, body='') + self.assertEqual(r.decoded_body, []) + + def test_non_formencoded_string_body(self): + body = 'foo bar' + r = Request(URI, body=body) + self.assertIsNone(r.decoded_body) + + def test_param_free_sequence_body(self): + body = [1, 1, 2, 3, 5, 8, 13] + r = Request(URI, body=body) + self.assertIsNone(r.decoded_body) + + def test_list_body(self): + r = Request(URI, body=PARAMS_TWOTUPLE) + self.assertCountEqual(r.decoded_body, PARAMS_TWOTUPLE) + + def test_dict_body(self): + r = Request(URI, body=PARAMS_DICT) + self.assertCountEqual(r.decoded_body, PARAMS_TWOTUPLE) + + def test_getattr_existing_attribute(self): + r = Request(URI, body='foo bar') + self.assertEqual('foo bar', getattr(r, 'body')) + + def test_getattr_return_default(self): + r = Request(URI, body='') + actual_value = getattr(r, 'does_not_exist', 'foo bar') + self.assertEqual('foo bar', actual_value) + + def test_getattr_raise_attribute_error(self): + r = Request(URI, body='foo bar') + with self.assertRaises(AttributeError): + getattr(r, 'does_not_exist') + + def test_sanitizing_authorization_header(self): + r = Request(URI, headers={'Accept': 'application/json', + 'Authorization': 'Basic Zm9vOmJhcg=='} + ) + self.assertNotIn('Zm9vOmJhcg==', repr(r)) + self.assertIn('<SANITIZED>', repr(r)) + # Double-check we didn't modify the underlying object: + self.assertEqual(r.headers['Authorization'], 'Basic Zm9vOmJhcg==') + + def test_token_body(self): + payload = 'client_id=foo&refresh_token=bar' + r = Request(URI, body=payload) + self.assertNotIn('bar', repr(r)) + self.assertIn('<SANITIZED>', repr(r)) + + payload = 'refresh_token=bar&client_id=foo' + r = Request(URI, body=payload) + self.assertNotIn('bar', repr(r)) + self.assertIn('<SANITIZED>', repr(r)) + + def test_password_body(self): + payload = 'username=foo&password=bar' + r = Request(URI, body=payload) + self.assertNotIn('bar', repr(r)) + self.assertIn('<SANITIZED>', repr(r)) + + payload = 'password=bar&username=foo' + r = Request(URI, body=payload) + self.assertNotIn('bar', repr(r)) + self.assertIn('<SANITIZED>', repr(r)) + + def test_headers_params(self): + r = Request(URI, headers={'token': 'foobar'}, body='token=banana') + self.assertEqual(r.headers['token'], 'foobar') + self.assertEqual(r.token, 'banana') + + def test_sanitized_request_non_debug_mode(self): + """make sure requests are sanitized when in non debug mode. + For the debug mode, the other tests checking sanitization should prove + that debug mode is working. + """ + try: + oauthlib.set_debug(False) + r = Request(URI, headers={'token': 'foobar'}, body='token=banana') + self.assertNotIn('token', repr(r)) + self.assertIn('SANITIZED', repr(r)) + finally: + # set flag back for other tests + oauthlib.set_debug(True) + + +class CaseInsensitiveDictTest(TestCase): + + def test_basic(self): + cid = CaseInsensitiveDict({}) + cid['a'] = 'b' + cid['c'] = 'd' + del cid['c'] + self.assertEqual(cid['A'], 'b') + self.assertEqual(cid['a'], 'b') + + def test_update(self): + cid = CaseInsensitiveDict({}) + cid.update({'KeY': 'value'}) + self.assertEqual(cid['kEy'], 'value') diff --git a/contrib/python/oauthlib/tests/test_uri_validate.py b/contrib/python/oauthlib/tests/test_uri_validate.py new file mode 100644 index 0000000000..6a9f8ea60b --- /dev/null +++ b/contrib/python/oauthlib/tests/test_uri_validate.py @@ -0,0 +1,84 @@ +import unittest +from oauthlib.uri_validate import is_absolute_uri + +from tests.unittest import TestCase + + +class UriValidateTest(TestCase): + + def test_is_absolute_uri(self): + self.assertIsNotNone(is_absolute_uri('schema://example.com/path')) + self.assertIsNotNone(is_absolute_uri('https://example.com/path')) + self.assertIsNotNone(is_absolute_uri('https://example.com')) + self.assertIsNotNone(is_absolute_uri('https://example.com:443/path')) + self.assertIsNotNone(is_absolute_uri('https://example.com:443/')) + self.assertIsNotNone(is_absolute_uri('https://example.com:443')) + self.assertIsNotNone(is_absolute_uri('http://example.com')) + self.assertIsNotNone(is_absolute_uri('http://example.com/path')) + self.assertIsNotNone(is_absolute_uri('http://example.com:80/path')) + + def test_query(self): + self.assertIsNotNone(is_absolute_uri('http://example.com:80/path?foo')) + self.assertIsNotNone(is_absolute_uri('http://example.com:80/path?foo=bar')) + self.assertIsNotNone(is_absolute_uri('http://example.com:80/path?foo=bar&fruit=banana')) + + def test_fragment_forbidden(self): + self.assertIsNone(is_absolute_uri('http://example.com:80/path#foo')) + self.assertIsNone(is_absolute_uri('http://example.com:80/path#foo=bar')) + self.assertIsNone(is_absolute_uri('http://example.com:80/path#foo=bar&fruit=banana')) + + def test_combined_forbidden(self): + self.assertIsNone(is_absolute_uri('http://example.com:80/path?foo#bar')) + self.assertIsNone(is_absolute_uri('http://example.com:80/path?foo&bar#fruit')) + self.assertIsNone(is_absolute_uri('http://example.com:80/path?foo=1&bar#fruit=banana')) + self.assertIsNone(is_absolute_uri('http://example.com:80/path?foo=1&bar=2#fruit=banana&bar=foo')) + + def test_custom_scheme(self): + self.assertIsNotNone(is_absolute_uri('com.example.bundle.id://')) + + def test_ipv6_bracket(self): + self.assertIsNotNone(is_absolute_uri('http://[::1]:38432/path')) + self.assertIsNotNone(is_absolute_uri('http://[::1]/path')) + self.assertIsNotNone(is_absolute_uri('http://[fd01:0001::1]/path')) + self.assertIsNotNone(is_absolute_uri('http://[fd01:1::1]/path')) + self.assertIsNotNone(is_absolute_uri('http://[0123:4567:89ab:cdef:0123:4567:89ab:cdef]/path')) + self.assertIsNotNone(is_absolute_uri('http://[0123:4567:89ab:cdef:0123:4567:89ab:cdef]:8080/path')) + + @unittest.skip("ipv6 edge-cases not supported") + def test_ipv6_edge_cases(self): + self.assertIsNotNone(is_absolute_uri('http://2001:db8::')) + self.assertIsNotNone(is_absolute_uri('http://::1234:5678')) + self.assertIsNotNone(is_absolute_uri('http://2001:db8::1234:5678')) + self.assertIsNotNone(is_absolute_uri('http://2001:db8:3333:4444:5555:6666:7777:8888')) + self.assertIsNotNone(is_absolute_uri('http://2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF')) + self.assertIsNotNone(is_absolute_uri('http://0123:4567:89ab:cdef:0123:4567:89ab:cdef/path')) + self.assertIsNotNone(is_absolute_uri('http://::')) + self.assertIsNotNone(is_absolute_uri('http://2001:0db8:0001:0000:0000:0ab9:C0A8:0102')) + + @unittest.skip("ipv6 dual ipv4 not supported") + def test_ipv6_dual(self): + self.assertIsNotNone(is_absolute_uri('http://2001:db8:3333:4444:5555:6666:1.2.3.4')) + self.assertIsNotNone(is_absolute_uri('http://::11.22.33.44')) + self.assertIsNotNone(is_absolute_uri('http://2001:db8::123.123.123.123')) + self.assertIsNotNone(is_absolute_uri('http://::1234:5678:91.123.4.56')) + self.assertIsNotNone(is_absolute_uri('http://::1234:5678:1.2.3.4')) + self.assertIsNotNone(is_absolute_uri('http://2001:db8::1234:5678:5.6.7.8')) + + def test_ipv4(self): + self.assertIsNotNone(is_absolute_uri('http://127.0.0.1:38432/')) + self.assertIsNotNone(is_absolute_uri('http://127.0.0.1:38432/')) + self.assertIsNotNone(is_absolute_uri('http://127.1:38432/')) + + def test_failures(self): + self.assertIsNone(is_absolute_uri('http://example.com:notaport/path')) + self.assertIsNone(is_absolute_uri('wrong')) + self.assertIsNone(is_absolute_uri('http://[:1]:38432/path')) + self.assertIsNone(is_absolute_uri('http://[abcd:efgh::1]/')) + + def test_recursive_regex(self): + from datetime import datetime + t0 = datetime.now() + is_absolute_uri('http://[::::::::::::::::::::::::::]/path') + t1 = datetime.now() + spent = t1 - t0 + self.assertGreater(0.1, spent.total_seconds(), "possible recursive loop detected") diff --git a/contrib/python/oauthlib/tests/unittest/__init__.py b/contrib/python/oauthlib/tests/unittest/__init__.py new file mode 100644 index 0000000000..f94f35c664 --- /dev/null +++ b/contrib/python/oauthlib/tests/unittest/__init__.py @@ -0,0 +1,32 @@ +import urllib.parse as urlparse +from unittest import TestCase + + +# URL comparison where query param order is insignificant +def url_equals(self, a, b, parse_fragment=False): + parsed_a = urlparse.urlparse(a, allow_fragments=parse_fragment) + parsed_b = urlparse.urlparse(b, allow_fragments=parse_fragment) + query_a = urlparse.parse_qsl(parsed_a.query) + query_b = urlparse.parse_qsl(parsed_b.query) + if parse_fragment: + fragment_a = urlparse.parse_qsl(parsed_a.fragment) + fragment_b = urlparse.parse_qsl(parsed_b.fragment) + self.assertCountEqual(fragment_a, fragment_b) + else: + self.assertEqual(parsed_a.fragment, parsed_b.fragment) + self.assertEqual(parsed_a.scheme, parsed_b.scheme) + self.assertEqual(parsed_a.netloc, parsed_b.netloc) + self.assertEqual(parsed_a.path, parsed_b.path) + self.assertEqual(parsed_a.params, parsed_b.params) + self.assertEqual(parsed_a.username, parsed_b.username) + self.assertEqual(parsed_a.password, parsed_b.password) + self.assertEqual(parsed_a.hostname, parsed_b.hostname) + self.assertEqual(parsed_a.port, parsed_b.port) + self.assertCountEqual(query_a, query_b) + + +TestCase.assertURLEqual = url_equals + +# Form body comparison where order is insignificant +TestCase.assertFormBodyEqual = lambda self, a, b: self.assertCountEqual( + urlparse.parse_qsl(a), urlparse.parse_qsl(b)) diff --git a/contrib/python/oauthlib/tests/ya.make b/contrib/python/oauthlib/tests/ya.make new file mode 100644 index 0000000000..b207e5ea63 --- /dev/null +++ b/contrib/python/oauthlib/tests/ya.make @@ -0,0 +1,88 @@ +PY3TEST() + +PEERDIR( + contrib/python/oauthlib + contrib/python/mock + contrib/python/PyJWT + contrib/python/blinker +) + +PY_SRCS( + NAMESPACE tests + unittest/__init__.py +) + +TEST_SRCS( + __init__.py + oauth1/__init__.py + oauth1/rfc5849/__init__.py + oauth1/rfc5849/endpoints/__init__.py + oauth1/rfc5849/endpoints/test_access_token.py + oauth1/rfc5849/endpoints/test_authorization.py + oauth1/rfc5849/endpoints/test_base.py + oauth1/rfc5849/endpoints/test_request_token.py + oauth1/rfc5849/endpoints/test_resource.py + oauth1/rfc5849/endpoints/test_signature_only.py + oauth1/rfc5849/test_client.py + oauth1/rfc5849/test_parameters.py + oauth1/rfc5849/test_request_validator.py + oauth1/rfc5849/test_signatures.py + oauth1/rfc5849/test_utils.py + oauth2/__init__.py + oauth2/rfc6749/__init__.py + oauth2/rfc6749/clients/__init__.py + oauth2/rfc6749/clients/test_backend_application.py + oauth2/rfc6749/clients/test_base.py + oauth2/rfc6749/clients/test_legacy_application.py + oauth2/rfc6749/clients/test_mobile_application.py + oauth2/rfc6749/clients/test_service_application.py + oauth2/rfc6749/clients/test_web_application.py + oauth2/rfc6749/endpoints/__init__.py + oauth2/rfc6749/endpoints/test_base_endpoint.py + oauth2/rfc6749/endpoints/test_client_authentication.py + oauth2/rfc6749/endpoints/test_credentials_preservation.py + oauth2/rfc6749/endpoints/test_error_responses.py + oauth2/rfc6749/endpoints/test_extra_credentials.py + oauth2/rfc6749/endpoints/test_introspect_endpoint.py + oauth2/rfc6749/endpoints/test_metadata.py + oauth2/rfc6749/endpoints/test_resource_owner_association.py + oauth2/rfc6749/endpoints/test_revocation_endpoint.py + oauth2/rfc6749/endpoints/test_scope_handling.py + oauth2/rfc6749/endpoints/test_utils.py + oauth2/rfc6749/grant_types/__init__.py + oauth2/rfc6749/grant_types/test_authorization_code.py + oauth2/rfc6749/grant_types/test_client_credentials.py + oauth2/rfc6749/grant_types/test_implicit.py + oauth2/rfc6749/grant_types/test_refresh_token.py + oauth2/rfc6749/grant_types/test_resource_owner_password.py + oauth2/rfc6749/test_parameters.py + oauth2/rfc6749/test_request_validator.py + oauth2/rfc6749/test_server.py + oauth2/rfc6749/test_tokens.py + oauth2/rfc6749/test_utils.py + oauth2/rfc8628/__init__.py + oauth2/rfc8628/clients/__init__.py + oauth2/rfc8628/clients/test_device.py + openid/__init__.py + openid/connect/__init__.py + openid/connect/core/__init__.py + openid/connect/core/endpoints/__init__.py + openid/connect/core/endpoints/test_claims_handling.py + openid/connect/core/endpoints/test_openid_connect_params_handling.py + openid/connect/core/endpoints/test_userinfo_endpoint.py + openid/connect/core/grant_types/__init__.py + openid/connect/core/grant_types/test_authorization_code.py + openid/connect/core/grant_types/test_base.py + openid/connect/core/grant_types/test_dispatchers.py + openid/connect/core/grant_types/test_hybrid.py + openid/connect/core/grant_types/test_implicit.py + openid/connect/core/test_request_validator.py + openid/connect/core/test_server.py + openid/connect/core/test_tokens.py + test_common.py + test_uri_validate.py +) + +NO_LINT() + +END() |