Update requests-oauthlib-1.3.0

This commit is contained in:
JonnyWong16 2021-10-14 23:47:27 -07:00
parent e55576fd80
commit f165d2d080
No known key found for this signature in database
GPG key ID: B1F1F9807184697A
15 changed files with 552 additions and 257 deletions

View file

@ -1,22 +1,19 @@
import logging
from .oauth1_auth import OAuth1 from .oauth1_auth import OAuth1
from .oauth1_session import OAuth1Session from .oauth1_session import OAuth1Session
from .oauth2_auth import OAuth2 from .oauth2_auth import OAuth2
from .oauth2_session import OAuth2Session, TokenUpdated from .oauth2_session import OAuth2Session, TokenUpdated
__version__ = '0.6.1' __version__ = "1.3.0"
import requests import requests
if requests.__version__ < '2.0.0':
msg = ('You are using requests version %s, which is older than ' if requests.__version__ < "2.0.0":
'requests-oauthlib expects, please upgrade to 2.0.0 or later.') msg = (
"You are using requests version %s, which is older than "
"requests-oauthlib expects, please upgrade to 2.0.0 or later."
)
raise Warning(msg % requests.__version__) raise Warning(msg % requests.__version__)
import logging logging.getLogger("requests_oauthlib").addHandler(logging.NullHandler())
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
logging.getLogger('requests_oauthlib').addHandler(NullHandler())

View file

@ -1,7 +1,10 @@
from __future__ import absolute_import from __future__ import absolute_import
from .facebook import facebook_compliance_fix from .facebook import facebook_compliance_fix
from .fitbit import fitbit_compliance_fix
from .linkedin import linkedin_compliance_fix from .linkedin import linkedin_compliance_fix
from .slack import slack_compliance_fix from .slack import slack_compliance_fix
from .instagram import instagram_compliance_fix
from .mailchimp import mailchimp_compliance_fix from .mailchimp import mailchimp_compliance_fix
from .weibo import weibo_compliance_fix from .weibo import weibo_compliance_fix
from .plentymarkets import plentymarkets_compliance_fix

View file

@ -4,15 +4,14 @@ from oauthlib.common import to_unicode
def douban_compliance_fix(session): def douban_compliance_fix(session):
def fix_token_type(r): def fix_token_type(r):
token = json.loads(r.text) token = json.loads(r.text)
token.setdefault('token_type', 'Bearer') token.setdefault("token_type", "Bearer")
fixed_token = json.dumps(token) fixed_token = json.dumps(token)
r._content = to_unicode(fixed_token).encode('utf-8') r._content = to_unicode(fixed_token).encode("utf-8")
return r return r
session._client_default_token_placement = 'query' session._client_default_token_placement = "query"
session.register_compliance_hook('access_token_response', fix_token_type) session.register_compliance_hook("access_token_response", fix_token_type)
return session return session

View file

@ -1,4 +1,5 @@
from json import dumps from json import dumps
try: try:
from urlparse import parse_qsl from urlparse import parse_qsl
except ImportError: except ImportError:
@ -8,26 +9,25 @@ from oauthlib.common import to_unicode
def facebook_compliance_fix(session): def facebook_compliance_fix(session):
def _compliance_fix(r): def _compliance_fix(r):
# if Facebook claims to be sending us json, let's trust them. # if Facebook claims to be sending us json, let's trust them.
if 'application/json' in r.headers.get('content-type', {}): if "application/json" in r.headers.get("content-type", {}):
return r return r
# Facebook returns a content-type of text/plain when sending their # Facebook returns a content-type of text/plain when sending their
# x-www-form-urlencoded responses, along with a 200. If not, let's # x-www-form-urlencoded responses, along with a 200. If not, let's
# assume we're getting JSON and bail on the fix. # assume we're getting JSON and bail on the fix.
if 'text/plain' in r.headers.get('content-type', {}) and r.status_code == 200: if "text/plain" in r.headers.get("content-type", {}) and r.status_code == 200:
token = dict(parse_qsl(r.text, keep_blank_values=True)) token = dict(parse_qsl(r.text, keep_blank_values=True))
else: else:
return r return r
expires = token.get('expires') expires = token.get("expires")
if expires is not None: if expires is not None:
token['expires_in'] = expires token["expires_in"] = expires
token['token_type'] = 'Bearer' token["token_type"] = "Bearer"
r._content = to_unicode(dumps(token)).encode('UTF-8') r._content = to_unicode(dumps(token)).encode("UTF-8")
return r return r
session.register_compliance_hook('access_token_response', _compliance_fix) session.register_compliance_hook("access_token_response", _compliance_fix)
return session return session

View file

@ -0,0 +1,25 @@
"""
The Fitbit API breaks from the OAuth2 RFC standard by returning an "errors"
object list, rather than a single "error" string. This puts hooks in place so
that oauthlib can process an error in the results from access token and refresh
token responses. This is necessary to prevent getting the generic red herring
MissingTokenError.
"""
from json import loads, dumps
from oauthlib.common import to_unicode
def fitbit_compliance_fix(session):
def _missing_error(r):
token = loads(r.text)
if "errors" in token:
# Set the error to the first one we have
token["error"] = token["errors"][0]["errorType"]
r._content = to_unicode(dumps(token)).encode("UTF-8")
return r
session.register_compliance_hook("access_token_response", _missing_error)
session.register_compliance_hook("refresh_token_response", _missing_error)
return session

View file

@ -0,0 +1,26 @@
try:
from urlparse import urlparse, parse_qs
except ImportError:
from urllib.parse import urlparse, parse_qs
from oauthlib.common import add_params_to_uri
def instagram_compliance_fix(session):
def _non_compliant_param_name(url, headers, data):
# If the user has already specified the token in the URL
# then there's nothing to do.
# If the specified token is different from ``session.access_token``,
# we assume the user intends to override the access token.
url_query = dict(parse_qs(urlparse(url).query))
token = url_query.get("access_token")
if token:
# Nothing to do, just return.
return url, headers, data
token = [("access_token", session.access_token)]
url = add_params_to_uri(url, token)
return url, headers, data
session.register_compliance_hook("protected_request", _non_compliant_param_name)
return session

