Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add top_k jvp and batching rules and tests #2853

Merged
merged 1 commit into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Operators
sub
tan
tie_in
top_k
transpose


Expand Down
46 changes: 44 additions & 2 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,8 @@ def sort_key_val(keys: Array, values: Array, dimension: int = -1) -> Array:
sorted_keys, sorted_values = result
return sorted_keys, sorted_values

def top_k(operand: Array, k: int) -> Array:
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
"""Returns top ``k`` values and their indices along the last axis of ``operand``."""
k = int(k)
if k < 0:
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
Expand Down Expand Up @@ -4618,12 +4619,53 @@ def _top_k_abstract_eval(operand, *, k):
return (ShapedArray(shape, operand.dtype),
ShapedArray(shape, onp.dtype(onp.int32)))

def _top_k_jvp(primals, tangents, *, k):
operand, = primals
tangent, = tangents
primals_out = top_k(operand, k)
if tangent is ad_util.zero:
tangents_out = (ad_util.zero, ad_util.zero)
else:
_, k_idxs = primals_out
idx_shape = k_idxs.shape
rank = len(idx_shape)
gather_index_shape = idx_shape + (1,)
gather_indices = []
for i in range(rank-1):
_iota = iota(k_idxs.dtype, idx_shape[i])
_iota = tie_in(operand, _iota)
_iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
gather_indices.append(_iota)
gather_indices.append(reshape(k_idxs, gather_index_shape))
gather_indices = concatenate(gather_indices, dimension=rank)
slice_sizes = (1,) * rank
dnums = GatherDimensionNumbers(
offset_dims=(),
collapsed_slice_dims=tuple(range(rank)),
start_index_map=tuple(range(rank)))
tangents_out = (gather(tangent, gather_indices, dnums, slice_sizes),
ad_util.zero)
return primals_out, tangents_out

def _top_k_batch_rule(batched_args, batch_dims, *, k):
operand, = batched_args
bdim, = batch_dims
if bdim == operand.ndim-1:
perm = onp.arange(operand.ndim)
perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1]
top_k_v, top_k_i = top_k(transpose(operand, perm), k=k)
return (transpose(top_k_v, perm),
transpose(top_k_i, perm)), (bdim, bdim)
else:
return top_k(operand, k=k), (bdim, bdim)

top_k_p = Primitive('top_k')
top_k_p.multiple_results = True
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval)
xla.translations[top_k_p] = partial(standard_translate, 'top_k')

ad.primitive_jvps[top_k_p] = _top_k_jvp
batching.primitive_batchers[top_k_p] = _top_k_batch_rule

def _tie_in_transpose_rule(t):
return [ad_util.zero, t]
Expand Down
7 changes: 7 additions & 0 deletions jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,13 @@ def fn(shape, dtype):
return randint(low, high=high, size=shape, dtype=dtype)
return fn

def rand_unique_int():
randchoice = npr.RandomState(0).choice
def fn(shape, dtype):
return randchoice(onp.arange(onp.prod(shape), dtype=dtype),
size=shape, replace=False)
return fn

def rand_bool():
rng = npr.RandomState(0)
def generator(shape, dtype):
Expand Down
37 changes: 37 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,22 @@ def args_maker():
fun = lambda keys, values: lax.sort_key_val(keys, values, axis)
check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
for dtype in [onp.float32,]
for shape in [(4,), (5, 5), (2, 1, 4)]
for k in [1, 3]
for rng_factory in [jtu.rand_default]))
def testTopKGrad(self, shape, dtype, k, rng_factory):
rng = rng_factory()
perm_rng = onp.random.RandomState(0)
flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
values = perm_rng.permutation(flat_values).reshape(shape)
fun = lambda vs: lax.top_k(vs, k=k)[0]
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
Expand Down Expand Up @@ -3220,6 +3236,27 @@ def testBroadcastShapesReturnsPythonInts(self):
out_shape = lax.broadcast_shapes(shape1, shape2)
self.assertTrue(all(type(s) is int for s in out_shape))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}_bdims={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, bdims),
"shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory}
for shape in [(4,), (3, 4, 5)]
for k in [1, 3]
for bdims in all_bdims(shape)
# TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed:
# The top_k indices for integer arrays with identical entries won't match between
# vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
for dtype, rng_factory in itertools.chain(
zip(float_dtypes, itertools.repeat(jtu.rand_default)),
zip(int_dtypes, itertools.repeat(jtu.rand_unique_int)))))
def testTopK(self, shape, dtype, k, bdims, rng_factory):
rng = rng_factory()
# _CheckBatching doesn't work with tuple outputs, so test outputs separately.
op1 = lambda x: lax.top_k(x, k=k)[0]
self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng)
op2 = lambda x: lax.top_k(x, k=k)[1]
self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)

# TODO Concatenate
# TODO Reverse
# TODO DynamicSlice
Expand Down