-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pairwise summation takes very long time #1918
Comments
Hi! I think you should use lax loop primitives instead of python loops because of the issue of unrolling those loops. See this issue for example #402 |
Thanks, @jotsif. I think you're right about the underlying issue. To say a bit more, this program is probably taking a long time to compile because those Python loops are getting unrolled yielding a large XLA program. (I'd expect that, after it's compiled once, it'll be very fast to execute!) One option is to write these loops using loop primitives that get translated directly to XLA loop constructs (rather than being unrolled), keeping the XLA program small and hence compile times short. You can use from jax.lax import fori_loop
@jit
def total_energy(x):
N = x.shape[0]
return fori_loop(0, N, lambda i, E:
fori_loop(0, i, lambda j, E:
E + potential(np.sqrt(np.sum((x[i]-x[j])**2))),
E),
0.)
(See the docs on asynchronous dispatch for an explanation of why you need That's pretty slow, much slower than CPU execution. I'm not sure what's going on. I'll raise it with our TPU experts next week. However, if this example is similar to the computation you actually care about, you can write it in a vectorized way in terms of NumPy primitives rather than using any explicit loops: @jit
def potential(r):
U = 4.*(np.power(r,-12) - np.power(r,-6))
return U
@jit
def total_energy(x):
distances = np.sqrt(np.sum((x[:, None, :] - x[None, :, :])**2, axis=-1))
return np.sum(np.tril(potential(distances), k=-1))
Is that style an option for you? |
You can also compute Euclidean pairwise distances using a polarization identity but in my brief tests that seemed less numerically stable here and not faster. |
I'm going to close this issue because it's not active, and because I think we covered the relevant stuff. Please re-open (or open a new issue) if that's not the case! |
I'm also trying to efficiently calculate pairwise distances with large matrices. The problem now, however, is that JAX seems to always run out of memory.
No matter how i write it, JAX allocates a 16k x 16k x 7k tensor, when (it seems to me) it should allocate a 16k x 16k matrix. Hmm? (This is almost certainly not the best way to write it!) In the meantime, I shuttle data from jax to pytorch and use the cdist function there, which works quickly, minus the translation time. |
Is there a Jax preferred way to run computations that result in a matrix/tensor of outputs? I've run into some of the issues described here. For a toy example, say we want to calculate pairwise distances between two vectors of size We start with: def compute_distances(v1, v2):
return np.power((v1[:,None,:] - v2[None,:,:]), 2).sum(-1) A problem with this method is numpy allocates a full One way around this is to pre-allocate the output array, loop over one of the inputs to minimize memory allocation, and use Numba to speed things up @numba.jit
def compute_distances(v1, v2):
rows = v1.shape[0]
cols = v2.shape[0]
output = np.zeros((rows, cols), dtype=v1.dtype)
for row in range(rows):
distance = np.power(v1[row] * v2, 2).sum(-1)
output[row] = distance
return output This method runs the above test in 10 seconds with negligible memory usage. Is there a similar approach for Jax? Below is the standard (as I understand it) Jax approach: def compute_pair_distance(a, b):
return jnp.power(a*b, 2).sum(-1)
compute_distances = jax.vmap(jax.vmap(compute_pair_distance, (None,0), 0), (0,None), 0)
distances = compute_distances(v1, v2) This has similar runtime to numba - about 10 seconds - but has a huge memory memory peak of 80 GB. Is there a way to do this sort of computation in Jax where the output array is pre-allocated and filled in? I haven't been able to figure out a way. |
I am trying to compute a very simple sum on the TPU in Google Colaboratory, but the jitted code goes into a very long (>10 minutes) computational cycle.
and test it:
What can be the reason? The notebook can be found here: https://colab.research.google.com/drive/12xuBBs7-P3JiHGclffBPhH5Xf2EdImKn
The text was updated successfully, but these errors were encountered: