diff --git a/smdebug/core/tfevent/summary.py b/smdebug/core/tfevent/summary.py index 2f917e5dc..c029eb73b 100644 --- a/smdebug/core/tfevent/summary.py +++ b/smdebug/core/tfevent/summary.py @@ -4,9 +4,11 @@ # Third Party import numpy as np +# First Party +from smdebug.core.utils import make_numpy_array + # Local from .proto.summary_pb2 import HistogramProto, Summary -from .util import make_numpy_array _INVALID_TAG_CHARACTERS = re.compile(r"[^-/\w\.]") diff --git a/smdebug/core/tfevent/util.py b/smdebug/core/tfevent/util.py index 0708069ab..80776821c 100644 --- a/smdebug/core/tfevent/util.py +++ b/smdebug/core/tfevent/util.py @@ -52,19 +52,3 @@ def make_tensor_proto(nparray_data, tag): sb = bytes(s, encoding="utf-8") tensor_proto.string_val.append(sb) return tensor_proto - - -def make_numpy_array(x): - if isinstance(x, np.ndarray): - return x - elif np.isscalar(x): - return np.array([x]) - elif isinstance(x, tuple): - return np.asarray(x, dtype=x.dtype) - elif isinstance(x, list): - return np.asarray(x) - else: - raise TypeError( - "_make_numpy_array only accepts input types of numpy.ndarray, scalar," - " while received type {}".format(str(type(x))) - ) diff --git a/smdebug/core/utils.py b/smdebug/core/utils.py index af871c527..a2487f049 100644 --- a/smdebug/core/utils.py +++ b/smdebug/core/utils.py @@ -8,6 +8,9 @@ from pathlib import Path from typing import Dict, List +# Third Party +import numpy as np + # First Party from smdebug.core.config_constants import ( CLAIM_FILENAME, @@ -22,6 +25,21 @@ logger = get_logger() +def make_numpy_array(x): + if isinstance(x, np.ndarray): + return x + elif np.isscalar(x): + return np.array([x]) + elif isinstance(x, tuple): + return np.asarray(x) + elif isinstance(x, list): + return np.asarray(x) + elif isinstance(x, dict): + return np.array(x) + else: + raise TypeError("_make_numpy_array does not support the" " type {}".format(str(type(x)))) + + def ensure_dir(file_path, is_file=True): if is_file: directory = os.path.dirname(file_path) diff --git a/smdebug/mxnet/hook.py b/smdebug/mxnet/hook.py index 7234fbf88..f5daa1885 100644 --- a/smdebug/mxnet/hook.py +++ b/smdebug/mxnet/hook.py @@ -1,5 +1,6 @@ # Third Party import mxnet as mx +from mxnet.ndarray import NDArray # First Party from smdebug.core.collection import DEFAULT_MXNET_COLLECTIONS, CollectionKeys @@ -253,4 +254,6 @@ def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs): @staticmethod def _make_numpy_array(tensor_value): + if isinstance(tensor_value, NDArray): + return tensor_value.asnumpy() return make_numpy_array(tensor_value) diff --git a/smdebug/mxnet/utils.py b/smdebug/mxnet/utils.py index aa228145e..767bfe09b 100644 --- a/smdebug/mxnet/utils.py +++ b/smdebug/mxnet/utils.py @@ -1,11 +1,11 @@ # Third Party import mxnet as mx import numpy as np -from mxnet.ndarray import NDArray # First Party from smdebug.core.reduction_config import ALLOWED_NORMS, ALLOWED_REDUCTIONS from smdebug.core.reductions import get_numpy_reduction +from smdebug.core.utils import make_numpy_array def get_reduction_of_data(aggregation_name, tensor_data, tensor_name, abs=False): @@ -42,22 +42,3 @@ def get_reduction_of_data(aggregation_name, tensor_data, tensor_name, abs=False) op = f(tensor_data) return op raise RuntimeError("Invalid aggregation_name {0} for mx.NDArray".format(aggregation_name)) - - -def make_numpy_array(x): - if isinstance(x, np.ndarray): - return x - elif np.isscalar(x): - return np.array([x]) - elif isinstance(x, NDArray): - return x.asnumpy() - elif isinstance(x, tuple): - # todo: fix this, will crash - return np.asarray(x, dtype=x.dtype) - elif isinstance(x, list): - return np.asarray(x) - else: - raise TypeError( - "_make_numpy_array only accepts input types of numpy.ndarray, scalar," - " and MXNet NDArray, while received type {}".format(str(type(x))) - ) diff --git a/smdebug/pytorch/hook.py b/smdebug/pytorch/hook.py index c50debf8a..1eeb0636f 100644 --- a/smdebug/pytorch/hook.py +++ b/smdebug/pytorch/hook.py @@ -8,9 +8,10 @@ from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys from smdebug.core.hook import CallbackHook from smdebug.core.json_config import DEFAULT_WORKER_NAME +from smdebug.core.utils import make_numpy_array from smdebug.pytorch.collection import CollectionManager from smdebug.pytorch.singleton_utils import set_hook -from smdebug.pytorch.utils import get_reduction_of_data, make_numpy_array +from smdebug.pytorch.utils import get_reduction_of_data DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.LOSSES] @@ -250,4 +251,6 @@ def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs): @staticmethod def _make_numpy_array(tensor_value): + if isinstance(tensor_value, torch.Tensor): + return tensor_value.to(torch.device("cpu")).data.numpy() return make_numpy_array(tensor_value) diff --git a/smdebug/pytorch/utils.py b/smdebug/pytorch/utils.py index ea0caf949..20b7665d4 100644 --- a/smdebug/pytorch/utils.py +++ b/smdebug/pytorch/utils.py @@ -34,21 +34,3 @@ def get_reduction_of_data(reduction_name, tensor_data, tensor_name, abs=False): op = f(tensor_data) return op raise RuntimeError("Invalid reduction_name {0}".format(reduction_name)) - - -def make_numpy_array(x): - if isinstance(x, np.ndarray): - return x - elif np.isscalar(x): - return np.array([x]) - elif isinstance(x, torch.Tensor): - return x.to(torch.device("cpu")).data.numpy() - elif isinstance(x, tuple): - return np.asarray(x, dtype=x.dtype) - elif isinstance(x, list): - return np.asarray(x) - else: - raise TypeError( - "_make_numpy_array only accepts input types of numpy.ndarray, scalar," - " and Torch Tensor, while received type {}".format(str(type(x))) - ) diff --git a/smdebug/tensorflow/base_hook.py b/smdebug/tensorflow/base_hook.py index 9a8d46e4e..357610da2 100644 --- a/smdebug/tensorflow/base_hook.py +++ b/smdebug/tensorflow/base_hook.py @@ -14,8 +14,7 @@ from smdebug.core.hook import BaseHook from smdebug.core.modes import ModeKeys from smdebug.core.reductions import get_numpy_reduction, get_reduction_tensor_name -from smdebug.core.tfevent.util import make_numpy_array -from smdebug.core.utils import serialize_tf_device +from smdebug.core.utils import make_numpy_array, serialize_tf_device from smdebug.core.writer import FileWriter # Local diff --git a/smdebug/tensorflow/session.py b/smdebug/tensorflow/session.py index 98bc247c0..fc5b57d50 100644 --- a/smdebug/tensorflow/session.py +++ b/smdebug/tensorflow/session.py @@ -7,8 +7,7 @@ # First Party from smdebug.core.collection import CollectionKeys from smdebug.core.tfevent.proto.summary_pb2 import Summary -from smdebug.core.tfevent.util import make_numpy_array -from smdebug.core.utils import match_inc +from smdebug.core.utils import make_numpy_array, match_inc # Local from .base_hook import TensorflowBaseHook diff --git a/smdebug/xgboost/hook.py b/smdebug/xgboost/hook.py index d8cea203a..8c5ecc29e 100644 --- a/smdebug/xgboost/hook.py +++ b/smdebug/xgboost/hook.py @@ -13,7 +13,7 @@ from smdebug.core.hook import CallbackHook from smdebug.core.json_config import create_hook_from_json_config from smdebug.core.save_config import SaveConfig -from smdebug.core.tfevent.util import make_numpy_array +from smdebug.core.utils import make_numpy_array from smdebug.xgboost.singleton_utils import set_hook # Local diff --git a/tests/core/test_make_numpy_array.py b/tests/core/test_make_numpy_array.py new file mode 100644 index 000000000..3b9fb70cf --- /dev/null +++ b/tests/core/test_make_numpy_array.py @@ -0,0 +1,48 @@ +# Third Party +import numpy as np + +# First Party +from smdebug.core.utils import make_numpy_array + + +def test_make_numpy_array(): + simple_numpy_array = np.ndarray(shape=(2, 2), dtype=float, order="F") + + # Check support for ndarray + try: + x = make_numpy_array(simple_numpy_array) + assert x.all() == simple_numpy_array.all() + except: + assert False + + # Check support for scalar + simple_scalar = "foo" + try: + x = make_numpy_array(simple_scalar) + assert x == np.array([simple_scalar]) + except: + assert False + + # Check support for tuple + simple_tuple = (0.5, 0.7) + try: + x = make_numpy_array(simple_tuple) + assert x.all() == np.array(simple_tuple).all() + except: + assert False + + # Check support for list + simple_list = [0.5, 0.7] + try: + x = make_numpy_array(simple_list) + assert x.all() == np.array(simple_list).all() + except: + assert False + + # Check support for dict + simple_dict = {"a": 0.5, "b": 0.7} + try: + x = make_numpy_array(simple_dict) + assert x == np.array(simple_dict) + except: + assert False