From dddad2a3dc0e36f76e98288f6b50ccd65fdebced Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Sun, 19 Apr 2020 11:49:15 -0700 Subject: [PATCH] Add top_k jvp and batching rules --- docs/jax.lax.rst | 1 + jax/lax/lax.py | 46 ++++++++++++++++++++++++++++++++++++++++++++-- jax/test_util.py | 7 +++++++ tests/lax_test.py | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+), 2 deletions(-) diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 9ced6a833fb3..3bc02d43c85f 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -123,6 +123,7 @@ Operators sub tan tie_in + top_k transpose diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 1ec5aeeb7f0f..5d274fd30e26 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -1183,7 +1183,8 @@ def sort_key_val(keys: Array, values: Array, dimension: int = -1) -> Array: sorted_keys, sorted_values = result return sorted_keys, sorted_values -def top_k(operand: Array, k: int) -> Array: +def top_k(operand: Array, k: int) -> Tuple[Array, Array]: + """Returns top ``k`` values and their indices along the last axis of ``operand``.""" k = int(k) if k < 0: raise ValueError("k argument to top_k must be nonnegative, got {}".format(k)) @@ -4618,12 +4619,53 @@ def _top_k_abstract_eval(operand, *, k): return (ShapedArray(shape, operand.dtype), ShapedArray(shape, onp.dtype(onp.int32))) +def _top_k_jvp(primals, tangents, *, k): + operand, = primals + tangent, = tangents + primals_out = top_k(operand, k) + if tangent is ad_util.zero: + tangents_out = (ad_util.zero, ad_util.zero) + else: + _, k_idxs = primals_out + idx_shape = k_idxs.shape + rank = len(idx_shape) + gather_index_shape = idx_shape + (1,) + gather_indices = [] + for i in range(rank-1): + _iota = iota(k_idxs.dtype, idx_shape[i]) + _iota = tie_in(operand, _iota) + _iota = broadcast_in_dim(_iota, gather_index_shape, (i,)) + gather_indices.append(_iota) + gather_indices.append(reshape(k_idxs, gather_index_shape)) + gather_indices = concatenate(gather_indices, dimension=rank) + slice_sizes = (1,) * rank + dnums = GatherDimensionNumbers( + offset_dims=(), + collapsed_slice_dims=tuple(range(rank)), + start_index_map=tuple(range(rank))) + tangents_out = (gather(tangent, gather_indices, dnums, slice_sizes), + ad_util.zero) + return primals_out, tangents_out + +def _top_k_batch_rule(batched_args, batch_dims, *, k): + operand, = batched_args + bdim, = batch_dims + if bdim == operand.ndim-1: + perm = onp.arange(operand.ndim) + perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1] + top_k_v, top_k_i = top_k(transpose(operand, perm), k=k) + return (transpose(top_k_v, perm), + transpose(top_k_i, perm)), (bdim, bdim) + else: + return top_k(operand, k=k), (bdim, bdim) + top_k_p = Primitive('top_k') top_k_p.multiple_results = True top_k_p.def_impl(partial(xla.apply_primitive, top_k_p)) top_k_p.def_abstract_eval(_top_k_abstract_eval) xla.translations[top_k_p] = partial(standard_translate, 'top_k') - +ad.primitive_jvps[top_k_p] = _top_k_jvp +batching.primitive_batchers[top_k_p] = _top_k_batch_rule def _tie_in_transpose_rule(t): return [ad_util.zero, t] diff --git a/jax/test_util.py b/jax/test_util.py index 9d6d6778ef16..d3ff1b42ed1d 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -634,6 +634,13 @@ def fn(shape, dtype): return randint(low, high=high, size=shape, dtype=dtype) return fn +def rand_unique_int(): + randchoice = npr.RandomState(0).choice + def fn(shape, dtype): + return randchoice(onp.arange(onp.prod(shape), dtype=dtype), + size=shape, replace=False) + return fn + def rand_bool(): rng = npr.RandomState(0) def generator(shape, dtype): diff --git a/tests/lax_test.py b/tests/lax_test.py index 88c2d758415f..6399089900c3 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2487,6 +2487,22 @@ def args_maker(): fun = lambda keys, values: lax.sort_key_val(keys, values, axis) check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k), + "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k} + for dtype in [onp.float32,] + for shape in [(4,), (5, 5), (2, 1, 4)] + for k in [1, 3] + for rng_factory in [jtu.rand_default])) + def testTopKGrad(self, shape, dtype, k, rng_factory): + rng = rng_factory() + perm_rng = onp.random.RandomState(0) + flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype) + values = perm_rng.permutation(flat_values).reshape(shape) + fun = lambda vs: lax.top_k(vs, k=k)[0] + check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, axes), @@ -3220,6 +3236,27 @@ def testBroadcastShapesReturnsPythonInts(self): out_shape = lax.broadcast_shapes(shape1, shape2) self.assertTrue(all(type(s) is int for s in out_shape)) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}_bdims={}".format( + jtu.format_shape_dtype_string(shape, dtype), k, bdims), + "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory} + for shape in [(4,), (3, 4, 5)] + for k in [1, 3] + for bdims in all_bdims(shape) + # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: + # The top_k indices for integer arrays with identical entries won't match between + # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. + for dtype, rng_factory in itertools.chain( + zip(float_dtypes, itertools.repeat(jtu.rand_default)), + zip(int_dtypes, itertools.repeat(jtu.rand_unique_int))))) + def testTopK(self, shape, dtype, k, bdims, rng_factory): + rng = rng_factory() + # _CheckBatching doesn't work with tuple outputs, so test outputs separately. + op1 = lambda x: lax.top_k(x, k=k)[0] + self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng) + op2 = lambda x: lax.top_k(x, k=k)[1] + self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng) + # TODO Concatenate # TODO Reverse # TODO DynamicSlice