Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion smdebug/core/tfevent/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# Third Party
import numpy as np

# First Party
from smdebug.core.utils import make_numpy_array

# Local
from .proto.summary_pb2 import HistogramProto, Summary
from .util import make_numpy_array

_INVALID_TAG_CHARACTERS = re.compile(r"[^-/\w\.]")

Expand Down
16 changes: 0 additions & 16 deletions smdebug/core/tfevent/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,3 @@ def make_tensor_proto(nparray_data, tag):
sb = bytes(s, encoding="utf-8")
tensor_proto.string_val.append(sb)
return tensor_proto


def make_numpy_array(x):
if isinstance(x, np.ndarray):
return x
elif np.isscalar(x):
return np.array([x])
elif isinstance(x, tuple):
return np.asarray(x, dtype=x.dtype)
elif isinstance(x, list):
return np.asarray(x)
else:
raise TypeError(
"_make_numpy_array only accepts input types of numpy.ndarray, scalar,"
" while received type {}".format(str(type(x)))
)
18 changes: 18 additions & 0 deletions smdebug/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from pathlib import Path
from typing import Dict, List

# Third Party
import numpy as np

# First Party
from smdebug.core.config_constants import (
CLAIM_FILENAME,
Expand All @@ -22,6 +25,21 @@
logger = get_logger()


def make_numpy_array(x):
if isinstance(x, np.ndarray):
return x
elif np.isscalar(x):
return np.array([x])
elif isinstance(x, tuple):
return np.asarray(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this originally had a dtype. add that back here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above.

elif isinstance(x, list):
return np.asarray(x)
elif isinstance(x, dict):
return np.array(x)
else:
raise TypeError("_make_numpy_array does not support the" " type {}".format(str(type(x))))


def ensure_dir(file_path, is_file=True):
if is_file:
directory = os.path.dirname(file_path)
Expand Down
3 changes: 3 additions & 0 deletions smdebug/mxnet/hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Third Party
import mxnet as mx
from mxnet.ndarray import NDArray

# First Party
from smdebug.core.collection import DEFAULT_MXNET_COLLECTIONS, CollectionKeys
Expand Down Expand Up @@ -253,4 +254,6 @@ def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs):

@staticmethod
def _make_numpy_array(tensor_value):
if isinstance(tensor_value, NDArray):
return tensor_value.asnumpy()
return make_numpy_array(tensor_value)
21 changes: 1 addition & 20 deletions smdebug/mxnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Third Party
import mxnet as mx
import numpy as np
from mxnet.ndarray import NDArray

# First Party
from smdebug.core.reduction_config import ALLOWED_NORMS, ALLOWED_REDUCTIONS
from smdebug.core.reductions import get_numpy_reduction
from smdebug.core.utils import make_numpy_array


def get_reduction_of_data(aggregation_name, tensor_data, tensor_name, abs=False):
Expand Down Expand Up @@ -42,22 +42,3 @@ def get_reduction_of_data(aggregation_name, tensor_data, tensor_name, abs=False)
op = f(tensor_data)
return op
raise RuntimeError("Invalid aggregation_name {0} for mx.NDArray".format(aggregation_name))


def make_numpy_array(x):
if isinstance(x, np.ndarray):
return x
elif np.isscalar(x):
return np.array([x])
elif isinstance(x, NDArray):
return x.asnumpy()
elif isinstance(x, tuple):
# todo: fix this, will crash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this about?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The crash used to occur because of a ref to dtype.
I have removed it and validated the support of tuple in tests/core/test_make_numpy_array.py line 26

return np.asarray(x, dtype=x.dtype)
elif isinstance(x, list):
return np.asarray(x)
else:
raise TypeError(
"_make_numpy_array only accepts input types of numpy.ndarray, scalar,"
" and MXNet NDArray, while received type {}".format(str(type(x)))
)
5 changes: 4 additions & 1 deletion smdebug/pytorch/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys
from smdebug.core.hook import CallbackHook
from smdebug.core.json_config import DEFAULT_WORKER_NAME
from smdebug.core.utils import make_numpy_array
from smdebug.pytorch.collection import CollectionManager
from smdebug.pytorch.singleton_utils import set_hook
from smdebug.pytorch.utils import get_reduction_of_data, make_numpy_array
from smdebug.pytorch.utils import get_reduction_of_data

DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.LOSSES]

