-
Notifications
You must be signed in to change notification settings - Fork 83
Refactor Make Numpy Array #329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
# Third Party | ||
import mxnet as mx | ||
import numpy as np | ||
from mxnet.ndarray import NDArray | ||
|
||
# First Party | ||
from smdebug.core.reduction_config import ALLOWED_NORMS, ALLOWED_REDUCTIONS | ||
from smdebug.core.reductions import get_numpy_reduction | ||
from smdebug.core.utils import make_numpy_array | ||
|
||
|
||
def get_reduction_of_data(aggregation_name, tensor_data, tensor_name, abs=False): | ||
|
@@ -42,22 +42,3 @@ def get_reduction_of_data(aggregation_name, tensor_data, tensor_name, abs=False) | |
op = f(tensor_data) | ||
return op | ||
raise RuntimeError("Invalid aggregation_name {0} for mx.NDArray".format(aggregation_name)) | ||
|
||
|
||
def make_numpy_array(x): | ||
if isinstance(x, np.ndarray): | ||
return x | ||
elif np.isscalar(x): | ||
return np.array([x]) | ||
elif isinstance(x, NDArray): | ||
return x.asnumpy() | ||
elif isinstance(x, tuple): | ||
# todo: fix this, will crash | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this about? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The crash used to occur because of a ref to dtype. |
||
return np.asarray(x, dtype=x.dtype) | ||
elif isinstance(x, list): | ||
return np.asarray(x) | ||
else: | ||
raise TypeError( | ||
"_make_numpy_array only accepts input types of numpy.ndarray, scalar," | ||
" and MXNet NDArray, while received type {}".format(str(type(x))) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Third Party | ||
import numpy as np | ||
|
||
# First Party | ||
from smdebug.core.utils import make_numpy_array | ||
|
||
|
||
def test_make_numpy_array(): | ||
simple_numpy_array = np.ndarray(shape=(2, 2), dtype=float, order="F") | ||
|
||
# Check support for ndarray | ||
try: | ||
x = make_numpy_array(simple_numpy_array) | ||
assert x.all() == simple_numpy_array.all() | ||
except: | ||
assert False | ||
|
||
# Check support for scalar | ||
simple_scalar = "foo" | ||
try: | ||
x = make_numpy_array(simple_scalar) | ||
assert x == np.array([simple_scalar]) | ||
except: | ||
assert False | ||
|
||
# Check support for tuple | ||
simple_tuple = (0.5, 0.7) | ||
try: | ||
x = make_numpy_array(simple_tuple) | ||
assert x.all() == np.array(simple_tuple).all() | ||
except: | ||
assert False | ||
|
||
# Check support for list | ||
simple_list = [0.5, 0.7] | ||
try: | ||
x = make_numpy_array(simple_list) | ||
assert x.all() == np.array(simple_list).all() | ||
except: | ||
assert False | ||
|
||
# Check support for dict | ||
simple_dict = {"a": 0.5, "b": 0.7} | ||
try: | ||
x = make_numpy_array(simple_dict) | ||
assert x == np.array(simple_dict) | ||
except: | ||
assert False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this originally had a dtype. add that back here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above.