diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2024-03-25 09:11:17 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2024-03-25 09:17:48 +0300 |
commit | 4624e4cfd95649270db02616edde8d0ca249b63d (patch) | |
tree | 1c8a43f50533ca759d137f258e42862e8cf5e80f /contrib/python/requests-oauthlib/requests_oauthlib/oauth2_session.py | |
parent | d2d971701bd8377ead5f973c96be81042774bd2a (diff) | |
download | ydb-4624e4cfd95649270db02616edde8d0ca249b63d.tar.gz |
Intermediate changes
Diffstat (limited to 'contrib/python/requests-oauthlib/requests_oauthlib/oauth2_session.py')
-rw-r--r-- | contrib/python/requests-oauthlib/requests_oauthlib/oauth2_session.py | 65 |
1 files changed, 56 insertions, 9 deletions
diff --git a/contrib/python/requests-oauthlib/requests_oauthlib/oauth2_session.py b/contrib/python/requests-oauthlib/requests_oauthlib/oauth2_session.py index db4468089b..93cc4d7bbd 100644 --- a/contrib/python/requests-oauthlib/requests_oauthlib/oauth2_session.py +++ b/contrib/python/requests-oauthlib/requests_oauthlib/oauth2_session.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - import logging from oauthlib.common import generate_token, urldecode @@ -46,6 +44,7 @@ class OAuth2Session(requests.Session): token=None, state=None, token_updater=None, + pkce=None, **kwargs ): """Construct a new OAuth 2 client session. @@ -72,18 +71,23 @@ class OAuth2Session(requests.Session): set a TokenUpdated warning will be raised when a token has been refreshed. This warning will carry the token in its token argument. + :param pkce: Set "S256" or "plain" to enable PKCE. Default is disabled. :param kwargs: Arguments to pass to the Session constructor. """ super(OAuth2Session, self).__init__(**kwargs) self._client = client or WebApplicationClient(client_id, token=token) self.token = token or {} - self.scope = scope + self._scope = scope self.redirect_uri = redirect_uri self.state = state or generate_token self._state = state self.auto_refresh_url = auto_refresh_url self.auto_refresh_kwargs = auto_refresh_kwargs or {} self.token_updater = token_updater + self._pkce = pkce + + if self._pkce not in ["S256", "plain", None]: + raise AttributeError("Wrong value for {}(.., pkce={})".format(self.__class__, self._pkce)) # Ensure that requests doesn't do any automatic auth. See #278. # The default behavior can be re-enabled by setting auth to None. @@ -95,8 +99,24 @@ class OAuth2Session(requests.Session): "access_token_response": set(), "refresh_token_response": set(), "protected_request": set(), + "refresh_token_request": set(), + "access_token_request": set(), } + @property + def scope(self): + """By default the scope from the client is used, except if overridden""" + if self._scope is not None: + return self._scope + elif self._client is not None: + return self._client.scope + else: + return None + + @scope.setter + def scope(self, scope): + self._scope = scope + def new_state(self): """Generates a state string to be used in authorizations.""" try: @@ -161,6 +181,13 @@ class OAuth2Session(requests.Session): :return: authorization_url, state """ state = state or self.new_state() + if self._pkce: + self._code_verifier = self._client.create_code_verifier(43) + kwargs["code_challenge_method"] = self._pkce + kwargs["code_challenge"] = self._client.create_code_challenge( + code_verifier=self._code_verifier, + code_challenge_method=self._pkce + ) return ( self._client.prepare_request_uri( url, @@ -185,7 +212,7 @@ class OAuth2Session(requests.Session): force_querystring=False, timeout=None, headers=None, - verify=True, + verify=None, proxies=None, include_client_id=None, client_secret=None, @@ -252,6 +279,13 @@ class OAuth2Session(requests.Session): "Please supply either code or " "authorization_response parameters." ) + if self._pkce: + if self._code_verifier is None: + raise ValueError( + "Code verifier is not found, authorization URL must be generated before" + ) + kwargs["code_verifier"] = self._code_verifier + # Earlier versions of this library build an HTTPBasicAuth header out of # `username` and `password`. The RFC states, however these attributes # must be in the request body and not the header. @@ -325,7 +359,7 @@ class OAuth2Session(requests.Session): headers = headers or { "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8", + "Content-Type": "application/x-www-form-urlencoded", } self.token = {} request_kwargs = {} @@ -338,6 +372,12 @@ class OAuth2Session(requests.Session): else: raise ValueError("The method kwarg must be POST or GET.") + for hook in self.compliance_hook["access_token_request"]: + log.debug("Invoking access_token_request hook %s.", hook) + token_url, headers, request_kwargs = hook( + token_url, headers, request_kwargs + ) + r = self.request( method=method, url=token_url, @@ -388,7 +428,7 @@ class OAuth2Session(requests.Session): auth=None, timeout=None, headers=None, - verify=True, + verify=None, proxies=None, **kwargs ): @@ -426,9 +466,13 @@ class OAuth2Session(requests.Session): if headers is None: headers = { "Accept": "application/json", - "Content-Type": ("application/x-www-form-urlencoded;charset=UTF-8"), + "Content-Type": ("application/x-www-form-urlencoded"), } + for hook in self.compliance_hook["refresh_token_request"]: + log.debug("Invoking refresh_token_request hook %s.", hook) + token_url, headers, body = hook(token_url, headers, body) + r = self.post( token_url, data=dict(urldecode(body)), @@ -450,7 +494,7 @@ class OAuth2Session(requests.Session): r = hook(r) self.token = self._client.parse_request_body_response(r.text, scope=self.scope) - if not "refresh_token" in self.token: + if "refresh_token" not in self.token: log.debug("No new refresh token given. Re-using old.") self.token["refresh_token"] = refresh_token return self.token @@ -464,6 +508,7 @@ class OAuth2Session(requests.Session): withhold_token=False, client_id=None, client_secret=None, + files=None, **kwargs ): """Intercept all requests and add the OAuth 2 token if present.""" @@ -519,7 +564,7 @@ class OAuth2Session(requests.Session): log.debug("Supplying headers %s and data %s", headers, data) log.debug("Passing through key word arguments %s.", kwargs) return super(OAuth2Session, self).request( - method, url, headers=headers, data=data, **kwargs + method, url, headers=headers, data=data, files=files, **kwargs ) def register_compliance_hook(self, hook_type, hook): @@ -529,6 +574,8 @@ class OAuth2Session(requests.Session): access_token_response invoked before token parsing. refresh_token_response invoked before refresh token parsing. protected_request invoked before making a request. + access_token_request invoked before making a token fetch request. + refresh_token_request invoked before making a refresh request. If you find a new hook is needed please send a GitHub PR request or open an issue. |