Skip to content

Commit

Permalink
Merge pull request #66 from asmeurer/repeat-fix
Browse files Browse the repository at this point in the history
Fix issue with repeat()
  • Loading branch information
asmeurer authored Oct 15, 2024
2 parents 05c8b0f + 6d780a8 commit 019d935
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions array_api_strict/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from ._array_object import Array
from ._creation_functions import asarray
from ._data_type_functions import result_type
from ._dtypes import _integer_dtypes
from ._data_type_functions import astype, result_type
from ._dtypes import _integer_dtypes, int64, uint64
from ._flags import requires_api_version, get_array_api_strict_flags

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -94,7 +94,13 @@ def repeat(
else:
raise TypeError("repeats must be an int or array")

return Array._new(np.repeat(x._array, repeats, axis=axis))
if repeats.dtype == uint64:
# NumPy does not allow uint64 because can't be cast down to x.dtype
# with 'safe' casting. However, repeats values larger than 2**63 are
# infeasable, and even if they are present by mistake, this will
# lead to underflow and an error.
repeats = astype(repeats, int64)
return Array._new(np.repeat(x._array, repeats._array, axis=axis))

# Note: the optional argument is called 'shape', not 'newshape'
def reshape(x: Array,
Expand Down

0 comments on commit 019d935

Please sign in to comment.