@@ -14,6 +14,9 @@ def __init__(self, config):
1414 self .__logger = None
1515 self .__is_config_updated = False
1616 self .__bearer_token = None
17+ self .__credentials = None
18+ self .__vault_url = None
19+ self .__is_static_token = None
1720
1821 def set_common_skyflow_credentials (self , credentials ):
1922 self .__common_skyflow_credentials = credentials
@@ -23,16 +26,29 @@ def set_logger(self, log_level, logger):
2326 self .__logger = logger
2427
2528 def initialize_client_configuration (self ):
26- credentials = get_credentials (self .__config .get ("credentials" ), self .__common_skyflow_credentials , logger = self .__logger )
27- token = self .get_bearer_token (credentials )
28- vault_url = get_vault_url (self .__config .get ("cluster_id" ),
29- self .__config .get ("env" ),
30- self .__config .get ("vault_id" ),
31- logger = self .__logger )
32- self .initialize_api_client (vault_url , token )
29+ if self .__api_client is not None and not self .__is_config_updated :
30+ if self .__is_static_token :
31+ return
32+ if self .__bearer_token is not None and not is_expired (self .__bearer_token ):
33+ return
34+
35+ needs_reinit = self .__api_client is None or self .__is_config_updated
36+ if needs_reinit :
37+ self .__credentials = get_credentials (self .__config .get ("credentials" ), self .__common_skyflow_credentials , logger = self .__logger )
38+ self .__vault_url = get_vault_url (self .__config .get ("cluster_id" ),
39+ self .__config .get ("env" ),
40+ self .__config .get ("vault_id" ),
41+ logger = self .__logger )
42+ self .__is_static_token = 'token' in self .__credentials or 'api_key' in self .__credentials
43+ token = self .get_bearer_token (self .__credentials )
44+ if needs_reinit :
45+ self .initialize_api_client (self .__vault_url , token )
3346
3447 def initialize_api_client (self , vault_url , token ):
35- self .__api_client = Skyflow (base_url = vault_url , token = token )
48+ self .__api_client = Skyflow (
49+ base_url = vault_url ,
50+ token = lambda : self .__bearer_token if self .__bearer_token else token ,
51+ )
3652
3753 def get_records_api (self ):
3854 return self .__api_client .records
@@ -63,11 +79,10 @@ def get_bearer_token(self, credentials):
6379 "ctx" : self .__config .get ("ctx" )
6480 }
6581
66- if self .__bearer_token is None or self .__is_config_updated :
82+ if self .__bearer_token is None or self .__is_config_updated or is_expired ( self . __bearer_token ) :
6783 if 'path' in credentials :
68- path = credentials .get ("path" )
6984 self .__bearer_token , _ = generate_bearer_token (
70- path ,
85+ credentials . get ( " path" ) ,
7186 options ,
7287 self .__logger
7388 )
@@ -83,10 +98,6 @@ def get_bearer_token(self, credentials):
8398 else :
8499 log_info (SkyflowMessages .Info .REUSE_BEARER_TOKEN .value , self .__logger )
85100
86- if is_expired (self .__bearer_token ):
87- self .__is_config_updated = True
88- raise SyntaxError (SkyflowMessages .Error .EXPIRED_TOKEN .value , SkyflowMessages .ErrorCodes .INVALID_INPUT .value )
89-
90101 return self .__bearer_token
91102
92103 def update_config (self , config ):
0 commit comments