Skip to content
2 changes: 2 additions & 0 deletions config/tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ run_for_framework() {
python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} --durations=50 --html=$REPORT_DIR/report_$1.html -v -s --self-contained-html --ignore=tests/core/test_paths.py --ignore=tests/core/test_index_utils.py --ignore=tests/core/test_collections.py tests/$1
if [ "$1" = "mxnet" ] ; then
python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/zero_code_change/test_mxnet_gluon_integration.py
# we run test/rules once, mxnet build has configured permission for sns to run this test
python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/rules
elif [ "$1" = "pytorch" ] ; then
python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/zero_code_change/test_pytorch_integration.py
python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} tests/zero_code_change/test_pytorch_multiprocessing.py
Expand Down
2 changes: 2 additions & 0 deletions smdebug/rules/action/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Local
from .action import Actions
83 changes: 83 additions & 0 deletions smdebug/rules/action/action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Standard Library
import json

# First Party
from smdebug.core.logger import get_logger

# Local
from .message_action import MessageAction
from .stop_training_action import StopTrainingAction

ALLOWED_ACTIONS = ["stoptraining", "sms", "email"]


class Actions:
def __init__(self, actions_str="", rule_name=""):
self._actions = []
self._logger = get_logger()
actions_str = actions_str.strip() if actions_str is not None else ""
if actions_str == "":
self._logger.info(f"No action specified. Action str is {actions_str}")
return
self._register_actions(actions_str, rule_name)

def _register_actions(self, actions_str="", rule_name=""):

actions_str = actions_str.lower()
self._logger.info(f"Action string: {actions_str} and rule_name:{rule_name}")
action_json = json.loads(actions_str)
actions_list = []
if isinstance(action_json, dict):
actions_list.append(action_json)
elif isinstance(action_json, list):
actions_list = action_json
else:
self._logger.info(
f"Action string: {actions_str}, expected either a list of dict or dict. Skipping action registering"
)
return

# action : {name:'StopTraining', 'training_job_prefix':''}
# {name:'sms or email', 'endpoint':''}
for action_dict in actions_list:
if not isinstance(action_dict, dict):
self._logger.info(
f"expected dictionary for action, got {action_dict} . Skipping this action."
)
continue
if "name" in action_dict:
if action_dict["name"] == "stoptraining":
training_job_prefix = (
action_dict["training_job_prefix"]
if "training_job_prefix" in action_dict
else None
)
if training_job_prefix is None:
self._logger.info(
f"Action :{action_dict['name']} requires 'training_job_prefix' key to be specified. "
f"Action_dict is: {action_dict}"
)
continue

action = StopTrainingAction(rule_name, training_job_prefix)
self._actions.append(action)
elif action_dict["name"] == "sms" or action_dict["name"] == "email":
endpoint = action_dict["endpoint"] if "endpoint" in action_dict else None
if endpoint is None:
self._logger.info(
f"Action :{action_dict['name']} requires endpoint key parameter. "
)
continue

action = MessageAction(rule_name, action_dict["name"], endpoint)
self._actions.append(action)

else:
self._logger.info(
f"Action :{action_dict['name']} not supported. Allowed action names are: {ALLOWED_ACTIONS}"
)

def invoke(self):
self._logger.info("Invoking actions")
for action in self._actions:
action.invoke()
145 changes: 145 additions & 0 deletions smdebug/rules/action/message_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Standard Library
import json
import os

# Third Party
import boto3

# First Party
from smdebug.core.logger import get_logger

# action :
# {name:'sms' or 'email', 'endpoint':'phone or emailid'}


class MessageAction:
def __init__(self, rule_name, message_type, message_endpoint):
self._topic_name = "SMDebugRules"
self._logger = get_logger()

if message_type == "sms" or message_type == "email":
self._protocol = message_type
else:
self._protocol = None
self._logger.info(
f"Unsupported message type:{message_type} in MessageAction. Returning"
)
return
self._message_endpoint = message_endpoint

# Below 2 is to help in tests
self._last_send_mesg_response = None
self._last_subscription_response = None

env_region_name = os.getenv("AWS_REGION", "us-east-1")

self._sns_client = boto3.client("sns", region_name=env_region_name)

self._topic_arn = self._create_sns_topic_if_not_exists()

self._subscribe_mesgtype_endpoint()
self._logger.info(
f"Registering messageAction with protocol:{self._protocol} endpoint:{self._message_endpoint} and topic_arn:{self._topic_arn} region:{env_region_name}"
)
self._rule_name = rule_name

def _create_sns_topic_if_not_exists(self):
try:
topic = self._sns_client.create_topic(Name=self._topic_name)
self._logger.info(
f"topic_name: {self._topic_name} , creating topic returned response:{topic}"
)
if topic:
return topic["TopicArn"]
except Exception as e:
self._logger.info(
f"Caught exception while creating topic:{self._topic_name} exception is: \n {e}"
)
return None

def _check_subscriptions(self, topic_arn, protocol, endpoint):
try:
next_token = "random"
subs = self._sns_client.list_subscriptions()
sub_array = subs["Subscriptions"]
while next_token is not None:
for sub_dict in sub_array:
proto = sub_dict["Protocol"]
ep = sub_dict["Endpoint"]
topic = sub_dict["TopicArn"]
if proto == protocol and topic == topic_arn and ep == endpoint:
self._logger.info(f"Existing Subscription found: {sub_dict}")
return True
if "NextToken" in subs:
next_token = subs["NextToken"]
subs = self._sns_client.list_subscriptions(NextToken=next_token)
sub_array = subs["Subscriptions"]
continue
else:
next_token = None
except Exception as e:
self._logger.info(
f"Caught exception while list subscription topic:{self._topic_name} exception is: \n {e}"
)
return False

