Skip to content

fix: Support placeholders for hyperparameters passed to TrainingStep #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
where each instance is a different channel of training data.
hyperparameters (dict, optional): Parameters used for training.
Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
hyperparameters: Parameters used for training.
* (dict, optional) - Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.)
* (Placeholder, optional) - The TrainingStep will use the hyperparameters specified by the Placeholder's value instead of the hyperparameters specified in the estimator.
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
Expand Down Expand Up @@ -127,8 +128,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri

if hyperparameters is not None:
if estimator.hyperparameters() is not None:
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
if not isinstance(hyperparameters, Placeholder):
if estimator.hyperparameters() is not None:
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
parameters['HyperParameters'] = hyperparameters

if experiment_config is not None:
Expand Down
53 changes: 46 additions & 7 deletions tests/unit/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def test_training_step_creation_with_placeholders(pca_estimator):
execution_input = ExecutionInput(schema={
'Data': str,
'OutputPath': str,
'HyperParameters': str
})

step_input = StepInput(schema={
Expand All @@ -292,6 +293,7 @@ def test_training_step_creation_with_placeholders(pca_estimator):
'TrialComponentDisplayName': 'Training'
},
tags=DEFAULT_TAGS,
hyperparameters=execution_input['HyperParameters']
)
assert step.to_dict() == {
'Type': 'Task',
Expand All @@ -312,13 +314,7 @@ def test_training_step_creation_with_placeholders(pca_estimator):
'VolumeSizeInGB': 30
},
'RoleArn': EXECUTION_ROLE,
'HyperParameters': {
'feature_dim': '50000',
'num_components': '10',
'subtract_mean': 'True',
'algorithm_mode': 'randomized',
'mini_batch_size': '200'
},
'HyperParameters.$': "$$.Execution.Input['HyperParameters']",
'InputDataConfig': [
{
'ChannelName': 'training',
Expand All @@ -344,6 +340,49 @@ def test_training_step_creation_with_placeholders(pca_estimator):
}


@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_training_step_creation_with_hyperparameters_containing_placeholders(pca_estimator):
execution_input = ExecutionInput(schema={
'Data': str,
'OutputPath': str,
'num_components': str,
'HyperParamA': str,
'HyperParamB': str,
})

step_input = StepInput(schema={
'JobName': str,
})

step = TrainingStep('Training',
estimator=pca_estimator,
job_name=step_input['JobName'],
data=execution_input['Data'],
output_data_config_path=execution_input['OutputPath'],
experiment_config={
'ExperimentName': 'pca_experiment',
'TrialName': 'pca_trial',
'TrialComponentDisplayName': 'Training'
},
tags=DEFAULT_TAGS,
hyperparameters={
'num_components': execution_input['num_components'], # This will overwrite the value that was defined in the pca_estimator
'HyperParamA': execution_input['HyperParamA'],
'HyperParamB': execution_input['HyperParamB']
}
)
assert step.to_dict()['Parameters']['HyperParameters'] == {
'HyperParamA.$': "$$.Execution.Input['HyperParamA']",
'HyperParamB.$': "$$.Execution.Input['HyperParamB']",
'algorithm_mode': 'randomized',
'feature_dim': 50000,
'mini_batch_size': 200,
'num_components.$': "$$.Execution.Input['num_components']",
'subtract_mean': True
}


@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):
Expand Down