Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,14 @@ def __init__(
self.worker = None
self.save_all_workers = True if include_workers == "all" else False
self.chief_worker = CONFIG_DEFAULT_WORKER_NAME

if include_collections is None:
include_collections = default_include_collections
self.default_include_collections = default_include_collections
self.include_collections = flatten(include_collections)
else:
include_collections = flatten(include_collections)
self.include_collections = list(
set(include_collections).union(set(default_include_collections))
)

self.save_all = save_all
self.save_config = SaveConfig.parse(save_config)
Expand Down Expand Up @@ -268,7 +272,7 @@ def _get_collections_to_save_for_step(self) -> Set["Collection"]:
step_str = f"for step {self.step}"
else:
step_str = f"for step {self.mode_steps[self.mode]} of mode {self.mode.name}"
self.logger.info(
self.logger.debug(
f"Saving the collections "
f"{', '.join([x.name for x in self._collections_to_save_for_step])} {step_str}"
)
Expand Down Expand Up @@ -583,14 +587,6 @@ def save_scalar(self, name, value, searchable=False):
scalar_obj = ScalarCache(name, val, searchable=True, write_tb=True, write_event=True)
self.scalar_cache.append(scalar_obj)

# def save_tensor(self, name, value):
# # todo: support to add these tensors to any collection.
# # complication here is that we need to export the file again
# # todo: what happens if name is conflicting
# if self.writer is None:
# self._init_writer()
# self._save_raw_tensor(name, value)

def _write_raw_tensor(self, tensor_name, tensor_value, save_collections, tensor_ref=None):
for s_col in save_collections:
reduction_config = s_col.reduction_config
Expand Down
26 changes: 20 additions & 6 deletions smdebug/core/save_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,26 @@ def __init__(
end_step: Union[int, str] = None,
save_steps: List[int] = None,
):
self.save_interval = int(save_interval or DEFAULT_SAVE_CONFIG_INTERVAL)
self.save_steps = save_steps or DEFAULT_SAVE_CONFIG_SAVE_STEPS
self.start_step = int(start_step or DEFAULT_SAVE_CONFIG_START_STEP)
self.end_step = end_step or DEFAULT_SAVE_CONFIG_END_STEP
if self.end_step: # can be None
self.end_step = int(self.end_step)
if save_interval is None:
self.save_interval = DEFAULT_SAVE_CONFIG_INTERVAL
else:
self.save_interval = int(save_interval)

if save_steps is None:
self.save_steps = DEFAULT_SAVE_CONFIG_SAVE_STEPS
else:
self.save_steps = save_steps

if start_step is None:
self.start_step = DEFAULT_SAVE_CONFIG_START_STEP
else:
self.start_step = int(start_step)

if end_step is None:
self.end_step = DEFAULT_SAVE_CONFIG_END_STEP
else:
self.end_step = int(end_step)

## DO NOT REMOVE; please make sure that _check & from_json is updated accordingly.
self._check()

Expand Down
4 changes: 0 additions & 4 deletions smdebug/mxnet/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def __init__(
save_all=save_all,
include_workers=include_workers,
)
# We would like to collect loss collection
# even if user does not specify any collections
if CollectionKeys.LOSSES not in self.include_collections:
self.include_collections.append(CollectionKeys.LOSSES)
self.last_block = None

self.model = None
Expand Down
4 changes: 0 additions & 4 deletions smdebug/pytorch/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ def __init__(
save_all=save_all,
include_workers=include_workers,
)
# We would like to collect loss collection
# even if user does not specify any collections
if CollectionKeys.LOSSES not in self.include_collections:
self.include_collections.append(CollectionKeys.LOSSES)
# mapping of module objects to their names,
# useful in forward hook for logging input/output of modules
self.module_maps = dict()
Expand Down
1 change: 0 additions & 1 deletion smdebug/tensorflow/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
DEFAULT_INCLUDE_COLLECTIONS = [
CollectionKeys.METRICS,
CollectionKeys.LOSSES,
CollectionKeys.SCALARS,
CollectionKeys.SEARCHABLE_SCALARS,
]

Expand Down
8 changes: 8 additions & 0 deletions tests/core/test_save_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Third Party

# First Party
from smdebug import modes
from smdebug.core.save_config import SaveConfig


Expand All @@ -19,3 +20,10 @@ def test_load_empty():
def test_load_none():
r1 = SaveConfig(start_step=100)
assert r1 == SaveConfig.from_json(r1.to_json())


def test_end_step():
s = SaveConfig(end_step=0)
assert s.should_save_step(modes.GLOBAL, 0) is False
assert s.should_save_step(modes.GLOBAL, 19) is False
assert s.should_save_step(modes.GLOBAL, 100) is False
3 changes: 2 additions & 1 deletion tests/tensorflow/hooks/test_collection_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_collection_defaults_json(out_dir, monkeypatch):
assert hook.save_config.get_save_config(ModeKeys.GLOBAL).save_interval == 1
# Check include_collections
assert "weights" in hook.include_collections and "losses" in hook.include_collections
assert len(hook.include_collections) == 2

assert len(hook.include_collections) == 4
# Check collection configurations for losses
assert (
hook.collection_manager.collections["losses"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"S3OutputPath": "s3://kjndjknd_bucket/prefix",
"LocalPath": "/tmp/test",
"HookParameters": {
"end_step": 0
}
}
13 changes: 13 additions & 0 deletions tests/tensorflow/hooks/test_save_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def test_save_config(out_dir):
helper_test_save_config(out_dir, hook)


def test_save_config_disable(out_dir, monkeypatch):
pre_test_clean_up()
monkeypatch.setenv(
CONFIG_FILE_PATH_ENV_STR,
"tests/tensorflow/hooks/test_json_configs/test_save_config_disable.json",
)
hook = SessionHook.hook_from_config()
simple_model(hook)
tr = create_trial(out_dir)
assert len(tr.steps()) == 0
assert len(tr.tensors()) == 0


def test_save_config_json(out_dir, monkeypatch):
pre_test_clean_up()
monkeypatch.setenv(
Expand Down
5 changes: 5 additions & 0 deletions tests/tensorflow/keras/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def train_model(
reduction_config=reduction_config,
)

if not save_all and include_collections is not None:
for cname in hook.include_collections:
if cname not in include_collections:
hook.get_collection(cname).save_config = SaveConfig(end_step=0)

if create_relu_collection:
hook.get_collection("relu").add_keras_layer(relu_layer, inputs=True, outputs=True)

Expand Down
5 changes: 5 additions & 0 deletions tests/tensorflow/keras/test_keras_mirrored.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def scale(image, label):
include_workers=include_workers,
)

if not save_all and include_collections is not None:
for cname in hook.include_collections:
if cname not in include_collections:
hook.get_collection(cname).save_config = SaveConfig(end_step=0)

if use_keras_optimizer:
opt = tf.keras.optimizers.Adam()
else:
Expand Down