Skip to content

Commit

Permalink
Merge branch 'main' into rm-__array__
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Nov 2, 2024
2 parents d630ee5 + dec8c22 commit 5485345
Showing 1 changed file with 46 additions and 8 deletions.
54 changes: 46 additions & 8 deletions array_api_strict/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._array_object import Array
from ._flags import requires_api_version
from ._creation_functions import asarray
from ._data_type_functions import broadcast_to, iinfo

from typing import Optional, Union

Expand Down Expand Up @@ -325,14 +326,51 @@ def clip(
if min is not None and max is not None and np.any(min > max):
raise ValueError("min must be less than or equal to max")

result = np.clip(x._array, min, max)
# Note: NumPy applies type promotion, but the standard specifies the
# return dtype should be the same as x
if result.dtype != x.dtype._np_dtype:
# TODO: I'm not completely sure this always gives the correct thing
# for integer dtypes. See https://github.com/numpy/numpy/issues/24976
result = result.astype(x.dtype._np_dtype)
return Array._new(result, device=x.device)
# np.clip does type promotion but the array API clip requires that the
# output have the same dtype as x. We do this instead of just downcasting
# the result of xp.clip() to handle some corner cases better (e.g.,
# avoiding uint64 -> float64 promotion).

# Note: cases where min or max overflow (integer) or round (float) in the
# wrong direction when downcasting to x.dtype are unspecified. This code
# just does whatever NumPy does when it downcasts in the assignment, but
# other behavior could be preferred, especially for integers. For example,
# this code produces:

# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
# -128

# but an answer of 0 might be preferred. See
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.

# At least handle the case of Python integers correctly (see
# https://github.com/numpy/numpy/pull/26892).
if type(min) is int and min <= iinfo(x.dtype).min:
min = None
if type(max) is int and max >= iinfo(x.dtype).max:
max = None

def _isscalar(a):
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape

result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)

out = asarray(broadcast_to(x, result_shape), copy=True)._array
device = x.device
x = x._array

if min is not None:
a = np.broadcast_to(np.asarray(min), result_shape)
ia = (out < a) | np.isnan(a)

out[ia] = a[ia]
if max is not None:
b = np.broadcast_to(np.asarray(max), result_shape)
ib = (out > b) | np.isnan(b)
out[ib] = b[ib]
return Array._new(out, device=device)

def conj(x: Array, /) -> Array:
"""
Expand Down

0 comments on commit 5485345

Please sign in to comment.