Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce use of lax on static data (e.g. shapes) #2933

Merged
merged 2 commits into from
May 2, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 16 additions & 27 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ def gradient_along_axis(a, h, axis):
return []
axis = [_canonicalize_axis(i, a.ndim) for i in axis]

if min([s for i, s in enumerate(a.shape) if i in axis]) < 2:
if _min([s for i, s in enumerate(a.shape) if i in axis]) < 2:
raise ValueError("Shape of array too small to calculate "
"a numerical gradient, "
"at least 2 elements are required.")
Expand Down Expand Up @@ -1852,7 +1852,7 @@ def _pad(array, pad_width, mode, constant_values):
array = asarray(array)
nd = ndim(array)
pad_width = onp.broadcast_to(onp.asarray(pad_width), (nd, 2))
if any(pad_width < 0):
if onp.any(pad_width < 0):
raise ValueError("index can't contain negative values")

if mode == "constant":
Expand Down Expand Up @@ -2313,52 +2313,41 @@ def _repeat_scalar(a, repeats, axis=None):

@_wraps(onp.repeat)
def repeat(a, repeats, axis=None):
'''
:param repeats: int or array of ints
'''
# use `_repeat_scalar` when possible
if isscalar(repeats):
return _repeat_scalar(a, repeats, axis)
repeats_raveled = ravel(array(repeats)) # make sure it's jax's array type
repeats_raveled = onp.ravel(onp.array(repeats))
if size(repeats_raveled) == 1:
return _repeat_scalar(a, list(repeats_raveled)[0], axis)
return _repeat_scalar(a, repeats_raveled.item(), axis)

if axis is None or isscalar(a):
a = ravel(a)
axis = 0

# repeats must match the dimension along the requested axis
a_shape = list(a.shape)
n = a_shape[axis]
if size(repeats_raveled) != n:
raise ValueError("repeats shape {} does not match the dimension on axis {}".format(
repeats_raveled.shape, n
))
if repeats_raveled.size != a.shape[axis]:
msg = "repeats shape {} does not match the dimension on axis {}"
raise ValueError(msg.format(repeats_raveled.shape, a.shape[axis]))

# calculating the new shape
total = sum(repeats_raveled)
total = repeats_raveled.sum()

new_shape = a_shape[:]
new_shape = list(a.shape)
new_shape[axis] = total

a_flattened = ravel(a)

'''
main algorithm:
first break down raveled input array into list of chunks; each chunk is the unit of repeat
then tile the repeats to have same length as the list of chunks
finally repeat each unit x number of times according to the tiled repeat list
'''
chunks = product(a_shape[:axis+1]).item()
# first break down raveled input array into list of chunks; each chunk is the
# unit of repeat. then tile the repeats to have same length as the list of
# chunks. finally repeat each unit x number of times according to the tiled
# repeat list.
chunks = _prod(a.shape[:axis+1])
a_splitted = split(a_flattened, chunks)
repeats_tiled = tile(repeats_raveled, chunks // len(repeats_raveled))
repeats_tiled = onp.tile(repeats_raveled, chunks // len(repeats_raveled))

ret = array([], dtype=a.dtype)
for i, repeat in enumerate(repeats_tiled):
if not isinstance(repeat, int):
repeat = repeat.item()
if repeat != 0:
ret = concatenate((ret, tile(a_splitted[i], repeat)))
ret = concatenate((ret, tile(a_splitted[i], (repeat,))))

return reshape(ret, new_shape)

Expand Down