Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support placeholders for processing step #155

Merged
merged 23 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
927b24f
documentation: Add setup instructions to run/debug tests locally
ca-nguyen Jul 16, 2021
003b5e8
Merge branch 'main' into update-contributing
shivlaks Aug 9, 2021
a7700a6
Added sub section for debug setup and linked to run tests instructions
ca-nguyen Aug 10, 2021
6b6443a
Update table
ca-nguyen Aug 12, 2021
7f6ef30
Support placeholders for processor parameters in processingstep
ca-nguyen Aug 12, 2021
00830f3
Added doc
ca-nguyen Aug 12, 2021
c708da7
Removed contibuting changes(included in another pr)
ca-nguyen Aug 12, 2021
2ea9e1f
Merge sagemaker generated parameters with placeholder compatible para…
ca-nguyen Aug 17, 2021
17543ed
documentation: Add setup instructions to run/debug tests locally
ca-nguyen Jul 16, 2021
36e2ee8
Added sub section for debug setup and linked to run tests instructions
ca-nguyen Aug 10, 2021
ea40f7c
Update table
ca-nguyen Aug 12, 2021
e499108
Support placeholders for processor parameters in processingstep
ca-nguyen Aug 12, 2021
4c63229
Added doc
ca-nguyen Aug 12, 2021
34bb281
Removed contibuting changes(included in another pr)
ca-nguyen Aug 12, 2021
a098c61
Merge sagemaker generated parameters with placeholder compatible para…
ca-nguyen Aug 17, 2021
06eb069
Merge branch 'support-placeholders-for-processing-step' of https://gi…
ca-nguyen Aug 17, 2021
da99c92
Using == instead of is()
ca-nguyen Aug 17, 2021
37b2422
Removed unused InvalidPathToPlaceholderParameter exception
ca-nguyen Aug 17, 2021
c433576
Merge branch 'main' into support-placeholders-for-processing-step
ca-nguyen Aug 17, 2021
fd640ab
Added doc and renamed args
ca-nguyen Aug 18, 2021
1dfa0e3
Update src/stepfunctions/steps/sagemaker.py parameters description
ca-nguyen Aug 19, 2021
6143783
Removed dict name args to opt for more generic log message when overw…
ca-nguyen Aug 19, 2021
ebc5e22
Using fstring in test
ca-nguyen Aug 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/stepfunctions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,10 @@ class MissingRequiredParameter(Exception):


class DuplicateStatesInChain(Exception):
pass
pass


class InvalidPathToPlaceholderParameter(Exception):

def __init__(self, message):
super(InvalidPathToPlaceholderParameter, self).__init__(message)
30 changes: 30 additions & 0 deletions src/stepfunctions/steps/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from enum import Enum
from stepfunctions.steps.fields import Field

