Skip to content

Commit 06b412d

Browse files
authored
Bugfix: Debugger breaks if should_save_tensor is called before collections are prepared (#372)
1 parent 3183d52 commit 06b412d

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

smdebug/core/hook.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,14 @@ def _increment_step(self):
550550
# Called in the internal AWS codebase to determine
551551
# if a particular tensor value should be saved
552552
def should_save_tensor_or_collection(self, tensor_name: str, collection_name: str) -> bool:
553+
if self.prepared_collections is False:
554+
# always return false if an attempt to save a
555+
# tensor is made before the collections are prepared.
556+
# this can happen if the fn is called before callbacks are init.
557+
self.logger.warning(
558+
"Tensors cannot be saved with smdebug before callbacks are initialized."
559+
)
560+
return False
553561
if self._is_collection_being_saved_for_step(collection_name):
554562
return True
555563
return self.is_tensor_saved_for_step(tensor_name)

tests/tensorflow2/test_should_save_tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,11 @@ def test_should_save_tensor_with_custom_collection(out_dir):
6161
else:
6262
assert not hook.should_save_tensor_or_collection(layer_name, CollectionKeys.GRADIENTS)
6363
assert not hook.should_save_tensor_or_collection(layer_name, CollectionKeys.LAYERS)
64+
65+
66+
def test_should_save_tensor_behavior_without_prepare_collections(out_dir):
67+
"""Always return false if an attempt to save a tensor is made before the collections are prepared.
68+
This can happen if the fn is called before callbacks are init."""
69+
hook = smd.KerasHook(out_dir, save_config=SaveConfig(save_interval=3), save_all=True)
70+
assert not hook.should_save_tensor_or_collection("dummy", CollectionKeys.GRADIENTS)
71+
assert not hook.should_save_tensor_or_collection("dummy", CollectionKeys.LAYERS)

0 commit comments

Comments
 (0)