diff --git a/smdebug/core/hook.py b/smdebug/core/hook.py index cb53a947e..8c352d110 100644 --- a/smdebug/core/hook.py +++ b/smdebug/core/hook.py @@ -550,6 +550,14 @@ def _increment_step(self): # Called in the internal AWS codebase to determine # if a particular tensor value should be saved def should_save_tensor_or_collection(self, tensor_name: str, collection_name: str) -> bool: + if self.prepared_collections is False: + # always return false if an attempt to save a + # tensor is made before the collections are prepared. + # this can happen if the fn is called before callbacks are init. + self.logger.warning( + "Tensors cannot be saved with smdebug before callbacks are initialized." + ) + return False if self._is_collection_being_saved_for_step(collection_name): return True return self.is_tensor_saved_for_step(tensor_name) diff --git a/tests/tensorflow2/test_should_save_tensor.py b/tests/tensorflow2/test_should_save_tensor.py index 0e3e4ba1d..dc0281c14 100644 --- a/tests/tensorflow2/test_should_save_tensor.py +++ b/tests/tensorflow2/test_should_save_tensor.py @@ -61,3 +61,11 @@ def test_should_save_tensor_with_custom_collection(out_dir): else: assert not hook.should_save_tensor_or_collection(layer_name, CollectionKeys.GRADIENTS) assert not hook.should_save_tensor_or_collection(layer_name, CollectionKeys.LAYERS) + + +def test_should_save_tensor_behavior_without_prepare_collections(out_dir): + """Always return false if an attempt to save a tensor is made before the collections are prepared. + This can happen if the fn is called before callbacks are init.""" + hook = smd.KerasHook(out_dir, save_config=SaveConfig(save_interval=3), save_all=True) + assert not hook.should_save_tensor_or_collection("dummy", CollectionKeys.GRADIENTS) + assert not hook.should_save_tensor_or_collection("dummy", CollectionKeys.LAYERS)