diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 9530478..a3ddd47 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -69,9 +69,10 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of :class:`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data. - hyperparameters (dict, optional): Parameters used for training. - Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator. + hyperparameters: Parameters used for training. + * (dict, optional) - Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator. If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.) + * (Placeholder, optional) - The TrainingStep will use the hyperparameters specified by the Placeholder's value instead of the hyperparameters specified in the estimator. mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator. experiment_config (dict, optional): Specify the experiment config for the training. (Default: None) wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True) @@ -127,8 +128,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri if hyperparameters is not None: - if estimator.hyperparameters() is not None: - hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters()) + if not isinstance(hyperparameters, Placeholder): + if estimator.hyperparameters() is not None: + hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters()) parameters['HyperParameters'] = hyperparameters if experiment_config is not None: diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index 664a498..02c6083 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -275,6 +275,7 @@ def test_training_step_creation_with_placeholders(pca_estimator): execution_input = ExecutionInput(schema={ 'Data': str, 'OutputPath': str, + 'HyperParameters': str }) step_input = StepInput(schema={ @@ -292,6 +293,7 @@ def test_training_step_creation_with_placeholders(pca_estimator): 'TrialComponentDisplayName': 'Training' }, tags=DEFAULT_TAGS, + hyperparameters=execution_input['HyperParameters'] ) assert step.to_dict() == { 'Type': 'Task', @@ -312,13 +314,7 @@ def test_training_step_creation_with_placeholders(pca_estimator): 'VolumeSizeInGB': 30 }, 'RoleArn': EXECUTION_ROLE, - 'HyperParameters': { - 'feature_dim': '50000', - 'num_components': '10', - 'subtract_mean': 'True', - 'algorithm_mode': 'randomized', - 'mini_batch_size': '200' - }, + 'HyperParameters.$': "$$.Execution.Input['HyperParameters']", 'InputDataConfig': [ { 'ChannelName': 'training', @@ -344,6 +340,49 @@ def test_training_step_creation_with_placeholders(pca_estimator): } +@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_training_step_creation_with_hyperparameters_containing_placeholders(pca_estimator): + execution_input = ExecutionInput(schema={ + 'Data': str, + 'OutputPath': str, + 'num_components': str, + 'HyperParamA': str, + 'HyperParamB': str, + }) + + step_input = StepInput(schema={ + 'JobName': str, + }) + + step = TrainingStep('Training', + estimator=pca_estimator, + job_name=step_input['JobName'], + data=execution_input['Data'], + output_data_config_path=execution_input['OutputPath'], + experiment_config={ + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Training' + }, + tags=DEFAULT_TAGS, + hyperparameters={ + 'num_components': execution_input['num_components'], # This will overwrite the value that was defined in the pca_estimator + 'HyperParamA': execution_input['HyperParamA'], + 'HyperParamB': execution_input['HyperParamB'] + } + ) + assert step.to_dict()['Parameters']['HyperParameters'] == { + 'HyperParamA.$': "$$.Execution.Input['HyperParamA']", + 'HyperParamB.$': "$$.Execution.Input['HyperParamB']", + 'algorithm_mode': 'randomized', + 'feature_dim': 50000, + 'mini_batch_size': 200, + 'num_components.$': "$$.Execution.Input['num_components']", + 'subtract_mean': True + } + + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):