Expand Down Expand Up @@ -250,4 +251,6 @@ def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs):

@staticmethod
def _make_numpy_array(tensor_value):
if isinstance(tensor_value, torch.Tensor):
return tensor_value.to(torch.device("cpu")).data.numpy()
return make_numpy_array(tensor_value)
18 changes: 0 additions & 18 deletions smdebug/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,3 @@ def get_reduction_of_data(reduction_name, tensor_data, tensor_name, abs=False):
op = f(tensor_data)
return op
raise RuntimeError("Invalid reduction_name {0}".format(reduction_name))


def make_numpy_array(x):
if isinstance(x, np.ndarray):
return x
elif np.isscalar(x):
return np.array([x])
elif isinstance(x, torch.Tensor):
return x.to(torch.device("cpu")).data.numpy()
elif isinstance(x, tuple):
return np.asarray(x, dtype=x.dtype)
elif isinstance(x, list):
return np.asarray(x)
else:
raise TypeError(
"_make_numpy_array only accepts input types of numpy.ndarray, scalar,"
" and Torch Tensor, while received type {}".format(str(type(x)))
)
3 changes: 1 addition & 2 deletions smdebug/tensorflow/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from smdebug.core.hook import BaseHook
from smdebug.core.modes import ModeKeys
from smdebug.core.reductions import get_numpy_reduction, get_reduction_tensor_name
from smdebug.core.tfevent.util import make_numpy_array
from smdebug.core.utils import serialize_tf_device
from smdebug.core.utils import make_numpy_array, serialize_tf_device
from smdebug.core.writer import FileWriter

# Local
Expand Down
3 changes: 1 addition & 2 deletions smdebug/tensorflow/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
# First Party
from smdebug.core.collection import CollectionKeys
from smdebug.core.tfevent.proto.summary_pb2 import Summary
from smdebug.core.tfevent.util import make_numpy_array
from smdebug.core.utils import match_inc
from smdebug.core.utils import make_numpy_array, match_inc

# Local
from .base_hook import TensorflowBaseHook
Expand Down
2 changes: 1 addition & 1 deletion smdebug/xgboost/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from smdebug.core.hook import CallbackHook
from smdebug.core.json_config import create_hook_from_json_config
from smdebug.core.save_config import SaveConfig
from smdebug.core.tfevent.util import make_numpy_array
from smdebug.core.utils import make_numpy_array
from smdebug.xgboost.singleton_utils import set_hook

# Local
Expand Down
48 changes: 48 additions & 0 deletions tests/core/test_make_numpy_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Third Party
import numpy as np

# First Party
from smdebug.core.utils import make_numpy_array


def test_make_numpy_array():
simple_numpy_array = np.ndarray(shape=(2, 2), dtype=float, order="F")

# Check support for ndarray
try:
x = make_numpy_array(simple_numpy_array)
assert x.all() == simple_numpy_array.all()
except:
assert False

# Check support for scalar
simple_scalar = "foo"
try:
x = make_numpy_array(simple_scalar)
assert x == np.array([simple_scalar])
except:
assert False

# Check support for tuple
simple_tuple = (0.5, 0.7)
try:
x = make_numpy_array(simple_tuple)
assert x.all() == np.array(simple_tuple).all()
except:
assert False

# Check support for list
simple_list = [0.5, 0.7]
try:
x = make_numpy_array(simple_list)
assert x.all() == np.array(simple_list).all()
except:
assert False

# Check support for dict
simple_dict = {"a": 0.5, "b": 0.7}
try:
x = make_numpy_array(simple_dict)
assert x == np.array(simple_dict)
except:
assert False