Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
match_inc,
remove_claim_file,
size_and_shape,
validate_custom_tensor_value,
)
from smdebug.core.writer import FileWriter
from smdebug.exceptions import InvalidCollectionConfiguration
Expand Down Expand Up @@ -880,6 +881,7 @@ def __init__(
)
self.exported_collections = False
self.data_type_name = data_type_name
self.custom_tensors_to_save = dict()

def _cleanup(self):
if not self.exported_collections:
Expand All @@ -905,6 +907,23 @@ def _write(self, module_name, var, suffix, idx):
)
return idx

def save_tensor(self, tensor_name, tensor_value, collections_to_write=CollectionKeys.DEFAULT):
if validate_custom_tensor_value(tensor_value, self._make_numpy_array) is False:
self.logger.warn("The tensor value could not be converted into a numpy value")
return
if isinstance(collections_to_write, str):
collections_to_write = [collections_to_write]
for collection in collections_to_write:
self.custom_tensors_to_save[tensor_name] = (tensor_value, collection)

def _save_custom_tensors_post_step(self):
for tensor_name in self.custom_tensors_to_save:
tensor_value, collection_names = self.custom_tensors_to_save[tensor_name]
c = self.collection_manager.get(collection_names, create=True)
c.add_tensor_name(tensor_name)
self._write_raw_tensor(tensor_name, tensor_value, [c])
self.custom_tensors_to_save.clear()

def _write_inputs(self, name, inputs):
tensor_name = name + CallbackHook.INPUT_TENSOR_SUFFIX
idx = self.written_tensor_name_for_step.get(tensor_name, 0)
Expand All @@ -922,3 +941,7 @@ def _write_outputs(self, name, outputs):
@abstractmethod
def _export_model(self):
pass

@staticmethod
def _make_numpy_array(tensor_value):
pass
1 change: 1 addition & 0 deletions smdebug/mxnet/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def forward_pre_hook(self, block, inputs):
self.exported_collections = True

self.last_block = block
self._save_custom_tensors_post_step()

