diff --git a/smdebug/mxnet/hook.py b/smdebug/mxnet/hook.py index e7f53eb96..e1c85dc5a 100644 --- a/smdebug/mxnet/hook.py +++ b/smdebug/mxnet/hook.py @@ -154,8 +154,9 @@ def forward_hook(self, block, inputs, outputs): # This overwhelms the logs; turn back on if you really need it # logger.debug("Processing the global step {0} for block {1}".format(self.step, block_name)) - # Output input tensor - self._write_inputs(block_name, inputs) + # Output input tensor if it is not a loss block + if isinstance(block, mx.gluon.loss.Loss) is False: + self._write_inputs(block_name, inputs) # Output output tensors self._write_outputs(block_name, outputs) diff --git a/tests/mxnet/test_hook_loss_collection.py b/tests/mxnet/test_hook_loss_collection.py index 95307c316..1a21b64f7 100644 --- a/tests/mxnet/test_hook_loss_collection.py +++ b/tests/mxnet/test_hook_loss_collection.py @@ -33,6 +33,9 @@ def test_loss_collection_default(): loss_val = loss_tensor.value(step_num=1) assert len(loss_val) > 0 + # Assert that we are not logging the inputs to loss block. + input_loss_tensors = tr.tensor_names(regex=".*loss._input*") + assert len(input_loss_tensors) == 0 shutil.rmtree(out_dir)