diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md new file mode 100644 index 0000000..6f68eb1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -0,0 +1,51 @@ +--- +name: "\U0001F41B Bug Report" +about: Report a bug +title: "short issue description" +labels: bug, needs-triage +--- + + + + + + +### Reproduction Steps + + + +### What did you expect to happen? + + + +### What actually happened? + + + + +### Environment + + - **AWS Step Functions Data Science Python SDK version :** + - **Python Version:** + +### Other + + + + + + +--- + +This is :bug: Bug Report \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/doc.md b/.github/ISSUE_TEMPLATE/doc.md new file mode 100644 index 0000000..931e8ac --- /dev/null +++ b/.github/ISSUE_TEMPLATE/doc.md @@ -0,0 +1,28 @@ +--- +name: "📕 Documentation Issue" +about: Issue in the reference documentation +title: "short issue description" +labels: feature-request, documentation, needs-triage +--- + + + + + + + + + + + + + +--- + +This is a 📕 documentation issue diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 0000000..5570a1b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,46 @@ +--- +name: "\U0001F680 Feature Request" +about: Request a new feature +title: "short issue description" +labels: feature-request, needs-triage +--- + + + + + + + +### Use Case + + + + + + + +### Proposed Solution + + + + + + + +### Other + + + + + + + +* [ ] :wave: I may be able to implement this feature request +* [ ] :warning: This feature might incur a breaking change + +--- + +This is a :rocket: Feature Request diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..3f06f79 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,43 @@ +### Description + +Please include a summary of the change being made. + +Fixes #(issue) + +### Why is the change necessary? + +What capability does it enable? What problem does it solve? + +### Solution + +Please include an overview of the solution. Discuss trade-offs made, caveats, alternatives, etc. + +### Testing + +How was this change tested? + +---- + +### Pull Request Checklist + +Please check all boxes (including N/A items) + +#### Testing + +- [ ] Unit tests added +- [ ] Integration test added +- [ ] Manual testing - why was it necessary? could it be automated? + +#### Documentation + +- [ ] __docs__: All relevant [docs](https://github.com/aws/aws-step-functions-data-science-sdk-python/tree/main/doc) updated +- [ ] __docstrings__: All public APIs documented + +### Title and description + +- [ ] __Change type__: Title is prefixed with change type: and follows [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) +- [ ] __References__: Indicate issues fixed via: `Fixes #xxx` + +---- + +By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license. diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 28edd35..8ee82a1 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -185,36 +185,42 @@ def __merge_hyperparameters(self, training_step_hyperparameters, estimator_hyper merged_hyperparameters[key] = value return merged_hyperparameters + class TransformStep(Task): """ Creates a Task State to execute a `SageMaker Transform Job `_. """ - def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, input_filter=None, output_filter=None, join_source=None, **kwargs): + def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, + compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, + input_filter=None, output_filter=None, join_source=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. transformer (sagemaker.transformer.Transformer): The SageMaker transformer to use in the TransformStep. job_name (str or Placeholder): Specify a transform job name. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution. model_name (str or Placeholder): Specify a model name for the transform job to use. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution. - data (str): Input data location in S3. - data_type (str): What the S3 location defines (default: 'S3Prefix'). + data (str or Placeholder): Input data location in S3. + data_type (str or Placeholder): What the S3 location defines (default: 'S3Prefix'). Valid values: * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as inputs for the transform job. * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as an input for the transform job. - content_type (str): MIME type of the input data (default: None). - compression_type (str): Compression type of the input data, if compressed (default: None). Valid values: 'Gzip', None. - split_type (str): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'. - experiment_config (dict, optional): Specify the experiment config for the transform. (Default: None) + content_type (str or Placeholder): MIME type of the input data (default: None). + compression_type (str or Placeholder): Compression type of the input data, if compressed (default: None). Valid values: 'Gzip', None. + split_type (str or Placeholder): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'. + experiment_config (dict or Placeholder, optional): Specify the experiment config for the transform. (Default: None) wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the transform job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the transform job and proceed to the next step. (default: True) - tags (list[dict], optional): `List to tags `_ to associate with the resource. - input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for inference. If you omit the field, it gets the value ‘$’, representing the entire input. For CSV data, each row is taken as a JSON array, so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. CSV data should follow the RFC format. See Supported JSONPath Operators for a table of supported JSONPath operators. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.features” (default: None). - output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.prediction” (default: None). - join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None. + tags (list[dict] or Placeholder, optional): `List to tags `_ to associate with the resource. + input_filter (str or Placeholder): A JSONPath to select a portion of the input to pass to the algorithm container for inference. If you omit the field, it gets the value ‘$’, representing the entire input. For CSV data, each row is taken as a JSON array, so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. CSV data should follow the RFC format. See Supported JSONPath Operators for a table of supported JSONPath operators. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.features” (default: None). + output_filter (str or Placeholder): A JSONPath to select a portion of the joined/original output to return as the output. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.prediction” (default: None). + join_source (str or Placeholder): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None. + parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateTransformJob`_. + You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders`_. + """ if wait_for_completion: """ @@ -233,7 +239,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= SageMakerApi.CreateTransformJob) if isinstance(job_name, str): - parameters = transform_config( + transform_parameters = transform_config( transformer=transformer, data=data, data_type=data_type, @@ -246,7 +252,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= join_source=join_source ) else: - parameters = transform_config( + transform_parameters = transform_config( transformer=transformer, data=data, data_type=data_type, @@ -259,17 +265,21 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= ) if isinstance(job_name, Placeholder): - parameters['TransformJobName'] = job_name + transform_parameters['TransformJobName'] = job_name - parameters['ModelName'] = model_name + transform_parameters['ModelName'] = model_name if experiment_config is not None: - parameters['ExperimentConfig'] = experiment_config + transform_parameters['ExperimentConfig'] = experiment_config if tags: - parameters['Tags'] = tags_dict_to_kv_list(tags) + transform_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags) - kwargs[Field.Parameters.value] = parameters + if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict): + # Update transform_parameters with input parameters + merge_dicts(transform_parameters, kwargs[Field.Parameters.value]) + + kwargs[Field.Parameters.value] = transform_parameters super(TransformStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/states.py b/src/stepfunctions/steps/states.py index 9669a73..8396e69 100644 --- a/src/stepfunctions/steps/states.py +++ b/src/stepfunctions/steps/states.py @@ -254,27 +254,29 @@ def accept(self, visitor): def add_retry(self, retry): """ - Add a Retry block to the tail end of the list of retriers for the state. + Add a retrier or a list of retriers to the tail end of the list of retriers for the state. + See `Error handling in Step Functions `_ for more details. Args: - retry (Retry): Retry block to add. + retry (Retry or list(Retry)): A retrier or list of retriers to add. """ if Field.Retry in self.allowed_fields(): - self.retries.append(retry) + self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry) else: - raise ValueError("{state_type} state does not support retry field. ".format(state_type=type(self).__name__)) + raise ValueError(f"{type(self).__name__} state does not support retry field. ") def add_catch(self, catch): """ - Add a Catch block to the tail end of the list of catchers for the state. + Add a catcher or a list of catchers to the tail end of the list of catchers for the state. + See `Error handling in Step Functions `_ for more details. Args: - catch (Catch): Catch block to add. + catch (Catch or list(Catch): catcher or list of catchers to add. """ if Field.Catch in self.allowed_fields(): - self.catches.append(catch) + self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch) else: - raise ValueError("{state_type} state does not support catch field. ".format(state_type=type(self).__name__)) + raise ValueError(f"{type(self).__name__} state does not support catch field. ") def to_dict(self): result = super(State, self).to_dict() @@ -487,10 +489,12 @@ class Parallel(State): A Parallel state causes the interpreter to execute each branch as concurrently as possible, and wait until each branch terminates (reaches a terminal state) before processing the next state in the Chain. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. + retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions `_ for more details. + catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions `_ for more details. comment (str, optional): Human-readable comment or description. (default: None) input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') parameters (dict, optional): The value of this field becomes the effective input for the state. @@ -500,6 +504,12 @@ def __init__(self, state_id, **kwargs): super(Parallel, self).__init__(state_id, 'Parallel', **kwargs) self.branches = [] + if retry: + self.add_retry(retry) + + if catch: + self.add_catch(catch) + def allowed_fields(self): return [ Field.Comment, @@ -536,11 +546,13 @@ class Map(State): A Map state can accept an input with a list of items, execute a state or chain for each item in the list, and return a list, with all corresponding results of each execution, as its output. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. iterator (State or Chain): State or chain to execute for each of the items in `items_path`. + retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions `_ for more details. + catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions `_ for more details. items_path (str, optional): Path in the input for items to iterate over. (default: '$') max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0) comment (str, optional): Human-readable comment or description. (default: None) @@ -551,6 +563,12 @@ def __init__(self, state_id, **kwargs): """ super(Map, self).__init__(state_id, 'Map', **kwargs) + if retry: + self.add_retry(retry) + + if catch: + self.add_catch(catch) + def attach_iterator(self, iterator): """ Attach `State` or `Chain` as iterator to the Map state, that will execute for each of the items in `items_path`. If an iterator was attached previously with the Map state, it will be replaced. @@ -586,10 +604,12 @@ class Task(State): Task State causes the interpreter to execute the work identified by the state’s `resource` field. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. + retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions `_ for more details. + catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions `_ for more details. resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI. timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. @@ -608,6 +628,12 @@ def __init__(self, state_id, **kwargs): if self.heartbeat_seconds is not None and self.heartbeat_seconds_path is not None: raise ValueError("Only one of 'heartbeat_seconds' or 'heartbeat_seconds_path' can be provided.") + if retry: + self.add_retry(retry) + + if catch: + self.add_catch(catch) + def allowed_fields(self): return [ Field.Comment, diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index b415f1f..2b497ed 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -179,6 +179,96 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn): state_machine_delete_wait(sfn_client, workflow.state_machine_arn) # End of Cleanup + +def test_transform_step_with_placeholder(trained_estimator, sfn_client, sfn_role_arn): + # Create transformer from supplied estimator + job_name = generate_job_name() + pca_transformer = trained_estimator.transformer(instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE) + + # Create a model step to save the model + model_step = ModelStep('create_model_step', model=trained_estimator.create_model(), model_name=job_name) + model_step.add_retry(SAGEMAKER_RETRY_STRATEGY) + + # Upload data for transformation to S3 + data_path = os.path.join(DATA_DIR, "one_p_mnist") + transform_input_path = os.path.join(data_path, "transform_input.csv") + transform_input_key_prefix = "integ-test-data/one_p_mnist/transform" + transform_input = pca_transformer.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) + + execution_input = ExecutionInput(schema={ + 'data': str, + 'content_type': str, + 'split_type': str, + 'job_name': str, + 'model_name': str, + 'instance_count': int, + 'instance_type': str, + 'strategy': str, + 'max_concurrent_transforms': int, + 'max_payload': int, + }) + + parameters = { + 'BatchStrategy': execution_input['strategy'], + 'TransformInput': { + 'SplitType': execution_input['split_type'], + }, + 'TransformResources': { + 'InstanceCount': execution_input['instance_count'], + 'InstanceType': execution_input['instance_type'], + }, + 'MaxConcurrentTransforms': execution_input['max_concurrent_transforms'], + 'MaxPayloadInMB': execution_input['max_payload'] + } + + # Build workflow definition + transform_step = TransformStep( + 'create_transform_job_step', + pca_transformer, + job_name=execution_input['job_name'], + model_name=execution_input['model_name'], + data=execution_input['data'], + content_type=execution_input['content_type'], + parameters=parameters + ) + transform_step.add_retry(SAGEMAKER_RETRY_STRATEGY) + workflow_graph = Chain([model_step, transform_step]) + + with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): + # Create workflow and check definition + workflow = create_workflow_and_check_definition( + workflow_graph=workflow_graph, + workflow_name=unique_name_from_base("integ-test-transform-step-workflow"), + sfn_client=sfn_client, + sfn_role_arn=sfn_role_arn + ) + + execution_input = { + 'job_name': job_name, + 'model_name': job_name, + 'data': transform_input, + 'content_type': "text/csv", + 'instance_count': INSTANCE_COUNT, + 'instance_type': INSTANCE_TYPE, + 'split_type': 'Line', + 'strategy': 'SingleRecord', + 'max_concurrent_transforms': 2, + 'max_payload': 5 + } + + # Execute workflow + execution = workflow.execute(inputs=execution_input) + execution_output = execution.get_output(wait=True) + + # Check workflow output + assert execution_output.get("TransformJobStatus") == "Completed" + + # Cleanup + state_machine_delete_wait(sfn_client, workflow.state_machine_arn) + + def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn): # Setup: Create model for trained estimator in SageMaker model = trained_estimator.create_model() diff --git a/tests/integ/test_state_machine_definition.py b/tests/integ/test_state_machine_definition.py index d21e59b..4881b75 100644 --- a/tests/integ/test_state_machine_definition.py +++ b/tests/integ/test_state_machine_definition.py @@ -422,18 +422,38 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): catch_state_name = "TaskWithCatchState" - custom_error = "CustomError" task_failed_error = "States.TaskFailed" - all_fail_error = "States.ALL" - custom_error_state_name = "Custom Error End" - task_failed_state_name = "Task Failed End" - all_error_state_name = "Catch All End" + timeout_error = "States.Timeout" + task_failed_state_name = "Catch Task Failed End" + timeout_state_name = "Catch Timeout End" catch_state_result = "Catch Result" task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" - # change the parameters to cause task state to fail + # Provide invalid TrainingImage to cause States.TaskFailed error training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" + task = steps.Task( + catch_state_name, + parameters=training_job_parameters, + resource=task_resource, + catch=steps.Catch( + error_equals=[timeout_error], + next_step=steps.Pass(timeout_state_name, result=catch_state_result) + ) + ) + task.add_catch( + steps.Catch( + error_equals=[task_failed_error], + next_step=steps.Pass(task_failed_state_name, result=catch_state_result) + ) + ) + + workflow = Workflow( + unique_name_from_base('Test_Catch_Workflow'), + definition=task, + role=sfn_role_arn + ) + asl_state_machine_definition = { "StartAt": catch_state_name, "States": { @@ -445,80 +465,61 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par "Catch": [ { "ErrorEquals": [ - all_fail_error + timeout_error ], - "Next": all_error_state_name + "Next": timeout_state_name + }, + { + "ErrorEquals": [ + task_failed_error + ], + "Next": task_failed_state_name } ] }, - all_error_state_name: { + task_failed_state_name: { "Type": "Pass", "Result": catch_state_result, "End": True - } + }, + timeout_state_name: { + "Type": "Pass", + "Result": catch_state_result, + "End": True + }, } } - task = steps.Task( - catch_state_name, - parameters=training_job_parameters, - resource=task_resource - ) - task.add_catch( - steps.Catch( - error_equals=[all_fail_error], - next_step=steps.Pass(all_error_state_name, result=catch_state_result) - ) - ) - - workflow = Workflow( - unique_name_from_base('Test_Catch_Workflow'), - definition=task, - role=sfn_role_arn - ) workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result) def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): retry_state_name = "RetryStateName" - all_fail_error = "Starts.ALL" + task_failed_error = "States.TaskFailed" + timeout_error = "States.Timeout" interval_seconds = 1 max_attempts = 2 backoff_rate = 2 task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" - # change the parameters to cause task state to fail + # Provide invalid TrainingImage to cause States.TaskFailed error training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" - asl_state_machine_definition = { - "StartAt": retry_state_name, - "States": { - retry_state_name: { - "Resource": task_resource, - "Parameters": training_job_parameters, - "Type": "Task", - "End": True, - "Retry": [ - { - "ErrorEquals": [all_fail_error], - "IntervalSeconds": interval_seconds, - "MaxAttempts": max_attempts, - "BackoffRate": backoff_rate - } - ] - } - } - } - task = steps.Task( retry_state_name, parameters=training_job_parameters, - resource=task_resource + resource=task_resource, + retry=steps.Retry( + error_equals=[timeout_error], + interval_seconds=interval_seconds, + max_attempts=max_attempts, + backoff_rate=backoff_rate + ) ) task.add_retry( steps.Retry( - error_equals=[all_fail_error], + error_equals=[task_failed_error], interval_seconds=interval_seconds, max_attempts=max_attempts, backoff_rate=backoff_rate @@ -531,4 +532,30 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par role=sfn_role_arn ) - workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None) \ No newline at end of file + asl_state_machine_definition = { + "StartAt": retry_state_name, + "States": { + retry_state_name: { + "Resource": task_resource, + "Parameters": training_job_parameters, + "Type": "Task", + "End": True, + "Retry": [ + { + "ErrorEquals": [timeout_error], + "IntervalSeconds": interval_seconds, + "MaxAttempts": max_attempts, + "BackoffRate": backoff_rate + }, + { + "ErrorEquals": [task_failed_error], + "IntervalSeconds": interval_seconds, + "MaxAttempts": max_attempts, + "BackoffRate": backoff_rate + } + ] + } + } + } + + workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None) diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index de95858..c18756e 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -902,6 +902,117 @@ def test_transform_step_creation(pca_transformer): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_transform_step_creation_with_placeholder(pca_transformer): + execution_input = ExecutionInput(schema={ + 'data': str, + 'data_type': str, + 'content_type': str, + 'compression_type': str, + 'split_type': str, + 'input_filter': str, + 'output_filter': str, + 'join_source': str, + 'job_name': str, + 'model_name': str, + 'instance_count': int, + 'strategy': str, + 'assemble_with': str, + 'output_path': str, + 'output_kms_key': str, + 'accept': str, + 'max_concurrent_transforms': int, + 'max_payload': int, + 'tags': [{str: str}], + 'env': str, + 'volume_kms_key': str, + 'experiment_config': str, + }) + + step_input = StepInput(schema={ + 'instance_type': str + }) + + parameters = { + 'BatchStrategy': execution_input['strategy'], + 'TransformOutput': { + 'Accept': execution_input['accept'], + 'AssembleWith': execution_input['assemble_with'], + 'KmsKeyId': execution_input['output_kms_key'], + 'S3OutputPath': execution_input['output_path'] + }, + 'TransformResources': { + 'InstanceCount': execution_input['instance_count'], + 'InstanceType': step_input['instance_type'], + 'VolumeKmsKeyId': execution_input['volume_kms_key'] + }, + 'ExperimentConfig': execution_input['experiment_config'], + 'Tags': execution_input['tags'], + 'Environment': execution_input['env'], + 'MaxConcurrentTransforms': execution_input['max_concurrent_transforms'], + 'MaxPayloadInMB': execution_input['max_payload'], + } + + step = TransformStep('Inference', + transformer=pca_transformer, + data=execution_input['data'], + data_type=execution_input['data_type'], + content_type=execution_input['content_type'], + compression_type=execution_input['compression_type'], + split_type=execution_input['split_type'], + job_name=execution_input['job_name'], + model_name=execution_input['model_name'], + experiment_config={ + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Transform' + }, + tags=execution_input['tags'], + join_source=execution_input['join_source'], + output_filter=execution_input['output_filter'], + input_filter=execution_input['input_filter'], + parameters=parameters + ) + + assert step.to_dict()['Parameters'] == { + 'BatchStrategy.$': "$$.Execution.Input['strategy']", + 'ModelName.$': "$$.Execution.Input['model_name']", + 'TransformInput': { + 'CompressionType.$': "$$.Execution.Input['compression_type']", + 'ContentType.$': "$$.Execution.Input['content_type']", + 'DataSource': { + 'S3DataSource': { + 'S3DataType.$': "$$.Execution.Input['data_type']", + 'S3Uri.$': "$$.Execution.Input['data']" + } + }, + 'SplitType.$': "$$.Execution.Input['split_type']" + }, + 'TransformOutput': { + 'Accept.$': "$$.Execution.Input['accept']", + 'AssembleWith.$': "$$.Execution.Input['assemble_with']", + 'KmsKeyId.$': "$$.Execution.Input['output_kms_key']", + 'S3OutputPath.$': "$$.Execution.Input['output_path']" + }, + 'TransformJobName.$': "$$.Execution.Input['job_name']", + 'TransformResources': { + 'InstanceCount.$': "$$.Execution.Input['instance_count']", + 'InstanceType.$': "$['instance_type']", + 'VolumeKmsKeyId.$': "$$.Execution.Input['volume_kms_key']" + }, + 'ExperimentConfig.$': "$$.Execution.Input['experiment_config']", + 'DataProcessing': { + 'InputFilter.$': "$$.Execution.Input['input_filter']", + 'OutputFilter.$': "$$.Execution.Input['output_filter']", + 'JoinSource.$': "$$.Execution.Input['join_source']", + }, + 'Tags.$': "$$.Execution.Input['tags']", + 'Environment.$': "$$.Execution.Input['env']", + 'MaxConcurrentTransforms.$': "$$.Execution.Input['max_concurrent_transforms']", + 'MaxPayloadInMB.$': "$$.Execution.Input['max_payload']" + } + + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_get_expected_model(pca_estimator): diff --git a/tests/unit/test_steps.py b/tests/unit/test_steps.py index 5c86279..3d34ee8 100644 --- a/tests/unit/test_steps.py +++ b/tests/unit/test_steps.py @@ -469,4 +469,126 @@ def test_default_paths_not_converted_to_null(): assert '"OutputPath": null' not in task_state.to_json() - +RETRY = Retry(error_equals=['ErrorA', 'ErrorB'], interval_seconds=1, max_attempts=2, backoff_rate=2) +RETRIES = [RETRY, Retry(error_equals=['ErrorC'], interval_seconds=5)] +EXPECTED_RETRY = [{'ErrorEquals': ['ErrorA', 'ErrorB'], 'IntervalSeconds': 1, 'BackoffRate': 2, 'MaxAttempts': 2}] +EXPECTED_RETRIES = EXPECTED_RETRY + [{'ErrorEquals': ['ErrorC'], 'IntervalSeconds': 5}] + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES), +]) +def test_parallel_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Parallel('Parallel', retry=retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES), +]) +def test_parallel_state_add_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Parallel('Parallel') + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES), +]) +def test_map_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Map('Map', retry=retry, iterator=Pass('Iterator')) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRIES, EXPECTED_RETRIES), + (RETRY, EXPECTED_RETRY), +]) +def test_map_state_add_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Map('Map', iterator=Pass('Iterator')) + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES) +]) +def test_task_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Task('Task', retry=retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES) +]) +def test_task_state_add_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Task('Task') + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + +CATCH = Catch(error_equals=['States.ALL'], next_step=Pass('End State')) +CATCHES = [CATCH, Catch(error_equals=['States.TaskFailed'], next_step=Pass('Next State'))] +EXPECTED_CATCH = [{'ErrorEquals': ['States.ALL'], 'Next': 'End State'}] +EXPECTED_CATCHES = EXPECTED_CATCH + [{'ErrorEquals': ['States.TaskFailed'], 'Next': 'Next State'}] + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_parallel_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Parallel('Parallel', catch=catch) + assert step.to_dict()['Catch'] == expected_catch + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_parallel_state_add_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Parallel('Parallel') + step.add_catch(catch) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_map_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Map('Map', catch=catch, iterator=Pass('Iterator')) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_map_state_add_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Map('Map', iterator=Pass('Iterator')) + step.add_catch(catch) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_task_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Task('Task', catch=catch) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_task_state_add_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Task('Task') + step.add_catch(catch) + assert step.to_dict()['Catch'] == expected_catch diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index b11398e..31b51fd 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -272,7 +272,7 @@ def test_list_workflows(client): def test_cloudformation_export_with_simple_definition(workflow): cfn_template = workflow.get_cloudformation_template() - cfn_template = yaml.load(cfn_template) + cfn_template = yaml.safe_load(cfn_template) assert 'StateMachineComponent' in cfn_template['Resources'] assert workflow.role == cfn_template['Resources']['StateMachineComponent']['Properties']['RoleArn'] assert cfn_template['Description'] == "CloudFormation template for AWS Step Functions - State Machine" @@ -300,7 +300,7 @@ def test_cloudformation_export_with_sagemaker_execution_role(workflow): } }) cfn_template = workflow.get_cloudformation_template(description="CloudFormation template with Sagemaker role") - cfn_template = yaml.load(cfn_template) + cfn_template = yaml.safe_load(cfn_template) assert json.dumps(workflow.definition.to_dict(), indent=2) == cfn_template['Resources']['StateMachineComponent']['Properties']['DefinitionString'] assert workflow.role == cfn_template['Resources']['StateMachineComponent']['Properties']['RoleArn'] assert cfn_template['Description'] == "CloudFormation template with Sagemaker role"