Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from smdebug.core.hook_utils import get_tensorboard_dir, verify_and_get_out_dir
from smdebug.core.json_config import create_hook_from_json_config
from smdebug.core.logger import get_logger
from smdebug.core.modes import ALLOWED_MODES, ModeKeys
from smdebug.core.modes import ALLOWED_MODE_NAMES, ALLOWED_MODES, ModeKeys
from smdebug.core.reduction_config import ReductionConfig
from smdebug.core.reductions import get_reduction_tensor_name
from smdebug.core.sagemaker_utils import is_sagemaker_job
Expand Down Expand Up @@ -461,7 +461,7 @@ def set_mode(self, mode):
self.mode = mode
else:
raise ValueError(
"Invalid mode {}. Valid modes are {}.".format(mode, ",".join(ALLOWED_MODES))
"Invalid mode {}. Valid modes are {}.".format(mode, ",".join(ALLOWED_MODE_NAMES))
)

if mode not in self.mode_steps:
Expand Down
2 changes: 1 addition & 1 deletion smdebug/core/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ModeKeys(Enum):
GLOBAL = 4


ALLOWED_MODES = [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]
ALLOWED_MODES = [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT, ModeKeys.GLOBAL]
ALLOWED_MODE_NAMES = [x.name for x in ALLOWED_MODES]
MODE_STEP_PLUGIN_NAME = "mode_step"
MODE_PLUGIN_NAME = "mode"
11 changes: 10 additions & 1 deletion tests/xgboost/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

# Third Party
import numpy as np
import pytest
import xgboost

# First Party
from smdebug import SaveConfig
from smdebug import SaveConfig, modes
from smdebug.core.access_layer.utils import has_training_ended
from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR
from smdebug.trials import create_trial
Expand Down Expand Up @@ -220,3 +221,11 @@ def test_hook_tensorboard_dir_created(tmpdir):
hook = Hook(out_dir=out_dir, export_tensorboard=True)
run_xgboost_model(hook=hook)
assert "tensorboard" in os.listdir(out_dir)


def test_setting_mode(tmpdir):
out_dir = os.path.join(tmpdir, str(uuid.uuid4()))
hook = Hook(out_dir=out_dir, export_tensorboard=True)
hook.set_mode(modes.GLOBAL)
with pytest.raises(ValueError):
hook.set_mode("a")