Skip to content

Commit

Permalink
Merge pull request #2853 from levskaya/topkjvp
Browse files Browse the repository at this point in the history
Add top_k jvp and batching rules and tests
  • Loading branch information
levskaya authored Apr 28, 2020
2 parents c0023f4 + dddad2a commit ca4e396
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ Operators
sub
tan
tie_in
top_k
transpose


Expand Down
46 changes: 44 additions & 2 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,8 @@ def sort_key_val(keys: Array, values: Array, dimension: int = -1) -> Array:
sorted_keys, sorted_values = result
return sorted_keys, sorted_values

def top_k(operand: Array, k: int) -> Array:
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
"""Returns top ``k`` values and their indices along the last axis of ``operand``."""
k = int(k)
if k < 0:
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
Expand Down Expand Up @@ -4624,12 +4625,53 @@ def _top_k_abstract_eval(operand, *, k):
return (ShapedArray(shape, operand.dtype),
ShapedArray(shape, onp.dtype(onp.int32)))

def _top_k_jvp(primals, tangents, *, k):
operand, = primals
tangent, = tangents
primals_out = top_k(operand, k)
if tangent is ad_util.zero:
tangents_out = (ad_util.zero, ad_util.zero)
else:
_, k_idxs = primals_out
idx_shape = k_idxs.shape
rank = len(idx_shape)
gather_index_shape = idx_shape + (1,)
gather_indices = []
for i in range(rank-1):
_iota = iota(k_idxs.dtype, idx_shape[i])
_iota = tie_in(operand, _iota)
_iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
gather_indices.append(_iota)
gather_indices.append(reshape(k_idxs, gather_index_shape))
gather_indices = concatenate(gather_indices, dimension=rank)
slice_sizes = (1,) * rank
dnums = GatherDimensionNumbers(
offset_dims=(),
collapsed_slice_dims=tuple(range(rank)),
start_index_map=tuple(range(rank)))
tangents_out = (gather(tangent, gather_indices, dnums, slice_sizes),
ad_util.zero)
return primals_out, tangents_out

def _top_k_batch_rule(batched_args, batch_dims, *, k):
operand, = batched_args
bdim, = batch_dims
if bdim == operand.ndim-1:
perm = onp.arange(operand.ndim)
perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1]
top_k_v, top_k_i = top_k(transpose(operand, perm), k=k)
return (transpose(top_k_v, perm),
transpose(top_k_i, perm)), (bdim, bdim)
else:
return top_k(operand, k=k), (bdim, bdim)

top_k_p = Primitive('top_k')
top_k_p.multiple_results = True
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval)
xla.translations[top_k_p] = partial(standard_translate, 'top_k')

ad.primitive_jvps[top_k_p] = _top_k_jvp
batching.primitive_batchers[top_k_p] = _top_k_batch_rule

def _tie_in_transpose_rule(t):
return [ad_util.zero, t]
Expand Down
7 changes: 7 additions & 0 deletions jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,13 @@ def fn(shape, dtype):
return randint(low, high=high, size=shape, dtype=dtype)
return fn

def rand_unique_int():
randchoice = npr.RandomState(0).choice
def fn(shape, dtype):
return randchoice(onp.arange(onp.prod(shape), dtype=dtype),
size=shape, replace=False)
return fn

def rand_bool():
rng = npr.RandomState(0)
def generator(shape, dtype):
Expand Down
37 changes: 37 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,6 +2484,22 @@ def args_maker():
fun = lambda keys, values: lax.sort_key_val(keys, values, axis)
check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
for dtype in [onp.float32,]
for shape in [(4,), (5, 5), (2, 1, 4)]
for k in [1, 3]
for rng_factory in [jtu.rand_default]))
def testTopKGrad(self, shape, dtype, k, rng_factory):
rng = rng_factory()
perm_rng = onp.random.RandomState(0)
flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
values = perm_rng.permutation(flat_values).reshape(shape)
fun = lambda vs: lax.top_k(vs, k=k)[0]
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
Expand Down Expand Up @@ -3213,6 +3229,27 @@ def testBroadcastShapesReturnsPythonInts(self):
out_shape = lax.broadcast_shapes(shape1, shape2)
self.assertTrue(all(type(s) is int for s in out_shape))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}_bdims={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, bdims),
"shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory}
for shape in [(4,), (3, 4, 5)]
for k in [1, 3]
for bdims in all_bdims(shape)
# TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed:
# The top_k indices for integer arrays with identical entries won't match between
# vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
for dtype, rng_factory in itertools.chain(
zip(float_dtypes, itertools.repeat(jtu.rand_default)),
zip(int_dtypes, itertools.repeat(jtu.rand_unique_int)))))
def testTopK(self, shape, dtype, k, bdims, rng_factory):
rng = rng_factory()
# _CheckBatching doesn't work with tuple outputs, so test outputs separately.
op1 = lambda x: lax.top_k(x, k=k)[0]
self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng)
op2 = lambda x: lax.top_k(x, k=k)[1]
self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)

# TODO Concatenate
# TODO Reverse
# TODO DynamicSlice
Expand Down

0 comments on commit ca4e396

Please sign in to comment.