diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index b7f3eaca6948..636bd5bcd901 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -1898,37 +1898,45 @@ def repeat(a, repeats, axis=None): return reshape(ret, new_shape) +def _triangular_mask(n, m, k, tie_in=None): + x = arange(n, dtype=int32) + y = arange(m, dtype=int32) + if tie_in is not None: + x = lax.tie_in(tie_in, x) + y = lax.tie_in(tie_in, y) + return lax.ge( + (lax.broadcast_in_dim(x, shape=(n, m), broadcast_dimensions=(0,)) + + lax.convert_element_type(k, int32)), + lax.broadcast(y, [n])) + @_wraps(onp.tri) def tri(N, M=None, k=0, dtype=None): lax._check_user_dtype_supported(dtype, "tri") - M = M if M is not None else N - dtype = dtype or float32 - x = arange(N, dtype=int32) - y = arange(M, dtype=int32) - mask = lax.ge( - (lax.broadcast_in_dim(x, shape=(N, M), broadcast_dimensions=(0,)) + - int32(k)), - lax.broadcast(y, [N])) - return lax.convert_element_type(mask, dtype) - + mask = _triangular_mask(N, M if M is not None else N, k) + return lax.convert_element_type(mask, dtype or float32) -@_wraps(onp.tril) -def tril(m, k=0): +@jit +def _tril(m, k=0): m_shape = shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.tril must be at least 2D") - mask = tri(*m_shape[-2:], k=k, dtype=bool) + mask = _triangular_mask(m_shape[-2], m_shape[-1], k, tie_in=m) return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) +@_wraps(onp.tril) +def tril(m, k=0): return _tril(m, k) -@_wraps(onp.triu, update_doc=False) -def triu(m, k=0): +@jit +def _triu(m, k): m_shape = shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.triu must be at least 2D") - mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) + mask = _triangular_mask(m_shape[-2], m_shape[-1], k-1, tie_in=m) return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) +@_wraps(onp.triu, update_doc=False) +def triu(m, k=0): return _triu(m, k) + @_wraps(onp.trace) def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):