From bb1e5402ce1a2c6e972d3a9a5f7496db425f435e Mon Sep 17 00:00:00 2001 From: Rahul Huilgol Date: Wed, 4 Dec 2019 12:57:54 -0800 Subject: [PATCH] Allow setting global mode and add test --- smdebug/core/hook.py | 4 ++-- smdebug/core/modes.py | 2 +- tests/xgboost/test_hook.py | 11 ++++++++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/smdebug/core/hook.py b/smdebug/core/hook.py index b07771079..3427b49ab 100644 --- a/smdebug/core/hook.py +++ b/smdebug/core/hook.py @@ -27,7 +27,7 @@ from smdebug.core.hook_utils import get_tensorboard_dir, verify_and_get_out_dir from smdebug.core.json_config import create_hook_from_json_config from smdebug.core.logger import get_logger -from smdebug.core.modes import ALLOWED_MODES, ModeKeys +from smdebug.core.modes import ALLOWED_MODE_NAMES, ALLOWED_MODES, ModeKeys from smdebug.core.reduction_config import ReductionConfig from smdebug.core.reductions import get_reduction_tensor_name from smdebug.core.sagemaker_utils import is_sagemaker_job @@ -461,7 +461,7 @@ def set_mode(self, mode): self.mode = mode else: raise ValueError( - "Invalid mode {}. Valid modes are {}.".format(mode, ",".join(ALLOWED_MODES)) + "Invalid mode {}. Valid modes are {}.".format(mode, ",".join(ALLOWED_MODE_NAMES)) ) if mode not in self.mode_steps: diff --git a/smdebug/core/modes.py b/smdebug/core/modes.py index 185a336bf..fce21ad1d 100644 --- a/smdebug/core/modes.py +++ b/smdebug/core/modes.py @@ -10,7 +10,7 @@ class ModeKeys(Enum): GLOBAL = 4 -ALLOWED_MODES = [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT] +ALLOWED_MODES = [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT, ModeKeys.GLOBAL] ALLOWED_MODE_NAMES = [x.name for x in ALLOWED_MODES] MODE_STEP_PLUGIN_NAME = "mode_step" MODE_PLUGIN_NAME = "mode" diff --git a/tests/xgboost/test_hook.py b/tests/xgboost/test_hook.py index 649b90546..5918fc1c9 100644 --- a/tests/xgboost/test_hook.py +++ b/tests/xgboost/test_hook.py @@ -4,10 +4,11 @@ # Third Party import numpy as np +import pytest import xgboost # First Party -from smdebug import SaveConfig +from smdebug import SaveConfig, modes from smdebug.core.access_layer.utils import has_training_ended from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR from smdebug.trials import create_trial @@ -220,3 +221,11 @@ def test_hook_tensorboard_dir_created(tmpdir): hook = Hook(out_dir=out_dir, export_tensorboard=True) run_xgboost_model(hook=hook) assert "tensorboard" in os.listdir(out_dir) + + +def test_setting_mode(tmpdir): + out_dir = os.path.join(tmpdir, str(uuid.uuid4())) + hook = Hook(out_dir=out_dir, export_tensorboard=True) + hook.set_mode(modes.GLOBAL) + with pytest.raises(ValueError): + hook.set_mode("a")