Skip to content

Commit ae6d1f0

Browse files
committed
add make numpy array test
1 parent cb45e75 commit ae6d1f0

File tree

11 files changed

+83
-61
lines changed

11 files changed

+83
-61
lines changed

smdebug/core/tfevent/summary.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
# Third Party
55
import numpy as np
66

7+
# First Party
8+
from smdebug.core.utils import make_numpy_array
9+
710
# Local
811
from .proto.summary_pb2 import HistogramProto, Summary
9-
from .util import make_numpy_array
1012

1113
_INVALID_TAG_CHARACTERS = re.compile(r"[^-/\w\.]")
1214

smdebug/core/tfevent/util.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,3 @@ def make_tensor_proto(nparray_data, tag):
5252
sb = bytes(s, encoding="utf-8")
5353
tensor_proto.string_val.append(sb)
5454
return tensor_proto
55-
56-
57-
def make_numpy_array(x):
58-
if isinstance(x, np.ndarray):
59-
return x
60-
elif np.isscalar(x):
61-
return np.array([x])
62-
elif isinstance(x, tuple):
63-
return np.asarray(x, dtype=x.dtype)
64-
elif isinstance(x, list):
65-
return np.asarray(x)
66-
else:
67-
raise TypeError(
68-
"_make_numpy_array only accepts input types of numpy.ndarray, scalar,"
69-
" while received type {}".format(str(type(x)))
70-
)

smdebug/core/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from pathlib import Path
99
from typing import Dict, List
1010

11+
# Third Party
12+
import numpy as np
13+
1114
# First Party
1215
from smdebug.core.config_constants import (
1316
CLAIM_FILENAME,
@@ -22,6 +25,24 @@
2225
logger = get_logger()
2326

2427

28+
def make_numpy_array(x):
29+
if isinstance(x, np.ndarray):
30+
return x
31+
elif np.isscalar(x):
32+
return np.array(x)
33+
elif isinstance(x, tuple):
34+
return np.asarray(x)
35+
elif isinstance(x, list):
36+
return np.asarray(x)
37+
elif isinstance(x, dict):
38+
return np.array(x)
39+
else:
40+
raise TypeError(
41+
"_make_numpy_array only accepts input types of numpy.ndarray, scalar,"
42+
" while received type {}".format(str(type(x)))
43+
)
44+
45+
2546
def ensure_dir(file_path, is_file=True):
2647
if is_file:
2748
directory = os.path.dirname(file_path)

smdebug/mxnet/hook.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Third Party
22
import mxnet as mx
3+
from mxnet.ndarray import NDArray
34

45
# First Party
56
from smdebug.core.collection import DEFAULT_MXNET_COLLECTIONS, CollectionKeys
@@ -253,4 +254,6 @@ def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs):
253254

254255
@staticmethod
255256
def _make_numpy_array(tensor_value):
257+
if isinstance(tensor_value, NDArray):
258+
return tensor_value.asnumpy()
256259
return make_numpy_array(tensor_value)

smdebug/mxnet/utils.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Third Party
22
import mxnet as mx
33
import numpy as np
4-
from mxnet.ndarray import NDArray
54

65
# First Party
76
from smdebug.core.reduction_config import ALLOWED_NORMS, ALLOWED_REDUCTIONS
87
from smdebug.core.reductions import get_numpy_reduction
8+
from smdebug.core.utils import make_numpy_array
99

1010

1111
def get_reduction_of_data(aggregation_name, tensor_data, tensor_name, abs=False):
@@ -42,22 +42,3 @@ def get_reduction_of_data(aggregation_name, tensor_data, tensor_name, abs=False)
4242
op = f(tensor_data)
4343
return op
4444
raise RuntimeError("Invalid aggregation_name {0} for mx.NDArray".format(aggregation_name))
45-
46-
47-
def make_numpy_array(x):
48-
if isinstance(x, np.ndarray):
49-
return x
50-
elif np.isscalar(x):
51-
return np.array([x])
52-
elif isinstance(x, NDArray):
53-
return x.asnumpy()
54-
elif isinstance(x, tuple):
55-
# todo: fix this, will crash
56-
return np.asarray(x, dtype=x.dtype)
57-
elif isinstance(x, list):
58-
return np.asarray(x)
59-
else:
60-
raise TypeError(
61-
"_make_numpy_array only accepts input types of numpy.ndarray, scalar,"
62-
" and MXNet NDArray, while received type {}".format(str(type(x)))
63-
)

smdebug/pytorch/hook.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys
99
from smdebug.core.hook import CallbackHook
1010
from smdebug.core.json_config import DEFAULT_WORKER_NAME
11+
from smdebug.core.utils import make_numpy_array
1112
from smdebug.pytorch.collection import CollectionManager
1213
from smdebug.pytorch.singleton_utils import set_hook
13-
from smdebug.pytorch.utils import get_reduction_of_data, make_numpy_array
14+
from smdebug.pytorch.utils import get_reduction_of_data
1415

1516
DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.LOSSES]
1617

@@ -250,4 +251,6 @@ def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs):
250251

251252
@staticmethod
252253
def _make_numpy_array(tensor_value):
254+
if isinstance(tensor_value, torch.Tensor):
255+
return tensor_value.to(torch.device("cpu")).data.numpy()
253256
return make_numpy_array(tensor_value)

smdebug/pytorch/utils.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,3 @@ def get_reduction_of_data(reduction_name, tensor_data, tensor_name, abs=False):
3434
op = f(tensor_data)
3535
return op
3636
raise RuntimeError("Invalid reduction_name {0}".format(reduction_name))
37-
38-
39-
def make_numpy_array(x):
40-
if isinstance(x, np.ndarray):
41-
return x
42-
elif np.isscalar(x):
43-
return np.array([x])
44-
elif isinstance(x, torch.Tensor):
45-
return x.to(torch.device("cpu")).data.numpy()
46-
elif isinstance(x, tuple):
47-
return np.asarray(x, dtype=x.dtype)
48-
elif isinstance(x, list):
49-
return np.asarray(x)
50-
else:
51-
raise TypeError(
52-
"_make_numpy_array only accepts input types of numpy.ndarray, scalar,"
53-
" and Torch Tensor, while received type {}".format(str(type(x)))
54-
)

smdebug/tensorflow/base_hook.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from smdebug.core.hook import BaseHook
1515
from smdebug.core.modes import ModeKeys
1616
from smdebug.core.reductions import get_numpy_reduction, get_reduction_tensor_name
17-
from smdebug.core.tfevent.util import make_numpy_array
18-
from smdebug.core.utils import serialize_tf_device
17+
from smdebug.core.utils import make_numpy_array, serialize_tf_device
1918
from smdebug.core.writer import FileWriter
2019

2120
# Local

smdebug/tensorflow/session.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
# First Party
88
from smdebug.core.collection import CollectionKeys
99
from smdebug.core.tfevent.proto.summary_pb2 import Summary
10-
from smdebug.core.tfevent.util import make_numpy_array
11-
from smdebug.core.utils import match_inc
10+
from smdebug.core.utils import make_numpy_array, match_inc
1211

1312
# Local
1413
from .base_hook import TensorflowBaseHook

smdebug/xgboost/hook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from smdebug.core.hook import CallbackHook
1414
from smdebug.core.json_config import create_hook_from_json_config
1515
from smdebug.core.save_config import SaveConfig
16-
from smdebug.core.tfevent.util import make_numpy_array
16+
from smdebug.core.utils import make_numpy_array
1717
from smdebug.xgboost.singleton_utils import set_hook
1818

1919
# Local

0 commit comments

Comments
 (0)