From e96f15770218bb52afb050c7d624bb9444411b6f Mon Sep 17 00:00:00 2001 From: Todd Stephenson <71510807+toddstep@users.noreply.github.com> Date: Thu, 26 Feb 2026 16:01:44 +0000 Subject: [PATCH] fix: fix HyperparameterTuner to launch training jobs with provided spot parameters HyperparameterTuner._build_training_job_definition() was not copying parameters needed for managed spot training: - enable_managed_spot_training - max_wait_time_in_seconds This caused training jobs to launch with on-demend instances. - Include the additional parameters in the job definition - Add a unit test --- sagemaker-train/src/sagemaker/train/tuner.py | 17 ++++++++------ .../tests/unit/train/test_tuner.py | 23 ++++++++++++++++++- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index f99f055f0b..dc18a6a71e 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -1444,13 +1444,15 @@ def _build_training_job_definition(self, inputs): # Build stopping condition stopping_condition = StoppingCondition() - if ( - model_trainer.stopping_condition - and model_trainer.stopping_condition.max_runtime_in_seconds - ): - stopping_condition.max_runtime_in_seconds = ( - model_trainer.stopping_condition.max_runtime_in_seconds - ) + if model_trainer.stopping_condition: + if model_trainer.stopping_condition.max_runtime_in_seconds: + stopping_condition.max_runtime_in_seconds = ( + model_trainer.stopping_condition.max_runtime_in_seconds + ) + if model_trainer.stopping_condition.max_wait_time_in_seconds: + stopping_condition.max_wait_time_in_seconds = ( + model_trainer.stopping_condition.max_wait_time_in_seconds + ) definition = HyperParameterTrainingJobDefinition( algorithm_specification=algorithm_spec, @@ -1460,6 +1462,7 @@ def _build_training_job_definition(self, inputs): resource_config=resource_config, stopping_condition=stopping_condition, static_hyper_parameters=self.static_hyperparameters or {}, + enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training, ) return definition diff --git a/sagemaker-train/tests/unit/train/test_tuner.py b/sagemaker-train/tests/unit/train/test_tuner.py index c0255eac47..6238119dfb 100644 --- a/sagemaker-train/tests/unit/train/test_tuner.py +++ b/sagemaker-train/tests/unit/train/test_tuner.py @@ -39,12 +39,14 @@ # --------------------------------------------------------------------------- -def _create_mock_model_trainer(with_internal_channels=False): +def _create_mock_model_trainer(with_internal_channels=False, with_spot_training=False): """Create a mock ModelTrainer with common attributes. Args: with_internal_channels: If True, adds internal channels (code, sm_drivers) to input_data_config for testing channel inclusion in tuning jobs. + with_spot_training: If True, sets spot parameters (enable_managed_spot_training, + max_wait_time_in_seconds) """ trainer = MagicMock() trainer.sagemaker_session = MagicMock() @@ -67,6 +69,9 @@ def _create_mock_model_trainer(with_internal_channels=False): _create_channel("code", "s3://bucket/code"), _create_channel("sm_drivers", "s3://bucket/drivers"), ] + if with_spot_training: + trainer.compute.enable_managed_spot_training = True + trainer.stopping_condition.max_wait_time_in_seconds = 3600 return trainer @@ -574,3 +579,19 @@ def test_build_training_job_definition_includes_internal_channels(self): assert "train" in channel_names, "User 'train' channel should be included" assert "validation" in channel_names, "User 'validation' channel should be included" assert len(channel_names) == 4, "Should have exactly 4 channels" + + def test_build_training_job_definition_includes_spot_params(self): + """Test that _build_training_job_definition includes spot parameters. + """ + tuner = HyperparameterTuner( + model_trainer=_create_mock_model_trainer(with_spot_training=True), + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + # Build training job definition + definition = tuner._build_training_job_definition(None) + + # Verify managed spot training enabled + assert definition.enable_managed_spot_training is True, "Spot should be enabled" + assert isinstance(definition.stopping_condition.max_wait_time_in_seconds, int), "Max wait time should be set"