Skip to content

Commit 599bcf5

Browse files
SK-2777: update client re-init logic
1 parent ad095e4 commit 599bcf5

4 files changed

Lines changed: 36 additions & 34 deletions

File tree

skyflow/error/_skyflow_error.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,4 @@ def __init__(self,
1515
self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value
1616
self.details = details
1717
self.request_id = request_id
18-
log_error(message, http_code, request_id, grpc_code, http_status, details)
1918
super().__init__()

skyflow/utils/_utils.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,18 @@
3030
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value
3131

3232
def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None):
33-
dotenv.load_dotenv()
34-
dotenv_path = dotenv.find_dotenv(usecwd=True)
35-
if dotenv_path:
36-
load_dotenv(dotenv_path)
37-
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
3833
if config_level_creds:
3934
return config_level_creds
4035
if common_skyflow_creds:
4136
return common_skyflow_creds
37+
dotenv_path = dotenv.find_dotenv(usecwd=True)
38+
if dotenv_path:
39+
load_dotenv(dotenv_path)
40+
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
4241
if env_skyflow_credentials:
43-
env_skyflow_credentials.strip()
44-
try:
45-
env_creds = env_skyflow_credentials.replace('\n', '\\n')
46-
return {
47-
'credentials_string': env_creds
48-
}
49-
except json.JSONDecodeError:
50-
raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code)
51-
else:
52-
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
42+
env_creds = env_skyflow_credentials.strip().replace('\n', '\\n')
43+
return {'credentials_string': env_creds}
44+
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
5345

5446
def validate_api_key(api_key: str, logger = None) -> bool:
5547
if len(api_key) != 42:

skyflow/utils/validations/_validations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non
122122
)
123123
if is_expired(credentials.get("token"), logger):
124124
raise SkyflowError(
125-
SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id)
126-
if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value,
125+
SkyflowMessages.Error.EXPIRED_TOKEN.value
126+
if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_TOKEN.value,
127127
invalid_input_error_code
128128
)
129129
elif "api_key" in credentials:
@@ -389,7 +389,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest):
389389
if hasattr(request, 'wait_time') and request.wait_time is not None:
390390
if not isinstance(request.wait_time, (int, float)):
391391
raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code)
392-
if request.wait_time < 0 and request.wait_time > 64:
392+
if request.wait_time < 0 or request.wait_time > 64:
393393
raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code)
394394

395395
def validate_insert_request(logger, request):

skyflow/vault/client/client.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ def __init__(self, config):
1414
self.__logger = None
1515
self.__is_config_updated = False
1616
self.__bearer_token = None
17+
self.__credentials = None
18+
self.__vault_url = None
19+
self.__is_static_token = None
1720

1821
def set_common_skyflow_credentials(self, credentials):
1922
self.__common_skyflow_credentials = credentials
@@ -23,16 +26,29 @@ def set_logger(self, log_level, logger):
2326
self.__logger = logger
2427

2528
def initialize_client_configuration(self):
26-
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
27-
token = self.get_bearer_token(credentials)
28-
vault_url = get_vault_url(self.__config.get("cluster_id"),
29-
self.__config.get("env"),
30-
self.__config.get("vault_id"),
31-
logger = self.__logger)
32-
self.initialize_api_client(vault_url, token)
29+
if self.__api_client is not None and not self.__is_config_updated:
30+
if self.__is_static_token:
31+
return
32+
if self.__bearer_token is not None and not is_expired(self.__bearer_token):
33+
return
34+
35+
needs_reinit = self.__api_client is None or self.__is_config_updated
36+
if needs_reinit:
37+
self.__credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger=self.__logger)
38+
self.__vault_url = get_vault_url(self.__config.get("cluster_id"),
39+
self.__config.get("env"),
40+
self.__config.get("vault_id"),
41+
logger=self.__logger)
42+
self.__is_static_token = 'token' in self.__credentials or 'api_key' in self.__credentials
43+
token = self.get_bearer_token(self.__credentials)
44+
if needs_reinit:
45+
self.initialize_api_client(self.__vault_url, token)
3346

3447
def initialize_api_client(self, vault_url, token):
35-
self.__api_client = Skyflow(base_url=vault_url, token=token)
48+
self.__api_client = Skyflow(
49+
base_url=vault_url,
50+
token=lambda: self.__bearer_token if self.__bearer_token else token,
51+
)
3652

3753
def get_records_api(self):
3854
return self.__api_client.records
@@ -63,11 +79,10 @@ def get_bearer_token(self, credentials):
6379
"ctx": self.__config.get("ctx")
6480
}
6581

66-
if self.__bearer_token is None or self.__is_config_updated:
82+
if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token):
6783
if 'path' in credentials:
68-
path = credentials.get("path")
6984
self.__bearer_token, _ = generate_bearer_token(
70-
path,
85+
credentials.get("path"),
7186
options,
7287
self.__logger
7388
)
@@ -83,10 +98,6 @@ def get_bearer_token(self, credentials):
8398
else:
8499
log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)
85100

86-
if is_expired(self.__bearer_token):
87-
self.__is_config_updated = True
88-
raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
89-
90101
return self.__bearer_token
91102

92103
def update_config(self, config):

0 commit comments

Comments
 (0)