diff --git a/lib/requests_oauthlib/__init__.py b/lib/requests_oauthlib/__init__.py index 8b320523..a4e03a4e 100644 --- a/lib/requests_oauthlib/__init__.py +++ b/lib/requests_oauthlib/__init__.py @@ -1,22 +1,19 @@ +import logging + from .oauth1_auth import OAuth1 from .oauth1_session import OAuth1Session from .oauth2_auth import OAuth2 from .oauth2_session import OAuth2Session, TokenUpdated -__version__ = '0.6.1' +__version__ = "1.3.0" import requests -if requests.__version__ < '2.0.0': - msg = ('You are using requests version %s, which is older than ' - 'requests-oauthlib expects, please upgrade to 2.0.0 or later.') + +if requests.__version__ < "2.0.0": + msg = ( + "You are using requests version %s, which is older than " + "requests-oauthlib expects, please upgrade to 2.0.0 or later." + ) raise Warning(msg % requests.__version__) -import logging -try: # Python 2.7+ - from logging import NullHandler -except ImportError: - class NullHandler(logging.Handler): - def emit(self, record): - pass - -logging.getLogger('requests_oauthlib').addHandler(NullHandler()) +logging.getLogger("requests_oauthlib").addHandler(logging.NullHandler()) diff --git a/lib/requests_oauthlib/compliance_fixes/__init__.py b/lib/requests_oauthlib/compliance_fixes/__init__.py index 46eacb8b..02fa5120 100644 --- a/lib/requests_oauthlib/compliance_fixes/__init__.py +++ b/lib/requests_oauthlib/compliance_fixes/__init__.py @@ -1,7 +1,10 @@ from __future__ import absolute_import from .facebook import facebook_compliance_fix +from .fitbit import fitbit_compliance_fix from .linkedin import linkedin_compliance_fix from .slack import slack_compliance_fix +from .instagram import instagram_compliance_fix from .mailchimp import mailchimp_compliance_fix from .weibo import weibo_compliance_fix +from .plentymarkets import plentymarkets_compliance_fix diff --git a/lib/requests_oauthlib/compliance_fixes/douban.py b/lib/requests_oauthlib/compliance_fixes/douban.py index 2e45b3b9..ecc57b08 100644 --- a/lib/requests_oauthlib/compliance_fixes/douban.py +++ b/lib/requests_oauthlib/compliance_fixes/douban.py @@ -4,15 +4,14 @@ from oauthlib.common import to_unicode def douban_compliance_fix(session): - def fix_token_type(r): token = json.loads(r.text) - token.setdefault('token_type', 'Bearer') + token.setdefault("token_type", "Bearer") fixed_token = json.dumps(token) - r._content = to_unicode(fixed_token).encode('utf-8') + r._content = to_unicode(fixed_token).encode("utf-8") return r - session._client_default_token_placement = 'query' - session.register_compliance_hook('access_token_response', fix_token_type) + session._client_default_token_placement = "query" + session.register_compliance_hook("access_token_response", fix_token_type) return session diff --git a/lib/requests_oauthlib/compliance_fixes/facebook.py b/lib/requests_oauthlib/compliance_fixes/facebook.py index 07181c39..90e79212 100644 --- a/lib/requests_oauthlib/compliance_fixes/facebook.py +++ b/lib/requests_oauthlib/compliance_fixes/facebook.py @@ -1,4 +1,5 @@ from json import dumps + try: from urlparse import parse_qsl except ImportError: @@ -8,26 +9,25 @@ from oauthlib.common import to_unicode def facebook_compliance_fix(session): - def _compliance_fix(r): # if Facebook claims to be sending us json, let's trust them. - if 'application/json' in r.headers.get('content-type', {}): + if "application/json" in r.headers.get("content-type", {}): return r # Facebook returns a content-type of text/plain when sending their # x-www-form-urlencoded responses, along with a 200. If not, let's # assume we're getting JSON and bail on the fix. - if 'text/plain' in r.headers.get('content-type', {}) and r.status_code == 200: + if "text/plain" in r.headers.get("content-type", {}) and r.status_code == 200: token = dict(parse_qsl(r.text, keep_blank_values=True)) else: return r - expires = token.get('expires') + expires = token.get("expires") if expires is not None: - token['expires_in'] = expires - token['token_type'] = 'Bearer' - r._content = to_unicode(dumps(token)).encode('UTF-8') + token["expires_in"] = expires + token["token_type"] = "Bearer" + r._content = to_unicode(dumps(token)).encode("UTF-8") return r - session.register_compliance_hook('access_token_response', _compliance_fix) + session.register_compliance_hook("access_token_response", _compliance_fix) return session diff --git a/lib/requests_oauthlib/compliance_fixes/fitbit.py b/lib/requests_oauthlib/compliance_fixes/fitbit.py new file mode 100644 index 00000000..7e627024 --- /dev/null +++ b/lib/requests_oauthlib/compliance_fixes/fitbit.py @@ -0,0 +1,25 @@ +""" +The Fitbit API breaks from the OAuth2 RFC standard by returning an "errors" +object list, rather than a single "error" string. This puts hooks in place so +that oauthlib can process an error in the results from access token and refresh +token responses. This is necessary to prevent getting the generic red herring +MissingTokenError. +""" + +from json import loads, dumps + +from oauthlib.common import to_unicode + + +def fitbit_compliance_fix(session): + def _missing_error(r): + token = loads(r.text) + if "errors" in token: + # Set the error to the first one we have + token["error"] = token["errors"][0]["errorType"] + r._content = to_unicode(dumps(token)).encode("UTF-8") + return r + + session.register_compliance_hook("access_token_response", _missing_error) + session.register_compliance_hook("refresh_token_response", _missing_error) + return session diff --git a/lib/requests_oauthlib/compliance_fixes/instagram.py b/lib/requests_oauthlib/compliance_fixes/instagram.py new file mode 100644 index 00000000..4e07fe08 --- /dev/null +++ b/lib/requests_oauthlib/compliance_fixes/instagram.py @@ -0,0 +1,26 @@ +try: + from urlparse import urlparse, parse_qs +except ImportError: + from urllib.parse import urlparse, parse_qs + +from oauthlib.common import add_params_to_uri + + +def instagram_compliance_fix(session): + def _non_compliant_param_name(url, headers, data): + # If the user has already specified the token in the URL + # then there's nothing to do. + # If the specified token is different from ``session.access_token``, + # we assume the user intends to override the access token. + url_query = dict(parse_qs(urlparse(url).query)) + token = url_query.get("access_token") + if token: + # Nothing to do, just return. + return url, headers, data + + token = [("access_token", session.access_token)] + url = add_params_to_uri(url, token) + return url, headers, data + + session.register_compliance_hook("protected_request", _non_compliant_param_name) + return session diff --git a/lib/requests_oauthlib/compliance_fixes/linkedin.py b/lib/requests_oauthlib/compliance_fixes/linkedin.py index e697ced9..cd5b4ace 100644 --- a/lib/requests_oauthlib/compliance_fixes/linkedin.py +++ b/lib/requests_oauthlib/compliance_fixes/linkedin.py @@ -4,21 +4,18 @@ from oauthlib.common import add_params_to_uri, to_unicode def linkedin_compliance_fix(session): - def _missing_token_type(r): token = loads(r.text) - token['token_type'] = 'Bearer' - r._content = to_unicode(dumps(token)).encode('UTF-8') + token["token_type"] = "Bearer" + r._content = to_unicode(dumps(token)).encode("UTF-8") return r def _non_compliant_param_name(url, headers, data): - token = [('oauth2_access_token', session.access_token)] + token = [("oauth2_access_token", session.access_token)] url = add_params_to_uri(url, token) return url, headers, data - session._client.default_token_placement = 'query' - session.register_compliance_hook('access_token_response', - _missing_token_type) - session.register_compliance_hook('protected_request', - _non_compliant_param_name) + session._client.default_token_placement = "query" + session.register_compliance_hook("access_token_response", _missing_token_type) + session.register_compliance_hook("protected_request", _non_compliant_param_name) return session diff --git a/lib/requests_oauthlib/compliance_fixes/mailchimp.py b/lib/requests_oauthlib/compliance_fixes/mailchimp.py index ee9bc942..c69ce9fd 100644 --- a/lib/requests_oauthlib/compliance_fixes/mailchimp.py +++ b/lib/requests_oauthlib/compliance_fixes/mailchimp.py @@ -2,21 +2,22 @@ import json from oauthlib.common import to_unicode + def mailchimp_compliance_fix(session): def _null_scope(r): token = json.loads(r.text) - if 'scope' in token and token['scope'] is None: - token.pop('scope') - r._content = to_unicode(json.dumps(token)).encode('utf-8') + if "scope" in token and token["scope"] is None: + token.pop("scope") + r._content = to_unicode(json.dumps(token)).encode("utf-8") return r def _non_zero_expiration(r): token = json.loads(r.text) - if 'expires_in' in token and token['expires_in'] == 0: - token['expires_in'] = 3600 - r._content = to_unicode(json.dumps(token)).encode('utf-8') + if "expires_in" in token and token["expires_in"] == 0: + token["expires_in"] = 3600 + r._content = to_unicode(json.dumps(token)).encode("utf-8") return r - session.register_compliance_hook('access_token_response', _null_scope) - session.register_compliance_hook('access_token_response', _non_zero_expiration) + session.register_compliance_hook("access_token_response", _null_scope) + session.register_compliance_hook("access_token_response", _non_zero_expiration) return session diff --git a/lib/requests_oauthlib/compliance_fixes/plentymarkets.py b/lib/requests_oauthlib/compliance_fixes/plentymarkets.py new file mode 100644 index 00000000..9f605f05 --- /dev/null +++ b/lib/requests_oauthlib/compliance_fixes/plentymarkets.py @@ -0,0 +1,29 @@ +from json import dumps, loads +import re + +from oauthlib.common import to_unicode + + +def plentymarkets_compliance_fix(session): + def _to_snake_case(n): + return re.sub("(.)([A-Z][a-z]+)", r"\1_\2", n).lower() + + def _compliance_fix(r): + # Plenty returns the Token in CamelCase instead of _ + if ( + "application/json" in r.headers.get("content-type", {}) + and r.status_code == 200 + ): + token = loads(r.text) + else: + return r + + fixed_token = {} + for k, v in token.items(): + fixed_token[_to_snake_case(k)] = v + + r._content = to_unicode(dumps(fixed_token)).encode("UTF-8") + return r + + session.register_compliance_hook("access_token_response", _compliance_fix) + return session diff --git a/lib/requests_oauthlib/compliance_fixes/slack.py b/lib/requests_oauthlib/compliance_fixes/slack.py index ab2d9382..3f574b03 100644 --- a/lib/requests_oauthlib/compliance_fixes/slack.py +++ b/lib/requests_oauthlib/compliance_fixes/slack.py @@ -29,9 +29,9 @@ def slack_compliance_fix(session): # ``data`` is something other than a dict: maybe a stream, # maybe a file object, maybe something else. We can't easily # modify it, so we'll set the token by modifying the URL instead. - token = [('token', session.access_token)] + token = [("token", session.access_token)] url = add_params_to_uri(url, token) return url, headers, data - session.register_compliance_hook('protected_request', _non_compliant_param_name) + session.register_compliance_hook("protected_request", _non_compliant_param_name) return session diff --git a/lib/requests_oauthlib/compliance_fixes/weibo.py b/lib/requests_oauthlib/compliance_fixes/weibo.py index 28aca327..6733abeb 100644 --- a/lib/requests_oauthlib/compliance_fixes/weibo.py +++ b/lib/requests_oauthlib/compliance_fixes/weibo.py @@ -4,14 +4,12 @@ from oauthlib.common import to_unicode def weibo_compliance_fix(session): - def _missing_token_type(r): token = loads(r.text) - token['token_type'] = 'Bearer' - r._content = to_unicode(dumps(token)).encode('UTF-8') + token["token_type"] = "Bearer" + r._content = to_unicode(dumps(token)).encode("UTF-8") return r - session._client.default_token_placement = 'query' - session.register_compliance_hook('access_token_response', - _missing_token_type) + session._client.default_token_placement = "query" + session.register_compliance_hook("access_token_response", _missing_token_type) return session diff --git a/lib/requests_oauthlib/oauth1_auth.py b/lib/requests_oauthlib/oauth1_auth.py index 46263841..cfbbd590 100644 --- a/lib/requests_oauthlib/oauth1_auth.py +++ b/lib/requests_oauthlib/oauth1_auth.py @@ -10,8 +10,8 @@ from requests.compat import is_py3 from requests.utils import to_native_string from requests.auth import AuthBase -CONTENT_TYPE_FORM_URLENCODED = 'application/x-www-form-urlencoded' -CONTENT_TYPE_MULTI_PART = 'multipart/form-data' +CONTENT_TYPE_FORM_URLENCODED = "application/x-www-form-urlencoded" +CONTENT_TYPE_MULTI_PART = "multipart/form-data" if is_py3: unicode = str @@ -26,18 +26,22 @@ class OAuth1(AuthBase): client_class = Client - def __init__(self, client_key, - client_secret=None, - resource_owner_key=None, - resource_owner_secret=None, - callback_uri=None, - signature_method=SIGNATURE_HMAC, - signature_type=SIGNATURE_TYPE_AUTH_HEADER, - rsa_key=None, verifier=None, - decoding='utf-8', - client_class=None, - force_include_body=False, - **kwargs): + def __init__( + self, + client_key, + client_secret=None, + resource_owner_key=None, + resource_owner_secret=None, + callback_uri=None, + signature_method=SIGNATURE_HMAC, + signature_type=SIGNATURE_TYPE_AUTH_HEADER, + rsa_key=None, + verifier=None, + decoding="utf-8", + client_class=None, + force_include_body=False, + **kwargs + ): try: signature_type = signature_type.upper() @@ -48,9 +52,19 @@ class OAuth1(AuthBase): self.force_include_body = force_include_body - self.client = client_class(client_key, client_secret, resource_owner_key, - resource_owner_secret, callback_uri, signature_method, - signature_type, rsa_key, verifier, decoding=decoding, **kwargs) + self.client = client_class( + client_key, + client_secret, + resource_owner_key, + resource_owner_secret, + callback_uri, + signature_method, + signature_type, + rsa_key, + verifier, + decoding=decoding, + **kwargs + ) def __call__(self, r): """Add OAuth parameters to the request. @@ -60,36 +74,44 @@ class OAuth1(AuthBase): """ # Overwriting url is safe here as request will not modify it past # this point. - log.debug('Signing request %s using client %s', r, self.client) + log.debug("Signing request %s using client %s", r, self.client) - content_type = r.headers.get('Content-Type', '') - if (not content_type and extract_params(r.body) - or self.client.signature_type == SIGNATURE_TYPE_BODY): + content_type = r.headers.get("Content-Type", "") + if ( + not content_type + and extract_params(r.body) + or self.client.signature_type == SIGNATURE_TYPE_BODY + ): content_type = CONTENT_TYPE_FORM_URLENCODED if not isinstance(content_type, unicode): - content_type = content_type.decode('utf-8') + content_type = content_type.decode("utf-8") - is_form_encoded = (CONTENT_TYPE_FORM_URLENCODED in content_type) + is_form_encoded = CONTENT_TYPE_FORM_URLENCODED in content_type - log.debug('Including body in call to sign: %s', - is_form_encoded or self.force_include_body) + log.debug( + "Including body in call to sign: %s", + is_form_encoded or self.force_include_body, + ) if is_form_encoded: - r.headers['Content-Type'] = CONTENT_TYPE_FORM_URLENCODED + r.headers["Content-Type"] = CONTENT_TYPE_FORM_URLENCODED r.url, headers, r.body = self.client.sign( - unicode(r.url), unicode(r.method), r.body or '', r.headers) + unicode(r.url), unicode(r.method), r.body or "", r.headers + ) elif self.force_include_body: # To allow custom clients to work on non form encoded bodies. r.url, headers, r.body = self.client.sign( - unicode(r.url), unicode(r.method), r.body or '', r.headers) + unicode(r.url), unicode(r.method), r.body or "", r.headers + ) else: # Omit body data in the signing of non form-encoded requests r.url, headers, _ = self.client.sign( - unicode(r.url), unicode(r.method), None, r.headers) + unicode(r.url), unicode(r.method), None, r.headers + ) r.prepare_headers(headers) r.url = to_native_string(r.url) - log.debug('Updated url: %s', r.url) - log.debug('Updated headers: %s', headers) - log.debug('Updated body: %r', r.body) + log.debug("Updated url: %s", r.url) + log.debug("Updated headers: %s", headers) + log.debug("Updated body: %r", r.body) return r diff --git a/lib/requests_oauthlib/oauth1_session.py b/lib/requests_oauthlib/oauth1_session.py index ad7b9069..aa17f28f 100644 --- a/lib/requests_oauthlib/oauth1_session.py +++ b/lib/requests_oauthlib/oauth1_session.py @@ -9,17 +9,11 @@ import logging from oauthlib.common import add_params_to_uri from oauthlib.common import urldecode as _urldecode -from oauthlib.oauth1 import ( - SIGNATURE_HMAC, SIGNATURE_RSA, SIGNATURE_TYPE_AUTH_HEADER -) +from oauthlib.oauth1 import SIGNATURE_HMAC, SIGNATURE_RSA, SIGNATURE_TYPE_AUTH_HEADER import requests from . import OAuth1 -import sys -if sys.version > "3": - unicode = str - log = logging.getLogger(__name__) @@ -28,13 +22,13 @@ def urldecode(body): """Parse query or json to python dictionary""" try: return _urldecode(body) - except: + except Exception: import json + return json.loads(body) class TokenRequestDenied(ValueError): - def __init__(self, message, response): super(TokenRequestDenied, self).__init__(message) self.response = response @@ -110,18 +104,21 @@ class OAuth1Session(requests.Session): """ - def __init__(self, client_key, - client_secret=None, - resource_owner_key=None, - resource_owner_secret=None, - callback_uri=None, - signature_method=SIGNATURE_HMAC, - signature_type=SIGNATURE_TYPE_AUTH_HEADER, - rsa_key=None, - verifier=None, - client_class=None, - force_include_body=False, - **kwargs): + def __init__( + self, + client_key, + client_secret=None, + resource_owner_key=None, + resource_owner_secret=None, + callback_uri=None, + signature_method=SIGNATURE_HMAC, + signature_type=SIGNATURE_TYPE_AUTH_HEADER, + rsa_key=None, + verifier=None, + client_class=None, + force_include_body=False, + **kwargs + ): """Construct the OAuth 1 session. :param client_key: A client specific identifier. @@ -158,20 +155,42 @@ class OAuth1Session(requests.Session): :param **kwargs: Additional keyword arguments passed to `OAuth1` """ super(OAuth1Session, self).__init__() - self._client = OAuth1(client_key, - client_secret=client_secret, - resource_owner_key=resource_owner_key, - resource_owner_secret=resource_owner_secret, - callback_uri=callback_uri, - signature_method=signature_method, - signature_type=signature_type, - rsa_key=rsa_key, - verifier=verifier, - client_class=client_class, - force_include_body=force_include_body, - **kwargs) + self._client = OAuth1( + client_key, + client_secret=client_secret, + resource_owner_key=resource_owner_key, + resource_owner_secret=resource_owner_secret, + callback_uri=callback_uri, + signature_method=signature_method, + signature_type=signature_type, + rsa_key=rsa_key, + verifier=verifier, + client_class=client_class, + force_include_body=force_include_body, + **kwargs + ) self.auth = self._client + @property + def token(self): + oauth_token = self._client.client.resource_owner_key + oauth_token_secret = self._client.client.resource_owner_secret + oauth_verifier = self._client.client.verifier + + token_dict = {} + if oauth_token: + token_dict["oauth_token"] = oauth_token + if oauth_token_secret: + token_dict["oauth_token_secret"] = oauth_token_secret + if oauth_verifier: + token_dict["oauth_verifier"] = oauth_verifier + + return token_dict + + @token.setter + def token(self, value): + self._populate_attributes(value) + @property def authorized(self): """Boolean that indicates whether this session has an OAuth token @@ -187,9 +206,9 @@ class OAuth1Session(requests.Session): else: # other methods of authentication use all three pieces return ( - bool(self._client.client.client_secret) and - bool(self._client.client.resource_owner_key) and - bool(self._client.client.resource_owner_secret) + bool(self._client.client.client_secret) + and bool(self._client.client.resource_owner_key) + and bool(self._client.client.resource_owner_secret) ) def authorization_url(self, url, request_token=None, **kwargs): @@ -234,12 +253,12 @@ class OAuth1Session(requests.Session): >>> oauth_session.authorization_url(authorization_url) 'https://api.twitter.com/oauth/authorize?oauth_token=sdf0o9823sjdfsdf&oauth_callback=https%3A%2F%2F127.0.0.1%2Fcallback' """ - kwargs['oauth_token'] = request_token or self._client.client.resource_owner_key - log.debug('Adding parameters %s to url %s', kwargs, url) + kwargs["oauth_token"] = request_token or self._client.client.resource_owner_key + log.debug("Adding parameters %s to url %s", kwargs, url) return add_params_to_uri(url, kwargs.items()) def fetch_request_token(self, url, realm=None, **request_kwargs): - """Fetch a request token. + r"""Fetch a request token. This is the first step in the OAuth 1 workflow. A request token is obtained by making a signed post request to url. The token is then @@ -264,9 +283,9 @@ class OAuth1Session(requests.Session): 'oauth_token_secret': '2kjshdfp92i34asdasd', } """ - self._client.client.realm = ' '.join(realm) if realm else None + self._client.client.realm = " ".join(realm) if realm else None token = self._fetch_token(url, **request_kwargs) - log.debug('Resetting callback_uri and realm (not needed in next phase).') + log.debug("Resetting callback_uri and realm (not needed in next phase).") self._client.client.callback_uri = None self._client.client.realm = None return token @@ -299,10 +318,10 @@ class OAuth1Session(requests.Session): """ if verifier: self._client.client.verifier = verifier - if not getattr(self._client.client, 'verifier', None): - raise VerifierMissing('No client verifier has been set.') + if not getattr(self._client.client, "verifier", None): + raise VerifierMissing("No client verifier has been set.") token = self._fetch_token(url, **request_kwargs) - log.debug('Resetting verifier attribute, should not be used anymore.') + log.debug("Resetting verifier attribute, should not be used anymore.") self._client.client.verifier = None return token @@ -322,28 +341,27 @@ class OAuth1Session(requests.Session): 'oauth_verifier: 'w34o8967345', } """ - log.debug('Parsing token from query part of url %s', url) + log.debug("Parsing token from query part of url %s", url) token = dict(urldecode(urlparse(url).query)) - log.debug('Updating internal client token attribute.') + log.debug("Updating internal client token attribute.") self._populate_attributes(token) + self.token = token return token def _populate_attributes(self, token): - if 'oauth_token' in token: - self._client.client.resource_owner_key = token['oauth_token'] + if "oauth_token" in token: + self._client.client.resource_owner_key = token["oauth_token"] else: raise TokenMissing( - 'Response does not contain a token: {resp}'.format(resp=token), - token, + "Response does not contain a token: {resp}".format(resp=token), token ) - if 'oauth_token_secret' in token: - self._client.client.resource_owner_secret = ( - token['oauth_token_secret']) - if 'oauth_verifier' in token: - self._client.client.verifier = token['oauth_verifier'] + if "oauth_token_secret" in token: + self._client.client.resource_owner_secret = token["oauth_token_secret"] + if "oauth_verifier" in token: + self._client.client.verifier = token["oauth_verifier"] def _fetch_token(self, url, **request_kwargs): - log.debug('Fetching token from %s using client %s', url, self._client.client) + log.debug("Fetching token from %s using client %s", url, self._client.client) r = self.post(url, **request_kwargs) if r.status_code >= 400: @@ -352,17 +370,21 @@ class OAuth1Session(requests.Session): log.debug('Decoding token from response "%s"', r.text) try: - token = dict(urldecode(r.text)) + token = dict(urldecode(r.text.strip())) except ValueError as e: - error = ("Unable to decode token from token response. " - "This is commonly caused by an unsuccessful request where" - " a non urlencoded error message is returned. " - "The decoding error was %s""" % e) + error = ( + "Unable to decode token from token response. " + "This is commonly caused by an unsuccessful request where" + " a non urlencoded error message is returned. " + "The decoding error was %s" + "" % e + ) raise ValueError(error) - log.debug('Obtained token %s', token) - log.debug('Updating internal client attributes from token data.') + log.debug("Obtained token %s", token) + log.debug("Updating internal client attributes from token data.") self._populate_attributes(token) + self.token = token return token def rebuild_auth(self, prepared_request, response): @@ -370,9 +392,9 @@ class OAuth1Session(requests.Session): When being redirected we should always strip Authorization header, since nonce may not be reused as per OAuth spec. """ - if 'Authorization' in prepared_request.headers: + if "Authorization" in prepared_request.headers: # If we get redirected to a new host, we should strip out # any authentication headers. - prepared_request.headers.pop('Authorization', True) + prepared_request.headers.pop("Authorization", True) prepared_request.prepare_auth(self.auth) return diff --git a/lib/requests_oauthlib/oauth2_auth.py b/lib/requests_oauthlib/oauth2_auth.py index 0ce58cc9..b880f72f 100644 --- a/lib/requests_oauthlib/oauth2_auth.py +++ b/lib/requests_oauthlib/oauth2_auth.py @@ -31,6 +31,7 @@ class OAuth2(AuthBase): """ if not is_secure_transport(r.url): raise InsecureTransportError() - r.url, r.headers, r.body = self._client.add_token(r.url, - http_method=r.method, body=r.body, headers=r.headers) + r.url, r.headers, r.body = self._client.add_token( + r.url, http_method=r.method, body=r.body, headers=r.headers + ) return r diff --git a/lib/requests_oauthlib/oauth2_session.py b/lib/requests_oauthlib/oauth2_session.py index b026a7f3..eea4ac6f 100644 --- a/lib/requests_oauthlib/oauth2_session.py +++ b/lib/requests_oauthlib/oauth2_session.py @@ -4,6 +4,7 @@ import logging from oauthlib.common import generate_token, urldecode from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError +from oauthlib.oauth2 import LegacyApplicationClient from oauthlib.oauth2 import TokenExpiredError, is_secure_transport import requests @@ -34,9 +35,19 @@ class OAuth2Session(requests.Session): you are driving a user agent able to obtain URL fragments. """ - def __init__(self, client_id=None, client=None, auto_refresh_url=None, - auto_refresh_kwargs=None, scope=None, redirect_uri=None, token=None, - state=None, token_updater=None, **kwargs): + def __init__( + self, + client_id=None, + client=None, + auto_refresh_url=None, + auto_refresh_kwargs=None, + scope=None, + redirect_uri=None, + token=None, + state=None, + token_updater=None, + **kwargs + ): """Construct a new OAuth 2 client session. :param client_id: Client id obtained during registration @@ -57,7 +68,7 @@ class OAuth2Session(requests.Session): :auto_refresh_kwargs: Extra arguments to pass to the refresh token endpoint. :token_updater: Method with one argument, token, to be used to update - your token databse on automatic token refresh. If not + your token database on automatic token refresh. If not set a TokenUpdated warning will be raised when a token has been refreshed. This warning will carry the token in its token argument. @@ -74,22 +85,26 @@ class OAuth2Session(requests.Session): self.auto_refresh_kwargs = auto_refresh_kwargs or {} self.token_updater = token_updater + # Ensure that requests doesn't do any automatic auth. See #278. + # The default behavior can be re-enabled by setting auth to None. + self.auth = lambda r: r + # Allow customizations for non compliant providers through various # hooks to adjust requests and responses. self.compliance_hook = { - 'access_token_response': set([]), - 'refresh_token_response': set([]), - 'protected_request': set([]), + "access_token_response": set(), + "refresh_token_response": set(), + "protected_request": set(), } def new_state(self): """Generates a state string to be used in authorizations.""" try: self._state = self.state() - log.debug('Generated new state %s.', self._state) + log.debug("Generated new state %s.", self._state) except TypeError: self._state = self.state - log.debug('Re-using previously supplied state %s.', self._state) + log.debug("Re-using previously supplied state %s.", self._state) return self._state @property @@ -111,7 +126,7 @@ class OAuth2Session(requests.Session): @token.setter def token(self, value): self._client.token = value - self._client._populate_attributes(value) + self._client.populate_token_attributes(value) @property def access_token(self): @@ -146,19 +161,42 @@ class OAuth2Session(requests.Session): :return: authorization_url, state """ state = state or self.new_state() - return self._client.prepare_request_uri(url, + return ( + self._client.prepare_request_uri( + url, redirect_uri=self.redirect_uri, scope=self.scope, state=state, - **kwargs), state + **kwargs + ), + state, + ) - def fetch_token(self, token_url, code=None, authorization_response=None, - body='', auth=None, username=None, password=None, method='POST', - timeout=None, headers=None, verify=True, **kwargs): + def fetch_token( + self, + token_url, + code=None, + authorization_response=None, + body="", + auth=None, + username=None, + password=None, + method="POST", + force_querystring=False, + timeout=None, + headers=None, + verify=True, + proxies=None, + include_client_id=None, + client_secret=None, + **kwargs + ): """Generic method for fetching an access token from the token endpoint. If you are using the MobileApplicationClient you will want to use - token_from_fragment instead of fetch_token. + `token_from_fragment` instead of `fetch_token`. + + The current implementation enforces the RFC guidelines. :param token_url: Token endpoint URL, must use HTTPS. :param code: Authorization code (used by WebApplicationClients). @@ -167,15 +205,30 @@ class OAuth2Session(requests.Session): WebApplicationClients instead of code. :param body: Optional application/x-www-form-urlencoded body to add the include in the token request. Prefer kwargs over body. - :param auth: An auth tuple or method as accepted by requests. - :param username: Username used by LegacyApplicationClients. - :param password: Password used by LegacyApplicationClients. + :param auth: An auth tuple or method as accepted by `requests`. + :param username: Username required by LegacyApplicationClients to appear + in the request body. + :param password: Password required by LegacyApplicationClients to appear + in the request body. :param method: The HTTP method used to make the request. Defaults to POST, but may also be GET. Other methods should be added as needed. - :param headers: Dict to default request headers with. + :param force_querystring: If True, force the request body to be sent + in the querystring instead. :param timeout: Timeout of the request in seconds. + :param headers: Dict to default request headers with. :param verify: Verify SSL certificate. + :param proxies: The `proxies` argument is passed onto `requests`. + :param include_client_id: Should the request body include the + `client_id` parameter. Default is `None`, + which will attempt to autodetect. This can be + forced to always include (True) or never + include (False). + :param client_secret: The `client_secret` paired to the `client_id`. + This is generally required unless provided in the + `auth` tuple. If the value is `None`, it will be + omitted from the request, however if the value is + an empty string, an empty string will be sent. :param kwargs: Extra parameters to include in the token request. :return: A token dict """ @@ -183,59 +236,130 @@ class OAuth2Session(requests.Session): raise InsecureTransportError() if not code and authorization_response: - self._client.parse_request_uri_response(authorization_response, - state=self._state) + self._client.parse_request_uri_response( + authorization_response, state=self._state + ) code = self._client.code elif not code and isinstance(self._client, WebApplicationClient): code = self._client.code if not code: - raise ValueError('Please supply either code or ' - 'authorization_code parameters.') + raise ValueError( + "Please supply either code or " "authorization_response parameters." + ) + # Earlier versions of this library build an HTTPBasicAuth header out of + # `username` and `password`. The RFC states, however these attributes + # must be in the request body and not the header. + # If an upstream server is not spec compliant and requires them to + # appear as an Authorization header, supply an explicit `auth` header + # to this function. + # This check will allow for empty strings, but not `None`. + # + # References + # 4.3.2 - Resource Owner Password Credentials Grant + # https://tools.ietf.org/html/rfc6749#section-4.3.2 - body = self._client.prepare_request_body(code=code, body=body, - redirect_uri=self.redirect_uri, username=username, - password=password, **kwargs) - - if (not auth) and username: + if isinstance(self._client, LegacyApplicationClient): + if username is None: + raise ValueError( + "`LegacyApplicationClient` requires both the " + "`username` and `password` parameters." + ) if password is None: - raise ValueError('Username was supplied, but not password.') - auth = requests.auth.HTTPBasicAuth(username, password) + raise ValueError( + "The required parameter `username` was supplied, " + "but `password` was not." + ) + + # merge username and password into kwargs for `prepare_request_body` + if username is not None: + kwargs["username"] = username + if password is not None: + kwargs["password"] = password + + # is an auth explicitly supplied? + if auth is not None: + # if we're dealing with the default of `include_client_id` (None): + # we will assume the `auth` argument is for an RFC compliant server + # and we should not send the `client_id` in the body. + # This approach allows us to still force the client_id by submitting + # `include_client_id=True` along with an `auth` object. + if include_client_id is None: + include_client_id = False + + # otherwise we may need to create an auth header + else: + # since we don't have an auth header, we MAY need to create one + # it is possible that we want to send the `client_id` in the body + # if so, `include_client_id` should be set to True + # otherwise, we will generate an auth header + if include_client_id is not True: + client_id = self.client_id + if client_id: + log.debug( + 'Encoding `client_id` "%s" with `client_secret` ' + "as Basic auth credentials.", + client_id, + ) + client_secret = client_secret if client_secret is not None else "" + auth = requests.auth.HTTPBasicAuth(client_id, client_secret) + + if include_client_id: + # this was pulled out of the params + # it needs to be passed into prepare_request_body + if client_secret is not None: + kwargs["client_secret"] = client_secret + + body = self._client.prepare_request_body( + code=code, + body=body, + redirect_uri=self.redirect_uri, + include_client_id=include_client_id, + **kwargs + ) headers = headers or { - 'Accept': 'application/json', - 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8", } self.token = {} - if method.upper() == 'POST': - r = self.post(token_url, data=dict(urldecode(body)), - timeout=timeout, headers=headers, auth=auth, - verify=verify) - log.debug('Prepared fetch token request body %s', body) - elif method.upper() == 'GET': - # if method is not 'POST', switch body to querystring and GET - r = self.get(token_url, params=dict(urldecode(body)), - timeout=timeout, headers=headers, auth=auth, - verify=verify) - log.debug('Prepared fetch token request querystring %s', body) + request_kwargs = {} + if method.upper() == "POST": + request_kwargs["params" if force_querystring else "data"] = dict( + urldecode(body) + ) + elif method.upper() == "GET": + request_kwargs["params"] = dict(urldecode(body)) else: - raise ValueError('The method kwarg must be POST or GET.') + raise ValueError("The method kwarg must be POST or GET.") - log.debug('Request to fetch token completed with status %s.', - r.status_code) - log.debug('Request headers were %s', r.request.headers) - log.debug('Request body was %s', r.request.body) - log.debug('Response headers were %s and content %s.', - r.headers, r.text) - log.debug('Invoking %d token response hooks.', - len(self.compliance_hook['access_token_response'])) - for hook in self.compliance_hook['access_token_response']: - log.debug('Invoking hook %s.', hook) + r = self.request( + method=method, + url=token_url, + timeout=timeout, + headers=headers, + auth=auth, + verify=verify, + proxies=proxies, + **request_kwargs + ) + + log.debug("Request to fetch token completed with status %s.", r.status_code) + log.debug("Request url was %s", r.request.url) + log.debug("Request headers were %s", r.request.headers) + log.debug("Request body was %s", r.request.body) + log.debug("Response headers were %s and content %s.", r.headers, r.text) + log.debug( + "Invoking %d token response hooks.", + len(self.compliance_hook["access_token_response"]), + ) + for hook in self.compliance_hook["access_token_response"]: + log.debug("Invoking hook %s.", hook) r = hook(r) self._client.parse_request_body_response(r.text, scope=self.scope) self.token = self._client.token - log.debug('Obtained token %s.', self.token) + log.debug("Obtained token %s.", self.token) return self.token def token_from_fragment(self, authorization_response): @@ -244,103 +368,153 @@ class OAuth2Session(requests.Session): :param authorization_response: The full URL of the redirect back to you :return: A token dict """ - self._client.parse_request_uri_response(authorization_response, - state=self._state) + self._client.parse_request_uri_response( + authorization_response, state=self._state + ) self.token = self._client.token return self.token - def refresh_token(self, token_url, refresh_token=None, body='', auth=None, - timeout=None, headers=None, verify=True, **kwargs): + def refresh_token( + self, + token_url, + refresh_token=None, + body="", + auth=None, + timeout=None, + headers=None, + verify=True, + proxies=None, + **kwargs + ): """Fetch a new access token using a refresh token. :param token_url: The token endpoint, must be HTTPS. :param refresh_token: The refresh_token to use. :param body: Optional application/x-www-form-urlencoded body to add the include in the token request. Prefer kwargs over body. - :param auth: An auth tuple or method as accepted by requests. + :param auth: An auth tuple or method as accepted by `requests`. :param timeout: Timeout of the request in seconds. + :param headers: A dict of headers to be used by `requests`. :param verify: Verify SSL certificate. + :param proxies: The `proxies` argument will be passed to `requests`. :param kwargs: Extra parameters to include in the token request. :return: A token dict """ if not token_url: - raise ValueError('No token endpoint set for auto_refresh.') + raise ValueError("No token endpoint set for auto_refresh.") if not is_secure_transport(token_url): raise InsecureTransportError() - refresh_token = refresh_token or self.token.get('refresh_token') + refresh_token = refresh_token or self.token.get("refresh_token") - log.debug('Adding auto refresh key word arguments %s.', - self.auto_refresh_kwargs) + log.debug( + "Adding auto refresh key word arguments %s.", self.auto_refresh_kwargs + ) kwargs.update(self.auto_refresh_kwargs) - body = self._client.prepare_refresh_body(body=body, - refresh_token=refresh_token, scope=self.scope, **kwargs) - log.debug('Prepared refresh token request body %s', body) + body = self._client.prepare_refresh_body( + body=body, refresh_token=refresh_token, scope=self.scope, **kwargs + ) + log.debug("Prepared refresh token request body %s", body) if headers is None: headers = { - 'Accept': 'application/json', - 'Content-Type': ( - 'application/x-www-form-urlencoded;charset=UTF-8' - ), + "Accept": "application/json", + "Content-Type": ("application/x-www-form-urlencoded;charset=UTF-8"), } - r = self.post(token_url, data=dict(urldecode(body)), auth=auth, - timeout=timeout, headers=headers, verify=verify, withhold_token=True) - log.debug('Request to refresh token completed with status %s.', - r.status_code) - log.debug('Response headers were %s and content %s.', - r.headers, r.text) - log.debug('Invoking %d token response hooks.', - len(self.compliance_hook['refresh_token_response'])) - for hook in self.compliance_hook['refresh_token_response']: - log.debug('Invoking hook %s.', hook) + r = self.post( + token_url, + data=dict(urldecode(body)), + auth=auth, + timeout=timeout, + headers=headers, + verify=verify, + withhold_token=True, + proxies=proxies, + ) + log.debug("Request to refresh token completed with status %s.", r.status_code) + log.debug("Response headers were %s and content %s.", r.headers, r.text) + log.debug( + "Invoking %d token response hooks.", + len(self.compliance_hook["refresh_token_response"]), + ) + for hook in self.compliance_hook["refresh_token_response"]: + log.debug("Invoking hook %s.", hook) r = hook(r) self.token = self._client.parse_request_body_response(r.text, scope=self.scope) - if not 'refresh_token' in self.token: - log.debug('No new refresh token given. Re-using old.') - self.token['refresh_token'] = refresh_token + if not "refresh_token" in self.token: + log.debug("No new refresh token given. Re-using old.") + self.token["refresh_token"] = refresh_token return self.token - def request(self, method, url, data=None, headers=None, withhold_token=False, **kwargs): + def request( + self, + method, + url, + data=None, + headers=None, + withhold_token=False, + client_id=None, + client_secret=None, + **kwargs + ): """Intercept all requests and add the OAuth 2 token if present.""" if not is_secure_transport(url): raise InsecureTransportError() if self.token and not withhold_token: - log.debug('Invoking %d protected resource request hooks.', - len(self.compliance_hook['protected_request'])) - for hook in self.compliance_hook['protected_request']: - log.debug('Invoking hook %s.', hook) + log.debug( + "Invoking %d protected resource request hooks.", + len(self.compliance_hook["protected_request"]), + ) + for hook in self.compliance_hook["protected_request"]: + log.debug("Invoking hook %s.", hook) url, headers, data = hook(url, headers, data) - log.debug('Adding token %s to request.', self.token) + log.debug("Adding token %s to request.", self.token) try: - url, headers, data = self._client.add_token(url, - http_method=method, body=data, headers=headers) + url, headers, data = self._client.add_token( + url, http_method=method, body=data, headers=headers + ) # Attempt to retrieve and save new access token if expired except TokenExpiredError: if self.auto_refresh_url: - log.debug('Auto refresh is set, attempting to refresh at %s.', - self.auto_refresh_url) - token = self.refresh_token(self.auto_refresh_url, **kwargs) + log.debug( + "Auto refresh is set, attempting to refresh at %s.", + self.auto_refresh_url, + ) + + # We mustn't pass auth twice. + auth = kwargs.pop("auth", None) + if client_id and client_secret and (auth is None): + log.debug( + 'Encoding client_id "%s" with client_secret as Basic auth credentials.', + client_id, + ) + auth = requests.auth.HTTPBasicAuth(client_id, client_secret) + token = self.refresh_token( + self.auto_refresh_url, auth=auth, **kwargs + ) if self.token_updater: - log.debug('Updating token to %s using %s.', - token, self.token_updater) + log.debug( + "Updating token to %s using %s.", token, self.token_updater + ) self.token_updater(token) - url, headers, data = self._client.add_token(url, - http_method=method, body=data, headers=headers) + url, headers, data = self._client.add_token( + url, http_method=method, body=data, headers=headers + ) else: raise TokenUpdated(token) else: raise - log.debug('Requesting url %s using method %s.', url, method) - log.debug('Supplying headers %s and data %s', headers, data) - log.debug('Passing through key word arguments %s.', kwargs) - return super(OAuth2Session, self).request(method, url, - headers=headers, data=data, **kwargs) + log.debug("Requesting url %s using method %s.", url, method) + log.debug("Supplying headers %s and data %s", headers, data) + log.debug("Passing through key word arguments %s.", kwargs) + return super(OAuth2Session, self).request( + method, url, headers=headers, data=data, **kwargs + ) def register_compliance_hook(self, hook_type, hook): """Register a hook for request/response tweaking. @@ -354,6 +528,7 @@ class OAuth2Session(requests.Session): or open an issue. """ if hook_type not in self.compliance_hook: - raise ValueError('Hook type %s is not in %s.', - hook_type, self.compliance_hook) + raise ValueError( + "Hook type %s is not in %s.", hook_type, self.compliance_hook + ) self.compliance_hook[hook_type].add(hook)