From 24481783721994bc083d151e51e50aa304933522 Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 11:23:36 -0700 Subject: [PATCH 01/13] Adding action class Actions added: stop trianing job, email, sms --- smdebug/rules/action/__init__.py | 2 + smdebug/rules/action/action.py | 82 +++++++++++++++ smdebug/rules/action/message_action.py | 101 +++++++++++++++++++ smdebug/rules/action/stop_training_action.py | 58 +++++++++++ smdebug/rules/rule.py | 5 +- 5 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 smdebug/rules/action/__init__.py create mode 100644 smdebug/rules/action/action.py create mode 100644 smdebug/rules/action/message_action.py create mode 100644 smdebug/rules/action/stop_training_action.py diff --git a/smdebug/rules/action/__init__.py b/smdebug/rules/action/__init__.py new file mode 100644 index 000000000..2c0841275 --- /dev/null +++ b/smdebug/rules/action/__init__.py @@ -0,0 +1,2 @@ +# Local +from .action import Actions diff --git a/smdebug/rules/action/action.py b/smdebug/rules/action/action.py new file mode 100644 index 000000000..b5fdf3016 --- /dev/null +++ b/smdebug/rules/action/action.py @@ -0,0 +1,82 @@ +# Standard Library +import json + +# First Party +from smdebug.core.logger import get_logger + +# Local +from .message_action import MessageAction +from .stop_training_action import StopTrainingAction + +ALLOWED_ACTIONS = ["stoptraining", "sms", "email"] + + +class Actions: + def __init__(self, action_str="", rule_name=""): + self._actions = [] + self._logger = get_logger() + if action_str is None: + self._logger.info(f"No action specified. Action str is {action_str}") + return + self._register_actions(action_str) + + def _register_actions(self, action_str="", rule_name=""): + + action_str = action_str.lower() + self._logger.info(f"Action string: {action_str} and rule_name:{rule_name}") + action_json = json.loads(action_str) + actions_list = [] + if isinstance(action_json, dict): + actions_list.append(action_json) + elif isinstance(action_json, list): + actions_list = action_json + else: + self._logger.info( + f"Action string: {action_str}, expected either a list of dict or dict. Skipping action registering" + ) + return + + # action : {name:'StopTraining', 'training_job_prefix':''} + # {name:'sms or email', 'endpoint':''} + for action_dict in actions_list: + if not isinstance(action_dict, dict): + self._logger.info( + f"expected dictionary for action, got {action_dict} . Skipping this action." + ) + continue + if "name" in action_dict: + if action_dict["name"] == "stoptraining": + training_job_prefix = ( + action_dict["training_job_prefix"] + if "training_job_prefix" in action_dict + else None + ) + if training_job_prefix is None: + self._logger.info( + f"Action :{action_dict['name']} requires 'training_job_prefix' key to be specified. " + f"Action_dict is: {action_dict}" + ) + continue + + action = StopTrainingAction(rule_name, training_job_prefix) + self._actions.append(action) + elif action_dict["name"] == "sms" or action_dict["name"] == "email": + endpoint = action_dict["endpoint"] if "endpoint" in action_dict else None + if endpoint is None: + self._logger.info( + f"Action :{action_dict['name']} requires endpoint key parameter. " + ) + continue + + action = MessageAction(rule_name, action_dict["name"], endpoint) + self._actions.append(action) + + else: + self._logger.info( + f"Action :{action_dict['name']} not supported. Allowed action names are: {ALLOWED_ACTIONS}" + ) + + def invoke(self): + self._logger.invoke("Invoking actions") + for action in self._actions: + action.invoke() diff --git a/smdebug/rules/action/message_action.py b/smdebug/rules/action/message_action.py new file mode 100644 index 000000000..837c05dfc --- /dev/null +++ b/smdebug/rules/action/message_action.py @@ -0,0 +1,101 @@ +# Standard Library +import os + +# Third Party +import boto3 + +# First Party +from smdebug.core.logger import get_logger + +# action : +# {name:'sms' or 'email', 'endpoint':'phone or emailid'} + + +class MessageAction: + def __init__(self, rule_name, message_type, message_endpoint): + self._topic_name = "SMDebugRules" + if message_type == "sms" or message_type == "email": + self._protocol = message_type + else: + self._protocol = None + # TODO log unsupported message type + return + self._message_endpoint = message_endpoint + self._logger = get_logger() + self._topic_arn = self._create_sns_topic_if_not_exists() + + self._subscribe_mesgtype_endpoint() + # TODO log debug topic arn , protocol, mesg endpoint + self._logger.info( + f"Registering MessageAction with protocol:{self._protocol} endpoint:{self._message_endpoint} and topic_arn:{self._topic_arn} " + ) + + env_region_name = os.environ["AWS_REGION"] + self._sns_client = boto3.client("sns", region_name=env_region_name) + self._rule_name = rule_name + + def _create_sns_topic_if_not_exists(self): + topic = self._sns_client.create_topic(Name=self.topic_name) + # TODO log info print topic + self._logger.info( + f"topic_name: {self._topic_name} , creating topic returned response:{topic}" + ) + if topic: + return topic.arn + return None + + def _subscribe_mesgtype_endpoint(self): + + response = None + try: + if self._topic_arn and self._protocol and self._message_endpoint: + filter_policy = {} + if self._protocol == "sms": + filter_policy["phone_num"] = [self._message_endpoint] + else: + filter_policy["email"] = [self._message_endpoint] + response = self._sns_client.subscribe( + TopicArn=self._topic_arn, + Protocol=self._protocol, # sms or email + Endpoint=self._message_endpoint, # phone number or email addresss + Attributes={ + "FilterPolicy": str( + filter_policy + ) # FilterPolicy {"phone_num": [ "+16693008439" ]} + }, + ReturnSubscriptionArn=False, # True means always return ARN + ) + except Exception as e: + self._logger.info( + f"Caught exception while subscribing endpoint on topic:{self._topic_arn} exception is: \n {e}" + ) + self._logger.info(f"response for sns subscribe is {response} ") + + def _send_message(self, message): + response = None + try: + if self._protocol == "sms": + msg_attributes = { + "phone_num": {"DataType": "string", "StringValue": self._message_endpoint} + } + else: + msg_attributes = { + "email": {"DataType": "string", "StringValue": self._message_endpoint} + } + response = self._sns_client.publish( + TopicArn=self._topic_arn, + # TargetArn='string', # this or Topic or Phone + # PhoneNumber='string', + Message=message, + Subject="string", + # MessageStructure='json', + MessageAttributes=msg_attributes, + ) + except Exception as e: + self._logger.info( + f"Caught exception while getting publishing message on topic:{self._topic_arn} exception is: \n {e}" + ) + self._logger.info(f"Response of send message:{response}") + + def invoke(self, message=None): + self._send_message(message) diff --git a/smdebug/rules/action/stop_training_action.py b/smdebug/rules/action/stop_training_action.py new file mode 100644 index 000000000..4c5136693 --- /dev/null +++ b/smdebug/rules/action/stop_training_action.py @@ -0,0 +1,58 @@ +# Standard Library +import os + +# Third Party +import boto3 + +# First Party +from smdebug.core.logger import get_logger + + +class StopTrainingAction: + def __init__(self, rule_name, training_job_prefix): + self._training_job_prefix = training_job_prefix + env_region_name = os.environ["AWS_REGION"] + self._sm_client = boto3.client("sns", region_name=env_region_name) + self._logger = get_logger() + self._rule_name = rule_name + self._found_jobs = self._get_sm_tj_jobs_with_prefix(training_job_prefix) + # TODO log debug topic arn , protocol, mesg endpoint + + def _get_sm_tj_jobs_with_prefix(self): + found_jobs = [] + try: + jobs = self._sm_client.list_training_jobs() + if "TrainingJobSummaries" in jobs: + jobs = jobs["TrainingJobSummaries"] + else: + self._logger.info( + f"No TrainingJob summaries found: list_training_jobs output is : {jobs}" + ) + return + for job in jobs: + self._logger.info( + f"TrainingJob name: {job['TrainingJobName']} , status:{job['TrainingJobStatus']}" + ) + if job["TrainingJobName"] is not None and job["TrainingJobName"].startswith( + self._training_job_prefix + ): + found_jobs.append(job["TrainingJobName"]) + self._logger.info(f"found_training job {found_jobs}") + except Exception as e: + self._logger.info( + f"Caught exception while getting list_training_job exception is: \n {e}" + ) + return found_jobs + + def _stop_training_job(self): + if len(self._found_jobs) != 1: + return + self._logger.info(f"Invoking StopTrainingJob action on SM jobname:{self._found_jobs}") + try: + res = self._sm_client.stop_training_job(TrainingJobName=self._found_jobs[0]) + self._logger.info(f"Stop Training job response:{res}") + except Exception as e: + self._logger.info(f"Got exception while stopping training job{self._found_jobs[0]}:{e}") + + def invoke(self, message=None): + self._stop_training_job(self._found_jobs) diff --git a/smdebug/rules/rule.py b/smdebug/rules/rule.py index 062fb492d..74fa42ad6 100644 --- a/smdebug/rules/rule.py +++ b/smdebug/rules/rule.py @@ -5,6 +5,7 @@ from smdebug.analysis.utils import no_refresh from smdebug.core.logger import get_logger from smdebug.exceptions import RuleEvaluationConditionMet +from smdebug.rules.action import Actions # Local from .req_tensors import RequiredTensors @@ -12,7 +13,7 @@ # This is Rule interface class Rule(ABC): - def __init__(self, base_trial, other_trials=None): + def __init__(self, base_trial, other_trials=None, action_str=""): self.base_trial = base_trial self.other_trials = other_trials @@ -24,6 +25,7 @@ def __init__(self, base_trial, other_trials=None): self.logger = get_logger() self.rule_name = self.__class__.__name__ + self._actions = Actions(self.rule_name, action_str) def set_required_tensors(self, step): pass @@ -55,4 +57,5 @@ def invoke(self, step): val = self.invoke_at_step(step) if val: + self._actions.invoke() raise RuleEvaluationConditionMet(self.rule_name, step) From 679b00dfb596988792e3ec7be4c727bbfca910e9 Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 11:36:35 -0700 Subject: [PATCH 02/13] handling empty json string --- smdebug/rules/action/action.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/smdebug/rules/action/action.py b/smdebug/rules/action/action.py index b5fdf3016..2f80c88d1 100644 --- a/smdebug/rules/action/action.py +++ b/smdebug/rules/action/action.py @@ -15,10 +15,11 @@ class Actions: def __init__(self, action_str="", rule_name=""): self._actions = [] self._logger = get_logger() - if action_str is None: + action_str = action_str.strip() if action_str is not None else "" + if action_str == "": self._logger.info(f"No action specified. Action str is {action_str}") return - self._register_actions(action_str) + self._register_actions(action_str, rule_name) def _register_actions(self, action_str="", rule_name=""): From ecd939bf696190f2177a65adf8e4b059b7f352f2 Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 11:46:43 -0700 Subject: [PATCH 03/13] Fixing a bug with parameters passing --- smdebug/rules/rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smdebug/rules/rule.py b/smdebug/rules/rule.py index 74fa42ad6..46ebcdffe 100644 --- a/smdebug/rules/rule.py +++ b/smdebug/rules/rule.py @@ -25,7 +25,7 @@ def __init__(self, base_trial, other_trials=None, action_str=""): self.logger = get_logger() self.rule_name = self.__class__.__name__ - self._actions = Actions(self.rule_name, action_str) + self._actions = Actions(action_str=action_str, rule_name=self.rule_name) def set_required_tensors(self, step): pass From fe8f77d6dde2ffd3037096b44a92bd55a394b24b Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 11:54:36 -0700 Subject: [PATCH 04/13] bug fix --- smdebug/rules/action/action.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smdebug/rules/action/action.py b/smdebug/rules/action/action.py index 2f80c88d1..82f2bfaac 100644 --- a/smdebug/rules/action/action.py +++ b/smdebug/rules/action/action.py @@ -78,6 +78,6 @@ def _register_actions(self, action_str="", rule_name=""): ) def invoke(self): - self._logger.invoke("Invoking actions") + self._logger.info("Invoking actions") for action in self._actions: action.invoke() From 25a69b3f253c6dffe007c530c71b73c9946983ed Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 12:19:15 -0700 Subject: [PATCH 05/13] some logs and client fix --- smdebug/rules/action/message_action.py | 5 ++--- smdebug/rules/action/stop_training_action.py | 7 +++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/smdebug/rules/action/message_action.py b/smdebug/rules/action/message_action.py index 837c05dfc..ad0260fe0 100644 --- a/smdebug/rules/action/message_action.py +++ b/smdebug/rules/action/message_action.py @@ -25,12 +25,11 @@ def __init__(self, rule_name, message_type, message_endpoint): self._topic_arn = self._create_sns_topic_if_not_exists() self._subscribe_mesgtype_endpoint() - # TODO log debug topic arn , protocol, mesg endpoint + env_region_name = os.environ["AWS_REGION"] or "us-east-1" self._logger.info( - f"Registering MessageAction with protocol:{self._protocol} endpoint:{self._message_endpoint} and topic_arn:{self._topic_arn} " + f"Registering MessageAction with protocol:{self._protocol} endpoint:{self._message_endpoint} and topic_arn:{self._topic_arn} region:{env_region_name}" ) - env_region_name = os.environ["AWS_REGION"] self._sns_client = boto3.client("sns", region_name=env_region_name) self._rule_name = rule_name diff --git a/smdebug/rules/action/stop_training_action.py b/smdebug/rules/action/stop_training_action.py index 4c5136693..9c15831c8 100644 --- a/smdebug/rules/action/stop_training_action.py +++ b/smdebug/rules/action/stop_training_action.py @@ -11,9 +11,12 @@ class StopTrainingAction: def __init__(self, rule_name, training_job_prefix): self._training_job_prefix = training_job_prefix - env_region_name = os.environ["AWS_REGION"] - self._sm_client = boto3.client("sns", region_name=env_region_name) + env_region_name = os.environ["AWS_REGION"] or "us-east-1" self._logger = get_logger() + self._logger.info( + f"StopTrainingAction created with training_job_prefix:{training_job_prefix} and region:{env_region_name}" + ) + self._sm_client = boto3.client("sagemaker", region_name=env_region_name) self._rule_name = rule_name self._found_jobs = self._get_sm_tj_jobs_with_prefix(training_job_prefix) # TODO log debug topic arn , protocol, mesg endpoint From d9067aead2fbe516ccf00177091c72b30326af63 Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 12:21:16 -0700 Subject: [PATCH 06/13] some more fixes --- smdebug/rules/action/stop_training_action.py | 1 - 1 file changed, 1 deletion(-) diff --git a/smdebug/rules/action/stop_training_action.py b/smdebug/rules/action/stop_training_action.py index 9c15831c8..f0b375d88 100644 --- a/smdebug/rules/action/stop_training_action.py +++ b/smdebug/rules/action/stop_training_action.py @@ -19,7 +19,6 @@ def __init__(self, rule_name, training_job_prefix): self._sm_client = boto3.client("sagemaker", region_name=env_region_name) self._rule_name = rule_name self._found_jobs = self._get_sm_tj_jobs_with_prefix(training_job_prefix) - # TODO log debug topic arn , protocol, mesg endpoint def _get_sm_tj_jobs_with_prefix(self): found_jobs = [] From dbc15a3b8ee2587b84186215b53e0b841bd822db Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 16:49:17 -0700 Subject: [PATCH 07/13] Adding tests and some fixes --- smdebug/rules/action/message_action.py | 84 ++++++++++++++------ smdebug/rules/action/stop_training_action.py | 6 +- tests/rules/__init__.py | 0 tests/rules/action/__init__.py | 0 4 files changed, 63 insertions(+), 27 deletions(-) create mode 100644 tests/rules/__init__.py create mode 100644 tests/rules/action/__init__.py diff --git a/smdebug/rules/action/message_action.py b/smdebug/rules/action/message_action.py index ad0260fe0..4b5d88777 100644 --- a/smdebug/rules/action/message_action.py +++ b/smdebug/rules/action/message_action.py @@ -1,4 +1,5 @@ # Standard Library +import json import os # Third Party @@ -14,87 +15,122 @@ class MessageAction: def __init__(self, rule_name, message_type, message_endpoint): self._topic_name = "SMDebugRules" + self._logger = get_logger() + if message_type == "sms" or message_type == "email": self._protocol = message_type else: self._protocol = None - # TODO log unsupported message type + self._logger.info( + f"Unsupported message type:{message_type} in MessageAction. Returning" + ) return self._message_endpoint = message_endpoint - self._logger = get_logger() + + # Below 2 is to help in tests + self._last_send_mesg_response = None + self._last_subscription_response = None + + env_region_name = os.getenv("AWS_REGION", "us-east-1") + + self._sns_client = boto3.client("sns", region_name=env_region_name) + self._topic_arn = self._create_sns_topic_if_not_exists() self._subscribe_mesgtype_endpoint() - env_region_name = os.environ["AWS_REGION"] or "us-east-1" self._logger.info( - f"Registering MessageAction with protocol:{self._protocol} endpoint:{self._message_endpoint} and topic_arn:{self._topic_arn} region:{env_region_name}" + f"Registering messageAction with protocol:{self._protocol} endpoint:{self._message_endpoint} and topic_arn:{self._topic_arn} region:{env_region_name}" ) - - self._sns_client = boto3.client("sns", region_name=env_region_name) self._rule_name = rule_name def _create_sns_topic_if_not_exists(self): - topic = self._sns_client.create_topic(Name=self.topic_name) - # TODO log info print topic + topic = self._sns_client.create_topic(Name=self._topic_name) self._logger.info( f"topic_name: {self._topic_name} , creating topic returned response:{topic}" ) if topic: - return topic.arn + return topic["TopicArn"] return None + def _check_subscriptions(self, topic_arn, protocol, endpoint): + next_token = "random" + subs = self._sns_client.list_subscriptions() + sub_array = subs["Subscriptions"] + while next_token is not None: + for sub_dict in sub_array: + proto = sub_dict["Protocol"] + ep = sub_dict["Endpoint"] + topic = sub_dict["TopicArn"] + if proto == protocol and topic == topic_arn and ep == endpoint: + return True + if "NextToken" in subs: + next_token = subs["NextToken"] + subs = self._sns_client.list_subscriptions(NextToken=next_token) + sub_array = subs["Subscriptions"] + continue + else: + next_token = None + return False + def _subscribe_mesgtype_endpoint(self): response = None try: + if self._topic_arn and self._protocol and self._message_endpoint: filter_policy = {} if self._protocol == "sms": filter_policy["phone_num"] = [self._message_endpoint] else: filter_policy["email"] = [self._message_endpoint] - response = self._sns_client.subscribe( - TopicArn=self._topic_arn, - Protocol=self._protocol, # sms or email - Endpoint=self._message_endpoint, # phone number or email addresss - Attributes={ - "FilterPolicy": str( - filter_policy - ) # FilterPolicy {"phone_num": [ "+16693008439" ]} - }, - ReturnSubscriptionArn=False, # True means always return ARN - ) + if not self._check_subscriptions( + self._topic_arn, self._protocol, self._message_endpoint + ): + + response = self._sns_client.subscribe( + TopicArn=self._topic_arn, + Protocol=self._protocol, # sms or email + Endpoint=self._message_endpoint, # phone number or email addresss + Attributes={"FilterPolicy": json.dumps(filter_policy)}, + ReturnSubscriptionArn=False, # True means always return ARN + ) + else: + response = f"Subscription exists for topic:{self._topic_arn}, protocol:{self._protocol}, endpoint:{self._message_endpoint}" except Exception as e: self._logger.info( f"Caught exception while subscribing endpoint on topic:{self._topic_arn} exception is: \n {e}" ) self._logger.info(f"response for sns subscribe is {response} ") + self._last_subscription_response = response def _send_message(self, message): response = None + message = f"SMDebugRule:{self._rule_name} fired. {message}" try: if self._protocol == "sms": msg_attributes = { - "phone_num": {"DataType": "string", "StringValue": self._message_endpoint} + "phone_num": {"DataType": "String", "StringValue": self._message_endpoint} } else: msg_attributes = { - "email": {"DataType": "string", "StringValue": self._message_endpoint} + "email": {"DataType": "String", "StringValue": self._message_endpoint} } response = self._sns_client.publish( TopicArn=self._topic_arn, # TargetArn='string', # this or Topic or Phone # PhoneNumber='string', Message=message, - Subject="string", + Subject=f"SMDebugRule:{self._rule_name} fired", # MessageStructure='json', MessageAttributes=msg_attributes, ) except Exception as e: self._logger.info( - f"Caught exception while getting publishing message on topic:{self._topic_arn} exception is: \n {e}" + f"Caught exception while publishing message on topic:{self._topic_arn} exception is: \n {e}" ) self._logger.info(f"Response of send message:{response}") + self._last_send_mesg_response = response + return response def invoke(self, message=None): self._send_message(message) diff --git a/smdebug/rules/action/stop_training_action.py b/smdebug/rules/action/stop_training_action.py index f0b375d88..53f1b5fdc 100644 --- a/smdebug/rules/action/stop_training_action.py +++ b/smdebug/rules/action/stop_training_action.py @@ -11,14 +11,14 @@ class StopTrainingAction: def __init__(self, rule_name, training_job_prefix): self._training_job_prefix = training_job_prefix - env_region_name = os.environ["AWS_REGION"] or "us-east-1" + env_region_name = os.getenv("AWS_REGION", "us-east-1") self._logger = get_logger() self._logger.info( f"StopTrainingAction created with training_job_prefix:{training_job_prefix} and region:{env_region_name}" ) self._sm_client = boto3.client("sagemaker", region_name=env_region_name) self._rule_name = rule_name - self._found_jobs = self._get_sm_tj_jobs_with_prefix(training_job_prefix) + self._found_jobs = self._get_sm_tj_jobs_with_prefix() def _get_sm_tj_jobs_with_prefix(self): found_jobs = [] @@ -57,4 +57,4 @@ def _stop_training_job(self): self._logger.info(f"Got exception while stopping training job{self._found_jobs[0]}:{e}") def invoke(self, message=None): - self._stop_training_job(self._found_jobs) + self._stop_training_job() diff --git a/tests/rules/__init__.py b/tests/rules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/rules/action/__init__.py b/tests/rules/action/__init__.py new file mode 100644 index 000000000..e69de29bb From 08e3caeaaf071b1a0f115fe572c9931d4b0deb49 Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 16:52:06 -0700 Subject: [PATCH 08/13] Extra logging mesg --- smdebug/rules/action/message_action.py | 1 + 1 file changed, 1 insertion(+) diff --git a/smdebug/rules/action/message_action.py b/smdebug/rules/action/message_action.py index 4b5d88777..3d758c8c8 100644 --- a/smdebug/rules/action/message_action.py +++ b/smdebug/rules/action/message_action.py @@ -62,6 +62,7 @@ def _check_subscriptions(self, topic_arn, protocol, endpoint): ep = sub_dict["Endpoint"] topic = sub_dict["TopicArn"] if proto == protocol and topic == topic_arn and ep == endpoint: + self._logger.info(f"Existing Subscription found: {sub_dict}") return True if "NextToken" in subs: next_token = subs["NextToken"] From d7dcd0e41dd66e73ad9b37f6f982defa26e4e071 Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 16:52:49 -0700 Subject: [PATCH 09/13] Adding test file --- tests/rules/action/test_message_action.py | 74 +++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/rules/action/test_message_action.py diff --git a/tests/rules/action/test_message_action.py b/tests/rules/action/test_message_action.py new file mode 100644 index 000000000..c970e3aeb --- /dev/null +++ b/tests/rules/action/test_message_action.py @@ -0,0 +1,74 @@ +# First Party +from smdebug.rules.action.action import Actions +from smdebug.rules.action.message_action import MessageAction +from smdebug.rules.action.stop_training_action import StopTrainingAction + + +def test_action_stop_training_job(): + action_str = '{"name": "stoptraining" , "training_job_prefix":"training_prefix"}' + action = Actions(action_str=action_str) + action.invoke() + + +def test_action_stop_training_job_invalid_params(): + action_str = '{"name": "stoptraining" , "invalid_job_prefix":"training_prefix"}' + action = Actions(action_str=action_str) + action.invoke() + + +def test_action_sms(): + action_str = '{"name": "sms" , "endpoint":"+11234567890"}' + action = Actions(action_str=action_str, rule_name="test_rule") + action.invoke() + sms_action = action._actions[0] + assert sms_action._last_subscription_response is not None + assert sms_action._last_send_mesg_response is not None + + +def test_action_sms_invalid_params(): + action_str = '{"name": "sms" , "invalid":"+11234567890"}' + action = Actions(action_str=action_str, rule_name="test_rule") + action.invoke() + + +def test_action_email(): + action_str = '{"name": "email" , "endpoint":"abc@abc.com"}' + action = Actions(action_str=action_str, rule_name="test_rule") + action.invoke() + email_action = action._actions[0] + assert email_action._last_subscription_response is not None + assert email_action._last_send_mesg_response is not None + + +def test_action_email_invalid_params(): + action_str = '{"name": "email" , "invalid":"abc@abc.com"}' + action = Actions(action_str=action_str, rule_name="test_rule") + action.invoke() + + +def test_invalid_message_action(): + action_str = '{"name": "invalid" , "invalid":"abc@abc.com"}' + action = Actions(action_str=action_str, rule_name="test_rule") + action.invoke() + + +def test_action_multiple(): + action_str = ( + '[{"name": "stoptraining" , "training_job_prefix":"training_prefix"}, {"name": "email" , ' + '"endpoint":"abc@abc.com"}] ' + ) + action = Actions(action_str=action_str, rule_name="test_rule") + actions = action._actions + assert len(actions) == 2 + stop_action = actions[0] + email_action = actions[1] + assert isinstance(stop_action, StopTrainingAction) == True + assert isinstance(email_action, MessageAction) == True + + assert stop_action._training_job_prefix == "training_prefix" + assert email_action._protocol == "email" + assert email_action._topic_name == "SMDebugRules" + assert email_action._message_endpoint == "abc@abc.com" + assert email_action._rule_name == "test_rule" + + assert email_action._last_subscription_response is not None From 8e8f238c5b5521a11a65beaa8224706e0217dcf5 Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 17:32:28 -0700 Subject: [PATCH 10/13] Adding rules test in CI --- config/tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/config/tests.sh b/config/tests.sh index 664e9cfcf..d4e1048f4 100644 --- a/config/tests.sh +++ b/config/tests.sh @@ -52,6 +52,7 @@ python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} -v -W=ignore -- run_for_framework core run_for_framework profiler +run_for_framework rules if [ "$run_pytest_xgboost" = "enable" ] ; then run_for_framework xgboost From e877df3240f84fa53ec58dcb61bc15c2e1f4e859 Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 17:43:24 -0700 Subject: [PATCH 11/13] some exception handling --- smdebug/rules/action/message_action.py | 58 +++++++++++++++----------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/smdebug/rules/action/message_action.py b/smdebug/rules/action/message_action.py index 3d758c8c8..e08984838 100644 --- a/smdebug/rules/action/message_action.py +++ b/smdebug/rules/action/message_action.py @@ -44,33 +44,43 @@ def __init__(self, rule_name, message_type, message_endpoint): self._rule_name = rule_name def _create_sns_topic_if_not_exists(self): - topic = self._sns_client.create_topic(Name=self._topic_name) - self._logger.info( - f"topic_name: {self._topic_name} , creating topic returned response:{topic}" - ) - if topic: - return topic["TopicArn"] + try: + topic = self._sns_client.create_topic(Name=self._topic_name) + self._logger.info( + f"topic_name: {self._topic_name} , creating topic returned response:{topic}" + ) + if topic: + return topic["TopicArn"] + except Exception as e: + self._logger.info( + f"Caught exception while creating topic:{self._topic_name} exception is: \n {e}" + ) return None def _check_subscriptions(self, topic_arn, protocol, endpoint): - next_token = "random" - subs = self._sns_client.list_subscriptions() - sub_array = subs["Subscriptions"] - while next_token is not None: - for sub_dict in sub_array: - proto = sub_dict["Protocol"] - ep = sub_dict["Endpoint"] - topic = sub_dict["TopicArn"] - if proto == protocol and topic == topic_arn and ep == endpoint: - self._logger.info(f"Existing Subscription found: {sub_dict}") - return True - if "NextToken" in subs: - next_token = subs["NextToken"] - subs = self._sns_client.list_subscriptions(NextToken=next_token) - sub_array = subs["Subscriptions"] - continue - else: - next_token = None + try: + next_token = "random" + subs = self._sns_client.list_subscriptions() + sub_array = subs["Subscriptions"] + while next_token is not None: + for sub_dict in sub_array: + proto = sub_dict["Protocol"] + ep = sub_dict["Endpoint"] + topic = sub_dict["TopicArn"] + if proto == protocol and topic == topic_arn and ep == endpoint: + self._logger.info(f"Existing Subscription found: {sub_dict}") + return True + if "NextToken" in subs: + next_token = subs["NextToken"] + subs = self._sns_client.list_subscriptions(NextToken=next_token) + sub_array = subs["Subscriptions"] + continue + else: + next_token = None + except Exception as e: + self._logger.info( + f"Caught exception while list subscription topic:{self._topic_name} exception is: \n {e}" + ) return False def _subscribe_mesgtype_endpoint(self): From 1dc3e0deca5ba9433237e321bedb0de527adaacc Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Tue, 30 Jun 2020 18:08:50 -0700 Subject: [PATCH 12/13] Adding test/rules to mxnet build which has reqd permission for sns --- config/tests.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/config/tests.sh b/config/tests.sh index d4e1048f4..d08aec9d6 100644 --- a/config/tests.sh +++ b/config/tests.sh @@ -20,6 +20,8 @@ run_for_framework() { python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} --durations=50 --html=$REPORT_DIR/report_$1.html -v -s --self-contained-html --ignore=tests/core/test_paths.py --ignore=tests/core/test_index_utils.py --ignore=tests/core/test_collections.py tests/$1 if [ "$1" = "mxnet" ] ; then python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/zero_code_change/test_mxnet_gluon_integration.py + # we run test/rules once, mxnet build has configured permission for sns to run this test + python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/rules elif [ "$1" = "pytorch" ] ; then python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/zero_code_change/test_pytorch_integration.py python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/zero_code_change/test_pytorch_multiprocessing.py @@ -52,7 +54,6 @@ python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} -v -W=ignore -- run_for_framework core run_for_framework profiler -run_for_framework rules if [ "$run_pytest_xgboost" = "enable" ] ; then run_for_framework xgboost From c0b8f74bcfe47277f5bc8ebc2bfc942a458cf712 Mon Sep 17 00:00:00 2001 From: Vikas Kumar Date: Mon, 6 Jul 2020 13:50:28 -0700 Subject: [PATCH 13/13] addressed review comments --- smdebug/rules/action/action.py | 20 ++++++++++---------- smdebug/rules/action/message_action.py | 2 -- smdebug/rules/rule.py | 2 +- tests/rules/action/test_message_action.py | 16 ++++++++-------- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/smdebug/rules/action/action.py b/smdebug/rules/action/action.py index 82f2bfaac..d7a4cd52a 100644 --- a/smdebug/rules/action/action.py +++ b/smdebug/rules/action/action.py @@ -12,20 +12,20 @@ class Actions: - def __init__(self, action_str="", rule_name=""): + def __init__(self, actions_str="", rule_name=""): self._actions = [] self._logger = get_logger() - action_str = action_str.strip() if action_str is not None else "" - if action_str == "": - self._logger.info(f"No action specified. Action str is {action_str}") + actions_str = actions_str.strip() if actions_str is not None else "" + if actions_str == "": + self._logger.info(f"No action specified. Action str is {actions_str}") return - self._register_actions(action_str, rule_name) + self._register_actions(actions_str, rule_name) - def _register_actions(self, action_str="", rule_name=""): + def _register_actions(self, actions_str="", rule_name=""): - action_str = action_str.lower() - self._logger.info(f"Action string: {action_str} and rule_name:{rule_name}") - action_json = json.loads(action_str) + actions_str = actions_str.lower() + self._logger.info(f"Action string: {actions_str} and rule_name:{rule_name}") + action_json = json.loads(actions_str) actions_list = [] if isinstance(action_json, dict): actions_list.append(action_json) @@ -33,7 +33,7 @@ def _register_actions(self, action_str="", rule_name=""): actions_list = action_json else: self._logger.info( - f"Action string: {action_str}, expected either a list of dict or dict. Skipping action registering" + f"Action string: {actions_str}, expected either a list of dict or dict. Skipping action registering" ) return diff --git a/smdebug/rules/action/message_action.py b/smdebug/rules/action/message_action.py index e08984838..cdf9c39e2 100644 --- a/smdebug/rules/action/message_action.py +++ b/smdebug/rules/action/message_action.py @@ -128,8 +128,6 @@ def _send_message(self, message): } response = self._sns_client.publish( TopicArn=self._topic_arn, - # TargetArn='string', # this or Topic or Phone - # PhoneNumber='string', Message=message, Subject=f"SMDebugRule:{self._rule_name} fired", # MessageStructure='json', diff --git a/smdebug/rules/rule.py b/smdebug/rules/rule.py index 46ebcdffe..1cebb9972 100644 --- a/smdebug/rules/rule.py +++ b/smdebug/rules/rule.py @@ -25,7 +25,7 @@ def __init__(self, base_trial, other_trials=None, action_str=""): self.logger = get_logger() self.rule_name = self.__class__.__name__ - self._actions = Actions(action_str=action_str, rule_name=self.rule_name) + self._actions = Actions(actions_str=action_str, rule_name=self.rule_name) def set_required_tensors(self, step): pass diff --git a/tests/rules/action/test_message_action.py b/tests/rules/action/test_message_action.py index c970e3aeb..b1b0d44a1 100644 --- a/tests/rules/action/test_message_action.py +++ b/tests/rules/action/test_message_action.py @@ -6,19 +6,19 @@ def test_action_stop_training_job(): action_str = '{"name": "stoptraining" , "training_job_prefix":"training_prefix"}' - action = Actions(action_str=action_str) + action = Actions(actions_str=action_str) action.invoke() def test_action_stop_training_job_invalid_params(): action_str = '{"name": "stoptraining" , "invalid_job_prefix":"training_prefix"}' - action = Actions(action_str=action_str) + action = Actions(actions_str=action_str) action.invoke() def test_action_sms(): action_str = '{"name": "sms" , "endpoint":"+11234567890"}' - action = Actions(action_str=action_str, rule_name="test_rule") + action = Actions(actions_str=action_str, rule_name="test_rule") action.invoke() sms_action = action._actions[0] assert sms_action._last_subscription_response is not None @@ -27,13 +27,13 @@ def test_action_sms(): def test_action_sms_invalid_params(): action_str = '{"name": "sms" , "invalid":"+11234567890"}' - action = Actions(action_str=action_str, rule_name="test_rule") + action = Actions(actions_str=action_str, rule_name="test_rule") action.invoke() def test_action_email(): action_str = '{"name": "email" , "endpoint":"abc@abc.com"}' - action = Actions(action_str=action_str, rule_name="test_rule") + action = Actions(actions_str=action_str, rule_name="test_rule") action.invoke() email_action = action._actions[0] assert email_action._last_subscription_response is not None @@ -42,13 +42,13 @@ def test_action_email(): def test_action_email_invalid_params(): action_str = '{"name": "email" , "invalid":"abc@abc.com"}' - action = Actions(action_str=action_str, rule_name="test_rule") + action = Actions(actions_str=action_str, rule_name="test_rule") action.invoke() def test_invalid_message_action(): action_str = '{"name": "invalid" , "invalid":"abc@abc.com"}' - action = Actions(action_str=action_str, rule_name="test_rule") + action = Actions(actions_str=action_str, rule_name="test_rule") action.invoke() @@ -57,7 +57,7 @@ def test_action_multiple(): '[{"name": "stoptraining" , "training_job_prefix":"training_prefix"}, {"name": "email" , ' '"endpoint":"abc@abc.com"}] ' ) - action = Actions(action_str=action_str, rule_name="test_rule") + action = Actions(actions_str=action_str, rule_name="test_rule") actions = action._actions assert len(actions) == 2 stop_action = actions[0]