# This hook is invoked by trainer after running the forward pass.
def forward_hook(self, block, inputs, outputs):
Expand Down
2 changes: 2 additions & 0 deletions smdebug/mxnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def make_numpy_array(x):
elif isinstance(x, tuple):
# todo: fix this, will crash
return np.asarray(x, dtype=x.dtype)
elif isinstance(x, list):
return np.asarray(x)
else:
raise TypeError(
"_make_numpy_array only accepts input types of numpy.ndarray, scalar,"
Expand Down
2 changes: 2 additions & 0 deletions smdebug/pytorch/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def forward_hook(self, module, inputs, outputs):

# Output output tensors
self._write_outputs(module_name, outputs)
self._save_custom_tensors_post_step()
self.last_saved_step = self.step

def backward_hook(self, tname):
Expand All @@ -172,6 +173,7 @@ def back(grad):
if grad is not None:
# self.logger.debug(f"Processing the backward step " f"{self.step} for {tname}")
self._save_for_tensor(self.GRADIENT_PREFIX + tname, grad)
self._save_custom_tensors_post_step()

return back

Expand Down
2 changes: 2 additions & 0 deletions smdebug/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def make_numpy_array(x):
return x.to(torch.device("cpu")).data.numpy()
elif isinstance(x, tuple):
return np.asarray(x, dtype=x.dtype)
elif isinstance(x, list):
return np.asarray(x)
else:
raise TypeError(
"_make_numpy_array only accepts input types of numpy.ndarray, scalar,"
Expand Down
1 change: 1 addition & 0 deletions smdebug/tensorflow/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
SMDEBUG_GRADIENTS_KEY = "smdebug_gradients"
SMDEBUG_LAYER_OUTPUTS_KEY = "smdebug_layer_outputs"
SMDEBUG_PREFIX = "smdebug_"
16 changes: 7 additions & 9 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Local
from .base_hook import TensorflowBaseHook
from .collection import CollectionKeys
from .constants import SMDEBUG_GRADIENTS_KEY, SMDEBUG_LAYER_OUTPUTS_KEY
from .constants import SMDEBUG_GRADIENTS_KEY, SMDEBUG_LAYER_OUTPUTS_KEY, SMDEBUG_PREFIX
from .tensor_ref import TensorRef, get_tf_names
from .utils import (
ModelInput,
Expand Down Expand Up @@ -391,7 +391,6 @@ def save_tensor(self, tensor_name, tensor_value, collections_to_write="default")
if validate_custom_tensor_value(tensor_value, self._make_numpy_array) is False:
self.logger.warn("The tensor value could not be converted into a numpy value")
return

if isinstance(collections_to_write, str):
collections_to_write = [collections_to_write]

Expand All @@ -403,11 +402,10 @@ def _save_custom_tensors_post_step(self):
# that the user has saved with the save_tensor api
for tensor_name in self.custom_tensors_to_save:
tensor_value, collection_names = self.custom_tensors_to_save[tensor_name]
self._save_tensor(tensor_name, tensor_value, collection_names)
# Clear saved custom tensors
self._save_tensor_to_file(tensor_name, tensor_value, collection_names)
self.custom_tensors_to_save.clear()

def _save_tensor(self, tensor_name, tensor_value, collections):
def _save_tensor_to_file(self, tensor_name, tensor_value, collections):
if isinstance(collections, set) is False:
collections = {collections}
# Since this function modifies the set, there is a possibility
Expand Down Expand Up @@ -442,7 +440,7 @@ def save_smdebug_logs(self, logs):
for key in logs:
tensors_to_save = []
collections_to_write = set()
if "smdebug_" in key:
if SMDEBUG_PREFIX in key:
# Save Model Outputs
if key in ModelOutputs:
export_name = get_model_output_export_name(key)
Expand Down Expand Up @@ -520,7 +518,7 @@ def _save_layer_input_and_outputs(self, grad_tape=False):
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
else set()
)
self._save_tensor(export_name, tensor.numpy(), input_collection)
self._save_tensor_to_file(export_name, tensor.numpy(), input_collection)
# Save Output
tensor = self.saved_layers[layer_name].layer_output
export_name = get_export_name_for_keras(layer_name, tensor_type="output", tensor=tensor)
Expand All @@ -530,7 +528,7 @@ def _save_layer_input_and_outputs(self, grad_tape=False):
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
else set()
)
self._save_tensor(export_name, tensor.numpy(), output_collection)
self._save_tensor_to_file(export_name, tensor.numpy(), output_collection)

def _save_tensors_post_step(self, batch, logs):
# some tensors available as value from within hook are saved here
Expand Down Expand Up @@ -733,7 +731,7 @@ def _save_layer_values(self, layer_outputs, collection, model=None, inputs=None)
export_name = get_export_name_for_keras(l.name, tensor_suffix)
tensors_to_save.append((export_name, o))
for t_name, t_value in tensors_to_save:
self._save_tensor(t_name, t_value, collections_to_write)
self._save_tensor_to_file(t_name, t_value, collections_to_write)

def save_layer_outputs(self, layer_outputs, model=None):
self._save_layer_values(layer_outputs, self.get_collection(CollectionKeys.LAYERS), model)
Expand Down
9 changes: 9 additions & 0 deletions tests/mxnet/mnist_gluon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def run_mnist_gluon_model(
make_input_zero=False,
normalize_mean=0.13,
normalize_std=0.31,
save_custom_tensor=False,
):
batch_size = 4
if make_input_zero:
Expand Down Expand Up @@ -103,6 +104,8 @@ def run_mnist_gluon_model(
eval_acc_name = "loss_acc"

# Start the training.
if save_custom_tensor:
hook.save_tensor("custom_tensor_1", mx.nd.array([1, 2, 3]))
for epoch in range(1):
train_loss, train_acc, valid_acc = 0.0, 0.0, 0.0
tic = time.time()
Expand All @@ -111,6 +114,8 @@ def run_mnist_gluon_model(

i = 0
for data, label in train_data:
if save_custom_tensor:
hook.save_tensor("custom_tensor_2", mx.nd.array([1, 2, 3]))
data = data.as_in_context(mx.cpu(0))
# forward + backward
with autograd.record():
Expand All @@ -124,6 +129,10 @@ def run_mnist_gluon_model(
train_acc += acc(output, label)
# hook.save_scalar(train_loss_name, train_loss)
# hook.save_scalar(train_acc_name, train_acc)
if save_custom_tensor:
# This tensor will not be added to default collections since
# collections have already been exported
hook.save_tensor("custom_tensor_3", mx.nd.array([1, 2, 3]))
i += 1
if num_steps_train is not None and i >= num_steps_train:
break
Expand Down
32 changes: 32 additions & 0 deletions tests/mxnet/test_custom_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Standard Library
import shutil
from datetime import datetime

# First Party
from smdebug import SaveConfig
from smdebug.core.collection import CollectionKeys
from smdebug.mxnet.hook import Hook as t_hook
from smdebug.trials import create_trial

# Local
from .mnist_gluon_model import run_mnist_gluon_model


def test_hook():
save_config = SaveConfig(save_steps=[0, 1, 2, 3])
run_id = "trial_" + datetime.now().strftime("%Y%m%d-%H%M%S%f")
out_dir = "/tmp/newlogsRunTest/" + run_id
hook = t_hook(out_dir=out_dir, save_config=save_config)
run_mnist_gluon_model(
hook=hook,
num_steps_train=10,
num_steps_eval=10,
register_to_loss_block=True,
save_custom_tensor=True,
)
trial = create_trial(out_dir)
custom_tensors = trial.tensor_names(collection=CollectionKeys.DEFAULT)
all_tensors = trial.tensor_names()
assert len(custom_tensors) == 2
assert len(all_tensors) == 4
shutil.rmtree(out_dir)
3 changes: 1 addition & 2 deletions tests/mxnet/test_hook_all_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def test_hook_all_zero(hook=None, out_dir=None):
assert tr
assert len(tr.steps()) == 4

tnames = tr.tensor_names(regex="conv._input")
tname = tr.tensor_names(regex="conv._input")[0]
tname = tr.tensor_names(regex="conv.+_input")[0]
conv_tensor_value = tr.tensor(tname).value(step_num=0)
is_zero = np.all(conv_tensor_value == 0)
assert is_zero == True
Expand Down
35 changes: 35 additions & 0 deletions tests/pytorch/test_save_custom_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Standard Library
import shutil
from datetime import datetime

# Third Party
import torch
import torch.optim as optim

# First Party
from smdebug.core.collection import CollectionKeys
from smdebug.pytorch import SaveConfig
from smdebug.pytorch.hook import Hook as t_hook
from smdebug.trials import create_trial

# Local
from .utils import Net, train


def test_hook():
run_id = "trial_" + datetime.now().strftime("%Y%m%d-%H%M%S%f")
out_dir = "/tmp/" + run_id
hook = t_hook(
out_dir=out_dir,
save_config=SaveConfig(save_steps=[0, 1, 2, 3]),
include_collections=["relu_activations"],
)

model = Net().to(torch.device("cpu"))
hook.register_module(model)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
train(model, hook, torch.device("cpu"), optimizer, num_steps=10, save_custom_tensor=True)
trial = create_trial(out_dir)
custom_tensors = trial.tensor_names(collection=CollectionKeys.DEFAULT)
assert len(custom_tensors) == 4
shutil.rmtree(out_dir)
13 changes: 11 additions & 2 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,19 @@ def forward(self, x):
return F.log_softmax(x, dim=1)


def train(model, hook, device, optimizer, num_steps=500, set_modes=False):
def train(model, hook, device, optimizer, num_steps=500, set_modes=False, save_custom_tensor=False):
if save_custom_tensor:
hook.save_tensor("custom_tensor_0", torch.tensor([[1.0, -1.0], [1.0, -1.0]]))

if set_modes:
hook.set_mode(modes.TRAIN)

if save_custom_tensor:
hook.save_tensor("custom_tensor_1", torch.tensor([[1.0, -1.0], [1.0, -1.0]]))

model.train()
# for batch_idx, (data, target) in enumerate(train_loader):
if save_custom_tensor:
hook.save_tensor("custom_tensor_2", torch.tensor([[1.0, -1.0], [1.0, -1.0]]))
for i in range(num_steps):
batch_size = 32
data, target = torch.rand(batch_size, 1, 28, 28), torch.rand(batch_size).long()
Expand All @@ -49,6 +56,8 @@ def train(model, hook, device, optimizer, num_steps=500, set_modes=False):
output = model(Variable(data, requires_grad=True))
loss = F.nll_loss(output, target)
hook.record_tensor_value("nll_loss", tensor_value=loss)
if save_custom_tensor:
hook.save_tensor("custom_tensor_3", torch.tensor([[1.0, -1.0], [1.0, -1.0]]))
loss.backward()
optimizer.step()

Expand Down