File tree Expand file tree Collapse file tree 2 files changed +16
-0
lines changed Expand file tree Collapse file tree 2 files changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -550,6 +550,14 @@ def _increment_step(self):
550550 # Called in the internal AWS codebase to determine
551551 # if a particular tensor value should be saved
552552 def should_save_tensor_or_collection (self , tensor_name : str , collection_name : str ) -> bool :
553+ if self .prepared_collections is False :
554+ # always return false if an attempt to save a
555+ # tensor is made before the collections are prepared.
556+ # this can happen if the fn is called before callbacks are init.
557+ self .logger .warning (
558+ "Tensors cannot be saved with smdebug before callbacks are initialized."
559+ )
560+ return False
553561 if self ._is_collection_being_saved_for_step (collection_name ):
554562 return True
555563 return self .is_tensor_saved_for_step (tensor_name )
Original file line number Diff line number Diff line change @@ -61,3 +61,11 @@ def test_should_save_tensor_with_custom_collection(out_dir):
6161 else :
6262 assert not hook .should_save_tensor_or_collection (layer_name , CollectionKeys .GRADIENTS )
6363 assert not hook .should_save_tensor_or_collection (layer_name , CollectionKeys .LAYERS )
64+
65+
66+ def test_should_save_tensor_behavior_without_prepare_collections (out_dir ):
67+ """Always return false if an attempt to save a tensor is made before the collections are prepared.
68+ This can happen if the fn is called before callbacks are init."""
69+ hook = smd .KerasHook (out_dir , save_config = SaveConfig (save_interval = 3 ), save_all = True )
70+ assert not hook .should_save_tensor_or_collection ("dummy" , CollectionKeys .GRADIENTS )
71+ assert not hook .should_save_tensor_or_collection ("dummy" , CollectionKeys .LAYERS )
You can’t perform that action at this time.
0 commit comments