-
Notifications
You must be signed in to change notification settings - Fork 83
Adding action class #285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Adding action class #285
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
2448178
Adding action class
Vikas-kum 679b00d
handling empty json string
Vikas-kum ecd939b
Fixing a bug with parameters passing
Vikas-kum fe8f77d
bug fix
Vikas-kum 25a69b3
some logs and client fix
Vikas-kum d9067ae
some more fixes
Vikas-kum dbc15a3
Adding tests and some fixes
Vikas-kum 08e3cae
Extra logging mesg
Vikas-kum d7dcd0e
Adding test file
Vikas-kum 8e8f238
Adding rules test in CI
Vikas-kum e877df3
some exception handling
Vikas-kum 1dc3e0d
Adding test/rules to mxnet build which has reqd permission for sns
Vikas-kum c0b8f74
addressed review comments
Vikas-kum File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Local | ||
from .action import Actions |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Standard Library | ||
import json | ||
|
||
# First Party | ||
from smdebug.core.logger import get_logger | ||
|
||
# Local | ||
from .message_action import MessageAction | ||
from .stop_training_action import StopTrainingAction | ||
|
||
ALLOWED_ACTIONS = ["stoptraining", "sms", "email"] | ||
|
||
|
||
class Actions: | ||
def __init__(self, actions_str="", rule_name=""): | ||
self._actions = [] | ||
self._logger = get_logger() | ||
actions_str = actions_str.strip() if actions_str is not None else "" | ||
if actions_str == "": | ||
self._logger.info(f"No action specified. Action str is {actions_str}") | ||
return | ||
self._register_actions(actions_str, rule_name) | ||
|
||
def _register_actions(self, actions_str="", rule_name=""): | ||
|
||
actions_str = actions_str.lower() | ||
self._logger.info(f"Action string: {actions_str} and rule_name:{rule_name}") | ||
action_json = json.loads(actions_str) | ||
actions_list = [] | ||
if isinstance(action_json, dict): | ||
actions_list.append(action_json) | ||
elif isinstance(action_json, list): | ||
actions_list = action_json | ||
else: | ||
self._logger.info( | ||
f"Action string: {actions_str}, expected either a list of dict or dict. Skipping action registering" | ||
) | ||
return | ||
|
||
# action : {name:'StopTraining', 'training_job_prefix':''} | ||
# {name:'sms or email', 'endpoint':''} | ||
for action_dict in actions_list: | ||
if not isinstance(action_dict, dict): | ||
self._logger.info( | ||
f"expected dictionary for action, got {action_dict} . Skipping this action." | ||
) | ||
continue | ||
if "name" in action_dict: | ||
if action_dict["name"] == "stoptraining": | ||
training_job_prefix = ( | ||
action_dict["training_job_prefix"] | ||
if "training_job_prefix" in action_dict | ||
else None | ||
) | ||
if training_job_prefix is None: | ||
self._logger.info( | ||
f"Action :{action_dict['name']} requires 'training_job_prefix' key to be specified. " | ||
f"Action_dict is: {action_dict}" | ||
) | ||
continue | ||
|
||
action = StopTrainingAction(rule_name, training_job_prefix) | ||
self._actions.append(action) | ||
elif action_dict["name"] == "sms" or action_dict["name"] == "email": | ||
endpoint = action_dict["endpoint"] if "endpoint" in action_dict else None | ||
if endpoint is None: | ||
self._logger.info( | ||
f"Action :{action_dict['name']} requires endpoint key parameter. " | ||
) | ||
continue | ||
|
||
action = MessageAction(rule_name, action_dict["name"], endpoint) | ||
self._actions.append(action) | ||
|
||
else: | ||
self._logger.info( | ||
f"Action :{action_dict['name']} not supported. Allowed action names are: {ALLOWED_ACTIONS}" | ||
) | ||
|
||
def invoke(self): | ||
self._logger.info("Invoking actions") | ||
for action in self._actions: | ||
action.invoke() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Standard Library | ||
import json | ||
import os | ||
|
||
# Third Party | ||
import boto3 | ||
|
||
# First Party | ||
from smdebug.core.logger import get_logger | ||
|
||
# action : | ||
# {name:'sms' or 'email', 'endpoint':'phone or emailid'} | ||
|
||
|
||
class MessageAction: | ||
def __init__(self, rule_name, message_type, message_endpoint): | ||
self._topic_name = "SMDebugRules" | ||
self._logger = get_logger() | ||
|
||
if message_type == "sms" or message_type == "email": | ||
self._protocol = message_type | ||
else: | ||
self._protocol = None | ||
self._logger.info( | ||
f"Unsupported message type:{message_type} in MessageAction. Returning" | ||
) | ||
return | ||
self._message_endpoint = message_endpoint | ||
|
||
# Below 2 is to help in tests | ||
self._last_send_mesg_response = None | ||
self._last_subscription_response = None | ||
|
||
env_region_name = os.getenv("AWS_REGION", "us-east-1") | ||
|
||
self._sns_client = boto3.client("sns", region_name=env_region_name) | ||
|
||
self._topic_arn = self._create_sns_topic_if_not_exists() | ||
|
||
self._subscribe_mesgtype_endpoint() | ||
self._logger.info( | ||
f"Registering messageAction with protocol:{self._protocol} endpoint:{self._message_endpoint} and topic_arn:{self._topic_arn} region:{env_region_name}" | ||
) | ||
self._rule_name = rule_name | ||
|
||
def _create_sns_topic_if_not_exists(self): | ||
try: | ||
topic = self._sns_client.create_topic(Name=self._topic_name) | ||
self._logger.info( | ||
f"topic_name: {self._topic_name} , creating topic returned response:{topic}" | ||
) | ||
if topic: | ||
return topic["TopicArn"] | ||
except Exception as e: | ||
self._logger.info( | ||
f"Caught exception while creating topic:{self._topic_name} exception is: \n {e}" | ||
) | ||
return None | ||
|
||
def _check_subscriptions(self, topic_arn, protocol, endpoint): | ||
try: | ||
next_token = "random" | ||
subs = self._sns_client.list_subscriptions() | ||
sub_array = subs["Subscriptions"] | ||
while next_token is not None: | ||
for sub_dict in sub_array: | ||
proto = sub_dict["Protocol"] | ||
ep = sub_dict["Endpoint"] | ||
topic = sub_dict["TopicArn"] | ||
if proto == protocol and topic == topic_arn and ep == endpoint: | ||
self._logger.info(f"Existing Subscription found: {sub_dict}") | ||
return True | ||
if "NextToken" in subs: | ||
next_token = subs["NextToken"] | ||
subs = self._sns_client.list_subscriptions(NextToken=next_token) | ||
sub_array = subs["Subscriptions"] | ||
continue | ||
else: | ||
next_token = None | ||
except Exception as e: | ||
self._logger.info( | ||
f"Caught exception while list subscription topic:{self._topic_name} exception is: \n {e}" | ||
) | ||
return False | ||
|
||
def _subscribe_mesgtype_endpoint(self): | ||
|
||
response = None | ||
try: | ||
|
||
if self._topic_arn and self._protocol and self._message_endpoint: | ||
filter_policy = {} | ||
if self._protocol == "sms": | ||
filter_policy["phone_num"] = [self._message_endpoint] | ||
else: | ||
filter_policy["email"] = [self._message_endpoint] | ||
if not self._check_subscriptions( | ||
self._topic_arn, self._protocol, self._message_endpoint | ||
): | ||
|
||
response = self._sns_client.subscribe( | ||
TopicArn=self._topic_arn, | ||
Protocol=self._protocol, # sms or email | ||
Endpoint=self._message_endpoint, # phone number or email addresss | ||
Attributes={"FilterPolicy": json.dumps(filter_policy)}, | ||
ReturnSubscriptionArn=False, # True means always return ARN | ||
) | ||
else: | ||
response = f"Subscription exists for topic:{self._topic_arn}, protocol:{self._protocol}, endpoint:{self._message_endpoint}" | ||
except Exception as e: | ||
self._logger.info( | ||
f"Caught exception while subscribing endpoint on topic:{self._topic_arn} exception is: \n {e}" | ||
) | ||
self._logger.info(f"response for sns subscribe is {response} ") | ||
self._last_subscription_response = response | ||
|
||
def _send_message(self, message): | ||
response = None | ||
message = f"SMDebugRule:{self._rule_name} fired. {message}" | ||
try: | ||
if self._protocol == "sms": | ||
msg_attributes = { | ||
"phone_num": {"DataType": "String", "StringValue": self._message_endpoint} | ||
} | ||
else: | ||
msg_attributes = { | ||
"email": {"DataType": "String", "StringValue": self._message_endpoint} | ||
} | ||
response = self._sns_client.publish( | ||
TopicArn=self._topic_arn, | ||
Message=message, | ||
Subject=f"SMDebugRule:{self._rule_name} fired", | ||
# MessageStructure='json', | ||
MessageAttributes=msg_attributes, | ||
) | ||
except Exception as e: | ||
self._logger.info( | ||
f"Caught exception while publishing message on topic:{self._topic_arn} exception is: \n {e}" | ||
) | ||
self._logger.info(f"Response of send message:{response}") | ||
self._last_send_mesg_response = response | ||
return response | ||
|
||
def invoke(self, message=None): | ||
self._send_message(message) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Standard Library | ||
import os | ||
|
||
# Third Party | ||
import boto3 | ||
|
||
# First Party | ||
from smdebug.core.logger import get_logger | ||
|
||
|
||
class StopTrainingAction: | ||
def __init__(self, rule_name, training_job_prefix): | ||
self._training_job_prefix = training_job_prefix | ||
env_region_name = os.getenv("AWS_REGION", "us-east-1") | ||
self._logger = get_logger() | ||
self._logger.info( | ||
f"StopTrainingAction created with training_job_prefix:{training_job_prefix} and region:{env_region_name}" | ||
) | ||
self._sm_client = boto3.client("sagemaker", region_name=env_region_name) | ||
self._rule_name = rule_name | ||
self._found_jobs = self._get_sm_tj_jobs_with_prefix() | ||
|
||
def _get_sm_tj_jobs_with_prefix(self): | ||
found_jobs = [] | ||
try: | ||
jobs = self._sm_client.list_training_jobs() | ||
if "TrainingJobSummaries" in jobs: | ||
jobs = jobs["TrainingJobSummaries"] | ||
else: | ||
self._logger.info( | ||
f"No TrainingJob summaries found: list_training_jobs output is : {jobs}" | ||
) | ||
return | ||
for job in jobs: | ||
self._logger.info( | ||
f"TrainingJob name: {job['TrainingJobName']} , status:{job['TrainingJobStatus']}" | ||
) | ||
if job["TrainingJobName"] is not None and job["TrainingJobName"].startswith( | ||
self._training_job_prefix | ||
): | ||
found_jobs.append(job["TrainingJobName"]) | ||
self._logger.info(f"found_training job {found_jobs}") | ||
except Exception as e: | ||
self._logger.info( | ||
f"Caught exception while getting list_training_job exception is: \n {e}" | ||
) | ||
return found_jobs | ||
|
||
def _stop_training_job(self): | ||
if len(self._found_jobs) != 1: | ||
return | ||
self._logger.info(f"Invoking StopTrainingJob action on SM jobname:{self._found_jobs}") | ||
try: | ||
res = self._sm_client.stop_training_job(TrainingJobName=self._found_jobs[0]) | ||
self._logger.info(f"Stop Training job response:{res}") | ||
Vikas-kum marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except Exception as e: | ||
self._logger.info(f"Got exception while stopping training job{self._found_jobs[0]}:{e}") | ||
|
||
def invoke(self, message=None): | ||
self._stop_training_job() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.