Skip to content

Commit

Permalink
Merge branch 'main' into support-placeholders-for-tuning-step
Browse files Browse the repository at this point in the history
  • Loading branch information
ca-nguyen authored Oct 28, 2021
2 parents 0ce5235 + 2091850 commit eaf25dc
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 19 deletions.
33 changes: 19 additions & 14 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
experiment_config (dict or Placeholder, optional): Specify the experiment config for the training. (Default: None)
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
output_data_config_path (str or Placeholder, optional): S3 location for saving the training result (model
artifacts and output files). If specified, it overrides the `output_path` property of `estimator`.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateTrainingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html>`_. (Default: None)
Expand Down Expand Up @@ -220,7 +220,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
split_type (str or Placeholder): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
experiment_config (dict or Placeholder, optional): Specify the experiment config for the transform. (Default: None)
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the transform job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the transform job and proceed to the next step. (default: True)
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
input_filter (str or Placeholder): A JSONPath to select a portion of the input to pass to the algorithm container for inference. If you omit the field, it gets the value ‘$’, representing the entire input. For CSV data, each row is taken as a JSON array, so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. CSV data should follow the RFC format. See Supported JSONPath Operators for a table of supported JSONPath operators. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.features” (default: None).
output_filter (str or Placeholder): A JSONPath to select a portion of the joined/original output to return as the output. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.prediction” (default: None).
join_source (str or Placeholder): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None.
Expand Down Expand Up @@ -302,14 +302,16 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here.
model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholders, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateModel<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_. (Default: None)
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
"""
if isinstance(model, FrameworkModel):
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
model_parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
if model_name:
parameters['ModelName'] = model_name
model_parameters['ModelName'] = model_name
elif isinstance(model, Model):
parameters = {
model_parameters = {
'ExecutionRoleArn': model.role,
'ModelName': model_name or model.name,
'PrimaryContainer': {
Expand All @@ -321,13 +323,17 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
else:
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))

if 'S3Operations' in parameters:
del parameters['S3Operations']
if 'S3Operations' in model_parameters:
del model_parameters['S3Operations']

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)
model_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)

kwargs[Field.Parameters.value] = parameters
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update model parameters with input parameters
merge_dicts(model_parameters, kwargs[Field.Parameters.value])

kwargs[Field.Parameters.value] = model_parameters

"""
Example resource arn: arn:aws:states:::sagemaker:createModel
Expand Down Expand Up @@ -357,7 +363,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
data_capture_config (sagemaker.model_monitor.DataCaptureConfig, optional): Specifies
configuration related to Endpoint data capture for use with
Amazon SageMaker Model Monitoring. Default: None.
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
"""
parameters = {
'EndpointConfigName': endpoint_config_name,
Expand Down Expand Up @@ -399,9 +405,8 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
endpoint_name (str or Placeholder): The name of the endpoint to create. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
endpoint_config_name (str or Placeholder): The name of the endpoint configuration to use for the endpoint. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
update (bool, optional): Boolean flag set to `True` if endpoint must to be updated. Set to `False` if new endpoint must be created. (default: False)
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
"""

parameters = {
Expand Down Expand Up @@ -528,7 +533,7 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
The KmsKeyId is applied to all outputs.
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateProcessingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html>`_.
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
Expand Down
63 changes: 58 additions & 5 deletions tests/integ/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,59 @@ def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_a
delete_sagemaker_model(model_name, sagemaker_session)



def test_model_step_with_placeholders(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn):
# Build workflow definition
execution_input = ExecutionInput(schema={
'ModelName': str,
'Mode': str,
'Tags': list
})

parameters = {
'PrimaryContainer': {
'Mode': execution_input['Mode']
},
'Tags': execution_input['Tags']
}

model_step = ModelStep('create_model_step', model=trained_estimator.create_model(),
model_name=execution_input['ModelName'], parameters=parameters)
model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
workflow_graph = Chain([model_step])

with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
# Create workflow and check definition
workflow = create_workflow_and_check_definition(
workflow_graph=workflow_graph,
workflow_name=unique_name_from_base("integ-test-model-step-workflow"),
sfn_client=sfn_client,
sfn_role_arn=sfn_role_arn
)

inputs = {
'ModelName': generate_job_name(),
'Mode': 'SingleModel',
'Tags': [{
'Key': 'Environment',
'Value': 'test'
}]
}

# Execute workflow
execution = workflow.execute(inputs=inputs)
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("ModelArn") is not None
assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
model_name = get_resource_name_from_arn(execution_output.get("ModelArn")).split("/")[1]
delete_sagemaker_model(model_name, sagemaker_session)


def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
# Create transformer from previously created estimator
job_name = generate_job_name()
Expand Down Expand Up @@ -349,7 +402,7 @@ def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session,
# Execute workflow
execution = workflow.execute()
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("EndpointConfigArn") is not None
assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200
Expand Down Expand Up @@ -390,7 +443,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
# Execute workflow
execution = workflow.execute()
execution_output = execution.get_output(wait=True)

# Check workflow output
endpoint_arn = execution_output.get("EndpointArn")
assert execution_output.get("EndpointArn") is not None
Expand Down Expand Up @@ -428,7 +481,7 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
max_jobs=2,
max_parallel_jobs=2,
)

# Build workflow definition
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=job_name, data=record_set_for_hyperparameter_tuning)
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
Expand All @@ -446,7 +499,7 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
# Execute workflow
execution = workflow.execute()
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"

Expand Down Expand Up @@ -586,7 +639,7 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
sfn_client=sfn_client,
sfn_role_arn=sfn_role_arn
)

# Execute workflow
execution = workflow.execute()
execution_output = execution.get_output(wait=True)
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,36 @@ def test_model_step_creation(pca_model):
}


@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_model_step_creation_with_placeholders(pca_model):
execution_input = ExecutionInput(schema={
'Environment': str,
'Tags': str
})

step_input = StepInput(schema={
'ModelName': str
})

parameters = {
'PrimaryContainer': {
'Environment': execution_input['Environment']
}
}
step = ModelStep('Create model', model=pca_model, model_name=step_input['ModelName'], tags=execution_input['Tags'],
parameters=parameters)
assert step.to_dict()['Parameters'] == {
'ExecutionRoleArn': EXECUTION_ROLE,
'ModelName.$': "$['ModelName']",
'PrimaryContainer': {
'Environment.$': "$$.Execution.Input['Environment']",
'Image': pca_model.image_uri,
'ModelDataUrl': pca_model.model_data
},
'Tags.$': "$$.Execution.Input['Tags']"
}


@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_model_step_creation_with_env(pca_model_with_env):
step = ModelStep('Create model', model=pca_model_with_env, model_name='pca-model', tags=DEFAULT_TAGS)
Expand Down

0 comments on commit eaf25dc

Please sign in to comment.