def _subscribe_mesgtype_endpoint(self):

response = None
try:

if self._topic_arn and self._protocol and self._message_endpoint:
filter_policy = {}
if self._protocol == "sms":
filter_policy["phone_num"] = [self._message_endpoint]
else:
filter_policy["email"] = [self._message_endpoint]
if not self._check_subscriptions(
self._topic_arn, self._protocol, self._message_endpoint
):

response = self._sns_client.subscribe(
TopicArn=self._topic_arn,
Protocol=self._protocol, # sms or email
Endpoint=self._message_endpoint, # phone number or email addresss
Attributes={"FilterPolicy": json.dumps(filter_policy)},
ReturnSubscriptionArn=False, # True means always return ARN
)
else:
response = f"Subscription exists for topic:{self._topic_arn}, protocol:{self._protocol}, endpoint:{self._message_endpoint}"
except Exception as e:
self._logger.info(
f"Caught exception while subscribing endpoint on topic:{self._topic_arn} exception is: \n {e}"
)
self._logger.info(f"response for sns subscribe is {response} ")
self._last_subscription_response = response

def _send_message(self, message):
response = None
message = f"SMDebugRule:{self._rule_name} fired. {message}"
try:
if self._protocol == "sms":
msg_attributes = {
"phone_num": {"DataType": "String", "StringValue": self._message_endpoint}
}
else:
msg_attributes = {
"email": {"DataType": "String", "StringValue": self._message_endpoint}
}
response = self._sns_client.publish(
TopicArn=self._topic_arn,
Message=message,
Subject=f"SMDebugRule:{self._rule_name} fired",
# MessageStructure='json',
MessageAttributes=msg_attributes,
)
except Exception as e:
self._logger.info(
f"Caught exception while publishing message on topic:{self._topic_arn} exception is: \n {e}"
)
self._logger.info(f"Response of send message:{response}")
self._last_send_mesg_response = response
return response

def invoke(self, message=None):
self._send_message(message)
60 changes: 60 additions & 0 deletions smdebug/rules/action/stop_training_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Standard Library
import os

# Third Party
import boto3

# First Party
from smdebug.core.logger import get_logger


class StopTrainingAction:
def __init__(self, rule_name, training_job_prefix):
self._training_job_prefix = training_job_prefix
env_region_name = os.getenv("AWS_REGION", "us-east-1")
self._logger = get_logger()
self._logger.info(
f"StopTrainingAction created with training_job_prefix:{training_job_prefix} and region:{env_region_name}"
)
self._sm_client = boto3.client("sagemaker", region_name=env_region_name)
self._rule_name = rule_name
self._found_jobs = self._get_sm_tj_jobs_with_prefix()

def _get_sm_tj_jobs_with_prefix(self):
found_jobs = []
try:
jobs = self._sm_client.list_training_jobs()
if "TrainingJobSummaries" in jobs:
jobs = jobs["TrainingJobSummaries"]
else:
self._logger.info(
f"No TrainingJob summaries found: list_training_jobs output is : {jobs}"
)
return
for job in jobs:
self._logger.info(
f"TrainingJob name: {job['TrainingJobName']} , status:{job['TrainingJobStatus']}"
)
if job["TrainingJobName"] is not None and job["TrainingJobName"].startswith(
self._training_job_prefix
):
found_jobs.append(job["TrainingJobName"])
self._logger.info(f"found_training job {found_jobs}")
except Exception as e:
self._logger.info(
f"Caught exception while getting list_training_job exception is: \n {e}"
)
return found_jobs

def _stop_training_job(self):
if len(self._found_jobs) != 1:
return
self._logger.info(f"Invoking StopTrainingJob action on SM jobname:{self._found_jobs}")
try:
res = self._sm_client.stop_training_job(TrainingJobName=self._found_jobs[0])
self._logger.info(f"Stop Training job response:{res}")
except Exception as e:
self._logger.info(f"Got exception while stopping training job{self._found_jobs[0]}:{e}")

def invoke(self, message=None):
self._stop_training_job()
5 changes: 4 additions & 1 deletion smdebug/rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from smdebug.analysis.utils import no_refresh
from smdebug.core.logger import get_logger
from smdebug.exceptions import RuleEvaluationConditionMet
from smdebug.rules.action import Actions

# Local
from .req_tensors import RequiredTensors


# This is Rule interface
class Rule(ABC):
def __init__(self, base_trial, other_trials=None):
def __init__(self, base_trial, other_trials=None, action_str=""):
self.base_trial = base_trial
self.other_trials = other_trials

Expand All @@ -24,6 +25,7 @@ def __init__(self, base_trial, other_trials=None):

self.logger = get_logger()
self.rule_name = self.__class__.__name__
self._actions = Actions(actions_str=action_str, rule_name=self.rule_name)

def set_required_tensors(self, step):
pass
Expand Down Expand Up @@ -55,4 +57,5 @@ def invoke(self, step):
val = self.invoke_at_step(step)

if val:
self._actions.invoke()
raise RuleEvaluationConditionMet(self.rule_name, step)
Empty file added tests/rules/__init__.py
Empty file.
Empty file added tests/rules/action/__init__.py
Empty file.
Loading