View file

@ -4,21 +4,18 @@ from oauthlib.common import add_params_to_uri, to_unicode
def linkedin_compliance_fix(session): def linkedin_compliance_fix(session):
def _missing_token_type(r): def _missing_token_type(r):
token = loads(r.text) token = loads(r.text)
token['token_type'] = 'Bearer' token["token_type"] = "Bearer"
r._content = to_unicode(dumps(token)).encode('UTF-8') r._content = to_unicode(dumps(token)).encode("UTF-8")
return r return r
def _non_compliant_param_name(url, headers, data): def _non_compliant_param_name(url, headers, data):
token = [('oauth2_access_token', session.access_token)] token = [("oauth2_access_token", session.access_token)]
url = add_params_to_uri(url, token) url = add_params_to_uri(url, token)
return url, headers, data return url, headers, data
session._client.default_token_placement = 'query' session._client.default_token_placement = "query"
session.register_compliance_hook('access_token_response', session.register_compliance_hook("access_token_response", _missing_token_type)
_missing_token_type) session.register_compliance_hook("protected_request", _non_compliant_param_name)
session.register_compliance_hook('protected_request',
_non_compliant_param_name)
return session return session

View file

@ -2,21 +2,22 @@ import json
from oauthlib.common import to_unicode from oauthlib.common import to_unicode
def mailchimp_compliance_fix(session): def mailchimp_compliance_fix(session):
def _null_scope(r): def _null_scope(r):
token = json.loads(r.text) token = json.loads(r.text)
if 'scope' in token and token['scope'] is None: if "scope" in token and token["scope"] is None:
token.pop('scope') token.pop("scope")
r._content = to_unicode(json.dumps(token)).encode('utf-8') r._content = to_unicode(json.dumps(token)).encode("utf-8")
return r return r
def _non_zero_expiration(r): def _non_zero_expiration(r):
token = json.loads(r.text) token = json.loads(r.text)
if 'expires_in' in token and token['expires_in'] == 0: if "expires_in" in token and token["expires_in"] == 0:
token['expires_in'] = 3600 token["expires_in"] = 3600
r._content = to_unicode(json.dumps(token)).encode('utf-8') r._content = to_unicode(json.dumps(token)).encode("utf-8")
return r return r
session.register_compliance_hook('access_token_response', _null_scope) session.register_compliance_hook("access_token_response", _null_scope)
session.register_compliance_hook('access_token_response', _non_zero_expiration) session.register_compliance_hook("access_token_response", _non_zero_expiration)
return session return session

View file

@ -0,0 +1,29 @@
from json import dumps, loads
import re
from oauthlib.common import to_unicode
def plentymarkets_compliance_fix(session):
def _to_snake_case(n):
return re.sub("(.)([A-Z][a-z]+)", r"\1_\2", n).lower()
def _compliance_fix(r):
# Plenty returns the Token in CamelCase instead of _
if (
"application/json" in r.headers.get("content-type", {})
and r.status_code == 200
):
token = loads(r.text)
else:
return r
fixed_token = {}
for k, v in token.items():
fixed_token[_to_snake_case(k)] = v
r._content = to_unicode(dumps(fixed_token)).encode("UTF-8")
return r
session.register_compliance_hook("access_token_response", _compliance_fix)
return session

View file

@ -29,9 +29,9 @@ def slack_compliance_fix(session):
# ``data`` is something other than a dict: maybe a stream, # ``data`` is something other than a dict: maybe a stream,
# maybe a file object, maybe something else. We can't easily # maybe a file object, maybe something else. We can't easily
# modify it, so we'll set the token by modifying the URL instead. # modify it, so we'll set the token by modifying the URL instead.
token = [('token', session.access_token)] token = [("token", session.access_token)]
url = add_params_to_uri(url, token) url = add_params_to_uri(url, token)
return url, headers, data return url, headers, data
session.register_compliance_hook('protected_request', _non_compliant_param_name) session.register_compliance_hook("protected_request", _non_compliant_param_name)
return session return session

View file

@ -4,14 +4,12 @@ from oauthlib.common import to_unicode
def weibo_compliance_fix(session): def weibo_compliance_fix(session):
def _missing_token_type(r): def _missing_token_type(r):
token = loads(r.text) token = loads(r.text)
token['token_type'] = 'Bearer' token["token_type"] = "Bearer"
r._content = to_unicode(dumps(token)).encode('UTF-8') r._content = to_unicode(dumps(token)).encode("UTF-8")
return r return r
session._client.default_token_placement = 'query' session._client.default_token_placement = "query"
session.register_compliance_hook('access_token_response', session.register_compliance_hook("access_token_response", _missing_token_type)
_missing_token_type)
return session return session

View file

