-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support placeholders for processing step #155
Changes from 7 commits
927b24f
003b5e8
a7700a6
6b6443a
7f6ef30
00830f3
c708da7
2ea9e1f
17543ed
36e2ee8
ea40f7c
e499108
4c63229
34bb281
a098c61
06eb069
da99c92
37b2422
c433576
fd640ab
1dfa0e3
6143783
ebc5e22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
from enum import Enum | ||
from stepfunctions.steps.fields import Field | ||
|
||
# Path to SageMaker placeholder parameters | ||
placeholder_paths = { | ||
# Paths taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html | ||
'ProcessingStep': { | ||
Field.Role.value: ['RoleArn'], | ||
Field.ImageUri.value: ['AppSpecification', 'ImageUri'], | ||
Field.InstanceCount.value: ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], | ||
Field.InstanceType.value: ['ProcessingResources', 'ClusterConfig', 'InstanceType'], | ||
Field.Entrypoint.value: ['AppSpecification', 'ContainerEntrypoint'], | ||
Field.VolumeSizeInGB.value: ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], | ||
Field.VolumeKMSKey.value: ['ProcessingResources', 'ClusterConfig', 'VolumeKmsKeyId'], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious: haven't played much with this, but do all nested properties (i.e. VolumeSizeInGb) support placeholders to supply their value? - Admittedly, Ive only supplied top level properties and haven't tinkered deep enough. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are when passed to parameters |
||
Field.Env.value: ['Environment'], | ||
Field.Tags.value: ['Tags'], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thought: is there a way to read these from the SDK or automate it? having this hand-rolled can be problematic for a few reasons:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had the same thought - doing this by hand can introduce errors and is not easily maintainable. Doing it this way allowed less code redundancy. Another option would be, for each SageMaker property, to call the function that adds it to the Parameters in the Sagemaker code instead of manually getting the path from the API docs
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Read @wong-a 's proposed solution after posting the previous comment - will go with that since it removes the need to map each args with |
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,10 +59,23 @@ class Field(Enum): | |
HeartbeatSeconds = 'heartbeat_seconds' | ||
HeartbeatSecondsPath = 'heartbeat_seconds_path' | ||
|
||
|
||
# Retry and catch fields | ||
ErrorEquals = 'error_equals' | ||
IntervalSeconds = 'interval_seconds' | ||
MaxAttempts = 'max_attempts' | ||
BackoffRate = 'backoff_rate' | ||
NextStep = 'next_step' | ||
|
||
# Sagemaker step fields | ||
# Processing Step: Processor | ||
Role = 'role' | ||
ImageUri = 'image_uri' | ||
InstanceCount = 'instance_count' | ||
InstanceType = 'instance_type' | ||
Entrypoint = 'entrypoint' | ||
VolumeSizeInGB = 'volume_size_in_gb' | ||
VolumeKMSKey = 'volume_kms_key' | ||
OutputKMSKey = 'output_kms_key' | ||
MaxRuntimeInSeconds = 'max_runtime_in_seconds' | ||
Env = 'env' | ||
Tags = 'tags' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thought: is this the right place for storing these properties? everything else in this file is specific to states and ASL, but this introduces properties specific to a service integration's API. some properties will also be duplicated across APIs/Service Integrations (things like role, tags, etc are probably used in multiple APIs) another thing to think about: as a customer, would it be more intuitive to have something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Love what you are proposing! Separating the SageMaker property fields and the state and ASL specific fields will definitely make it more intuitive to the customer Changes will be made in next commit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. This class is for ASL fields, not parameters of specific service integrations |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,10 +13,14 @@ | |
from __future__ import absolute_import | ||
|
||
import logging | ||
import operator | ||
|
||
from enum import Enum | ||
from functools import reduce | ||
|
||
from stepfunctions.exceptions import InvalidPathToPlaceholderParameter | ||
from stepfunctions.inputs import Placeholder | ||
from stepfunctions.steps.constants import placeholder_paths | ||
from stepfunctions.steps.states import Task | ||
from stepfunctions.steps.fields import Field | ||
from stepfunctions.steps.utils import tags_dict_to_kv_list | ||
|
@@ -25,6 +29,7 @@ | |
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config | ||
from sagemaker.model import Model, FrameworkModel | ||
from sagemaker.model_monitor import DataCaptureConfig | ||
from sagemaker.processing import ProcessingJob | ||
|
||
logger = logging.getLogger('stepfunctions.sagemaker') | ||
|
||
|
@@ -41,6 +46,104 @@ class SageMakerApi(Enum): | |
CreateProcessingJob = "createProcessingJob" | ||
|
||
|
||
class SageMakerTask(Task): | ||
|
||
""" | ||
Task State causes the interpreter to execute the work identified by the state’s `resource` field. | ||
""" | ||
|
||
def __init__(self, state_id, step_type, tags, **kwargs): | ||
""" | ||
Args: | ||
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. | ||
resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI. | ||
timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) | ||
timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. | ||
heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. | ||
heartbeat_seconds_path (str, optional): Path specifying the state's heartbeat value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. | ||
comment (str, optional): Human-readable comment or description. (default: None) | ||
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') | ||
parameters (dict, optional): The value of this field becomes the effective input for the state. | ||
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') | ||
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') | ||
""" | ||
self._replace_sagemaker_placeholders(step_type, kwargs) | ||
if tags: | ||
self.set_tags_config(tags, kwargs[Field.Parameters.value], step_type) | ||
|
||
super(SageMakerTask, self).__init__(state_id, **kwargs) | ||
|
||
|
||
def allowed_fields(self): | ||
sagemaker_fields = [ | ||
# ProcessingStep: Processor | ||
Field.Role, | ||
Field.ImageUri, | ||
Field.InstanceCount, | ||
Field.InstanceType, | ||
Field.Entrypoint, | ||
Field.VolumeSizeInGB, | ||
Field.VolumeKMSKey, | ||
Field.OutputKMSKey, | ||
Field.MaxRuntimeInSeconds, | ||
Field.Env, | ||
Field.Tags, | ||
] | ||
|
||
return super(SageMakerTask, self).allowed_fields() + sagemaker_fields | ||
|
||
|
||
def _replace_sagemaker_placeholders(self, step_type, args): | ||
# Fetch path from type | ||
sagemaker_parameters = args[Field.Parameters.value] | ||
paths = placeholder_paths.get(step_type) | ||
treated_args = [] | ||
|
||
for arg_name, value in args.items(): | ||
if arg_name in [Field.Parameters.value]: | ||
continue | ||
if arg_name in paths.keys(): | ||
path = paths.get(arg_name) | ||
if self._set_placeholder(sagemaker_parameters, path, value, arg_name): | ||
treated_args.append(arg_name) | ||
|
||
SageMakerTask.remove_treated_args(treated_args, args) | ||
|
||
@staticmethod | ||
def get_value_from_path(parameters, path): | ||
value_from_path = reduce(operator.getitem, path, parameters) | ||
return value_from_path | ||
# return reduce(operator.getitem, path, parameters) | ||
|
||
@staticmethod | ||
def _set_placeholder(parameters, path, value, arg_name): | ||
is_set = False | ||
try: | ||
SageMakerTask.get_value_from_path(parameters, path[:-1])[path[-1]] = value | ||
is_set = True | ||
except KeyError as e: | ||
message = f"Invalid path {path} for {arg_name}: {e}" | ||
raise InvalidPathToPlaceholderParameter(message) | ||
return is_set | ||
|
||
@staticmethod | ||
def remove_treated_args(treated_args, args): | ||
for treated_arg in treated_args: | ||
try: | ||
del args[treated_arg] | ||
except KeyError as e: | ||
pass | ||
|
||
def set_tags_config(self, tags, parameters, step_type): | ||
if isinstance(tags, Placeholder): | ||
# Replace with placeholder | ||
path = placeholder_paths.get(step_type).get(Field.Tags.value) | ||
if path: | ||
self._set_placeholder(parameters, path, tags, Field.Tags.value) | ||
else: | ||
parameters['Tags'] = tags_dict_to_kv_list(tags) | ||
|
||
|
||
class TrainingStep(Task): | ||
|
||
""" | ||
|
@@ -473,13 +576,19 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta | |
super(TuningStep, self).__init__(state_id, **kwargs) | ||
|
||
|
||
class ProcessingStep(Task): | ||
class ProcessingStep(SageMakerTask): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. breaking change?? the class constructor used to take in a Task but now that's been changed. when customers upgrade versions, won't their existing code fail? we cannot make breaking changes as we need to follow semantic versioning while releasing minor version updates. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This changes ProcessingStep's base class, but not the constructor arguments. With this change, instead of calling Task's constructor directly in init(), we call SageMakerTask's constructor which in turn calls Task's constructor. Before:
After:
|
||
|
||
""" | ||
Creates a Task State to execute a SageMaker Processing Job. | ||
|
||
The following properties can be passed down as kwargs to the sagemaker.processing.Processor to be used dynamically | ||
in the processing job (compatible with Placeholders): role, image_uri, instance_count, instance_type, | ||
volume_size_in_gb, volume_kms_key, output_kms_key | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious: what's the source of truth here? how did we verify that these properties are the ones supported. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ones that are made to be placeholder compatible are :
Since we are replacing the Placeholders with the ExecutionInput when starting the job, when the Sagemaker job starts, all placeholders are replaced. If some args that we configured to hold placeholder in our state machine were not replaced, this should trigger an error. I can add a test to confirm the behaviour in the next commit Thanks for bringing this up - this confirms that this documentation is not clear and we might want to switch to the Alternative solution where we would add all Placeholder compatible properties as optional args in the step constructor, making it clearer to the customer which are Placeholder compatible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is docstring is out of date with the new implementation |
||
""" | ||
|
||
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs): | ||
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, | ||
container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, | ||
tags=None, max_runtime_in_seconds=None, **kwargs): | ||
""" | ||
Args: | ||
ca-nguyen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. | ||
|
@@ -499,7 +608,8 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp | |
ARN of a KMS key, alias of a KMS key, or alias of a KMS key. | ||
The KmsKeyId is applied to all outputs. | ||
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) | ||
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource. | ||
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource. | ||
max_runtime_in_seconds (int or Placeholder): Specifies the maximum runtime in seconds for the processing job | ||
""" | ||
if wait_for_completion: | ||
""" | ||
|
@@ -528,12 +638,12 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp | |
if experiment_config is not None: | ||
parameters['ExperimentConfig'] = experiment_config | ||
|
||
if tags: | ||
parameters['Tags'] = tags_dict_to_kv_list(tags) | ||
|
||
if 'S3Operations' in parameters: | ||
del parameters['S3Operations'] | ||
|
||
if max_runtime_in_seconds: | ||
parameters['StoppingCondition'] = ProcessingJob.prepare_stopping_condition(max_runtime_in_seconds) | ||
|
||
kwargs[Field.Parameters.value] = parameters | ||
|
||
super(ProcessingStep, self).__init__(state_id, **kwargs) | ||
super(ProcessingStep, self).__init__(state_id, __class__.__name__, tags, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,9 @@ | |
from sagemaker.tuner import HyperparameterTuner | ||
from sagemaker.processing import ProcessingInput, ProcessingOutput | ||
|
||
from stepfunctions.inputs import ExecutionInput | ||
from stepfunctions.steps import Chain | ||
from stepfunctions.steps.fields import Field | ||
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep | ||
from stepfunctions.workflow import Workflow | ||
|
||
|
@@ -352,3 +354,85 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien | |
# Cleanup | ||
state_machine_delete_wait(sfn_client, workflow.state_machine_arn) | ||
# End of Cleanup | ||
|
||
|
||
def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn, | ||
sagemaker_role_arn): | ||
region = boto3.session.Session().region_name | ||
input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: why not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that using |
||
|
||
input_s3 = sagemaker_session.upload_data( | ||
path=os.path.join(DATA_DIR, 'sklearn_processing'), | ||
bucket=sagemaker_session.default_bucket(), | ||
key_prefix='integ-test-data/sklearn_processing/code' | ||
) | ||
|
||
output_s3 = 's3://' + sagemaker_session.default_bucket() + '/integ-test-data/sklearn_processing' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: why not use f strings here instead of concatenation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed - using Same comment: |
||
|
||
inputs = [ | ||
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'), | ||
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code', | ||
input_name='code'), | ||
] | ||
|
||
outputs = [ | ||
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data', | ||
output_name='train_data'), | ||
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data', | ||
output_name='test_data'), | ||
] | ||
|
||
# Build workflow definition | ||
execution_input = ExecutionInput(schema={ | ||
Field.ImageUri.value: str, | ||
Field.InstanceCount.value: int, | ||
Field.Entrypoint.value: str, | ||
Field.Role.value: str, | ||
Field.VolumeSizeInGB.value: int, | ||
Field.MaxRuntimeInSeconds.value: int | ||
}) | ||
|
||
job_name = generate_job_name() | ||
processing_step = ProcessingStep('create_processing_job_step', | ||
processor=sklearn_processor_fixture, | ||
job_name=job_name, | ||
inputs=inputs, | ||
outputs=outputs, | ||
container_arguments=['--train-test-split-ratio', '0.2'], | ||
container_entrypoint=execution_input[Field.Entrypoint.value], | ||
image_uri=execution_input[Field.ImageUri.value], | ||
instance_count=execution_input[Field.InstanceCount.value], | ||
role=execution_input[Field.Role.value], | ||
volume_size_in_gb=execution_input[Field.VolumeSizeInGB.value], | ||
max_runtime_in_seconds=execution_input[Field.MaxRuntimeInSeconds.value] | ||
) | ||
workflow_graph = Chain([processing_step]) | ||
|
||
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): | ||
# Create workflow and check definition | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: unnecessary comment as the method name expresses this in snake case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed- will be removed with the next commit |
||
workflow = create_workflow_and_check_definition( | ||
workflow_graph=workflow_graph, | ||
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"), | ||
sfn_client=sfn_client, | ||
sfn_role_arn=sfn_role_arn | ||
) | ||
|
||
execution_input = { | ||
Field.ImageUri.value: '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', | ||
Field.InstanceCount.value: 1, | ||
Field.Entrypoint.value: ['python3', '/opt/ml/processing/input/code/preprocessor.py'], | ||
Field.Role.value: sagemaker_role_arn, | ||
Field.VolumeSizeInGB.value: 30, | ||
Field.MaxRuntimeInSeconds.value: 500 | ||
} | ||
|
||
# Execute workflow | ||
execution = workflow.execute(inputs=execution_input) | ||
execution_output = execution.get_output(wait=True) | ||
|
||
# Check workflow output | ||
assert execution_output.get("ProcessingJobStatus") == "Completed" | ||
|
||
# Cleanup | ||
state_machine_delete_wait(sfn_client, workflow.state_machine_arn) | ||
# End of Cleanup | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: i think the code is self explanatory. we can drop this comment 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right! i'll remove the comments :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you forget to remove this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes - will remove it in the next commit! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1: I took the location to the CreateProcessingJob request from there to save in placeholder_paths for each arg
2: all can