From 811d110e87f8044d8e1d270b5e09cae3a3bbdd3f Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Fri, 5 Jun 2020 02:16:33 -0700 Subject: [PATCH] handle eager tensors --- smdebug/core/reductions.py | 13 +++++++------ smdebug/tensorflow/base_hook.py | 2 ++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/smdebug/core/reductions.py b/smdebug/core/reductions.py index c73b2fe9f..3be6c98b5 100644 --- a/smdebug/core/reductions.py +++ b/smdebug/core/reductions.py @@ -29,14 +29,15 @@ def get_basic_numpy_reduction(reduction_name, numpy_data): return getattr(np, reduction_name)(numpy_data) elif reduction_name in ALLOWED_NORMS: if reduction_name in ["l1", "l2"]: - ord = int(reduction_name[1]) + order = int(reduction_name[1]) else: - ord = None + order = None - if abs: - rv = np.linalg.norm(np.absolute(numpy_data), ord=ord) - else: - rv = np.linalg.norm(numpy_data, ord=ord) + if np.isscalar(numpy_data): + # np.linalg.norm expects array-like inputs + # but numpy_data can sometimes be a scalar value + numpy_data = [numpy_data] + rv = np.linalg.norm(numpy_data, ord=order) return rv return None diff --git a/smdebug/tensorflow/base_hook.py b/smdebug/tensorflow/base_hook.py index 188f6c71e..baaa26476 100644 --- a/smdebug/tensorflow/base_hook.py +++ b/smdebug/tensorflow/base_hook.py @@ -527,6 +527,8 @@ def _make_numpy_array(tensor_value): @staticmethod def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs): + if hasattr(tensor_value, "numpy"): + tensor_value = tensor_value.numpy() return get_numpy_reduction(reduction_name, tensor_value, abs) def add_to_collection(self, collection_name, variable):