Skip to content

Commit

Permalink
Allow Passing Tags in E/T/TC Create methods
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhu0 committed Jul 9, 2020
1 parent 9af182d commit 7a78b7f
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 7 deletions.
6 changes: 5 additions & 1 deletion src/smexperiments/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ class Experiment(_base_types.Record):
Attributes:
experiment_name (str): The name of the experiment. The name must be unique within an account.
description (str): A description of the experiment.
tags (List[dict[str, str]]): A list of tags to associate with the experiment.
"""

experiment_name = None
description = None
tags = None

_boto_create_method = "create_experiment"
_boto_load_method = "describe_experiment"
Expand Down Expand Up @@ -92,7 +94,7 @@ def load(cls, experiment_name, sagemaker_boto_client=None):
)

@classmethod
def create(cls, experiment_name=None, description=None, sagemaker_boto_client=None):
def create(cls, experiment_name=None, description=None, tags=None, sagemaker_boto_client=None):
"""
Create a new experiment in SageMaker and return an ``Experiment`` object.
Expand All @@ -101,6 +103,7 @@ def create(cls, experiment_name=None, description=None, sagemaker_boto_client=No
experiment_description: (str, optional): Description of the experiment
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker. If not
supplied, a default boto3 client will be created and used.
tags (List[dict[str, str]]): A list of tags to associate with the experiment.
Returns:
sagemaker.experiments.experiment.Experiment: A SageMaker ``Experiment`` object
Expand All @@ -109,6 +112,7 @@ def create(cls, experiment_name=None, description=None, sagemaker_boto_client=No
cls._boto_create_method,
experiment_name=experiment_name,
description=description,
tags=tags,
sagemaker_boto_client=sagemaker_boto_client,
)

Expand Down
8 changes: 6 additions & 2 deletions src/smexperiments/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ class Trial(_base_types.Record):
Attributes:
trial_name (str): The name of the trial.
experiment_name (str): The name of the trial's experiment.
tags (List[dict[str, str]]): A list of tags to associate with the trial.
"""

trial_name = None
experiment_name = None
tags = None

_boto_create_method = "create_trial"
_boto_load_method = "describe_trial"
Expand Down Expand Up @@ -96,15 +98,16 @@ def load(cls, trial_name, sagemaker_boto_client=None):
)

@classmethod
def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, trial_components=None):
def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, trial_components=None, tags=None):
"""Create a new trial and return a ``Trial`` object.
Args:
experiment_name: (str): Name of the experiment to create this trial in.
trial_name: (str, optional): Name of the Trial. If not specified, an auto-generated name will be used.
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker.
If not supplied, a default boto3 client will be created and used.
trial_components (list): A list of trial component names, trial components, or trial component trackers
trial_components (list): A list of trial component names, trial components, or trial component trackers.
tags (List[dict[str, str]]): A list of tags to associate with the trial.
Returns:
smexperiments.trial.Trial: A SageMaker ``Trial`` object
Expand All @@ -114,6 +117,7 @@ def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, tr
cls._boto_create_method,
trial_name=trial_name,
experiment_name=experiment_name,
tags=tags,
sagemaker_boto_client=sagemaker_boto_client,
)
if trial_components:
Expand Down
5 changes: 4 additions & 1 deletion src/smexperiments/trial_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TrialComponent(_base_types.Record):
parameters_to_remove (list): The hyperparameters to remove from the component.
input_artifacts_to_remove (list): The input artifacts to remove from the component.
output_artifacts_to_remove (list): The output artifacts to remove from the component.
tags (List[dict[str, str]]): A list of tags to associate with the trial component.
"""

trial_component_name = None
Expand All @@ -63,6 +64,7 @@ class TrialComponent(_base_types.Record):
parameters_to_remove = None
input_artifacts_to_remove = None
output_artifacts_to_remove = None
tags = None

_boto_load_method = "describe_trial_component"
_boto_create_method = "create_trial_component"
Expand Down Expand Up @@ -125,7 +127,7 @@ def load(cls, trial_component_name, sagemaker_boto_client=None):
return trial_component

@classmethod
def create(cls, trial_component_name, display_name=None, sagemaker_boto_client=None):
def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_boto_client=None):
"""Create a trial component and return a ``TrialComponent`` object representing it.
Returns:
Expand All @@ -136,6 +138,7 @@ def create(cls, trial_component_name, display_name=None, sagemaker_boto_client=N
cls._boto_create_method,
trial_component_name=trial_component_name,
display_name=display_name,
tags=tags,
sagemaker_boto_client=sagemaker_boto_client,
)

Expand Down
11 changes: 8 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from smexperiments import experiment, trial, trial_component
from tests.helpers import name, names

TAGS = [{"Key": "some-key", "Value": "some-value"}]


