Skip to content

Commit

Permalink
Cleanup: fix type issues in lax_numpy.py (#3816)
Browse files Browse the repository at this point in the history
These changes are basically a no-op wth the current default types, but fixes issues if/when those types are changed to 32-bit in the future.
  • Loading branch information
jakevdp authored Jul 22, 2020
1 parent 04a6238 commit fddb28d
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ def _make_scalar_type(np_scalar_type):

_canonicalize_axis = lax._canonicalize_axis

_DEFAULT_TYPEMAP = {
np.bool_: bool_,
np.int_: int_,
np.float_: float_,
np.complex_: complex_
}

def _np_array(obj, dtype=None, **kwargs):
"""Return a properly-typed numpy array.
Expand All @@ -212,16 +219,10 @@ def _np_array(obj, dtype=None, **kwargs):
uses Jax's default dtypes.
"""
arr = np.array(obj, dtype=dtype, **kwargs)
typemap = {
np.bool_: bool_,
np.int_: int_,
np.float_: float_,
np.complex_: complex_
}
obj_dtype = getattr(obj, 'dtype', None)
arr_dtype = np.dtype(arr.dtype)
if dtype is None and obj_dtype is None and arr_dtype in typemap:
arr = arr.astype(typemap[arr_dtype])
arr_dtype = np.dtype(arr.dtype).type
if dtype is None and obj_dtype is None and arr_dtype in _DEFAULT_TYPEMAP:
arr = arr.astype(_DEFAULT_TYPEMAP[arr_dtype])
return arr

_np_asarray = partial(_np_array, copy=False)
Expand Down Expand Up @@ -1860,7 +1861,8 @@ def nanvar(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):

normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims)
normalizer = normalizer - ddof
normalizer_mask = lax.le(normalizer, 0)
zero = lax.full_like(normalizer, 0, shape=())
normalizer_mask = lax.le(normalizer, zero)

result = nansum(centered, axis, keepdims=keepdims)
result = where(normalizer_mask, nan, result)
Expand Down Expand Up @@ -2358,9 +2360,11 @@ def arange(start, stop=None, step=None, dtype=None):
dtype = dtype or _dtype(start)
return lax.iota(dtype, np.ceil(start)) # avoids materializing
else:
start = None if start is None else require(start, msg("start"))
start = require(start, msg("start"))
stop = None if stop is None else require(stop, msg("stop"))
step = None if step is None else require(step, msg("step"))
if dtype is None:
dtype = _dtype(start, *filter(lambda x: x is not None, [stop, step]))
return array(np.arange(start, stop=stop, step=step, dtype=dtype))


Expand Down Expand Up @@ -4248,7 +4252,7 @@ def searchsorted(a, v, side='left', sorter=None):
@_wraps(np.digitize)
def digitize(x, bins, right=False):
if len(bins) == 0:
return zeros(x, dtype=int32)
return zeros(x, dtype=int_)
side = 'right' if not right else 'left'
return where(
bins[-1] >= bins[0],
Expand Down

0 comments on commit fddb28d

Please sign in to comment.