# Path to SageMaker placeholder parameters
placeholder_paths = {
# Paths taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question:

  1. what does the comment mean when it says "taken" from the documentation link?
  2. can all properties be represented by placeholders or is it only some?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1: I took the location to the CreateProcessingJob request from there to save in placeholder_paths for each arg
2: all can

'ProcessingStep': {
Field.Role.value: ['RoleArn'],
Field.ImageUri.value: ['AppSpecification', 'ImageUri'],
Field.InstanceCount.value: ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
Field.InstanceType.value: ['ProcessingResources', 'ClusterConfig', 'InstanceType'],
Field.Entrypoint.value: ['AppSpecification', 'ContainerEntrypoint'],
Field.VolumeSizeInGB.value: ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
Field.VolumeKMSKey.value: ['ProcessingResources', 'ClusterConfig', 'VolumeKmsKeyId'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: haven't played much with this, but do all nested properties (i.e. VolumeSizeInGb) support placeholders to supply their value? - Admittedly, Ive only supplied top level properties and haven't tinkered deep enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are when passed to parameters
The placeholders are all replaced in parameters(included nested values) here

Field.Env.value: ['Environment'],
Field.Tags.value: ['Tags'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: is there a way to read these from the SDK or automate it? having this hand-rolled can be problematic for a few reasons:

  1. drift if the API signatures expand
  2. prone to error, since it's reliant on everything being hand rolled

Copy link
Contributor Author

@ca-nguyen ca-nguyen Aug 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same thought - doing this by hand can introduce errors and is not easily maintainable. Doing it this way allowed less code redundancy.

Another option would be, for each SageMaker property, to call the function that adds it to the Parameters in the Sagemaker code instead of manually getting the path from the API docs
That would mean:

  • Each property will call a different function (if existing) in sagemaker in order to add it to Parameters
  • Some properties will not have an existing function in sagemaker that adds them to Parameters and we will have to do it by hand using the placeholder_paths

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read @wong-a 's proposed solution after posting the previous comment - will go with that since it removes the need to map each args with placeholder_paths

}
}
15 changes: 14 additions & 1 deletion src/stepfunctions/steps/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,23 @@ class Field(Enum):
HeartbeatSeconds = 'heartbeat_seconds'
HeartbeatSecondsPath = 'heartbeat_seconds_path'


# Retry and catch fields
ErrorEquals = 'error_equals'
IntervalSeconds = 'interval_seconds'
MaxAttempts = 'max_attempts'
BackoffRate = 'backoff_rate'
NextStep = 'next_step'

# Sagemaker step fields
# Processing Step: Processor
Role = 'role'
ImageUri = 'image_uri'
InstanceCount = 'instance_count'
InstanceType = 'instance_type'
Entrypoint = 'entrypoint'
VolumeSizeInGB = 'volume_size_in_gb'
VolumeKMSKey = 'volume_kms_key'
OutputKMSKey = 'output_kms_key'
MaxRuntimeInSeconds = 'max_runtime_in_seconds'
Env = 'env'
Tags = 'tags'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: is this the right place for storing these properties? everything else in this file is specific to states and ASL, but this introduces properties specific to a service integration's API.

some properties will also be duplicated across APIs/Service Integrations (things like role, tags, etc are probably used in multiple APIs)

another thing to think about:

as a customer, would it be more intuitive to have something like Placeholders.SagemakerProcessingStep.blah than Fields.blah

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love what you are proposing! Separating the SageMaker property fields and the state and ASL specific fields will definitely make it more intuitive to the customer

Changes will be made in next commit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. This class is for ASL fields, not parameters of specific service integrations

124 changes: 117 additions & 7 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
from __future__ import absolute_import

import logging
import operator

from enum import Enum
from functools import reduce

from stepfunctions.exceptions import InvalidPathToPlaceholderParameter
from stepfunctions.inputs import Placeholder
from stepfunctions.steps.constants import placeholder_paths
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import tags_dict_to_kv_list
Expand All @@ -25,6 +29,7 @@
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
from sagemaker.model import Model, FrameworkModel
from sagemaker.model_monitor import DataCaptureConfig
from sagemaker.processing import ProcessingJob

logger = logging.getLogger('stepfunctions.sagemaker')

Expand All @@ -41,6 +46,104 @@ class SageMakerApi(Enum):
CreateProcessingJob = "createProcessingJob"


class SageMakerTask(Task):

"""
Task State causes the interpreter to execute the work identified by the state’s `resource` field.
"""

def __init__(self, state_id, step_type, tags, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI.
timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60)
timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name.
heartbeat_seconds_path (str, optional): Path specifying the state's heartbeat value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
self._replace_sagemaker_placeholders(step_type, kwargs)
if tags:
self.set_tags_config(tags, kwargs[Field.Parameters.value], step_type)

super(SageMakerTask, self).__init__(state_id, **kwargs)


def allowed_fields(self):
sagemaker_fields = [
# ProcessingStep: Processor
Field.Role,
Field.ImageUri,
Field.InstanceCount,
Field.InstanceType,
Field.Entrypoint,
Field.VolumeSizeInGB,
Field.VolumeKMSKey,
Field.OutputKMSKey,
Field.MaxRuntimeInSeconds,
Field.Env,
Field.Tags,
]

return super(SageMakerTask, self).allowed_fields() + sagemaker_fields


def _replace_sagemaker_placeholders(self, step_type, args):
# Fetch path from type
sagemaker_parameters = args[Field.Parameters.value]
paths = placeholder_paths.get(step_type)
treated_args = []

for arg_name, value in args.items():
if arg_name in [Field.Parameters.value]:
continue
if arg_name in paths.keys():
path = paths.get(arg_name)
if self._set_placeholder(sagemaker_parameters, path, value, arg_name):
treated_args.append(arg_name)

SageMakerTask.remove_treated_args(treated_args, args)

@staticmethod
def get_value_from_path(parameters, path):
value_from_path = reduce(operator.getitem, path, parameters)
return value_from_path
# return reduce(operator.getitem, path, parameters)

@staticmethod
def _set_placeholder(parameters, path, value, arg_name):
is_set = False
try:
SageMakerTask.get_value_from_path(parameters, path[:-1])[path[-1]] = value
is_set = True
except KeyError as e:
message = f"Invalid path {path} for {arg_name}: {e}"
raise InvalidPathToPlaceholderParameter(message)
return is_set

@staticmethod
def remove_treated_args(treated_args, args):
for treated_arg in treated_args:
try:
del args[treated_arg]
except KeyError as e:
pass

def set_tags_config(self, tags, parameters, step_type):
if isinstance(tags, Placeholder):
# Replace with placeholder
path = placeholder_paths.get(step_type).get(Field.Tags.value)
if path:
self._set_placeholder(parameters, path, tags, Field.Tags.value)
else:
parameters['Tags'] = tags_dict_to_kv_list(tags)


class TrainingStep(Task):

"""
Expand Down Expand Up @@ -473,13 +576,19 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
super(TuningStep, self).__init__(state_id, **kwargs)


class ProcessingStep(Task):
class ProcessingStep(SageMakerTask):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breaking change??

the class constructor used to take in a Task but now that's been changed. when customers upgrade versions, won't their existing code fail?

we cannot make breaking changes as we need to follow semantic versioning while releasing minor version updates.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes ProcessingStep's base class, but not the constructor arguments. With this change, instead of calling Task's constructor directly in init(), we call SageMakerTask's constructor which in turn calls Task's constructor.

Before:

  • ProcessingtStep.init()
    • Task.init()

After:

  • ProcessingtStep.init()
    • SageMakerTask.init()
      • Task.init()


"""
Creates a Task State to execute a SageMaker Processing Job.

The following properties can be passed down as kwargs to the sagemaker.processing.Processor to be used dynamically
in the processing job (compatible with Placeholders): role, image_uri, instance_count, instance_type,
volume_size_in_gb, volume_kms_key, output_kms_key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: what's the source of truth here? how did we verify that these properties are the ones supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ones that are made to be placeholder compatible are :

  1. The args that are documented as being placeholder compatible in the Args section (for ex: job_name)
  2. The ones that are included in placeholder_paths (src/stepfunctions/steps/constants.py)

Since we are replacing the Placeholders with the ExecutionInput when starting the job, when the Sagemaker job starts, all placeholders are replaced. If some args that we configured to hold placeholder in our state machine were not replaced, this should trigger an error.

I can add a test to confirm the behaviour in the next commit

Thanks for bringing this up - this confirms that this documentation is not clear and we might want to switch to the Alternative solution where we would add all Placeholder compatible properties as optional args in the step constructor, making it clearer to the customer which are Placeholder compatible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is docstring is out of date with the new implementation

"""

def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs):
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None,
container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True,
tags=None, max_runtime_in_seconds=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
Expand All @@ -499,7 +608,8 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
The KmsKeyId is applied to all outputs.
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
max_runtime_in_seconds (int or Placeholder): Specifies the maximum runtime in seconds for the processing job
"""
if wait_for_completion:
"""
Expand Down Expand Up @@ -528,12 +638,12 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
if experiment_config is not None:
parameters['ExperimentConfig'] = experiment_config

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)

if 'S3Operations' in parameters:
del parameters['S3Operations']

if max_runtime_in_seconds:
parameters['StoppingCondition'] = ProcessingJob.prepare_stopping_condition(max_runtime_in_seconds)

kwargs[Field.Parameters.value] = parameters

super(ProcessingStep, self).__init__(state_id, **kwargs)
super(ProcessingStep, self).__init__(state_id, __class__.__name__, tags, **kwargs)
84 changes: 84 additions & 0 deletions tests/integ/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
from sagemaker.tuner import HyperparameterTuner
from sagemaker.processing import ProcessingInput, ProcessingOutput

from stepfunctions.inputs import ExecutionInput
from stepfunctions.steps import Chain
from stepfunctions.steps.fields import Field
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep
from stepfunctions.workflow import Workflow

Expand Down Expand Up @@ -352,3 +354,85 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup


def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn,
sagemaker_role_arn):
region = boto3.session.Session().region_name
input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not use f strings here too instead of format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that using fstring is more readable and efficient. format was used for all other tests so i kept it for consistency.
Will change it for this added test and perhaps we can make the change for the rest of the file in a separate PR


input_s3 = sagemaker_session.upload_data(
path=os.path.join(DATA_DIR, 'sklearn_processing'),
bucket=sagemaker_session.default_bucket(),
key_prefix='integ-test-data/sklearn_processing/code'
)

output_s3 = 's3://' + sagemaker_session.default_bucket() + '/integ-test-data/sklearn_processing'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not use f strings here instead of concatenation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - using fstringwould be more readable and efficient.

Same comment: format was used for all other tests so i kept it for consistency.
Will change it for this added test and perhaps we can make the change for the rest of the file in a separate PR


inputs = [
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'),
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code',
input_name='code'),
]

outputs = [
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data',
output_name='train_data'),
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data',
output_name='test_data'),
]

# Build workflow definition
execution_input = ExecutionInput(schema={
Field.ImageUri.value: str,
Field.InstanceCount.value: int,
Field.Entrypoint.value: str,
Field.Role.value: str,
Field.VolumeSizeInGB.value: int,
Field.MaxRuntimeInSeconds.value: int
})

job_name = generate_job_name()
processing_step = ProcessingStep('create_processing_job_step',
processor=sklearn_processor_fixture,
job_name=job_name,
inputs=inputs,
outputs=outputs,
container_arguments=['--train-test-split-ratio', '0.2'],
container_entrypoint=execution_input[Field.Entrypoint.value],
image_uri=execution_input[Field.ImageUri.value],
instance_count=execution_input[Field.InstanceCount.value],
role=execution_input[Field.Role.value],
volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value],
max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value]
)
workflow_graph = Chain([processing_step])

with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
# Create workflow and check definition
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary comment as the method name expresses this in snake case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed- will be removed with the next commit

workflow = create_workflow_and_check_definition(
workflow_graph=workflow_graph,
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"),
sfn_client=sfn_client,
sfn_role_arn=sfn_role_arn
)

execution_input = {
Field.ImageUri.value: '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3',
Field.InstanceCount.value: 1,
Field.Entrypoint.value: ['python3', '/opt/ml/processing/input/code/preprocessor.py'],
Field.Role.value: sagemaker_role_arn,
Field.VolumeSizeInGB.value: 30,
Field.MaxRuntimeInSeconds.value: 500
}

# Execute workflow
execution = workflow.execute(inputs=execution_input)
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("ProcessingJobStatus") == "Completed"

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think the code is self explanatory. we can drop this comment 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! i'll remove the comments :)
They are included in all the other tests - will do a cleanup for the other tests in another PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you forget to remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - will remove it in the next commit!

Loading