@ -10,8 +10,8 @@ from requests.compat import is_py3
from requests.utils import to_native_string from requests.utils import to_native_string
from requests.auth import AuthBase from requests.auth import AuthBase
CONTENT_TYPE_FORM_URLENCODED = 'application/x-www-form-urlencoded' CONTENT_TYPE_FORM_URLENCODED = "application/x-www-form-urlencoded"
CONTENT_TYPE_MULTI_PART = 'multipart/form-data' CONTENT_TYPE_MULTI_PART = "multipart/form-data"
if is_py3: if is_py3:
unicode = str unicode = str
@ -26,18 +26,22 @@ class OAuth1(AuthBase):
client_class = Client client_class = Client
def __init__(self, client_key, def __init__(
client_secret=None, self,
resource_owner_key=None, client_key,
resource_owner_secret=None, client_secret=None,
callback_uri=None, resource_owner_key=None,
signature_method=SIGNATURE_HMAC, resource_owner_secret=None,
signature_type=SIGNATURE_TYPE_AUTH_HEADER, callback_uri=None,
rsa_key=None, verifier=None, signature_method=SIGNATURE_HMAC,
decoding='utf-8', signature_type=SIGNATURE_TYPE_AUTH_HEADER,
client_class=None, rsa_key=None,
force_include_body=False, verifier=None,
**kwargs): decoding="utf-8",
client_class=None,
force_include_body=False,
**kwargs
):
try: try:
signature_type = signature_type.upper() signature_type = signature_type.upper()
@ -48,9 +52,19 @@ class OAuth1(AuthBase):
self.force_include_body = force_include_body self.force_include_body = force_include_body
self.client = client_class(client_key, client_secret, resource_owner_key, self.client = client_class(
resource_owner_secret, callback_uri, signature_method, client_key,
signature_type, rsa_key, verifier, decoding=decoding, **kwargs) client_secret,
resource_owner_key,
resource_owner_secret,
callback_uri,
signature_method,
signature_type,
rsa_key,
verifier,
decoding=decoding,
**kwargs
)
def __call__(self, r): def __call__(self, r):
"""Add OAuth parameters to the request. """Add OAuth parameters to the request.
@ -60,36 +74,44 @@ class OAuth1(AuthBase):
""" """
# Overwriting url is safe here as request will not modify it past # Overwriting url is safe here as request will not modify it past
# this point. # this point.
log.debug('Signing request %s using client %s', r, self.client) log.debug("Signing request %s using client %s", r, self.client)
content_type = r.headers.get('Content-Type', '') content_type = r.headers.get("Content-Type", "")
if (not content_type and extract_params(r.body) if (
or self.client.signature_type == SIGNATURE_TYPE_BODY): not content_type
and extract_params(r.body)
or self.client.signature_type == SIGNATURE_TYPE_BODY
):
content_type = CONTENT_TYPE_FORM_URLENCODED content_type = CONTENT_TYPE_FORM_URLENCODED
if not isinstance(content_type, unicode): if not isinstance(content_type, unicode):
content_type = content_type.decode('utf-8') content_type = content_type.decode("utf-8")
is_form_encoded = (CONTENT_TYPE_FORM_URLENCODED in content_type) is_form_encoded = CONTENT_TYPE_FORM_URLENCODED in content_type
log.debug('Including body in call to sign: %s', log.debug(
is_form_encoded or self.force_include_body) "Including body in call to sign: %s",
is_form_encoded or self.force_include_body,
)
if is_form_encoded: if is_form_encoded:
r.headers['Content-Type'] = CONTENT_TYPE_FORM_URLENCODED r.headers["Content-Type"] = CONTENT_TYPE_FORM_URLENCODED
r.url, headers, r.body = self.client.sign( r.url, headers, r.body = self.client.sign(
unicode(r.url), unicode(r.method), r.body or '', r.headers) unicode(r.url), unicode(r.method), r.body or "", r.headers
)
elif self.force_include_body: elif self.force_include_body:
# To allow custom clients to work on non form encoded bodies. # To allow custom clients to work on non form encoded bodies.
r.url, headers, r.body = self.client.sign( r.url, headers, r.body = self.client.sign(
unicode(r.url), unicode(r.method), r.body or '', r.headers) unicode(r.url), unicode(r.method), r.body or "", r.headers
)
else: else:
# Omit body data in the signing of non form-encoded requests # Omit body data in the signing of non form-encoded requests
r.url, headers, _ = self.client.sign( r.url, headers, _ = self.client.sign(
unicode(r.url), unicode(r.method), None, r.headers) unicode(r.url), unicode(r.method), None, r.headers
)
r.prepare_headers(headers) r.prepare_headers(headers)
r.url = to_native_string(r.url) r.url = to_native_string(r.url)
log.debug('Updated url: %s', r.url) log.debug("Updated url: %s", r.url)
log.debug('Updated headers: %s', headers) log.debug("Updated headers: %s", headers)
log.debug('Updated body: %r', r.body) log.debug("Updated body: %r", r.body)
return r return r

View file

