From 4827dd9c56fb9ef23ae93e12a4ad01edeb7b000b Mon Sep 17 00:00:00 2001 From: Anurag Bandyopadhyay Date: Tue, 28 Apr 2026 20:33:54 +0530 Subject: [PATCH 1/7] fix: add token expiry buffer to prevent expired token usage --- openfga_sdk/oauth2.py | 18 +++++++++++---- openfga_sdk/sync/oauth2.py | 18 +++++++++++---- test/oauth2_test.py | 47 ++++++++++++++++++++++++++++++++++++-- test/sync/oauth2_test.py | 47 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 116 insertions(+), 14 deletions(-) diff --git a/openfga_sdk/oauth2.py b/openfga_sdk/oauth2.py index 27eac5a..ee92b59 100644 --- a/openfga_sdk/oauth2.py +++ b/openfga_sdk/oauth2.py @@ -9,7 +9,11 @@ import urllib3 from openfga_sdk.configuration import Configuration -from openfga_sdk.constants import USER_AGENT +from openfga_sdk.constants import ( + TOKEN_EXPIRY_JITTER_IN_SEC, + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, + USER_AGENT, +) from openfga_sdk.credentials import Credentials from openfga_sdk.exceptions import AuthenticationError from openfga_sdk.telemetry.attributes import TelemetryAttributes @@ -36,6 +40,7 @@ def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials self._access_token = None self._access_expiry_time = None + self._access_token_expiry_buffer = 0 self._telemetry = Telemetry() if configuration is None: @@ -45,13 +50,12 @@ def __init__(self, credentials: Credentials, configuration=None): def _token_valid(self): """ - Return whether token is valid + Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens) """ if self._access_token is None or self._access_expiry_time is None: return False - if self._access_expiry_time < datetime.now(): - return False - return True + remaining = (self._access_expiry_time - datetime.now()).total_seconds() + return remaining > self._access_token_expiry_buffer async def _obtain_token(self, client): """ @@ -140,6 +144,10 @@ async def _obtain_token(self, client): seconds=int(api_response.get("expires_in")) ) self._access_token = api_response.get("access_token") + self._access_token_expiry_buffer = ( + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC + + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC + ) self._telemetry.metrics.credentialsRequest( attributes={ TelemetryAttributes.fga_client_request_client_id: configuration.client_id diff --git a/openfga_sdk/sync/oauth2.py b/openfga_sdk/sync/oauth2.py index 0f5bc09..c2dc231 100644 --- a/openfga_sdk/sync/oauth2.py +++ b/openfga_sdk/sync/oauth2.py @@ -9,7 +9,11 @@ import urllib3 from openfga_sdk.configuration import Configuration -from openfga_sdk.constants import USER_AGENT +from openfga_sdk.constants import ( + TOKEN_EXPIRY_JITTER_IN_SEC, + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, + USER_AGENT, +) from openfga_sdk.credentials import Credentials from openfga_sdk.exceptions import AuthenticationError from openfga_sdk.telemetry.attributes import TelemetryAttributes @@ -36,6 +40,7 @@ def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials self._access_token = None self._access_expiry_time = None + self._access_token_expiry_buffer = 0 self._telemetry = Telemetry() if configuration is None: @@ -45,13 +50,12 @@ def __init__(self, credentials: Credentials, configuration=None): def _token_valid(self): """ - Return whether token is valid + Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens) """ if self._access_token is None or self._access_expiry_time is None: return False - if self._access_expiry_time < datetime.now(): - return False - return True + remaining = (self._access_expiry_time - datetime.now()).total_seconds() + return remaining > self._access_token_expiry_buffer def _obtain_token(self, client): """ @@ -140,6 +144,10 @@ def _obtain_token(self, client): seconds=int(api_response.get("expires_in")) ) self._access_token = api_response.get("access_token") + self._access_token_expiry_buffer = ( + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC + + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC + ) self._telemetry.metrics.credentialsRequest( attributes={ TelemetryAttributes.fga_client_request_client_id: configuration.client_id diff --git a/test/oauth2_test.py b/test/oauth2_test.py index 48b5030..45884a7 100644 --- a/test/oauth2_test.py +++ b/test/oauth2_test.py @@ -6,7 +6,7 @@ from openfga_sdk import rest from openfga_sdk.configuration import Configuration -from openfga_sdk.constants import USER_AGENT +from openfga_sdk.constants import TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, USER_AGENT from openfga_sdk.credentials import CredentialConfiguration, Credentials from openfga_sdk.exceptions import AuthenticationError from openfga_sdk.oauth2 import OAuth2Client @@ -34,7 +34,7 @@ async def test_get_authentication_valid_client_credentials(self): """ client = OAuth2Client(None) client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() + timedelta(seconds=60) + client._access_expiry_time = datetime.now() + timedelta(seconds=3600) auth_header = await client.get_authentication_header(None) self.assertEqual(auth_header, {"Authorization": "Bearer XYZ123"}) @@ -651,6 +651,49 @@ async def test_get_authentication_without_audience(self, mock_request): ) await rest_client.close() + @patch.object(rest.RESTClientObject, "request") + @patch("openfga_sdk.oauth2.random") + async def test_get_authentication_refreshes_near_expiry_token( + self, mock_random, mock_request + ): + """ + Token close to expiry (within buffer window) should trigger a proactive refresh + """ + mock_random.random.return_value = 0 + short_lived_secs = max(1, TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC - 1) + + mock_request.side_effect = [ + mock_response( + f'{{"expires_in": {short_lived_secs}, "access_token": "short-lived-token"}}', + 200, + ), + mock_response( + '{"expires_in": 3600, "access_token": "refreshed-token"}', + 200, + ), + ] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + header1 = await client.get_authentication_header(rest_client) + header2 = await client.get_authentication_header(rest_client) + + self.assertEqual(header1, {"Authorization": "Bearer short-lived-token"}) + self.assertEqual(header2, {"Authorization": "Bearer refreshed-token"}) + self.assertEqual(mock_request.call_count, 2) + + await rest_client.close() + @patch.object(rest.RESTClientObject, "request") async def test_get_authentication_with_scopes_no_audience(self, mock_request): """ diff --git a/test/sync/oauth2_test.py b/test/sync/oauth2_test.py index d0dc387..5708802 100644 --- a/test/sync/oauth2_test.py +++ b/test/sync/oauth2_test.py @@ -5,7 +5,7 @@ import urllib3 from openfga_sdk.configuration import Configuration -from openfga_sdk.constants import USER_AGENT +from openfga_sdk.constants import TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, USER_AGENT from openfga_sdk.credentials import CredentialConfiguration, Credentials from openfga_sdk.exceptions import AuthenticationError from openfga_sdk.sync import rest @@ -34,7 +34,7 @@ def test_get_authentication_valid_client_credentials(self): """ client = OAuth2Client(None) client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() + timedelta(seconds=60) + client._access_expiry_time = datetime.now() + timedelta(seconds=3600) auth_header = client.get_authentication_header(None) self.assertEqual(auth_header, {"Authorization": "Bearer XYZ123"}) @@ -427,6 +427,49 @@ def test_get_authentication_without_audience(self, mock_request): ) rest_client.close() + @patch.object(rest.RESTClientObject, "request") + @patch("openfga_sdk.sync.oauth2.random") + def test_get_authentication_refreshes_near_expiry_token( + self, mock_random, mock_request + ): + """ + Token close to expiry (within buffer window) should trigger a proactive refresh + """ + mock_random.random.return_value = 0 + short_lived_secs = max(1, TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC - 1) + + mock_request.side_effect = [ + mock_response( + f'{{"expires_in": {short_lived_secs}, "access_token": "short-lived-token"}}', + 200, + ), + mock_response( + '{"expires_in": 3600, "access_token": "refreshed-token"}', + 200, + ), + ] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + header1 = client.get_authentication_header(rest_client) + header2 = client.get_authentication_header(rest_client) + + self.assertEqual(header1, {"Authorization": "Bearer short-lived-token"}) + self.assertEqual(header2, {"Authorization": "Bearer refreshed-token"}) + self.assertEqual(mock_request.call_count, 2) + + rest_client.close() + @patch.object(rest.RESTClientObject, "request") def test_get_authentication_with_scopes_no_audience(self, mock_request): """ From 7c654392a54c558839d76d52cb15876b0c0ca0f1 Mon Sep 17 00:00:00 2001 From: Anurag Bandyopadhyay Date: Tue, 28 Apr 2026 21:12:42 +0530 Subject: [PATCH 2/7] feat: add lock to prevent concurrent token refreshes --- openfga_sdk/oauth2.py | 11 +++++---- openfga_sdk/sync/oauth2.py | 12 ++++++---- test/oauth2_test.py | 37 ++++++++++++++++++++++++++++++ test/sync/oauth2_test.py | 46 +++++++++++++++++++++++++++++++++++++- 4 files changed, 97 insertions(+), 9 deletions(-) diff --git a/openfga_sdk/oauth2.py b/openfga_sdk/oauth2.py index ee92b59..065bb01 100644 --- a/openfga_sdk/oauth2.py +++ b/openfga_sdk/oauth2.py @@ -41,6 +41,7 @@ def __init__(self, credentials: Credentials, configuration=None): self._access_token = None self._access_expiry_time = None self._access_token_expiry_buffer = 0 + self._lock = asyncio.Lock() self._telemetry = Telemetry() if configuration is None: @@ -80,7 +81,9 @@ async def _obtain_token(self, client): # Add scope parameter if scopes are configured if configuration.scopes is not None: if isinstance(configuration.scopes, list): - scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip()) + scope_str = " ".join( + s.strip() for s in configuration.scopes if s and s.strip() + ) else: scope_str = ( configuration.scopes.strip() @@ -162,8 +165,8 @@ async def get_authentication_header(self, client): """ If configured, return the header for authentication """ - # check to see token is valid if not self._token_valid(): - # In this case, the token is not valid, we need to get the refresh the token - await self._obtain_token(client) + async with self._lock: + if not self._token_valid(): + await self._obtain_token(client) return {"Authorization": f"Bearer {self._access_token}"} diff --git a/openfga_sdk/sync/oauth2.py b/openfga_sdk/sync/oauth2.py index c2dc231..704c0c6 100644 --- a/openfga_sdk/sync/oauth2.py +++ b/openfga_sdk/sync/oauth2.py @@ -2,6 +2,7 @@ import math import random import sys +import threading import time from datetime import datetime, timedelta @@ -41,6 +42,7 @@ def __init__(self, credentials: Credentials, configuration=None): self._access_token = None self._access_expiry_time = None self._access_token_expiry_buffer = 0 + self._lock = threading.Lock() self._telemetry = Telemetry() if configuration is None: @@ -80,7 +82,9 @@ def _obtain_token(self, client): # Add scope parameter if scopes are configured if configuration.scopes is not None: if isinstance(configuration.scopes, list): - scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip()) + scope_str = " ".join( + s.strip() for s in configuration.scopes if s and s.strip() + ) else: scope_str = ( configuration.scopes.strip() @@ -162,8 +166,8 @@ def get_authentication_header(self, client): """ If configured, return the header for authentication """ - # check to see token is valid if not self._token_valid(): - # In this case, the token is not valid, we need to get the refresh the token - self._obtain_token(client) + with self._lock: + if not self._token_valid(): + self._obtain_token(client) return {"Authorization": f"Bearer {self._access_token}"} diff --git a/test/oauth2_test.py b/test/oauth2_test.py index 45884a7..e78a092 100644 --- a/test/oauth2_test.py +++ b/test/oauth2_test.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timedelta from unittest import IsolatedAsyncioTestCase from unittest.mock import patch @@ -694,6 +695,42 @@ async def test_get_authentication_refreshes_near_expiry_token( await rest_client.close() + async def test_concurrent_requests_only_fetch_token_once(self): + """ + Multiple concurrent requests while the token is invalid should result in + only one token fetch — subsequent coroutines wait on the lock and reuse + the token obtained by the first. + """ + obtain_calls = [] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + oauth_client = OAuth2Client(credentials) + + async def mock_obtain_token(client): + obtain_calls.append(1) + await asyncio.sleep(0) # yield so other coroutines reach the lock + oauth_client._access_token = "concurrent-token" + oauth_client._access_expiry_time = datetime.now() + timedelta(seconds=3600) + oauth_client._access_token_expiry_buffer = 300 + + with patch.object(oauth_client, "_obtain_token", side_effect=mock_obtain_token): + results = await asyncio.gather( + *[oauth_client.get_authentication_header(None) for _ in range(5)] + ) + + self.assertEqual(len(obtain_calls), 1) + self.assertTrue( + all(r == {"Authorization": "Bearer concurrent-token"} for r in results) + ) + @patch.object(rest.RESTClientObject, "request") async def test_get_authentication_with_scopes_no_audience(self, mock_request): """ diff --git a/test/sync/oauth2_test.py b/test/sync/oauth2_test.py index 5708802..0ab435a 100644 --- a/test/sync/oauth2_test.py +++ b/test/sync/oauth2_test.py @@ -1,3 +1,5 @@ +import threading +import time from datetime import datetime, timedelta from unittest import IsolatedAsyncioTestCase from unittest.mock import patch @@ -470,6 +472,49 @@ def test_get_authentication_refreshes_near_expiry_token( rest_client.close() + def test_concurrent_requests_only_fetch_token_once(self): + """ + Multiple concurrent threads while the token is invalid should result in + only one token fetch — subsequent threads wait on the lock and reuse + the token obtained by the first. + """ + obtain_calls = [] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + oauth_client = OAuth2Client(credentials) + + def mock_obtain_token(client): + obtain_calls.append(1) + time.sleep(0.05) # hold the lock briefly so other threads queue up + oauth_client._access_token = "concurrent-token" + oauth_client._access_expiry_time = datetime.now() + timedelta(seconds=3600) + oauth_client._access_token_expiry_buffer = 300 + + results = [] + + def call(): + results.append(oauth_client.get_authentication_header(None)) + + with patch.object(oauth_client, "_obtain_token", side_effect=mock_obtain_token): + threads = [threading.Thread(target=call) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(obtain_calls), 1) + self.assertTrue( + all(r == {"Authorization": "Bearer concurrent-token"} for r in results) + ) + @patch.object(rest.RESTClientObject, "request") def test_get_authentication_with_scopes_no_audience(self, mock_request): """ @@ -520,4 +565,3 @@ def test_get_authentication_with_scopes_no_audience(self, mock_request): }, ) rest_client.close() - From 7e1cd8bc5755f2993ed1dd216f117ec68ee2118f Mon Sep 17 00:00:00 2001 From: Anurag Bandyopadhyay Date: Tue, 28 Apr 2026 21:16:59 +0530 Subject: [PATCH 3/7] fix: ruff check --- test/oauth2_test.py | 1 + test/sync/oauth2_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/test/oauth2_test.py b/test/oauth2_test.py index e78a092..2cc28e8 100644 --- a/test/oauth2_test.py +++ b/test/oauth2_test.py @@ -1,4 +1,5 @@ import asyncio + from datetime import datetime, timedelta from unittest import IsolatedAsyncioTestCase from unittest.mock import patch diff --git a/test/sync/oauth2_test.py b/test/sync/oauth2_test.py index 0ab435a..ebdba61 100644 --- a/test/sync/oauth2_test.py +++ b/test/sync/oauth2_test.py @@ -1,5 +1,6 @@ import threading import time + from datetime import datetime, timedelta from unittest import IsolatedAsyncioTestCase from unittest.mock import patch From e9c445eab8f8e23ef7de4de383b43610234a9edf Mon Sep 17 00:00:00 2001 From: Anurag Bandyopadhyay Date: Tue, 28 Apr 2026 21:36:36 +0530 Subject: [PATCH 4/7] fix: address comment --- test/sync/oauth2_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/sync/oauth2_test.py b/test/sync/oauth2_test.py index ebdba61..30a04c4 100644 --- a/test/sync/oauth2_test.py +++ b/test/sync/oauth2_test.py @@ -511,6 +511,7 @@ def call(): for t in threads: t.join() + self.assertEqual(len(results), 5) self.assertEqual(len(obtain_calls), 1) self.assertTrue( all(r == {"Authorization": "Bearer concurrent-token"} for r in results) From b48115c77fab5825f54749033d8c47fcc66883d9 Mon Sep 17 00:00:00 2001 From: Anurag Bandyopadhyay Date: Tue, 28 Apr 2026 22:03:11 +0530 Subject: [PATCH 5/7] fix: use atomic token state to eliminate torn-read race condition Replace three separate mutable fields with a single frozen dataclass assigned atomically, ensuring concurrent threads always see a complete token state snapshot. --- openfga_sdk/oauth2.py | 37 +++++++++++++++++---------- openfga_sdk/sync/oauth2.py | 37 +++++++++++++++++---------- test/oauth2_test.py | 52 ++++++++++++++++++++++---------------- test/sync/oauth2_test.py | 36 ++++++++++++++++---------- 4 files changed, 98 insertions(+), 64 deletions(-) diff --git a/openfga_sdk/oauth2.py b/openfga_sdk/oauth2.py index 065bb01..9d1a72c 100644 --- a/openfga_sdk/oauth2.py +++ b/openfga_sdk/oauth2.py @@ -4,6 +4,7 @@ import random import sys +from dataclasses import dataclass from datetime import datetime, timedelta import urllib3 @@ -20,6 +21,13 @@ from openfga_sdk.telemetry.telemetry import Telemetry +@dataclass(frozen=True) +class _TokenState: + access_token: str + expiry_time: datetime + expiry_buffer: float + + def jitter(loop_count, min_wait_in_ms): """ Generate a random jitter value for exponential backoff @@ -38,9 +46,7 @@ def jitter(loop_count, min_wait_in_ms): class OAuth2Client: def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials - self._access_token = None - self._access_expiry_time = None - self._access_token_expiry_buffer = 0 + self._token_state: _TokenState | None = None self._lock = asyncio.Lock() self._telemetry = Telemetry() @@ -53,10 +59,11 @@ def _token_valid(self): """ Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens) """ - if self._access_token is None or self._access_expiry_time is None: + state = self._token_state # atomic snapshot — either old or new, never torn + if state is None: return False - remaining = (self._access_expiry_time - datetime.now()).total_seconds() - return remaining > self._access_token_expiry_buffer + remaining = (state.expiry_time - datetime.now()).total_seconds() + return remaining > state.expiry_buffer async def _obtain_token(self, client): """ @@ -143,13 +150,15 @@ async def _obtain_token(self, client): raise AuthenticationError(http_resp=raw_response) if api_response.get("expires_in") and api_response.get("access_token"): - self._access_expiry_time = datetime.now() + timedelta( - seconds=int(api_response.get("expires_in")) - ) - self._access_token = api_response.get("access_token") - self._access_token_expiry_buffer = ( - TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC - + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC + self._token_state = _TokenState( + access_token=api_response.get("access_token"), + expiry_time=datetime.now() + timedelta( + seconds=int(api_response.get("expires_in")) + ), + expiry_buffer=( + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC + + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC + ), ) self._telemetry.metrics.credentialsRequest( attributes={ @@ -169,4 +178,4 @@ async def get_authentication_header(self, client): async with self._lock: if not self._token_valid(): await self._obtain_token(client) - return {"Authorization": f"Bearer {self._access_token}"} + return {"Authorization": f"Bearer {self._token_state.access_token}"} diff --git a/openfga_sdk/sync/oauth2.py b/openfga_sdk/sync/oauth2.py index 704c0c6..5b29679 100644 --- a/openfga_sdk/sync/oauth2.py +++ b/openfga_sdk/sync/oauth2.py @@ -5,6 +5,7 @@ import threading import time +from dataclasses import dataclass from datetime import datetime, timedelta import urllib3 @@ -21,6 +22,13 @@ from openfga_sdk.telemetry.telemetry import Telemetry +@dataclass(frozen=True) +class _TokenState: + access_token: str + expiry_time: datetime + expiry_buffer: float + + def jitter(loop_count, min_wait_in_ms): """ Generate a random jitter value for exponential backoff @@ -39,9 +47,7 @@ def jitter(loop_count, min_wait_in_ms): class OAuth2Client: def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials - self._access_token = None - self._access_expiry_time = None - self._access_token_expiry_buffer = 0 + self._token_state: _TokenState | None = None self._lock = threading.Lock() self._telemetry = Telemetry() @@ -54,10 +60,11 @@ def _token_valid(self): """ Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens) """ - if self._access_token is None or self._access_expiry_time is None: + state = self._token_state # atomic snapshot — either old or new, never torn + if state is None: return False - remaining = (self._access_expiry_time - datetime.now()).total_seconds() - return remaining > self._access_token_expiry_buffer + remaining = (state.expiry_time - datetime.now()).total_seconds() + return remaining > state.expiry_buffer def _obtain_token(self, client): """ @@ -144,13 +151,15 @@ def _obtain_token(self, client): raise AuthenticationError(http_resp=raw_response) if api_response.get("expires_in") and api_response.get("access_token"): - self._access_expiry_time = datetime.now() + timedelta( - seconds=int(api_response.get("expires_in")) - ) - self._access_token = api_response.get("access_token") - self._access_token_expiry_buffer = ( - TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC - + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC + self._token_state = _TokenState( + access_token=api_response.get("access_token"), + expiry_time=datetime.now() + timedelta( + seconds=int(api_response.get("expires_in")) + ), + expiry_buffer=( + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC + + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC + ), ) self._telemetry.metrics.credentialsRequest( attributes={ @@ -170,4 +179,4 @@ def get_authentication_header(self, client): with self._lock: if not self._token_valid(): self._obtain_token(client) - return {"Authorization": f"Bearer {self._access_token}"} + return {"Authorization": f"Bearer {self._token_state.access_token}"} diff --git a/test/oauth2_test.py b/test/oauth2_test.py index 2cc28e8..b4e061f 100644 --- a/test/oauth2_test.py +++ b/test/oauth2_test.py @@ -11,7 +11,7 @@ from openfga_sdk.constants import TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, USER_AGENT from openfga_sdk.credentials import CredentialConfiguration, Credentials from openfga_sdk.exceptions import AuthenticationError -from openfga_sdk.oauth2 import OAuth2Client +from openfga_sdk.oauth2 import OAuth2Client, _TokenState # Helper function to construct mock response @@ -35,8 +35,11 @@ async def test_get_authentication_valid_client_credentials(self): Test getting authentication header when method is client credentials """ client = OAuth2Client(None) - client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() + timedelta(seconds=3600) + client._token_state = _TokenState( + access_token="XYZ123", + expiry_time=datetime.now() + timedelta(seconds=3600), + expiry_buffer=0, + ) auth_header = await client.get_authentication_header(None) self.assertEqual(auth_header, {"Authorization": "Bearer XYZ123"}) @@ -67,9 +70,9 @@ async def test_get_authentication_obtain_client_credentials(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -151,8 +154,11 @@ async def test_get_authentication_obtain_with_expired_client_credentials_failed( rest_client = rest.RESTClientObject(Configuration()) client = OAuth2Client(credentials) - client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() - timedelta(seconds=240) + client._token_state = _TokenState( + access_token="XYZ123", + expiry_time=datetime.now() - timedelta(seconds=240), + expiry_buffer=0, + ) with self.assertRaises(AuthenticationError): await client.get_authentication_header(rest_client) @@ -293,9 +299,9 @@ async def test_get_authentication_keep_full_url(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -348,9 +354,9 @@ async def test_get_authentication_add_scheme(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -403,9 +409,9 @@ async def test_get_authentication_add_path(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -458,9 +464,9 @@ async def test_get_authentication_add_scheme_and_path(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -516,9 +522,9 @@ async def test_get_authentication_obtain_client_credentials_with_scopes_list( client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -575,9 +581,9 @@ async def test_get_authentication_obtain_client_credentials_with_scopes_string( client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -718,9 +724,11 @@ async def test_concurrent_requests_only_fetch_token_once(self): async def mock_obtain_token(client): obtain_calls.append(1) await asyncio.sleep(0) # yield so other coroutines reach the lock - oauth_client._access_token = "concurrent-token" - oauth_client._access_expiry_time = datetime.now() + timedelta(seconds=3600) - oauth_client._access_token_expiry_buffer = 300 + oauth_client._token_state = _TokenState( + access_token="concurrent-token", + expiry_time=datetime.now() + timedelta(seconds=3600), + expiry_buffer=300, + ) with patch.object(oauth_client, "_obtain_token", side_effect=mock_obtain_token): results = await asyncio.gather( diff --git a/test/sync/oauth2_test.py b/test/sync/oauth2_test.py index 30a04c4..3d964b0 100644 --- a/test/sync/oauth2_test.py +++ b/test/sync/oauth2_test.py @@ -12,7 +12,7 @@ from openfga_sdk.credentials import CredentialConfiguration, Credentials from openfga_sdk.exceptions import AuthenticationError from openfga_sdk.sync import rest -from openfga_sdk.sync.oauth2 import OAuth2Client +from openfga_sdk.sync.oauth2 import OAuth2Client, _TokenState # Helper function to construct mock response @@ -36,8 +36,11 @@ def test_get_authentication_valid_client_credentials(self): Test getting authentication header when method is client credentials """ client = OAuth2Client(None) - client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() + timedelta(seconds=3600) + client._token_state = _TokenState( + access_token="XYZ123", + expiry_time=datetime.now() + timedelta(seconds=3600), + expiry_buffer=0, + ) auth_header = client.get_authentication_header(None) self.assertEqual(auth_header, {"Authorization": "Bearer XYZ123"}) @@ -68,9 +71,9 @@ def test_get_authentication_obtain_client_credentials(self, mock_request): client = OAuth2Client(credentials) auth_header = client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -126,9 +129,9 @@ def test_get_authentication_obtain_client_credentials_with_scopes_list( client = OAuth2Client(credentials) auth_header = client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -185,9 +188,9 @@ def test_get_authentication_obtain_client_credentials_with_scopes_string( client = OAuth2Client(credentials) auth_header = client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -268,8 +271,11 @@ async def test_get_authentication_obtain_with_expired_client_credentials_failed( rest_client = rest.RESTClientObject(Configuration()) client = OAuth2Client(credentials) - client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() - timedelta(seconds=240) + client._token_state = _TokenState( + access_token="XYZ123", + expiry_time=datetime.now() - timedelta(seconds=240), + expiry_buffer=0, + ) with self.assertRaises(AuthenticationError): client.get_authentication_header(rest_client) @@ -495,9 +501,11 @@ def test_concurrent_requests_only_fetch_token_once(self): def mock_obtain_token(client): obtain_calls.append(1) time.sleep(0.05) # hold the lock briefly so other threads queue up - oauth_client._access_token = "concurrent-token" - oauth_client._access_expiry_time = datetime.now() + timedelta(seconds=3600) - oauth_client._access_token_expiry_buffer = 300 + oauth_client._token_state = _TokenState( + access_token="concurrent-token", + expiry_time=datetime.now() + timedelta(seconds=3600), + expiry_buffer=300, + ) results = [] From aebd088795db4abe0ea4167a91614883f3e41d02 Mon Sep 17 00:00:00 2001 From: Anurag Bandyopadhyay Date: Tue, 28 Apr 2026 22:03:54 +0530 Subject: [PATCH 6/7] fix: ruff? --- openfga_sdk/oauth2.py | 5 ++--- openfga_sdk/sync/oauth2.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/openfga_sdk/oauth2.py b/openfga_sdk/oauth2.py index 9d1a72c..0af778a 100644 --- a/openfga_sdk/oauth2.py +++ b/openfga_sdk/oauth2.py @@ -152,9 +152,8 @@ async def _obtain_token(self, client): if api_response.get("expires_in") and api_response.get("access_token"): self._token_state = _TokenState( access_token=api_response.get("access_token"), - expiry_time=datetime.now() + timedelta( - seconds=int(api_response.get("expires_in")) - ), + expiry_time=datetime.now() + + timedelta(seconds=int(api_response.get("expires_in"))), expiry_buffer=( TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC diff --git a/openfga_sdk/sync/oauth2.py b/openfga_sdk/sync/oauth2.py index 5b29679..7e42437 100644 --- a/openfga_sdk/sync/oauth2.py +++ b/openfga_sdk/sync/oauth2.py @@ -153,9 +153,8 @@ def _obtain_token(self, client): if api_response.get("expires_in") and api_response.get("access_token"): self._token_state = _TokenState( access_token=api_response.get("access_token"), - expiry_time=datetime.now() + timedelta( - seconds=int(api_response.get("expires_in")) - ), + expiry_time=datetime.now() + + timedelta(seconds=int(api_response.get("expires_in"))), expiry_buffer=( TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC From 858134f02aeb0179129f526b954aae40c524f82f Mon Sep 17 00:00:00 2001 From: SoulPancake Date: Sat, 2 May 2026 00:26:44 +0530 Subject: [PATCH 7/7] feat: restructure and extract common class and func --- openfga_sdk/oauth2.py | 26 +------------------------- openfga_sdk/oauth2_common.py | 28 ++++++++++++++++++++++++++++ openfga_sdk/sync/oauth2.py | 26 +------------------------- test/oauth2_test.py | 3 ++- test/sync/oauth2_test.py | 3 ++- 5 files changed, 34 insertions(+), 52 deletions(-) create mode 100644 openfga_sdk/oauth2_common.py diff --git a/openfga_sdk/oauth2.py b/openfga_sdk/oauth2.py index 0af778a..de00961 100644 --- a/openfga_sdk/oauth2.py +++ b/openfga_sdk/oauth2.py @@ -1,10 +1,7 @@ import asyncio import json -import math import random -import sys -from dataclasses import dataclass from datetime import datetime, timedelta import urllib3 @@ -17,32 +14,11 @@ ) from openfga_sdk.credentials import Credentials from openfga_sdk.exceptions import AuthenticationError +from openfga_sdk.oauth2_common import _TokenState, jitter from openfga_sdk.telemetry.attributes import TelemetryAttributes from openfga_sdk.telemetry.telemetry import Telemetry -@dataclass(frozen=True) -class _TokenState: - access_token: str - expiry_time: datetime - expiry_buffer: float - - -def jitter(loop_count, min_wait_in_ms): - """ - Generate a random jitter value for exponential backoff - """ - minimum = math.ceil(2**loop_count * min_wait_in_ms) - maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms) - jitter = random.randrange(minimum, maximum) / 1000 - - # If running in pytest, set jitter to 0 to speed up tests - if "pytest" in sys.modules: - jitter = 0 - - return jitter - - class OAuth2Client: def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials diff --git a/openfga_sdk/oauth2_common.py b/openfga_sdk/oauth2_common.py new file mode 100644 index 0000000..71562a6 --- /dev/null +++ b/openfga_sdk/oauth2_common.py @@ -0,0 +1,28 @@ +import math +import random +import sys + +from dataclasses import dataclass +from datetime import datetime + + +@dataclass(frozen=True) +class _TokenState: + access_token: str + expiry_time: datetime + expiry_buffer: float + + +def jitter(loop_count, min_wait_in_ms): + """ + Generate a random jitter value for exponential backoff + """ + minimum = math.ceil(2**loop_count * min_wait_in_ms) + maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms) + jitter = random.randrange(minimum, maximum) / 1000 + + # If running in pytest, set jitter to 0 to speed up tests + if "pytest" in sys.modules: + jitter = 0 + + return jitter diff --git a/openfga_sdk/sync/oauth2.py b/openfga_sdk/sync/oauth2.py index 7e42437..d23cc93 100644 --- a/openfga_sdk/sync/oauth2.py +++ b/openfga_sdk/sync/oauth2.py @@ -1,11 +1,8 @@ import json -import math import random -import sys import threading import time -from dataclasses import dataclass from datetime import datetime, timedelta import urllib3 @@ -18,32 +15,11 @@ ) from openfga_sdk.credentials import Credentials from openfga_sdk.exceptions import AuthenticationError +from openfga_sdk.oauth2_common import _TokenState, jitter from openfga_sdk.telemetry.attributes import TelemetryAttributes from openfga_sdk.telemetry.telemetry import Telemetry -@dataclass(frozen=True) -class _TokenState: - access_token: str - expiry_time: datetime - expiry_buffer: float - - -def jitter(loop_count, min_wait_in_ms): - """ - Generate a random jitter value for exponential backoff - """ - minimum = math.ceil(2**loop_count * min_wait_in_ms) - maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms) - jitter = random.randrange(minimum, maximum) / 1000 - - # If running in pytest, set jitter to 0 to speed up tests - if "pytest" in sys.modules: - jitter = 0 - - return jitter - - class OAuth2Client: def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials diff --git a/test/oauth2_test.py b/test/oauth2_test.py index b4e061f..ffa3e5a 100644 --- a/test/oauth2_test.py +++ b/test/oauth2_test.py @@ -11,7 +11,8 @@ from openfga_sdk.constants import TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, USER_AGENT from openfga_sdk.credentials import CredentialConfiguration, Credentials from openfga_sdk.exceptions import AuthenticationError -from openfga_sdk.oauth2 import OAuth2Client, _TokenState +from openfga_sdk.oauth2 import OAuth2Client +from openfga_sdk.oauth2_common import _TokenState # Helper function to construct mock response diff --git a/test/sync/oauth2_test.py b/test/sync/oauth2_test.py index 3d964b0..40109fa 100644 --- a/test/sync/oauth2_test.py +++ b/test/sync/oauth2_test.py @@ -11,8 +11,9 @@ from openfga_sdk.constants import TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, USER_AGENT from openfga_sdk.credentials import CredentialConfiguration, Credentials from openfga_sdk.exceptions import AuthenticationError +from openfga_sdk.oauth2_common import _TokenState from openfga_sdk.sync import rest -from openfga_sdk.sync.oauth2 import OAuth2Client, _TokenState +from openfga_sdk.sync.oauth2 import OAuth2Client # Helper function to construct mock response