Skip to content

Commit

Permalink
Wrap np.tril and np.triu in jit.
Browse files Browse the repository at this point in the history
Avoids materializing potentially large triangular mask constants.
  • Loading branch information
hawkinsp committed Nov 18, 2019
1 parent d323431 commit 7f7c21f
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,37 +1898,45 @@ def repeat(a, repeats, axis=None):

return reshape(ret, new_shape)

def _triangular_mask(n, m, k, tie_in=None):
x = arange(n, dtype=int32)
y = arange(m, dtype=int32)
if tie_in is not None:
x = lax.tie_in(tie_in, x)
y = lax.tie_in(tie_in, y)
return lax.ge(
(lax.broadcast_in_dim(x, shape=(n, m), broadcast_dimensions=(0,)) +
lax.convert_element_type(k, int32)),
lax.broadcast(y, [n]))

@_wraps(onp.tri)
def tri(N, M=None, k=0, dtype=None):
lax._check_user_dtype_supported(dtype, "tri")
M = M if M is not None else N
dtype = dtype or float32
x = arange(N, dtype=int32)
y = arange(M, dtype=int32)
mask = lax.ge(
(lax.broadcast_in_dim(x, shape=(N, M), broadcast_dimensions=(0,)) +
int32(k)),
lax.broadcast(y, [N]))
return lax.convert_element_type(mask, dtype)

mask = _triangular_mask(N, M if M is not None else N, k)
return lax.convert_element_type(mask, dtype or float32)

@_wraps(onp.tril)
def tril(m, k=0):
@jit
def _tril(m, k=0):
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.tril must be at least 2D")
mask = tri(*m_shape[-2:], k=k, dtype=bool)
mask = _triangular_mask(m_shape[-2], m_shape[-1], k, tie_in=m)
return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))

@_wraps(onp.tril)
def tril(m, k=0): return _tril(m, k)

@_wraps(onp.triu, update_doc=False)
def triu(m, k=0):
@jit
def _triu(m, k):
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.triu must be at least 2D")
mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
mask = _triangular_mask(m_shape[-2], m_shape[-1], k-1, tie_in=m)
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)

@_wraps(onp.triu, update_doc=False)
def triu(m, k=0): return _triu(m, k)


@_wraps(onp.trace)
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
Expand Down

0 comments on commit 7f7c21f

Please sign in to comment.