diff --git a/smdebug/core/config_constants.py b/smdebug/core/config_constants.py index 4eabf447d..d8a212f6b 100644 --- a/smdebug/core/config_constants.py +++ b/smdebug/core/config_constants.py @@ -41,3 +41,5 @@ CALLABLE_CACHE_ENV_VAR = "SMDEBUG_KERAS_CALLABLE_CACHE_TYPE" DEFAULT_CALLABLE_CACHE = "CACHE_PER_MODE" + +DEFAULT_SAVED_COLLECTIONS = ["losses"] diff --git a/smdebug/core/hook.py b/smdebug/core/hook.py index 10577fed6..ef3171638 100644 --- a/smdebug/core/hook.py +++ b/smdebug/core/hook.py @@ -22,6 +22,7 @@ ) from smdebug.core.collection_manager import CollectionManager from smdebug.core.config_constants import ( + DEFAULT_SAVED_COLLECTIONS, DEFAULT_WORKER_NAME, LATEST_GLOBAL_STEP_SAVED, LATEST_GLOBAL_STEP_SEEN, @@ -343,6 +344,13 @@ def _get_collections_to_save_for_step(self) -> Set["Collection"]: ) return self._collections_to_save_for_step + def is_tensor_saved_for_step(self, tensor_name): + collections_to_save = self._get_collections_to_save_for_step() + for c in collections_to_save: + if match_inc(tensor_name, c.include_regex): + return True + return False + def _get_collections_with_tensor(self, tensor_name) -> Set["Collection"]: self._assert_prep() # for tf this will be prepopulated in check_and_add_tensor @@ -364,6 +372,14 @@ def _get_collections_with_tensor(self, tensor_name) -> Set["Collection"]: def _get_default_collections(self): pass + def has_default_hook_configuration(self): + # Used in the internal framework forks to determine if the hook + # is using the default hook configuration + collections_being_saved = [x.name for x in self._collections_to_save] + if set(collections_being_saved) == set(DEFAULT_SAVED_COLLECTIONS): + return True + return False + def _prepare_collections(self): """Populate collections_to_save and ensure every collection has a save_config and reduction_config.""" @@ -525,6 +541,13 @@ def _increment_step(self): self.mode_steps[ModeKeys.GLOBAL] = self.step self._collections_to_save_for_step = None + # Called in the internal AWS codebase to determine + # if a particular tensor value should be saved + def should_save_tensor_or_collection(self, tensor_name: str, collection_name: str) -> bool: + if self._is_collection_being_saved_for_step(collection_name): + return True + return self.is_tensor_saved_for_step(tensor_name) + def _write_state(self): if self.state_store.is_checkpoint_updated(): current_state = dict() diff --git a/smdebug/tensorflow/base_hook.py b/smdebug/tensorflow/base_hook.py index a8cc9f679..44eb92d66 100644 --- a/smdebug/tensorflow/base_hook.py +++ b/smdebug/tensorflow/base_hook.py @@ -20,6 +20,7 @@ # Local from .collection import CollectionKeys, CollectionManager +from .constants import TF_DEFAULT_SAVED_COLLECTIONS from .singleton_utils import set_hook from .utils import ( TFDistributionStrategy, @@ -217,6 +218,14 @@ def export_collections(self): collection_file_name = f"{self.worker}_collections.json" self.collection_manager.export(self.out_dir, collection_file_name) + def has_default_hook_configuration(self): + # Used in AWS TF to determine if the hook + # is using the default hook configuration + collections_being_saved = [x.name for x in self._collections_to_save] + if set(collections_being_saved) == set(TF_DEFAULT_SAVED_COLLECTIONS): + return True + return False + def _get_custom_and_default_collections(self) -> Tuple[Set["Collection"], Set["Collection"]]: if self._custom_collections is None: self._custom_collections = set() diff --git a/smdebug/tensorflow/constants.py b/smdebug/tensorflow/constants.py index 4e52a0114..bde7e52ca 100644 --- a/smdebug/tensorflow/constants.py +++ b/smdebug/tensorflow/constants.py @@ -1,3 +1,5 @@ SMDEBUG_GRADIENTS_KEY = "smdebug_gradients" SMDEBUG_LAYER_OUTPUTS_KEY = "smdebug_layer_outputs" SMDEBUG_PREFIX = "smdebug_" + +TF_DEFAULT_SAVED_COLLECTIONS = ["losses", "metrics", "sm_metrics"] diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index ccea3ba47..7c0ee2c23 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -30,6 +30,7 @@ get_model_input_export_name, get_model_output_export_name, is_keras_optimizer, + is_tf_version_2_3_x, is_tf_version_2x, ) @@ -71,6 +72,14 @@ def __init__( ) # stores tensors custom tensors saved by users every step self.saved_layers = dict() self.has_registered_model = False + # supports_tf_logs property was introduced in TF 2.3.0 + # it indicates to the framework that the callback is not + # limited to reading only numpy logs + self._supports_tf_logs = True + # TF 2.3.0 has a callback ordering bug + # this flag indicated to the train_batch_begin callback + # the the step was already incremented in the on_train_begin callback + self.step_incremented_in_on_train_begin = False def _is_not_supported(self): if self.distribution_strategy is None: @@ -109,7 +118,8 @@ def register_model(self, model): # It attaches a hook to every layer of the model to capture # layer values self.model = model - self._wrap_model_with_input_output_saver() + if self.tape is not None: + self._wrap_model_with_input_output_saver() self.has_registered_model = True def _get_matching_collections( @@ -348,7 +358,10 @@ def _prepare_tensors_available_post_step(self): # Add tensor to custom collections for custom_coll in custom_collections: - if match_inc(tensor_ref.name, custom_coll.include_regex): + if ( + match_inc(tensor_ref.name, custom_coll.include_regex) + and tensor_ref.tf_obj is not None + ): custom_coll.add_for_mode(tensor_ref.tf_obj, self.mode) if custom_coll not in self.tensor_to_collections[tensor_ref.name]: self.tensor_to_collections[tensor_ref.name].add(custom_coll) @@ -390,6 +403,12 @@ def _save_custom_tensors_post_step(self): self._save_tensor_to_file(tensor_name, tensor_value, collection_names) self.custom_tensors_to_save.clear() + def should_save_layer(self, layer_name): + # Called in AWS TF to determine + # if a particular layer value + # should be saved + return self.should_save_tensor_or_collection(layer_name, CollectionKeys.LAYERS) + def _save_tensor_to_file(self, tensor_name, tensor_value, collections): if isinstance(collections, set) is False: collections = {collections} @@ -418,6 +437,31 @@ def _save_tensor_to_file(self, tensor_name, tensor_value, collections): collection.set_tensor_ref(tensor_ref) self._save_for_tensor(tensor_name, t, check_before_write=True) + def save_gradients_from_logs(self, gradients): + if gradients is not None: + gradient_collection = self.get_collection(CollectionKeys.GRADIENTS) + step_collections = self._get_collections_to_save_for_step() + collections_to_write = ( + {gradient_collection} if gradient_collection in step_collections else set() + ) + if gradients and isinstance(gradients[0], tuple) is False: + gradients = zip(self.model.trainable_variables, gradients) + for v, g in gradients: + if isinstance(v, tf.Tensor): + # Tensor.name is meaningless with eager execution + layer_name = str(v.numpy(), "utf-8") + elif isinstance(v, tf.Variable): + layer_name = v.name + else: + layer_name = v + layer_name = layer_name.split(":")[0] + export_name = "gradients/" + layer_name + "Grad" + if isinstance(g, IndexedSlices): + # This class is a simple wrapper for a pair of Tensor objects + # See: https://www.tensorflow.org/api_docs/python/tf/IndexedSlices + g = g.values + self._save_tensor_to_file(export_name, g, collections_to_write) + def save_smdebug_logs(self, logs): if logs is None: return @@ -437,24 +481,10 @@ def save_smdebug_logs(self, logs): ) # Save Gradients elif key == SMDEBUG_GRADIENTS_KEY: - gradients = logs[key] - if gradients is not None: - for g, v in zip(gradients, self.model.trainable_variables): - layer_name = v.name - if len(layer_name.split(":")) > 1: - layer_name = layer_name.split(":")[0] - export_name = "gradients/" + layer_name + "Grad" - if isinstance(g, IndexedSlices): - # This class is a simple wrapper for a pair of Tensor objects - # See: https://www.tensorflow.org/api_docs/python/tf/IndexedSlices - g = g.values - tensors_to_save.append((export_name, g)) - collections_to_write = {self.get_collection(CollectionKeys.GRADIENTS)} + self.save_gradients_from_logs(logs[key]) # Save Intermediate Layers elif key == SMDEBUG_LAYER_OUTPUTS_KEY: - layer_outputs = logs[key] - self.save_layer_outputs(layer_outputs) - self.save_layer_inputs(logs[ModelInput.INPUTS], layer_outputs) + self._save_layer_values(logs[key]) # Save Model Inputs elif key in ModelInputs: export_name = get_model_input_export_name() @@ -489,10 +519,9 @@ def _save_metrics(self, batch, logs, force_save=False): self._add_metric(metric_name=key) self._save_for_tensor(key, logs[key], check_before_write=False) - def _save_layer_input_and_outputs(self, grad_tape=False): - # Iterates over all the saved layers for input and output values - if is_tf_version_2x() is False or (grad_tape is False and self.model.run_eagerly is False): - # This function only works when the run_eagerly is True + def _save_layer_input_and_outputs(self): + # Run only for GradTape + if self.tape is None: return for layer_name in self.saved_layers: # Save Input @@ -520,7 +549,6 @@ def _save_tensors_post_step(self, batch, logs): # weights, metrics self._save_metrics(batch, logs) self.save_smdebug_logs(logs) - self._save_layer_input_and_outputs() self._save_custom_tensors_post_step() if is_tf_version_2x() and tf.executing_eagerly(): @@ -615,6 +643,13 @@ def _on_any_mode_begin(self, mode): self.graph = tf.get_default_graph() self.set_mode(mode) + if self.prepared_collections is False and is_tf_version_2_3_x(): + # Addresses ordering issues in TF 2.3.0 + # sets prepared_collections to True here + self._prepare_collections() + self._increment_step() + self.step_incremented_in_on_train_begin = True + # have to clear callable cache if we are not caching per mode self.callable_cache.change_mode() @@ -658,7 +693,12 @@ def _on_any_batch_begin(self, batch, mode, logs=None): # Write the gradients of the past step if the writer is still available. if self.writer is not None or len(self.writer_map): self._close_writers() - self._increment_step() + + # Addresses callback ordering bug in TF 2.3.0 + if self.step_incremented_in_on_train_begin is False: + self._increment_step() + else: + self.step_incremented_in_on_train_begin = False if self.prepared_collections is False: # sets prepared_collections to True here @@ -668,7 +708,6 @@ def _on_any_batch_begin(self, batch, mode, logs=None): if (is_tf_version_2x() and tf.executing_eagerly()) or self._validate_exec_function( self._get_exec_function(mode) ): - self._wrap_model_with_input_output_saver() self._prepare_layers(mode) self._prepare_tensors_available_post_step() self._prepared_tensors[mode] = True @@ -698,33 +737,23 @@ def on_test_batch_begin(self, batch, logs=None): def on_predict_batch_begin(self, batch, logs=None): self._on_any_batch_begin(batch, ModeKeys.PREDICT, logs=logs) - def _save_layer_values(self, layer_outputs, collection, model=None, inputs=None): - if model is None: - if self.model: - model = self.model - else: - return - if layer_outputs is not None: - tensors_to_save = [] - step_collections = self._get_collections_to_save_for_step() - collections_to_write = {collection} if collection in step_collections else set() - tensor_suffix = "output" - if inputs is not None: - layer_outputs = [inputs] + layer_outputs - tensor_suffix = "input" - for o, l in zip(layer_outputs, model.layers): - export_name = get_export_name_for_keras(l.name, tensor_suffix) - tensors_to_save.append((export_name, o)) - for t_name, t_value in tensors_to_save: - self._save_tensor_to_file(t_name, t_value, collections_to_write) - - def save_layer_outputs(self, layer_outputs, model=None): - self._save_layer_values(layer_outputs, self.get_collection(CollectionKeys.LAYERS), model) - - def save_layer_inputs(self, x, layer_outputs, model=None): - self._save_layer_values( - layer_outputs, self.get_collection(CollectionKeys.LAYERS), model, inputs=x - ) + def _save_layer_values(self, logs): + if logs is None: + return + step_collections = self._get_collections_to_save_for_step() + layer_collection = self.get_collection(CollectionKeys.LAYERS) + collections_to_write = {layer_collection} if layer_collection in step_collections else set() + for layer_name, layer_input, layer_output in logs: + # Cast layer_name to str since it can also be of type bytes + # when run with mirrored strategy + if len(layer_input) == 1: + # Layer Inputs are flattened and passed as a list into + # the next layer. Unpacking it speeds up the _make_numpy fn. + layer_input = layer_input[0] + layer_input_tensor_name = get_export_name_for_keras(str(layer_name), "input") + self._save_tensor_to_file(layer_input_tensor_name, layer_input, collections_to_write) + layer_output_tensor_name = get_export_name_for_keras(str(layer_name), "output") + self._save_tensor_to_file(layer_output_tensor_name, layer_output, collections_to_write) def _write_optimizer_variables(self): optimizer_collections = self.collection_manager.get(CollectionKeys.OPTIMIZER_VARIABLES) @@ -951,7 +980,7 @@ def run(*args, **kwargs): ) self._write_optimizer_variables() - self._save_layer_input_and_outputs(grad_tape=True) + self._save_layer_input_and_outputs() if not ((isinstance(loss, tf.Tensor)) and hasattr(loss, "numpy")): return grads self._add_metric(metric_name="loss", metric_value=loss) diff --git a/smdebug/tensorflow/utils.py b/smdebug/tensorflow/utils.py index 34a713629..90dfa5d29 100644 --- a/smdebug/tensorflow/utils.py +++ b/smdebug/tensorflow/utils.py @@ -384,3 +384,7 @@ def get_keras_mode(mode): def is_tf_version_2x(): return version.parse(tf.__version__) >= version.parse("2.0.0") + + +def is_tf_version_2_3_x(): + return version.parse(tf.__version__) >= version.parse("2.3.0") diff --git a/tests/tensorflow2/test_keras.py b/tests/tensorflow2/test_keras.py index 089cc2d9e..ef4457336 100644 --- a/tests/tensorflow2/test_keras.py +++ b/tests/tensorflow2/test_keras.py @@ -520,11 +520,7 @@ def test_include_regex(out_dir, tf_eager_mode): tr = create_trial_fast_refresh(out_dir) tnames = tr.tensor_names(collection="custom_coll") - - if tf_eager_mode: - assert len(tnames) == (12 if is_tf_2_2() else 8) - else: - assert len(tnames) == 8 + assert len(tnames) == 12 for tname in tnames: assert tr.tensor(tname).value(0) is not None diff --git a/tests/tensorflow2/test_should_save_tensor.py b/tests/tensorflow2/test_should_save_tensor.py new file mode 100644 index 000000000..6f09da2f1 --- /dev/null +++ b/tests/tensorflow2/test_should_save_tensor.py @@ -0,0 +1,59 @@ +# Third Party +import tensorflow as tf + +# First Party +import smdebug.tensorflow as smd +from smdebug.core.collection import CollectionKeys +from smdebug.tensorflow import SaveConfig +from smdebug.tensorflow.constants import TF_DEFAULT_SAVED_COLLECTIONS + +model = tf.keras.models.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10, activation="softmax"), + ] +) + + +def helper_create_hook(out_dir, collections, include_regex=None): + hook = smd.KerasHook( + out_dir, save_config=SaveConfig(save_interval=3), include_collections=collections + ) + + if include_regex: + for collection in collections: + hook.get_collection(collection).include(include_regex) + + hook.register_model(model) + hook.on_train_begin() + return hook + + +def test_should_save_tensor_with_default_collections(out_dir): + hook = helper_create_hook(out_dir, TF_DEFAULT_SAVED_COLLECTIONS) + for layer in model.layers: + layer_name = layer.name + assert not hook.should_save_tensor_or_collection(layer_name, CollectionKeys.GRADIENTS) + assert not hook.should_save_tensor_or_collection(layer_name, CollectionKeys.LAYERS) + + +def test_should_save_tensor_with_tf_collection(out_dir): + hook = helper_create_hook(out_dir, [CollectionKeys.GRADIENTS, CollectionKeys.LAYERS]) + for layer in model.layers: + layer_name = layer.name + assert hook.should_save_tensor_or_collection(layer_name, CollectionKeys.GRADIENTS) + assert hook.should_save_tensor_or_collection(layer_name, CollectionKeys.LAYERS) + + +def test_should_save_tensor_with_custom_collection(out_dir): + hook = helper_create_hook(out_dir, ["custom_coll"], include_regex="dense") + for layer in model.layers: + layer_name = layer.name + if "dense" in layer_name: + assert hook.should_save_tensor_or_collection(layer_name, CollectionKeys.GRADIENTS) + assert hook.should_save_tensor_or_collection(layer_name, CollectionKeys.LAYERS) + else: + assert not hook.should_save_tensor_or_collection(layer_name, CollectionKeys.GRADIENTS) + assert not hook.should_save_tensor_or_collection(layer_name, CollectionKeys.LAYERS) diff --git a/tests/tensorflow2/utils.py b/tests/tensorflow2/utils.py index 591504d34..0cfec9217 100644 --- a/tests/tensorflow2/utils.py +++ b/tests/tensorflow2/utils.py @@ -12,7 +12,7 @@ def is_tf_2_2(): number of tensor_names emitted by 1. :return: bool """ - if version.parse(tf.__version__) == version.parse("2.2.0"): + if version.parse(tf.__version__) >= version.parse("2.2.0"): return True return False diff --git a/tests/zero_code_change/test_mxnet_gluon_integration.py b/tests/zero_code_change/test_mxnet_gluon_integration.py index e14e72cdf..07a18ba86 100644 --- a/tests/zero_code_change/test_mxnet_gluon_integration.py +++ b/tests/zero_code_change/test_mxnet_gluon_integration.py @@ -117,6 +117,9 @@ def validate(): from smdebug.mxnet import get_hook hook = get_hook() + # Check if the hook was executed with the default + # hook configuration + assert hook.has_default_hook_configuration() out_dir = hook.out_dir print("Created the trial with out_dir {0}".format(out_dir)) tr = create_trial(out_dir) diff --git a/tests/zero_code_change/test_pytorch_integration.py b/tests/zero_code_change/test_pytorch_integration.py index 63bef2f97..eb6d06536 100644 --- a/tests/zero_code_change/test_pytorch_integration.py +++ b/tests/zero_code_change/test_pytorch_integration.py @@ -64,6 +64,9 @@ def test_pytorch(script_mode, use_loss_module): hook = smd.get_hook() print(f"hook = {hook}") + # Check if the hook was executed with the default + # hook configuration + assert hook.has_default_hook_configuration() from smdebug.trials import create_trial diff --git a/tests/zero_code_change/test_tensorflow2_gradtape_integration.py b/tests/zero_code_change/test_tensorflow2_gradtape_integration.py index d44b851a0..414ec4587 100644 --- a/tests/zero_code_change/test_tensorflow2_gradtape_integration.py +++ b/tests/zero_code_change/test_tensorflow2_gradtape_integration.py @@ -78,6 +78,8 @@ def helper_test_keras_v2_gradienttape( train_acc_metric.reset_states() hook = smd.get_hook() assert hook + if default: + assert hook.has_default_hook_configuration() hook.close() # Check that hook created and tensors saved trial = smd.create_trial(path=sim.out_dir) diff --git a/tests/zero_code_change/test_tensorflow2_integration.py b/tests/zero_code_change/test_tensorflow2_integration.py index bc0a4d13e..9fddbdc32 100644 --- a/tests/zero_code_change/test_tensorflow2_integration.py +++ b/tests/zero_code_change/test_tensorflow2_integration.py @@ -46,7 +46,7 @@ def train_step(self, data): self.compiled_metrics.update_state(y, y_pred, sample_weight) result_dict = {m.name: m.result() for m in self.metrics} result_dict.update({f"{SMDEBUG_PREFIX}y": y}) - result_dict.update({f"{SMDEBUG_PREFIX}gradients": y}) + result_dict.update({f"{SMDEBUG_PREFIX}gradients": gradients}) # to pass gradients and labels to the hook, add logs with the prefix SMDEBUG_ # For examples: @@ -117,6 +117,9 @@ def helper_test_keras_v2(script_mode: bool = False, eager_mode: bool = True): hook = smd.get_hook() assert hook + # Check if the hook was executed with the default + # hook configuration + assert hook.has_default_hook_configuration() hook.close() # Check that hook created and tensors saved trial = smd.create_trial(path=sim.out_dir)