From 27cbc8cee8259e6d8e23a0ba378a614819cab031 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Thu, 16 Jul 2020 00:53:20 -0700 Subject: [PATCH 01/11] save custom tensors --- .../scripts/tf_save_metrics_gradient_tape.py | 2 +- smdebug/core/hook.py | 16 +++++++++ smdebug/mxnet/hook.py | 1 + smdebug/pytorch/hook.py | 1 + smdebug/tensorflow/keras.py | 20 ++++++----- tests/mxnet/mnist_gluon_model.py | 3 ++ tests/mxnet/test_custom_tensor.py | 30 ++++++++++++++++ tests/pytorch/test_save_custom_tensor.py | 35 +++++++++++++++++++ tests/pytorch/utils.py | 4 ++- tests/tensorflow2/test_keras.py | 6 ++-- 10 files changed, 104 insertions(+), 14 deletions(-) create mode 100644 tests/mxnet/test_custom_tensor.py create mode 100644 tests/pytorch/test_save_custom_tensor.py diff --git a/examples/tensorflow2/scripts/tf_save_metrics_gradient_tape.py b/examples/tensorflow2/scripts/tf_save_metrics_gradient_tape.py index db5284373..08bcf0b53 100644 --- a/examples/tensorflow2/scripts/tf_save_metrics_gradient_tape.py +++ b/examples/tensorflow2/scripts/tf_save_metrics_gradient_tape.py @@ -87,7 +87,7 @@ def helper_keras_gradtape( with hook.wrap_tape(tf.GradientTape(persistent=persistent)) as tape: logits = model(data, training=True) loss_value = cce(labels, logits) - hook.save_custom_tensor("y_labels", labels, "outputs") + hook.save_tensor("y_labels", labels, "outputs") grads = tape.gradient(loss_value, model.variables) # By default, the resources held by a GradientTape are released as diff --git a/smdebug/core/hook.py b/smdebug/core/hook.py index 7e70404da..2dfff1f22 100644 --- a/smdebug/core/hook.py +++ b/smdebug/core/hook.py @@ -880,6 +880,7 @@ def __init__( ) self.exported_collections = False self.data_type_name = data_type_name + self.custom_tensors_to_save = dict() def _cleanup(self): if not self.exported_collections: @@ -905,6 +906,21 @@ def _write(self, module_name, var, suffix, idx): ) return idx + def save_tensor(self, tensor_name, tensor_value, collections_to_write=None): + if collections_to_write is None: + collections_to_write = CollectionKeys.DEFAULT + if isinstance(collections_to_write, str): + collections_to_write = [collections_to_write] + for collection in collections_to_write: + self.custom_tensors_to_save[tensor_name] = (tensor_value, collection) + + def _save_custom_tensors_post_step(self): + for tensor_name in self.custom_tensors_to_save: + tensor_value, collection_names = self.custom_tensors_to_save[tensor_name] + c = self.collection_manager.get(collection_names, create=True) + c.add_tensor_name(tensor_name) + self._write_raw_tensor(tensor_name, tensor_value, [c]) + def _write_inputs(self, name, inputs): tensor_name = name + CallbackHook.INPUT_TENSOR_SUFFIX idx = self.written_tensor_name_for_step.get(tensor_name, 0) diff --git a/smdebug/mxnet/hook.py b/smdebug/mxnet/hook.py index cacdfb98c..aa1fbeca7 100644 --- a/smdebug/mxnet/hook.py +++ b/smdebug/mxnet/hook.py @@ -169,6 +169,7 @@ def forward_hook(self, block, inputs, outputs): # Output output tensors self._write_outputs(block_name, outputs) + self._save_custom_tensors_post_step() self.last_saved_step = self.step diff --git a/smdebug/pytorch/hook.py b/smdebug/pytorch/hook.py index fe8834fda..790473b55 100644 --- a/smdebug/pytorch/hook.py +++ b/smdebug/pytorch/hook.py @@ -162,6 +162,7 @@ def forward_hook(self, module, inputs, outputs): # Output output tensors self._write_outputs(module_name, outputs) + self._save_custom_tensors_post_step() self.last_saved_step = self.step def backward_hook(self, tname): diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index cd53a93e7..21b2b324d 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -385,7 +385,9 @@ def _add_metric(self, metric_name, metric_value: tf.Tensor = None): coll.set_tensor_ref(TensorRef.from_non_graph_var(metric_name)) self.tensor_to_collections[metric_name] = {coll} - def save_custom_tensor(self, tensor_name, tensor_value, collections_to_write): + def save_tensor(self, tensor_name, tensor_value, collections_to_write=None): + if collections_to_write is None: + collections_to_write = CollectionKeys.DEFAULT if isinstance(collections_to_write, str): collections_to_write = [collections_to_write] for collection in collections_to_write: @@ -394,9 +396,9 @@ def save_custom_tensor(self, tensor_name, tensor_value, collections_to_write): def _save_custom_tensors_post_step(self): for tensor_name in self.custom_tensors_to_save: tensor_value, collection_names = self.custom_tensors_to_save[tensor_name] - self._save_tensor(tensor_name, tensor_value, collection_names) + self._save_tensor_to_file(tensor_name, tensor_value, collection_names) - def _save_tensor(self, tensor_name, tensor_value, collections_to_write): + def _save_tensor_to_file(self, tensor_name, tensor_value, collections_to_write): if isinstance(collections_to_write, set) is False: collections_to_write = {collections_to_write} # Since this function modifies the set, there is a possibility @@ -442,7 +444,7 @@ def save_smdebug_logs(self, logs): else set() ) for t_name, t_value in tensors_to_save: - self._save_tensor(t_name, t_value, collections_to_write) + self._save_tensor_to_file(t_name, t_value, collections_to_write) elif key == SMDEBUG_GRADIENTS_KEY: tensors_to_save = [] gradients = logs[key] @@ -457,7 +459,7 @@ def save_smdebug_logs(self, logs): tensors_to_save.append((export_name, g)) collections_to_write = {self.get_collection(CollectionKeys.GRADIENTS)} for t_name, t_value in tensors_to_save: - self._save_tensor(t_name, t_value, collections_to_write) + self._save_tensor_to_file(t_name, t_value, collections_to_write) elif key == SMDEBUG_LAYER_OUTPUTS_KEY: layer_outputs = logs[key] self.save_layer_outputs(layer_outputs) @@ -473,7 +475,7 @@ def save_smdebug_logs(self, logs): else set() ) for t_name, t_value in tensors_to_save: - self._save_tensor(t_name, t_value, collections_to_write) + self._save_tensor_to_file(t_name, t_value, collections_to_write) def _save_metrics(self, batch, logs, force_save=False): # if force_save is True, doesn't check whether collection needs to be saved for steps @@ -510,7 +512,7 @@ def _save_layer_input_and_outputs(self, grad_tape=False): if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS) else set() ) - self._save_tensor(export_name, tensor.numpy(), input_collection) + self._save_tensor_to_file(export_name, tensor.numpy(), input_collection) # Save Output tensor = self.saved_layers[layer_name].layer_output export_name = get_export_name_for_keras(layer_name, tensor_type="output", tensor=tensor) @@ -520,7 +522,7 @@ def _save_layer_input_and_outputs(self, grad_tape=False): if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS) else set() ) - self._save_tensor(export_name, tensor.numpy(), output_collection) + self._save_tensor_to_file(export_name, tensor.numpy(), output_collection) def _save_tensors_post_step(self, batch, logs): # some tensors available as value from within hook are saved here @@ -723,7 +725,7 @@ def _save_layer_values(self, layer_outputs, collection, model=None, inputs=None) export_name = get_export_name_for_keras(l.name, tensor_suffix) tensors_to_save.append((export_name, o)) for t_name, t_value in tensors_to_save: - self._save_tensor(t_name, t_value, collections_to_write) + self._save_tensor_to_file(t_name, t_value, collections_to_write) def save_layer_outputs(self, layer_outputs, model=None): self._save_layer_values(layer_outputs, self.get_collection(CollectionKeys.LAYERS), model) diff --git a/tests/mxnet/mnist_gluon_model.py b/tests/mxnet/mnist_gluon_model.py index daeba9b6e..7d1a42d60 100644 --- a/tests/mxnet/mnist_gluon_model.py +++ b/tests/mxnet/mnist_gluon_model.py @@ -28,6 +28,7 @@ def run_mnist_gluon_model( make_input_zero=False, normalize_mean=0.13, normalize_std=0.31, + save_custom_tensor=False, ): batch_size = 4 if make_input_zero: @@ -103,6 +104,8 @@ def run_mnist_gluon_model( eval_acc_name = "loss_acc" # Start the training. + if save_custom_tensor: + hook.save_tensor("custom_tensor", mx.nd.array([1, 2, 3])) for epoch in range(1): train_loss, train_acc, valid_acc = 0.0, 0.0, 0.0 tic = time.time() diff --git a/tests/mxnet/test_custom_tensor.py b/tests/mxnet/test_custom_tensor.py new file mode 100644 index 000000000..649aa757e --- /dev/null +++ b/tests/mxnet/test_custom_tensor.py @@ -0,0 +1,30 @@ +# Standard Library +import shutil +from datetime import datetime + +# First Party +from smdebug import SaveConfig +from smdebug.core.collection import CollectionKeys +from smdebug.mxnet.hook import Hook as t_hook +from smdebug.trials import create_trial + +# Local +from .mnist_gluon_model import run_mnist_gluon_model + + +def test_hook(): + save_config = SaveConfig(save_steps=[0, 1, 2, 3]) + run_id = "trial_" + datetime.now().strftime("%Y%m%d-%H%M%S%f") + out_dir = "/tmp/newlogsRunTest/" + run_id + hook = t_hook(out_dir=out_dir, save_config=save_config) + run_mnist_gluon_model( + hook=hook, + num_steps_train=10, + num_steps_eval=10, + register_to_loss_block=True, + save_custom_tensor=True, + ) + trial = create_trial(out_dir) + custom_tensors = trial.tensor_names(collection=CollectionKeys.DEFAULT) + assert len(custom_tensors) + shutil.rmtree(out_dir) diff --git a/tests/pytorch/test_save_custom_tensor.py b/tests/pytorch/test_save_custom_tensor.py new file mode 100644 index 000000000..539b20761 --- /dev/null +++ b/tests/pytorch/test_save_custom_tensor.py @@ -0,0 +1,35 @@ +# Standard Library +import shutil +from datetime import datetime + +# Third Party +import torch +import torch.optim as optim + +# First Party +from smdebug.core.collection import CollectionKeys +from smdebug.pytorch import SaveConfig +from smdebug.pytorch.hook import Hook as t_hook +from smdebug.trials import create_trial + +# Local +from .utils import Net, train + + +def test_hook(): + run_id = "trial_" + datetime.now().strftime("%Y%m%d-%H%M%S%f") + out_dir = "/tmp/" + run_id + hook = t_hook( + out_dir=out_dir, + save_config=SaveConfig(save_steps=[0, 1, 2, 3]), + include_collections=["relu_activations"], + ) + + model = Net().to(torch.device("cpu")) + hook.register_module(model) + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + train(model, hook, torch.device("cpu"), optimizer, num_steps=10, save_custom_tensor=True) + trial = create_trial(out_dir) + custom_tensors = trial.tensor_names(collection=CollectionKeys.DEFAULT) + assert len(custom_tensors) + shutil.rmtree(out_dir) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 5ac1d399b..cec6c87f7 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -35,12 +35,14 @@ def forward(self, x): return F.log_softmax(x, dim=1) -def train(model, hook, device, optimizer, num_steps=500, set_modes=False): +def train(model, hook, device, optimizer, num_steps=500, set_modes=False, save_custom_tensor=False): if set_modes: hook.set_mode(modes.TRAIN) model.train() # for batch_idx, (data, target) in enumerate(train_loader): + if save_custom_tensor: + hook.save_tensor("custom_tensor", torch.tensor([[1.0, -1.0], [1.0, -1.0]])) for i in range(num_steps): batch_size = 32 data, target = torch.rand(batch_size, 1, 28, 28), torch.rand(batch_size).long() diff --git a/tests/tensorflow2/test_keras.py b/tests/tensorflow2/test_keras.py index ae9a30c50..c353dfd03 100644 --- a/tests/tensorflow2/test_keras.py +++ b/tests/tensorflow2/test_keras.py @@ -714,9 +714,9 @@ def test_save_custom_tensors(out_dir, tf_eager_mode): t1 = tf.constant([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]) t2 = tf.Variable([5 + 4j, 6 + 1j]) t3 = tf.Variable([False, False, False, True]) - hook.save_custom_tensor("custom_tensor_1", t1, include_collections) - hook.save_custom_tensor("custom_tensor_2", t2, include_collections) - hook.save_custom_tensor("custom_tensor_3", t3, include_collections) + hook.save_tensor("custom_tensor_1", t1, include_collections) + hook.save_tensor("custom_tensor_2", t2, include_collections) + hook.save_tensor("custom_tensor_3", t3, include_collections) helper_keras_fit( trial_dir=out_dir, From 7db230dbd12bbd452783b9b6656a337009e13988 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Wed, 22 Jul 2020 21:10:18 -0700 Subject: [PATCH 02/11] PR comments --- smdebug/tensorflow/keras.py | 4 +--- tests/mxnet/test_hook_all_zero.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index 21b2b324d..b90f1d8a9 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -385,9 +385,7 @@ def _add_metric(self, metric_name, metric_value: tf.Tensor = None): coll.set_tensor_ref(TensorRef.from_non_graph_var(metric_name)) self.tensor_to_collections[metric_name] = {coll} - def save_tensor(self, tensor_name, tensor_value, collections_to_write=None): - if collections_to_write is None: - collections_to_write = CollectionKeys.DEFAULT + def save_tensor(self, tensor_name, tensor_value, collections_to_write="default"): if isinstance(collections_to_write, str): collections_to_write = [collections_to_write] for collection in collections_to_write: diff --git a/tests/mxnet/test_hook_all_zero.py b/tests/mxnet/test_hook_all_zero.py index 1d6c0b00a..c9c693d80 100644 --- a/tests/mxnet/test_hook_all_zero.py +++ b/tests/mxnet/test_hook_all_zero.py @@ -36,8 +36,7 @@ def test_hook_all_zero(hook=None, out_dir=None): assert tr assert len(tr.steps()) == 4 - tnames = tr.tensor_names(regex="conv._input") - tname = tr.tensor_names(regex="conv._input")[0] + tname = tr.tensor_names(regex="conv.+_input")[0] conv_tensor_value = tr.tensor(tname).value(step_num=0) is_zero = np.all(conv_tensor_value == 0) assert is_zero == True From 22fe114fcc1986c3baf781ac4df9e585bb7a8c38 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Wed, 22 Jul 2020 23:04:57 -0700 Subject: [PATCH 03/11] pr comments --- smdebug/core/hook.py | 12 +++++++++--- smdebug/core/utils.py | 8 ++++++++ smdebug/mxnet/utils.py | 2 ++ smdebug/pytorch/utils.py | 2 ++ 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/smdebug/core/hook.py b/smdebug/core/hook.py index 2dfff1f22..62d0fd11b 100644 --- a/smdebug/core/hook.py +++ b/smdebug/core/hook.py @@ -44,6 +44,7 @@ match_inc, remove_claim_file, size_and_shape, + validate_custom_tensor_value, ) from smdebug.core.writer import FileWriter from smdebug.exceptions import InvalidCollectionConfiguration @@ -906,9 +907,10 @@ def _write(self, module_name, var, suffix, idx): ) return idx - def save_tensor(self, tensor_name, tensor_value, collections_to_write=None): - if collections_to_write is None: - collections_to_write = CollectionKeys.DEFAULT + def save_tensor(self, tensor_name, tensor_value, collections_to_write=CollectionKeys.DEFAULT): + if validate_custom_tensor_value(tensor_value, self._make_numpy_array) is False: + self.logger.warn("The tensor value could not be converted into a numpy value") + return if isinstance(collections_to_write, str): collections_to_write = [collections_to_write] for collection in collections_to_write: @@ -938,3 +940,7 @@ def _write_outputs(self, name, outputs): @abstractmethod def _export_model(self): pass + + @staticmethod + def _make_numpy_array(tensor_value): + pass diff --git a/smdebug/core/utils.py b/smdebug/core/utils.py index 02ca881c2..af871c527 100644 --- a/smdebug/core/utils.py +++ b/smdebug/core/utils.py @@ -297,6 +297,14 @@ def remove_file_if_exists(file_path): os.remove(file_path) +def validate_custom_tensor_value(tensor_value, make_numpy_fn): + try: + make_numpy_fn(tensor_value) + except TypeError: + return False + return True + + class SagemakerSimulator(object): """ Creates an environment variable pointing to a JSON config file, and creates the config file. diff --git a/smdebug/mxnet/utils.py b/smdebug/mxnet/utils.py index ab27fe1ad..aa228145e 100644 --- a/smdebug/mxnet/utils.py +++ b/smdebug/mxnet/utils.py @@ -54,6 +54,8 @@ def make_numpy_array(x): elif isinstance(x, tuple): # todo: fix this, will crash return np.asarray(x, dtype=x.dtype) + elif isinstance(x, list): + return np.asarray(x) else: raise TypeError( "_make_numpy_array only accepts input types of numpy.ndarray, scalar," diff --git a/smdebug/pytorch/utils.py b/smdebug/pytorch/utils.py index 95359257c..ea0caf949 100644 --- a/smdebug/pytorch/utils.py +++ b/smdebug/pytorch/utils.py @@ -45,6 +45,8 @@ def make_numpy_array(x): return x.to(torch.device("cpu")).data.numpy() elif isinstance(x, tuple): return np.asarray(x, dtype=x.dtype) + elif isinstance(x, list): + return np.asarray(x) else: raise TypeError( "_make_numpy_array only accepts input types of numpy.ndarray, scalar," From 0aab04d2ab1262faa2670f7d38fa3fe98e62d8f9 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Wed, 22 Jul 2020 23:09:07 -0700 Subject: [PATCH 04/11] validate in tf2 --- smdebug/tensorflow/keras.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index b90f1d8a9..af818d354 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -8,7 +8,7 @@ # First Party from smdebug.core.modes import ModeKeys, str_to_mode_keys -from smdebug.core.utils import match_inc +from smdebug.core.utils import match_inc, validate_custom_tensor_value from smdebug.tensorflow.callable_cache import CallableCache from smdebug.tensorflow.utils import InputOutputSaver, get_layer_call_fn @@ -386,6 +386,9 @@ def _add_metric(self, metric_name, metric_value: tf.Tensor = None): self.tensor_to_collections[metric_name] = {coll} def save_tensor(self, tensor_name, tensor_value, collections_to_write="default"): + if validate_custom_tensor_value(tensor_value, self._make_numpy_array) is False: + self.logger.warn("The tensor value could not be converted into a numpy value") + return if isinstance(collections_to_write, str): collections_to_write = [collections_to_write] for collection in collections_to_write: From 2b847b40ab6818b46971cac6a3c4d12a3d9a8b51 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Wed, 22 Jul 2020 23:37:54 -0700 Subject: [PATCH 05/11] retrigger CI From 2b21e02d91c375b7ac95e2930986d42aef0feff8 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Mon, 27 Jul 2020 10:49:19 -0700 Subject: [PATCH 06/11] clear dicts after loop --- smdebug/core/hook.py | 1 + smdebug/tensorflow/keras.py | 1 + 2 files changed, 2 insertions(+) diff --git a/smdebug/core/hook.py b/smdebug/core/hook.py index 62d0fd11b..1885ca09d 100644 --- a/smdebug/core/hook.py +++ b/smdebug/core/hook.py @@ -922,6 +922,7 @@ def _save_custom_tensors_post_step(self): c = self.collection_manager.get(collection_names, create=True) c.add_tensor_name(tensor_name) self._write_raw_tensor(tensor_name, tensor_value, [c]) + self.custom_tensors_to_save.clear() def _write_inputs(self, name, inputs): tensor_name = name + CallbackHook.INPUT_TENSOR_SUFFIX diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index af818d354..8c5ffeceb 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -398,6 +398,7 @@ def _save_custom_tensors_post_step(self): for tensor_name in self.custom_tensors_to_save: tensor_value, collection_names = self.custom_tensors_to_save[tensor_name] self._save_tensor_to_file(tensor_name, tensor_value, collection_names) + self.custom_tensors_to_save.clear() def _save_tensor_to_file(self, tensor_name, tensor_value, collections_to_write): if isinstance(collections_to_write, set) is False: From 69282f78b8020642381e432fcc7a6b372129bac9 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Mon, 27 Jul 2020 12:03:36 -0700 Subject: [PATCH 07/11] nit --- tests/pytorch/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index cec6c87f7..ef03aff04 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -40,7 +40,6 @@ def train(model, hook, device, optimizer, num_steps=500, set_modes=False, save_c hook.set_mode(modes.TRAIN) model.train() - # for batch_idx, (data, target) in enumerate(train_loader): if save_custom_tensor: hook.save_tensor("custom_tensor", torch.tensor([[1.0, -1.0], [1.0, -1.0]])) for i in range(num_steps): From 55120ceb77d598f2445c91967bb9a7cac3527a84 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Mon, 27 Jul 2020 12:17:27 -0700 Subject: [PATCH 08/11] save custom api --- tests/mxnet/test_custom_tensor.py | 2 +- tests/pytorch/utils.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/mxnet/test_custom_tensor.py b/tests/mxnet/test_custom_tensor.py index 649aa757e..d96a410ab 100644 --- a/tests/mxnet/test_custom_tensor.py +++ b/tests/mxnet/test_custom_tensor.py @@ -26,5 +26,5 @@ def test_hook(): ) trial = create_trial(out_dir) custom_tensors = trial.tensor_names(collection=CollectionKeys.DEFAULT) - assert len(custom_tensors) + assert len(custom_tensors) == 4 shutil.rmtree(out_dir) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index ef03aff04..f68d9d146 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -36,12 +36,18 @@ def forward(self, x): def train(model, hook, device, optimizer, num_steps=500, set_modes=False, save_custom_tensor=False): + if save_custom_tensor: + hook.save_tensor("custom_tensor_0", torch.tensor([[1.0, -1.0], [1.0, -1.0]])) + if set_modes: hook.set_mode(modes.TRAIN) + if save_custom_tensor: + hook.save_tensor("custom_tensor_1", torch.tensor([[1.0, -1.0], [1.0, -1.0]])) + model.train() if save_custom_tensor: - hook.save_tensor("custom_tensor", torch.tensor([[1.0, -1.0], [1.0, -1.0]])) + hook.save_tensor("custom_tensor_2", torch.tensor([[1.0, -1.0], [1.0, -1.0]])) for i in range(num_steps): batch_size = 32 data, target = torch.rand(batch_size, 1, 28, 28), torch.rand(batch_size).long() @@ -52,6 +58,8 @@ def train(model, hook, device, optimizer, num_steps=500, set_modes=False, save_c hook.record_tensor_value("nll_loss", tensor_value=loss) loss.backward() optimizer.step() + if save_custom_tensor: + hook.save_tensor("custom_tensor_3", torch.tensor([[1.0, -1.0], [1.0, -1.0]])) def evaluate(model, hook, device, num_steps=100, set_modes=False): From 5383d8e3fb523e3586c5f5d99aac7765c4832389 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Mon, 27 Jul 2020 14:20:51 -0700 Subject: [PATCH 09/11] retrigger CI From 4a801181960a557ad6ec15b1feee42b09f671e7b Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Mon, 27 Jul 2020 15:45:29 -0700 Subject: [PATCH 10/11] save tensors --- smdebug/mxnet/hook.py | 2 +- smdebug/pytorch/hook.py | 1 + tests/mxnet/mnist_gluon_model.py | 8 +++++++- tests/mxnet/test_custom_tensor.py | 4 +++- tests/pytorch/test_save_custom_tensor.py | 2 +- tests/pytorch/utils.py | 4 ++-- 6 files changed, 15 insertions(+), 6 deletions(-) diff --git a/smdebug/mxnet/hook.py b/smdebug/mxnet/hook.py index aa1fbeca7..7234fbf88 100644 --- a/smdebug/mxnet/hook.py +++ b/smdebug/mxnet/hook.py @@ -154,6 +154,7 @@ def forward_pre_hook(self, block, inputs): self.exported_collections = True self.last_block = block + self._save_custom_tensors_post_step() # This hook is invoked by trainer after running the forward pass. def forward_hook(self, block, inputs, outputs): @@ -169,7 +170,6 @@ def forward_hook(self, block, inputs, outputs): # Output output tensors self._write_outputs(block_name, outputs) - self._save_custom_tensors_post_step() self.last_saved_step = self.step diff --git a/smdebug/pytorch/hook.py b/smdebug/pytorch/hook.py index 790473b55..c50debf8a 100644 --- a/smdebug/pytorch/hook.py +++ b/smdebug/pytorch/hook.py @@ -173,6 +173,7 @@ def back(grad): if grad is not None: # self.logger.debug(f"Processing the backward step " f"{self.step} for {tname}") self._save_for_tensor(self.GRADIENT_PREFIX + tname, grad) + self._save_custom_tensors_post_step() return back diff --git a/tests/mxnet/mnist_gluon_model.py b/tests/mxnet/mnist_gluon_model.py index 7d1a42d60..028f0da1a 100644 --- a/tests/mxnet/mnist_gluon_model.py +++ b/tests/mxnet/mnist_gluon_model.py @@ -105,7 +105,7 @@ def run_mnist_gluon_model( # Start the training. if save_custom_tensor: - hook.save_tensor("custom_tensor", mx.nd.array([1, 2, 3])) + hook.save_tensor("custom_tensor_1", mx.nd.array([1, 2, 3])) for epoch in range(1): train_loss, train_acc, valid_acc = 0.0, 0.0, 0.0 tic = time.time() @@ -114,6 +114,8 @@ def run_mnist_gluon_model( i = 0 for data, label in train_data: + if save_custom_tensor: + hook.save_tensor("custom_tensor_2", mx.nd.array([1, 2, 3])) data = data.as_in_context(mx.cpu(0)) # forward + backward with autograd.record(): @@ -127,6 +129,10 @@ def run_mnist_gluon_model( train_acc += acc(output, label) # hook.save_scalar(train_loss_name, train_loss) # hook.save_scalar(train_acc_name, train_acc) + if save_custom_tensor: + # This tensor will not be added to default collections since + # collections have already been exported + hook.save_tensor("custom_tensor_3", mx.nd.array([1, 2, 3])) i += 1 if num_steps_train is not None and i >= num_steps_train: break diff --git a/tests/mxnet/test_custom_tensor.py b/tests/mxnet/test_custom_tensor.py index d96a410ab..0c87f8331 100644 --- a/tests/mxnet/test_custom_tensor.py +++ b/tests/mxnet/test_custom_tensor.py @@ -26,5 +26,7 @@ def test_hook(): ) trial = create_trial(out_dir) custom_tensors = trial.tensor_names(collection=CollectionKeys.DEFAULT) - assert len(custom_tensors) == 4 + all_tensors = trial.tensor_names() + assert len(custom_tensors) == 2 + assert len(all_tensors) == 4 shutil.rmtree(out_dir) diff --git a/tests/pytorch/test_save_custom_tensor.py b/tests/pytorch/test_save_custom_tensor.py index 539b20761..e80131e99 100644 --- a/tests/pytorch/test_save_custom_tensor.py +++ b/tests/pytorch/test_save_custom_tensor.py @@ -31,5 +31,5 @@ def test_hook(): train(model, hook, torch.device("cpu"), optimizer, num_steps=10, save_custom_tensor=True) trial = create_trial(out_dir) custom_tensors = trial.tensor_names(collection=CollectionKeys.DEFAULT) - assert len(custom_tensors) + assert len(custom_tensors) == 3 shutil.rmtree(out_dir) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index f68d9d146..45978cfd1 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -56,10 +56,10 @@ def train(model, hook, device, optimizer, num_steps=500, set_modes=False, save_c output = model(Variable(data, requires_grad=True)) loss = F.nll_loss(output, target) hook.record_tensor_value("nll_loss", tensor_value=loss) + if save_custom_tensor: + hook.save_tensor("custom_tensor_3", torch.tensor([[1.0, -1.0], [1.0, -1.0]])) loss.backward() optimizer.step() - if save_custom_tensor: - hook.save_tensor("custom_tensor_3", torch.tensor([[1.0, -1.0], [1.0, -1.0]])) def evaluate(model, hook, device, num_steps=100, set_modes=False): From da3e711fadc642df2237904c6662eef7af04501b Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Mon, 27 Jul 2020 15:49:42 -0700 Subject: [PATCH 11/11] pytorch assert --- tests/pytorch/test_save_custom_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_save_custom_tensor.py b/tests/pytorch/test_save_custom_tensor.py index e80131e99..c8713ac87 100644 --- a/tests/pytorch/test_save_custom_tensor.py +++ b/tests/pytorch/test_save_custom_tensor.py @@ -31,5 +31,5 @@ def test_hook(): train(model, hook, torch.device("cpu"), optimizer, num_steps=10, save_custom_tensor=True) trial = create_trial(out_dir) custom_tensors = trial.tensor_names(collection=CollectionKeys.DEFAULT) - assert len(custom_tensors) == 3 + assert len(custom_tensors) == 4 shutil.rmtree(out_dir)