diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py index 8871f6727f..c5efd78b4d 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py @@ -15,14 +15,17 @@ import dataclasses import json +import logging import io import sys import hashlib +import hmac import pickle +import secrets -from typing import Any, Callable, Union +from typing import Any, Callable, Union, Optional import cloudpickle from tblib import pickling_support @@ -38,6 +41,8 @@ # Note: do not use os.path.join for s3 uris, fails on windows +logger = logging.getLogger(__name__) + def _get_python_version(): """Returns the current python version.""" @@ -49,6 +54,7 @@ class _MetaData: """Metadata about the serialized data or functions.""" sha256_hash: str + secret_arn: Optional[str] = None # ARN to AWS Secrets Manager secret containing HMAC key version: str = "2023-04-24" python_version: str = _get_python_version() serialization_module: str = "cloudpickle" @@ -66,7 +72,8 @@ def from_json(s): raise DeserializationError("Corrupt metadata file. It is not a valid json file.") sha256_hash = obj.get("sha256_hash") - metadata = _MetaData(sha256_hash=sha256_hash) + secret_arn = obj.get("secret_arn") # May be None for legacy format + metadata = _MetaData(sha256_hash=sha256_hash, secret_arn=secret_arn) metadata.version = obj.get("version") metadata.python_version = obj.get("python_version") metadata.serialization_module = obj.get("serialization_module") @@ -155,16 +162,21 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: # TODO: use dask serializer in case dask distributed is installed in users' environment. def serialize_func_to_s3( - func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None + func: Callable, + sagemaker_session: Session, + s3_uri: str, + job_name: str, + s3_kms_key: str = None ): """Serializes function and uploads it to S3. Args: + func: function to be serialized and persisted sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - func: function to be serialized and persisted Raises: SerializationError: when fail to serialize function to bytes. """ @@ -173,6 +185,7 @@ def serialize_func_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(func), s3_uri=s3_uri, sagemaker_session=sagemaker_session, + job_name=job_name, s3_kms_key=s3_kms_key, ) @@ -199,23 +212,32 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callabl bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, + buffer=bytes_to_deserialize, + sagemaker_session=sagemaker_session, + secret_arn=metadata.secret_arn, + s3_uri=s3_uri ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_obj_to_s3( - obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None + obj: Any, + sagemaker_session: Session, + s3_uri: str, + job_name: str, + s3_kms_key: str = None ): """Serializes data object and uploads it to S3. Args: + obj: object to be serialized and persisted sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ @@ -224,6 +246,7 @@ def serialize_obj_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(obj), s3_uri=s3_uri, sagemaker_session=sagemaker_session, + job_name=job_name, s3_kms_key=s3_kms_key, ) @@ -288,23 +311,32 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, + buffer=bytes_to_deserialize, + sagemaker_session=sagemaker_session, + secret_arn=metadata.secret_arn, + s3_uri=s3_uri ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_exception_to_s3( - exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None + exc: Exception, + sagemaker_session: Session, + s3_uri: str, + job_name: str, + s3_kms_key: str = None ): """Serializes exception with traceback and uploads it to S3. Args: + exc: Exception to be serialized and persisted sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. - exc: Exception to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ @@ -314,6 +346,7 @@ def serialize_exception_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(exc), s3_uri=s3_uri, sagemaker_session=sagemaker_session, + job_name=job_name, s3_kms_key=s3_kms_key, ) @@ -322,6 +355,7 @@ def _upload_payload_and_metadata_to_s3( bytes_to_upload: Union[bytes, io.BytesIO], s3_uri: str, sagemaker_session: Session, + job_name: str, s3_kms_key, ): """Uploads serialized payload and metadata to s3. @@ -331,14 +365,22 @@ def _upload_payload_and_metadata_to_s3( s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. + job_name (str): Remote function job name for secret management s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. """ _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) - sha256_hash = _compute_hash(bytes_to_upload) + # Get or create HMAC secret in Secrets Manager + secret_arn, hmac_key = _get_or_create_hmac_secret(sagemaker_session, job_name) + + # Compute HMAC-SHA256 hash + sha256_hash = _compute_hmac(bytes_to_upload, hmac_key) + + # Store secret ARN in Parameter Store as trust anchor (Mitigation #3) + _store_secret_arn_in_parameter_store(sagemaker_session, job_name, secret_arn) _upload_bytes_to_s3( - _MetaData(sha256_hash).to_json(), + _MetaData(sha256_hash=sha256_hash, secret_arn=secret_arn).to_json(), f"{s3_uri}/metadata.json", s3_kms_key, sagemaker_session, @@ -365,7 +407,11 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> An bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( - expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize + expected_hash_value=metadata.sha256_hash, + buffer=bytes_to_deserialize, + sagemaker_session=sagemaker_session, + secret_arn=metadata.secret_arn, + s3_uri=s3_uri ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) @@ -396,15 +442,252 @@ def _compute_hash(buffer: bytes) -> str: return hashlib.sha256(buffer).hexdigest() -def _perform_integrity_check(expected_hash_value: str, buffer: bytes): +def _get_or_create_hmac_secret(sagemaker_session: Session, job_name: str) -> tuple[str, str]: + """Get or create HMAC key in AWS Secrets Manager. + + Args: + sagemaker_session: SageMaker session + job_name: Remote function job name + + Returns: + Tuple of (secret_arn, hmac_key) + """ + secret_name = f"sagemaker/remote-function/{job_name}/hmac-key" + secrets_client = sagemaker_session.boto_session.client('secretsmanager') + + try: + # Try to retrieve existing secret + response = secrets_client.get_secret_value(SecretId=secret_name) + return response['ARN'], response['SecretString'] + except secrets_client.exceptions.ResourceNotFoundException: + # Create new secret + hmac_key = secrets.token_hex(32) + + response = secrets_client.create_secret( + Name=secret_name, + SecretString=hmac_key, + Description=f"HMAC key for SageMaker remote function job {job_name}", + Tags=[ + {'Key': 'SageMaker:JobName', 'Value': job_name}, + {'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'} + ] + ) + return response['ARN'], hmac_key + + +def _get_hmac_key_from_secret(sagemaker_session: Session, secret_arn: str) -> str: + """Retrieve HMAC key from AWS Secrets Manager. + + Args: + sagemaker_session: SageMaker session + secret_arn: ARN of the secret containing HMAC key + + Returns: + HMAC key string + """ + secrets_client = sagemaker_session.boto_session.client('secretsmanager') + response = secrets_client.get_secret_value(SecretId=secret_arn) + return response['SecretString'] + + +def _compute_hmac(buffer: bytes, hmac_key: str) -> str: + """Compute HMAC-SHA256 hash. + + Args: + buffer: Data to hash + hmac_key: HMAC secret key + + Returns: + HMAC-SHA256 hex digest + """ + return hmac.new(hmac_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() + + +def _store_secret_arn_in_parameter_store( + sagemaker_session: Session, + job_name: str, + secret_arn: str +): + """Store secret ARN in Parameter Store as trust anchor. + + Args: + sagemaker_session: SageMaker session + job_name: Remote function job name + secret_arn: ARN of the secret to store + """ + ssm_client = sagemaker_session.boto_session.client('ssm') + parameter_name = f"/sagemaker/remote-function/{job_name}/secret-arn" + + ssm_client.put_parameter( + Name=parameter_name, + Value=secret_arn, + Type="String", + Overwrite=True, + Description=f"Secret ARN for SageMaker remote function job {job_name}", + Tags=[ + {'Key': 'SageMaker:JobName', 'Value': job_name}, + {'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'} + ] + ) + + +def _get_secret_arn_from_parameter_store( + sagemaker_session: Session, + job_name: str +) -> str: + """Retrieve secret ARN from Parameter Store. + + Args: + sagemaker_session: SageMaker session + job_name: Remote function job name + + Returns: + Secret ARN string + + Raises: + DeserializationError: If parameter not found + """ + ssm_client = sagemaker_session.boto_session.client('ssm') + parameter_name = f"/sagemaker/remote-function/{job_name}/secret-arn" + + try: + response = ssm_client.get_parameter(Name=parameter_name) + return response['Parameter']['Value'] + except ssm_client.exceptions.ParameterNotFound: + raise DeserializationError( + f"Secret ARN not found in Parameter Store for job {job_name}. " + "This may indicate the job was not properly initialized or artifacts were tampered with." + ) + + +def _extract_job_name_from_s3_uri(s3_uri: str) -> str: + """Extract job name from S3 URI. + + S3 URI format: s3://bucket/path/to/job-name/results + or: s3://bucket/job-name/function + + Args: + s3_uri: S3 URI containing job name + + Returns: + Job name extracted from URI + """ + # Remove s3:// prefix and split by / + parts = s3_uri.replace("s3://", "").split("/") + + # Try to find a part that looks like a job name + # Job names typically contain execution IDs or timestamps + for part in reversed(parts): + if part and part not in ['function', 'arguments', 'results', 'exception', 'payload.pkl', 'metadata.json']: + return part + + # Fallback: use the last meaningful part + return parts[-2] if len(parts) > 1 else parts[0] + + +def _validate_secret_arn( + sagemaker_session: Session, + metadata_secret_arn: str, + job_name: str +): + """Validate secret ARN from metadata against trusted sources. + + Implements two mitigations: + 1. Validate secret is in same AWS account + 2. Validate secret ARN matches Parameter Store (trust anchor) + + Args: + sagemaker_session: SageMaker session + metadata_secret_arn: Secret ARN from S3 metadata (untrusted) + job_name: Remote function job name + + Raises: + DeserializationError: If validation fails + """ + # Mitigation #1: Validate same account + sts_client = sagemaker_session.boto_session.client('sts') + current_account_id = sts_client.get_caller_identity()['Account'] + + # Parse account ID from ARN: arn:aws:secretsmanager:region:ACCOUNT_ID:secret:name + arn_parts = metadata_secret_arn.split(":") + if len(arn_parts) < 5: + raise DeserializationError(f"Invalid secret ARN format: {metadata_secret_arn}") + + metadata_account_id = arn_parts[4] + + if metadata_account_id != current_account_id: + raise DeserializationError( + f"Secret must be in the same AWS account. " + f"Expected account {current_account_id}, but got {metadata_account_id}. " + "This may indicate a cross-account attack attempt." + ) + + # Mitigation #3: Validate against Parameter Store (trust anchor) + expected_secret_arn = _get_secret_arn_from_parameter_store(sagemaker_session, job_name) + + if metadata_secret_arn != expected_secret_arn: + raise DeserializationError( + f"Secret ARN mismatch. Expected: {expected_secret_arn}, " + f"Got: {metadata_secret_arn}. " + "Possible tampering detected - metadata may have been modified." + ) + + +def _perform_integrity_check( + expected_hash_value: str, + buffer: bytes, + sagemaker_session: Optional[Session] = None, + secret_arn: Optional[str] = None, + s3_uri: Optional[str] = None +): """Performs integrity checks for serialized code/arguments uploaded to s3. Verifies whether the hash read from s3 matches the hash calculated during remote function execution. + + Args: + expected_hash_value: Expected hash value from metadata + buffer: Serialized data buffer + sagemaker_session: SageMaker session (required if secret_arn is provided) + secret_arn: ARN of secret containing HMAC key (None for legacy plain SHA-256) + s3_uri: S3 URI for extracting job name (required if secret_arn is provided) """ - actual_hash_value = _compute_hash(buffer=buffer) - if expected_hash_value != actual_hash_value: - raise DeserializationError( - "Integrity check for the serialized function or data failed. " - "Please restrict access to your S3 bucket" + if secret_arn: + # New secure method: HMAC with key from Secrets Manager + if not sagemaker_session: + raise DeserializationError( + "sagemaker_session is required for HMAC integrity check" + ) + + if not s3_uri: + raise DeserializationError( + "s3_uri is required for HMAC integrity check to extract job name" + ) + + # Extract job name from S3 URI + job_name = _extract_job_name_from_s3_uri(s3_uri) + + # Validate secret ARN (Mitigations #1 and #3) + _validate_secret_arn(sagemaker_session, secret_arn, job_name) + + # Now safe to retrieve HMAC key + hmac_key = _get_hmac_key_from_secret(sagemaker_session, secret_arn) + actual_hash_value = _compute_hmac(buffer, hmac_key) + + if not hmac.compare_digest(expected_hash_value, actual_hash_value): + raise DeserializationError( + "HMAC integrity check failed. Serialized data may have been tampered with. " + "Please restrict access to your S3 bucket." + ) + else: + # Legacy method: plain SHA-256 (backward compatibility) + logger.warning( + "Using legacy SHA-256 integrity check without HMAC authentication. " + "This provides weaker security guarantees. Please upgrade to the latest SDK version." ) + actual_hash_value = _compute_hash(buffer) + if expected_hash_value != actual_hash_value: + raise DeserializationError( + "Integrity check for the serialized function or data failed. " + "Please restrict access to your S3 bucket" + ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py index c7ee86f8a7..1a45c378f4 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py @@ -57,6 +57,7 @@ def __init__( s3_base_uri: str, s3_kms_key: str = None, context: Context = Context(), + job_name: str = None, ): """Construct a StoredFunction object. @@ -66,11 +67,13 @@ def __init__( s3_base_uri: the base uri to which serialized artifacts will be uploaded. s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. context: Build or run context of a pipeline step. + job_name: Remote function job name for secret management. """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key self.context = context + self.job_name = job_name or os.environ.get("TRAINING_JOB_NAME") # For pipeline steps, function code is at: base/step_name/build_timestamp/ # For results, path is: base/step_name/build_timestamp/execution_id/ @@ -110,6 +113,7 @@ def save(self, func, *args, **kwargs): func=func, sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -123,7 +127,7 @@ def save(self, func, *args, **kwargs): obj=(args, kwargs), sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), - + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -144,6 +148,7 @@ def save_pipeline_step_function(self, serialized_data): s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), sagemaker_session=self.sagemaker_session, + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -156,6 +161,7 @@ def save_pipeline_step_function(self, serialized_data): s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), sagemaker_session=self.sagemaker_session, + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) @@ -203,7 +209,7 @@ def load_and_invoke(self) -> Any: obj=result, sagemaker_session=self.sagemaker_session, s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER), - + job_name=self.job_name, s3_kms_key=self.s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/errors.py b/sagemaker-core/src/sagemaker/core/remote_function/errors.py index 3f391570cf..6315c1c527 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/errors.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/errors.py @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg): f.write(failure_msg) -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, job_name=None) -> int: """Handle all exceptions raised during remote function execution. Args: @@ -79,6 +79,7 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: AWS service calls are delegated to. s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + job_name (str): Remote function job name for secret management. Returns : exit_code (int): Exit code to terminate current job. """ @@ -96,6 +97,7 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: exc=error, sagemaker_session=sagemaker_session, s3_uri=s3_path_join(s3_base_uri, "exception"), + job_name=job_name or os.environ.get("TRAINING_JOB_NAME"), s3_kms_key=s3_kms_key, ) diff --git a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py index 2e69f4f116..c43978f687 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/invoke_function.py @@ -108,6 +108,7 @@ def _execute_remote_function( s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, context=context, + job_name=os.environ.get("TRAINING_JOB_NAME"), ) if run_in_context: diff --git a/sagemaker-core/src/sagemaker/core/remote_function/job.py b/sagemaker-core/src/sagemaker/core/remote_function/job.py index 6e727d4b9c..b6ac5572b7 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/job.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/job.py @@ -931,6 +931,7 @@ def compile( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=job_settings.s3_kms_key, + job_name=job_name, ) stored_function.save(func, *func_args, **func_kwargs) else: @@ -942,6 +943,7 @@ def compile( step_name=step_compilation_context.step_name, func_step_s3_dir=step_compilation_context.pipeline_build_time, ), + job_name=job_name, ) stored_function.save_pipeline_step_function(serialized_data) diff --git a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py index 4810eba2e0..7bd24489e7 100644 --- a/sagemaker-core/tests/unit/remote_function/test_invoke_function.py +++ b/sagemaker-core/tests/unit/remote_function/test_invoke_function.py @@ -188,6 +188,7 @@ def test_executes_without_run_context(self, mock_stored_function_class): s3_base_uri="s3://bucket/path", s3_kms_key="key-123", context=mock_context, + job_name=None, ) mock_stored_func.load_and_invoke.assert_called_once() diff --git a/sagemaker-core/tests/unit/remote_function/test_serialization_security.py b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py new file mode 100644 index 0000000000..0478617ea1 --- /dev/null +++ b/sagemaker-core/tests/unit/remote_function/test_serialization_security.py @@ -0,0 +1,401 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for serialization security (HMAC + Secrets Manager + Parameter Store).""" +from __future__ import absolute_import + +import hashlib +import hmac as hmac_module +import json +from unittest.mock import Mock, patch, MagicMock + +import pytest + +from sagemaker.core.remote_function.core.serialization import ( + _MetaData, + _compute_hash, + _compute_hmac, + _get_or_create_hmac_secret, + _get_hmac_key_from_secret, + _store_secret_arn_in_parameter_store, + _get_secret_arn_from_parameter_store, + _extract_job_name_from_s3_uri, + _validate_secret_arn, + _perform_integrity_check, + _upload_payload_and_metadata_to_s3, + serialize_obj_to_s3, + deserialize_obj_from_s3, + serialize_func_to_s3, + serialize_exception_to_s3, + deserialize_func_from_s3, + deserialize_exception_from_s3, +) +from sagemaker.core.remote_function.errors import DeserializationError + + +MOCK_JOB_NAME = "test-remote-function-job" +MOCK_SECRET_ARN = "arn:aws:secretsmanager:us-west-2:123456789012:secret:sagemaker/remote-function/test-remote-function-job/hmac-key-AbCdEf" +MOCK_HMAC_KEY = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" +MOCK_ACCOUNT_ID = "123456789012" +MOCK_S3_URI = "s3://my-bucket/remote-function/test-remote-function-job/results" + + +def _mock_sagemaker_session(account_id=MOCK_ACCOUNT_ID): + """Create a mock SageMaker session with Secrets Manager, SSM, and STS clients.""" + session = Mock() + + # Mock Secrets Manager client + secrets_client = Mock() + secrets_client.get_secret_value.return_value = { + "ARN": MOCK_SECRET_ARN, + "SecretString": MOCK_HMAC_KEY, + } + secrets_client.create_secret.return_value = { + "ARN": MOCK_SECRET_ARN, + } + secrets_client.exceptions = Mock() + secrets_client.exceptions.ResourceNotFoundException = type( + "ResourceNotFoundException", (Exception,), {} + ) + + # Mock SSM client + ssm_client = Mock() + ssm_client.get_parameter.return_value = { + "Parameter": {"Value": MOCK_SECRET_ARN} + } + ssm_client.exceptions = Mock() + ssm_client.exceptions.ParameterNotFound = type( + "ParameterNotFound", (Exception,), {} + ) + + # Mock STS client + sts_client = Mock() + sts_client.get_caller_identity.return_value = {"Account": account_id} + + def client_factory(service_name): + if service_name == "secretsmanager": + return secrets_client + elif service_name == "ssm": + return ssm_client + elif service_name == "sts": + return sts_client + return Mock() + + session.boto_session.client = client_factory + return session, secrets_client, ssm_client, sts_client + + +class TestMetaData: + """Tests for _MetaData class.""" + + def test_metadata_with_secret_arn(self): + metadata = _MetaData(sha256_hash="abc123", secret_arn=MOCK_SECRET_ARN) + json_bytes = metadata.to_json() + parsed = _MetaData.from_json(json_bytes) + + assert parsed.sha256_hash == "abc123" + assert parsed.secret_arn == MOCK_SECRET_ARN + + def test_metadata_without_secret_arn_legacy(self): + metadata = _MetaData(sha256_hash="abc123") + json_bytes = metadata.to_json() + parsed = _MetaData.from_json(json_bytes) + + assert parsed.sha256_hash == "abc123" + assert parsed.secret_arn is None + + def test_metadata_missing_hash_raises(self): + with pytest.raises(DeserializationError, match="SHA256 hash"): + _MetaData.from_json(json.dumps({"version": "2023-04-24", "serialization_module": "cloudpickle"})) + + def test_metadata_invalid_json_raises(self): + with pytest.raises(DeserializationError, match="not a valid json"): + _MetaData.from_json(b"not json") + + +class TestComputeHmac: + """Tests for HMAC computation.""" + + def test_compute_hmac(self): + data = b"test data" + key = "test-key" + result = _compute_hmac(data, key) + expected = hmac_module.new(key.encode(), msg=data, digestmod=hashlib.sha256).hexdigest() + assert result == expected + + def test_compute_hmac_different_keys_produce_different_hashes(self): + data = b"test data" + hash1 = _compute_hmac(data, "key1") + hash2 = _compute_hmac(data, "key2") + assert hash1 != hash2 + + def test_compute_hash_plain_sha256(self): + data = b"test data" + result = _compute_hash(data) + expected = hashlib.sha256(data).hexdigest() + assert result == expected + + +class TestGetOrCreateHmacSecret: + """Tests for Secrets Manager integration.""" + + def test_get_existing_secret(self): + session, secrets_client, _, _ = _mock_sagemaker_session() + + arn, key = _get_or_create_hmac_secret(session, MOCK_JOB_NAME) + + assert arn == MOCK_SECRET_ARN + assert key == MOCK_HMAC_KEY + secrets_client.get_secret_value.assert_called_once_with( + SecretId=f"sagemaker/remote-function/{MOCK_JOB_NAME}/hmac-key" + ) + + def test_create_new_secret_when_not_found(self): + session, secrets_client, _, _ = _mock_sagemaker_session() + + # Simulate ResourceNotFoundException + secrets_client.get_secret_value.side_effect = ( + secrets_client.exceptions.ResourceNotFoundException("not found") + ) + + arn, key = _get_or_create_hmac_secret(session, MOCK_JOB_NAME) + + assert arn == MOCK_SECRET_ARN + assert len(key) == 64 # secrets.token_hex(32) produces 64 chars + secrets_client.create_secret.assert_called_once() + + +class TestParameterStore: + """Tests for Parameter Store trust anchor.""" + + def test_store_secret_arn(self): + session, _, ssm_client, _ = _mock_sagemaker_session() + + _store_secret_arn_in_parameter_store(session, MOCK_JOB_NAME, MOCK_SECRET_ARN) + + ssm_client.put_parameter.assert_called_once() + call_kwargs = ssm_client.put_parameter.call_args[1] + assert call_kwargs["Name"] == f"/sagemaker/remote-function/{MOCK_JOB_NAME}/secret-arn" + assert call_kwargs["Value"] == MOCK_SECRET_ARN + assert call_kwargs["Overwrite"] is True + + def test_get_secret_arn(self): + session, _, ssm_client, _ = _mock_sagemaker_session() + + result = _get_secret_arn_from_parameter_store(session, MOCK_JOB_NAME) + + assert result == MOCK_SECRET_ARN + ssm_client.get_parameter.assert_called_once_with( + Name=f"/sagemaker/remote-function/{MOCK_JOB_NAME}/secret-arn" + ) + + def test_get_secret_arn_not_found_raises(self): + session, _, ssm_client, _ = _mock_sagemaker_session() + ssm_client.get_parameter.side_effect = ( + ssm_client.exceptions.ParameterNotFound("not found") + ) + + with pytest.raises(DeserializationError, match="Secret ARN not found"): + _get_secret_arn_from_parameter_store(session, MOCK_JOB_NAME) + + +class TestExtractJobName: + """Tests for S3 URI job name extraction.""" + + def test_extract_from_results_uri(self): + result = _extract_job_name_from_s3_uri( + "s3://bucket/remote-function/my-job-123/results" + ) + assert result == "my-job-123" + + def test_extract_from_function_uri(self): + result = _extract_job_name_from_s3_uri( + "s3://bucket/remote-function/my-job-123/function" + ) + assert result == "my-job-123" + + def test_extract_from_exception_uri(self): + result = _extract_job_name_from_s3_uri( + "s3://bucket/remote-function/my-job-123/exception" + ) + assert result == "my-job-123" + + +class TestValidateSecretArn: + """Tests for secret ARN validation (Mitigations #1 and #3).""" + + def test_valid_secret_arn_passes(self): + """Valid ARN in same account matching Parameter Store should pass.""" + session, _, _, _ = _mock_sagemaker_session() + + # Should not raise + _validate_secret_arn(session, MOCK_SECRET_ARN, MOCK_JOB_NAME) + + def test_cross_account_arn_rejected(self): + """Mitigation #1: Secret ARN from different account should be rejected.""" + session, _, _, _ = _mock_sagemaker_session(account_id=MOCK_ACCOUNT_ID) + + attacker_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:evil-secret" + + with pytest.raises(DeserializationError, match="same AWS account"): + _validate_secret_arn(session, attacker_arn, MOCK_JOB_NAME) + + def test_tampered_arn_rejected(self): + """Mitigation #3: ARN not matching Parameter Store should be rejected.""" + session, _, ssm_client, _ = _mock_sagemaker_session() + + # Parameter Store returns the legitimate ARN + ssm_client.get_parameter.return_value = { + "Parameter": {"Value": MOCK_SECRET_ARN} + } + + # Attacker's ARN (same account but different secret) + tampered_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:attacker-created-secret" + + with pytest.raises(DeserializationError, match="Secret ARN mismatch"): + _validate_secret_arn(session, tampered_arn, MOCK_JOB_NAME) + + def test_invalid_arn_format_rejected(self): + """Malformed ARN should be rejected.""" + session, _, _, _ = _mock_sagemaker_session() + + with pytest.raises(DeserializationError, match="Invalid secret ARN format"): + _validate_secret_arn(session, "not-an-arn", MOCK_JOB_NAME) + + +class TestPerformIntegrityCheck: + """Tests for integrity check with HMAC.""" + + def test_hmac_integrity_check_passes(self): + """Valid HMAC should pass integrity check.""" + session, _, _, _ = _mock_sagemaker_session() + + payload = b"test payload" + expected_hmac = _compute_hmac(payload, MOCK_HMAC_KEY) + + # Should not raise + _perform_integrity_check( + expected_hash_value=expected_hmac, + buffer=payload, + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + s3_uri=MOCK_S3_URI, + ) + + def test_hmac_integrity_check_fails_on_tampered_payload(self): + """Tampered payload should fail HMAC check.""" + session, _, _, _ = _mock_sagemaker_session() + + original_payload = b"original payload" + tampered_payload = b"tampered payload" + expected_hmac = _compute_hmac(original_payload, MOCK_HMAC_KEY) + + with pytest.raises(DeserializationError, match="HMAC integrity check failed"): + _perform_integrity_check( + expected_hash_value=expected_hmac, + buffer=tampered_payload, + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + s3_uri=MOCK_S3_URI, + ) + + def test_legacy_sha256_check_passes_with_warning(self): + """Legacy SHA-256 check should pass with warning when no secret_arn.""" + payload = b"test payload" + expected_hash = _compute_hash(payload) + + # Should not raise (legacy path) + _perform_integrity_check( + expected_hash_value=expected_hash, + buffer=payload, + ) + + def test_legacy_sha256_check_fails_on_tampered_payload(self): + """Legacy SHA-256 check should fail on tampered payload.""" + original_payload = b"original payload" + tampered_payload = b"tampered payload" + expected_hash = _compute_hash(original_payload) + + with pytest.raises(DeserializationError, match="Integrity check"): + _perform_integrity_check( + expected_hash_value=expected_hash, + buffer=tampered_payload, + ) + + def test_hmac_check_requires_session(self): + """HMAC check should require sagemaker_session.""" + with pytest.raises(DeserializationError, match="sagemaker_session is required"): + _perform_integrity_check( + expected_hash_value="hash", + buffer=b"data", + secret_arn=MOCK_SECRET_ARN, + ) + + def test_hmac_check_requires_s3_uri(self): + """HMAC check should require s3_uri.""" + session, _, _, _ = _mock_sagemaker_session() + + with pytest.raises(DeserializationError, match="s3_uri is required"): + _perform_integrity_check( + expected_hash_value="hash", + buffer=b"data", + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + ) + + +class TestAttackScenarios: + """Tests simulating actual attack scenarios.""" + + def test_attacker_replaces_payload_and_metadata_plain_hash(self): + """Attacker replaces both files with plain SHA-256 - should fail HMAC check.""" + session, secrets_client, _, _ = _mock_sagemaker_session() + + # Attacker creates malicious payload + malicious_payload = b"malicious code" + + # Attacker computes plain SHA-256 (not HMAC) + plain_hash = hashlib.sha256(malicious_payload).hexdigest() + + # Attacker's HMAC won't match because they don't know the key + with pytest.raises(DeserializationError, match="HMAC integrity check failed"): + _perform_integrity_check( + expected_hash_value=plain_hash, + buffer=malicious_payload, + sagemaker_session=session, + secret_arn=MOCK_SECRET_ARN, + s3_uri=MOCK_S3_URI, + ) + + def test_attacker_points_to_cross_account_secret(self): + """Attacker points to their own secret in different account - should be rejected.""" + session, _, _, _ = _mock_sagemaker_session() + + attacker_secret_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:attacker-secret" + + with pytest.raises(DeserializationError, match="same AWS account"): + _validate_secret_arn(session, attacker_secret_arn, MOCK_JOB_NAME) + + def test_attacker_creates_secret_in_same_account(self): + """Attacker creates secret in same account but ARN doesn't match Parameter Store.""" + session, _, ssm_client, _ = _mock_sagemaker_session() + + # Parameter Store has the legitimate ARN + ssm_client.get_parameter.return_value = { + "Parameter": {"Value": MOCK_SECRET_ARN} + } + + # Attacker's secret in same account + attacker_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:sagemaker/remote-function/evil-job/hmac-key" + + with pytest.raises(DeserializationError, match="Secret ARN mismatch"): + _validate_secret_arn(session, attacker_arn, MOCK_JOB_NAME)