-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support placeholders for processing step #155
Changes from all commits
927b24f
003b5e8
a7700a6
6b6443a
7f6ef30
00830f3
c708da7
2ea9e1f
17543ed
36e2ee8
ea40f7c
e499108
4c63229
34bb281
a098c61
06eb069
da99c92
37b2422
c433576
fd640ab
1dfa0e3
6143783
ebc5e22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
|
||
import boto3 | ||
import logging | ||
from stepfunctions.inputs import Placeholder | ||
|
||
logger = logging.getLogger('stepfunctions') | ||
|
||
|
@@ -45,3 +46,28 @@ def get_aws_partition(): | |
return cur_partition | ||
|
||
return cur_partition | ||
|
||
|
||
def merge_dicts(target, source): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: is it typical to modify a dict in place rather than return a merged dict that doesn't manipulate inputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was implemented having dict.update() function in mind, where it is possible to update a dict with another. In our case, we are merging nested dicts as well. |
||
""" | ||
ca-nguyen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Merges source dictionary into the target dictionary. | ||
Values in the target dict are updated with the values of the source dict. | ||
Args: | ||
target (dict): Base dictionary into which source is merged | ||
source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value | ||
will overwrite target's value for the corresponding key | ||
""" | ||
if isinstance(target, dict) and isinstance(source, dict): | ||
for key, value in source.items(): | ||
if key in target: | ||
if isinstance(target[key], dict) and isinstance(source[key], dict): | ||
merge_dicts(target[key], source[key]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice to see recursion being used :) |
||
elif target[key] == value: | ||
pass | ||
else: | ||
logger.info( | ||
f"Property: <{key}> with value: <{target[key]}>" | ||
f" will be overwritten with provided value: <{value}>") | ||
target[key] = source[key] | ||
else: | ||
target[key] = source[key] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
from sagemaker.tuner import HyperparameterTuner | ||
from sagemaker.processing import ProcessingInput, ProcessingOutput | ||
|
||
from stepfunctions.inputs import ExecutionInput | ||
from stepfunctions.steps import Chain | ||
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep | ||
from stepfunctions.workflow import Workflow | ||
|
@@ -352,3 +353,98 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien | |
# Cleanup | ||
state_machine_delete_wait(sfn_client, workflow.state_machine_arn) | ||
# End of Cleanup | ||
|
||
|
||
def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn, | ||
sagemaker_role_arn): | ||
region = boto3.session.Session().region_name | ||
input_data = f"s3://sagemaker-sample-data-{region}/processing/census/census-income.csv" | ||
|
||
input_s3 = sagemaker_session.upload_data( | ||
path=os.path.join(DATA_DIR, 'sklearn_processing'), | ||
bucket=sagemaker_session.default_bucket(), | ||
key_prefix='integ-test-data/sklearn_processing/code' | ||
) | ||
|
||
output_s3 = f"s3://{sagemaker_session.default_bucket()}/integ-test-data/sklearn_processing" | ||
|
||
inputs = [ | ||
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'), | ||
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code', | ||
input_name='code'), | ||
] | ||
|
||
outputs = [ | ||
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data', | ||
output_name='train_data'), | ||
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data', | ||
output_name='test_data'), | ||
] | ||
|
||
# Build workflow definition | ||
execution_input = ExecutionInput(schema={ | ||
'image_uri': str, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we're only using these values for test purposes, using the direct string values for better code readability |
||
'instance_count': int, | ||
'entrypoint': str, | ||
'role': str, | ||
'volume_size_in_gb': int, | ||
'max_runtime_in_seconds': int, | ||
'container_arguments': [str], | ||
}) | ||
|
||
parameters = { | ||
'AppSpecification': { | ||
'ContainerEntrypoint': execution_input['entrypoint'], | ||
'ImageUri': execution_input['image_uri'] | ||
}, | ||
'ProcessingResources': { | ||
'ClusterConfig': { | ||
'InstanceCount': execution_input['instance_count'], | ||
'VolumeSizeInGB': execution_input['volume_size_in_gb'] | ||
} | ||
}, | ||
'RoleArn': execution_input['role'], | ||
'StoppingCondition': { | ||
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds'] | ||
} | ||
} | ||
|
||
job_name = generate_job_name() | ||
processing_step = ProcessingStep('create_processing_job_step', | ||
processor=sklearn_processor_fixture, | ||
job_name=job_name, | ||
inputs=inputs, | ||
outputs=outputs, | ||
container_arguments=execution_input['container_arguments'], | ||
container_entrypoint=execution_input['entrypoint'], | ||
parameters=parameters | ||
) | ||
workflow_graph = Chain([processing_step]) | ||
|
||
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): | ||
workflow = create_workflow_and_check_definition( | ||
workflow_graph=workflow_graph, | ||
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"), | ||
sfn_client=sfn_client, | ||
sfn_role_arn=sfn_role_arn | ||
) | ||
|
||
execution_input = { | ||
'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3', | ||
'instance_count': 1, | ||
'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'], | ||
'role': sagemaker_role_arn, | ||
'volume_size_in_gb': 30, | ||
'max_runtime_in_seconds': 500, | ||
'container_arguments': ['--train-test-split-ratio', '0.2'] | ||
} | ||
|
||
# Execute workflow | ||
execution = workflow.execute(inputs=execution_input) | ||
execution_output = execution.get_output(wait=True) | ||
|
||
# Check workflow output | ||
assert execution_output.get("ProcessingJobStatus") == "Completed" | ||
|
||
# Cleanup | ||
state_machine_delete_wait(sfn_client, workflow.state_machine_arn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice to see us embracing pep8 in files we touch 🙌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙌🙌🙌