@ -9,17 +9,11 @@ import logging
from oauthlib.common import add_params_to_uri from oauthlib.common import add_params_to_uri
from oauthlib.common import urldecode as _urldecode from oauthlib.common import urldecode as _urldecode
from oauthlib.oauth1 import ( from oauthlib.oauth1 import SIGNATURE_HMAC, SIGNATURE_RSA, SIGNATURE_TYPE_AUTH_HEADER
SIGNATURE_HMAC, SIGNATURE_RSA, SIGNATURE_TYPE_AUTH_HEADER
)
import requests import requests
from . import OAuth1 from . import OAuth1
import sys
if sys.version > "3":
unicode = str
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -28,13 +22,13 @@ def urldecode(body):
"""Parse query or json to python dictionary""" """Parse query or json to python dictionary"""
try: try:
return _urldecode(body) return _urldecode(body)
except: except Exception:
import json import json
return json.loads(body) return json.loads(body)
class TokenRequestDenied(ValueError): class TokenRequestDenied(ValueError):
def __init__(self, message, response): def __init__(self, message, response):
super(TokenRequestDenied, self).__init__(message) super(TokenRequestDenied, self).__init__(message)
self.response = response self.response = response
@ -110,18 +104,21 @@ class OAuth1Session(requests.Session):
<Response [200]> <Response [200]>
""" """
def __init__(self, client_key, def __init__(
client_secret=None, self,
resource_owner_key=None, client_key,
resource_owner_secret=None, client_secret=None,
callback_uri=None, resource_owner_key=None,
signature_method=SIGNATURE_HMAC, resource_owner_secret=None,
signature_type=SIGNATURE_TYPE_AUTH_HEADER, callback_uri=None,
rsa_key=None, signature_method=SIGNATURE_HMAC,
verifier=None, signature_type=SIGNATURE_TYPE_AUTH_HEADER,
client_class=None, rsa_key=None,
force_include_body=False, verifier=None,
**kwargs): client_class=None,
force_include_body=False,
**kwargs
):
"""Construct the OAuth 1 session. """Construct the OAuth 1 session.
:param client_key: A client specific identifier. :param client_key: A client specific identifier.
@ -158,20 +155,42 @@ class OAuth1Session(requests.Session):
:param **kwargs: Additional keyword arguments passed to `OAuth1` :param **kwargs: Additional keyword arguments passed to `OAuth1`
""" """
super(OAuth1Session, self).__init__() super(OAuth1Session, self).__init__()
self._client = OAuth1(client_key, self._client = OAuth1(
client_secret=client_secret, client_key,
resource_owner_key=resource_owner_key, client_secret=client_secret,
resource_owner_secret=resource_owner_secret, resource_owner_key=resource_owner_key,
callback_uri=callback_uri, resource_owner_secret=resource_owner_secret,
signature_method=signature_method, callback_uri=callback_uri,
signature_type=signature_type, signature_method=signature_method,
rsa_key=rsa_key, signature_type=signature_type,
verifier=verifier, rsa_key=rsa_key,
client_class=client_class, verifier=verifier,
force_include_body=force_include_body, client_class=client_class,
**kwargs) force_include_body=force_include_body,
**kwargs
)
self.auth = self._client self.auth = self._client
@property
def token(self):
oauth_token = self._client.client.resource_owner_key
oauth_token_secret = self._client.client.resource_owner_secret
oauth_verifier = self._client.client.verifier
token_dict = {}
if oauth_token:
token_dict["oauth_token"] = oauth_token
if oauth_token_secret:
token_dict["oauth_token_secret"] = oauth_token_secret
if oauth_verifier:
token_dict["oauth_verifier"] = oauth_verifier
return token_dict
@token.setter
def token(self, value):
self._populate_attributes(value)
@property @property
def authorized(self): def authorized(self):
"""Boolean that indicates whether this session has an OAuth token """Boolean that indicates whether this session has an OAuth token
@ -187,9 +206,9 @@ class OAuth1Session(requests.Session):
else: else:
# other methods of authentication use all three pieces # other methods of authentication use all three pieces
return ( return (
bool(self._client.client.client_secret) and bool(self._client.client.client_secret)
bool(self._client.client.resource_owner_key) and and bool(self._client.client.resource_owner_key)
bool(self._client.client.resource_owner_secret) and bool(self._client.client.resource_owner_secret)
) )
def authorization_url(self, url, request_token=None, **kwargs): def authorization_url(self, url, request_token=None, **kwargs):
@ -234,12 +253,12 @@ class OAuth1Session(requests.Session):
>>> oauth_session.authorization_url(authorization_url) >>> oauth_session.authorization_url(authorization_url)
'https://api.twitter.com/oauth/authorize?oauth_token=sdf0o9823sjdfsdf&oauth_callback=https%3A%2F%2F127.0.0.1%2Fcallback' 'https://api.twitter.com/oauth/authorize?oauth_token=sdf0o9823sjdfsdf&oauth_callback=https%3A%2F%2F127.0.0.1%2Fcallback'
""" """
kwargs['oauth_token'] = request_token or self._client.client.resource_owner_key kwargs["oauth_token"] = request_token or self._client.client.resource_owner_key
log.debug('Adding parameters %s to url %s', kwargs, url) log.debug("Adding parameters %s to url %s", kwargs, url)
return add_params_to_uri(url, kwargs.items()) return add_params_to_uri(url, kwargs.items())
def fetch_request_token(self, url, realm=None, **request_kwargs): def fetch_request_token(self, url, realm=None, **request_kwargs):
"""Fetch a request token. r"""Fetch a request token.
This is the first step in the OAuth 1 workflow. A request token is This is the first step in the OAuth 1 workflow. A request token is
obtained by making a signed post request to url. The token is then obtained by making a signed post request to url. The token is then
@ -264,9 +283,9 @@ class OAuth1Session(requests.Session):
'oauth_token_secret': '2kjshdfp92i34asdasd', 'oauth_token_secret': '2kjshdfp92i34asdasd',
} }
""" """
self._client.client.realm = ' '.join(realm) if realm else None self._client.client.realm = " ".join(realm) if realm else None
token = self._fetch_token(url, **request_kwargs) token = self._fetch_token(url, **request_kwargs)
log.debug('Resetting callback_uri and realm (not needed in next phase).') log.debug("Resetting callback_uri and realm (not needed in next phase).")
self._client.client.callback_uri = None self._client.client.callback_uri = None
self._client.client.realm = None self._client.client.realm = None
return token return token
@ -299,10 +318,10 @@ class OAuth1Session(requests.Session):
""" """
if verifier: if verifier:
self._client.client.verifier = verifier self._client.client.verifier = verifier
if not getattr(self._client.client, 'verifier', None): if not getattr(self._client.client, "verifier", None):
raise VerifierMissing('No client verifier has been set.') raise VerifierMissing("No client verifier has been set.")
token = self._fetch_token(url, **request_kwargs) token = self._fetch_token(url, **request_kwargs)
log.debug('Resetting verifier attribute, should not be used anymore.') log.debug("Resetting verifier attribute, should not be used anymore.")
self._client.client.verifier = None self._client.client.verifier = None
return token return token
@ -322,28 +341,27 @@ class OAuth1Session(requests.Session):
'oauth_verifier: 'w34o8967345', 'oauth_verifier: 'w34o8967345',
} }
""" """
log.debug('Parsing token from query part of url %s', url) log.debug("Parsing token from query part of url %s", url)
token = dict(urldecode(urlparse(url).query)) token = dict(urldecode(urlparse(url).query))
log.debug('Updating internal client token attribute.') log.debug("Updating internal client token attribute.")
self._populate_attributes(token) self._populate_attributes(token)
self.token = token
return token return token
def _populate_attributes(self, token): def _populate_attributes(self, token):
if 'oauth_token' in token: if "oauth_token" in token:
self._client.client.resource_owner_key = token['oauth_token'] self._client.client.resource_owner_key = token["oauth_token"]
else: else:
raise TokenMissing( raise TokenMissing(
'Response does not contain a token: {resp}'.format(resp=token), "Response does not contain a token: {resp}".format(resp=token), token
token,
) )
if 'oauth_token_secret' in token: if "oauth_token_secret" in token:
self._client.client.resource_owner_secret = ( self._client.client.resource_owner_secret = token["oauth_token_secret"]
token['oauth_token_secret']) if "oauth_verifier" in token:
if 'oauth_verifier' in token: self._client.client.verifier = token["oauth_verifier"]
self._client.client.verifier = token['oauth_verifier']
def _fetch_token(self, url, **request_kwargs): def _fetch_token(self, url, **request_kwargs):
log.debug('Fetching token from %s using client %s', url, self._client.client) log.debug("Fetching token from %s using client %s", url, self._client.client)
r = self.post(url, **request_kwargs) r = self.post(url, **request_kwargs)
if r.status_code >= 400: if r.status_code >= 400:
@ -352,17 +370,21 @@ class OAuth1Session(requests.Session):
log.debug('Decoding token from response "%s"', r.text) log.debug('Decoding token from response "%s"', r.text)
try: try:
token = dict(urldecode(r.text)) token = dict(urldecode(r.text.strip()))
except ValueError as e: except ValueError as e:
error = ("Unable to decode token from token response. " error = (
"This is commonly caused by an unsuccessful request where" "Unable to decode token from token response. "
" a non urlencoded error message is returned. " "This is commonly caused by an unsuccessful request where"
"The decoding error was %s""" % e) " a non urlencoded error message is returned. "
"The decoding error was %s"
"" % e
)
raise ValueError(error) raise ValueError(error)
log.debug('Obtained token %s', token) log.debug("Obtained token %s", token)
log.debug('Updating internal client attributes from token data.') log.debug("Updating internal client attributes from token data.")
self._populate_attributes(token) self._populate_attributes(token)
self.token = token
return token return token
def rebuild_auth(self, prepared_request, response): def rebuild_auth(self, prepared_request, response):
@ -370,9 +392,9 @@ class OAuth1Session(requests.Session):
When being redirected we should always strip Authorization When being redirected we should always strip Authorization
header, since nonce may not be reused as per OAuth spec. header, since nonce may not be reused as per OAuth spec.
""" """
if 'Authorization' in prepared_request.headers: if "Authorization" in prepared_request.headers:
# If we get redirected to a new host, we should strip out # If we get redirected to a new host, we should strip out
# any authentication headers. # any authentication headers.
prepared_request.headers.pop('Authorization', True) prepared_request.headers.pop("Authorization", True)
prepared_request.prepare_auth(self.auth) prepared_request.prepare_auth(self.auth)
return return

