diff --git a/config/tests.sh b/config/tests.sh index 664e9cfcf..d08aec9d6 100644 --- a/config/tests.sh +++ b/config/tests.sh @@ -20,6 +20,8 @@ run_for_framework() { python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} --durations=50 --html=$REPORT_DIR/report_$1.html -v -s --self-contained-html --ignore=tests/core/test_paths.py --ignore=tests/core/test_index_utils.py --ignore=tests/core/test_collections.py tests/$1 if [ "$1" = "mxnet" ] ; then python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/zero_code_change/test_mxnet_gluon_integration.py + # we run test/rules once, mxnet build has configured permission for sns to run this test + python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/rules elif [ "$1" = "pytorch" ] ; then python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/zero_code_change/test_pytorch_integration.py python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/zero_code_change/test_pytorch_multiprocessing.py diff --git a/smdebug/rules/action/__init__.py b/smdebug/rules/action/__init__.py new file mode 100644 index 000000000..2c0841275 --- /dev/null +++ b/smdebug/rules/action/__init__.py @@ -0,0 +1,2 @@ +# Local +from .action import Actions diff --git a/smdebug/rules/action/action.py b/smdebug/rules/action/action.py new file mode 100644 index 000000000..d7a4cd52a --- /dev/null +++ b/smdebug/rules/action/action.py @@ -0,0 +1,83 @@ +# Standard Library +import json + +# First Party +from smdebug.core.logger import get_logger + +# Local +from .message_action import MessageAction +from .stop_training_action import StopTrainingAction + +ALLOWED_ACTIONS = ["stoptraining", "sms", "email"] + + +class Actions: + def __init__(self, actions_str="", rule_name=""): + self._actions = [] + self._logger = get_logger() + actions_str = actions_str.strip() if actions_str is not None else "" + if actions_str == "": + self._logger.info(f"No action specified. Action str is {actions_str}") + return + self._register_actions(actions_str, rule_name) + + def _register_actions(self, actions_str="", rule_name=""): + + actions_str = actions_str.lower() + self._logger.info(f"Action string: {actions_str} and rule_name:{rule_name}") + action_json = json.loads(actions_str) + actions_list = [] + if isinstance(action_json, dict): + actions_list.append(action_json) + elif isinstance(action_json, list): + actions_list = action_json + else: + self._logger.info( + f"Action string: {actions_str}, expected either a list of dict or dict. Skipping action registering" + ) + return + + # action : {name:'StopTraining', 'training_job_prefix':''} + # {name:'sms or email', 'endpoint':''} + for action_dict in actions_list: + if not isinstance(action_dict, dict): + self._logger.info( + f"expected dictionary for action, got {action_dict} . Skipping this action." + ) + continue + if "name" in action_dict: + if action_dict["name"] == "stoptraining": + training_job_prefix = ( + action_dict["training_job_prefix"] + if "training_job_prefix" in action_dict + else None + ) + if training_job_prefix is None: + self._logger.info( + f"Action :{action_dict['name']} requires 'training_job_prefix' key to be specified. " + f"Action_dict is: {action_dict}" + ) + continue + + action = StopTrainingAction(rule_name, training_job_prefix) + self._actions.append(action) + elif action_dict["name"] == "sms" or action_dict["name"] == "email": + endpoint = action_dict["endpoint"] if "endpoint" in action_dict else None + if endpoint is None: + self._logger.info( + f"Action :{action_dict['name']} requires endpoint key parameter. " + ) + continue + + action = MessageAction(rule_name, action_dict["name"], endpoint) + self._actions.append(action) + + else: + self._logger.info( + f"Action :{action_dict['name']} not supported. Allowed action names are: {ALLOWED_ACTIONS}" + ) + + def invoke(self): + self._logger.info("Invoking actions") + for action in self._actions: + action.invoke() diff --git a/smdebug/rules/action/message_action.py b/smdebug/rules/action/message_action.py new file mode 100644 index 000000000..cdf9c39e2 --- /dev/null +++ b/smdebug/rules/action/message_action.py @@ -0,0 +1,145 @@ +# Standard Library +import json +import os + +# Third Party +import boto3 + +# First Party +from smdebug.core.logger import get_logger + +# action : +# {name:'sms' or 'email', 'endpoint':'phone or emailid'} + + +class MessageAction: + def __init__(self, rule_name, message_type, message_endpoint): + self._topic_name = "SMDebugRules" + self._logger = get_logger() + + if message_type == "sms" or message_type == "email": + self._protocol = message_type + else: + self._protocol = None + self._logger.info( + f"Unsupported message type:{message_type} in MessageAction. Returning" + ) + return + self._message_endpoint = message_endpoint + + # Below 2 is to help in tests + self._last_send_mesg_response = None + self._last_subscription_response = None + + env_region_name = os.getenv("AWS_REGION", "us-east-1") + + self._sns_client = boto3.client("sns", region_name=env_region_name) + + self._topic_arn = self._create_sns_topic_if_not_exists() + + self._subscribe_mesgtype_endpoint() + self._logger.info( + f"Registering messageAction with protocol:{self._protocol} endpoint:{self._message_endpoint} and topic_arn:{self._topic_arn} region:{env_region_name}" + ) + self._rule_name = rule_name + + def _create_sns_topic_if_not_exists(self): + try: + topic = self._sns_client.create_topic(Name=self._topic_name) + self._logger.info( + f"topic_name: {self._topic_name} , creating topic returned response:{topic}" + ) + if topic: + return topic["TopicArn"] + except Exception as e: + self._logger.info( + f"Caught exception while creating topic:{self._topic_name} exception is: \n {e}" + ) + return None + + def _check_subscriptions(self, topic_arn, protocol, endpoint): + try: + next_token = "random" + subs = self._sns_client.list_subscriptions() + sub_array = subs["Subscriptions"] + while next_token is not None: + for sub_dict in sub_array: + proto = sub_dict["Protocol"] + ep = sub_dict["Endpoint"] + topic = sub_dict["TopicArn"] + if proto == protocol and topic == topic_arn and ep == endpoint: + self._logger.info(f"Existing Subscription found: {sub_dict}") + return True + if "NextToken" in subs: + next_token = subs["NextToken"] + subs = self._sns_client.list_subscriptions(NextToken=next_token) + sub_array = subs["Subscriptions"] + continue + else: + next_token = None + except Exception as e: + self._logger.info( + f"Caught exception while list subscription topic:{self._topic_name} exception is: \n {e}" + ) + return False + + def _subscribe_mesgtype_endpoint(self): + + response = None + try: + + if self._topic_arn and self._protocol and self._message_endpoint: + filter_policy = {} + if self._protocol == "sms": + filter_policy["phone_num"] = [self._message_endpoint] + else: + filter_policy["email"] = [self._message_endpoint] + if not self._check_subscriptions( + self._topic_arn, self._protocol, self._message_endpoint + ): + + response = self._sns_client.subscribe( + TopicArn=self._topic_arn, + Protocol=self._protocol, # sms or email + Endpoint=self._message_endpoint, # phone number or email addresss + Attributes={"FilterPolicy": json.dumps(filter_policy)}, + ReturnSubscriptionArn=False, # True means always return ARN + ) + else: + response = f"Subscription exists for topic:{self._topic_arn}, protocol:{self._protocol}, endpoint:{self._message_endpoint}" + except Exception as e: + self._logger.info( + f"Caught exception while subscribing endpoint on topic:{self._topic_arn} exception is: \n {e}" + ) + self._logger.info(f"response for sns subscribe is {response} ") + self._last_subscription_response = response + + def _send_message(self, message): + response = None + message = f"SMDebugRule:{self._rule_name} fired. {message}" + try: + if self._protocol == "sms": + msg_attributes = { + "phone_num": {"DataType": "String", "StringValue": self._message_endpoint} + } + else: + msg_attributes = { + "email": {"DataType": "String", "StringValue": self._message_endpoint} + } + response = self._sns_client.publish( + TopicArn=self._topic_arn, + Message=message, + Subject=f"SMDebugRule:{self._rule_name} fired", + # MessageStructure='json', + MessageAttributes=msg_attributes, + ) + except Exception as e: + self._logger.info( + f"Caught exception while publishing message on topic:{self._topic_arn} exception is: \n {e}" + ) + self._logger.info(f"Response of send message:{response}") + self._last_send_mesg_response = response + return response + + def invoke(self, message=None): + self._send_message(message) diff --git a/smdebug/rules/action/stop_training_action.py b/smdebug/rules/action/stop_training_action.py new file mode 100644 index 000000000..53f1b5fdc --- /dev/null +++ b/smdebug/rules/action/stop_training_action.py @@ -0,0 +1,60 @@ +# Standard Library +import os + +# Third Party +import boto3 + +# First Party +from smdebug.core.logger import get_logger + + +class StopTrainingAction: + def __init__(self, rule_name, training_job_prefix): + self._training_job_prefix = training_job_prefix + env_region_name = os.getenv("AWS_REGION", "us-east-1") + self._logger = get_logger() + self._logger.info( + f"StopTrainingAction created with training_job_prefix:{training_job_prefix} and region:{env_region_name}" + ) + self._sm_client = boto3.client("sagemaker", region_name=env_region_name) + self._rule_name = rule_name + self._found_jobs = self._get_sm_tj_jobs_with_prefix() + + def _get_sm_tj_jobs_with_prefix(self): + found_jobs = [] + try: + jobs = self._sm_client.list_training_jobs() + if "TrainingJobSummaries" in jobs: + jobs = jobs["TrainingJobSummaries"] + else: + self._logger.info( + f"No TrainingJob summaries found: list_training_jobs output is : {jobs}" + ) + return + for job in jobs: + self._logger.info( + f"TrainingJob name: {job['TrainingJobName']} , status:{job['TrainingJobStatus']}" + ) + if job["TrainingJobName"] is not None and job["TrainingJobName"].startswith( + self._training_job_prefix + ): + found_jobs.append(job["TrainingJobName"]) + self._logger.info(f"found_training job {found_jobs}") + except Exception as e: + self._logger.info( + f"Caught exception while getting list_training_job exception is: \n {e}" + ) + return found_jobs + + def _stop_training_job(self): + if len(self._found_jobs) != 1: + return + self._logger.info(f"Invoking StopTrainingJob action on SM jobname:{self._found_jobs}") + try: + res = self._sm_client.stop_training_job(TrainingJobName=self._found_jobs[0]) + self._logger.info(f"Stop Training job response:{res}") + except Exception as e: + self._logger.info(f"Got exception while stopping training job{self._found_jobs[0]}:{e}") + + def invoke(self, message=None): + self._stop_training_job() diff --git a/smdebug/rules/rule.py b/smdebug/rules/rule.py index 062fb492d..1cebb9972 100644 --- a/smdebug/rules/rule.py +++ b/smdebug/rules/rule.py @@ -5,6 +5,7 @@ from smdebug.analysis.utils import no_refresh from smdebug.core.logger import get_logger from smdebug.exceptions import RuleEvaluationConditionMet +from smdebug.rules.action import Actions # Local from .req_tensors import RequiredTensors @@ -12,7 +13,7 @@ # This is Rule interface class Rule(ABC): - def __init__(self, base_trial, other_trials=None): + def __init__(self, base_trial, other_trials=None, action_str=""): self.base_trial = base_trial self.other_trials = other_trials @@ -24,6 +25,7 @@ def __init__(self, base_trial, other_trials=None): self.logger = get_logger() self.rule_name = self.__class__.__name__ + self._actions = Actions(actions_str=action_str, rule_name=self.rule_name) def set_required_tensors(self, step): pass @@ -55,4 +57,5 @@ def invoke(self, step): val = self.invoke_at_step(step) if val: + self._actions.invoke() raise RuleEvaluationConditionMet(self.rule_name, step) diff --git a/tests/rules/__init__.py b/tests/rules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/rules/action/__init__.py b/tests/rules/action/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/rules/action/test_message_action.py b/tests/rules/action/test_message_action.py new file mode 100644 index 000000000..b1b0d44a1 --- /dev/null +++ b/tests/rules/action/test_message_action.py @@ -0,0 +1,74 @@ +# First Party +from smdebug.rules.action.action import Actions +from smdebug.rules.action.message_action import MessageAction +from smdebug.rules.action.stop_training_action import StopTrainingAction + + +def test_action_stop_training_job(): + action_str = '{"name": "stoptraining" , "training_job_prefix":"training_prefix"}' + action = Actions(actions_str=action_str) + action.invoke() + + +def test_action_stop_training_job_invalid_params(): + action_str = '{"name": "stoptraining" , "invalid_job_prefix":"training_prefix"}' + action = Actions(actions_str=action_str) + action.invoke() + + +def test_action_sms(): + action_str = '{"name": "sms" , "endpoint":"+11234567890"}' + action = Actions(actions_str=action_str, rule_name="test_rule") + action.invoke() + sms_action = action._actions[0] + assert sms_action._last_subscription_response is not None + assert sms_action._last_send_mesg_response is not None + + +def test_action_sms_invalid_params(): + action_str = '{"name": "sms" , "invalid":"+11234567890"}' + action = Actions(actions_str=action_str, rule_name="test_rule") + action.invoke() + + +def test_action_email(): + action_str = '{"name": "email" , "endpoint":"abc@abc.com"}' + action = Actions(actions_str=action_str, rule_name="test_rule") + action.invoke() + email_action = action._actions[0] + assert email_action._last_subscription_response is not None + assert email_action._last_send_mesg_response is not None + + +def test_action_email_invalid_params(): + action_str = '{"name": "email" , "invalid":"abc@abc.com"}' + action = Actions(actions_str=action_str, rule_name="test_rule") + action.invoke() + + +def test_invalid_message_action(): + action_str = '{"name": "invalid" , "invalid":"abc@abc.com"}' + action = Actions(actions_str=action_str, rule_name="test_rule") + action.invoke() + + +def test_action_multiple(): + action_str = ( + '[{"name": "stoptraining" , "training_job_prefix":"training_prefix"}, {"name": "email" , ' + '"endpoint":"abc@abc.com"}] ' + ) + action = Actions(actions_str=action_str, rule_name="test_rule") + actions = action._actions + assert len(actions) == 2 + stop_action = actions[0] + email_action = actions[1] + assert isinstance(stop_action, StopTrainingAction) == True + assert isinstance(email_action, MessageAction) == True + + assert stop_action._training_job_prefix == "training_prefix" + assert email_action._protocol == "email" + assert email_action._topic_name == "SMDebugRules" + assert email_action._message_endpoint == "abc@abc.com" + assert email_action._rule_name == "test_rule" + + assert email_action._last_subscription_response is not None