Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions sagemaker-train/src/sagemaker/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,13 +1444,15 @@ def _build_training_job_definition(self, inputs):

# Build stopping condition
stopping_condition = StoppingCondition()
if (
model_trainer.stopping_condition
and model_trainer.stopping_condition.max_runtime_in_seconds
):
stopping_condition.max_runtime_in_seconds = (
model_trainer.stopping_condition.max_runtime_in_seconds
)
if model_trainer.stopping_condition:
if model_trainer.stopping_condition.max_runtime_in_seconds:
stopping_condition.max_runtime_in_seconds = (
model_trainer.stopping_condition.max_runtime_in_seconds
)
if model_trainer.stopping_condition.max_wait_time_in_seconds:
stopping_condition.max_wait_time_in_seconds = (
model_trainer.stopping_condition.max_wait_time_in_seconds
)

definition = HyperParameterTrainingJobDefinition(
algorithm_specification=algorithm_spec,
Expand All @@ -1460,6 +1462,7 @@ def _build_training_job_definition(self, inputs):
resource_config=resource_config,
stopping_condition=stopping_condition,
static_hyper_parameters=self.static_hyperparameters or {},
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
)

return definition
23 changes: 22 additions & 1 deletion sagemaker-train/tests/unit/train/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@
# ---------------------------------------------------------------------------


def _create_mock_model_trainer(with_internal_channels=False):
def _create_mock_model_trainer(with_internal_channels=False, with_spot_training=False):
"""Create a mock ModelTrainer with common attributes.

Args:
with_internal_channels: If True, adds internal channels (code, sm_drivers)
to input_data_config for testing channel inclusion in tuning jobs.
with_spot_training: If True, sets spot parameters (enable_managed_spot_training,
max_wait_time_in_seconds)
"""
trainer = MagicMock()
trainer.sagemaker_session = MagicMock()
Expand All @@ -67,6 +69,9 @@ def _create_mock_model_trainer(with_internal_channels=False):
_create_channel("code", "s3://bucket/code"),
_create_channel("sm_drivers", "s3://bucket/drivers"),
]
if with_spot_training:
trainer.compute.enable_managed_spot_training = True
trainer.stopping_condition.max_wait_time_in_seconds = 3600
return trainer


Expand Down Expand Up @@ -574,3 +579,19 @@ def test_build_training_job_definition_includes_internal_channels(self):
assert "train" in channel_names, "User 'train' channel should be included"
assert "validation" in channel_names, "User 'validation' channel should be included"
assert len(channel_names) == 4, "Should have exactly 4 channels"

def test_build_training_job_definition_includes_spot_params(self):
"""Test that _build_training_job_definition includes spot parameters.
"""
tuner = HyperparameterTuner(
model_trainer=_create_mock_model_trainer(with_spot_training=True),
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

# Build training job definition
definition = tuner._build_training_job_definition(None)

# Verify managed spot training enabled
assert definition.enable_managed_spot_training is True, "Spot should be enabled"
assert isinstance(definition.stopping_condition.max_wait_time_in_seconds, int), "Max wait time should be set"