Skip to content

Commit

Permalink
Merge pull request #84 from yzhu0/modifyTCUpdateParam
Browse files Browse the repository at this point in the history
Enable Removal of Parameters and Artifacts under TrialComponent update
  • Loading branch information
yzhu0 authored Jul 1, 2020
2 parents a5e8f78 + 0ee91c9 commit 9af182d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
9 changes: 9 additions & 0 deletions src/smexperiments/trial_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class TrialComponent(_base_types.Record):
input_artiacts (dict): Dictionary of input artifacts.
output_artiacts (dict): Dictionary of output artifacts.
metrics (obj): Aggregated metrics for the job.
parameters_to_remove (list): The hyperparameters to remove from the component.
input_artifacts_to_remove (list): The input artifacts to remove from the component.
output_artifacts_to_remove (list): The output artifacts to remove from the component.
"""

trial_component_name = None
Expand All @@ -57,6 +60,9 @@ class TrialComponent(_base_types.Record):
input_artifacts = None
output_artifacts = None
metrics = None
parameters_to_remove = None
input_artifacts_to_remove = None
output_artifacts_to_remove = None

_boto_load_method = "describe_trial_component"
_boto_create_method = "create_trial_component"
Expand All @@ -81,6 +87,9 @@ class TrialComponent(_base_types.Record):
"parameters",
"input_artifacts",
"output_artifacts",
"parameters_to_remove",
"input_artifacts_to_remove",
"output_artifacts_to_remove",
]
_boto_delete_members = ["trial_component_name"]

Expand Down
20 changes: 15 additions & 5 deletions tests/integ/test_trial_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@ def test_save(trial_component_obj, sagemaker_boto_client):
trial_component_obj.end_time = datetime.datetime.now(datetime.timezone.utc)
trial_component_obj.parameters = {"foo": "bar", "whizz": 100.1}
trial_component_obj.input_artifacts = {
"snizz": api_types.TrialComponentArtifact(value="s3:/foo/bar", media_type="text/plain")
"snizz": api_types.TrialComponentArtifact(value="s3:/foo/bar", media_type="text/plain"),
"snizz1": api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2"),
}
trial_component_obj.output_artifacts = {
"fly": api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow")
"fly": api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow"),
"fly2": api_types.TrialComponentArtifact(value="s3:/sky/far2", media_type="away/tomorrow2"),
}
trial_component_obj.parameters_to_remove = ["foo"]
trial_component_obj.input_artifacts_to_remove = ["snizz"]
trial_component_obj.output_artifacts_to_remove = ["fly2"]

trial_component_obj.save()

loaded = trial_component.TrialComponent.load(
Expand All @@ -47,9 +53,13 @@ def test_save(trial_component_obj, sagemaker_boto_client):
assert trial_component_obj.start_time - loaded.start_time < datetime.timedelta(seconds=1)
assert trial_component_obj.end_time - loaded.end_time < datetime.timedelta(seconds=1)

assert trial_component_obj.parameters == loaded.parameters
assert trial_component_obj.input_artifacts == loaded.input_artifacts
assert trial_component_obj.output_artifacts == loaded.output_artifacts
assert loaded.parameters == {"whizz": 100.1}
assert loaded.input_artifacts == {
"snizz1": api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2")
}
assert loaded.output_artifacts == {
"fly": api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow")
}


def test_load(trial_component_obj, sagemaker_boto_client):
Expand Down
18 changes: 16 additions & 2 deletions tests/unit/test_trial_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,24 @@ def test_search(sagemaker_boto_client):


def test_save(sagemaker_boto_client):
obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar")
obj = trial_component.TrialComponent(
sagemaker_boto_client,
trial_component_name="foo",
display_name="bar",
parameters_to_remove=["E"],
input_artifacts_to_remove=["F"],
output_artifacts_to_remove=["G"],
)
sagemaker_boto_client.update_trial_component.return_value = {}
obj.save()
sagemaker_boto_client.update_trial_component.assert_called_with(TrialComponentName="foo", DisplayName="bar")

sagemaker_boto_client.update_trial_component.assert_called_with(
TrialComponentName="foo",
DisplayName="bar",
ParametersToRemove=["E"],
InputArtifactsToRemove=["F"],
OutputArtifactsToRemove=["G"],
)


def test_delete(sagemaker_boto_client):
Expand Down

0 comments on commit 9af182d

Please sign in to comment.