def pytest_addoption(parser):
parser.addoption("--boto-model-file", action="store", default=None)
Expand Down Expand Up @@ -93,7 +95,7 @@ def experiment_obj(sagemaker_boto_client):
description = "{}-{}".format("description", str(uuid.uuid4()))
boto3.set_stream_logger("", logging.INFO)
experiment_obj = experiment.Experiment.create(
experiment_name=name(), description=description, sagemaker_boto_client=sagemaker_boto_client
experiment_name=name(), description=description, sagemaker_boto_client=sagemaker_boto_client, tags=TAGS
)
yield experiment_obj
time.sleep(0.5)
Expand All @@ -103,7 +105,10 @@ def experiment_obj(sagemaker_boto_client):
@pytest.fixture
def trial_obj(sagemaker_boto_client, experiment_obj):
trial_obj = trial.Trial.create(
trial_name=name(), experiment_name=experiment_obj.experiment_name, sagemaker_boto_client=sagemaker_boto_client
trial_name=name(),
experiment_name=experiment_obj.experiment_name,
tags=TAGS,
sagemaker_boto_client=sagemaker_boto_client,
)
yield trial_obj
time.sleep(0.5)
Expand All @@ -113,7 +118,7 @@ def trial_obj(sagemaker_boto_client, experiment_obj):
@pytest.fixture
def trial_component_obj(sagemaker_boto_client):
trial_component_obj = trial_component.TrialComponent.create(
trial_component_name=name(), sagemaker_boto_client=sagemaker_boto_client
trial_component_name=name(), sagemaker_boto_client=sagemaker_boto_client, tags=TAGS,
)
yield trial_component_obj
time.sleep(0.5)
Expand Down
8 changes: 8 additions & 0 deletions tests/integ/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ def test_create_delete(experiment_obj):
assert experiment_obj.experiment_name


def test_create_tags(experiment_obj, sagemaker_boto_client):
while True:
actual_tags = sagemaker_boto_client.list_tags(ResourceArn=experiment_obj.experiment_arn)["Tags"]
if actual_tags:
break
assert actual_tags == experiment_obj.tags


def test_save(experiment_obj):
description = name()
experiment_obj.description = description
Expand Down
8 changes: 8 additions & 0 deletions tests/integ/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ def test_create_delete(trial_obj):
assert trial_obj.trial_name


def test_create_tags(trial_obj, sagemaker_boto_client):
while True:
actual_tags = sagemaker_boto_client.list_tags(ResourceArn=trial_obj.trial_arn)["Tags"]
if actual_tags:
break
assert actual_tags == trial_obj.tags


def test_list(trials, sagemaker_boto_client):
slack = datetime.timedelta(minutes=1)
now = datetime.datetime.now(datetime.timezone.utc)
Expand Down
8 changes: 8 additions & 0 deletions tests/integ/test_trial_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def test_create_delete(trial_component_obj):
assert trial_component_obj.trial_component_name


def test_create_tags(trial_component_obj, sagemaker_boto_client):
while True:
actual_tags = sagemaker_boto_client.list_tags(ResourceArn=trial_component_obj.trial_component_arn)["Tags"]
if actual_tags:
break
assert actual_tags == trial_component_obj.tags


def test_save(trial_component_obj, sagemaker_boto_client):
trial_component_obj.display_name = str(uuid.uuid4())
trial_component_obj.status = api_types.TrialComponentStatus(primary_status="InProgress", message="Message")
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,26 @@ def test_load(sagemaker_boto_client):

def test_create(sagemaker_boto_client):
sagemaker_boto_client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
tags = {"Key": "foo", "Value": "bar"}
experiment_obj = experiment.Experiment.create(
experiment_name="name-value", sagemaker_boto_client=sagemaker_boto_client
)
assert experiment_obj.experiment_name == "name-value"
sagemaker_boto_client.create_experiment.assert_called_with(ExperimentName="name-value")


def test_create_with_tags(sagemaker_boto_client):
sagemaker_boto_client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
tags = [{"Key": "foo", "Value": "bar"}]
experiment_obj = experiment.Experiment.create(
experiment_name="name-value", sagemaker_boto_client=sagemaker_boto_client, tags=tags
)
assert experiment_obj.experiment_name == "name-value"
sagemaker_boto_client.create_experiment.assert_called_with(
ExperimentName="name-value", Tags=[{"Key": "foo", "Value": "bar"}]
)


def test_list(sagemaker_boto_client, datetime_obj):
sagemaker_boto_client.list_experiments.return_value = {
"ExperimentSummaries": [
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ def test_create(sagemaker_boto_client):
)


def test_create_with_tags(sagemaker_boto_client):
sagemaker_boto_client.create_trial.return_value = {
"Arn": "arn:aws:1234",
"TrialName": "name-value",
}
tags = [{"Key": "foo", "Value": "bar"}]
trial_obj = trial.Trial.create(
trial_name="name-value",
experiment_name="experiment-name-value",
sagemaker_boto_client=sagemaker_boto_client,
tags=tags,
)
assert trial_obj.trial_name == "name-value"
sagemaker_boto_client.create_trial.assert_called_with(
TrialName="name-value", ExperimentName="experiment-name-value", Tags=[{"Key": "foo", "Value": "bar"}]
)


def test_create_no_name(sagemaker_boto_client):
sagemaker_boto_client.create_trial.return_value = {}
trial.Trial.create(experiment_name="experiment-name-value", sagemaker_boto_client=sagemaker_boto_client)
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/test_trial_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def test_create(sagemaker_boto_client):
assert "bazz" == obj.trial_component_arn


def test_create_with_tags(sagemaker_boto_client):
sagemaker_boto_client.create_trial_component.return_value = {
"TrialComponentArn": "bazz",
}
tags = [{"Key": "foo", "Value": "bar"}]
obj = trial_component.TrialComponent.create(
trial_component_name="foo", display_name="bar", sagemaker_boto_client=sagemaker_boto_client, tags=tags
)
sagemaker_boto_client.create_trial_component.assert_called_with(
TrialComponentName="foo", DisplayName="bar", Tags=[{"Key": "foo", "Value": "bar"}]
)


def test_load(sagemaker_boto_client):
now = datetime.datetime.now(datetime.timezone.utc)

Expand Down

0 comments on commit 7a78b7f

Please sign in to comment.