View file

@ -31,6 +31,7 @@ class OAuth2(AuthBase):
""" """
if not is_secure_transport(r.url): if not is_secure_transport(r.url):
raise InsecureTransportError() raise InsecureTransportError()
r.url, r.headers, r.body = self._client.add_token(r.url, r.url, r.headers, r.body = self._client.add_token(
http_method=r.method, body=r.body, headers=r.headers) r.url, http_method=r.method, body=r.body, headers=r.headers
)
return r return r

View file

@ -4,6 +4,7 @@ import logging
from oauthlib.common import generate_token, urldecode from oauthlib.common import generate_token, urldecode
from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError
from oauthlib.oauth2 import LegacyApplicationClient
from oauthlib.oauth2 import TokenExpiredError, is_secure_transport from oauthlib.oauth2 import TokenExpiredError, is_secure_transport
import requests import requests
@ -34,9 +35,19 @@ class OAuth2Session(requests.Session):
you are driving a user agent able to obtain URL fragments. you are driving a user agent able to obtain URL fragments.
""" """
def __init__(self, client_id=None, client=None, auto_refresh_url=None, def __init__(
auto_refresh_kwargs=None, scope=None, redirect_uri=None, token=None, self,
state=None, token_updater=None, **kwargs): client_id=None,
client=None,
auto_refresh_url=None,
auto_refresh_kwargs=None,
scope=None,
redirect_uri=None,
token=None,
state=None,
token_updater=None,
**kwargs
):
"""Construct a new OAuth 2 client session. """Construct a new OAuth 2 client session.
:param client_id: Client id obtained during registration :param client_id: Client id obtained during registration
@ -57,7 +68,7 @@ class OAuth2Session(requests.Session):
:auto_refresh_kwargs: Extra arguments to pass to the refresh token :auto_refresh_kwargs: Extra arguments to pass to the refresh token
endpoint. endpoint.
:token_updater: Method with one argument, token, to be used to update :token_updater: Method with one argument, token, to be used to update
your token databse on automatic token refresh. If not your token database on automatic token refresh. If not
set a TokenUpdated warning will be raised when a token set a TokenUpdated warning will be raised when a token
has been refreshed. This warning will carry the token has been refreshed. This warning will carry the token
in its token argument. in its token argument.
@ -74,22 +85,26 @@ class OAuth2Session(requests.Session):
self.auto_refresh_kwargs = auto_refresh_kwargs or {} self.auto_refresh_kwargs = auto_refresh_kwargs or {}
self.token_updater = token_updater self.token_updater = token_updater
# Ensure that requests doesn't do any automatic auth. See #278.
# The default behavior can be re-enabled by setting auth to None.
self.auth = lambda r: r
# Allow customizations for non compliant providers through various # Allow customizations for non compliant providers through various
# hooks to adjust requests and responses. # hooks to adjust requests and responses.
self.compliance_hook = { self.compliance_hook = {
'access_token_response': set([]), "access_token_response": set(),
'refresh_token_response': set([]), "refresh_token_response": set(),
'protected_request': set([]), "protected_request": set(),
} }
def new_state(self): def new_state(self):
"""Generates a state string to be used in authorizations.""" """Generates a state string to be used in authorizations."""
try: try:
self._state = self.state() self._state = self.state()
log.debug('Generated new state %s.', self._state) log.debug("Generated new state %s.", self._state)
except TypeError: except TypeError:
self._state = self.state self._state = self.state
log.debug('Re-using previously supplied state %s.', self._state) log.debug("Re-using previously supplied state %s.", self._state)
return self._state return self._state
@property @property
@ -111,7 +126,7 @@ class OAuth2Session(requests.Session):
@token.setter @token.setter
def token(self, value): def token(self, value):
self._client.token = value self._client.token = value
self._client._populate_attributes(value) self._client.populate_token_attributes(value)
@property @property
def access_token(self): def access_token(self):
@ -146,19 +161,42 @@ class OAuth2Session(requests.Session):
:return: authorization_url, state :return: authorization_url, state
""" """
state = state or self.new_state() state = state or self.new_state()
return self._client.prepare_request_uri(url, return (
self._client.prepare_request_uri(
url,
redirect_uri=self.redirect_uri, redirect_uri=self.redirect_uri,
scope=self.scope, scope=self.scope,
state=state, state=state,
**kwargs), state **kwargs
),
state,
)
def fetch_token(self, token_url, code=None, authorization_response=None, def fetch_token(
body='', auth=None, username=None, password=None, method='POST', self,
timeout=None, headers=None, verify=True, **kwargs): token_url,
code=None,
authorization_response=None,
body="",
auth=None,
username=None,
password=None,
method="POST",
force_querystring=False,
timeout=None,
headers=None,
verify=True,
proxies=None,
include_client_id=None,
client_secret=None,
**kwargs
):
"""Generic method for fetching an access token from the token endpoint. """Generic method for fetching an access token from the token endpoint.
If you are using the MobileApplicationClient you will want to use If you are using the MobileApplicationClient you will want to use
token_from_fragment instead of fetch_token. `token_from_fragment` instead of `fetch_token`.
The current implementation enforces the RFC guidelines.
:param token_url: Token endpoint URL, must use HTTPS. :param token_url: Token endpoint URL, must use HTTPS.
:param code: Authorization code (used by WebApplicationClients). :param code: Authorization code (used by WebApplicationClients).
@ -167,15 +205,30 @@ class OAuth2Session(requests.Session):
WebApplicationClients instead of code. WebApplicationClients instead of code.
:param body: Optional application/x-www-form-urlencoded body to add the :param body: Optional application/x-www-form-urlencoded body to add the
include in the token request. Prefer kwargs over body. include in the token request. Prefer kwargs over body.
:param auth: An auth tuple or method as accepted by requests. :param auth: An auth tuple or method as accepted by `requests`.
:param username: Username used by LegacyApplicationClients. :param username: Username required by LegacyApplicationClients to appear
:param password: Password used by LegacyApplicationClients. in the request body.
:param password: Password required by LegacyApplicationClients to appear
in the request body.
:param method: The HTTP method used to make the request. Defaults :param method: The HTTP method used to make the request. Defaults
to POST, but may also be GET. Other methods should to POST, but may also be GET. Other methods should
be added as needed. be added as needed.
:param headers: Dict to default request headers with. :param force_querystring: If True, force the request body to be sent
in the querystring instead.
:param timeout: Timeout of the request in seconds. :param timeout: Timeout of the request in seconds.
:param headers: Dict to default request headers with.
:param verify: Verify SSL certificate. :param verify: Verify SSL certificate.
:param proxies: The `proxies` argument is passed onto `requests`.
:param include_client_id: Should the request body include the
`client_id` parameter. Default is `None`,
which will attempt to autodetect. This can be
forced to always include (True) or never
include (False).
:param client_secret: The `client_secret` paired to the `client_id`.
This is generally required unless provided in the
`auth` tuple. If the value is `None`, it will be
omitted from the request, however if the value is
an empty string, an empty string will be sent.
:param kwargs: Extra parameters to include in the token request. :param kwargs: Extra parameters to include in the token request.
:return: A token dict :return: A token dict
""" """
@ -183,59 +236,130 @@ class OAuth2Session(requests.Session):
raise InsecureTransportError() raise InsecureTransportError()
if not code and authorization_response: if not code and authorization_response:
self._client.parse_request_uri_response(authorization_response, self._client.parse_request_uri_response(
state=self._state) authorization_response, state=self._state
)
code = self._client.code code = self._client.code
elif not code and isinstance(self._client, WebApplicationClient): elif not code and isinstance(self._client, WebApplicationClient):
code = self._client.code code = self._client.code
if not code: if not code:
raise ValueError('Please supply either code or ' raise ValueError(
'authorization_code parameters.') "Please supply either code or " "authorization_response parameters."
)
# Earlier versions of this library build an HTTPBasicAuth header out of
# `username` and `password`. The RFC states, however these attributes
# must be in the request body and not the header.
# If an upstream server is not spec compliant and requires them to
# appear as an Authorization header, supply an explicit `auth` header
# to this function.
# This check will allow for empty strings, but not `None`.
#
# References
# 4.3.2 - Resource Owner Password Credentials Grant
# https://tools.ietf.org/html/rfc6749#section-4.3.2
body = self._client.prepare_request_body(code=code, body=body, if isinstance(self._client, LegacyApplicationClient):
redirect_uri=self.redirect_uri, username=username, if username is None:
password=password, **kwargs) raise ValueError(
"`LegacyApplicationClient` requires both the "
if (not auth) and username: "`username` and `password` parameters."
)
if password is None: if password is None:
raise ValueError('Username was supplied, but not password.') raise ValueError(
auth = requests.auth.HTTPBasicAuth(username, password) "The required parameter `username` was supplied, "
"but `password` was not."
)
# merge username and password into kwargs for `prepare_request_body`
if username is not None:
kwargs["username"] = username
if password is not None:
kwargs["password"] = password
# is an auth explicitly supplied?
if auth is not None:
# if we're dealing with the default of `include_client_id` (None):
# we will assume the `auth` argument is for an RFC compliant server
# and we should not send the `client_id` in the body.
# This approach allows us to still force the client_id by submitting
# `include_client_id=True` along with an `auth` object.
if include_client_id is None:
include_client_id = False
# otherwise we may need to create an auth header
else:
# since we don't have an auth header, we MAY need to create one
# it is possible that we want to send the `client_id` in the body
# if so, `include_client_id` should be set to True
# otherwise, we will generate an auth header
if include_client_id is not True:
client_id = self.client_id
if client_id:
log.debug(
'Encoding `client_id` "%s" with `client_secret` '
"as Basic auth credentials.",
client_id,
)
client_secret = client_secret if client_secret is not None else ""
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
if include_client_id:
# this was pulled out of the params
# it needs to be passed into prepare_request_body
if client_secret is not None:
kwargs["client_secret"] = client_secret
body = self._client.prepare_request_body(
code=code,
body=body,
redirect_uri=self.redirect_uri,
include_client_id=include_client_id,
**kwargs
)
headers = headers or { headers = headers or {
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8",
} }
self.token = {} self.token = {}
if method.upper() == 'POST': request_kwargs = {}
r = self.post(token_url, data=dict(urldecode(body)), if method.upper() == "POST":
timeout=timeout, headers=headers, auth=auth, request_kwargs["params" if force_querystring else "data"] = dict(
verify=verify) urldecode(body)
log.debug('Prepared fetch token request body %s', body) )
elif method.upper() == 'GET': elif method.upper() == "GET":
# if method is not 'POST', switch body to querystring and GET request_kwargs["params"] = dict(urldecode(body))
r = self.get(token_url, params=dict(urldecode(body)),
timeout=timeout, headers=headers, auth=auth,
verify=verify)
log.debug('Prepared fetch token request querystring %s', body)
else: else:
raise ValueError('The method kwarg must be POST or GET.') raise ValueError("The method kwarg must be POST or GET.")
log.debug('Request to fetch token completed with status %s.', r = self.request(
r.status_code) method=method,
log.debug('Request headers were %s', r.request.headers) url=token_url,
log.debug('Request body was %s', r.request.body) timeout=timeout,
log.debug('Response headers were %s and content %s.', headers=headers,
r.headers, r.text) auth=auth,
log.debug('Invoking %d token response hooks.', verify=verify,
len(self.compliance_hook['access_token_response'])) proxies=proxies,
for hook in self.compliance_hook['access_token_response']: **request_kwargs
log.debug('Invoking hook %s.', hook) )
log.debug("Request to fetch token completed with status %s.", r.status_code)
log.debug("Request url was %s", r.request.url)
log.debug("Request headers were %s", r.request.headers)
log.debug("Request body was %s", r.request.body)
log.debug("Response headers were %s and content %s.", r.headers, r.text)
log.debug(
"Invoking %d token response hooks.",
len(self.compliance_hook["access_token_response"]),
)
for hook in self.compliance_hook["access_token_response"]:
log.debug("Invoking hook %s.", hook)
r = hook(r) r = hook(r)
self._client.parse_request_body_response(r.text, scope=self.scope) self._client.parse_request_body_response(r.text, scope=self.scope)
self.token = self._client.token self.token = self._client.token
log.debug('Obtained token %s.', self.token) log.debug("Obtained token %s.", self.token)
return self.token return self.token
def token_from_fragment(self, authorization_response): def token_from_fragment(self, authorization_response):
@ -244,103 +368,153 @@ class OAuth2Session(requests.Session):
:param authorization_response: The full URL of the redirect back to you :param authorization_response: The full URL of the redirect back to you
:return: A token dict :return: A token dict
""" """
self._client.parse_request_uri_response(authorization_response, self._client.parse_request_uri_response(
state=self._state) authorization_response, state=self._state
)
self.token = self._client.token self.token = self._client.token
return self.token return self.token
def refresh_token(self, token_url, refresh_token=None, body='', auth=None, def refresh_token(
timeout=None, headers=None, verify=True, **kwargs): self,
token_url,
refresh_token=None,
body="",
auth=None,
timeout=None,
headers=None,
verify=True,
proxies=None,
**kwargs
):
"""Fetch a new access token using a refresh token. """Fetch a new access token using a refresh token.
:param token_url: The token endpoint, must be HTTPS. :param token_url: The token endpoint, must be HTTPS.
:param refresh_token: The refresh_token to use. :param refresh_token: The refresh_token to use.
:param body: Optional application/x-www-form-urlencoded body to add the :param body: Optional application/x-www-form-urlencoded body to add the
include in the token request. Prefer kwargs over body. include in the token request. Prefer kwargs over body.
:param auth: An auth tuple or method as accepted by requests. :param auth: An auth tuple or method as accepted by `requests`.
:param timeout: Timeout of the request in seconds. :param timeout: Timeout of the request in seconds.
:param headers: A dict of headers to be used by `requests`.
:param verify: Verify SSL certificate. :param verify: Verify SSL certificate.
:param proxies: The `proxies` argument will be passed to `requests`.
:param kwargs: Extra parameters to include in the token request. :param kwargs: Extra parameters to include in the token request.
:return: A token dict :return: A token dict
""" """
if not token_url: if not token_url:
raise ValueError('No token endpoint set for auto_refresh.') raise ValueError("No token endpoint set for auto_refresh.")
if not is_secure_transport(token_url): if not is_secure_transport(token_url):
raise InsecureTransportError() raise InsecureTransportError()
refresh_token = refresh_token or self.token.get('refresh_token') refresh_token = refresh_token or self.token.get("refresh_token")
log.debug('Adding auto refresh key word arguments %s.', log.debug(
self.auto_refresh_kwargs) "Adding auto refresh key word arguments %s.", self.auto_refresh_kwargs
)
kwargs.update(self.auto_refresh_kwargs) kwargs.update(self.auto_refresh_kwargs)
body = self._client.prepare_refresh_body(body=body, body = self._client.prepare_refresh_body(
refresh_token=refresh_token, scope=self.scope, **kwargs) body=body, refresh_token=refresh_token, scope=self.scope, **kwargs
log.debug('Prepared refresh token request body %s', body) )
log.debug("Prepared refresh token request body %s", body)
if headers is None: if headers is None:
headers = { headers = {
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': ( "Content-Type": ("application/x-www-form-urlencoded;charset=UTF-8"),
'application/x-www-form-urlencoded;charset=UTF-8'
),
} }
r = self.post(token_url, data=dict(urldecode(body)), auth=auth, r = self.post(
timeout=timeout, headers=headers, verify=verify, withhold_token=True) token_url,
log.debug('Request to refresh token completed with status %s.', data=dict(urldecode(body)),
r.status_code) auth=auth,
log.debug('Response headers were %s and content %s.', timeout=timeout,
r.headers, r.text) headers=headers,
log.debug('Invoking %d token response hooks.', verify=verify,
len(self.compliance_hook['refresh_token_response'])) withhold_token=True,
for hook in self.compliance_hook['refresh_token_response']: proxies=proxies,
log.debug('Invoking hook %s.', hook) )
log.debug("Request to refresh token completed with status %s.", r.status_code)
log.debug("Response headers were %s and content %s.", r.headers, r.text)
log.debug(
"Invoking %d token response hooks.",
len(self.compliance_hook["refresh_token_response"]),
)
for hook in self.compliance_hook["refresh_token_response"]:
log.debug("Invoking hook %s.", hook)
r = hook(r) r = hook(r)
self.token = self._client.parse_request_body_response(r.text, scope=self.scope) self.token = self._client.parse_request_body_response(r.text, scope=self.scope)
if not 'refresh_token' in self.token: if not "refresh_token" in self.token:
log.debug('No new refresh token given. Re-using old.') log.debug("No new refresh token given. Re-using old.")
self.token['refresh_token'] = refresh_token self.token["refresh_token"] = refresh_token
return self.token return self.token
def request(self, method, url, data=None, headers=None, withhold_token=False, **kwargs): def request(
self,
method,
url,
data=None,
headers=None,
withhold_token=False,
client_id=None,
client_secret=None,
**kwargs
):
"""Intercept all requests and add the OAuth 2 token if present.""" """Intercept all requests and add the OAuth 2 token if present."""
if not is_secure_transport(url): if not is_secure_transport(url):
raise InsecureTransportError() raise InsecureTransportError()
if self.token and not withhold_token: if self.token and not withhold_token:
log.debug('Invoking %d protected resource request hooks.', log.debug(
len(self.compliance_hook['protected_request'])) "Invoking %d protected resource request hooks.",
for hook in self.compliance_hook['protected_request']: len(self.compliance_hook["protected_request"]),
log.debug('Invoking hook %s.', hook) )
for hook in self.compliance_hook["protected_request"]:
log.debug("Invoking hook %s.", hook)
url, headers, data = hook(url, headers, data) url, headers, data = hook(url, headers, data)
log.debug('Adding token %s to request.', self.token) log.debug("Adding token %s to request.", self.token)
try: try:
url, headers, data = self._client.add_token(url, url, headers, data = self._client.add_token(
http_method=method, body=data, headers=headers) url, http_method=method, body=data, headers=headers
)
# Attempt to retrieve and save new access token if expired # Attempt to retrieve and save new access token if expired
except TokenExpiredError: except TokenExpiredError:
if self.auto_refresh_url: if self.auto_refresh_url:
log.debug('Auto refresh is set, attempting to refresh at %s.', log.debug(
self.auto_refresh_url) "Auto refresh is set, attempting to refresh at %s.",
token = self.refresh_token(self.auto_refresh_url, **kwargs) self.auto_refresh_url,
)
# We mustn't pass auth twice.
auth = kwargs.pop("auth", None)
if client_id and client_secret and (auth is None):
log.debug(
'Encoding client_id "%s" with client_secret as Basic auth credentials.',
client_id,
)
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
token = self.refresh_token(
self.auto_refresh_url, auth=auth, **kwargs
)
if self.token_updater: if self.token_updater:
log.debug('Updating token to %s using %s.', log.debug(
token, self.token_updater) "Updating token to %s using %s.", token, self.token_updater
)
self.token_updater(token) self.token_updater(token)
url, headers, data = self._client.add_token(url, url, headers, data = self._client.add_token(
http_method=method, body=data, headers=headers) url, http_method=method, body=data, headers=headers
)
else: else:
raise TokenUpdated(token) raise TokenUpdated(token)
else: else:
raise raise
log.debug('Requesting url %s using method %s.', url, method) log.debug("Requesting url %s using method %s.", url, method)
log.debug('Supplying headers %s and data %s', headers, data) log.debug("Supplying headers %s and data %s", headers, data)
log.debug('Passing through key word arguments %s.', kwargs) log.debug("Passing through key word arguments %s.", kwargs)
return super(OAuth2Session, self).request(method, url, return super(OAuth2Session, self).request(
headers=headers, data=data, **kwargs) method, url, headers=headers, data=data, **kwargs
)
def register_compliance_hook(self, hook_type, hook): def register_compliance_hook(self, hook_type, hook):
"""Register a hook for request/response tweaking. """Register a hook for request/response tweaking.
@ -354,6 +528,7 @@ class OAuth2Session(requests.Session):
or open an issue. or open an issue.
""" """
if hook_type not in self.compliance_hook: if hook_type not in self.compliance_hook:
raise ValueError('Hook type %s is not in %s.', raise ValueError(
hook_type, self.compliance_hook) "Hook type %s is not in %s.", hook_type, self.compliance_hook
)
self.compliance_hook[hook_type].add(hook) self.compliance_hook[hook_type].add(hook)