Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: array.dtype.type will never be in ak.types.numpytype._dtype_to_primitive_dict #1841

Merged
merged 11 commits into from
Oct 27, 2022
40 changes: 23 additions & 17 deletions src/awkward/operations/ak_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import builtins
import numbers
from datetime import datetime, timedelta

import awkward as ak

Expand Down Expand Up @@ -64,15 +66,18 @@ def _impl(array):
if array is None:
return ak.types.UnknownType()

elif isinstance(
array,
tuple(x.type for x in ak.types.numpytype._dtype_to_primitive_dict),
elif isinstance(array, np.dtype):
return ak.types.NumpyType(ak.types.numpytype.dtype_to_primitive(array))

elif (
isinstance(array, np.generic)
or isinstance(array, builtins.type)
and issubclass(array, np.generic)
):
return ak.types.NumpyType(
ak.types.numpytype._dtype_to_primitive_dict[array.dtype]
)
primitive = ak.types.numpytype.dtype_to_primitive(np.dtype(array))
return ak.types.NumpyType(primitive)

elif isinstance(array, (bool, np.bool_)):
elif isinstance(array, bool): # np.bool_ in np.generic (above)
return ak.types.NumpyType("bool")

elif isinstance(array, numbers.Integral):
Expand All @@ -81,6 +86,15 @@ def _impl(array):
elif isinstance(array, numbers.Real):
return ak.types.NumpyType("float64")

elif isinstance(array, numbers.Complex):
return ak.types.NumpyType("complex128")

elif isinstance(array, datetime): # np.datetime64 in np.generic (above)
return ak.types.NumpyType("datetime64")

elif isinstance(array, timedelta): # np.timedelta64 in np.generic (above)
return ak.types.NumpyType("timedelta")

elif isinstance(
array,
(
Expand All @@ -95,16 +109,8 @@ def _impl(array):
if len(array.shape) == 0:
return _impl(array.reshape((1,))[0])
else:
try:
out = ak.types.numpytype._dtype_to_primitive_dict[array.dtype.type]
except KeyError as err:
raise ak._errors.wrap_error(
TypeError(
"numpy array type is unrecognized by awkward: %r"
% array.dtype.type
)
) from err
out = ak.types.NumpyType(out)
primitive = ak.types.numpytype.dtype_to_primitive(array.dtype)
out = ak.types.NumpyType(primitive)
for x in array.shape[-1:0:-1]:
out = ak.types.RegularType(out, x)
return ak.types.ArrayType(out, array.shape[0])
Expand Down
23 changes: 23 additions & 0 deletions tests/test_1840-ak_type-to-handle-ndarray-dtype-and-nptypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import numpy as np # noqa: F401
import pytest # noqa: F401

import awkward as ak # noqa: F401


def test_array():
array = np.random.random(size=512).astype(dtype=np.float64)
assert ak.type(array) == ak.types.ArrayType(ak.types.NumpyType("float64"), 512)


def test_dtype():
assert ak.type(np.dtype(np.float64)) == ak.types.NumpyType("float64")


def test_type():
assert ak.type(np.float64) == ak.types.NumpyType("float64")


def test_type_instance():
assert ak.type(np.float64(10.0)) == ak.types.NumpyType("float64")