diff --git a/panoptes_client/panoptes.py b/panoptes_client/panoptes.py index 1987a91..329fdec 100644 --- a/panoptes_client/panoptes.py +++ b/panoptes_client/panoptes.py @@ -139,6 +139,7 @@ def __init__( self.logged_in = False self.logged_in_user_id = None self.bearer_token = None + self._bearer_token_lock = threading.Lock() self.admin = admin self.username = None self.password = None @@ -522,7 +523,14 @@ def get_csrf_token(self): return self.session.get(url, headers=headers).headers['x-csrf-token'] def get_bearer_token(self): - if not self.valid_bearer_token(): + if self.valid_bearer_token(): + return self.bearer_token + + with self._bearer_token_lock: + # Another thread may have refreshed the token while we waited. + if self.valid_bearer_token(): + return self.bearer_token + grant_type = 'password' if self.client_secret: @@ -545,17 +553,27 @@ def get_bearer_token(self): 'client_id': self.client_id, } - if grant_type == 'client_credentials': - bearer_data['client_secret'] = self.client_secret - bearer_data['url'] = self.redirect_url - - token_response = self.session.post( - self.endpoint + '/oauth/token', - bearer_data - ).json() - - if 'errors' in token_response: - raise PanoptesAPIException(token_response['errors']) + try: + token_response = self._request_bearer_token( + bearer_data, + grant_type, + ) + except PanoptesAPIException: + if grant_type != 'password': + raise + self.bearer_token = None + self.refresh_token = None + self.logged_in = False + if not self.login(): + raise + bearer_data = { + 'grant_type': grant_type, + 'client_id': self.client_id, + } + token_response = self._request_bearer_token( + bearer_data, + grant_type, + ) self.bearer_token = token_response['access_token'] if (self.bearer_token and grant_type == 'client_credentials'): @@ -570,6 +588,32 @@ def get_bearer_token(self): ) return self.bearer_token + def _request_bearer_token(self, bearer_data, grant_type): + if grant_type == 'client_credentials': + bearer_data['client_secret'] = self.client_secret + bearer_data['url'] = self.redirect_url + + token_response = self.session.post( + self.endpoint + '/oauth/token', + bearer_data + ).json() + + if 'errors' in token_response: + raise PanoptesAPIException(token_response['errors']) + + if 'access_token' not in token_response: + raise PanoptesAPIException( + token_response.get( + 'error_description', + token_response.get( + 'error', + 'Authentication failed: no access token returned' + ) + ) + ) + + return token_response + def valid_bearer_token(self): # Return invalid if there is no token if not self.has_bearer_token(): diff --git a/panoptes_client/tests/test_bearer_expiry.py b/panoptes_client/tests/test_bearer_expiry.py index fdd877e..24e9951 100644 --- a/panoptes_client/tests/test_bearer_expiry.py +++ b/panoptes_client/tests/test_bearer_expiry.py @@ -1,14 +1,15 @@ from panoptes_client.panoptes import Panoptes +from panoptes_client.panoptes import PanoptesAPIException import datetime import unittest import sys if sys.version_info <= (3, 0): - from mock import patch + from mock import Mock, patch else: - from unittest.mock import patch + from unittest.mock import Mock, patch class MockDate(datetime.datetime): @@ -92,3 +93,75 @@ def test_has_no_token(self): client = Panoptes() assert client.has_bearer_token() is False + + def test_refresh_token_failure_retries_after_login(self): + MockDate.fake(datetime.datetime(2017, 1, 1, 10, 0, 0)) + + client = Panoptes() + client.valid_bearer_token = Mock(return_value=False) + client.username = 'user' + client.password = 'password' + client.logged_in = True + client.bearer_token = 'expired' + client.refresh_token = 'stale-refresh' + + refresh_response = Mock() + refresh_response.json.return_value = {'error': 'invalid_grant'} + login_response = Mock() + login_response.status_code = 200 + login_response.json.return_value = {'users': [{'id': '1'}]} + token_response = Mock() + token_response.json.return_value = { + 'access_token': 'new-token', + 'expires_in': 3600, + 'refresh_token': 'new-refresh', + } + csrf_response = Mock() + csrf_response.headers = {'x-csrf-token': 'csrf-token'} + client.session.get = Mock(return_value=csrf_response) + client.session.post = Mock(side_effect=[ + refresh_response, + login_response, + token_response, + ]) + + assert client.get_bearer_token() == 'new-token' + assert client.refresh_token == 'new-refresh' + assert client.logged_in is True + assert client.session.post.call_count == 3 + assert client.session.post.call_args_list[0][0][1]['grant_type'] == ( + 'refresh_token' + ) + assert client.session.post.call_args_list[2][0][1]['grant_type'] == ( + 'password' + ) + + def test_missing_access_token_raises_api_exception_after_retry(self): + MockDate.fake(datetime.datetime(2017, 1, 1, 10, 0, 0)) + + client = Panoptes() + client.valid_bearer_token = Mock(return_value=False) + client.username = 'user' + client.password = 'password' + client.logged_in = True + client.bearer_token = 'expired' + client.refresh_token = 'stale-refresh' + + refresh_response = Mock() + refresh_response.json.return_value = {'error': 'invalid_grant'} + login_response = Mock() + login_response.status_code = 200 + login_response.json.return_value = {'users': [{'id': '1'}]} + retry_response = Mock() + retry_response.json.return_value = {'error': 'invalid_grant'} + csrf_response = Mock() + csrf_response.headers = {'x-csrf-token': 'csrf-token'} + client.session.get = Mock(return_value=csrf_response) + client.session.post = Mock(side_effect=[ + refresh_response, + login_response, + retry_response, + ]) + + with self.assertRaises(PanoptesAPIException): + client.get_bearer_token()